summaryrefslogtreecommitdiff
path: root/embedding_example
diff options
context:
space:
mode:
Diffstat (limited to 'embedding_example')
-rw-r--r--embedding_example/setup_nerv.lua9
1 files changed, 4 insertions, 5 deletions
diff --git a/embedding_example/setup_nerv.lua b/embedding_example/setup_nerv.lua
index 5ade950..d80c306 100644
--- a/embedding_example/setup_nerv.lua
+++ b/embedding_example/setup_nerv.lua
@@ -11,12 +11,11 @@ local batch_size = 1
network:init(batch_size)
function propagator(input, output)
- local transformed = nerv.speech_utils.global_transf(input,
- global_transf, 0, gconf) -- preprocessing
- local gpu_input = nerv.CuMatrixFloat(transformed:nrow(), transformed:ncol())
+ local transformed = nerv.speech_utils.global_transf(
+ gconf.cumat_type.new_from_host(input),
+ global_transf, 0, 0, gconf) -- preprocessing
+ local gpu_input = transformed
local gpu_output = nerv.CuMatrixFloat(output:nrow(), output:ncol())
- print(transformed)
- gpu_input:copy_fromh(transformed)
network:propagate({gpu_input}, {gpu_output})
gpu_output:copy_toh(output)
-- collect garbage in-time to save GPU memory