aboutsummaryrefslogblamecommitdiff
path: root/nerv/examples/asr_trainer.lua
blob: 645f1ef90081eaef0e61ae99add167630b36a6ab (plain) (tree)
1
2
3
4
5
6
7
8
9
10

             
                                    
                                            
                  


                                              

                                  

                                                         

                                   

                                                           
       




                                                           

                                                            
                                                    
                                                                        
                                                          
 







                                                                                

                            
                                                                      
                                 
                     







                                                                  

                                  
                                           


                                      
                                      

                                        
                             
                        
               
                            
                                 


                                                 
                                         

                                                             

                                       
                                       






                                                                                   
                    



                                                                   
                   
                                                      
                                                
               

                                                               
                                                  



                                                             
                               
                      

                                        



                                                         
                              

                                








                                                                     
           
                                                               



                            

                                                 

                                             


                                            
                                                          


                                  




                                                                                       
                                

               































                                                                          






                                  


                          
                   





                            
                 



                       


                                       


                                              








                                                                                   







                                         
              








                                                                             



                                                   

                                   


                                                                                      
                             
 
             







                                                                

















































                                                                                        
                                 
   
require 'lfs'
require 'pl'
local function build_trainer(ifname)
    local host_param_repo = nerv.ParamRepo()
    local mat_type
    local src_loc_type
    local train_loc_type
    host_param_repo:import(ifname, nil, gconf)
    if gconf.use_cpu then
        mat_type = gconf.mmat_type
        src_loc_type = nerv.ParamRepo.LOC_TYPES.ON_HOST
        train_loc_type = nerv.ParamRepo.LOC_TYPES.ON_HOST
    else
        mat_type = gconf.cumat_type
        src_loc_type = nerv.ParamRepo.LOC_TYPES.ON_HOST
        train_loc_type = nerv.ParamRepo.LOC_TYPES.ON_DEVICE
    end
    local param_repo = host_param_repo:copy(train_loc_type)
    local layer_repo = make_layer_repo(param_repo)
    local network = get_network(layer_repo)
    local global_transf = get_global_transf(layer_repo)
    local input_order = get_input_order()

    network = nerv.Network("nt", gconf, {network = network})
    network:init(gconf.batch_size, gconf.chunk_size)
    global_transf = nerv.Network("gt", gconf, {network = global_transf})
    global_transf:ini