aboutsummaryrefslogtreecommitdiff
path: root/nerv/examples/lmptb/m-tests
diff options
context:
space:
mode:
authortxh18 <cloudygooseg@gmail.com>2015-11-03 18:36:43 +0800
committertxh18 <cloudygooseg@gmail.com>2015-11-03 18:36:43 +0800
commitd18122af2f57b8dd81db49385484f0e51d167a23 (patch)
tree935bed09505675f9ad8af61d29e226222f8c70e8 /nerv/examples/lmptb/m-tests
parent4d349c4b8639e074aafa7d4245231bf1f3decae6 (diff)
still working on TNN
Diffstat (limited to 'nerv/examples/lmptb/m-tests')
-rw-r--r--nerv/examples/lmptb/m-tests/dagl_test.lua18
-rw-r--r--nerv/examples/lmptb/m-tests/lmseqreader_test.lua36
2 files changed, 36 insertions, 18 deletions
diff --git a/nerv/examples/lmptb/m-tests/dagl_test.lua b/nerv/examples/lmptb/m-tests/dagl_test.lua
index 5e90551..a50107d 100644
--- a/nerv/examples/lmptb/m-tests/dagl_test.lua
+++ b/nerv/examples/lmptb/m-tests/dagl_test.lua
@@ -98,14 +98,11 @@ end
--global_conf: table
--layerRepo: nerv.LayerRepo
---Returns: a nerv.TDAGLayer
+--Returns: a nerv.TNN
function prepare_dagLayer(global_conf, layerRepo)
- printf("%s Initing daglayer ...\n", global_conf.sche_log_pre)
+ printf("%s Initing TNN ...\n", global_conf.sche_log_pre)
--input: input_w, input_w, ... input_w_now, last_activation
- local dim_in_t = {}
- dim_in_t[1] = 1 --input to select_linear layer
- dim_in_t[2] = global_conf.vocab:size() --input to softmax label
local connections_t = {
{"<input>[1]", "selectL1[1]", 0},
{"selectL1[1]", "recurrentL1[1]", 0},
@@ -124,11 +121,11 @@ function prepare_dagLayer(global_conf, layerRepo)
end
]]--
- local dagL = nerv.TDAGLayer("dagL", global_conf, {["dim_in"] = dim_in_t, ["dim_out"] = {1}, ["sub_layers"] = layerRepo,
+ local tnn = nerv.TNN("TNN", global_conf, {["dim_in"] = {1, global_conf.vocab:size()}, ["dim_out"] = {1}, ["sub_layers"] = layerRepo,
["connections"] = connections_t,
})
- printf("%s Initing DAGLayer end.\n", global_conf.sche_log_pre)
- return dagL
+ printf("%s Initing TNN end.\n", global_conf.sche_log_pre)
+ return tnn
end
train_fn = '/home/slhome/txh18/workspace/nerv/nerv/nerv/examples/lmptb/m-tests/some-text'
@@ -160,7 +157,6 @@ global_conf["vocab"] = vocab
global_conf.vocab:build_file(global_conf.train_fn, false)
local paramRepo = prepare_parameters(global_conf, true)
local layerRepo = prepare_layers(global_conf, paramRepo)
-local dagL = prepare_dagLayer(global_conf, layerRepo)
-dagL:init(global_conf.batch_size, global_conf.chunk_size)
-
+local tnn = prepare_dagLayer(global_conf, layerRepo)
+tnn:init(global_conf.batch_size, global_conf.chunk_size)
diff --git a/nerv/examples/lmptb/m-tests/lmseqreader_test.lua b/nerv/examples/lmptb/m-tests/lmseqreader_test.lua
index 504698f..cbcdcbe 100644
--- a/nerv/examples/lmptb/m-tests/lmseqreader_test.lua
+++ b/nerv/examples/lmptb/m-tests/lmseqreader_test.lua
@@ -1,4 +1,5 @@
require 'lmptb.lmseqreader'
+require 'lmptb.lmutil'
local printf = nerv.printf
@@ -8,15 +9,36 @@ local vocab = nerv.LMVocab()
vocab:build_file(test_fn)
local chunk_size = 5
local batch_size = 3
-local reader = nerv.LMSeqReader({}, batch_size, chunk_size, vocab)
+local global_conf = {
+ lrate = 1, wcost = 1e-6, momentum = 0,
+ cumat_type = nerv.CuMatrixFloat,
+ mmat_type = nerv.CuMatrixFloat,
+
+ hidden_size = 20,
+ chunk_size = chunk_size,
+ batch_size = batch_size,
+ max_iter = 18,
+ param_random = function() return (math.random() / 5 - 0.1) end,
+ independent = true,
+
+ train_fn = train_fn,
+ test_fn = test_fn,
+ sche_log_pre = "[SCHEDULER]:",
+ log_w_num = 10, --give a message when log_w_num words have been processed
+ timer = nerv.Timer(),
+
+ vocab = vocab
+}
+
+local reader = nerv.LMSeqReader(global_conf, batch_size, chunk_size, vocab)
reader:open_file(test_fn)
-local input = {}
-local label = {}
-for i = 1, batch_size, 1 do
- input[i] = {}
- label[i] = {}
-end
local feeds = {}
+feeds.flags_now = {}
+feeds.inputs_m = {}
+for j = 1, chunk_size do
+ feeds.inputs_m[j] = {global_conf.cumat_type(batch_size, 1), global_conf.cumat_type(batch_size, global_conf.vocab:size())}
+ feeds.flags_now[j] = {}
+end
while (1) do
local r = reader:get_batch(feeds)
if (r == false) then break end