aboutsummaryrefslogtreecommitdiff
path: root/nerv/io/sgd_buffer.lua
diff options
context:
space:
mode:
Diffstat (limited to 'nerv/io/sgd_buffer.lua')
-rw-r--r--nerv/io/sgd_buffer.lua31
1 files changed, 24 insertions, 7 deletions
diff --git a/nerv/io/sgd_buffer.lua b/nerv/io/sgd_buffer.lua
index 74c4934..3cf4f5a 100644
--- a/nerv/io/sgd_buffer.lua
+++ b/nerv/io/sgd_buffer.lua
@@ -7,21 +7,38 @@ function SGDBuffer:__init(global_conf, buffer_conf)
self.randomize = buffer_conf.randomize
self.consume = buffer_conf.consume
local cumat_type = global_conf.cumat_type
+ if self.gconf.use_cpu then
+ self.output_mat_type = self.gconf.mmat_type
+ else
+ self.output_mat_type = self.gconf.cumat_type
+ end
if buffer_conf.use_gpu then
self.mat_type = cumat_type
- self.copy_rows_from_by_idx = cumat_type.copy_rows_fromd_by_idx
- self.copy_from = cumat_type.copy_fromd
- self.copy_from_reader = cumat_type.copy_fromh
+ if self.gconf.use_cpu then
+ -- gpu buffer -> cpu training
+ nerv.error("not implemeted")
+ else
+ -- gpu buffer -> gpu training
+ self.copy_rows_from_by_idx = cumat_type.copy_rows_fromd_by_idx
+ self.copy_from = cumat_type.copy_fromd
+ end
self.perm_gen = function (x)
return cumat_type.new_from_host(nerv.MMatrixFloat.perm_gen(x))
end
else
self.mat_type = global_conf.mmat_type
- self.copy_rows_from_by_idx = cumat_type.copy_rows_fromh_by_idx
- self.copy_from = cumat_type.copy_fromh
+ if self.gconf.use_cpu then
+ -- cpu buffer -> cpu training
+ self.copy_rows_from_by_idx = gconf.mmat_type.copy_rows_fromh_by_idx
+ self.copy_from = gconf.mmat_type.copy_fromh
+ else
+ -- cpu buffer -> gpu training
+ self.copy_rows_from_by_idx = cumat_type.copy_rows_fromh_by_idx
+ self.copy_from = cumat_type.copy_fromh
+ end
self.perm_gen = nerv.MMatrixFloat.perm_gen
- self.copy_from_reader = self.mat_type.copy_from
end
+ self.copy_from_reader = self.mat_type.copy_fromh
self.head = 0
self.tail = 0
self.readers = {}
@@ -117,7 +134,7 @@ function SGDBuffer:get_data()
local res = {}
for i, reader in ipairs(self.readers) do
for id, buff in pairs(reader.buffs) do
- local batch = self.gconf.cumat_type(actual_batch_size, buff.width)
+ local batch = self.output_mat_type(actual_batch_size, buff.width)
if self.randomize then
self.copy_rows_from_by_idx(batch, buff.data, self.rand_map, self.head)
else