aboutsummaryrefslogtreecommitdiff
path: root/nerv/examples/trainer.lua
blob: caed2e26bd38bceb88cc64d296f574b0c2f53c6a (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
require 'lfs'
require 'pl'

-- =========================================================
--  Deal with command line input & init training envrioment
-- =========================================================

local function check_and_add_defaults(spec, opts)
    local function get_opt_val(k)
        local k = string.gsub(k, '_', '-')
        return opts[k].val, opts[k].specified
    end
    local opt_v = get_opt_val("resume_from")
    if opt_v then
        nerv.info("resuming from previous training state")
        gconf = dofile(opt_v)
    end
    for k, v in pairs(spec) do
        local opt_v, specified = get_opt_val(k)
        if (not specified) and gconf[k] ~= nil then
            nerv.info("using setting in network config file: %s = %s", k, gconf[k])
        elseif opt_v ~= nil then
            nerv.info("using setting in options: %s = %s", k, opt_v)
            gconf[k] = opt_v
        end
    end
end

local function make_options(spec)
    local options = {}
    for k, v in pairs(spec) do
        table.insert(options,
                    {string.gsub(k, '_', '-'), nil, type(v), default = v})
    end
    return options
end

local function print_help(options)
    nerv.printf("Usage: <trainer.lua> [options] <network_config.lua>\n")
    nerv.print_usage(options)
end

local function print_gconf()
    local key_maxlen = 0
    for k, v in pairs(gconf) do
        key_maxlen = math.max(key_maxlen, #k or 0)
    end
    local function pattern_gen()
        return string.format("%%-%ds = %%s\n", key_maxlen)
    end
    nerv.info("ready to train with the following gconf settings:")
    nerv.printf(pattern_gen(), "Key", "Value")
    for k, v in pairs(gconf) do
        nerv.printf(pattern_gen(), k or "", v or "")
    end
end

local function dump_gconf(fname)
    local f = io.open(fname, "w")
    f:write("return ")
    f:write(table.tostring(gconf))
    f:close()
end

local trainer_defaults = {
    lrate = 0.8,
    hfactor = 0.5,
    batch_size = 256,
    chunk_size = 1,
    buffer_size = 81920,
    wcost = 1e-6,
    momentum = 0.9,
    cur_iter = 1,
    max_iter = 20,
    randomize = true,
    cumat_tname = "nerv.CuMatrixFloat",
    mmat_tname = "nerv.MMatrixFloat",
    trainer_tname = "nerv.Trainer",
}

local options = make_options(trainer_defaults)
local extra_opt_spec = {
    {"clip", nil, "number"},
    {"resume-from", nil, "string"},
    {"help", "h", "boolean", default = false, desc = "show this help information"},
    {"dir", nil, "string", desc = "specify the working directory"},
}

table.extend(options, extra_opt_spec)

local opts
arg, opts = nerv.parse_args(arg, options)

if #arg < 1 or opts["help"].val then
    print_help(options)
    return
end

local script = arg[1]
local script_arg = {}
for i = 2, #arg do
    table.insert(script_arg, arg[i])
end
arg = script_arg
dofile(script)

--[[

Rule: command-line option overrides network config overrides trainer default.
Note: config key like aaa_bbbb_cc could be overriden by specifying
--aaa-bbbb-cc to command-line arguments.

]]--

check_and_add_defaults(trainer_defaults, opts)
gconf.mmat_type = nerv.get_type(gconf.mmat_tname)
gconf.cumat_type = nerv.get_type(gconf.cumat_tname)
gconf.trainer = nerv.get_type(gconf.trainer_tname)
gconf.use_cpu = econf.use_cpu or false
if gconf.initialized_param == nil then
    gconf.initialized_param = {}
end
if gconf.param_random == nil then
    gconf.param_random = function() return math.random() / 5 - 0.1 end
end

local date_pattern = "%Y-%m-%d_%H:%M:%S"
local logfile_name = "log"
local working_dir = opts["dir"].val or
                    string.format("nerv_%s", os.date(date_pattern))
gconf.working_dir = working_dir
gconf.date_pattern = date_pattern

print_gconf()
if not lfs.mkdir(working_dir) then
    nerv.error("[trainer] working directory already exists")
end

-- copy the network config
dir.copyfile(script, working_dir)
-- set logfile path
nerv.set_logfile(path.join(working_dir, logfile_name))

-- ============
--  Main loop
-- ============

local trainer = gconf.trainer(gconf)
trainer:training_preprocess()
gconf.best_cv = trainer:process('validate', false)
nerv.info("initial cross validation: %.3f", gconf.best_cv)

for i = gconf.cur_iter, gconf.max_iter do
    gconf.cur_iter = i
    dump_gconf(path.join(working_dir, string.format("iter_%d.meta", i)))
    nerv.info("[NN] begin iteration %d with lrate = %.6f", i, gconf.lrate)
    local train_err = trainer:process('train', true)
    nerv.info("[TR] training set %d: %.3f", i, train_err)
    local cv_err = trainer:process('validate', false)
    nerv.info("[CV] cross validation %d: %.3f", i, cv_err)
    if gconf.test then
        local test_err = trainer:process('test', false)
        nerv.info('[TE] testset error %d: %.3f', i, test_err)
    end
    trainer:save_params(train_err, cv_err)
end
dump_gconf(path.join(working_dir, string.format("iter_%d.meta", gconf.max_iter + 1)))
trainer:training_afterprocess()