diff options
Diffstat (limited to 'nn')
-rw-r--r-- | nn/layer_dag.lua | 23 | ||||
-rw-r--r-- | nn/layer_repo.lua | 4 | ||||
-rw-r--r-- | nn/param_repo.lua | 70 |
3 files changed, 73 insertions, 24 deletions
diff --git a/nn/layer_dag.lua b/nn/layer_dag.lua index 2dda7c9..8e30216 100644 --- a/nn/layer_dag.lua +++ b/nn/layer_dag.lua @@ -85,13 +85,14 @@ function DAGLayer:__init(id, global_conf, layer_conf) end end + -- topology sort local queue = {} local l = 1 local r = 1 for id, ref in pairs(layers) do if ref.in_deg == 0 then table.insert(queue, ref) - nerv.utils.printf("adding source layer: %s\n", id) + nerv.info("adding source layer: %s", id) r = r + 1 end end @@ -111,13 +112,13 @@ function DAGLayer:__init(id, global_conf, layer_conf) end end for i = 1, #queue do - nerv.utils.printf("queued layer: %s\n", queue[i].layer.id) + nerv.info("enqueued layer: %s", queue[i].layer.id) end for id, ref in pairs(layers) do -- check wether the graph is connected if ref.visited == false then - nerv.utils.printf("warning: layer %s is ignored\n", id) + nerv.warning("layer %s is ignored", id) end end @@ -131,7 +132,7 @@ function DAGLayer:__init(id, global_conf, layer_conf) self.gconf = global_conf end -function DAGLayer:init(batch_size) -- topology sort +function DAGLayer:init(batch_size) for i, conn in ipairs(self.parsed_conn) do local _, output_dim local ref_from, port_from, ref_to, port_to @@ -160,7 +161,7 @@ function DAGLayer:init(batch_size) -- topology sort end end -- initialize sub layers - ref.layer:init() + ref.layer:init(batch_size) end for i = 1, #self.dim_in do if self.inputs[i] == nil then @@ -227,7 +228,7 @@ function DAGLayer:propagate(input, output) end end -function DAGLayer:back_propagate(next_bp_err, bp_err, input, output) +function DAGLayer:back_propagate(bp_err, next_bp_err, input, output) self:set_err_outputs(next_bp_err) self:set_err_inputs(bp_err) self:set_inputs(input) @@ -235,16 +236,14 @@ function DAGLayer:back_propagate(next_bp_err, bp_err, input, output) for i = #self.queue, 1, -1 do local ref = self.queue[i] -- print(ref.layer.id) - ref.layer:back_propagate(ref.err_outputs, ref.err_inputs, ref.inputs, ref.outputs) + ref.layer:back_propagate(ref.err_inputs, ref.err_outputs, ref.inputs, ref.outputs) end end function DAGLayer:get_params() - local res = {} + local param_repos = {} for id, ref in pairs(self.queue) do - for i, p in ipairs(ref.layer:get_params()) do - table.insert(res, p) - end + table.insert(param_repos, ref.layer:get_params()) end - return res + return nerv.ParamRepo.merge(param_repos) end diff --git a/nn/layer_repo.lua b/nn/layer_repo.lua index b1d2248..602c37c 100644 --- a/nn/layer_repo.lua +++ b/nn/layer_repo.lua @@ -8,7 +8,7 @@ function LayerRepo:__init(layer_spec, param_repo, global_conf) if layers[id] ~= nil then nerv.error("a layer with id %s already exists", id) end - nerv.utils.printf("id: %s\n", id) + nerv.info("create layer: %s", id) if type(spec[2]) ~= "table" then nerv.error("layer config table is need") end @@ -17,7 +17,7 @@ function LayerRepo:__init(layer_spec, param_repo, global_conf) nerv.error("parameter description table is needed") end for pname, pid in pairs(spec[1]) do - layer_config[pname] = param_repo:get_param(pid, global_conf) + layer_config[pname] = param_repo:get_param(pid) end layers[id] = layer_type(id, global_conf, layer_config) end diff --git a/nn/param_repo.lua b/nn/param_repo.lua index 3e37c31..ab971ba 100644 --- a/nn/param_repo.lua +++ b/nn/param_repo.lua @@ -1,26 +1,76 @@ local ParamRepo = nerv.class("nerv.ParamRepo") +function ParamRepo:__init(plist) + self.params = {} + if plist ~= nil then + for i, p in ipairs(plist) do + self.params[p.id] = p + end + end +end + +function ParamRepo:add(pid, p) + if self.params[pid] ~= nil then + nerv.error("duplicate params with the same id: %s", pid) + end + self.params[pid] = p +end -function ParamRepo:__init(param_files) - local param_table = {} +function ParamRepo:remove(pid, p) + if self.params[pid] == nil then + nerv.error("param %s does not exit", pid) + end + table.remove(self.params, pid) +end + +function ParamRepo.merge(repos) + local self = nerv.ParamRepo() + for i, repo in ipairs(repos) do + if not nerv.is_type(repo, "nerv.ParamRepo") then + nerv.error("nerv.ParamRepo objects expected, got %s", repo) + end + for pid, p in pairs(repo.params) do + self:add(pid, p) + end + end + return self +end + +function ParamRepo:import(param_files, pids, gconf) if type(param_files) ~= "table" then nerv.error("param file table is need") end for i = 1, #param_files do local pf = nerv.ChunkFile(param_files[i], "r") for cid, cspec in pairs(pf.metadata) do - if param_table[cid] ~= nil then - nerv.error("conflicting chunk id in param files") + if pids == nil or pids[cid] ~= nil then + local p = pf:read_chunk(cid, gconf) + if not nerv.is_type(p, "nerv.Param") then + nerv.error("param chunk is expected") + end + self:add(cid, p) end - param_table[cid] = pf end end - self.param_table = param_table end -function ParamRepo:get_param(pid, global_conf) - local pf = self.param_table[pid] - if pf == nil then +function ParamRepo:export(param_file, pids) + cf = nerv.ChunkFile(param_file, "w") + if pids == nil then + for id, p in pairs(self.params) do + cf:write_chunk(p) + end + else + for i, pid in ipairs(pids) do + cf:write_chunk(self:get_param(pid)) + end + end + cf:close() +end + +function ParamRepo:get_param(pid) + local p = self.params[pid] + if p == nil then nerv.error("param with id %s not found", pid) end - return pf:read_chunk(pid, global_conf) + return p end |