aboutsummaryrefslogblamecommitdiff
path: root/nerv/examples/lmptb/lm_sampler.lua
blob: c9adf85c4b610f5647f7fda5cbd027142321a191 (plain) (tree)
1
2
3
4
5
6
7
8
9
10




                                              

                                                                            


                                                                     

                       

   
                                     


                                                  
                                   

                         
                                                                         
                                                 
                                                                                              


                               

                                                                                                
    


                                                                            








                                               
                  

                      

   
                                                   

                                 

















                                                                                                                             

                                    
                                           
                                   




                             
            





                                                     




                                                                



                                        
                                                         

                               


                                     

                                    



                                                                 

                                                
        
                                          



                                                                                             
                                                             












                                                                               




                                      

              
local LMSampler = nerv.class('nerv.LMSampler')

function LMSampler:__init(global_conf)
    self.log_pre = "LMSampler"
    self.gconf = global_conf
    self.batch_size = self.gconf.batch_size
    self.chunk_size = self.gconf.chunk_size --largest sample sentence length
    self.vocab = self.gconf.vocab
    self.sen_end_token = self.vocab.sen_end_token
    self.sen_end_id = self.vocab:get_word_str(self.sen_end_token).id 

    self.loaded = false
end

function LMSampler:load_dagL(dagL)   
    nerv.printf("%s loading dagL\n", self.log_pre)

    self.dagL = dagL
    self.dagL:init(self.batch_size)

    self.dagL_inputs = {}
    self.dagL_inputs[1] = self.gconf.cumat_type(self.gconf.batch_size, 1)
    self.dagL_inputs[1]:fill(self.sen_end_id - 1)
    self.dagL_inputs[2] = self.gconf.cumat_type(self.gconf.batch_size, self.gconf.hidden_size)
    self.dagL_inputs[2]:fill(0)
    
    self.dagL_outputs = {}
    self.dagL_outputs[1] = self.gconf.cumat_type(self.gconf.batch_size, self.gconf.vocab:size())
    self.dagL_outputs[2] = self.gconf.cumat_type(self.gconf.batch_size, self.gconf.hidden_size)
    
    self.smout_d = self.gconf.cumat_type(self.batch_size, self.vocab:size())
    self.ssout_d = self.gconf.cumat_type(self.batch_size, self.vocab:size())
    self.ssout_h = self.gconf.mmat_type(self.batch_size, self.vocab:size())

    self.store = {}
    for i = 1, self.batch_size do
        self.store[i] = {}
        self.store[i][1] = {}
        self.store[i][1].w = self.sen_end_token
        self.store[i][1].id = self.sen_end_id
        self.store[i][1].p = 0
    end
    self.repo = {}

    self.loaded = true
end

function LMSampler:sample_to_store(ssout) --private
    for i = 1, self.batch_size do
        local ran = math.random()
        local id = 1
        local low = 0
        local high = ssout:ncol() - 1
        if ssout[i - 1][high] < 0.9999 or ssout[i - 1][high] > 1.0001 then
            nerv.error("%s ERROR, softmax output summation(%f) seems to have some problem", self.log_pre, ssout[i - 1][high])
        end
        if ssout[i - 1][low] < ran then
            while low + 1 < high do
                local mid = math.floor((low + high) / 2)
                if ssout[i - 1][mid] < ran then
                    low = mid
                else
                    high = mid
                end
            end
            id = high + 1
        end
        --[[
        local s = 0
        local id = self.vocab:size()
        for j = 0, self.vocab:size() - 1 do
            s = s + smout[i - 1][j]
            if s >= ran then 
                id = j + 1
                break
            end
        end
        ]]--
        if #self.store[i] >= self.chunk_size - 2 then
            id = self.sen_end_id
        end
        local tmp = {}
        tmp.w = self.vocab:get_word_id(id).str
        tmp.id = id
        if id == 1 then
            tmp.p = ssout[i - 1][id - 1]
        else
            tmp.p = ssout[i - 1][id - 1] - ssout[i - 1][id - 2] 
        end
        table.insert(self.store[i], tmp)
    end
end

function LMSampler:lm_sample_rnn_dagL(sample_num, p_conf)
    assert(self.loaded == true)

    local dagL = self.dagL
    local inputs = self.dagL_inputs
    local outputs = self.dagL_outputs
    
    while #self.repo < sample_num do
        dagL:propagate(inputs, outputs)
        inputs[2]:copy_fromd(outputs[2]) --copy hidden activation
    
        self.smout_d:softmax(outputs[1])
        self.ssout_d:prefixsum_row(self.smout_d)
        self.ssout_d:copy_toh(self.ssout_h)
        
        self:sample_to_store(self.ssout_h)
        for i = 1, self.batch_size do
            inputs[1][i - 1][0] = self.store[i][#self.store[i]].id - 1
            if self.store[i][#self.store[i]].id == self.sen_end_id then --meet a sentence end
                if #self.store[i] >= 3 then
                    self.repo[#self.repo + 1] = self.store[i]
                end
                inputs[2][i - 1]:fill(0)
                self.store[i] = {}
                self.store[i][1] = {}
                self.store[i][1].w = self.sen_end_token
                self.store[i][1].id = self.sen_end_id
                self.store[i][1].p = 0
            end
        end

        collectgarbage("collect")                                              
    end

    local res = {}
    for i = 1, sample_num do
        res[i] = self.repo[#self.repo]
        self.repo[#self.repo] = nil
    end
    return res
end