aboutsummaryrefslogtreecommitdiff
path: root/nerv/examples/network_debug/network.lua
diff options
context:
space:
mode:
Diffstat (limited to 'nerv/examples/network_debug/network.lua')
-rw-r--r--nerv/examples/network_debug/network.lua120
1 files changed, 57 insertions, 63 deletions
diff --git a/nerv/examples/network_debug/network.lua b/nerv/examples/network_debug/network.lua
index 5518e27..386c3b0 100644
--- a/nerv/examples/network_debug/network.lua
+++ b/nerv/examples/network_debug/network.lua
@@ -2,11 +2,17 @@ nerv.include('select_linear.lua')
local nn = nerv.class('nerv.NN')
-function nn:__init(global_conf, train_data, val_data, layers, connections)
+function nn:__init(global_conf, layers, connections)
self.gconf = global_conf
self.network = self:get_network(layers, connections)
- self.train_data = self:get_data(train_data)
- self.val_data = self:get_data(val_data)
+
+ self.output = {}
+ self.err_output = {}
+ for i = 1, self.gconf.chunk_size do
+ self.output[i] = {self.gconf.cumat_type(self.gconf.batch_size, 1)}
+ self.err_output[i] = {self.gconf.cumat_type(self.gconf.batch_size, 1)}
+ self.err_output[i][2] = self.gconf.cumat_type(self.gconf.batch_size, 1)
+ end
end
function nn:get_network(layers, connections)
@@ -20,79 +26,67 @@ function nn:get_network(layers, connections)
return network
end
-function nn:get_data(data)
- local err_output = {}
- local softmax_output = {}
- local output = {}
- for i = 1, self.gconf.chunk_size do
- err_output[i] = self.gconf.cumat_type(self.gconf.batch_size, 1)
- softmax_output[i] = self.gconf.cumat_type(self.gconf.batch_size, self.gconf.vocab_size)
- output[i] = self.gconf.cumat_type(self.gconf.batch_size, 1)
- end
- local ret = {}
- for i = 1, #data do
- ret[i] = {}
- ret[i].input = {}
- ret[i].output = {}
- ret[i].err_input = {}
- ret[i].err_output = {}
- for t = 1, self.gconf.chunk_size do
- ret[i].input[t] = {}
- ret[i].output[t] = {}
- ret[i].err_input[t] = {}
- ret[i].err_output[t] = {}
- ret[i].input[t][1] = data[i].input[t]
- ret[i].input[t][2] = data[i].output[t]
- ret[i].output[t][1] = output[t]
- local err_input = self.gconf.mmat_type(self.gconf.batch_size, 1)
- for j = 1, self.gconf.batch_size do
- if t <= data[i].seq_len[j] then
- err_input[j - 1][0] = 1
- else
- err_input[j - 1][0] = 0
+function nn:process(data, do_train, reader)
+ local timer = self.gconf.timer
+ local buffer = nerv.SeqBuffer(self.gconf, {
+ batch_size = self.gconf.batch_size, chunk_size = self.gconf.chunk_size,
+ readers = {reader},
+ })
+ local total_err = 0
+ local total_frame = 0
+ self.network:epoch_init()
+ while true do
+ timer:tic('IO')
+ data = buffer:get_data()
+ if data == nil then
+ break
+ end
+ local err_input = {}
+ if do_train then
+ for t = 1, self.gconf.chunk_size do
+ local tmp = self.gconf.mmat_type(self.gconf.batch_size, 1)
+ for i = 1, self.gconf.batch_size do
+ if t <= data.seq_length[i] then
+ tmp[i - 1][0] = 1
+ else
+ tmp[i - 1][0] = 0
+ end
end
+ err_input[t] = {self.gconf.cumat_type.new_from_host(tmp)}
end
- ret[i].err_input[t][1] = self.gconf.cumat_type.new_from_host(err_input)
- ret[i].err_output[t][1] = err_output[t]
- ret[i].err_output[t][2] = softmax_output[t]
end
- ret[i].seq_length = data[i].seq_len
- ret[i].new_seq = {}
- for j = 1, self.gconf.batch_size do
- if data[i].seq_start[j] then
- table.insert(ret[i].new_seq, j)
- end
+ local info = {input = {}, output = self.output, err_input = err_input, do_train = do_train,
+ err_output = self.err_output, seq_length = data.seq_length, new_seq = data.new_seq}
+ for t = 1, self.gconf.chunk_size do
+ info.input[t] = {data.data['input'][t]}
+ info.input[t][2] = data.data['label'][t]
end
- end
- return ret
-end
+ timer:toc('IO')
-function nn:process(data, do_train)
- local timer = self.gconf.timer
- local total_err = 0
- local total_frame = 0
- self.network:epoch_init()
- for id = 1, #data do
- data[id].do_train = do_train
timer:tic('network')
- self.network:mini_batch_init(data[id])
+ self.network:mini_batch_init(info)
self.network:propagate()
timer:toc('network')
+
+ timer:tic('IO')
for t = 1, self.gconf.chunk_size do
- local tmp = data[id].output[t][1]:new_to_host()
+ local tmp = info.output[t][1]:new_to_host()
for i = 1, self.gconf.batch_size do
- if t <= data[id].seq_length[i] then
- total_err = total_err + math.log10(math.exp(tmp[i - 1][0]))
- total_frame = total_frame + 1
- end
+ total_err = total_err + math.log10(math.exp(tmp[i - 1][0]))
end
end
+ for i = 1, self.gconf.batch_size do
+ total_frame = total_frame + info.seq_length[i]
+ end
+ timer:toc('IO')
+
+ timer:tic('network')
if do_train then
- timer:tic('network')
self.network:back_propagate()
self.network:update()
- timer:toc('network')
end
+ timer:toc('network')
+
timer:tic('gc')
collectgarbage('collect')
timer:toc('gc')
@@ -100,11 +94,11 @@ function nn:process(data, do_train)
return math.pow(10, - total_err / total_frame)
end
-function nn:epoch()
- local train_error = self:process(self.train_data, true)
+function nn:epoch(train_reader, val_reader)
+ local train_error = self:process(self.train_data, true, train_reader)
local tmp = self.gconf.dropout_rate
self.gconf.dropout_rate = 0
- local val_error = self:process(self.val_data, false)
+ local val_error = self:process(self.val_data, false, val_reader)
self.gconf.dropout_rate = tmp
return train_error, val_error
end