summaryrefslogblamecommitdiff
path: root/kaldi_decode/src/asr_propagator.lua
blob: ff9b8a20f35a1a2ed8f7e80aae4628480a981f1c (plain) (tree)
1
2
3
4
5


                                                             

                                                                                  










                                                                                             
 
                                          
                                       
                                    


                                                       

                                                            
 


                                                                        
                                       
                        
                               









                                                
                                                         



                                               
 







                                                         
                                                                       


                                                          


                                      
                                              
           
                                                                                    


                                                                                

                                                


                                                 





                                       
                                 
                                

       
                           


                              
                  
                                       
                                             
                                                           





                              
print = function(...) io.write(table.concat({...}, "\t")) end
io.output('/dev/null')
-- path and cpath are correctly set by `path.sh`
local k,l,_=pcall(require,"luarocks.loader") _=k and l.add_context("nerv","scm-1")
require 'nerv'
nerv.printf("*** NERV: A Lua-based toolkit for high-performance deep learning (alpha) ***\n")
nerv.info("automatically initialize a default MContext...")
nerv.MMatrix._default_context = nerv.MContext()
nerv.info("the default MContext is ok")
-- only for backward compatibilty, will be removed in the future
local function _add_profile_method(cls)
    local c = cls._default_context
    cls.print_profile = function () c:print_profile() end
    cls.clear_profile = function () c:clear_profile() end
end
_add_profile_method(nerv.MMatrix)

function build_propagator(ifname, feature)
    local param_repo = nerv.ParamRepo()
    param_repo:import(ifname, gconf)
    local layer_repo = make_layer_repo(param_repo)
    local network = get_decode_network(layer_repo)
    local global_transf = get_global_transf(layer_repo)
    local input_order = get_decode_input_order()
    local readers = make_decode_readers(feature, layer_repo)

    network = nerv.Network("nt", gconf, {network = network})
    global_transf = nerv.Network("gt", gconf, {network = global_transf})

    local batch_propagator = function()
        local data = nil
        for ri = 1, #readers do
            data = readers[ri].reader:get_data()
            if data ~= nil then
                break
            end
        end

        if data == nil then
            return "", nil
        end

        gconf.batch_size = data[input_order[1].id]:nrow()
        global_transf:init(gconf.batch_size, 1)
        global_transf:epoch_init()
        network:init(gconf.batch_size, 1)
        network:epoch_init()

        local input = {}
        for i, e in ipairs(input_order) do
            local id = e.id
            if data[id] == nil then
                nerv.error("input data %s not found", id)
            end
            local transformed
            if e.global_transf then
                transformed = nerv.speech_utils.global_transf(data[id],
                                    global_transf,
                                    gconf.frm_ext or 0, 0,
                                    gconf)
            else
                transformed = data[id]
            end
            table.insert(input, {transformed})
        end
        local output = {{nerv.MMatrixFloat(input[1][1]:nrow(), network.dim_out[1])}}
        network:mini_batch_init({seq_length = table.vector(gconf.batch_size, 1),
                                new_seq = {},
                                do_train = false,
                                input = input,
                                output = output,
                                err_input = {},
                                err_output = {}})
        network:propagate()
        
        local utt = data["key"]
        if utt == nil then
            nerv.error("no key found.")
        end

        collectgarbage("collect")
        return utt, output[1][1]
    end

    return batch_propagator
end

function init(config, feature)
    dofile(config)
    gconf.mmat_type = nerv.MMatrixFloat
    gconf.use_cpu = true -- use CPU to decode
    trainer = build_propagator(gconf.decode_param, feature)
end

function feed()
    local utt, mat = trainer()
    return utt, mat
end