aboutsummaryrefslogblamecommitdiff
path: root/nerv/examples/lmptb/lm_trainer.lua
blob: 6bd06bbe5aadde11e961f3cf5e7228e3e34d6c5e (plain) (tree)
1
2
3
4
5
6
7
8
9



                          
                    



                                              
                            
 






                                                      
                   




                                                                              
                
                                



                                                                                                                               



                                                                                                                         
                                
        

                                           
       
 
                                                                                             
                        
 

                                                                


                                                                                                
        
                             


                                                                          
                                                                                                   

                

                                                            
                      
                                                  



                                                  
    
                                
                                          
                                    
                                                                                   



                                                        
                                                  












                                                                                                                       
                                


                                         

                                                 

                                   
                                                   
                                       
                                                                              

                                                                                                      



                                                                   


                           





                                                                                                   


                                                            
                                                            

                   
                                                 
                                                               

                                                                                                                
                                                             
                                                                                                     

                                     








                                            




                                                                               


                                                                          



                 































































































































                                                                                                                                 
 
require 'lmptb.lmvocab'
require 'lmptb.lmfeeder'
require 'lmptb.lmutil'
require 'lmptb.layer.init'
--require 'tnn.init'
require 'lmptb.lmseqreader'

local LMTrainer = nerv.class('nerv.LMTrainer')

--local printf = nerv.printf

--The bias param update in nerv don't have wcost added
function nerv.BiasParam:update_by_gradient(gradient) 
    local gconf = self.gconf
    local l2 = 1 - gconf.lrate * gconf.wcost
    self:_update_by_gradient(gradient, l2, l2)
end

--Returns: LMResult
function LMTrainer.lm_process_file_rnn(global_conf, fn, tnn, do_train, p_conf)
    if p_conf == nil then
        p_conf = {}
    end
    local reader
    local r_conf
    local chunk_size, batch_size
    if p_conf.one_sen_report == true then --report log prob one by one sentence
        if do_train == true then
            nerv.warning("LMTrainer.lm_process_file_rnn: warning, one_sen_report is true while do_train is also true, strange")
        end
        nerv.printf("lm_process_file_rnn: one_sen report mode, set batch_size to 1 and chunk_size to max_sen_len(%d)\n", 
                global_conf.max_sen_len)
        batch_size = 1 
        chunk_size = global_conf.max_sen_len
        r_conf["se_mode"] = true
    else
        batch_size = global_conf.batch_size
        chunk_size = global_conf.chunk_size
    end

    reader = nerv.LMSeqReader(global_conf, batch_size, chunk_size, global_conf.vocab, r_conf)
    reader:open_file(fn)

    local result = nerv.LMResult(global_conf, global_conf.vocab)
    result:init("rnn")
    if global_conf.dropout_rate ~= nil then
        nerv.info("LMTrainer.lm_process_file_rnn: dropout_rate is %f", global_conf.dropout_rate)
    end
        
    global_conf.timer:flush()
    tnn:flush_all() --caution: will also flush the inputs from the reader!

    local next_log_wcn = global_conf.log_w_num
    local neto_bakm = global_conf.mmat_type(batch_size, 1) --space backup matrix for network output

    while (1) do
        global_conf.timer:tic('most_out_loop_lmprocessfile')

        local r, feeds
        global_conf.timer:tic('tnn_beforeprocess')
        r, feeds = tnn:getfeed_from_reader(reader)
        if r == false then 
            break 
        end
    
        for t = 1, chunk_size do
            tnn.err_inputs_m[t][1]:fill(1)
            for i = 1, batch_size do
                if bit.band(feeds.flags_now[t][i], nerv.TNN.FC.HAS_LABEL) == 0 then
                    tnn.err_inputs_m[t][1][i - 1][0] = 0
                end
            end
        end
        global_conf.timer:toc('tnn_beforeprocess')

        --[[
        for j = 1, global_conf.chunk_size, 1 do
            for i = 1, global_conf.batch_size, 1 do
                printf("%s[L(%s)] ", feeds.inputs_s[j][i], feeds.labels_s[j][i])   --vocab:get_word_str(input[i][j]).id
            end
            printf("\n")
        end
        printf("\n")
        ]]--

        tnn:net_propagate()
 
        if do_train == true then
            tnn:net_backpropagate(false) 
            tnn:net_backpropagate(true)
        end

        global_conf.timer:tic('tnn_afterprocess')
        local sen_logp = {}
        for t = 1, chunk_size, 1 do
            tnn.outputs_m[t][1]:copy_toh(neto_bakm)
            for i = 1, batch_size, 1 do
                if (feeds.labels_s[t][i] ~= global_conf.vocab.null_token) then
                    --result:add("rnn", feeds.labels_s[t][i], math.exp(tnn.outputs_m[t][1][i - 1][0]))
                    result:add("rnn", feeds.labels_s[t][i], math.exp(neto_bakm[i - 1][0]))
                    if sen_logp[i] == nil then
                        sen_logp[i] = 0
                    end
                    sen_logp[i] = sen_logp[i] + neto_bakm[i - 1][0]
                end
            end            
        end
        if p_conf.one_sen_report == true then
            for i = 1, batch_size do
                nerv.printf("LMTrainer.lm_process_file_rnn: one_sen_report, %f\n", sen_logp[i])    
            end
        end

        tnn:move_right_to_nextmb({0}) --only copy for time 0
        global_conf.timer:toc('tnn_afterprocess')

        global_conf.timer:toc('most_out_loop_lmprocessfile')

        --print log
        if result["rnn"].cn_w > next_log_wcn then
            next_log_wcn = next_log_wcn + global_conf.log_w_num
            nerv.printf(