aboutsummaryrefslogblamecommitdiff
path: root/nerv/examples/lmptb/lmptb/lmseqreader.lua
blob: 127292940851d78de0cc7c071f649bb2e2ae138a (plain) (tree)
1
2
3
4
5
6
7
8
9
10
11
                       
                      
                    







                                               
                                                                            


                                                                    
                                
                                      

                      






                                                                                                       



                                           



                                                                                  





                               
                                                                                              
       


                                                                                                           


                                    
                                                                      
       

                                                                                                           



                                                                                           


                                                                              
                                                                                                                                                             
       


         
                                                                                         






                                                  
                                               





                                                           










                                                                                                                           





                                   
                                                  
               






                                                    

                                   

                                    
                         

       

                                                                      
                                         
 
                         


                                    
                                    
                                  
                                                                                         
                                        
                           



                                                                                                        
                                                      
                                                     


                                                         
                                                      
                

                                                






                                                                                                                                       

                                                          
                                                         
                   

                                                          




                                                                                                                 
                    
                                                                   


                                                                                           
                   





                                                                                                       


                                                                                 
                                           
                                         







                                                                                       
                   
               
           
       
    




                                                             
                                                          


                                                              

       

                                 
                                                                                     





                                                             
                                                                          
                    
        
                   





















                                                                                                           
require 'lmptb.lmvocab'
require 'lmptb.lmutil'
--require 'tnn.init'

local LMReader = nerv.class("nerv.LMSeqReader")

local printf = nerv.printf

--global_conf: table
--batch_size: int
--vocab: nerv.LMVocab
function LMReader:__init(global_conf, batch_size, chunk_size, vocab, r_conf)
    self.gconf = global_conf
    self.fh = nil --file handle to read, nil means currently no file
    self.batch_size = batch_size
    self.chunk_size = chunk_size
    self.log_pre = "[LOG]LMSeqReader:"
    self.vocab = vocab
    self.streams = nil
    if r_conf == nil then
        r_conf = {}
    end
    self.se_mode = false --sentence end mode, when a sentence end is met, the stream after will be null
    if r_conf.se_mode == true then
        self.se_mode = true
    end
    self.compressed_label = false
    if r_conf.compressed_label == true then
        self.compressed_label = true
    end
    self.same_io = false
    if r_conf.same_io == true then --can be used to train P(wi|w1..(i-1),(i+1)..n)
        self.same_io = true
    end
end

--fn: string
--Initialize all streams
function LMReader:open_file(fn)
    if (self.fh ~= nil) then
        nerv.error("%s error: in open_file(fn is %s), file handle not nil.", self.log_pre, fn)
    end
    nerv.printf("%s opening file %s...\n", self.log_pre, fn)
    nerv.printf("%s batch_size:%d chunk_size:%d\n", self.log_pre, self.batch_size, self.chunk_size)
    nerv.printf("%s se_mode:%s same_io:%s\n", self.log_pre, tostring(self.se_mode), tostring(self.same_io))
    self.fh = io.open(fn, "r")
    self.streams = {}
    for i = 1, self.batch_size, 1 do
        self.streams[i] = {["store"] = {}, ["head"]