aboutsummaryrefslogblamecommitdiff
path: root/fastnn/example/asgd_sds_trainer.lua
blob: cf1c7a6ea8b6c2d9891eb694741b8cfa38cb2797 (plain) (tree)
1
2
3
4
5
6
7
8
9
10
11





                                                                                                                                                                                                         
                                                                                  



                                    





                  
                                            
 






                      


                                            


























                                                     

                  


                             

                                                                  





                                                           
 




                                






                                                                 

                    
                                                                                                















                                                              
                                              



                                     






















                                                                               
 


















                                                                                    
                                                       






































































































                                                                                                                  

                                              























































































                                                                                                                         
NERV_ROOT = "/sgfs/users/wd007/src/nerv-2"

env = string.format([[
package.path="/home/slhome/wd007/.luarocks/share/lua/5.1/?.lua;/home/slhome/wd007/.luarocks/share/lua/5.1/?/init.lua;%s/install/share/lua/5.1/?.lua;%s/install/share/lua/5.1/?/init.lua;"..package.path; 
package.cpath="/home/slhome/wd007/.luarocks/lib/lua/5.1/?.so;%s/install/lib/lua/5.1/?.so;"..package.cpath
local k,l,_=pcall(require,"luarocks.loader") _=k and l.add_context("nerv","scm-1")
]], NERV_ROOT, NERV_ROOT, NERV_ROOT)

loadstring(env)()

require 'nerv'

require 'fastnn'
require 'libhtkio'
require 'threads'

dofile("fastnn/example/fastnn_baseline.lua")



train_thread_code = [[
%s

require 'nerv'
require 'fastnn'
require 'libhtkio'

dofile("fastnn/example/fastnn_baseline.lua")
os.execute("export MALLOC_CHECK_=0")

local thread_idx = %d
local feat_repo_shareid = %d
local data_mutex_shareid = %d 
local master_shareid = %d
local gpu_shareid = %d
local xent_shareid = %d
local batch_size = %d
local lrate = %f
local bp = %d
local scp_file = '%s'
local nnet_in = '%s'
local nnet_out = '%s'

local share_mutex = threads.Mutex(data_mutex_shareid)
local share_master = fastnn.ModelSync(master_shareid)
local share_gpu = fastnn.CDevice(gpu_shareid)
local share_xent = fastnn.CXent(xent_shareid)

if bp == 0 then
	bp = false
else
	bp = true
	gconf.tr_scp = scp_file
end

share_mutex:lock()

gconf.randomize = bp
gconf.lrate = lrate
gconf.batch_size = batch_size
gconf.initialized_param[2] = nnet_in
nerv.info_stderr("input network: %%s", gconf.initialized_param[2])
--nerv.info_stderr(gconf.randomize)
nerv.info_stderr("input batch_size: %%d", gconf.batch_size)
nerv.info_stderr("input scp_file: %%s", scp_file)
nerv.info_stderr("input lrate: %%f", gconf.lrate)



share_gpu:select_gpu()

nerv.context = nerv.CCuContext()
--print(nerv.context)

nerv.info_stderr("thread %%d loading parameters ...", thread_idx)
local param_repo = nerv.ParamRepo()
param_repo:import(gconf.initialized_param, nil, gconf)
local layer_repo = make_layer_repo(param_repo)
local network = get_network(layer_repo)
local global_transf = get_global_transf(layer_repo)

share_mutex:unlock()

local buffer = make_buffer(make_readers(nil, layer_repo, feat_repo_shareid, data_mutex_shareid))

local input_order = get_input_order()

	-- initialize the network
        network:init(gconf.batch_size)
	gconf.cnt = 0 
        err_input = {nerv.CuMatrixFloat(gconf.batch_size, 1)} 
        err_input[1]:fill(1)

	share_master:Initialize(network)
	share_master:SyncInc()

        for data in buffer.get_data, buffer do

		gconf.cnt = gconf.cnt + 1
		if gconf.cnt == 2000 then
			print_stat(layer_repo)
			gconf.cnt = 0
		end

		local input = {}
		
		for i, e in ipairs(input_order) do
                	local id = e.id
                	if data[id] == nil then
                    		nerv.error("input data %%s not found", id)
                	end
                	local transformed
                	if e.global_transf then
                    	transformed = nerv.speech_utils.global_transf(data[id],
                                        global_transf,
                                        gconf.frm_ext or 0, 0,
                                        gconf)
                	else
                    		transformed = data[id]
                	end
	                table.insert(input, transformed)
            	end

            	local output = {nerv.CuMatrixFloat(gconf.batch_size, 1)}
            	err_output = {}
            	for i = 1, #input do
                	table.insert(err_output, input[i]:create())
		end

            	network:propagate(input, output)

		if bp then
                	network:back_propagate(err_input, err_output, input, output)
                	network:gradient(err_input, input, output)

			share_master:LockModel()
			share_master:WeightToD(network)			
                	network:update_gradient()
              --  	network:update(err_input, input, output)
			share_master:WeightFromD(network)			
			share_master:UnLockModel()
            	end 

            -- collect garbage in-time to save GPU memory
            collectgarbage("collect")
        end 

	--print_stat(network_node_repo)
	local ce_crit = layer_repo:get_layer("ce_crit")
	local xent = fastnn.CXent(ce_crit.total_frames, ce_crit.total_correct, ce_crit.total_ce, ce_crit.total_ce)

	share_master:LockModel()
	share_xent:add(xent)
	share_master:SyncDec()
	--print(string.format("ThreadCount: %%d", share_master:ThreadCount()))
	if share_master:ThreadCount() == 0 and bp then
		 share_master:WeightToD(network)
		 local fname = string.format("%%s_tr%%.3f",
                            nnet_out, frame_acc(share_xent))
		 nerv.info_stderr("writing back %%s ...", fname)
		 network:get_params():export(fname, nil)
	end
	share_master:UnLockModel()
]]


