From 86dbfcfd490ce3f8fd4591b0950fbea7f1826c70 Mon Sep 17 00:00:00 2001 From: Determinant Date: Sat, 26 Mar 2016 15:23:58 +0800 Subject: fix "not implemented" and lstm rebinding bugs --- nerv/examples/asr_trainer.lua | 3 +++ nerv/layer/duplicate.lua | 4 ++++ nerv/layer/graph.lua | 7 ++++++- nerv/layer/identity.lua | 4 ++++ nerv/layer/init.lua | 2 +- nerv/layer/lstm.lua | 1 + nerv/layer/rnn.lua | 1 + nerv/nn/layer_repo.lua | 17 ++++++++++++++--- 8 files changed, 34 insertions(+), 5 deletions(-) diff --git a/nerv/examples/asr_trainer.lua b/nerv/examples/asr_trainer.lua index aa1019d..38ba6e9 100644 --- a/nerv/examples/asr_trainer.lua +++ b/nerv/examples/asr_trainer.lua @@ -196,6 +196,7 @@ local trainer_defaults = { max_iter = 20, min_halving = 5, do_halving = false, + keep_halving = false, cumat_tname = "nerv.CuMatrixFloat", mmat_tname = "nerv.MMatrixFloat", debug = false, @@ -294,6 +295,8 @@ for i = gconf.cur_iter, gconf.max_iter do if gconf.accu_best - accu_prev < gconf.start_halving_inc and i >= gconf.min_halving then gconf.do_halving = true + elseif not gconf.keep_halving then + gconf.do_halving = false end if gconf.do_halving then gconf.lrate = gconf.lrate * gconf.halving_factor diff --git a/nerv/layer/duplicate.lua b/nerv/layer/duplicate.lua index 2621cdf..3f38579 100644 --- a/nerv/layer/duplicate.lua +++ b/nerv/layer/duplicate.lua @@ -13,6 +13,10 @@ function DuplicateLayer:__init(id, global_conf, layer_conf) end end +function DuplicateLayer:bind_params() + -- do nothing +end + function DuplicateLayer:init() end diff --git a/nerv/layer/graph.lua b/nerv/layer/graph.lua index ddbc85e..5790f95 100644 --- a/nerv/layer/graph.lua +++ b/nerv/layer/graph.lua @@ -2,7 +2,8 @@ local GraphLayer = nerv.class('nerv.GraphLayer', 'nerv.Layer') function GraphLayer:__init(id, global_conf, layer_conf) nerv.Layer.__init(self, id, global_conf, layer_conf) - self:graph_init(layer_conf.layer_repo, layer_conf.connections) + self.lrepo = layer_conf.layer_repo + self:graph_init(self.lrepo, layer_conf.connections) end local function parse_id(str) @@ -164,3 +165,7 @@ function GraphLayer:get_params() end return nerv.ParamRepo.merge(param_repos, self.loc_type) end + +function GraphLayer:bind_params() + self.lrepo:rebind(self.lconf.pr) +end diff --git a/nerv/layer/identity.lua b/nerv/layer/identity.lua index d56337d..a7ba8b2 100644 --- a/nerv/layer/identity.lua +++ b/nerv/layer/identity.lua @@ -28,3 +28,7 @@ end function IdentityLayer:get_params() return nerv.ParamRepo({}, self.loc_type) end + +function IdentityLayer:bind_params() + -- do nothing +end diff --git a/nerv/layer/init.lua b/nerv/layer/init.lua index 475ef62..d266773 100644 --- a/nerv/layer/init.lua +++ b/nerv/layer/init.lua @@ -113,7 +113,7 @@ function Layer:find_param(plist, lconf, gconf, p_type, p_dim) end pid = self.id .. '_' .. plist[1] if lconf.pr:has_param(pid) then - nerv.info("param id for [%s] of layer [%s] is generated automatically.", pname, self.id) + nerv.info("param id for [%s] of layer [%s] is generated automatically.", plist[1], self.id) return lconf.pr:get_param(pid) end nerv.info("param id for [%s] of layer [%s] is not found in the specified param repo, " .. diff --git a/nerv/layer/lstm.lua b/nerv/layer/lstm.lua index 56f674a..3de3453 100644 --- a/nerv/layer/lstm.lua +++ b/nerv/layer/lstm.lua @@ -81,5 +81,6 @@ function LSTMLayer:__init(id, global_conf, layer_conf) self:add_prefix(layers, connections) local layer_repo = nerv.LayerRepo(layers, pr, global_conf) + self.lrepo = layer_repo self:graph_init(layer_repo, connections) end diff --git a/nerv/layer/rnn.lua b/nerv/layer/rnn.lua index aad2b94..fd6e753 100644 --- a/nerv/layer/rnn.lua +++ b/nerv/layer/rnn.lua @@ -44,5 +44,6 @@ function RNNLayer:__init(id, global_conf, layer_conf) self:add_prefix(layers, connections) local layer_repo = nerv.LayerRepo(layers, pr, global_conf) + self.lrepo = layer_repo self:graph_init(layer_repo, connections) end diff --git a/nerv/nn/layer_repo.lua b/nerv/nn/layer_repo.lua index acef54a..647aac9 100644 --- a/nerv/nn/layer_repo.lua +++ b/nerv/nn/layer_repo.lua @@ -29,10 +29,21 @@ function LayerRepo:add_layers(layer_spec, param_repo, global_conf) end function LayerRepo:rebind(param_repo) - for id, layer in pairs(self.layers) do - layer.lconf.pr = param_repo - layer:bind_params() + if self.__rebinding then + return end + self.__rebinding = true + for _, layer in pairs(self.layers) do + if not layer.__already_rebound then + layer.__already_rebound = true + layer.lconf.pr = param_repo + layer:bind_params() + end + end + for _, layer in pairs(self.layers) do + layer.__already_rebound = false + end + self.__rebinding = false end function LayerRepo:get_layer(lid) -- cgit v1.2.3