aboutsummaryrefslogtreecommitdiff
path: root/nn/param_repo.lua
diff options
context:
space:
mode:
authorDeterminant <[email protected]>2015-06-02 20:28:16 +0800
committerDeterminant <[email protected]>2015-06-02 20:28:16 +0800
commit74d9e9e7371c80394698fb9805cbf0cbde67a8f3 (patch)
tree36b070f1fcfa2be8fc80c50b7a221862a0dfd14a /nn/param_repo.lua
parent60083f2e51935ce55cec7a4c39d1724a16d9c769 (diff)
add ParamRepo, LayerRepo, DAGLayer
Diffstat (limited to 'nn/param_repo.lua')
-rw-r--r--nn/param_repo.lua26
1 files changed, 26 insertions, 0 deletions
diff --git a/nn/param_repo.lua b/nn/param_repo.lua
new file mode 100644
index 0000000..3e37c31
--- /dev/null
+++ b/nn/param_repo.lua
@@ -0,0 +1,26 @@
+local ParamRepo = nerv.class("nerv.ParamRepo")
+
+function ParamRepo:__init(param_files)
+ local param_table = {}
+ if type(param_files) ~= "table" then
+ nerv.error("param file table is need")
+ end
+ for i = 1, #param_files do
+ local pf = nerv.ChunkFile(param_files[i], "r")
+ for cid, cspec in pairs(pf.metadata) do
+ if param_table[cid] ~= nil then
+ nerv.error("conflicting chunk id in param files")
+ end
+ param_table[cid] = pf
+ end
+ end
+ self.param_table = param_table
+end
+
+function ParamRepo:get_param(pid, global_conf)
+ local pf = self.param_table[pid]
+ if pf == nil then
+ nerv.error("param with id %s not found", pid)
+ end
+ return pf:read_chunk(pid, global_conf)
+end