diff options
-rw-r--r-- | nerv/io/sgd_buffer.lua | 18 | ||||
-rw-r--r-- | nerv/nn/layer_dag.lua | 2 |
2 files changed, 11 insertions, 9 deletions
diff --git a/nerv/io/sgd_buffer.lua b/nerv/io/sgd_buffer.lua index 3f854f0..74c4934 100644 --- a/nerv/io/sgd_buffer.lua +++ b/nerv/io/sgd_buffer.lua @@ -5,9 +5,7 @@ function SGDBuffer:__init(global_conf, buffer_conf) self.buffer_size = math.floor(buffer_conf.buffer_size / global_conf.batch_size) * global_conf.batch_size self.randomize = buffer_conf.randomize - if self.randomize == nil then - self.randomize = false - end + self.consume = buffer_conf.consume local cumat_type = global_conf.cumat_type if buffer_conf.use_gpu then self.mat_type = cumat_type @@ -104,26 +102,30 @@ function SGDBuffer:get_data() local batch_size = self.gconf.batch_size if self.head >= self.tail then -- buffer is empty local t = os.clock() - if not self:saturate() then + if (not self:saturate()) and (not self.consume) then return nil -- the remaining data cannot build a batch end + if self.tail == self.head then + return nil -- nothing left + end nerv.info("%.3fs to fill the buffer", os.clock() - t) end - if self.head + batch_size > self.tail then + if self.head + batch_size > self.tail and (not self.consume) then return nil -- the remaining data cannot build a batch end + actual_batch_size = math.min(batch_size, self.tail - self.head) local res = {} for i, reader in ipairs(self.readers) do for id, buff in pairs(reader.buffs) do - local batch = self.gconf.cumat_type(batch_size, buff.width) + local batch = self.gconf.cumat_type(actual_batch_size, buff.width) if self.randomize then self.copy_rows_from_by_idx(batch, buff.data, self.rand_map, self.head) else - self.copy_from(batch, buff.data, self.head, self.head + batch_size) + self.copy_from(batch, buff.data, self.head, self.head + actual_batch_size) end res[id] = batch end end - self.head = self.head + batch_size + self.head = self.head + actual_batch_size return res end diff --git a/nerv/nn/layer_dag.lua b/nerv/nn/layer_dag.lua index 25297c2..f69d31c 100644 --- a/nerv/nn/layer_dag.lua +++ b/nerv/nn/layer_dag.lua @@ -266,7 +266,7 @@ function DAGLayer:get_intermediate(id, port_type) if id == "<input>" or id == "<output>" then nerv.error("an actual real layer id is expected") end - local layer = layers[id] + local layer = self.layers[id] if layer == nil then nerv.error("layer id %s not found", id) end |