diff options
-rw-r--r-- | nerv/nn/network.lua | 124 |
1 files changed, 119 insertions, 5 deletions
diff --git a/nerv/nn/network.lua b/nerv/nn/network.lua index 01290e7..e1a9629 100644 --- a/nerv/nn/network.lua +++ b/nerv/nn/network.lua @@ -111,6 +111,10 @@ function network:init(batch_size, chunk_size) self:make_initial_store() collectgarbage('collect') + + for i = 1, #self.layers do + self.layers[i]:init(batch_size, chunk_size) + end end function network:topsort() @@ -315,23 +319,86 @@ function network:make_initial_store() end end +function network:set_input(input) + for t = 1, #self.chunk_size do + for i = 1, #self.dim_in do + local edge = self.socket.inputs[i] + local id, port, time = edge[1], edge[2], edge[3] + if t + time >= 1 and t + time <= self.chunk_size then + self.input[t + time][id][port] = input[t][i] + end + end + end +end + +function network:set_output(output) + for t = 1, #self.chunk_size do + for i = 1, #self.dim_out do + local edge = self.socket.outputs[i] + local id, port, time = edge[1], edge[2], edge[3] + if t - time >= 1 and t - time <= self.chunk_size then + self.output[t - time][id][port] = output[t][i] + end + end + end +end + +function network:set_err_input(err_input) + for t = 1, #self.chunk_size do + for i = 1, #self.dim_out do + local edge = self.socket.outputs[i] + local id, port, time = edge[1], edge[2], edge[3] + if t - time >= 1 and t - time <= self.chunk_size then + self.err_input[t - time][id][port] = err_input[t][i] + end + end + end +end + +function network:set_err_output(err_output) + for t = 1, self.chunk_size do + for i = 1, #self.dim_in do + local edge = self.socket.inputs[i] + local id, port, time = edge[1], edge[2], edge[3] + if t + time >= 1 and t + time <= self.chunk_size then + self.err_output[t + time][id][port] = err_output[t][i] + end + end + end +end + function network:mini_batch_init(information) self.info = information - self.max_chunk = 0 + self.max_length = 0 + self.border = {} + for i = 1, self.chunk_size do + self.border[i] = {} + end for i = 1, self.batch_size do - if self.info.seq_length[i] > self.max_chunk then - self.max_chunk = self.info.seq_length[i] + if self.info.seq_length[i] > self.max_length then + self.max_length = self.info.seq_length[i] + end + for t = 1, self.delay do + local chunk = self.info.seq_length[i] + t + if chunk > self.chunk_size then + break + end + table.insert(self.border[chunk], i) end end for t = 1 - self.delay, 0 do for i = 1, #self.layers do local _, dim_out = self.layers[i]:get_dim() for j = 1, #dim_out do - self.output[t][i][j]:copy_from(self.output[t + self.chunk_size][i][j]) + self.legacy[t][i][j]:copy_from(self.output[t + self.chunk_size][i][j]) + for k = 1, #self.info.new_seq do + local batch = self.info.new_seq[k] + self.legacy[t][i][j][batch - 1]:fill(self.nn_act_default) + end end end end - for t = self.max_chunk + 1, self.max_chunk + self.delay do + for t = self.max_length + 1, self.max_length + self.delay do if t > self.chunk_size then break end @@ -345,4 +412,51 @@ function network:mini_batch_init(information) end function network:propagate(input, output) + network:set_input(input) + network:set_output(output) + for i = 1, #self.queue do + local t, id = self.queue[i].chunk, self.queue[i].id + if t <= self.max_length then + self.layers[id]:propagate(self.input[t][id], self.output[t][id], t) + end + for j = 1, #self.border[t] do + local batch = self.border[t][j] + local _, dim_out = self.layers[id]:get_dim() + for k = 1, #dim_out do + self.output[t][id][k][batch - 1]:fill(self.nn_act_default) + end + end + end +end + +function network:back_propagate(bp_err, next_bp_err, input, output) + network:set_input(input) + network:set_output(output) + network:set_err_input(bp_err) + network:set_err_output(next_bp_err) + for i = #self.queue, 1, -1 do + local t, id = self.queue[i].chunk, self.queue[i].id + if t <= self.max_length then + for j = 1, #self.border[t] do + local batch = self.border[t][j] + local dim_in, _ = self.layers[id]:get_dim() + for k = 1, #dim_in do + self.err_input[t][id][k][batch - 1]:fill(0) + end + end + self.layers[id]:back_propagate(self.err_input[t][id], self.err_output[t][id], self.input[t][id], self.output[t][id], t) + end + end +end + +function network:update(bp_err, input, output) + network:set_input(input) + network:set_output(output) + network:set_err_input(bp_err) + for i = 1, #self.queue do + local t, id = self.queue[i].chunk, self.queue[i].id + if t <= self.max_length then + self.layers[id]:update(self.err_input[t][id], self.input[t][id], self.output[t][id], t) + end + end end |