aboutsummaryrefslogtreecommitdiff
path: root/fastnn/threads/test/test-threads-shared.lua
blob: 3d63851df1ae4c440972b180a502a49e9f22e6fc (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
require 'torch'

local threads = require 'threads'
local status, tds = pcall(require, 'tds')
tds = status and tds or nil

local nthread = 4
local njob = 10
local msg = "hello from a satellite thread"

threads.Threads.serialization('threads.sharedserialize')

local x = {}
local xh = tds and tds.hash() or {}
local xs = {}
local z = tds and tds.hash() or {}
local D = 10
local K = tds and 100000 or 100  -- good luck in non-shared (30M)
for i=1,njob do
   x[i] = torch.ones(D)
   xh[i] = torch.ones(D)
   xs[i] = torch.FloatStorage(D):fill(1)
   for j=1,K do
      z[(i-1)*K+j] = "blah" .. i .. j
   end
end
collectgarbage()
collectgarbage()

print('GO')

local pool = threads.Threads(
   nthread,
   function(threadIdx)
      pcall(require, 'tds')
      print('starting a new thread/state number:', threadIdx)
      gmsg = msg -- we copy here an upvalue of the main thread
   end
)

local jobdone = 0
for i=1,njob do
   pool:addjob(
      function()
         assert(x[i]:sum() == D)
         assert(xh[i]:sum() == D)
         assert(torch.FloatTensor(xs[i]):sum() == D)
         for j=1,K do
            assert(z[(i-1)*K+j] == "blah" .. i .. j)
         end
         x[i]:add(1)
         xh[i]:add(1)
         torch.FloatTensor(xs[i]):add(1)
         print(string.format('%s -- thread ID is %x', gmsg, __threadid))
         collectgarbage()
         collectgarbage()
         return __threadid
      end,

      function(id)
         print(string.format("task %d finished (ran on thread ID %x)", i, id))
         jobdone = jobdone + 1
      end
   )
end

for i=1,njob do
   pool:addjob(
      function()
         collectgarbage()
         collectgarbage()
      end
   )
end

pool:synchronize()

print(string.format('%d jobs done', jobdone))

pool:terminate()

-- did we do the job in shared mode?
for i=1,njob do
   assert(x[i]:sum() == 2*D)
   assert(xh[i]:sum() == 2*D)
   assert(torch.FloatTensor(xs[i]):sum() == 2*D)
end

-- serialize and zero x
local str = torch.serialize(x)
local strh = torch.serialize(xh)
local strs = torch.serialize(xs)
for i=1,njob do
   x[i]:zero()
   xh[i]:zero()
   xs[i]:fill(0)
end

-- dude, check that unserialized x does not point on x
local y = torch.deserialize(str)
local yh = torch.deserialize(strh)
local ys = torch.deserialize(strs)
for i=1,njob do
   assert(y[i]:sum() == 2*D)
   assert(yh[i]:sum() == 2*D)
   assert(torch.FloatTensor(ys[i]):sum() == 2*D)
end

pool:terminate()

print('PASSED')