aboutsummaryrefslogblamecommitdiff
path: root/nerv/examples/swb_baseline2.lua
blob: 87b01fa08b6800b98f0e2e9f34026a08e6ac0605 (plain) (tree)
1
2
3
4
5
6
7
8
9
10
                




                                                                                                        
                                                                                          
                       

                                                             




                                                                

                                                                          
                                                                                   
 
 





                                            




                                      

                                                                                                 


                              

                                                                                                     



                              
                                                                      
                                                                           
                                                                       
                                                                           
                                                                       
                                                                           
                                                                       
                                                                           
                                                                       
                                                                           
                                                                       
                                                                           
                                                                       
                                                                           
                                                                       
                                                                          


                               






                                                                         


                                                                                     
                                                                                   


                                                            
                                                                       




                          
                             
         
                             
                                                              
                                        
                               




                                                    
                 

                    
                                                               
                                        
                               















                                                     
                 
             




                          
                             
         
                         
                                                        
                                        
                               

                                                          


                                                    
                 

                              
                                                               
                                        
                               

                                                          

                                                    
                 
             





                         
                                        


                                            


                                               

                                        
                                       


                                                   
                                             

                         








                                                                  

   

                                      

   

                                         

   


                                                        

   



                                                   

   

                                                        




                                                                              
                                                                                            

                                          
require 'htk_io'
gconf = {lrate = 0.8,
        wcost = 1e-6,
        momentum = 0.9,
        frm_ext = 5,
        rearrange = true, -- just to make the context order consistent with old TNet results, deprecated
        frm_trim = 5, -- trim the first and last 5 frames, TNet just does this, deprecated
        chunk_size = 1,
        tr_scp = "/speechlab/users/mfy43/swb50/train_bp.scp",
        cv_scp = "/speechlab/users/mfy43/swb50/train_cv.scp",
        ali = {file = "/speechlab/users/mfy43/swb50/ref.mlf",
               format = "map",
               format_arg = "/speechlab/users/mfy43/swb50/dict",
               dir = "*/",
               ext = "lab"},
        htk_conf = "/speechlab/users/mfy43/swb50/plp_0_d_a.conf",
        initialized_param = {"/speechlab/users/mfy43/swb50/swb_init.nerv",
                            "/speechlab/users/mfy43/swb50/swb_global_transf.nerv"},
}

local input_size = 429
local output_size = 3001
local hidden_size = 2048
local trainer = nerv.Trainer

