diff options
Diffstat (limited to 'fastnn/lib/modelsync.lua')
-rw-r--r-- | fastnn/lib/modelsync.lua | 107 |
1 files changed, 107 insertions, 0 deletions
diff --git a/fastnn/lib/modelsync.lua b/fastnn/lib/modelsync.lua new file mode 100644 index 0000000..a247562 --- /dev/null +++ b/fastnn/lib/modelsync.lua @@ -0,0 +1,107 @@ + +local C = require 'libfastnn' +local T = require 'libthreads' + +local ModelSync = nerv.class("fastnn.ModelSync") + +fastnn.CModelSync = C.CModelSync +fastnn.Thread = T.Thread + + +function ModelSync:__init(shareid) + self.modelsync = fastnn.CModelSync(shareid) +-- print(self.modelsync.initbuffer) + --print(self.modelsync.setpos) + --print(self.modelsync.initialized) + --print(self.modelsync.weightfromd) +-- print(self.modelsync.weighttod) +-- print(self.modelsync.aaaa) +-- print(self.modelsync.bbbb) +-- print(self.modelsync.cccc) +end + +function ModelSync:GetDim(nnet) + + local repo = nnet:get_params() + local params = repo.params + local dim = 0 + for pid, ref in pairs(params) do + if nerv.is_type(ref.trans, "nerv.Matrix") then + dim = dim + ref.trans:nrow() * ref.trans:nstride() + end + end + + return dim +end + + +function ModelSync:Initialize(nnet) + + self:LockModel() + + if not self.modelsync:initialized() then + dim = self:GetDim(nnet) + self.modelsync:initbuffer(dim) + self:WeightFromD(nnet) + end + + self:UnLockModel() +end + +function ModelSync:WeightFromD(nnet) + local repo = nnet:get_params() + local params = repo.params + self.modelsync:setpos(0) + for pid, ref in pairs(params) do + if nerv.is_type(ref.trans, "nerv.Matrix") then + self.modelsync:weightfromd(ref.trans) + end + end +end + +function ModelSync:WeightToD(nnet) + local repo = nnet:get_params() + local params = repo.params + self.modelsync:setpos(0) + for pid, ref in pairs(params) do + if nerv.is_type(ref.trans, "nerv.Matrix") then + self.modelsync:weighttod(ref.trans) + end + end +end + +function ModelSync:LockState() + self.modelsync:lockstate() +end + +function ModelSync:UnLockState() + self.modelsync:unlockstate() +end + + +function ModelSync:LockModel() + self.modelsync:lockmodel() +end + + +function ModelSync:UnLockModel() + self.modelsync:unlockmodel() +end + +function ModelSync:Id() + return self.modelsync:id() +end + +function ModelSync:ThreadCount() + return self.modelsync:threadcount() +end + +function ModelSync:SyncInc() + return self.modelsync:syncinc() +end + +function ModelSync:SyncDec() + return self.modelsync:syncdec() +end + + |