aboutsummaryrefslogblamecommitdiff
path: root/fastnn/lib/modelsync.lua
blob: a247562bc4217e259854e4d54e3c1e4d229e583b (plain) (tree)










































































































                                                                   
local C =  require 'libfastnn'
local T =  require 'libthreads'

local ModelSync = nerv.class("fastnn.ModelSync")

fastnn.CModelSync = C.CModelSync
fastnn.Thread = T.Thread


function ModelSync:__init(shareid)
	self.modelsync = fastnn.CModelSync(shareid)
--	print(self.modelsync.initbuffer)
	--print(self.modelsync.setpos)
	--print(self.modelsync.initialized)
	--print(self.modelsync.weightfromd)
--	print(self.modelsync.weighttod)
--	print(self.modelsync.aaaa)
--	print(self.modelsync.bbbb)
--	print(self.modelsync.cccc)
end

function ModelSync:GetDim(nnet)

    local repo = nnet:get_params()
    local params = repo.params  
    local dim = 0 
    for pid, ref in pairs(params) do
          if nerv.is_type(ref.trans, "nerv.Matrix") then
                 dim = dim + ref.trans:nrow() * ref.trans:nstride()
           end 
    end 
	
    return dim
end


function ModelSync:Initialize(nnet)

	self:LockModel()	

	if not self.modelsync:initialized() then
    		dim = self:GetDim(nnet)
    		self.modelsync:initbuffer(dim)
    		self:WeightFromD(nnet)
	end

	self:UnLockModel()
end

function ModelSync:WeightFromD(nnet)
	local repo = nnet:get_params()
	local params = repo.params	
	self.modelsync:setpos(0)
	for pid, ref in pairs(params) do
		if nerv.is_type(ref.trans, "nerv.Matrix") then
			self.modelsync:weightfromd(ref.trans)
		end
	end
end
	
function ModelSync:WeightToD(nnet)
	local repo = nnet:get_params()
	local params = repo.params	
	self.modelsync:setpos(0)
	for pid, ref in pairs(params) do
		if nerv.is_type(ref.trans, "nerv.Matrix") then
			self.modelsync:weighttod(ref.trans)
		end
	end
end

function ModelSync:LockState()
	self.modelsync:lockstate()	
end

function ModelSync:UnLockState()
	self.modelsync:unlockstate()	
end


function ModelSync:LockModel()
	self.modelsync:lockmodel()	
end

	
function ModelSync:UnLockModel()
	self.modelsync:unlockmodel()	
end

function ModelSync:Id()
	return self.modelsync:id()
end

function ModelSync:ThreadCount()
	return self.modelsync:threadcount()
end

function ModelSync:SyncInc()
	return self.modelsync:syncinc()
end

function ModelSync:SyncDec()
	return self.modelsync:syncdec()
end