aboutsummaryrefslogtreecommitdiff
path: root/nerv/main.lua
diff options
context:
space:
mode:
Diffstat (limited to 'nerv/main.lua')
-rw-r--r--nerv/main.lua36
1 files changed, 30 insertions, 6 deletions
diff --git a/nerv/main.lua b/nerv/main.lua
index 5cb7d07..865aba0 100644
--- a/nerv/main.lua
+++ b/nerv/main.lua
@@ -1,8 +1,10 @@
-print 'Hello'
-
local global_conf = {
cumat_type = nerv.CuMatrixFloat,
param_random = function() return 0 end,
+ lrate = 0.1,
+ wcost = 0,
+ momentum = 0.9,
+ batch_size = 2,
}
local layer_repo = nerv.LayerRepo(
@@ -11,13 +13,13 @@ local layer_repo = nerv.LayerRepo(
rnn = {dim_in = {23}, dim_out = {26}},
},
['nerv.AffineLayer'] = {
- input = {dim_in = {20}, dim_out = {23}},
+ input = {dim_in = {62}, dim_out = {23}},
output = {dim_in = {26, 79}, dim_out = {79}},
},
['nerv.SigmoidLayer'] = {
sigmoid = {dim_in = {23}, dim_out = {23}},
},
- ['nerv.SoftmaxLayer'] = {
+ ['nerv.IdentityLayer'] = {
softmax = {dim_in = {79}, dim_out = {79}},
},
['nerv.DuplicateLayer'] = {
@@ -36,8 +38,30 @@ local connections = {
{'softmax[1]', '<output>[1]', 0},
}
-local graph = nerv.GraphLayer('graph', global_conf, {dim_in = {20}, dim_out = {79}, layer_repo = layer_repo, connections = connections})
+local graph = nerv.GraphLayer('graph', global_conf, {dim_in = {62}, dim_out = {79}, layer_repo = layer_repo, connections = connections})
local network = nerv.Network('network', global_conf, {network = graph})
-network:init(2,5)
+local batch = global_conf.batch_size
+local chunk = 5
+network:init(batch, chunk)
+
+local input = {}
+local output = {}
+local err_input = {}
+local err_output = {}
+local input_size = 62
+local output_size = 79
+for i = 1, chunk do
+ input[i] = {global_conf.cumat_type(batch, input_size)}
+ output[i] = {global_conf.cumat_type(batch, output_size)}
+ err_input[i] = {global_conf.cumat_type(batch, output_size)}
+ err_output[i] = {global_conf.cumat_type(batch, input_size)}
+end
+
+for i = 1, 100 do
+ network:mini_batch_init({seq_length = {5, 3}, new_seq = {2}})
+ network:propagate(input, output)
+ network:back_propagate(err_input, err_output, input, output)
+ network:update(err_input, input, output)
+end