diff options
author | Determinant <ted.sybil@gmail.com> | 2016-03-10 13:40:11 +0800 |
---|---|---|
committer | Determinant <ted.sybil@gmail.com> | 2016-03-10 13:40:11 +0800 |
commit | a32195e3e2ae9ca0f0c7a82e73e6bddb64568c05 (patch) | |
tree | a19f21f8cbadecff7357f9a102f160f5fe699b65 /nerv/nn | |
parent | 4a6872601f05e9ecc059f83fb64a0a4887992b99 (diff) |
major change: clearer param binding semantics; permit rebinding; enable
resuming from previous training
Diffstat (limited to 'nerv/nn')
-rw-r--r-- | nerv/nn/layer_dag.lua | 16 | ||||
-rw-r--r-- | nerv/nn/layer_repo.lua | 30 | ||||
-rw-r--r-- | nerv/nn/param_repo.lua | 59 |
3 files changed, 77 insertions, 28 deletions
diff --git a/nerv/nn/layer_dag.lua b/nerv/nn/layer_dag.lua index 6896878..f999752 100644 --- a/nerv/nn/layer_dag.lua +++ b/nerv/nn/layer_dag.lua @@ -134,20 +134,16 @@ function DAGLayer:__init(id, global_conf, layer_conf) end end + nerv.Layer.__init(self, id, global_conf, layer_conf) self.layers = layers self.inputs = inputs self.outputs = outputs - self.id = id - self.dim_in = dim_in - self.dim_out = dim_out self.parsed_conn = parsed_conn self.queue = queue - self.gconf = global_conf - if self.gconf.use_cpu then - self.mat_type = self.gconf.mmat_type - else - self.mat_type = self.gconf.cumat_type - end +end + +function DAGLayer:bind_params() + -- do nothing (instead of rebinding params for each layer) end function DAGLayer:init(batch_size, chunk_size) @@ -325,7 +321,7 @@ function DAGLayer:get_params() for id, ref in pairs(self.queue) do table.insert(param_repos, ref.layer:get_params()) end - return nerv.ParamRepo.merge(param_repos) + return nerv.ParamRepo.merge(param_repos, self.loc_type) end DAGLayer.PORT_TYPES = { diff --git a/nerv/nn/layer_repo.lua b/nerv/nn/layer_repo.lua index 3d3a79f..acef54a 100644 --- a/nerv/nn/layer_repo.lua +++ b/nerv/nn/layer_repo.lua @@ -12,29 +12,29 @@ function LayerRepo:add_layers(layer_spec, param_repo, global_conf) if layer_type == nil then nerv.error('layer type `%s` not found', ltype) end - for id, spec in pairs(llist) do - if layers[id] ~= nil then - nerv.error("a layer with id %s already exists", id) - end - nerv.info("create layer: %s", id) - if type(spec[2]) ~= "table" then + for id, lconf in pairs(llist) do + if type(lconf) ~= "table" then nerv.error("layer config table is need") end - layer_config = spec[2] - if type(spec[1]) ~= "table" then - nerv.error("parameter description table is needed") - end - for pname, pid in pairs(spec[1]) do - layer_config[pname] = param_repo:get_param(pid) + if lconf.pr == nil then + lconf.pr = param_repo end - if layer_config.pr == nil then - layer_config.pr = param_repo + if layers[id] ~= nil then + nerv.error("a layer with id %s already exists", id) end - layers[id] = layer_type(id, global_conf, layer_config) + nerv.info("create layer: %s", id) + layers[id] = layer_type(id, global_conf, lconf) end end end +function LayerRepo:rebind(param_repo) + for id, layer in pairs(self.layers) do + layer.lconf.pr = param_repo + layer:bind_params() + end +end + function LayerRepo:get_layer(lid) local layer = self.layers[lid] if layer == nil then diff --git a/nerv/nn/param_repo.lua b/nerv/nn/param_repo.lua index c124e08..aba7765 100644 --- a/nerv/nn/param_repo.lua +++ b/nerv/nn/param_repo.lua @@ -1,8 +1,37 @@ local ParamRepo = nerv.class("nerv.ParamRepo") -function ParamRepo:__init(plist) + +ParamRepo.LOC_TYPES = { + ON_DEVICE = {}, + ON_HOST = {} +} + +function ParamRepo:__init(plist, loc_type) self.params = {} + self.loc_type = loc_type or ParamRepo.LOC_TYPES.ON_HOST + local function make_checker(tname) + return function (mat) + if not nerv.is_type(mat, tname) then + nerv.error("unexpected param type in repo specification") + end + end + end + self.make_copier = function (mat_type, copy_method) + return function (mat) + local target = mat_type(mat:nrow(), mat:ncol()) + mat[copy_method](mat, target) + return target + end + end + + if self.loc_type == ParamRepo.LOC_TYPES.ON_HOST then + self.checker = make_checker("nerv.MMatrix") + else + self.checker = make_checker("nerv.CuMatrix") + end + if plist ~= nil then for i, p in ipairs(plist) do + p:check(self.checker) self.params[p.id] = p end end @@ -12,6 +41,7 @@ function ParamRepo:add(pid, p) if self.params[pid] ~= nil then nerv.error("duplicate params with the same id: %s", pid) end + p:check(self.checker) self.params[pid] = p end @@ -22,8 +52,8 @@ function ParamRepo:remove(pid, p) table.remove(self.params, pid) end -function ParamRepo.merge(repos) - local self = nerv.ParamRepo() +function ParamRepo.merge(repos, loc_type) + local self = nerv.ParamRepo(nil, loc_type) for i, repo in ipairs(repos) do if not nerv.is_type(repo, "nerv.ParamRepo") then nerv.error("nerv.ParamRepo objects expected, got %s", repo) @@ -78,3 +108,26 @@ function ParamRepo:get_param(pid) end return p end + +function ParamRepo:copy(loc_type, pids) + local copier + local target = nerv.ParamRepo(nil, loc_type) + if loc_type == nil then + loc_type = self.loc_type + end + if loc_type == ParamRepo.LOC_TYPES.ON_HOST then + copier = self.make_copier(gconf.mmat_type, 'copy_toh') + else + copier = self.make_copier(gconf.cumat_type, 'copy_tod') + end + if pids == nil then + for id, p in pairs(self.params) do + target.params[id] = p:copy(copier) + end + else + for i, pid in ipairs(pids) do + target.params[pid] = self:get_param(pid):copy(copier) + end + end + return target +end |