aboutsummaryrefslogtreecommitdiff
path: root/layer
diff options
context:
space:
mode:
Diffstat (limited to 'layer')
-rw-r--r--layer/affine.lua3
-rw-r--r--layer/init.lua13
2 files changed, 9 insertions, 7 deletions
diff --git a/layer/affine.lua b/layer/affine.lua
index 97703a8..d88813e 100644
--- a/layer/affine.lua
+++ b/layer/affine.lua
@@ -3,7 +3,8 @@ local BiasParam = nerv.class('nerv.BiasParam', 'nerv.LinearTransParam')
local AffineLayer = nerv.class('nerv.AffineLayer', 'nerv.Layer')
function LinearTransParam:read(pcdata)
- self.trans = nerv.CuMatrixFloat.new_from_host(nerv.MMatrixFloat.load(pcdata))
+ self.trans = self.gconf.mat_type.new_from_host(
+ nerv.MMatrixFloat.load(pcdata))
end
function LinearTransParam:write(pfhandle)
diff --git a/layer/init.lua b/layer/init.lua
index 0f0afe8..a98621d 100644
--- a/layer/init.lua
+++ b/layer/init.lua
@@ -2,12 +2,9 @@
local Param = nerv.class('nerv.Param')
-function nerv.Param:__init(id)
+function nerv.Param:__init(id, global_conf)
self.id = id
-end
-
-function nerv.Param:init(id)
- nerv.error_method_not_implemented()
+ self.gconf = global_conf
end
function nerv.Param:get_info()
@@ -28,7 +25,11 @@ end
local Layer = nerv.class('nerv.Layer')
-function nerv.Layer:_init(id, global_conf, ...)
+function nerv.Layer:__init(id, global_conf, ...)
+ nerv.error_method_not_implemented()
+end
+
+function nerv.Layer:init(id)
nerv.error_method_not_implemented()
end