aboutsummaryrefslogtreecommitdiff
path: root/nerv/io/frm_buffer.lua
diff options
context:
space:
mode:
Diffstat (limited to 'nerv/io/frm_buffer.lua')
-rw-r--r--nerv/io/frm_buffer.lua151
1 files changed, 151 insertions, 0 deletions
diff --git a/nerv/io/frm_buffer.lua b/nerv/io/frm_buffer.lua
new file mode 100644
index 0000000..9761f16
--- /dev/null
+++ b/nerv/io/frm_buffer.lua
@@ -0,0 +1,151 @@
+local FrmBuffer = nerv.class("nerv.FrmBuffer", "nerv.DataBuffer")
+
+function FrmBuffer:__init(global_conf, buffer_conf)
+ self.gconf = global_conf
+ self.batch_size = buffer_conf.batch_size
+ self.buffer_size = math.floor(buffer_conf.buffer_size /
+ self.batch_size) * self.batch_size
+ self.randomize = buffer_conf.randomize
+ self.consume = buffer_conf.consume
+ local cumat_type = global_conf.cumat_type
+ if self.gconf.use_cpu then
+ self.output_mat_type = self.gconf.mmat_type
+ else
+ self.output_mat_type = self.gconf.cumat_type
+ end
+ if buffer_conf.use_gpu then
+ self.mat_type = cumat_type
+ if self.gconf.use_cpu then
+ -- gpu buffer -> cpu training
+ nerv.error("not implemeted")
+ else
+ -- gpu buffer -> gpu training
+ self.copy_rows_from_by_idx = cumat_type.copy_rows_fromd_by_idx
+ self.copy_from = cumat_type.copy_fromd
+ end
+ self.perm_gen = function (x)
+ return cumat_type.new_from_host(nerv.MMatrixFloat.perm_gen(x))
+ end
+ else
+ self.mat_type = global_conf.mmat_type
+ if self.gconf.use_cpu then
+ -- cpu buffer -> cpu training
+ self.copy_rows_from_by_idx = gconf.mmat_type.copy_rows_fromh_by_idx
+ self.copy_from = gconf.mmat_type.copy_fromh
+ else
+ -- cpu buffer -> gpu training
+ self.copy_rows_from_by_idx = cumat_type.copy_rows_fromh_by_idx
+ self.copy_from = cumat_type.copy_fromh
+ end
+ self.perm_gen = nerv.MMatrixFloat.perm_gen
+ end
+ self.copy_from_reader = self.mat_type.copy_fromh
+ self.head = 0
+ self.tail = 0
+ self.readers = {}
+ for i, reader_spec in ipairs(buffer_conf.readers) do
+ local buffs = {}
+ for id, width in pairs(reader_spec.data) do
+ buffs[id] = {data = self.mat_type(self.buffer_size, width),
+ leftover = nil,
+ width = width}
+ end
+ table.insert(self.readers, {buffs = buffs,
+ reader = reader_spec.reader,
+ tail = 0,
+ has_leftover = false})
+ end
+end
+
+function FrmBuffer:saturate()
+ local buffer_size = self.buffer_size
+ self.head = 0
+ self.tail = buffer_size
+ for i, reader in ipairs(self.readers) do
+ reader.tail = 0
+ if reader.has_leftover then
+ local lrow
+ for id, buff in pairs(reader.buffs) do
+ lrow = buff.leftover:nrow()
+ if lrow > buffer_size then
+ nerv.error("buffer size is too small to contain leftovers")
+ end
+ buff.data:copy_from(buff.leftover, 0, lrow)
+ buff.leftover = nil
+ end
+ nerv.info("buffer leftover: %d\n", lrow)
+ reader.tail = lrow
+ reader.has_leftover = false
+ end
+ while reader.tail < buffer_size do
+ local data = reader.reader:get_data()
+ if data == nil then
+ break
+ end
+ local drow = nil
+ for id, d in pairs(data) do
+ if drow == nil then
+ drow = d:nrow()
+ elseif d:nrow() ~= drow then
+ nerv.error("reader provides with inconsistent rows of data")
+ end
+ end
+ local remain = buffer_size - reader.tail
+ if drow > remain then
+ for id, buff in pairs(reader.buffs) do
+ local d = data[id]
+ if d == nil then
+ nerv.error("reader does not provide data for %s", id)
+ end
+ buff.leftover = self.mat_type(drow - remain,
+ buff.width)
+ self.copy_from_reader(buff.leftover, d, remain, drow)
+ end
+ drow = remain
+ reader.has_leftover = true
+ end
+ for id, buff in pairs(reader.buffs) do
+ self.copy_from_reader(buff.data, data[id], 0, drow, reader.tail)
+ end
+ reader.tail = reader.tail + drow
+ end
+ self.tail = math.min(self.tail, reader.tail)
+ end
+ self.rand_map = self.perm_gen(self.tail) -- generate shuffled index
+ collectgarbage("collect")
+ return self.tail >= self.batch_size
+end
+
+function FrmBuffer:get_data()
+ local batch_size = self.batch_size
+ if self.head >= self.tail then -- buffer is empty
+ local t = os.clock()
+ if (not self:saturate()) and (not self.consume) then
+ return nil -- the remaining data cannot build a batch
+ end
+ if self.tail == self.head then
+ return nil -- nothing left
+ end
+ nerv.info("%.3fs to fill the buffer", os.clock() - t)
+ end
+ if self.head + batch_size > self.tail and (not self.consume) then
+ return nil -- the remaining data cannot build a batch
+ end
+ actual_batch_size = math.min(batch_size, self.tail - self.head)
+ local res = {seq_length = table.vector(gconf.batch_size, 1),
+ new_seq = {},
+ data = {}}
+ for i, reader in ipairs(self.readers) do
+ for id, buff in pairs(reader.buffs) do
+ local batch = self.output_mat_type(actual_batch_size, buff.width)
+ if self.randomize then
+ self.copy_rows_from_by_idx(batch, buff.data, self.rand_map, self.head)
+ else
+ self.copy_from(batch, buff.data, self.head, self.head + actual_batch_size)
+ end
+ res.data[id] = {batch}
+ end
+ end
+ self.head = self.head + actual_batch_size
+ return res
+end