blob: e1dae8c01c9e394526ef77db3f6b509d85d30d5a (
plain) (
blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
|
require 'lmptb.lmseqreader'
local printf = nerv.printf
local test_fn = "/home/slhome/txh18/workspace/nerv/nerv/nerv/examples/lmptb/m-tests/some-text"
--local test_fn = "/home/slhome/txh18/workspace/nerv-project/nerv/examples/lmptb/PTBdata/ptb.train.txt"
local vocab = nerv.LMVocab()
vocab:build_file(test_fn)
local batch_size = 5
local seq_size = 3
local reader = nerv.LMSeqReader({}, batch_size, seq_size, vocab)
reader:open_file(test_fn)
local input = {}
local label = {}
for i = 1, seq_size, 1 do
input[i] = {}
label[i] = {}
end
while (1) do
local r = reader:get_batch(input, label)
if (r == false) then break end
for j = 1, batch_size, 1 do
for i = 1, seq_size, 1 do
printf("%s[L(%s)] ", input[i][j], label[i][j]) --vocab:get_word_str(input[i][j]).id
end
printf("\n")
end
printf("\n")
end
|