aboutsummaryrefslogtreecommitdiff
path: root/nerv/main.lua
diff options
context:
space:
mode:
Diffstat (limited to 'nerv/main.lua')
-rw-r--r--nerv/main.lua18
1 files changed, 14 insertions, 4 deletions
diff --git a/nerv/main.lua b/nerv/main.lua
index 0633e87..5cb7d07 100644
--- a/nerv/main.lua
+++ b/nerv/main.lua
@@ -17,6 +17,12 @@ local layer_repo = nerv.LayerRepo(
['nerv.SigmoidLayer'] = {
sigmoid = {dim_in = {23}, dim_out = {23}},
},
+ ['nerv.SoftmaxLayer'] = {
+ softmax = {dim_in = {79}, dim_out = {79}},
+ },
+ ['nerv.DuplicateLayer'] = {
+ dup = {dim_in = {79}, dim_out = {79, 79}},
+ },
}, nerv.ParamRepo(), global_conf)
local connections = {
@@ -24,10 +30,14 @@ local connections = {
{'input[1]', 'sigmoid[1]', 0},
{'sigmoid[1]', 'rnn[1]', 0},
{'rnn[1]', 'output[1]', 0},
- {'output[1]', 'output[2]', 1},
- {'output[1]', '<output>[1]', 0},
+ {'output[1]', 'dup[1]', 0},
+ {'dup[1]', 'output[2]', -1},
+ {'dup[2]', 'softmax[1]', 0},
+ {'softmax[1]', '<output>[1]', 0},
}
-local graph = nerv.GraphLayer('network', global_conf, {dim_in = {20}, dim_out = {79}, layer_repo = layer_repo, connections = connections})
+local graph = nerv.GraphLayer('graph', global_conf, {dim_in = {20}, dim_out = {79}, layer_repo = layer_repo, connections = connections})
+
+local network = nerv.Network('network', global_conf, {network = graph})
-local network = nerv.Network(graph)
+network:init(2,5)