aboutsummaryrefslogblamecommitdiff
path: root/nerv/examples/lmptb/m-tests/tnn_test.lua
blob: c4890b6942e45a6f45b0131d641f6dcc99511dd0 (plain) (tree)
1
2
3
4
5
6



                          
                  
                           














                                                                                                                                                                         






























































                                                                                                                                                                                                                                                                                                    
                                    





































































                                                                                                                                                    

                                              


                       




                                                



                                                
                                                                                     
                                                        


                   






















                                                                                                                       


                                                   
                                                               

                                                                                                           
                                 













                                                                               


                                       








                                                                     
                           



                                                                               


                                            

               
                                          
                                    
                                  

                        
                      


                     
                                                                   




                                  
                                                                                















                                                                                         
                   
                    
                 
                                                                   




                                  
                                                                             





                                                                         

                                                                    

                                                        



                                                                             








                                                                                             
                   









                                                                                       



                                                                                                                                 




                                                                                           
                                                     










                                                                                                








                                                                                      
require 'lmptb.lmvocab'
require 'lmptb.lmfeeder'
require 'lmptb.lmutil'
require 'lmptb.layer.init'
require 'rnn.init'
require 'lmptb.lmseqreader'

--[[global function rename]]--
printf = nerv.printf
--[[global function rename ends]]--

--global_conf: table
--first_time: bool
--Returns: a ParamRepo
function prepare_parameters(global_conf, first_time)
    printf("%s preparing parameters...\n", global_conf.sche_log_pre) 
    
    if (first_time) then
        ltp_ih = nerv.LinearTransParam("ltp_ih", global_conf)  
        ltp_ih.trans = global_conf.cumat_type(global_conf.vocab:size(), global_conf.hidden_size) --index 0 is for zero, others correspond to vocab index(starting from 1)
        ltp_ih.trans:generate(global_conf.param_random)

        ltp_hh = nerv.LinearTransParam("ltp_hh", global_conf)
        ltp_hh.trans = global_conf.cumat_type(global_conf.hidden_size, global_conf.hidden_size)
        ltp_hh.trans:generate(global_conf.param_random) 

        ltp_ho = nerv.LinearTransParam("ltp_ho", global_conf)
        ltp_ho.trans = global_conf.cumat_type(global_conf.hidden_size, global_conf.vocab:size())
        ltp_ho.trans:generate(global_conf.param_random)

        bp_h = nerv.BiasParam("bp_h", global_conf)
        bp_h.trans = global_conf.cumat_type(1, global_conf.hidden_size)
        bp_h.trans:generate(global_conf.param_random)

        bp_o = nerv.BiasParam("bp_o", global_conf)
        bp_o.trans = global_conf.cumat_type(1, global_conf.vocab:size())
        bp_o.trans:generate(global_conf.param_random)

        local f = nerv.ChunkFile(global_conf.param_fn, 'w')
        f:write_chunk(ltp_ih)
        f:write_chunk(ltp_hh)
        f:write_chunk(ltp_ho)
        f:write_chunk(bp_h)
        f:write_chunk(bp_o)
        f:close()
    end
    
    local paramRepo = nerv.ParamRepo()
    paramRepo:import({global_conf.param_fn}, nil, global_conf)

    printf("%s preparing parameters end.\n", global_conf.sche_log_pre)

    return paramRepo
end

