aboutsummaryrefslogtreecommitdiff
path: root/nerv/io/sgd_buffer.lua
diff options
context:
space:
mode:
authorDeterminant <ted.sybil@gmail.com>2016-03-16 17:53:39 +0800
committerDeterminant <ted.sybil@gmail.com>2016-03-16 17:53:39 +0800
commit289ac7f4b6e88b935da5c891e1efcf91fc047403 (patch)
treed4fc3a4fc20f2d5908624b3f6587ecd57966d719 /nerv/io/sgd_buffer.lua
parent07fc1e2794027d44c255e1062c4491346b101a08 (diff)
merge seq_buffer and change asr_trainer.lua accordingly
Diffstat (limited to 'nerv/io/sgd_buffer.lua')
-rw-r--r--nerv/io/sgd_buffer.lua149
1 files changed, 0 insertions, 149 deletions
diff --git a/nerv/io/sgd_buffer.lua b/nerv/io/sgd_buffer.lua
deleted file mode 100644
index d78f6d1..0000000
--- a/nerv/io/sgd_buffer.lua
+++ /dev/null
@@ -1,149 +0,0 @@
-local SGDBuffer = nerv.class("nerv.SGDBuffer", "nerv.DataBuffer")
-
-function SGDBuffer:__init(global_conf, buffer_conf)
- self.gconf = global_conf
- self.batch_size = buffer_conf.batch_size
- self.buffer_size = math.floor(buffer_conf.buffer_size /
- self.batch_size) * self.batch_size
- self.randomize = buffer_conf.randomize
- self.consume = buffer_conf.consume
- local cumat_type = global_conf.cumat_type
- if self.gconf.use_cpu then
- self.output_mat_type = self.gconf.mmat_type
- else
- self.output_mat_type = self.gconf.cumat_type
- end
- if buffer_conf.use_gpu then
- self.mat_type = cumat_type
- if self.gconf.use_cpu then
- -- gpu buffer -> cpu training
- nerv.error("not implemeted")
- else
- -- gpu buffer -> gpu training
- self.copy_rows_from_by_idx = cumat_type.copy_rows_fromd_by_idx
- self.copy_from = cumat_type.copy_fromd
- end
- self.perm_gen = function (x)
- return cumat_type.new_from_host(nerv.MMatrixFloat.perm_gen(x))
- end
- else
- self.mat_type = global_conf.mmat_type
- if self.gconf.use_cpu then
- -- cpu buffer -> cpu training
- self.copy_rows_from_by_idx = gconf.mmat_type.copy_rows_fromh_by_idx
- self.copy_from = gconf.mmat_type.copy_fromh
- else
- -- cpu buffer -> gpu training
- self.copy_rows_from_by_idx = cumat_type.copy_rows_fromh_by_idx
- self.copy_from = cumat_type.copy_fromh
- end
- self.perm_gen = nerv.MMatrixFloat.perm_gen
- end
- self.copy_from_reader = self.mat_type.copy_fromh
- self.head = 0
- self.tail = 0
- self.readers = {}
- for i, reader_spec in ipairs(buffer_conf.readers) do
- local buffs = {}
- for id, width in pairs(reader_spec.data) do
- buffs[id] = {data = self.mat_type(self.buffer_size, width),
- leftover = nil,
- width = width}
- end
- table.insert(self.readers, {buffs = buffs,
- reader = reader_spec.reader,
- tail = 0,
- has_leftover = false})
- end
-end
-
-function SGDBuffer:saturate()
- local buffer_size = self.buffer_size
- self.head = 0
- self.tail = buffer_size
- for i, reader in ipairs(self.readers) do
- reader.tail = 0
- if reader.has_leftover then
- local lrow
- for id, buff in pairs(reader.buffs) do
- lrow = buff.leftover:nrow()
- if lrow > buffer_size then
- nerv.error("buffer size is too small to contain leftovers")
- end
- buff.data:copy_from(buff.leftover, 0, lrow)
- buff.leftover = nil
- end
- nerv.info("buffer leftover: %d\n", lrow)
- reader.tail = lrow
- reader.has_leftover = false
- end
- while reader.tail < buffer_size do
- local data = reader.reader:get_data()
- if data == nil then
- break
- end
- local drow = nil
- for id, d in pairs(data) do
- if drow == nil then
- drow = d:nrow()
- elseif d:nrow() ~= drow then
- nerv.error("reader provides with inconsistent rows of data")
- end
- end
- local remain = buffer_size - reader.tail
- if drow > remain then
- for id, buff in pairs(reader.buffs) do
- local d = data[id]
- if d == nil then
- nerv.error("reader does not provide data for %s", id)
- end
- buff.leftover = self.mat_type(drow - remain,
- buff.width)
- self.copy_from_reader(buff.leftover, d, remain, drow)
- end
- drow = remain
- reader.has_leftover = true
- end
- for id, buff in pairs(reader.buffs) do
- self.copy_from_reader(buff.data, data[id], 0, drow, reader.tail)
- end
- reader.tail = reader.tail + drow
- end
- self.tail = math.min(self.tail, reader.tail)
- end
- self.rand_map = self.perm_gen(self.tail) -- generate shuffled index
- collectgarbage("collect")
- return self.tail >= self.batch_size
-end
-
-function SGDBuffer:get_data()
- local batch_size = self.batch_size
- if self.head >= self.tail then -- buffer is empty
- local t = os.clock()
- if (not self:saturate()) and (not self.consume) then
- return nil -- the remaining data cannot build a batch
- end
- if self.tail == self.head then
- return nil -- nothing left
- end
- nerv.info("%.3fs to fill the buffer", os.clock() - t)
- end
- if self.head + batch_size > self.tail and (not self.consume) then
- return nil -- the remaining data cannot build a batch
- end
- actual_batch_size = math.min(batch_size, self.tail - self.head)
- local res = {}
- for i, reader in ipairs(self.readers) do
- for id, buff in pairs(reader.buffs) do
- local batch = self.output_mat_type(actual_batch_size, buff.width)
- if self.randomize then
- self.copy_rows_from_by_idx(batch, buff.data, self.rand_map, self.head)
- else
- self.copy_from(batch, buff.data, self.head, self.head + actual_batch_size)
- end
- res[id] = batch
- end
- end
- self.head = self.head + actual_batch_size
- return res
-end