diff options
Diffstat (limited to 'fastnn/test.lua')
-rw-r--r-- | fastnn/test.lua | 57 |
1 files changed, 57 insertions, 0 deletions
diff --git a/fastnn/test.lua b/fastnn/test.lua new file mode 100644 index 0000000..3b56eb1 --- /dev/null +++ b/fastnn/test.lua @@ -0,0 +1,57 @@ +require 'fastnn' + +function build_trainer(ifname) + local param_repo = nerv.ParamRepo() + param_repo:import(ifname, nil, gconf) + local sublayer_repo = make_sublayer_repo(param_repo) + local layer_repo = make_layer_repo(sublayer_repo, param_repo) + local nnet = get_network(layer_repo) + local input_order = get_input_order() + + + local iterative_trainer = function (prefix, scp_file, bp) + + -- build buffer + local buffer = make_buffer(make_readers(scp_file, layer_repo)) + + --[[local control = fastnn.modelsync(); + local lua_control = fastnn.ModelSync(control:id()) + print(control:__tostring()) + print(lua_control:GetDim(nnet)) + lua_control:Initialize(nnet) + lua_control:WeightToD(nnet) + lua_control:WeightToD(nnet) + ]] + + local example_repo = fastnn.CExamplesRepo(128, false) + -- print(example_repo) + local share_repo = fastnn.CExamplesRepo(example_repo:id(), true) + feat_id = get_feat_id() + + local t = 1; + for data in buffer.get_data, buffer do + local example = fastnn.Example.PrepareData(data, layer_repo.global_transf, feat_id) + print(example) + share_repo:accept(example) + end + + end + return iterative_trainer +end + + +dofile(arg[1]) +start_halving_inc = 0.5 +halving_factor = 0.6 +end_halving_inc = 0.1 +min_iter = 1 +max_iter = 20 +min_halving = 5 +gconf.batch_size = 256 +gconf.buffer_size = 81920 + +local pf0 = gconf.initialized_param +local trainer = build_trainer(pf0) + +local accu_best = trainer(nil, gconf.cv_scp, false) + |