aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--nerv/layer/duplicate.lua4
-rw-r--r--nerv/layer/graph.lua24
-rw-r--r--nerv/layer/init.lua2
-rw-r--r--nerv/layer/rnn.lua4
-rw-r--r--nerv/main.lua12
-rw-r--r--nerv/nn/network.lua15
6 files changed, 54 insertions, 7 deletions
diff --git a/nerv/layer/duplicate.lua b/nerv/layer/duplicate.lua
index 1a93b26..8988617 100644
--- a/nerv/layer/duplicate.lua
+++ b/nerv/layer/duplicate.lua
@@ -38,3 +38,7 @@ end
function DuplicateLayer:update()
end
+
+function DuplicateLayer:get_params()
+ return nerv.ParamRepo({})
+end
diff --git a/nerv/layer/graph.lua b/nerv/layer/graph.lua
index d72d849..1406eff 100644
--- a/nerv/layer/graph.lua
+++ b/nerv/layer/graph.lua
@@ -21,6 +21,30 @@ local function parse_id(str)
return id, port
end
+function GraphLayer:add_prefix(layers, connections)
+ local function ap(name)
+ return self.id .. '.' .. name
+ end
+
+ for layer_type, sublayers in pairs(layers) do
+ local tmp = {}
+ for name, layer_config in pairs(sublayers) do
+ tmp[ap(name)] = layer_config
+ end
+ layers[layer_type] = tmp
+ end
+
+ for i = 1, #connections do
+ local from, to = connections[i][1], connections[i][2]
+ if parse_id(from) ~= '<input>' then
+ connections[i][1] = ap(from)
+ end
+ if parse_id(to) ~= '<output>' then
+ connections[i][2] = ap(to)
+ end
+ end
+end
+
function GraphLayer:discover(id, layer_repo)
if id == '<output>' then
id = '<input>'
diff --git a/nerv/layer/init.lua b/nerv/layer/init.lua
index 39f97b1..4fabefa 100644
--- a/nerv/layer/init.lua
+++ b/nerv/layer/init.lua
@@ -75,7 +75,7 @@ function Layer:set_attr(name, value)
end
function Layer:get_sublayer(id)
- nerv.error('primitive layer does not have sublayers.')
+ nerv.error('primitive layer does not have sublayers')
end
function Layer:find_param(pid_list, lconf, gconf, p_type, p_dim)
diff --git a/nerv/layer/rnn.lua b/nerv/layer/rnn.lua
index 806ac58..38f2326 100644
--- a/nerv/layer/rnn.lua
+++ b/nerv/layer/rnn.lua
@@ -27,8 +27,6 @@ function RNNLayer:__init(id, global_conf, layer_conf)
}
}
- local layer_repo = nerv.LayerRepo(layers, pr, global_conf)
-
local connections = {
{'<input>[1]', 'main[1]', 0},
{'main[1]', 'sigmoid[1]', 0},
@@ -37,5 +35,7 @@ function RNNLayer:__init(id, global_conf, layer_conf)
{'dup[2]', '<output>[1]', 0},
}
+ self:add_prefix(layers, connections)
+ local layer_repo = nerv.LayerRepo(layers, pr, global_conf)
self:graph_init(layer_repo, connections)
end
diff --git a/nerv/main.lua b/nerv/main.lua
index 865aba0..7c82ebf 100644
--- a/nerv/main.lua
+++ b/nerv/main.lua
@@ -10,7 +10,8 @@ local global_conf = {
local layer_repo = nerv.LayerRepo(
{
['nerv.RNNLayer'] = {
- rnn = {dim_in = {23}, dim_out = {26}},
+ rnn1 = {dim_in = {23}, dim_out = {26}},
+ rnn2 = {dim_in = {26}, dim_out = {26}},
},
['nerv.AffineLayer'] = {
input = {dim_in = {62}, dim_out = {23}},
@@ -30,8 +31,9 @@ local layer_repo = nerv.LayerRepo(
local connections = {
{'<input>[1]', 'input[1]', 0},
{'input[1]', 'sigmoid[1]', 0},
- {'sigmoid[1]', 'rnn[1]', 0},
- {'rnn[1]', 'output[1]', 0},
+ {'sigmoid[1]', 'rnn1[1]', 0},
+ {'rnn1[1]', 'rnn2[1]', 0},
+ {'rnn2[1]', 'output[1]', 0},
{'output[1]', 'dup[1]', 0},
{'dup[1]', 'output[2]', -1},
{'dup[2]', 'softmax[1]', 0},
@@ -65,3 +67,7 @@ for i = 1, 100 do
network:back_propagate(err_input, err_output, input, output)
network:update(err_input, input, output)
end
+
+local tmp = network:get_params()
+
+tmp:export('../../workspace/test.param')
diff --git a/nerv/nn/network.lua b/nerv/nn/network.lua
index 0bbcc59..39df5f0 100644
--- a/nerv/nn/network.lua
+++ b/nerv/nn/network.lua
@@ -18,7 +18,8 @@ function network:__init(id, global_conf, network_conf)
self.layers = {}
self.input_conn = {}
self.output_conn = {}
- self.socket = self:compile(network_conf.network)
+ self.network = network_conf.network
+ self.socket = self:compile(self.network)
for i = 1, #self.dim_in do
local edge = self.socket.inputs[i]
local id, port, time = edge[1], edge[2], edge[3]
@@ -472,3 +473,15 @@ function network:update(bp_err, input, output)
end
end
end
+
+function network:set_attr(name, value)
+ self.network:set_attr(name, value)
+end
+
+function network:get_sublayer(id)
+ return self.network:get_sublayer(id)
+end
+
+function network:get_params()
+ return self.network:get_params()
+end