aboutsummaryrefslogtreecommitdiff
path: root/nerv/examples/lmptb/lmptb/layer/lm_affine_recurrent.lua
diff options
context:
space:
mode:
Diffstat (limited to 'nerv/examples/lmptb/lmptb/layer/lm_affine_recurrent.lua')
-rw-r--r--nerv/examples/lmptb/lmptb/layer/lm_affine_recurrent.lua6
1 files changed, 3 insertions, 3 deletions
diff --git a/nerv/examples/lmptb/lmptb/layer/lm_affine_recurrent.lua b/nerv/examples/lmptb/lmptb/layer/lm_affine_recurrent.lua
index f1eb4a1..c43e567 100644
--- a/nerv/examples/lmptb/lmptb/layer/lm_affine_recurrent.lua
+++ b/nerv/examples/lmptb/lmptb/layer/lm_affine_recurrent.lua
@@ -1,4 +1,4 @@
-local LMRecurrent = nerv.class('nerv.LMAffineRecurrentLayer', 'nerv.AffineRecurrentLayer') --breaks at sentence end, when </s> is met, input will be set to zero
+local LMRecurrent = nerv.class('nerv.IndRecurrentLayer', 'nerv.AffineRecurrentLayer') --breaks at sentence end, when </s> is met, input will be set to zero
--id: string
--global_conf: table
@@ -11,10 +11,10 @@ function LMRecurrent:__init(id, global_conf, layer_conf)
end
function LMRecurrent:propagate(input, output)
- output[1]:mul(input[1], self.ltp_ih.trans, 1.0, 0.0, 'N', 'N')
+ output[1]:copy_fromd(input[1])
if (self.independent == true) then
for i = 1, input[1]:nrow() do
- if (input[1][i - 1][self.break_id - 1] > 0.1) then --here is sentence break
+ if (self.gconf.input_word_id[self.id][i - 1][0] == self.break_id) then --here is sentence break
input[2][i - 1]:fill(0)
end
end