local C = require 'libfastnn' local T = require 'libthreads' local ModelSync = nerv.class("fastnn.ModelSync") fastnn.CModelSync = C.CModelSync fastnn.Thread = T.Thread function ModelSync:__init(shareid) self.modelsync = fastnn.CModelSync(shareid) -- print(self.modelsync.initbuffer) --print(self.modelsync.setpos) --print(self.modelsync.initialized) --print(self.modelsync.weightfromd) -- print(self.modelsync.weighttod) -- print(self.modelsync.aaaa) -- print(self.modelsync.bbbb) -- print(self.modelsync.cccc) end function ModelSync:GetDim(nnet) local repo = nnet:get_params() local params = repo.params local dim = 0 for pid, ref in pairs(params) do if nerv.is_type(ref.trans, "nerv.Matrix") then dim = dim + ref.trans:nrow() * ref.trans:nstride() end end return dim end function ModelSync:Initialize(nnet) self:LockModel() if not self.modelsync:initialized() then dim = self:GetDim(nnet) self.modelsync:initbuffer(dim) self:WeightFromD(nnet) end self:UnLockModel() end function ModelSync:WeightFromD(nnet) local repo = nnet:get_params() local params = repo.params self.modelsync:setpos(0) for pid, ref in pairs(params) do if nerv.is_type(ref.trans, "nerv.Matrix") then self.modelsync:weightfromd(ref.trans) end end end function ModelSync:WeightToD(nnet) local repo = nnet:get_params() local params = repo.params self.modelsync:setpos(0) for pid, ref in pairs(params) do if nerv.is_type(ref.trans, "nerv.Matrix") then self.modelsync:weighttod(ref.trans) end end end function ModelSync:LockState() self.modelsync:lockstate() end function ModelSync:UnLockState() self.modelsync:unlockstate() end function ModelSync:LockModel() self.modelsync:lockmodel() end function ModelSync:UnLockModel() self.modelsync:unlockmodel() end function ModelSync:Id() return self.modelsync:id() end function ModelSync:ThreadCount() return self.modelsync:threadcount() end function ModelSync:SyncInc() return self.modelsync:syncinc() end function ModelSync:SyncDec() return self.modelsync:syncdec() end