aboutsummaryrefslogtreecommitdiff
path: root/lua/network.lua
blob: d106ba15c4416f9fe3ac315301a23ab9a08b673f (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
nerv.include('select_linear.lua')

local nn = nerv.class('nerv.NN')

function nn:__init(global_conf, train_data, val_data, layers, connections)
    self.gconf = global_conf
    self.network = self:get_network(layers, connections)
    self.train_data = self:get_data(train_data)
    self.val_data = self:get_data(val_data)
end

function nn:get_network(layers, connections)
    local layer_repo = nerv.LayerRepo(layers, self.gconf.pr, self.gconf)
    local graph = nerv.GraphLayer('graph', self.gconf, 
        {dim_in = {1, self.gconf.vocab_size}, dim_out = {1}, 
        layer_repo = layer_repo, connections = connections})
    local network = nerv.Network('network', self.gconf, 
        {network = graph, clip = self.gconf.clip})
    network:init(self.gconf.batch_size, self.gconf.chunk_size)
    return network
end

function nn:get_data(data)
    local err_output = {}
    local softmax_output = {}
    local output = {}
    for i = 1, self.gconf.chunk_size do
        err_output[i] = self.gconf.cumat_type(self.gconf.batch_size, 1)
        softmax_output[i] = self.gconf.cumat_type(self.gconf.batch_size, self.gconf.vocab_size)
        output[i] = self.gconf.cumat_type(self.gconf.batch_size, 1)
    end
    local ret = {}
    for i = 1, #data do
        ret[i] = {}
        ret[i].input = {}
        ret[i].output = {}
        ret[i].err_input = {}
        ret[i].err_output = {}
        for t = 1, self.gconf.chunk_size do
            ret[i].input[t] = {}
            ret[i].output[t] = {}
            ret[i].err_input[t] = {}
            ret[i].err_output[t] = {}
            ret[i].input[t][1] = data[i].input[t]
            ret[i].input[t][2] = data[i].output[t]
            ret[i].output[t][1] = output[t]
            local err_input = self.gconf.mmat_type(self.gconf.batch_size, 1)
            for j = 1, self.gconf.batch_size do
                if t <= data[i].seq_len[j] then
                    err_input[j - 1][0] = 1
                else
                    err_input[j - 1][0] = 0
                end
            end
            ret[i].err_input[t][1] = self.gconf.cumat_type.new_from_host(err_input)
            ret[i].err_output[t][1] = err_output[t]
            ret[i].err_output[t][2] = softmax_output[t]
        end
        ret[i].seq_length = data[i].seq_len
        ret[i].new_seq = {}
        for j = 1, self.gconf.batch_size do
            if data[i].seq_start[j] then
                table.insert(ret[i].new_seq, j)
            end
        end
    end
    return ret
end

function nn:process(data, do_train)
    local timer = self.gconf.timer
    local total_err = 0
    local total_frame = 0
    for id = 1, #data do
        data[id].do_train = do_train
        timer:tic('network')
        self.network:mini_batch_init(data[id])
        self.network:propagate()
        timer:toc('network')
        for t = 1, self.gconf.chunk_size do
            local tmp = data[id].output[t][1]:new_to_host()
            for i = 1, self.gconf.batch_size do
                if t <= data[id].seq_length[i] then
                    total_err = total_err + math.log10(math.exp(tmp[i - 1][0]))
                    total_frame = total_frame + 1
                end
            end
        end
        if do_train then
            timer:tic('network')
            self.network:back_propagate()
            self.network:update()
            timer:toc('network')
        end
        timer:tic('gc')
        collectgarbage('collect')
        timer:toc('gc')
    end
    return math.pow(10, - total_err / total_frame)
end

function nn:epoch()
    local train_error = self:process(self.train_data, true)
    local val_error = self:process(self.val_data, false)
    return train_error, val_error
end