diff options
author | txh18 <[email protected]> | 2016-02-03 14:13:33 +0800 |
---|---|---|
committer | txh18 <[email protected]> | 2016-02-03 14:13:33 +0800 |
commit | 2fc05a9b3bb28ea8cae66c82b891028cccc40e53 (patch) | |
tree | f26f0d8f87ee6f65bf3f30336f1e323f9c45e515 | |
parent | bb0f58c82882d34ee1737227476167be9367433c (diff) |
added same_io option to lm_seq_reader
-rw-r--r-- | nerv/examples/lmptb/lmptb/lmseqreader.lua | 34 | ||||
-rw-r--r-- | nerv/examples/lmptb/m-tests/lmseqreader_test.lua | 10 | ||||
-rw-r--r-- | nerv/examples/lmptb/m-tests/some-text | 2 |
3 files changed, 27 insertions, 19 deletions
diff --git a/nerv/examples/lmptb/lmptb/lmseqreader.lua b/nerv/examples/lmptb/lmptb/lmseqreader.lua index 0f29f8b..1272929 100644 --- a/nerv/examples/lmptb/lmptb/lmseqreader.lua +++ b/nerv/examples/lmptb/lmptb/lmseqreader.lua @@ -28,6 +28,10 @@ function LMReader:__init(global_conf, batch_size, chunk_size, vocab, r_conf) if r_conf.compressed_label == true then self.compressed_label = true end + self.same_io = false + if r_conf.same_io == true then --can be used to train P(wi|w1..(i-1),(i+1)..n) + self.same_io = true + end end --fn: string @@ -36,9 +40,9 @@ function LMReader:open_file(fn) if (self.fh ~= nil) then nerv.error("%s error: in open_file(fn is %s), file handle not nil.", self.log_pre, fn) end - printf("%s opening file %s...\n", self.log_pre, fn) - print(self.log_pre, "batch_size:", self.batch_size, "chunk_size", self.chunk_size) - print(self.log_pre, "se_mode:", self.se_mode) + nerv.printf("%s opening file %s...\n", self.log_pre, fn) + nerv.printf("%s batch_size:%d chunk_size:%d\n", self.log_pre, self.batch_size, self.chunk_size) + nerv.printf("%s se_mode:%s same_io:%s\n", self.log_pre, tostring(self.se_mode), tostring(self.same_io)) self.fh = io.open(fn, "r") self.streams = {} for i = 1, self.batch_size, 1 do @@ -132,12 +136,15 @@ function LMReader:get_batch(feeds) else self:refresh_stream(i) if st.store[st.head] ~= nil then - inputs_s[j][i] = st.store[st.head] - --inputs_m[j][1][i - 1][0] = self.vocab:get_word_str(st.store[st.head]).id - 1 - self.bak_inputs_m[j][1][i - 1][0] = self.vocab:get_word_str(st.store[st.head]).id - 1 + if self.same_io == false then + inputs_s[j][i] = st.store[st.head] + self.bak_inputs_m[j][1][i - 1][0] = self.vocab:get_word_str(st.store[st.head]).id - 1 + else + inputs_s[j][i] = st.store[st.head + 1] + self.bak_inputs_m[j][1][i - 1][0] = self.vocab:get_word_str(st.store[st.head + 1]).id - 1 + end else inputs_s[j][i] = self.vocab.null_token - --inputs_m[j][1][i - 1][0] = 0 self.bak_inputs_m[j][1][i - 1][0] = 0 end if st.store[st.head + 1] ~= nil then @@ -148,7 +155,7 @@ function LMReader:get_batch(feeds) inputs_m[j][2][i - 1][self.vocab:get_word_str(st.store[st.head + 1]).id - 1] = 1 end else - if (inputs_s[j][i] ~= self.vocab.null_token) then + if inputs_s[j][i] ~= self.vocab.null_token then nerv.error("reader error : input not null but label is null_token") end labels_s[j][i] = self.vocab.null_token @@ -159,6 +166,9 @@ function LMReader:get_batch(feeds) end flags[j][i] = bit.bor(flags[j][i], nerv.TNN.FC.SEQ_NORM) --has both input and label got_new = true + if st.store[st.head] == self.vocab.sen_end_token then + flags[j][i] = bit.bor(flags[j][i], nerv.TNN.FC.SEQ_START) + end st.store[st.head] = nil st.head = st.head + 1 if labels_s[j][i] == self.vocab.sen_end_token then @@ -169,10 +179,7 @@ function LMReader:get_batch(feeds) end_stream = true --meet sentence end, this stream ends now end end - if inputs_s[j][i] == self.vocab.sen_end_token then - flags[j][i] = bit.bor(flags[j][i], nerv.TNN.FC.SEQ_START) - end - end + end end end end @@ -190,7 +197,7 @@ function LMReader:get_batch(feeds) --check for self.al_sen_start for i = 1, self.batch_size do - if inputs_s[1][i] ~= self.vocab.sen_end_token and inputs_s[1][i] ~= self.vocab.null_token then + if bit.band(flags[1][i], nerv.TNN.FC.SEQ_START) == 0 and flags[1][i] > 0 then self.stat.al_sen_start = false end end @@ -198,7 +205,6 @@ function LMReader:get_batch(feeds) if got_new == false then nerv.info("lmseqreader file ends, printing stats...") nerv.printf("al_sen_start:%s\n", tostring(self.stat.al_sen_start)) - return false else return true diff --git a/nerv/examples/lmptb/m-tests/lmseqreader_test.lua b/nerv/examples/lmptb/m-tests/lmseqreader_test.lua index 9127559..3f99741 100644 --- a/nerv/examples/lmptb/m-tests/lmseqreader_test.lua +++ b/nerv/examples/lmptb/m-tests/lmseqreader_test.lua @@ -7,7 +7,7 @@ local test_fn = "/home/slhome/txh18/workspace/nerv/nerv/nerv/examples/lmptb/m-te --local test_fn = "/home/slhome/txh18/workspace/nerv-project/nerv/examples/lmptb/PTBdata/ptb.train.txt" local vocab = nerv.LMVocab() vocab:build_file(test_fn) -local chunk_size = 20 +local chunk_size = 15 local batch_size = 3 local global_conf = { lrate = 1, wcost = 1e-6, momentum = 0, @@ -30,7 +30,8 @@ local global_conf = { vocab = vocab } -local reader = nerv.LMSeqReader(global_conf, batch_size, chunk_size, vocab, {["se_mode"] = true}) +local reader = nerv.LMSeqReader(global_conf, batch_size, chunk_size, vocab, + {["se_mode"] = true, ["same_io"] = true}) reader:open_file(test_fn) local feeds = {} feeds.flags_now = {} @@ -40,14 +41,15 @@ for j = 1, chunk_size do feeds.inputs_m[j] = {global_conf.cumat_type(batch_size, 1), global_conf.cumat_type(batch_size, global_conf.vocab:size())} feeds.flags_now[j] = {} end -while (1) do +for k = 1, 5 do local r = reader:get_batch(feeds) if (r == false) then break end for j = 1, chunk_size, 1 do for i = 1, batch_size, 1 do - printf("%s[L(%s)] ", feeds.inputs_s[j][i], feeds.labels_s[j][i]) --vocab:get_word_str(input[i][j]).id + printf("%s[L(%s)]F%d ", feeds.inputs_s[j][i], feeds.labels_s[j][i], feeds.flags_now[j][i]) --vocab:get_word_str(input[i][j]).id end printf("\n") end printf("\n") end +printf("reader.sen_start %s\n", tostring(reader.stat.al_sen_start)) diff --git a/nerv/examples/lmptb/m-tests/some-text b/nerv/examples/lmptb/m-tests/some-text index da4bea9..6756fa0 100644 --- a/nerv/examples/lmptb/m-tests/some-text +++ b/nerv/examples/lmptb/m-tests/some-text @@ -1,4 +1,4 @@ -</s> aa bb cc aa bb cc aa bb cc aa bb cc aa bb cc aa </s> +</s> aa bb cc aa bb cc aa bb cc aa bb cc aa </s> </s> aa bb cc aa bb cc aa bb cc aa </s> </s> bb cc aa bb cc aa bb cc aa </s> </s> aa bb cc aa </s> |