aboutsummaryrefslogblamecommitdiff
path: root/nerv/nn/network.lua
blob: 910cdad9a9a19cdb42540101759ef5bd69d73cee (plain) (tree)
1
2
3
4
5
6
7

                                          

                                                      


                                       










                                                     
 
                    

                         
                                            















                                                        
 
                  
                              

                                                  


                                                


                                                 


               




















                                                        




                                                      
                                        
                               

                                 

                                               
                                         


                                          







                                                                          





                                                       

                                                        
                   


                                                                       



                                                                          


                                                             

                                                           



                                                                                                   
                       

                                                                     





                   





                                             
 

                             




                                 
   
 
                             
                      
                              
                                                             


                                                                       
       


                          
                                      


                                 
                                  







                                                       



                                                                         















                                                   
                   



                                                          






                                                                     










                                               
                                        



















                                                                               












                                                                                     


           
                                   
































































































                                                                                 







                                                              




               
                                 
                                 










                                                                 
                                 










                                                                 
                                 





















                                                                      












                                                                                                      



                                                 

                       
                       
                                       
                                 


                                                         


































                                                                 
               

           
 
                  






                                                                                                  
                   


                                                                                 
                   


               
 

   
                            




                                                                               
                                  




                                                                             




               
                                 


                                                           
                                                                                                                                   
                            





                                                                         
           







                                                               


       
                         


                                                                                               

           
   











                                        
local network = nerv.class('nerv.Network')

function network:__init(id, global_conf, network_conf)
    self.id = id
    self.network = network_conf.network
    self.dim_in = self.network.dim_in
    self.dim_out = self.network.dim_out
    self.gconf = global_conf
    if self.gconf.use_cpu then
        self.mat_type = self.gconf.mmat_type
    else
        self.mat_type = self.gconf.cumat_type
    end
    self.clip = network_conf.clip
    self.nn_act_default = network_conf.nn_act_default
    if self.nn_act_default == nil then
        self.nn_act_default = 0
    end

    self.layers = {}
    self.input_conn = {}
    self.output_conn = {}
    self.socket = self:compile(self.network)
    for i = 1, #self.dim_in do
        local edge = self.socket.inputs[i]
        local id, port, time = edge[1], edge[2], edge[3]
        if self.input_conn[id][port] ~= nil then
            nerv.error('duplicate edge')
        end
        self.input_conn[id][port] = {0, i, time}
    end
    for i = 1, #self.dim_out do
        local edge = self.socket.outputs[i]
        local id, port, time = edge[1], edge[2], edge[3]
        if self.output_conn[id][port] ~= nil then
            nerv.error('duplicate edge')
        end
        self.output_conn[id][port] = {0, i, time}
    end

    self.delay = 0
    for i = 1, #self.layers do
        local dim_in, _ = self.layers[i]:get_dim()
        for j = 1, #dim_in do
            if self.input_conn[i][j] == nil then
                nerv.error('dangling input')
            end
            local time = self.input_conn[i][j][3]
            if math.abs(time) > self.delay then
                self.delay = math.abs(time)
            end
        end
    end

    self.input_edge = {}
    self.output_edge = {}
    for t = -self.delay, self.delay do
        self.input_edge[t] = {}
        self.output_edge[t] = {}
    end
    for i = 1, #self.layers do
        local dim_in, dim_out = self.layers[i]:get_dim()
        for j = 1, #dim_in do
            local time = self.input_conn[i][j][3]
            table.insert(self.input_edge[time], {i, j})
        end
        for j = 1, #dim_out do
            if self.output_conn[i][j] == nil then
                nerv.error('dangling output')
            end
            local time = self.output_conn[i][j][3]
            table.insert(self.output_edge[time], {i, j})
        end
    end
end

