diff options
Diffstat (limited to 'nerv/examples/asr_trainer.lua')
-rw-r--r-- | nerv/examples/asr_trainer.lua | 18 |
1 files changed, 12 insertions, 6 deletions
diff --git a/nerv/examples/asr_trainer.lua b/nerv/examples/asr_trainer.lua index 69cfeed..3fa2653 100644 --- a/nerv/examples/asr_trainer.lua +++ b/nerv/examples/asr_trainer.lua @@ -5,6 +5,12 @@ function build_trainer(ifname) local network = get_network(layer_repo) local global_transf = get_global_transf(layer_repo) local input_order = get_input_order() + local mat_type + if gconf.use_cpu then + mat_type = gconf.mmat_type + else + mat_type = gconf.cumat_type + end local iterative_trainer = function (prefix, scp_file, bp) gconf.randomize = bp -- build buffer @@ -12,15 +18,15 @@ function build_trainer(ifname) -- initialize the network network:init(gconf.batch_size) gconf.cnt = 0 - err_input = {nerv.CuMatrixFloat(gconf.batch_size, 1)} + err_input = {mat_type(gconf.batch_size, 1)} err_input[1]:fill(1) for data in buffer.get_data, buffer do -- prine stat periodically gconf.cnt = gconf.cnt + 1 if gconf.cnt == 1000 then print_stat(layer_repo) - nerv.CuMatrix.print_profile() - nerv.CuMatrix.clear_profile() + mat_type.print_profile() + mat_type.clear_profile() gconf.cnt = 0 -- break end @@ -42,7 +48,7 @@ function build_trainer(ifname) end table.insert(input, transformed) end - local output = {nerv.CuMatrixFloat(gconf.batch_size, 1)} + local output = {mat_type(gconf.batch_size, 1)} err_output = {} for i = 1, #input do table.insert(err_output, input[i]:create()) @@ -56,8 +62,8 @@ function build_trainer(ifname) collectgarbage("collect") end print_stat(layer_repo) - nerv.CuMatrix.print_profile() - nerv.CuMatrix.clear_profile() + mat_type.print_profile() + mat_type.clear_profile() if (not bp) and prefix ~= nil then nerv.info("writing back...") local fname = string.format("%s_cv%.3f.nerv", |