1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
|
local LMSampler = nerv.class('nerv.LMSampler')
function LMSampler:__init(global_conf)
self.log_pre = "LMSampler"
self.gconf = global_conf
self.batch_size = self.gconf.batch_size
self.chunk_size = self.gconf.chunk_size --largest sample sentence length
self.vocab = self.gconf.vocab
self.sen_end_token = self.vocab.sen_end_token
self.sen_end_id = self.vocab:get_word_str(self.sen_end_token).id
self.loaded = false
end
function LMSampler:load_dagL(dagL)
nerv.printf("%s loading dagL\n", self.log_pre)
self.dagL = dagL
self.dagL:init(self.batch_size)
self.dagL_inputs = {}
self.dagL_inputs[1] = self.gconf.cumat_type(self.gconf.batch_size, 1)
self.dagL_inputs[1]:fill(self.sen_end_id - 1)
self.dagL_inputs[2] = self.gconf.cumat_type(self.gconf.batch_size, self.gconf.hidden_size)
self.dagL_inputs[2]:fill(0)
self.dagL_outputs = {}
self.dagL_outputs[1] = self.gconf.cumat_type(self.gconf.batch_size, self.gconf.vocab:size())
self.dagL_outputs[2] = self.gconf.cumat_type(self.gconf.batch_size, self.gconf.hidden_size)
self.smout_d = self.gconf.cumat_type(self.batch_size, self.vocab:size())
self.ssout_d = self.gconf.cumat_type(self.batch_size, self.vocab:size())
self.ssout_h = self.gconf.mmat_type(self.batch_size, self.vocab:size())
self.store = {}
for i = 1, self.batch_size do
self.store[i] = {}
self.store[i][1] = {}
self.store[i][1].w = self.sen_end_token
self.store[i][1].id = self.sen_end_id
self.store[i][1].p = 0
end
self.repo = {}
self.loaded = true
end
function LMSampler:sample_to_store(ssout) --private
for i = 1, self.batch_size do
local ran = math.random()
local id = 1
local low = 0
local high = ssout:ncol() - 1
if ssout[i - 1][high] < 0.9999 or ssout[i - 1][high] > 1.0001 then
nerv.error("%s ERROR, softmax output summation(%f) seems to have some problem", self.log_pre, ssout[i - 1][high])
end
if ssout[i - 1][low] < ran then
while low + 1 < high do
local mid = math.floor((low + high) / 2)
if ssout[i - 1][mid] < ran then
low = mid
else
high = mid
end
end
id = high + 1
end
--[[
local s = 0
local id = self.vocab:size()
for j = 0, self.vocab:size() - 1 do
s = s + smout[i - 1][j]
if s >= ran then
id = j + 1
break
end
end
]]--
if #self.store[i] >= self.chunk_size - 2 then
id = self.sen_end_id
end
local tmp = {}
tmp.w = self.vocab:get_word_id(id).str
tmp.id = id
if id == 1 then
tmp.p = ssout[i - 1][id - 1]
else
tmp.p = ssout[i - 1][id - 1] - ssout[i - 1][id - 2]
end
table.insert(self.store[i], tmp)
end
end
function LMSampler:lm_sample_rnn_dagL(sample_num, p_conf)
assert(self.loaded == true)
local dagL = self.dagL
local inputs = self.dagL_inputs
local outputs = self.dagL_outputs
while #self.repo < sample_num do
dagL:propagate(inputs, outputs)
inputs[2]:copy_fromd(outputs[2]) --copy hidden activation
self.smout_d:softmax(outputs[1])
self.ssout_d:prefixsum_row(self.smout_d)
self.ssout_d:copy_toh(self.ssout_h)
self:sample_to_store(self.ssout_h)
for i = 1, self.batch_size do
inputs[1][i - 1][0] = self.store[i][#self.store[i]].id - 1
if self.store[i][#self.store[i]].id == self.sen_end_id then --meet a sentence end
if #self.store[i] >= 3 then
self.repo[#self.repo + 1] = self.store[i]
end
inputs[2][i - 1]:fill(0)
self.store[i] = {}
self.store[i][1] = {}
self.store[i][1].w = self.sen_end_token
self.store[i][1].id = self.sen_end_id
self.store[i][1].p = 0
end
end
collectgarbage("collect")
end
local res = {}
for i = 1, sample_num do
res[i] = self.repo[#self.repo]
self.repo[#self.repo] = nil
end
return res
end
|