diff options
Diffstat (limited to 'embedding_example/setup_nerv.lua')
-rw-r--r-- | embedding_example/setup_nerv.lua | 25 |
1 files changed, 25 insertions, 0 deletions
diff --git a/embedding_example/setup_nerv.lua b/embedding_example/setup_nerv.lua new file mode 100644 index 0000000..d80c306 --- /dev/null +++ b/embedding_example/setup_nerv.lua @@ -0,0 +1,25 @@ +local k,l,_=pcall(require,"luarocks.loader") _=k and l.add_context("nerv","scm-1") +require 'nerv' +local arg = {...} +dofile(arg[1]) +local param_repo = nerv.ParamRepo() +param_repo:import(gconf.initialized_param, nil, gconf) +local layer_repo = make_layer_repo(param_repo) +local network = get_decode_network(layer_repo) +local global_transf = get_global_transf(layer_repo) +local batch_size = 1 +network:init(batch_size) + +function propagator(input, output) + local transformed = nerv.speech_utils.global_transf( + gconf.cumat_type.new_from_host(input), + global_transf, 0, 0, gconf) -- preprocessing + local gpu_input = transformed + local gpu_output = nerv.CuMatrixFloat(output:nrow(), output:ncol()) + network:propagate({gpu_input}, {gpu_output}) + gpu_output:copy_toh(output) + -- collect garbage in-time to save GPU memory + collectgarbage("collect") +end + +return network.dim_in[1], network.dim_out[1], propagator |