aboutsummaryrefslogtreecommitdiff
path: root/nerv/examples/network_debug/main.lua
blob: bbcdb6c1253fa6c79d4713def64dcc7e0274b0b4 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
nerv.include('reader.lua')
nerv.include('timer.lua')
nerv.include('config.lua')
nerv.include(arg[1])

local global_conf = get_global_conf()
local timer = global_conf.timer

local data_path = 'examples/lmptb/PTBdata/'

local layers = get_layers(global_conf)
local connections = get_connections(global_conf)

local NN = nerv.NN(global_conf, layers, connections)

timer:tic('global')
local best_cv = 1e10
for i = 1, global_conf.max_iter do
    timer:tic('Epoch' .. i)
    local train_reader = nerv.Reader(data_path .. 'vocab', data_path .. 'ptb.train.txt.adds')
    local val_reader = nerv.Reader(data_path .. 'vocab', data_path .. 'ptb.valid.txt.adds')
    local train_ppl, val_ppl = NN:epoch(train_reader, val_reader)
    nerv.printf('Epoch %d: %f %f %f\n', i, global_conf.lrate, train_ppl, val_ppl)
    if val_ppl < best_cv then
        best_cv = val_ppl
    else
        global_conf.lrate = global_conf.lrate / 2.0
    end
    timer:toc('Epoch' .. i)
    timer:check('Epoch' .. i)
    io.flush()
end
timer:toc('global')
timer:check('global')
timer:check('network')
timer:check('gc')
timer:check('IO')
global_conf.cumat_type.print_profile()