From 5761e68ec1b73ed867443fb2687739395f22f2f9 Mon Sep 17 00:00:00 2001 From: txh18 Date: Wed, 11 Nov 2015 10:41:14 +0800 Subject: got good result when batch_size=1, strange! --- nerv/examples/lmptb/m-tests/tnn_test.lua | 12 +++++++++--- nerv/examples/lmptb/rnn/tnn.lua | 5 +++-- nerv/layer/affine.lua | 9 ++++++++- 3 files changed, 20 insertions(+), 6 deletions(-) diff --git a/nerv/examples/lmptb/m-tests/tnn_test.lua b/nerv/examples/lmptb/m-tests/tnn_test.lua index 40e332c..a2c38f0 100644 --- a/nerv/examples/lmptb/m-tests/tnn_test.lua +++ b/nerv/examples/lmptb/m-tests/tnn_test.lua @@ -155,6 +155,9 @@ function lm_process_file(global_conf, fn, tnn, do_train) local next_log_wcn = global_conf.log_w_num + global_conf.fz = 0 + global_conf.fz2 = 0 + while (1) do local r, feeds @@ -198,7 +201,7 @@ function lm_process_file(global_conf, fn, tnn, do_train) next_log_wcn = next_log_wcn + global_conf.log_w_num printf("%s %d words processed %s.\n", global_conf.sche_log_pre, result["rnn"].cn_w, os.date()) printf("\t%s log prob per sample :%f.\n", global_conf.sche_log_pre, result:logp_sample("rnn")) - nerv.LMUtil.wait(1) + nerv.LMUtil.wait(0.1) end --[[ @@ -213,6 +216,9 @@ function lm_process_file(global_conf, fn, tnn, do_train) --break --debug end + + print("gconf.fz", global_conf.fz) + print("gconf.fz2", global_conf.fz2) printf("%s Displaying result:\n", global_conf.sche_log_pre) printf("%s %s\n", global_conf.sche_log_pre, result:status("rnn")) @@ -232,14 +238,14 @@ valid_fn = data_dir .. '/ptb.valid.txt.adds' test_fn = data_dir .. '/ptb.test.txt.adds' global_conf = { - lrate = 1, wcost = 1e-6, momentum = 0, + lrate = 0.1, wcost = 1e-6, momentum = 0, cumat_type = nerv.CuMatrixFloat, mmat_type = nerv.MMatrixFloat, nn_act_default = 0, hidden_size = 200, chunk_size = 15, - batch_size = 10, + batch_size = 1, max_iter = 25, param_random = function() return (math.random() / 5 - 0.1) end, diff --git a/nerv/examples/lmptb/rnn/tnn.lua b/nerv/examples/lmptb/rnn/tnn.lua index 10b501e..fc5321d 100644 --- a/nerv/examples/lmptb/rnn/tnn.lua +++ b/nerv/examples/lmptb/rnn/tnn.lua @@ -386,6 +386,7 @@ function TNN:propagate_dfs(ref, t) end ]]-- ref.layer:propagate(ref.inputs_m[t], ref.outputs_m[t], t) --propagate! + if (bit.band(self.feeds_now.flagsPack_now[t], bit.bor(nerv.TNN.FC.SEQ_START, nerv.TNN.FC.SEQ_END)) > 0) then --restore cross-border history for i = 1, self.batch_size do local seq_start = bit.band(self.feeds_now.flags_now[t][i], nerv.TNN.FC.SEQ_START) @@ -393,13 +394,13 @@ function TNN:propagate_dfs(ref, t) if (seq_start > 0 or seq_end > 0) then for p, conn in pairs(ref.o_conns_p) do if ((ref.o_conns_p[p].time > 0 and seq_end > 0) or (ref.o_conns_p[p].time < 0 and seq_start > 0)) then + self.gconf.fz2 = self.gconf.fz2 + 1 ref.outputs_m[t][p][i - 1]:fill(self.gconf.nn_act_default) end end end end end - --set input flag for future layers for i = 1, #ref.dim_out do if (ref.outputs_b[t][i] == true) then @@ -501,13 +502,13 @@ function TNN:backpropagate_dfs(ref, t, do_update) if (seq_start > 0 or seq_end > 0) then for p, conn in pairs(ref.i_conns_p) do if ((ref.i_conns_p[p].time > 0 and seq_start > 0) or (ref.i_conns_p[p].time < 0 and seq_end > 0)) then --cross-border, set to zero + self.gconf.fz = self.gconf.fz + 1 ref.err_outputs_m[t][p][i - 1]:fill(0) end end end end end - for i = 1, #ref.dim_in do if (ref.err_outputs_b[t][i] == true) then diff --git a/nerv/layer/affine.lua b/nerv/layer/affine.lua index 015ec3f..0462383 100644 --- a/nerv/layer/affine.lua +++ b/nerv/layer/affine.lua @@ -31,7 +31,14 @@ function LinearTransParam:update(gradient) MatrixParam.update(self, gradient) local gconf = self.gconf -- weight decay - self.trans:add(self.trans, self.trans, 1.0, -gconf.lrate * gconf.wcost) + self.trans:add(self.trans, self.trans, 1.0, -gconf.lrate * gconf.wcost / gconf.batch_size) +end + +function BiasParam:update(gradient) + MatrixParam.update(self, gradient) + local gconf = self.gconf + -- weight decay + self.trans:add(self.trans, self.trans, 1.0, -gconf.lrate * gconf.wcost / gconf.batch_size) end function AffineLayer:__init(id, global_conf, layer_conf) -- cgit v1.2.3-70-g09d2