From 7f03ce8da24870f2757473385a75ed990b36d817 Mon Sep 17 00:00:00 2001 From: txh18 Date: Mon, 21 Dec 2015 13:36:54 +0800 Subject: added compressed_label support in the reader --- nerv/examples/lmptb/lm_trainer.lua | 12 ++++++++++-- nerv/examples/lmptb/lmptb/lmseqreader.lua | 19 ++++++++++++++++++- 2 files changed, 28 insertions(+), 3 deletions(-) diff --git a/nerv/examples/lmptb/lm_trainer.lua b/nerv/examples/lmptb/lm_trainer.lua index eab6e2d..06c1a4c 100644 --- a/nerv/examples/lmptb/lm_trainer.lua +++ b/nerv/examples/lmptb/lm_trainer.lua @@ -23,6 +23,9 @@ function LMTrainer.lm_process_file_rnn(global_conf, fn, tnn, do_train, p_conf) end local reader local r_conf = {} + if p_conf.compressed_label ~= nil then + r_conf.compressed_label = p_conf.compressed_label + end local chunk_size, batch_size if p_conf.one_sen_report == true then --report log prob one by one sentence if do_train == true then @@ -156,13 +159,16 @@ function LMTrainer.lm_process_file_birnn(global_conf, fn, tnn, do_train, p_conf) local reader local chunk_size, batch_size local r_conf = {["se_mode"] = true} + if p_conf.compressed_label ~= nil then + r_conf.compressed_label = p_conf.compressed_label + end if p_conf.one_sen_report == true then --report log prob one by one sentence if do_train == true then nerv.warning("LMTrainer.lm_process_file_birnn: warning, one_sen_report is true while do_train is also true, strange") end nerv.printf("lm_process_file_birnn: one_sen report mode, set batch_size to 1 and chunk_size to max_sen_len(%d)\n", global_conf.max_sen_len) - batch_size = 1 + batch_size = global_conf.batch_size chunk_size = global_conf.max_sen_len else batch_size = global_conf.batch_size @@ -239,7 +245,9 @@ function LMTrainer.lm_process_file_birnn(global_conf, fn, tnn, do_train, p_conf) end if p_conf.one_sen_report == true then for i = 1, batch_size do - nerv.printf("LMTrainer.lm_process_file_birnn: one_sen_report_output, %f\n", sen_logp[i]) + if sen_logp[i] ~= nil then + nerv.printf("LMTrainer.lm_process_file_birnn: one_sen_report_output, %f\n", sen_logp[i]) + end end end diff --git a/nerv/examples/lmptb/lmptb/lmseqreader.lua b/nerv/examples/lmptb/lmptb/lmseqreader.lua index b603911..0f29f8b 100644 --- a/nerv/examples/lmptb/lmptb/lmseqreader.lua +++ b/nerv/examples/lmptb/lmptb/lmseqreader.lua @@ -24,6 +24,10 @@ function LMReader:__init(global_conf, batch_size, chunk_size, vocab, r_conf) if r_conf.se_mode == true then self.se_mode = true end + self.compressed_label = false + if r_conf.compressed_label == true then + self.compressed_label = true + end end --fn: string @@ -46,6 +50,9 @@ function LMReader:open_file(fn) for j = 1, self.chunk_size, 1 do self.bak_inputs_m[j] = {} self.bak_inputs_m[j][1] = self.gconf.mmat_type(self.batch_size, 1) + if self.compressed_label == true then + self.bak_inputs_m[j][2] = self.gconf.mmat_type(self.batch_size, 1) + end --self.bak_inputs_m[j][2] = self.gconf.mmat_type(self.batch_size, self.vocab:size()) --since MMatrix does not yet have fill, this m[j][2] is not used end end @@ -118,6 +125,9 @@ function LMReader:get_batch(feeds) end inputs_s[j][i] = self.vocab.null_token self.bak_inputs_m[j][1][i - 1][0] = 0 + if self.compressed_label == true then + self.bak_inputs_m[j][2][i - 1][0] = 0 + end labels_s[j][i] = self.vocab.null_token else self:refresh_stream(i) @@ -132,7 +142,11 @@ function LMReader:get_batch(feeds) end if st.store[st.head + 1] ~= nil then labels_s[j][i] = st.store[st.head + 1] - inputs_m[j][2][i - 1][self.vocab:get_word_str(st.store[st.head + 1]).id - 1] = 1 + if self.compressed_label == true then + self.bak_inputs_m[j][2][i - 1][0] = self.vocab:get_word_str(st.store[st.head + 1]).id - 1 + else + inputs_m[j][2][i - 1][self.vocab:get_word_str(st.store[st.head + 1]).id - 1] = 1 + end else if (inputs_s[j][i] ~= self.vocab.null_token) then nerv.error("reader error : input not null but label is null_token") @@ -169,6 +183,9 @@ function LMReader:get_batch(feeds) flagsPack[j] = bit.bor(flagsPack[j], flags[j][i]) end inputs_m[j][1]:copy_fromh(self.bak_inputs_m[j][1]) + if self.compressed_label == true then + inputs_m[j][2]:copy_fromh(self.bak_inputs_m[j][2]) + end end --check for self.al_sen_start -- cgit v1.2.3