aboutsummaryrefslogblamecommitdiff
path: root/nerv/examples/asr_trainer.lua
blob: 6bdf57c073694e2dc710b3c29556b80941f2a06c (plain) (tree)
1
2
3
4
5
6
7
8
9
10

             
                                    
                                            
                  


                                              

                                  

                                                         

                                   

                                                           
       




                                                           





                                                                        







                                                                                

                            
                                                                      
                                 
                     
                                                   
                            

                                  



                                              
                                      

                                        
                             
                        
               
                            


                                                 


                                                             



                                                                           
                                                              




                                                
               
                                                          



                                                           







                                                                                    
                      

                                        



                                                         
                              

                                








                                                                     
           
                                                               



                            
















                                                 































                                                                          






                                  








                            
                 



                       


                                       


                                              








                                                                                   







                                         
              








                                                                             



                                                   

                                   


                                                                                      
                             
 
             







                                                                

















































                                                                                        
                                 
   
require 'lfs'
require 'pl'
local function build_trainer(ifname)
    local host_param_repo = nerv.ParamRepo()
    local mat_type
    local src_loc_type
    local train_loc_type
    host_param_repo:import(ifname, nil, gconf)
    if gconf.use_cpu then
        mat_type = gconf.mmat_type
        src_loc_type = nerv.ParamRepo.LOC_TYPES.ON_HOST
        train_loc_type = nerv.ParamRepo.LOC_TYPES.ON_HOST
    else
        mat_type = gconf.cumat_type
        src_loc_type = nerv.ParamRepo.LOC_TYPES.ON_HOST
        train_loc_type = nerv.ParamRepo.LOC_TYPES.ON_DEVICE
    end
    local param_repo = host_param_repo:copy(train_loc_type)
    local layer_repo = make_layer_repo(param_repo)
    local network = get_network(layer_repo)
    local global_transf = get_global_transf(layer_repo)
    local input_order = get_input_order()

    network = nerv.Network("nt", gconf, {network = network})
    network:init(gconf.batch_size, 1)
    global_transf = nerv.Network("gt", gconf, {network = global_transf})
    global_transf:init(gconf.batch_size, 1)

    local iterative_trainer = function (prefix, scp_file, bp, rebind_param_repo)
        -- rebind the params if necessary
        if rebind_param_repo then
            host_param_repo = rebind_param_repo
            param_repo = host_param_repo:copy(train_loc_type)
            layer_repo:rebind(param_repo)
            rebind_param_repo = nil
        end
        gconf.randomize = bp
        -- build buffer
        local buffer = make_buffer(make_readers(scp_file, layer_repo))
        -- initialize the network
        gconf.cnt = 0
        err_input = {mat_type(gconf.batch_size, 1)}
        err_input[1]:fill(1)
        network:epoch_init()
        global_transf:epoch_init()
        for data in buffer.get_data, buffer do
            -- prine stat periodically
            gconf.cnt = gconf.cnt + 1
            if gconf.cnt == 1000 then
                print_stat(layer_repo)
                mat_type.print_profile()
                mat_type.clear_profile()
                gconf.cnt = 0
                -- break
            end
            local input = {}
--            if gconf.cnt == 1000 then break end
            for i, e in ipairs(input_order) do
                local id = e.id
                if data[id] == nil then
                    nerv.error("input data %s not found", id)
                end
                local transformed
                if e.global_transf then
                    transformed = nerv.speech_utils.global_transf(data[id],
                                        global_transf,
                                        gconf.frm_ext or 0, 0,
                                        gconf)
                else
                    transformed = data[id]
                end
                table.insert(input, transformed)
            end
            local output = {mat_type(gconf.batch_size, 1)}
            err_output = {}
            for i = 1, #input do
                table.insert(err_output, input[i]:create())
            end
            network:mini_batch_init({seq_length = table.vector(gconf.batch_size, 1),
                                    new_seq = {},
                                    do_train = bp,
                                    input = {input},
                                    output = {output},
                                    err_input = {err_input},
                                    err_output = {err_output}})
            network:propagate()
            if bp then
                network:back_propagate()
                network:update()
            end
            -- collect garbage in-time to save GPU memory
            collectgarbage("collect")
        end
        print_stat(layer_repo)
        mat_type.print_profile()
        mat_type.clear_profile()
        local fname
        if (not bp) then
            host_param_repo = param_repo:copy(src_loc_type)
            if prefix ~= nil then
                nerv.info("writing back...")
                fname = string.format("%s_cv%.3f.nerv",
                                    prefix, get_accuracy(layer_repo))
                host_param_repo:export(fname, nil)
            end
        end
        return get_accuracy(layer_repo), host_param_repo, fname
    end
    return<