aboutsummaryrefslogblamecommitdiff
path: root/nerv/examples/asr_trainer.lua
blob: 684ea30c6100143a37805f7f36a7be13b00c824c (plain) (tree)
1
2
3
4
5
6
7
                                    

                                         
                                                  
                                           
                                                       
                                         





                                   
                                                             

                            
                                                                      


                                      
                                                   
                            



                                              
                                      

                                        
                             
                        
               
                            


                                                 


                                                             



                                                                           
                                                              




                                                
               
                                                          



                                                           

                                            
                                                                            




                                                         
                              

                                
                                          
                                        
                                                         
                                                             
                                                   
           
                                       



                            

































































                                                                              
              









                                                                             


                                   
                                                   
 
             
                                                      
                            
                                                                          
                                                    
                                                       

                                                                  
                                        
                                                                         
                                                  



                                                    


                                                            


                                             

             


                                              
       

                                                        



                                
                                 
   
local function build_trainer(ifname)
    local param_repo = nerv.ParamRepo()
    param_repo:import(ifname, nil, gconf)
    local layer_repo = make_layer_repo(param_repo)
    local network = get_network(layer_repo)
    local global_transf = get_global_transf(layer_repo)
    local input_order = get_input_order()
    local mat_type
    if gconf.use_cpu then
        mat_type = gconf.mmat_type
    else
        mat_type = gconf.cumat_type
    end
    local iterative_trainer = function (prefix, scp_file, bp)
        gconf.randomize = bp
        -- build buffer
        local buffer = make_buffer(make_readers(scp_file, layer_repo))
        -- initialize the network
        network:init(gconf.batch_size)
        gconf.cnt = 0
        err_input = {mat_type(gconf.batch_size, 1)}
        err_input[1]:fill(1)
        for data in buffer.get_data, buffer do
            -- prine stat periodically
            gconf.cnt = gconf.cnt + 1
            if gconf.cnt == 1000 then
                print_stat(layer_repo)
                mat_type.print_profile()
                mat_type.clear_profile()
                gconf.cnt = 0
                -- break
            end
            local input = {}
--            if gconf.cnt == 1000 then break end
            for i, e in ipairs(input_order) do
                local id = e.id
                if data[id] == nil then
                    nerv.error("input data %s not found", id)
                end
                local transformed
                if e.global_transf then
                    transformed = nerv.speech_utils.global_transf(data[id],
                                        global_transf,
                                        gconf.frm_ext or 0, 0,
                                        gconf)
                else
                    transformed = data[id]
                end
                table.insert(input, transformed)
            end
            local output = {mat_type(gconf.batch_size, 1)}
            err_output = {}
            for i = 1, #input do
                table.insert(err_output, input[i]:create())
            end
            network:propagate(input, output)
            if bp then
                network:back_propagate(err_input, err_output, input, output)
                network:update(err_input, input, output)
            end
            -- collect garbage in-time to save GPU memory
            collectgarbage("collect")
        end
        print_stat(layer_repo)
        mat_type.print_profile()
        mat_type.clear_profile()
        if (not bp) and prefix ~= nil then
            nerv.info("writing back...")
            local fname = string.format("%s_cv%.3f.nerv",
                            prefix, get_accuracy(layer_repo))
            network:get_params():export(fname, nil)
        end
        return get_accuracy(layer_repo)
    end
    return iterative_trainer
end

local function check_and_add_defaults(spec)
    for k, v in pairs(spec) do
        gconf[k] = opts[string.gsub(k, '_', '-')].val or gconf[k] or v
    end
end

local function make_options(spec)
    local options = {}
    for k, v in pairs(spec) do
        table.insert(options,
                    {string.gsub(k, '_', '-'), nil, type(v), default = v})
    end
    return options
end

local function print_help(options)
    nerv.printf("Usage: <asr_trainer.lua> [options] network_config.lua\n")
    nerv.print_usage(options)
end

local function print_gconf()
    local key_maxlen = 0
    for k, v in pairs(gconf) do
        key_maxlen = math.max(key_maxlen, #k or 0)
    end
    local function pattern_gen()
        return string.format("%%-%ds = %%s\n", key_maxlen)
    end
    nerv.info("ready to train with the following gconf settings:")
    nerv.printf(pattern_gen(), "Key", "Value")
    for k, v in pairs(gconf) do
        nerv.printf(pattern_gen(), k or "", v or "")
    end
end

local trainer_defaults = {
    lrate = 0.8,
    batch_size = 256,
    buffer_size = 81920,
    wcost = 1e-6,
    momentum = 0.9,
    start_halving_inc = 0.5,
    halving_factor = 0.6,
    end_halving_inc = 0.1,
    min_iter = 1,
    max_iter = 20,
    min_halving = 5,
    do_halving = false,
    tr_scp = nil,
    cv_scp = nil,
    cumat_type = nerv.CuMatrixFloat,
    mmat_type = nerv.MMat