1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
|
--- Implements a concept that stores a collection of parameter groups.
--- The class for stroing a collection of parameter groups (`nerv.Param`).
local ParamRepo = nerv.class("nerv.ParamRepo")
--- The location constants for `loc_type`.
-- @field ON_DEVICE the storage is on device (GPU RAM)
-- @field ON_HOST the storage is on host (main RAM)
ParamRepo.LOC_TYPES = {
ON_DEVICE = {},
ON_HOST = {}
}
--- The constructor.
-- @param plist an array of parameters that will be initially in the collection
-- @param loc_type the type of storage location, see `nerv.ParamRepo.LOC_TYPES`
function ParamRepo:__init(plist, loc_type)
self.params = {}
self.loc_type = loc_type or ParamRepo.LOC_TYPES.ON_HOST
local function make_checker(tname)
return function (mat)
if not nerv.is_type(mat, tname) then
nerv.error("unexpected param type in repo specification")
end
end
end
self.make_copier = function (mat_type, copy_method)
return function (mat)
local target = mat_type(mat:nrow(), mat:ncol())
mat[copy_method](mat, target)
return target
end
end
if self.loc_type == ParamRepo.LOC_TYPES.ON_HOST then
self.checker = make_checker("nerv.MMatrix")
else
self.checker = make_checker("nerv.CuMatrix")
end
if plist ~= nil then
for i, p in ipairs(plist) do
p:check(self.checker)
self.params[p.id] = p
end
end
end
--- Add a parameter to the collection.
-- @param p the parameter to be added
function ParamRepo:add(p)
if self.params[p.id] ~= nil then
nerv.error("duplicate params with the same id: %s", p.id)
end
p:check(self.checker)
self.params[p.id] = p
end
--- Remove a parameter from the collection.
-- @param pid the id of the parameter to be removed
function ParamRepo:remove(pid)
if self.params[pid] == nil then
nerv.error("param %s does not exit", pid)
end
self.params[pid] = nil
end
--- Merge two or more parameter collecitons.
-- @param repos an array of parameter repos to be merged
-- @param loc_type the type of storage location, see `nerv.ParamRepo.LOC_TYPES`
-- @return the merged parameter collection (repo)
function ParamRepo.merge(repos, loc_type)
-- TODO: remove redundant `loc_type` and check the consistency of `loc_type`
-- from different merging param repos.
local self = nerv.ParamRepo(nil, loc_type)
for i, repo in ipairs(repos) do
if not nerv.is_type(repo, "nerv.ParamRepo") then
nerv.error("nerv.ParamRepo objects expected, got %s", repo)
end
for pid, p in pairs(repo.params) do
self:add(p)
end
end
return self
end
--- Import parameters from a NERV chunk file.
-- @param param_files an array of filenames of the files to be loaded from
-- @param gconf a table describing the computation state and providing
-- with some global settings
-- @param pids optional, an array of identifiers of the parameters to be imported
function ParamRepo:import(param_files, gconf, pids)
if type(param_files) ~= "table" then
nerv.error("param file table is need")
end
for i = 1, #param_files do
local pf = nerv.ChunkFile(param_files[i], "r")
for cid, cspec in pairs(pf.metadata) do
if pids == nil or pids[cid] ~= nil then
local p = pf:read_chunk(cid, gconf)
if not nerv.is_type(p, "nerv.Param") then
nerv.error("param chunk is expected")
end
self:add(p)
end
end
end
end
--- Export the parameter collection to a NERV chunk file.
-- @param param_file the output filename
-- @param pids optional, the identifiers of the parameters to be exported
function ParamRepo:export(param_file, pids)
cf = nerv.ChunkFile(param_file, "w")
if pids == nil then
for id, p in pairs(self.params) do
cf:write_chunk(p)
end
else
for i, pid in ipairs(pids) do
cf:write_chunk(self:get_param(pid))
end
end
cf:close()
end
--- Test whether the collection has a parameter.
-- @param pid the identifier to be tested
-- @return true if a parameter with the identifier exists
function ParamRepo:has_param(pid)
return self.params[pid] ~= nil
end
--- Retrieve the parameter by the identifier.
-- @param pid the identifier of the parameter to be retrieved
-- @return the retrieved parameter
function ParamRepo:get_param(pid)
local p = self.params[pid]
if p == nil then
nerv.error("param with id %s not found", pid)
end
return p
end
--- Create a copy of the current collection.
-- @param loc_type the storage location of the new copy
-- @param gconf a table describing the computation state and providing
-- with some global settings
-- @param pids optional, an array of identifiers of the parameters to be copied
function ParamRepo:copy(loc_type, gconf, pids)
local copier
local target = nerv.ParamRepo(nil, loc_type)
if loc_type == nil then
loc_type = self.loc_type
end
if loc_type == ParamRepo.LOC_TYPES.ON_HOST then
copier = self.make_copier(gconf.mmat_type, 'copy_toh')
else
copier = self.make_copier(gconf.cumat_type, 'copy_tod')
end
if pids == nil then
for id, p in pairs(self.params) do
target.params[id] = p:copy(copier)
end
else
for i, pid in ipairs(pids) do
target.params[pid] = self:get_param(pid):copy(copier)
end
end
return target
end
|