function trainer:make_layer_repo(param_repo)
    local layer_repo = nerv.LayerRepo(
    {
        -- global transf
        ["nerv.BiasLayer"] =
        {
            blayer1 = {dim_in = {input_size}, dim_out = {input_size}, params = {bias = "bias0"}},
            blayer2 = {dim_in = {input_size}, dim_out = {input_size}, params = {bias = "bias1"}}
        },
        ["nerv.WindowLayer"] =
        {
            wlayer1 = {dim_in = {input_size}, dim_out = {input_size}, params = {window = "window0"}},
            wlayer2 = {dim_in = {input_size}, dim_out = {input_size}, params = {window = "window1"}}
        },
        -- biased linearity
        ["nerv.AffineLayer"] =
        {
            affine0 = {dim_in = {input_size}, dim_out = {hidden_size},
                        params = {ltp = "affine0_ltp", bp = "affine0_bp"}},
            affine1 = {dim_in = {hidden_size}, dim_out = {hidden_size},
                        params = {ltp = "affine1_ltp", bp = "affine1_bp"}},
            affine2 = {dim_in = {hidden_size}, dim_out = {hidden_size},
                        params = {ltp = "affine2_ltp", bp = "affine2_bp"}},
            affine3 = {dim_in = {hidden_size}, dim_out = {hidden_size},
                        params = {ltp = "affine3_ltp", bp = "affine3_bp"}},
            affine4 = {dim_in = {hidden_size}, dim_out = {hidden_size},
                        params = {ltp = "affine4_ltp", bp = "affine4_bp"}},
            affine5 = {dim_in = {hidden_size}, dim_out = {hidden_size},
                        params = {ltp = "affine5_ltp", bp = "affine5_bp"}},
            affine6 = {dim_in = {hidden_size}, dim_out = {hidden_size},
                        params = {ltp = "affine6_ltp", bp = "affine6_bp"}},
            affine7 = {dim_in = {hidden_size}, dim_out = {output_size},
                        params = {ltp = "affine7_ltp", bp = "affine7_bp"}}
        },
        ["nerv.SigmoidLayer"] =
        {
            sigmoid0 = {dim_in = {hidden_size}, dim_out = {hidden_size}},
            sigmoid1 = {dim_in = {hidden_size}, dim_out = {hidden_size}},
            sigmoid2 = {dim_in = {hidden_size}, dim_out = {hidden_size}},
            sigmoid3 = {dim_in = {hidden_size}, dim_out = {hidden_size}},
            sigmoid4 = {dim_in = {hidden_size}, dim_out = {hidden_size}},
            sigmoid5 = {dim_in = {hidden_size}, dim_out = {hidden_size}},
            sigmoid6 = {dim_in = {hidden_size}, dim_out = {hidden_size}}
        },
        ["nerv.SoftmaxCELayer"] = -- softmax + ce criterion layer for finetune output
        {
            ce_crit = {dim_in = {output_size, 1}, dim_out = {1}, compressed = true}
        },
        ["nerv.SoftmaxLayer"] = -- softmax for decode output
        {
            softmax = {dim_in = {output_size}, dim_out = {output_size}}
        }
    }, param_repo, gconf)

    layer_repo:add_layers(
    {
        ["nerv.GraphLayer"] =
        {
            global_transf = {
                dim_in = {input_size}, dim_out = {input_size},
                layer_repo = layer_repo,
                connections = {
                    {"<input>[1]", "blayer1[1]", 0},
                    {"blayer1[1]", "wlayer1[1]", 0},
                    {"wlayer1[1]", "blayer2[1]", 0},
                    {"blayer2[1]", "wlayer2[1]", 0},
                    {"wlayer2[1]", "<output>[1]", 0}
                }
            },
            main = {
                dim_in = {input_size}, dim_out = {output_size},
                layer_repo = layer_repo,
                connections = {
                    {"<input>[1]", "affine0[1]", 0},
                    {"affine0[1]", "sigmoid0[1]", 0},
                    {"sigmoid0[1]", "affine1[1]", 0},
                    {"affine1[1]", "sigmoid1[1]", 0},
                    {"sigmoid1[1]", "affine2[1]", 0},
                    {"affine2[1]", "sigmoid2[1]", 0},
                    {"sigmoid2[1]", "affine3[1]", 0},
                    {"affine3[1]", "sigmoid3[1]", 0},
                    {"sigmoid3[1]", "affine4[1]", 0},
                    {"affine4[1]", "sigmoid4[1]", 0},
                    {"sigmoid4[1]", "affine5[1]", 0},
                    {"affine5[1]", "sigmoid5[1]", 0},
                    {"sigmoid5[1]", "affine6[1]", 0},
                    {"affine6[1]", "sigmoid6[1]", 0},
                    {"sigmoid6[1]", "affine7[1]", 0},
                    {"affine7[1]", "<output>[1]", 0}
                }
            }
        }
    }, param_repo, gconf)

    layer_repo:add_layers(
    {
        ["nerv.GraphLayer"] =
        {
            ce_output = {
                dim_in = {input_size, 1}, dim_out = {1},
                layer_repo = layer_repo,
                connections = {
                    {"<input>[1]", "global_transf[1]", 0},
                    {"global_transf[1]", "main[1]", 0},
                    {"main[1]", "ce_crit[1]", 0},
                    {"<input>[2]", "ce_crit[2]", 0},
                    {"ce_crit[1]", "<output>[1]", 0}
                }
            },
            softmax_output = {
                dim_in = {input_size}, dim_out = {output_size},
                layer_repo = layer_repo,
                connections = {
                    {"<input>[1]", "global_transf[1]", 0},
                    {"global_transf[1]", "main[1]", 0},
                    {"main[1]", "softmax[1]", 0},
                    {"softmax[1]", "<output>[1]", 0}
                }
            }
        }
    }, param_repo, gconf)

    return layer_repo
end

function trainer:get_network(layer_repo)
    return layer_repo:get_layer("ce_output")
end

function trainer:get_readers(dataset)
    local function reader_gen(scp, ali)
        return {{reader = nerv.HTKReader(gconf,
                    {
                        id = "main_scp",
                        scp_file = scp,
                        conf_file = gconf.htk_conf,
                        frm_ext = gconf.frm_ext,
                        mlfs = {
                            phone_state = ali
                        }
                    }),
                data = {main_scp = input_size, phone_state = 1}}}
    end
    if dataset == 'train' then
        return reader_gen(gconf.tr_scp, gconf.tr_ali or gconf.ali)
    elseif dataset == 'validate' then
        return reader_gen(gconf.cv_scp, gconf.cv_ali or gconf.ali)
    else
        nerv.error('no such dataset')
    end
end

function trainer:get_input_order()
    return {"main_scp", "phone_state"}
end

function trainer:get_decode_input_order()
    return {"main_scp"}
end

function trainer:get_error()
    local ce_crit = self.layer_repo:get_layer("ce_crit")
    return ce_crit.total_ce / ce_crit.total_frames
end

function trainer:mini_batch_afterprocess(cnt, info)
    if cnt % 1000 == 0 then
        self:epoch_afterprocess()
    end
end

function trainer:epoch_afterprocess()
    local ce_crit = self.layer_repo:get_layer("ce_crit")
    nerv.info("*** training stat begin ***")
    nerv.printf("cross entropy:\t\t%.8f\n", ce_crit.total_ce)
    nerv.printf("correct:\t\t%d\n", ce_crit.total_correct)
    nerv.printf("frames:\t\t\t%d\n", ce_crit.total_frames)
    nerv.printf("err/frm:\t\t%.8f\n", ce_crit.total_ce / ce_crit.total_frames)
    nerv.printf("accuracy:\t\t%.3f%%\n", ce_crit.total_correct / ce_crit.total_frames * 100)
    nerv.info("*** training stat end ***")
end