aboutsummaryrefslogtreecommitdiff
path: root/nerv/examples/lmptb/sample_grulm_ptb_main.lua
diff options
context:
space:
mode:
authortxh18 <cloudygooseg@gmail.com>2016-01-20 15:15:28 +0800
committertxh18 <cloudygooseg@gmail.com>2016-01-20 15:15:28 +0800
commite829ef9253dfece9ff9f599130bc625f1267e136 (patch)
treec23d6661450a178d563599ad2a97ea3560967f72 /nerv/examples/lmptb/sample_grulm_ptb_main.lua
parent20e0c009ab387107ed1a492dd42b0253ca19ed37 (diff)
used prefixsum_row operation to speed up sampling on the softmax output
Diffstat (limited to 'nerv/examples/lmptb/sample_grulm_ptb_main.lua')
-rw-r--r--nerv/examples/lmptb/sample_grulm_ptb_main.lua4
1 files changed, 2 insertions, 2 deletions
diff --git a/nerv/examples/lmptb/sample_grulm_ptb_main.lua b/nerv/examples/lmptb/sample_grulm_ptb_main.lua
index 9a13d36..30dfe26 100644
--- a/nerv/examples/lmptb/sample_grulm_ptb_main.lua
+++ b/nerv/examples/lmptb/sample_grulm_ptb_main.lua
@@ -424,11 +424,11 @@ if commands["sampling"] == 1 then
local dagL = load_net_dagL(global_conf, global_conf.fn_to_sample)
local sampler = nerv.LMSampler(global_conf)
sampler:load_dagL(dagL)
- for k = 1, 5 do
+ for k = 1, 1 do
local res = sampler:lm_sample_rnn_dagL(10, {})
for i = 1, #res do
for j = 1, #res[i] do
- nerv.printf("%s ", res[i][j].w)
+ nerv.printf("%s(%f) ", res[i][j].w, res[i][j].p)
end
nerv.printf("\n")
end