aboutsummaryrefslogtreecommitdiff
path: root/nerv/io/seq_buffer.lua
diff options
context:
space:
mode:
Diffstat (limited to 'nerv/io/seq_buffer.lua')
-rw-r--r--nerv/io/seq_buffer.lua92
1 files changed, 77 insertions, 15 deletions
diff --git a/nerv/io/seq_buffer.lua b/nerv/io/seq_buffer.lua
index 8cde1b3..5c60f64 100644
--- a/nerv/io/seq_buffer.lua
+++ b/nerv/io/seq_buffer.lua
@@ -28,6 +28,9 @@ local SeqBuffer = nerv.class('nerv.SeqBuffer', 'nerv.DataBuffer')
-- * `batch_size`: the number of rows for each batch matrix
-- * `chunk_size`: the length of the BPTT context (number of batch
-- matrices to provide upon each invocation of `get_data()`)
+-- * `buffer_size`: the number of frames to be buffered and shuffled at once (shuffle
+-- in the sequence level, not in the frame level)
+-- * `randomize`: shuffle the buffer after filled if true
-- * `nn_act_default`: the default value to fill into the "holes" (non-data
-- frames)
@@ -36,6 +39,8 @@ function SeqBuffer:__init(global_conf, buffer_conf)
self.batch_size = buffer_conf.batch_size
self.chunk_size = buffer_conf.chunk_size
+ self.buffer_size = buffer_conf.buffer_size
+ self.randomize = buffer_conf.randomize
self.readers = {}
for _, v in ipairs(buffer_conf.readers) do
table.insert(self.readers, v.reader)
@@ -49,6 +54,11 @@ function SeqBuffer:__init(global_conf, buffer_conf)
self.queue = {}
self.head = 1
self.tail = 0
+ self.offset = 1
+ self.buffer = {}
+ self.length = {}
+ self.index = {}
+ self.complete = false
end
function SeqBuffer:new_mini_batch()
@@ -62,25 +72,77 @@ function SeqBuffer:new_mini_batch()
return res
end
+local function random_shuffle(a)
+ for i = #a, 2, -1 do
+ local j = math.random(i)
+ a[i], a[j] = a[j], a[i]
+ end
+end
+
+function SeqBuffer:fill_buffer()
+ if self.complete then
+ return false
+ end
+ local t = os.clock()
+ self.buffer = {}
+ self.length = {}
+ local size = 0
+ while size < self.buffer_size do
+ local drow = nil
+ local data = {}
+ for i = 1, #self.readers do
+ local tmp = self.readers[i]:get_data()
+ if tmp == nil then
+ self.complete = true
+ break
+ end
+ for id, d in pairs(tmp) do
+ if drow == nil then
+ drow = d:nrow()
+ elseif d:nrow() ~= drow then
+ nerv.error('readers provides with inconsistent rows of data')
+ end
+ data[id] = d
+ end
+ end
+ if self.complete then
+ break
+ end
+ size = size + drow
+ table.insert(self.buffer, data)
+ table.insert(self.length, drow)
+ end
+ self.index = {}
+ for i = 1, #self.buffer do
+ self.index[i] = i
+ end
+ if self.randomize then
+ random_shuffle(self.index)
+ end
+ self.offset = 1
+ collectgarbage('collect')
+ nerv.info('%.3fs to fill the buffer', os.clock() - t)
+ return #self.buffer > 0
+end
+
+function SeqBuffer:get_buffered_data()
+ if self.offset > #self.buffer then
+ if not self:fill_buffer() then
+ return nil
+ end
+ end
+ local id = self.index[self.offset]
+ self.offset = self.offset + 1
+ return self.buffer[id], self.length[id]
+end
+
function SeqBuffer:saturate(batch)
if self.queue[self.head] ~= nil and self.queue[self.head].seq_length[batch] ~= 0 then
return true
end
- local data = {}
- local drow = nil
- for i = 1, #self.readers do
- local tmp = self.readers[i]:get_data()
- if tmp == nil then
- return false
- end
- for id, d in pairs(tmp) do
- if drow == nil then
- drow = d:nrow()
- elseif d:nrow() ~= drow then
- nerv.error('readers provides with inconsistent rows of data')
- end
- data[id] = d
- end
+ local data, drow = self:get_buffered_data()
+ if data == nil then
+ return false
end
local offset = 0
local head = self.head