--global_conf: table
--Returns: nerv.LayerRepo
function prepare_layers(global_conf, paramRepo)
    printf("%s preparing layers...\n", global_conf.sche_log_pre)

    --local recurrentLconfig = {{["bp"] = "bp_h", ["ltp_hh"] = "ltp_hh"}, {["dim_in"] = {global_conf.hidden_size, global_conf.hidden_size}, ["dim_out"] = {global_conf.hidden_size}, ["break_id"] = global_conf.vocab:get_sen_entry().id, ["independent"] = global_conf.independent, ["clip"] = 10}}
    local recurrentLconfig = {{["bp"] = "bp_h", ["ltp_hh"] = "ltp_hh"}, {["dim_in"] = {global_conf.hidden_size, global_conf.hidden_size}, ["dim_out"] = {global_conf.hidden_size}, ["clip"] = 10}}

    local layers = {
        ["nerv.AffineRecurrentLayer"] = {
            ["recurrentL1"] = recurrentLconfig, 
        },

        ["nerv.SelectLinearLayer"] = {
            ["selectL1"] = {{["ltp"] = "ltp_ih"}, {["dim_in"] = {1}, ["dim_out"] = {global_conf.hidden_size}}},
        },
        
        ["nerv.SigmoidLayer"] = {
            ["sigmoidL1"] = {{}, {["dim_in"] = {global_conf.hidden_size}, ["dim_out"] = {global_conf.hidden_size}}}
        },
        
        ["nerv.CombinerLayer"] = {
            ["combinerL1"] = {{}, {["dim_in"] = {global_conf.hidden_size}, ["dim_out"] = {global_conf.hidden_size, global_conf.hidden_size}, ["lambda"] = {1}}}
        },

        ["nerv.AffineLayer"] = {
            ["outputL"] = {{["ltp"] = "ltp_ho", ["bp"] = "bp_o"}, {["dim_in"] = {global_conf.hidden_size}, ["dim_out"] = {global_conf.vocab:size()}}},
        },

        ["nerv.SoftmaxCELayerT"] = {
            ["softmaxL"] = {{}, {["dim_in"] = {global_conf.vocab:size(), global_conf.vocab:size()}, ["dim_out"] = {1}}},
        },
    }
    
    --[[ --we do not need those in the new rnn framework
    printf("%s adding %d bptt layers...\n", global_conf.sche_log_pre, global_conf.bptt)
    for i = 1, global_conf.bptt do
        layers["nerv.IndRecurrentLayer"]["recurrentL" .. (i + 1)] = recurrentLconfig 
        layers["nerv.SigmoidLayer"]["sigmoidL" .. (i + 1)] = {{}, {["dim_in"] = {global_conf.hidden_size}, ["dim_out"] = {global_conf.hidden_size}}}
        layers["nerv.SelectLinearLayer"]["selectL" .. (i + 1)] = {{["ltp"] = "ltp_ih"}, {["dim_in"] = {1}, ["dim_out"] = {global_conf.hidden_size}}}
    end
    --]]

    local layerRepo = nerv.LayerRepo(layers, paramRepo, global_conf)
    printf("%s preparing layers end.\n", global_conf.sche_log_pre)
    return layerRepo
end

--global_conf: table
--layerRepo: nerv.LayerRepo
--Returns: a nerv.TNN
function prepare_tnn(global_conf, layerRepo)
    printf("%s Generate and initing TNN ...\n", global_conf.sche_log_pre)

    --input: input_w, input_w, ... input_w_now, last_activation
    local connections_t = {
        {"<input>[1]", "selectL1[1]", 0},
        {"selectL1[1]", "recurrentL1[1]", 0},  
        {"recurrentL1[1]", "sigmoidL1[1]", 0},
        {"sigmoidL1[1]", "combinerL1[1]", 0},
        {"combinerL1[1]", "recurrentL1[2]", 1},
        {"combinerL1[2]", "outputL[1]", 0},
        {"outputL[1]", "softmaxL[1]", 0},
        {"<input>[2]", "softmaxL[2]", 0},
        {"softmaxL[1]", "<output>[1]", 0}
    }

    --[[
    printf("%s printing DAG connections:\n", global_conf.sche_log_pre)
    for key, value in pairs(connections_t) do
        printf("\t%s->%s\n", key, value)
    end
    ]]--

    local tnn = nerv.TNN("TNN", global_conf, {["dim_in"] = {1, global_conf.vocab:size()}, ["dim_out"] = {1}, ["sub_layers"] = layerRepo,
            ["connections"] = connections_t, 
        })

    tnn:init(global_conf.batch_size, global_conf.chunk_size)

    printf("%s Initing TNN end.\n", global_conf.sche_log_pre)
    return tnn
end

function load_net(global_conf)
    local paramRepo = prepare_parameters(global_conf, false)
    local layerRepo = prepare_layers(global_conf, paramRepo)
    local tnn = prepare_tnn(global_conf, layerRepo)
    return tnn, paramRepo
end

