aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--nerv/examples/lmptb/lmptb/lmseqreader.lua3
-rw-r--r--nerv/examples/lmptb/m-tests/tnn_test.lua2
2 files changed, 4 insertions, 1 deletions
diff --git a/nerv/examples/lmptb/lmptb/lmseqreader.lua b/nerv/examples/lmptb/lmptb/lmseqreader.lua
index f7e2539..3232b5a 100644
--- a/nerv/examples/lmptb/lmptb/lmseqreader.lua
+++ b/nerv/examples/lmptb/lmptb/lmseqreader.lua
@@ -105,6 +105,9 @@ function LMReader:get_batch(feeds)
labels_s[j][i] = self.vocab.null_token
end
if (inputs_s[j][i] ~= self.vocab.null_token) then
+ if (labels_s[j][i] == self.vocab.null_token) then
+ nerv.error("reader error : label is null while input is not null")
+ end
flags[j][i] = bit.bor(flags[j][i], nerv.TNN.FC.SEQ_NORM)
got_new = true
st.store[st.head] = nil
diff --git a/nerv/examples/lmptb/m-tests/tnn_test.lua b/nerv/examples/lmptb/m-tests/tnn_test.lua
index 9c20914..ac9f570 100644
--- a/nerv/examples/lmptb/m-tests/tnn_test.lua
+++ b/nerv/examples/lmptb/m-tests/tnn_test.lua
@@ -165,7 +165,7 @@ function lm_process_file(global_conf, fn, tnn, do_train)
tnn.err_inputs_m[t][1]:fill(1)
for i = 1, global_conf.batch_size do
if (bit.band(feeds.flags_now[t][i], nerv.TNN.FC.HAS_LABEL) == 0) then
- tnn.err_inputs_m[t][1][i][0] = 0
+ tnn.err_inputs_m[t][1][i - 1][0] = 0
end
end
end