nerv.include('select_linear.lua')
local reader = nerv.class('nerv.TNNReader')
function reader:__init(global_conf, data)
self.gconf = global_conf
self.offset = 0
self.data = data
end
function reader:get_batch(feeds)
self.offset = self.offset + 1
if self.offset > #self.data then
return false
end
for i = 1, self.gconf.chunk_size do
feeds.inputs_m[i][1]:copy_from(self.data[self.offset].input[i])
feeds.inputs_m[i][2]:copy_from(self.data[self.offset].output[i]:decompress(self.gconf.vocab_size))
end
feeds.flags_now = self.data[self.offset].flags
feeds.flagsPack_now = self.data[self.offset].flagsPack
return true
end
function reader:has_data(t, i)
return t <= self.data[self.offset].seq_len[i]
end
function reader:get_err_input()
return self.data[self.offset].err_input
end
local nn = nerv.class('nerv.NN')
function nn:__init(global_conf, train_data, val_data, layers, connections)
self.gconf = global_conf
self.tnn = self:get_tnn(layers, connections)
self.train_data = self:get_data(train_data)
self.val_data = self:get_data(val_data)
end
function nn:get_tnn(layers, connections)
self.gconf.dropout_rate = 0
local layer_repo = nerv.LayerRepo(layers, self.gconf.pr, self.gconf)
local tnn = nerv.TNN('TNN', self.gconf, {dim_in = {1, self.gconf.vocab_size},
dim_out = {1}, sub_layers = layer_repo, connections = connections,
clip = self.gconf.clip})
tnn:init(self.gconf.batch_size, self.gconf.chunk_size)
return tnn
end
function nn:get_data(data)
local ret = {}
for i = 1, #data do
ret[i] = {}
ret[i].input = data[i].input
ret[i].output = data[i].output
ret[i].flags = {}
ret[i].err_input = {}
for t = 1, self.gconf.chunk_size do
ret[i].flags[t] = {}
local err_input = self.gconf.mmat_type(self.gconf.batch_size, 1)
for j = 1, self.gconf.batch_size do
if t <= data[i].seq_len[j] then
ret[i].flags[t][j] = nerv.TNN.FC.SEQ_NORM
err_input[j - 1][0] = 1
else
ret[i].flags[t][j] = 0
err_input[j - 1][0] = 0
end
end
ret[i].err_input[t] = self.gconf.cumat_type.new_from_host(err_input)
end
for j = 1, self.gconf.batch_size do
if data[i].seq_start[j] then
ret[i].flags[1][j] = bit.bor(ret[i].flags[1][j], nerv.TNN.FC.SEQ_START)
end
if data[i].seq_end[j] then
local t = data[i].seq_len[j]
ret[i].flags[t][j] = bit.bor(ret[i].flags[t][j], nerv.TNN.FC.SEQ_END)
end
end
ret[i].flagsPack = {}
for t = 1, self.gconf.chunk_size do
ret[i].flagsPack[t] = 0
for j = 1, self.gconf.batch_size do
ret[i].flagsPack[t] = bit.bor(ret[i].flagsPack[t], ret[i].flags[t][j])
end
end
ret[i].seq_len = data[i].seq_len
end
return ret
end
function nn:process(data, do_train)
local total_err = 0
local total_frame = 0
local reader = nerv.TNNReader(self.gconf, data)
while true do
local r, _ = self.tnn:getfeed_from_reader(reader)
if not r then
break
end
if do_train then
self.gconf.dropout_rate = self.gconf.dropout
else
self.gconf.dropout_rate = 0
end
self.tnn:net_propagate()
for t = 1, self.gconf.chunk_size do
local tmp = self.tnn.outputs_m[t][1]:new_to_host()
for i = 1, self.gconf.batch_size do
if reader:has_data(t, i) then
total_err = total_err + math.log10(math.exp(tmp[i - 1][0]))
total_frame = total_frame + 1
end
end
end
if do_train then
local err_input = reader:get_err_input()
for i = 1, self.gconf.chunk_size do
self.tnn.err_inputs_m[i][1]:copy_from(err_input[i])
end
self.tnn:net_backpropagate(false)
self.tnn:net_backpropagate(true)
end
collectgarbage('collect')
end
return math.pow(10, - total_err / total_frame)
end
function nn:epoch()
local train_error = self:process(self.train_data, true)
local val_error = self:process(self.val_data, false)
return train_error, val_error
end