1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
|
local ParamRepo = nerv.class("nerv.ParamRepo")
ParamRepo.LOC_TYPES = {
ON_DEVICE = {},
ON_HOST = {}
}
function ParamRepo:__init(plist, loc_type)
self.params = {}
self.loc_type = loc_type or ParamRepo.LOC_TYPES.ON_HOST
local function make_checker(tname)
return function (mat)
if not nerv.is_type(mat, tname) then
nerv.error("unexpected param type in repo specification")
end
end
end
self.make_copier = function (mat_type, copy_method)
return function (mat)
local target = mat_type(mat:nrow(), mat:ncol())
mat[copy_method](mat, target)
return target
end
end
if self.loc_type == ParamRepo.LOC_TYPES.ON_HOST then
self.checker = make_checker("nerv.MMatrix")
else
self.checker = make_checker("nerv.CuMatrix")
end
if plist ~= nil then
for i, p in ipairs(plist) do
p:check(self.checker)
self.params[p.id] = p
end
end
end
function ParamRepo:add(pid, p)
if self.params[pid] ~= nil then
nerv.error("duplicate params with the same id: %s", pid)
end
p:check(self.checker)
self.params[pid] = p
end
function ParamRepo:remove(pid, p)
if self.params[pid] == nil then
nerv.error("param %s does not exit", pid)
end
table.remove(self.params, pid)
end
function ParamRepo.merge(repos, loc_type)
local self = nerv.ParamRepo(nil, loc_type)
for i, repo in ipairs(repos) do
if not nerv.is_type(repo, "nerv.ParamRepo") then
nerv.error("nerv.ParamRepo objects expected, got %s", repo)
end
for pid, p in pairs(repo.params) do
self:add(pid, p)
end
end
return self
end
function ParamRepo:import(param_files, pids, gconf)
if type(param_files) ~= "table" then
nerv.error("param file table is need")
end
for i = 1, #param_files do
local pf = nerv.ChunkFile(param_files[i], "r")
for cid, cspec in pairs(pf.metadata) do
if pids == nil or pids[cid] ~= nil then
local p = pf:read_chunk(cid, gconf)
if not nerv.is_type(p, "nerv.Param") then
nerv.error("param chunk is expected")
end
self:add(cid, p)
end
end
end
end
function ParamRepo:export(param_file, pids)
cf = nerv.ChunkFile(param_file, "w")
if pids == nil then
for id, p in pairs(self.params) do
cf:write_chunk(p)
end
else
for i, pid in ipairs(pids) do
cf:write_chunk(self:get_param(pid))
end
end
cf:close()
end
function ParamRepo:has_param(pid)
return self.params[pid] ~= nil
end
function ParamRepo:get_param(pid)
local p = self.params[pid]
if p == nil then
nerv.error("param with id %s not found", pid)
end
return p
end
function ParamRepo:copy(loc_type, pids)
local copier
local target = nerv.ParamRepo(nil, loc_type)
if loc_type == nil then
loc_type = self.loc_type
end
if loc_type == ParamRepo.LOC_TYPES.ON_HOST then
copier = self.make_copier(gconf.mmat_type, 'copy_toh')
else
copier = self.make_copier(gconf.cumat_type, 'copy_tod')
end
if pids == nil then
for id, p in pairs(self.params) do
target.params[id] = p:copy(copier)
end
else
for i, pid in ipairs(pids) do
target.params[pid] = self:get_param(pid):copy(copier)
end
end
return target
end
|