aboutsummaryrefslogtreecommitdiff
path: root/nerv/nn/layer_dag.lua
diff options
context:
space:
mode:
Diffstat (limited to 'nerv/nn/layer_dag.lua')
-rw-r--r--nerv/nn/layer_dag.lua9
1 files changed, 7 insertions, 2 deletions
diff --git a/nerv/nn/layer_dag.lua b/nerv/nn/layer_dag.lua
index 73bb77d..6ad7ae9 100644
--- a/nerv/nn/layer_dag.lua
+++ b/nerv/nn/layer_dag.lua
@@ -131,6 +131,11 @@ function DAGLayer:__init(id, global_conf, layer_conf)
self.parsed_conn = parsed_conn
self.queue = queue
self.gconf = global_conf
+ if self.gconf.use_cpu then
+ self.mat_type = self.gconf.mmat_type
+ else
+ self.mat_type = self.gconf.cumat_type
+ end
end
function DAGLayer:init(batch_size)
@@ -144,7 +149,7 @@ function DAGLayer:init(batch_size)
if output_dim[port_from] > 0 then
dim = output_dim[port_from]
end
- local mid = self.gconf.cumat_type(batch_size, dim)
+ local mid = self.mat_type(batch_size, dim)
local err_mid = mid:create()
ref_from.outputs[port_from] = mid
@@ -190,7 +195,7 @@ function DAGLayer:batch_resize(batch_size)
_, output_dim = ref_from.layer:get_dim()
if ref_from.outputs[port_from]:nrow() ~= batch_size and output_dim[port_from] > 0 then
- local mid = self.gconf.cumat_type(batch_size, output_dim[port_from])
+ local mid = self.mat_type(batch_size, output_dim[port_from])
local err_mid = mid:create()
ref_from.outputs[port_from] = mid