diff options
-rw-r--r-- | nerv/examples/lmptb/m-tests/tnn_test.lua | 45 | ||||
-rw-r--r-- | nerv/examples/lmptb/rnn/tnn.lua | 32 |
2 files changed, 57 insertions, 20 deletions
diff --git a/nerv/examples/lmptb/m-tests/tnn_test.lua b/nerv/examples/lmptb/m-tests/tnn_test.lua index 7b2de4a..c033696 100644 --- a/nerv/examples/lmptb/m-tests/tnn_test.lua +++ b/nerv/examples/lmptb/m-tests/tnn_test.lua @@ -158,7 +158,8 @@ function lm_process_file(global_conf, fn, tnn, do_train) tnn.err_inputs_m[t][1]:fill(1) end - local batch_num = 1 + local next_log_wcn = global_conf.log_w_num + while (1) do local r, feeds @@ -187,7 +188,11 @@ function lm_process_file(global_conf, fn, tnn, do_train) if (feeds.labels_s[t][i] ~= global_conf.vocab.null_token) then result:add("rnn", feeds.labels_s[t][i], math.exp(tnn.outputs_m[t][1][i - 1][0])) end - end + end + end + if (result["rnn"].cn_w > next_log_wcn) then + printf("%s %d words processed %s.\n", global_conf.sche_log_pre, result["rnn"].cn_w, os.date()) + printf("\t%s log prob per sample :%f.\n", global_conf.sche_log_pre, result:logp_sample("rnn")) end --[[ @@ -217,27 +222,26 @@ if (set == "ptb") then data_dir = '/home/slhome/txh18/workspace/nerv/nerv/nerv/examples/lmptb/PTBdata' train_fn = data_dir .. '/ptb.train.txt.cntk' -test_fn = data_dir .. '/ptb.test.txt.cntk' valid_fn = data_dir .. '/ptb.valid.txt.cntk' +test_fn = data_dir .. '/ptb.test.txt.cntk' global_conf = { lrate = 1, wcost = 1e-6, momentum = 0, cumat_type = nerv.CuMatrixFloat, - mmat_type = nerv.CuMatrixFloat, + mmat_type = nerv.MMatrixFloat, nn_act_default = 0, - hidden_size = 20, - chunk_size = 5, - batch_size = 3, - max_iter = 18, + hidden_size = 200, + chunk_size = 15, + batch_size = 10, + max_iter = 25, param_random = function() return (math.random() / 5 - 0.1) end, - independent = true, train_fn = train_fn, valid_fn = valid_fn, test_fn = test_fn, sche_log_pre = "[SCHEDULER]:", - log_w_num = 10, --give a message when log_w_num words have been processed + log_w_num = 10000, --give a message when log_w_num words have been processed timer = nerv.Timer() } @@ -256,7 +260,7 @@ global_conf = { hidden_size = 20, chunk_size = 5, batch_size = 3, - max_iter = 18, + max_iter = 3, param_random = function() return (math.random() / 5 - 0.1) end, independent = true, @@ -264,7 +268,7 @@ global_conf = { valid_fn = valid_fn, test_fn = test_fn, sche_log_pre = "[SCHEDULER]:", - log_w_num = 10, --give a message when log_w_num words have been processed + log_w_num = 20, --give a message when log_w_num words have been processed timer = nerv.Timer() } @@ -310,11 +314,12 @@ for iter = 1, global_conf.max_iter, 1 do end printf("\n") nerv.LMUtil.wait(2) -end -a= " printf(\"===VALIDATION PPL record===\\n\") \ - for i = 0, #ppl_rec do printf(\"<ITER%d LR%.5f: %.3f> \", i, lr_rec[i], ppl_rec[i]) end \ - printf(\"\\n\") \ - printf(\"===FINAL TEST===\\n\") \ - global_conf.sche_log_pre = \"[SCHEDULER FINAL_TEST]:\" \ - dagL, _ = load_net(global_conf) \ - propagateFile(global_conf, dagL, global_conf.test_fn, {do_train = false, report_word = false})" +end +printf("===VALIDATION PPL record===\n") +for i = 0, #ppl_rec do printf("<ITER%d LR%.5f: %.3f> ", i, lr_rec[i], ppl_rec[i]) end +printf("\n") +printf("===FINAL TEST===\n") +global_conf.sche_log_pre = "[SCHEDULER FINAL_TEST]:" +tnn, paramRepo = load_net(global_conf) +lm_process_file(global_conf, global_conf.test_fn, tnn, false) --false update! + diff --git a/nerv/examples/lmptb/rnn/tnn.lua b/nerv/examples/lmptb/rnn/tnn.lua index 019d24c..ae9ed7a 100644 --- a/nerv/examples/lmptb/rnn/tnn.lua +++ b/nerv/examples/lmptb/rnn/tnn.lua @@ -321,6 +321,22 @@ function TNN:net_propagate() --propagate according to feeds_now end end end + + local flag_out = true + for t = 1, self.chunk_size do --check whether every output has been computed + if (bit.band(feeds_now.flagsPack_now[t], nerv.TNN.FC.HAS_LABEL) > 0) then + for i = 1, #self.dim_out do + local ref = self.outputs_p[i].ref + if (ref.outputs_b[t][1] ~= true) then + flag_out = false + break + end + end + end + end + if (flag_out == false) then + nerv.error("some thing wrong, some labeled output is not propagated") + end end --ref: the TNN_ref of a layer @@ -421,6 +437,22 @@ function TNN:net_backpropagate(do_update) --propagate according to feeds_now end end end + + local flag_out = true + for t = 1, self.chunk_size do --check whether every output has been computed + if (bit.band(feeds_now.flagsPack_now[t], nerv.TNN.FC.HAS_INPUT) > 0) then + for i = 1, #self.dim_in do + local ref = self.inputs_p[i].ref + if (ref.err_outputs_b[t][1] ~= true) then + flag_out = false + break + end + end + end + end + if (flag_out == false) then + nerv.error("some thing wrong, some input is not back_propagated") + end end --ref: the TNN_ref of a layer |