function get_train_thread(train_thread_code, env, thread_idx, feat_repo_shareid,  
                                                data_mutex_shareid, master_shareid, gpu_shareid, xent_shareid, 
						batch_size, lrate, bp, scp_file, nnet_in, nnet_out)
        return string.format(train_thread_code, env, thread_idx, feat_repo_shareid,  
                                                data_mutex_shareid, master_shareid, gpu_shareid, xent_shareid, 
						batch_size, lrate, bp, scp_file, nnet_in, nnet_out)
end

function trainer(batch_size, lrate, bp, scp_file, nnet_in, nnet_out, num_threads)
	local train_threads={}
	local trainer = {}
	local num_threads=num_threads

	local data_mutex = threads.Mutex()
	local data_mutex_shareid = data_mutex:id()

	local master = fastnn.CModelSync()
	local master_shareid = master:id()
	--print(master)

	local xent = fastnn.CXent()
	local xent_shareid = xent:id()
	--print(xent)

	local gpu = fastnn.CDevice()
	local gpu_shareid = gpu:id()
	--print(gpu_shareid)
	gpu:init()

	local feat_repo = nerv.TNetFeatureRepo(scp_file, gconf.htk_conf, gconf.frm_ext)
	local feat_repo_shareid = feat_repo:id()

	for i=1,num_threads,1 do

        	train_threads[i] = get_train_thread(train_thread_code, env, i, feat_repo_shareid,
                                                data_mutex_shareid, master_shareid, gpu_shareid, xent_shareid, 
						batch_size, lrate, bp, scp_file, nnet_in, nnet_out)
		--print(train_threads[i])
        	trainer[i] = threads.Thread(train_threads[i])
	end

	nerv.info_stderr('| waiting for thread...')

	for i=1,num_threads,1 do
        	trainer[i]:free()
	end

	print_xent(xent)

	nerv.info_stderr('| all thread finished!')

	return frame_acc(xent)
end

function get_filename(fname)
	return string.gsub((string.gsub(fname, "(.*/)(.*)", "%2")),"(.*)%..*", "%1")
end

