diff options
Diffstat (limited to 'nerv/io/sgd_buffer.lua')
-rw-r--r-- | nerv/io/sgd_buffer.lua | 7 |
1 files changed, 4 insertions, 3 deletions
diff --git a/nerv/io/sgd_buffer.lua b/nerv/io/sgd_buffer.lua index 3cf4f5a..d78f6d1 100644 --- a/nerv/io/sgd_buffer.lua +++ b/nerv/io/sgd_buffer.lua @@ -2,8 +2,9 @@ local SGDBuffer = nerv.class("nerv.SGDBuffer", "nerv.DataBuffer") function SGDBuffer:__init(global_conf, buffer_conf) self.gconf = global_conf + self.batch_size = buffer_conf.batch_size self.buffer_size = math.floor(buffer_conf.buffer_size / - global_conf.batch_size) * global_conf.batch_size + self.batch_size) * self.batch_size self.randomize = buffer_conf.randomize self.consume = buffer_conf.consume local cumat_type = global_conf.cumat_type @@ -112,11 +113,11 @@ function SGDBuffer:saturate() end self.rand_map = self.perm_gen(self.tail) -- generate shuffled index collectgarbage("collect") - return self.tail >= self.gconf.batch_size + return self.tail >= self.batch_size end function SGDBuffer:get_data() - local batch_size = self.gconf.batch_size + local batch_size = self.batch_size if self.head >= self.tail then -- buffer is empty local t = os.clock() if (not self:saturate()) and (not self.consume) then |