aboutsummaryrefslogtreecommitdiff
path: root/nerv/examples/network_debug/main.lua
diff options
context:
space:
mode:
Diffstat (limited to 'nerv/examples/network_debug/main.lua')
-rw-r--r--nerv/examples/network_debug/main.lua21
1 files changed, 7 insertions, 14 deletions
diff --git a/nerv/examples/network_debug/main.lua b/nerv/examples/network_debug/main.lua
index 790c404..bbcdb6c 100644
--- a/nerv/examples/network_debug/main.lua
+++ b/nerv/examples/network_debug/main.lua
@@ -6,35 +6,26 @@ nerv.include(arg[1])
local global_conf = get_global_conf()
local timer = global_conf.timer
-timer:tic('IO')
-
local data_path = 'examples/lmptb/PTBdata/'
-local train_reader = nerv.Reader(data_path .. 'vocab', data_path .. 'ptb.train.txt.adds')
-local val_reader = nerv.Reader(data_path .. 'vocab', data_path .. 'ptb.valid.txt.adds')
-
-local train_data = train_reader:get_all_batch(global_conf)
-local val_data = val_reader:get_all_batch(global_conf)
local layers = get_layers(global_conf)
local connections = get_connections(global_conf)
-local NN = nerv.NN(global_conf, train_data, val_data, layers, connections)
-
-timer:toc('IO')
-timer:check('IO')
-io.flush()
+local NN = nerv.NN(global_conf, layers, connections)
timer:tic('global')
local best_cv = 1e10
for i = 1, global_conf.max_iter do
timer:tic('Epoch' .. i)
- local train_ppl, val_ppl = NN:epoch()
+ local train_reader = nerv.Reader(data_path .. 'vocab', data_path .. 'ptb.train.txt.adds')
+ local val_reader = nerv.Reader(data_path .. 'vocab', data_path .. 'ptb.valid.txt.adds')
+ local train_ppl, val_ppl = NN:epoch(train_reader, val_reader)
+ nerv.printf('Epoch %d: %f %f %f\n', i, global_conf.lrate, train_ppl, val_ppl)
if val_ppl < best_cv then
best_cv = val_ppl
else
global_conf.lrate = global_conf.lrate / 2.0
end
- nerv.printf('Epoch %d: %f %f %f\n', i, global_conf.lrate, train_ppl, val_ppl)
timer:toc('Epoch' .. i)
timer:check('Epoch' .. i)
io.flush()
@@ -43,3 +34,5 @@ timer:toc('global')
timer:check('global')
timer:check('network')
timer:check('gc')
+timer:check('IO')
+global_conf.cumat_type.print_profile()