function network:compile(layer)
    local socket = {inputs = {}, outputs = {}}
    if not nerv.is_type(layer, 'nerv.GraphLayer') then
        table.insert(self.layers, layer)
        local id = #self.layers
        self.input_conn[id] = {}
        self.output_conn[id] = {}
        local dim_in, dim_out = layer:get_dim()
        for i = 1, #dim_in do
            socket.inputs[i] = {id, i, 0}
        end
        for i = 1, #dim_out do
            socket.outputs[i] = {id, i, 0}
        end
    else
        local sublayer_socket = {}
        for id, sublayer in pairs(layer.layers) do
            if id ~= '<input>' then
               sublayer_socket[sublayer.id] = self:compile(sublayer.layer)
            end
        end
        for _, edge in pairs(layer.connections) do
            -- id = 0 means <input> or <output>
            local id_from, port_from = edge[1], edge[2]
            local id_to, port_to = edge[3], edge[4]
            local time = edge[5]
            if id_from == 0 then
                if socket.inputs[port_from] ~= nil then
                    nerv.error('duplicate input socket')
                end
                local input = sublayer_socket[id_to].inputs[port_to]
                local id, port, t = input[1], input[2], input[3] + time
                socket.inputs[port_from] = {id, port, t}
            else
                local output = sublayer_socket[id_from].outputs[port_from]
                local id, port, t = output[1], output[2], output[3] + time
                if id_to == 0 then
                    if socket.outputs[port_to] ~= nil then
                        nerv.error('duplicate output socket')
                    end
                    socket.outputs[port_to] = {id, port, t}
                else
                    local input = sublayer_socket[id_to].inputs[port_to]
                    local id1, port1, t1 = input[1], input[2], input[3]
                    if self.input_conn[id1][port1] ~= nil or self.output_conn[id][port] ~= nil then
                        nerv.error('duplicate edge')
                    end
                    self.input_conn[id1][port1] = {id, port, t + t1}
                    self.output_conn[id][port] = {id1, port1, t + t1}
                end
            end
        end
    end
    return socket
end

function network:init(batch_size, chunk_size)
    self.batch_size = batch_size
    self.chunk_size = chunk_size

    self:topsort()

    self:make_initial_store()
    collectgarbage('collect')

    self.flush = {}
    for t = 1, self.chunk_size do
        self.flush[t] = {}
    end
end

function network:epoch_init()
    self.timestamp = 0
    for i = 1, #self.layers do
        self.layers[i]:init(self.batch_size, self.chunk_size)
        for t = 1, self.chunk_size do
            self.flush[t][i] = {timestamp = 0, input = {}, output = {}}
        end
    end
end

function network:topsort()
    nerv.info('network topology sort')
    local degree = {}
    for t = 1, self.chunk_size do
        degree[t] = {}
        for i = 1, #self.layers do
            degree[t][i] = 0
        end
    end

    for t = 1, self.chunk_size do
        for i = 1, #self.layers do
            local _, dim_out = self.layers[i]:get_dim()
            for j = 1, #dim_out do
                local edge = self.output_conn[i][j]
                local id, time = edge[1], edge[3] + t
                if time >= 1 and time <= self.chunk_size and id ~= 0 then
                    degree[time][id] = degree[time][id] + 1
                end
            end
        end
    end

    self.queue = {}
    local l = 1
    local r = 0
    for t = 1, self.chunk_size do
        for i = 1, #self.layers do
            if degree[t][i] == 0 then
                r = r + 1
                self.queue[r] = {chunk = t, id = i}
            end
        end
    end
    while l <= r do
        local t, i = self.queue[l].chunk, self.queue[l].id
        l = l + 1
        local _, dim_out = self.layers[i]:get_dim()
        for j = 1, #dim_out do
            local edge = self.output_conn[i][j]
            local id, time = edge[1], edge[3] + t
            if time >= 1 and time <= self.chunk_size and id ~= 0 then
                degree[time][id] = degree[time][id] - 1
                if degree[time][id] == 0 then
                    r = r + 1
                    self.queue[r] = {chunk = time, id = id}
                end
            end
        end
    end

    if r ~= self.chunk_size * #self.layers then
        nerv.error('loop detected')
    end
end

