summaryrefslogtreecommitdiff
path: root/nerv/nn/layer_dag.lua
diff options
context:
space:
mode:
Diffstat (limited to 'nerv/nn/layer_dag.lua')
-rw-r--r--nerv/nn/layer_dag.lua16
1 files changed, 6 insertions, 10 deletions
diff --git a/nerv/nn/layer_dag.lua b/nerv/nn/layer_dag.lua
index 6896878..f999752 100644
--- a/nerv/nn/layer_dag.lua
+++ b/nerv/nn/layer_dag.lua
@@ -134,20 +134,16 @@ function DAGLayer:__init(id, global_conf, layer_conf)
end
end
+ nerv.Layer.__init(self, id, global_conf, layer_conf)
self.layers = layers
self.inputs = inputs
self.outputs = outputs
- self.id = id
- self.dim_in = dim_in
- self.dim_out = dim_out
self.parsed_conn = parsed_conn
self.queue = queue
- self.gconf = global_conf
- if self.gconf.use_cpu then
- self.mat_type = self.gconf.mmat_type
- else
- self.mat_type = self.gconf.cumat_type
- end
+end
+
+function DAGLayer:bind_params()
+ -- do nothing (instead of rebinding params for each layer)
end
function DAGLayer:init(batch_size, chunk_size)
@@ -325,7 +321,7 @@ function DAGLayer:get_params()
for id, ref in pairs(self.queue) do
table.insert(param_repos, ref.layer:get_params())
end
- return nerv.ParamRepo.merge(param_repos)
+ return nerv.ParamRepo.merge(param_repos, self.loc_type)
end
DAGLayer.PORT_TYPES = {