summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorQi Liu <[email protected]>2016-05-11 16:53:56 +0800
committerQi Liu <[email protected]>2016-05-11 16:53:56 +0800
commite85a4af65640aef77378275d478c1ba8b06b785e (patch)
treed7c95e62c067ec510fd3ad8b546fc367805ab632
parent4585970021f75d4c9e7154fc1681a80efa0f48ab (diff)
seq buffer support shuffle
-rw-r--r--nerv/io/seq_buffer.lua91
-rw-r--r--nerv/nn/trainer.lua3
2 files changed, 79 insertions, 15 deletions
diff --git a/nerv/io/seq_buffer.lua b/nerv/io/seq_buffer.lua
index 8cde1b3..b1c2a75 100644
--- a/nerv/io/seq_buffer.lua
+++ b/nerv/io/seq_buffer.lua
@@ -28,6 +28,8 @@ local SeqBuffer = nerv.class('nerv.SeqBuffer', 'nerv.DataBuffer')
-- * `batch_size`: the number of rows for each batch matrix
-- * `chunk_size`: the length of the BPTT context (number of batch
-- matrices to provide upon each invocation of `get_data()`)
+-- * `buffer_size`: the number of frames to be buffered and shuffled at once (shuffle
+-- in the sequence level, not in the frame level)
-- * `nn_act_default`: the default value to fill into the "holes" (non-data
-- frames)
@@ -36,6 +38,8 @@ function SeqBuffer:__init(global_conf, buffer_conf)
self.batch_size = buffer_conf.batch_size
self.chunk_size = buffer_conf.chunk_size
+ self.buffer_size = buffer_conf.buffer_size
+ self.randomize = buffer_conf.randomize
self.readers = {}
for _, v in ipairs(buffer_conf.readers) do
table.insert(self.readers, v.reader)
@@ -49,6 +53,11 @@ function SeqBuffer:__init(global_conf, buffer_conf)
self.queue = {}
self.head = 1
self.tail = 0
+ self.offset = 1
+ self.buffer = {}
+ self.length = {}
+ self.index = {}
+ self.complete = false
end
function SeqBuffer:new_mini_batch()
@@ -62,25 +71,77 @@ function SeqBuffer:new_mini_batch()
return res
end
+local function random_shuffle(a)
+ for i = #a, 2, -1 do
+ local j = math.random(i)
+ a[i], a[j] = a[j], a[i]
+ end
+end
+
+function SeqBuffer:fill_buffer()
+ if self.complete then
+ return false
+ end
+ local t = os.clock()
+ self.buffer = {}
+ self.length = {}
+ local size = 0
+ while size < self.buffer_size do
+ local drow = nil
+ local data = {}
+ for i = 1, #self.readers do
+ local tmp = self.readers[i]:get_data()
+ if tmp == nil then
+ self.complete = true
+ break
+ end
+ for id, d in pairs(tmp) do
+ if drow == nil then
+ drow = d:nrow()
+ elseif d:nrow() ~= drow then
+ nerv.error('readers provides with inconsistent rows of data')
+ end
+ data[id] = d
+ end
+ end
+ if self.complete then
+ break
+ end
+ size = size + drow
+ table.insert(self.buffer, data)
+ table.insert(self.length, drow)
+ end
+ self.index = {}
+ for i = 1, #self.buffer do
+ self.index[i] = i
+ end
+ if self.randomize then
+ random_shuffle(self.index)
+ end
+ self.offset = 1
+ collectgarbage('collect')
+ nerv.info('%.3fs to fill the buffer', os.clock() - t)
+ return #self.buffer > 0
+end
+
+function SeqBuffer:get_buffered_data()
+ if self.offset > #self.buffer then
+ if not self:fill_buffer() then
+ return nil
+ end
+ end
+ local id = self.index[self.offset]
+ self.offset = self.offset + 1
+ return self.buffer[id], self.length[id]
+end
+
function SeqBuffer:saturate(batch)
if self.queue[self.head] ~= nil and self.queue[self.head].seq_length[batch] ~= 0 then
return true
end
- local data = {}
- local drow = nil
- for i = 1, #self.readers do
- local tmp = self.readers[i]:get_data()
- if tmp == nil then
- return false
- end
- for id, d in pairs(tmp) do
- if drow == nil then
- drow = d:nrow()
- elseif d:nrow() ~= drow then
- nerv.error('readers provides with inconsistent rows of data')
- end
- data[id] = d
- end
+ local data, drow = self:get_buffered_data()
+ if data == nil then
+ return false
end
local offset = 0
local head = self.head
diff --git a/nerv/nn/trainer.lua b/nerv/nn/trainer.lua
index 8357c10..a17b36c 100644
--- a/nerv/nn/trainer.lua
+++ b/nerv/nn/trainer.lua
@@ -77,9 +77,12 @@ function trainer:make_buffer(readers)
})
else
return nerv.SeqBuffer(gconf, {
+ buffer_size = gconf.buffer_size,
batch_size = gconf.batch_size,
chunk_size = gconf.chunk_size,
+ randomize = gconf.randomize,
readers = readers,
+ nn_act_default = gconf.nn_act_default,
})
end
end