diff options
author | txh18 <[email protected]> | 2015-12-06 15:18:13 +0800 |
---|---|---|
committer | txh18 <[email protected]> | 2015-12-06 15:18:13 +0800 |
commit | 313011f24dfdacfe9c18d018d5bb877625a09ec7 (patch) | |
tree | 1f5ce2cfc4f4faf5d5bb12cc80bda8ee3125d0e9 | |
parent | 79c711d9c92a8e92f7ad9187a66d3e2aac239356 (diff) |
small bug fix on lm train script
-rw-r--r-- | nerv/examples/lmptb/bilstmlm_ptb_main.lua | 23 | ||||
-rw-r--r-- | nerv/examples/lmptb/lm_trainer.lua | 2 |
2 files changed, 12 insertions, 13 deletions
diff --git a/nerv/examples/lmptb/bilstmlm_ptb_main.lua b/nerv/examples/lmptb/bilstmlm_ptb_main.lua index 9100d0d..cf0009b 100644 --- a/nerv/examples/lmptb/bilstmlm_ptb_main.lua +++ b/nerv/examples/lmptb/bilstmlm_ptb_main.lua @@ -325,7 +325,7 @@ global_conf = { hidden_size = 20, layer_num = 1, - chunk_size = 2, + chunk_size = 20, batch_size = 10, max_iter = 3, param_random = function() return (math.random() / 5 - 0.1) end, @@ -372,7 +372,6 @@ global_conf.dropout_list = nerv.SUtil.parse_schedule(global_conf.dropout_str) global_conf.log_fn = global_conf.work_dir .. '/log_lstm_tnn_' .. commands_str ..os.date("_TT%m_%d_%X",os.time()) global_conf.log_fn, _ = string.gsub(global_conf.log_fn, ':', '-') commands = nerv.SUtil.parse_commands_set(commands_str) - nerv.printf("%s creating work_dir(%s)...\n", global_conf.sche_log_pre, global_conf.work_dir) nerv.LMUtil.wait(2) os.execute("mkdir -p "..global_conf.work_dir) @@ -419,10 +418,10 @@ if commands["train"] == 1 then global_conf.paramRepo = tnn:get_params() --get auto-generted params global_conf.paramRepo:export(global_conf.param_fn .. '.0', nil) --some parameters are auto-generated, saved again to param.0 file global_conf.dropout_rate = 0 - local result = LMTrainer.lm_process_file_rnn(global_conf, global_conf.valid_fn, tnn, false) --false update! + local result = LMTrainer.lm_process_file_birnn(global_conf, global_conf.valid_fn, tnn, false) --false update! nerv.LMUtil.wait(1) ppl_rec[0] = {} - ppl_rec[0].valid = result:ppl_all("rnn") + ppl_rec[0].valid = result:ppl_all("birnn") ppl_last = ppl_rec[0].valid ppl_rec[0].train = 0 ppl_rec[0].test = 0 @@ -439,20 +438,20 @@ if commands["train"] == 1 then tnn = load_net(global_conf, iter - 1) nerv.printf("===ITERATION %d LR %f===\n", iter, global_conf.lrate) global_conf.dropout_rate = nerv.SUtil.sche_get(global_conf.dropout_list, iter) - result = LMTrainer.lm_process_file_rnn(global_conf, global_conf.train_fn_shuf, tnn, true) --true update! + result = LMTrainer.lm_process_file_birnn(global_conf, global_conf.train_fn_shuf, tnn, true) --true update! global_conf.dropout_rate = 0 ppl_rec[iter] = {} - ppl_rec[iter].train = result:ppl_all("rnn") + ppl_rec[iter].train = result:ppl_all("birnn") --shuffling training file nerv.printf("%s shuffling training file\n", global_conf.sche_log_pre) os.execute('cp ' .. global_conf.train_fn_shuf .. ' ' .. global_conf.train_fn_shuf_bak) os.execute('cat ' .. global_conf.train_fn_shuf_bak .. ' | sort -R --random-source=/dev/zero > ' .. global_conf.train_fn_shuf) nerv.printf("===PEEK ON TEST %d===\n", iter) - result = LMTrainer.lm_process_file_rnn(global_conf, global_conf.test_fn, tnn, false) --false update! - ppl_rec[iter].test = result:ppl_all("rnn") + result = LMTrainer.lm_process_file_birnn(global_conf, global_conf.test_fn, tnn, false) --false update! + ppl_rec[iter].test = result:ppl_all("birnn") nerv.printf("===VALIDATION %d===\n", iter) - result = LMTrainer.lm_process_file_rnn(global_conf, global_conf.valid_fn, tnn, false) --false update! - ppl_rec[iter].valid = result:ppl_all("rnn") + result = LMTrainer.lm_process_file_birnn(global_conf, global_conf.valid_fn, tnn, false) --false update! + ppl_rec[iter].valid = result:ppl_all("birnn") ppl_rec[iter].lr = global_conf.lrate if ((ppl_last / ppl_rec[iter].valid < global_conf.lr_decay or lr_half == true) and iter > global_conf.decay_iter) then global_conf.lrate = (global_conf.lrate * 0.6) @@ -494,7 +493,7 @@ if commands["test"] == 1 then end tnn = load_net(global_conf, test_iter) global_conf.dropout_rate = 0 - LMTrainer.lm_process_file_rnn(global_conf, global_conf.test_fn, tnn, false) --false update! + LMTrainer.lm_process_file_birnn(global_conf, global_conf.test_fn, tnn, false) --false update! end --if commands["test"] if commands["testout"] == 1 then @@ -510,7 +509,7 @@ if commands["testout"] == 1 then end tnn = load_net(global_conf, test_iter) global_conf.dropout_rate = 0 - LMTrainer.lm_process_file_rnn(global_conf, q_fn, tnn, false) --false update! + LMTrainer.lm_process_file_birnn(global_conf, q_fn, tnn, false) --false update! end --if commands["testout"] diff --git a/nerv/examples/lmptb/lm_trainer.lua b/nerv/examples/lmptb/lm_trainer.lua index 2cdbd4f..0ccd847 100644 --- a/nerv/examples/lmptb/lm_trainer.lua +++ b/nerv/examples/lmptb/lm_trainer.lua @@ -168,7 +168,7 @@ function LMTrainer.lm_process_file_birnn(global_conf, fn, tnn, do_train, p_conf) chunk_size = global_conf.chunk_size end - reader = nerv.LMSeqReader(global_conf, batch_size, chunk_size, global_conf.vocab) + reader = nerv.LMSeqReader(global_conf, batch_size, chunk_size, global_conf.vocab, r_conf) reader:open_file(fn) local result = nerv.LMResult(global_conf, global_conf.vocab) |