aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authortxh18 <cloudygooseg@gmail.com>2016-01-19 19:28:42 +0800
committertxh18 <cloudygooseg@gmail.com>2016-01-19 19:28:42 +0800
commitae051c0195374ddfab6a0e693d2c2cfa34f24e99 (patch)
treed287fb6cfaa787342fe449da33769074e24a45f5
parentacc2b72b08aa7fe16b85fb47b0308e3da5adae66 (diff)
added lm_sampler, todo:test it
-rw-r--r--nerv/examples/lmptb/lm_sampler.lua98
-rw-r--r--nerv/examples/lmptb/sample_grulm_ptb_main.lua12
2 files changed, 104 insertions, 6 deletions
diff --git a/nerv/examples/lmptb/lm_sampler.lua b/nerv/examples/lmptb/lm_sampler.lua
new file mode 100644
index 0000000..c165127
--- /dev/null
+++ b/nerv/examples/lmptb/lm_sampler.lua
@@ -0,0 +1,98 @@
+local LMSampler = nerv.class('nerv.LMSampler')
+
+function LMSampler:__init(global_conf)
+ self.log_pre = "LMSampler"
+ self.gconf = global_conf
+ self.vocab = self.gconf.vocab
+ self.sen_end_token = self.vocab.sen_end_token
+ self.sen_end_id = self.vocab:get_word_str(self.sen_end_token).id
+end
+
+function LMSampler:load_dagL(dagL)
+ self.batch_size = self.gconf.batch_size
+ self.chunk_size = self.gconf.chunk_size
+
+ nerv.printf("%s loading dagL\n", self.log_pre)
+
+ self.dagL = dagL
+
+ self.dagL_inputs = {}
+ self.dagL_inputs[1] = global_conf.cumat_type(global_conf.batch_size, 1)
+ self.dagL_inputs[1]:fill(self.sen_end_id - 1)
+ self.dagL_inputs[2] = global_conf.cumat_type(global_conf.batch_size, global_conf.hidden_size)
+ self.dagL_inputs[2]:fill(0)
+
+ self.dagL_outputs = {}
+ self.dagL_outputs[1] = global_conf.cumat_type(global_conf.batch_size, global_conf.vocab:size())
+ self.dagL_outputs[2] = global_conf.cumat_type(global_conf.batch_size, global_conf.hidden_size)
+
+ self.smout_d = global_conf.cumat_type(self.batch_size, self.vocab:size())
+ self.smout_h = global_conf.mmat_type(self.batch_size, self.vocab:size())
+
+ self.store = {}
+ for i = 1, self.batch_size do
+ self.store[i] = {}
+ self.store[i][1] = {}
+ self.store[i][1].w = self.sen_end_token
+ self.store[i][1].id = self.sen_end_id
+ self.store[i][1].p = 0
+ end
+end
+
+function LMSampler:sample_to_store(smout)
+ for i = 1, self.batch_size do
+ local ran = math.random()
+ local s = 0, id = self.vocab:size()
+ for j = 0, self.vocab:size() - 1 do
+ s = s + smout[i][j]
+ if s >= ran then
+ id = j + 1
+ break
+ end
+ end
+ if #self.store[i] >= self.chunk_size - 2 then
+ id = self.sen_end_id
+ end
+ local tmp = {}
+ tmp.w = self.vocab:get_word_id(id).str
+ tmp.id = id
+ tmp.p = smout[i][id]
+ table.insert(self.store[i], tmp)
+ end
+end
+
+--Returns: LMResult
+function LMSampler:lm_sample_rnn_dagL(sample_num, p_conf)
+ local dagL = self.dagL
+ local inputs = self.dagL_inputs
+ local outputs = self.dagL_outputs
+
+ local res = {}
+ while #res < sample_num do
+ dagL:propagate(inputs, outputs)
+ inputs[2]:copy_fromd(outputs[2]) --copy hidden activation
+
+ self.smout_d:softmax(outputs[1])
+ self.smout_d:copy_toh(self.smout_h)
+
+ self:sample_to_store(self.smout_h)
+ for i = 1, self.batch_size do
+ inputs[1][i - 1][0] = self.store[i][#self.store[i]].id - 1
+ if self.store[i][#self.store[i]].id == self.sen_end_id then --meet a sentence end
+ if #self.store[i] >= 3 then
+ res[#res + 1] = self.store[i]
+ end
+ inputs[2][i - 1]:fill(0)
+ self.store[i] = {}
+ self.store[i][1] = {}
+ self.store[i][1].w = self.sen_end_token
+ self.store[i][1].id = self.sen_end_id
+ self.store[i][1].p = 0
+ end
+ end
+
+ collectgarbage("collect")
+ end
+
+ return res
+end
diff --git a/nerv/examples/lmptb/sample_grulm_ptb_main.lua b/nerv/examples/lmptb/sample_grulm_ptb_main.lua
index 86209da..b6351bd 100644
--- a/nerv/examples/lmptb/sample_grulm_ptb_main.lua
+++ b/nerv/examples/lmptb/sample_grulm_ptb_main.lua
@@ -5,6 +5,7 @@ require 'lmptb.layer.init'
--require 'tnn.init'
require 'lmptb.lmseqreader'
require 'lm_trainer'
+require 'lm_sampler'
--[[global function rename]]--
--local printf = nerv.printf
@@ -134,7 +135,7 @@ function prepare_tnn(global_conf, layerRepo)
end
function prepare_dagL(global_conf, layerRepo)
- nerv.printf("%s Generate and initing TNN ...\n", global_conf.sche_log_pre)
+ nerv.printf("%s Generate and initing dagL ...\n", global_conf.sche_log_pre)
--input: input_w, input_w, ... input_w_now, last_activation
local connections_t = {
@@ -376,15 +377,12 @@ commands = nerv.SUtil.parse_commands_set(commands_str)
if start_lr ~= nil then
global_conf.lrate = start_lr
end
-
-nerv.printf("%s creating work_dir(%s)...\n", global_conf.sche_log_pre, global_conf.work_dir)
-nerv.LMUtil.wait(2)
-os.execute("mkdir -p "..global_conf.work_dir)
-os.execute("cp " .. global_conf.train_fn .. " " .. global_conf.train_fn_shuf)
+--[[
--redirecting log outputs!
nerv.SUtil.log_redirect(global_conf.log_fn)
nerv.LMUtil.wait(2)
+]]--
----------------printing options---------------------------------
nerv.printf("%s printing global_conf...\n", global_conf.sche_log_pre)
@@ -424,6 +422,8 @@ if commands["sampling"] == 1 then
nerv.printf("===SAMPLE===\n")
global_conf.sche_log_pre = "[SCHEDULER SAMPLING]:"
local dagL = load_net_dagL(global_conf, global_conf.fn_to_sample)
+ local sampler = nerv.LMSampler(global_conf)
+ sampler:load_dagL(dagL)
--global_conf.dropout_rate = 0
--LMTrainer.lm_process_file_rnn(global_conf, global_conf.test_fn, tnn, false) --false update!
end --if commands["sampling"]