aboutsummaryrefslogtreecommitdiff
path: root/nerv/examples/ptb/main.lua
blob: 5d1a326ee0fb4f4855b052e7524606f9ca581b70 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
nerv.include('reader.lua')
nerv.include('select_linear.lua')

gconf = {
    chunk_size = 15,
    dropout_rate = 0,
    lrate = 1.5,
    wcost = 1e-5,
    max_iter = 35,
    clip = 5,
    momentum = 0.9,
    batch_size = 200,
    test = true,
}

local hidden_size = 300
local vocab_size = 10000
local layer_num = 1
local dropout_rate = 0.5
local trainer = nerv.Trainer

function trainer:make_layer_repo(param_repo)
    local layers = {
        ['nerv.LSTMLayer'] = {},
        ['nerv.DropoutLayer'] = {},
        ['nerv.SelectLinearLayer'] = {
            ['select'] = {dim_in = {1}, dim_out = {hidden_size}, vocab = vocab_size, pr = param_repo},
        },
        ['nerv.AffineLayer'] = {
            output = {dim_in = {hidden_size}, dim_out = {vocab_size}, pr = param_repo},
        },
        ['nerv.SoftmaxCELayer'] = {
            softmax = {dim_in = {vocab_size, 1}, dim_out = {1}, compressed = true},
        },
    }
    for i = 1, layer_num do
        layers['nerv.LSTMLayer']['lstm' .. i] = {dim_in = {hidden_size}, dim_out = {hidden_size}, pr = param_repo}
        layers['nerv.DropoutLayer']['dropout' .. i] = {dim_in = {hidden_size}, dim_out = {hidden_size}}
    end
    return nerv.LayerRepo(layers, param_repo, gconf)
end

function trainer:get_network(layer_repo)
    local connections = {
        {'<input>[1]', 'select[1]', 0},
        {'select[1]', 'lstm1[1]', 0},
        {'dropout' .. layer_num .. '[1]', 'output[1]', 0},
        {'output[1]', 'softmax[1]', 0},
        {'<input>[2]', 'softmax[2]', 0},
        {'softmax[1]', '<output>[1]', 0},
    }
    for i = 1, layer_num do
        table.insert(connections, {'lstm' .. i .. '[1]', 'dropout' .. i .. '[1]', 0})
        if i < 1 then
            table.insert(connections, {'dropout' .. (i - 1) .. '[1]', 'lstm' .. i .. '[1]', 0})
        end
    end
    return nerv.GraphLayer('graph', gconf, {dim_in = {1, 1}, dim_out = {1}, layer_repo = layer_repo, connections = connections})
end

function trainer:get_input_order()
    return {'input', 'label'}
end

function trainer:get_readers(dataset)
    local data_path = 'nerv/nerv/examples/lmptb/PTBdata/'
    local vocab_file = data_path .. 'vocab'
    local train_file = data_path .. 'ptb.train.txt.adds'
    local cv_file = data_path .. 'ptb.valid.txt.adds'
    local test_file = data_path .. 'ptb.test.txt.adds'
    local reader
    if dataset == 'train' then  
        reader = nerv.Reader(vocab_file, train_file)
    elseif dataset == 'validate' then
        reader = nerv.Reader(vocab_file, cv_file)
    elseif dataset == 'test' then
        reader = nerv.Reader(vocab_file, test_file)
    else
        nerv.error('no such dataset')
    end
    return {{reader = reader, data = {input = 1, label = 1}}}
end

local total_err
local total_frame

function trainer:get_error()
    return math.pow(10, -total_err / total_frame)
end

function trainer:epoch_preprocess(dataset, do_train)
    if dataset == 'train' then
        gconf.dropout_rate = dropout_rate
        nerv.info('set dropout rate to %f', dropout_rate)
    end
    if dataset == 'validate' then
        gconf.dropout_rate = 0
        nerv.info('set dropout rate to 0')
    end
    if dataset == 'test' then
        gconf.dropout_rate = 0
        nerv.info('set dropout rate to 0')
    end
    total_err = 0
    total_frame = 0
end

function trainer:mini_batch_middleprocess(cnt, info)
    for t = 1, gconf.chunk_size do
        local tmp = info.output[1][t]:new_to_host()
        for i = 1, gconf.batch_size do
            total_err = total_err + math.log10(math.exp(tmp[i - 1][0]))
        end
    end
    for i = 1, gconf.batch_size do
        total_frame = total_frame + info.seq_length[i]
    end
end