diff options
Diffstat (limited to 'nerv/io')
-rw-r--r-- | nerv/io/sgd_buffer.lua | 34 |
1 files changed, 25 insertions, 9 deletions
diff --git a/nerv/io/sgd_buffer.lua b/nerv/io/sgd_buffer.lua index f9d281c..3f854f0 100644 --- a/nerv/io/sgd_buffer.lua +++ b/nerv/io/sgd_buffer.lua @@ -8,13 +8,29 @@ function SGDBuffer:__init(global_conf, buffer_conf) if self.randomize == nil then self.randomize = false end + local cumat_type = global_conf.cumat_type + if buffer_conf.use_gpu then + self.mat_type = cumat_type + self.copy_rows_from_by_idx = cumat_type.copy_rows_fromd_by_idx + self.copy_from = cumat_type.copy_fromd + self.copy_from_reader = cumat_type.copy_fromh + self.perm_gen = function (x) + return cumat_type.new_from_host(nerv.MMatrixFloat.perm_gen(x)) + end + else + self.mat_type = global_conf.mmat_type + self.copy_rows_from_by_idx = cumat_type.copy_rows_fromh_by_idx + self.copy_from = cumat_type.copy_fromh + self.perm_gen = nerv.MMatrixFloat.perm_gen + self.copy_from_reader = self.mat_type.copy_from + end self.head = 0 self.tail = 0 self.readers = {} for i, reader_spec in ipairs(buffer_conf.readers) do local buffs = {} for id, width in pairs(reader_spec.data) do - buffs[id] = {data = global_conf.mmat_type(self.buffer_size, width), + buffs[id] = {data = self.mat_type(self.buffer_size, width), leftover = nil, width = width} end @@ -41,7 +57,7 @@ function SGDBuffer:saturate() buff.data:copy_from(buff.leftover, 0, lrow) buff.leftover = nil end - nerv.printf("buffer leftover: %d\n", lrow) + nerv.info("buffer leftover: %d\n", lrow) reader.tail = lrow reader.has_leftover = false end @@ -65,21 +81,21 @@ function SGDBuffer:saturate() if d == nil then nerv.error("reader does not provide data for %s", id) end - buff.leftover = self.gconf.mmat_type(drow - remain, - buff.width) - buff.leftover:copy_from(d, remain, drow) + buff.leftover = self.mat_type(drow - remain, + buff.width) + self.copy_from_reader(buff.leftover, d, remain, drow) end drow = remain reader.has_leftover = true end for id, buff in pairs(reader.buffs) do - buff.data:copy_from(data[id], 0, drow, reader.tail) + self.copy_from_reader(buff.data, data[id], 0, drow, reader.tail) end reader.tail = reader.tail + drow end self.tail = math.min(self.tail, reader.tail) end - self.rand_map = nerv.MMatrixInt.perm_gen(self.tail) -- generate shuffled index + self.rand_map = self.perm_gen(self.tail) -- generate shuffled index collectgarbage("collect") return self.tail >= self.gconf.batch_size end @@ -101,9 +117,9 @@ function SGDBuffer:get_data() for id, buff in pairs(reader.buffs) do local batch = self.gconf.cumat_type(batch_size, buff.width) if self.randomize then - batch:copy_rows_fromh_by_idx(buff.data, self.rand_map, self.head) + self.copy_rows_from_by_idx(batch, buff.data, self.rand_map, self.head) else - batch:copy_fromh(buff.data, self.head, self.head + batch_size) + self.copy_from(batch, buff.data, self.head, self.head + batch_size) end res[id] = batch end |