aboutsummaryrefslogtreecommitdiff
path: root/nerv/examples/lmptb/m-tests/lmseqreader_test.lua
diff options
context:
space:
mode:
Diffstat (limited to 'nerv/examples/lmptb/m-tests/lmseqreader_test.lua')
-rw-r--r--nerv/examples/lmptb/m-tests/lmseqreader_test.lua52
1 files changed, 52 insertions, 0 deletions
diff --git a/nerv/examples/lmptb/m-tests/lmseqreader_test.lua b/nerv/examples/lmptb/m-tests/lmseqreader_test.lua
new file mode 100644
index 0000000..cbcdcbe
--- /dev/null
+++ b/nerv/examples/lmptb/m-tests/lmseqreader_test.lua
@@ -0,0 +1,52 @@
+require 'lmptb.lmseqreader'
+require 'lmptb.lmutil'
+
+local printf = nerv.printf
+
+local test_fn = "/home/slhome/txh18/workspace/nerv/nerv/nerv/examples/lmptb/m-tests/some-text"
+--local test_fn = "/home/slhome/txh18/workspace/nerv-project/nerv/examples/lmptb/PTBdata/ptb.train.txt"
+local vocab = nerv.LMVocab()
+vocab:build_file(test_fn)
+local chunk_size = 5
+local batch_size = 3
+local global_conf = {
+ lrate = 1, wcost = 1e-6, momentum = 0,
+ cumat_type = nerv.CuMatrixFloat,
+ mmat_type = nerv.CuMatrixFloat,
+
+ hidden_size = 20,
+ chunk_size = chunk_size,
+ batch_size = batch_size,
+ max_iter = 18,
+ param_random = function() return (math.random() / 5 - 0.1) end,
+ independent = true,
+
+ train_fn = train_fn,
+ test_fn = test_fn,
+ sche_log_pre = "[SCHEDULER]:",
+ log_w_num = 10, --give a message when log_w_num words have been processed
+ timer = nerv.Timer(),
+
+ vocab = vocab
+}
+
+local reader = nerv.LMSeqReader(global_conf, batch_size, chunk_size, vocab)
+reader:open_file(test_fn)
+local feeds = {}
+feeds.flags_now = {}
+feeds.inputs_m = {}
+for j = 1, chunk_size do
+ feeds.inputs_m[j] = {global_conf.cumat_type(batch_size, 1), global_conf.cumat_type(batch_size, global_conf.vocab:size())}
+ feeds.flags_now[j] = {}
+end
+while (1) do
+ local r = reader:get_batch(feeds)
+ if (r == false) then break end
+ for j = 1, chunk_size, 1 do
+ for i = 1, batch_size, 1 do
+ printf("%s[L(%s)] ", feeds.inputs_s[j][i], feeds.labels_s[j][i]) --vocab:get_word_str(input[i][j]).id
+ end
+ printf("\n")
+ end
+ printf("\n")
+end