aboutsummaryrefslogtreecommitdiff
path: root/nerv/examples/lmptb/rnn/tnn.lua
diff options
context:
space:
mode:
Diffstat (limited to 'nerv/examples/lmptb/rnn/tnn.lua')
-rw-r--r--nerv/examples/lmptb/rnn/tnn.lua200
1 files changed, 165 insertions, 35 deletions
diff --git a/nerv/examples/lmptb/rnn/tnn.lua b/nerv/examples/lmptb/rnn/tnn.lua
index 8037918..460fcc4 100644
--- a/nerv/examples/lmptb/rnn/tnn.lua
+++ b/nerv/examples/lmptb/rnn/tnn.lua
@@ -34,9 +34,11 @@ local function discover(id, layers, layer_repo)
layer = layer,
inputs_m = {}, --storage for computation, inputs_m[time][port]
inputs_b = {}, --inputs_g[time][port], whether this input can been computed
+ inputs_p_matbak = {}, --which is a back-up space to handle some cross-border computation, inputs_p_matbak[port]
outputs_m = {},
outputs_b = {},
err_inputs_m = {},
+ err_inputs_p_matbak = {}, --which is a back-up space to handle some cross-border computation
err_inputs_b = {},
err_outputs_m = {},
err_outputs_b = {},
@@ -57,26 +59,36 @@ nerv.TNN.FC.HAS_INPUT = 1
nerv.TNN.FC.HAS_LABEL = 2
nerv.TNN.FC.SEQ_NORM = bit.bor(nerv.TNN.FC.HAS_INPUT, nerv.TNN.FC.HAS_LABEL) --This instance have both input and label
-function TNN.makeInitialStore(st, p, dim, batch_size, chunk_size, global_conf, st_c, p_c)
+function TNN.makeInitialStore(st, p, dim, batch_size, chunk_size, global_conf, st_c, p_c, t_c)
--Return a table of matrix storage from time (1-chunk_size)..(2*chunk_size)
if (type(st) ~= "table") then
nerv.error("st should be a table")
end
- for i = 1 - chunk_size, chunk_size * 2 do
+ for i = 1 - chunk_size - 1, chunk_size * 2 + 1 do --intentionally allocated more time, should be [1-chunk_size, chunk_size*2]
if (st[i] == nil) then
st[i] = {}
end
st[i][p] = global_conf.cumat_type(batch_size, dim)
st[i][p]:fill(0)
if (st_c ~= nil) then
- if (st_c[i] == nil) then
- st_c[i] = {}
+ if (st_c[i + t_c] == nil) then
+ st_c[i + t_c] = {}
end
- st_c[i][p_c] = st[i][p]
+ st_c[i + t_c][p_c] = st[i][p]
end
end
end
+function TNN:outOfFeedRange(t) --out of chunk, or no input, for the current feed
+ if (t < 1 or t > self.chunk_size) then
+ return true
+ end
+ if (self.feeds_now.flagsPack_now[t] == 0 or self.feeds_now.flagsPack_now[t] == nil) then
+ return true
+ end
+ return false
+end
+
function TNN:__init(id, global_conf, layer_conf)
local layers = {}
local inputs_p = {} --map:port of the TDAGLayer to layer ref and port
@@ -109,7 +121,7 @@ function TNN:__init(id, global_conf, layer_conf)
outputs_p[port_to] = {["ref"] = ref_from, ["port"] = port_from}
ref_from.outputs_m[port_from] = {} --just a place holder
else
- conn_now = {
+ local conn_now = {
["src"] = {["ref"] = ref_from, ["port"] = port_from},
["dst"] = {["ref"] = ref_to, ["port"] = port_to},
["time"] = time_to
@@ -138,9 +150,11 @@ function TNN:init(batch_size, chunk_size)
self.chunk_size = chunk_size
for i, conn in ipairs(self.parsed_conns) do --init storage for connections inside the NN
local _, output_dim
- local ref_from, port_from, ref_to, port_to
+ local ref_from, port_from, ref_to, port_to, time
ref_from, port_from = conn.src.ref, conn.src.port
ref_to, port_to = conn.dst.ref, conn.dst.port
+ time = conn.time
+
local dim = ref_from.dim_out[port_from]
if (dim == 0) then
nerv.error("layer %s has a zero dim port", ref_from.layer.id)
@@ -148,9 +162,9 @@ function TNN:init(batch_size, chunk_size)
print("TNN initing storage", ref_from.layer.id, "->", ref_to.layer.id)
ref_to.inputs_p_matbak[port_to] = self.gconf.cumat_type(batch_size, dim)
- self.makeInitialStore(ref_from.outputs_m, port_from, dim, batch_size, chunk_size, self.gconf, ref_to.inputs_m, port_to)
+ self.makeInitialStore(ref_from.outputs_m, port_from, dim, batch_size, chunk_size, self.gconf, ref_to.inputs_m, port_to, time)
ref_from.err_inputs_p_matbak[port_from] = self.gconf.cumat_type(batch_size, dim)
- self.makeInitialStore(ref_from.err_inputs_m, port_from, dim, batch_size, chunk_size, self.gconf, ref_to.err_outputs_m, port_to)
+ self.makeInitialStore(ref_from.err_inputs_m, port_from, dim, batch_size, chunk_size, self.gconf, ref_to.err_outputs_m, port_to, time)
end
@@ -159,8 +173,8 @@ function TNN:init(batch_size, chunk_size)
for i = 1, #self.dim_out do --Init storage for output ports
local ref = self.outputs_p[i].ref
local p = self.outputs_p[i].port
- self.makeInitialStore(ref.outputs_m, p, self.dim_out[i], batch_size, chunk_size, self.gconf, self.outputs_m, i)
- self.makeInitialStore(ref.err_inputs_m, p, self.dim_out[i], batch_size, chunk_size, self.gconf, self.err_inputs_m, i)
+ self.makeInitialStore(ref.outputs_m, p, self.dim_out[i], batch_size, chunk_size, self.gconf, self.outputs_m, i, 0)
+ self.makeInitialStore(ref.err_inputs_m, p, self.dim_out[i], batch_size, chunk_size, self.gconf, self.err_inputs_m, i, 0)
end
self.inputs_m = {}
@@ -168,8 +182,8 @@ function TNN:init(batch_size, chunk_size)
for i = 1, #self.dim_in do --Init storage for input ports
local ref = self.inputs_p[i].ref
local p = self.inputs_p[i].port
- self.makeInitialStore(ref.inputs_m, p, self.dim_in[i], batch_size, chunk_size, self.gconf, self.inputs_m, i)
- self.makeInitialStore(ref.err_outputs_m, p, self.dim_in[i], batch_size, chunk_size, self.gconf, self.err_outputs_m, i)
+ self.makeInitialStore(ref.inputs_m, p, self.dim_in[i], batch_size, chunk_size, self.gconf, self.inputs_m, i, 0)
+ self.makeInitialStore(ref.err_outputs_m, p, self.dim_in[i], batch_size, chunk_size, self.gconf, self.err_outputs_m, i, 0)
end
for id, ref in pairs(self.layers) do --Calling init for child layers
@@ -274,45 +288,161 @@ function TNN:getFeedFromReader(reader)
return got_new, feeds_now
end
+function TNN:moveRightToNextMB() --move output history activations of 1..chunk_size to 1-chunk_size..0
+ for t = self.chunk_size, 1, -1 do
+ for id, ref in pairs(self.layers) do
+ for p = 1, #ref.dim_out do
+ ref.outputs_m[t - self.chunk_size][p]:copy_fromd(ref.outputs_m[t][p])
+ end
+ end
+ end
+end
+
function TNN:net_propagate() --propagate according to feeds_now
+ for t = 1, self.chunk_size, 1 do
+ for id, ref in pairs(self.layers) do
+ for p = 1, #ref.dim_out do
+ ref.outputs_b[t][p] = false
+ end
+ for p = 1, #ref.dim_in do
+ ref.inputs_b[t][p] = false
+ end
+ end
+ end
+
local feeds_now = self.feeds_now
- for t = 1, chunk_size do
+ for t = 1, self.chunk_size do
if (bit.band(feeds_now.flagsPack_now[t], nerv.TNN.FC.HAS_INPUT) > 0) then
for i = 1, #self.dim_in do
- local ref = inputs_p[i].ref
- local p = inputs_p[i].port
+ local ref = self.inputs_p[i].ref
+ local p = self.inputs_p[i].port
ref.inputs_b[t][p] = true
+ self:propagate_dfs(ref, t)
end
- --TODO
end
end
end
-function DAGLayer:update(bp_err, input, output)
- self:set_err_inputs(bp_err)
- self:set_inputs(input)
- self:set_outputs(output)
- -- print("update")
- for id, ref in pairs(self.queue) do
- -- print(ref.layer.id)
- ref.layer:update(ref.err_inputs, ref.inputs, ref.outputs)
+--ref: the TNN_ref of a layer
+--t: the current time to propagate
+function TNN:propagate_dfs(ref, t)
+ if (self:outOfFeedRange(t)) then
+ return
+ end
+ if (ref.outputs_b[t][1] == true) then --already propagated, 1 is just a random port
+ return
+ end
+
+ --print("debug dfs", ref.layer.id, t)
+
+ local flag = true --whether have all inputs
+ for _, conn in pairs(ref.conns_i) do
+ local p = conn.dst.port
+ if (not (ref.inputs_b[t][p] or self:outOfFeedRange(t - conn.time))) then
+ flag = false
+ break
+ end
+ end
+ if (flag == false) then
+ return
+ end
+
+ --ok, do propagate
+ --print("debug ok, propagating");
+ ref.layer:propagate(ref.inputs_m[t], ref.outputs_m[t])
+ for i = 1, #ref.dim_out do
+ if (ref.outputs_b[t][i] == true) then
+ nerv.error("this time's outputs_b should be false")
+ end
+ ref.outputs_b[t][i] = true
+ end
+
+ --try dfs for further layers
+ for _, conn in pairs(ref.conns_o) do
+ --print("debug dfs-searching", conn.dst.ref.layer.id)
+ conn.dst.ref.inputs_b[t + conn.time][conn.dst.port] = true
+ self:propagate_dfs(conn.dst.ref, t + conn.time)
end
end
-function DAGLayer:back_propagate(bp_err, next_bp_err, input, output)
- self:set_err_outputs(next_bp_err)
- self:set_err_inputs(bp_err)
- self:set_inputs(input)
- self:set_outputs(output)
- for i = #self.queue, 1, -1 do
- local ref = self.queue[i]
- -- print(ref.layer.id)
- ref.layer:back_propagate(ref.err_inputs, ref.err_outputs, ref.inputs, ref.outputs)
+--do_update: bool, whether we are doing back-propagate or updating the parameters
+function TNN:net_backpropagate(do_update) --propagate according to feeds_now
+ if (do_update == nil) then
+ nerv.error("do_update should not be nil")
+ end
+ for t = 1, self.chunk_size, 1 do
+ for id, ref in pairs(self.layers) do
+ for p = 1, #ref.dim_out do
+ ref.err_inputs_b[t][p] = false
+ end
+ for p = 1, #ref.dim_in do
+ ref.err_outputs_b[t][p] = false
+ end
+ end
+ end
+
+ local feeds_now = self.feeds_now
+ for t = 1, self.chunk_size do
+ if (bit.band(feeds_now.flagsPack_now[t], nerv.TNN.FC.HAS_LABEL) > 0) then
+ for i = 1, #self.dim_out do
+ local ref = self.outputs_p[i].ref
+ local p = self.outputs_p[i].port
+ ref.err_inputs_b[t][p] = true
+ self:backpropagate_dfs(ref, t, do_update)
+ end
+ end
+ end
+end
+
+--ref: the TNN_ref of a layer
+--t: the current time to propagate
+function TNN:backpropagate_dfs(ref, t, do_update)
+ if (self:outOfFeedRange(t)) then
+ return
+ end
+ if (ref.err_outputs_b[t][1] == true) then --already back_propagated, 1 is just a random port
+ return
+ end
+
+ --print("debug dfs", ref.layer.id, t)
+
+ local flag = true --whether have all inputs
+ for _, conn in pairs(ref.conns_o) do
+ local p = conn.src.port
+ if (not (ref.err_inputs_b[t][p] or self:outOfFeedRange(t + conn.time))) then
+ flag = false
+ break
+ end
+ end
+ if (flag == false) then
+ return
+ end
+
+ --ok, do back_propagate
+ --print("debug ok, back-propagating(or updating)")
+ if (do_update == false) then
+ ref.layer:back_propagate(ref.err_inputs_m[t], ref.err_outputs_m[t], ref.inputs_m[t], ref.outputs_m[t])
+ else
+ --print(ref.err_inputs_m[t][1])
+ ref.layer:update(ref.err_inputs_m[t], ref.inputs_m[t], ref.outputs_m[t])
+ end
+ for i = 1, #ref.dim_in do
+ if (ref.err_outputs_b[t][i] == true) then
+ nerv.error("this time's outputs_b should be false")
+ end
+ ref.err_outputs_b[t][i] = true
+ end
+
+ --try dfs for further layers
+ for _, conn in pairs(ref.conns_i) do
+ --print("debug dfs-searching", conn.src.ref.layer.id)
+ conn.src.ref.err_inputs_b[t - conn.time][conn.src.port] = true
+ self:backpropagate_dfs(conn.src.ref, t - conn.time, do_update)
end
end
--Return: nerv.ParamRepo
-function DAGLayer:get_params()
+function TNN:get_params()
local param_repos = {}
for id, ref in pairs(self.queue) do
table.insert(param_repos, ref.layer:get_params())