aboutsummaryrefslogtreecommitdiff
path: root/lua/tnn.lua
diff options
context:
space:
mode:
Diffstat (limited to 'lua/tnn.lua')
-rw-r--r--lua/tnn.lua136
1 files changed, 0 insertions, 136 deletions
diff --git a/lua/tnn.lua b/lua/tnn.lua
deleted file mode 100644
index bf9f118..0000000
--- a/lua/tnn.lua
+++ /dev/null
@@ -1,136 +0,0 @@
-nerv.include('select_linear.lua')
-
-local reader = nerv.class('nerv.TNNReader')
-
-function reader:__init(global_conf, data)
- self.gconf = global_conf
- self.offset = 0
- self.data = data
-end
-
-function reader:get_batch(feeds)
- self.offset = self.offset + 1
- if self.offset > #self.data then
- return false
- end
- for i = 1, self.gconf.chunk_size do
- feeds.inputs_m[i][1]:copy_from(self.data[self.offset].input[i])
- feeds.inputs_m[i][2]:copy_from(self.data[self.offset].output[i]:decompress(self.gconf.vocab_size))
- end
- feeds.flags_now = self.data[self.offset].flags
- feeds.flagsPack_now = self.data[self.offset].flagsPack
- return true
-end
-
-function reader:has_data(t, i)
- return t <= self.data[self.offset].seq_len[i]
-end
-
-function reader:get_err_input()
- return self.data[self.offset].err_input
-end
-
-local nn = nerv.class('nerv.NN')
-
-function nn:__init(global_conf, train_data, val_data, layers, connections)
- self.gconf = global_conf
- self.tnn = self:get_tnn(layers, connections)
- self.train_data = self:get_data(train_data)
- self.val_data = self:get_data(val_data)
-end
-
-function nn:get_tnn(layers, connections)
- self.gconf.dropout_rate = 0
- local layer_repo = nerv.LayerRepo(layers, self.gconf.pr, self.gconf)
- local tnn = nerv.TNN('TNN', self.gconf, {dim_in = {1, self.gconf.vocab_size},
- dim_out = {1}, sub_layers = layer_repo, connections = connections,
- clip = self.gconf.clip})
- tnn:init(self.gconf.batch_size, self.gconf.chunk_size)
- return tnn
-end
-
-function nn:get_data(data)
- local ret = {}
- for i = 1, #data do
- ret[i] = {}
- ret[i].input = data[i].input
- ret[i].output = data[i].output
- ret[i].flags = {}
- ret[i].err_input = {}
- for t = 1, self.gconf.chunk_size do
- ret[i].flags[t] = {}
- local err_input = self.gconf.mmat_type(self.gconf.batch_size, 1)
- for j = 1, self.gconf.batch_size do
- if t <= data[i].seq_len[j] then
- ret[i].flags[t][j] = nerv.TNN.FC.SEQ_NORM
- err_input[j - 1][0] = 1
- else
- ret[i].flags[t][j] = 0
- err_input[j - 1][0] = 0
- end
- end
- ret[i].err_input[t] = self.gconf.cumat_type.new_from_host(err_input)
- end
- for j = 1, self.gconf.batch_size do
- if data[i].seq_start[j] then
- ret[i].flags[1][j] = bit.bor(ret[i].flags[1][j], nerv.TNN.FC.SEQ_START)
- end
- if data[i].seq_end[j] then
- local t = data[i].seq_len[j]
- ret[i].flags[t][j] = bit.bor(ret[i].flags[t][j], nerv.TNN.FC.SEQ_END)
- end
- end
- ret[i].flagsPack = {}
- for t = 1, self.gconf.chunk_size do
- ret[i].flagsPack[t] = 0
- for j = 1, self.gconf.batch_size do
- ret[i].flagsPack[t] = bit.bor(ret[i].flagsPack[t], ret[i].flags[t][j])
- end
- end
- ret[i].seq_len = data[i].seq_len
- end
- return ret
-end
-
-function nn:process(data, do_train)
- local total_err = 0
- local total_frame = 0
- local reader = nerv.TNNReader(self.gconf, data)
- while true do
- local r, _ = self.tnn:getfeed_from_reader(reader)
- if not r then
- break
- end
- if do_train then
- self.gconf.dropout_rate = self.gconf.dropout
- else
- self.gconf.dropout_rate = 0
- end
- self.tnn:net_propagate()
- for t = 1, self.gconf.chunk_size do
- local tmp = self.tnn.outputs_m[t][1]:new_to_host()
- for i = 1, self.gconf.batch_size do
- if reader:has_data(t, i) then
- total_err = total_err + math.log10(math.exp(tmp[i - 1][0]))
- total_frame = total_frame + 1
- end
- end
- end
- if do_train then
- local err_input = reader:get_err_input()
- for i = 1, self.gconf.chunk_size do
- self.tnn.err_inputs_m[i][1]:copy_from(err_input[i])
- end
- self.tnn:net_backpropagate(false)
- self.tnn:net_backpropagate(true)
- end
- collectgarbage('collect')
- end
- return math.pow(10, - total_err / total_frame)
-end
-
-function nn:epoch()
- local train_error = self:process(self.train_data, true)
- local val_error = self:process(self.val_data, false)
- return train_error, val_error
-end