function network:make_initial_store()
    nerv.info('network initing storage')

    -- allocate memory
    local memory = {}
    local err_memory = {}
    for t = 1 - self.delay, self.chunk_size + self.delay do
        memory[t] = {}
        err_memory[t] = {}
        for i = 1, #self.layers do
            memory[t][i] = {}
            err_memory[t][i] = {}
            local dim_in, dim_out = self.layers[i]:get_dim()
            for j = 1, #dim_in do
                err_memory[t][i][j] = self.mat_type(self.batch_size, dim_in[j])
                err_memory[t][i][j]:fill(0)
            end
            for j = 1, #dim_out do
                memory[t][i][j] = self.mat_type(self.batch_size, dim_out[j])
                memory[t][i][j]:fill(self.nn_act_default)
            end
        end
        if t < 1 or t > self.chunk_size then
            -- memory[t][0] stores network input
            memory[t][0] = {}
            for j = 1, #self.dim_in do
                memory[t][0][j] = self.mat_type(self.batch_size, self.dim_in[j])
                memory[t][0][j]:fill(self.nn_act_default)
            end
            -- err_memory[t][0] stores network err_input
            err_memory[t][0] = {}
            for j = 1, #self.dim_out do
                err_memory[t][0][j] = self.mat_type(self.batch_size, self.dim_out[j])
                err_memory[t][0][j]:fill(0)
            end
        end
    end

    -- connect memory and reference
    self.input = {}
    self.output = {}
    self.err_input = {}
    self.err_output = {}
    for t = 1, self.chunk_size do
        self.input[t] = {}
        self.output[t] = {}
        self.err_input[t] = {}
        self.err_output[t] = {}
        for i = 1, #self.layers do
            self.input[t][i] = {}
            self.output[t][i] = {}
            self.err_input[t][i] = {}
            self.err_output[t][i] = {}
            local dim_in, dim_out = self.layers[i]:get_dim()
            for j = 1, #dim_in do
                local edge = self.input_conn[i][j]
                local id, port, time = edge[1], edge[2], edge[3]
                if id ~= 0 or t - time < 1 or t - time > self.chunk_size then
                    self.input[t][i][j] = memory[t - time][id][port]
                end
                if id ~= 0 then
                    self.err_output[t][i][j] = err_memory[t][i][j]
                end
            end
            for j = 1, #dim_out do
                local edge = self.output_conn[i][j]
                local id, port, time = edge[1], edge[2], edge[3]
                if id ~= 0 then
                    self.output[t][i][j] = memory[t][i][j]
                end
                if id ~= 0 or t + time < 1 or t + time > self.chunk_size then
                    self.err_input[t][i][j] = err_memory[t + time][id][port]
                end
            end
        end
    end

    -- check dangling reference
    for t = 1, self.chunk_size do
        for i = 1, #self.dim_in do
            local edge = self.socket.inputs[i]
            local id, port, time = edge[1], edge[2], edge[3]
            if t + time >= 1 and t + time <= self.chunk_size then
                if self.input[t + time][id][port] ~= nil then
                    nerv.error('input reference not nil')
                end
                self.input[t + time][id][port] = true      -- just a place holder
                if self.err_output[t + time][id][port] ~= nil then
                    nerv.error('err_output reference not nil')
                end
                self.err_output[t + time][id][port] = true -- just a place holder
            end
        end
        for i = 1, #self.dim_out do
            local edge = self.socket.outputs[i]
            local id, port, time = edge[1], edge[2], edge[3]
            if t - time >= 1 and t - time <= self.chunk_size then
                if self.output[t - time][id][port] ~= nil then
                    nerv.error('output reference not nil')
                end
                self.output[t - time][id][port] = true     -- just a place holder
                if self.err_input[t - time][id][port] ~= nil then
                    nerv.error('err_output reference not nil')
                end
                self.err_input[t - time][id][port] = true  -- just a place holder
            end
        end
    end
    for t = 1, self.chunk_size do
        for i = 1, #self.layers do
            local dim_in, dim_out = self.layers[i]:get_dim()
            for j = 1, #dim_in do
                if self.input[t][i][j] == nil then
                    nerv.error('input reference dangling')
                end
                if self.err_output[t][i][j] == nil then
                    nerv.error('err_output reference dangling')
                end
            end
            for j = 1, #dim_out do
                if self.output[t][i][j] == nil then
                    nerv.error('output reference dangling')
                end
                if self.err_input[t][i][j] == nil then
                    nerv.error('err_input reference dangling')
                end
            end
        end
    end

    -- allocate reference for legacy of previous mini-batch
    self.legacy = {}
    for t = 1 - self.delay, 0 do
        self.legacy[t] = {}
        for i = 1, #self.layers do
            self.legacy[t][i] = {}
        end
    end
    for d = 1, self.delay do
        for t = 1 - d, 0 do
            for i = 1, #self.output_edge[d] do
                local edge = self.output_edge[d][i]
                local id, port = edge[1], edge[2]
                self.legacy[t][id][port] = memory[t][id][port]
            end
        end
    end
end

function network:set_input(input)
    for t = 1, self.chunk_size do
        for i = 1, #self.dim_in do
            local edge = self.socket.inputs[i]
            local id, port, time = edge[1], edge[2], edge[3]
            if t + time >= 1 and t + time <= self.chunk_size then
                self.input[t + time][id][port] = input[t][i]
            end
        end
    end
end

function network:set_output(output)
    for t = 1, self.chunk_size do
        for i = 1, #self.dim_out do
            local edge = self.socket.outputs[i]
            local id, port, time = edge[1], edge[2], edge[3]
            if t - time >= 1 and t - time <= self.chunk_size then
                self.output[t - time][id][port] = output[t][i]
            end
        end
    end
end

function network:set_err_input(err_input)
    for t = 1, self.chunk_size do
        for i = 1, #self.dim_out do
            local edge = self.socket.outputs[i]
            local id, port, time = edge[1], edge[2], edge[3]
            if t - time >= 1 and t - time <= self.chunk_size then
                self.err_input[t - time][id][port] = err_input[t][i]
            end
        end
    end
end

