aboutsummaryrefslogtreecommitdiff
path: root/nerv/io/sgd_buffer.lua
diff options
context:
space:
mode:
authorDeterminant <ted.sybil@gmail.com>2015-08-28 13:21:52 +0800
committerDeterminant <ted.sybil@gmail.com>2015-08-28 13:21:52 +0800
commit1a9f63e351582f54fec7817927168cb1dbb0c1d6 (patch)
treec340b648c60d93b956be5956fa03233383e37e5d /nerv/io/sgd_buffer.lua
parent8bf9c7575ffeeabb3924e9e02a35afe187071fe2 (diff)
support gpu buffering
Diffstat (limited to 'nerv/io/sgd_buffer.lua')
-rw-r--r--nerv/io/sgd_buffer.lua34
1 files changed, 25 insertions, 9 deletions
diff --git a/nerv/io/sgd_buffer.lua b/nerv/io/sgd_buffer.lua
index f9d281c..3f854f0 100644
--- a/nerv/io/sgd_buffer.lua
+++ b/nerv/io/sgd_buffer.lua
@@ -8,13 +8,29 @@ function SGDBuffer:__init(global_conf, buffer_conf)
if self.randomize == nil then
self.randomize = false
end
+ local cumat_type = global_conf.cumat_type
+ if buffer_conf.use_gpu then
+ self.mat_type = cumat_type
+ self.copy_rows_from_by_idx = cumat_type.copy_rows_fromd_by_idx
+ self.copy_from = cumat_type.copy_fromd
+ self.copy_from_reader = cumat_type.copy_fromh
+ self.perm_gen = function (x)
+ return cumat_type.new_from_host(nerv.MMatrixFloat.perm_gen(x))
+ end
+ else
+ self.mat_type = global_conf.mmat_type
+ self.copy_rows_from_by_idx = cumat_type.copy_rows_fromh_by_idx
+ self.copy_from = cumat_type.copy_fromh
+ self.perm_gen = nerv.MMatrixFloat.perm_gen
+ self.copy_from_reader = self.mat_type.copy_from
+ end
self.head = 0
self.tail = 0
self.readers = {}
for i, reader_spec in ipairs(buffer_conf.readers) do
local buffs = {}
for id, width in pairs(reader_spec.data) do
- buffs[id] = {data = global_conf.mmat_type(self.buffer_size, width),
+ buffs[id] = {data = self.mat_type(self.buffer_size, width),
leftover = nil,
width = width}
end
@@ -41,7 +57,7 @@ function SGDBuffer:saturate()
buff.data:copy_from(buff.leftover, 0, lrow)
buff.leftover = nil
end
- nerv.printf("buffer leftover: %d\n", lrow)
+ nerv.info("buffer leftover: %d\n", lrow)
reader.tail = lrow
reader.has_leftover = false
end
@@ -65,21 +81,21 @@ function SGDBuffer:saturate()
if d == nil then
nerv.error("reader does not provide data for %s", id)
end
- buff.leftover = self.gconf.mmat_type(drow - remain,
- buff.width)
- buff.leftover:copy_from(d, remain, drow)
+ buff.leftover = self.mat_type(drow - remain,
+ buff.width)
+ self.copy_from_reader(buff.leftover, d, remain, drow)
end
drow = remain
reader.has_leftover = true
end
for id, buff in pairs(reader.buffs) do
- buff.data:copy_from(data[id], 0, drow, reader.tail)
+ self.copy_from_reader(buff.data, data[id], 0, drow, reader.tail)
end
reader.tail = reader.tail + drow
end
self.tail = math.min(self.tail, reader.tail)
end
- self.rand_map = nerv.MMatrixInt.perm_gen(self.tail) -- generate shuffled index
+ self.rand_map = self.perm_gen(self.tail) -- generate shuffled index
collectgarbage("collect")
return self.tail >= self.gconf.batch_size
end
@@ -101,9 +117,9 @@ function SGDBuffer:get_data()
for id, buff in pairs(reader.buffs) do
local batch = self.gconf.cumat_type(batch_size, buff.width)
if self.randomize then
- batch:copy_rows_fromh_by_idx(buff.data, self.rand_map, self.head)
+ self.copy_rows_from_by_idx(batch, buff.data, self.rand_map, self.head)
else
- batch:copy_fromh(buff.data, self.head, self.head + batch_size)
+ self.copy_from(batch, buff.data, self.head, self.head + batch_size)
end
res[id] = batch
end