function do_sds(tr_scp, sds_scp, sds_rate)
        math.randomseed(os.time())
        local scp_file = io.open(tr_scp, "r")
        local sds_file = io.open(sds_scp, "w")
        for line in scp_file:lines() do
		rate = math.random()
		if (rate < sds_rate) then
			sds_file:write(line.."\n")
		end 
	end 
	scp_file:close()
	sds_file:close()
end

function print_tag(iter)
	 io.stderr:write(string.format("########################################################\n"))
	 io.stderr:write(string.format("# NN TRAINING ITERATION:%d, %s\n", iter, os.date()))
	 io.stderr:write(string.format("########################################################\n"))
end
	

start_halving_inc = 0.5
halving_factor = 0.8
end_halving_inc = 0.1
min_iter = 1
max_iter = 20
min_halving = 0
gconf.batch_size = 256
pf0 = get_filename(gconf.initialized_param[2])
nnet_in = gconf.initialized_param[2]
nnet_out = ""
sds_scp = "tr_sds_"..string.format("%.4d", math.random()*10000)..".scp"  --"tr_sds.scp"
sds_factor = 0.4
num_threads = 2
global_option = nil

os.execute("export MALLOC_CHECK_=0")
print_gconf()

-- training begin
nerv.info_stderr("begin initial cross validation")  
accu_best = trainer(gconf.batch_size, gconf.lrate, 0, 
				gconf.cv_scp, nnet_in, "", num_threads) 
local do_halving = false
local accu_new = accu_best

nerv.info_stderr("initial cross validation: %.3f\n", accu_best)  

for i = 1, max_iter do

	if accu_new >= accu_best then
		local sds_rate = math.cos((i-1)*11.0/180*math.pi)
		if (sds_rate <= sds_factor) then
			sds_rate = sds_factor
		end
		nerv.info_stderr("iteration %d sds_rate: %.6f", i, sds_rate)
		do_sds(gconf.tr_scp, sds_scp, sds_rate)
	end
		
        nnet_out=pf0.."_iter"..i
	--print(nnet_out)
	print_tag(i)
	nerv.info_stderr("[NN] begin iteration %d learning_rate: %.3f batch_size: %d.", i, gconf.lrate, gconf.batch_size)
	accu_tr = trainer(gconf.batch_size, gconf.lrate, 1,
					sds_scp, nnet_in, nnet_out, num_threads)
	collectgarbage("collect")
	nerv.info_stderr("[TR] end iteration %d frame_accuracy: %.3f.\n", i, accu_tr)	
	os.execute("sleep " .. 3)

	nnet_out = nnet_out.."_tr"..accu_tr
	accu_new = trainer(gconf.batch_size, gconf.lrate, 0,
					gconf.cv_scp, nnet_out, "", num_threads)
	collectgarbage("collect")
	nerv.info_stderr("[CV] end iteration %d frame_accuracy: %.3f.\n\n", i, accu_new)	
	os.execute("sleep " .. 3)

	local nnet_tmp = string.format("%s_%s_iter_%d_lr%f_tr%.3f_cv%.3f", 
					pf0,
					os.date("%Y%m%d%H%M%S"),
   					i, gconf.lrate, accu_tr, accu_new)
	
	-- TODO: revert the weights
	local accu_diff = accu_new - accu_best
	local cmd
	if accu_new > accu_best then
		accu_best = accu_new
		nnet_in = nnet_tmp
		gconf.batch_size = gconf.batch_size + 128
		if gconf.batch_size > 1024 then
			gconf.batch_size = 1024
		end
	else  
	-- reject
		nnet_tmp = nnet_tmp.."_rejected"
		do_halving = true
	end
	cmd = "mv "..nnet_out.." "..nnet_tmp
	os.execute(cmd)	

	if do_halving and accu_diff < end_halving_inc and i > min_iter then
		break;
	end
	
	if accu_diff < start_halving_inc and i >= min_halving then
		do_halving = true
	end

	if do_halving then
		gconf.lrate = gconf.lrate * halving_factor
		halving_factor = halving_factor - 0.025
		if halving_factor < 0.6 then
			halving_factor = 0.6
		end
	end
	nerv.info_stderr("iteration %d done!", i)
end