diff options
author | Determinant <ted.sybil@gmail.com> | 2016-03-16 17:53:39 +0800 |
---|---|---|
committer | Determinant <ted.sybil@gmail.com> | 2016-03-16 17:53:39 +0800 |
commit | 289ac7f4b6e88b935da5c891e1efcf91fc047403 (patch) | |
tree | d4fc3a4fc20f2d5908624b3f6587ecd57966d719 /nerv/io/sgd_buffer.lua | |
parent | 07fc1e2794027d44c255e1062c4491346b101a08 (diff) |
merge seq_buffer and change asr_trainer.lua accordingly
Diffstat (limited to 'nerv/io/sgd_buffer.lua')
-rw-r--r-- | nerv/io/sgd_buffer.lua | 149 |
1 files changed, 0 insertions, 149 deletions
diff --git a/nerv/io/sgd_buffer.lua b/nerv/io/sgd_buffer.lua deleted file mode 100644 index d78f6d1..0000000 --- a/nerv/io/sgd_buffer.lua +++ /dev/null @@ -1,149 +0,0 @@ -local SGDBuffer = nerv.class("nerv.SGDBuffer", "nerv.DataBuffer") - -function SGDBuffer:__init(global_conf, buffer_conf) - self.gconf = global_conf - self.batch_size = buffer_conf.batch_size - self.buffer_size = math.floor(buffer_conf.buffer_size / - self.batch_size) * self.batch_size - self.randomize = buffer_conf.randomize - self.consume = buffer_conf.consume - local cumat_type = global_conf.cumat_type - if self.gconf.use_cpu then - self.output_mat_type = self.gconf.mmat_type - else - self.output_mat_type = self.gconf.cumat_type - end - if buffer_conf.use_gpu then - self.mat_type = cumat_type - if self.gconf.use_cpu then - -- gpu buffer -> cpu training - nerv.error("not implemeted") - else - -- gpu buffer -> gpu training - self.copy_rows_from_by_idx = cumat_type.copy_rows_fromd_by_idx - self.copy_from = cumat_type.copy_fromd - end - self.perm_gen = function (x) - return cumat_type.new_from_host(nerv.MMatrixFloat.perm_gen(x)) - end - else - self.mat_type = global_conf.mmat_type - if self.gconf.use_cpu then - -- cpu buffer -> cpu training - self.copy_rows_from_by_idx = gconf.mmat_type.copy_rows_fromh_by_idx - self.copy_from = gconf.mmat_type.copy_fromh - else - -- cpu buffer -> gpu training - self.copy_rows_from_by_idx = cumat_type.copy_rows_fromh_by_idx - self.copy_from = cumat_type.copy_fromh - end - self.perm_gen = nerv.MMatrixFloat.perm_gen - end - self.copy_from_reader = self.mat_type.copy_fromh - self.head = 0 - self.tail = 0 - self.readers = {} - for i, reader_spec in ipairs(buffer_conf.readers) do - local buffs = {} - for id, width in pairs(reader_spec.data) do - buffs[id] = {data = self.mat_type(self.buffer_size, width), - leftover = nil, - width = width} - end - table.insert(self.readers, {buffs = buffs, - reader = reader_spec.reader, - tail = 0, - has_leftover = false}) - end -end - -function SGDBuffer:saturate() - local buffer_size = self.buffer_size - self.head = 0 - self.tail = buffer_size - for i, reader in ipairs(self.readers) do - reader.tail = 0 - if reader.has_leftover then - local lrow - for id, buff in pairs(reader.buffs) do - lrow = buff.leftover:nrow() - if lrow > buffer_size then - nerv.error("buffer size is too small to contain leftovers") - end - buff.data:copy_from(buff.leftover, 0, lrow) - buff.leftover = nil - end - nerv.info("buffer leftover: %d\n", lrow) - reader.tail = lrow - reader.has_leftover = false - end - while reader.tail < buffer_size do - local data = reader.reader:get_data() - if data == nil then - break - end - local drow = nil - for id, d in pairs(data) do - if drow == nil then - drow = d:nrow() - elseif d:nrow() ~= drow then - nerv.error("reader provides with inconsistent rows of data") - end - end - local remain = buffer_size - reader.tail - if drow > remain then - for id, buff in pairs(reader.buffs) do - local d = data[id] - if d == nil then - nerv.error("reader does not provide data for %s", id) - end - buff.leftover = self.mat_type(drow - remain, - buff.width) - self.copy_from_reader(buff.leftover, d, remain, drow) - end - drow = remain - reader.has_leftover = true - end - for id, buff in pairs(reader.buffs) do - self.copy_from_reader(buff.data, data[id], 0, drow, reader.tail) - end - reader.tail = reader.tail + drow - end - self.tail = math.min(self.tail, reader.tail) - end - self.rand_map = self.perm_gen(self.tail) -- generate shuffled index - collectgarbage("collect") - return self.tail >= self.batch_size -end - -function SGDBuffer:get_data() - local batch_size = self.batch_size - if self.head >= self.tail then -- buffer is empty - local t = os.clock() - if (not self:saturate()) and (not self.consume) then - return nil -- the remaining data cannot build a batch - end - if self.tail == self.head then - return nil -- nothing left - end - nerv.info("%.3fs to fill the buffer", os.clock() - t) - end - if self.head + batch_size > self.tail and (not self.consume) then - return nil -- the remaining data cannot build a batch - end - actual_batch_size = math.min(batch_size, self.tail - self.head) - local res = {} - for i, reader in ipairs(self.readers) do - for id, buff in pairs(reader.buffs) do - local batch = self.output_mat_type(actual_batch_size, buff.width) - if self.randomize then - self.copy_rows_from_by_idx(batch, buff.data, self.rand_map, self.head) - else - self.copy_from(batch, buff.data, self.head, self.head + actual_batch_size) - end - res[id] = batch - end - end - self.head = self.head + actual_batch_size - return res -end |