aboutsummaryrefslogtreecommitdiff
path: root/nerv/nn/param_repo.lua
blob: 1e7a366eb8b9f7cdf8ad8a4f14af7e6c84af87ec (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
local ParamRepo = nerv.class("nerv.ParamRepo")

ParamRepo.LOC_TYPES = {
    ON_DEVICE = {},
    ON_HOST = {}
}

function ParamRepo:__init(plist, loc_type)
    self.params = {}
    self.loc_type = loc_type or ParamRepo.LOC_TYPES.ON_HOST
    local function make_checker(tname)
        return function (mat)
            if not nerv.is_type(mat, tname) then
                nerv.error("unexpected param type in repo specification")
            end
        end
    end
    self.make_copier = function (mat_type, copy_method)
        return function (mat)
            local target = mat_type(mat:nrow(), mat:ncol())
            mat[copy_method](mat, target)
            return target
        end
    end

    if self.loc_type == ParamRepo.LOC_TYPES.ON_HOST then
        self.checker = make_checker("nerv.MMatrix")
    else
        self.checker = make_checker("nerv.CuMatrix")
    end

    if plist ~= nil then
        for i, p in ipairs(plist) do
            p:check(self.checker)
            self.params[p.id] = p
        end
    end
end

function ParamRepo:add(pid, p)
    if self.params[pid] ~= nil then
        nerv.error("duplicate params with the same id: %s", pid)
    end
    p:check(self.checker)
    self.params[pid] = p
end

function ParamRepo:remove(pid, p)
    if self.params[pid] == nil then
        nerv.error("param %s does not exit", pid)
    end
    table.remove(self.params, pid)
end

function ParamRepo.merge(repos, loc_type)
    local self = nerv.ParamRepo(nil, loc_type)
    for i, repo in ipairs(repos) do
        if not nerv.is_type(repo, "nerv.ParamRepo") then
            nerv.error("nerv.ParamRepo objects expected, got %s", repo)
        end
        for pid, p in pairs(repo.params) do
            self:add(pid, p)
        end
    end
    return self
end

function ParamRepo:import(param_files, gconf, pids)
    if type(param_files) ~= "table" then
        nerv.error("param file table is need")
    end
    for i = 1, #param_files do
        local pf = nerv.ChunkFile(param_files[i], "r")
        for cid, cspec in pairs(pf.metadata) do
            if pids == nil or pids[cid] ~= nil then
                local p = pf:read_chunk(cid, gconf)
                if not nerv.is_type(p, "nerv.Param") then
                    nerv.error("param chunk is expected")
                end
                self:add(cid, p)
            end
        end
    end
end

function ParamRepo:export(param_file, pids)
    cf = nerv.ChunkFile(param_file, "w")
    if pids == nil then
        for id, p in pairs(self.params) do
            cf:write_chunk(p)
        end
    else
        for i, pid in ipairs(pids) do
            cf:write_chunk(self:get_param(pid))
        end
    end
    cf:close()
end

function ParamRepo:has_param(pid)
    return self.params[pid] ~= nil
end

function ParamRepo:get_param(pid)
    local p = self.params[pid]
    if p == nil then
        nerv.error("param with id %s not found", pid)
    end
    return p
end

function ParamRepo:copy(loc_type, gconf, pids)
    local copier
    local target = nerv.ParamRepo(nil, loc_type)
    if loc_type == nil then
        loc_type = self.loc_type
    end
    if loc_type == ParamRepo.LOC_TYPES.ON_HOST then
        copier = self.make_copier(gconf.mmat_type, 'copy_toh')
    else
        copier = self.make_copier(gconf.cumat_type, 'copy_tod')
    end
    if pids == nil then
        for id, p in pairs(self.params) do
            target.params[id] = p:copy(copier)
        end
    else
        for i, pid in ipairs(pids) do
            target.params[pid] = self:get_param(pid):copy(copier)
        end
    end
    return target
end