summaryrefslogtreecommitdiff
path: root/lua/network.lua
diff options
context:
space:
mode:
Diffstat (limited to 'lua/network.lua')
-rw-r--r--lua/network.lua30
1 files changed, 17 insertions, 13 deletions
diff --git a/lua/network.lua b/lua/network.lua
index 6280f24..0c11321 100644
--- a/lua/network.lua
+++ b/lua/network.lua
@@ -57,12 +57,11 @@ function nn:get_data(data)
ret[i].err_output[t][1] = err_output[t]
ret[i].err_output[t][2] = softmax_output[t]
end
- ret[i].info = {}
- ret[i].info.seq_length = data[i].seq_len
- ret[i].info.new_seq = {}
+ ret[i].seq_length = data[i].seq_len
+ ret[i].new_seq = {}
for j = 1, self.gconf.batch_size do
if data[i].seq_start[j] then
- table.insert(ret[i].info.new_seq, j)
+ table.insert(ret[i].new_seq, j)
end
end
end
@@ -70,34 +69,39 @@ function nn:get_data(data)
end
function nn:process(data, do_train)
+ local timer = self.gconf.timer
local total_err = 0
local total_frame = 0
for id = 1, #data do
if do_train then
self.gconf.dropout_rate = self.gconf.dropout
+ data[id].do_train = true
else
self.gconf.dropout_rate = 0
+ data[id].do_train = false
end
- self.network:mini_batch_init(data[id].info)
- local input = {}
- for t = 1, self.gconf.chunk_size do
- input[t] = {data[id].input[t][1], data[id].input[t][2]:decompress(self.gconf.vocab_size)}
- end
- self.network:propagate(input, data[id].output)
+ timer:tic('network')
+ self.network:mini_batch_init(data[id])
+ self.network:propagate()
+ timer:toc('network')
for t = 1, self.gconf.chunk_size do
local tmp = data[id].output[t][1]:new_to_host()
for i = 1, self.gconf.batch_size do
- if t <= data[id].info.seq_length[i] then
+ if t <= data[id].seq_length[i] then
total_err = total_err + math.log10(math.exp(tmp[i - 1][0]))
total_frame = total_frame + 1
end
end
end
if do_train then
- self.network:back_propagate(data[id].err_input, data[id].err_output, input, data[id].output)
- self.network:update(data[id].err_input, input, data[id].output)
+ timer:tic('network')
+ self.network:back_propagate()
+ self.network:update()
+ timer:toc('network')
end
+ timer:tic('gc')
collectgarbage('collect')
+ timer:toc('gc')
end
return math.pow(10, - total_err / total_frame)
end