aboutsummaryrefslogtreecommitdiff
path: root/layer
diff options
context:
space:
mode:
authorDeterminant <ted.sybil@gmail.com>2015-05-26 14:06:52 +0800
committerDeterminant <ted.sybil@gmail.com>2015-05-26 14:06:52 +0800
commit81bf2d653902860c5d28ccade19ac6e1fd56acaf (patch)
tree831a25c84332ac0839dbe498f61620ea634170e0 /layer
parent8c5246a8794011ca0c25f15643771f031d258594 (diff)
add layer and param
Diffstat (limited to 'layer')
-rw-r--r--layer/affine.lua11
-rw-r--r--layer/init.lua41
2 files changed, 52 insertions, 0 deletions
diff --git a/layer/affine.lua b/layer/affine.lua
new file mode 100644
index 0000000..d5c50fc
--- /dev/null
+++ b/layer/affine.lua
@@ -0,0 +1,11 @@
+local LinearTransParam = nerv.class('nerv.LinearTransParam', 'nerv.Param')
+local BiasParam = nerv.class('nerv.BiasParam', 'nerv.LinearTransParam')
+local AffineLayer = nerv.class('nerv.AffineLayer', 'nerv.Layer')
+
+function LinearTransParam:read(pcdata)
+ self.trans = nerv.CuMatrixFloat.new_from_host(nerv.MMatrixFloat.load(pcdata))
+end
+
+function LinearTransParam:write(pfhandle)
+ self.trans:new_to_host():save(pfhandle)
+end
diff --git a/layer/init.lua b/layer/init.lua
new file mode 100644
index 0000000..c57a405
--- /dev/null
+++ b/layer/init.lua
@@ -0,0 +1,41 @@
+-- The following methods must be implemented to let a layer work properly
+
+local Param = nerv.class('nerv.Param')
+
+function nerv.Param:__init(id)
+ self.id = id
+end
+
+function nerv.Param:get_info()
+ return self.info
+end
+
+function nerv.Param:set_info(info)
+ self.info = info
+end
+
+function nerv.Param:read(pfhandle)
+ nerv.error_method_not_implemented()
+end
+
+function nerv.Param:write(pfhandle)
+ nerv.error_method_not_implemented()
+end
+
+local Layer = nerv.class('nerv.Layer')
+
+function nerv.Layer:_init(param)
+ nerv.error_method_not_implemented()
+end
+
+function nerv.Layer:update(bp_err, input, output)
+ nerv.error_method_not_implemented()
+end
+
+function nerv.Layer:propagate(input, output)
+ nerv.error_method_not_implemented()
+end
+
+function nerv.Layer:back_propagate(next_bp_err, bp_err, input, output)
+ nerv.error_method_not_implemented()
+end