aboutsummaryrefslogtreecommitdiff
path: root/nerv/examples/seq_trainer.lua
blob: b8ed3eb82848c70e759390c1b74ae4199cd9714e (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
function build_trainer(ifname)
    local param_repo = nerv.ParamRepo()
    param_repo:import(ifname, nil, gconf)
    local layer_repo = make_layer_repo(param_repo)
    local network = get_network(layer_repo)
    local global_transf = get_global_transf(layer_repo)
    local input_order = get_input_order()
    local iterative_trainer = function (prefix, scp_file, bp)
        local readers = make_readers(scp_file, layer_repo)
        -- initialize the network
        network:init(1)
        gconf.cnt = 0
        for ri = 1, #readers, 1 do
            while true do
                local data = readers[ri].reader:get_data()
                if data == nil then
                    break
                end
                -- prine stat periodically
                gconf.cnt = gconf.cnt + 1
                if gconf.cnt == 1000 then
                    print_stat(layer_repo)
                    nerv.CuMatrix.print_profile()
                    nerv.CuMatrix.clear_profile()
                    gconf.cnt = 0
                    -- break
                end
                local input = {}
                --            if gconf.cnt == 1000 then break end
                for i, e in ipairs(input_order) do
                    local id = e.id
                    if data[id] == nil then
                        nerv.error("input data %s not found", id)
                    end
                    local transformed
                    if e.global_transf then
                        local batch = gconf.cumat_type(data[id]:nrow(), data[id]:ncol())
                        batch:copy_fromh(data[id])
                        transformed = nerv.speech_utils.global_transf(batch,
                        global_transf,
                        gconf.frm_ext or 0, 0,
                        gconf)
                    else
                        transformed = data[id]
                    end
                    table.insert(input, transformed)
                end
                err_output = {input[1]:create()}
                network:batch_resize(input[1]:nrow())
                if network:propagate(input, {{}}) == true then
                    network:back_propagate({{}}, err_output, input, {{}})
                    gconf.batch_size = 1.0 - gconf.momentum  -- important!!!
                    network:update({{}}, input, {{}})
                end
                -- collect garbage in-time to save GPU memory
                collectgarbage("collect")
            end
        end
        print_stat(layer_repo)
        nerv.CuMatrix.print_profile()
        nerv.CuMatrix.clear_profile()
        if prefix ~= nil then
            nerv.info("writing back...")
            local fname = string.format("%s_tr%.3f.nerv",
            prefix, get_accuracy(layer_repo))
            network:get_params():export(fname, nil)
        end
        return get_accuracy(layer_repo)
    end
    return iterative_trainer
end

dofile(arg[1])

local pf0 = gconf.initialized_param
local trainer = build_trainer(pf0)

local i = 1
nerv.info("[NN] begin iteration %d with lrate = %.6f", i, gconf.lrate)
local accu_tr = trainer(string.format("%s_%s_iter_%d_lr%f",
string.gsub(
(string.gsub(pf0[1], "(.*/)(.*)", "%2")),
"(.*)%..*", "%1"),
os.date("%Y%m%d%H%M%S"),
i, gconf.lrate), gconf.tr_scp, true)
nerv.info("[TR] training set %d: %.3f", i, accu_tr)