aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--nerv/examples/lmptb/bilstmlm_ptb_main.lua23
-rw-r--r--nerv/examples/lmptb/lm_trainer.lua2
2 files changed, 12 insertions, 13 deletions
diff --git a/nerv/examples/lmptb/bilstmlm_ptb_main.lua b/nerv/examples/lmptb/bilstmlm_ptb_main.lua
index 9100d0d..cf0009b 100644
--- a/nerv/examples/lmptb/bilstmlm_ptb_main.lua
+++ b/nerv/examples/lmptb/bilstmlm_ptb_main.lua
@@ -325,7 +325,7 @@ global_conf = {
hidden_size = 20,
layer_num = 1,
- chunk_size = 2,
+ chunk_size = 20,
batch_size = 10,
max_iter = 3,
param_random = function() return (math.random() / 5 - 0.1) end,
@@ -372,7 +372,6 @@ global_conf.dropout_list = nerv.SUtil.parse_schedule(global_conf.dropout_str)
global_conf.log_fn = global_conf.work_dir .. '/log_lstm_tnn_' .. commands_str ..os.date("_TT%m_%d_%X",os.time())
global_conf.log_fn, _ = string.gsub(global_conf.log_fn, ':', '-')
commands = nerv.SUtil.parse_commands_set(commands_str)
-
nerv.printf("%s creating work_dir(%s)...\n", global_conf.sche_log_pre, global_conf.work_dir)
nerv.LMUtil.wait(2)
os.execute("mkdir -p "..global_conf.work_dir)
@@ -419,10 +418,10 @@ if commands["train"] == 1 then
global_conf.paramRepo = tnn:get_params() --get auto-generted params
global_conf.paramRepo:export(global_conf.param_fn .. '.0', nil) --some parameters are auto-generated, saved again to param.0 file
global_conf.dropout_rate = 0
- local result = LMTrainer.lm_process_file_rnn(global_conf, global_conf.valid_fn, tnn, false) --false update!
+ local result = LMTrainer.lm_process_file_birnn(global_conf, global_conf.valid_fn, tnn, false) --false update!
nerv.LMUtil.wait(1)
ppl_rec[0] = {}
- ppl_rec[0].valid = result:ppl_all("rnn")
+ ppl_rec[0].valid = result:ppl_all("birnn")
ppl_last = ppl_rec[0].valid
ppl_rec[0].train = 0
ppl_rec[0].test = 0
@@ -439,20 +438,20 @@ if commands["train"] == 1 then
tnn = load_net(global_conf, iter - 1)
nerv.printf("===ITERATION %d LR %f===\n", iter, global_conf.lrate)
global_conf.dropout_rate = nerv.SUtil.sche_get(global_conf.dropout_list, iter)
- result = LMTrainer.lm_process_file_rnn(global_conf, global_conf.train_fn_shuf, tnn, true) --true update!
+ result = LMTrainer.lm_process_file_birnn(global_conf, global_conf.train_fn_shuf, tnn, true) --true update!
global_conf.dropout_rate = 0
ppl_rec[iter] = {}
- ppl_rec[iter].train = result:ppl_all("rnn")
+ ppl_rec[iter].train = result:ppl_all("birnn")
--shuffling training file
nerv.printf("%s shuffling training file\n", global_conf.sche_log_pre)
os.execute('cp ' .. global_conf.train_fn_shuf .. ' ' .. global_conf.train_fn_shuf_bak)
os.execute('cat ' .. global_conf.train_fn_shuf_bak .. ' | sort -R --random-source=/dev/zero > ' .. global_conf.train_fn_shuf)
nerv.printf("===PEEK ON TEST %d===\n", iter)
- result = LMTrainer.lm_process_file_rnn(global_conf, global_conf.test_fn, tnn, false) --false update!
- ppl_rec[iter].test = result:ppl_all("rnn")
+ result = LMTrainer.lm_process_file_birnn(global_conf, global_conf.test_fn, tnn, false) --false update!
+ ppl_rec[iter].test = result:ppl_all("birnn")
nerv.printf("===VALIDATION %d===\n", iter)
- result = LMTrainer.lm_process_file_rnn(global_conf, global_conf.valid_fn, tnn, false) --false update!
- ppl_rec[iter].valid = result:ppl_all("rnn")
+ result = LMTrainer.lm_process_file_birnn(global_conf, global_conf.valid_fn, tnn, false) --false update!
+ ppl_rec[iter].valid = result:ppl_all("birnn")
ppl_rec[iter].lr = global_conf.lrate
if ((ppl_last / ppl_rec[iter].valid < global_conf.lr_decay or lr_half == true) and iter > global_conf.decay_iter) then
global_conf.lrate = (global_conf.lrate * 0.6)
@@ -494,7 +493,7 @@ if commands["test"] == 1 then
end
tnn = load_net(global_conf, test_iter)
global_conf.dropout_rate = 0
- LMTrainer.lm_process_file_rnn(global_conf, global_conf.test_fn, tnn, false) --false update!
+ LMTrainer.lm_process_file_birnn(global_conf, global_conf.test_fn, tnn, false) --false update!
end --if commands["test"]
if commands["testout"] == 1 then
@@ -510,7 +509,7 @@ if commands["testout"] == 1 then
end
tnn = load_net(global_conf, test_iter)
global_conf.dropout_rate = 0
- LMTrainer.lm_process_file_rnn(global_conf, q_fn, tnn, false) --false update!
+ LMTrainer.lm_process_file_birnn(global_conf, q_fn, tnn, false) --false update!
end --if commands["testout"]
diff --git a/nerv/examples/lmptb/lm_trainer.lua b/nerv/examples/lmptb/lm_trainer.lua
index 2cdbd4f..0ccd847 100644
--- a/nerv/examples/lmptb/lm_trainer.lua
+++ b/nerv/examples/lmptb/lm_trainer.lua
@@ -168,7 +168,7 @@ function LMTrainer.lm_process_file_birnn(global_conf, fn, tnn, do_train, p_conf)
chunk_size = global_conf.chunk_size
end
- reader = nerv.LMSeqReader(global_conf, batch_size, chunk_size, global_conf.vocab)
+ reader = nerv.LMSeqReader(global_conf, batch_size, chunk_size, global_conf.vocab, r_conf)
reader:open_file(fn)
local result = nerv.LMResult(global_conf, global_conf.vocab)