diff options
Diffstat (limited to 'nerv/nn/layer_dag.lua')
-rw-r--r-- | nerv/nn/layer_dag.lua | 9 |
1 files changed, 7 insertions, 2 deletions
diff --git a/nerv/nn/layer_dag.lua b/nerv/nn/layer_dag.lua index 73bb77d..6ad7ae9 100644 --- a/nerv/nn/layer_dag.lua +++ b/nerv/nn/layer_dag.lua @@ -131,6 +131,11 @@ function DAGLayer:__init(id, global_conf, layer_conf) self.parsed_conn = parsed_conn self.queue = queue self.gconf = global_conf + if self.gconf.use_cpu then + self.mat_type = self.gconf.mmat_type + else + self.mat_type = self.gconf.cumat_type + end end function DAGLayer:init(batch_size) @@ -144,7 +149,7 @@ function DAGLayer:init(batch_size) if output_dim[port_from] > 0 then dim = output_dim[port_from] end - local mid = self.gconf.cumat_type(batch_size, dim) + local mid = self.mat_type(batch_size, dim) local err_mid = mid:create() ref_from.outputs[port_from] = mid @@ -190,7 +195,7 @@ function DAGLayer:batch_resize(batch_size) _, output_dim = ref_from.layer:get_dim() if ref_from.outputs[port_from]:nrow() ~= batch_size and output_dim[port_from] > 0 then - local mid = self.gconf.cumat_type(batch_size, output_dim[port_from]) + local mid = self.mat_type(batch_size, output_dim[port_from]) local err_mid = mid:create() ref_from.outputs[port_from] = mid |