aboutsummaryrefslogtreecommitdiff
path: root/nerv/examples/asr_trainer.lua
diff options
context:
space:
mode:
Diffstat (limited to 'nerv/examples/asr_trainer.lua')
-rw-r--r--nerv/examples/asr_trainer.lua4
1 files changed, 2 insertions, 2 deletions
diff --git a/nerv/examples/asr_trainer.lua b/nerv/examples/asr_trainer.lua
index 4fa4096..8dfb2ac 100644
--- a/nerv/examples/asr_trainer.lua
+++ b/nerv/examples/asr_trainer.lua
@@ -12,7 +12,7 @@ function build_trainer(ifname)
-- initialize the network
network:init(gconf.batch_size)
gconf.cnt = 0
- err_input = {nerv.CuMatrixFloat(256, 1)}
+ err_input = {nerv.CuMatrixFloat(gconf.batch_size, 1)}
err_input[1]:fill(1)
for data in buffer.get_data, buffer do
-- prine stat periodically
@@ -32,7 +32,7 @@ function build_trainer(ifname)
end
table.insert(input, data[id])
end
- local output = {nerv.CuMatrixFloat(256, 1)}
+ local output = {nerv.CuMatrixFloat(gconf.batch_size, 1)}
err_output = {input[1]:create()}
network:propagate(input, output)
if bp then