aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--nerv/nn/network.lua124
1 files changed, 119 insertions, 5 deletions
diff --git a/nerv/nn/network.lua b/nerv/nn/network.lua
index 01290e7..e1a9629 100644
--- a/nerv/nn/network.lua
+++ b/nerv/nn/network.lua
@@ -111,6 +111,10 @@ function network:init(batch_size, chunk_size)
self:make_initial_store()
collectgarbage('collect')
+
+ for i = 1, #self.layers do
+ self.layers[i]:init(batch_size, chunk_size)
+ end
end
function network:topsort()
@@ -315,23 +319,86 @@ function network:make_initial_store()
end
end
+function network:set_input(input)
+ for t = 1, #self.chunk_size do
+ for i = 1, #self.dim_in do
+ local edge = self.socket.inputs[i]
+ local id, port, time = edge[1], edge[2], edge[3]
+ if t + time >= 1 and t + time <= self.chunk_size then
+ self.input[t + time][id][port] = input[t][i]
+ end
+ end
+ end
+end
+
+function network:set_output(output)
+ for t = 1, #self.chunk_size do
+ for i = 1, #self.dim_out do
+ local edge = self.socket.outputs[i]
+ local id, port, time = edge[1], edge[2], edge[3]
+ if t - time >= 1 and t - time <= self.chunk_size then
+ self.output[t - time][id][port] = output[t][i]
+ end
+ end
+ end
+end
+
+function network:set_err_input(err_input)
+ for t = 1, #self.chunk_size do
+ for i = 1, #self.dim_out do
+ local edge = self.socket.outputs[i]
+ local id, port, time = edge[1], edge[2], edge[3]
+ if t - time >= 1 and t - time <= self.chunk_size then
+ self.err_input[t - time][id][port] = err_input[t][i]
+ end
+ end
+ end
+end
+
+function network:set_err_output(err_output)
+ for t = 1, self.chunk_size do
+ for i = 1, #self.dim_in do
+ local edge = self.socket.inputs[i]
+ local id, port, time = edge[1], edge[2], edge[3]
+ if t + time >= 1 and t + time <= self.chunk_size then
+ self.err_output[t + time][id][port] = err_output[t][i]
+ end
+ end
+ end
+end
+
function network:mini_batch_init(information)
self.info = information
- self.max_chunk = 0
+ self.max_length = 0
+ self.border = {}
+ for i = 1, self.chunk_size do
+ self.border[i] = {}
+ end
for i = 1, self.batch_size do
- if self.info.seq_length[i] > self.max_chunk then
- self.max_chunk = self.info.seq_length[i]
+ if self.info.seq_length[i] > self.max_length then
+ self.max_length = self.info.seq_length[i]
+ end
+ for t = 1, self.delay do
+ local chunk = self.info.seq_length[i] + t
+ if chunk > self.chunk_size then
+ break
+ end
+ table.insert(self.border[chunk], i)
end
end
for t = 1 - self.delay, 0 do
for i = 1, #self.layers do
local _, dim_out = self.layers[i]:get_dim()
for j = 1, #dim_out do
- self.output[t][i][j]:copy_from(self.output[t + self.chunk_size][i][j])
+ self.legacy[t][i][j]:copy_from(self.output[t + self.chunk_size][i][j])
+ for k = 1, #self.info.new_seq do
+ local batch = self.info.new_seq[k]
+ self.legacy[t][i][j][batch - 1]:fill(self.nn_act_default)
+ end
end
end
end
- for t = self.max_chunk + 1, self.max_chunk + self.delay do
+ for t = self.max_length + 1, self.max_length + self.delay do
if t > self.chunk_size then
break
end
@@ -345,4 +412,51 @@ function network:mini_batch_init(information)
end
function network:propagate(input, output)
+ network:set_input(input)
+ network:set_output(output)
+ for i = 1, #self.queue do
+ local t, id = self.queue[i].chunk, self.queue[i].id
+ if t <= self.max_length then
+ self.layers[id]:propagate(self.input[t][id], self.output[t][id], t)
+ end
+ for j = 1, #self.border[t] do
+ local batch = self.border[t][j]
+ local _, dim_out = self.layers[id]:get_dim()
+ for k = 1, #dim_out do
+ self.output[t][id][k][batch - 1]:fill(self.nn_act_default)
+ end
+ end
+ end
+end
+
+function network:back_propagate(bp_err, next_bp_err, input, output)
+ network:set_input(input)
+ network:set_output(output)
+ network:set_err_input(bp_err)
+ network:set_err_output(next_bp_err)
+ for i = #self.queue, 1, -1 do
+ local t, id = self.queue[i].chunk, self.queue[i].id
+ if t <= self.max_length then
+ for j = 1, #self.border[t] do
+ local batch = self.border[t][j]
+ local dim_in, _ = self.layers[id]:get_dim()
+ for k = 1, #dim_in do
+ self.err_input[t][id][k][batch - 1]:fill(0)
+ end
+ end
+ self.layers[id]:back_propagate(self.err_input[t][id], self.err_output[t][id], self.input[t][id], self.output[t][id], t)
+ end
+ end
+end
+
+function network:update(bp_err, input, output)
+ network:set_input(input)
+ network:set_output(output)
+ network:set_err_input(bp_err)
+ for i = 1, #self.queue do
+ local t, id = self.queue[i].chunk, self.queue[i].id
+ if t <= self.max_length then
+ self.layers[id]:update(self.err_input[t][id], self.input[t][id], self.output[t][id], t)
+ end
+ end
end