diff options
Diffstat (limited to 'nerv/nn/layer_dag.lua')
-rw-r--r-- | nerv/nn/layer_dag.lua | 16 |
1 files changed, 6 insertions, 10 deletions
diff --git a/nerv/nn/layer_dag.lua b/nerv/nn/layer_dag.lua index 6896878..f999752 100644 --- a/nerv/nn/layer_dag.lua +++ b/nerv/nn/layer_dag.lua @@ -134,20 +134,16 @@ function DAGLayer:__init(id, global_conf, layer_conf) end end + nerv.Layer.__init(self, id, global_conf, layer_conf) self.layers = layers self.inputs = inputs self.outputs = outputs - self.id = id - self.dim_in = dim_in - self.dim_out = dim_out self.parsed_conn = parsed_conn self.queue = queue - self.gconf = global_conf - if self.gconf.use_cpu then - self.mat_type = self.gconf.mmat_type - else - self.mat_type = self.gconf.cumat_type - end +end + +function DAGLayer:bind_params() + -- do nothing (instead of rebinding params for each layer) end function DAGLayer:init(batch_size, chunk_size) @@ -325,7 +321,7 @@ function DAGLayer:get_params() for id, ref in pairs(self.queue) do table.insert(param_repos, ref.layer:get_params()) end - return nerv.ParamRepo.merge(param_repos) + return nerv.ParamRepo.merge(param_repos, self.loc_type) end DAGLayer.PORT_TYPES = { |