diff options
Diffstat (limited to 'nerv/examples/lmptb/lmptb/lmseqreader.lua')
-rw-r--r-- | nerv/examples/lmptb/lmptb/lmseqreader.lua | 22 |
1 files changed, 20 insertions, 2 deletions
diff --git a/nerv/examples/lmptb/lmptb/lmseqreader.lua b/nerv/examples/lmptb/lmptb/lmseqreader.lua index ed791d2..0f29f8b 100644 --- a/nerv/examples/lmptb/lmptb/lmseqreader.lua +++ b/nerv/examples/lmptb/lmptb/lmseqreader.lua @@ -1,4 +1,5 @@ require 'lmptb.lmvocab' +require 'lmptb.lmutil' --require 'tnn.init' local LMReader = nerv.class("nerv.LMSeqReader") @@ -23,6 +24,10 @@ function LMReader:__init(global_conf, batch_size, chunk_size, vocab, r_conf) if r_conf.se_mode == true then self.se_mode = true end + self.compressed_label = false + if r_conf.compressed_label == true then + self.compressed_label = true + end end --fn: string @@ -45,6 +50,9 @@ function LMReader:open_file(fn) for j = 1, self.chunk_size, 1 do self.bak_inputs_m[j] = {} self.bak_inputs_m[j][1] = self.gconf.mmat_type(self.batch_size, 1) + if self.compressed_label == true then + self.bak_inputs_m[j][2] = self.gconf.mmat_type(self.batch_size, 1) + end --self.bak_inputs_m[j][2] = self.gconf.mmat_type(self.batch_size, self.vocab:size()) --since MMatrix does not yet have fill, this m[j][2] is not used end end @@ -58,7 +66,7 @@ function LMReader:refresh_stream(id) local st = self.streams[id] if (st.store[st.head] ~= nil) then return end if (self.fh == nil) then return end - local list = self.vocab:read_line(self.fh) + local list = nerv.LMUtil.read_line(self.fh) if (list == nil) then --file has end printf("%s file expires, closing.\n", self.log_pre) self.fh:close() @@ -117,6 +125,9 @@ function LMReader:get_batch(feeds) end inputs_s[j][i] = self.vocab.null_token self.bak_inputs_m[j][1][i - 1][0] = 0 + if self.compressed_label == true then + self.bak_inputs_m[j][2][i - 1][0] = 0 + end labels_s[j][i] = self.vocab.null_token else self:refresh_stream(i) @@ -131,7 +142,11 @@ function LMReader:get_batch(feeds) end if st.store[st.head + 1] ~= nil then labels_s[j][i] = st.store[st.head + 1] - inputs_m[j][2][i - 1][self.vocab:get_word_str(st.store[st.head + 1]).id - 1] = 1 + if self.compressed_label == true then + self.bak_inputs_m[j][2][i - 1][0] = self.vocab:get_word_str(st.store[st.head + 1]).id - 1 + else + inputs_m[j][2][i - 1][self.vocab:get_word_str(st.store[st.head + 1]).id - 1] = 1 + end else if (inputs_s[j][i] ~= self.vocab.null_token) then nerv.error("reader error : input not null but label is null_token") @@ -168,6 +183,9 @@ function LMReader:get_batch(feeds) flagsPack[j] = bit.bor(flagsPack[j], flags[j][i]) end inputs_m[j][1]:copy_fromh(self.bak_inputs_m[j][1]) + if self.compressed_label == true then + inputs_m[j][2]:copy_fromh(self.bak_inputs_m[j][2]) + end end --check for self.al_sen_start |