aboutsummaryrefslogtreecommitdiff
path: root/fastnn/lib/modelsync.lua
blob: a247562bc4217e259854e4d54e3c1e4d229e583b (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
local C =  require 'libfastnn'
local T =  require 'libthreads'

local ModelSync = nerv.class("fastnn.ModelSync")

fastnn.CModelSync = C.CModelSync
fastnn.Thread = T.Thread


function ModelSync:__init(shareid)
	self.modelsync = fastnn.CModelSync(shareid)
--	print(self.modelsync.initbuffer)
	--print(self.modelsync.setpos)
	--print(self.modelsync.initialized)
	--print(self.modelsync.weightfromd)
--	print(self.modelsync.weighttod)
--	print(self.modelsync.aaaa)
--	print(self.modelsync.bbbb)
--	print(self.modelsync.cccc)
end

function ModelSync:GetDim(nnet)

    local repo = nnet:get_params()
    local params = repo.params  
    local dim = 0 
    for pid, ref in pairs(params) do
          if nerv.is_type(ref.trans, "nerv.Matrix") then
                 dim = dim + ref.trans:nrow() * ref.trans:nstride()
           end 
    end 
	
    return dim
end


function ModelSync:Initialize(nnet)

	self:LockModel()	

	if not self.modelsync:initialized() then
    		dim = self:GetDim(nnet)
    		self.modelsync:initbuffer(dim)
    		self:WeightFromD(nnet)
	end

	self:UnLockModel()
end

function ModelSync:WeightFromD(nnet)
	local repo = nnet:get_params()
	local params = repo.params	
	self.modelsync:setpos(0)
	for pid, ref in pairs(params) do
		if nerv.is_type(ref.trans, "nerv.Matrix") then
			self.modelsync:weightfromd(ref.trans)
		end
	end
end
	
function ModelSync:WeightToD(nnet)
	local repo = nnet:get_params()
	local params = repo.params	
	self.modelsync:setpos(0)
	for pid, ref in pairs(params) do
		if nerv.is_type(ref.trans, "nerv.Matrix") then
			self.modelsync:weighttod(ref.trans)
		end
	end
end

function ModelSync:LockState()
	self.modelsync:lockstate()	
end

function ModelSync:UnLockState()
	self.modelsync:unlockstate()	
end


function ModelSync:LockModel()
	self.modelsync:lockmodel()	
end

	
function ModelSync:UnLockModel()
	self.modelsync:unlockmodel()	
end

function ModelSync:Id()
	return self.modelsync:id()
end

function ModelSync:ThreadCount()
	return self.modelsync:threadcount()
end

function ModelSync:SyncInc()
	return self.modelsync:syncinc()
end

function ModelSync:SyncDec()
	return self.modelsync:syncdec()
end