diff options
author | txh18 <[email protected]> | 2016-02-05 22:46:53 +0800 |
---|---|---|
committer | txh18 <[email protected]> | 2016-02-05 22:46:53 +0800 |
commit | 7a421571300417dba0d5c703d9a460ad19aeef14 (patch) | |
tree | 67062f09266ecef0580434435f4388749be3fc4c | |
parent | 3d7a2be2d8ac3083617df2b7194921971f0ac94e (diff) |
enhanced m-test/lm_sample
-rw-r--r-- | nerv/examples/lmptb/m-tests/lm_sampler_test.lua | 45 |
1 files changed, 30 insertions, 15 deletions
diff --git a/nerv/examples/lmptb/m-tests/lm_sampler_test.lua b/nerv/examples/lmptb/m-tests/lm_sampler_test.lua index 42a5787..0313d77 100644 --- a/nerv/examples/lmptb/m-tests/lm_sampler_test.lua +++ b/nerv/examples/lmptb/m-tests/lm_sampler_test.lua @@ -362,15 +362,12 @@ global_conf = { end -lr_half = false --can not be local, to be set by loadstring -start_iter = -1 -start_lr = nil -ppl_last = 100000 commands_str = "sampling" --"train:test" commands = {} -test_iter = -1 ---for testout(question) -q_file = "/home/slhome/txh18/workspace/ptb/questionGen/gen/ptb.test.txt.q10rs1_Msss.adds" +test_iter = -1 --obselete +random_seed = 1 +sample_num = 10 +out_fn = nil if arg[2] ~= nil then nerv.printf("%s applying arg[2](%s)...\n", global_conf.sche_log_pre, arg[2]) @@ -407,16 +404,16 @@ end nerv.LMUtil.wait(2) nerv.printf("%s printing training scheduling options...\n", global_conf.sche_log_pre) -nerv.printf("lr_half:\t%s\n", tostring(lr_half)) -nerv.printf("start_iter:\t%s\n", tostring(start_iter)) -nerv.printf("ppl_last:\t%s\n", tostring(ppl_last)) nerv.printf("commands_str:\t%s\n", commands_str) nerv.printf("test_iter:\t%s\n", tostring(test_iter)) +nerv.printf("random_seed:\t%s\n", tostring(random_seed)) +nerv.printf("sample_num:\t%s\n", tostring(sample_num)) +nerv.printf("out_fn:\t%s\n", tostring(out_fn)) nerv.printf("%s printing training scheduling end.\n", global_conf.sche_log_pre) nerv.LMUtil.wait(2) ------------------printing options end------------------------------ -math.randomseed(1) +math.randomseed(random_seed) local vocab = nerv.LMVocab() global_conf["vocab"] = vocab @@ -438,15 +435,33 @@ if commands["sampling"] == 1 then nerv.printf("===SAMPLE===\n") global_conf.sche_log_pre = "[SCHEDULER SAMPLING]:" local sampler = prepare_sampler(sm_conf) - for k = 1, 1 do - local res = sampler:lm_sample_rnn_dagL(10, {}) + local out_fh = nil + if out_fn ~= nil then + out_fh = assert(io.open(out_fn, "w")) + nerv.printf("%s outputing samples to file \"%s\"...\n", global_conf.sche_log_pre, out_fn) + end + for k = 1, sample_num do + local res = sampler:lm_sample_rnn_dagL(1, {}) for i = 1, #res do + if out_fh == nil then nerv.printf("lm_sampler_output_sample: ") end for j = 1, #res[i] do - nerv.printf("%s(%f) ", res[i][j].w, res[i][j].p) + if out_fh == nil then + nerv.printf("%s %f ", res[i][j].w, res[i][j].p) + else + out_fh:write(nerv.sprintf("%s %f ", res[i][j].w, res[i][j].p)) + end + end + if out_fh == nil then + nerv.printf("\n") + else + out_fh:write(nerv.sprintf("\n")) end - nerv.printf("\n") end + if k % 100 == 0 and out_fh ~= nil then nerv.printf("%s %d sample done\n", global_conf.sche_log_pre, k) end end + + if out_fh ~= nil then out_fh:close() end + nerv.printf("%s complete,bye\n", global_conf.sche_log_pre) --global_conf.dropout_rate = 0 --LMTrainer.lm_process_file_rnn(global_conf, global_conf.test_fn, tnn, false) --false update! end --if commands["sampling"] |