From 7579ff4941d7019d4e911978879ec07b62a4e523 Mon Sep 17 00:00:00 2001
From: Determinant <ted.sybil@gmail.com>
Date: Wed, 5 Aug 2015 09:29:24 +0800
Subject: use expanded features and do global transf in embedding_example

---
 embedding_example/setup_nerv.lua          | 10 ++++++----
 embedding_example/swb_baseline_decode.lua |  5 +++++
 2 files changed, 11 insertions(+), 4 deletions(-)

(limited to 'embedding_example')

diff --git a/embedding_example/setup_nerv.lua b/embedding_example/setup_nerv.lua
index 3ae878d..49a5dd6 100644
--- a/embedding_example/setup_nerv.lua
+++ b/embedding_example/setup_nerv.lua
@@ -7,17 +7,19 @@ param_repo:import(gconf.initialized_param, nil, gconf)
 local sublayer_repo = make_sublayer_repo(param_repo)
 local layer_repo = make_layer_repo(sublayer_repo, param_repo)
 local network = get_network(layer_repo)
+local global_transf = get_global_transf(layer_repo)
 local batch_size = 1
 network:init(batch_size)
 
 function propagator(input, output)
-    local gpu_input = nerv.CuMatrixFloat(input:nrow(), input:ncol())
+    local transformed = nerv.speech_utils.global_transf(input,
+                            global_transf, 0, gconf) -- preprocessing
+    local gpu_input = nerv.CuMatrixFloat(transformed:nrow(), transformed:ncol())
     local gpu_output = nerv.CuMatrixFloat(output:nrow(), output:ncol())
-    gpu_input:copy_fromh(input)
-    print(gpu_input)
+    print(transformed)
+    gpu_input:copy_fromh(transformed)
     network:propagate({gpu_input}, {gpu_output})
     gpu_output:copy_toh(output)
-    print(output)
     -- collect garbage in-time to save GPU memory
     collectgarbage("collect")
 end
diff --git a/embedding_example/swb_baseline_decode.lua b/embedding_example/swb_baseline_decode.lua
index 14a463b..8cdb320 100644
--- a/embedding_example/swb_baseline_decode.lua
+++ b/embedding_example/swb_baseline_decode.lua
@@ -107,3 +107,8 @@ end
 function get_network(layer_repo)
     return layer_repo:get_layer("main")
 end
+
+
+function get_global_transf(layer_repo)
+    return layer_repo:get_layer("global_transf")
+end
-- 
cgit v1.2.3-70-g09d2