diff options
Diffstat (limited to 'nerv/main.lua')
-rw-r--r-- | nerv/main.lua | 18 |
1 files changed, 14 insertions, 4 deletions
diff --git a/nerv/main.lua b/nerv/main.lua index 0633e87..5cb7d07 100644 --- a/nerv/main.lua +++ b/nerv/main.lua @@ -17,6 +17,12 @@ local layer_repo = nerv.LayerRepo( ['nerv.SigmoidLayer'] = { sigmoid = {dim_in = {23}, dim_out = {23}}, }, + ['nerv.SoftmaxLayer'] = { + softmax = {dim_in = {79}, dim_out = {79}}, + }, + ['nerv.DuplicateLayer'] = { + dup = {dim_in = {79}, dim_out = {79, 79}}, + }, }, nerv.ParamRepo(), global_conf) local connections = { @@ -24,10 +30,14 @@ local connections = { {'input[1]', 'sigmoid[1]', 0}, {'sigmoid[1]', 'rnn[1]', 0}, {'rnn[1]', 'output[1]', 0}, - {'output[1]', 'output[2]', 1}, - {'output[1]', '<output>[1]', 0}, + {'output[1]', 'dup[1]', 0}, + {'dup[1]', 'output[2]', -1}, + {'dup[2]', 'softmax[1]', 0}, + {'softmax[1]', '<output>[1]', 0}, } -local graph = nerv.GraphLayer('network', global_conf, {dim_in = {20}, dim_out = {79}, layer_repo = layer_repo, connections = connections}) +local graph = nerv.GraphLayer('graph', global_conf, {dim_in = {20}, dim_out = {79}, layer_repo = layer_repo, connections = connections}) + +local network = nerv.Network('network', global_conf, {network = graph}) -local network = nerv.Network(graph) +network:init(2,5) |