diff options
-rw-r--r-- | nerv/examples/lmptb/m-tests/tnn_test.lua | 13 | ||||
-rw-r--r-- | nerv/examples/lmptb/rnn/tnn.lua | 6 |
2 files changed, 14 insertions, 5 deletions
diff --git a/nerv/examples/lmptb/m-tests/tnn_test.lua b/nerv/examples/lmptb/m-tests/tnn_test.lua index a778dea..7a8519e 100644 --- a/nerv/examples/lmptb/m-tests/tnn_test.lua +++ b/nerv/examples/lmptb/m-tests/tnn_test.lua @@ -82,7 +82,7 @@ function prepare_layers(global_conf, paramRepo) ["outputL"] = {{["ltp"] = "ltp_ho", ["bp"] = "bp_o"}, {["dim_in"] = {global_conf.hidden_size}, ["dim_out"] = {global_conf.vocab:size()}}}, }, - ["nerv.SoftmaxCELayer"] = { + ["nerv.SoftmaxCELayerT"] = { ["softmaxL"] = {{}, {["dim_in"] = {global_conf.vocab:size(), global_conf.vocab:size()}, ["dim_out"] = {1}}}, }, } @@ -164,6 +164,15 @@ function lm_process_file(global_conf, fn, tnn, do_train) r, feeds = tnn:getFeedFromReader(reader) if (r == false) then break end + + for t = 1, global_conf.chunk_size do + tnn.err_inputs_m[t][1]:fill(1) + for i = 1, global_conf.batch_size do + if (bit.bor(feeds.flags_now[t][i], nerv.TNN.FC.HAS_LABEL) == 0) then + tnn.err_inputs_m[t][1][i][0] = 0 + end + end + end --[[ for j = 1, global_conf.chunk_size, 1 do @@ -242,7 +251,7 @@ global_conf = { valid_fn = valid_fn, test_fn = test_fn, sche_log_pre = "[SCHEDULER]:", - log_w_num = 10000, --give a message when log_w_num words have been processed + log_w_num = 40000, --give a message when log_w_num words have been processed timer = nerv.Timer() } diff --git a/nerv/examples/lmptb/rnn/tnn.lua b/nerv/examples/lmptb/rnn/tnn.lua index 8c3963c..f470190 100644 --- a/nerv/examples/lmptb/rnn/tnn.lua +++ b/nerv/examples/lmptb/rnn/tnn.lua @@ -379,7 +379,7 @@ function TNN:propagate_dfs(ref, t) end end end - ref.layer:propagate(ref.inputs_m[t], ref.outputs_m[t]) --propagate! + ref.layer:propagate(ref.inputs_m[t], ref.outputs_m[t], t) --propagate! if (bit.bor(self.feeds_now.flagsPack_now[t], bit.bor(nerv.TNN.FC.SEQ_START, nerv.TNN.FC.SEQ_END)) > 0) then --restore cross-border history for i = 1, self.batch_size do local seq_start = bit.bor(self.feeds_now.flags_now[t][i], nerv.TNN.FC.SEQ_START) @@ -495,10 +495,10 @@ function TNN:backpropagate_dfs(ref, t, do_update) end end if (do_update == false) then - ref.layer:back_propagate(ref.err_inputs_m[t], ref.err_outputs_m[t], ref.inputs_m[t], ref.outputs_m[t]) + ref.layer:back_propagate(ref.err_inputs_m[t], ref.err_outputs_m[t], ref.inputs_m[t], ref.outputs_m[t], t) else --print(ref.err_inputs_m[t][1]) - ref.layer:update(ref.err_inputs_m[t], ref.inputs_m[t], ref.outputs_m[t]) + ref.layer:update(ref.err_inputs_m[t], ref.inputs_m[t], ref.outputs_m[t], t) end for i = 1, #ref.dim_in do if (ref.err_outputs_b[t][i] == true) then |