aboutsummaryrefslogtreecommitdiff
path: root/nerv/examples/asr_trainer.lua
diff options
context:
space:
mode:
authorDeterminant <ted.sybil@gmail.com>2015-08-04 11:11:50 +0800
committerDeterminant <ted.sybil@gmail.com>2015-08-04 11:11:50 +0800
commit5b16335a903551ffef4fafa88d67146b9131a74e (patch)
tree2691b465eb1ebb905d12b73648fda8012d844704 /nerv/examples/asr_trainer.lua
parentb385d55268b7b327534e227065907a5ea2d2b731 (diff)
...
Diffstat (limited to 'nerv/examples/asr_trainer.lua')
-rw-r--r--nerv/examples/asr_trainer.lua4
1 files changed, 2 insertions, 2 deletions
diff --git a/nerv/examples/asr_trainer.lua b/nerv/examples/asr_trainer.lua
index 4fa4096..8dfb2ac 100644
--- a/nerv/examples/asr_trainer.lua
+++ b/nerv/examples/asr_trainer.lua
@@ -12,7 +12,7 @@ function build_trainer(ifname)
-- initialize the network
network:init(gconf.batch_size)
gconf.cnt = 0
- err_input = {nerv.CuMatrixFloat(256, 1)}
+ err_input = {nerv.CuMatrixFloat(gconf.batch_size, 1)}
err_input[1]:fill(1)
for data in buffer.get_data, buffer do
-- prine stat periodically
@@ -32,7 +32,7 @@ function build_trainer(ifname)
end
table.insert(input, data[id])
end
- local output = {nerv.CuMatrixFloat(256, 1)}
+ local output = {nerv.CuMatrixFloat(gconf.batch_size, 1)}
err_output = {input[1]:create()}
network:propagate(input, output)
if bp then