diff options
-rw-r--r-- | nerv/examples/lmptb/lm_sampler.lua | 20 | ||||
-rw-r--r-- | nerv/examples/lmptb/sample_grulm_ptb_main.lua | 15 |
2 files changed, 25 insertions, 10 deletions
diff --git a/nerv/examples/lmptb/lm_sampler.lua b/nerv/examples/lmptb/lm_sampler.lua index c165127..c25a75c 100644 --- a/nerv/examples/lmptb/lm_sampler.lua +++ b/nerv/examples/lmptb/lm_sampler.lua @@ -37,14 +37,16 @@ function LMSampler:load_dagL(dagL) self.store[i][1].id = self.sen_end_id self.store[i][1].p = 0 end + self.repo = {} end function LMSampler:sample_to_store(smout) for i = 1, self.batch_size do local ran = math.random() - local s = 0, id = self.vocab:size() + local s = 0 + local id = self.vocab:size() for j = 0, self.vocab:size() - 1 do - s = s + smout[i][j] + s = s + smout[i - 1][j] if s >= ran then id = j + 1 break @@ -56,7 +58,7 @@ function LMSampler:sample_to_store(smout) local tmp = {} tmp.w = self.vocab:get_word_id(id).str tmp.id = id - tmp.p = smout[i][id] + tmp.p = smout[i - 1][id - 1] table.insert(self.store[i], tmp) end end @@ -66,9 +68,8 @@ function LMSampler:lm_sample_rnn_dagL(sample_num, p_conf) local dagL = self.dagL local inputs = self.dagL_inputs local outputs = self.dagL_outputs - - local res = {} - while #res < sample_num do + + while #self.repo < sample_num do dagL:propagate(inputs, outputs) inputs[2]:copy_fromd(outputs[2]) --copy hidden activation @@ -80,7 +81,7 @@ function LMSampler:lm_sample_rnn_dagL(sample_num, p_conf) inputs[1][i - 1][0] = self.store[i][#self.store[i]].id - 1 if self.store[i][#self.store[i]].id == self.sen_end_id then --meet a sentence end if #self.store[i] >= 3 then - res[#res + 1] = self.store[i] + self.repo[#self.repo + 1] = self.store[i] end inputs[2][i - 1]:fill(0) self.store[i] = {} @@ -94,5 +95,10 @@ function LMSampler:lm_sample_rnn_dagL(sample_num, p_conf) collectgarbage("collect") end + local res = {} + for i = 1, sample_num do + res[i] = self.repo[#self.repo] + self.repo[#self.repo] = nil + end return res end diff --git a/nerv/examples/lmptb/sample_grulm_ptb_main.lua b/nerv/examples/lmptb/sample_grulm_ptb_main.lua index b6351bd..9a13d36 100644 --- a/nerv/examples/lmptb/sample_grulm_ptb_main.lua +++ b/nerv/examples/lmptb/sample_grulm_ptb_main.lua @@ -178,7 +178,7 @@ function prepare_dagL(global_conf, layerRepo) dagL:init(global_conf.batch_size) nerv.printf("%s Initing DAGL end.\n", global_conf.sche_log_pre) - return tnn + return dagL end function load_net_tnn(global_conf, fn) @@ -191,8 +191,8 @@ end function load_net_dagL(global_conf, fn) prepare_parameters(global_conf, fn) local layerRepo = prepare_layers(global_conf) - local tnn = prepare_dagL(global_conf, layerRepo) - return tnn + local dagL = prepare_dagL(global_conf, layerRepo) + return dagL end local train_fn, valid_fn, test_fn @@ -424,6 +424,15 @@ if commands["sampling"] == 1 then local dagL = load_net_dagL(global_conf, global_conf.fn_to_sample) local sampler = nerv.LMSampler(global_conf) sampler:load_dagL(dagL) + for k = 1, 5 do + local res = sampler:lm_sample_rnn_dagL(10, {}) + for i = 1, #res do + for j = 1, #res[i] do + nerv.printf("%s ", res[i][j].w) + end + nerv.printf("\n") + end + end --global_conf.dropout_rate = 0 --LMTrainer.lm_process_file_rnn(global_conf, global_conf.test_fn, tnn, false) --false update! end --if commands["sampling"] |