diff options
Diffstat (limited to 'nerv/main.lua')
-rw-r--r-- | nerv/main.lua | 12 |
1 files changed, 9 insertions, 3 deletions
diff --git a/nerv/main.lua b/nerv/main.lua index 865aba0..7c82ebf 100644 --- a/nerv/main.lua +++ b/nerv/main.lua @@ -10,7 +10,8 @@ local global_conf = { local layer_repo = nerv.LayerRepo( { ['nerv.RNNLayer'] = { - rnn = {dim_in = {23}, dim_out = {26}}, + rnn1 = {dim_in = {23}, dim_out = {26}}, + rnn2 = {dim_in = {26}, dim_out = {26}}, }, ['nerv.AffineLayer'] = { input = {dim_in = {62}, dim_out = {23}}, @@ -30,8 +31,9 @@ local layer_repo = nerv.LayerRepo( local connections = { {'<input>[1]', 'input[1]', 0}, {'input[1]', 'sigmoid[1]', 0}, - {'sigmoid[1]', 'rnn[1]', 0}, - {'rnn[1]', 'output[1]', 0}, + {'sigmoid[1]', 'rnn1[1]', 0}, + {'rnn1[1]', 'rnn2[1]', 0}, + {'rnn2[1]', 'output[1]', 0}, {'output[1]', 'dup[1]', 0}, {'dup[1]', 'output[2]', -1}, {'dup[2]', 'softmax[1]', 0}, @@ -65,3 +67,7 @@ for i = 1, 100 do network:back_propagate(err_input, err_output, input, output) network:update(err_input, input, output) end + +local tmp = network:get_params() + +tmp:export('../../workspace/test.param') |