diff options
Diffstat (limited to 'nerv/examples/lmptb/m-tests/lmseqreader_test.lua')
-rw-r--r-- | nerv/examples/lmptb/m-tests/lmseqreader_test.lua | 17 |
1 files changed, 13 insertions, 4 deletions
diff --git a/nerv/examples/lmptb/m-tests/lmseqreader_test.lua b/nerv/examples/lmptb/m-tests/lmseqreader_test.lua index bdea740..e1dae8c 100644 --- a/nerv/examples/lmptb/m-tests/lmseqreader_test.lua +++ b/nerv/examples/lmptb/m-tests/lmseqreader_test.lua @@ -10,11 +10,20 @@ local batch_size = 5 local seq_size = 3 local reader = nerv.LMSeqReader({}, batch_size, seq_size, vocab) reader:open_file(test_fn) +local input = {} +local label = {} +for i = 1, seq_size, 1 do + input[i] = {} + label[i] = {} +end while (1) do - local list = reader:get_batch() - if (list == nil) then break end - for i = 1, batch_size, 1 do - printf("%s(%d) ", list[i], vocab:get_word_str(list[i]).id) + local r = reader:get_batch(input, label) + if (r == false) then break end + for j = 1, batch_size, 1 do + for i = 1, seq_size, 1 do + printf("%s[L(%s)] ", input[i][j], label[i][j]) --vocab:get_word_str(input[i][j]).id + end + printf("\n") end printf("\n") end |