summaryrefslogtreecommitdiff
path: root/nerv/examples/lmptb/rnn
diff options
context:
space:
mode:
Diffstat (limited to 'nerv/examples/lmptb/rnn')
-rw-r--r--nerv/examples/lmptb/rnn/tnn.lua16
1 files changed, 13 insertions, 3 deletions
diff --git a/nerv/examples/lmptb/rnn/tnn.lua b/nerv/examples/lmptb/rnn/tnn.lua
index d10ab82..c2e397c 100644
--- a/nerv/examples/lmptb/rnn/tnn.lua
+++ b/nerv/examples/lmptb/rnn/tnn.lua
@@ -291,11 +291,21 @@ function TNN:getfeed_from_reader(reader)
return got_new, feeds_now
end
-function TNN:move_right_to_nextmb() --move output history activations of 1..chunk_size to 1-chunk_size..0
- for t = 1, self.chunk_size, 1 do
+function TNN:move_right_to_nextmb(list_t) --move output history activations of 1..chunk_size to 1-chunk_size..0
+ if list_t == nil then
+ list_t = {}
+ for i = 1, self.chunk_size do
+ list_t[i] = i - self.chunk_size
+ end
+ end
+ for i = 1, #list_t do
+ t = list_t[i]
+ if t < 1 - self.chunk_size or t > 0 then
+ nerv.error("MB move range error")
+ end
for id, ref in pairs(self.layers) do
for p = 1, #ref.dim_out do
- ref.outputs_m[t - self.chunk_size][p]:copy_fromd(ref.outputs_m[t][p])
+ ref.outputs_m[t][p]:copy_fromd(ref.outputs_m[t + self.chunk_size][p])
end
end
end