diff options
author | Determinant <ted.sybil@gmail.com> | 2016-02-15 15:04:13 +0800 |
---|---|---|
committer | Determinant <ted.sybil@gmail.com> | 2016-02-15 15:04:13 +0800 |
commit | 3362020a6bc43766a92882abe6d127c8bb98a628 (patch) | |
tree | cad93eb88c2813694c0ae4ca4ecb9873a719ad85 /nerv/io | |
parent | dcad8a3f80fc55ca93984d981f9b829d2e4ea728 (diff) |
try a basic merge
Diffstat (limited to 'nerv/io')
-rw-r--r-- | nerv/io/sgd_buffer.lua | 31 |
1 files changed, 24 insertions, 7 deletions
diff --git a/nerv/io/sgd_buffer.lua b/nerv/io/sgd_buffer.lua index 74c4934..3cf4f5a 100644 --- a/nerv/io/sgd_buffer.lua +++ b/nerv/io/sgd_buffer.lua @@ -7,21 +7,38 @@ function SGDBuffer:__init(global_conf, buffer_conf) self.randomize = buffer_conf.randomize self.consume = buffer_conf.consume local cumat_type = global_conf.cumat_type + if self.gconf.use_cpu then + self.output_mat_type = self.gconf.mmat_type + else + self.output_mat_type = self.gconf.cumat_type + end if buffer_conf.use_gpu then self.mat_type = cumat_type - self.copy_rows_from_by_idx = cumat_type.copy_rows_fromd_by_idx - self.copy_from = cumat_type.copy_fromd - self.copy_from_reader = cumat_type.copy_fromh + if self.gconf.use_cpu then + -- gpu buffer -> cpu training + nerv.error("not implemeted") + else + -- gpu buffer -> gpu training + self.copy_rows_from_by_idx = cumat_type.copy_rows_fromd_by_idx + self.copy_from = cumat_type.copy_fromd + end self.perm_gen = function (x) return cumat_type.new_from_host(nerv.MMatrixFloat.perm_gen(x)) end else self.mat_type = global_conf.mmat_type - self.copy_rows_from_by_idx = cumat_type.copy_rows_fromh_by_idx - self.copy_from = cumat_type.copy_fromh + if self.gconf.use_cpu then + -- cpu buffer -> cpu training + self.copy_rows_from_by_idx = gconf.mmat_type.copy_rows_fromh_by_idx + self.copy_from = gconf.mmat_type.copy_fromh + else + -- cpu buffer -> gpu training + self.copy_rows_from_by_idx = cumat_type.copy_rows_fromh_by_idx + self.copy_from = cumat_type.copy_fromh + end self.perm_gen = nerv.MMatrixFloat.perm_gen - self.copy_from_reader = self.mat_type.copy_from end + self.copy_from_reader = self.mat_type.copy_fromh self.head = 0 self.tail = 0 self.readers = {} @@ -117,7 +134,7 @@ function SGDBuffer:get_data() local res = {} for i, reader in ipairs(self.readers) do for id, buff in pairs(reader.buffs) do - local batch = self.gconf.cumat_type(actual_batch_size, buff.width) + local batch = self.output_mat_type(actual_batch_size, buff.width) if self.randomize then self.copy_rows_from_by_idx(batch, buff.data, self.rand_map, self.head) else |