aboutsummaryrefslogtreecommitdiff
path: root/nerv/main.lua
diff options
context:
space:
mode:
Diffstat (limited to 'nerv/main.lua')
-rw-r--r--nerv/main.lua12
1 files changed, 9 insertions, 3 deletions
diff --git a/nerv/main.lua b/nerv/main.lua
index 865aba0..7c82ebf 100644
--- a/nerv/main.lua
+++ b/nerv/main.lua
@@ -10,7 +10,8 @@ local global_conf = {
local layer_repo = nerv.LayerRepo(
{
['nerv.RNNLayer'] = {
- rnn = {dim_in = {23}, dim_out = {26}},
+ rnn1 = {dim_in = {23}, dim_out = {26}},
+ rnn2 = {dim_in = {26}, dim_out = {26}},
},
['nerv.AffineLayer'] = {
input = {dim_in = {62}, dim_out = {23}},
@@ -30,8 +31,9 @@ local layer_repo = nerv.LayerRepo(
local connections = {
{'<input>[1]', 'input[1]', 0},
{'input[1]', 'sigmoid[1]', 0},
- {'sigmoid[1]', 'rnn[1]', 0},
- {'rnn[1]', 'output[1]', 0},
+ {'sigmoid[1]', 'rnn1[1]', 0},
+ {'rnn1[1]', 'rnn2[1]', 0},
+ {'rnn2[1]', 'output[1]', 0},
{'output[1]', 'dup[1]', 0},
{'dup[1]', 'output[2]', -1},
{'dup[2]', 'softmax[1]', 0},
@@ -65,3 +67,7 @@ for i = 1, 100 do
network:back_propagate(err_input, err_output, input, output)
network:update(err_input, input, output)
end
+
+local tmp = network:get_params()
+
+tmp:export('../../workspace/test.param')