diff options
author | Determinant <[email protected]> | 2016-03-11 13:59:46 +0800 |
---|---|---|
committer | Determinant <[email protected]> | 2016-03-11 13:59:46 +0800 |
commit | e6d28de460dfd06d696d369119247179c7a7525d (patch) | |
tree | 6263fb1555ddcba962edc31ee1312679135c06c4 /nerv/nn/network.lua | |
parent | a32195e3e2ae9ca0f0c7a82e73e6bddb64568c05 (diff) | |
parent | f26288ba61d3d16866e1b227a71e7d9c46923436 (diff) |
Merge branch 'master' of https://github.com/liuq901/nerv into liuq901-master
Conflicts:
nerv/layer/init.lua
nerv/nn/layer_repo.lua
Diffstat (limited to 'nerv/nn/network.lua')
-rw-r--r-- | nerv/nn/network.lua | 498 |
1 files changed, 498 insertions, 0 deletions
diff --git a/nerv/nn/network.lua b/nerv/nn/network.lua new file mode 100644 index 0000000..35e11e3 --- /dev/null +++ b/nerv/nn/network.lua @@ -0,0 +1,498 @@ +local network = nerv.class('nerv.Network') + +function network:__init(id, global_conf, network_conf) + self.id = id + self.network = network_conf.network + self.dim_in = self.network.dim_in + self.dim_out = self.network.dim_out + self.gconf = global_conf + if self.gconf.use_cpu then + self.mat_type = self.gconf.mmat_type + else + self.mat_type = self.gconf.cumat_type + end + self.clip = network_conf.clip + self.nn_act_default = network_conf.nn_act_default + if self.nn_act_default == nil then + self.nn_act_default = 0 + end + self.layers = {} + self.input_conn = {} + self.output_conn = {} + self.socket = self:compile(self.network) + for i = 1, #self.dim_in do + local edge = self.socket.inputs[i] + local id, port, time = edge[1], edge[2], edge[3] + if self.input_conn[id][port] ~= nil then + nerv.error('duplicate edge') + end + self.input_conn[id][port] = {0, i, time} + end + for i = 1, #self.dim_out do + local edge = self.socket.outputs[i] + local id, port, time = edge[1], edge[2], edge[3] + if self.output_conn[id][port] ~= nil then + nerv.error('duplicate edge') + end + self.output_conn[id][port] = {0, i, time} + end + self.delay = 0 + for i = 1, #self.layers do + local dim_in, _ = self.layers[i]:get_dim() + for j = 1, #dim_in do + local time = self.input_conn[i][j][3] + if math.abs(time) > self.delay then + self.delay = math.abs(time) + end + end + end +end + +function network:compile(layer) + local socket = {inputs = {}, outputs = {}} + if not nerv.is_type(layer, 'nerv.GraphLayer') then + table.insert(self.layers, layer) + local id = #self.layers + self.input_conn[id] = {} + self.output_conn[id] = {} + local dim_in, dim_out = layer:get_dim() + for i = 1, #dim_in do + socket.inputs[i] = {id, i, 0} + end + for i = 1, #dim_out do + socket.outputs[i] = {id, i, 0} + end + else + local sublayer_socket = {} + for id, sublayer in pairs(layer.layers) do + if id ~= '<input>' then + sublayer_socket[sublayer.id] = self:compile(sublayer.layer) + end + end + for _, edge in pairs(layer.connections) do + -- id = 0 means <input> or <output> + local id_from, port_from = edge[1], edge[2] + local id_to, port_to = edge[3], edge[4] + local time = edge[5] + if id_from == 0 then + if socket.inputs[port_from] ~= nil then + nerv.error('duplicate input socket') + end + local input = sublayer_socket[id_to].inputs[port_to] + local id, port, t = input[1], input[2], input[3] + time + socket.inputs[port_from] = {id, port, t} + else + local output = sublayer_socket[id_from].outputs[port_from] + local id, port, t = output[1], output[2], output[3] + time + if id_to == 0 then + if socket.outputs[port_to] ~= nil then + nerv.error('duplicate output socket') + end + socket.outputs[port_to] = {id, port, t} + else + local input = sublayer_socket[id_to].inputs[port_to] + local id1, port1, t1 = input[1], input[2], input[3] + if self.input_conn[id1][port1] ~= nil or self.output_conn[id][port] ~= nil then + nerv.error('duplicate edge') + end + self.input_conn[id1][port1] = {id, port, t + t1} + self.output_conn[id][port] = {id1, port1, t + t1} + end + end + end + end + return socket +end + +function network:init(batch_size, chunk_size) + self.batch_size = batch_size + self.chunk_size = chunk_size + + self:topsort() + + self:make_initial_store() + collectgarbage('collect') + + for i = 1, #self.layers do + self.layers[i]:init(batch_size, chunk_size) + end +end + +function network:topsort() + nerv.info('network topology sort') + local degree = {} + for t = 1, self.chunk_size do + degree[t] = {} + for i = 1, #self.layers do + degree[t][i] = 0 + end + end + + for t = 1, self.chunk_size do + for i = 1, #self.layers do + local _, dim_out = self.layers[i]:get_dim() + for j = 1, #dim_out do + if self.output_conn[i][j] ~= nil then + local edge = self.output_conn[i][j] + local id, time = edge[1], edge[3] + t + if time >= 1 and time <= self.chunk_size and id ~= 0 then + degree[time][id] = degree[time][id] + 1 + end + end + end + end + end + + self.queue = {} + local l = 1 + local r = 0 + for t = 1, self.chunk_size do + for i = 1, #self.layers do + if degree[t][i] == 0 then + r = r + 1 + self.queue[r] = {chunk = t, id = i} + end + end + end + while l<=r do + local t, i = self.queue[l].chunk, self.queue[l].id + l = l + 1 + local _, dim_out = self.layers[i]:get_dim() + for j = 1, #dim_out do + if self.output_conn[i][j] ~= nil then + local edge = self.output_conn[i][j] + local id, time = edge[1], edge[3] + t + if time >= 1 and time <= self.chunk_size and id ~= 0 then + degree[time][id] = degree[time][id] - 1 + if degree[time][id] == 0 then + r = r + 1 + self.queue[r] = {chunk = time, id = id} + end + end + end + end + end + + if r ~= self.chunk_size * #self.layers then + nerv.error('loop detected') + end +end + +function network:make_initial_store() + nerv.info('network initing storage') + + -- allocate memory + local memory = {} + local err_memory = {} + for t = 1 - self.delay, self.chunk_size + self.delay do + memory[t] = {} + err_memory[t] = {} + for i = 1, #self.layers do + memory[t][i] = {} + err_memory[t][i] = {} + local dim_in, dim_out = self.layers[i]:get_dim() + for j = 1, #dim_in do + err_memory[t][i][j] = self.mat_type(self.batch_size, dim_in[j]) + err_memory[t][i][j]:fill(0) + end + for j = 1, #dim_out do + memory[t][i][j] = self.mat_type(self.batch_size, dim_out[j]) + memory[t][i][j]:fill(self.nn_act_default) + end + end + -- memory[t][0] stores network input + memory[t][0] = {} + for j = 1, #self.dim_in do + memory[t][0][j] = self.mat_type(self.batch_size, self.dim_in[j]) + memory[t][0][j]:fill(self.nn_act_default) + end + -- err_memory[t][0] stores network err_input + err_memory[t][0] = {} + for j = 1, #self.dim_out do + err_memory[t][0][j] = self.mat_type(self.batch_size, self.dim_out[j]) + err_memory[t][0][j]:fill(0) + end + end + + -- connect memory and reference + self.input = {} + self.output = {} + self.err_input = {} + self.err_output = {} + for t = 1, self.chunk_size do + self.input[t] = {} + self.output[t] = {} + self.err_input[t] = {} + self.err_output[t] = {} + for i = 1, #self.layers do + self.input[t][i] = {} + self.output[t][i] = {} + self.err_input[t][i] = {} + self.err_output[t][i] = {} + local dim_in, dim_out = self.layers[i]:get_dim() + for j = 1, #dim_in do + local edge = self.input_conn[i][j] + local id, port, time = edge[1], edge[2], edge[3] + if id ~= 0 or t - time < 1 or t - time > self.chunk_size then + self.input[t][i][j] = memory[t - time][id][port] + end + if id ~= 0 then + self.err_output[t][i][j] = err_memory[t][i][j] + end + end + for j = 1, #dim_out do + local edge = self.output_conn[i][j] + local id, port, time = edge[1], edge[2], edge[3] + if id ~= 0 then + self.output[t][i][j] = memory[t][i][j] + end + if id ~= 0 or t + time < 1 or t + time > self.chunk_size then + self.err_input[t][i][j] = err_memory[t + time][id][port] + end + end + end + end + + -- check dangling reference + for t = 1, self.chunk_size do + for i = 1, #self.dim_in do + local edge = self.socket.inputs[i] + local id, port, time = edge[1], edge[2], edge[3] + if t + time >= 1 and t + time <= self.chunk_size then + if self.input[t + time][id][port] ~= nil then + nerv.error('input reference not nil') + end + self.input[t + time][id][port] = true -- just a place holder + if self.err_output[t + time][id][port] ~= nil then + nerv.error('err_output reference not nil') + end + self.err_output[t + time][id][port] = true -- just a place holder + end + end + for i = 1, #self.dim_out do + local edge = self.socket.outputs[i] + local id, port, time = edge[1], edge[2], edge[3] + if t - time >= 1 and t - time <= self.chunk_size then + if self.output[t - time][id][port] ~= nil then + nerv.error('output reference not nil') + end + self.output[t - time][id][port] = true -- just a place holder + if self.err_input[t - time][id][port] ~= nil then + nerv.error('err_output reference not nil') + end + self.err_input[t - time][id][port] = true -- just a place holder + end + end + end + for t = 1, self.chunk_size do + for i = 1, #self.layers do + local dim_in, dim_out = self.layers[i]:get_dim() + for j = 1, #dim_in do + if self.input[t][i][j] == nil then + nerv.error('input reference dangling') + end + if self.err_output[t][i][j] == nil then + nerv.error('err_output reference dangling') + end + end + for j = 1, #dim_out do + if self.output[t][i][j] == nil then + nerv.error('output reference dangling') + end + if self.err_input[t][i][j] == nil then + nerv.error('err_input reference dangling') + end + end + end + end + + -- allocate reference for legacy of previous mini-batch + self.legacy = {} + for t = 1 - self.delay, 0 do + self.legacy[t] = {} + for i = 1, #self.layers do + self.legacy[t][i] = {} + local _, dim_out = self.layers[i]:get_dim() + for j = 1, #dim_out do + self.legacy[t][i][j] = memory[t][i][j] + end + end + end +end + +function network:set_input(input) + for t = 1, self.chunk_size do + for i = 1, #self.dim_in do + local edge = self.socket.inputs[i] + local id, port, time = edge[1], edge[2], edge[3] + if t + time >= 1 and t + time <= self.chunk_size then + self.input[t + time][id][port] = input[t][i] + end + end + end +end + +function network:set_output(output) + for t = 1, self.chunk_size do + for i = 1, #self.dim_out do + local edge = self.socket.outputs[i] + local id, port, time = edge[1], edge[2], edge[3] + if t - time >= 1 and t - time <= self.chunk_size then + self.output[t - time][id][port] = output[t][i] + end + end + end +end + +function network:set_err_input(err_input) + for t = 1, self.chunk_size do + for i = 1, #self.dim_out do + local edge = self.socket.outputs[i] + local id, port, time = edge[1], edge[2], edge[3] + if t - time >= 1 and t - time <= self.chunk_size then + self.err_input[t - time][id][port] = err_input[t][i] + end + end + end +end + +function network:set_err_output(err_output) + for t = 1, self.chunk_size do + for i = 1, #self.dim_in do + local edge = self.socket.inputs[i] + local id, port, time = edge[1], edge[2], edge[3] + if t + time >= 1 and t + time <= self.chunk_size then + self.err_output[t + time][id][port] = err_output[t][i] + end + end + end +end + +--[[ + [info] is a table that contains information of current mini-batch. These fields must be contained: + [input], [output] : matrix array which stores the network input and output + [seq_length] : a table contains the length of every sequences + [new_seq]: a table contains the batch number of new sequences + [do_train]: a bool value indicates do train or not + if [do_train] is true, these fileds also must be contained: + [err_input], [err_output] : matrix array which stores the network err_input and err_output +--]] +function network:mini_batch_init(info) + self.info = info + self:set_input(self.info.input) + self:set_output(self.info.output) + + -- calculate border + self.max_length = 0 + self.border = {} + for i = 1, self.chunk_size do + self.border[i] = {} + end + for i = 1, self.batch_size do + if self.info.seq_length[i] > self.max_length then + self.max_length = self.info.seq_length[i] + end + for t = 1, self.delay do + local chunk = self.info.seq_length[i] + t + if chunk > self.chunk_size then + break + end + table.insert(self.border[chunk], i) + end + end + + -- copy legacy + for t = 1 - self.delay, 0 do + for i = 1, #self.layers do + local _, dim_out = self.layers[i]:get_dim() + for j = 1, #dim_out do + if t + self.chunk_size >= 1 and self.output_conn[i][j][1] ~= 0 then + self.legacy[t][i][j]:copy_from(self.output[t + self.chunk_size][i][j]) + end + for k = 1, #self.info.new_seq do + local batch = self.info.new_seq[k] + self.legacy[t][i][j][batch - 1]:fill(self.nn_act_default) + end + end + end + end + + if self.info.do_train then + self:set_err_input(self.info.err_input) + self:set_err_output(self.info.err_output) + + -- flush border gradient + for t = self.max_length + 1, self.max_length + self.delay do + if t > self.chunk_size then + break + end + for i = 1, #self.layers do + local dim_in, _ = self.layers[i]:get_dim() + for j = 1, #dim_in do + self.err_output[t][i][j]:fill(0) + end + end + end + end +end + +function network:propagate() + for i = 1, #self.queue do + local t, id = self.queue[i].chunk, self.queue[i].id + if t <= self.max_length then + self.layers[id]:propagate(self.input[t][id], self.output[t][id], t) + end + -- flush border activation + for j = 1, #self.border[t] do + local batch = self.border[t][j] + local _, dim_out = self.layers[id]:get_dim() + for k = 1, #dim_out do + self.output[t][id][k][batch - 1]:fill(self.nn_act_default) + end + end + end +end + +function network:back_propagate() + for i = #self.queue, 1, -1 do + local t, id = self.queue[i].chunk, self.queue[i].id + if t <= self.max_length then + -- flush border gradient + for j = 1, #self.border[t] do + local batch = self.border[t][j] + local _, dim_out = self.layers[id]:get_dim() + for k = 1, #dim_out do + self.err_input[t][id][k][batch - 1]:fill(0) + end + end + self.layers[id]:back_propagate(self.err_input[t][id], self.err_output[t][id], self.input[t][id], self.output[t][id], t) + if self.clip ~= nil then + local dim_in, _ = self.layers[id]:get_dim() + for j = 1, #dim_in do + self.err_output[t][id][j]:clip(-self.clip, self.clip) + end + end + end + end +end + +function network:update() + for i = 1, #self.queue do + local t, id = self.queue[i].chunk, self.queue[i].id + if t <= self.max_length then + self.layers[id]:update(self.err_input[t][id], self.input[t][id], self.output[t][id], t) + end + end +end + +function network:set_attr(name, value) + self.network:set_attr(name, value) +end + +function network:get_sublayer(id) + return self.network:get_sublayer(id) +end + +function network:get_params() + return self.network:get_params() +end |