aboutsummaryrefslogblamecommitdiff
path: root/lua/tnn.lua
blob: bf9f1189c389105245f205718be2cf42489db2bd (plain) (tree)







































































































































                                                                                                          
nerv.include('select_linear.lua')

local reader = nerv.class('nerv.TNNReader')

function reader:__init(global_conf, data)
    self.gconf = global_conf
    self.offset = 0
    self.data = data
end

function reader:get_batch(feeds)
    self.offset = self.offset + 1
    if self.offset > #self.data then
        return false
    end
    for i = 1, self.gconf.chunk_size do
        feeds.inputs_m[i][1]:copy_from(self.data[self.offset].input[i])
        feeds.inputs_m[i][2]:copy_from(self.data[self.offset].output[i]:decompress(self.gconf.vocab_size))
    end
    feeds.flags_now = self.data[self.offset].flags
    feeds.flagsPack_now = self.data[self.offset].flagsPack
    return true
end

function reader:has_data(t, i)
    return t <= self.data[self.offset].seq_len[i]
end

function reader:get_err_input()
    return self.data[self.offset].err_input
end

local nn = nerv.class('nerv.NN')

function nn:__init(global_conf, train_data, val_data, layers, connections)
    self.gconf = global_conf
    self.tnn = self:get_tnn(layers, connections)
    self.train_data = self:get_data(train_data)
    self.val_data = self:get_data(val_data)
end

function nn:get_tnn(layers, connections)
    self.gconf.dropout_rate = 0
    local layer_repo = nerv.LayerRepo(layers, self.gconf.pr, self.gconf)
    local tnn = nerv.TNN('TNN', self.gconf, {dim_in = {1, self.gconf.vocab_size}, 
        dim_out = {1}, sub_layers = layer_repo, connections = connections, 
        clip = self.gconf.clip})
    tnn:init(self.gconf.batch_size, self.gconf.chunk_size)
    return tnn
end

function nn:get_data(data)
    local ret = {}
    for i = 1, #data do
        ret[i] = {}
        ret[i].input = data[i].input
        ret[i].output = data[i].output
        ret[i].flags = {}
        ret[i].err_input = {}
        for t = 1, self.gconf.chunk_size do
            ret[i].flags[t] = {}
            local err_input = self.gconf.mmat_type(self.gconf.batch_size, 1)
            for j = 1, self.gconf.batch_size do
                if t <= data[i].seq_len[j] then
                    ret[i].flags[t][j] = nerv.TNN.FC.SEQ_NORM
                    err_input[j - 1][0] = 1
                else
                    ret[i].flags[t][j] = 0
                    err_input[j - 1][0] = 0
                end
            end
            ret[i].err_input[t] = self.gconf.cumat_type.new_from_host(err_input)
        end
        for j = 1, self.gconf.batch_size do
            if data[i].seq_start[j] then
                ret[i].flags[1][j] = bit.bor(ret[i].flags[1][j], nerv.TNN.FC.SEQ_START)
            end
            if data[i].seq_end[j] then
                local t = data[i].seq_len[j]
                ret[i].flags[t][j] = bit.bor(ret[i].flags[t][j], nerv.TNN.FC.SEQ_END)
            end
        end
        ret[i].flagsPack = {}
        for t = 1, self.gconf.chunk_size do
            ret[i].flagsPack[t] = 0
            for j = 1, self.gconf.batch_size do
                ret[i].flagsPack[t] = bit.bor(ret[i].flagsPack[t], ret[i].flags[t][j])
            end
        end
        ret[i].seq_len = data[i].seq_len
    end
    return ret
end

function nn:process(data, do_train)
    local total_err = 0
    local total_frame = 0
    local reader = nerv.TNNReader(self.gconf, data)
    while true do
        local r, _ = self.tnn:getfeed_from_reader(reader)
        if not r then
            break
        end
        if do_train then
            self.gconf.dropout_rate = self.gconf.dropout
        else
            self.gconf.dropout_rate = 0
        end
        self.tnn:net_propagate()
        for t = 1, self.gconf.chunk_size do
            local tmp = self.tnn.outputs_m[t][1]:new_to_host()
            for i = 1, self.gconf.batch_size do
                if reader:has_data(t, i) then
                    total_err = total_err + math.log10(math.exp(tmp[i - 1][0]))
                    total_frame = total_frame + 1
                end
            end
        end
        if do_train then
            local err_input = reader:get_err_input()
            for i = 1, self.gconf.chunk_size do
                self.tnn.err_inputs_m[i][1]:copy_from(err_input[i])
            end
            self.tnn:net_backpropagate(false)
            self.tnn:net_backpropagate(true)
        end
        collectgarbage('collect')
    end
    return math.pow(10, - total_err / total_frame)
end

function nn:epoch()
    local train_error = self:process(self.train_data, true)
    local val_error = self:process(self.val_data, false)
    return train_error, val_error
end