diff options
-rw-r--r-- | nerv/examples/lmptb/lstmlm_ptb_main.lua | 17 | ||||
-rw-r--r-- | nerv/examples/lmptb/tnn/layer_dag_t.lua | 7 | ||||
-rw-r--r-- | nerv/examples/lmptb/tnn/layersT/lstm_t.lua | 12 |
3 files changed, 22 insertions, 14 deletions
diff --git a/nerv/examples/lmptb/lstmlm_ptb_main.lua b/nerv/examples/lmptb/lstmlm_ptb_main.lua index d3f38a2..42b541f 100644 --- a/nerv/examples/lmptb/lstmlm_ptb_main.lua +++ b/nerv/examples/lmptb/lstmlm_ptb_main.lua @@ -129,10 +129,17 @@ function prepare_tnn(global_conf, layerRepo) --input: input_w, input_w, ... input_w_now, last_activation local connections_t = { {"<input>[1]", "selectL1[1]", 0}, - {"selectL1[1]", "recurrentL1[1]", 0}, - {"recurrentL1[1]", "sigmoidL1[1]", 0}, - {"sigmoidL1[1]", "combinerL1[1]", 0}, - {"combinerL1[1]", "recurrentL1[2]", 1}, + + --{"selectL1[1]", "recurrentL1[1]", 0}, + --{"recurrentL1[1]", "sigmoidL1[1]", 0}, + --{"sigmoidL1[1]", "combinerL1[1]", 0}, + --{"combinerL1[1]", "recurrentL1[2]", 1}, + + {"selectL1[1]", "lstmL1[1]", 0}, + {"lstmL1[2]", "lstmL1[3]", 1}, + {"lstmL1[1]", "combinerL1[1]", 0}, + {"combinerL1[1]", "lstmL1[2]", 1}, + {"combinerL1[2]", "outputL[1]", 0}, {"outputL[1]", "softmaxL[1]", 0}, {"<input>[2]", "softmaxL[2]", 0}, @@ -268,7 +275,7 @@ if (arg[2] ~= nil) then loadstring(arg[2])() nerv.LMUtil.wait(0.5) else - printf("%s not user setting, all default...\n", global_conf.sche_log_pre) + printf("%s no user setting, all default...\n", global_conf.sche_log_pre) end global_conf.work_dir = global_conf.work_dir_base .. 'h' .. global_conf.hidden_size .. 'ch' .. global_conf.chunk_size .. 'ba' .. global_conf.batch_size .. 'slr' .. global_conf.lrate .. 'wc' .. global_conf.wcost diff --git a/nerv/examples/lmptb/tnn/layer_dag_t.lua b/nerv/examples/lmptb/tnn/layer_dag_t.lua index ade65cc..e3a9316 100644 --- a/nerv/examples/lmptb/tnn/layer_dag_t.lua +++ b/nerv/examples/lmptb/tnn/layer_dag_t.lua @@ -2,7 +2,7 @@ local DAGLayerT = nerv.class("nerv.DAGLayerT", "nerv.LayerT") local function parse_id(str) local id, port, _ - _, _, id, port = string.find(str, "([a-zA-Z0-9_]+)%[([0-9]+)%]") + _, _, id, port = string.find(str, "([a-zA-Z0-9_.]+)%[([0-9]+)%]") if id == nil or port == nil then _, _, id, port = string.find(str, "(.+)%[([0-9]+)%]") if not (id == "<input>" or id == "<output>") then @@ -142,7 +142,7 @@ function DAGLayerT:__init(id, global_conf, layer_conf) end function DAGLayerT:init(batch_size, chunk_size) - nerv.info("initing DAGLayerT...\n") + nerv.info("initing DAGLayerT %s...\n", self.id) if chunk_size == nil then chunk_size = 1 nerv.info("(Initing DAGLayerT) chunk_size is nil, setting it to default 1\n") @@ -321,7 +321,7 @@ function DAGLayerT:update(bp_err, input, output, t) end end -function DAGLayerT:propagate(input, output) +function DAGLayerT:propagate(input, output, t) if t == nil then t = 1 end @@ -330,6 +330,7 @@ function DAGLayerT:propagate(input, output) local ret = false for i = 1, #self.queue do local ref = self.queue[i] + --print("debug DAGLAyerT:propagate", ref.id, t) ret = ref.layer:propagate(ref.inputs[t], ref.outputs[t], t) end return ret diff --git a/nerv/examples/lmptb/tnn/layersT/lstm_t.lua b/nerv/examples/lmptb/tnn/layersT/lstm_t.lua index d7d8a20..409c617 100644 --- a/nerv/examples/lmptb/tnn/layersT/lstm_t.lua +++ b/nerv/examples/lmptb/tnn/layersT/lstm_t.lua @@ -68,14 +68,14 @@ function LSTMLayerT:__init(id, global_conf, layer_conf) [ap("inputHDup[3]")] = ap("forgetGateL[2]"), [ap("inputCDup[3]")] = ap("forgetGateL[3]"), - [ap("mainTanhL[1]")] = ap("inputGMul[1]"), - [ap("inputGateL[1]")] = ap("inputGMul[2]"), + [ap("mainTanhL[1]")] = ap("inputGMulL[1]"), + [ap("inputGateL[1]")] = ap("inputGMulL[2]"), - [ap("inputCDup[4]")] = ap("forgetGMul[1]"), - [ap("forgetGateL[1]")] = ap("forgetGMul[2]"), + [ap("inputCDup[4]")] = ap("forgetGMulL[1]"), + [ap("forgetGateL[1]")] = ap("forgetGMulL[2]"), - [ap("inputGMul[1]")] = ap("mainCDup[1]"), - [ap("forgetGMul[1]")] = ap("mainCDup[2]"), + [ap("inputGMulL[1]")] = ap("mainCDup[1]"), + [ap("forgetGMulL[1]")] = ap("mainCDup[2]"), [ap("mainCDup[2]")] = "<output>[2]", [ap("mainCDup[1]")] = ap("outputTanhL[1]"), |