aboutsummaryrefslogtreecommitdiff
path: root/nerv
diff options
context:
space:
mode:
authorQi Liu <[email protected]>2016-03-11 13:32:00 +0800
committerQi Liu <[email protected]>2016-03-11 13:32:00 +0800
commitf26288ba61d3d16866e1b227a71e7d9c46923436 (patch)
treeea41bb08994d9d2ee59c3ac5f3ec2c41bcaac6d2 /nerv
parent05fcde5bf0caa1ceb70fef02fc88eda6f00c5ed5 (diff)
update mini_batch_init
Diffstat (limited to 'nerv')
-rw-r--r--nerv/main.lua73
-rw-r--r--nerv/nn/network.lua63
2 files changed, 37 insertions, 99 deletions
diff --git a/nerv/main.lua b/nerv/main.lua
deleted file mode 100644
index 7c82ebf..0000000
--- a/nerv/main.lua
+++ /dev/null
@@ -1,73 +0,0 @@
-local global_conf = {
- cumat_type = nerv.CuMatrixFloat,
- param_random = function() return 0 end,
- lrate = 0.1,
- wcost = 0,
- momentum = 0.9,
- batch_size = 2,
-}
-
-local layer_repo = nerv.LayerRepo(
- {
- ['nerv.RNNLayer'] = {
- rnn1 = {dim_in = {23}, dim_out = {26}},
- rnn2 = {dim_in = {26}, dim_out = {26}},
- },
- ['nerv.AffineLayer'] = {
- input = {dim_in = {62}, dim_out = {23}},
- output = {dim_in = {26, 79}, dim_out = {79}},
- },
- ['nerv.SigmoidLayer'] = {
- sigmoid = {dim_in = {23}, dim_out = {23}},
- },
- ['nerv.IdentityLayer'] = {
- softmax = {dim_in = {79}, dim_out = {79}},
- },
- ['nerv.DuplicateLayer'] = {
- dup = {dim_in = {79}, dim_out = {79, 79}},
- },
- }, nerv.ParamRepo(), global_conf)
-
-local connections = {
- {'<input>[1]', 'input[1]', 0},
- {'input[1]', 'sigmoid[1]', 0},
- {'sigmoid[1]', 'rnn1[1]', 0},
- {'rnn1[1]', 'rnn2[1]', 0},
- {'rnn2[1]', 'output[1]', 0},
- {'output[1]', 'dup[1]', 0},
- {'dup[1]', 'output[2]', -1},
- {'dup[2]', 'softmax[1]', 0},
- {'softmax[1]', '<output>[1]', 0},
-}
-
-local graph = nerv.GraphLayer('graph', global_conf, {dim_in = {62}, dim_out = {79}, layer_repo = layer_repo, connections = connections})
-
-local network = nerv.Network('network', global_conf, {network = graph})
-
-local batch = global_conf.batch_size
-local chunk = 5
-network:init(batch, chunk)
-
-local input = {}
-local output = {}
-local err_input = {}
-local err_output = {}
-local input_size = 62
-local output_size = 79
-for i = 1, chunk do
- input[i] = {global_conf.cumat_type(batch, input_size)}
- output[i] = {global_conf.cumat_type(batch, output_size)}
- err_input[i] = {global_conf.cumat_type(batch, output_size)}
- err_output[i] = {global_conf.cumat_type(batch, input_size)}
-end
-
-for i = 1, 100 do
- network:mini_batch_init({seq_length = {5, 3}, new_seq = {2}})
- network:propagate(input, output)
- network:back_propagate(err_input, err_output, input, output)
- network:update(err_input, input, output)
-end
-
-local tmp = network:get_params()
-
-tmp:export('../../workspace/test.param')
diff --git a/nerv/nn/network.lua b/nerv/nn/network.lua
index 39df5f0..35e11e3 100644
--- a/nerv/nn/network.lua
+++ b/nerv/nn/network.lua
@@ -2,8 +2,9 @@ local network = nerv.class('nerv.Network')
function network:__init(id, global_conf, network_conf)
self.id = id
- self.dim_in = network_conf.network.dim_in
- self.dim_out = network_conf.network.dim_out
+ self.network = network_conf.network
+ self.dim_in = self.network.dim_in
+ self.dim_out = self.network.dim_out
self.gconf = global_conf
if self.gconf.use_cpu then
self.mat_type = self.gconf.mmat_type
@@ -18,7 +19,6 @@ function network:__init(id, global_conf, network_conf)
self.layers = {}
self.input_conn = {}
self.output_conn = {}
- self.network = network_conf.network
self.socket = self:compile(self.network)
for i = 1, #self.dim_in do
local edge = self.socket.inputs[i]
@@ -368,8 +368,21 @@ function network:set_err_output(err_output)
end
end
-function network:mini_batch_init(information)
- self.info = information
+--[[
+ [info] is a table that contains information of current mini-batch. These fields must be contained:
+ [input], [output] : matrix array which stores the network input and output
+ [seq_length] : a table contains the length of every sequences
+ [new_seq]: a table contains the batch number of new sequences
+ [do_train]: a bool value indicates do train or not
+ if [do_train] is true, these fileds also must be contained:
+ [err_input], [err_output] : matrix array which stores the network err_input and err_output
+--]]
+function network:mini_batch_init(info)
+ self.info = info
+ self:set_input(self.info.input)
+ self:set_output(self.info.output)
+
+ -- calculate border
self.max_length = 0
self.border = {}
for i = 1, self.chunk_size do
@@ -387,6 +400,7 @@ function network:mini_batch_init(information)
table.insert(self.border[chunk], i)
end
end
+
-- copy legacy
for t = 1 - self.delay, 0 do
for i = 1, #self.layers do
@@ -402,23 +416,27 @@ function network:mini_batch_init(information)
end
end
end
- -- flush border gradient
- for t = self.max_length + 1, self.max_length + self.delay do
- if t > self.chunk_size then
- break
- end
- for i = 1, #self.layers do
- local dim_in, _ = self.layers[i]:get_dim()
- for j = 1, #dim_in do
- self.err_output[t][i][j]:fill(0)
+
+ if self.info.do_train then
+ self:set_err_input(self.info.err_input)
+ self:set_err_output(self.info.err_output)
+
+ -- flush border gradient
+ for t = self.max_length + 1, self.max_length + self.delay do
+ if t > self.chunk_size then
+ break
+ end
+ for i = 1, #self.layers do
+ local dim_in, _ = self.layers[i]:get_dim()
+ for j = 1, #dim_in do
+ self.err_output[t][i][j]:fill(0)
+ end
end
end
end
end
-function network:propagate(input, output)
- self:set_input(input)
- self:set_output(output)
+function network:propagate()
for i = 1, #self.queue do
local t, id = self.queue[i].chunk, self.queue[i].id
if t <= self.max_length then
@@ -435,11 +453,7 @@ function network:propagate(input, output)
end
end
-function network:back_propagate(bp_err, next_bp_err, input, output)
- self:set_input(input)
- self:set_output(output)
- self:set_err_input(bp_err)
- self:set_err_output(next_bp_err)
+function network:back_propagate()
for i = #self.queue, 1, -1 do
local t, id = self.queue[i].chunk, self.queue[i].id
if t <= self.max_length then
@@ -462,10 +476,7 @@ function network:back_propagate(bp_err, next_bp_err, input, output)
end
end
-function network:update(bp_err, input, output)
- self:set_input(input)
- self:set_output(output)
- self:set_err_input(bp_err)
+function network:update()
for i = 1, #self.queue do
local t, id = self.queue[i].chunk, self.queue[i].id
if t <= self.max_length then