blob: a247562bc4217e259854e4d54e3c1e4d229e583b (
plain) (
blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
|
local C = require 'libfastnn'
local T = require 'libthreads'
local ModelSync = nerv.class("fastnn.ModelSync")
fastnn.CModelSync = C.CModelSync
fastnn.Thread = T.Thread
function ModelSync:__init(shareid)
self.modelsync = fastnn.CModelSync(shareid)
-- print(self.modelsync.initbuffer)
--print(self.modelsync.setpos)
--print(self.modelsync.initialized)
--print(self.modelsync.weightfromd)
-- print(self.modelsync.weighttod)
-- print(self.modelsync.aaaa)
-- print(self.modelsync.bbbb)
-- print(self.modelsync.cccc)
end
function ModelSync:GetDim(nnet)
local repo = nnet:get_params()
local params = repo.params
local dim = 0
for pid, ref in pairs(params) do
if nerv.is_type(ref.trans, "nerv.Matrix") then
dim = dim + ref.trans:nrow() * ref.trans:nstride()
end
end
return dim
end
function ModelSync:Initialize(nnet)
self:LockModel()
if not self.modelsync:initialized() then
dim = self:GetDim(nnet)
self.modelsync:initbuffer(dim)
self:WeightFromD(nnet)
end
self:UnLockModel()
end
function ModelSync:WeightFromD(nnet)
local repo = nnet:get_params()
local params = repo.params
self.modelsync:setpos(0)
for pid, ref in pairs(params) do
if nerv.is_type(ref.trans, "nerv.Matrix") then
self.modelsync:weightfromd(ref.trans)
end
end
end
function ModelSync:WeightToD(nnet)
local repo = nnet:get_params()
local params = repo.params
self.modelsync:setpos(0)
for pid, ref in pairs(params) do
if nerv.is_type(ref.trans, "nerv.Matrix") then
self.modelsync:weighttod(ref.trans)
end
end
end
function ModelSync:LockState()
self.modelsync:lockstate()
end
function ModelSync:UnLockState()
self.modelsync:unlockstate()
end
function ModelSync:LockModel()
self.modelsync:lockmodel()
end
function ModelSync:UnLockModel()
self.modelsync:unlockmodel()
end
function ModelSync:Id()
return self.modelsync:id()
end
function ModelSync:ThreadCount()
return self.modelsync:threadcount()
end
function ModelSync:SyncInc()
return self.modelsync:syncinc()
end
function ModelSync:SyncDec()
return self.modelsync:syncdec()
end
|