diff options
-rw-r--r-- | nerv/tnn/tnn.lua | 6 |
1 files changed, 5 insertions, 1 deletions
diff --git a/nerv/tnn/tnn.lua b/nerv/tnn/tnn.lua index 86b149e..95e23a9 100644 --- a/nerv/tnn/tnn.lua +++ b/nerv/tnn/tnn.lua @@ -112,6 +112,10 @@ function TNN:__init(id, global_conf, layer_conf) local dim_out = layer_conf.dim_out local parsed_conns = {} local _ + + for id, layer in pairs(layer_conf.sub_layers.layers) do + discover(id, layer, layer_conf.sub_layers) + end for _, ll in pairs(layer_conf.connections) do local id_from, port_from = parse_id(ll[1]) @@ -184,7 +188,6 @@ function TNN:init(batch_size, chunk_size) self.make_initial_store(ref_from.outputs_m, port_from, dim, batch_size, chunk_size, self.extend_t, self.gconf, ref_to.inputs_m, port_to, time) ref_from.err_inputs_matbak_p[port_from] = self.gconf.cumat_type(batch_size, dim) self.make_initial_store(ref_from.err_inputs_m, port_from, dim, batch_size, chunk_size, self.extend_t, self.gconf, ref_to.err_outputs_m, port_to, time) - end self.outputs_m = {} @@ -217,6 +220,7 @@ function TNN:init(batch_size, chunk_size) end end -- initialize sub layers + nerv.info("TNN initing sub-layer %s", ref.id) ref.layer:init(batch_size, chunk_size) collectgarbage("collect") end |