aboutsummaryrefslogtreecommitdiff
path: root/nerv/examples/swb_baseline2.lua
diff options
context:
space:
mode:
authorDeterminant <ted.sybil@gmail.com>2016-03-02 18:24:09 +0800
committerDeterminant <ted.sybil@gmail.com>2016-03-02 18:24:09 +0800
commitad704f2623cc9e0a5d702434bfdebc345465ca12 (patch)
tree898d0688e913efc3ff098ba51e5c1a5488f5771d /nerv/examples/swb_baseline2.lua
parentd3abc6459a776ff7fa3777f4f561bc4f5d5e2075 (diff)
major changes in asr_trainer.lua; unified settings in `gconf`
Diffstat (limited to 'nerv/examples/swb_baseline2.lua')
-rw-r--r--nerv/examples/swb_baseline2.lua7
1 files changed, 5 insertions, 2 deletions
diff --git a/nerv/examples/swb_baseline2.lua b/nerv/examples/swb_baseline2.lua
index 0e2a6e0..b0b9689 100644
--- a/nerv/examples/swb_baseline2.lua
+++ b/nerv/examples/swb_baseline2.lua
@@ -1,7 +1,5 @@
require 'htk_io'
gconf = {lrate = 0.8, wcost = 1e-6, momentum = 0.9,
- cumat_type = nerv.CuMatrixFloat,
- mmat_type = nerv.MMatrixFloat,
rearrange = true, -- just to make the context order consistent with old results, deprecated
frm_ext = 5,
frm_trim = 5, -- trim the first and last 5 frames, TNet just does this, deprecated
@@ -173,6 +171,7 @@ function make_buffer(readers)
return nerv.SGDBuffer(gconf,
{
buffer_size = gconf.buffer_size,
+ batch_size = gconf.batch_size,
randomize = gconf.randomize,
readers = readers,
use_gpu = true
@@ -184,6 +183,10 @@ function get_input_order()
{id = "phone_state"}}
end
+function get_decode_input_order()
+ return {{id = "main_scp", global_transf = true}}
+end
+
function get_accuracy(layer_repo)
local ce_crit = layer_repo:get_layer("ce_crit")
return ce_crit.total_correct / ce_crit.total_frames * 100