aboutsummaryrefslogtreecommitdiff
path: root/nerv/main.lua
blob: 865aba0d29a61d061a2745a94faf07de431e3dd4 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
local global_conf = {
    cumat_type = nerv.CuMatrixFloat,
    param_random = function() return 0 end,
    lrate = 0.1,
    wcost = 0,
    momentum = 0.9,
    batch_size = 2,
}

local layer_repo = nerv.LayerRepo(
    {
        ['nerv.RNNLayer'] = {
            rnn = {dim_in = {23}, dim_out = {26}},
        },
        ['nerv.AffineLayer'] = {
            input = {dim_in = {62}, dim_out = {23}},
            output = {dim_in = {26, 79}, dim_out = {79}},
        },
        ['nerv.SigmoidLayer'] = {
            sigmoid = {dim_in = {23}, dim_out = {23}},
        },
        ['nerv.IdentityLayer'] = {
            softmax = {dim_in = {79}, dim_out = {79}},
        },
        ['nerv.DuplicateLayer'] = {
            dup = {dim_in = {79}, dim_out = {79, 79}},
        },
    }, nerv.ParamRepo(), global_conf)

local connections = {
    {'<input>[1]', 'input[1]', 0},
    {'input[1]', 'sigmoid[1]', 0},
    {'sigmoid[1]', 'rnn[1]', 0},
    {'rnn[1]', 'output[1]', 0},
    {'output[1]', 'dup[1]', 0},
    {'dup[1]', 'output[2]', -1},
    {'dup[2]', 'softmax[1]', 0},
    {'softmax[1]', '<output>[1]', 0},
}

local graph = nerv.GraphLayer('graph', global_conf, {dim_in = {62}, dim_out = {79}, layer_repo = layer_repo, connections = connections})

local network = nerv.Network('network', global_conf, {network = graph})

local batch = global_conf.batch_size
local chunk = 5
network:init(batch, chunk)

local input = {}
local output = {}
local err_input = {}
local err_output = {}
local input_size = 62
local output_size = 79
for i = 1, chunk do
    input[i] = {global_conf.cumat_type(batch, input_size)}
    output[i] = {global_conf.cumat_type(batch, output_size)}
    err_input[i] = {global_conf.cumat_type(batch, output_size)}
    err_output[i] = {global_conf.cumat_type(batch, input_size)}
end

for i = 1, 100 do
    network:mini_batch_init({seq_length = {5, 3}, new_seq = {2}})
    network:propagate(input, output)
    network:back_propagate(err_input, err_output, input, output)
    network:update(err_input, input, output)
end