diff options
Diffstat (limited to 'nerv/io/frm_buffer.lua')
-rw-r--r-- | nerv/io/frm_buffer.lua | 36 |
1 files changed, 36 insertions, 0 deletions
diff --git a/nerv/io/frm_buffer.lua b/nerv/io/frm_buffer.lua index 9761f16..45f73a0 100644 --- a/nerv/io/frm_buffer.lua +++ b/nerv/io/frm_buffer.lua @@ -1,5 +1,38 @@ +--- Implements a frame-level chopped and shuffled buffer which shall be used +-- for acyclic feed forward NNs (`chunk_size = 1`). +-- @author Ted Yin <ted.sybil@gmail.com> + +--- The class for a frame-level chopped and shuffled buffer +-- which shall be used for acyclic feed forward NNs +-- @type nerv.FrmBuffer + local FrmBuffer = nerv.class("nerv.FrmBuffer", "nerv.DataBuffer") +--- The constructor. +-- @param global_conf a table describing the computation state and providing +-- with some global settings +-- +-- The following fields in `global_conf` will be used: +-- +-- * `use_cpu`: whether to provide with the chunks/"mini-batches" stored in the +-- main memory on invocation of `get_data()` +-- * `mmat_type`: the class used for creating matrices in CPU computation +-- * `cumat_type` (if `use_cpu = false`): the class used for creating matrices +-- in GPU computation +-- +-- @param buffer_conf a table providing with settings dedicated for the buffer. +-- Available fields includes: +-- +-- * `readers`: an array of `nerv.DataReader` instances specifying the +-- readers used to read data +-- * `batch_size`: the number of rows for each batch matrix +-- * `buffer_size`: the number of frames to be buffered and shuffled at once +-- * `randomize`: shuffle the buffer after filled if true +-- * `consume`: drop the last frames which cannot make up a full `batch_size` +-- matrix if false +-- * `use_gpu`: the buffer space will be allocated on the device (graphics +-- card) if true + function FrmBuffer:__init(global_conf, buffer_conf) self.gconf = global_conf self.batch_size = buffer_conf.batch_size @@ -116,6 +149,9 @@ function FrmBuffer:saturate() return self.tail >= self.batch_size end +--- Get a batch group from the buffer. +-- See `nerv.DataBuffer` for reference + function FrmBuffer:get_data() local batch_size = self.batch_size if self.head >= self.tail then -- buffer is empty |