aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authortxh18 <[email protected]>2015-11-11 20:24:34 +0800
committertxh18 <[email protected]>2015-11-11 20:24:34 +0800
commit73402335834c990dbe6a7729ace7a830ed2f91ae (patch)
treed53384e6da7efbdb635030603aab7ab35a78b006
parent5761e68ec1b73ed867443fb2687739395f22f2f9 (diff)
added a little debug info in reader
-rw-r--r--nerv/examples/lmptb/lmptb/lmseqreader.lua4
-rw-r--r--nerv/examples/lmptb/m-tests/tnn_test.lua8
2 files changed, 8 insertions, 4 deletions
diff --git a/nerv/examples/lmptb/lmptb/lmseqreader.lua b/nerv/examples/lmptb/lmptb/lmseqreader.lua
index d75167e..e0dcd95 100644
--- a/nerv/examples/lmptb/lmptb/lmseqreader.lua
+++ b/nerv/examples/lmptb/lmptb/lmseqreader.lua
@@ -24,6 +24,7 @@ function LMReader:open_file(fn)
nerv.error("%s error: in open_file(fn is %s), file handle not nil.", self.log_pre, fn)
end
printf("%s opening file %s...\n", self.log_pre, fn)
+ print("batch_size:", self.batch_size, "chunk_size", self.chunk_size)
self.fh = io.open(fn, "r")
self.streams = {}
for i = 1, self.batch_size, 1 do
@@ -102,6 +103,9 @@ function LMReader:get_batch(feeds)
labels_s[j][i] = st.store[st.head + 1]
inputs_m[j][2][i - 1][self.vocab:get_word_str(st.store[st.head + 1]).id - 1] = 1
else
+ if (inputs_s[j][i] ~= self.vocab.null_token) then
+ nerv.error("reader error : input not null but label is null_token")
+ end
labels_s[j][i] = self.vocab.null_token
end
if (inputs_s[j][i] ~= self.vocab.null_token) then
diff --git a/nerv/examples/lmptb/m-tests/tnn_test.lua b/nerv/examples/lmptb/m-tests/tnn_test.lua
index a2c38f0..c4890b6 100644
--- a/nerv/examples/lmptb/m-tests/tnn_test.lua
+++ b/nerv/examples/lmptb/m-tests/tnn_test.lua
@@ -238,15 +238,15 @@ valid_fn = data_dir .. '/ptb.valid.txt.adds'
test_fn = data_dir .. '/ptb.test.txt.adds'
global_conf = {
- lrate = 0.1, wcost = 1e-6, momentum = 0,
+ lrate = 1, wcost = 1e-6, momentum = 0,
cumat_type = nerv.CuMatrixFloat,
mmat_type = nerv.MMatrixFloat,
nn_act_default = 0,
hidden_size = 200,
- chunk_size = 15,
- batch_size = 1,
- max_iter = 25,
+ chunk_size = 5,
+ batch_size = 10,
+ max_iter = 20,
param_random = function() return (math.random() / 5 - 0.1) end,
train_fn = train_fn,