aboutsummaryrefslogtreecommitdiff
path: root/lua/network.lua
diff options
context:
space:
mode:
authorQi Liu <[email protected]>2016-03-11 20:11:00 +0800
committerQi Liu <[email protected]>2016-03-11 20:11:00 +0800
commite2a9af061db485d4388902d738c9d8be3f94ab34 (patch)
tree468d6c6afa0801f6a6bf794b3674f8814b8827f7 /lua/network.lua
parent2f46a5e2b37a054f482f76f4ac3d26b144cf988f (diff)
add recipe and fix bugs
Diffstat (limited to 'lua/network.lua')
-rw-r--r--lua/network.lua106
1 files changed, 0 insertions, 106 deletions
diff --git a/lua/network.lua b/lua/network.lua
deleted file mode 100644
index d106ba1..0000000
--- a/lua/network.lua
+++ /dev/null
@@ -1,106 +0,0 @@
-nerv.include('select_linear.lua')
-
-local nn = nerv.class('nerv.NN')
-
-function nn:__init(global_conf, train_data, val_data, layers, connections)
- self.gconf = global_conf
- self.network = self:get_network(layers, connections)
- self.train_data = self:get_data(train_data)
- self.val_data = self:get_data(val_data)
-end
-
-function nn:get_network(layers, connections)
- local layer_repo = nerv.LayerRepo(layers, self.gconf.pr, self.gconf)
- local graph = nerv.GraphLayer('graph', self.gconf,
- {dim_in = {1, self.gconf.vocab_size}, dim_out = {1},
- layer_repo = layer_repo, connections = connections})
- local network = nerv.Network('network', self.gconf,
- {network = graph, clip = self.gconf.clip})
- network:init(self.gconf.batch_size, self.gconf.chunk_size)
- return network
-end
-
-function nn:get_data(data)
- local err_output = {}
- local softmax_output = {}
- local output = {}
- for i = 1, self.gconf.chunk_size do
- err_output[i] = self.gconf.cumat_type(self.gconf.batch_size, 1)
- softmax_output[i] = self.gconf.cumat_type(self.gconf.batch_size, self.gconf.vocab_size)
- output[i] = self.gconf.cumat_type(self.gconf.batch_size, 1)
- end
- local ret = {}
- for i = 1, #data do
- ret[i] = {}
- ret[i].input = {}
- ret[i].output = {}
- ret[i].err_input = {}
- ret[i].err_output = {}
- for t = 1, self.gconf.chunk_size do
- ret[i].input[t] = {}
- ret[i].output[t] = {}
- ret[i].err_input[t] = {}
- ret[i].err_output[t] = {}
- ret[i].input[t][1] = data[i].input[t]
- ret[i].input[t][2] = data[i].output[t]
- ret[i].output[t][1] = output[t]
- local err_input = self.gconf.mmat_type(self.gconf.batch_size, 1)
- for j = 1, self.gconf.batch_size do
- if t <= data[i].seq_len[j] then
- err_input[j - 1][0] = 1
- else
- err_input[j - 1][0] = 0
- end
- end
- ret[i].err_input[t][1] = self.gconf.cumat_type.new_from_host(err_input)
- ret[i].err_output[t][1] = err_output[t]
- ret[i].err_output[t][2] = softmax_output[t]
- end
- ret[i].seq_length = data[i].seq_len
- ret[i].new_seq = {}
- for j = 1, self.gconf.batch_size do
- if data[i].seq_start[j] then
- table.insert(ret[i].new_seq, j)
- end
- end
- end
- return ret
-end
-
-function nn:process(data, do_train)
- local timer = self.gconf.timer
- local total_err = 0
- local total_frame = 0
- for id = 1, #data do
- data[id].do_train = do_train
- timer:tic('network')
- self.network:mini_batch_init(data[id])
- self.network:propagate()
- timer:toc('network')
- for t = 1, self.gconf.chunk_size do
- local tmp = data[id].output[t][1]:new_to_host()
- for i = 1, self.gconf.batch_size do
- if t <= data[id].seq_length[i] then
- total_err = total_err + math.log10(math.exp(tmp[i - 1][0]))
- total_frame = total_frame + 1
- end
- end
- end
- if do_train then
- timer:tic('network')
- self.network:back_propagate()
- self.network:update()
- timer:toc('network')
- end
- timer:tic('gc')
- collectgarbage('collect')
- timer:toc('gc')
- end
- return math.pow(10, - total_err / total_frame)
-end
-
-function nn:epoch()
- local train_error = self:process(self.train_data, true)
- local val_error = self:process(self.val_data, false)
- return train_error, val_error
-end