diff options
Diffstat (limited to 'nerv/tnn/tnn.lua')
-rw-r--r-- | nerv/tnn/tnn.lua | 28 |
1 files changed, 26 insertions, 2 deletions
diff --git a/nerv/tnn/tnn.lua b/nerv/tnn/tnn.lua index cf02123..d527fe6 100644 --- a/nerv/tnn/tnn.lua +++ b/nerv/tnn/tnn.lua @@ -64,7 +64,7 @@ function TNN.make_initial_store(st, p, dim, batch_size, chunk_size, extend_t, gl if (type(st) ~= "table") then nerv.error("st should be a table") end - for i = 1 - extend_t - 1, chunk_size + extend_t + 1 do --intentionally allocated more time + for i = 1 - extend_t - 2, chunk_size + extend_t + 2 do --intentionally allocated more time if (st[i] == nil) then st[i] = {} end @@ -77,6 +77,7 @@ function TNN.make_initial_store(st, p, dim, batch_size, chunk_size, extend_t, gl st_c[i + t_c][p_c] = st[i][p] end end + collectgarbage("collect") --free the old one to save memory end function TNN:out_of_feedrange(t) --out of chunk, or no input, for the current feed @@ -111,6 +112,10 @@ function TNN:__init(id, global_conf, layer_conf) local dim_out = layer_conf.dim_out local parsed_conns = {} local _ + + for id, _ in pairs(layer_conf.sub_layers.layers) do --caution: with this line, some layer not connected will be included + discover(id, layers, layer_conf.sub_layers) + end for _, ll in pairs(layer_conf.connections) do local id_from, port_from = parse_id(ll[1]) @@ -183,7 +188,6 @@ function TNN:init(batch_size, chunk_size) self.make_initial_store(ref_from.outputs_m, port_from, dim, batch_size, chunk_size, self.extend_t, self.gconf, ref_to.inputs_m, port_to, time) ref_from.err_inputs_matbak_p[port_from] = self.gconf.cumat_type(batch_size, dim) self.make_initial_store(ref_from.err_inputs_m, port_from, dim, batch_size, chunk_size, self.extend_t, self.gconf, ref_to.err_outputs_m, port_to, time) - end self.outputs_m = {} @@ -216,7 +220,9 @@ function TNN:init(batch_size, chunk_size) end end -- initialize sub layers + nerv.info("TNN initing sub-layer %s", ref.id) ref.layer:init(batch_size, chunk_size) + collectgarbage("collect") end local flags_now = {} @@ -339,6 +345,13 @@ function TNN:net_propagate() --propagate according to feeds_now end local feeds_now = self.feeds_now + for t = 1, self.chunk_size do --some layer maybe do not have inputs from time 1..chunk_size + for id, ref in pairs(self.layers) do + if #ref.dim_in > 0 then --some layer is just there(only to save some parameter) + self:propagate_dfs(ref, t) + end + end + end for t = 1, self.chunk_size do if (bit.band(feeds_now.flagsPack_now[t], nerv.TNN.FC.HAS_INPUT) > 0) then for i = 1, #self.dim_in do @@ -362,6 +375,7 @@ function TNN:net_propagate() --propagate according to feeds_now end end end + if (flag_out == false) then nerv.error("some thing wrong, some labeled output is not propagated") end @@ -458,6 +472,13 @@ function TNN:net_backpropagate(do_update) --propagate according to feeds_now end local feeds_now = self.feeds_now + for t = 1, self.chunk_size do --some layer maybe do not have outputs from time 1..chunk_size + for id, ref in pairs(self.layers) do + if #ref.dim_out > 0 then --some layer is just there(only to save some parameter) + self:backpropagate_dfs(ref, t, do_update) + end + end + end for t = 1, self.chunk_size do if bit.band(feeds_now.flagsPack_now[t], nerv.TNN.FC.HAS_LABEL) > 0 then for i = 1, #self.dim_out do @@ -489,6 +510,9 @@ end --ref: the TNN_ref of a layer --t: the current time to propagate function TNN:backpropagate_dfs(ref, t, do_update) + if do_update == nil then + nerv.error("got a nil do_update") + end if self:out_of_feedrange(t) then return end |