aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorDeterminant <[email protected]>2016-03-26 15:23:58 +0800
committerDeterminant <[email protected]>2016-03-26 15:23:58 +0800
commit86dbfcfd490ce3f8fd4591b0950fbea7f1826c70 (patch)
treeb14298e8a020ab110af8cf667e1bb7c01bea693c
parent38a2afc7d9c50859e99e09f4f64af3a4254f6f37 (diff)
fix "not implemented" and lstm rebinding bugsalpha-3.1
-rw-r--r--nerv/examples/asr_trainer.lua3
-rw-r--r--nerv/layer/duplicate.lua4
-rw-r--r--nerv/layer/graph.lua7
-rw-r--r--nerv/layer/identity.lua4
-rw-r--r--nerv/layer/init.lua2
-rw-r--r--nerv/layer/lstm.lua1
-rw-r--r--nerv/layer/rnn.lua1
-rw-r--r--nerv/nn/layer_repo.lua17
8 files changed, 34 insertions, 5 deletions
diff --git a/nerv/examples/asr_trainer.lua b/nerv/examples/asr_trainer.lua
index aa1019d..38ba6e9 100644
--- a/nerv/examples/asr_trainer.lua
+++ b/nerv/examples/asr_trainer.lua
@@ -196,6 +196,7 @@ local trainer_defaults = {
max_iter = 20,
min_halving = 5,
do_halving = false,
+ keep_halving = false,
cumat_tname = "nerv.CuMatrixFloat",
mmat_tname = "nerv.MMatrixFloat",
debug = false,
@@ -294,6 +295,8 @@ for i = gconf.cur_iter, gconf.max_iter do
if gconf.accu_best - accu_prev < gconf.start_halving_inc and
i >= gconf.min_halving then
gconf.do_halving = true
+ elseif not gconf.keep_halving then
+ gconf.do_halving = false
end
if gconf.do_halving then
gconf.lrate = gconf.lrate * gconf.halving_factor
diff --git a/nerv/layer/duplicate.lua b/nerv/layer/duplicate.lua
index 2621cdf..3f38579 100644
--- a/nerv/layer/duplicate.lua
+++ b/nerv/layer/duplicate.lua
@@ -13,6 +13,10 @@ function DuplicateLayer:__init(id, global_conf, layer_conf)
end
end
+function DuplicateLayer:bind_params()
+ -- do nothing
+end
+
function DuplicateLayer:init()
end
diff --git a/nerv/layer/graph.lua b/nerv/layer/graph.lua
index ddbc85e..5790f95 100644
--- a/nerv/layer/graph.lua
+++ b/nerv/layer/graph.lua
@@ -2,7 +2,8 @@ local GraphLayer = nerv.class('nerv.GraphLayer', 'nerv.Layer')
function GraphLayer:__init(id, global_conf, layer_conf)
nerv.Layer.__init(self, id, global_conf, layer_conf)
- self:graph_init(layer_conf.layer_repo, layer_conf.connections)
+ self.lrepo = layer_conf.layer_repo
+ self:graph_init(self.lrepo, layer_conf.connections)
end
local function parse_id(str)
@@ -164,3 +165,7 @@ function GraphLayer:get_params()
end
return nerv.ParamRepo.merge(param_repos, self.loc_type)
end
+
+function GraphLayer:bind_params()
+ self.lrepo:rebind(self.lconf.pr)
+end
diff --git a/nerv/layer/identity.lua b/nerv/layer/identity.lua
index d56337d..a7ba8b2 100644
--- a/nerv/layer/identity.lua
+++ b/nerv/layer/identity.lua
@@ -28,3 +28,7 @@ end
function IdentityLayer:get_params()
return nerv.ParamRepo({}, self.loc_type)
end
+
+function IdentityLayer:bind_params()
+ -- do nothing
+end
diff --git a/nerv/layer/init.lua b/nerv/layer/init.lua
index 475ef62..d266773 100644
--- a/nerv/layer/init.lua
+++ b/nerv/layer/init.lua
@@ -113,7 +113,7 @@ function Layer:find_param(plist, lconf, gconf, p_type, p_dim)
end
pid = self.id .. '_' .. plist[1]
if lconf.pr:has_param(pid) then
- nerv.info("param id for [%s] of layer [%s] is generated automatically.", pname, self.id)
+ nerv.info("param id for [%s] of layer [%s] is generated automatically.", plist[1], self.id)
return lconf.pr:get_param(pid)
end
nerv.info("param id for [%s] of layer [%s] is not found in the specified param repo, " ..
diff --git a/nerv/layer/lstm.lua b/nerv/layer/lstm.lua
index 56f674a..3de3453 100644
--- a/nerv/layer/lstm.lua
+++ b/nerv/layer/lstm.lua
@@ -81,5 +81,6 @@ function LSTMLayer:__init(id, global_conf, layer_conf)
self:add_prefix(layers, connections)
local layer_repo = nerv.LayerRepo(layers, pr, global_conf)
+ self.lrepo = layer_repo
self:graph_init(layer_repo, connections)
end
diff --git a/nerv/layer/rnn.lua b/nerv/layer/rnn.lua
index aad2b94..fd6e753 100644
--- a/nerv/layer/rnn.lua
+++ b/nerv/layer/rnn.lua
@@ -44,5 +44,6 @@ function RNNLayer:__init(id, global_conf, layer_conf)
self:add_prefix(layers, connections)
local layer_repo = nerv.LayerRepo(layers, pr, global_conf)
+ self.lrepo = layer_repo
self:graph_init(layer_repo, connections)
end
diff --git a/nerv/nn/layer_repo.lua b/nerv/nn/layer_repo.lua
index acef54a..647aac9 100644
--- a/nerv/nn/layer_repo.lua
+++ b/nerv/nn/layer_repo.lua
@@ -29,10 +29,21 @@ function LayerRepo:add_layers(layer_spec, param_repo, global_conf)
end
function LayerRepo:rebind(param_repo)
- for id, layer in pairs(self.layers) do
- layer.lconf.pr = param_repo
- layer:bind_params()
+ if self.__rebinding then
+ return
end
+ self.__rebinding = true
+ for _, layer in pairs(self.layers) do
+ if not layer.__already_rebound then
+ layer.__already_rebound = true
+ layer.lconf.pr = param_repo
+ layer:bind_params()
+ end
+ end
+ for _, layer in pairs(self.layers) do
+ layer.__already_rebound = false
+ end
+ self.__rebinding = false
end
function LayerRepo:get_layer(lid)