aboutsummaryrefslogtreecommitdiff
path: root/nn
diff options
context:
space:
mode:
authorDeterminant <ted.sybil@gmail.com>2015-06-20 20:00:25 +0800
committerDeterminant <ted.sybil@gmail.com>2015-06-20 20:00:25 +0800
commitf3f4e74eb4dbb8829e5ee136ba4b0c0a7938b551 (patch)
tree8beb12182020267ce32904d646ad0c736c27dcd2 /nn
parent2ab9610a4fff798c1668cdc041515256fa813865 (diff)
change concept of ParamRepo; provide generalized param update; code clean-up; #25 #26 #27 #29
Diffstat (limited to 'nn')
-rw-r--r--nn/layer_dag.lua23
-rw-r--r--nn/layer_repo.lua4
-rw-r--r--nn/param_repo.lua70
3 files changed, 73 insertions, 24 deletions
diff --git a/nn/layer_dag.lua b/nn/layer_dag.lua
index 2dda7c9..8e30216 100644
--- a/nn/layer_dag.lua
+++ b/nn/layer_dag.lua
@@ -85,13 +85,14 @@ function DAGLayer:__init(id, global_conf, layer_conf)
end
end
+ -- topology sort
local queue = {}
local l = 1
local r = 1
for id, ref in pairs(layers) do
if ref.in_deg == 0 then
table.insert(queue, ref)
- nerv.utils.printf("adding source layer: %s\n", id)
+ nerv.info("adding source layer: %s", id)
r = r + 1
end
end
@@ -111,13 +112,13 @@ function DAGLayer:__init(id, global_conf, layer_conf)
end
end
for i = 1, #queue do
- nerv.utils.printf("queued layer: %s\n", queue[i].layer.id)
+ nerv.info("enqueued layer: %s", queue[i].layer.id)
end
for id, ref in pairs(layers) do
-- check wether the graph is connected
if ref.visited == false then
- nerv.utils.printf("warning: layer %s is ignored\n", id)
+ nerv.warning("layer %s is ignored", id)
end
end
@@ -131,7 +132,7 @@ function DAGLayer:__init(id, global_conf, layer_conf)
self.gconf = global_conf
end
-function DAGLayer:init(batch_size) -- topology sort
+function DAGLayer:init(batch_size)
for i, conn in ipairs(self.parsed_conn) do
local _, output_dim
local ref_from, port_from, ref_to, port_to
@@ -160,7 +161,7 @@ function DAGLayer:init(batch_size) -- topology sort
end
end
-- initialize sub layers
- ref.layer:init()
+ ref.layer:init(batch_size)
end
for i = 1, #self.dim_in do
if self.inputs[i] == nil then
@@ -227,7 +228,7 @@ function DAGLayer:propagate(input, output)
end
end
-function DAGLayer:back_propagate(next_bp_err, bp_err, input, output)
+function DAGLayer:back_propagate(bp_err, next_bp_err, input, output)
self:set_err_outputs(next_bp_err)
self:set_err_inputs(bp_err)
self:set_inputs(input)
@@ -235,16 +236,14 @@ function DAGLayer:back_propagate(next_bp_err, bp_err, input, output)
for i = #self.queue, 1, -1 do
local ref = self.queue[i]
-- print(ref.layer.id)
- ref.layer:back_propagate(ref.err_outputs, ref.err_inputs, ref.inputs, ref.outputs)
+ ref.layer:back_propagate(ref.err_inputs, ref.err_outputs, ref.inputs, ref.outputs)
end
end
function DAGLayer:get_params()
- local res = {}
+ local param_repos = {}
for id, ref in pairs(self.queue) do
- for i, p in ipairs(ref.layer:get_params()) do
- table.insert(res, p)
- end
+ table.insert(param_repos, ref.layer:get_params())
end
- return res
+ return nerv.ParamRepo.merge(param_repos)
end
diff --git a/nn/layer_repo.lua b/nn/layer_repo.lua
index b1d2248..602c37c 100644
--- a/nn/layer_repo.lua
+++ b/nn/layer_repo.lua
@@ -8,7 +8,7 @@ function LayerRepo:__init(layer_spec, param_repo, global_conf)
if layers[id] ~= nil then
nerv.error("a layer with id %s already exists", id)
end
- nerv.utils.printf("id: %s\n", id)
+ nerv.info("create layer: %s", id)
if type(spec[2]) ~= "table" then
nerv.error("layer config table is need")
end
@@ -17,7 +17,7 @@ function LayerRepo:__init(layer_spec, param_repo, global_conf)
nerv.error("parameter description table is needed")
end
for pname, pid in pairs(spec[1]) do
- layer_config[pname] = param_repo:get_param(pid, global_conf)
+ layer_config[pname] = param_repo:get_param(pid)
end
layers[id] = layer_type(id, global_conf, layer_config)
end
diff --git a/nn/param_repo.lua b/nn/param_repo.lua
index 3e37c31..ab971ba 100644
--- a/nn/param_repo.lua
+++ b/nn/param_repo.lua
@@ -1,26 +1,76 @@
local ParamRepo = nerv.class("nerv.ParamRepo")
+function ParamRepo:__init(plist)
+ self.params = {}
+ if plist ~= nil then
+ for i, p in ipairs(plist) do
+ self.params[p.id] = p
+ end
+ end
+end
+
+function ParamRepo:add(pid, p)
+ if self.params[pid] ~= nil then
+ nerv.error("duplicate params with the same id: %s", pid)
+ end
+ self.params[pid] = p
+end
-function ParamRepo:__init(param_files)
- local param_table = {}
+function ParamRepo:remove(pid, p)
+ if self.params[pid] == nil then
+ nerv.error("param %s does not exit", pid)
+ end
+ table.remove(self.params, pid)
+end
+
+function ParamRepo.merge(repos)
+ local self = nerv.ParamRepo()
+ for i, repo in ipairs(repos) do
+ if not nerv.is_type(repo, "nerv.ParamRepo") then
+ nerv.error("nerv.ParamRepo objects expected, got %s", repo)
+ end
+ for pid, p in pairs(repo.params) do
+ self:add(pid, p)
+ end
+ end
+ return self
+end
+
+function ParamRepo:import(param_files, pids, gconf)
if type(param_files) ~= "table" then
nerv.error("param file table is need")
end
for i = 1, #param_files do
local pf = nerv.ChunkFile(param_files[i], "r")
for cid, cspec in pairs(pf.metadata) do
- if param_table[cid] ~= nil then
- nerv.error("conflicting chunk id in param files")
+ if pids == nil or pids[cid] ~= nil then
+ local p = pf:read_chunk(cid, gconf)
+ if not nerv.is_type(p, "nerv.Param") then
+ nerv.error("param chunk is expected")
+ end
+ self:add(cid, p)
end
- param_table[cid] = pf
end
end
- self.param_table = param_table
end
-function ParamRepo:get_param(pid, global_conf)
- local pf = self.param_table[pid]
- if pf == nil then
+function ParamRepo:export(param_file, pids)
+ cf = nerv.ChunkFile(param_file, "w")
+ if pids == nil then
+ for id, p in pairs(self.params) do
+ cf:write_chunk(p)
+ end
+ else
+ for i, pid in ipairs(pids) do
+ cf:write_chunk(self:get_param(pid))
+ end
+ end
+ cf:close()
+end
+
+function ParamRepo:get_param(pid)
+ local p = self.params[pid]
+ if p == nil then
nerv.error("param with id %s not found", pid)
end
- return pf:read_chunk(pid, global_conf)
+ return p
end