diff options
author | txh18 <[email protected]> | 2015-12-04 17:53:31 +0800 |
---|---|---|
committer | txh18 <[email protected]> | 2015-12-04 17:53:31 +0800 |
commit | cce6efcdfbe50a59e260cb5d55ae2c77326dc67c (patch) | |
tree | 074e41d6090c71cb6aa88a0ab5919878b3393de2 | |
parent | 618450eb71817ded45c422f35d8fede2d52a66b2 (diff) |
added testout command for lstmlm
-rw-r--r-- | nerv/examples/lmptb/lm_trainer.lua | 15 | ||||
-rw-r--r-- | nerv/examples/lmptb/lstmlm_ptb_main.lua | 44 |
2 files changed, 48 insertions, 11 deletions
diff --git a/nerv/examples/lmptb/lm_trainer.lua b/nerv/examples/lmptb/lm_trainer.lua index 9ef4794..58d5bfc 100644 --- a/nerv/examples/lmptb/lm_trainer.lua +++ b/nerv/examples/lmptb/lm_trainer.lua @@ -17,8 +17,19 @@ function nerv.BiasParam:update_by_gradient(gradient) end --Returns: LMResult -function LMTrainer.lm_process_file_rnn(global_conf, fn, tnn, do_train) - local reader = nerv.LMSeqReader(global_conf, global_conf.batch_size, global_conf.chunk_size, global_conf.vocab) +function LMTrainer.lm_process_file_rnn(global_conf, fn, tnn, do_train, p_conf) + if p_conf == nil then + p_conf = {} + end + local reader + if p_conf.one_sen_report == true then --report log prob one by one sentence + if do_train == true then + nerv.warning("LMTrainer.lm_process_file_rnn: warning, one_sen_report is true while do_train is also true, strange") + end + reader = nerv.LMSeqReader(global_conf, 1, global_conf.max_sen_len, global_conf.vocab) + else + reader = nerv.LMSeqReader(global_conf, global_conf.batch_size, global_conf.chunk_size, global_conf.vocab) + end reader:open_file(fn) local result = nerv.LMResult(global_conf, global_conf.vocab) result:init("rnn") diff --git a/nerv/examples/lmptb/lstmlm_ptb_main.lua b/nerv/examples/lmptb/lstmlm_ptb_main.lua index 681c308..9f02324 100644 --- a/nerv/examples/lmptb/lstmlm_ptb_main.lua +++ b/nerv/examples/lmptb/lstmlm_ptb_main.lua @@ -195,20 +195,23 @@ local set = arg[1] --"test" if (set == "ptb") then -data_dir = '/home/slhome/txh18/workspace/nerv/nerv/nerv/examples/lmptb/PTBdata' +root_dir = '/home/slhome/txh18/workspace' +data_dir = root_dir .. '/ptb/DATA' train_fn = data_dir .. '/ptb.train.txt.adds' valid_fn = data_dir .. '/ptb.valid.txt.adds' test_fn = data_dir .. '/ptb.test.txt.adds' vocab_fn = data_dir .. '/vocab' +qdata_dir = root_dir .. '/ptb/questionGen/gen' + global_conf = { - lrate = 0.15, wcost = 1e-5, momentum = 0, clip_t = 2, + lrate = 0.15, wcost = 1e-5, momentum = 0, clip_t = 5, cumat_type = nerv.CuMatrixFloat, mmat_type = nerv.MMatrixFloat, nn_act_default = 0, - hidden_size = 650, - layer_num = 2, + hidden_size = 300, + layer_num = 1, chunk_size = 15, batch_size = 20, max_iter = 45, @@ -221,10 +224,11 @@ global_conf = { valid_fn = valid_fn, test_fn = test_fn, vocab_fn = vocab_fn, + max_sen_len = 90, sche_log_pre = "[SCHEDULER]:", log_w_num = 40000, --give a message when log_w_num words have been processed timer = nerv.Timer(), - work_dir_base = '/home/slhome/txh18/workspace/nerv/play/ptbEXP/tnn_lstm_test' + work_dir_base = '/home/slhome/txh18/workspace/ptb/EXP-nerv/lstmlm_v1.0' } elseif (set == "msr_sc") then @@ -303,6 +307,9 @@ local commands_str = "train:test" local commands = {} local test_iter = -1 +--for testout(question) +local q_file = "ptb.test.txt.q10rs1_Msss.adds" + if arg[2] ~= nil then nerv.printf("%s applying arg[2](%s)...\n", global_conf.sche_log_pre, arg[2]) loadstring(arg[2])() @@ -311,21 +318,22 @@ else nerv.printf("%s no user setting, all default...\n", global_conf.sche_log_pre) end -global_conf.work_dir = global_conf.work_dir_base .. 'h' .. global_conf.hidden_size .. 'l' .. global_conf.layer_num --.. 'ch' .. global_conf.chunk_size .. 'ba' .. global_conf.batch_size .. 'slr' .. global_conf.lrate .. 'wc' .. global_conf.wcost +global_conf.work_dir = global_conf.work_dir_base .. 'h' .. global_conf.hidden_size .. 'l' .. global_conf.layer_num .. 'ch' .. global_conf.chunk_size .. 'ba' .. global_conf.batch_size .. 'slr' .. global_conf.lrate .. 'wc' .. global_conf.wcost global_conf.train_fn_shuf = global_conf.work_dir .. '/train_fn_shuf' global_conf.train_fn_shuf_bak = global_conf.train_fn_shuf .. '_bak' global_conf.param_fn = global_conf.work_dir .. "/params" global_conf.dropout_list = nerv.SUtil.parse_schedule(global_conf.dropout_str) -global_conf.log_fn = global_conf.work_dir .. '/lstm_tnn_' .. commands_str .. '_log' +global_conf.log_fn = global_conf.work_dir .. '/log_lstm_tnn_' .. commands_str ..os.date("_TT%X_%m_%d",os.time()) commands = nerv.SUtil.parse_commands_set(commands_str) -nerv.printf("%s creating work_dir...\n", global_conf.sche_log_pre) -nerv.LMUtil.wait(1) +nerv.printf("%s creating work_dir(%s)...\n", global_conf.sche_log_pre, global_conf.work_dir) +nerv.LMUtil.wait(2) os.execute("mkdir -p "..global_conf.work_dir) os.execute("cp " .. global_conf.train_fn .. " " .. global_conf.train_fn_shuf) --redirecting log outputs! nerv.SUtil.log_redirect(global_conf.log_fn) +nerv.LMUtil.wait(2) ----------------printing options--------------------------------- nerv.printf("%s printing global_conf...\n", global_conf.sche_log_pre) @@ -441,3 +449,21 @@ if commands["test"] == 1 then global_conf.dropout_rate = 0 LMTrainer.lm_process_file_rnn(global_conf, global_conf.test_fn, tnn, false) --false update! end --if commands["test"] + +if commands["testout"] == 1 then + nerv.printf("===TEST OUT===\n") + nerv.printf("q_file:\t%s\n", q_file) + local q_fn = qdata_dir .. q_file + global_conf.sche_log_pre = "[SCHEDULER FINAL_TEST]:" + if final_iter ~= -1 and test_iter == -1 then + test_iter = final_iter + end + if test_iter == -1 then + test_iter = "final" + end + tnn = load_net(global_conf, test_iter) + global_conf.dropout_rate = 0 + LMTrainer.lm_process_file_rnn(global_conf, q_fn, tnn, false) --false update! +end --if commands["testout"] + + |