aboutsummaryrefslogtreecommitdiff
path: root/nerv/examples/lmptb/lmptb/lmseqreader.lua
diff options
context:
space:
mode:
Diffstat (limited to 'nerv/examples/lmptb/lmptb/lmseqreader.lua')
-rw-r--r--nerv/examples/lmptb/lmptb/lmseqreader.lua28
1 files changed, 20 insertions, 8 deletions
diff --git a/nerv/examples/lmptb/lmptb/lmseqreader.lua b/nerv/examples/lmptb/lmptb/lmseqreader.lua
index e0dcd95..cc805a4 100644
--- a/nerv/examples/lmptb/lmptb/lmseqreader.lua
+++ b/nerv/examples/lmptb/lmptb/lmseqreader.lua
@@ -30,6 +30,13 @@ function LMReader:open_file(fn)
for i = 1, self.batch_size, 1 do
self.streams[i] = {["store"] = {}, ["head"] = 1, ["tail"] = 0}
end
+
+ self.bak_inputs_m = {} --backup MMatrix for temporary storey(then copy to TNN CuMatrix)
+ for j = 1, self.chunk_size, 1 do
+ self.bak_inputs_m[j] = {}
+ self.bak_inputs_m[j][1] = self.gconf.mmat_type(self.batch_size, 1)
+ self.bak_inputs_m[j][2] = self.gconf.mmat_type(self.batch_size, self.vocab:size()) --since MMatrix does not yet have fill, this m[j][2] is not used
+ end
end
--id: int
@@ -78,7 +85,7 @@ function LMReader:get_batch(feeds)
local labels_s = feeds.labels_s
for i = 1, self.chunk_size, 1 do
inputs_s[i] = {}
- labels_s[i] = {}
+ labels_s[i] = {}
end
local inputs_m = feeds.inputs_m --port 1 : word_id, port 2 : label
@@ -86,20 +93,24 @@ function LMReader:get_batch(feeds)
local flagsPack = feeds.flagsPack_now
local got_new = false
+ for j = 1, self.chunk_size, 1 do
+ inputs_m[j][2]:fill(0)
+ end
for i = 1, self.batch_size, 1 do
local st = self.streams[i]
for j = 1, self.chunk_size, 1 do
flags[j][i] = 0
self:refresh_stream(i)
- if (st.store[st.head] ~= nil) then
+ if st.store[st.head] ~= nil then
inputs_s[j][i] = st.store[st.head]
- inputs_m[j][1][i - 1][0] = self.vocab:get_word_str(st.store[st.head]).id - 1
+ --inputs_m[j][1][i - 1][0] = self.vocab:get_word_str(st.store[st.head]).id - 1
+ self.bak_inputs_m[j][1][i - 1][0] = self.vocab:get_word_str(st.store[st.head]).id - 1
else
inputs_s[j][i] = self.vocab.null_token
- inputs_m[j][1][i - 1][0] = 0
+ --inputs_m[j][1][i - 1][0] = 0
+ self.bak_inputs_m[j][1][i - 1][0] = 0
end
- inputs_m[j][2][i - 1]:fill(0)
- if (st.store[st.head + 1] ~= nil) then
+ if st.store[st.head + 1] ~= nil then
labels_s[j][i] = st.store[st.head + 1]
inputs_m[j][2][i - 1][self.vocab:get_word_str(st.store[st.head + 1]).id - 1] = 1
else
@@ -116,12 +127,12 @@ function LMReader:get_batch(feeds)
got_new = true
st.store[st.head] = nil
st.head = st.head + 1
- if (labels_s[j][i] == self.vocab.sen_end_token) then
+ if labels_s[j][i] == self.vocab.sen_end_token then
flags[j][i] = bit.bor(flags[j][i], nerv.TNN.FC.SEQ_END)
st.store[st.head] = nil --sentence end is passed
st.head = st.head + 1
end
- if (inputs_s[j][i] == self.vocab.sen_end_token) then
+ if inputs_s[j][i] == self.vocab.sen_end_token then
flags[j][i] = bit.bor(flags[j][i], nerv.TNN.FC.SEQ_START)
end
end
@@ -133,6 +144,7 @@ function LMReader:get_batch(feeds)
for i = 1, self.batch_size, 1 do
flagsPack[j] = bit.bor(flagsPack[j], flags[j][i])
end
+ inputs_m[j][1]:copy_fromh(self.bak_inputs_m[j][1])
end
if (got_new == false) then