diff options
author | Determinant <[email protected]> | 2015-08-05 09:29:24 +0800 |
---|---|---|
committer | Determinant <[email protected]> | 2015-08-05 09:29:24 +0800 |
commit | 7579ff4941d7019d4e911978879ec07b62a4e523 (patch) | |
tree | 534e756b083f02bf4f2573a5f82541db9b936466 /embedding_example | |
parent | 7ae89059d68850e12826bc6812e4a6d521e45b53 (diff) |
use expanded features and do global transf in embedding_example
Diffstat (limited to 'embedding_example')
-rw-r--r-- | embedding_example/setup_nerv.lua | 10 | ||||
-rw-r--r-- | embedding_example/swb_baseline_decode.lua | 5 |
2 files changed, 11 insertions, 4 deletions
diff --git a/embedding_example/setup_nerv.lua b/embedding_example/setup_nerv.lua index 3ae878d..49a5dd6 100644 --- a/embedding_example/setup_nerv.lua +++ b/embedding_example/setup_nerv.lua @@ -7,17 +7,19 @@ param_repo:import(gconf.initialized_param, nil, gconf) local sublayer_repo = make_sublayer_repo(param_repo) local layer_repo = make_layer_repo(sublayer_repo, param_repo) local network = get_network(layer_repo) +local global_transf = get_global_transf(layer_repo) local batch_size = 1 network:init(batch_size) function propagator(input, output) - local gpu_input = nerv.CuMatrixFloat(input:nrow(), input:ncol()) + local transformed = nerv.speech_utils.global_transf(input, + global_transf, 0, gconf) -- preprocessing + local gpu_input = nerv.CuMatrixFloat(transformed:nrow(), transformed:ncol()) local gpu_output = nerv.CuMatrixFloat(output:nrow(), output:ncol()) - gpu_input:copy_fromh(input) - print(gpu_input) + print(transformed) + gpu_input:copy_fromh(transformed) network:propagate({gpu_input}, {gpu_output}) gpu_output:copy_toh(output) - print(output) -- collect garbage in-time to save GPU memory collectgarbage("collect") end diff --git a/embedding_example/swb_baseline_decode.lua b/embedding_example/swb_baseline_decode.lua index 14a463b..8cdb320 100644 --- a/embedding_example/swb_baseline_decode.lua +++ b/embedding_example/swb_baseline_decode.lua @@ -107,3 +107,8 @@ end function get_network(layer_repo) return layer_repo:get_layer("main") end + + +function get_global_transf(layer_repo) + return layer_repo:get_layer("global_transf") +end |