aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authortxh18 <cloudygooseg@gmail.com>2015-11-30 19:50:04 +0800
committertxh18 <cloudygooseg@gmail.com>2015-11-30 19:50:04 +0800
commit64256d367cf575fb61d666bdf0b9285dfdb4db25 (patch)
treed389fe873857f9fcbc52440fd1e73f2c605852a9
parente6ea10bd32cef61565206358a104d1b17ba162f7 (diff)
bug fix for lstm_t layer, t not inclueded in propagate!
-rw-r--r--nerv/examples/lmptb/lstmlm_ptb_main.lua2
-rw-r--r--nerv/examples/lmptb/tnn/layersT/lstm_t.lua12
2 files changed, 7 insertions, 7 deletions
diff --git a/nerv/examples/lmptb/lstmlm_ptb_main.lua b/nerv/examples/lmptb/lstmlm_ptb_main.lua
index 7ec583d..ca9530b 100644
--- a/nerv/examples/lmptb/lstmlm_ptb_main.lua
+++ b/nerv/examples/lmptb/lstmlm_ptb_main.lua
@@ -184,7 +184,7 @@ test_fn = data_dir .. '/ptb.test.txt.adds'
vocab_fn = data_dir .. '/vocab'
global_conf = {
- lrate = 0.1, wcost = 1e-6, momentum = 0, clip_t = 10,
+ lrate = 0.1, wcost = 1e-5, momentum = 0, clip_t = 10,
cumat_type = nerv.CuMatrixFloat,
mmat_type = nerv.MMatrixFloat,
nn_act_default = 0,
diff --git a/nerv/examples/lmptb/tnn/layersT/lstm_t.lua b/nerv/examples/lmptb/tnn/layersT/lstm_t.lua
index 0bd9c76..ded6058 100644
--- a/nerv/examples/lmptb/tnn/layersT/lstm_t.lua
+++ b/nerv/examples/lmptb/tnn/layersT/lstm_t.lua
@@ -108,16 +108,16 @@ function LSTMLayerT:batch_resize(batch_size, chunk_size)
self.dagL:batch_resize(batch_size, chunk_size)
end
-function LSTMLayerT:update(bp_err, input, output)
- self.dagL:update(bp_err, input, output)
+function LSTMLayerT:update(bp_err, input, output, t)
+ self.dagL:update(bp_err, input, output, t)
end
-function LSTMLayerT:propagate(input, output)
- self.dagL:propagate(input, output)
+function LSTMLayerT:propagate(input, output, t)
+ self.dagL:propagate(input, output, t)
end
-function LSTMLayerT:back_propagate(bp_err, next_bp_err, input, output)
- self.dagL:back_propagate(bp_err, next_bp_err, input, output)
+function LSTMLayerT:back_propagate(bp_err, next_bp_err, input, output, t)
+ self.dagL:back_propagate(bp_err, next_bp_err, input, output, t)
end
function LSTMLayerT:get_params()