summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authortxh18 <[email protected]>2016-02-03 14:13:33 +0800
committertxh18 <[email protected]>2016-02-03 14:13:33 +0800
commit2fc05a9b3bb28ea8cae66c82b891028cccc40e53 (patch)
treef26f0d8f87ee6f65bf3f30336f1e323f9c45e515
parentbb0f58c82882d34ee1737227476167be9367433c (diff)
added same_io option to lm_seq_reader
-rw-r--r--nerv/examples/lmptb/lmptb/lmseqreader.lua34
-rw-r--r--nerv/examples/lmptb/m-tests/lmseqreader_test.lua10
-rw-r--r--nerv/examples/lmptb/m-tests/some-text2
3 files changed, 27 insertions, 19 deletions
diff --git a/nerv/examples/lmptb/lmptb/lmseqreader.lua b/nerv/examples/lmptb/lmptb/lmseqreader.lua
index 0f29f8b..1272929 100644
--- a/nerv/examples/lmptb/lmptb/lmseqreader.lua
+++ b/nerv/examples/lmptb/lmptb/lmseqreader.lua
@@ -28,6 +28,10 @@ function LMReader:__init(global_conf, batch_size, chunk_size, vocab, r_conf)
if r_conf.compressed_label == true then
self.compressed_label = true
end
+ self.same_io = false
+ if r_conf.same_io == true then --can be used to train P(wi|w1..(i-1),(i+1)..n)
+ self.same_io = true
+ end
end
--fn: string
@@ -36,9 +40,9 @@ function LMReader:open_file(fn)
if (self.fh ~= nil) then
nerv.error("%s error: in open_file(fn is %s), file handle not nil.", self.log_pre, fn)
end
- printf("%s opening file %s...\n", self.log_pre, fn)
- print(self.log_pre, "batch_size:", self.batch_size, "chunk_size", self.chunk_size)
- print(self.log_pre, "se_mode:", self.se_mode)
+ nerv.printf("%s opening file %s...\n", self.log_pre, fn)
+ nerv.printf("%s batch_size:%d chunk_size:%d\n", self.log_pre, self.batch_size, self.chunk_size)
+ nerv.printf("%s se_mode:%s same_io:%s\n", self.log_pre, tostring(self.se_mode), tostring(self.same_io))
self.fh = io.open(fn, "r")
self.streams = {}
for i = 1, self.batch_size, 1 do
@@ -132,12 +136,15 @@ function LMReader:get_batch(feeds)
else
self:refresh_stream(i)
if st.store[st.head] ~= nil then
- inputs_s[j][i] = st.store[st.head]
- --inputs_m[j][1][i - 1][0] = self.vocab:get_word_str(st.store[st.head]).id - 1
- self.bak_inputs_m[j][1][i - 1][0] = self.vocab:get_word_str(st.store[st.head]).id - 1
+ if self.same_io == false then
+ inputs_s[j][i] = st.store[st.head]
+ self.bak_inputs_m[j][1][i - 1][0] = self.vocab:get_word_str(st.store[st.head]).id - 1
+ else
+ inputs_s[j][i] = st.store[st.head + 1]
+ self.bak_inputs_m[j][1][i - 1][0] = self.vocab:get_word_str(st.store[st.head + 1]).id - 1
+ end
else
inputs_s[j][i] = self.vocab.null_token
- --inputs_m[j][1][i - 1][0] = 0
self.bak_inputs_m[j][1][i - 1][0] = 0
end
if st.store[st.head + 1] ~= nil then
@@ -148,7 +155,7 @@ function LMReader:get_batch(feeds)
inputs_m[j][2][i - 1][self.vocab:get_word_str(st.store[st.head + 1]).id - 1] = 1
end
else
- if (inputs_s[j][i] ~= self.vocab.null_token) then
+ if inputs_s[j][i] ~= self.vocab.null_token then
nerv.error("reader error : input not null but label is null_token")
end
labels_s[j][i] = self.vocab.null_token
@@ -159,6 +166,9 @@ function LMReader:get_batch(feeds)
end
flags[j][i] = bit.bor(flags[j][i], nerv.TNN.FC.SEQ_NORM) --has both input and label
got_new = true
+ if st.store[st.head] == self.vocab.sen_end_token then
+ flags[j][i] = bit.bor(flags[j][i], nerv.TNN.FC.SEQ_START)
+ end
st.store[st.head] = nil
st.head = st.head + 1
if labels_s[j][i] == self.vocab.sen_end_token then
@@ -169,10 +179,7 @@ function LMReader:get_batch(feeds)
end_stream = true --meet sentence end, this stream ends now
end
end
- if inputs_s[j][i] == self.vocab.sen_end_token then
- flags[j][i] = bit.bor(flags[j][i], nerv.TNN.FC.SEQ_START)
- end
- end
+ end
end
end
end
@@ -190,7 +197,7 @@ function LMReader:get_batch(feeds)
--check for self.al_sen_start
for i = 1, self.batch_size do
- if inputs_s[1][i] ~= self.vocab.sen_end_token and inputs_s[1][i] ~= self.vocab.null_token then
+ if bit.band(flags[1][i], nerv.TNN.FC.SEQ_START) == 0 and flags[1][i] > 0 then
self.stat.al_sen_start = false
end
end
@@ -198,7 +205,6 @@ function LMReader:get_batch(feeds)
if got_new == false then
nerv.info("lmseqreader file ends, printing stats...")
nerv.printf("al_sen_start:%s\n", tostring(self.stat.al_sen_start))
-
return false
else
return true
diff --git a/nerv/examples/lmptb/m-tests/lmseqreader_test.lua b/nerv/examples/lmptb/m-tests/lmseqreader_test.lua
index 9127559..3f99741 100644
--- a/nerv/examples/lmptb/m-tests/lmseqreader_test.lua
+++ b/nerv/examples/lmptb/m-tests/lmseqreader_test.lua
@@ -7,7 +7,7 @@ local test_fn = "/home/slhome/txh18/workspace/nerv/nerv/nerv/examples/lmptb/m-te
--local test_fn = "/home/slhome/txh18/workspace/nerv-project/nerv/examples/lmptb/PTBdata/ptb.train.txt"
local vocab = nerv.LMVocab()
vocab:build_file(test_fn)
-local chunk_size = 20
+local chunk_size = 15
local batch_size = 3
local global_conf = {
lrate = 1, wcost = 1e-6, momentum = 0,
@@ -30,7 +30,8 @@ local global_conf = {
vocab = vocab
}
-local reader = nerv.LMSeqReader(global_conf, batch_size, chunk_size, vocab, {["se_mode"] = true})
+local reader = nerv.LMSeqReader(global_conf, batch_size, chunk_size, vocab,
+ {["se_mode"] = true, ["same_io"] = true})
reader:open_file(test_fn)
local feeds = {}
feeds.flags_now = {}
@@ -40,14 +41,15 @@ for j = 1, chunk_size do
feeds.inputs_m[j] = {global_conf.cumat_type(batch_size, 1), global_conf.cumat_type(batch_size, global_conf.vocab:size())}
feeds.flags_now[j] = {}
end
-while (1) do
+for k = 1, 5 do
local r = reader:get_batch(feeds)
if (r == false) then break end
for j = 1, chunk_size, 1 do
for i = 1, batch_size, 1 do
- printf("%s[L(%s)] ", feeds.inputs_s[j][i], feeds.labels_s[j][i]) --vocab:get_word_str(input[i][j]).id
+ printf("%s[L(%s)]F%d ", feeds.inputs_s[j][i], feeds.labels_s[j][i], feeds.flags_now[j][i]) --vocab:get_word_str(input[i][j]).id
end
printf("\n")
end
printf("\n")
end
+printf("reader.sen_start %s\n", tostring(reader.stat.al_sen_start))
diff --git a/nerv/examples/lmptb/m-tests/some-text b/nerv/examples/lmptb/m-tests/some-text
index da4bea9..6756fa0 100644
--- a/nerv/examples/lmptb/m-tests/some-text
+++ b/nerv/examples/lmptb/m-tests/some-text
@@ -1,4 +1,4 @@
-</s> aa bb cc aa bb cc aa bb cc aa bb cc aa bb cc aa </s>
+</s> aa bb cc aa bb cc aa bb cc aa bb cc aa </s>
</s> aa bb cc aa bb cc aa bb cc aa </s>
</s> bb cc aa bb cc aa bb cc aa </s>
</s> aa bb cc aa </s>