aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--nerv/examples/lmptb/lmptb/lmseqreader.lua86
-rw-r--r--nerv/examples/lmptb/m-tests/lmseqreader_test.lua3
2 files changed, 56 insertions, 33 deletions
diff --git a/nerv/examples/lmptb/lmptb/lmseqreader.lua b/nerv/examples/lmptb/lmptb/lmseqreader.lua
index 04eba45..ff07415 100644
--- a/nerv/examples/lmptb/lmptb/lmseqreader.lua
+++ b/nerv/examples/lmptb/lmptb/lmseqreader.lua
@@ -1,4 +1,5 @@
require 'lmptb.lmvocab'
+require 'tnn.init'
local LMReader = nerv.class("nerv.LMSeqReader")
@@ -7,7 +8,7 @@ local printf = nerv.printf
--global_conf: table
--batch_size: int
--vocab: nerv.LMVocab
-function LMReader:__init(global_conf, batch_size, chunk_size, vocab)
+function LMReader:__init(global_conf, batch_size, chunk_size, vocab, r_conf)
self.gconf = global_conf
self.fh = nil --file handle to read, nil means currently no file
self.batch_size = batch_size
@@ -15,6 +16,13 @@ function LMReader:__init(global_conf, batch_size, chunk_size, vocab)
self.log_pre = "[LOG]LMSeqReader:"
self.vocab = vocab
self.streams = nil
+ if r_conf == nil then
+ r_conf = {}
+ end
+ self.se_mode = false --sentence end mode, when a sentence end is met, the stream after will be null
+ if r_conf.se_mode == true then
+ self.se_mode = true
+ end
end
--fn: string
@@ -25,6 +33,7 @@ function LMReader:open_file(fn)
end
printf("%s opening file %s...\n", self.log_pre, fn)
print(self.log_pre, "batch_size:", self.batch_size, "chunk_size", self.chunk_size)
+ print(self.log_pre, "se_mode:", self.se_mode)
self.fh = io.open(fn, "r")
self.streams = {}
for i = 1, self.batch_size, 1 do
@@ -35,7 +44,7 @@ function LMReader:open_file(fn)
for j = 1, self.chunk_size, 1 do
self.bak_inputs_m[j] = {}
self.bak_inputs_m[j][1] = self.gconf.mmat_type(self.batch_size, 1)
- self.bak_inputs_m[j][2] = self.gconf.mmat_type(self.batch_size, self.vocab:size()) --since MMatrix does not yet have fill, this m[j][2] is not used
+ --self.bak_inputs_m[j][2] = self.gconf.mmat_type(self.batch_size, self.vocab:size()) --since MMatrix does not yet have fill, this m[j][2] is not used
end
end
@@ -98,44 +107,57 @@ function LMReader:get_batch(feeds)
end
for i = 1, self.batch_size, 1 do
local st = self.streams[i]
+ local end_stream = false --used for se_mode, indicating that this stream is ended
for j = 1, self.chunk_size, 1 do
flags[j][i] = 0
- self:refresh_stream(i)
- if st.store[st.head] ~= nil then
- inputs_s[j][i] = st.store[st.head]
- --inputs_m[j][1][i - 1][0] = self.vocab:get_word_str(st.store[st.head]).id - 1
- self.bak_inputs_m[j][1][i - 1][0] = self.vocab:get_word_str(st.store[st.head]).id - 1
- else
+ if end_stream == true then
+ if self.se_mode == false then
+ nerv.error("lmseqreader:getbatch: error, end_stream is true while se_mode is false")
+ end
inputs_s[j][i] = self.vocab.null_token
- --inputs_m[j][1][i - 1][0] = 0
self.bak_inputs_m[j][1][i - 1][0] = 0
- end
- if st.store[st.head + 1] ~= nil then
- labels_s[j][i] = st.store[st.head + 1]
- inputs_m[j][2][i - 1][self.vocab:get_word_str(st.store[st.head + 1]).id - 1] = 1
+ labels_s[j][i] = self.vocab.null_token
else
- if (inputs_s[j][i] ~= self.vocab.null_token) then
- nerv.error("reader error : input not null but label is null_token")
+ self:refresh_stream(i)
+ if st.store[st.head] ~= nil then
+ inputs_s[j][i] = st.store[st.head]
+ --inputs_m[j][1][i - 1][0] = self.vocab:get_word_str(st.store[st.head]).id - 1
+ self.bak_inputs_m[j][1][i - 1][0] = self.vocab:get_word_str(st.store[st.head]).id - 1
+ else
+ inputs_s[j][i] = self.vocab.null_token
+ --inputs_m[j][1][i - 1][0] = 0
+ self.bak_inputs_m[j][1][i - 1][0] = 0
end
- labels_s[j][i] = self.vocab.null_token
- end
- if (inputs_s[j][i] ~= self.vocab.null_token) then
- if (labels_s[j][i] == self.vocab.null_token) then
- nerv.error("reader error : label is null while input is not null")
+ if st.store[st.head + 1] ~= nil then
+ labels_s[j][i] = st.store[st.head + 1]
+ inputs_m[j][2][i - 1][self.vocab:get_word_str(st.store[st.head + 1]).id - 1] = 1
+ else
+ if (inputs_s[j][i] ~= self.vocab.null_token) then
+ nerv.error("reader error : input not null but label is null_token")
+ end
+ labels_s[j][i] = self.vocab.null_token
end
- flags[j][i] = bit.bor(flags[j][i], nerv.TNN.FC.SEQ_NORM)
- got_new = true
- st.store[st.head] = nil
- st.head = st.head + 1
- if labels_s[j][i] == self.vocab.sen_end_token then
- flags[j][i] = bit.bor(flags[j][i], nerv.TNN.FC.SEQ_END)
- st.store[st.head] = nil --sentence end is passed
+ if inputs_s[j][i] ~= self.vocab.null_token then
+ if labels_s[j][i] == self.vocab.null_token then
+ nerv.error("reader error : label is null while input is not null")
+ end
+ flags[j][i] = bit.bor(flags[j][i], nerv.TNN.FC.SEQ_NORM) --has both input and label
+ got_new = true
+ st.store[st.head] = nil
st.head = st.head + 1
- end
- if inputs_s[j][i] == self.vocab.sen_end_token then
- flags[j][i] = bit.bor(flags[j][i], nerv.TNN.FC.SEQ_START)
- end
- end
+ if labels_s[j][i] == self.vocab.sen_end_token then
+ flags[j][i] = bit.bor(flags[j][i], nerv.TNN.FC.SEQ_END)
+ st.store[st.head] = nil --sentence end is passed
+ st.head = st.head + 1
+ if self.se_mode == true then
+ end_stream = true --meet sentence end, this stream ends now
+ end
+ end
+ if inputs_s[j][i] == self.vocab.sen_end_token then
+ flags[j][i] = bit.bor(flags[j][i], nerv.TNN.FC.SEQ_START)
+ end
+ end
+ end
end
end
diff --git a/nerv/examples/lmptb/m-tests/lmseqreader_test.lua b/nerv/examples/lmptb/m-tests/lmseqreader_test.lua
index cbcdcbe..b98ff95 100644
--- a/nerv/examples/lmptb/m-tests/lmseqreader_test.lua
+++ b/nerv/examples/lmptb/m-tests/lmseqreader_test.lua
@@ -12,7 +12,7 @@ local batch_size = 3
local global_conf = {
lrate = 1, wcost = 1e-6, momentum = 0,
cumat_type = nerv.CuMatrixFloat,
- mmat_type = nerv.CuMatrixFloat,
+ mmat_type = nerv.MMatrixFloat,
hidden_size = 20,
chunk_size = chunk_size,
@@ -35,6 +35,7 @@ reader:open_file(test_fn)
local feeds = {}
feeds.flags_now = {}
feeds.inputs_m = {}
+feeds.flagsPack_now = {}
for j = 1, chunk_size do
feeds.inputs_m[j] = {global_conf.cumat_type(batch_size, 1), global_conf.cumat_type(batch_size, global_conf.vocab:size())}
feeds.flags_now[j] = {}