summaryrefslogtreecommitdiff
path: root/lua/main.lua
diff options
context:
space:
mode:
authorQi Liu <[email protected]>2016-03-11 18:18:59 +0800
committerQi Liu <[email protected]>2016-03-11 18:18:59 +0800
commit2f46a5e2b37a054f482f76f4ac3d26b144cf988f (patch)
treee442e76741d664a29924c5f0ec6cc72e87345539 /lua/main.lua
parent13729e83219cd90e33f329c49a50f6f4a4420721 (diff)
add lua
Diffstat (limited to 'lua/main.lua')
-rw-r--r--lua/main.lua45
1 files changed, 45 insertions, 0 deletions
diff --git a/lua/main.lua b/lua/main.lua
new file mode 100644
index 0000000..39818aa
--- /dev/null
+++ b/lua/main.lua
@@ -0,0 +1,45 @@
+nerv.include('reader.lua')
+nerv.include('timer.lua')
+nerv.include('config.lua')
+nerv.include(arg[1])
+
+local global_conf = get_global_conf()
+local timer = global_conf.timer
+
+timer:tic('IO')
+
+local data_path = 'nerv/nerv/examples/lmptb/PTBdata/'
+local train_reader = nerv.Reader(data_path .. 'vocab', data_path .. 'ptb.train.txt.adds')
+local val_reader = nerv.Reader(data_path .. 'vocab', data_path .. 'ptb.valid.txt.adds')
+
+local train_data = train_reader:get_all_batch(global_conf)
+local val_data = val_reader:get_all_batch(global_conf)
+
+local layers = get_layers(global_conf)
+local connections = get_connections(global_conf)
+
+local NN = nerv.NN(global_conf, train_data, val_data, layers, connections)
+
+timer:toc('IO')
+timer:check('IO')
+io.flush()
+
+timer:tic('global')
+local best_cv = 1e10
+for i = 1, global_conf.max_iter do
+ timer:tic('Epoch' .. i)
+ local train_ppl, val_ppl = NN:epoch()
+ if val_ppl < best_cv then
+ best_cv = val_ppl
+ else
+ global_conf.lrate = global_conf.lrate / 2.0
+ end
+ nerv.printf('Epoch %d: %f %f %f\n', i, global_conf.lrate, train_ppl, val_ppl)
+ timer:toc('Epoch' .. i)
+ timer:check('Epoch' .. i)
+ io.flush()
+end
+timer:toc('global')
+timer:check('global')
+timer:check('network')
+timer:check('gc')