diff options
Diffstat (limited to 'fastnn/example/asgd_sds_trainer.lua')
-rw-r--r-- | fastnn/example/asgd_sds_trainer.lua | 328 |
1 files changed, 328 insertions, 0 deletions
diff --git a/fastnn/example/asgd_sds_trainer.lua b/fastnn/example/asgd_sds_trainer.lua new file mode 100644 index 0000000..44611b2 --- /dev/null +++ b/fastnn/example/asgd_sds_trainer.lua @@ -0,0 +1,328 @@ +package.path="/home/slhome/wd007/.luarocks/share/lua/5.1/?.lua;/home/slhome/wd007/.luarocks/share/lua/5.1/?/init.lua;/sgfs/users/wd007/src/nerv/install/share/lua/5.1/?.lua;/sgfs/users/wd007/src/nerv/install/share/lua/5.1/?/init.lua;"..package.path; +package.cpath="/home/slhome/wd007/.luarocks/lib/lua/5.1/?.so;/sgfs/users/wd007/src/nerv/install/lib/lua/5.1/?.so;"..package.cpath +local k,l,_=pcall(require,"luarocks.loader") _=k and l.add_context("nerv","scm-1") +require 'nerv' + +require 'fastnn' +require 'libhtkio' +require 'threads' + +dofile("fastnn/fastnn_baseline.lua") + +env = string.format([[ +package.path="/home/slhome/wd007/.luarocks/share/lua/5.1/?.lua;/home/slhome/wd007/.luarocks/share/lua/5.1/?/init.lua;/sgfs/users/wd007/src/nerv/install/share/lua/5.1/?.lua;/sgfs/users/wd007/src/nerv/install/share/lua/5.1/?/init.lua;"..package.path; +package.cpath="/home/slhome/wd007/.luarocks/lib/lua/5.1/?.so;/sgfs/users/wd007/src/nerv/install/lib/lua/5.1/?.so;"..package.cpath +local k,l,_=pcall(require,"luarocks.loader") _=k and l.add_context("nerv","scm-1") +]]) + + +train_thread_code = [[ +%s + +require 'nerv' +require 'fastnn' +dofile("fastnn/fastnn_baseline.lua") +os.execute("export MALLOC_CHECK_=0") + +local thread_idx = %d +local feat_repo_shareid = %d +local data_mutex_shareid = %d +local master_shareid = %d +local gpu_shareid = %d +local xent_shareid = %d +local batch_size = %d +local lrate = %f +local bp = %d +local scp_file = '%s' +local nnet_in = '%s' +local nnet_out = '%s' + +local share_mutex = threads.Mutex(data_mutex_shareid) +local share_master = fastnn.ModelSync(master_shareid) +local share_gpu = fastnn.CDevice(gpu_shareid) +local share_xent = fastnn.CXent(xent_shareid) + +if bp == 0 then + bp = false +else + bp = true + gconf.tr_scp = scp_file +end + +gconf.randomize = bp +gconf.lrate = lrate +gconf.batch_size = batch_size +gconf.network[1] = nnet_in +nerv.info_stderr("input network: %%s", gconf.network[1]) +--nerv.info_stderr(gconf.randomize) +nerv.info_stderr("input batch_size: %%d", gconf.batch_size) +nerv.info_stderr("input scp_file: %%s", scp_file) +nerv.info_stderr("input lrate: %%f", gconf.lrate) + + +share_mutex:lock() +share_gpu:select_gpu() + +nerv.context = nerv.CCuContext() +--print(nerv.context) + +nerv.info_stderr("thread %%d loading transf ...", thread_idx) +local param_transf_repo = nerv.ParamRepo() +param_transf_repo:import(gconf.transf, nil, gconf) +local transf_node_repo = make_transf_node_repo(param_transf_repo) +local transf_layer_repo = make_transf_link_repo(transf_node_repo, param_transf_repo) +local transf = transf_layer_repo:get_layer("global_transf") + +nerv.info_stderr("thread %%d loading network ...", thread_idx) +local param_network_repo = nerv.ParamRepo() +param_network_repo:import(gconf.network, nil, gconf) +local network_node_repo = make_network_node_repo(param_network_repo) +local network_layer_repo = make_network_link_repo(network_node_repo, param_network_repo) +local network = get_network(network_layer_repo) +share_mutex:unlock() + +local buffer = make_buffer(make_readers(nil, transf_layer_repo, feat_repo_shareid, data_mutex_shareid)) + +local input_order = get_input_order() + + -- initialize the network + network:init(gconf.batch_size) + gconf.cnt = 0 + err_input = {nerv.CuMatrixFloat(gconf.batch_size, 1)} + err_input[1]:fill(1) + + share_master:Initialize(network) + share_master:SyncInc() + + for data in buffer.get_data, buffer do + + gconf.cnt = gconf.cnt + 1 + if gconf.cnt == 2000 then + print_stat(network_node_repo) + gconf.cnt = 0 + end + + local input = {} + + for i, id in ipairs(input_order) do + if data[id] == nil then + nerv.error("input data %%s not found", id) + end + table.insert(input, data[id]) + end + + local output = {nerv.CuMatrixFloat(gconf.batch_size, 1)} + err_output = {input[1]:create()} + network:propagate(input, output) + + if bp then + network:back_propagate(err_input, err_output, input, output) + network:gradient(err_input, input, output) + + share_master:LockModel() + share_master:WeightToD(network) + network:update_gradient() + -- network:update(err_input, input, output) + share_master:WeightFromD(network) + share_master:UnLockModel() + end + + -- collect garbage in-time to save GPU memory + collectgarbage("collect") + end + + --print_stat(network_node_repo) + local ce_crit = network_node_repo:get_layer("ce_crit") + local xent = fastnn.CXent(ce_crit.total_frames, ce_crit.total_correct, ce_crit.total_ce, ce_crit.total_ce) + + share_master:LockModel() + share_xent:add(xent) + share_master:SyncDec() + --print(string.format("ThreadCount: %%d", share_master:ThreadCount())) + if share_master:ThreadCount() == 0 and bp then + share_master:WeightToD(network) + local fname = string.format("%%s_tr%%.3f", + nnet_out, frame_acc(share_xent)) + nerv.info_stderr("writing back %%s ...", fname) + network:get_params():export(fname, nil) + end + share_master:UnLockModel() +]] + + +function get_train_thread(train_thread_code, env, thread_idx, feat_repo_shareid, + data_mutex_shareid, master_shareid, gpu_shareid, xent_shareid, + batch_size, lrate, bp, scp_file, nnet_in, nnet_out) + return string.format(train_thread_code, env, thread_idx, feat_repo_shareid, + data_mutex_shareid, master_shareid, gpu_shareid, xent_shareid, + batch_size, lrate, bp, scp_file, nnet_in, nnet_out) +end + +function trainer(batch_size, lrate, bp, scp_file, nnet_in, nnet_out, num_threads) + local train_threads={} + local trainer = {} + local num_threads=num_threads + + local data_mutex = threads.Mutex() + local data_mutex_shareid = data_mutex:id() + + local master = fastnn.CModelSync() + local master_shareid = master:id() + --print(master) + + local xent = fastnn.CXent() + local xent_shareid = xent:id() + --print(xent) + + local gpu = fastnn.CDevice() + local gpu_shareid = gpu:id() + --print(gpu_shareid) + gpu:init() + + local feat_repo = nerv.TNetFeatureRepo(scp_file, gconf.htk_conf, gconf.frm_ext) + local feat_repo_shareid = feat_repo:id() + + for i=1,num_threads,1 do + + train_threads[i] = get_train_thread(train_thread_code, env, i, feat_repo_shareid, + data_mutex_shareid, master_shareid, gpu_shareid, xent_shareid, + batch_size, lrate, bp, scp_file, nnet_in, nnet_out) + --print(train_threads[i]) + trainer[i] = threads.Thread(train_threads[i]) + end + + nerv.info_stderr('| waiting for thread...') + + for i=1,num_threads,1 do + trainer[i]:free() + end + + print_xent(xent) + + nerv.info_stderr('| all thread finished!') + + return frame_acc(xent) +end + +function get_filename(fname) + return string.gsub((string.gsub(fname, "(.*/)(.*)", "%2")),"(.*)%..*", "%1") +end + +function do_sds(tr_scp, sds_scp, sds_rate) + math.randomseed(os.time()) + local scp_file = io.open(tr_scp, "r") + local sds_file = io.open(sds_scp, "w") + for line in scp_file:lines() do + rate = math.random() + if (rate < sds_rate) then + sds_file:write(line.."\n") + end + end + scp_file:close() + sds_file:close() +end + +function print_tag(iter) + io.stderr:write(string.format("########################################################\n")) + io.stderr:write(string.format("# NN TRAINING ITERATION:%d, %s\n", iter, os.date())) + io.stderr:write(string.format("########################################################\n")) +end + + +start_halving_inc = 0.5 +halving_factor = 0.8 +end_halving_inc = 0.1 +min_iter = 1 +max_iter = 20 +min_halving = 0 +gconf.batch_size = 256 +pf0 = get_filename(gconf.network[1]) +nnet_in = gconf.network[1] +nnet_out = "" +sds_scp = "tr_sds_"..string.format("%.4d", math.random()*10000)..".scp" --"tr_sds.scp" +sds_factor = 0.4 +num_threads = 2 +global_option = nil + +os.execute("export MALLOC_CHECK_=0") +print_gconf() + +-- training begin +nerv.info_stderr("begin initial cross validation") +accu_best = trainer(gconf.batch_size, gconf.lrate, 0, + gconf.cv_scp, nnet_in, "", num_threads) +local do_halving = false +local accu_new = accu_best + +nerv.info_stderr("initial cross validation: %.3f\n", accu_best) + +for i = 1, max_iter do + + if accu_new >= accu_best then + local sds_rate = math.cos((i-1)*11.0/180*math.pi) + if (sds_rate <= sds_factor) then + sds_rate = sds_factor + end + nerv.info_stderr("iteration %d sds_rate: %.6f", i, sds_rate) + do_sds(gconf.tr_scp, sds_scp, sds_rate) + end + + nnet_out=pf0.."_iter"..i + --print(nnet_out) + print_tag(i) + nerv.info_stderr("[NN] begin iteration %d learning_rate: %.3f batch_size: %d.", i, gconf.lrate, gconf.batch_size) + accu_tr = trainer(gconf.batch_size, gconf.lrate, 1, + sds_scp, nnet_in, nnet_out, num_threads) + collectgarbage("collect") + nerv.info_stderr("[TR] end iteration %d frame_accuracy: %.3f.\n", i, accu_tr) + os.execute("sleep " .. 3) + + nnet_out = nnet_out.."_tr"..accu_tr + accu_new = trainer(gconf.batch_size, gconf.lrate, 0, + gconf.cv_scp, nnet_out, "", num_threads) + collectgarbage("collect") + nerv.info_stderr("[CV] end iteration %d frame_accuracy: %.3f.\n\n", i, accu_new) + os.execute("sleep " .. 3) + + local nnet_tmp = string.format("%s_%s_iter_%d_lr%f_tr%.3f_cv%.3f", + pf0, + os.date("%Y%m%d%H%M%S"), + i, gconf.lrate, accu_tr, accu_new) + + -- TODO: revert the weights + local accu_diff = accu_new - accu_best + local cmd + if accu_new > accu_best then + accu_best = accu_new + nnet_in = nnet_tmp + gconf.batch_size = gconf.batch_size + 128 + if gconf.batch_size > 1024 then + gconf.batch_size = 1024 + end + else + -- reject + nnet_tmp = nnet_tmp.."_rejected" + do_halving = true + end + cmd = "mv "..nnet_out.." "..nnet_tmp + os.execute(cmd) + + if do_halving and accu_diff < end_halving_inc and i > min_iter then + break; + end + + if accu_diff < start_halving_inc and i >= min_halving then + do_halving = true + end + + if do_halving then + gconf.lrate = gconf.lrate * halving_factor + halving_factor = halving_factor - 0.025 + if halving_factor < 0.6 then + halving_factor = 0.6 + end + end + nerv.info_stderr("iteration %d done!", i) +end + + |