blob: 52cb754893320bd3c21f1424b42199552e17aaa8 (
plain) (
tree)
|
|
require 'lfs'
require 'pl'
local function build_trainer(ifname)
local host_param_repo = nerv.ParamRepo()
local mat_type
local src_loc_type
local train_loc_type
host_param_repo:import(ifname, gconf)
if gconf.use_cpu then
mat_type = gconf.mmat_type
src_loc_type = nerv.ParamRepo.LOC_TYPES.ON_HOST
train_loc_type = nerv.ParamRepo.LOC_TYPES.ON_HOST
else
mat_type = gconf.cumat_type
src_loc_type = nerv.ParamRepo.LOC_TYPES.ON_HOST
train_loc_type = nerv.ParamRepo.LOC_TYPES.ON_DEVICE
end
local param_repo = host_param_repo:copy(train_loc_type, gconf)
local layer_repo = make_layer_repo(param_repo)
local network =
|