diff options
author | txh18 <[email protected]> | 2015-10-30 11:48:20 +0800 |
---|---|---|
committer | txh18 <[email protected]> | 2015-10-30 11:48:20 +0800 |
commit | eac51ea6ba6fc4c2a39b9888e0109832feee564c (patch) | |
tree | 9f8466e6ffcdc1cece7790a6a0ad4b8fc0b73c2b | |
parent | 4d48cb10ba5bcd6e441a1919d61a64d0a6b4bee9 (diff) |
finished a simple version of lmseqreader
-rw-r--r-- | nerv/examples/lmptb/lmptb/lmseqreader.lua | 3 | ||||
-rw-r--r-- | nerv/examples/lmptb/m-tests/lmseqreader_test.lua | 17 |
2 files changed, 14 insertions, 6 deletions
diff --git a/nerv/examples/lmptb/lmptb/lmseqreader.lua b/nerv/examples/lmptb/lmptb/lmseqreader.lua index 26dc3be..9396783 100644 --- a/nerv/examples/lmptb/lmptb/lmseqreader.lua +++ b/nerv/examples/lmptb/lmptb/lmseqreader.lua @@ -12,7 +12,7 @@ function LMReader:__init(global_conf, batch_size, seq_size, vocab) self.fh = nil --file handle to read, nil means currently no file self.batch_size = batch_size self.seq_size = seq_size - self.log_pre = "[LOG]LMFeeder:" + self.log_pre = "[LOG]LMSeqReader:" self.vocab = vocab self.streams = nil end @@ -66,7 +66,6 @@ end function LMReader:get_batch(input, label) local got_new = false - local list = {} for i = 1, self.seq_size, 1 do local st = self.streams[i] for j = 1, self.batch_size, 1 do diff --git a/nerv/examples/lmptb/m-tests/lmseqreader_test.lua b/nerv/examples/lmptb/m-tests/lmseqreader_test.lua index bdea740..e1dae8c 100644 --- a/nerv/examples/lmptb/m-tests/lmseqreader_test.lua +++ b/nerv/examples/lmptb/m-tests/lmseqreader_test.lua @@ -10,11 +10,20 @@ local batch_size = 5 local seq_size = 3 local reader = nerv.LMSeqReader({}, batch_size, seq_size, vocab) reader:open_file(test_fn) +local input = {} +local label = {} +for i = 1, seq_size, 1 do + input[i] = {} + label[i] = {} +end while (1) do - local list = reader:get_batch() - if (list == nil) then break end - for i = 1, batch_size, 1 do - printf("%s(%d) ", list[i], vocab:get_word_str(list[i]).id) + local r = reader:get_batch(input, label) + if (r == false) then break end + for j = 1, batch_size, 1 do + for i = 1, seq_size, 1 do + printf("%s[L(%s)] ", input[i][j], label[i][j]) --vocab:get_word_str(input[i][j]).id + end + printf("\n") end printf("\n") end |