aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--nerv/examples/lmptb/m-tests/tnn_test.lua45
-rw-r--r--nerv/examples/lmptb/rnn/tnn.lua32
2 files changed, 57 insertions, 20 deletions
diff --git a/nerv/examples/lmptb/m-tests/tnn_test.lua b/nerv/examples/lmptb/m-tests/tnn_test.lua
index 7b2de4a..c033696 100644
--- a/nerv/examples/lmptb/m-tests/tnn_test.lua
+++ b/nerv/examples/lmptb/m-tests/tnn_test.lua
@@ -158,7 +158,8 @@ function lm_process_file(global_conf, fn, tnn, do_train)
tnn.err_inputs_m[t][1]:fill(1)
end
- local batch_num = 1
+ local next_log_wcn = global_conf.log_w_num
+
while (1) do
local r, feeds
@@ -187,7 +188,11 @@ function lm_process_file(global_conf, fn, tnn, do_train)
if (feeds.labels_s[t][i] ~= global_conf.vocab.null_token) then
result:add("rnn", feeds.labels_s[t][i], math.exp(tnn.outputs_m[t][1][i - 1][0]))
end
- end
+ end
+ end
+ if (result["rnn"].cn_w > next_log_wcn) then
+ printf("%s %d words processed %s.\n", global_conf.sche_log_pre, result["rnn"].cn_w, os.date())
+ printf("\t%s log prob per sample :%f.\n", global_conf.sche_log_pre, result:logp_sample("rnn"))
end
--[[
@@ -217,27 +222,26 @@ if (set == "ptb") then
data_dir = '/home/slhome/txh18/workspace/nerv/nerv/nerv/examples/lmptb/PTBdata'
train_fn = data_dir .. '/ptb.train.txt.cntk'
-test_fn = data_dir .. '/ptb.test.txt.cntk'
valid_fn = data_dir .. '/ptb.valid.txt.cntk'
+test_fn = data_dir .. '/ptb.test.txt.cntk'
global_conf = {
lrate = 1, wcost = 1e-6, momentum = 0,
cumat_type = nerv.CuMatrixFloat,
- mmat_type = nerv.CuMatrixFloat,
+ mmat_type = nerv.MMatrixFloat,
nn_act_default = 0,
- hidden_size = 20,
- chunk_size = 5,
- batch_size = 3,
- max_iter = 18,
+ hidden_size = 200,
+ chunk_size = 15,
+ batch_size = 10,
+ max_iter = 25,
param_random = function() return (math.random() / 5 - 0.1) end,
- independent = true,
train_fn = train_fn,
valid_fn = valid_fn,
test_fn = test_fn,
sche_log_pre = "[SCHEDULER]:",
- log_w_num = 10, --give a message when log_w_num words have been processed
+ log_w_num = 10000, --give a message when log_w_num words have been processed
timer = nerv.Timer()
}
@@ -256,7 +260,7 @@ global_conf = {
hidden_size = 20,
chunk_size = 5,
batch_size = 3,
- max_iter = 18,
+ max_iter = 3,
param_random = function() return (math.random() / 5 - 0.1) end,
independent = true,
@@ -264,7 +268,7 @@ global_conf = {
valid_fn = valid_fn,
test_fn = test_fn,
sche_log_pre = "[SCHEDULER]:",
- log_w_num = 10, --give a message when log_w_num words have been processed
+ log_w_num = 20, --give a message when log_w_num words have been processed
timer = nerv.Timer()
}
@@ -310,11 +314,12 @@ for iter = 1, global_conf.max_iter, 1 do
end
printf("\n")
nerv.LMUtil.wait(2)
-end
-a= " printf(\"===VALIDATION PPL record===\\n\") \
- for i = 0, #ppl_rec do printf(\"<ITER%d LR%.5f: %.3f> \", i, lr_rec[i], ppl_rec[i]) end \
- printf(\"\\n\") \
- printf(\"===FINAL TEST===\\n\") \
- global_conf.sche_log_pre = \"[SCHEDULER FINAL_TEST]:\" \
- dagL, _ = load_net(global_conf) \
- propagateFile(global_conf, dagL, global_conf.test_fn, {do_train = false, report_word = false})"
+end
+printf("===VALIDATION PPL record===\n")
+for i = 0, #ppl_rec do printf("<ITER%d LR%.5f: %.3f> ", i, lr_rec[i], ppl_rec[i]) end
+printf("\n")
+printf("===FINAL TEST===\n")
+global_conf.sche_log_pre = "[SCHEDULER FINAL_TEST]:"
+tnn, paramRepo = load_net(global_conf)
+lm_process_file(global_conf, global_conf.test_fn, tnn, false) --false update!
+
diff --git a/nerv/examples/lmptb/rnn/tnn.lua b/nerv/examples/lmptb/rnn/tnn.lua
index 019d24c..ae9ed7a 100644
--- a/nerv/examples/lmptb/rnn/tnn.lua
+++ b/nerv/examples/lmptb/rnn/tnn.lua
@@ -321,6 +321,22 @@ function TNN:net_propagate() --propagate according to feeds_now
end
end
end
+
+ local flag_out = true
+ for t = 1, self.chunk_size do --check whether every output has been computed
+ if (bit.band(feeds_now.flagsPack_now[t], nerv.TNN.FC.HAS_LABEL) > 0) then
+ for i = 1, #self.dim_out do
+ local ref = self.outputs_p[i].ref
+ if (ref.outputs_b[t][1] ~= true) then
+ flag_out = false
+ break
+ end
+ end
+ end
+ end
+ if (flag_out == false) then
+ nerv.error("some thing wrong, some labeled output is not propagated")
+ end
end
--ref: the TNN_ref of a layer
@@ -421,6 +437,22 @@ function TNN:net_backpropagate(do_update) --propagate according to feeds_now
end
end
end
+
+ local flag_out = true
+ for t = 1, self.chunk_size do --check whether every output has been computed
+ if (bit.band(feeds_now.flagsPack_now[t], nerv.TNN.FC.HAS_INPUT) > 0) then
+ for i = 1, #self.dim_in do
+ local ref = self.inputs_p[i].ref
+ if (ref.err_outputs_b[t][1] ~= true) then
+ flag_out = false
+ break
+ end
+ end
+ end
+ end
+ if (flag_out == false) then
+ nerv.error("some thing wrong, some input is not back_propagated")
+ end
end
--ref: the TNN_ref of a layer