aboutsummaryrefslogtreecommitdiff
path: root/nerv/nn/param_repo.lua
blob: a9eb0bd46c2aaa42b6606b3f4945b7ee1a70e6cc (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
--- Implements a concept that stores a collection of parameter groups.

--- The class for stroing a collection of parameter groups (`nerv.Param`).

local ParamRepo = nerv.class("nerv.ParamRepo")

--- The location constants for `loc_type`.
-- @field ON_DEVICE the storage is on device (GPU RAM)
-- @field ON_HOST the storage is on host (main RAM)

ParamRepo.LOC_TYPES = {
    ON_DEVICE = {},
    ON_HOST = {}
}

--- The constructor.
-- @param plist an array of parameters that will be initially in the collection
-- @param loc_type the type of storage location, see `nerv.ParamRepo.LOC_TYPES`

function ParamRepo:__init(plist, loc_type)
    self.params = {}
    self.loc_type = loc_type or ParamRepo.LOC_TYPES.ON_HOST
    local function make_checker(tname)
        return function (mat)
            if not nerv.is_type(mat, tname) then
                nerv.error("unexpected param type in repo specification")
            end
        end
    end
    self.make_copier = function (mat_type, copy_method)
        return function (mat)
            local target = mat_type(mat:nrow(), mat:ncol())
            mat[copy_method](mat, target)
            return target
        end
    end

    if self.loc_type == ParamRepo.LOC_TYPES.ON_HOST then
        self.checker = make_checker("nerv.MMatrix")
    else
        self.checker = make_checker("nerv.CuMatrix")
    end

    if plist ~= nil then
        for i, p in ipairs(plist) do
            p:check(self.checker)
            self.params[p.id] = p
        end
    end
end

--- Add a parameter to the collection.
-- @param p the parameter to be added

function ParamRepo:add(p)
    if self.params[p.id] ~= nil then
        nerv.error("duplicate params with the same id: %s", p.id)
    end
    p:check(self.checker)
    self.params[p.id] = p
end

--- Remove a parameter from the collection.
-- @param pid the id of the parameter to be removed

function ParamRepo:remove(pid)
    if self.params[pid] == nil then
        nerv.error("param %s does not exit", pid)
    end
    self.params[pid] = nil
end

--- Merge two or more parameter collecitons.
-- @param repos an array of parameter repos to be merged
-- @param loc_type the type of storage location, see `nerv.ParamRepo.LOC_TYPES`
-- @return the merged parameter collection (repo)

function ParamRepo.merge(repos, loc_type)

-- TODO: remove redundant `loc_type` and check the consistency of `loc_type`
-- from different merging param repos.

    local self = nerv.ParamRepo(nil, loc_type)
    for i, repo in ipairs(repos) do
        if not nerv.is_type(repo, "nerv.ParamRepo") then
            nerv.error("nerv.ParamRepo objects expected, got %s", repo)
        end
        for pid, p in pairs(repo.params) do
            self:add(p)
        end
    end
    return self
end

--- Import parameters from a NERV chunk file.
-- @param param_files an array of filenames of the files to be loaded from
-- @param gconf a table describing the computation state and providing
-- with some global settings
-- @param pids optional, an array of identifiers of the parameters to be imported

function ParamRepo:import(param_files, gconf, pids)
    if type(param_files) ~= "table" then
        nerv.error("param file table is need")
    end
    for i = 1, #param_files do
        local pf = nerv.ChunkFile(param_files[i], "r")
        for cid, cspec in pairs(pf.metadata) do
            if pids == nil or pids[cid] ~= nil then
                local p = pf:read_chunk(cid, gconf)
                if not nerv.is_type(p, "nerv.Param") then
                    nerv.error("param chunk is expected")
                end
                self:add(p)
            end
        end
    end
end

--- Export the parameter collection to a NERV chunk file.
-- @param param_file the output filename
-- @param pids optional, the identifiers of the parameters to be exported

function ParamRepo:export(param_file, pids)
    cf = nerv.ChunkFile(param_file, "w")
    if pids == nil then
        for id, p in pairs(self.params) do
            cf:write_chunk(p)
        end
    else
        for i, pid in ipairs(pids) do
            cf:write_chunk(self:get_param(pid))
        end
    end
    cf:close()
end

--- Test whether the collection has a parameter.
-- @param pid the identifier to be tested
-- @return true if a parameter with the identifier exists

function ParamRepo:has_param(pid)
    return self.params[pid] ~= nil
end

--- Retrieve the parameter by the identifier.
-- @param pid the identifier of the parameter to be retrieved
-- @return the retrieved parameter

function ParamRepo:get_param(pid)
    local p = self.params[pid]
    if p == nil then
        nerv.error("param with id %s not found", pid)
    end
    return p
end

--- Create a copy of the current collection.
-- @param loc_type the storage location of the new copy
-- @param gconf a table describing the computation state and providing
-- with some global settings
-- @param pids optional, an array of identifiers of the parameters to be copied

function ParamRepo:copy(loc_type, gconf, pids)
    local copier
    local target = nerv.ParamRepo(nil, loc_type)
    if loc_type == nil then
        loc_type = self.loc_type
    end
    if loc_type == ParamRepo.LOC_TYPES.ON_HOST then
        copier = self.make_copier(gconf.mmat_type, 'copy_toh')
    else
        copier = self.make_copier(gconf.cumat_type, 'copy_tod')
    end
    if pids == nil then
        for id, p in pairs(self.params) do
            target.params[id] = p:copy(copier)
        end
    else
        for i, pid in ipairs(pids) do
            target.params[pid] = self:get_param(pid):copy(copier)
        end
    end
    return target
end