aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--nerv/io/sgd_buffer.lua18
-rw-r--r--nerv/nn/layer_dag.lua2
2 files changed, 11 insertions, 9 deletions
diff --git a/nerv/io/sgd_buffer.lua b/nerv/io/sgd_buffer.lua
index 3f854f0..74c4934 100644
--- a/nerv/io/sgd_buffer.lua
+++ b/nerv/io/sgd_buffer.lua
@@ -5,9 +5,7 @@ function SGDBuffer:__init(global_conf, buffer_conf)
self.buffer_size = math.floor(buffer_conf.buffer_size /
global_conf.batch_size) * global_conf.batch_size
self.randomize = buffer_conf.randomize
- if self.randomize == nil then
- self.randomize = false
- end
+ self.consume = buffer_conf.consume
local cumat_type = global_conf.cumat_type
if buffer_conf.use_gpu then
self.mat_type = cumat_type
@@ -104,26 +102,30 @@ function SGDBuffer:get_data()
local batch_size = self.gconf.batch_size
if self.head >= self.tail then -- buffer is empty
local t = os.clock()
- if not self:saturate() then
+ if (not self:saturate()) and (not self.consume) then
return nil -- the remaining data cannot build a batch
end
+ if self.tail == self.head then
+ return nil -- nothing left
+ end
nerv.info("%.3fs to fill the buffer", os.clock() - t)
end
- if self.head + batch_size > self.tail then
+ if self.head + batch_size > self.tail and (not self.consume) then
return nil -- the remaining data cannot build a batch
end
+ actual_batch_size = math.min(batch_size, self.tail - self.head)
local res = {}
for i, reader in ipairs(self.readers) do
for id, buff in pairs(reader.buffs) do
- local batch = self.gconf.cumat_type(batch_size, buff.width)
+ local batch = self.gconf.cumat_type(actual_batch_size, buff.width)
if self.randomize then
self.copy_rows_from_by_idx(batch, buff.data, self.rand_map, self.head)
else
- self.copy_from(batch, buff.data, self.head, self.head + batch_size)
+ self.copy_from(batch, buff.data, self.head, self.head + actual_batch_size)
end
res[id] = batch
end
end
- self.head = self.head + batch_size
+ self.head = self.head + actual_batch_size
return res
end
diff --git a/nerv/nn/layer_dag.lua b/nerv/nn/layer_dag.lua
index 25297c2..f69d31c 100644
--- a/nerv/nn/layer_dag.lua
+++ b/nerv/nn/layer_dag.lua
@@ -266,7 +266,7 @@ function DAGLayer:get_intermediate(id, port_type)
if id == "<input>" or id == "<output>" then
nerv.error("an actual real layer id is expected")
end
- local layer = layers[id]
+ local layer = self.layers[id]
if layer == nil then
nerv.error("layer id %s not found", id)
end