From df2a5d287c1889da0d3c91a2f057086b5a080be7 Mon Sep 17 00:00:00 2001 From: txh18 Date: Wed, 2 Dec 2015 21:24:54 +0800 Subject: added se_mode for lmseqreader, todo:check it --- nerv/examples/lmptb/lmptb/lmseqreader.lua | 86 +++++++++++++++--------- nerv/examples/lmptb/m-tests/lmseqreader_test.lua | 3 +- 2 files changed, 56 insertions(+), 33 deletions(-) diff --git a/nerv/examples/lmptb/lmptb/lmseqreader.lua b/nerv/examples/lmptb/lmptb/lmseqreader.lua index 04eba45..ff07415 100644 --- a/nerv/examples/lmptb/lmptb/lmseqreader.lua +++ b/nerv/examples/lmptb/lmptb/lmseqreader.lua @@ -1,4 +1,5 @@ require 'lmptb.lmvocab' +require 'tnn.init' local LMReader = nerv.class("nerv.LMSeqReader") @@ -7,7 +8,7 @@ local printf = nerv.printf --global_conf: table --batch_size: int --vocab: nerv.LMVocab -function LMReader:__init(global_conf, batch_size, chunk_size, vocab) +function LMReader:__init(global_conf, batch_size, chunk_size, vocab, r_conf) self.gconf = global_conf self.fh = nil --file handle to read, nil means currently no file self.batch_size = batch_size @@ -15,6 +16,13 @@ function LMReader:__init(global_conf, batch_size, chunk_size, vocab) self.log_pre = "[LOG]LMSeqReader:" self.vocab = vocab self.streams = nil + if r_conf == nil then + r_conf = {} + end + self.se_mode = false --sentence end mode, when a sentence end is met, the stream after will be null + if r_conf.se_mode == true then + self.se_mode = true + end end --fn: string @@ -25,6 +33,7 @@ function LMReader:open_file(fn) end printf("%s opening file %s...\n", self.log_pre, fn) print(self.log_pre, "batch_size:", self.batch_size, "chunk_size", self.chunk_size) + print(self.log_pre, "se_mode:", self.se_mode) self.fh = io.open(fn, "r") self.streams = {} for i = 1, self.batch_size, 1 do @@ -35,7 +44,7 @@ function LMReader:open_file(fn) for j = 1, self.chunk_size, 1 do self.bak_inputs_m[j] = {} self.bak_inputs_m[j][1] = self.gconf.mmat_type(self.batch_size, 1) - self.bak_inputs_m[j][2] = self.gconf.mmat_type(self.batch_size, self.vocab:size()) --since MMatrix does not yet have fill, this m[j][2] is not used + --self.bak_inputs_m[j][2] = self.gconf.mmat_type(self.batch_size, self.vocab:size()) --since MMatrix does not yet have fill, this m[j][2] is not used end end @@ -98,44 +107,57 @@ function LMReader:get_batch(feeds) end for i = 1, self.batch_size, 1 do local st = self.streams[i] + local end_stream = false --used for se_mode, indicating that this stream is ended for j = 1, self.chunk_size, 1 do flags[j][i] = 0 - self:refresh_stream(i) - if st.store[st.head] ~= nil then - inputs_s[j][i] = st.store[st.head] - --inputs_m[j][1][i - 1][0] = self.vocab:get_word_str(st.store[st.head]).id - 1 - self.bak_inputs_m[j][1][i - 1][0] = self.vocab:get_word_str(st.store[st.head]).id - 1 - else + if end_stream == true then + if self.se_mode == false then + nerv.error("lmseqreader:getbatch: error, end_stream is true while se_mode is false") + end inputs_s[j][i] = self.vocab.null_token - --inputs_m[j][1][i - 1][0] = 0 self.bak_inputs_m[j][1][i - 1][0] = 0 - end - if st.store[st.head + 1] ~= nil then - labels_s[j][i] = st.store[st.head + 1] - inputs_m[j][2][i - 1][self.vocab:get_word_str(st.store[st.head + 1]).id - 1] = 1 + labels_s[j][i] = self.vocab.null_token else - if (inputs_s[j][i] ~= self.vocab.null_token) then - nerv.error("reader error : input not null but label is null_token") + self:refresh_stream(i) + if st.store[st.head] ~= nil then + inputs_s[j][i] = st.store[st.head] + --inputs_m[j][1][i - 1][0] = self.vocab:get_word_str(st.store[st.head]).id - 1 + self.bak_inputs_m[j][1][i - 1][0] = self.vocab:get_word_str(st.store[st.head]).id - 1 + else + inputs_s[j][i] = self.vocab.null_token + --inputs_m[j][1][i - 1][0] = 0 + self.bak_inputs_m[j][1][i - 1][0] = 0 end - labels_s[j][i] = self.vocab.null_token - end - if (inputs_s[j][i] ~= self.vocab.null_token) then - if (labels_s[j][i] == self.vocab.null_token) then - nerv.error("reader error : label is null while input is not null") + if st.store[st.head + 1] ~= nil then + labels_s[j][i] = st.store[st.head + 1] + inputs_m[j][2][i - 1][self.vocab:get_word_str(st.store[st.head + 1]).id - 1] = 1 + else + if (inputs_s[j][i] ~= self.vocab.null_token) then + nerv.error("reader error : input not null but label is null_token") + end + labels_s[j][i] = self.vocab.null_token end - flags[j][i] = bit.bor(flags[j][i], nerv.TNN.FC.SEQ_NORM) - got_new = true - st.store[st.head] = nil - st.head = st.head + 1 - if labels_s[j][i] == self.vocab.sen_end_token then - flags[j][i] = bit.bor(flags[j][i], nerv.TNN.FC.SEQ_END) - st.store[st.head] = nil --sentence end is passed + if inputs_s[j][i] ~= self.vocab.null_token then + if labels_s[j][i] == self.vocab.null_token then + nerv.error("reader error : label is null while input is not null") + end + flags[j][i] = bit.bor(flags[j][i], nerv.TNN.FC.SEQ_NORM) --has both input and label + got_new = true + st.store[st.head] = nil st.head = st.head + 1 - end - if inputs_s[j][i] == self.vocab.sen_end_token then - flags[j][i] = bit.bor(flags[j][i], nerv.TNN.FC.SEQ_START) - end - end + if labels_s[j][i] == self.vocab.sen_end_token then + flags[j][i] = bit.bor(flags[j][i], nerv.TNN.FC.SEQ_END) + st.store[st.head] = nil --sentence end is passed + st.head = st.head + 1 + if self.se_mode == true then + end_stream = true --meet sentence end, this stream ends now + end + end + if inputs_s[j][i] == self.vocab.sen_end_token then + flags[j][i] = bit.bor(flags[j][i], nerv.TNN.FC.SEQ_START) + end + end + end end end diff --git a/nerv/examples/lmptb/m-tests/lmseqreader_test.lua b/nerv/examples/lmptb/m-tests/lmseqreader_test.lua index cbcdcbe..b98ff95 100644 --- a/nerv/examples/lmptb/m-tests/lmseqreader_test.lua +++ b/nerv/examples/lmptb/m-tests/lmseqreader_test.lua @@ -12,7 +12,7 @@ local batch_size = 3 local global_conf = { lrate = 1, wcost = 1e-6, momentum = 0, cumat_type = nerv.CuMatrixFloat, - mmat_type = nerv.CuMatrixFloat, + mmat_type = nerv.MMatrixFloat, hidden_size = 20, chunk_size = chunk_size, @@ -35,6 +35,7 @@ reader:open_file(test_fn) local feeds = {} feeds.flags_now = {} feeds.inputs_m = {} +feeds.flagsPack_now = {} for j = 1, chunk_size do feeds.inputs_m[j] = {global_conf.cumat_type(batch_size, 1), global_conf.cumat_type(batch_size, global_conf.vocab:size())} feeds.flags_now[j] = {} -- cgit v1.2.3-70-g09d2