diff options
-rw-r--r-- | nerv/examples/lmptb/lm_sampler.lua | 98 | ||||
-rw-r--r-- | nerv/examples/lmptb/sample_grulm_ptb_main.lua | 12 |
2 files changed, 104 insertions, 6 deletions
diff --git a/nerv/examples/lmptb/lm_sampler.lua b/nerv/examples/lmptb/lm_sampler.lua new file mode 100644 index 0000000..c165127 --- /dev/null +++ b/nerv/examples/lmptb/lm_sampler.lua @@ -0,0 +1,98 @@ +local LMSampler = nerv.class('nerv.LMSampler') + +function LMSampler:__init(global_conf) + self.log_pre = "LMSampler" + self.gconf = global_conf + self.vocab = self.gconf.vocab + self.sen_end_token = self.vocab.sen_end_token + self.sen_end_id = self.vocab:get_word_str(self.sen_end_token).id +end + +function LMSampler:load_dagL(dagL) + self.batch_size = self.gconf.batch_size + self.chunk_size = self.gconf.chunk_size + + nerv.printf("%s loading dagL\n", self.log_pre) + + self.dagL = dagL + + self.dagL_inputs = {} + self.dagL_inputs[1] = global_conf.cumat_type(global_conf.batch_size, 1) + self.dagL_inputs[1]:fill(self.sen_end_id - 1) + self.dagL_inputs[2] = global_conf.cumat_type(global_conf.batch_size, global_conf.hidden_size) + self.dagL_inputs[2]:fill(0) + + self.dagL_outputs = {} + self.dagL_outputs[1] = global_conf.cumat_type(global_conf.batch_size, global_conf.vocab:size()) + self.dagL_outputs[2] = global_conf.cumat_type(global_conf.batch_size, global_conf.hidden_size) + + self.smout_d = global_conf.cumat_type(self.batch_size, self.vocab:size()) + self.smout_h = global_conf.mmat_type(self.batch_size, self.vocab:size()) + + self.store = {} + for i = 1, self.batch_size do + self.store[i] = {} + self.store[i][1] = {} + self.store[i][1].w = self.sen_end_token + self.store[i][1].id = self.sen_end_id + self.store[i][1].p = 0 + end +end + +function LMSampler:sample_to_store(smout) + for i = 1, self.batch_size do + local ran = math.random() + local s = 0, id = self.vocab:size() + for j = 0, self.vocab:size() - 1 do + s = s + smout[i][j] + if s >= ran then + id = j + 1 + break + end + end + if #self.store[i] >= self.chunk_size - 2 then + id = self.sen_end_id + end + local tmp = {} + tmp.w = self.vocab:get_word_id(id).str + tmp.id = id + tmp.p = smout[i][id] + table.insert(self.store[i], tmp) + end +end + +--Returns: LMResult +function LMSampler:lm_sample_rnn_dagL(sample_num, p_conf) + local dagL = self.dagL + local inputs = self.dagL_inputs + local outputs = self.dagL_outputs + + local res = {} + while #res < sample_num do + dagL:propagate(inputs, outputs) + inputs[2]:copy_fromd(outputs[2]) --copy hidden activation + + self.smout_d:softmax(outputs[1]) + self.smout_d:copy_toh(self.smout_h) + + self:sample_to_store(self.smout_h) + for i = 1, self.batch_size do + inputs[1][i - 1][0] = self.store[i][#self.store[i]].id - 1 + if self.store[i][#self.store[i]].id == self.sen_end_id then --meet a sentence end + if #self.store[i] >= 3 then + res[#res + 1] = self.store[i] + end + inputs[2][i - 1]:fill(0) + self.store[i] = {} + self.store[i][1] = {} + self.store[i][1].w = self.sen_end_token + self.store[i][1].id = self.sen_end_id + self.store[i][1].p = 0 + end + end + + collectgarbage("collect") + end + + return res +end diff --git a/nerv/examples/lmptb/sample_grulm_ptb_main.lua b/nerv/examples/lmptb/sample_grulm_ptb_main.lua index 86209da..b6351bd 100644 --- a/nerv/examples/lmptb/sample_grulm_ptb_main.lua +++ b/nerv/examples/lmptb/sample_grulm_ptb_main.lua @@ -5,6 +5,7 @@ require 'lmptb.layer.init' --require 'tnn.init' require 'lmptb.lmseqreader' require 'lm_trainer' +require 'lm_sampler' --[[global function rename]]-- --local printf = nerv.printf @@ -134,7 +135,7 @@ function prepare_tnn(global_conf, layerRepo) end function prepare_dagL(global_conf, layerRepo) - nerv.printf("%s Generate and initing TNN ...\n", global_conf.sche_log_pre) + nerv.printf("%s Generate and initing dagL ...\n", global_conf.sche_log_pre) --input: input_w, input_w, ... input_w_now, last_activation local connections_t = { @@ -376,15 +377,12 @@ commands = nerv.SUtil.parse_commands_set(commands_str) if start_lr ~= nil then global_conf.lrate = start_lr end - -nerv.printf("%s creating work_dir(%s)...\n", global_conf.sche_log_pre, global_conf.work_dir) -nerv.LMUtil.wait(2) -os.execute("mkdir -p "..global_conf.work_dir) -os.execute("cp " .. global_conf.train_fn .. " " .. global_conf.train_fn_shuf) +--[[ --redirecting log outputs! nerv.SUtil.log_redirect(global_conf.log_fn) nerv.LMUtil.wait(2) +]]-- ----------------printing options--------------------------------- nerv.printf("%s printing global_conf...\n", global_conf.sche_log_pre) @@ -424,6 +422,8 @@ if commands["sampling"] == 1 then nerv.printf("===SAMPLE===\n") global_conf.sche_log_pre = "[SCHEDULER SAMPLING]:" local dagL = load_net_dagL(global_conf, global_conf.fn_to_sample) + local sampler = nerv.LMSampler(global_conf) + sampler:load_dagL(dagL) --global_conf.dropout_rate = 0 --LMTrainer.lm_process_file_rnn(global_conf, global_conf.test_fn, tnn, false) --false update! end --if commands["sampling"] |