require 'lmptb.lmseqreader' require 'lmptb.lmutil' local printf = nerv.printf local test_fn = "/home/slhome/txh18/workspace/nerv/nerv/nerv/examples/lmptb/m-tests/some-text" --local test_fn = "/home/slhome/txh18/workspace/nerv-project/nerv/examples/lmptb/PTBdata/ptb.train.txt" local vocab = nerv.LMVocab() vocab:build_file(test_fn) local chunk_size = 20 local batch_size = 3 local global_conf = { lrate = 1, wcost = 1e-6, momentum = 0, cumat_type = nerv.CuMatrixFloat, mmat_type = nerv.MMatrixFloat, hidden_size = 20, chunk_size = chunk_size, batch_size = batch_size, max_iter = 18, param_random = function() return (math.random() / 5 - 0.1) end, independent = true, train_fn = train_fn, test_fn = test_fn, sche_log_pre = "[SCHEDULER]:", log_w_num = 10, --give a message when log_w_num words have been processed timer = nerv.Timer(), vocab = vocab } local reader = nerv.LMSeqReader(global_conf, batch_size, chunk_size, vocab, {["se_mode"] = true}) reader:open_file(test_fn) local feeds = {} feeds.flags_now = {} feeds.inputs_m = {} feeds.flagsPack_now = {} for j = 1, chunk_size do feeds.inputs_m[j] = {global_conf.cumat_type(batch_size, 1), global_conf.cumat_type(batch_size, global_conf.vocab:size())} feeds.flags_now[j] = {} end while (1) do local r = reader:get_batch(feeds) if (r == false) then break end for j = 1, chunk_size, 1 do for i = 1, batch_size, 1 do printf("%s[L(%s)] ", feeds.inputs_s[j][i], feeds.labels_s[j][i]) --vocab:get_word_str(input[i][j]).id end printf("\n") end printf("\n") end