1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
|
require 'lfs'
require 'pl'
-- =========================================================
-- Deal with command line input & init training envrioment
-- =========================================================
local function check_and_add_defaults(spec, opts)
local function get_opt_val(k)
local k = string.gsub(k, '_', '-')
return opts[k].val, opts[k].specified
end
local opt_v = get_opt_val("resume_from")
if opt_v then
nerv.info("resuming from previous training state")
gconf = dofile(opt_v)
end
for k, v in pairs(spec) do
local opt_v, specified = get_opt_val(k)
if (not specified) and gconf[k] ~= nil then
nerv.info("using setting in network config file: %s = %s", k, gconf[k])
elseif opt_v ~= nil then
nerv.info("using setting in options: %s = %s", k, opt_v)
gconf[k] = opt_v
end
end
end
local function make_options(spec)
local options = {}
for k, v in pairs(spec) do
table.insert(options,
{string.gsub(k, '_', '-'), nil, type(v), default = v})
end
return options
end
local function print_help(options)
nerv.printf("Usage: <trainer.lua> [options] <network_config.lua>\n")
nerv.print_usage(options)
end
local function print_gconf()
local key_maxlen = 0
for k, v in pairs(gconf) do
key_maxlen = math.max(key_maxlen, #k or 0)
end
local function pattern_gen()
return string.format("%%-%ds = %%s\n", key_maxlen)
end
nerv.info("ready to train with the following gconf settings:")
nerv.printf(pattern_gen(), "Key", "Value")
for k, v in pairs(gconf) do
nerv.printf(pattern_gen(), k or "", v or "")
end
end
local function dump_gconf(fname)
local f = io.open(fname, "w")
f:write("return ")
f:write(table.tostring(gconf))
f:close()
end
local trainer_defaults = {
lrate = 0.8,
hfactor = 0.5,
batch_size = 256,
chunk_size = 1,
buffer_size = 81920,
wcost = 1e-6,
momentum = 0.9,
cur_iter = 1,
max_iter = 20,
randomize = true,
cumat_tname = "nerv.CuMatrixFloat",
mmat_tname = "nerv.MMatrixFloat",
trainer_tname = "nerv.Trainer",
}
local options = make_options(trainer_defaults)
local extra_opt_spec = {
{"clip", nil, "number"},
{"resume-from", nil, "string"},
{"help", "h", "boolean", default = false, desc = "show this help information"},
{"dir", nil, "string", desc = "specify the working directory"},
}
table.extend(options, extra_opt_spec)
local opts
arg, opts = nerv.parse_args(arg, options)
if #arg < 1 or opts["help"].val then
print_help(options)
return
end
local script = arg[1]
local script_arg = {}
for i = 2, #arg do
table.insert(script_arg, arg[i])
end
arg = script_arg
dofile(script)
--[[
Rule: command-line option overrides network config overrides trainer default.
Note: config key like aaa_bbbb_cc could be overriden by specifying
--aaa-bbbb-cc to command-line arguments.
]]--
check_and_add_defaults(trainer_defaults, opts)
gconf.mmat_type = nerv.get_type(gconf.mmat_tname)
gconf.cumat_type = nerv.get_type(gconf.cumat_tname)
gconf.trainer = nerv.get_type(gconf.trainer_tname)
gconf.use_cpu = econf.use_cpu or false
if gconf.initialized_param == nil then
gconf.initialized_param = {}
end
if gconf.param_random == nil then
gconf.param_random = function() return math.random() / 5 - 0.1 end
end
local date_pattern = "%Y-%m-%d_%H:%M:%S"
local logfile_name = "log"
local working_dir = opts["dir"].val or
string.format("nerv_%s", os.date(date_pattern))
gconf.working_dir = working_dir
gconf.date_pattern = date_pattern
print_gconf()
if not lfs.mkdir(working_dir) then
nerv.error("[trainer] working directory already exists")
end
-- copy the network config
dir.copyfile(script, working_dir)
-- set logfile path
nerv.set_logfile(path.join(working_dir, logfile_name))
-- ============
-- Main loop
-- ============
local trainer = gconf.trainer(gconf)
trainer:training_preprocess()
gconf.best_cv = trainer:process('validate', false)
nerv.info("initial cross validation: %.3f", gconf.best_cv)
for i = gconf.cur_iter, gconf.max_iter do
gconf.cur_iter = i
dump_gconf(path.join(working_dir, string.format("iter_%d.meta", i)))
nerv.info("[NN] begin iteration %d with lrate = %.6f", i, gconf.lrate)
local train_err = trainer:process('train', true)
nerv.info("[TR] training set %d: %.3f", i, train_err)
local cv_err = trainer:process('validate', false)
nerv.info("[CV] cross validation %d: %.3f", i, cv_err)
if gconf.test then
local test_err = trainer:process('test', false)
nerv.info('[TE] testset error %d: %.3f', i, test_err)
end
trainer:save_params(train_err, cv_err)
end
dump_gconf(path.join(working_dir, string.format("iter_%d.meta", gconf.max_iter + 1)))
trainer:training_afterprocess()
|