aboutsummaryrefslogtreecommitdiff
path: root/nerv/examples
diff options
context:
space:
mode:
authorDeterminant <ted.sybil@gmail.com>2015-08-26 14:26:54 +0800
committerDeterminant <ted.sybil@gmail.com>2015-08-26 14:26:54 +0800
commite81e9832ec4f2ad031fd42b5018cea134e8cda7e (patch)
treeed49289619399a99c80f47398ccc4de9ae4cedf6 /nerv/examples
parented2a4148dbb9c18f428571b3e2970d7b2adfb058 (diff)
move global_transf to asr_trainer.lua
Diffstat (limited to 'nerv/examples')
-rw-r--r--nerv/examples/asr_trainer.lua23
-rw-r--r--nerv/examples/swb_baseline.lua7
-rw-r--r--nerv/examples/swb_baseline_basic.lua7
3 files changed, 27 insertions, 10 deletions
diff --git a/nerv/examples/asr_trainer.lua b/nerv/examples/asr_trainer.lua
index dcadfa3..5a50542 100644
--- a/nerv/examples/asr_trainer.lua
+++ b/nerv/examples/asr_trainer.lua
@@ -3,6 +3,7 @@ function build_trainer(ifname)
param_repo:import(ifname, nil, gconf)
local layer_repo = make_layer_repo(param_repo)
local network = get_network(layer_repo)
+ local global_transf = get_global_transf(layer_repo)
local input_order = get_input_order()
local iterative_trainer = function (prefix, scp_file, bp)
gconf.randomize = bp
@@ -24,15 +25,29 @@ function build_trainer(ifname)
-- break
end
local input = {}
--- if gconf.cnt == 100 then break end
- for i, id in ipairs(input_order) do
+-- if gconf.cnt == 1000 then break end
+ for i, e in ipairs(input_order) do
+ local id = e.id
if data[id] == nil then
nerv.error("input data %s not found", id)
end
- table.insert(input, data[id])
+ local transformed
+ if e.global_transf then
+ transformed = nerv.speech_utils.global_transf(data[id],
+ global_transf,
+ gconf.frm_ext or 0,
+ gconf.frm_trim or 0,
+ gconf)
+ else
+ transformed = data[id]
+ end
+ table.insert(input, transformed)
end
local output = {nerv.CuMatrixFloat(gconf.batch_size, 1)}
- err_output = {input[1]:create()}
+ err_output = {}
+ for i = 1, #input do
+ table.insert(err_output, input[i]:create())
+ end
network:propagate(input, output)
if bp then
network:back_propagate(err_input, err_output, input, output)
diff --git a/nerv/examples/swb_baseline.lua b/nerv/examples/swb_baseline.lua
index 0e9f897..bbc6467 100644
--- a/nerv/examples/swb_baseline.lua
+++ b/nerv/examples/swb_baseline.lua
@@ -3,6 +3,7 @@ gconf = {lrate = 0.8, wcost = 1e-6, momentum = 0.9,
cumat_type = nerv.CuMatrixFloat,
mmat_type = nerv.MMatrixFloat,
frm_ext = 5,
+ frm_trim = 5,
tr_scp = "/slfs1/users/mfy43/swb_ivec/train_bp.scp",
cv_scp = "/slfs1/users/mfy43/swb_ivec/train_cv.scp",
htk_conf = "/slfs1/users/mfy43/swb_ivec/plp_0_d_a.conf",
@@ -161,8 +162,7 @@ function make_readers(scp_file, layer_repo)
dir = "*/",
ext = "lab"
}
- },
- global_transf = layer_repo:get_layer("global_transf")
+ }
}),
data = {main_scp = 429, phone_state = 1}}
}
@@ -178,7 +178,8 @@ function make_buffer(readers)
end
function get_input_order()
- return {"main_scp", "phone_state"}
+ return {{id = "main_scp", global_transf = true},
+ {id = "phone_state"}}
end
function get_accuracy(layer_repo)
diff --git a/nerv/examples/swb_baseline_basic.lua b/nerv/examples/swb_baseline_basic.lua
index c47ec3e..71f04a3 100644
--- a/nerv/examples/swb_baseline_basic.lua
+++ b/nerv/examples/swb_baseline_basic.lua
@@ -3,6 +3,7 @@ gconf = {lrate = 0.8, wcost = 1e-6, momentum = 0.9,
cumat_type = nerv.CuMatrixFloat,
mmat_type = nerv.MMatrixFloat,
frm_ext = 5,
+ frm_trim = 5,
tr_scp = "/slfs1/users/mfy43/swb_ivec/train_bp.scp",
cv_scp = "/slfs1/users/mfy43/swb_ivec/train_cv.scp",
htk_conf = "/slfs1/users/mfy43/swb_ivec/plp_0_d_a.conf",
@@ -124,8 +125,7 @@ function make_readers(scp_file, layer_repo)
dir = "*/",
ext = "lab"
}
- },
- global_transf = layer_repo:get_layer("global_transf")
+ }
}),
data = {main_scp = 429, phone_state = 1}}
}
@@ -141,7 +141,8 @@ function make_buffer(readers)
end
function get_input_order()
- return {"main_scp", "phone_state"}
+ return {{id = "main_scp", global_transf = true},
+ {id = "phone_state"}}
end
function get_accuracy(layer_repo)