aboutsummaryrefslogtreecommitdiff
path: root/nerv/io/frm_buffer.lua
diff options
context:
space:
mode:
Diffstat (limited to 'nerv/io/frm_buffer.lua')
-rw-r--r--nerv/io/frm_buffer.lua36
1 files changed, 36 insertions, 0 deletions
diff --git a/nerv/io/frm_buffer.lua b/nerv/io/frm_buffer.lua
index 9761f16..45f73a0 100644
--- a/nerv/io/frm_buffer.lua
+++ b/nerv/io/frm_buffer.lua
@@ -1,5 +1,38 @@
+--- Implements a frame-level chopped and shuffled buffer which shall be used
+-- for acyclic feed forward NNs (`chunk_size = 1`).
+-- @author Ted Yin <ted.sybil@gmail.com>
+
+--- The class for a frame-level chopped and shuffled buffer
+-- which shall be used for acyclic feed forward NNs
+-- @type nerv.FrmBuffer
+
local FrmBuffer = nerv.class("nerv.FrmBuffer", "nerv.DataBuffer")
+--- The constructor.
+-- @param global_conf a table describing the computation state and providing
+-- with some global settings
+--
+-- The following fields in `global_conf` will be used:
+--
+-- * `use_cpu`: whether to provide with the chunks/"mini-batches" stored in the
+-- main memory on invocation of `get_data()`
+-- * `mmat_type`: the class used for creating matrices in CPU computation
+-- * `cumat_type` (if `use_cpu = false`): the class used for creating matrices
+-- in GPU computation
+--
+-- @param buffer_conf a table providing with settings dedicated for the buffer.
+-- Available fields includes:
+--
+-- * `readers`: an array of `nerv.DataReader` instances specifying the
+-- readers used to read data
+-- * `batch_size`: the number of rows for each batch matrix
+-- * `buffer_size`: the number of frames to be buffered and shuffled at once
+-- * `randomize`: shuffle the buffer after filled if true
+-- * `consume`: drop the last frames which cannot make up a full `batch_size`
+-- matrix if false
+-- * `use_gpu`: the buffer space will be allocated on the device (graphics
+-- card) if true
+
function FrmBuffer:__init(global_conf, buffer_conf)
self.gconf = global_conf
self.batch_size = buffer_conf.batch_size
@@ -116,6 +149,9 @@ function FrmBuffer:saturate()
return self.tail >= self.batch_size
end
+--- Get a batch group from the buffer.
+-- See `nerv.DataBuffer` for reference
+
function FrmBuffer:get_data()
local batch_size = self.batch_size
if self.head >= self.tail then -- buffer is empty