diff options
author | txh18 <[email protected]> | 2015-10-27 16:24:55 +0800 |
---|---|---|
committer | txh18 <[email protected]> | 2015-10-27 16:24:55 +0800 |
commit | 7c95640c95f1cc1d84b4d49fa97fd922748b88a7 (patch) | |
tree | e78fc611b8768ddf6e9191d597bf667e83b3353b | |
parent | ba8a1c9d5366c22b0b631f26ae1de7c5da2cbaeb (diff) |
added update_select_rows for select_linear:update speed-up
-rw-r--r-- | nerv/examples/lmptb/lmptb/layer/select_linear.lua | 13 | ||||
-rw-r--r-- | nerv/examples/lmptb/main.lua | 3 | ||||
-rw-r--r-- | nerv/lib/matrix/cukernel.h | 1 | ||||
-rw-r--r-- | nerv/lib/matrix/generic/cukernel.cu | 20 | ||||
-rw-r--r-- | nerv/lib/matrix/generic/cumatrix.c | 12 | ||||
-rw-r--r-- | nerv/lib/matrix/generic/cumatrix.h | 1 | ||||
-rw-r--r-- | nerv/matrix/generic/cumatrix.c | 14 | ||||
-rw-r--r-- | nerv/nn/layer_dag.lua | 2 |
8 files changed, 60 insertions, 6 deletions
diff --git a/nerv/examples/lmptb/lmptb/layer/select_linear.lua b/nerv/examples/lmptb/lmptb/layer/select_linear.lua index e4afac4..d4cff0b 100644 --- a/nerv/examples/lmptb/lmptb/layer/select_linear.lua +++ b/nerv/examples/lmptb/lmptb/layer/select_linear.lua @@ -30,12 +30,13 @@ function SL:init(batch_size) end function SL:update(bp_err, input, output) - for i = 1, input[1]:ncol(), 1 do - if (input[1][0][i - 1] ~= 0) then - local word_vec = self.ltp.trans[input[1][0][i - 1]] - word_vec:add(word_vec, bp_err[1][i - 1], 1, - self.gconf.lrate / self.gconf.batch_size) - end - end + --for i = 1, input[1]:ncol(), 1 do + -- if (input[1][0][i - 1] ~= 0) then + -- local word_vec = self.ltp.trans[input[1][0][i - 1]] + --word_vec:add(word_vec, bp_err[1][i - 1], 1, - self.gconf.lrate / self.gconf.batch_size) + -- end + --end + self.ltp.trans:update_select_rows(bp_err[1], input[1], - self.gconf.lrate / self.gconf.batch_size, 0) end function SL:propagate(input, output) diff --git a/nerv/examples/lmptb/main.lua b/nerv/examples/lmptb/main.lua index 13d610e..d505456 100644 --- a/nerv/examples/lmptb/main.lua +++ b/nerv/examples/lmptb/main.lua @@ -17,6 +17,7 @@ function prepare_parameters(global_conf, first_time) ltp_ih = nerv.LinearTransParam("ltp_ih", global_conf) ltp_ih.trans = global_conf.cumat_type(global_conf.vocab:size() + 1, global_conf.hidden_size) --index 0 is for zero, others correspond to vocab index(starting from 1) ltp_ih.trans:generate(global_conf.param_random) + ltp_ih.trans[0]:fill(0) ltp_hh = nerv.LinearTransParam("ltp_hh", global_conf) ltp_hh.trans = global_conf.cumat_type(global_conf.hidden_size, global_conf.hidden_size) @@ -153,6 +154,7 @@ function propagateFile(global_conf, dagL, fn, config) local token_store = {} local hidden_store = {} local sigmoidL_ref = dagL.layers["sigmoidL1"] + local inputL_ref = dagL.layers["selectL1"] token_store[tnow] = feeder:get_batch() for i = 1, global_conf.bptt + 1 do @@ -209,6 +211,7 @@ function propagateFile(global_conf, dagL, fn, config) global_conf.timer:tic("dagL-update") dagL:update(dagL_err, dagL_input, dagL_output) global_conf.timer:toc("dagL-update") + inputL_ref.layer.ltp.trans[0]:fill(0) --afraid that this will be updated in select_linear:update end for i = 1, global_conf.batch_size, 1 do diff --git a/nerv/lib/matrix/cukernel.h b/nerv/lib/matrix/cukernel.h index 2126c6f..fffe0bc 100644 --- a/nerv/lib/matrix/cukernel.h +++ b/nerv/lib/matrix/cukernel.h @@ -13,6 +13,7 @@ void cudak_(cuda_softmax_final)(const Matrix *a, const Matrix *max, const Matrix void cudak_(cuda_add_row)(const Matrix *a, Matrix *b, double beta); void cudak_(cuda_fill)(Matrix *a, double val); void cudak_(cuda_clip)(Matrix *a, double val_1, double val_2); +void cudak_(cuda_update_select_rows)(Matrix *c, const Matrix *a, const Matrix *idx, double alpha, double beta); void cudak_(cuda_expand_frm)(const Matrix *a, Matrix *b, int context); void cudak_(cuda_rearrange_frm)(const Matrix *a, Matrix *b, int step); void cudak_(cuda_scale_rows_by_row)(const Matrix *a, Matrix *b); diff --git a/nerv/lib/matrix/generic/cukernel.cu b/nerv/lib/matrix/generic/cukernel.cu index 08feb59..6c8e64a 100644 --- a/nerv/lib/matrix/generic/cukernel.cu +++ b/nerv/lib/matrix/generic/cukernel.cu @@ -225,6 +225,15 @@ __global__ void cudak_(clip)(MATRIX_ELEM *a, a[j + i * stride] = val_1; } +__global__ void cudak_(update_select_rows)(MATRIX_ELEM *c, const MATRIX_ELEM *a, const MATRIX_ELEM *idx, + int nrow_a, int ncol_a, int stride_c, int stride_a, double alpha, double beta) { + int j = blockIdx.x * blockDim.x + threadIdx.x; + int i = blockIdx.y * blockDim.y + threadIdx.y; + if (i >= nrow_a || j >= ncol_a) return; + int i_c = lrintf(idx[i]); + c[j + i_c * stride_c] = c[j + i_c * stride_c] * (1 - beta * alpha) + a[j + i * stride_a] * alpha; +} + __global__ void cudak_(expand_frm)(const MATRIX_ELEM *a, MATRIX_ELEM *b, int nrow, int ncol, int enrow, int encol, @@ -540,6 +549,17 @@ extern "C" { a->stride / sizeof(MATRIX_ELEM), val_1, val_2); cudaStreamSynchronize(0); } + + void cudak_(cuda_update_select_rows)(Matrix *c, const Matrix *a, const Matrix *idx, double alpha, double beta) { + dim3 threadsPerBlock(CUDA_THREADS_N, CUDA_THREADS_N); + dim3 numBlocks(CEIL_DIV(a->ncol, threadsPerBlock.x), + CEIL_DIV(a->nrow, threadsPerBlock.y)); + cudak_(update_select_rows)<<<numBlocks, threadsPerBlock>>> \ + (MATRIX_ELEM_PTR(c), MATRIX_ELEM_PTR(a), MATRIX_ELEM_PTR(idx), + a->nrow, a->ncol, c->stride / sizeof(MATRIX_ELEM), + a->stride / sizeof(MATRIX_ELEM), alpha, beta); + cudaStreamSynchronize(0); + } void cudak_(cuda_expand_frm)(const Matrix *a, Matrix *b, int context) { dim3 threadsPerBlock(CUDA_THREADS_N, CUDA_THREADS_N); diff --git a/nerv/lib/matrix/generic/cumatrix.c b/nerv/lib/matrix/generic/cumatrix.c index 770e503..2dc5899 100644 --- a/nerv/lib/matrix/generic/cumatrix.c +++ b/nerv/lib/matrix/generic/cumatrix.c @@ -359,6 +359,18 @@ void nerv_matrix_(copy_rows_fromd_by_idx)(Matrix *a, const Matrix *b, NERV_SET_STATUS(status, NERV_NORMAL, 0); } +void nerv_matrix_(update_select_rows)(Matrix *c, const Matrix *a, const Matrix *idx, double alpha, double beta, Status *status) { + long nrow = a->nrow; + if (idx->nrow != 1) + NERV_EXIT_STATUS(status, MAT_IDX_VECTOR_EXP, 0); + if (a->ncol != c->ncol) + NERV_EXIT_STATUS(status, MAT_MISMATCH_DIM, 0); + PROFILE_START + cudak_(cuda_update_select_rows)(c, a, idx, alpha, beta); + PROFILE_STOP + NERV_SET_STATUS(status, NERV_NORMAL, 0); +} + void nerv_matrix_(expand_frm)(Matrix *a, const Matrix *b, int context, Status *status) { if (a->nrow != b->nrow) diff --git a/nerv/lib/matrix/generic/cumatrix.h b/nerv/lib/matrix/generic/cumatrix.h index 04e8c5a..21c29b7 100644 --- a/nerv/lib/matrix/generic/cumatrix.h +++ b/nerv/lib/matrix/generic/cumatrix.h @@ -42,6 +42,7 @@ void nerv_matrix_(copy_rows_fromh_by_idx)(Matrix *a, const Matrix *b, const Matrix *idx, int b_begin, Status *status); void nerv_matrix_(copy_rows_fromd_by_idx)(Matrix *a, const Matrix *b, const Matrix *idx, int b_begin, Status *status); +void nerv_matrix_(update_select_rows)(Matrix *c, const Matrix *a, const Matrix *idx, double alpha, double beta, Status *status); void nerv_matrix_(expand_frm)(Matrix *a, const Matrix *b, int context, Status *status); diff --git a/nerv/matrix/generic/cumatrix.c b/nerv/matrix/generic/cumatrix.c index 08cb4c2..623352e 100644 --- a/nerv/matrix/generic/cumatrix.c +++ b/nerv/matrix/generic/cumatrix.c @@ -291,6 +291,19 @@ static int nerv_matrix_(lua_scale_rows_by_row)(lua_State *L) { return 0; } +static int nerv_matrix_(lua_update_select_rows)(lua_State *L) { + //Update c's select rows, i.e. c[idx[i]] = c[idx[i]] * (1 - beta * alpha) + a[i] * alpha + Status status; + Matrix *c = luaT_checkudata(L, 1, nerv_matrix_(tname)); + const Matrix *a = luaT_checkudata(L, 2, nerv_matrix_(tname)); + const Matrix *idx = luaT_checkudata(L, 3, nerv_matrix_(tname)); + MATRIX_ELEM alpha = luaL_checknumber(L, 4); + MATRIX_ELEM beta = luaL_checknumber(L, 5); + nerv_matrix_(update_select_rows)(c, a, idx, alpha, beta, &status); + NERV_LUA_CHECK_STATUS(L, status); + return 0; +} + static const luaL_Reg nerv_matrix_(extra_methods)[] = { {"colsum", nerv_matrix_(lua_colsum)}, {"colsame", nerv_matrix_(lua_colsame)}, @@ -310,6 +323,7 @@ static const luaL_Reg nerv_matrix_(extra_methods)[] = { {"add_row", nerv_matrix_(lua_add_row)}, {"clip", nerv_matrix_(lua_clip)}, {"fill", nerv_matrix_(lua_fill)}, + {"update_select_rows", nerv_matrix_(lua_update_select_rows)}, {"sigmoid", nerv_matrix_(lua_sigmoid)}, {"sigmoid_grad", nerv_matrix_(lua_sigmoid_grad)}, {"softmax", nerv_matrix_(lua_softmax)}, diff --git a/nerv/nn/layer_dag.lua b/nerv/nn/layer_dag.lua index 91818d6..4904f4f 100644 --- a/nerv/nn/layer_dag.lua +++ b/nerv/nn/layer_dag.lua @@ -251,7 +251,9 @@ function DAGLayer:update(bp_err, input, output) -- print("update") for id, ref in pairs(self.queue) do -- print(ref.layer.id) + self.gconf.timer:tic("(update)"..ref.layer.id); ref.layer:update(ref.err_inputs, ref.inputs, ref.outputs) + self.gconf.timer:toc("(update)"..ref.layer.id); end end |