aboutsummaryrefslogtreecommitdiff
path: root/nerv/examples/lmptb/rnn/tnn.lua
diff options
context:
space:
mode:
Diffstat (limited to 'nerv/examples/lmptb/rnn/tnn.lua')
-rw-r--r--nerv/examples/lmptb/rnn/tnn.lua42
1 files changed, 26 insertions, 16 deletions
diff --git a/nerv/examples/lmptb/rnn/tnn.lua b/nerv/examples/lmptb/rnn/tnn.lua
index d6bf42e..c2e397c 100644
--- a/nerv/examples/lmptb/rnn/tnn.lua
+++ b/nerv/examples/lmptb/rnn/tnn.lua
@@ -58,7 +58,7 @@ nerv.TNN.FC.HAS_INPUT = 1
nerv.TNN.FC.HAS_LABEL = 2
nerv.TNN.FC.SEQ_NORM = bit.bor(nerv.TNN.FC.HAS_INPUT, nerv.TNN.FC.HAS_LABEL) --This instance have both input and label
-function TNN.makeInitialStore(st, p, dim, batch_size, chunk_size, global_conf, st_c, p_c, t_c)
+function TNN.make_initial_store(st, p, dim, batch_size, chunk_size, global_conf, st_c, p_c, t_c)
--Return a table of matrix storage from time (1-chunk_size)..(2*chunk_size)
if (type(st) ~= "table") then
nerv.error("st should be a table")
@@ -78,7 +78,7 @@ function TNN.makeInitialStore(st, p, dim, batch_size, chunk_size, global_conf, s
end
end
-function TNN:outOfFeedRange(t) --out of chunk, or no input, for the current feed
+function TNN:out_of_feedrange(t) --out of chunk, or no input, for the current feed
if (t < 1 or t > self.chunk_size) then
return true
end
@@ -165,9 +165,9 @@ function TNN:init(batch_size, chunk_size)
print("TNN initing storage", ref_from.layer.id, "->", ref_to.layer.id)
ref_to.inputs_matbak_p[port_to] = self.gconf.cumat_type(batch_size, dim)
- self.makeInitialStore(ref_from.outputs_m, port_from, dim, batch_size, chunk_size, self.gconf, ref_to.inputs_m, port_to, time)
+ self.make_initial_store(ref_from.outputs_m, port_from, dim, batch_size, chunk_size, self.gconf, ref_to.inputs_m, port_to, time)
ref_from.err_inputs_matbak_p[port_from] = self.gconf.cumat_type(batch_size, dim)
- self.makeInitialStore(ref_from.err_inputs_m, port_from, dim, batch_size, chunk_size, self.gconf, ref_to.err_outputs_m, port_to, time)
+ self.make_initial_store(ref_from.err_inputs_m, port_from, dim, batch_size, chunk_size, self.gconf, ref_to.err_outputs_m, port_to, time)
end
@@ -176,8 +176,8 @@ function TNN:init(batch_size, chunk_size)
for i = 1, #self.dim_out do --Init storage for output ports
local ref = self.outputs_p[i].ref
local p = self.outputs_p[i].port
- self.makeInitialStore(ref.outputs_m, p, self.dim_out[i], batch_size, chunk_size, self.gconf, self.outputs_m, i, 0)
- self.makeInitialStore(ref.err_inputs_m, p, self.dim_out[i], batch_size, chunk_size, self.gconf, self.err_inputs_m, i, 0)
+ self.make_initial_store(ref.outputs_m, p, self.dim_out[i], batch_size, chunk_size, self.gconf, self.outputs_m, i, 0)
+ self.make_initial_store(ref.err_inputs_m, p, self.dim_out[i], batch_size, chunk_size, self.gconf, self.err_inputs_m, i, 0)
end
self.inputs_m = {}
@@ -185,8 +185,8 @@ function TNN:init(batch_size, chunk_size)
for i = 1, #self.dim_in do --Init storage for input ports
local ref = self.inputs_p[i].ref
local p = self.inputs_p[i].port
- self.makeInitialStore(ref.inputs_m, p, self.dim_in[i], batch_size, chunk_size, self.gconf, self.inputs_m, i, 0)
- self.makeInitialStore(ref.err_outputs_m, p, self.dim_in[i], batch_size, chunk_size, self.gconf, self.err_outputs_m, i, 0)
+ self.make_initial_store(ref.inputs_m, p, self.dim_in[i], batch_size, chunk_size, self.gconf, self.inputs_m, i, 0)
+ self.make_initial_store(ref.err_outputs_m, p, self.dim_in[i], batch_size, chunk_size, self.gconf, self.err_outputs_m, i, 0)
end
for id, ref in pairs(self.layers) do --Calling init for child layers
@@ -285,17 +285,27 @@ end
--reader: some reader
--Returns: bool, whether has new feed
--Returns: feeds, a table that will be filled with the reader's feeds
-function TNN:getFeedFromReader(reader)
+function TNN:getfeed_from_reader(reader)
local feeds_now = self.feeds_now
local got_new = reader:get_batch(feeds_now)
return got_new, feeds_now
end
-function TNN:moveRightToNextMB() --move output history activations of 1..chunk_size to 1-chunk_size..0
- for t = 1, self.chunk_size, 1 do
+function TNN:move_right_to_nextmb(list_t) --move output history activations of 1..chunk_size to 1-chunk_size..0
+ if list_t == nil then
+ list_t = {}
+ for i = 1, self.chunk_size do
+ list_t[i] = i - self.chunk_size
+ end
+ end
+ for i = 1, #list_t do
+ t = list_t[i]
+ if t < 1 - self.chunk_size or t > 0 then
+ nerv.error("MB move range error")
+ end
for id, ref in pairs(self.layers) do
for p = 1, #ref.dim_out do
- ref.outputs_m[t - self.chunk_size][p]:copy_fromd(ref.outputs_m[t][p])
+ ref.outputs_m[t][p]:copy_fromd(ref.outputs_m[t + self.chunk_size][p])
end
end
end
@@ -345,7 +355,7 @@ end
--ref: the TNN_ref of a layer
--t: the current time to propagate
function TNN:propagate_dfs(ref, t)
- if (self:outOfFeedRange(t)) then
+ if (self:out_of_feedrange(t)) then
return
end
if (ref.outputs_b[t][1] == true) then --already propagated, 1 is just a random port
@@ -357,7 +367,7 @@ function TNN:propagate_dfs(ref, t)
local flag = true --whether have all inputs
for _, conn in pairs(ref.i_conns_p) do
local p = conn.dst.port
- if (not (ref.inputs_b[t][p] or self:outOfFeedRange(t - conn.time))) then
+ if (not (ref.inputs_b[t][p] or self:out_of_feedrange(t - conn.time))) then
flag = false
break
end
@@ -465,7 +475,7 @@ end
--ref: the TNN_ref of a layer
--t: the current time to propagate
function TNN:backpropagate_dfs(ref, t, do_update)
- if (self:outOfFeedRange(t)) then
+ if (self:out_of_feedrange(t)) then
return
end
if (ref.err_outputs_b[t][1] == true) then --already back_propagated, 1 is just a random port
@@ -477,7 +487,7 @@ function TNN:backpropagate_dfs(ref, t, do_update)
local flag = true --whether have all inputs
for _, conn in pairs(ref.o_conns_p) do
local p = conn.src.port
- if (not (ref.err_inputs_b[t][p] or self:outOfFeedRange(t + conn.time))) then
+ if (not (ref.err_inputs_b[t][p] or self:out_of_feedrange(t + conn.time))) then
flag = false
break
end