aboutsummaryrefslogtreecommitdiff
path: root/lua/tnn.lua
blob: bf9f1189c389105245f205718be2cf42489db2bd (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
nerv.include('select_linear.lua')

local reader = nerv.class('nerv.TNNReader')

function reader:__init(global_conf, data)
    self.gconf = global_conf
    self.offset = 0
    self.data = data
end

function reader:get_batch(feeds)
    self.offset = self.offset + 1
    if self.offset > #self.data then
        return false
    end
    for i = 1, self.gconf.chunk_size do
        feeds.inputs_m[i][1]:copy_from(self.data[self.offset].input[i])
        feeds.inputs_m[i][2]:copy_from(self.data[self.offset].output[i]:decompress(self.gconf.vocab_size))
    end
    feeds.flags_now = self.data[self.offset].flags
    feeds.flagsPack_now = self.data[self.offset].flagsPack
    return true
end

function reader:has_data(t, i)
    return t <= self.data[self.offset].seq_len[i]
end

function reader:get_err_input()
    return self.data[self.offset].err_input
end

local nn = nerv.class('nerv.NN')

function nn:__init(global_conf, train_data, val_data, layers, connections)
    self.gconf = global_conf
    self.tnn = self:get_tnn(layers, connections)
    self.train_data = self:get_data(train_data)
    self.val_data = self:get_data(val_data)
end

function nn:get_tnn(layers, connections)
    self.gconf.dropout_rate = 0
    local layer_repo = nerv.LayerRepo(layers, self.gconf.pr, self.gconf)
    local tnn = nerv.TNN('TNN', self.gconf, {dim_in = {1, self.gconf.vocab_size}, 
        dim_out = {1}, sub_layers = layer_repo, connections = connections, 
        clip = self.gconf.clip})
    tnn:init(self.gconf.batch_size, self.gconf.chunk_size)
    return tnn
end

function nn:get_data(data)
    local ret = {}
    for i = 1, #data do
        ret[i] = {}
        ret[i].input = data[i].input
        ret[i].output = data[i].output
        ret[i].flags = {}
        ret[i].err_input = {}
        for t = 1, self.gconf.chunk_size do
            ret[i].flags[t] = {}
            local err_input = self.gconf.mmat_type(self.gconf.batch_size, 1)
            for j = 1, self.gconf.batch_size do
                if t <= data[i].seq_len[j] then
                    ret[i].flags[t][j] = nerv.TNN.FC.SEQ_NORM
                    err_input[j - 1][0] = 1
                else
                    ret[i].flags[t][j] = 0
                    err_input[j - 1][0] = 0
                end
            end
            ret[i].err_input[t] = self.gconf.cumat_type.new_from_host(err_input)
        end
        for j = 1, self.gconf.batch_size do
            if data[i].seq_start[j] then
                ret[i].flags[1][j] = bit.bor(ret[i].flags[1][j], nerv.TNN.FC.SEQ_START)
            end
            if data[i].seq_end[j] then
                local t = data[i].seq_len[j]
                ret[i].flags[t][j] = bit.bor(ret[i].flags[t][j], nerv.TNN.FC.SEQ_END)
            end
        end
        ret[i].flagsPack = {}
        for t = 1, self.gconf.chunk_size do
            ret[i].flagsPack[t] = 0
            for j = 1, self.gconf.batch_size do
                ret[i].flagsPack[t] = bit.bor(ret[i].flagsPack[t], ret[i].flags[t][j])
            end
        end
        ret[i].seq_len = data[i].seq_len
    end
    return ret
end

function nn:process(data, do_train)
    local total_err = 0
    local total_frame = 0
    local reader = nerv.TNNReader(self.gconf, data)
    while true do
        local r, _ = self.tnn:getfeed_from_reader(reader)
        if not r then
            break
        end
        if do_train then
            self.gconf.dropout_rate = self.gconf.dropout
        else
            self.gconf.dropout_rate = 0
        end
        self.tnn:net_propagate()
        for t = 1, self.gconf.chunk_size do
            local tmp = self.tnn.outputs_m[t][1]:new_to_host()
            for i = 1, self.gconf.batch_size do
                if reader:has_data(t, i) then
                    total_err = total_err + math.log10(math.exp(tmp[i - 1][0]))
                    total_frame = total_frame + 1
                end
            end
        end
        if do_train then
            local err_input = reader:get_err_input()
            for i = 1, self.gconf.chunk_size do
                self.tnn.err_inputs_m[i][1]:copy_from(err_input[i])
            end
            self.tnn:net_backpropagate(false)
            self.tnn:net_backpropagate(true)
        end
        collectgarbage('collect')
    end
    return math.pow(10, - total_err / total_frame)
end

function nn:epoch()
    local train_error = self:process(self.train_data, true)
    local val_error = self:process(self.val_data, false)
    return train_error, val_error
end