diff options
author | Determinant <[email protected]> | 2016-02-17 20:14:06 +0800 |
---|---|---|
committer | Determinant <[email protected]> | 2016-02-17 20:14:06 +0800 |
commit | 0ee43c21af4fcd3aed070b1f5ad1eb9feb2ad159 (patch) | |
tree | ceb1d38328767fb657bc0d37ec6e513b08a86277 /nerv/examples/lmptb/lmptb/lmseqreader.lua | |
parent | 490a10c2130773bd022f05513fa2905b6a6c6e91 (diff) |
try to merge manually
Diffstat (limited to 'nerv/examples/lmptb/lmptb/lmseqreader.lua')
-rw-r--r-- | nerv/examples/lmptb/lmptb/lmseqreader.lua | 34 |
1 files changed, 20 insertions, 14 deletions
diff --git a/nerv/examples/lmptb/lmptb/lmseqreader.lua b/nerv/examples/lmptb/lmptb/lmseqreader.lua index 0f29f8b..1272929 100644 --- a/nerv/examples/lmptb/lmptb/lmseqreader.lua +++ b/nerv/examples/lmptb/lmptb/lmseqreader.lua @@ -28,6 +28,10 @@ function LMReader:__init(global_conf, batch_size, chunk_size, vocab, r_conf) if r_conf.compressed_label == true then self.compressed_label = true end + self.same_io = false + if r_conf.same_io == true then --can be used to train P(wi|w1..(i-1),(i+1)..n) + self.same_io = true + end end --fn: string @@ -36,9 +40,9 @@ function LMReader:open_file(fn) if (self.fh ~= nil) then nerv.error("%s error: in open_file(fn is %s), file handle not nil.", self.log_pre, fn) end - printf("%s opening file %s...\n", self.log_pre, fn) - print(self.log_pre, "batch_size:", self.batch_size, "chunk_size", self.chunk_size) - print(self.log_pre, "se_mode:", self.se_mode) + nerv.printf("%s opening file %s...\n", self.log_pre, fn) + nerv.printf("%s batch_size:%d chunk_size:%d\n", self.log_pre, self.batch_size, self.chunk_size) + nerv.printf("%s se_mode:%s same_io:%s\n", self.log_pre, tostring(self.se_mode), tostring(self.same_io)) self.fh = io.open(fn, "r") self.streams = {} for i = 1, self.batch_size, 1 do @@ -132,12 +136,15 @@ function LMReader:get_batch(feeds) else self:refresh_stream(i) if st.store[st.head] ~= nil then - inputs_s[j][i] = st.store[st.head] - --inputs_m[j][1][i - 1][0] = self.vocab:get_word_str(st.store[st.head]).id - 1 - self.bak_inputs_m[j][1][i - 1][0] = self.vocab:get_word_str(st.store[st.head]).id - 1 + if self.same_io == false then + inputs_s[j][i] = st.store[st.head] + self.bak_inputs_m[j][1][i - 1][0] = self.vocab:get_word_str(st.store[st.head]).id - 1 + else + inputs_s[j][i] = st.store[st.head + 1] + self.bak_inputs_m[j][1][i - 1][0] = self.vocab:get_word_str(st.store[st.head + 1]).id - 1 + end else inputs_s[j][i] = self.vocab.null_token - --inputs_m[j][1][i - 1][0] = 0 self.bak_inputs_m[j][1][i - 1][0] = 0 end if st.store[st.head + 1] ~= nil then @@ -148,7 +155,7 @@ function LMReader:get_batch(feeds) inputs_m[j][2][i - 1][self.vocab:get_word_str(st.store[st.head + 1]).id - 1] = 1 end else - if (inputs_s[j][i] ~= self.vocab.null_token) then + if inputs_s[j][i] ~= self.vocab.null_token then nerv.error("reader error : input not null but label is null_token") end labels_s[j][i] = self.vocab.null_token @@ -159,6 +166,9 @@ function LMReader:get_batch(feeds) end flags[j][i] = bit.bor(flags[j][i], nerv.TNN.FC.SEQ_NORM) --has both input and label got_new = true + if st.store[st.head] == self.vocab.sen_end_token then + flags[j][i] = bit.bor(flags[j][i], nerv.TNN.FC.SEQ_START) + end st.store[st.head] = nil st.head = st.head + 1 if labels_s[j][i] == self.vocab.sen_end_token then @@ -169,10 +179,7 @@ function LMReader:get_batch(feeds) end_stream = true --meet sentence end, this stream ends now end end - if inputs_s[j][i] == self.vocab.sen_end_token then - flags[j][i] = bit.bor(flags[j][i], nerv.TNN.FC.SEQ_START) - end - end + end end end end @@ -190,7 +197,7 @@ function LMReader:get_batch(feeds) --check for self.al_sen_start for i = 1, self.batch_size do - if inputs_s[1][i] ~= self.vocab.sen_end_token and inputs_s[1][i] ~= self.vocab.null_token then + if bit.band(flags[1][i], nerv.TNN.FC.SEQ_START) == 0 and flags[1][i] > 0 then self.stat.al_sen_start = false end end @@ -198,7 +205,6 @@ function LMReader:get_batch(feeds) if got_new == false then nerv.info("lmseqreader file ends, printing stats...") nerv.printf("al_sen_start:%s\n", tostring(self.stat.al_sen_start)) - return false else return true |