aboutsummaryrefslogtreecommitdiff
path: root/nerv/examples/lmptb/lm_trainer.lua
blob: d34634c11d9365675260ba7f15e20675ee0e5227 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
require 'lmptb.lmvocab'
require 'lmptb.lmfeeder'
require 'lmptb.lmutil'
require 'lmptb.layer.init'
require 'rnn.init'
require 'lmptb.lmseqreader'

local LMTrainer = nerv.class('nerv.LMTrainer')

local printf = nerv.printf

--Returns: LMResult
function LMTrainer.lm_process_file(global_conf, fn, tnn, do_train)
    local reader = nerv.LMSeqReader(global_conf, global_conf.batch_size, global_conf.chunk_size, global_conf.vocab)
    reader:open_file(fn)
    local result = nerv.LMResult(global_conf, global_conf.vocab)
    result:init("rnn")
    
    tnn:flush_all() --caution: will also flush the inputs from the reader!

    local next_log_wcn = global_conf.log_w_num

    while (1) do
        local r, feeds

        r, feeds = tnn:getFeedFromReader(reader)
        if (r == false) then break end
    
        for t = 1, global_conf.chunk_size do
            tnn.err_inputs_m[t][1]:fill(1)
            for i = 1, global_conf.batch_size do
                if (bit.band(feeds.flags_now[t][i], nerv.TNN.FC.HAS_LABEL) == 0) then
                    tnn.err_inputs_m[t][1][i - 1][0] = 0
                end
            end
        end

        --[[
        for j = 1, global_conf.chunk_size, 1 do
            for i = 1, global_conf.batch_size, 1 do
                printf("%s[L(%s)] ", feeds.inputs_s[j][i], feeds.labels_s[j][i])   --vocab:get_word_str(input[i][j]).id
            end
            printf("\n")
        end
        printf("\n")
        ]]--

        tnn:net_propagate()
 
        if (do_train == true) then
            tnn:net_backpropagate(false) 
            tnn:net_backpropagate(true)
        end
 
        for t = 1, global_conf.chunk_size, 1 do
            for i = 1, global_conf.batch_size, 1 do
                if (feeds.labels_s[t][i] ~= global_conf.vocab.null_token) then
                    result:add("rnn", feeds.labels_s[t][i], math.exp(tnn.outputs_m[t][1][i - 1][0]))
                end
            end            
        end
        if (result["rnn"].cn_w > next_log_wcn) then
            next_log_wcn = next_log_wcn + global_conf.log_w_num
            printf("%s %d words processed %s.\n", global_conf.sche_log_pre, result["rnn"].cn_w, os.date()) 
            printf("\t%s log prob per sample :%f.\n", global_conf.sche_log_pre, result:logp_sample("rnn"))
            nerv.LMUtil.wait(0.1)
        end
         
        --[[
        for t = 1, global_conf.chunk_size do
            print(tnn.outputs_m[t][1])
        end
        ]]--

        tnn:moveRightToNextMB()

        collectgarbage("collect")                                              

        --break --debug
    end
    
    printf("%s Displaying result:\n", global_conf.sche_log_pre)
    printf("%s %s\n", global_conf.sche_log_pre, result:status("rnn"))
    printf("%s Doing on %s end.\n", global_conf.sche_log_pre, fn)
    
    return result
end