nerv.include('select_linear.lua')
local nn = nerv.class('nerv.NN')
function nn:__init(global_conf, train_data, val_data, layers, connections)
self.gconf = global_conf
self.network = self:get_network(layers, connections)
self.train_data = self:get_data(train_data)
self.val_data = self:get_data(val_data)
end
function nn:get_network(layers, connections)
self.gconf.dropout_rate = 0
local layer_repo = nerv.LayerRepo(layers, self.gconf.pr, self.gconf)
local graph = nerv.GraphLayer('graph', self.gconf,
{dim_in = {1, self.gconf.vocab_size}, dim_out = {1},
layer_repo = layer_repo, connections = connections})
local network = nerv.Network('network', self.gconf,
{network = graph, clip = self.gconf.clip})
network:init(self.gconf.batch_size, self.gconf.chunk_size)
return network
end
function nn:get_data(data)
local err_output = {}
local softmax_output = {}
local output = {}
for i = 1, self.gconf.chunk_size do
err_output[i] = self.gconf.cumat_type(self.gconf.batch_size, 1)
softmax_output[i] = self.gconf.cumat_type(self.gconf.batch_size, self.gconf.vocab_size)
output[i] = self.gconf.cumat_type(self.gconf.batch_size, 1)
end
local ret = {}
for i = 1, #data do
ret[i] = {}
ret[i].input = {}
ret[i].output = {}
ret[i].err_input = {}
ret[i].err_output = {}
for t = 1, self.gconf.chunk_size do
ret[i].input[t] = {}
ret[i].output[t] = {}
ret[i].err_input[t] = {}
ret[i].err_output[t] = {}
ret[i].input[t][1] = data[i].input[t]
ret[i].input[t][2] = data[i].output[t]
ret[i].output[t][1] = output[t]
local err_input = self.gconf.mmat_type(self.gconf.batch_size, 1)
for j = 1, self.gconf.batch_size do
if t <= data[i].seq_len[j] then
err_input[j - 1][0] = 1
else
err_input[j - 1][0] = 0
end
end
ret[i].err_input[t][1] = self.gconf.cumat_type.new_from_host(err_input)
ret[i].err_output[t][1] = err_output[t]
ret[i].err_output[t][2] = softmax_output[t]
end
ret[i].info = {}
ret[i].info.seq_length = data[i].seq_len
ret[i].info.new_seq = {}
for j = 1, self.gconf.batch_size do
if data[i].seq_start[j] then
table.insert(ret[i].info.new_seq, j)
end
end
end
return ret
end
function nn:process(data, do_train)
local total_err = 0
local total_frame = 0
for id = 1, #data do
if do_train then
self.gconf.dropout_rate = self.gconf.dropout
else
self.gconf.dropout_rate = 0
end
self.network:mini_batch_init(data[id].info)
local input = {}
for t = 1, self.gconf.chunk_size do
input[t] = {data[id].input[t][1], data[id].input[t][2]:decompress(self.gconf.vocab_size)}
end
self.network:propagate(input, data[id].output)
for t = 1, self.gconf.chunk_size do
local tmp = data[id].output[t][1]:new_to_host()
for i = 1, self.gconf.batch_size do
if t <= data[id].info.seq_length[i] then
total_err = total_err + math.log10(math.exp(tmp[i - 1][0]))
total_frame = total_frame + 1
end
end
end
if do_train then
self.network:back_propagate(data[id].err_input, data[id].err_output, input, data[id].output)
self.network:update(data[id].err_input, input, data[id].output)
end
collectgarbage('collect')
end
return math.pow(10, - total_err / total_frame)
end
function nn:epoch()
local train_error = self:process(self.train_data, true)
local val_error = self:process(self.val_data, false)
return train_error, val_error
end