summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authortxh18 <[email protected]>2015-10-30 11:48:20 +0800
committertxh18 <[email protected]>2015-10-30 11:48:20 +0800
commiteac51ea6ba6fc4c2a39b9888e0109832feee564c (patch)
tree9f8466e6ffcdc1cece7790a6a0ad4b8fc0b73c2b
parent4d48cb10ba5bcd6e441a1919d61a64d0a6b4bee9 (diff)
finished a simple version of lmseqreader
-rw-r--r--nerv/examples/lmptb/lmptb/lmseqreader.lua3
-rw-r--r--nerv/examples/lmptb/m-tests/lmseqreader_test.lua17
2 files changed, 14 insertions, 6 deletions
diff --git a/nerv/examples/lmptb/lmptb/lmseqreader.lua b/nerv/examples/lmptb/lmptb/lmseqreader.lua
index 26dc3be..9396783 100644
--- a/nerv/examples/lmptb/lmptb/lmseqreader.lua
+++ b/nerv/examples/lmptb/lmptb/lmseqreader.lua
@@ -12,7 +12,7 @@ function LMReader:__init(global_conf, batch_size, seq_size, vocab)
self.fh = nil --file handle to read, nil means currently no file
self.batch_size = batch_size
self.seq_size = seq_size
- self.log_pre = "[LOG]LMFeeder:"
+ self.log_pre = "[LOG]LMSeqReader:"
self.vocab = vocab
self.streams = nil
end
@@ -66,7 +66,6 @@ end
function LMReader:get_batch(input, label)
local got_new = false
- local list = {}
for i = 1, self.seq_size, 1 do
local st = self.streams[i]
for j = 1, self.batch_size, 1 do
diff --git a/nerv/examples/lmptb/m-tests/lmseqreader_test.lua b/nerv/examples/lmptb/m-tests/lmseqreader_test.lua
index bdea740..e1dae8c 100644
--- a/nerv/examples/lmptb/m-tests/lmseqreader_test.lua
+++ b/nerv/examples/lmptb/m-tests/lmseqreader_test.lua
@@ -10,11 +10,20 @@ local batch_size = 5
local seq_size = 3
local reader = nerv.LMSeqReader({}, batch_size, seq_size, vocab)
reader:open_file(test_fn)
+local input = {}
+local label = {}
+for i = 1, seq_size, 1 do
+ input[i] = {}
+ label[i] = {}
+end
while (1) do
- local list = reader:get_batch()
- if (list == nil) then break end
- for i = 1, batch_size, 1 do
- printf("%s(%d) ", list[i], vocab:get_word_str(list[i]).id)
+ local r = reader:get_batch(input, label)
+ if (r == false) then break end
+ for j = 1, batch_size, 1 do
+ for i = 1, seq_size, 1 do
+ printf("%s[L(%s)] ", input[i][j], label[i][j]) --vocab:get_word_str(input[i][j]).id
+ end
+ printf("\n")
end
printf("\n")
end