diff options
Diffstat (limited to 'layer')
-rw-r--r-- | layer/affine.lua | 3 | ||||
-rw-r--r-- | layer/init.lua | 13 |
2 files changed, 9 insertions, 7 deletions
diff --git a/layer/affine.lua b/layer/affine.lua index 97703a8..d88813e 100644 --- a/layer/affine.lua +++ b/layer/affine.lua @@ -3,7 +3,8 @@ local BiasParam = nerv.class('nerv.BiasParam', 'nerv.LinearTransParam') local AffineLayer = nerv.class('nerv.AffineLayer', 'nerv.Layer') function LinearTransParam:read(pcdata) - self.trans = nerv.CuMatrixFloat.new_from_host(nerv.MMatrixFloat.load(pcdata)) + self.trans = self.gconf.mat_type.new_from_host( + nerv.MMatrixFloat.load(pcdata)) end function LinearTransParam:write(pfhandle) diff --git a/layer/init.lua b/layer/init.lua index 0f0afe8..a98621d 100644 --- a/layer/init.lua +++ b/layer/init.lua @@ -2,12 +2,9 @@ local Param = nerv.class('nerv.Param') -function nerv.Param:__init(id) +function nerv.Param:__init(id, global_conf) self.id = id -end - -function nerv.Param:init(id) - nerv.error_method_not_implemented() + self.gconf = global_conf end function nerv.Param:get_info() @@ -28,7 +25,11 @@ end local Layer = nerv.class('nerv.Layer') -function nerv.Layer:_init(id, global_conf, ...) +function nerv.Layer:__init(id, global_conf, ...) + nerv.error_method_not_implemented() +end + +function nerv.Layer:init(id) nerv.error_method_not_implemented() end |