From 1499ef632a2b9d63d6f68da9f42401d4d141a9f6 Mon Sep 17 00:00:00 2001
From: txh18 <cloudygooseg@gmail.com>
Date: Sun, 8 Nov 2015 19:49:15 +0800
Subject: switched to softmax_ce_t

---
 nerv/examples/lmptb/m-tests/tnn_test.lua | 13 +++++++++++--
 nerv/examples/lmptb/rnn/tnn.lua          |  6 +++---
 2 files changed, 14 insertions(+), 5 deletions(-)

diff --git a/nerv/examples/lmptb/m-tests/tnn_test.lua b/nerv/examples/lmptb/m-tests/tnn_test.lua
index a778dea..7a8519e 100644
--- a/nerv/examples/lmptb/m-tests/tnn_test.lua
+++ b/nerv/examples/lmptb/m-tests/tnn_test.lua
@@ -82,7 +82,7 @@ function prepare_layers(global_conf, paramRepo)
             ["outputL"] = {{["ltp"] = "ltp_ho", ["bp"] = "bp_o"}, {["dim_in"] = {global_conf.hidden_size}, ["dim_out"] = {global_conf.vocab:size()}}},
         },
 
-        ["nerv.SoftmaxCELayer"] = {
+        ["nerv.SoftmaxCELayerT"] = {
             ["softmaxL"] = {{}, {["dim_in"] = {global_conf.vocab:size(), global_conf.vocab:size()}, ["dim_out"] = {1}}},
         },
     }
@@ -164,6 +164,15 @@ function lm_process_file(global_conf, fn, tnn, do_train)
 
         r, feeds = tnn:getFeedFromReader(reader)
         if (r == false) then break end
+    
+        for t = 1, global_conf.chunk_size do
+            tnn.err_inputs_m[t][1]:fill(1)
+            for i = 1, global_conf.batch_size do
+                if (bit.bor(feeds.flags_now[t][i], nerv.TNN.FC.HAS_LABEL) == 0) then
+                    tnn.err_inputs_m[t][1][i][0] = 0
+                end
+            end
+        end
 
         --[[
         for j = 1, global_conf.chunk_size, 1 do
@@ -242,7 +251,7 @@ global_conf = {
     valid_fn = valid_fn,
     test_fn = test_fn,
     sche_log_pre = "[SCHEDULER]:",
-    log_w_num = 10000, --give a message when log_w_num words have been processed
+    log_w_num = 40000, --give a message when log_w_num words have been processed
     timer = nerv.Timer()
 }
 
diff --git a/nerv/examples/lmptb/rnn/tnn.lua b/nerv/examples/lmptb/rnn/tnn.lua
index 8c3963c..f470190 100644
--- a/nerv/examples/lmptb/rnn/tnn.lua
+++ b/nerv/examples/lmptb/rnn/tnn.lua
@@ -379,7 +379,7 @@ function TNN:propagate_dfs(ref, t)
             end
         end
     end
-    ref.layer:propagate(ref.inputs_m[t], ref.outputs_m[t]) --propagate!
+    ref.layer:propagate(ref.inputs_m[t], ref.outputs_m[t], t) --propagate!
     if (bit.bor(self.feeds_now.flagsPack_now[t], bit.bor(nerv.TNN.FC.SEQ_START, nerv.TNN.FC.SEQ_END)) > 0) then --restore cross-border history
         for i = 1, self.batch_size do
             local seq_start = bit.bor(self.feeds_now.flags_now[t][i], nerv.TNN.FC.SEQ_START) 
@@ -495,10 +495,10 @@ function TNN:backpropagate_dfs(ref, t, do_update)
         end
     end
     if (do_update == false) then
-        ref.layer:back_propagate(ref.err_inputs_m[t], ref.err_outputs_m[t], ref.inputs_m[t], ref.outputs_m[t])
+        ref.layer:back_propagate(ref.err_inputs_m[t], ref.err_outputs_m[t], ref.inputs_m[t], ref.outputs_m[t], t)
     else 
         --print(ref.err_inputs_m[t][1])
-        ref.layer:update(ref.err_inputs_m[t], ref.inputs_m[t], ref.outputs_m[t])
+        ref.layer:update(ref.err_inputs_m[t], ref.inputs_m[t], ref.outputs_m[t], t)
     end
     for i = 1, #ref.dim_in do
         if (ref.err_outputs_b[t][i] == true) then
-- 
cgit v1.2.3-70-g09d2