From ee81548f79c496f4e2f3e12325150bb96ecb432e Mon Sep 17 00:00:00 2001
From: txh18 <cloudygooseg@gmail.com>
Date: Fri, 13 Nov 2015 15:17:07 +0800
Subject: saving param file for every iter

---
 nerv/examples/lmptb/tnn_ptb_main.lua | 56 ++++++++++++++++++++++--------------
 1 file changed, 35 insertions(+), 21 deletions(-)

diff --git a/nerv/examples/lmptb/tnn_ptb_main.lua b/nerv/examples/lmptb/tnn_ptb_main.lua
index 5cc92c4..a59a44b 100644
--- a/nerv/examples/lmptb/tnn_ptb_main.lua
+++ b/nerv/examples/lmptb/tnn_ptb_main.lua
@@ -14,10 +14,11 @@ local LMTrainer = nerv.LMTrainer
 --global_conf: table
 --first_time: bool
 --Returns: a ParamRepo
-function prepare_parameters(global_conf, first_time)
+function prepare_parameters(global_conf, iter)
     printf("%s preparing parameters...\n", global_conf.sche_log_pre) 
     
-    if (first_time) then
+    if (iter == -1) then --first time
+        printf("%s first time, generating parameters...\n", global_conf.sche_log_pre) 
         ltp_ih = nerv.LinearTransParam("ltp_ih", global_conf)  
         ltp_ih.trans = global_conf.cumat_type(global_conf.vocab:size(), global_conf.hidden_size) --index 0 is for zero, others correspond to vocab index(starting from 1)
         ltp_ih.trans:generate(global_conf.param_random)
@@ -38,17 +39,20 @@ function prepare_parameters(global_conf, first_time)
         bp_o.trans = global_conf.cumat_type(1, global_conf.vocab:size())
         bp_o.trans:generate(global_conf.param_random)
 
-        local f = nerv.ChunkFile(global_conf.param_fn, 'w')
+        local f = nerv.ChunkFile(global_conf.param_fn .. '.0', 'w')
         f:write_chunk(ltp_ih)
         f:write_chunk(ltp_hh)
         f:write_chunk(ltp_ho)
         f:write_chunk(bp_h)
         f:write_chunk(bp_o)
         f:close()
+
+        return nil
     end
     
+    printf("%s loading parameter from file %s...\n", global_conf.sche_log_pre, global_conf.param_fn .. '.' .. tostring(iter)) 
     local paramRepo = nerv.ParamRepo()
-    paramRepo:import({global_conf.param_fn}, nil, global_conf)
+    paramRepo:import({global_conf.param_fn .. '.' .. tostring(iter)}, nil, global_conf)
 
     printf("%s preparing parameters end.\n", global_conf.sche_log_pre)
 
@@ -139,8 +143,8 @@ function prepare_tnn(global_conf, layerRepo)
     return tnn
 end
 
-function load_net(global_conf)
-    local paramRepo = prepare_parameters(global_conf, false)
+function load_net(global_conf, next_iter)
+    local paramRepo = prepare_parameters(global_conf, next_iter)
     local layerRepo = prepare_layers(global_conf, paramRepo)
     local tnn = prepare_tnn(global_conf, layerRepo)
     return tnn, paramRepo
@@ -152,9 +156,10 @@ local set = arg[1] --"test"
 if (set == "ptb") then
 
 data_dir = '/home/slhome/txh18/workspace/nerv/nerv/nerv/examples/lmptb/PTBdata'
-train_fn = data_dir .. '/ptb.train.txt.adds'
+train_fn = data_dir .. '/ptb.valid.txt.adds'
 valid_fn = data_dir .. '/ptb.valid.txt.adds'
 test_fn = data_dir .. '/ptb.test.txt.adds'
