From 4d349c4b8639e074aafa7d4245231bf1f3decae6 Mon Sep 17 00:00:00 2001 From: txh18 Date: Mon, 2 Nov 2015 22:55:24 +0800 Subject: ... --- nerv/examples/lmptb/lmptb/lmseqreader.lua | 38 +++++++++++++++++------- nerv/examples/lmptb/m-tests/dagl_test.lua | 2 ++ nerv/examples/lmptb/m-tests/lmseqreader_test.lua | 17 ++++++----- nerv/examples/lmptb/rnn/layer_tdag.lua | 11 +++---- 4 files changed, 44 insertions(+), 24 deletions(-) diff --git a/nerv/examples/lmptb/lmptb/lmseqreader.lua b/nerv/examples/lmptb/lmptb/lmseqreader.lua index 9396783..307c5a3 100644 --- a/nerv/examples/lmptb/lmptb/lmseqreader.lua +++ b/nerv/examples/lmptb/lmptb/lmseqreader.lua @@ -1,4 +1,5 @@ require 'lmptb.lmvocab' +require 'rnn.layer_tdag' local LMReader = nerv.class("nerv.LMSeqReader") @@ -7,11 +8,11 @@ local printf = nerv.printf --global_conf: table --batch_size: int --vocab: nerv.LMVocab -function LMReader:__init(global_conf, batch_size, seq_size, vocab) +function LMReader:__init(global_conf, batch_size, chunk_size, vocab) self.gconf = global_conf self.fh = nil --file handle to read, nil means currently no file self.batch_size = batch_size - self.seq_size = seq_size + self.chunk_size = chunk_size self.log_pre = "[LOG]LMSeqReader:" self.vocab = vocab self.streams = nil @@ -64,27 +65,42 @@ function LMReader:refresh_stream(id) end end -function LMReader:get_batch(input, label) +--feeds: a table that will be filled by the reader +--Returns: inputs_m, labels_m +function LMReader:get_batch(feeds) + if (feeds == nil or type(feeds) ~= "table") then + nerv.error("feeds is not a table") + end + + feeds["inputs_s"] = {} + feeds["labels_s"] = {} + inputs_s = feeds.inputs_s + labels_s = feeds.labels_s + for i = 1, self.chunk_size, 1 do + inputs_s[i] = {} + labels_s[i] = {} + end + local got_new = false - for i = 1, self.seq_size, 1 do + for i = 1, self.batch_size, 1 do local st = self.streams[i] - for j = 1, self.batch_size, 1 do + for j = 1, self.chunk_size, 1 do self:refresh_stream(i) if (st.store[st.head] ~= nil) then - input[i][j] = st.store[st.head] + inputs_s[j][i] = st.store[st.head] else - input[i][j] = self.vocab.null_token + inputs_s[j][i] = self.vocab.null_token end if (st.store[st.head + 1] ~= nil) then - label[i][j] = st.store[st.head + 1] + labels_s[j][i] = st.store[st.head + 1] else - label[i][j] = self.vocab.null_token + labels_s[j][i] = self.vocab.null_token end - if (input[i][j] ~= self.vocab.null_token) then + if (inputs_s[j][i] ~= self.vocab.null_token) then got_new = true st.store[st.head] = nil st.head = st.head + 1 - if (label[i][j] == self.vocab.sen_end_token) then + if (labels_s[j][i] == self.vocab.sen_end_token) then st.store[st.head] = nil --sentence end is passed st.head = st.head + 1 end diff --git a/nerv/examples/lmptb/m-tests/dagl_test.lua b/nerv/examples/lmptb/m-tests/dagl_test.lua index 8959a04..5e90551 100644 --- a/nerv/examples/lmptb/m-tests/dagl_test.lua +++ b/nerv/examples/lmptb/m-tests/dagl_test.lua @@ -162,3 +162,5 @@ local paramRepo = prepare_parameters(global_conf, true) local layerRepo = prepare_layers(global_conf, paramRepo) local dagL = prepare_dagLayer(global_conf, layerRepo) dagL:init(global_conf.batch_size, global_conf.chunk_size) + + diff --git a/nerv/examples/lmptb/m-tests/lmseqreader_test.lua b/nerv/examples/lmptb/m-tests/lmseqreader_test.lua index e1dae8c..504698f 100644 --- a/nerv/examples/lmptb/m-tests/lmseqreader_test.lua +++ b/nerv/examples/lmptb/m-tests/lmseqreader_test.lua @@ -6,22 +6,23 @@ local test_fn = "/home/slhome/txh18/workspace/nerv/nerv/nerv/examples/lmptb/m-te --local test_fn = "/home/slhome/txh18/workspace/nerv-project/nerv/examples/lmptb/PTBdata/ptb.train.txt" local vocab = nerv.LMVocab() vocab:build_file(test_fn) -local batch_size = 5 -local seq_size = 3 -local reader = nerv.LMSeqReader({}, batch_size, seq_size, vocab) +local chunk_size = 5 +local batch_size = 3 +local reader = nerv.LMSeqReader({}, batch_size, chunk_size, vocab) reader:open_file(test_fn) local input = {} local label = {} -for i = 1, seq_size, 1 do +for i = 1, batch_size, 1 do input[i] = {} label[i] = {} end +local feeds = {} while (1) do - local r = reader:get_batch(input, label) + local r = reader:get_batch(feeds) if (r == false) then break end - for j = 1, batch_size, 1 do - for i = 1, seq_size, 1 do - printf("%s[L(%s)] ", input[i][j], label[i][j]) --vocab:get_word_str(input[i][j]).id + for j = 1, chunk_size, 1 do + for i = 1, batch_size, 1 do + printf("%s[L(%s)] ", feeds.inputs_s[j][i], feeds.labels_s[j][i]) --vocab:get_word_str(input[i][j]).id end printf("\n") end diff --git a/nerv/examples/lmptb/rnn/layer_tdag.lua b/nerv/examples/lmptb/rnn/layer_tdag.lua index 3fa501e..6e5d774 100644 --- a/nerv/examples/lmptb/rnn/layer_tdag.lua +++ b/nerv/examples/lmptb/rnn/layer_tdag.lua @@ -220,15 +220,16 @@ function DAGLayer:set_err_inputs(bp_errs_m) end end ---[[ function DAGLayer:set_err_outputs(next_bp_err) for i = 1, #self.dim_in do - local layer = self.inputs[i][1] - local port = self.inputs[i][2] - layer.err_outputs[port] = next_bp_err[i] + if (next_bp_err[i] == nil) then + nerv.error("next_bp_err[%d] is not provided", i) + end + local ref = self.inputs_p[i].ref + local p = self.inputs_p[i].port + ref.err_outputs_m[p] = next_bp_err[i] end end -]]-- function DAGLayer:update(bp_err, input, output) self:set_err_inputs(bp_err) -- cgit v1.2.3-70-g09d2