diff options
Diffstat (limited to 'nerv/examples/lmptb/lmptb/lmseqreader.lua')
-rw-r--r-- | nerv/examples/lmptb/lmptb/lmseqreader.lua | 28 |
1 files changed, 20 insertions, 8 deletions
diff --git a/nerv/examples/lmptb/lmptb/lmseqreader.lua b/nerv/examples/lmptb/lmptb/lmseqreader.lua index e0dcd95..cc805a4 100644 --- a/nerv/examples/lmptb/lmptb/lmseqreader.lua +++ b/nerv/examples/lmptb/lmptb/lmseqreader.lua @@ -30,6 +30,13 @@ function LMReader:open_file(fn) for i = 1, self.batch_size, 1 do self.streams[i] = {["store"] = {}, ["head"] = 1, ["tail"] = 0} end + + self.bak_inputs_m = {} --backup MMatrix for temporary storey(then copy to TNN CuMatrix) + for j = 1, self.chunk_size, 1 do + self.bak_inputs_m[j] = {} + self.bak_inputs_m[j][1] = self.gconf.mmat_type(self.batch_size, 1) + self.bak_inputs_m[j][2] = self.gconf.mmat_type(self.batch_size, self.vocab:size()) --since MMatrix does not yet have fill, this m[j][2] is not used + end end --id: int @@ -78,7 +85,7 @@ function LMReader:get_batch(feeds) local labels_s = feeds.labels_s for i = 1, self.chunk_size, 1 do inputs_s[i] = {} - labels_s[i] = {} + labels_s[i] = {} end local inputs_m = feeds.inputs_m --port 1 : word_id, port 2 : label @@ -86,20 +93,24 @@ function LMReader:get_batch(feeds) local flagsPack = feeds.flagsPack_now local got_new = false + for j = 1, self.chunk_size, 1 do + inputs_m[j][2]:fill(0) + end for i = 1, self.batch_size, 1 do local st = self.streams[i] for j = 1, self.chunk_size, 1 do flags[j][i] = 0 self:refresh_stream(i) - if (st.store[st.head] ~= nil) then + if st.store[st.head] ~= nil then inputs_s[j][i] = st.store[st.head] - inputs_m[j][1][i - 1][0] = self.vocab:get_word_str(st.store[st.head]).id - 1 + --inputs_m[j][1][i - 1][0] = self.vocab:get_word_str(st.store[st.head]).id - 1 + self.bak_inputs_m[j][1][i - 1][0] = self.vocab:get_word_str(st.store[st.head]).id - 1 else inputs_s[j][i] = self.vocab.null_token - inputs_m[j][1][i - 1][0] = 0 + --inputs_m[j][1][i - 1][0] = 0 + self.bak_inputs_m[j][1][i - 1][0] = 0 end - inputs_m[j][2][i - 1]:fill(0) - if (st.store[st.head + 1] ~= nil) then + if st.store[st.head + 1] ~= nil then labels_s[j][i] = st.store[st.head + 1] inputs_m[j][2][i - 1][self.vocab:get_word_str(st.store[st.head + 1]).id - 1] = 1 else @@ -116,12 +127,12 @@ function LMReader:get_batch(feeds) got_new = true st.store[st.head] = nil st.head = st.head + 1 - if (labels_s[j][i] == self.vocab.sen_end_token) then + if labels_s[j][i] == self.vocab.sen_end_token then flags[j][i] = bit.bor(flags[j][i], nerv.TNN.FC.SEQ_END) st.store[st.head] = nil --sentence end is passed st.head = st.head + 1 end - if (inputs_s[j][i] == self.vocab.sen_end_token) then + if inputs_s[j][i] == self.vocab.sen_end_token then flags[j][i] = bit.bor(flags[j][i], nerv.TNN.FC.SEQ_START) end end @@ -133,6 +144,7 @@ function LMReader:get_batch(feeds) for i = 1, self.batch_size, 1 do flagsPack[j] = bit.bor(flagsPack[j], flags[j][i]) end + inputs_m[j][1]:copy_fromh(self.bak_inputs_m[j][1]) end if (got_new == false) then |