aboutsummaryrefslogtreecommitdiff
path: root/embedding_example
diff options
context:
space:
mode:
authorDeterminant <ted.sybil@gmail.com>2015-08-05 09:29:24 +0800
committerDeterminant <ted.sybil@gmail.com>2015-08-05 09:29:24 +0800
commit7579ff4941d7019d4e911978879ec07b62a4e523 (patch)
tree534e756b083f02bf4f2573a5f82541db9b936466 /embedding_example
parent7ae89059d68850e12826bc6812e4a6d521e45b53 (diff)
use expanded features and do global transf in embedding_example
Diffstat (limited to 'embedding_example')
-rw-r--r--embedding_example/setup_nerv.lua10
-rw-r--r--embedding_example/swb_baseline_decode.lua5
2 files changed, 11 insertions, 4 deletions
diff --git a/embedding_example/setup_nerv.lua b/embedding_example/setup_nerv.lua
index 3ae878d..49a5dd6 100644
--- a/embedding_example/setup_nerv.lua
+++ b/embedding_example/setup_nerv.lua
@@ -7,17 +7,19 @@ param_repo:import(gconf.initialized_param, nil, gconf)
local sublayer_repo = make_sublayer_repo(param_repo)
local layer_repo = make_layer_repo(sublayer_repo, param_repo)
local network = get_network(layer_repo)
+local global_transf = get_global_transf(layer_repo)
local batch_size = 1
network:init(batch_size)
function propagator(input, output)
- local gpu_input = nerv.CuMatrixFloat(input:nrow(), input:ncol())
+ local transformed = nerv.speech_utils.global_transf(input,
+ global_transf, 0, gconf) -- preprocessing
+ local gpu_input = nerv.CuMatrixFloat(transformed:nrow(), transformed:ncol())
local gpu_output = nerv.CuMatrixFloat(output:nrow(), output:ncol())
- gpu_input:copy_fromh(input)
- print(gpu_input)
+ print(transformed)
+ gpu_input:copy_fromh(transformed)
network:propagate({gpu_input}, {gpu_output})
gpu_output:copy_toh(output)
- print(output)
-- collect garbage in-time to save GPU memory
collectgarbage("collect")
end
diff --git a/embedding_example/swb_baseline_decode.lua b/embedding_example/swb_baseline_decode.lua
index 14a463b..8cdb320 100644
--- a/embedding_example/swb_baseline_decode.lua
+++ b/embedding_example/swb_baseline_decode.lua
@@ -107,3 +107,8 @@ end
function get_network(layer_repo)
return layer_repo:get_layer("main")
end
+
+
+function get_global_transf(layer_repo)
+ return layer_repo:get_layer("global_transf")
+end