diff options
author | Qi Liu <[email protected]> | 2016-03-11 20:11:00 +0800 |
---|---|---|
committer | Qi Liu <[email protected]> | 2016-03-11 20:11:00 +0800 |
commit | e2a9af061db485d4388902d738c9d8be3f94ab34 (patch) | |
tree | 468d6c6afa0801f6a6bf794b3674f8814b8827f7 /lua/tnn.lua | |
parent | 2f46a5e2b37a054f482f76f4ac3d26b144cf988f (diff) |
add recipe and fix bugs
Diffstat (limited to 'lua/tnn.lua')
-rw-r--r-- | lua/tnn.lua | 136 |
1 files changed, 0 insertions, 136 deletions
diff --git a/lua/tnn.lua b/lua/tnn.lua deleted file mode 100644 index bf9f118..0000000 --- a/lua/tnn.lua +++ /dev/null @@ -1,136 +0,0 @@ -nerv.include('select_linear.lua') - -local reader = nerv.class('nerv.TNNReader') - -function reader:__init(global_conf, data) - self.gconf = global_conf - self.offset = 0 - self.data = data -end - -function reader:get_batch(feeds) - self.offset = self.offset + 1 - if self.offset > #self.data then - return false - end - for i = 1, self.gconf.chunk_size do - feeds.inputs_m[i][1]:copy_from(self.data[self.offset].input[i]) - feeds.inputs_m[i][2]:copy_from(self.data[self.offset].output[i]:decompress(self.gconf.vocab_size)) - end - feeds.flags_now = self.data[self.offset].flags - feeds.flagsPack_now = self.data[self.offset].flagsPack - return true -end - -function reader:has_data(t, i) - return t <= self.data[self.offset].seq_len[i] -end - -function reader:get_err_input() - return self.data[self.offset].err_input -end - -local nn = nerv.class('nerv.NN') - -function nn:__init(global_conf, train_data, val_data, layers, connections) - self.gconf = global_conf - self.tnn = self:get_tnn(layers, connections) - self.train_data = self:get_data(train_data) - self.val_data = self:get_data(val_data) -end - -function nn:get_tnn(layers, connections) - self.gconf.dropout_rate = 0 - local layer_repo = nerv.LayerRepo(layers, self.gconf.pr, self.gconf) - local tnn = nerv.TNN('TNN', self.gconf, {dim_in = {1, self.gconf.vocab_size}, - dim_out = {1}, sub_layers = layer_repo, connections = connections, - clip = self.gconf.clip}) - tnn:init(self.gconf.batch_size, self.gconf.chunk_size) - return tnn -end - -function nn:get_data(data) - local ret = {} - for i = 1, #data do - ret[i] = {} - ret[i].input = data[i].input - ret[i].output = data[i].output - ret[i].flags = {} - ret[i].err_input = {} - for t = 1, self.gconf.chunk_size do - ret[i].flags[t] = {} - local err_input = self.gconf.mmat_type(self.gconf.batch_size, 1) - for j = 1, self.gconf.batch_size do - if t <= data[i].seq_len[j] then - ret[i].flags[t][j] = nerv.TNN.FC.SEQ_NORM - err_input[j - 1][0] = 1 - else - ret[i].flags[t][j] = 0 - err_input[j - 1][0] = 0 - end - end - ret[i].err_input[t] = self.gconf.cumat_type.new_from_host(err_input) - end - for j = 1, self.gconf.batch_size do - if data[i].seq_start[j] then - ret[i].flags[1][j] = bit.bor(ret[i].flags[1][j], nerv.TNN.FC.SEQ_START) - end - if data[i].seq_end[j] then - local t = data[i].seq_len[j] - ret[i].flags[t][j] = bit.bor(ret[i].flags[t][j], nerv.TNN.FC.SEQ_END) - end - end - ret[i].flagsPack = {} - for t = 1, self.gconf.chunk_size do - ret[i].flagsPack[t] = 0 - for j = 1, self.gconf.batch_size do - ret[i].flagsPack[t] = bit.bor(ret[i].flagsPack[t], ret[i].flags[t][j]) - end - end - ret[i].seq_len = data[i].seq_len - end - return ret -end - -function nn:process(data, do_train) - local total_err = 0 - local total_frame = 0 - local reader = nerv.TNNReader(self.gconf, data) - while true do - local r, _ = self.tnn:getfeed_from_reader(reader) - if not r then - break - end - if do_train then - self.gconf.dropout_rate = self.gconf.dropout - else - self.gconf.dropout_rate = 0 - end - self.tnn:net_propagate() - for t = 1, self.gconf.chunk_size do - local tmp = self.tnn.outputs_m[t][1]:new_to_host() - for i = 1, self.gconf.batch_size do - if reader:has_data(t, i) then - total_err = total_err + math.log10(math.exp(tmp[i - 1][0])) - total_frame = total_frame + 1 - end - end - end - if do_train then - local err_input = reader:get_err_input() - for i = 1, self.gconf.chunk_size do - self.tnn.err_inputs_m[i][1]:copy_from(err_input[i]) - end - self.tnn:net_backpropagate(false) - self.tnn:net_backpropagate(true) - end - collectgarbage('collect') - end - return math.pow(10, - total_err / total_frame) -end - -function nn:epoch() - local train_error = self:process(self.train_data, true) - local val_error = self:process(self.val_data, false) - return train_error, val_error -end |