aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authortxh18 <cloudygooseg@gmail.com>2015-12-21 13:36:54 +0800
committertxh18 <cloudygooseg@gmail.com>2015-12-21 13:36:54 +0800
commit7f03ce8da24870f2757473385a75ed990b36d817 (patch)
tree4af71834fa4a33d82f0fa0e4e410745ee8e6d329
parent996472e76c31ba560622841b4b31318244317c84 (diff)
added compressed_label support in the reader
-rw-r--r--nerv/examples/lmptb/lm_trainer.lua12
-rw-r--r--nerv/examples/lmptb/lmptb/lmseqreader.lua19
2 files changed, 28 insertions, 3 deletions
diff --git a/nerv/examples/lmptb/lm_trainer.lua b/nerv/examples/lmptb/lm_trainer.lua
index eab6e2d..06c1a4c 100644
--- a/nerv/examples/lmptb/lm_trainer.lua
+++ b/nerv/examples/lmptb/lm_trainer.lua
@@ -23,6 +23,9 @@ function LMTrainer.lm_process_file_rnn(global_conf, fn, tnn, do_train, p_conf)
end
local reader
local r_conf = {}
+ if p_conf.compressed_label ~= nil then
+ r_conf.compressed_label = p_conf.compressed_label
+ end
local chunk_size, batch_size
if p_conf.one_sen_report == true then --report log prob one by one sentence
if do_train == true then
@@ -156,13 +159,16 @@ function LMTrainer.lm_process_file_birnn(global_conf, fn, tnn, do_train, p_conf)
local reader
local chunk_size, batch_size
local r_conf = {["se_mode"] = true}
+ if p_conf.compressed_label ~= nil then
+ r_conf.compressed_label = p_conf.compressed_label
+ end
if p_conf.one_sen_report == true then --report log prob one by one sentence
if do_train == true then
nerv.warning("LMTrainer.lm_process_file_birnn: warning, one_sen_report is true while do_train is also true, strange")
end
nerv.printf("lm_process_file_birnn: one_sen report mode, set batch_size to 1 and chunk_size to max_sen_len(%d)\n",
global_conf.max_sen_len)
- batch_size = 1
+ batch_size = global_conf.batch_size
chunk_size = global_conf.max_sen_len
else
batch_size = global_conf.batch_size
@@ -239,7 +245,9 @@ function LMTrainer.lm_process_file_birnn(global_conf, fn, tnn, do_train, p_conf)
end
if p_conf.one_sen_report == true then
for i = 1, batch_size do
- nerv.printf("LMTrainer.lm_process_file_birnn: one_sen_report_output, %f\n", sen_logp[i])
+ if sen_logp[i] ~= nil then
+ nerv.printf("LMTrainer.lm_process_file_birnn: one_sen_report_output, %f\n", sen_logp[i])
+ end
end
end
diff --git a/nerv/examples/lmptb/lmptb/lmseqreader.lua b/nerv/examples/lmptb/lmptb/lmseqreader.lua
index b603911..0f29f8b 100644
--- a/nerv/examples/lmptb/lmptb/lmseqreader.lua
+++ b/nerv/examples/lmptb/lmptb/lmseqreader.lua
@@ -24,6 +24,10 @@ function LMReader:__init(global_conf, batch_size, chunk_size, vocab, r_conf)
if r_conf.se_mode == true then
self.se_mode = true
end
+ self.compressed_label = false
+ if r_conf.compressed_label == true then
+ self.compressed_label = true
+ end
end
--fn: string
@@ -46,6 +50,9 @@ function LMReader:open_file(fn)
for j = 1, self.chunk_size, 1 do
self.bak_inputs_m[j] = {}
self.bak_inputs_m[j][1] = self.gconf.mmat_type(self.batch_size, 1)
+ if self.compressed_label == true then
+ self.bak_inputs_m[j][2] = self.gconf.mmat_type(self.batch_size, 1)
+ end
--self.bak_inputs_m[j][2] = self.gconf.mmat_type(self.batch_size, self.vocab:size()) --since MMatrix does not yet have fill, this m[j][2] is not used
end
end
@@ -118,6 +125,9 @@ function LMReader:get_batch(feeds)
end
inputs_s[j][i] = self.vocab.null_token
self.bak_inputs_m[j][1][i - 1][0] = 0
+ if self.compressed_label == true then
+ self.bak_inputs_m[j][2][i - 1][0] = 0
+ end
labels_s[j][i] = self.vocab.null_token
else
self:refresh_stream(i)
@@ -132,7 +142,11 @@ function LMReader:get_batch(feeds)
end
if st.store[st.head + 1] ~= nil then
labels_s[j][i] = st.store[st.head + 1]
- inputs_m[j][2][i - 1][self.vocab:get_word_str(st.store[st.head + 1]).id - 1] = 1
+ if self.compressed_label == true then
+ self.bak_inputs_m[j][2][i - 1][0] = self.vocab:get_word_str(st.store[st.head + 1]).id - 1
+ else
+ inputs_m[j][2][i - 1][self.vocab:get_word_str(st.store[st.head + 1]).id - 1] = 1
+ end
else
if (inputs_s[j][i] ~= self.vocab.null_token) then
nerv.error("reader error : input not null but label is null_token")
@@ -169,6 +183,9 @@ function LMReader:get_batch(feeds)
flagsPack[j] = bit.bor(flagsPack[j], flags[j][i])
end
inputs_m[j][1]:copy_fromh(self.bak_inputs_m[j][1])
+ if self.compressed_label == true then
+ inputs_m[j][2]:copy_fromh(self.bak_inputs_m[j][2])
+ end
end
--check for self.al_sen_start