blob: 3b56eb1a31921e43583d1ea0382c64c99c4acb3d (
plain) (
tree)
|
|
require 'fastnn'
function build_trainer(ifname)
local param_repo = nerv.ParamRepo()
param_repo:import(ifname, nil, gconf)
local sublayer_repo = make_sublayer_repo(param_repo)
local layer_repo = make_layer_repo(sublayer_repo, param_repo)
local nnet = get_network(layer_repo)
local input_order = get_input_order()
local iterative_trainer = function (prefix, scp_file, bp)
-- build buffer
local buffer = make_buffer(make_readers(scp_file, layer_repo))
--[[local control = fastnn.modelsync();
local lua_control = fastnn.ModelSync(control:id())
print(control:__tostring())
print(lua_control:GetDim(nnet))
lua_control:Initialize(nnet)
lua_control:WeightToD(nnet)
lua_control:WeightToD(nnet)
]]
local example_repo = fastnn.CExamplesRepo(128, false)
-- print(example_repo)
local share_repo = fastnn.CExamplesRepo(example_repo:id(), true)
feat_id = get_feat_id()
local t = 1;
for data in buffer.get_data, buffer do
local example = fastnn.Example.PrepareData(data, layer_repo.global_transf, feat_id)
print(example)
share_repo:accept(example)
end
end
return iterative_trainer
end
dofile(arg[1])
start_halving_inc = 0.5
halving_factor = 0.6
end_halving_inc = 0.1
min_iter = 1
max_iter = 20
min_halving = 5
gconf.batch_size = 256
gconf.buffer_size = 81920
local pf0 = gconf.initialized_param
local trainer = build_trainer(pf0)
local accu_best = trainer(nil, gconf.cv_scp, false)
|