aboutsummaryrefslogtreecommitdiff
path: root/nerv/examples/trainer.lua
diff options
context:
space:
mode:
Diffstat (limited to 'nerv/examples/trainer.lua')
-rw-r--r--nerv/examples/trainer.lua36
1 files changed, 19 insertions, 17 deletions
diff --git a/nerv/examples/trainer.lua b/nerv/examples/trainer.lua
index 783ff1d..8e3efcb 100644
--- a/nerv/examples/trainer.lua
+++ b/nerv/examples/trainer.lua
@@ -1,9 +1,9 @@
require 'lfs'
require 'pl'
--- =======================================================
--- Deal with command line input & init training envrioment
--- =======================================================
+-- =========================================================
+-- Deal with command line input & init training envrioment
+-- =========================================================
local function check_and_add_defaults(spec, opts)
local function get_opt_val(k)
@@ -14,15 +14,14 @@ local function check_and_add_defaults(spec, opts)
if opt_v then
nerv.info("resuming from previous training state")
gconf = dofile(opt_v)
- else
- for k, v in pairs(spec) do
- local opt_v, specified = get_opt_val(k)
- if (not specified) and gconf[k] ~= nil then
- nerv.info("using setting in network config file: %s = %s", k, gconf[k])
- elseif opt_v ~= nil then
- nerv.info("using setting in options: %s = %s", k, opt_v)
- gconf[k] = opt_v
- end
+ end
+ for k, v in pairs(spec) do
+ local opt_v, specified = get_opt_val(k)
+ if (not specified) and gconf[k] ~= nil then
+ nerv.info("using setting in network config file: %s = %s", k, gconf[k])
+ elseif opt_v ~= nil then
+ nerv.info("using setting in options: %s = %s", k, opt_v)
+ gconf[k] = opt_v
end
end
end
@@ -65,6 +64,7 @@ end
local trainer_defaults = {
lrate = 0.8,
+ hfactor = 0.5,
batch_size = 256,
chunk_size = 1,
buffer_size = 81920,
@@ -125,7 +125,8 @@ end
local date_pattern = "%Y-%m-%d_%H:%M:%S"
local logfile_name = "log"
-local working_dir = opts["dir"].val or string.format("nerv_%s", os.date(date_pattern))
+local working_dir = opts["dir"].val or
+ string.format("nerv_%s", os.date(date_pattern))
gconf.working_dir = working_dir
gconf.date_pattern = date_pattern
@@ -139,9 +140,9 @@ dir.copyfile(script, working_dir)
-- set logfile path
nerv.set_logfile(path.join(working_dir, logfile_name))
--- =============
--- main function
--- =============
+-- ============
+-- Main loop
+-- ============
local trainer = gconf.trainer(gconf)
trainer:training_preprocess()
@@ -160,6 +161,7 @@ for i = gconf.cur_iter, gconf.max_iter do
local test_err = trainer:process('test', false)
nerv.info('[TE] testset error %d: %.3f', i, test_err)
end
- trainer:halving(train_err, cv_err)
+ trainer:save_params(train_err, cv_err)
end
+dump_gconf(path.join(working_dir, string.format("iter_%d.meta", gconf.max_iter + 1)))
trainer:training_afterprocess()