summaryrefslogblamecommitdiff
path: root/kaldi_decode/src/asr_propagator.lua
blob: ab18d6dcfb40b67524d8f66eb9848eca1a8ce77d (plain) (tree)
1
2
3
4
5


                                                             

                                                                                  










                                                                                             
 
                                          


                                  
                                       
                                    


                                                               
                                                            

                                                             





                                                                        
                                  
 
                                                            
                                                    
 



                                                                          
                               
                         
           





                                                                                 

                          
                            



















                                                                         
 

                               
                                               


                                                             
                                               


                                                                                                  
               















                                                                                     
               
           


                                                                               
           
                               
                                 

                                                                              

       
                           


                              
                  
                                       
                                             
                        
                                                              


               
                                 

                   
print = function(...) io.write(table.concat({...}, "\t")) end
io.output('/dev/null')
-- path and cpath are correctly set by `path.sh`
local k,l,_=pcall(require,"luarocks.loader") _=k and l.add_context("nerv","scm-1")
require 'nerv'
nerv.printf("*** NERV: A Lua-based toolkit for high-performance deep learning (alpha) ***\n")
nerv.info("automatically initialize a default MContext...")
nerv.MMatrix._default_context = nerv.MContext()
nerv.info("the default MContext is ok")
-- only for backward compatibilty, will be removed in the future
local function _add_profile_method(cls)
    local c = cls._default_context
    cls.print_profile = function () c:print_profile() end
    cls.clear_profile = function () c:clear_profile() end
end
_add_profile_method(nerv.MMatrix)

function build_propagator(ifname, feature)
    -- FIXME: this is still a hack
    local trainer = nerv.Trainer
    ----
    local param_repo = nerv.ParamRepo()
    param_repo:import(ifname, gconf)
    local layer_repo = trainer.make_layer_repo(nil, param_repo)
    local network = trainer.get_decode_network(nil, layer_repo)
    local input_order = trainer.get_decode_input_order(nil)
    local input_name = gconf.decode_input_name or "main_scp"
    local readers = trainer.make_decode_readers(nil, feature)
    -- nerv.info("prepare")
    local buffer = nerv.SeqBuffer(gconf, {
                                        buffer_size = gconf.buffer_size,
                                        batch_size = gconf.batch_size,
                                        chunk_size = gconf.chunk_size,
                                        randomize = gconf.randomize,
                                        readers = readers,
                                })

    network = nerv.Network("nt", gconf, {network = network})
    network:init(gconf.batch_size, gconf.chunk_size)

    local prev_data = buffer:get_data() or nerv.error("no data in buffer")
    local terminate = false
    local input_pos = nil
    for i, v in ipairs(input_order) do
        if v == input_name then
            input_pos = i
        end
    end
    if input_pos == nil then
        nerv.error("input name %s not found in the input order list", input_name)
    end
    local batch_propagator = function()
        if terminate then
            return "", nil
        end
        network:epoch_init()
        local accu_output = {}
        local utt_id = readers[input_pos].reader.key
        if utt_id == nil then
            nerv.error("no key found.")
        end
        while true do
            local d
            if prev_data then
                d = prev_data
                prev_data = nil
            else
                d = buffer:get_data()
                if d == nil then
                    terminate = true
                    break
                elseif #d.new_seq > 0 then
                    prev_data = d -- the first data of the next utterance
                    break
                end
            end

            local input = {}
            local output = {{}}
            for i, id in ipairs(input_order) do
                if d.data[id] == nil then
                    nerv.error("input data %s not found", id)
                end
                table.insert(input, d.data[id])
                for i = 1, gconf.chunk_size do
                    table.insert(output[1], gconf.mmat_type(gconf.batch_size, network.dim_out[1]))
                end
            end
            --nerv.info("input num: %d\nmat: %s\n", #input[1], input[1][1])
            --nerv.info("output num: %d\nmat: %s\n", #output[1], output[1][1])
            network:mini_batch_init({seq_length = d.seq_length,
                                    new_seq = d.new_seq,
                                    do_train = false,
                                    input = input,
                                    output = output,
                                    err_input = {},
                                    err_output = {}})
            network:propagate()
            for i, v in ipairs(output[1]) do
                --nerv.info(gconf.mask[i])
                if gconf.mask[i][0][0] > 0 then -- is not a hole
                    table.insert(accu_output, v)
                    --nerv.info("input: %s\noutput: %s\n", input[1][i], output[1][i])
                end
            end
        end
        local utt_matrix = gconf.mmat_type(#accu_output, accu_output[1]:ncol())
        for i, v in ipairs(accu_output) do
            utt_matrix:copy_from(v, 0, v:nrow(), i - 1)
        end
        --nerv.info(utt_matrix)
        collectgarbage("collect")
        nerv.info("propagated %d features of an utterance", utt_matrix:nrow())
        return utt_id, utt_matrix
    end

    return batch_propagator
end

function init(config, feature)
    dofile(config)
    gconf.mmat_type = nerv.MMatrixFloat
    gconf.use_cpu = true -- use CPU to decode
    gconf.batch_size = 1
    propagator = build_propagator(gconf.decode_param, feature)
end

function feed()
    local utt, mat = propagator()
    return utt, mat
end