aboutsummaryrefslogtreecommitdiff
path: root/lua
diff options
context:
space:
mode:
authorQi Liu <[email protected]>2016-03-11 13:32:00 +0800
committerQi Liu <[email protected]>2016-03-11 13:32:00 +0800
commitf26288ba61d3d16866e1b227a71e7d9c46923436 (patch)
treeea41bb08994d9d2ee59c3ac5f3ec2c41bcaac6d2 /lua
parent05fcde5bf0caa1ceb70fef02fc88eda6f00c5ed5 (diff)
update mini_batch_init
Diffstat (limited to 'lua')
-rw-r--r--lua/config.lua4
-rw-r--r--lua/main.lua4
-rw-r--r--lua/network.lua30
-rw-r--r--lua/reader.lua3
4 files changed, 24 insertions, 17 deletions
diff --git a/lua/config.lua b/lua/config.lua
index 9d73b64..1ec1198 100644
--- a/lua/config.lua
+++ b/lua/config.lua
@@ -12,7 +12,7 @@ function get_global_conf()
layer_num = 1,
chunk_size = 15,
batch_size = 20,
- max_iter = 1,
+ max_iter = 3,
param_random = function() return (math.random() / 5 - 0.1) end,
dropout = 0.5,
timer = nerv.Timer(),
@@ -34,7 +34,7 @@ function get_layers(global_conf)
output = {dim_in = {global_conf.hidden_size}, dim_out = {global_conf.vocab_size}, pr = pr}
},
['nerv.SoftmaxCELayer'] = {
- softmax = {dim_in = {global_conf.vocab_size, global_conf.vocab_size}, dim_out = {1}},
+ softmax = {dim_in = {global_conf.vocab_size, global_conf.vocab_size}, dim_out = {1}, compressed = true},
},
}
for i = 1, global_conf.layer_num do
diff --git a/lua/main.lua b/lua/main.lua
index 684efac..ce0270a 100644
--- a/lua/main.lua
+++ b/lua/main.lua
@@ -9,7 +9,7 @@ local timer = global_conf.timer
timer:tic('IO')
local data_path = 'nerv/nerv/examples/lmptb/PTBdata/'
-local train_reader = nerv.Reader(data_path .. 'vocab', data_path .. 'ptb.train.txt.adds')
+local train_reader = nerv.Reader(data_path .. 'vocab', data_path .. 'ptb.valid.txt.adds')
local val_reader = nerv.Reader(data_path .. 'vocab', data_path .. 'ptb.valid.txt.adds')
local train_data = train_reader:get_all_batch(global_conf)
@@ -41,3 +41,5 @@ for i = 1, global_conf.max_iter do
end
timer:toc('global')
timer:check('global')
+timer:check('network')
+timer:check('gc')
diff --git a/lua/network.lua b/lua/network.lua
index 6280f24..0c11321 100644
--- a/lua/network.lua
+++ b/lua/network.lua
@@ -57,12 +57,11 @@ function nn:get_data(data)
ret[i].err_output[t][1] = err_output[t]
ret[i].err_output[t][2] = softmax_output[t]
end
- ret[i].info = {}
- ret[i].info.seq_length = data[i].seq_len
- ret[i].info.new_seq = {}
+ ret[i].seq_length = data[i].seq_len
+ ret[i].new_seq = {}
for j = 1, self.gconf.batch_size do
if data[i].seq_start[j] then
- table.insert(ret[i].info.new_seq, j)
+ table.insert(ret[i].new_seq, j)
end
end
end
@@ -70,34 +69,39 @@ function nn:get_data(data)
end
function nn:process(data, do_train)
+ local timer = self.gconf.timer
local total_err = 0
local total_frame = 0
for id = 1, #data do
if do_train then
self.gconf.dropout_rate = self.gconf.dropout
+ data[id].do_train = true
else
self.gconf.dropout_rate = 0
+ data[id].do_train = false
end
- self.network:mini_batch_init(data[id].info)
- local input = {}
- for t = 1, self.gconf.chunk_size do
- input[t] = {data[id].input[t][1], data[id].input[t][2]:decompress(self.gconf.vocab_size)}
- end
- self.network:propagate(input, data[id].output)
+ timer:tic('network')
+ self.network:mini_batch_init(data[id])
+ self.network:propagate()
+ timer:toc('network')
for t = 1, self.gconf.chunk_size do
local tmp = data[id].output[t][1]:new_to_host()
for i = 1, self.gconf.batch_size do
- if t <= data[id].info.seq_length[i] then
+ if t <= data[id].seq_length[i] then
total_err = total_err + math.log10(math.exp(tmp[i - 1][0]))
total_frame = total_frame + 1
end
end
end
if do_train then
- self.network:back_propagate(data[id].err_input, data[id].err_output, input, data[id].output)
- self.network:update(data[id].err_input, input, data[id].output)
+ timer:tic('network')
+ self.network:back_propagate()
+ self.network:update()
+ timer:toc('network')
end
+ timer:tic('gc')
collectgarbage('collect')
+ timer:toc('gc')
end
return math.pow(10, - total_err / total_frame)
end
diff --git a/lua/reader.lua b/lua/reader.lua
index 2e51a9c..0c7bcb6 100644
--- a/lua/reader.lua
+++ b/lua/reader.lua
@@ -58,7 +58,8 @@ function Reader:get_all_batch(global_conf)
for i = 1, global_conf.batch_size do
pos[i] = nil
end
- while true do
+ --while true do
+ for i = 1, 100 do
local input = {}
local output = {}
for i = 1, global_conf.chunk_size do