From b547dd2a30e91ce124d50f763997070ea67c6f7e Mon Sep 17 00:00:00 2001 From: Qi Liu Date: Tue, 15 Mar 2016 13:17:45 +0800 Subject: add mask on softmax_ce --- nerv/examples/network_debug/network.lua | 8 ++++---- nerv/layer/softmax_ce.lua | 4 +++- nerv/lib/matrix/generic/cukernel.cu | 22 ++++++++++++++++++++++ nerv/lib/matrix/generic/cumatrix.c | 14 +++++++++++++- nerv/lib/matrix/generic/cumatrix.h | 2 ++ nerv/lib/matrix/generic/mmatrix.c | 23 +++++++++++++++++++++++ nerv/lib/matrix/generic/mmatrix.h | 2 ++ nerv/matrix/generic/cumatrix.c | 1 + nerv/matrix/generic/matrix.c | 12 ++++++++++++ nerv/matrix/generic/mmatrix.c | 1 + 10 files changed, 83 insertions(+), 6 deletions(-) diff --git a/nerv/examples/network_debug/network.lua b/nerv/examples/network_debug/network.lua index 1841d21..386c3b0 100644 --- a/nerv/examples/network_debug/network.lua +++ b/nerv/examples/network_debug/network.lua @@ -72,12 +72,12 @@ function nn:process(data, do_train, reader) for t = 1, self.gconf.chunk_size do local tmp = info.output[t][1]:new_to_host() for i = 1, self.gconf.batch_size do - if t <= info.seq_length[i] then - total_err = total_err + math.log10(math.exp(tmp[i - 1][0])) - total_frame = total_frame + 1 - end + total_err = total_err + math.log10(math.exp(tmp[i - 1][0])) end end + for i = 1, self.gconf.batch_size do + total_frame = total_frame + info.seq_length[i] + end timer:toc('IO') timer:tic('network') diff --git a/nerv/layer/softmax_ce.lua b/nerv/layer/softmax_ce.lua index 7b4a80c..acd4ee6 100644 --- a/nerv/layer/softmax_ce.lua +++ b/nerv/layer/softmax_ce.lua @@ -61,14 +61,16 @@ function SoftmaxCELayer:propagate(input, output, t) end ce:mul_elem(ce, label) ce = ce:rowsum() + ce:set_values_by_mask(self.gconf.mask[t], 0) if output[1] ~= nil then output[1]:copy_from(ce) end -- add total ce self.total_ce = self.total_ce - ce:colsum()[0][0] - self.total_frames = self.total_frames + softmax:nrow() + self.total_frames = self.total_frames + self.gconf.mask[t]:colsum()[0][0] -- TODO: add colsame for uncompressed label if self.compressed then + classified:set_values_by_mask(self.gconf.mask[t], -1) self.total_correct = self.total_correct + classified:colsame(input[2])[0][0] end end diff --git a/nerv/lib/matrix/generic/cukernel.cu b/nerv/lib/matrix/generic/cukernel.cu index 93121dc..4717209 100644 --- a/nerv/lib/matrix/generic/cukernel.cu +++ b/nerv/lib/matrix/generic/cukernel.cu @@ -324,6 +324,15 @@ __global__ void cudak_(rearrange_frm)(const MATRIX_ELEM *a, MATRIX_ELEM *b, b[j + i * stride] = a[j / step + (j % step) * orig_dim + i * stride]; } +__global__ void cudak_(set_values_by_mask)(const MATRIX_ELEM *a, MATRIX_ELEM *b, + int nrow, int ncol, + int astride, int bstride, double val) { + int j = blockIdx.x * blockDim.x + threadIdx.x; + int i = blockIdx.y * blockDim.y + threadIdx.y; + if (i >= nrow || j >= ncol || a[i * astride] != 0.0) return; + b[j + i * bstride] = val; +} + __global__ void cudak_(scale_rows_by_col)(const MATRIX_ELEM *a, MATRIX_ELEM *b, int nrow, int ncol, int astride, int bstride) { @@ -766,6 +775,19 @@ extern "C" { cudaStreamSynchronize(0); } + void cudak_(cuda_set_values_by_mask)(const Matrix *a, Matrix *b, double val) { + dim3 threadsPerBlock(CUDA_THREADS_N, CUDA_THREADS_N); + dim3 numBlocks(CEIL_DIV(b->ncol, threadsPerBlock.x), + CEIL_DIV(b->nrow, threadsPerBlock.y)); + cudak_(set_values_by_mask)<<>> \ + (MATRIX_ELEM_PTR(a), MATRIX_ELEM_PTR(b), + b->nrow, b->ncol, + a->stride / sizeof(MATRIX_ELEM), + b->stride / sizeof(MATRIX_ELEM), + val); + cudaStreamSynchronize(0); + } + void cudak_(cuda_scale_rows_by_row)(const Matrix *a, Matrix *b) { dim3 threadsPerBlock(CUDA_THREADS_N, CUDA_THREADS_N); dim3 numBlocks(CEIL_DIV(b->ncol, threadsPerBlock.x), diff --git a/nerv/lib/matrix/generic/cumatrix.c b/nerv/lib/matrix/generic/cumatrix.c index 6d84663..bc5f285 100644 --- a/nerv/lib/matrix/generic/cumatrix.c +++ b/nerv/lib/matrix/generic/cumatrix.c @@ -515,7 +515,7 @@ void nerv_matrix_(prefixsum_row)(Matrix *a, const Matrix *b, NERV_SET_STATUS(status, NERV_NORMAL, 0); } -void nerv_matrix_(diagonalize)(Matrix *a, CuContext * context, Status *status) { +void nerv_matrix_(diagonalize)(Matrix *a, CuContext *context, Status *status) { if (a->nrow != a->ncol) NERV_EXIT_STATUS(status, MAT_MISMATCH_DIM, 0); PROFILE_START @@ -524,6 +524,18 @@ void nerv_matrix_(diagonalize)(Matrix *a, CuContext * context, Status *status) { NERV_SET_STATUS(status, NERV_NORMAL, 0); } +void nerv_matrix_(set_values_by_mask)(Matrix *a, const Matrix *b, double val, + CuContext *context, Status *status) { + if (a->nrow != b->nrow) + NERV_EXIT_STATUS(status, MAT_MISMATCH_DIM, 0); + if (b->ncol != 1) + NERV_EXIT_STATUS(status, MAT_COL_VECTOR_EXP, 0); + PROFILE_START + cudak_(cuda_set_values_by_mask)(b, a, val); + PROFILE_STOP + NERV_SET_STATUS(status, NERV_NORMAL, 0); +} + static void cuda_matrix_(free)(MATRIX_ELEM *ptr, CuContext *context, Status *status) { CUDA_SAFE_SYNC_CALL(cudaFree(ptr), status); NERV_SET_STATUS(status, NERV_NORMAL, 0); diff --git a/nerv/lib/matrix/generic/cumatrix.h b/nerv/lib/matrix/generic/cumatrix.h index de3a09e..79bfc76 100644 --- a/nerv/lib/matrix/generic/cumatrix.h +++ b/nerv/lib/matrix/generic/cumatrix.h @@ -35,6 +35,8 @@ void nerv_matrix_(fill)(Matrix *self, double val, CuContext *context, Status *status); void nerv_matrix_(diagonalize)(Matrix *self, CuContext *context, Status *status); +void nerv_matrix_(set_values_by_mask)(Matrix *self, Matrix *mask, double val, + CuContext *context, Status *status); void nerv_matrix_(copy_fromd)(Matrix *a, const Matrix *b, int a_begin, int b_begin, int b_end, CuContext *context, Status *status); diff --git a/nerv/lib/matrix/generic/mmatrix.c b/nerv/lib/matrix/generic/mmatrix.c index badddbd..e356de7 100644 --- a/nerv/lib/matrix/generic/mmatrix.c +++ b/nerv/lib/matrix/generic/mmatrix.c @@ -507,6 +507,29 @@ void nerv_matrix_(scale_rows_by_col)(Matrix *a, const Matrix *b, NERV_SET_STATUS(status, NERV_NORMAL, 0); } +void nerv_matrix_(set_values_by_mask)(Matrix *a, const Matrix *b, double val, + MContext *context, Status *status) { + if (a->nrow != b->nrow) + NERV_EXIT_STATUS(status, MAT_MISMATCH_DIM, 0); + if (b->ncol != 1) + NERV_EXIT_STATUS(status, MAT_COL_VECTOR_EXP, 0); + int i, j; + size_t astride = a->stride, bstride = b->stride; + MATRIX_ELEM *arow = MATRIX_ELEM_PTR(a), + *brow = MATRIX_ELEM_PTR(b); + for (i = 0; i < a->nrow; i++) + { + if (brow[0] == 0.0) + { + for (j = 0; j < a->ncol; j++) + arow[j] = val; + } + arow = MATRIX_NEXT_ROW_PTR(arow, astride); + brow = MATRIX_NEXT_ROW_PTR(brow, bstride); + } + NERV_SET_STATUS(status, NERV_NORMAL, 0); +} + static void host_matrix_(free)(MATRIX_ELEM *ptr, MContext *context, Status *status) { free(ptr); NERV_SET_STATUS(status, NERV_NORMAL, 0); diff --git a/nerv/lib/matrix/generic/mmatrix.h b/nerv/lib/matrix/generic/mmatrix.h index 6d17c99..41c39f6 100644 --- a/nerv/lib/matrix/generic/mmatrix.h +++ b/nerv/lib/matrix/generic/mmatrix.h @@ -48,6 +48,8 @@ void nerv_matrix_(expand_frm)(Matrix *a, const Matrix *b, int cont, MContext *context, Status *status); void nerv_matrix_(rearrange_frm)(Matrix *a, const Matrix *b, int step, MContext *context, Status *status); +void nerv_matrix_(set_values_by_mask)(Matrix *a, const Matrix *b, double val, + MContext *context, Status *status); void nerv_matrix_(scale_rows_by_col)(Matrix *a, const Matrix *b, MContext *context, Status *status); void nerv_matrix_(scale_rows_by_row)(Matrix *a, const Matrix *b, diff --git a/nerv/matrix/generic/cumatrix.c b/nerv/matrix/generic/cumatrix.c index 0c90d39..9577fd5 100644 --- a/nerv/matrix/generic/cumatrix.c +++ b/nerv/matrix/generic/cumatrix.c @@ -268,6 +268,7 @@ static const luaL_Reg nerv_matrix_(extra_methods)[] = { {"scale_rows_by_col", nerv_matrix_(lua_scale_rows_by_col)}, {"prefixsum_row", nerv_matrix_(lua_prefixsum_row)}, {"diagonalize", nerv_matrix_(lua_diagonalize)}, + {"set_values_by_mask", nerv_matrix_(lua_set_values_by_mask)}, #ifdef __NERV_FUTURE_CUDA_7 {"update_select_rows_by_rowidx", nerv_matrix_(lua_update_select_rows_by_rowidx)}, {"update_select_rows_by_colidx", nerv_matrix_(lua_update_select_rows_by_colidx)}, diff --git a/nerv/matrix/generic/matrix.c b/nerv/matrix/generic/matrix.c index fe07585..3e91933 100644 --- a/nerv/matrix/generic/matrix.c +++ b/nerv/matrix/generic/matrix.c @@ -395,4 +395,16 @@ static int nerv_matrix_(lua_diagonalize)(lua_State *L) { return 0; } +static int nerv_matrix_(lua_set_values_by_mask)(lua_State *L) { + Status status; + MATRIX_CONTEXT *context; + MATRIX_GET_CONTEXT(L, 4); + Matrix *a = luaT_checkudata(L, 1, nerv_matrix_(tname)); + Matrix *mask = luaT_checkudata(L, 2, nerv_matrix_(tname)); + double val = luaL_checknumber(L, 3); + nerv_matrix_(set_values_by_mask)(a, mask, val, context, &status); + NERV_LUA_CHECK_STATUS(L, status); + return 0; +} + #endif diff --git a/nerv/matrix/generic/mmatrix.c b/nerv/matrix/generic/mmatrix.c index a5e5969..de1eaa3 100644 --- a/nerv/matrix/generic/mmatrix.c +++ b/nerv/matrix/generic/mmatrix.c @@ -117,6 +117,7 @@ static const luaL_Reg nerv_matrix_(extra_methods)[] = { {"clip", nerv_matrix_(lua_clip)}, {"fill", nerv_matrix_(lua_fill)}, {"diagonalize", nerv_matrix_(lua_diagonalize)}, + {"set_values_by_mask", nerv_matrix_(lua_set_values_by_mask)}, {"sigmoid", nerv_matrix_(lua_sigmoid)}, {"sigmoid_grad", nerv_matrix_(lua_sigmoid_grad)}, {"softmax", nerv_matrix_(lua_softmax)}, -- cgit v1.2.3