aboutsummaryrefslogtreecommitdiff
path: root/nerv/main.lua
blob: 85e291c2aeaf7beb4d81fb028ee38145344293fe (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
print 'Hello'

local global_conf = {
    cumat_type = nerv.CuMatrixFloat,
    param_random = function() return 0 end,
}

local layer_repo = nerv.LayerRepo(
    {
        ['nerv.RNNLayer'] = {
            rnn = {dim_in = {23}, dim_out = {26}},
        },
        ['nerv.AffineLayer'] = {
            input = {dim_in = {20}, dim_out = {23}},
            output = {dim_in = {26, 79}, dim_out = {79}},
        },
        ['nerv.SigmoidLayer'] = {
            sigmoid = {dim_in = {23}, dim_out = {23}},
        },
    }, nerv.ParamRepo(), global_conf)

local connections = {
    {'<input>[1]', 'input[1]', 0},
    {'input[1]', 'sigmoid[1]', 0},
    {'sigmoid[1]', 'rnn[1]', 0},
    {'rnn[1]', 'output[1]', 0},
    {'output[1]', 'output[2]', 1},
    {'output[1]', '<output>[1]', 0},
}

local network = nerv.GraphLayer('network', global_conf, {dim_in = {20}, dim_out = {79}, layer_repo = layer_repo, connections = connections})