aboutsummaryrefslogtreecommitdiff
path: root/fastnn/threads/test
diff options
context:
space:
mode:
authoruphantom <[email protected]>2015-08-28 17:41:14 +0800
committeruphantom <[email protected]>2015-08-28 17:41:14 +0800
commita68d3c982ed0dd4ef5bbc9e0c22b9ecf9565b924 (patch)
treebc59ef1a69b32276cc97454fbc3c881fc8c518cc /fastnn/threads/test
parent1a9f63e351582f54fec7817927168cb1dbb0c1d6 (diff)
fastnn version 1.0
Diffstat (limited to 'fastnn/threads/test')
-rw-r--r--fastnn/threads/test/test-low-level.lua39
-rw-r--r--fastnn/threads/test/test-threads-async.lua66
-rw-r--r--fastnn/threads/test/test-threads-multiple.lua15
-rw-r--r--fastnn/threads/test/test-threads-shared.lua111
-rw-r--r--fastnn/threads/test/test-threads.lua20
5 files changed, 251 insertions, 0 deletions
diff --git a/fastnn/threads/test/test-low-level.lua b/fastnn/threads/test/test-low-level.lua
new file mode 100644
index 0000000..aea31db
--- /dev/null
+++ b/fastnn/threads/test/test-low-level.lua
@@ -0,0 +1,39 @@
+local t = require 'libthreads'
+
+local m = t.Mutex()
+local c = t.Condition()
+print(string.format('| main thread mutex: 0x%x', m:id()))
+print(string.format('| main thread condition: 0x%x', c:id()))
+
+local code = string.format([[
+ local t = require 'libthreads'
+ local c = t.Condition(%d)
+ print('|| hello from thread')
+ print(string.format('|| thread condition: 0x%%x', c:id()))
+ print('|| doing some stuff in thread...')
+ local x = 0
+ for i=1,10000 do
+ for i=1,10000 do
+ x = x + math.sin(i)
+ end
+ x = x / 10000
+ end
+ print(string.format('|| ...ok (x=%%f)', x))
+ c:signal()
+]], c:id())
+
+print(code)
+
+local thread = t.Thread(code)
+
+
+print('| waiting for thread...')
+m:lock()
+c:wait(m)
+print('| thread finished!')
+
+thread:free()
+m:free()
+c:free()
+
+print('| done')
diff --git a/fastnn/threads/test/test-threads-async.lua b/fastnn/threads/test/test-threads-async.lua
new file mode 100644
index 0000000..68bcd35
--- /dev/null
+++ b/fastnn/threads/test/test-threads-async.lua
@@ -0,0 +1,66 @@
+local threads = require 'threads'
+
+local nthread = 4
+local njob = 100
+
+local pool = threads.Threads(
+ nthread,
+ function(threadid)
+ print('starting a new thread/state number ' .. threadid)
+ end
+)
+
+
+local jobid = 0
+local result -- DO NOT put this in get
+local function get()
+
+ -- fill up the queue as much as we can
+ -- this will not block
+ while jobid < njob and pool:acceptsjob() do
+ jobid = jobid + 1
+
+ pool:addjob(
+ function(jobid)
+ print(string.format('thread ID %d is performing job %d', __threadid, jobid))
+ return string.format("job output from job %d", jobid)
+ end,
+
+ function(jobres)
+ result = jobres
+ end,
+
+ jobid
+ )
+ end
+
+ -- is there still something to do?
+ if pool:hasjob() then
+ pool:dojob() -- yes? do it!
+ if pool:haserror() then -- check for errors
+ pool:synchronize() -- finish everything and throw error
+ end
+ return result
+ end
+
+end
+
+local jobdone = 0
+repeat
+ -- get something asynchronously
+ local res = get()
+
+ -- do something with res (if any)
+ if res then
+ print(res)
+ jobdone = jobdone + 1
+ end
+
+until not res -- until there is nothing remaining
+
+assert(jobid == 100)
+assert(jobdone == 100)
+
+print('PASSED')
+
+pool:terminate()
diff --git a/fastnn/threads/test/test-threads-multiple.lua b/fastnn/threads/test/test-threads-multiple.lua
new file mode 100644
index 0000000..4429696
--- /dev/null
+++ b/fastnn/threads/test/test-threads-multiple.lua
@@ -0,0 +1,15 @@
+local threads = require 'threads'
+
+for i=1,1000 do
+ io.write(string.format('%04d.', tonumber(i)))
+ io.flush()
+ local pool =
+ threads.Threads(
+ 4,
+ function(threadid)
+ require 'torch'
+ end
+ )
+end
+print()
+print('PASSED')
diff --git a/fastnn/threads/test/test-threads-shared.lua b/fastnn/threads/test/test-threads-shared.lua
new file mode 100644
index 0000000..3d63851
--- /dev/null
+++ b/fastnn/threads/test/test-threads-shared.lua
@@ -0,0 +1,111 @@
+require 'torch'
+
+local threads = require 'threads'
+local status, tds = pcall(require, 'tds')
+tds = status and tds or nil
+
+local nthread = 4
+local njob = 10
+local msg = "hello from a satellite thread"
+
+threads.Threads.serialization('threads.sharedserialize')
+
+local x = {}
+local xh = tds and tds.hash() or {}
+local xs = {}
+local z = tds and tds.hash() or {}
+local D = 10
+local K = tds and 100000 or 100 -- good luck in non-shared (30M)
+for i=1,njob do
+ x[i] = torch.ones(D)
+ xh[i] = torch.ones(D)
+ xs[i] = torch.FloatStorage(D):fill(1)
+ for j=1,K do
+ z[(i-1)*K+j] = "blah" .. i .. j
+ end
+end
+collectgarbage()
+collectgarbage()
+
+print('GO')
+
+local pool = threads.Threads(
+ nthread,
+ function(threadIdx)
+ pcall(require, 'tds')
+ print('starting a new thread/state number:', threadIdx)
+ gmsg = msg -- we copy here an upvalue of the main thread
+ end
+)
+
+local jobdone = 0
+for i=1,njob do
+ pool:addjob(
+ function()
+ assert(x[i]:sum() == D)
+ assert(xh[i]:sum() == D)
+ assert(torch.FloatTensor(xs[i]):sum() == D)
+ for j=1,K do
+ assert(z[(i-1)*K+j] == "blah" .. i .. j)
+ end
+ x[i]:add(1)
+ xh[i]:add(1)
+ torch.FloatTensor(xs[i]):add(1)
+ print(string.format('%s -- thread ID is %x', gmsg, __threadid))
+ collectgarbage()
+ collectgarbage()
+ return __threadid
+ end,
+
+ function(id)
+ print(string.format("task %d finished (ran on thread ID %x)", i, id))
+ jobdone = jobdone + 1
+ end
+ )
+end
+
+for i=1,njob do
+ pool:addjob(
+ function()
+ collectgarbage()
+ collectgarbage()
+ end
+ )
+end
+
+pool:synchronize()
+
+print(string.format('%d jobs done', jobdone))
+
+pool:terminate()
+
+-- did we do the job in shared mode?
+for i=1,njob do
+ assert(x[i]:sum() == 2*D)
+ assert(xh[i]:sum() == 2*D)
+ assert(torch.FloatTensor(xs[i]):sum() == 2*D)
+end
+
+-- serialize and zero x
+local str = torch.serialize(x)
+local strh = torch.serialize(xh)
+local strs = torch.serialize(xs)
+for i=1,njob do
+ x[i]:zero()
+ xh[i]:zero()
+ xs[i]:fill(0)
+end
+
+-- dude, check that unserialized x does not point on x
+local y = torch.deserialize(str)
+local yh = torch.deserialize(strh)
+local ys = torch.deserialize(strs)
+for i=1,njob do
+ assert(y[i]:sum() == 2*D)
+ assert(yh[i]:sum() == 2*D)
+ assert(torch.FloatTensor(ys[i]):sum() == 2*D)
+end
+
+pool:terminate()
+
+print('PASSED')
diff --git a/fastnn/threads/test/test-threads.lua b/fastnn/threads/test/test-threads.lua
new file mode 100644
index 0000000..e49a381
--- /dev/null
+++ b/fastnn/threads/test/test-threads.lua
@@ -0,0 +1,20 @@
+local clib = require 'libthreads'
+
+
+nthread = 1
+
+str='ad;alkfakd;af'
+code = [[ function func(str) print(str) end; print(str);]]
+print(code)
+
+--thread = clib.Thread(code)
+
+
+--thread:free()
+
+
+require 'threads'
+
+tt = threads.Thread(code)
+
+