diff options
-rw-r--r-- | nerv/examples/lmptb/lmptb/layer/lm_affine_recurrent.lua | 2 | ||||
-rw-r--r-- | nerv/examples/lmptb/lmptb/layer/select_linear.lua | 21 | ||||
-rw-r--r-- | nerv/examples/lmptb/lmptb/lmutil.lua | 12 | ||||
-rw-r--r-- | nerv/examples/lmptb/main.lua | 11 |
4 files changed, 27 insertions, 19 deletions
diff --git a/nerv/examples/lmptb/lmptb/layer/lm_affine_recurrent.lua b/nerv/examples/lmptb/lmptb/layer/lm_affine_recurrent.lua index c43e567..a5ecce1 100644 --- a/nerv/examples/lmptb/lmptb/layer/lm_affine_recurrent.lua +++ b/nerv/examples/lmptb/lmptb/layer/lm_affine_recurrent.lua @@ -14,7 +14,7 @@ function LMRecurrent:propagate(input, output) output[1]:copy_fromd(input[1]) if (self.independent == true) then for i = 1, input[1]:nrow() do - if (self.gconf.input_word_id[self.id][i - 1][0] == self.break_id) then --here is sentence break + if (self.gconf.input_word_id[self.id][0][i - 1] == self.break_id) then --here is sentence break input[2][i - 1]:fill(0) end end diff --git a/nerv/examples/lmptb/lmptb/layer/select_linear.lua b/nerv/examples/lmptb/lmptb/layer/select_linear.lua index 4798536..e4afac4 100644 --- a/nerv/examples/lmptb/lmptb/layer/select_linear.lua +++ b/nerv/examples/lmptb/lmptb/layer/select_linear.lua @@ -30,22 +30,23 @@ function SL:init(batch_size) end function SL:update(bp_err, input, output) - for i = 1, input[1]:nrow(), 1 do - if (input[1][i - 1][0] ~= 0) then - local word_vec = self.ltp.trans[input[1][i - 1][0] - 1] + for i = 1, input[1]:ncol(), 1 do + if (input[1][0][i - 1] ~= 0) then + local word_vec = self.ltp.trans[input[1][0][i - 1]] word_vec:add(word_vec, bp_err[1][i - 1], 1, - self.gconf.lrate / self.gconf.batch_size) end end end function SL:propagate(input, output) - for i = 0, input[1]:nrow() - 1, 1 do - if (input[1][i][0] > 0) then - output[1][i]:copy_fromd(self.ltp.trans[input[1][i][0] - 1]) - else - output[1][i]:fill(0) - end - end + --for i = 0, input[1]:ncol() - 1, 1 do + -- if (input[1][0][i] > 0) then + -- output[1][i]:copy_fromd(self.ltp.trans[input[1][0][i]]) + -- else + -- output[1][i]:fill(0) + -- end + --end + output[1]:copy_rows_fromd_by_idx(self.ltp.trans, input[1]) end function SL:back_propagate(bp_err, next_bp_err, input, output) diff --git a/nerv/examples/lmptb/lmptb/lmutil.lua b/nerv/examples/lmptb/lmptb/lmutil.lua index 7f45a49..77babef 100644 --- a/nerv/examples/lmptb/lmptb/lmutil.lua +++ b/nerv/examples/lmptb/lmptb/lmutil.lua @@ -1,5 +1,11 @@ local Util = nerv.class("nerv.LMUtil") +--function rounds a number to the given number of decimal places. +function Util.round(num, idp) + local mult = 10^(idp or 0) + return math.floor(num * mult + 0.5) / mult +end + --list: table, list of string(word) --vocab: nerv.LMVocab --ty: nerv.CuMatrix @@ -42,15 +48,15 @@ end --Returns: nerv.MMatrixInt --Set the matrix to be ids of the words, id starting at 1, not 0 function Util.set_id(m, list, vocab) - if (m:nrow() ~= #list or m:ncol() ~= 1) then + if (m:ncol() ~= #list or m:nrow() ~= 1) then nerv.error("nrow of matrix mismatch with list or its col not one") end for i = 1, #list, 1 do --index in matrix starts at 0 if (list[i] ~= vocab.null_token) then - m[i - 1][0] = vocab:get_word_str(list[i]).id + m[0][i - 1] = vocab:get_word_str(list[i]).id else - m[i - 1][0] = 0 + m[0][i - 1] = 0 end end return m diff --git a/nerv/examples/lmptb/main.lua b/nerv/examples/lmptb/main.lua index 9b39e83..13d610e 100644 --- a/nerv/examples/lmptb/main.lua +++ b/nerv/examples/lmptb/main.lua @@ -15,7 +15,7 @@ function prepare_parameters(global_conf, first_time) if (first_time) then ltp_ih = nerv.LinearTransParam("ltp_ih", global_conf) - ltp_ih.trans = global_conf.cumat_type(global_conf.vocab:size(), global_conf.hidden_size) + ltp_ih.trans = global_conf.cumat_type(global_conf.vocab:size() + 1, global_conf.hidden_size) --index 0 is for zero, others correspond to vocab index(starting from 1) ltp_ih.trans:generate(global_conf.param_random) ltp_hh = nerv.LinearTransParam("ltp_hh", global_conf) @@ -164,7 +164,7 @@ function propagateFile(global_conf, dagL, fn, config) local dagL_input = {} for i = 1, global_conf.bptt + 1 do - dagL_input[i] = nerv.MMatrixInt(global_conf.batch_size, 1) + dagL_input[i] = global_conf.cumat_type(1, global_conf.batch_size) --changed to row vector, debughtx end dagL_input[global_conf.bptt + 2] = global_conf.cumat_type(global_conf.batch_size, global_conf.hidden_size) dagL_input[global_conf.bptt + 3] = global_conf.cumat_type(global_conf.batch_size, global_conf.vocab:size()) @@ -192,7 +192,7 @@ function propagateFile(global_conf, dagL, fn, config) global_conf.input_word_id["recurrentL"..i] = dagL_input[i] --for IndRecurrent end dagL_input[global_conf.bptt + 2]:copy_fromd(hidden_store[tnow - global_conf.bptt - 1]) - nerv.LMUtil.set_onehot(dagL_input[global_conf.bptt + 3], token_store[tnow + 1], global_conf.vocab) + nerv.LMUtil.set_onehot(dagL_input[global_conf.bptt + 3], token_store[tnow + 1], global_conf.vocab) --for softmax --local dagL_input = create_dag_input(global_conf, token_store, hidden_store, tnow) global_conf.timer:tic("dagL-propagate") @@ -224,6 +224,7 @@ function propagateFile(global_conf, dagL, fn, config) for key, value in pairs(global_conf.timer.rec) do printf("\t [global_conf.timer]: time spent on %s:%.5fs\n", key, value) end + --comment this for debughtx global_conf.timer:flush() --nerv.CuMatrix.print_profile() --nerv.CuMatrix.clear_profile() @@ -277,10 +278,10 @@ if (set == "ptb") then valid_fn = valid_fn, test_fn = test_fn, sche_log_pre = "[SCHEDULER]:", - log_w_num = 10000, --give a message when log_w_num words have been processed + log_w_num = 500000, --give a message when log_w_num words have been processed timer = nerv.Timer() } - global_conf.work_dir = work_dir_base.."/h"..global_conf.hidden_size.."bp"..global_conf.bptt.."slr"..global_conf.lrate..os.date("_%bD%dH%H") + global_conf.work_dir = work_dir_base.."/h"..global_conf.hidden_size.."bp"..global_conf.bptt.."slr"..global_conf.lrate --..os.date("_%bD%dH%H") --comment this for testing global_conf.param_fn = global_conf.work_dir.."/params" elseif (set == "test") then train_fn = "/slfs1/users/txh18/workspace/nerv-project/some-text" |