aboutsummaryrefslogtreecommitdiff
path: root/nerv/examples/lmptb/m-tests/lmseqreader_test.lua
blob: 3f99741329e5e8d43c6a28cdafea505abc395c03 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
require 'lmptb.lmseqreader'
require 'lmptb.lmutil'

local printf = nerv.printf

local test_fn = "/home/slhome/txh18/workspace/nerv/nerv/nerv/examples/lmptb/m-tests/some-text"
--local test_fn = "/home/slhome/txh18/workspace/nerv-project/nerv/examples/lmptb/PTBdata/ptb.train.txt"
local vocab = nerv.LMVocab()
vocab:build_file(test_fn)
local chunk_size = 15
local batch_size = 3
local global_conf = {
    lrate = 1, wcost = 1e-6, momentum = 0,
    cumat_type = nerv.CuMatrixFloat,
    mmat_type = nerv.MMatrixFloat,

    hidden_size = 20,
    chunk_size = chunk_size,
    batch_size = batch_size, 
    max_iter = 18,
    param_random = function() return (math.random() / 5 - 0.1) end,
    independent = true,

    train_fn = train_fn,
    test_fn = test_fn,
    sche_log_pre = "[SCHEDULER]:",
    log_w_num = 10, --give a message when log_w_num words have been processed
    timer = nerv.Timer(),
    
    vocab = vocab
}

local reader = nerv.LMSeqReader(global_conf, batch_size, chunk_size, vocab, 
        {["se_mode"] = true, ["same_io"] = true})
reader:open_file(test_fn)
local feeds = {}
feeds.flags_now = {}
feeds.inputs_m = {}
feeds.flagsPack_now = {}
for j = 1, chunk_size do
    feeds.inputs_m[j] = {global_conf.cumat_type(batch_size, 1), global_conf.cumat_type(batch_size, global_conf.vocab:size())}
    feeds.flags_now[j] = {}
end
for k = 1, 5 do
    local r = reader:get_batch(feeds)
    if (r == false) then break end
    for j = 1, chunk_size, 1 do
        for i = 1, batch_size, 1 do
            printf("%s[L(%s)]F%d ", feeds.inputs_s[j][i], feeds.labels_s[j][i], feeds.flags_now[j][i])   --vocab:get_word_str(input[i][j]).id
        end
        printf("\n")
    end
    printf("\n")
end
printf("reader.sen_start %s\n", tostring(reader.stat.al_sen_start))