aboutsummaryrefslogtreecommitdiff
path: root/nerv/io
diff options
context:
space:
mode:
authorDeterminant <ted.sybil@gmail.com>2015-08-31 18:59:22 +0800
committerDeterminant <ted.sybil@gmail.com>2015-08-31 18:59:22 +0800
commit3721c74d56ffdea43851489617f33cd13b87ab76 (patch)
treeeefc73df5c2dc535155d137f544f64ad0409bb6c /nerv/io
parentcad144243b898a7bed91c18572bf42944e9db3b3 (diff)
...
Diffstat (limited to 'nerv/io')
-rw-r--r--nerv/io/sgd_buffer.lua18
1 files changed, 10 insertions, 8 deletions
diff --git a/nerv/io/sgd_buffer.lua b/nerv/io/sgd_buffer.lua
index 3f854f0..74c4934 100644
--- a/nerv/io/sgd_buffer.lua
+++ b/nerv/io/sgd_buffer.lua
@@ -5,9 +5,7 @@ function SGDBuffer:__init(global_conf, buffer_conf)
self.buffer_size = math.floor(buffer_conf.buffer_size /
global_conf.batch_size) * global_conf.batch_size
self.randomize = buffer_conf.randomize
- if self.randomize == nil then
- self.randomize = false
- end
+ self.consume = buffer_conf.consume
local cumat_type = global_conf.cumat_type
if buffer_conf.use_gpu then
self.mat_type = cumat_type
@@ -104,26 +102,30 @@ function SGDBuffer:get_data()
local batch_size = self.gconf.batch_size
if self.head >= self.tail then -- buffer is empty
local t = os.clock()
- if not self:saturate() then
+ if (not self:saturate()) and (not self.consume) then
return nil -- the remaining data cannot build a batch
end
+ if self.tail == self.head then
+ return nil -- nothing left
+ end
nerv.info("%.3fs to fill the buffer", os.clock() - t)
end
- if self.head + batch_size > self.tail then
+ if self.head + batch_size > self.tail and (not self.consume) then
return nil -- the remaining data cannot build a batch
end
+ actual_batch_size = math.min(batch_size, self.tail - self.head)
local res = {}
for i, reader in ipairs(self.readers) do
for id, buff in pairs(reader.buffs) do
- local batch = self.gconf.cumat_type(batch_size, buff.width)
+ local batch = self.gconf.cumat_type(actual_batch_size, buff.width)
if self.randomize then
self.copy_rows_from_by_idx(batch, buff.data, self.rand_map, self.head)
else
- self.copy_from(batch, buff.data, self.head, self.head + batch_size)
+ self.copy_from(batch, buff.data, self.head, self.head + actual_batch_size)
end
res[id] = batch
end
end
- self.head = self.head + batch_size
+ self.head = self.head + actual_batch_size
return res
end