function network:set_err_output(err_output)
    for t = 1, self.chunk_size do
        for i = 1, #self.dim_in do
            local edge = self.socket.inputs[i]
            local id, port, time = edge[1], edge[2], edge[3]
            if t + time >= 1 and t + time <= self.chunk_size then
                self.err_output[t + time][id][port] = err_output[t][i]
            end
        end
    end
end

--[[
    [info] is a table that contains information of current mini-batch. These fields must be contained:
        [input], [output] : matrix array which stores the network input and output
        [seq_length] : a table contains the length of every sequences
        [new_seq]: a table contains the batch number of new sequences
        [do_train]: a bool value indicates do train or not
        if [do_train] is true, these fileds also must be contained:
            [err_input], [err_output] : matrix array which stores the network err_input and err_output
--]]
function network:mini_batch_init(info)
    self.info = info
    self:set_input(self.info.input)
    self:set_output(self.info.output)
    if self.info.do_train then
        self:set_err_input(self.info.err_input)
        self:set_err_output(self.info.err_output)
    end

    -- calculate border
    self.max_length = 0
    self.timestamp = self.timestamp + 1
    for i = 1, self.batch_size do
        if self.info.seq_length[i] > self.max_length then
            self.max_length = self.info.seq_length[i]
        end
        local border = self.info.seq_length[i]
        for d = 1, self.delay do
            for t = border + 1, border + d do
                if t > self.chunk_size then
                    break
                end
                for j = 1, #self.output_edge[-d] do
                    local edge = self.output_edge[-d][j]
                    local id, port = edge[1], edge[2]
                    local flush = self.flush[t][id]
                    if flush.timestamp ~= self.timestamp then
                        flush.timestamp = self.timestamp
                        flush.input = {}
                        flush.output = {}
                    end
                    table.insert(flush.output, {port, i})
                end
            end
            if self.info.do_train then
                for t = border, border - d + 1, -1 do
                    if t < 1 then
                        break
                    end
                    for j = 1, #self.input_edge[-d] do
                        local edge = self.input_edge[-d][j]
                        local id, port = edge[1], edge[2]
                        local flush = self.flush[t][id]
                        if flush.timestamp ~= self.timestamp then
                            flush.timestamp = self.timestamp
                            flush.input = {}
                            flush.output = {}
                        end
                        table.insert(flush.input, {port, i})
                    end
                end
            end
        end
    end

    -- copy legacy
    for d = 1, self.delay do
        for t = 1 - d, 0 do
            for i = 1, #self.output_edge[d] do
                local edge = self.output_edge[d][i]
                local id, port = edge[1], edge[2]
                if t + self.chunk_size >= 1 and self.output_conn[id][port][1] ~= 0 then
                    self.legacy[t][id][port]:copy_from(self.output[t + self.chunk_size][id][port])
                end
                for j = 1, #self.info.new_seq do
                    local batch = self.info.new_seq[j]
                    self.legacy[t][id][port][batch - 1]:fill(self.nn_act_default)
                end
            end
        end
    end

end

function network:propagate()
    for i = 1, #self.queue do
        local t, id = self.queue[i].chunk, self.queue[i].id
        if t <= self.max_length then
            self.layers[id]:propagate(self.input[t][id], self.output[t][id], t)
        end
        -- flush border activation
        if self.flush[t][id].timestamp == self.timestamp then
            for j = 1, #self.flush[t][id].output do
                local border = self.flush[t][id].output[j]
                local port, batch = border[1], border[2]
                self.output[t][id][port][batch - 1]:fill(self.nn_act_default)
            end
        end
    end
end

function network:back_propagate()
    for i = #self.queue, 1, -1 do
        local t, id = self.queue[i].chunk, self.queue[i].id
        if t <= self.max_length then
            self.layers[id]:back_propagate(self.err_input[t][id], self.err_output[t][id], self.input[t][id], self.output[t][id], t)
            -- gradient clip
            if self.clip ~= nil then
                local dim_in, _ = self.layers[id]:get_dim()
                for j = 1, #dim_in do
                    self.err_output[t][id][j]:clip(-self.clip, self.clip)
                end
            end
        end
        -- flush border gradient
        if self.flush[t][id].timestamp == self.timestamp then
            for j = 1, #self.flush[t][id].input do
                local border = self.flush[t][id].input[j]
                local port, batch = border[1], border[2]
                self.err_output[t][id][port][batch - 1]:fill(0)
            end
        end
    end
end

function network:update()
    for t = 1, self.max_length do
        for i = 1, #self.layers do
            self.layers[i]:update(self.err_input[t][i], self.input[t][i], self.output[t][i], t)
        end
    end
end

function network:set_attr(name, value)
    self.network:set_attr(name, value)
end

function network:get_sublayer(id)
    return self.network:get_sublayer(id)
end

function network:get_params()
    return self.network:get_params()
end