aboutsummaryrefslogtreecommitdiff
path: root/nn/layer_repo.lua
diff options
context:
space:
mode:
authorcloudygoose <[email protected]>2015-06-03 10:29:41 +0800
committercloudygoose <[email protected]>2015-06-03 10:29:41 +0800
commitbf01fd6cea42def51becb6ea866d4fd335e45842 (patch)
tree09d12e50e3a6156c7e0cd7412b22fa4b61189495 /nn/layer_repo.lua
parent6984519cbb659aac0b0b323de93d5a90aa2049b7 (diff)
parentbb56a806e0636a0b20117b1644701d63e2bfaefb (diff)
...
Merge remote-tracking branch 'upstream/master'
Diffstat (limited to 'nn/layer_repo.lua')
-rw-r--r--nn/layer_repo.lua34
1 files changed, 34 insertions, 0 deletions
diff --git a/nn/layer_repo.lua b/nn/layer_repo.lua
new file mode 100644
index 0000000..b1d2248
--- /dev/null
+++ b/nn/layer_repo.lua
@@ -0,0 +1,34 @@
+local LayerRepo = nerv.class("nerv.LayerRepo")
+
+function LayerRepo:__init(layer_spec, param_repo, global_conf)
+ local layers = {}
+ for ltype, llist in pairs(layer_spec) do
+ local layer_type = nerv.get_type(ltype)
+ for id, spec in pairs(llist) do
+ if layers[id] ~= nil then
+ nerv.error("a layer with id %s already exists", id)
+ end
+ nerv.utils.printf("id: %s\n", id)
+ if type(spec[2]) ~= "table" then
+ nerv.error("layer config table is need")
+ end
+ layer_config = spec[2]
+ if type(spec[1]) ~= "table" then
+ nerv.error("parameter description table is needed")
+ end
+ for pname, pid in pairs(spec[1]) do
+ layer_config[pname] = param_repo:get_param(pid, global_conf)
+ end
+ layers[id] = layer_type(id, global_conf, layer_config)
+ end
+ end
+ self.layers = layers
+end
+
+function LayerRepo:get_layer(lid)
+ local layer = self.layers[lid]
+ if layer == nil then
+ nerv.error("layer with id %s not found", lid)
+ end
+ return layer
+end