aboutsummaryrefslogtreecommitdiff
path: root/nerv/examples/lmptb/lmptb/lmseqreader.lua
diff options
context:
space:
mode:
Diffstat (limited to 'nerv/examples/lmptb/lmptb/lmseqreader.lua')
-rw-r--r--nerv/examples/lmptb/lmptb/lmseqreader.lua34
1 files changed, 20 insertions, 14 deletions
diff --git a/nerv/examples/lmptb/lmptb/lmseqreader.lua b/nerv/examples/lmptb/lmptb/lmseqreader.lua
index 0f29f8b..1272929 100644
--- a/nerv/examples/lmptb/lmptb/lmseqreader.lua
+++ b/nerv/examples/lmptb/lmptb/lmseqreader.lua
@@ -28,6 +28,10 @@ function LMReader:__init(global_conf, batch_size, chunk_size, vocab, r_conf)
if r_conf.compressed_label == true then
self.compressed_label = true
end
+ self.same_io = false
+ if r_conf.same_io == true then --can be used to train P(wi|w1..(i-1),(i+1)..n)
+ self.same_io = true
+ end
end
--fn: string
@@ -36,9 +40,9 @@ function LMReader:open_file(fn)
if (self.fh ~= nil) then
nerv.error("%s error: in open_file(fn is %s), file handle not nil.", self.log_pre, fn)
end
- printf("%s opening file %s...\n", self.log_pre, fn)
- print(self.log_pre, "batch_size:", self.batch_size, "chunk_size", self.chunk_size)
- print(self.log_pre, "se_mode:", self.se_mode)
+ nerv.printf("%s opening file %s...\n", self.log_pre, fn)
+ nerv.printf("%s batch_size:%d chunk_size:%d\n", self.log_pre, self.batch_size, self.chunk_size)
+ nerv.printf("%s se_mode:%s same_io:%s\n", self.log_pre, tostring(self.se_mode), tostring(self.same_io))
self.fh = io.open(fn, "r")
self.streams = {}
for i = 1, self.batch_size, 1 do
@@ -132,12 +136,15 @@ function LMReader:get_batch(feeds)
else
self:refresh_stream(i)
if st.store[st.head] ~= nil then
- inputs_s[j][i] = st.store[st.head]
- --inputs_m[j][1][i - 1][0] = self.vocab:get_word_str(st.store[st.head]).id - 1
- self.bak_inputs_m[j][1][i - 1][0] = self.vocab:get_word_str(st.store[st.head]).id - 1
+ if self.same_io == false then
+ inputs_s[j][i] = st.store[st.head]
+ self.bak_inputs_m[j][1][i - 1][0] = self.vocab:get_word_str(st.store[st.head]).id - 1
+ else
+ inputs_s[j][i] = st.store[st.head + 1]
+ self.bak_inputs_m[j][1][i - 1][0] = self.vocab:get_word_str(st.store[st.head + 1]).id - 1
+ end
else
inputs_s[j][i] = self.vocab.null_token
- --inputs_m[j][1][i - 1][0] = 0
self.bak_inputs_m[j][1][i - 1][0] = 0
end
if st.store[st.head + 1] ~= nil then
@@ -148,7 +155,7 @@ function LMReader:get_batch(feeds)
inputs_m[j][2][i - 1][self.vocab:get_word_str(st.store[st.head + 1]).id - 1] = 1
end
else
- if (inputs_s[j][i] ~= self.vocab.null_token) then
+ if inputs_s[j][i] ~= self.vocab.null_token then
nerv.error("reader error : input not null but label is null_token")
end
labels_s[j][i] = self.vocab.null_token
@@ -159,6 +166,9 @@ function LMReader:get_batch(feeds)
end
flags[j][i] = bit.bor(flags[j][i], nerv.TNN.FC.SEQ_NORM) --has both input and label
got_new = true
+ if st.store[st.head] == self.vocab.sen_end_token then
+ flags[j][i] = bit.bor(flags[j][i], nerv.TNN.FC.SEQ_START)
+ end
st.store[st.head] = nil
st.head = st.head + 1
if labels_s[j][i] == self.vocab.sen_end_token then
@@ -169,10 +179,7 @@ function LMReader:get_batch(feeds)
end_stream = true --meet sentence end, this stream ends now
end
end
- if inputs_s[j][i] == self.vocab.sen_end_token then
- flags[j][i] = bit.bor(flags[j][i], nerv.TNN.FC.SEQ_START)
- end
- end
+ end
end
end
end
@@ -190,7 +197,7 @@ function LMReader:get_batch(feeds)
--check for self.al_sen_start
for i = 1, self.batch_size do
- if inputs_s[1][i] ~= self.vocab.sen_end_token and inputs_s[1][i] ~= self.vocab.null_token then
+ if bit.band(flags[1][i], nerv.TNN.FC.SEQ_START) == 0 and flags[1][i] > 0 then
self.stat.al_sen_start = false
end
end
@@ -198,7 +205,6 @@ function LMReader:get_batch(feeds)
if got_new == false then
nerv.info("lmseqreader file ends, printing stats...")
nerv.printf("al_sen_start:%s\n", tostring(self.stat.al_sen_start))
-
return false
else
return true