summaryrefslogtreecommitdiff
path: root/lua/tnn.lua
diff options
context:
space:
mode:
Diffstat (limited to 'lua/tnn.lua')
-rw-r--r--lua/tnn.lua136
1 files changed, 136 insertions, 0 deletions
diff --git a/lua/tnn.lua b/lua/tnn.lua
new file mode 100644
index 0000000..bf9f118
--- /dev/null
+++ b/lua/tnn.lua
@@ -0,0 +1,136 @@
+nerv.include('select_linear.lua')
+
+local reader = nerv.class('nerv.TNNReader')
+
+function reader:__init(global_conf, data)
+ self.gconf = global_conf
+ self.offset = 0
+ self.data = data
+end
+
+function reader:get_batch(feeds)
+ self.offset = self.offset + 1
+ if self.offset > #self.data then
+ return false
+ end
+ for i = 1, self.gconf.chunk_size do
+ feeds.inputs_m[i][1]:copy_from(self.data[self.offset].input[i])
+ feeds.inputs_m[i][2]:copy_from(self.data[self.offset].output[i]:decompress(self.gconf.vocab_size))
+ end
+ feeds.flags_now = self.data[self.offset].flags
+ feeds.flagsPack_now = self.data[self.offset].flagsPack
+ return true
+end
+
+function reader:has_data(t, i)
+ return t <= self.data[self.offset].seq_len[i]
+end
+
+function reader:get_err_input()
+ return self.data[self.offset].err_input
+end
+
+local nn = nerv.class('nerv.NN')
+
+function nn:__init(global_conf, train_data, val_data, layers, connections)
+ self.gconf = global_conf
+ self.tnn = self:get_tnn(layers, connections)
+ self.train_data = self:get_data(train_data)
+ self.val_data = self:get_data(val_data)
+end
+
+function nn:get_tnn(layers, connections)
+ self.gconf.dropout_rate = 0
+ local layer_repo = nerv.LayerRepo(layers, self.gconf.pr, self.gconf)
+ local tnn = nerv.TNN('TNN', self.gconf, {dim_in = {1, self.gconf.vocab_size},
+ dim_out = {1}, sub_layers = layer_repo, connections = connections,
+ clip = self.gconf.clip})
+ tnn:init(self.gconf.batch_size, self.gconf.chunk_size)
+ return tnn
+end
+
+function nn:get_data(data)
+ local ret = {}
+ for i = 1, #data do
+ ret[i] = {}
+ ret[i].input = data[i].input
+ ret[i].output = data[i].output
+ ret[i].flags = {}
+ ret[i].err_input = {}
+ for t = 1, self.gconf.chunk_size do
+ ret[i].flags[t] = {}
+ local err_input = self.gconf.mmat_type(self.gconf.batch_size, 1)
+ for j = 1, self.gconf.batch_size do
+ if t <= data[i].seq_len[j] then
+ ret[i].flags[t][j] = nerv.TNN.FC.SEQ_NORM
+ err_input[j - 1][0] = 1
+ else
+ ret[i].flags[t][j] = 0
+ err_input[j - 1][0] = 0
+ end
+ end
+ ret[i].err_input[t] = self.gconf.cumat_type.new_from_host(err_input)
+ end
+ for j = 1, self.gconf.batch_size do
+ if data[i].seq_start[j] then
+ ret[i].flags[1][j] = bit.bor(ret[i].flags[1][j], nerv.TNN.FC.SEQ_START)
+ end
+ if data[i].seq_end[j] then
+ local t = data[i].seq_len[j]
+ ret[i].flags[t][j] = bit.bor(ret[i].flags[t][j], nerv.TNN.FC.SEQ_END)
+ end
+ end
+ ret[i].flagsPack = {}
+ for t = 1, self.gconf.chunk_size do
+ ret[i].flagsPack[t] = 0
+ for j = 1, self.gconf.batch_size do
+ ret[i].flagsPack[t] = bit.bor(ret[i].flagsPack[t], ret[i].flags[t][j])
+ end
+ end
+ ret[i].seq_len = data[i].seq_len
+ end
+ return ret
+end
+
+function nn:process(data, do_train)
+ local total_err = 0
+ local total_frame = 0
+ local reader = nerv.TNNReader(self.gconf, data)
+ while true do
+ local r, _ = self.tnn:getfeed_from_reader(reader)
+ if not r then
+ break
+ end
+ if do_train then
+ self.gconf.dropout_rate = self.gconf.dropout
+ else
+ self.gconf.dropout_rate = 0
+ end
+ self.tnn:net_propagate()
+ for t = 1, self.gconf.chunk_size do
+ local tmp = self.tnn.outputs_m[t][1]:new_to_host()
+ for i = 1, self.gconf.batch_size do
+ if reader:has_data(t, i) then
+ total_err = total_err + math.log10(math.exp(tmp[i - 1][0]))
+ total_frame = total_frame + 1
+ end
+ end
+ end
+ if do_train then
+ local err_input = reader:get_err_input()
+ for i = 1, self.gconf.chunk_size do
+ self.tnn.err_inputs_m[i][1]:copy_from(err_input[i])
+ end
+ self.tnn:net_backpropagate(false)
+ self.tnn:net_backpropagate(true)
+ end
+ collectgarbage('collect')
+ end
+ return math.pow(10, - total_err / total_frame)
+end
+
+function nn:epoch()
+ local train_error = self:process(self.train_data, true)
+ local val_error = self:process(self.val_data, false)
+ return train_error, val_error
+end