aboutsummaryrefslogtreecommitdiff
path: root/nerv/examples/lmptb/m-tests/lmseqreader_test.lua
blob: e1dae8c01c9e394526ef77db3f6b509d85d30d5a (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
require 'lmptb.lmseqreader'

local printf = nerv.printf

local test_fn = "/home/slhome/txh18/workspace/nerv/nerv/nerv/examples/lmptb/m-tests/some-text"
--local test_fn = "/home/slhome/txh18/workspace/nerv-project/nerv/examples/lmptb/PTBdata/ptb.train.txt"
local vocab = nerv.LMVocab()
vocab:build_file(test_fn)
local batch_size = 5
local seq_size = 3
local reader = nerv.LMSeqReader({}, batch_size, seq_size, vocab)
reader:open_file(test_fn)
local input = {}
local label = {}
for i = 1, seq_size, 1 do
    input[i] = {}
    label[i] = {}
end
while (1) do
    local r = reader:get_batch(input, label)
    if (r == false) then break end
    for j = 1, batch_size, 1 do
        for i = 1, seq_size, 1 do
            printf("%s[L(%s)] ", input[i][j], label[i][j])   --vocab:get_word_str(input[i][j]).id
        end
        printf("\n")
    end
    printf("\n")
end