aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--nerv/examples/lmptb/lm_sampler.lua20
-rw-r--r--nerv/examples/lmptb/sample_grulm_ptb_main.lua15
2 files changed, 25 insertions, 10 deletions
diff --git a/nerv/examples/lmptb/lm_sampler.lua b/nerv/examples/lmptb/lm_sampler.lua
index c165127..c25a75c 100644
--- a/nerv/examples/lmptb/lm_sampler.lua
+++ b/nerv/examples/lmptb/lm_sampler.lua
@@ -37,14 +37,16 @@ function LMSampler:load_dagL(dagL)
self.store[i][1].id = self.sen_end_id
self.store[i][1].p = 0
end
+ self.repo = {}
end
function LMSampler:sample_to_store(smout)
for i = 1, self.batch_size do
local ran = math.random()
- local s = 0, id = self.vocab:size()
+ local s = 0
+ local id = self.vocab:size()
for j = 0, self.vocab:size() - 1 do
- s = s + smout[i][j]
+ s = s + smout[i - 1][j]
if s >= ran then
id = j + 1
break
@@ -56,7 +58,7 @@ function LMSampler:sample_to_store(smout)
local tmp = {}
tmp.w = self.vocab:get_word_id(id).str
tmp.id = id
- tmp.p = smout[i][id]
+ tmp.p = smout[i - 1][id - 1]
table.insert(self.store[i], tmp)
end
end
@@ -66,9 +68,8 @@ function LMSampler:lm_sample_rnn_dagL(sample_num, p_conf)
local dagL = self.dagL
local inputs = self.dagL_inputs
local outputs = self.dagL_outputs
-
- local res = {}
- while #res < sample_num do
+
+ while #self.repo < sample_num do
dagL:propagate(inputs, outputs)
inputs[2]:copy_fromd(outputs[2]) --copy hidden activation
@@ -80,7 +81,7 @@ function LMSampler:lm_sample_rnn_dagL(sample_num, p_conf)
inputs[1][i - 1][0] = self.store[i][#self.store[i]].id - 1
if self.store[i][#self.store[i]].id == self.sen_end_id then --meet a sentence end
if #self.store[i] >= 3 then
- res[#res + 1] = self.store[i]
+ self.repo[#self.repo + 1] = self.store[i]
end
inputs[2][i - 1]:fill(0)
self.store[i] = {}
@@ -94,5 +95,10 @@ function LMSampler:lm_sample_rnn_dagL(sample_num, p_conf)
collectgarbage("collect")
end
+ local res = {}
+ for i = 1, sample_num do
+ res[i] = self.repo[#self.repo]
+ self.repo[#self.repo] = nil
+ end
return res
end
diff --git a/nerv/examples/lmptb/sample_grulm_ptb_main.lua b/nerv/examples/lmptb/sample_grulm_ptb_main.lua
index b6351bd..9a13d36 100644
--- a/nerv/examples/lmptb/sample_grulm_ptb_main.lua
+++ b/nerv/examples/lmptb/sample_grulm_ptb_main.lua
@@ -178,7 +178,7 @@ function prepare_dagL(global_conf, layerRepo)
dagL:init(global_conf.batch_size)
nerv.printf("%s Initing DAGL end.\n", global_conf.sche_log_pre)
- return tnn
+ return dagL
end
function load_net_tnn(global_conf, fn)
@@ -191,8 +191,8 @@ end
function load_net_dagL(global_conf, fn)
prepare_parameters(global_conf, fn)
local layerRepo = prepare_layers(global_conf)
- local tnn = prepare_dagL(global_conf, layerRepo)
- return tnn
+ local dagL = prepare_dagL(global_conf, layerRepo)
+ return dagL
end
local train_fn, valid_fn, test_fn
@@ -424,6 +424,15 @@ if commands["sampling"] == 1 then
local dagL = load_net_dagL(global_conf, global_conf.fn_to_sample)
local sampler = nerv.LMSampler(global_conf)
sampler:load_dagL(dagL)
+ for k = 1, 5 do
+ local res = sampler:lm_sample_rnn_dagL(10, {})
+ for i = 1, #res do
+ for j = 1, #res[i] do
+ nerv.printf("%s ", res[i][j].w)
+ end
+ nerv.printf("\n")
+ end
+ end
--global_conf.dropout_rate = 0
--LMTrainer.lm_process_file_rnn(global_conf, global_conf.test_fn, tnn, false) --false update!
end --if commands["sampling"]