diff options
Diffstat (limited to 'lua/network.lua')
-rw-r--r-- | lua/network.lua | 106 |
1 files changed, 0 insertions, 106 deletions
diff --git a/lua/network.lua b/lua/network.lua deleted file mode 100644 index d106ba1..0000000 --- a/lua/network.lua +++ /dev/null @@ -1,106 +0,0 @@ -nerv.include('select_linear.lua') - -local nn = nerv.class('nerv.NN') - -function nn:__init(global_conf, train_data, val_data, layers, connections) - self.gconf = global_conf - self.network = self:get_network(layers, connections) - self.train_data = self:get_data(train_data) - self.val_data = self:get_data(val_data) -end - -function nn:get_network(layers, connections) - local layer_repo = nerv.LayerRepo(layers, self.gconf.pr, self.gconf) - local graph = nerv.GraphLayer('graph', self.gconf, - {dim_in = {1, self.gconf.vocab_size}, dim_out = {1}, - layer_repo = layer_repo, connections = connections}) - local network = nerv.Network('network', self.gconf, - {network = graph, clip = self.gconf.clip}) - network:init(self.gconf.batch_size, self.gconf.chunk_size) - return network -end - -function nn:get_data(data) - local err_output = {} - local softmax_output = {} - local output = {} - for i = 1, self.gconf.chunk_size do - err_output[i] = self.gconf.cumat_type(self.gconf.batch_size, 1) - softmax_output[i] = self.gconf.cumat_type(self.gconf.batch_size, self.gconf.vocab_size) - output[i] = self.gconf.cumat_type(self.gconf.batch_size, 1) - end - local ret = {} - for i = 1, #data do - ret[i] = {} - ret[i].input = {} - ret[i].output = {} - ret[i].err_input = {} - ret[i].err_output = {} - for t = 1, self.gconf.chunk_size do - ret[i].input[t] = {} - ret[i].output[t] = {} - ret[i].err_input[t] = {} - ret[i].err_output[t] = {} - ret[i].input[t][1] = data[i].input[t] - ret[i].input[t][2] = data[i].output[t] - ret[i].output[t][1] = output[t] - local err_input = self.gconf.mmat_type(self.gconf.batch_size, 1) - for j = 1, self.gconf.batch_size do - if t <= data[i].seq_len[j] then - err_input[j - 1][0] = 1 - else - err_input[j - 1][0] = 0 - end - end - ret[i].err_input[t][1] = self.gconf.cumat_type.new_from_host(err_input) - ret[i].err_output[t][1] = err_output[t] - ret[i].err_output[t][2] = softmax_output[t] - end - ret[i].seq_length = data[i].seq_len - ret[i].new_seq = {} - for j = 1, self.gconf.batch_size do - if data[i].seq_start[j] then - table.insert(ret[i].new_seq, j) - end - end - end - return ret -end - -function nn:process(data, do_train) - local timer = self.gconf.timer - local total_err = 0 - local total_frame = 0 - for id = 1, #data do - data[id].do_train = do_train - timer:tic('network') - self.network:mini_batch_init(data[id]) - self.network:propagate() - timer:toc('network') - for t = 1, self.gconf.chunk_size do - local tmp = data[id].output[t][1]:new_to_host() - for i = 1, self.gconf.batch_size do - if t <= data[id].seq_length[i] then - total_err = total_err + math.log10(math.exp(tmp[i - 1][0])) - total_frame = total_frame + 1 - end - end - end - if do_train then - timer:tic('network') - self.network:back_propagate() - self.network:update() - timer:toc('network') - end - timer:tic('gc') - collectgarbage('collect') - timer:toc('gc') - end - return math.pow(10, - total_err / total_frame) -end - -function nn:epoch() - local train_error = self:process(self.train_data, true) - local val_error = self:process(self.val_data, false) - return train_error, val_error -end |