summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--nerv/examples/lmptb/lmptb/layer/lm_affine_recurrent.lua2
-rw-r--r--nerv/examples/lmptb/lmptb/layer/select_linear.lua21
-rw-r--r--nerv/examples/lmptb/lmptb/lmutil.lua12
-rw-r--r--nerv/examples/lmptb/main.lua11
4 files changed, 27 insertions, 19 deletions
diff --git a/nerv/examples/lmptb/lmptb/layer/lm_affine_recurrent.lua b/nerv/examples/lmptb/lmptb/layer/lm_affine_recurrent.lua
index c43e567..a5ecce1 100644
--- a/nerv/examples/lmptb/lmptb/layer/lm_affine_recurrent.lua
+++ b/nerv/examples/lmptb/lmptb/layer/lm_affine_recurrent.lua
@@ -14,7 +14,7 @@ function LMRecurrent:propagate(input, output)
output[1]:copy_fromd(input[1])
if (self.independent == true) then
for i = 1, input[1]:nrow() do
- if (self.gconf.input_word_id[self.id][i - 1][0] == self.break_id) then --here is sentence break
+ if (self.gconf.input_word_id[self.id][0][i - 1] == self.break_id) then --here is sentence break
input[2][i - 1]:fill(0)
end
end
diff --git a/nerv/examples/lmptb/lmptb/layer/select_linear.lua b/nerv/examples/lmptb/lmptb/layer/select_linear.lua
index 4798536..e4afac4 100644
--- a/nerv/examples/lmptb/lmptb/layer/select_linear.lua
+++ b/nerv/examples/lmptb/lmptb/layer/select_linear.lua
@@ -30,22 +30,23 @@ function SL:init(batch_size)
end
function SL:update(bp_err, input, output)
- for i = 1, input[1]:nrow(), 1 do
- if (input[1][i - 1][0] ~= 0) then
- local word_vec = self.ltp.trans[input[1][i - 1][0] - 1]
+ for i = 1, input[1]:ncol(), 1 do
+ if (input[1][0][i - 1] ~= 0) then
+ local word_vec = self.ltp.trans[input[1][0][i - 1]]
word_vec:add(word_vec, bp_err[1][i - 1], 1, - self.gconf.lrate / self.gconf.batch_size)
end
end
end
function SL:propagate(input, output)
- for i = 0, input[1]:nrow() - 1, 1 do
- if (input[1][i][0] > 0) then
- output[1][i]:copy_fromd(self.ltp.trans[input[1][i][0] - 1])
- else
- output[1][i]:fill(0)
- end
- end
+ --for i = 0, input[1]:ncol() - 1, 1 do
+ -- if (input[1][0][i] > 0) then
+ -- output[1][i]:copy_fromd(self.ltp.trans[input[1][0][i]])
+ -- else
+ -- output[1][i]:fill(0)
+ -- end
+ --end
+ output[1]:copy_rows_fromd_by_idx(self.ltp.trans, input[1])
end
function SL:back_propagate(bp_err, next_bp_err, input, output)
diff --git a/nerv/examples/lmptb/lmptb/lmutil.lua b/nerv/examples/lmptb/lmptb/lmutil.lua
index 7f45a49..77babef 100644
--- a/nerv/examples/lmptb/lmptb/lmutil.lua
+++ b/nerv/examples/lmptb/lmptb/lmutil.lua
@@ -1,5 +1,11 @@
local Util = nerv.class("nerv.LMUtil")
+--function rounds a number to the given number of decimal places.
+function Util.round(num, idp)
+ local mult = 10^(idp or 0)
+ return math.floor(num * mult + 0.5) / mult
+end
+
--list: table, list of string(word)
--vocab: nerv.LMVocab
--ty: nerv.CuMatrix
@@ -42,15 +48,15 @@ end
--Returns: nerv.MMatrixInt
--Set the matrix to be ids of the words, id starting at 1, not 0
function Util.set_id(m, list, vocab)
- if (m:nrow() ~= #list or m:ncol() ~= 1) then
+ if (m:ncol() ~= #list or m:nrow() ~= 1) then
nerv.error("nrow of matrix mismatch with list or its col not one")
end
for i = 1, #list, 1 do
--index in matrix starts at 0
if (list[i] ~= vocab.null_token) then
- m[i - 1][0] = vocab:get_word_str(list[i]).id
+ m[0][i - 1] = vocab:get_word_str(list[i]).id
else
- m[i - 1][0] = 0
+ m[0][i - 1] = 0
end
end
return m
diff --git a/nerv/examples/lmptb/main.lua b/nerv/examples/lmptb/main.lua
index 9b39e83..13d610e 100644
--- a/nerv/examples/lmptb/main.lua
+++ b/nerv/examples/lmptb/main.lua
@@ -15,7 +15,7 @@ function prepare_parameters(global_conf, first_time)
if (first_time) then
ltp_ih = nerv.LinearTransParam("ltp_ih", global_conf)
- ltp_ih.trans = global_conf.cumat_type(global_conf.vocab:size(), global_conf.hidden_size)
+ ltp_ih.trans = global_conf.cumat_type(global_conf.vocab:size() + 1, global_conf.hidden_size) --index 0 is for zero, others correspond to vocab index(starting from 1)
ltp_ih.trans:generate(global_conf.param_random)
ltp_hh = nerv.LinearTransParam("ltp_hh", global_conf)
@@ -164,7 +164,7 @@ function propagateFile(global_conf, dagL, fn, config)
local dagL_input = {}
for i = 1, global_conf.bptt + 1 do
- dagL_input[i] = nerv.MMatrixInt(global_conf.batch_size, 1)
+ dagL_input[i] = global_conf.cumat_type(1, global_conf.batch_size) --changed to row vector, debughtx
end
dagL_input[global_conf.bptt + 2] = global_conf.cumat_type(global_conf.batch_size, global_conf.hidden_size)
dagL_input[global_conf.bptt + 3] = global_conf.cumat_type(global_conf.batch_size, global_conf.vocab:size())
@@ -192,7 +192,7 @@ function propagateFile(global_conf, dagL, fn, config)
global_conf.input_word_id["recurrentL"..i] = dagL_input[i] --for IndRecurrent
end
dagL_input[global_conf.bptt + 2]:copy_fromd(hidden_store[tnow - global_conf.bptt - 1])
- nerv.LMUtil.set_onehot(dagL_input[global_conf.bptt + 3], token_store[tnow + 1], global_conf.vocab)
+ nerv.LMUtil.set_onehot(dagL_input[global_conf.bptt + 3], token_store[tnow + 1], global_conf.vocab) --for softmax
--local dagL_input = create_dag_input(global_conf, token_store, hidden_store, tnow)
global_conf.timer:tic("dagL-propagate")
@@ -224,6 +224,7 @@ function propagateFile(global_conf, dagL, fn, config)
for key, value in pairs(global_conf.timer.rec) do
printf("\t [global_conf.timer]: time spent on %s:%.5fs\n", key, value)
end
+ --comment this for debughtx
global_conf.timer:flush()
--nerv.CuMatrix.print_profile()
--nerv.CuMatrix.clear_profile()
@@ -277,10 +278,10 @@ if (set == "ptb") then
valid_fn = valid_fn,
test_fn = test_fn,
sche_log_pre = "[SCHEDULER]:",
- log_w_num = 10000, --give a message when log_w_num words have been processed
+ log_w_num = 500000, --give a message when log_w_num words have been processed
timer = nerv.Timer()
}
- global_conf.work_dir = work_dir_base.."/h"..global_conf.hidden_size.."bp"..global_conf.bptt.."slr"..global_conf.lrate..os.date("_%bD%dH%H")
+ global_conf.work_dir = work_dir_base.."/h"..global_conf.hidden_size.."bp"..global_conf.bptt.."slr"..global_conf.lrate --..os.date("_%bD%dH%H") --comment this for testing
global_conf.param_fn = global_conf.work_dir.."/params"
elseif (set == "test") then
train_fn = "/slfs1/users/txh18/workspace/nerv-project/some-text"