diff options
Diffstat (limited to 'io/sgd_buffer.lua')
-rw-r--r-- | io/sgd_buffer.lua | 14 |
1 files changed, 13 insertions, 1 deletions
diff --git a/io/sgd_buffer.lua b/io/sgd_buffer.lua index dadcf67..bf72744 100644 --- a/io/sgd_buffer.lua +++ b/io/sgd_buffer.lua @@ -4,6 +4,10 @@ function SGDBuffer:__init(global_conf, buffer_conf) self.gconf = global_conf self.buffer_size = math.floor(buffer_conf.buffer_size / global_conf.batch_size) * global_conf.batch_size + self.randomize = buffer_conf.randomize + if self.randomize == nil then + self.randomize = false + end self.head = 0 self.tail = 0 self.readers = {} @@ -35,7 +39,9 @@ function SGDBuffer:saturate() nerv.error("buffer size is too small to contain leftovers") end buff.data:copy_from(buff.leftover, 0, lrow) + buff.leftover = nil end + nerv.utils.printf("leftover: %d\n", lrow) reader.tail = lrow reader.has_leftover = false end @@ -73,6 +79,8 @@ function SGDBuffer:saturate() end self.tail = math.min(self.tail, reader.tail) end + self.rand_map = nerv.MMatrixInt.perm_gen(self.tail) -- generate shuffled index + collectgarbage("collect") return self.tail >= self.gconf.batch_size end @@ -90,7 +98,11 @@ function SGDBuffer:get_data() for i, reader in ipairs(self.readers) do for id, buff in pairs(reader.buffs) do local batch = self.gconf.cumat_type(batch_size, buff.width) - batch:copy_fromh(buff.data, self.head, self.head + batch_size) + if self.randomize then + batch:copy_rows_fromh_by_idx(buff.data, self.rand_map, self.head) + else + batch:copy_fromh(buff.data, self.head, self.head + batch_size) + end res[id] = batch end end |