diff options
author | Determinant <[email protected]> | 2015-06-03 23:00:45 +0800 |
---|---|---|
committer | Determinant <[email protected]> | 2015-06-03 23:00:45 +0800 |
commit | ea6f2990f99dd9ded6a0e74d75a3ec84900a2518 (patch) | |
tree | 03b4ea34fa373189bf6b2b017bf54793d5c89f8e /io | |
parent | bb56a806e0636a0b20117b1644701d63e2bfaefb (diff) |
demo now works (without random shuffle)
Diffstat (limited to 'io')
-rw-r--r-- | io/init.lua | 22 | ||||
-rw-r--r-- | io/sgd_buffer.lua | 99 |
2 files changed, 121 insertions, 0 deletions
diff --git a/io/init.lua b/io/init.lua index 4a663a7..9bbd51a 100644 --- a/io/init.lua +++ b/io/init.lua @@ -28,3 +28,25 @@ function nerv.ChunkFile:read_chunk(id, global_conf) chunk:read(self:get_chunkdata(id)) return chunk end + +local DataReader = nerv.class("nerv.DataReader") + +function DataReader:__init(global_conf, reader_conf) + nerv.error_method_not_implemented() +end + +function DataReader:get_data() + nerv.error_method_not_implemented() +end + +local DataBuffer = nerv.class("nerv.DataBuffer") + +function DataBuffer:__init(global_conf, buffer_conf) + nerv.error_method_not_implemented() +end + +function DataBuffer:get_batch() + nerv.error_method_not_implemented() +end + +require 'io.sgd_buffer' diff --git a/io/sgd_buffer.lua b/io/sgd_buffer.lua new file mode 100644 index 0000000..dadcf67 --- /dev/null +++ b/io/sgd_buffer.lua @@ -0,0 +1,99 @@ +local SGDBuffer = nerv.class("nerv.SGDBuffer", "nerv.DataBuffer") + +function SGDBuffer:__init(global_conf, buffer_conf) + self.gconf = global_conf + self.buffer_size = math.floor(buffer_conf.buffer_size / + global_conf.batch_size) * global_conf.batch_size + self.head = 0 + self.tail = 0 + self.readers = {} + for i, reader_spec in ipairs(buffer_conf.readers) do + local buffs = {} + for id, width in pairs(reader_spec.data) do + buffs[id] = {data = global_conf.mmat_type(self.buffer_size, width), + leftover = {}, + width = width} + end + table.insert(self.readers, {buffs = buffs, + reader = reader_spec.reader, + tail = 0, + has_leftover = false}) + end +end + +function SGDBuffer:saturate() + local buffer_size = self.buffer_size + self.head = 0 + self.tail = buffer_size + for i, reader in ipairs(self.readers) do + reader.tail = 0 + if reader.has_leftover then + local lrow + for id, buff in pairs(reader.buffs) do + lrow = buff.leftover:nrow() + if lrow > buffer_size then + nerv.error("buffer size is too small to contain leftovers") + end + buff.data:copy_from(buff.leftover, 0, lrow) + end + reader.tail = lrow + reader.has_leftover = false + end + while reader.tail < buffer_size do + local data = reader.reader:get_data() + if data == nil then + break + end + local drow = nil + for id, d in pairs(data) do + if drow == nil then + drow = d:nrow() + elseif d:nrow() ~= drow then + nerv.error("reader provides with inconsistent rows of data") + end + end + local remain = buffer_size - reader.tail + if drow > remain then + for id, buff in pairs(reader.buffs) do + local d = data[id] + if d == nil then + nerv.error("reader does not provide data for %s", id) + end + buff.leftover = self.gconf.mmat_type(drow - remain, + buff.width) + buff.leftover:copy_from(d, remain, drow) + end + drow = remain + reader.has_leftover = true + end + for id, buff in pairs(reader.buffs) do + buff.data:copy_from(data[id], 0, drow, reader.tail) + end + reader.tail = reader.tail + drow + end + self.tail = math.min(self.tail, reader.tail) + end + return self.tail >= self.gconf.batch_size +end + +function SGDBuffer:get_data() + local batch_size = self.gconf.batch_size + if self.head >= self.tail then -- buffer is empty + if not self:saturate() then + return nil -- the remaining data cannot build a batch + end + end + if self.head + batch_size > self.tail then + return nil -- the remaining data cannot build a batch + end + local res = {} + for i, reader in ipairs(self.readers) do + for id, buff in pairs(reader.buffs) do + local batch = self.gconf.cumat_type(batch_size, buff.width) + batch:copy_fromh(buff.data, self.head, self.head + batch_size) + res[id] = batch + end + end + self.head = self.head + batch_size + return res +end |