aboutsummaryrefslogtreecommitdiff
path: root/nerv/examples/lmptb/lm_trainer.lua
blob: 58d5bfce651595770cca9f72e2a96b952f58dfcf (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
require 'lmptb.lmvocab'
require 'lmptb.lmfeeder'
require 'lmptb.lmutil'
require 'lmptb.layer.init'
--require 'tnn.init'
require 'lmptb.lmseqreader'

local LMTrainer = nerv.class('nerv.LMTrainer')

--local printf = nerv.printf

--The bias param update in nerv don't have wcost added
function nerv.BiasParam:update_by_gradient(gradient) 
    local gconf = self.gconf
    local l2 = 1 - gconf.lrate * gconf.wcost
    self:_update_by_gradient(gradient, l2, l2)
end

--Returns: LMResult
function LMTrainer.lm_process_file_rnn(global_conf, fn, tnn, do_train, p_conf)
    if p_conf == nil then
        p_conf = {}
    end
    local reader
    if p_conf.one_sen_report == true then --report log prob one by one sentence
        if do_train == true then
            nerv.warning("LMTrainer.lm_process_file_rnn: warning, one_sen_report is true while do_train is also true, strange")
        end
        reader = nerv.LMSeqReader(global_conf, 1, global_conf.max_sen_len, global_conf.vocab)
    else
        reader = nerv.LMSeqReader(global_conf, global_conf.batch_size, global_conf.chunk_size, global_conf.vocab)
    end
    reader:open_file(fn)
    local result = nerv.LMResult(global_conf, global_conf.vocab)
    result:init("rnn")
    if global_conf.dropout_rate ~= nil then
        nerv.info("LMTrainer.lm_process_file_rnn: dropout_rate is %f", global_conf.dropout_rate)
    end
        
    global_conf.timer:flush()
    tnn:flush_all() --caution: will also flush the inputs from the reader!

    local next_log_wcn = global_conf.log_w_num
    local neto_bakm = global_conf.mmat_type(global_conf.batch_size, 1) --space backup matrix for network output

    while (1) do
        global_conf.timer:tic('most_out_loop_lmprocessfile')

        local r, feeds
        global_conf.timer:tic('tnn_beforeprocess')
        r, feeds = tnn:getfeed_from_reader(reader)
        if r == false then 
            break 
        end
    
        for t = 1, global_conf.chunk_size do
            tnn.err_inputs_m[t][1]:fill(1)
            for i = 1, global_conf.batch_size do
                if bit.band(feeds.flags_now[t][i], nerv.TNN.FC.HAS_LABEL) == 0 then
                    tnn.err_inputs_m[t][1][i - 1][0] = 0
                end
            end
        end
        global_conf.timer:toc('tnn_beforeprocess')

        --[[
        for j = 1, global_conf.chunk_size, 1 do
            for i = 1, global_conf.batch_size, 1 do
                printf("%s[L(%s)] ", feeds.inputs_s[j][i], feeds.labels_s[j][i])   --vocab:get_word_str(input[i][j]).id
            end
            printf("\n")
        end
        printf("\n")
        ]]--

        tnn:net_propagate()
 
        if do_train == true then
            tnn:net_backpropagate(false) 
            tnn:net_backpropagate(true)
        end

        global_conf.timer:tic('tnn_afterprocess')
        for t = 1, global_conf.chunk_size, 1 do
            tnn.outputs_m[t][1]:copy_toh(neto_bakm)
            for i = 1, global_conf.batch_size, 1 do
                if (feeds.labels_s[t][i] ~= global_conf.vocab.null_token) then
                    --result:add("rnn", feeds.labels_s[t][i], math.exp(tnn.outputs_m[t][1][i - 1][0]))
                    result:add("rnn", feeds.labels_s[t][i], math.exp(neto_bakm[i - 1][0]))
                end
            end            
        end
        tnn:move_right_to_nextmb({0}) --only copy for time 0
        global_conf.timer:toc('tnn_afterprocess')

        global_conf.timer:toc('most_out_loop_lmprocessfile')

        --print log
        if result["rnn"].cn_w > next_log_wcn then
            next_log_wcn = next_log_wcn + global_conf.log_w_num
            nerv.printf("%s %d words processed %s.\n", global_conf.sche_log_pre, result["rnn"].cn_w, os.date()) 
            nerv.printf("\t%s log prob per sample :%f.\n", global_conf.sche_log_pre, result:logp_sample("rnn"))
            for key, value in pairs(global_conf.timer.rec) do
                nerv.printf("\t [global_conf.timer]: time spent on %s:%.5f clock time\n", key, value)
            end
            global_conf.timer:flush()
            nerv.LMUtil.wait(0.1)
        end
         
        --[[
        for t = 1, global_conf.chunk_size do
            print(tnn.outputs_m[t][1])
        end
        ]]--


        collectgarbage("collect")                                              

        --break --debug
    end
    
    nerv.printf("%s Displaying result:\n", global_conf.sche_log_pre)
    nerv.printf("%s %s\n", global_conf.sche_log_pre, result:status("rnn"))
    nerv.printf("%s Doing on %s end.\n", global_conf.sche_log_pre, fn)
    
    return result
end