aboutsummaryrefslogblamecommitdiff
path: root/nerv/examples/lmptb/main.lua
blob: 8764998cdd49c0a8ddb6c8fc7d4bebee401c4879 (plain) (tree)
























































                                                                                                  
                                                                                                                                                                                                                                                                                                  
                    
                                      


                                                



                                                                                                               














                                                                                                                                                      
                                                                                     
                                                                                                                                                    
                                                                                                                                                    













                                                                    
                                                          






                                                             
                                               



                                           
                                                    





                                                    

                                                                           


                                                                            
                                                                                                 














                                                                                                                          






















                                                                                                              







                                                                                                               



                                                                                
                                                                                                          





                                                                                                                   

                                  



                                                                              
                                                 










                                                                                                          




                                                                                                    





                                                                                  









                                                                                                                                                       






                                                                                                               




                                                       

                                                                               





















                                                                     



                                                                                                         
                   
                                              


                                        
                          
                        

                                                                  






                                                                       

                                                                                    



                                                                                                                                               



                                                                          






                                                
                                                                  











                                                                                    
                            


















                                                                                                                    
                 
                                                                 
                   
                     
                           







                                                                                                                        

                                                                      
                                                         
                            











                                                                                                        
                                                                                             















                                                                                                   

 
require 'lmptb.lmvocab'
require 'lmptb.lmfeeder'
require 'lmptb.lmutil'
nerv.include('lmptb/layer/init.lua')

--[[global function rename]]--
printf = nerv.printf
--[[global function rename ends]]--

--global_conf: table
--first_time: bool
--Returns: a ParamRepo
function prepare_parameters(global_conf, first_time)
    printf("%s preparing parameters...\n", global_conf.sche_log_pre) 
    
    if (first_time) then
        ltp_ih = nerv.LinearTransParam("ltp_ih", global_conf)  
        ltp_ih.trans = global_conf.cumat_type(global_conf.vocab:size(), global_conf.hidden_size)  
        ltp_ih.trans:generate(global_conf.param_random)

        ltp_hh = nerv.LinearTransParam("ltp_hh", global_conf)
        ltp_hh.trans = global_conf.cumat_type(global_conf.hidden_size, global_conf.hidden_size)
        ltp_hh.trans:generate(global_conf.param_random) 

        ltp_ho = nerv.LinearTransParam("ltp_ho", global_conf)
        ltp_ho.trans = global_conf.cumat_type(global_conf.hidden_size, global_conf.vocab:size())
        ltp_ho.trans:generate(global_conf.param_random)

        bp_h = nerv.BiasParam("bp_h", global_conf)
        bp_h.trans = global_conf.cumat_type(1, global_conf.hidden_size)
        bp_h.trans:generate(global_conf.param_random)

        bp_o = nerv.BiasParam("bp_o", global_conf)
        bp_o.trans = global_conf.cumat_type(1, global_conf.vocab:size())
        bp_o.trans:generate(global_conf.param_random)

        local f = nerv.ChunkFile(global_conf.param_fn, 'w')
        f:write_chunk(ltp_ih)
        f:write_chunk(ltp_hh)
        f:write_chunk(ltp_ho)
        f:write_chunk(bp_h)
        f:write_chunk(bp_o)
        f:close()
    end
    
    local paramRepo = nerv.ParamRepo()
    paramRepo:import({global_conf.param_fn}, nil, global_conf)

    printf("%s preparing parameters end.\n", global_conf.sche_log_pre)

    return paramRepo
end

--global_conf: table
--Returns: nerv.LayerRepo
function prepare_layers(global_conf, paramRepo)
    printf("%s preparing layers...\n", global_conf.sche_log_pre)
    local recurrentLconfig = {{["bp"] = "bp_h", ["ltp_hh"] = "ltp_hh"}, {["dim_in"] = {global_conf.hidden_size, global_conf.hidden_size}, ["dim_out"] = {global_conf.hidden_size}, ["break_id"] = global_conf.vocab:get_sen_entry().id, ["independent"] = global_conf.independent, ["clip"] = 10}}
    local layers = {
        ["nerv.IndRecurrentLayer"] = {
            ["recurrentL1"] = recurrentLconfig, 
        },

        ["nerv.SelectLinearLayer"] = {
            ["selectL1"] = {{["ltp"] = "ltp_ih"}, {["dim_in"] = {1}, ["dim_out"