aboutsummaryrefslogtreecommitdiff
path: root/nerv/examples/lmptb/lm_sampler.lua
blob: 9d31f174d967b1f8e68ccceaadc1357f9f63debf (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
local LMSampler = nerv.class('nerv.LMSampler')

function LMSampler:__init(global_conf)
    self.log_pre = "LMSampler"
    self.gconf = global_conf
    self.batch_size = self.gconf.batch_size
    self.chunk_size = self.gconf.chunk_size --largest sample sentence length
    self.vocab = self.gconf.vocab
    self.sen_end_token = self.vocab.sen_end_token
    self.sen_end_id = self.vocab:get_word_str(self.sen_end_token).id 

    self.loaded = false

end

function LMSampler:load_dagL(dagL)   
    nerv.printf("%s loading dagL\n", self.log_pre)

    self.dagL = dagL
    self.dagL:init(self.batch_size)

    self.dagL_inputs = {}
    self.dagL_inputs[1] = global_conf.cumat_type(global_conf.batch_size, 1)
    self.dagL_inputs[1]:fill(self.sen_end_id - 1)
    self.dagL_inputs[2] = global_conf.cumat_type(global_conf.batch_size, global_conf.hidden_size)
    self.dagL_inputs[2]:fill(0)
    
    self.dagL_outputs = {}
    self.dagL_outputs[1] = global_conf.cumat_type(global_conf.batch_size, global_conf.vocab:size())
    self.dagL_outputs[2] = global_conf.cumat_type(global_conf.batch_size, global_conf.hidden_size)
    
    self.smout_d = global_conf.cumat_type(self.batch_size, self.vocab:size())
    self.ssout_d = global_conf.cumat_type(self.batch_size, self.vocab:size())
    self.ssout_h = global_conf.mmat_type(self.batch_size, self.vocab:size())

    self.store = {}
    for i = 1, self.batch_size do
        self.store[i] = {}
        self.store[i][1] = {}
        self.store[i][1].w = self.sen_end_token
        self.store[i][1].id = self.sen_end_id
        self.store[i][1].p = 0
    end
    self.repo = {}

    self.loaded = true
end

function LMSampler:sample_to_store(ssout) --private
    for i = 1, self.batch_size do
        local ran = math.random()
        local id = 1
        local low = 0
        local high = ssout:ncol() - 1
        if ssout[i - 1][high] < 0.9999 or ssout[i - 1][high] > 1.0001 then
            nerv.error("%s ERROR, softmax output summation(%f) seems to have some problem", self.log_pre, ssout[i - 1][high])
        end
        if ssout[i - 1][low] < ran then
            while low + 1 < high do
                local mid = math.floor((low + high) / 2)
                if ssout[i - 1][mid] < ran then
                    low = mid
                else
                    high = mid
                end
            end
            id = high + 1
        end
        --[[
        local s = 0
        local id = self.vocab:size()
        for j = 0, self.vocab:size() - 1 do
            s = s + smout[i - 1][j]
            if s >= ran then 
                id = j + 1
                break
            end
        end
        ]]--
        if #self.store[i] >= self.chunk_size - 2 then
            id = self.sen_end_id
        end
        local tmp = {}
        tmp.w = self.vocab:get_word_id(id).str
        tmp.id = id
        if id == 1 then
            tmp.p = ssout[i - 1][id - 1]
        else
            tmp.p = ssout[i - 1][id - 1] - ssout[i - 1][id - 2] 
        end
        table.insert(self.store[i], tmp)
    end
end

function LMSampler:lm_sample_rnn_dagL(sample_num, p_conf)
    assert(self.loaded == true)

    local dagL = self.dagL
    local inputs = self.dagL_inputs
    local outputs = self.dagL_outputs
    
    while #self.repo < sample_num do
        dagL:propagate(inputs, outputs)
        inputs[2]:copy_fromd(outputs[2]) --copy hidden activation
    
        self.smout_d:softmax(outputs[1])
        self.ssout_d:prefixsum_row(self.smout_d)
        self.ssout_d:copy_toh(self.ssout_h)
        
        self:sample_to_store(self.ssout_h)
        for i = 1, self.batch_size do
            inputs[1][i - 1][0] = self.store[i][#self.store[i]].id - 1
            if self.store[i][#self.store[i]].id == self.sen_end_id then --meet a sentence end
                if #self.store[i] >= 3 then
                    self.repo[#self.repo + 1] = self.store[i]
                end
                inputs[2][i - 1]:fill(0)
                self.store[i] = {}
                self.store[i][1] = {}
                self.store[i][1].w = self.sen_end_token
                self.store[i][1].id = self.sen_end_id
                self.store[i][1].p = 0
            end
        end

        collectgarbage("collect")                                              
    end

    local res = {}
    for i = 1, sample_num do
        res[i] = self.repo[#self.repo]
        self.repo[#self.repo] = nil
    end
    return res
end