diff options
author | Qi Liu <[email protected]> | 2016-03-11 21:28:29 +0800 |
---|---|---|
committer | Qi Liu <[email protected]> | 2016-03-11 21:28:29 +0800 |
commit | 442e261a0f2cb8836e2859bd814a267cc8aa5db2 (patch) | |
tree | 112dd3932a8d23fc2e36f67c347f13bb2d19232a | |
parent | e2a9af061db485d4388902d738c9d8be3f94ab34 (diff) | |
parent | 14c1997203e04838b1737716dc385e1aa08fe91f (diff) |
update diagonlal lstm
-rw-r--r-- | nerv/layer/lstm.lua | 6 | ||||
-rw-r--r-- | nerv/layer/lstm_gate.lua | 7 | ||||
-rw-r--r-- | nerv/lib/matrix/generic/cukernel.cu | 18 | ||||
-rw-r--r-- | nerv/lib/matrix/generic/cumatrix.c | 9 | ||||
-rw-r--r-- | nerv/lib/matrix/generic/cumatrix.h | 2 | ||||
-rw-r--r-- | nerv/lib/matrix/generic/mmatrix.c | 17 | ||||
-rw-r--r-- | nerv/lib/matrix/generic/mmatrix.h | 1 | ||||
-rw-r--r-- | nerv/matrix/generic/cumatrix.c | 1 | ||||
-rw-r--r-- | nerv/matrix/generic/matrix.c | 8 | ||||
-rw-r--r-- | nerv/matrix/generic/mmatrix.c | 1 |
10 files changed, 67 insertions, 3 deletions
diff --git a/nerv/layer/lstm.lua b/nerv/layer/lstm.lua index 5dbcc20..56f674a 100644 --- a/nerv/layer/lstm.lua +++ b/nerv/layer/lstm.lua @@ -29,9 +29,9 @@ function LSTMLayer:__init(id, global_conf, layer_conf) outputTanh = {dim_in = {dout}, dim_out = {dout}}, }, ['nerv.LSTMGateLayer'] = { - forgetGate = {dim_in = {din, dout, dout}, dim_out = {dout}, pr = pr}, - inputGate = {dim_in = {din, dout, dout}, dim_out = {dout}, pr = pr}, - outputGate = {dim_in = {din, dout, dout}, dim_out = {dout}, pr = pr}, + forgetGate = {dim_in = {din, dout, dout}, dim_out = {dout}, param_type = {'N', 'N', 'D'}, pr = pr}, + inputGate = {dim_in = {din, dout, dout}, dim_out = {dout}, param_type = {'N', 'N', 'D'}, pr = pr}, + outputGate = {dim_in = {din, dout, dout}, dim_out = {dout}, param_type = {'N', 'N', 'D'}, pr = pr}, }, ['nerv.ElemMulLayer'] = { inputGateMul = {dim_in = {dout, dout}, dim_out = {dout}}, diff --git a/nerv/layer/lstm_gate.lua b/nerv/layer/lstm_gate.lua index 7a27bab..e690721 100644 --- a/nerv/layer/lstm_gate.lua +++ b/nerv/layer/lstm_gate.lua @@ -3,6 +3,7 @@ local LSTMGateLayer = nerv.class('nerv.LSTMGateLayer', 'nerv.Layer') function LSTMGateLayer:__init(id, global_conf, layer_conf) nerv.Layer.__init(self, id, global_conf, layer_conf) + self.param_type = layer_conf.param_type self:check_dim_len(-1, 1) --accept multiple inputs self:bind_params() end @@ -12,6 +13,9 @@ function LSTMGateLayer:bind_params() self["ltp" .. i] = self:find_param("ltp" .. i, self.lconf, self.gconf, nerv.LinearTransParam, {self.dim_in[i], self.dim_out[1]}) + if self.param_type[i] == 'D' then + self["ltp" .. i].trans:diagonalize() + end end self.bp = self:find_param("bp", self.lconf, self.gconf, nerv.BiasParam, {1, self.dim_out[1]}) @@ -63,6 +67,9 @@ function LSTMGateLayer:update(bp_err, input, output) self.err_bakm:sigmoid_grad(bp_err[1], output[1]) for i = 1, #self.dim_in do self["ltp" .. i]:update_by_err_input(self.err_bakm, input[i]) + if self.param_type[i] == 'D' then + self["ltp" .. i].trans:diagonalize() + end end self.bp:update_by_gradient(self.err_bakm:colsum()) end diff --git a/nerv/lib/matrix/generic/cukernel.cu b/nerv/lib/matrix/generic/cukernel.cu index 0e09cfa..93121dc 100644 --- a/nerv/lib/matrix/generic/cukernel.cu +++ b/nerv/lib/matrix/generic/cukernel.cu @@ -250,6 +250,14 @@ __global__ void cudak_(fill)(MATRIX_ELEM *a, a[j + i * stride] = val; } +__global__ void cudak_(diagonalize)(MATRIX_ELEM *a, + int nrow, int ncol, int stride) { + int j = blockIdx.x * blockDim.x + threadIdx.x; + int i = blockIdx.y * blockDim.y + threadIdx.y; + if (i >= nrow || j >= ncol || i == j) return; + a[j + i * stride] = 0; +} + __global__ void cudak_(clip)(MATRIX_ELEM *a, int nrow, int ncol, int stride, double val_1, double val_2) { int j = blockIdx.x * blockDim.x + threadIdx.x; @@ -678,6 +686,16 @@ extern "C" { cudaStreamSynchronize(0); } + void cudak_(cuda_diagonalize)(Matrix *a) { + dim3 threadsPerBlock(CUDA_THREADS_N, CUDA_THREADS_N); + dim3 numBlocks(CEIL_DIV(a->ncol, threadsPerBlock.x), + CEIL_DIV(a->nrow, threadsPerBlock.y)); + cudak_(diagonalize)<<<numBlocks, threadsPerBlock>>> \ + (MATRIX_ELEM_PTR(a), a->nrow, a->ncol, + a->stride / sizeof(MATRIX_ELEM)); + cudaStreamSynchronize(0); + } + void cudak_(cuda_clip)(Matrix *a, double val_1, double val_2) { dim3 threadsPerBlock(CUDA_THREADS_N, CUDA_THREADS_N); dim3 numBlocks(CEIL_DIV(a->ncol, threadsPerBlock.x), diff --git a/nerv/lib/matrix/generic/cumatrix.c b/nerv/lib/matrix/generic/cumatrix.c index 6342d90..6d84663 100644 --- a/nerv/lib/matrix/generic/cumatrix.c +++ b/nerv/lib/matrix/generic/cumatrix.c @@ -515,6 +515,15 @@ void nerv_matrix_(prefixsum_row)(Matrix *a, const Matrix *b, NERV_SET_STATUS(status, NERV_NORMAL, 0); } +void nerv_matrix_(diagonalize)(Matrix *a, CuContext * context, Status *status) { + if (a->nrow != a->ncol) + NERV_EXIT_STATUS(status, MAT_MISMATCH_DIM, 0); + PROFILE_START + cudak_(cuda_diagonalize)(a); + PROFILE_STOP + NERV_SET_STATUS(status, NERV_NORMAL, 0); +} + static void cuda_matrix_(free)(MATRIX_ELEM *ptr, CuContext *context, Status *status) { CUDA_SAFE_SYNC_CALL(cudaFree(ptr), status); NERV_SET_STATUS(status, NERV_NORMAL, 0); diff --git a/nerv/lib/matrix/generic/cumatrix.h b/nerv/lib/matrix/generic/cumatrix.h index fe83b5d..de3a09e 100644 --- a/nerv/lib/matrix/generic/cumatrix.h +++ b/nerv/lib/matrix/generic/cumatrix.h @@ -33,6 +33,8 @@ void nerv_matrix_(clip)(Matrix *self, double val1, double val2, CuContext *context, Status *status); void nerv_matrix_(fill)(Matrix *self, double val, CuContext *context, Status *status); +void nerv_matrix_(diagonalize)(Matrix *self, + CuContext *context, Status *status); void nerv_matrix_(copy_fromd)(Matrix *a, const Matrix *b, int a_begin, int b_begin, int b_end, CuContext *context, Status *status); diff --git a/nerv/lib/matrix/generic/mmatrix.c b/nerv/lib/matrix/generic/mmatrix.c index fb99b53..6272cbe 100644 --- a/nerv/lib/matrix/generic/mmatrix.c +++ b/nerv/lib/matrix/generic/mmatrix.c @@ -274,6 +274,23 @@ void nerv_matrix_(fill)(Matrix *self, double val, NERV_SET_STATUS(status, NERV_NORMAL, 0); } +void nerv_matrix_(diagonalize)(Matrix *selfa, + MContext *context, Status *status) { + if (self->nrow != self->ncol) + NERV_EXIT_STATUS(status, MAT_MISMATCH_DIM, 0); + int i, j; + size_t astride = self->stride; + MATRIX_ELEM *arow = MATRIX_ELEM_PTR(self); + for (i = 0; i < self->nrow; i++) + { + for (j = 0; j < self->ncol; j++) + if (i != j) + arow[j] = 0; + arow = MATRIX_NEXT_ROW_PTR(arow, astride); + } + NERV_SET_STATUS(status, NERV_NORMAL, 0); +} + void nerv_matrix_(sigmoid)(Matrix *a, const Matrix *b, MContext *context, Status *status) { CHECK_SAME_DIMENSION(a, b, status); diff --git a/nerv/lib/matrix/generic/mmatrix.h b/nerv/lib/matrix/generic/mmatrix.h index 6e0589a..6d17c99 100644 --- a/nerv/lib/matrix/generic/mmatrix.h +++ b/nerv/lib/matrix/generic/mmatrix.h @@ -27,6 +27,7 @@ void nerv_matrix_(add_row)(Matrix *b, const Matrix *a, double beta, MContext *context, Status *status); void nerv_matrix_(clip)(Matrix *self, double val1, double val2, MContext *context, Status *status); +void nerv_matrix_(diagonalize)(Matrix *self, MContext *context, Status *status); void nerv_matrix_(fill)(Matrix *self, double val, MContext *context, Status *status); void nerv_matrix_(copy_fromh)(Matrix *a, const Matrix *b, int a_begin, int b_begin, int b_end, diff --git a/nerv/matrix/generic/cumatrix.c b/nerv/matrix/generic/cumatrix.c index 00e4ee3..0c90d39 100644 --- a/nerv/matrix/generic/cumatrix.c +++ b/nerv/matrix/generic/cumatrix.c @@ -267,6 +267,7 @@ static const luaL_Reg nerv_matrix_(extra_methods)[] = { {"scale_rows_by_row", nerv_matrix_(lua_scale_rows_by_row)}, {"scale_rows_by_col", nerv_matrix_(lua_scale_rows_by_col)}, {"prefixsum_row", nerv_matrix_(lua_prefixsum_row)}, + {"diagonalize", nerv_matrix_(lua_diagonalize)}, #ifdef __NERV_FUTURE_CUDA_7 {"update_select_rows_by_rowidx", nerv_matrix_(lua_update_select_rows_by_rowidx)}, {"update_select_rows_by_colidx", nerv_matrix_(lua_update_select_rows_by_colidx)}, diff --git a/nerv/matrix/generic/matrix.c b/nerv/matrix/generic/matrix.c index 8c2f871..b544dd9 100644 --- a/nerv/matrix/generic/matrix.c +++ b/nerv/matrix/generic/matrix.c @@ -385,4 +385,12 @@ static int nerv_matrix_(lua_scale_rows_by_row)(lua_State *L) { return 0; } +static int nerv_matrix_(lua_diagonalize)(lua_State *L) { + Status status; + Matrix *a = luaT_checkudata(L, 1, nerv_matrix_(tname)); + nerv_matrix_(diagonalize)(a, &status); + NERV_LUA_CHECK_STATUS(L, status); + return 0; +} + #endif diff --git a/nerv/matrix/generic/mmatrix.c b/nerv/matrix/generic/mmatrix.c index 1f37173..a5e5969 100644 --- a/nerv/matrix/generic/mmatrix.c +++ b/nerv/matrix/generic/mmatrix.c @@ -116,6 +116,7 @@ static const luaL_Reg nerv_matrix_(extra_methods)[] = { {"add_row", nerv_matrix_(lua_add_row)}, {"clip", nerv_matrix_(lua_clip)}, {"fill", nerv_matrix_(lua_fill)}, + {"diagonalize", nerv_matrix_(lua_diagonalize)}, {"sigmoid", nerv_matrix_(lua_sigmoid)}, {"sigmoid_grad", nerv_matrix_(lua_sigmoid_grad)}, {"softmax", nerv_matrix_(lua_softmax)}, |