require 'lmptb.lmvocab' require 'lmptb.lmutil' --require 'tnn.init' local LMReader = nerv.class("nerv.LMSeqReader") local printf = nerv.printf --global_conf: table --batch_size: int --vocab: nerv.LMVocab function LMReader:__init(global_conf, batch_size, chunk_size, vocab, r_conf) self.gconf = global_conf self.fh = nil --file handle to read, nil means currently no file self.batch_size = batch_size self.chunk_size = chunk_size self.log_pre = "[LOG]LMSeqReader:" self.vocab = vocab self.streams = nil if r_conf == nil then r_conf = {} end self.se_mode = false --sentence end mode, when a sentence end is met, the stream after will be null if r_conf.se_mode == true then self.se_mode = true end self.compressed_label = false if r_conf.compressed_label == true then self.compressed_label = true end self.same_io = false if r_conf.same_io == true then --can be used to train P(wi|w1..(i-1),(i+1)..n) self.same_io = true end end --fn: string --Initialize all streams function LMReader:open_file(fn) if (self.fh ~= nil) then nerv.error("%s error: in open_file(fn is %s), file handle not nil.", self.log_pre, fn) end nerv.printf("%s opening file %s...\n", self.log_pre, fn) nerv.printf("%s batch_size:%d chunk_size:%d\n", self.log_pre, self.batch_size, self.chunk_size) nerv.printf("%s se_mode:%s same_io:%s\n", self.log_pre, tostring(self.se_mode), tostring(self.same_io)) self.fh = io.open(fn, "r") self.streams = {} for i = 1, self.batch_size, 1 do self.streams[i] = {["store"] = {}, ["head"] = 1, ["tail"] = 0} end self.stat = {} --stat collected during file reading self.stat.al_sen_start = true --check whether it's always sentence_start at the begining of a minibatch self.bak_inputs_m = {} --backup MMatrix for temporary storey(then copy to TNN CuMatrix) for j = 1, self.chunk_size, 1 do self.bak_inputs_m[j] = {} self.bak_inputs_m[j][1] = self.gconf.mmat_type(self.batch_size, 1) if self.compressed_label == true then self.bak_inputs_m[j][2] = self.gconf.mmat_type(self.batch_size, 1) end --self.bak_inputs_m[j][2] = self.gconf.mmat_type(self.batch_size, self.vocab:size()) --since MMatrix does not yet have fill, this m[j][2] is not used end end --id: int --Refresh stream id, read a line from file, will check whether this line is cntklm-style function LMReader:refresh_stream(id) if (self.streams[id] == nil) then nerv.error("stream %d does not exit.", id) end local st = self.streams[id] if (st.store[st.head] ~= nil) then return end if (self.fh == nil) then return end local list = nerv.LMUtil.read_line(self.fh) if (list == nil) then --file has end printf("%s file expires, closing.\n", self.log_pre) self.fh:close() self.fh = nil return end --some sanity check if (list[1] ~= self.vocab.sen_end_token or list[#list] ~= self.vocab.sen_end_token) then --check for cntklm style input nerv.error("%s sentence not begin or end with : %s", self.log_pre, table.tostring(list)); end for i = 2, #list - 1, 1 do if (list[i] == self.vocab.sen_end_token) then nerv.error("%s Got in the middle of a line(%s) in file", self.log_pre, table.tostring(list)) end end for i = 1, #list, 1 do st.tail = st.tail + 1 st.store[st.tail] = list[i] end end --feeds: a table that will be filled by the reader --Returns: bool function LMReader:get_batch(feeds) if (feeds == nil or type(feeds) ~= "table") then nerv.error("feeds is not a table") end feeds["inputs_s"] = {} feeds["labels_s"] = {} local inputs_s = feeds.inputs_s local labels_s = feeds.labels_s for i = 1, self.chunk_size, 1 do inputs_s[i] = {} labels_s[i] = {} end local inputs_m = feeds.inputs_m --port 1 : word_id, port 2 : label local flags = feeds.flags_now local flagsPack = feeds.flagsPack_now local got_new = false for j = 1, self.chunk_size, 1 do inputs_m[j][2]:fill(0) end for i = 1, self.batch_size, 1 do local st = self.streams[i] local end_stream = false --used for se_mode, indicating that this stream is ended for j = 1, self.chunk_size, 1 do flags[j][i] = 0 if end_stream == true then if self.se_mode == false then nerv.error("lmseqreader:getbatch: error, end_stream is true while se_mode is false") end inputs_s[j][i] = self.vocab.null_token self.bak_inputs_m[j][1][i - 1][0] = 0 if self.compressed_label == true then self.bak_inputs_m[j][2][i - 1][0] = 0 end labels_s[j][i] = self.vocab.null_token else self:refresh_stream(i) if st.store[st.head] ~= nil then if self.same_io == false then inputs_s[j][i] = st.store[st.head] self.bak_inputs_m[j][1][i - 1][0] = self.vocab:get_word_str(st.store[st.head]).id - 1 else inputs_s[j][i] = st.store[st.head + 1] self.bak_inputs_m[j][1][i - 1][0] = self.vocab:get_word_str(st.store[st.head + 1]).id - 1 end else inputs_s[j][i] = self.vocab.null_token self.bak_inputs_m[j][1][i - 1][0] = 0 end if st.store[st.head + 1] ~= nil then labels_s[j][i] = st.store[st.head + 1] if self.compressed_label == true then self.bak_inputs_m[j][2][i - 1][0] = self.vocab:get_word_str(st.store[st.head + 1]).id - 1 else inputs_m[j][2][i - 1][self.vocab:get_word_str(st.store[st.head + 1]).id - 1] = 1 end else if inputs_s[j][i] ~= self.vocab.null_token then nerv.error("reader error : input not null but label is null_token") end labels_s[j][i] = self.vocab.null_token end if inputs_s[j][i] ~= self.vocab.null_token then if labels_s[j][i] == self.vocab.null_token then nerv.error("reader error : label is null while input is not null") end flags[j][i] = bit.bor(flags[j][i], nerv.TNN.FC.SEQ_NORM) --has both input and label got_new = true if st.store[st.head] == self.vocab.sen_end_token then flags[j][i] = bit.bor(flags[j][i], nerv.TNN.FC.SEQ_START) end st.store[st.head] = nil st.head = st.head + 1 if labels_s[j][i] == self.vocab.sen_end_token then flags[j][i] = bit.bor(flags[j][i], nerv.TNN.FC.SEQ_END) st.store[st.head] = nil --sentence end is passed st.head = st.head + 1 if self.se_mode == true then end_stream = true --meet sentence end, this stream ends now end end end end end end for j = 1, self.chunk_size, 1 do flagsPack[j] = 0 for i = 1, self.batch_size, 1 do flagsPack[j] = bit.bor(flagsPack[j], flags[j][i]) end inputs_m[j][1]:copy_fromh(self.bak_inputs_m[j][1]) if self.compressed_label == true then inputs_m[j][2]:copy_fromh(self.bak_inputs_m[j][2]) end end --check for self.al_sen_start for i = 1, self.batch_size do if bit.band(flags[1][i], nerv.TNN.FC.SEQ_START) == 0 and flags[1][i] > 0 then self.stat.al_sen_start = false end end if got_new == false then nerv.info("lmseqreader file ends, printing stats...") nerv.printf("al_sen_start:%s\n", tostring(self.stat.al_sen_start)) return false else return true end end --[[ do local test_fn = "/home/slhome/txh18/workspace/nerv/nerv/some-text" --local test_fn = "/home/slhome/txh18/workspace/nerv-project/nerv/examples/lmptb/PTBdata/ptb.train.txt" local vocab = nerv.LMVocab() vocab:build_file(test_fn) local batch_size = 3 local feeder = nerv.LMFeeder({}, batch_size, vocab) feeder:open_file(test_fn) while (1) do local list = feeder:get_batch() if (list == nil) then break end for i = 1, batch_size, 1 do printf("%s(%d) ", list[i], vocab:get_word_str(list[i]).id) end printf("\n") end end ]]--