aboutsummaryrefslogtreecommitdiff
path: root/nerv/io/sgd_buffer.lua
diff options
context:
space:
mode:
Diffstat (limited to 'nerv/io/sgd_buffer.lua')
-rw-r--r--nerv/io/sgd_buffer.lua7
1 files changed, 4 insertions, 3 deletions
diff --git a/nerv/io/sgd_buffer.lua b/nerv/io/sgd_buffer.lua
index 3cf4f5a..d78f6d1 100644
--- a/nerv/io/sgd_buffer.lua
+++ b/nerv/io/sgd_buffer.lua
@@ -2,8 +2,9 @@ local SGDBuffer = nerv.class("nerv.SGDBuffer", "nerv.DataBuffer")
function SGDBuffer:__init(global_conf, buffer_conf)
self.gconf = global_conf
+ self.batch_size = buffer_conf.batch_size
self.buffer_size = math.floor(buffer_conf.buffer_size /
- global_conf.batch_size) * global_conf.batch_size
+ self.batch_size) * self.batch_size
self.randomize = buffer_conf.randomize
self.consume = buffer_conf.consume
local cumat_type = global_conf.cumat_type
@@ -112,11 +113,11 @@ function SGDBuffer:saturate()
end
self.rand_map = self.perm_gen(self.tail) -- generate shuffled index
collectgarbage("collect")
- return self.tail >= self.gconf.batch_size
+ return self.tail >= self.batch_size
end
function SGDBuffer:get_data()
- local batch_size = self.gconf.batch_size
+ local batch_size = self.batch_size
if self.head >= self.tail then -- buffer is empty
local t = os.clock()
if (not self:saturate()) and (not self.consume) then