+vocab_fn = data_dir .. '/vocab'
 
 global_conf = {
     lrate = 1, wcost = 1e-6, momentum = 0,
@@ -171,6 +176,7 @@ global_conf = {
     train_fn = train_fn,
     valid_fn = valid_fn,
     test_fn = test_fn,
+    vocab_fn = vocab_fn,
     sche_log_pre = "[SCHEDULER]:",
     log_w_num = 40000, --give a message when log_w_num words have been processed
     timer = nerv.Timer()
@@ -181,6 +187,7 @@ else
 valid_fn = '/home/slhome/txh18/workspace/nerv/nerv/nerv/examples/lmptb/m-tests/some-text'
 train_fn = '/home/slhome/txh18/workspace/nerv/nerv/nerv/examples/lmptb/m-tests/some-text'
 test_fn = '/home/slhome/txh18/workspace/nerv/nerv/nerv/examples/lmptb/m-tests/some-text'
+vocab_fn = '/home/slhome/txh18/workspace/nerv/nerv/nerv/examples/lmptb/m-tests/some-text'
 
 global_conf = {
     lrate = 1, wcost = 1e-6, momentum = 0,
@@ -197,6 +204,7 @@ global_conf = {
     train_fn = train_fn,
     valid_fn = valid_fn,
     test_fn = test_fn,
+    vocab_fn = vocab_fn, 
     sche_log_pre = "[SCHEDULER]:",
     log_w_num = 10, --give a message when log_w_num words have been processed
     timer = nerv.Timer()
@@ -221,28 +229,30 @@ os.execute("cp " .. global_conf.train_fn .. " " .. global_conf.train_fn_shuf)
 
 local vocab = nerv.LMVocab()
 global_conf["vocab"] = vocab
-global_conf.vocab:build_file(global_conf.train_fn, false)
+printf("%s building vocab...\n", global_conf.sche_log_pre)
+global_conf.vocab:build_file(global_conf.vocab_fn, false)
 
-prepare_parameters(global_conf, true) --randomly generate parameters
+prepare_parameters(global_conf, -1) --randomly generate parameters
 
 print("===INITIAL VALIDATION===") 
-local tnn, paramRepo = load_net(global_conf)
+local tnn, paramRepo = load_net(global_conf, 0)
 local result = LMTrainer.lm_process_file(global_conf, global_conf.valid_fn, tnn, false) --false update!
 nerv.LMUtil.wait(3)
 ppl_rec = {} 
-lr_rec = {}
 ppl_rec[0] = {} 
 ppl_rec[0].valid = result:ppl_all("rnn")  
 ppl_last = ppl_rec[0].valid 
 ppl_rec[0].train = 0 
 ppl_rec[0].test = 0
-lr_rec[0] = 0 
+ppl_rec[0].lr = 0 
 print() 
 local lr_half = false 
+local final_iter
 for iter = 1, global_conf.max_iter, 1 do
-    tnn, paramRepo = load_net(global_conf) 
-    printf("===ITERATION %d LR %f===\n", iter, global_conf.lrate) 
+    final_iter = iter
     global_conf.sche_log_pre = "[SCHEDULER ITER"..iter.." LR"..global_conf.lrate.."]:" 
+    tnn, paramRepo = load_net(global_conf, iter - 1) 
+    printf("===ITERATION %d LR %f===\n", iter, global_conf.lrate) 
     result = LMTrainer.lm_process_file(global_conf, global_conf.train_fn_shuf, tnn, true) --true update!
     ppl_rec[iter] = {}
     ppl_rec[iter].train = result:ppl_all("rnn")
@@ -256,29 +266,33 @@ for iter = 1, global_conf.max_iter, 1 do
     printf("===VALIDATION %d===\n", iter) 
     result = LMTrainer.lm_process_file(global_conf, global_conf.valid_fn, tnn, false) --false update!
     ppl_rec[iter].valid = result:ppl_all("rnn") 
-    lr_rec[iter] = global_conf.lrate 
+    ppl_rec[iter].lr = global_conf.lrate 
     if (ppl_last / ppl_rec[iter].valid < 1.0003 or lr_half == true) then 
         global_conf.lrate = (global_conf.lrate * 0.6)
         lr_half = true 
     end 
     if (ppl_rec[iter].valid < ppl_last) then 
-        printf("%s saving net to file %s...\n", global_conf.sche_log_pre, global_conf.param_fn) 
-        paramRepo:export(global_conf.param_fn, nil) 
+        printf("%s PPL improves, saving net to file %s.%d...\n", global_conf.sche_log_pre, global_conf.param_fn, iter) 
+        paramRepo:export(global_conf.param_fn .. '.' .. tostring(iter), nil) 
         ppl_last = ppl_rec[iter].valid
     else 
-        printf("%s PPL did not improve, rejected...\n", global_conf.sche_log_pre) 
-        if (lr_halg == true) then
+        printf("%s PPL did not improve, rejected, copying param file of last iter...\n", global_conf.sche_log_pre) 
+        os.execute('cp ' .. global_conf.param_fn..'.'..tostring(iter - 1) .. ' ' .. global_conf.param_fn..'.'..tostring(iter))
+        if (lr_half == true) then
             printf("%s LR is already halfing, end training...\n", global_conf.sche_log_pre)
+            break
         end
     end 
     printf("\n") 
     nerv.LMUtil.wait(2) 
 end
 printf("===VALIDATION PPL record===\n") 
-for i = 0, #ppl_rec do printf("<ITER%d LR%.5f train:%.3f valid:%.3f test:%.3f> \n", i, lr_rec[i], ppl_rec[i].train, ppl_rec[i].valid, ppl_rec[i].test) end 
+for i, _ in pairs(ppl_rec) do 
+    printf("<ITER%d LR%.5f train:%.3f valid:%.3f test:%.3f> \n", i, ppl_rec[i].lr, ppl_rec[i].train, ppl_rec[i].valid, ppl_rec[i].test) 
+end 
 printf("\n") 
 printf("===FINAL TEST===\n") 
 global_conf.sche_log_pre = "[SCHEDULER FINAL_TEST]:" 
-tnn, paramRepo = load_net(global_conf) 
+tnn, paramRepo = load_net(global_conf, final_iter) 
 LMTrainer.lm_process_file(global_conf, global_conf.test_fn, tnn, false) --false update!
 
-- 
cgit v1.2.3-70-g09d2