aboutsummaryrefslogtreecommitdiff
path: root/nerv/nn/network.lua
diff options
context:
space:
mode:
Diffstat (limited to 'nerv/nn/network.lua')
-rw-r--r--nerv/nn/network.lua63
1 files changed, 37 insertions, 26 deletions
diff --git a/nerv/nn/network.lua b/nerv/nn/network.lua
index 39df5f0..35e11e3 100644
--- a/nerv/nn/network.lua
+++ b/nerv/nn/network.lua
@@ -2,8 +2,9 @@ local network = nerv.class('nerv.Network')
function network:__init(id, global_conf, network_conf)
self.id = id
- self.dim_in = network_conf.network.dim_in
- self.dim_out = network_conf.network.dim_out
+ self.network = network_conf.network
+ self.dim_in = self.network.dim_in
+ self.dim_out = self.network.dim_out
self.gconf = global_conf
if self.gconf.use_cpu then
self.mat_type = self.gconf.mmat_type
@@ -18,7 +19,6 @@ function network:__init(id, global_conf, network_conf)
self.layers = {}
self.input_conn = {}
self.output_conn = {}
- self.network = network_conf.network
self.socket = self:compile(self.network)
for i = 1, #self.dim_in do
local edge = self.socket.inputs[i]
@@ -368,8 +368,21 @@ function network:set_err_output(err_output)
end
end
-function network:mini_batch_init(information)
- self.info = information
+--[[
+ [info] is a table that contains information of current mini-batch. These fields must be contained:
+ [input], [output] : matrix array which stores the network input and output
+ [seq_length] : a table contains the length of every sequences
+ [new_seq]: a table contains the batch number of new sequences
+ [do_train]: a bool value indicates do train or not
+ if [do_train] is true, these fileds also must be contained:
+ [err_input], [err_output] : matrix array which stores the network err_input and err_output
+--]]
+function network:mini_batch_init(info)
+ self.info = info
+ self:set_input(self.info.input)
+ self:set_output(self.info.output)
+
+ -- calculate border
self.max_length = 0
self.border = {}
for i = 1, self.chunk_size do
@@ -387,6 +400,7 @@ function network:mini_batch_init(information)
table.insert(self.border[chunk], i)
end
end
+
-- copy legacy
for t = 1 - self.delay, 0 do
for i = 1, #self.layers do
@@ -402,23 +416,27 @@ function network:mini_batch_init(information)
end
end
end
- -- flush border gradient
- for t = self.max_length + 1, self.max_length + self.delay do
- if t > self.chunk_size then
- break
- end
- for i = 1, #self.layers do
- local dim_in, _ = self.layers[i]:get_dim()
- for j = 1, #dim_in do
- self.err_output[t][i][j]:fill(0)
+
+ if self.info.do_train then
+ self:set_err_input(self.info.err_input)
+ self:set_err_output(self.info.err_output)
+
+ -- flush border gradient
+ for t = self.max_length + 1, self.max_length + self.delay do
+ if t > self.chunk_size then
+ break
+ end
+ for i = 1, #self.layers do
+ local dim_in, _ = self.layers[i]:get_dim()
+ for j = 1, #dim_in do
+ self.err_output[t][i][j]:fill(0)
+ end
end
end
end
end
-function network:propagate(input, output)
- self:set_input(input)
- self:set_output(output)
+function network:propagate()
for i = 1, #self.queue do
local t, id = self.queue[i].chunk, self.queue[i].id
if t <= self.max_length then
@@ -435,11 +453,7 @@ function network:propagate(input, output)
end
end
-function network:back_propagate(bp_err, next_bp_err, input, output)
- self:set_input(input)
- self:set_output(output)
- self:set_err_input(bp_err)
- self:set_err_output(next_bp_err)
+function network:back_propagate()
for i = #self.queue, 1, -1 do
local t, id = self.queue[i].chunk, self.queue[i].id
if t <= self.max_length then
@@ -462,10 +476,7 @@ function network:back_propagate(bp_err, next_bp_err, input, output)
end
end
-function network:update(bp_err, input, output)
- self:set_input(input)
- self:set_output(output)
- self:set_err_input(bp_err)
+function network:update()
for i = 1, #self.queue do
local t, id = self.queue[i].chunk, self.queue[i].id
if t <= self.max_length then