aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--nerv/examples/lmptb/m-tests/tnn_test.lua20
1 files changed, 15 insertions, 5 deletions
diff --git a/nerv/examples/lmptb/m-tests/tnn_test.lua b/nerv/examples/lmptb/m-tests/tnn_test.lua
index c033696..ddea54c 100644
--- a/nerv/examples/lmptb/m-tests/tnn_test.lua
+++ b/nerv/examples/lmptb/m-tests/tnn_test.lua
@@ -191,6 +191,7 @@ function lm_process_file(global_conf, fn, tnn, do_train)
end
end
if (result["rnn"].cn_w > next_log_wcn) then
+ next_log_wcn = next_log_wcn + global_conf.log_w_num
printf("%s %d words processed %s.\n", global_conf.sche_log_pre, result["rnn"].cn_w, os.date())
printf("\t%s log prob per sample :%f.\n", global_conf.sche_log_pre, result:logp_sample("rnn"))
end
@@ -216,14 +217,14 @@ function lm_process_file(global_conf, fn, tnn, do_train)
end
local train_fn, valid_fn, test_fn, global_conf
-local set = "test"
+local set = arg[1] --"test"
if (set == "ptb") then
data_dir = '/home/slhome/txh18/workspace/nerv/nerv/nerv/examples/lmptb/PTBdata'
-train_fn = data_dir .. '/ptb.train.txt.cntk'
-valid_fn = data_dir .. '/ptb.valid.txt.cntk'
-test_fn = data_dir .. '/ptb.test.txt.cntk'
+train_fn = data_dir .. '/ptb.train.txt.adds'
+valid_fn = data_dir .. '/ptb.valid.txt.adds'
+test_fn = data_dir .. '/ptb.test.txt.adds'
global_conf = {
lrate = 1, wcost = 1e-6, momentum = 0,
@@ -275,8 +276,14 @@ global_conf = {
end
global_conf.work_dir = '/home/slhome/txh18/workspace/nerv/play/dagL_test'
+global_conf.train_fn_shuf = global_conf.work_dir .. '/train_fn_shuf'
+global_conf.train_fn_shuf_bak = global_conf.train_fn_shuf .. '_bak'
global_conf.param_fn = global_conf.work_dir .. "/params"
+printf("%s creating work_dir...\n", global_conf.sche_log_pre)
+os.execute("mkdir -p "..global_conf.work_dir)
+os.execute("cp " .. global_conf.train_fn .. " " .. global_conf.train_fn_shuf)
+
local vocab = nerv.LMVocab()
global_conf["vocab"] = vocab
global_conf.vocab:build_file(global_conf.train_fn, false)
@@ -296,7 +303,10 @@ for iter = 1, global_conf.max_iter, 1 do
tnn, paramRepo = load_net(global_conf)
printf("===ITERATION %d LR %f===\n", iter, global_conf.lrate)
global_conf.sche_log_pre = "[SCHEDULER ITER"..iter.." LR"..global_conf.lrate.."]:"
- lm_process_file(global_conf, global_conf.train_fn, tnn, true) --true update!
+ lm_process_file(global_conf, global_conf.train_fn_shuf, tnn, true) --true update!
+ --shuffling training file
+ os.execute('cp ' .. global_conf.train_fn_shuf .. ' ' .. global_conf.train_fn_shuf_bak)
+ os.execute('cat ' .. global_conf.train_fn_shuf_bak .. ' | sort -R --random-source=/dev/zero > ' .. global_conf.train_fn_shuf)
printf("===VALIDATION %d===\n", iter)
result = lm_process_file(global_conf, global_conf.valid_fn, tnn, false) --false update!
ppl_rec[iter] = result:ppl_net("rnn")