diff options
-rw-r--r-- | nerv/examples/lmptb/lstmlm_ptb_main.lua | 18 | ||||
-rw-r--r-- | nerv/examples/lmptb/m-tests/sutil_test.lua | 2 |
2 files changed, 11 insertions, 9 deletions
diff --git a/nerv/examples/lmptb/lstmlm_ptb_main.lua b/nerv/examples/lmptb/lstmlm_ptb_main.lua index 333fa96..a49e5c2 100644 --- a/nerv/examples/lmptb/lstmlm_ptb_main.lua +++ b/nerv/examples/lmptb/lstmlm_ptb_main.lua @@ -146,22 +146,22 @@ function prepare_tnn(global_conf, layerRepo) {"selectL1[1]", "lstmL1[1]", 0}, {"lstmL1[2]", "lstmL1[3]", 1}, - {"lstmL1[1]", "dropoutL1[1]", 0}, - {"dropoutL1[1]", "combinerL1[1]", 0}, + {"lstmL1[1]", "combinerL1[1]", 0}, {"combinerL1[1]", "lstmL1[2]", 1}, - - {"combinerL"..global_conf.layer_num.."[2]", "outputL[1]", 0}, + {"combinerL1[2]", "dropoutL1[1]", 0}, + + {"dropoutL"..global_conf.layer_num.."[1]", "outputL[1]", 0}, {"outputL[1]", "softmaxL[1]", 0}, {"<input>[2]", "softmaxL[2]", 0}, {"softmaxL[1]", "<output>[1]", 0} } for l = 2, global_conf.layer_num do - table.insert(connections_t, {"combinerL"..(l-1).."[2]", "lstmL"..l.."[1]", 0}) + table.insert(connections_t, {"dropoutL"..(l-1).."[1]", "lstmL"..l.."[1]", 0}) table.insert(connections_t, {"lstmL"..l.."[2]", "lstmL"..l.."[3]", 1}) - table.insert(connections_t, {"lstmL"..l.."[1]", "dropoutL"..l.."[1]", 0}) - table.insert(connections_t, {"dropoutL"..l.."[1]", "combinerL"..l.."[1]", 0}) + table.insert(connections_t, {"lstmL"..l.."[1]", "combinerL"..l.."[1]", 0}) table.insert(connections_t, {"combinerL"..l.."[1]", "lstmL"..l.."[2]", 1}) + table.insert(connections_t, {"combinerL"..l.."[2]", "dropoutL"..l.."[1]", 0}) end --[[ @@ -207,14 +207,14 @@ global_conf = { mmat_type = nerv.MMatrixFloat, nn_act_default = 0, - hidden_size = 600, + hidden_size = 650, layer_num = 2, chunk_size = 15, batch_size = 20, max_iter = 45, lr_decay = 1.003, decay_iter = 10, - param_random = function() return (math.random() / 5 - 0.1) end, + param_random = function() return (math.random() / 50 - 0.01) end, dropout_str = "0.5", train_fn = train_fn, diff --git a/nerv/examples/lmptb/m-tests/sutil_test.lua b/nerv/examples/lmptb/m-tests/sutil_test.lua index 058de7e..95660d9 100644 --- a/nerv/examples/lmptb/m-tests/sutil_test.lua +++ b/nerv/examples/lmptb/m-tests/sutil_test.lua @@ -8,3 +8,5 @@ print("!!!") for p, v in pairs(coms) do print(p,v) end +nerv.sss = "sss" +print(nerv.sss) |