diff options
-rw-r--r-- | nerv/examples/lmptb/m-tests/tnn_test.lua | 20 |
1 files changed, 15 insertions, 5 deletions
diff --git a/nerv/examples/lmptb/m-tests/tnn_test.lua b/nerv/examples/lmptb/m-tests/tnn_test.lua index c033696..ddea54c 100644 --- a/nerv/examples/lmptb/m-tests/tnn_test.lua +++ b/nerv/examples/lmptb/m-tests/tnn_test.lua @@ -191,6 +191,7 @@ function lm_process_file(global_conf, fn, tnn, do_train) end end if (result["rnn"].cn_w > next_log_wcn) then + next_log_wcn = next_log_wcn + global_conf.log_w_num printf("%s %d words processed %s.\n", global_conf.sche_log_pre, result["rnn"].cn_w, os.date()) printf("\t%s log prob per sample :%f.\n", global_conf.sche_log_pre, result:logp_sample("rnn")) end @@ -216,14 +217,14 @@ function lm_process_file(global_conf, fn, tnn, do_train) end local train_fn, valid_fn, test_fn, global_conf -local set = "test" +local set = arg[1] --"test" if (set == "ptb") then data_dir = '/home/slhome/txh18/workspace/nerv/nerv/nerv/examples/lmptb/PTBdata' -train_fn = data_dir .. '/ptb.train.txt.cntk' -valid_fn = data_dir .. '/ptb.valid.txt.cntk' -test_fn = data_dir .. '/ptb.test.txt.cntk' +train_fn = data_dir .. '/ptb.train.txt.adds' +valid_fn = data_dir .. '/ptb.valid.txt.adds' +test_fn = data_dir .. '/ptb.test.txt.adds' global_conf = { lrate = 1, wcost = 1e-6, momentum = 0, @@ -275,8 +276,14 @@ global_conf = { end global_conf.work_dir = '/home/slhome/txh18/workspace/nerv/play/dagL_test' +global_conf.train_fn_shuf = global_conf.work_dir .. '/train_fn_shuf' +global_conf.train_fn_shuf_bak = global_conf.train_fn_shuf .. '_bak' global_conf.param_fn = global_conf.work_dir .. "/params" +printf("%s creating work_dir...\n", global_conf.sche_log_pre) +os.execute("mkdir -p "..global_conf.work_dir) +os.execute("cp " .. global_conf.train_fn .. " " .. global_conf.train_fn_shuf) + local vocab = nerv.LMVocab() global_conf["vocab"] = vocab global_conf.vocab:build_file(global_conf.train_fn, false) @@ -296,7 +303,10 @@ for iter = 1, global_conf.max_iter, 1 do tnn, paramRepo = load_net(global_conf) printf("===ITERATION %d LR %f===\n", iter, global_conf.lrate) global_conf.sche_log_pre = "[SCHEDULER ITER"..iter.." LR"..global_conf.lrate.."]:" - lm_process_file(global_conf, global_conf.train_fn, tnn, true) --true update! + lm_process_file(global_conf, global_conf.train_fn_shuf, tnn, true) --true update! + --shuffling training file + os.execute('cp ' .. global_conf.train_fn_shuf .. ' ' .. global_conf.train_fn_shuf_bak) + os.execute('cat ' .. global_conf.train_fn_shuf_bak .. ' | sort -R --random-source=/dev/zero > ' .. global_conf.train_fn_shuf) printf("===VALIDATION %d===\n", iter) result = lm_process_file(global_conf, global_conf.valid_fn, tnn, false) --false update! ppl_rec[iter] = result:ppl_net("rnn") |