aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--nerv/tnn/tnn.lua6
1 files changed, 5 insertions, 1 deletions
diff --git a/nerv/tnn/tnn.lua b/nerv/tnn/tnn.lua
index 86b149e..95e23a9 100644
--- a/nerv/tnn/tnn.lua
+++ b/nerv/tnn/tnn.lua
@@ -112,6 +112,10 @@ function TNN:__init(id, global_conf, layer_conf)
local dim_out = layer_conf.dim_out
local parsed_conns = {}
local _
+
+ for id, layer in pairs(layer_conf.sub_layers.layers) do
+ discover(id, layer, layer_conf.sub_layers)
+ end
for _, ll in pairs(layer_conf.connections) do
local id_from, port_from = parse_id(ll[1])
@@ -184,7 +188,6 @@ function TNN:init(batch_size, chunk_size)
self.make_initial_store(ref_from.outputs_m, port_from, dim, batch_size, chunk_size, self.extend_t, self.gconf, ref_to.inputs_m, port_to, time)
ref_from.err_inputs_matbak_p[port_from] = self.gconf.cumat_type(batch_size, dim)
self.make_initial_store(ref_from.err_inputs_m, port_from, dim, batch_size, chunk_size, self.extend_t, self.gconf, ref_to.err_outputs_m, port_to, time)
-
end
self.outputs_m = {}
@@ -217,6 +220,7 @@ function TNN:init(batch_size, chunk_size)
end
end
-- initialize sub layers
+ nerv.info("TNN initing sub-layer %s", ref.id)
ref.layer:init(batch_size, chunk_size)
collectgarbage("collect")
end