aboutsummaryrefslogblamecommitdiff
path: root/nerv/examples/asr_trainer.lua
blob: 645f1ef90081eaef0e61ae99add167630b36a6ab (plain) (tree)
1
2
3
4
5
6
7
8
9
10

             
                                    
                                            
                  


                                              

                                  

                                                         

                                   

                                                           
       




                                                           

                                                            
                                                    
                                                                        
                                                          
 







                                                                                

                            
                                                                      
                                 
                     







                                                                  

                                  
                                           


                                      
                                      

                                        
                             
                        
               
                            
                                 


                                                 
                                         

                                                             

                                       
                                       






                                                                                   
                    



                                                                   
                   
                                                      
                                                
               

                                                               
                                                  



                                                             
                               
                      

                                        



                                                         
                              

                                








                                                                     
           
                                                               



                            

                                                 

                                             


                                            
                                                          


                                  




                                                                                       
                                

               































                                                                          






                                  


                          
                   





                            
                 



                       


                                       


                                              








                                                                                   







                                         
              








                                                                             



                                                   

                                   


                                                                                      
                             
 
             







                                                                

















































                                                                                        
                                 
   
require 'lfs'
require 'pl'
local function build_trainer(ifname)
    local host_param_repo = nerv.ParamRepo()
    local mat_type
    local src_loc_type
    local train_loc_type
    host_param_repo:import(ifname, nil, gconf)
    if gconf.use_cpu then
        mat_type = gconf.mmat_type
        src_loc_type = nerv.ParamRepo.LOC_TYPES.ON_HOST
        train_loc_type = nerv.ParamRepo.LOC_TYPES.ON_HOST
    else
        mat_type = gconf.cumat_type
        src_loc_type = nerv.ParamRepo.LOC_TYPES.ON_HOST
        train_loc_type = nerv.ParamRepo.LOC_TYPES.ON_DEVICE
    end
    local param_repo = host_param_repo:copy(train_loc_type)
    local layer_repo = make_layer_repo(param_repo)
    local network = get_network(layer_repo)
    local global_transf = get_global_transf(layer_repo)
    local input_order = get_input_order()

    network = nerv.Network("nt", gconf, {network = network})
    network:init(gconf.batch_size, gconf.chunk_size)
    global_transf = nerv.Network("gt", gconf, {network = global_transf})
    global_transf:init(gconf.batch_size, gconf.chunk_size)

    local iterative_trainer = function (prefix, scp_file, bp, rebind_param_repo)
        -- rebind the params if necessary
        if rebind_param_repo then
            host_param_repo = rebind_param_repo
            param_repo = host_param_repo:copy(train_loc_type)
            layer_repo:rebind(param_repo)
            rebind_param_repo = nil
        end
        gconf.randomize = bp
        -- build buffer
        local buffer = make_buffer(make_readers(scp_file, layer_repo))
        -- initialize the network
        gconf.cnt = 0
        local err_input = {{}}
        local output = {{}}
        for i = 1, gconf.chunk_size do
            local mini_batch = mat_type(gconf.batch_size, 1)
            mini_batch:fill(1)
            table.insert(err_input[1], mini_batch)
            table.insert(output[1], mat_type(gconf.batch_size, 1))
        end
        network:epoch_init()
        global_transf:epoch_init()
        for d in buffer.get_data, buffer do
            -- prine stat periodically
            gconf.cnt = gconf.cnt + 1
            if gconf.cnt == 1000 then
                print_stat(layer_repo)
                mat_type.print_profile()
                mat_type.clear_profile()
                gconf.cnt = 0
                -- break
            end
            local input = {}
            local err_output = {}
--            if gconf.cnt == 1000 then break end
            for i, e in ipairs(input_order) do
                local id = e.id
                if d.data[id] == nil then
                    nerv.error("input data %s not found", id)
                end
                local transformed = {}
                local err_output_i = {}
                if e.global_transf then
                    for _, mini_batch in ipairs(d.data[id]) do
                        table.insert(transformed,
                                        nerv.speech_utils.global_transf(mini_batch,
                                            global_transf,
                                            gconf.frm_ext or 0, 0,
                                            gconf))
                    end
                else
                    transformed = d.data[id]
                end
                for _, mini_batch in ipairs(transformed) do
                    table.insert(err_output_i, mini_batch:create())
                end
                table.insert(err_output, err_output_i)
                table.insert(input, transformed)
            end
            network:mini_batch_init({seq_length = d.seq_length,
                                    new_seq = d.new_seq,
                                    do_train = bp,
                                    input = input,
                                    output = output,
                                    err_input = err_input,
                                    err_output = err_output})
            network:propagate()
            if bp then
                network:back_propagate()
                network:update()
            end
            -- collect garbage in-time to save GPU memory
            collectgarbage("collect")
        end
        print_stat(layer_repo)
        mat_type.print_profile()
        mat_type.clear_profile()
        local fname
        if (not bp) then
            host_param_repo = param_repo:copy(src_loc_type)
            if prefix ~= nil then
                nerv.info("writing back...")
                fname = string.format("%s_cv%.3f.nerv",
                                    prefix, get_accuracy(layer_repo))
                host_param_repo:export(fname, nil)
            end
        end
        return get_accuracy(layer_repo), host_param_repo, fname
    end
    return iterative_trainer
end

local function check_and_add_defaults(spec, opts)
    local function get_opt_val(k)
        local k = string.gsub(k, '_', '-')
        return opts[k].val, opts[k].specified
    end
    local opt_v = get_opt_val("resume_from")
    if opt_v then
        nerv.info("resuming from previous training state")
        gconf = dofile(opt_v)
    else
        for k, v in pairs(spec) do
            local opt_v, specified = get_opt_val(k)
            if (not specified) and gconf[k] ~= nil then
                nerv.info("using setting in network config file: %s = %s", k, gconf[k])
            elseif opt_v ~= nil then
                nerv.info("using setting in options: %s = %s", k, opt_v)
                gconf[k] = opt_v
            end
        end
    end