--Returns: LMResult
function lm_process_file(global_conf, fn, tnn, do_train)
    local reader = nerv.LMSeqReader(global_conf, global_conf.batch_size, global_conf.chunk_size, global_conf.vocab)
    reader:open_file(fn)
    local result = nerv.LMResult(global_conf, global_conf.vocab)
    result:init("rnn")
    
    tnn:flush_all() --caution: will also flush the inputs from the reader!

    local next_log_wcn = global_conf.log_w_num

    global_conf.fz = 0
    global_conf.fz2 = 0

    while (1) do
        local r, feeds

        r, feeds = tnn:getFeedFromReader(reader)
        if (r == false) then break end
    
        for t = 1, global_conf.chunk_size do
            tnn.err_inputs_m[t][1]:fill(1)
            for i = 1, global_conf.batch_size do
                if (bit.band(feeds.flags_now[t][i], nerv.TNN.FC.HAS_LABEL) == 0) then
                    tnn.err_inputs_m[t][1][i - 1][0] = 0
                end
            end
        end

        --[[
        for j = 1, global_conf.chunk_size, 1 do
            for i = 1, global_conf.batch_size, 1 do
                printf("%s[L(%s)] ", feeds.inputs_s[j][i], feeds.labels_s[j][i])   --vocab:get_word_str(input[i][j]).id
            end
            printf("\n")
        end
        printf("\n")
        ]]--

        tnn:net_propagate()
 
        if (do_train == true) then
            tnn:net_backpropagate(false) 
            tnn:net_backpropagate(true)
        end
 
        for t = 1, global_conf.chunk_size, 1 do
            for i = 1, global_conf.batch_size, 1 do
                if (feeds.labels_s[t][i] ~= global_conf.vocab.null_token) then
                    result:add("rnn", feeds.labels_s[t][i], math.exp(tnn.outputs_m[t][1][i - 1][0]))
                end
            end            
        end
        if (result["rnn"].cn_w > next_log_wcn) then
            next_log_wcn = next_log_wcn + global_conf.log_w_num
            printf("%s %d words processed %s.\n", global_conf.sche_log_pre, result["rnn"].cn_w, os.date()) 
            printf("\t%s log prob per sample :%f.\n", global_conf.sche_log_pre, result:logp_sample("rnn"))
            nerv.LMUtil.wait(0.1)
        end
         
        --[[
        for t = 1, global_conf.chunk_size do
            print(tnn.outputs_m[t][1])
        end
        ]]--

        tnn:moveRightToNextMB()

        collectgarbage("collect")                                              

        --break --debug
    end
    
    print("gconf.fz", global_conf.fz)
    print("gconf.fz2", global_conf.fz2)

    printf("%s Displaying result:\n", global_conf.sche_log_pre)
    printf("%s %s\n", global_conf.sche_log_pre, result:status("rnn"))
    printf("%s Doing on %s end.\n", global_conf.sche_log_pre, fn)
    
    return result
end

local train_fn, valid_fn, test_fn, global_conf
local set = arg[1] --"test"

if (set == "ptb") then

data_dir = '/home/slhome/txh18/workspace/nerv/nerv/nerv/examples/lmptb/PTBdata'
train_fn = data_dir .. '/ptb.train.txt.adds'
valid_fn = data_dir .. '/ptb.valid.txt.adds'
test_fn = data_dir .. '/ptb.test.txt.adds'

global_conf = {
    lrate = 1, wcost = 1e-6, momentum = 0,
    cumat_type = nerv.CuMatrixFloat,
    mmat_type = nerv.MMatrixFloat,
    nn_act_default = 0, 

    hidden_size = 200,
    chunk_size = 5,
    batch_size = 10, 
    max_iter = 20,
    param_random = function() return (math.random() / 5 - 0.1) end,

    train_fn = train_fn,
    valid_fn = valid_fn,
    test_fn = test_fn,
    sche_log_pre = "[SCHEDULER]:",
    log_w_num = 40000, --give a message when log_w_num words have been processed
    timer = nerv.Timer()
}

else

valid_fn = '/home/slhome/txh18/workspace/nerv/nerv/nerv/examples/lmptb/m-tests/some-text'
train_fn = '/home/slhome/txh18/workspace/nerv/nerv/nerv/examples/lmptb/m-tests/some-text'
test_fn = '/home/slhome/txh18/workspace/nerv/nerv/nerv/examples/lmptb/m-tests/some-text'

