diff options
Diffstat (limited to 'nerv/layer')
-rw-r--r-- | nerv/layer/duplicate.lua | 4 | ||||
-rw-r--r-- | nerv/layer/graph.lua | 24 | ||||
-rw-r--r-- | nerv/layer/init.lua | 2 | ||||
-rw-r--r-- | nerv/layer/rnn.lua | 4 |
4 files changed, 31 insertions, 3 deletions
diff --git a/nerv/layer/duplicate.lua b/nerv/layer/duplicate.lua index 1a93b26..8988617 100644 --- a/nerv/layer/duplicate.lua +++ b/nerv/layer/duplicate.lua @@ -38,3 +38,7 @@ end function DuplicateLayer:update() end + +function DuplicateLayer:get_params() + return nerv.ParamRepo({}) +end diff --git a/nerv/layer/graph.lua b/nerv/layer/graph.lua index d72d849..1406eff 100644 --- a/nerv/layer/graph.lua +++ b/nerv/layer/graph.lua @@ -21,6 +21,30 @@ local function parse_id(str) return id, port end +function GraphLayer:add_prefix(layers, connections) + local function ap(name) + return self.id .. '.' .. name + end + + for layer_type, sublayers in pairs(layers) do + local tmp = {} + for name, layer_config in pairs(sublayers) do + tmp[ap(name)] = layer_config + end + layers[layer_type] = tmp + end + + for i = 1, #connections do + local from, to = connections[i][1], connections[i][2] + if parse_id(from) ~= '<input>' then + connections[i][1] = ap(from) + end + if parse_id(to) ~= '<output>' then + connections[i][2] = ap(to) + end + end +end + function GraphLayer:discover(id, layer_repo) if id == '<output>' then id = '<input>' diff --git a/nerv/layer/init.lua b/nerv/layer/init.lua index 39f97b1..4fabefa 100644 --- a/nerv/layer/init.lua +++ b/nerv/layer/init.lua @@ -75,7 +75,7 @@ function Layer:set_attr(name, value) end function Layer:get_sublayer(id) - nerv.error('primitive layer does not have sublayers.') + nerv.error('primitive layer does not have sublayers') end function Layer:find_param(pid_list, lconf, gconf, p_type, p_dim) diff --git a/nerv/layer/rnn.lua b/nerv/layer/rnn.lua index 806ac58..38f2326 100644 --- a/nerv/layer/rnn.lua +++ b/nerv/layer/rnn.lua @@ -27,8 +27,6 @@ function RNNLayer:__init(id, global_conf, layer_conf) } } - local layer_repo = nerv.LayerRepo(layers, pr, global_conf) - local connections = { {'<input>[1]', 'main[1]', 0}, {'main[1]', 'sigmoid[1]', 0}, @@ -37,5 +35,7 @@ function RNNLayer:__init(id, global_conf, layer_conf) {'dup[2]', '<output>[1]', 0}, } + self:add_prefix(layers, connections) + local layer_repo = nerv.LayerRepo(layers, pr, global_conf) self:graph_init(layer_repo, connections) end |