diff options
author | Qi Liu <[email protected]> | 2016-03-09 11:58:13 +0800 |
---|---|---|
committer | Qi Liu <[email protected]> | 2016-03-09 11:58:13 +0800 |
commit | 05fcde5bf0caa1ceb70fef02fc88eda6f00c5ed5 (patch) | |
tree | a3bfb245d3f106525ec2ff4f987848fcd3f56217 /lua/main.lua | |
parent | 4e56b863203ab6919192efe973ba9f8ee0d5ac65 (diff) |
add recipe
Diffstat (limited to 'lua/main.lua')
-rw-r--r-- | lua/main.lua | 43 |
1 files changed, 43 insertions, 0 deletions
diff --git a/lua/main.lua b/lua/main.lua new file mode 100644 index 0000000..684efac --- /dev/null +++ b/lua/main.lua @@ -0,0 +1,43 @@ +nerv.include('reader.lua') +nerv.include('timer.lua') +nerv.include('config.lua') +nerv.include(arg[1]) + +local global_conf = get_global_conf() +local timer = global_conf.timer + +timer:tic('IO') + +local data_path = 'nerv/nerv/examples/lmptb/PTBdata/' +local train_reader = nerv.Reader(data_path .. 'vocab', data_path .. 'ptb.train.txt.adds') +local val_reader = nerv.Reader(data_path .. 'vocab', data_path .. 'ptb.valid.txt.adds') + +local train_data = train_reader:get_all_batch(global_conf) +local val_data = val_reader:get_all_batch(global_conf) + +local layers = get_layers(global_conf) +local connections = get_connections(global_conf) + +local NN = nerv.NN(global_conf, train_data, val_data, layers, connections) + +timer:toc('IO') +timer:check('IO') +io.flush() + +timer:tic('global') +local best_cv = 1e10 +for i = 1, global_conf.max_iter do + timer:tic('Epoch' .. i) + local train_ppl, val_ppl = NN:epoch() + if val_ppl < best_cv then + best_cv = val_ppl + else + global_conf.lrate = global_conf.lrate / 2.0 + end + nerv.printf('Epoch %d: %f %f %f\n', i, global_conf.lrate, train_ppl, val_ppl) + timer:toc('Epoch' .. i) + timer:check('Epoch' .. i) + io.flush() +end +timer:toc('global') +timer:check('global') |