global_conf = {
    lrate = 1, wcost = 1e-6, momentum = 0,
    cumat_type = nerv.CuMatrixFloat,
    mmat_type = nerv.CuMatrixFloat,
    nn_act_default = 0, 

    hidden_size = 20,
    chunk_size = 2,
    batch_size = 3, 
    max_iter = 3,
    param_random = function() return (math.random() / 5 - 0.1) end,

    train_fn = train_fn,
    valid_fn = valid_fn,
    test_fn = test_fn,
    sche_log_pre = "[SCHEDULER]:",
    log_w_num = 10, --give a message when log_w_num words have been processed
    timer = nerv.Timer()
}

end

global_conf.work_dir = '/home/slhome/txh18/workspace/nerv/play/dagL_test'
global_conf.train_fn_shuf = global_conf.work_dir .. '/train_fn_shuf'
global_conf.train_fn_shuf_bak = global_conf.train_fn_shuf .. '_bak'
global_conf.param_fn = global_conf.work_dir .. "/params"

printf("%s creating work_dir...\n", global_conf.sche_log_pre)
os.execute("mkdir -p "..global_conf.work_dir)
os.execute("cp " .. global_conf.train_fn .. " " .. global_conf.train_fn_shuf)

local vocab = nerv.LMVocab()
global_conf["vocab"] = vocab
global_conf.vocab:build_file(global_conf.train_fn, false)

prepare_parameters(global_conf, true) --randomly generate parameters

print("===INITIAL VALIDATION===") 
local tnn, paramRepo = load_net(global_conf)
local result = lm_process_file(global_conf, global_conf.valid_fn, tnn, false) --false update!
nerv.LMUtil.wait(3)
ppl_rec = {} 
lr_rec = {}
ppl_rec[0] = result:ppl_net("rnn")  ppl_last = ppl_rec[0]
lr_rec[0] = 0 
print() 
local lr_half = false 
for iter = 1, global_conf.max_iter, 1 do
    tnn, paramRepo = load_net(global_conf) 
    printf("===ITERATION %d LR %f===\n", iter, global_conf.lrate) 
    global_conf.sche_log_pre = "[SCHEDULER ITER"..iter.." LR"..global_conf.lrate.."]:" 
    lm_process_file(global_conf, global_conf.train_fn_shuf, tnn, true) --true update!
    --shuffling training file
    os.execute('cp ' .. global_conf.train_fn_shuf .. ' ' .. global_conf.train_fn_shuf_bak)
    os.execute('cat ' .. global_conf.train_fn_shuf_bak .. ' | sort -R --random-source=/dev/zero > ' .. global_conf.train_fn_shuf)
    printf("===VALIDATION %d===\n", iter) 
    result = lm_process_file(global_conf, global_conf.valid_fn, tnn, false) --false update!
    ppl_rec[iter] = result:ppl_net("rnn") 
    lr_rec[iter] = global_conf.lrate 
    if (ppl_last / ppl_rec[iter] < 1.03 or lr_half == true) then 
        global_conf.lrate = (global_conf.lrate * 0.6)
        lr_half = true 
    end 
    if (ppl_rec[iter] < ppl_last) then 
        printf("%s saving net to file %s...\n", global_conf.sche_log_pre, global_conf.param_fn) 
        paramRepo:export(global_conf.param_fn, nil) 
        ppl_last = ppl_rec[iter] 
    else 
        printf("%s PPL did not improve, rejected...\n", global_conf.sche_log_pre) 
    end 
    printf("\n") 
    nerv.LMUtil.wait(2) 
end
printf("===VALIDATION PPL record===\n") 
for i = 0, #ppl_rec do printf("<ITER%d LR%.5f: %.3f> ", i, lr_rec[i], ppl_rec[i]) end 
printf("\n") 
printf("===FINAL TEST===\n") 
global_conf.sche_log_pre = "[SCHEDULER FINAL_TEST]:" 
tnn, paramRepo = load_net(global_conf) 
lm_process_file(global_conf, global_conf.test_fn, tnn, false) --false update!