diff options
Diffstat (limited to 'fastnn/threads/test')
-rw-r--r-- | fastnn/threads/test/test-low-level.lua | 39 | ||||
-rw-r--r-- | fastnn/threads/test/test-threads-async.lua | 66 | ||||
-rw-r--r-- | fastnn/threads/test/test-threads-multiple.lua | 15 | ||||
-rw-r--r-- | fastnn/threads/test/test-threads-shared.lua | 111 | ||||
-rw-r--r-- | fastnn/threads/test/test-threads.lua | 20 |
5 files changed, 251 insertions, 0 deletions
diff --git a/fastnn/threads/test/test-low-level.lua b/fastnn/threads/test/test-low-level.lua new file mode 100644 index 0000000..aea31db --- /dev/null +++ b/fastnn/threads/test/test-low-level.lua @@ -0,0 +1,39 @@ +local t = require 'libthreads' + +local m = t.Mutex() +local c = t.Condition() +print(string.format('| main thread mutex: 0x%x', m:id())) +print(string.format('| main thread condition: 0x%x', c:id())) + +local code = string.format([[ + local t = require 'libthreads' + local c = t.Condition(%d) + print('|| hello from thread') + print(string.format('|| thread condition: 0x%%x', c:id())) + print('|| doing some stuff in thread...') + local x = 0 + for i=1,10000 do + for i=1,10000 do + x = x + math.sin(i) + end + x = x / 10000 + end + print(string.format('|| ...ok (x=%%f)', x)) + c:signal() +]], c:id()) + +print(code) + +local thread = t.Thread(code) + + +print('| waiting for thread...') +m:lock() +c:wait(m) +print('| thread finished!') + +thread:free() +m:free() +c:free() + +print('| done') diff --git a/fastnn/threads/test/test-threads-async.lua b/fastnn/threads/test/test-threads-async.lua new file mode 100644 index 0000000..68bcd35 --- /dev/null +++ b/fastnn/threads/test/test-threads-async.lua @@ -0,0 +1,66 @@ +local threads = require 'threads' + +local nthread = 4 +local njob = 100 + +local pool = threads.Threads( + nthread, + function(threadid) + print('starting a new thread/state number ' .. threadid) + end +) + + +local jobid = 0 +local result -- DO NOT put this in get +local function get() + + -- fill up the queue as much as we can + -- this will not block + while jobid < njob and pool:acceptsjob() do + jobid = jobid + 1 + + pool:addjob( + function(jobid) + print(string.format('thread ID %d is performing job %d', __threadid, jobid)) + return string.format("job output from job %d", jobid) + end, + + function(jobres) + result = jobres + end, + + jobid + ) + end + + -- is there still something to do? + if pool:hasjob() then + pool:dojob() -- yes? do it! + if pool:haserror() then -- check for errors + pool:synchronize() -- finish everything and throw error + end + return result + end + +end + +local jobdone = 0 +repeat + -- get something asynchronously + local res = get() + + -- do something with res (if any) + if res then + print(res) + jobdone = jobdone + 1 + end + +until not res -- until there is nothing remaining + +assert(jobid == 100) +assert(jobdone == 100) + +print('PASSED') + +pool:terminate() diff --git a/fastnn/threads/test/test-threads-multiple.lua b/fastnn/threads/test/test-threads-multiple.lua new file mode 100644 index 0000000..4429696 --- /dev/null +++ b/fastnn/threads/test/test-threads-multiple.lua @@ -0,0 +1,15 @@ +local threads = require 'threads' + +for i=1,1000 do + io.write(string.format('%04d.', tonumber(i))) + io.flush() + local pool = + threads.Threads( + 4, + function(threadid) + require 'torch' + end + ) +end +print() +print('PASSED') diff --git a/fastnn/threads/test/test-threads-shared.lua b/fastnn/threads/test/test-threads-shared.lua new file mode 100644 index 0000000..3d63851 --- /dev/null +++ b/fastnn/threads/test/test-threads-shared.lua @@ -0,0 +1,111 @@ +require 'torch' + +local threads = require 'threads' +local status, tds = pcall(require, 'tds') +tds = status and tds or nil + +local nthread = 4 +local njob = 10 +local msg = "hello from a satellite thread" + +threads.Threads.serialization('threads.sharedserialize') + +local x = {} +local xh = tds and tds.hash() or {} +local xs = {} +local z = tds and tds.hash() or {} +local D = 10 +local K = tds and 100000 or 100 -- good luck in non-shared (30M) +for i=1,njob do + x[i] = torch.ones(D) + xh[i] = torch.ones(D) + xs[i] = torch.FloatStorage(D):fill(1) + for j=1,K do + z[(i-1)*K+j] = "blah" .. i .. j + end +end +collectgarbage() +collectgarbage() + +print('GO') + +local pool = threads.Threads( + nthread, + function(threadIdx) + pcall(require, 'tds') + print('starting a new thread/state number:', threadIdx) + gmsg = msg -- we copy here an upvalue of the main thread + end +) + +local jobdone = 0 +for i=1,njob do + pool:addjob( + function() + assert(x[i]:sum() == D) + assert(xh[i]:sum() == D) + assert(torch.FloatTensor(xs[i]):sum() == D) + for j=1,K do + assert(z[(i-1)*K+j] == "blah" .. i .. j) + end + x[i]:add(1) + xh[i]:add(1) + torch.FloatTensor(xs[i]):add(1) + print(string.format('%s -- thread ID is %x', gmsg, __threadid)) + collectgarbage() + collectgarbage() + return __threadid + end, + + function(id) + print(string.format("task %d finished (ran on thread ID %x)", i, id)) + jobdone = jobdone + 1 + end + ) +end + +for i=1,njob do + pool:addjob( + function() + collectgarbage() + collectgarbage() + end + ) +end + +pool:synchronize() + +print(string.format('%d jobs done', jobdone)) + +pool:terminate() + +-- did we do the job in shared mode? +for i=1,njob do + assert(x[i]:sum() == 2*D) + assert(xh[i]:sum() == 2*D) + assert(torch.FloatTensor(xs[i]):sum() == 2*D) +end + +-- serialize and zero x +local str = torch.serialize(x) +local strh = torch.serialize(xh) +local strs = torch.serialize(xs) +for i=1,njob do + x[i]:zero() + xh[i]:zero() + xs[i]:fill(0) +end + +-- dude, check that unserialized x does not point on x +local y = torch.deserialize(str) +local yh = torch.deserialize(strh) +local ys = torch.deserialize(strs) +for i=1,njob do + assert(y[i]:sum() == 2*D) + assert(yh[i]:sum() == 2*D) + assert(torch.FloatTensor(ys[i]):sum() == 2*D) +end + +pool:terminate() + +print('PASSED') diff --git a/fastnn/threads/test/test-threads.lua b/fastnn/threads/test/test-threads.lua new file mode 100644 index 0000000..e49a381 --- /dev/null +++ b/fastnn/threads/test/test-threads.lua @@ -0,0 +1,20 @@ +local clib = require 'libthreads' + + +nthread = 1 + +str='ad;alkfakd;af' +code = [[ function func(str) print(str) end; print(str);]] +print(code) + +--thread = clib.Thread(code) + + +--thread:free() + + +require 'threads' + +tt = threads.Thread(code) + + |