end

local function make_options(spec)
    local options = {}
    for k, v in pairs(spec) do
        table.insert(options,
                    {string.gsub(k, '_', '-'), nil, type(v), default = v})
    end
    return options
end

local function print_help(options)
    nerv.printf("Usage: <asr_trainer.lua> [options] network_config.lua\n")
    nerv.print_usage(options)
end

local function print_gconf()
    local key_maxlen = 0
    for k, v in pairs(gconf) do
        key_maxlen = math.max(key_maxlen, #k or 0)
    end
    local function pattern_gen()
        return string.format("%%-%ds = %%s\n", key_maxlen)
    end
    nerv.info("ready to train with the following gconf settings:")
    nerv.printf(pattern_gen(), "Key", "Value")
    for k, v in pairs(gconf) do
        nerv.printf(pattern_gen(), k or "", v or "")
    end
end

local function dump_gconf(fname)
    local f = io.open(fname, "w")
    f:write("return ")
    f:write(table.tostring(gconf))
    f:close()
end

local trainer_defaults = {
    lrate = 0.8,
    batch_size = 256,
    chunk_size = 1,
    buffer_size = 81920,
    wcost = 1e-6,
    momentum = 0.9,
    start_halving_inc = 0.5,
    halving_factor = 0.6,
    end_halving_inc = 0.1,
    cur_iter = 1,
    min_iter = 1,
    max_iter = 20,
    min_halving = 5,
    do_halving = false,
    cumat_tname = "nerv.CuMatrixFloat",
    mmat_tname = "nerv.MMatrixFloat",
    debug = false,
}

local options = make_options(trainer_defaults)
local extra_opt_spec = {
    {"tr-scp", nil, "string"},
    {"cv-scp", nil, "string"},
    {"resume-from", nil, "string"},
    {"help", "h", "boolean", default = false, desc = "show this help information"},
    {"dir", nil, "string", desc = "specify the working directory"},
}

table.extend(options, extra_opt_spec)

arg, opts = nerv.parse_args(arg, options)

if #arg < 1 or opts["help"].val then
    print_help(options)
    return
end

dofile(arg[1])

--[[

Rule: command-line option overrides network config overrides trainer default.
Note: config key like aaa_bbbb_cc could be overriden by specifying
--aaa-bbbb-cc to command-line arguments.

]]--

check_and_add_defaults(trainer_defaults, opts)
gconf.mmat_type = nerv.get_type(gconf.mmat_tname)
gconf.cumat_type = nerv.get_type(gconf.cumat_tname)
gconf.use_cpu = econf.use_cpu or false

local pf0 = gconf.initialized_param
local date_pattern = "%Y%m%d%H%M%S"
local logfile_name = "log"
local working_dir = opts["dir"].val or string.format("nerv_%s", os.date(date_pattern))
local rebind_param_repo = nil

print_gconf()
if not lfs.mkdir(working_dir) then
    nerv.error("[asr_trainer] working directory already exists")
end
-- copy the network config
dir.copyfile(arg[1], working_dir)
-- set logfile path
nerv.set_logfile(path.join(working_dir, logfile_name))
path.chdir(working_dir)

-- start the training
local trainer = build_trainer(pf0)
local pr_prev
gconf.accu_best, pr_prev = trainer(nil, gconf.cv_scp, false)
nerv.info("initial cross validation: %.3f", gconf.accu_best)
for i = gconf.cur_iter, gconf.max_iter do
    local stop = false
    gconf.cur_iter = i
    dump_gconf(string.format("iter_%d.meta", i))
    repeat -- trick to implement `continue` statement
        nerv.info("[NN] begin iteration %d with lrate = %.6f", i, gconf.lrate)
        local accu_tr = trainer(nil, gconf.tr_scp, true, rebind_param_repo)
        nerv.info("[TR] training set %d: %.3f", i, accu_tr)
        local param_prefix = string.format("%s_%s_iter_%d_lr%f_tr%.3f",
                                string.gsub(
                                    (string.gsub(pf0[1], "(.*/)(.*)", "%2")),
                                    "(.*)%..*", "%1"),
                                os.date(date_pattern),
                                i, gconf.lrate,
                                accu_tr)
        local accu_new, pr_new, param_fname = trainer(param_prefix, gconf.cv_scp, false)
        nerv.info("[CV] cross validation %d: %.3f", i, accu_new)
        local accu_prev = gconf.accu_best
        if accu_new < gconf.accu_best then
            nerv.info("rejecting the trained params, rollback to the previous one")
            file.move(param_fname, param_fname .. ".rejected")
            rebind_param_repo = pr_prev
            break -- `continue` equivalent
        else
            nerv.info("accepting the trained params")
            gconf.accu_best = accu_new
            pr_prev = pr_new
            gconf.initialized_param = {path.join(path.currentdir(), param_fname)}
        end
        if gconf.do_halving and
            gconf.accu_best - accu_prev < gconf.end_halving_inc and
            i > gconf.min_iter then
            stop = true
            break
        end
        if gconf.accu_best - accu_prev < gconf.start_halving_inc and
            i >= gconf.min_halving then
            gconf.do_halving = true
        end
        if gconf.do_halving then
            gconf.lrate = gconf.lrate * gconf.halving_factor
        end
    until true
    if stop then break end
--    nerv.Matrix.print_profile()
end