aboutsummaryrefslogtreecommitdiff
path: root/nerv/io/seq_buffer.lua
diff options
context:
space:
mode:
Diffstat (limited to 'nerv/io/seq_buffer.lua')
-rw-r--r--nerv/io/seq_buffer.lua34
1 files changed, 34 insertions, 0 deletions
diff --git a/nerv/io/seq_buffer.lua b/nerv/io/seq_buffer.lua
index 029e7b8..65df617 100644
--- a/nerv/io/seq_buffer.lua
+++ b/nerv/io/seq_buffer.lua
@@ -1,5 +1,36 @@
+--- Implements a sequence-level chopped and shuffled buffer which
+-- shall be used for cyclic NNs.
+-- @author Qi Liu <liuq901@163.com>
+
+--- The class for a sequence-level chopped and shuffled buffer which
+-- shall be used for cyclic NNs.
+-- @type nerv.SeqBuffer
+
local SeqBuffer = nerv.class('nerv.SeqBuffer', 'nerv.DataBuffer')
+--- The constructor.
+-- @param global_conf a table describing the computation state and providing
+-- with some global settings
+--
+-- The following fields in `global_conf` will be used:
+--
+-- * `use_cpu`: whether to provide with the chunks/"mini-batches" stored in the
+-- main memory on invocation of `get_data()`
+-- * `mmat_type`: the class used for creating matrices in CPU computation
+-- * `cumat_type` (if `use_cpu = false`): the class used for creating matrices
+-- in GPU computation
+--
+-- @param buffer_conf a table providing with settings dedicated for the buffer.
+-- Available fields includes:
+--
+-- * `readers`: an array of `nerv.DataReader` instances specifying the
+-- readers used to read data
+-- * `batch_size`: the number of rows for each batch matrix
+-- * `chunk_size`: the length of the BPTT context (number of batch
+-- matrices to provide upon each invocation of `get_data()`)
+-- * `nn_act_default`: the default value to fill into the "holes" (non-data
+-- frames)
+
function SeqBuffer:__init(global_conf, buffer_conf)
self.gconf = global_conf
@@ -84,6 +115,9 @@ function SeqBuffer:saturate(batch)
return true
end
+--- Get a batch group from the buffer.
+-- See `nerv.DataBuffer` for reference
+
function SeqBuffer:get_data()
local has_data = false
for i = 1, self.batch_size do