aboutsummaryrefslogtreecommitdiff
path: root/io/sgd_buffer.lua
diff options
context:
space:
mode:
Diffstat (limited to 'io/sgd_buffer.lua')
-rw-r--r--io/sgd_buffer.lua14
1 files changed, 13 insertions, 1 deletions
diff --git a/io/sgd_buffer.lua b/io/sgd_buffer.lua
index dadcf67..bf72744 100644
--- a/io/sgd_buffer.lua
+++ b/io/sgd_buffer.lua
@@ -4,6 +4,10 @@ function SGDBuffer:__init(global_conf, buffer_conf)
self.gconf = global_conf
self.buffer_size = math.floor(buffer_conf.buffer_size /
global_conf.batch_size) * global_conf.batch_size
+ self.randomize = buffer_conf.randomize
+ if self.randomize == nil then
+ self.randomize = false
+ end
self.head = 0
self.tail = 0
self.readers = {}
@@ -35,7 +39,9 @@ function SGDBuffer:saturate()
nerv.error("buffer size is too small to contain leftovers")
end
buff.data:copy_from(buff.leftover, 0, lrow)
+ buff.leftover = nil
end
+ nerv.utils.printf("leftover: %d\n", lrow)
reader.tail = lrow
reader.has_leftover = false
end
@@ -73,6 +79,8 @@ function SGDBuffer:saturate()
end
self.tail = math.min(self.tail, reader.tail)
end
+ self.rand_map = nerv.MMatrixInt.perm_gen(self.tail) -- generate shuffled index
+ collectgarbage("collect")
return self.tail >= self.gconf.batch_size
end
@@ -90,7 +98,11 @@ function SGDBuffer:get_data()
for i, reader in ipairs(self.readers) do
for id, buff in pairs(reader.buffs) do
local batch = self.gconf.cumat_type(batch_size, buff.width)
- batch:copy_fromh(buff.data, self.head, self.head + batch_size)
+ if self.randomize then
+ batch:copy_rows_fromh_by_idx(buff.data, self.rand_map, self.head)
+ else
+ batch:copy_fromh(buff.data, self.head, self.head + batch_size)
+ end
res[id] = batch
end
end