aboutsummaryrefslogblamecommitdiff
path: root/fastnn/test.lua
blob: 3b56eb1a31921e43583d1ea0382c64c99c4acb3d (plain) (tree)
























































                                                                                                   
require 'fastnn'

function build_trainer(ifname)
    local param_repo = nerv.ParamRepo()
    param_repo:import(ifname, nil, gconf)
    local sublayer_repo = make_sublayer_repo(param_repo)
    local layer_repo = make_layer_repo(sublayer_repo, param_repo)
    local nnet = get_network(layer_repo)
    local input_order = get_input_order()


    local iterative_trainer = function (prefix, scp_file, bp) 

	-- build buffer
        local buffer = make_buffer(make_readers(scp_file, layer_repo))

    --[[local control = fastnn.modelsync();
    local lua_control = fastnn.ModelSync(control:id())    
	print(control:__tostring())
    print(lua_control:GetDim(nnet))
    lua_control:Initialize(nnet)
    lua_control:WeightToD(nnet)
    lua_control:WeightToD(nnet)
	]]

    	local example_repo = fastnn.CExamplesRepo(128, false)
	-- print(example_repo)
    	local share_repo = fastnn.CExamplesRepo(example_repo:id(), true)
	feat_id = get_feat_id()

	local t = 1;
	for data in buffer.get_data, buffer do
		local example = fastnn.Example.PrepareData(data, layer_repo.global_transf, feat_id)
		print(example)
		share_repo:accept(example)
	end
    
   end
	return iterative_trainer	
end


dofile(arg[1])
start_halving_inc = 0.5 
halving_factor = 0.6 
end_halving_inc = 0.1 
min_iter = 1 
max_iter = 20
min_halving = 5 
gconf.batch_size = 256 
gconf.buffer_size = 81920

local pf0 = gconf.initialized_param
local trainer = build_trainer(pf0)

local accu_best = trainer(nil, gconf.cv_scp, false)