aboutsummaryrefslogtreecommitdiff
path: root/nerv/tnn/tnn.lua
diff options
context:
space:
mode:
Diffstat (limited to 'nerv/tnn/tnn.lua')
-rw-r--r--nerv/tnn/tnn.lua28
1 files changed, 26 insertions, 2 deletions
diff --git a/nerv/tnn/tnn.lua b/nerv/tnn/tnn.lua
index cf02123..d527fe6 100644
--- a/nerv/tnn/tnn.lua
+++ b/nerv/tnn/tnn.lua
@@ -64,7 +64,7 @@ function TNN.make_initial_store(st, p, dim, batch_size, chunk_size, extend_t, gl
if (type(st) ~= "table") then
nerv.error("st should be a table")
end
- for i = 1 - extend_t - 1, chunk_size + extend_t + 1 do --intentionally allocated more time
+ for i = 1 - extend_t - 2, chunk_size + extend_t + 2 do --intentionally allocated more time
if (st[i] == nil) then
st[i] = {}
end
@@ -77,6 +77,7 @@ function TNN.make_initial_store(st, p, dim, batch_size, chunk_size, extend_t, gl
st_c[i + t_c][p_c] = st[i][p]
end
end
+ collectgarbage("collect") --free the old one to save memory
end
function TNN:out_of_feedrange(t) --out of chunk, or no input, for the current feed
@@ -111,6 +112,10 @@ function TNN:__init(id, global_conf, layer_conf)
local dim_out = layer_conf.dim_out
local parsed_conns = {}
local _
+
+ for id, _ in pairs(layer_conf.sub_layers.layers) do --caution: with this line, some layer not connected will be included
+ discover(id, layers, layer_conf.sub_layers)
+ end
for _, ll in pairs(layer_conf.connections) do
local id_from, port_from = parse_id(ll[1])
@@ -183,7 +188,6 @@ function TNN:init(batch_size, chunk_size)
self.make_initial_store(ref_from.outputs_m, port_from, dim, batch_size, chunk_size, self.extend_t, self.gconf, ref_to.inputs_m, port_to, time)
ref_from.err_inputs_matbak_p[port_from] = self.gconf.cumat_type(batch_size, dim)
self.make_initial_store(ref_from.err_inputs_m, port_from, dim, batch_size, chunk_size, self.extend_t, self.gconf, ref_to.err_outputs_m, port_to, time)
-
end
self.outputs_m = {}
@@ -216,7 +220,9 @@ function TNN:init(batch_size, chunk_size)
end
end
-- initialize sub layers
+ nerv.info("TNN initing sub-layer %s", ref.id)
ref.layer:init(batch_size, chunk_size)
+ collectgarbage("collect")
end
local flags_now = {}
@@ -339,6 +345,13 @@ function TNN:net_propagate() --propagate according to feeds_now
end
local feeds_now = self.feeds_now
+ for t = 1, self.chunk_size do --some layer maybe do not have inputs from time 1..chunk_size
+ for id, ref in pairs(self.layers) do
+ if #ref.dim_in > 0 then --some layer is just there(only to save some parameter)
+ self:propagate_dfs(ref, t)
+ end
+ end
+ end
for t = 1, self.chunk_size do
if (bit.band(feeds_now.flagsPack_now[t], nerv.TNN.FC.HAS_INPUT) > 0) then
for i = 1, #self.dim_in do
@@ -362,6 +375,7 @@ function TNN:net_propagate() --propagate according to feeds_now
end
end
end
+
if (flag_out == false) then
nerv.error("some thing wrong, some labeled output is not propagated")
end
@@ -458,6 +472,13 @@ function TNN:net_backpropagate(do_update) --propagate according to feeds_now
end
local feeds_now = self.feeds_now
+ for t = 1, self.chunk_size do --some layer maybe do not have outputs from time 1..chunk_size
+ for id, ref in pairs(self.layers) do
+ if #ref.dim_out > 0 then --some layer is just there(only to save some parameter)
+ self:backpropagate_dfs(ref, t, do_update)
+ end
+ end
+ end
for t = 1, self.chunk_size do
if bit.band(feeds_now.flagsPack_now[t], nerv.TNN.FC.HAS_LABEL) > 0 then
for i = 1, #self.dim_out do
@@ -489,6 +510,9 @@ end
--ref: the TNN_ref of a layer
--t: the current time to propagate
function TNN:backpropagate_dfs(ref, t, do_update)
+ if do_update == nil then
+ nerv.error("got a nil do_update")
+ end
if self:out_of_feedrange(t) then
return
end