aboutsummaryrefslogtreecommitdiff
path: root/nerv/examples/swb_baseline.lua
diff options
context:
space:
mode:
Diffstat (limited to 'nerv/examples/swb_baseline.lua')
-rw-r--r--nerv/examples/swb_baseline.lua7
1 files changed, 5 insertions, 2 deletions
diff --git a/nerv/examples/swb_baseline.lua b/nerv/examples/swb_baseline.lua
index cacc401..4cb2389 100644
--- a/nerv/examples/swb_baseline.lua
+++ b/nerv/examples/swb_baseline.lua
@@ -1,7 +1,5 @@
require 'htk_io'
gconf = {lrate = 0.8, wcost = 1e-6, momentum = 0.9,
- cumat_type = nerv.CuMatrixFloat,
- mmat_type = nerv.MMatrixFloat,
rearrange = true, -- just to make the context order consistent with old results, deprecated
frm_ext = 5,
frm_trim = 5, -- trim the first and last 5 frames, TNet just does this, deprecated
@@ -173,6 +171,7 @@ function make_buffer(readers)
return nerv.SGDBuffer(gconf,
{
buffer_size = gconf.buffer_size,
+ batch_size = gconf.batch_size,
randomize = gconf.randomize,
readers = readers,
use_gpu = true
@@ -184,6 +183,10 @@ function get_input_order()
{id = "phone_state"}}
end
+function get_decode_input_order()
+ return {{id = "main_scp", global_transf = true}}
+end
+
function get_accuracy(layer_repo)
local ce_crit = layer_repo:get_layer("ce_crit")
return ce_crit.total_correct / ce_crit.total_frames * 100