aboutsummaryrefslogtreecommitdiff
path: root/nerv/nn/trainer.lua
blob: a17b36cddf31029ea9550ff58254c4f284a2ae31 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
local trainer = nerv.class('nerv.Trainer')

function trainer:__init(gconf)
    local mat_type
    self.gconf = gconf
    self.src_loc_type = nerv.ParamRepo.LOC_TYPES.ON_HOST
    local src_loc_type = self.src_loc_type
    if gconf.use_cpu then
        mat_type = gconf.mmat_type
        self.train_loc_type = nerv.ParamRepo.LOC_TYPES.ON_HOST
    else
        mat_type = gconf.cumat_type
        self.train_loc_type = nerv.ParamRepo.LOC_TYPES.ON_DEVICE
    end
    local train_loc_type = self.train_loc_type
    local host_param_repo = nerv.ParamRepo()
    -- import the parameters from chunk files
    host_param_repo:import(gconf.initialized_param, gconf)
    local param_repo = host_param_repo:copy(train_loc_type, gconf)
    -- create layers and establish initial bindings
    self.layer_repo = self:make_layer_repo(param_repo)
    local layer_repo = self.layer_repo
    -- compile the network to be trained
    local graph = self:get_network(layer_repo)
    self.input_order = self:get_input_order()
    self.network = nerv.Network('network', gconf,
                                {network = graph,
                                 nn_act_default = gconf.nn_act_default})
    local network = self.network
    network:init(gconf.batch_size, gconf.chunk_size)

    local dim_in, dim_out = network.dim_in, network.dim_out
    self.err_output = {}
    local err_output = self.err_output
    for i = 1, #dim_in do
        err_output[i] = {}
        local dummy = mat_type(gconf.batch_size, dim_in[i])
        for t = 1, gconf.chunk_size do
            table.insert(err_output[i], dummy)
        end
    end
    self.output = {}
    self.err_input = {}
    local output = self.output
    local err_input = self.err_input
    for i = 1, #dim_out do
        output[i] = {}
        for t = 1, gconf.chunk_size do
            table.insert(output[i], mat_type(gconf.batch_size, dim_out[i]))
        end
        err_input[i] = {}
        if dim_out[i] ~= 1 then
            nerv.warning("the output has multiple heads, the default " ..
                        "`err_input` will be zero")
        end
        for t = 1, gconf.chunk_size do
            if dim_out[i] == 1 then
                table.insert(err_input[i], gconf.mask[t])
            else
                table.insert(err_input[i], mat_type(gconf.batch_size, dim_out[i]))
                err_input[i][t]:fill(0)
            end
        end
    end
end

function trainer:make_buffer(readers)
    local gconf = self.gconf
    if gconf.chunk_size == 1 then
        return nerv.FrmBuffer(gconf, {
            buffer_size = gconf.buffer_size,
            batch_size = gconf.batch_size,
            chunk_size = gconf.chunk_size,
            randomize = gconf.randomize,
            readers = readers,
            use_gpu = true,
        })
    else
        return nerv.SeqBuffer(gconf, {
            buffer_size = gconf.buffer_size,
            batch_size = gconf.batch_size,
            chunk_size = gconf.chunk_size,
            randomize = gconf.randomize,
            readers = readers,
            nn_act_default = gconf.nn_act_default,
        })
    end
end

function trainer:process(dataset, do_train)
    self:epoch_preprocess(dataset, do_train)
    local buffer = self:make_buffer(self:get_readers(dataset))
    local cnt = 0
    local network = self.network
    local input_order = self.input_order
    local output = self.output
    local err_input = self.err_input
    local err_output = self.err_output
    network:epoch_init()

    for data in buffer.get_data, buffer do
        cnt = cnt + 1
        local info = {input = {},
                      output = output,
                      err_input = err_input,
                      err_output = err_output,
                      do_train = do_train,
                      seq_length = data.seq_length,
                      new_seq = data.new_seq}

        for i = 1, #network.dim_in do
            info.input[i] = data.data[input_order[i]]
        end

        self:mini_batch_preprocess(cnt, info)
        network:mini_batch_init(info)
        network:propagate()
        self:mini_batch_inprocess(cnt, info)
        if do_train then
            network:back_propagate()
            network:update()
        end
        self:mini_batch_afterprocess(cnt, info)

        collectgarbage('collect')
    end

    self:epoch_afterprocess(dataset, do_train)
    return self:get_error()
end

function trainer:if_accept(cv_err)
    return cv_err < gconf.best_cv
end

function trainer:do_halving()
    gconf.lrate = gconf.lrate * gconf.hfactor
end

function trainer:save_params(train_err, cv_err)
    local gconf = self.gconf
    local src_loc_type = self.src_loc_type
    local train_loc_type = self.train_loc_type
    local layer_repo = self.layer_repo
    local param_fname = string.format('%s_iter_%d_lr%f_tr%.3f_cv%.3f.nerv',
                                      os.date(gconf.date_pattern),
                                      gconf.cur_iter,
                                      gconf.lrate,
                                      train_err,
                                      cv_err)
    param_fname = path.join(gconf.working_dir, param_fname)
    local network = self.network
    local host_param_repo = network:get_params():copy(src_loc_type, gconf)
    host_param_repo:export(param_fname)

    if self:if_accept(cv_err) then
        nerv.info("accepting the trained params")
        gconf.best_cv = cv_err
        gconf.initialized_param = {param_fname}
    else
        nerv.info("rejecting the trained params, rollback to the previous one")
        file.move(param_fname, param_fname .. '.rejected')
        host_param_repo = nerv.ParamRepo()
        host_param_repo:import(gconf.initialized_param, gconf)
        local param_repo = host_param_repo:copy(train_loc_type, gconf)
        -- rebind the parameters
        layer_repo:rebind(param_repo)
        self:do_halving()
    end
end

function trainer:training_preprocess()
end

function trainer:training_afterprocess()
end

function trainer:epoch_preprocess(dataset, do_train)
end

function trainer:epoch_afterprocess(dataset, do_train)
end

function trainer:mini_batch_preprocess(cnt, info)
end

function trainer:mini_batch_inprocess(cnt, info)
end

function trainer:mini_batch_afterprocess(cnt, info)
end

function trainer:make_layer_repo(param_repo)
    nerv.error_method_not_implemented()
end

function trainer:get_network(layer_repo)
    nerv.error_method_not_implemented()
end

function trainer:get_readers(dataset)
    nerv.error_method_not_implemented()
end

function trainer:get_input_order()
    nerv.error_method_not_implemented()
end

function trainer:get_error()
    nerv.error_method_not_implemented()
end