diff options
author | Qi Liu <[email protected]> | 2016-05-11 16:53:56 +0800 |
---|---|---|
committer | Qi Liu <[email protected]> | 2016-05-11 16:53:56 +0800 |
commit | e85a4af65640aef77378275d478c1ba8b06b785e (patch) | |
tree | d7c95e62c067ec510fd3ad8b546fc367805ab632 /nerv | |
parent | 4585970021f75d4c9e7154fc1681a80efa0f48ab (diff) |
seq buffer support shuffle
Diffstat (limited to 'nerv')
-rw-r--r-- | nerv/io/seq_buffer.lua | 91 | ||||
-rw-r--r-- | nerv/nn/trainer.lua | 3 |
2 files changed, 79 insertions, 15 deletions
diff --git a/nerv/io/seq_buffer.lua b/nerv/io/seq_buffer.lua index 8cde1b3..b1c2a75 100644 --- a/nerv/io/seq_buffer.lua +++ b/nerv/io/seq_buffer.lua @@ -28,6 +28,8 @@ local SeqBuffer = nerv.class('nerv.SeqBuffer', 'nerv.DataBuffer') -- * `batch_size`: the number of rows for each batch matrix -- * `chunk_size`: the length of the BPTT context (number of batch -- matrices to provide upon each invocation of `get_data()`) +-- * `buffer_size`: the number of frames to be buffered and shuffled at once (shuffle +-- in the sequence level, not in the frame level) -- * `nn_act_default`: the default value to fill into the "holes" (non-data -- frames) @@ -36,6 +38,8 @@ function SeqBuffer:__init(global_conf, buffer_conf) self.batch_size = buffer_conf.batch_size self.chunk_size = buffer_conf.chunk_size + self.buffer_size = buffer_conf.buffer_size + self.randomize = buffer_conf.randomize self.readers = {} for _, v in ipairs(buffer_conf.readers) do table.insert(self.readers, v.reader) @@ -49,6 +53,11 @@ function SeqBuffer:__init(global_conf, buffer_conf) self.queue = {} self.head = 1 self.tail = 0 + self.offset = 1 + self.buffer = {} + self.length = {} + self.index = {} + self.complete = false end function SeqBuffer:new_mini_batch() @@ -62,25 +71,77 @@ function SeqBuffer:new_mini_batch() return res end +local function random_shuffle(a) + for i = #a, 2, -1 do + local j = math.random(i) + a[i], a[j] = a[j], a[i] + end +end + +function SeqBuffer:fill_buffer() + if self.complete then + return false + end + local t = os.clock() + self.buffer = {} + self.length = {} + local size = 0 + while size < self.buffer_size do + local drow = nil + local data = {} + for i = 1, #self.readers do + local tmp = self.readers[i]:get_data() + if tmp == nil then + self.complete = true + break + end + for id, d in pairs(tmp) do + if drow == nil then + drow = d:nrow() + elseif d:nrow() ~= drow then + nerv.error('readers provides with inconsistent rows of data') + end + data[id] = d + end + end + if self.complete then + break + end + size = size + drow + table.insert(self.buffer, data) + table.insert(self.length, drow) + end + self.index = {} + for i = 1, #self.buffer do + self.index[i] = i + end + if self.randomize then + random_shuffle(self.index) + end + self.offset = 1 + collectgarbage('collect') + nerv.info('%.3fs to fill the buffer', os.clock() - t) + return #self.buffer > 0 +end + +function SeqBuffer:get_buffered_data() + if self.offset > #self.buffer then + if not self:fill_buffer() then + return nil + end + end + local id = self.index[self.offset] + self.offset = self.offset + 1 + return self.buffer[id], self.length[id] +end + function SeqBuffer:saturate(batch) if self.queue[self.head] ~= nil and self.queue[self.head].seq_length[batch] ~= 0 then return true end - local data = {} - local drow = nil - for i = 1, #self.readers do - local tmp = self.readers[i]:get_data() - if tmp == nil then - return false - end - for id, d in pairs(tmp) do - if drow == nil then - drow = d:nrow() - elseif d:nrow() ~= drow then - nerv.error('readers provides with inconsistent rows of data') - end - data[id] = d - end + local data, drow = self:get_buffered_data() + if data == nil then + return false end local offset = 0 local head = self.head diff --git a/nerv/nn/trainer.lua b/nerv/nn/trainer.lua index 8357c10..a17b36c 100644 --- a/nerv/nn/trainer.lua +++ b/nerv/nn/trainer.lua @@ -77,9 +77,12 @@ function trainer:make_buffer(readers) }) else return nerv.SeqBuffer(gconf, { + buffer_size = gconf.buffer_size, batch_size = gconf.batch_size, chunk_size = gconf.chunk_size, + randomize = gconf.randomize, readers = readers, + nn_act_default = gconf.nn_act_default, }) end end |