blob: a247562bc4217e259854e4d54e3c1e4d229e583b (
plain) (
tree)
|
|
local C = require 'libfastnn'
local T = require 'libthreads'
local ModelSync = nerv.class("fastnn.ModelSync")
fastnn.CModelSync = C.CModelSync
fastnn.Thread = T.Thread
function ModelSync:__init(shareid)
self.modelsync = fastnn.CModelSync(shareid)
-- print(self.modelsync.initbuffer)
--print(self.modelsync.setpos)
--print(self.modelsync.initialized)
--print(self.modelsync.weightfromd)
-- print(self.modelsync.weighttod)
-- print(self.modelsync.aaaa)
-- print(self.modelsync.bbbb)
-- print(self.modelsync.cccc)
end
function ModelSync:GetDim(nnet)
local repo = nnet:get_params()
local params = repo.params
local dim = 0
for pid, ref in pairs(params) do
if nerv.is_type(ref.trans, "nerv.Matrix") then
dim = dim + ref.trans:nrow() * ref.trans:nstride()
end
end
return dim
end
function ModelSync:Initialize(nnet)
self:LockModel()
if not self.modelsync:initialized() then
dim = self:GetDim(nnet)
self.modelsync:initbuffer(dim)
self:WeightFromD(nnet)
end
self:UnLockModel()
end
function ModelSync:WeightFromD(nnet)
local repo = nnet:get_params()
local params = repo.params
self.modelsync:setpos(0)
for pid, ref in pairs(params) do
if nerv.is_type(ref.trans, "nerv.Matrix") then
self.modelsync:weightfromd(ref.trans)
end
end
end
function ModelSync:WeightToD(nnet)
local repo = nnet:get_params()
local params = repo.params
self.modelsync:setpos(0)
for pid, ref in pairs(params) do
if nerv.is_type(ref.trans, "nerv.Matrix") then
self.modelsync:weighttod(ref.trans)
end
end
end
function ModelSync:LockState()
self.modelsync:lockstate()
end
function ModelSync:UnLockState()
self.modelsync:unlockstate()
end
function ModelSync:LockModel()
self.modelsync:lockmodel()
end
function ModelSync:UnLockModel()
self.modelsync:unlockmodel()
end
function ModelSync:Id()
return self.modelsync:id()
end
function ModelSync:ThreadCount()
return self.modelsync:threadcount()
end
function ModelSync:SyncInc()
return self.modelsync:syncinc()
end
function ModelSync:SyncDec()
return self.modelsync:syncdec()
end
|