From 8374e8fbc545633b6adf5c4090af8997a65778d2 Mon Sep 17 00:00:00 2001 From: Qi Liu Date: Thu, 3 Mar 2016 19:42:15 +0800 Subject: update add_prefix for graph layer --- nerv/layer/duplicate.lua | 4 ++++ nerv/layer/graph.lua | 24 ++++++++++++++++++++++++ nerv/layer/init.lua | 2 +- nerv/layer/rnn.lua | 4 ++-- nerv/main.lua | 12 +++++++++--- nerv/nn/network.lua | 15 ++++++++++++++- 6 files changed, 54 insertions(+), 7 deletions(-) (limited to 'nerv') diff --git a/nerv/layer/duplicate.lua b/nerv/layer/duplicate.lua index 1a93b26..8988617 100644 --- a/nerv/layer/duplicate.lua +++ b/nerv/layer/duplicate.lua @@ -38,3 +38,7 @@ end function DuplicateLayer:update() end + +function DuplicateLayer:get_params() + return nerv.ParamRepo({}) +end diff --git a/nerv/layer/graph.lua b/nerv/layer/graph.lua index d72d849..1406eff 100644 --- a/nerv/layer/graph.lua +++ b/nerv/layer/graph.lua @@ -21,6 +21,30 @@ local function parse_id(str) return id, port end +function GraphLayer:add_prefix(layers, connections) + local function ap(name) + return self.id .. '.' .. name + end + + for layer_type, sublayers in pairs(layers) do + local tmp = {} + for name, layer_config in pairs(sublayers) do + tmp[ap(name)] = layer_config + end + layers[layer_type] = tmp + end + + for i = 1, #connections do + local from, to = connections[i][1], connections[i][2] + if parse_id(from) ~= '' then + connections[i][1] = ap(from) + end + if parse_id(to) ~= '' then + connections[i][2] = ap(to) + end + end +end + function GraphLayer:discover(id, layer_repo) if id == '' then id = '' diff --git a/nerv/layer/init.lua b/nerv/layer/init.lua index 39f97b1..4fabefa 100644 --- a/nerv/layer/init.lua +++ b/nerv/layer/init.lua @@ -75,7 +75,7 @@ function Layer:set_attr(name, value) end function Layer:get_sublayer(id) - nerv.error('primitive layer does not have sublayers.') + nerv.error('primitive layer does not have sublayers') end function Layer:find_param(pid_list, lconf, gconf, p_type, p_dim) diff --git a/nerv/layer/rnn.lua b/nerv/layer/rnn.lua index 806ac58..38f2326 100644 --- a/nerv/layer/rnn.lua +++ b/nerv/layer/rnn.lua @@ -27,8 +27,6 @@ function RNNLayer:__init(id, global_conf, layer_conf) } } - local layer_repo = nerv.LayerRepo(layers, pr, global_conf) - local connections = { {'[1]', 'main[1]', 0}, {'main[1]', 'sigmoid[1]', 0}, @@ -37,5 +35,7 @@ function RNNLayer:__init(id, global_conf, layer_conf) {'dup[2]', '[1]', 0}, } + self:add_prefix(layers, connections) + local layer_repo = nerv.LayerRepo(layers, pr, global_conf) self:graph_init(layer_repo, connections) end diff --git a/nerv/main.lua b/nerv/main.lua index 865aba0..7c82ebf 100644 --- a/nerv/main.lua +++ b/nerv/main.lua @@ -10,7 +10,8 @@ local global_conf = { local layer_repo = nerv.LayerRepo( { ['nerv.RNNLayer'] = { - rnn = {dim_in = {23}, dim_out = {26}}, + rnn1 = {dim_in = {23}, dim_out = {26}}, + rnn2 = {dim_in = {26}, dim_out = {26}}, }, ['nerv.AffineLayer'] = { input = {dim_in = {62}, dim_out = {23}}, @@ -30,8 +31,9 @@ local layer_repo = nerv.LayerRepo( local connections = { {'[1]', 'input[1]', 0}, {'input[1]', 'sigmoid[1]', 0}, - {'sigmoid[1]', 'rnn[1]', 0}, - {'rnn[1]', 'output[1]', 0}, + {'sigmoid[1]', 'rnn1[1]', 0}, + {'rnn1[1]', 'rnn2[1]', 0}, + {'rnn2[1]', 'output[1]', 0}, {'output[1]', 'dup[1]', 0}, {'dup[1]', 'output[2]', -1}, {'dup[2]', 'softmax[1]', 0}, @@ -65,3 +67,7 @@ for i = 1, 100 do network:back_propagate(err_input, err_output, input, output) network:update(err_input, input, output) end + +local tmp = network:get_params() + +tmp:export('../../workspace/test.param') diff --git a/nerv/nn/network.lua b/nerv/nn/network.lua index 0bbcc59..39df5f0 100644 --- a/nerv/nn/network.lua +++ b/nerv/nn/network.lua @@ -18,7 +18,8 @@ function network:__init(id, global_conf, network_conf) self.layers = {} self.input_conn = {} self.output_conn = {} - self.socket = self:compile(network_conf.network) + self.network = network_conf.network + self.socket = self:compile(self.network) for i = 1, #self.dim_in do local edge = self.socket.inputs[i] local id, port, time = edge[1], edge[2], edge[3] @@ -472,3 +473,15 @@ function network:update(bp_err, input, output) end end end + +function network:set_attr(name, value) + self.network:set_attr(name, value) +end + +function network:get_sublayer(id) + return self.network:get_sublayer(id) +end + +function network:get_params() + return self.network:get_params() +end -- cgit v1.2.3