aboutsummaryrefslogtreecommitdiff
path: root/nerv/examples/lmptb/lm_trainer.lua
blob: 3c7078e586a3cd0477d1be6ae4dee7eee18e2186 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
require 'lmptb.lmvocab'
require 'lmptb.lmfeeder'
require 'lmptb.lmutil'
require 'lmptb.layer.init'
--require 'tnn.init'
require 'lmptb.lmseqreader'

local LMTrainer = nerv.class('nerv.LMTrainer')

--local printf = nerv.printf

--The bias param update in nerv don't have wcost added
function nerv.BiasParam:update_by_gradient(gradient) 
    local gconf = self.gconf
    local l2 = 1 - gconf.lrate * gconf.wcost
    self:_update_by_gradient(gradient, l2, l2)
end

--Returns: LMResult
function LMTrainer.lm_process_file_rnn(global_conf, fn, tnn, do_train, p_conf)
    if p_conf == nil then
        p_conf = {}
    end
    local reader
    local chunk_size, batch_size
    if p_conf.one_sen_report == true then --report log prob one by one sentence
        if do_train == true then
            nerv.warning("LMTrainer.lm_process_file_rnn: warning, one_sen_report is true while do_train is also true, strange")
        end
        nerv.printf("lm_process_file_rnn: one_sen report mode, set batch_size to 1 and chunk_size to max_sen_len(%d)\n", 
                global_conf.max_sen_len)
        batch_size = 1 
        chunk_size = global_conf.max_sen_len
    else
        batch_size = global_conf.batch_size
        chunk_size = global_conf.chunk_size
    end

    reader = nerv.LMSeqReader(global_conf, batch_size, chunk_size, global_conf.vocab)
    reader:open_file(fn)

    local result = nerv.LMResult(global_conf, global_conf.vocab)
    result:init("rnn")
    if global_conf.dropout_rate ~= nil then
        nerv.info("LMTrainer.lm_process_file_rnn: dropout_rate is %f", global_conf.dropout_rate)
    end
        
    global_conf.timer:flush()
    tnn:flush_all() --caution: will also flush the inputs from the reader!

    local next_log_wcn = global_conf.log_w_num
    local neto_bakm = global_conf.mmat_type(batch_size, 1) --space backup matrix for network output

    while (1) do
        global_conf.timer:tic('most_out_loop_lmprocessfile')

        local r, feeds
        global_conf.timer:tic('tnn_beforeprocess')
        r, feeds = tnn:getfeed_from_reader(reader)
        if r == false then 
            break 
        end
    
        for t = 1, chunk_size do
            tnn.err_inputs_m[t][1]:fill(1)
            for i = 1, batch_size do
                if bit.band(feeds.flags_now[t][i], nerv.TNN.FC.HAS_LABEL) == 0 then
                    tnn.err_inputs_m[t][1][i - 1][0] = 0
                end
            end
        end
        global_conf.timer:toc('tnn_beforeprocess')

        --[[
        for j = 1, global_conf.chunk_size, 1 do
            for i = 1, global_conf.batch_size, 1 do
                printf("%s[L(%s)] ", feeds.inputs_s[j][i], feeds.labels_s[j][i])   --vocab:get_word_str(input[i][j]).id
            end
            printf("\n")
        end
        printf("\n")
        ]]--

        tnn:net_propagate()
 
        if do_train == true then
            tnn:net_backpropagate(false) 
            tnn:net_backpropagate(true)
        end

        global_conf.timer:tic('tnn_afterprocess')
        local sen_logp = {}
        for t = 1, chunk_size, 1 do
            tnn.outputs_m[t][1]:copy_toh(neto_bakm)
            for i = 1, batch_size, 1 do
                if (feeds.labels_s[t][i] ~= global_conf.vocab.null_token) then
                    --result:add("rnn", feeds.labels_s[t][i], math.exp(tnn.outputs_m[t][1][i - 1][0]))
                    result:add("rnn", feeds.labels_s[t][i], math.exp(neto_bakm[i - 1][0]))
                    if sen_logp[i] == nil then
                        sen_logp[i] = 0
                    end
                    sen_logp[i] = sen_logp[i] + neto_bakm[i - 1][0]
                end
            end            
        end
        if p_conf.one_sen_report == true then
            for i = 1, batch_size do
                nerv.printf("LMTrainer.lm_process_file_rnn: one_sen_report, %f\n", sen_logp[i])    
            end
        end

        tnn:move_right_to_nextmb({0}) --only copy for time 0
        global_conf.timer:toc('tnn_afterprocess')

        global_conf.timer:toc('most_out_loop_lmprocessfile')

        --print log
        if result["rnn"].cn_w > next_log_wcn then
            next_log_wcn = next_log_wcn + global_conf.log_w_num
            nerv.printf("%s %d words processed %s.\n", global_conf.sche_log_pre, result["rnn"].cn_w, os.date()) 
            nerv.printf("\t%s log prob per sample :%f.\n", global_conf.sche_log_pre, result:logp_sample("rnn"))
            for key, value in pairs(global_conf.timer.rec) do
                nerv.printf("\t [global_conf.timer]: time spent on %s:%.5f clock time\n", key, value)
            end
            global_conf.timer:flush()
            nerv.LMUtil.wait(0.1)
        end
         
        --[[
        for t = 1, global_conf.chunk_size do
            print(tnn.outputs_m[t][1])
        end
        ]]--

        collectgarbage("collect")                                              

        --break --debug
    end
    
    nerv.printf("%s Displaying result:\n", global_conf.sche_log_pre)
    nerv.printf("%s %s\n", global_conf.sche_log_pre, result:status("rnn"))
    nerv.printf("%s Doing on %s end.\n", global_conf.sche_log_pre, fn)
    
    return result
end