From a51498d2761714f4e034f036b84b4b89a88e9539 Mon Sep 17 00:00:00 2001 From: Qi Liu Date: Wed, 24 Feb 2016 16:16:03 +0800 Subject: update LSTM layer --- nerv/layer/lstm.lua | 11 +++++++---- nerv/layer/lstm_gate.lua | 7 +++++++ nerv/lib/matrix/generic/cukernel.cu | 18 ++++++++++++++++++ nerv/lib/matrix/generic/cumatrix.c | 9 +++++++++ nerv/lib/matrix/generic/cumatrix.h | 1 + nerv/lib/matrix/generic/mmatrix.c | 16 ++++++++++++++++ nerv/lib/matrix/generic/mmatrix.h | 1 + nerv/matrix/generic/cumatrix.c | 1 + nerv/matrix/generic/matrix.c | 8 ++++++++ nerv/matrix/generic/mmatrix.c | 1 + 10 files changed, 69 insertions(+), 4 deletions(-) (limited to 'nerv') diff --git a/nerv/layer/lstm.lua b/nerv/layer/lstm.lua index 500bd87..b0cfe08 100644 --- a/nerv/layer/lstm.lua +++ b/nerv/layer/lstm.lua @@ -19,7 +19,7 @@ function LSTMLayer:__init(id, global_conf, layer_conf) return self.id .. '.' .. str end local din1, din2, din3 = self.dim_in[1], self.dim_in[2], self.dim_in[3] - local dout1, dout2, dout3 = self.dim_out[1], self.dim_out[2], self.dim_out[3] + local dout1, dout2 = self.dim_out[1], self.dim_out[2] local layers = { ["nerv.CombinerLayer"] = { [ap("inputXDup")] = {{}, {dim_in = {din1}, @@ -49,11 +49,14 @@ function LSTMLayer:__init(id, global_conf, layer_conf) }, ["nerv.LSTMGateLayer"] = { [ap("forgetGateL")] = {{}, {dim_in = {din1, din2, din3}, - dim_out = {din3}, pr = pr}}, + dim_out = {din3}, pr = pr}, + param_type = {'N', 'N', 'D'}}, [ap("inputGateL")] = {{}, {dim_in = {din1, din2, din3}, - dim_out = {din3}, pr = pr}}, + dim_out = {din3}, pr = pr}, + param_tpye = {'N', 'N', 'D'}}, [ap("outputGateL")] = {{}, {dim_in = {din1, din2, din3}, - dim_out = {din3}, pr = pr}}, + dim_out = {din3}, pr = pr}, + param_type = {'N', 'N', 'D'}}, }, ["nerv.ElemMulLayer"] = { diff --git a/nerv/layer/lstm_gate.lua b/nerv/layer/lstm_gate.lua index 1963eba..8785b4f 100644 --- a/nerv/layer/lstm_gate.lua +++ b/nerv/layer/lstm_gate.lua @@ -5,12 +5,16 @@ function LSTMGateLayer:__init(id, global_conf, layer_conf) self.id = id self.dim_in = layer_conf.dim_in self.dim_out = layer_conf.dim_out + self.param_type = layer_conf.param_type self.gconf = global_conf for i = 1, #self.dim_in do self["ltp" .. i] = self:find_param("ltp" .. i, layer_conf, global_conf, nerv.LinearTransParam, {self.dim_in[i], self.dim_out[1]}) + if self.param_type[i] == 'D' then + self["ltp" .. i].trans:diagonalize() + end end self.bp = self:find_param("bp", layer_conf, global_conf, nerv.BiasParam, {1, self.dim_out[1]}) @@ -64,6 +68,9 @@ function LSTMGateLayer:update(bp_err, input, output) self.err_bakm:sigmoid_grad(bp_err[1], output[1]) for i = 1, #self.dim_in do self["ltp" .. i]:update_by_err_input(self.err_bakm, input[i]) + if self.param_type[i] == 'D' then + self["ltp" .. i].trans:diagonalize() + end end self.bp:update_by_gradient(self.err_bakm:colsum()) end diff --git a/nerv/lib/matrix/generic/cukernel.cu b/nerv/lib/matrix/generic/cukernel.cu index 51e3b6a..311a6ce 100644 --- a/nerv/lib/matrix/generic/cukernel.cu +++ b/nerv/lib/matrix/generic/cukernel.cu @@ -250,6 +250,14 @@ __global__ void cudak_(fill)(MATRIX_ELEM *a, a[j + i * stride] = val; } +__global__ void cudak_(diagonalize)(MATRIX_ELEM *a, + int nrow, int ncol, int stride) { + int j = blockIdx.x * blockDim.x + threadIdx.x; + int i = blockIdx.y * blockDim.y + threadIdx.y; + if (i >= nrow || j >= ncol || i == j) return; + a[j + i * stride] = 0; +} + __global__ void cudak_(clip)(MATRIX_ELEM *a, int nrow, int ncol, int stride, double val_1, double val_2) { int j = blockIdx.x * blockDim.x + threadIdx.x; @@ -679,6 +687,16 @@ extern "C" { cudaStreamSynchronize(0); } + void cudak_(cuda_diagonalize)(Matrix *a) { + dim3 threadsPerBlock(CUDA_THREADS_N, CUDA_THREADS_N); + dim3 numBlocks(CEIL_DIV(a->ncol, threadsPerBlock.x), + CEIL_DIV(a->nrow, threadsPerBlock.y)); + cudak_(diagonalize)<<>> \ + (MATRIX_ELEM_PTR(a), a->nrow, a->ncol, + a->stride / sizeof(MATRIX_ELEM)); + cudaStreamSynchronize(0); + } + void cudak_(cuda_clip)(Matrix *a, double val_1, double val_2) { dim3 threadsPerBlock(CUDA_THREADS_N, CUDA_THREADS_N); dim3 numBlocks(CEIL_DIV(a->ncol, threadsPerBlock.x), diff --git a/nerv/lib/matrix/generic/cumatrix.c b/nerv/lib/matrix/generic/cumatrix.c index 7b70607..1c74866 100644 --- a/nerv/lib/matrix/generic/cumatrix.c +++ b/nerv/lib/matrix/generic/cumatrix.c @@ -494,6 +494,15 @@ void nerv_matrix_(prefixsum_row)(Matrix *a, const Matrix *b, Status *status) { NERV_SET_STATUS(status, NERV_NORMAL, 0); } +void nerv_matrix_(diagonalize)(Matrix *a, Status *status) { + if (a->nrow != a->ncol) + NERV_EXIT_STATUS(status, MAT_MISMATCH_DIM, 0); + PROFILE_START + cudak_(cuda_diagonalize)(a); + PROFILE_STOP + NERV_SET_STATUS(status, NERV_NORMAL, 0); +} + static void cuda_matrix_(free)(MATRIX_ELEM *ptr, Status *status) { CUDA_SAFE_SYNC_CALL(cudaFree(ptr), status); NERV_SET_STATUS(status, NERV_NORMAL, 0); diff --git a/nerv/lib/matrix/generic/cumatrix.h b/nerv/lib/matrix/generic/cumatrix.h index f3c2df8..48d1f13 100644 --- a/nerv/lib/matrix/generic/cumatrix.h +++ b/nerv/lib/matrix/generic/cumatrix.h @@ -25,6 +25,7 @@ void nerv_matrix_(add_row)(Matrix *b, const Matrix *a, double beta, Status *status); void nerv_matrix_(clip)(Matrix *self, double val_1, double val_2, Status *status); void nerv_matrix_(fill)(Matrix *self, double val, Status *status); +void nerv_matrix_(diagonalize)(Matrix *self, Status *statut); void nerv_matrix_(copy_fromd)(Matrix *a, const Matrix *b, int a_begin, int b_begin, int b_end, Status *status); diff --git a/nerv/lib/matrix/generic/mmatrix.c b/nerv/lib/matrix/generic/mmatrix.c index fa1dc5f..3dabe0e 100644 --- a/nerv/lib/matrix/generic/mmatrix.c +++ b/nerv/lib/matrix/generic/mmatrix.c @@ -265,6 +265,22 @@ void nerv_matrix_(fill)(Matrix *self, double val, Status *status) { NERV_SET_STATUS(status, NERV_NORMAL, 0); } +void nerv_matrix_(diagonalize)(Matrix *self, Status *status) { + if (self->nrow != self->ncol) + NERV_EXIT_STATUS(status, MAT_MISMATCH_DIM, 0); + int i, j; + size_t astride = self->stride; + MATRIX_ELEM *arow = MATRIX_ELEM_PTR(self); + for (i = 0; i < self->nrow; i++) + { + for (j = 0; j < self->ncol; j++) + if (i != j) + arow[j] = 0; + arow = MATRIX_NEXT_ROW_PTR(arow, astride); + } + NERV_SET_STATUS(status, NERV_NORMAL, 0); +} + void nerv_matrix_(sigmoid)(Matrix *a, const Matrix *b, Status *status) { CHECK_SAME_DIMENSION(a, b, status); int i, j; diff --git a/nerv/lib/matrix/generic/mmatrix.h b/nerv/lib/matrix/generic/mmatrix.h index c54c4e5..2cbca47 100644 --- a/nerv/lib/matrix/generic/mmatrix.h +++ b/nerv/lib/matrix/generic/mmatrix.h @@ -23,6 +23,7 @@ void nerv_matrix_(add_row)(Matrix *b, const Matrix *a, double beta, Status *status); void nerv_matrix_(clip)(Matrix *self, double val_1, double val_2, Status *status); void nerv_matrix_(fill)(Matrix *self, double val, Status *status); +void nerv_matrix_(diagonalize)(Matrix *self, Status *status); void nerv_matrix_(copy_fromh)(Matrix *a, const Matrix *b, int a_begin, int b_begin, int b_end, Status *status); diff --git a/nerv/matrix/generic/cumatrix.c b/nerv/matrix/generic/cumatrix.c index b706c21..f8b8038 100644 --- a/nerv/matrix/generic/cumatrix.c +++ b/nerv/matrix/generic/cumatrix.c @@ -240,6 +240,7 @@ static const luaL_Reg nerv_matrix_(extra_methods)[] = { {"scale_rows_by_row", nerv_matrix_(lua_scale_rows_by_row)}, {"scale_rows_by_col", nerv_matrix_(lua_scale_rows_by_col)}, {"prefixsum_row", nerv_matrix_(lua_prefixsum_row)}, + {"diagonalize", nerv_matrix_(lua_diagonalize)}, #ifdef __NERV_FUTURE_CUDA_7 {"update_select_rows_by_rowidx", nerv_matrix_(lua_update_select_rows_by_rowidx)}, {"update_select_rows_by_colidx", nerv_matrix_(lua_update_select_rows_by_colidx)}, diff --git a/nerv/matrix/generic/matrix.c b/nerv/matrix/generic/matrix.c index c1da774..3162ffb 100644 --- a/nerv/matrix/generic/matrix.c +++ b/nerv/matrix/generic/matrix.c @@ -338,4 +338,12 @@ static int nerv_matrix_(lua_scale_rows_by_row)(lua_State *L) { return 0; } +static int nerv_matrix_(lua_diagonalize)(lua_State *L) { + Status status; + Matrix *a = luaT_checkudata(L, 1, nerv_matrix_(tname)); + nerv_matrix_(diagonalize)(a, &status); + NERV_LUA_CHECK_STATUS(L, status); + return 0; +} + #endif diff --git a/nerv/matrix/generic/mmatrix.c b/nerv/matrix/generic/mmatrix.c index 93562d0..1665eff 100644 --- a/nerv/matrix/generic/mmatrix.c +++ b/nerv/matrix/generic/mmatrix.c @@ -107,6 +107,7 @@ static const luaL_Reg nerv_matrix_(extra_methods)[] = { {"add_row", nerv_matrix_(lua_add_row)}, {"clip", nerv_matrix_(lua_clip)}, {"fill", nerv_matrix_(lua_fill)}, + {"diagonalize", nerv_matrix_(lua_diagonalize)}, {"sigmoid", nerv_matrix_(lua_sigmoid)}, {"sigmoid_grad", nerv_matrix_(lua_sigmoid_grad)}, {"softmax", nerv_matrix_(lua_softmax)}, -- cgit v1.2.3 From 9fb40644f5b7aa141ca23c93e6aa77c05a04c76c Mon Sep 17 00:00:00 2001 From: Qi Liu Date: Wed, 24 Feb 2016 16:24:50 +0800 Subject: fix typo --- nerv/layer/lstm.lua | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'nerv') diff --git a/nerv/layer/lstm.lua b/nerv/layer/lstm.lua index b0cfe08..a1833ef 100644 --- a/nerv/layer/lstm.lua +++ b/nerv/layer/lstm.lua @@ -53,7 +53,7 @@ function LSTMLayer:__init(id, global_conf, layer_conf) param_type = {'N', 'N', 'D'}}, [ap("inputGateL")] = {{}, {dim_in = {din1, din2, din3}, dim_out = {din3}, pr = pr}, - param_tpye = {'N', 'N', 'D'}}, + param_type = {'N', 'N', 'D'}}, [ap("outputGateL")] = {{}, {dim_in = {din1, din2, din3}, dim_out = {din3}, pr = pr}, param_type = {'N', 'N', 'D'}}, -- cgit v1.2.3 From 14c1997203e04838b1737716dc385e1aa08fe91f Mon Sep 17 00:00:00 2001 From: Qi Liu Date: Wed, 24 Feb 2016 16:40:54 +0800 Subject: fix bug --- nerv/layer/lstm.lua | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) (limited to 'nerv') diff --git a/nerv/layer/lstm.lua b/nerv/layer/lstm.lua index a1833ef..d8eee71 100644 --- a/nerv/layer/lstm.lua +++ b/nerv/layer/lstm.lua @@ -49,14 +49,14 @@ function LSTMLayer:__init(id, global_conf, layer_conf) }, ["nerv.LSTMGateLayer"] = { [ap("forgetGateL")] = {{}, {dim_in = {din1, din2, din3}, - dim_out = {din3}, pr = pr}, - param_type = {'N', 'N', 'D'}}, + dim_out = {din3}, pr = pr, + param_type = {'N', 'N', 'D'}}}, [ap("inputGateL")] = {{}, {dim_in = {din1, din2, din3}, - dim_out = {din3}, pr = pr}, - param_type = {'N', 'N', 'D'}}, + dim_out = {din3}, pr = pr, + param_type = {'N', 'N', 'D'}}}, [ap("outputGateL")] = {{}, {dim_in = {din1, din2, din3}, - dim_out = {din3}, pr = pr}, - param_type = {'N', 'N', 'D'}}, + dim_out = {din3}, pr = pr, + param_type = {'N', 'N', 'D'}}}, }, ["nerv.ElemMulLayer"] = { -- cgit v1.2.3 From e2a9af061db485d4388902d738c9d8be3f94ab34 Mon Sep 17 00:00:00 2001 From: Qi Liu Date: Fri, 11 Mar 2016 20:11:00 +0800 Subject: add recipe and fix bugs --- nerv/Makefile | 2 +- nerv/examples/network_debug/config.lua | 62 +++++++++ nerv/examples/network_debug/main.lua | 45 ++++++ nerv/examples/network_debug/network.lua | 110 +++++++++++++++ nerv/examples/network_debug/reader.lua | 113 +++++++++++++++ nerv/examples/network_debug/select_linear.lua | 59 ++++++++ nerv/examples/network_debug/timer.lua | 33 +++++ nerv/examples/network_debug/tnn.lua | 136 ++++++++++++++++++ nerv/io/init.lua | 3 +- nerv/io/seq_buffer.lua | 0 nerv/layer/dropout.lua | 11 +- nerv/layer/graph.lua | 2 +- nerv/layer/lstm.lua | 191 +++++++++----------------- nerv/layer/rnn.lua | 20 +-- nerv/matrix/init.lua | 18 ++- 15 files changed, 662 insertions(+), 143 deletions(-) create mode 100644 nerv/examples/network_debug/config.lua create mode 100644 nerv/examples/network_debug/main.lua create mode 100644 nerv/examples/network_debug/network.lua create mode 100644 nerv/examples/network_debug/reader.lua create mode 100644 nerv/examples/network_debug/select_linear.lua create mode 100644 nerv/examples/network_debug/timer.lua create mode 100644 nerv/examples/network_debug/tnn.lua create mode 100644 nerv/io/seq_buffer.lua (limited to 'nerv') diff --git a/nerv/Makefile b/nerv/Makefile index 7921bd9..68465a1 100644 --- a/nerv/Makefile +++ b/nerv/Makefile @@ -44,7 +44,7 @@ LUA_LIBS := matrix/init.lua io/init.lua init.lua \ layer/elem_mul.lua layer/lstm.lua layer/lstm_gate.lua layer/dropout.lua layer/gru.lua \ layer/graph.lua layer/rnn.lua layer/duplicate.lua layer/identity.lua \ nn/init.lua nn/layer_repo.lua nn/param_repo.lua nn/network.lua \ - io/sgd_buffer.lua + io/sgd_buffer.lua io/seq_buffer.lua INCLUDE := -I $(LUA_INCDIR) -DLUA_USE_APICHECK CUDA_INCLUDE := -I $(CUDA_BASE)/include/ diff --git a/nerv/examples/network_debug/config.lua b/nerv/examples/network_debug/config.lua new file mode 100644 index 0000000..e20d5a9 --- /dev/null +++ b/nerv/examples/network_debug/config.lua @@ -0,0 +1,62 @@ +function get_global_conf() + local global_conf = { + lrate = 0.15, + wcost = 1e-5, + momentum = 0, + clip = 5, + cumat_type = nerv.CuMatrixFloat, + mmat_type = nerv.MMatrixFloat, + vocab_size = 10000, + nn_act_default = 0, + hidden_size = 300, + layer_num = 1, + chunk_size = 15, + batch_size = 20, + max_iter = 35, + param_random = function() return (math.random() / 5 - 0.1) end, + dropout_rate = 0.5, + timer = nerv.Timer(), + pr = nerv.ParamRepo(), + } + return global_conf +end + +function get_layers(global_conf) + local pr = global_conf.pr + local layers = { + ['nerv.LSTMLayer'] = {}, + ['nerv.DropoutLayer'] = {}, + ['nerv.SelectLinearLayer'] = { + ['select'] = {dim_in = {1}, dim_out = {global_conf.hidden_size}, vocab = global_conf.vocab_size, pr = pr}, + }, + ['nerv.AffineLayer'] = { + output = {dim_in = {global_conf.hidden_size}, dim_out = {global_conf.vocab_size}, pr = pr} + }, + ['nerv.SoftmaxCELayer'] = { + softmax = {dim_in = {global_conf.vocab_size, global_conf.vocab_size}, dim_out = {1}, compressed = true}, + }, + } + for i = 1, global_conf.layer_num do + layers['nerv.LSTMLayer']['lstm' .. i] = {dim_in = {global_conf.hidden_size}, dim_out = {global_conf.hidden_size}, pr = pr} + layers['nerv.DropoutLayer']['dropout' .. i] = {dim_in = {global_conf.hidden_size}, dim_out = {global_conf.hidden_size}} + end + return layers +end + +function get_connections(global_conf) + local connections = { + {'[1]', 'select[1]', 0}, + {'select[1]', 'lstm1[1]', 0}, + {'dropout' .. global_conf.layer_num .. '[1]', 'output[1]', 0}, + {'output[1]', 'softmax[1]', 0}, + {'[2]', 'softmax[2]', 0}, + {'softmax[1]', '[1]', 0}, + } + for i = 1, global_conf.layer_num do + table.insert(connections, {'lstm' .. i .. '[1]', 'dropout' .. i .. '[1]', 0}) + if i < 1 then + table.insert(connections, {'dropout' .. (i - 1) .. '[1]', 'lstm' .. i .. '[1]', 0}) + end + end + return connections +end diff --git a/nerv/examples/network_debug/main.lua b/nerv/examples/network_debug/main.lua new file mode 100644 index 0000000..790c404 --- /dev/null +++ b/nerv/examples/network_debug/main.lua @@ -0,0 +1,45 @@ +nerv.include('reader.lua') +nerv.include('timer.lua') +nerv.include('config.lua') +nerv.include(arg[1]) + +local global_conf = get_global_conf() +local timer = global_conf.timer + +timer:tic('IO') + +local data_path = 'examples/lmptb/PTBdata/' +local train_reader = nerv.Reader(data_path .. 'vocab', data_path .. 'ptb.train.txt.adds') +local val_reader = nerv.Reader(data_path .. 'vocab', data_path .. 'ptb.valid.txt.adds') + +local train_data = train_reader:get_all_batch(global_conf) +local val_data = val_reader:get_all_batch(global_conf) + +local layers = get_layers(global_conf) +local connections = get_connections(global_conf) + +local NN = nerv.NN(global_conf, train_data, val_data, layers, connections) + +timer:toc('IO') +timer:check('IO') +io.flush() + +timer:tic('global') +local best_cv = 1e10 +for i = 1, global_conf.max_iter do + timer:tic('Epoch' .. i) + local train_ppl, val_ppl = NN:epoch() + if val_ppl < best_cv then + best_cv = val_ppl + else + global_conf.lrate = global_conf.lrate / 2.0 + end + nerv.printf('Epoch %d: %f %f %f\n', i, global_conf.lrate, train_ppl, val_ppl) + timer:toc('Epoch' .. i) + timer:check('Epoch' .. i) + io.flush() +end +timer:toc('global') +timer:check('global') +timer:check('network') +timer:check('gc') diff --git a/nerv/examples/network_debug/network.lua b/nerv/examples/network_debug/network.lua new file mode 100644 index 0000000..5518e27 --- /dev/null +++ b/nerv/examples/network_debug/network.lua @@ -0,0 +1,110 @@ +nerv.include('select_linear.lua') + +local nn = nerv.class('nerv.NN') + +function nn:__init(global_conf, train_data, val_data, layers, connections) + self.gconf = global_conf + self.network = self:get_network(layers, connections) + self.train_data = self:get_data(train_data) + self.val_data = self:get_data(val_data) +end + +function nn:get_network(layers, connections) + local layer_repo = nerv.LayerRepo(layers, self.gconf.pr, self.gconf) + local graph = nerv.GraphLayer('graph', self.gconf, + {dim_in = {1, self.gconf.vocab_size}, dim_out = {1}, + layer_repo = layer_repo, connections = connections}) + local network = nerv.Network('network', self.gconf, + {network = graph, clip = self.gconf.clip}) + network:init(self.gconf.batch_size, self.gconf.chunk_size) + return network +end + +function nn:get_data(data) + local err_output = {} + local softmax_output = {} + local output = {} + for i = 1, self.gconf.chunk_size do + err_output[i] = self.gconf.cumat_type(self.gconf.batch_size, 1) + softmax_output[i] = self.gconf.cumat_type(self.gconf.batch_size, self.gconf.vocab_size) + output[i] = self.gconf.cumat_type(self.gconf.batch_size, 1) + end + local ret = {} + for i = 1, #data do + ret[i] = {} + ret[i].input = {} + ret[i].output = {} + ret[i].err_input = {} + ret[i].err_output = {} + for t = 1, self.gconf.chunk_size do + ret[i].input[t] = {} + ret[i].output[t] = {} + ret[i].err_input[t] = {} + ret[i].err_output[t] = {} + ret[i].input[t][1] = data[i].input[t] + ret[i].input[t][2] = data[i].output[t] + ret[i].output[t][1] = output[t] + local err_input = self.gconf.mmat_type(self.gconf.batch_size, 1) + for j = 1, self.gconf.batch_size do + if t <= data[i].seq_len[j] then + err_input[j - 1][0] = 1 + else + err_input[j - 1][0] = 0 + end + end + ret[i].err_input[t][1] = self.gconf.cumat_type.new_from_host(err_input) + ret[i].err_output[t][1] = err_output[t] + ret[i].err_output[t][2] = softmax_output[t] + end + ret[i].seq_length = data[i].seq_len + ret[i].new_seq = {} + for j = 1, self.gconf.batch_size do + if data[i].seq_start[j] then + table.insert(ret[i].new_seq, j) + end + end + end + return ret +end + +function nn:process(data, do_train) + local timer = self.gconf.timer + local total_err = 0 + local total_frame = 0 + self.network:epoch_init() + for id = 1, #data do + data[id].do_train = do_train + timer:tic('network') + self.network:mini_batch_init(data[id]) + self.network:propagate() + timer:toc('network') + for t = 1, self.gconf.chunk_size do + local tmp = data[id].output[t][1]:new_to_host() + for i = 1, self.gconf.batch_size do + if t <= data[id].seq_length[i] then + total_err = total_err + math.log10(math.exp(tmp[i - 1][0])) + total_frame = total_frame + 1 + end + end + end + if do_train then + timer:tic('network') + self.network:back_propagate() + self.network:update() + timer:toc('network') + end + timer:tic('gc') + collectgarbage('collect') + timer:toc('gc') + end + return math.pow(10, - total_err / total_frame) +end + +function nn:epoch() + local train_error = self:process(self.train_data, true) + local tmp = self.gconf.dropout_rate + self.gconf.dropout_rate = 0 + local val_error = self:process(self.val_data, false) + self.gconf.dropout_rate = tmp + return train_error, val_error +end diff --git a/nerv/examples/network_debug/reader.lua b/nerv/examples/network_debug/reader.lua new file mode 100644 index 0000000..d2624d3 --- /dev/null +++ b/nerv/examples/network_debug/reader.lua @@ -0,0 +1,113 @@ +local Reader = nerv.class('nerv.Reader') + +function Reader:__init(vocab_file, input_file) + self:get_vocab(vocab_file) + self:get_seq(input_file) +end + +function Reader:get_vocab(vocab_file) + local f = io.open(vocab_file, 'r') + local id = 0 + self.vocab = {} + while true do + local word = f:read() + if word == nil then + break + end + self.vocab[word] = id + id = id + 1 + end + self.size = id +end + +function Reader:split(s, t) + local ret = {} + for x in (s .. t):gmatch('(.-)' .. t) do + table.insert(ret, x) + end + return ret +end + +function Reader:get_seq(input_file) + local f = io.open(input_file, 'r') + self.seq = {} + while true do + local seq = f:read() + if seq == nil then + break + end + seq = self:split(seq, ' ') + local tmp = {} + for i = 1, #seq do + if seq[i] ~= '' then + table.insert(tmp, self.vocab[seq[i]]) + end + end + table.insert(self.seq, tmp) + end +end + +function Reader:get_in_out(id, pos) + return self.seq[id][pos], self.seq[id][pos + 1], pos + 1 == #self.seq[id] +end + +function Reader:get_all_batch(global_conf) + local data = {} + local pos = {} + local offset = 1 + for i = 1, global_conf.batch_size do + pos[i] = nil + end + while true do + --for i = 1, 100 do + local input = {} + local output = {} + for i = 1, global_conf.chunk_size do + input[i] = global_conf.mmat_type(global_conf.batch_size, 1) + input[i]:fill(global_conf.nn_act_default) + output[i] = global_conf.mmat_type(global_conf.batch_size, 1) + output[i]:fill(global_conf.nn_act_default) + end + local seq_start = {} + local seq_end = {} + local seq_len = {} + for i = 1, global_conf.batch_size do + seq_start[i] = false + seq_end[i] = false + seq_len[i] = 0 + end + local has_new = false + for i = 1, global_conf.batch_size do + if pos[i] == nil then + if offset < #self.seq then + seq_start[i] = true + pos[i] = {offset, 1} + offset = offset + 1 + end + end + if pos[i] ~= nil then + has_new = true + for j = 1, global_conf.chunk_size do + local final + input[j][i-1][0], output[j][i-1][0], final = self:get_in_out(pos[i][1], pos[i][2]) + seq_len[i] = j + if final then + seq_end[i] = true + pos[i] = nil + break + end + pos[i][2] = pos[i][2] + 1 + end + end + end + if not has_new then + break + end + for i = 1, global_conf.chunk_size do + input[i] = global_conf.cumat_type.new_from_host(input[i]) + output[i] = global_conf.cumat_type.new_from_host(output[i]) + end + table.insert(data, {input = input, output = output, seq_start = seq_start, seq_end = seq_end, seq_len = seq_len}) + end + return data +end diff --git a/nerv/examples/network_debug/select_linear.lua b/nerv/examples/network_debug/select_linear.lua new file mode 100644 index 0000000..91beedf --- /dev/null +++ b/nerv/examples/network_debug/select_linear.lua @@ -0,0 +1,59 @@ +local SL = nerv.class('nerv.SelectLinearLayer', 'nerv.Layer') + +--id: string +--global_conf: table +--layer_conf: table +--Get Parameters +function SL:__init(id, global_conf, layer_conf) + nerv.Layer.__init(self, id, global_conf, layer_conf) + + self.vocab = layer_conf.vocab + self.ltp = self:find_param("ltp", layer_conf, global_conf, nerv.LinearTransParam, {self.vocab, self.dim_out[1]}) --layer_conf.ltp + + self:check_dim_len(1, 1) +end + +--Check parameter +function SL:init(batch_size) + if (self.dim_in[1] ~= 1) then --one word id + nerv.error("mismatching dimensions of ltp and input") + end + if (self.dim_out[1] ~= self.ltp.trans:ncol()) then + nerv.error("mismatching dimensions of bp and output") + end + + self.batch_size = bath_size + self.ltp:train_init() +end + +function SL:update(bp_err, input, output) + --use this to produce reproducable result, don't forget to set the dropout to zero! + --for i = 1, input[1]:nrow(), 1 do + -- local word_vec = self.ltp.trans[input[1][i - 1][0]] + -- word_vec:add(word_vec, bp_err[1][i - 1], 1, - self.gconf.lrate / self.gconf.batch_size) + --end + + --I tried the update_select_rows kernel which uses atomicAdd, but it generates unreproducable result + self.ltp.trans:update_select_rows_by_colidx(bp_err[1], input[1], - self.gconf.lrate / self.gconf.batch_size, 0) + self.ltp.trans:add(self.ltp.trans, self.ltp.trans, 1.0, - self.gconf.lrate * self.gconf.wcost) +end + +function SL:propagate(input, output) + --for i = 0, input[1]:ncol() - 1, 1 do + -- if (input[1][0][i] > 0) then + -- output[1][i]:copy_fromd(self.ltp.trans[input[1][0][i]]) + -- else + -- output[1][i]:fill(0) + -- end + --end + output[1]:copy_rows_fromd_by_colidx(self.ltp.trans, input[1]) +end + +function SL:back_propagate(bp_err, next_bp_err, input, output) + --input is compressed, do nothing +end + +function SL:get_params() + local paramRepo = nerv.ParamRepo({self.ltp}) + return paramRepo +end diff --git a/nerv/examples/network_debug/timer.lua b/nerv/examples/network_debug/timer.lua new file mode 100644 index 0000000..2c54ca8 --- /dev/null +++ b/nerv/examples/network_debug/timer.lua @@ -0,0 +1,33 @@ +local Timer = nerv.class("nerv.Timer") + +function Timer:__init() + self.last = {} + self.rec = {} +end + +function Timer:tic(item) + self.last[item] = os.clock() +end + +function Timer:toc(item) + if (self.last[item] == nil) then + nerv.error("item not there") + end + if (self.rec[item] == nil) then + self.rec[item] = 0 + end + self.rec[item] = self.rec[item] + os.clock() - self.last[item] +end + +function Timer:check(item) + if self.rec[item]==nil then + nerv.error('item not there') + end + nerv.printf('"%s" lasts for %f secs.\n',item,self.rec[item]) +end + +function Timer:flush() + for key, value in pairs(self.rec) do + self.rec[key] = nil + end +end diff --git a/nerv/examples/network_debug/tnn.lua b/nerv/examples/network_debug/tnn.lua new file mode 100644 index 0000000..bf9f118 --- /dev/null +++ b/nerv/examples/network_debug/tnn.lua @@ -0,0 +1,136 @@ +nerv.include('select_linear.lua') + +local reader = nerv.class('nerv.TNNReader') + +function reader:__init(global_conf, data) + self.gconf = global_conf + self.offset = 0 + self.data = data +end + +function reader:get_batch(feeds) + self.offset = self.offset + 1 + if self.offset > #self.data then + return false + end + for i = 1, self.gconf.chunk_size do + feeds.inputs_m[i][1]:copy_from(self.data[self.offset].input[i]) + feeds.inputs_m[i][2]:copy_from(self.data[self.offset].output[i]:decompress(self.gconf.vocab_size)) + end + feeds.flags_now = self.data[self.offset].flags + feeds.flagsPack_now = self.data[self.offset].flagsPack + return true +end + +function reader:has_data(t, i) + return t <= self.data[self.offset].seq_len[i] +end + +function reader:get_err_input() + return self.data[self.offset].err_input +end + +local nn = nerv.class('nerv.NN') + +function nn:__init(global_conf, train_data, val_data, layers, connections) + self.gconf = global_conf + self.tnn = self:get_tnn(layers, connections) + self.train_data = self:get_data(train_data) + self.val_data = self:get_data(val_data) +end + +function nn:get_tnn(layers, connections) + self.gconf.dropout_rate = 0 + local layer_repo = nerv.LayerRepo(layers, self.gconf.pr, self.gconf) + local tnn = nerv.TNN('TNN', self.gconf, {dim_in = {1, self.gconf.vocab_size}, + dim_out = {1}, sub_layers = layer_repo, connections = connections, + clip = self.gconf.clip}) + tnn:init(self.gconf.batch_size, self.gconf.chunk_size) + return tnn +end + +function nn:get_data(data) + local ret = {} + for i = 1, #data do + ret[i] = {} + ret[i].input = data[i].input + ret[i].output = data[i].output + ret[i].flags = {} + ret[i].err_input = {} + for t = 1, self.gconf.chunk_size do + ret[i].flags[t] = {} + local err_input = self.gconf.mmat_type(self.gconf.batch_size, 1) + for j = 1, self.gconf.batch_size do + if t <= data[i].seq_len[j] then + ret[i].flags[t][j] = nerv.TNN.FC.SEQ_NORM + err_input[j - 1][0] = 1 + else + ret[i].flags[t][j] = 0 + err_input[j - 1][0] = 0 + end + end + ret[i].err_input[t] = self.gconf.cumat_type.new_from_host(err_input) + end + for j = 1, self.gconf.batch_size do + if data[i].seq_start[j] then + ret[i].flags[1][j] = bit.bor(ret[i].flags[1][j], nerv.TNN.FC.SEQ_START) + end + if data[i].seq_end[j] then + local t = data[i].seq_len[j] + ret[i].flags[t][j] = bit.bor(ret[i].flags[t][j], nerv.TNN.FC.SEQ_END) + end + end + ret[i].flagsPack = {} + for t = 1, self.gconf.chunk_size do + ret[i].flagsPack[t] = 0 + for j = 1, self.gconf.batch_size do + ret[i].flagsPack[t] = bit.bor(ret[i].flagsPack[t], ret[i].flags[t][j]) + end + end + ret[i].seq_len = data[i].seq_len + end + return ret +end + +function nn:process(data, do_train) + local total_err = 0 + local total_frame = 0 + local reader = nerv.TNNReader(self.gconf, data) + while true do + local r, _ = self.tnn:getfeed_from_reader(reader) + if not r then + break + end + if do_train then + self.gconf.dropout_rate = self.gconf.dropout + else + self.gconf.dropout_rate = 0 + end + self.tnn:net_propagate() + for t = 1, self.gconf.chunk_size do + local tmp = self.tnn.outputs_m[t][1]:new_to_host() + for i = 1, self.gconf.batch_size do + if reader:has_data(t, i) then + total_err = total_err + math.log10(math.exp(tmp[i - 1][0])) + total_frame = total_frame + 1 + end + end + end + if do_train then + local err_input = reader:get_err_input() + for i = 1, self.gconf.chunk_size do + self.tnn.err_inputs_m[i][1]:copy_from(err_input[i]) + end + self.tnn:net_backpropagate(false) + self.tnn:net_backpropagate(true) + end + collectgarbage('collect') + end + return math.pow(10, - total_err / total_frame) +end + +function nn:epoch() + local train_error = self:process(self.train_data, true) + local val_error = self:process(self.val_data, false) + return train_error, val_error +end diff --git a/nerv/io/init.lua b/nerv/io/init.lua index eb2e3e5..c36d850 100644 --- a/nerv/io/init.lua +++ b/nerv/io/init.lua @@ -52,8 +52,9 @@ function DataBuffer:__init(global_conf, buffer_conf) nerv.error_method_not_implemented() end -function DataBuffer:get_batch() +function DataBuffer:get_data() nerv.error_method_not_implemented() end nerv.include('sgd_buffer.lua') +nerv.include('seq_buffer.lua') diff --git a/nerv/io/seq_buffer.lua b/nerv/io/seq_buffer.lua new file mode 100644 index 0000000..e69de29 diff --git a/nerv/layer/dropout.lua b/nerv/layer/dropout.lua index 1a379c9..39a8963 100644 --- a/nerv/layer/dropout.lua +++ b/nerv/layer/dropout.lua @@ -2,8 +2,7 @@ local DropoutLayer = nerv.class("nerv.DropoutLayer", "nerv.Layer") function DropoutLayer:__init(id, global_conf, layer_conf) nerv.Layer.__init(self, id, global_conf, layer_conf) - self.rate = layer_conf.dropout_rate or global_conf.dropout_rate - if self.rate == nil then + if self.gconf.dropout_rate == nil then nerv.warning("[DropoutLayer:propagate] dropout rate is not set") end self:check_dim_len(1, 1) -- two inputs: nn output and label @@ -41,12 +40,12 @@ function DropoutLayer:propagate(input, output, t) if t == nil then t = 1 end - if self.rate then + if self.gconf.dropout_rate ~= 0 then self.mask[t]:rand_uniform() -- since we will lose a portion of the actvations, we multiply the -- activations by 1 / (1 - rate) to compensate - self.mask[t]:thres_mask(self.mask[t], self.rate, - 0, 1 / (1.0 - self.rate)) + self.mask[t]:thres_mask(self.mask[t], self.gconf.dropout_rate, + 0, 1 / (1.0 - self.gconf.dropout_rate)) output[1]:mul_elem(input[1], self.mask[t]) else output[1]:copy_fromd(input[1]) @@ -61,7 +60,7 @@ function DropoutLayer:back_propagate(bp_err, next_bp_err, input, output, t) if t == nil then t = 1 end - if self.rate then + if self.gconf.dropout_rate then next_bp_err[1]:mul_elem(bp_err[1], self.mask[t]) else next_bp_err[1]:copy_fromd(bp_err[1]) diff --git a/nerv/layer/graph.lua b/nerv/layer/graph.lua index 5f42fca..68d5f51 100644 --- a/nerv/layer/graph.lua +++ b/nerv/layer/graph.lua @@ -112,7 +112,7 @@ function GraphLayer:graph_init(layer_repo, connections) end for i = 1, #ref.dim_out do if ref.outputs[i] == nil then - nerv.error('dangling output port %d os layer %s', i, id) + nerv.error('dangling output port %d of layer %s', i, id) end end end diff --git a/nerv/layer/lstm.lua b/nerv/layer/lstm.lua index 641d5dc..5dbcc20 100644 --- a/nerv/layer/lstm.lua +++ b/nerv/layer/lstm.lua @@ -1,144 +1,85 @@ -local LSTMLayer = nerv.class('nerv.LSTMLayer', 'nerv.Layer') +local LSTMLayer = nerv.class('nerv.LSTMLayer', 'nerv.GraphLayer') function LSTMLayer:__init(id, global_conf, layer_conf) - -- input1:x - -- input2:h - -- input3:c nerv.Layer.__init(self, id, global_conf, layer_conf) - -- prepare a DAGLayer to hold the lstm structure + self:check_dim_len(1, 1) + + local din = layer_conf.dim_in[1] + local dout = layer_conf.dim_out[1] + local pr = layer_conf.pr if pr == nil then pr = nerv.ParamRepo({}, self.loc_type) end - - local function ap(str) - return self.id .. '.' .. str - end - local din1, din2, din3 = self.dim_in[1], self.dim_in[2], self.dim_in[3] - local dout1, dout2, dout3 = self.dim_out[1], self.dim_out[2], self.dim_out[3] - local layers = { - ["nerv.CombinerLayer"] = { - [ap("inputXDup")] = {dim_in = {din1}, - dim_out = {din1, din1, din1, din1}, - lambda = {1}}, - [ap("inputHDup")] = {dim_in = {din2}, - dim_out = {din2, din2, din2, din2}, - lambda = {1}}, - - [ap("inputCDup")] = {dim_in = {din3}, - dim_out = {din3, din3, din3}, - lambda = {1}}, - - [ap("mainCDup")] = {dim_in = {din3, din3}, - dim_out = {din3, din3, din3}, - lambda = {1, 1}}, + local layers = { + ['nerv.CombinerLayer'] = { + mainCombine = {dim_in = {dout, dout}, dim_out = {dout}, lambda = {1, 1}}, }, - ["nerv.AffineLayer"] = { - [ap("mainAffineL")] = {dim_in = {din1, din2}, - dim_out = {dout1}, - pr = pr}, + ['nerv.DuplicateLayer'] = { + inputDup = {dim_in = {din}, dim_out = {din, din, din, din}}, + outputDup = {dim_in = {dout}, dim_out = {dout, dout, dout, dout, dout}}, + cellDup = {dim_in = {dout}, dim_out = {dout, dout, dout, dout, dout}}, }, - ["nerv.TanhLayer"] = { - [ap("mainTanhL")] = {dim_in = {dout1}, dim_out = {dout1}}, - [ap("outputTanhL")] = {dim_in = {dout1}, dim_out = {dout1}}, + ['nerv.AffineLayer'] = { + mainAffine = {dim_in = {din, dout}, dim_out = {dout}, pr = pr}, }, - ["nerv.LSTMGateLayer"] = { - [ap("forgetGateL")] = {dim_in = {din1, din2, din3}, - dim_out = {din3}, pr = pr}, - [ap("inputGateL")] = {dim_in = {din1, din2, din3}, - dim_out = {din3}, pr = pr}, - [ap("outputGateL")] = {dim_in = {din1, din2, din3}, - dim_out = {din3}, pr = pr}, - + ['nerv.TanhLayer'] = { + mainTanh = {dim_in = {dout}, dim_out = {dout}}, + outputTanh = {dim_in = {dout}, dim_out = {dout}}, }, - ["nerv.ElemMulLayer"] = { - [ap("inputGMulL")] = {dim_in = {din3, din3}, - dim_out = {din3}}, - [ap("forgetGMulL")] = {dim_in = {din3, din3}, - dim_out = {din3}}, - [ap("outputGMulL")] = {dim_in = {din3, din3}, - dim_out = {din3}}, + ['nerv.LSTMGateLayer'] = { + forgetGate = {dim_in = {din, dout, dout}, dim_out = {dout}, pr = pr}, + inputGate = {dim_in = {din, dout, dout}, dim_out = {dout}, pr = pr}, + outputGate = {dim_in = {din, dout, dout}, dim_out = {dout}, pr = pr}, + }, + ['nerv.ElemMulLayer'] = { + inputGateMul = {dim_in = {dout, dout}, dim_out = {dout}}, + forgetGateMul = {dim_in = {dout, dout}, dim_out = {dout}}, + outputGateMul = {dim_in = {dout, dout}, dim_out = {dout}}, }, } - self.lrepo = nerv.LayerRepo(layers, pr, global_conf) - local connections = { - ["[1]"] = ap("inputXDup[1]"), - ["[2]"] = ap("inputHDup[1]"), - ["[3]"] = ap("inputCDup[1]"), - - [ap("inputXDup[1]")] = ap("mainAffineL[1]"), - [ap("inputHDup[1]")] = ap("mainAffineL[2]"), - [ap("mainAffineL[1]")] = ap("mainTanhL[1]"), - - [ap("inputXDup[2]")] = ap("inputGateL[1]"), - [ap("inputHDup[2]")] = ap("inputGateL[2]"), - [ap("inputCDup[1]")] = ap("inputGateL[3]"), - - [ap("inputXDup[3]")] = ap("forgetGateL[1]"), - [ap("inputHDup[3]")] = ap("forgetGateL[2]"), - [ap("inputCDup[2]")] = ap("forgetGateL[3]"), - - [ap("mainTanhL[1]")] = ap("inputGMulL[1]"), - [ap("inputGateL[1]")] = ap("inputGMulL[2]"), - - [ap("inputCDup[3]")] = ap("forgetGMulL[1]"), - [ap("forgetGateL[1]")] = ap("forgetGMulL[2]"), - - [ap("inputGMulL[1]")] = ap("mainCDup[1]"), - [ap("forgetGMulL[1]")] = ap("mainCDup[2]"), - - [ap("inputXDup[4]")] = ap("outputGateL[1]"), - [ap("inputHDup[4]")] = ap("outputGateL[2]"), - [ap("mainCDup[3]")] = ap("outputGateL[3]"), - - [ap("mainCDup[2]")] = "[2]", - [ap("mainCDup[1]")] = ap("outputTanhL[1]"), - - [ap("outputTanhL[1]")] = ap("outputGMulL[1]"), - [ap("outputGateL[1]")] = ap("outputGMulL[2]"), - - [ap("outputGMulL[1]")] = "[1]", + -- lstm input + {'[1]', 'inputDup[1]', 0}, + + -- input gate + {'inputDup[1]', 'inputGate[1]', 0}, + {'outputDup[1]', 'inputGate[2]', 1}, + {'cellDup[1]', 'inputGate[3]', 1}, + + -- forget gate + {'inputDup[2]', 'forgetGate[1]', 0}, + {'outputDup[2]', 'forgetGate[2]', 1}, + {'cellDup[2]', 'forgetGate[3]', 1}, + + -- lstm cell + {'forgetGate[1]', 'forgetGateMul[1]', 0}, + {'cellDup[3]', 'forgetGateMul[2]', 1}, + {'inputDup[3]', 'mainAffine[1]', 0}, + {'outputDup[3]', 'mainAffine[2]', 1}, + {'mainAffine[1]', 'mainTanh[1]', 0}, + {'inputGate[1]', 'inputGateMul[1]', 0}, + {'mainTanh[1]', 'inputGateMul[2]', 0}, + {'inputGateMul[1]', 'mainCombine[1]', 0}, + {'forgetGateMul[1]', 'mainCombine[2]', 0}, + {'mainCombine[1]', 'cellDup[1]', 0}, + + -- forget gate + {'inputDup[4]', 'outputGate[1]', 0}, + {'outputDup[4]', 'outputGate[2]', 1}, + {'cellDup[4]', 'outputGate[3]', 0}, + + -- lstm output + {'cellDup[5]', 'outputTanh[1]', 0}, + {'outputGate[1]', 'outputGateMul[1]', 0}, + {'outputTanh[1]', 'outputGateMul[2]', 0}, + {'outputGateMul[1]', 'outputDup[1]', 0}, + {'outputDup[5]', '[1]', 0}, } - self.dag = nerv.DAGLayer(self.id, global_conf, - {dim_in = self.dim_in, - dim_out = self.dim_out, - sub_layers = self.lrepo, - connections = connections}) - - self:check_dim_len(3, 2) -- x, h, c and h, c -end - -function LSTMLayer:bind_params() - local pr = layer_conf.pr - if pr == nil then - pr = nerv.ParamRepo({}, self.loc_type) - end - self.lrepo:rebind(pr) -end - -function LSTMLayer:init(batch_size, chunk_size) - self.dag:init(batch_size, chunk_size) -end - -function LSTMLayer:batch_resize(batch_size, chunk_size) - self.dag:batch_resize(batch_size, chunk_size) -end - -function LSTMLayer:update(bp_err, input, output, t) - self.dag:update(bp_err, input, output, t) -end - -function LSTMLayer:propagate(input, output, t) - self.dag:propagate(input, output, t) -end - -function LSTMLayer:back_propagate(bp_err, next_bp_err, input, output, t) - self.dag:back_propagate(bp_err, next_bp_err, input, output, t) -end -function LSTMLayer:get_params() - return self.dag:get_params() + self:add_prefix(layers, connections) + local layer_repo = nerv.LayerRepo(layers, pr, global_conf) + self:graph_init(layer_repo, connections) end diff --git a/nerv/layer/rnn.lua b/nerv/layer/rnn.lua index e59cf5b..0b5ccaa 100644 --- a/nerv/layer/rnn.lua +++ b/nerv/layer/rnn.lua @@ -4,6 +4,10 @@ function RNNLayer:__init(id, global_conf, layer_conf) nerv.Layer.__init(self, id, global_conf, layer_conf) self:check_dim_len(1, 1) + if layer_conf.activation == nil then + layer_conf.activation = 'nerv.SigmoidLayer' + end + local din = layer_conf.dim_in[1] local dout = layer_conf.dim_out[1] @@ -16,20 +20,20 @@ function RNNLayer:__init(id, global_conf, layer_conf) ['nerv.AffineLayer'] = { main = {dim_in = {din, dout}, dim_out = {dout}, pr = pr}, }, - ['nerv.SigmoidLayer'] = { - sigmoid = {dim_in = {dout}, dim_out = {dout}}, + [layers.activation] = { + activation = {dim_in = {dout}, dim_out = {dout}}, }, ['nerv.DuplicateLayer'] = { - dup = {dim_in = {dout}, dim_out = {dout, dout}}, - } + duplicate = {dim_in = {dout}, dim_out = {dout, dout}}, + }, } local connections = { {'[1]', 'main[1]', 0}, - {'main[1]', 'sigmoid[1]', 0}, - {'sigmoid[1]', 'dup[1]', 0}, - {'dup[1]', 'main[2]', 1}, - {'dup[2]', '[1]', 0}, + {'main[1]', 'activation[1]', 0}, + {'activation[1]', 'duplicate[1]', 0}, + {'duplicate[1]', 'main[2]', 1}, + {'duplicate[2]', '[1]', 0}, } self:add_prefix(layers, connections) diff --git a/nerv/matrix/init.lua b/nerv/matrix/init.lua index cf85004..722c780 100644 --- a/nerv/matrix/init.lua +++ b/nerv/matrix/init.lua @@ -40,7 +40,8 @@ end --- Assign each element in a matrix using the value returned by a callback `gen`. -- @param gen the callback used to generated the values in the matrix, to which -- the indices of row and column will be passed (e.g., `gen(i, j)`) -function nerv.Matrix:generate(gen) + +function nerv.Matrix:_generate(gen) if (self:dim() == 2) then for i = 0, self:nrow() - 1 do local row = self[i] @@ -55,6 +56,21 @@ function nerv.Matrix:generate(gen) end end +function nerv.Matrix:generate(gen) + local tmp + if nerv.is_type(self, 'nerv.CuMatrixFloat') then + tmp = nerv.MMatrixFloat(self:nrow(), self:ncol()) + elseif nerv.is_type(self, 'nerv.CuMatrixDouble') then + tmp = nerv.MMatrixDouble(self:nrow(), self:ncol()) + else + tmp = self + end + tmp:_generate(gen) + if nerv.is_type(self, 'nerv.CuMatrix') then + self:copy_fromh(tmp) + end +end + --- Create a fresh new matrix of the same matrix type (as `self`). -- @param nrow optional, the number of rows in the created matrix if specified, -- otherwise `self:nrow()` will be used -- cgit v1.2.3 From 48e209f519e528c298e3471362451d6b0485abb8 Mon Sep 17 00:00:00 2001 From: Qi Liu Date: Fri, 11 Mar 2016 21:41:37 +0800 Subject: fix bug --- nerv/examples/network_debug/reader.lua | 2 +- nerv/lib/matrix/generic/mmatrix.c | 2 +- nerv/matrix/generic/matrix.c | 4 +++- 3 files changed, 5 insertions(+), 3 deletions(-) (limited to 'nerv') diff --git a/nerv/examples/network_debug/reader.lua b/nerv/examples/network_debug/reader.lua index d2624d3..b10baaf 100644 --- a/nerv/examples/network_debug/reader.lua +++ b/nerv/examples/network_debug/reader.lua @@ -59,7 +59,7 @@ function Reader:get_all_batch(global_conf) pos[i] = nil end while true do - --for i = 1, 100 do + -- for i = 1, 26 do local input = {} local output = {} for i = 1, global_conf.chunk_size do diff --git a/nerv/lib/matrix/generic/mmatrix.c b/nerv/lib/matrix/generic/mmatrix.c index 6272cbe..badddbd 100644 --- a/nerv/lib/matrix/generic/mmatrix.c +++ b/nerv/lib/matrix/generic/mmatrix.c @@ -274,7 +274,7 @@ void nerv_matrix_(fill)(Matrix *self, double val, NERV_SET_STATUS(status, NERV_NORMAL, 0); } -void nerv_matrix_(diagonalize)(Matrix *selfa, +void nerv_matrix_(diagonalize)(Matrix *self, MContext *context, Status *status) { if (self->nrow != self->ncol) NERV_EXIT_STATUS(status, MAT_MISMATCH_DIM, 0); diff --git a/nerv/matrix/generic/matrix.c b/nerv/matrix/generic/matrix.c index b544dd9..fe07585 100644 --- a/nerv/matrix/generic/matrix.c +++ b/nerv/matrix/generic/matrix.c @@ -387,8 +387,10 @@ static int nerv_matrix_(lua_scale_rows_by_row)(lua_State *L) { static int nerv_matrix_(lua_diagonalize)(lua_State *L) { Status status; + MATRIX_CONTEXT *context; + MATRIX_GET_CONTEXT(L, 2); Matrix *a = luaT_checkudata(L, 1, nerv_matrix_(tname)); - nerv_matrix_(diagonalize)(a, &status); + nerv_matrix_(diagonalize)(a, context, &status); NERV_LUA_CHECK_STATUS(L, status); return 0; } -- cgit v1.2.3 From 2660af7f6a9ac243a8ad38bf3375ef0fd292bf52 Mon Sep 17 00:00:00 2001 From: Qi Liu Date: Fri, 11 Mar 2016 21:45:49 +0800 Subject: fix bug of dropout --- nerv/layer/dropout.lua | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'nerv') diff --git a/nerv/layer/dropout.lua b/nerv/layer/dropout.lua index 39a8963..de0fb64 100644 --- a/nerv/layer/dropout.lua +++ b/nerv/layer/dropout.lua @@ -40,7 +40,7 @@ function DropoutLayer:propagate(input, output, t) if t == nil then t = 1 end - if self.gconf.dropout_rate ~= 0 then + if self.gconf.dropout_rate then self.mask[t]:rand_uniform() -- since we will lose a portion of the actvations, we multiply the -- activations by 1 / (1 - rate) to compensate -- cgit v1.2.3