aboutsummaryrefslogtreecommitdiff
path: root/nerv/io/seq_buffer.lua
diff options
context:
space:
mode:
Diffstat (limited to 'nerv/io/seq_buffer.lua')
-rw-r--r--nerv/io/seq_buffer.lua108
1 files changed, 108 insertions, 0 deletions
diff --git a/nerv/io/seq_buffer.lua b/nerv/io/seq_buffer.lua
index e69de29..029e7b8 100644
--- a/nerv/io/seq_buffer.lua
+++ b/nerv/io/seq_buffer.lua
@@ -0,0 +1,108 @@
+local SeqBuffer = nerv.class('nerv.SeqBuffer', 'nerv.DataBuffer')
+
+function SeqBuffer:__init(global_conf, buffer_conf)
+ self.gconf = global_conf
+
+ self.batch_size = buffer_conf.batch_size
+ self.chunk_size = buffer_conf.chunk_size
+ self.readers = {}
+ for _, v in ipairs(buffer_conf.readers) do
+ table.insert(self.readers, v.reader)
+ end
+ self.nn_act_default = buffer_conf.nn_act_default
+ if self.nn_act_default == nil then
+ self.nn_act_default = 0
+ end
+
+ self.mat_type = self.gconf.mmat_type
+ self.queue = {}
+ self.head = 1
+ self.tail = 0
+end
+
+function SeqBuffer:new_mini_batch()
+ local res = {}
+ res.data = {}
+ res.new_seq = {}
+ res.seq_length = {}
+ for i = 1, self.batch_size do
+ res.seq_length[i] = 0
+ end
+ return res
+end
+
+function SeqBuffer:saturate(batch)
+ if self.queue[self.head] ~= nil and self.queue[self.head].seq_length[batch] ~= 0 then
+ return true
+ end
+ local data = {}
+ local drow = nil
+ for i = 1, #self.readers do
+ local tmp = self.readers[i]:get_data()
+ if tmp == nil then
+ return false
+ end
+ for id, d in pairs(tmp) do
+ if drow == nil then
+ drow = d:nrow()
+ elseif d:nrow() ~= drow then
+ nerv.error('readers provides with inconsistent rows of data')
+ end
+ data[id] = d
+ end
+ end
+ local offset = 0
+ local head = self.head
+ while offset < drow do
+ local last = math.min(offset + self.chunk_size, drow)
+ if head > self.tail then
+ self.tail = self.tail + 1
+ self.queue[self.tail] = self:new_mini_batch()
+ end
+ self.queue[head].seq_length[batch] = last - offset
+ if offset == 0 then
+ table.insert(self.queue[head].new_seq, batch)
+ end
+ local mini_batch = self.queue[head].data
+ for id, d in pairs(data) do
+ if mini_batch[id] == nil then
+ mini_batch[id] = {}
+ end
+ local tmp = mini_batch[id]
+ for i = offset + 1, last do
+ local chunk = i - offset
+ if tmp[chunk] == nil then
+ tmp[chunk] = self.mat_type(self.batch_size, d:ncol())
+ tmp[chunk]:fill(self.nn_act_default)
+ end
+ tmp[chunk]:copy_from(d, i - 1, i, batch - 1)
+ end
+ end
+ head = head + 1
+ offset = last
+ end
+ return true
+end
+
+function SeqBuffer:get_data()
+ local has_data = false
+ for i = 1, self.batch_size do
+ if self:saturate(i) then
+ has_data = true
+ end
+ end
+ if not has_data then
+ return nil
+ end
+ local res = self.queue[self.head]
+ self.queue[self.head] = nil
+ self.head = self.head + 1
+ if not self.gconf.use_cpu then
+ for id, d in pairs(res.data) do
+ for i = 1, #d do
+ d[i] = self.gconf.cumat_type.new_from_host(d[i])
+ end
+ end
+ end
+ return res
+end