diff options
Diffstat (limited to 'nerv/main.lua')
-rw-r--r-- | nerv/main.lua | 73 |
1 files changed, 0 insertions, 73 deletions
diff --git a/nerv/main.lua b/nerv/main.lua deleted file mode 100644 index 7c82ebf..0000000 --- a/nerv/main.lua +++ /dev/null @@ -1,73 +0,0 @@ -local global_conf = { - cumat_type = nerv.CuMatrixFloat, - param_random = function() return 0 end, - lrate = 0.1, - wcost = 0, - momentum = 0.9, - batch_size = 2, -} - -local layer_repo = nerv.LayerRepo( - { - ['nerv.RNNLayer'] = { - rnn1 = {dim_in = {23}, dim_out = {26}}, - rnn2 = {dim_in = {26}, dim_out = {26}}, - }, - ['nerv.AffineLayer'] = { - input = {dim_in = {62}, dim_out = {23}}, - output = {dim_in = {26, 79}, dim_out = {79}}, - }, - ['nerv.SigmoidLayer'] = { - sigmoid = {dim_in = {23}, dim_out = {23}}, - }, - ['nerv.IdentityLayer'] = { - softmax = {dim_in = {79}, dim_out = {79}}, - }, - ['nerv.DuplicateLayer'] = { - dup = {dim_in = {79}, dim_out = {79, 79}}, - }, - }, nerv.ParamRepo(), global_conf) - -local connections = { - {'<input>[1]', 'input[1]', 0}, - {'input[1]', 'sigmoid[1]', 0}, - {'sigmoid[1]', 'rnn1[1]', 0}, - {'rnn1[1]', 'rnn2[1]', 0}, - {'rnn2[1]', 'output[1]', 0}, - {'output[1]', 'dup[1]', 0}, - {'dup[1]', 'output[2]', -1}, - {'dup[2]', 'softmax[1]', 0}, - {'softmax[1]', '<output>[1]', 0}, -} - -local graph = nerv.GraphLayer('graph', global_conf, {dim_in = {62}, dim_out = {79}, layer_repo = layer_repo, connections = connections}) - -local network = nerv.Network('network', global_conf, {network = graph}) - -local batch = global_conf.batch_size -local chunk = 5 -network:init(batch, chunk) - -local input = {} -local output = {} -local err_input = {} -local err_output = {} -local input_size = 62 -local output_size = 79 -for i = 1, chunk do - input[i] = {global_conf.cumat_type(batch, input_size)} - output[i] = {global_conf.cumat_type(batch, output_size)} - err_input[i] = {global_conf.cumat_type(batch, output_size)} - err_output[i] = {global_conf.cumat_type(batch, input_size)} -end - -for i = 1, 100 do - network:mini_batch_init({seq_length = {5, 3}, new_seq = {2}}) - network:propagate(input, output) - network:back_propagate(err_input, err_output, input, output) - network:update(err_input, input, output) -end - -local tmp = network:get_params() - -tmp:export('../../workspace/test.param') |