aboutsummaryrefslogtreecommitdiff
path: root/nerv/examples/asr_trainer.lua
diff options
context:
space:
mode:
Diffstat (limited to 'nerv/examples/asr_trainer.lua')
-rw-r--r--nerv/examples/asr_trainer.lua18
1 files changed, 12 insertions, 6 deletions
diff --git a/nerv/examples/asr_trainer.lua b/nerv/examples/asr_trainer.lua
index 69cfeed..3fa2653 100644
--- a/nerv/examples/asr_trainer.lua
+++ b/nerv/examples/asr_trainer.lua
@@ -5,6 +5,12 @@ function build_trainer(ifname)
local network = get_network(layer_repo)
local global_transf = get_global_transf(layer_repo)
local input_order = get_input_order()
+ local mat_type
+ if gconf.use_cpu then
+ mat_type = gconf.mmat_type
+ else
+ mat_type = gconf.cumat_type
+ end
local iterative_trainer = function (prefix, scp_file, bp)
gconf.randomize = bp
-- build buffer
@@ -12,15 +18,15 @@ function build_trainer(ifname)
-- initialize the network
network:init(gconf.batch_size)
gconf.cnt = 0
- err_input = {nerv.CuMatrixFloat(gconf.batch_size, 1)}
+ err_input = {mat_type(gconf.batch_size, 1)}
err_input[1]:fill(1)
for data in buffer.get_data, buffer do
-- prine stat periodically
gconf.cnt = gconf.cnt + 1
if gconf.cnt == 1000 then
print_stat(layer_repo)
- nerv.CuMatrix.print_profile()
- nerv.CuMatrix.clear_profile()
+ mat_type.print_profile()
+ mat_type.clear_profile()
gconf.cnt = 0
-- break
end
@@ -42,7 +48,7 @@ function build_trainer(ifname)
end
table.insert(input, transformed)
end
- local output = {nerv.CuMatrixFloat(gconf.batch_size, 1)}
+ local output = {mat_type(gconf.batch_size, 1)}
err_output = {}
for i = 1, #input do
table.insert(err_output, input[i]:create())
@@ -56,8 +62,8 @@ function build_trainer(ifname)
collectgarbage("collect")
end
print_stat(layer_repo)
- nerv.CuMatrix.print_profile()
- nerv.CuMatrix.clear_profile()
+ mat_type.print_profile()
+ mat_type.clear_profile()
if (not bp) and prefix ~= nil then
nerv.info("writing back...")
local fname = string.format("%s_cv%.3f.nerv",