aboutsummaryrefslogblamecommitdiff
path: root/nerv/nn/param_repo.lua
blob: a9eb0bd46c2aaa42b6606b3f4945b7ee1a70e6cc (plain) (tree)
1
2
3
4
5
6
7
8
9
10



                                                                          
                                              
 



                                                      




                       



                                                                               
                                          
                    





















                                                                         

                                    
                                 




                                 


                                      


                                                                 
       
                         
                         
   
 


                                                   
                              


                                                 
                          

   




                                                                               
                                         



                                                                            
                                              




                                                                       
                       




               





                                                                                 
                                                   





                                                      




                                                         
                           
               

           

   



                                                                         

                                           





                                          





                                               



                                                         
                                 
                                  

   



                                                             
                                 




                                                     
   
 





                                                                               
                                              




















                                                                 
--- Implements a concept that stores a collection of parameter groups.

--- The class for stroing a collection of parameter groups (`nerv.Param`).

local ParamRepo = nerv.class("nerv.ParamRepo")

--- The location constants for `loc_type`.
-- @field ON_DEVICE the storage is on device (GPU RAM)
-- @field ON_HOST the storage is on host (main RAM)

ParamRepo.LOC_TYPES = {
    ON_DEVICE = {},
    ON_HOST = {}
}

--- The constructor.
-- @param plist an array of parameters that will be initially in the collection
-- @param loc_type the type of storage location, see `nerv.ParamRepo.LOC_TYPES`

function ParamRepo:__init(plist, loc_type)
    self.params = {}
    self.loc_type = loc_type or ParamRepo.LOC_TYPES.ON_HOST
    local function make_checker(tname)
        return function (mat)
            if not nerv.is_type(mat, tname) then
                nerv.error("unexpected param type in repo specification")
            end
        end
    end
    self.make_copier = function (mat_type, copy_method)
        return function (mat)
            local target = mat_type(mat:nrow(), mat:ncol())
            mat[copy_method](mat, target)
            return target
        end
    end

    if self.loc_type == ParamRepo.LOC_TYPES.ON_HOST then
        self.checker = make_checker("nerv.MMatrix")
    else
        self.checker = make_checker("nerv.CuMatrix")
    end

    if plist ~= nil then
        for i, p in ipairs(plist) do
            p:check(self.checker)
            self.params[p.id] = p
        end
    end
end

--- Add a parameter to the collection.
-- @param p the parameter to be added

function ParamRepo:add(p)
    if self.params[p.id] ~= nil then
        nerv.error("duplicate params with the same id: %s", p.id)
    end
    p:check(self.checker)
    self.params[p.id] = p
end

--- Remove a parameter from the collection.
-- @param pid the id of the parameter to be removed

function ParamRepo:remove(pid)
    if self.params[pid] == nil then
        nerv.error("param %s does not exit", pid)
    end
    self.params[pid] = nil
end

--- Merge two or more parameter collecitons.
-- @param repos an array of parameter repos to be merged
-- @param loc_type the type of storage location, see `nerv.ParamRepo.LOC_TYPES`
-- @return the merged parameter collection (repo)

function ParamRepo.merge(repos, loc_type)

-- TODO: remove redundant `loc_type` and check the consistency of `loc_type`
-- from different merging param repos.

    local self = nerv.ParamRepo(nil, loc_type)
    for i, repo in ipairs(repos) do
        if not nerv.is_type(repo, "nerv.ParamRepo") then
            nerv.error("nerv.ParamRepo objects expected, got %s", repo)
        end
        for pid, p in pairs(repo.params) do
            self:add(p)
        end
    end
    return self
end

--- Import parameters from a NERV chunk file.
-- @param param_files an array of filenames of the files to be loaded from
-- @param gconf a table describing the computation state and providing
-- with some global settings
-- @param pids optional, an array of identifiers of the parameters to be imported

function ParamRepo:import(param_files, gconf, pids)
    if type(param_files) ~= "table" then
        nerv.error("param file table is need")
    end
    for i = 1, #param_files do
        local pf = nerv.ChunkFile(param_files[i], "r")
        for cid, cspec in pairs(pf.metadata) do
            if pids == nil or pids[cid] ~= nil then
                local p = pf:read_chunk(cid, gconf)
                if not nerv.is_type(p, "nerv.Param") then
                    nerv.error("param chunk is expected")
                end
                self:add(p)
            end
        end
    end
end

--- Export the parameter collection to a NERV chunk file.
-- @param param_file the output filename
-- @param pids optional, the identifiers of the parameters to be exported

function ParamRepo:export(param_file, pids)
    cf = nerv.ChunkFile(param_file, "w")
    if pids == nil then
        for id, p in pairs(self.params) do
            cf:write_chunk(p)
        end
    else
        for i, pid in ipairs(pids) do
            cf:write_chunk(self:get_param(pid))
        end
    end
    cf:close()
end

--- Test whether the collection has a parameter.
-- @param pid the identifier to be tested
-- @return true if a parameter with the identifier exists

function ParamRepo:has_param(pid)
    return self.params[pid] ~= nil
end

--- Retrieve the parameter by the identifier.
-- @param pid the identifier of the parameter to be retrieved
-- @return the retrieved parameter

function ParamRepo:get_param(pid)
    local p = self.params[pid]
    if p == nil then
        nerv.error("param with id %s not found", pid)
    end
    return p
end

--- Create a copy of the current collection.
-- @param loc_type the storage location of the new copy
-- @param gconf a table describing the computation state and providing
-- with some global settings
-- @param pids optional, an array of identifiers of the parameters to be copied

function ParamRepo:copy(loc_type, gconf, pids)
    local copier
    local target = nerv.ParamRepo(nil, loc_type)
    if loc_type == nil then
        loc_type = self.loc_type
    end
    if loc_type == ParamRepo.LOC_TYPES.ON_HOST then
        copier = self.make_copier(gconf.mmat_type, 'copy_toh')
    else
        copier = self.make_copier(gconf.cumat_type, 'copy_tod')
    end
    if pids == nil then
        for id, p in pairs(self.params) do
            target.params[id] = p:copy(copier)
        end
    else
        for i, pid in ipairs(pids) do
            target.params[pid] = self:get_param(pid):copy(copier)
        end
    end
    return target
end