aboutsummaryrefslogtreecommitdiff
path: root/nerv/examples/lmptb/m-tests/lmseqreader_test.lua
diff options
context:
space:
mode:
Diffstat (limited to 'nerv/examples/lmptb/m-tests/lmseqreader_test.lua')
-rw-r--r--nerv/examples/lmptb/m-tests/lmseqreader_test.lua17
1 files changed, 13 insertions, 4 deletions
diff --git a/nerv/examples/lmptb/m-tests/lmseqreader_test.lua b/nerv/examples/lmptb/m-tests/lmseqreader_test.lua
index bdea740..e1dae8c 100644
--- a/nerv/examples/lmptb/m-tests/lmseqreader_test.lua
+++ b/nerv/examples/lmptb/m-tests/lmseqreader_test.lua
@@ -10,11 +10,20 @@ local batch_size = 5
local seq_size = 3
local reader = nerv.LMSeqReader({}, batch_size, seq_size, vocab)
reader:open_file(test_fn)
+local input = {}
+local label = {}
+for i = 1, seq_size, 1 do
+ input[i] = {}
+ label[i] = {}
+end
while (1) do
- local list = reader:get_batch()
- if (list == nil) then break end
- for i = 1, batch_size, 1 do
- printf("%s(%d) ", list[i], vocab:get_word_str(list[i]).id)
+ local r = reader:get_batch(input, label)
+ if (r == false) then break end
+ for j = 1, batch_size, 1 do
+ for i = 1, seq_size, 1 do
+ printf("%s[L(%s)] ", input[i][j], label[i][j]) --vocab:get_word_str(input[i][j]).id
+ end
+ printf("\n")
end
printf("\n")
end