require 'torch' local threads = require 'threads' local status, tds = pcall(require, 'tds') tds = status and tds or nil local nthread = 4 local njob = 10 local msg = "hello from a satellite thread" threads.Threads.serialization('threads.sharedserialize') local x = {} local xh = tds and tds.hash() or {} local xs = {} local z = tds and tds.hash() or {} local D = 10 local K = tds and 100000 or 100 -- good luck in non-shared (30M) for i=1,njob do x[i] = torch.ones(D) xh[i] = torch.ones(D) xs[i] = torch.FloatStorage(D):fill(1) for j=1,K do z[(i-1)*K+j] = "blah" .. i .. j end end collectgarbage() collectgarbage() print('GO') local pool = threads.Threads( nthread, function(threadIdx) pcall(require, 'tds') print('starting a new thread/state number:', threadIdx) gmsg = msg -- we copy here an upvalue of the main thread end ) local jobdone = 0 for i=1,njob do pool:addjob( function() assert(x[i]:sum() == D) assert(xh[i]:sum() == D) assert(torch.FloatTensor(xs[i]):sum() == D) for j=1,K do assert(z[(i-1)*K+j] == "blah" .. i .. j) end x[i]:add(1) xh[i]:add(1) torch.FloatTensor(xs[i]):add(1) print(string.format('%s -- thread ID is %x', gmsg, __threadid)) collectgarbage() collectgarbage() return __threadid end, function(id) print(string.format("task %d finished (ran on thread ID %x)", i, id)) jobdone = jobdone + 1 end ) end for i=1,njob do pool:addjob( function() collectgarbage() collectgarbage() end ) end pool:synchronize() print(string.format('%d jobs done', jobdone)) pool:terminate() -- did we do the job in shared mode? for i=1,njob do assert(x[i]:sum() == 2*D) assert(xh[i]:sum() == 2*D) assert(torch.FloatTensor(xs[i]):sum() == 2*D) end -- serialize and zero x local str = torch.serialize(x) local strh = torch.serialize(xh) local strs = torch.serialize(xs) for i=1,njob do x[i]:zero() xh[i]:zero() xs[i]:fill(0) end -- dude, check that unserialized x does not point on x local y = torch.deserialize(str) local yh = torch.deserialize(strh) local ys = torch.deserialize(strs) for i=1,njob do assert(y[i]:sum() == 2*D) assert(yh[i]:sum() == 2*D) assert(torch.FloatTensor(ys[i]):sum() == 2*D) end pool:terminate() print('PASSED')