aboutsummaryrefslogblamecommitdiff
path: root/nerv/examples/lmptb/lm_trainer.lua
blob: 9dfefe5dd74b5b44a475fb32d4baba075af10616 (plain) (tree)
1
2
3
4
5
6
7
8
9



                          
                    



                                              
                            
 






                                                      
                   




                                                                              
                     


                                                         
                                



                                                                                                                               
                                                                                                     
                                        
                                           
                                            
                                
        

                                           
       
 
                                                                                             
                        
 

                                                                


                                                                                                
        
                             
                                    


                                                                          
                                                                                                   

                                                                   
                

                                                            
                      
                                                  



                                                  
    
                                
                                          
                                    
                                                                                   



                                                        
                                                  












                                                                                                                       
                                


                                         

                                                 

                                   
                                                   
                                       
                                                                              

                                                                                                      



                                                                   


                           

                                             


                                                                                                             


               


                                                            
                                                            

                   
                                                 
                                                               

                                                                                                                
                                                             
                                                                                                     

                                     








                                            




                                                                               


                                                                          



                 







                                                                                


                                                         



                                                                                                                                 
                                                                                                       
                                        
                                           





                                            
                                                                                             








                                                                                                  
                                    




                                                                                                   

                                                                     








                                                            










































                                                                                                                       


                                                                                                                





                                                                                      

                                                                                                            































                                                                                                                  
 
require 'lmptb.lmvocab'
require 'lmptb.lmfeeder'
require 'lmptb.lmutil'
require 'lmptb.layer.init'
--require 'tnn.init'
require 'lmptb.lmseqreader'

local LMTrainer = nerv.class('nerv.LMTrainer')

--local printf = nerv.printf

--The bias param update in nerv don't have wcost added
function nerv.BiasParam:update_by_gradient(gradient) 
    local gconf = self.gconf
    local l2 = 1 - gconf.lrate * gconf.wcost
    self:_update_by_gradient(gradient, l2, l2)
end

--Returns: LMResult
function LMTrainer.lm_process_file_rnn(global_conf, fn, tnn, do_train, p_conf)
    if p_conf == nil then
        p_conf = {}
    end
    local reader
    local r_conf = {}
    if p_conf.compressed_label ~= nil then
        r_conf.compressed_label = p_conf.compressed_label
    end
    local chunk_size, batch_size
    if p_conf.one_sen_report == true then --report log prob one by one sentence
        if do_train == true then
            nerv.warning("LMTrainer.lm_process_file_rnn: warning, one_sen_report is true while do_train is also true, strange")
        end
        nerv.printf("lm_process_file_rnn: one_sen report mode, set chunk_size to max_sen_len(%d)\n", 
                global_conf.max_sen_len)
        batch_size