aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorDeterminant <ted.sybil@gmail.com>2015-08-31 19:01:08 +0800
committerDeterminant <ted.sybil@gmail.com>2015-08-31 19:01:08 +0800
commitbfec52ea59aee722f6fed2aa60600c02f5f5e76b (patch)
tree59db1cb85030d5ada3ef336ed230c5ff2b425a47
parent447bd1ec6b7be07f22653874fc9db84c9b6a9f9a (diff)
parent3721c74d56ffdea43851489617f33cd13b87ab76 (diff)
Merge branch 'master' into fastnnfastnn
Conflicts: nerv/io/sgd_buffer.lua
-rw-r--r--nerv/io/sgd_buffer.lua20
-rw-r--r--nerv/nn/layer_dag.lua2
2 files changed, 12 insertions, 10 deletions
diff --git a/nerv/io/sgd_buffer.lua b/nerv/io/sgd_buffer.lua
index 65d6da1..dd5d285 100644
--- a/nerv/io/sgd_buffer.lua
+++ b/nerv/io/sgd_buffer.lua
@@ -5,9 +5,7 @@ function SGDBuffer:__init(global_conf, buffer_conf)
self.buffer_size = math.floor(buffer_conf.buffer_size /
global_conf.batch_size) * global_conf.batch_size
self.randomize = buffer_conf.randomize
- if self.randomize == nil then
- self.randomize = false
- end
+ self.consume = buffer_conf.consume
local cumat_type = global_conf.cumat_type
if buffer_conf.use_gpu then
self.mat_type = cumat_type
@@ -104,26 +102,30 @@ function SGDBuffer:get_data()
local batch_size = self.gconf.batch_size
if self.head >= self.tail then -- buffer is empty
local t = os.clock()
- if not self:saturate() then
+ if (not self:saturate()) and (not self.consume) then
return nil -- the remaining data cannot build a batch
end
- --nerv.info("%.3fs to fill the buffer", os.clock() - t)
+ if self.tail == self.head then
+ return nil -- nothing left
+ end
+ nerv.info("%.3fs to fill the buffer", os.clock() - t)
end
- if self.head + batch_size > self.tail then
+ if self.head + batch_size > self.tail and (not self.consume) then
return nil -- the remaining data cannot build a batch
end
+ actual_batch_size = math.min(batch_size, self.tail - self.head)
local res = {}
for i, reader in ipairs(self.readers) do
for id, buff in pairs(reader.buffs) do
- local batch = self.gconf.cumat_type(batch_size, buff.width)
+ local batch = self.gconf.cumat_type(actual_batch_size, buff.width)
if self.randomize then
self.copy_rows_from_by_idx(batch, buff.data, self.rand_map, self.head)
else
- self.copy_from(batch, buff.data, self.head, self.head + batch_size)
+ self.copy_from(batch, buff.data, self.head, self.head + actual_batch_size)
end
res[id] = batch
end
end
- self.head = self.head + batch_size
+ self.head = self.head + actual_batch_size
return res
end
diff --git a/nerv/nn/layer_dag.lua b/nerv/nn/layer_dag.lua
index a262a72..7843509 100644
--- a/nerv/nn/layer_dag.lua
+++ b/nerv/nn/layer_dag.lua
@@ -285,7 +285,7 @@ function DAGLayer:get_intermediate(id, port_type)
if id == "<input>" or id == "<output>" then
nerv.error("an actual real layer id is expected")
end
- local layer = layers[id]
+ local layer = self.layers[id]
if layer == nil then
nerv.error("layer id %s not found", id)
end