aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authortxh18 <[email protected]>2015-11-27 19:58:16 +0800
committertxh18 <[email protected]>2015-11-27 19:58:16 +0800
commitf0ac603cbfc5bbb95dad885d35822f0f747b0ab2 (patch)
treed87781b6bd46685a31350cff4939e75aa706b528
parentc1d48a64432245fa19527816969e43a368728013 (diff)
lstm_tnn can be run, todo:testing
-rw-r--r--nerv/examples/lmptb/lstmlm_ptb_main.lua17
-rw-r--r--nerv/examples/lmptb/tnn/layer_dag_t.lua7
-rw-r--r--nerv/examples/lmptb/tnn/layersT/lstm_t.lua12
3 files changed, 22 insertions, 14 deletions
diff --git a/nerv/examples/lmptb/lstmlm_ptb_main.lua b/nerv/examples/lmptb/lstmlm_ptb_main.lua
index d3f38a2..42b541f 100644
--- a/nerv/examples/lmptb/lstmlm_ptb_main.lua
+++ b/nerv/examples/lmptb/lstmlm_ptb_main.lua
@@ -129,10 +129,17 @@ function prepare_tnn(global_conf, layerRepo)
--input: input_w, input_w, ... input_w_now, last_activation
local connections_t = {
{"<input>[1]", "selectL1[1]", 0},
- {"selectL1[1]", "recurrentL1[1]", 0},
- {"recurrentL1[1]", "sigmoidL1[1]", 0},
- {"sigmoidL1[1]", "combinerL1[1]", 0},
- {"combinerL1[1]", "recurrentL1[2]", 1},
+
+ --{"selectL1[1]", "recurrentL1[1]", 0},
+ --{"recurrentL1[1]", "sigmoidL1[1]", 0},
+ --{"sigmoidL1[1]", "combinerL1[1]", 0},
+ --{"combinerL1[1]", "recurrentL1[2]", 1},
+
+ {"selectL1[1]", "lstmL1[1]", 0},
+ {"lstmL1[2]", "lstmL1[3]", 1},
+ {"lstmL1[1]", "combinerL1[1]", 0},
+ {"combinerL1[1]", "lstmL1[2]", 1},
+
{"combinerL1[2]", "outputL[1]", 0},
{"outputL[1]", "softmaxL[1]", 0},
{"<input>[2]", "softmaxL[2]", 0},
@@ -268,7 +275,7 @@ if (arg[2] ~= nil) then
loadstring(arg[2])()
nerv.LMUtil.wait(0.5)
else
- printf("%s not user setting, all default...\n", global_conf.sche_log_pre)
+ printf("%s no user setting, all default...\n", global_conf.sche_log_pre)
end
global_conf.work_dir = global_conf.work_dir_base .. 'h' .. global_conf.hidden_size .. 'ch' .. global_conf.chunk_size .. 'ba' .. global_conf.batch_size .. 'slr' .. global_conf.lrate .. 'wc' .. global_conf.wcost
diff --git a/nerv/examples/lmptb/tnn/layer_dag_t.lua b/nerv/examples/lmptb/tnn/layer_dag_t.lua
index ade65cc..e3a9316 100644
--- a/nerv/examples/lmptb/tnn/layer_dag_t.lua
+++ b/nerv/examples/lmptb/tnn/layer_dag_t.lua
@@ -2,7 +2,7 @@ local DAGLayerT = nerv.class("nerv.DAGLayerT", "nerv.LayerT")
local function parse_id(str)
local id, port, _
- _, _, id, port = string.find(str, "([a-zA-Z0-9_]+)%[([0-9]+)%]")
+ _, _, id, port = string.find(str, "([a-zA-Z0-9_.]+)%[([0-9]+)%]")
if id == nil or port == nil then
_, _, id, port = string.find(str, "(.+)%[([0-9]+)%]")
if not (id == "<input>" or id == "<output>") then
@@ -142,7 +142,7 @@ function DAGLayerT:__init(id, global_conf, layer_conf)
end
function DAGLayerT:init(batch_size, chunk_size)
- nerv.info("initing DAGLayerT...\n")
+ nerv.info("initing DAGLayerT %s...\n", self.id)
if chunk_size == nil then
chunk_size = 1
nerv.info("(Initing DAGLayerT) chunk_size is nil, setting it to default 1\n")
@@ -321,7 +321,7 @@ function DAGLayerT:update(bp_err, input, output, t)
end
end
-function DAGLayerT:propagate(input, output)
+function DAGLayerT:propagate(input, output, t)
if t == nil then
t = 1
end
@@ -330,6 +330,7 @@ function DAGLayerT:propagate(input, output)
local ret = false
for i = 1, #self.queue do
local ref = self.queue[i]
+ --print("debug DAGLAyerT:propagate", ref.id, t)
ret = ref.layer:propagate(ref.inputs[t], ref.outputs[t], t)
end
return ret
diff --git a/nerv/examples/lmptb/tnn/layersT/lstm_t.lua b/nerv/examples/lmptb/tnn/layersT/lstm_t.lua
index d7d8a20..409c617 100644
--- a/nerv/examples/lmptb/tnn/layersT/lstm_t.lua
+++ b/nerv/examples/lmptb/tnn/layersT/lstm_t.lua
@@ -68,14 +68,14 @@ function LSTMLayerT:__init(id, global_conf, layer_conf)
[ap("inputHDup[3]")] = ap("forgetGateL[2]"),
[ap("inputCDup[3]")] = ap("forgetGateL[3]"),
- [ap("mainTanhL[1]")] = ap("inputGMul[1]"),
- [ap("inputGateL[1]")] = ap("inputGMul[2]"),
+ [ap("mainTanhL[1]")] = ap("inputGMulL[1]"),
+ [ap("inputGateL[1]")] = ap("inputGMulL[2]"),
- [ap("inputCDup[4]")] = ap("forgetGMul[1]"),
- [ap("forgetGateL[1]")] = ap("forgetGMul[2]"),
+ [ap("inputCDup[4]")] = ap("forgetGMulL[1]"),
+ [ap("forgetGateL[1]")] = ap("forgetGMulL[2]"),
- [ap("inputGMul[1]")] = ap("mainCDup[1]"),
- [ap("forgetGMul[1]")] = ap("mainCDup[2]"),
+ [ap("inputGMulL[1]")] = ap("mainCDup[1]"),
+ [ap("forgetGMulL[1]")] = ap("mainCDup[2]"),
[ap("mainCDup[2]")] = "<output>[2]",
[ap("mainCDup[1]")] = ap("outputTanhL[1]"),