aboutsummaryrefslogtreecommitdiff
path: root/fastnn/test.lua
diff options
context:
space:
mode:
Diffstat (limited to 'fastnn/test.lua')
-rw-r--r--fastnn/test.lua57
1 files changed, 57 insertions, 0 deletions
diff --git a/fastnn/test.lua b/fastnn/test.lua
new file mode 100644
index 0000000..3b56eb1
--- /dev/null
+++ b/fastnn/test.lua
@@ -0,0 +1,57 @@
+require 'fastnn'
+
+function build_trainer(ifname)
+ local param_repo = nerv.ParamRepo()
+ param_repo:import(ifname, nil, gconf)
+ local sublayer_repo = make_sublayer_repo(param_repo)
+ local layer_repo = make_layer_repo(sublayer_repo, param_repo)
+ local nnet = get_network(layer_repo)
+ local input_order = get_input_order()
+
+
+ local iterative_trainer = function (prefix, scp_file, bp)
+
+ -- build buffer
+ local buffer = make_buffer(make_readers(scp_file, layer_repo))
+
+ --[[local control = fastnn.modelsync();
+ local lua_control = fastnn.ModelSync(control:id())
+ print(control:__tostring())
+ print(lua_control:GetDim(nnet))
+ lua_control:Initialize(nnet)
+ lua_control:WeightToD(nnet)
+ lua_control:WeightToD(nnet)
+ ]]
+
+ local example_repo = fastnn.CExamplesRepo(128, false)
+ -- print(example_repo)
+ local share_repo = fastnn.CExamplesRepo(example_repo:id(), true)
+ feat_id = get_feat_id()
+
+ local t = 1;
+ for data in buffer.get_data, buffer do
+ local example = fastnn.Example.PrepareData(data, layer_repo.global_transf, feat_id)
+ print(example)
+ share_repo:accept(example)
+ end
+
+ end
+ return iterative_trainer
+end
+
+
+dofile(arg[1])
+start_halving_inc = 0.5
+halving_factor = 0.6
+end_halving_inc = 0.1
+min_iter = 1
+max_iter = 20
+min_halving = 5
+gconf.batch_size = 256
+gconf.buffer_size = 81920
+
+local pf0 = gconf.initialized_param
+local trainer = build_trainer(pf0)
+
+local accu_best = trainer(nil, gconf.cv_scp, false)
+