aboutsummaryrefslogtreecommitdiff
path: root/nerv/examples/lmptb/lmptb/lmfeeder.lua
blob: 34631bfa0e2a1e6d969f9e093c2c13be4034f8fd (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
require 'lmptb.lmvocab'

local Feeder = nerv.class("nerv.LMFeeder")

local printf = nerv.printf

--global_conf: table
--batch_size: int
--vocab: nerv.LMVocab
function Feeder:__init(global_conf, batch_size, vocab)
    self.gconf = global_conf
    self.fh = nil --file handle to read, nil means currently no file
    self.batch_size = batch_size
    self.log_pre = "[LOG]LMFeeder:"
    self.vocab = vocab
    self.streams = nil
end

--fn: string
--Initialize all streams
function Feeder:open_file(fn)
    if (self.fh ~= nil) then
        nerv.error("%s error: in open_file, file handle not nil.")
    end
    printf("%s opening file %s...\n", self.log_pre, fn)
    self.fh = io.open(fn, "r")
    self.streams = {}
    for i = 1, self.batch_size, 1 do
        self.streams[i] = {["store"] = {self.vocab.sen_end_token}, ["head"] = 1, ["tail"] = 1}
    end
end
   
--id: int
--Refresh stream id,  read a line from file
function Feeder:refresh_stream(id)
    if (self.streams[id] == nil) then
        nerv.error("stream %d does not exit.", id)
    end
    local st = self.streams[id]
    if (st.store[st.head] ~= nil) then return end
    if (self.fh == nil) then return end
    local list = self.vocab:read_line(self.fh)
    if (list == nil) then --file has end
        printf("%s file expires, closing.\n", self.log_pre)
        self.fh:close() 
        self.fh = nil 
        return 
    end
    for i = 1, #list, 1 do
        st.tail = st.tail + 1
        st.store[st.tail] = list[i]
    end
end

--Returns: nil/table
--If gets something, return a list of string, vocab.null_token indicates end of string
function Feeder:get_batch()
    local got_new = false
    local list = {}
    for i = 1, self.batch_size, 1 do
        self:refresh_stream(i)
        local st = self.streams[i]
        list[i] = st.store[st.head]
        if (list[i] == nil) then list[i] = self.vocab.null_token end
        if (list[i] ~= nil and list[i] ~= self.vocab.null_token)then
            got_new = true
            st.store[st.head] = nil
            st.head = st.head + 1
        end 
    end
    if (got_new == false) then
        return nil
    else
        return list
    end
end

--[[
do
    local test_fn = "/home/slhome/txh18/workspace/nerv-project/some-text"
    --local test_fn = "/home/slhome/txh18/workspace/nerv-project/nerv/examples/lmptb/PTBdata/ptb.train.txt"
    local vocab = nerv.LMVocab()
    vocab:build_file(test_fn)
    local batch_size = 3
    local feeder = nerv.LMFeeder({}, batch_size, vocab)
    feeder:open_file(test_fn)
    while (1) do
        local list = feeder:get_batch()
        if (list == nil) then break end
        for i = 1, batch_size, 1 do
            printf("%s(%d) ", list[i], vocab:get_word_str(list[i]).id) 
        end
        printf("\n")
    end
end
]]--