aboutsummaryrefslogtreecommitdiff
path: root/fastnn/threads/test/test-threads-shared.lua
diff options
context:
space:
mode:
Diffstat (limited to 'fastnn/threads/test/test-threads-shared.lua')
-rw-r--r--fastnn/threads/test/test-threads-shared.lua111
1 files changed, 111 insertions, 0 deletions
diff --git a/fastnn/threads/test/test-threads-shared.lua b/fastnn/threads/test/test-threads-shared.lua
new file mode 100644
index 0000000..3d63851
--- /dev/null
+++ b/fastnn/threads/test/test-threads-shared.lua
@@ -0,0 +1,111 @@
+require 'torch'
+
+local threads = require 'threads'
+local status, tds = pcall(require, 'tds')
+tds = status and tds or nil
+
+local nthread = 4
+local njob = 10
+local msg = "hello from a satellite thread"
+
+threads.Threads.serialization('threads.sharedserialize')
+
+local x = {}
+local xh = tds and tds.hash() or {}
+local xs = {}
+local z = tds and tds.hash() or {}
+local D = 10
+local K = tds and 100000 or 100 -- good luck in non-shared (30M)
+for i=1,njob do
+ x[i] = torch.ones(D)
+ xh[i] = torch.ones(D)
+ xs[i] = torch.FloatStorage(D):fill(1)
+ for j=1,K do
+ z[(i-1)*K+j] = "blah" .. i .. j
+ end
+end
+collectgarbage()
+collectgarbage()
+
+print('GO')
+
+local pool = threads.Threads(
+ nthread,
+ function(threadIdx)
+ pcall(require, 'tds')
+ print('starting a new thread/state number:', threadIdx)
+ gmsg = msg -- we copy here an upvalue of the main thread
+ end
+)
+
+local jobdone = 0
+for i=1,njob do
+ pool:addjob(
+ function()
+ assert(x[i]:sum() == D)
+ assert(xh[i]:sum() == D)
+ assert(torch.FloatTensor(xs[i]):sum() == D)
+ for j=1,K do
+ assert(z[(i-1)*K+j] == "blah" .. i .. j)
+ end
+ x[i]:add(1)
+ xh[i]:add(1)
+ torch.FloatTensor(xs[i]):add(1)
+ print(string.format('%s -- thread ID is %x', gmsg, __threadid))
+ collectgarbage()
+ collectgarbage()
+ return __threadid
+ end,
+
+ function(id)
+ print(string.format("task %d finished (ran on thread ID %x)", i, id))
+ jobdone = jobdone + 1
+ end
+ )
+end
+
+for i=1,njob do
+ pool:addjob(
+ function()
+ collectgarbage()
+ collectgarbage()
+ end
+ )
+end
+
+pool:synchronize()
+
+print(string.format('%d jobs done', jobdone))
+
+pool:terminate()
+
+-- did we do the job in shared mode?
+for i=1,njob do
+ assert(x[i]:sum() == 2*D)
+ assert(xh[i]:sum() == 2*D)
+ assert(torch.FloatTensor(xs[i]):sum() == 2*D)
+end
+
+-- serialize and zero x
+local str = torch.serialize(x)
+local strh = torch.serialize(xh)
+local strs = torch.serialize(xs)
+for i=1,njob do
+ x[i]:zero()
+ xh[i]:zero()
+ xs[i]:fill(0)
+end
+
+-- dude, check that unserialized x does not point on x
+local y = torch.deserialize(str)
+local yh = torch.deserialize(strh)
+local ys = torch.deserialize(strs)
+for i=1,njob do
+ assert(y[i]:sum() == 2*D)
+ assert(yh[i]:sum() == 2*D)
+ assert(torch.FloatTensor(ys[i]):sum() == 2*D)
+end
+
+pool:terminate()
+
+print('PASSED')