From e68d606fbea209794d7380e84dbffc1c6d22021e Mon Sep 17 00:00:00 2001 From: txh18 Date: Tue, 29 Dec 2015 18:03:28 +0800 Subject: small change in tnn, will create all sub-layers first --- nerv/tnn/tnn.lua | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/nerv/tnn/tnn.lua b/nerv/tnn/tnn.lua index 86b149e..95e23a9 100644 --- a/nerv/tnn/tnn.lua +++ b/nerv/tnn/tnn.lua @@ -112,6 +112,10 @@ function TNN:__init(id, global_conf, layer_conf) local dim_out = layer_conf.dim_out local parsed_conns = {} local _ + + for id, layer in pairs(layer_conf.sub_layers.layers) do + discover(id, layer, layer_conf.sub_layers) + end for _, ll in pairs(layer_conf.connections) do local id_from, port_from = parse_id(ll[1]) @@ -184,7 +188,6 @@ function TNN:init(batch_size, chunk_size) self.make_initial_store(ref_from.outputs_m, port_from, dim, batch_size, chunk_size, self.extend_t, self.gconf, ref_to.inputs_m, port_to, time) ref_from.err_inputs_matbak_p[port_from] = self.gconf.cumat_type(batch_size, dim) self.make_initial_store(ref_from.err_inputs_m, port_from, dim, batch_size, chunk_size, self.extend_t, self.gconf, ref_to.err_outputs_m, port_to, time) - end self.outputs_m = {} @@ -217,6 +220,7 @@ function TNN:init(batch_size, chunk_size) end end -- initialize sub layers + nerv.info("TNN initing sub-layer %s", ref.id) ref.layer:init(batch_size, chunk_size) collectgarbage("collect") end -- cgit v1.2.3