diff options
Diffstat (limited to 'nerv/examples/lmptb/m-tests/lmseqreader_test.lua')
-rw-r--r-- | nerv/examples/lmptb/m-tests/lmseqreader_test.lua | 52 |
1 files changed, 52 insertions, 0 deletions
diff --git a/nerv/examples/lmptb/m-tests/lmseqreader_test.lua b/nerv/examples/lmptb/m-tests/lmseqreader_test.lua new file mode 100644 index 0000000..cbcdcbe --- /dev/null +++ b/nerv/examples/lmptb/m-tests/lmseqreader_test.lua @@ -0,0 +1,52 @@ +require 'lmptb.lmseqreader' +require 'lmptb.lmutil' + +local printf = nerv.printf + +local test_fn = "/home/slhome/txh18/workspace/nerv/nerv/nerv/examples/lmptb/m-tests/some-text" +--local test_fn = "/home/slhome/txh18/workspace/nerv-project/nerv/examples/lmptb/PTBdata/ptb.train.txt" +local vocab = nerv.LMVocab() +vocab:build_file(test_fn) +local chunk_size = 5 +local batch_size = 3 +local global_conf = { + lrate = 1, wcost = 1e-6, momentum = 0, + cumat_type = nerv.CuMatrixFloat, + mmat_type = nerv.CuMatrixFloat, + + hidden_size = 20, + chunk_size = chunk_size, + batch_size = batch_size, + max_iter = 18, + param_random = function() return (math.random() / 5 - 0.1) end, + independent = true, + + train_fn = train_fn, + test_fn = test_fn, + sche_log_pre = "[SCHEDULER]:", + log_w_num = 10, --give a message when log_w_num words have been processed + timer = nerv.Timer(), + + vocab = vocab +} + +local reader = nerv.LMSeqReader(global_conf, batch_size, chunk_size, vocab) +reader:open_file(test_fn) +local feeds = {} +feeds.flags_now = {} +feeds.inputs_m = {} +for j = 1, chunk_size do + feeds.inputs_m[j] = {global_conf.cumat_type(batch_size, 1), global_conf.cumat_type(batch_size, global_conf.vocab:size())} + feeds.flags_now[j] = {} +end +while (1) do + local r = reader:get_batch(feeds) + if (r == false) then break end + for j = 1, chunk_size, 1 do + for i = 1, batch_size, 1 do + printf("%s[L(%s)] ", feeds.inputs_s[j][i], feeds.labels_s[j][i]) --vocab:get_word_str(input[i][j]).id + end + printf("\n") + end + printf("\n") +end |