aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--nerv/examples/lmptb/lmptb/lmseqreader.lua38
-rw-r--r--nerv/examples/lmptb/m-tests/dagl_test.lua2
-rw-r--r--nerv/examples/lmptb/m-tests/lmseqreader_test.lua17
-rw-r--r--nerv/examples/lmptb/rnn/layer_tdag.lua11
4 files changed, 44 insertions, 24 deletions
diff --git a/nerv/examples/lmptb/lmptb/lmseqreader.lua b/nerv/examples/lmptb/lmptb/lmseqreader.lua
index 9396783..307c5a3 100644
--- a/nerv/examples/lmptb/lmptb/lmseqreader.lua
+++ b/nerv/examples/lmptb/lmptb/lmseqreader.lua
@@ -1,4 +1,5 @@
require 'lmptb.lmvocab'
+require 'rnn.layer_tdag'
local LMReader = nerv.class("nerv.LMSeqReader")
@@ -7,11 +8,11 @@ local printf = nerv.printf
--global_conf: table
--batch_size: int
--vocab: nerv.LMVocab
-function LMReader:__init(global_conf, batch_size, seq_size, vocab)
+function LMReader:__init(global_conf, batch_size, chunk_size, vocab)
self.gconf = global_conf
self.fh = nil --file handle to read, nil means currently no file
self.batch_size = batch_size
- self.seq_size = seq_size
+ self.chunk_size = chunk_size
self.log_pre = "[LOG]LMSeqReader:"
self.vocab = vocab
self.streams = nil
@@ -64,27 +65,42 @@ function LMReader:refresh_stream(id)
end
end
-function LMReader:get_batch(input, label)
+--feeds: a table that will be filled by the reader
+--Returns: inputs_m, labels_m
+function LMReader:get_batch(feeds)
+ if (feeds == nil or type(feeds) ~= "table") then
+ nerv.error("feeds is not a table")
+ end
+
+ feeds["inputs_s"] = {}
+ feeds["labels_s"] = {}
+ inputs_s = feeds.inputs_s
+ labels_s = feeds.labels_s
+ for i = 1, self.chunk_size, 1 do
+ inputs_s[i] = {}
+ labels_s[i] = {}
+ end
+
local got_new = false
- for i = 1, self.seq_size, 1 do
+ for i = 1, self.batch_size, 1 do
local st = self.streams[i]
- for j = 1, self.batch_size, 1 do
+ for j = 1, self.chunk_size, 1 do
self:refresh_stream(i)
if (st.store[st.head] ~= nil) then
- input[i][j] = st.store[st.head]
+ inputs_s[j][i] = st.store[st.head]
else
- input[i][j] = self.vocab.null_token
+ inputs_s[j][i] = self.vocab.null_token
end
if (st.store[st.head + 1] ~= nil) then
- label[i][j] = st.store[st.head + 1]
+ labels_s[j][i] = st.store[st.head + 1]
else
- label[i][j] = self.vocab.null_token
+ labels_s[j][i] = self.vocab.null_token
end
- if (input[i][j] ~= self.vocab.null_token) then
+ if (inputs_s[j][i] ~= self.vocab.null_token) then
got_new = true
st.store[st.head] = nil
st.head = st.head + 1
- if (label[i][j] == self.vocab.sen_end_token) then
+ if (labels_s[j][i] == self.vocab.sen_end_token) then
st.store[st.head] = nil --sentence end is passed
st.head = st.head + 1
end
diff --git a/nerv/examples/lmptb/m-tests/dagl_test.lua b/nerv/examples/lmptb/m-tests/dagl_test.lua
index 8959a04..5e90551 100644
--- a/nerv/examples/lmptb/m-tests/dagl_test.lua
+++ b/nerv/examples/lmptb/m-tests/dagl_test.lua
@@ -162,3 +162,5 @@ local paramRepo = prepare_parameters(global_conf, true)
local layerRepo = prepare_layers(global_conf, paramRepo)
local dagL = prepare_dagLayer(global_conf, layerRepo)
dagL:init(global_conf.batch_size, global_conf.chunk_size)
+
+
diff --git a/nerv/examples/lmptb/m-tests/lmseqreader_test.lua b/nerv/examples/lmptb/m-tests/lmseqreader_test.lua
index e1dae8c..504698f 100644
--- a/nerv/examples/lmptb/m-tests/lmseqreader_test.lua
+++ b/nerv/examples/lmptb/m-tests/lmseqreader_test.lua
@@ -6,22 +6,23 @@ local test_fn = "/home/slhome/txh18/workspace/nerv/nerv/nerv/examples/lmptb/m-te
--local test_fn = "/home/slhome/txh18/workspace/nerv-project/nerv/examples/lmptb/PTBdata/ptb.train.txt"
local vocab = nerv.LMVocab()
vocab:build_file(test_fn)
-local batch_size = 5
-local seq_size = 3
-local reader = nerv.LMSeqReader({}, batch_size, seq_size, vocab)
+local chunk_size = 5
+local batch_size = 3
+local reader = nerv.LMSeqReader({}, batch_size, chunk_size, vocab)
reader:open_file(test_fn)
local input = {}
local label = {}
-for i = 1, seq_size, 1 do
+for i = 1, batch_size, 1 do
input[i] = {}
label[i] = {}
end
+local feeds = {}
while (1) do
- local r = reader:get_batch(input, label)
+ local r = reader:get_batch(feeds)
if (r == false) then break end
- for j = 1, batch_size, 1 do
- for i = 1, seq_size, 1 do
- printf("%s[L(%s)] ", input[i][j], label[i][j]) --vocab:get_word_str(input[i][j]).id
+ for j = 1, chunk_size, 1 do
+ for i = 1, batch_size, 1 do
+ printf("%s[L(%s)] ", feeds.inputs_s[j][i], feeds.labels_s[j][i]) --vocab:get_word_str(input[i][j]).id
end
printf("\n")
end
diff --git a/nerv/examples/lmptb/rnn/layer_tdag.lua b/nerv/examples/lmptb/rnn/layer_tdag.lua
index 3fa501e..6e5d774 100644
--- a/nerv/examples/lmptb/rnn/layer_tdag.lua
+++ b/nerv/examples/lmptb/rnn/layer_tdag.lua
@@ -220,15 +220,16 @@ function DAGLayer:set_err_inputs(bp_errs_m)
end
end
---[[
function DAGLayer:set_err_outputs(next_bp_err)
for i = 1, #self.dim_in do
- local layer = self.inputs[i][1]
- local port = self.inputs[i][2]
- layer.err_outputs[port] = next_bp_err[i]
+ if (next_bp_err[i] == nil) then
+ nerv.error("next_bp_err[%d] is not provided", i)
+ end
+ local ref = self.inputs_p[i].ref
+ local p = self.inputs_p[i].port
+ ref.err_outputs_m[p] = next_bp_err[i]
end
end
-]]--
function DAGLayer:update(bp_err, input, output)
self:set_err_inputs(bp_err)