aboutsummaryrefslogtreecommitdiff
path: root/nerv/io/frm_buffer.lua
blob: 9761f165fe2180fd3d0bb3819c732a9b056442ec (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
local FrmBuffer = nerv.class("nerv.FrmBuffer", "nerv.DataBuffer")

function FrmBuffer:__init(global_conf, buffer_conf)
    self.gconf = global_conf
    self.batch_size = buffer_conf.batch_size
    self.buffer_size = math.floor(buffer_conf.buffer_size /
                                    self.batch_size) * self.batch_size
    self.randomize = buffer_conf.randomize
    self.consume = buffer_conf.consume
    local cumat_type = global_conf.cumat_type
    if self.gconf.use_cpu then
        self.output_mat_type = self.gconf.mmat_type
    else
        self.output_mat_type = self.gconf.cumat_type
    end
    if buffer_conf.use_gpu then
        self.mat_type = cumat_type
        if self.gconf.use_cpu then
            -- gpu buffer -> cpu training
            nerv.error("not implemeted")
        else
            -- gpu buffer -> gpu training
            self.copy_rows_from_by_idx = cumat_type.copy_rows_fromd_by_idx
            self.copy_from = cumat_type.copy_fromd
        end
        self.perm_gen = function (x)
            return cumat_type.new_from_host(nerv.MMatrixFloat.perm_gen(x))
        end
    else
        self.mat_type = global_conf.mmat_type
        if self.gconf.use_cpu then
            -- cpu buffer -> cpu training
            self.copy_rows_from_by_idx = gconf.mmat_type.copy_rows_fromh_by_idx
            self.copy_from = gconf.mmat_type.copy_fromh
        else
            -- cpu buffer -> gpu training
            self.copy_rows_from_by_idx = cumat_type.copy_rows_fromh_by_idx
            self.copy_from = cumat_type.copy_fromh
        end
        self.perm_gen = nerv.MMatrixFloat.perm_gen
    end
    self.copy_from_reader = self.mat_type.copy_fromh
    self.head = 0
    self.tail = 0
    self.readers = {}
    for i, reader_spec in ipairs(buffer_conf.readers) do
        local buffs = {}
        for id, width in pairs(reader_spec.data) do
            buffs[id] = {data = self.mat_type(self.buffer_size, width),
                        leftover = nil,
                        width = width}
        end
        table.insert(self.readers, {buffs = buffs,
                                    reader = reader_spec.reader,
                                    tail = 0,
                                    has_leftover = false})
    end
end

function FrmBuffer:saturate()
    local buffer_size = self.buffer_size
    self.head = 0
    self.tail = buffer_size
    for i, reader in ipairs(self.readers) do
        reader.tail = 0
        if reader.has_leftover then
            local lrow
            for id, buff in pairs(reader.buffs) do
                lrow = buff.leftover:nrow()
                if lrow > buffer_size then
                    nerv.error("buffer size is too small to contain leftovers")
                end
                buff.data:copy_from(buff.leftover, 0, lrow)
                buff.leftover = nil
            end
            nerv.info("buffer leftover: %d\n", lrow)
            reader.tail = lrow
            reader.has_leftover = false
        end
        while reader.tail < buffer_size do
            local data = reader.reader:get_data()
            if data == nil then
                break
            end
            local drow = nil
            for id, d in pairs(data) do
                if drow == nil then
                    drow = d:nrow()
                elseif d:nrow() ~= drow then
                    nerv.error("reader provides with inconsistent rows of data")
                end
            end
            local remain = buffer_size - reader.tail
            if drow > remain then
                for id, buff in pairs(reader.buffs) do
                    local d = data[id]
                    if d == nil then
                        nerv.error("reader does not provide data for %s", id)
                    end
                    buff.leftover = self.mat_type(drow - remain,
                                                  buff.width)
                    self.copy_from_reader(buff.leftover, d, remain, drow)
                end
                drow = remain
                reader.has_leftover = true
            end
            for id, buff in pairs(reader.buffs) do
                self.copy_from_reader(buff.data, data[id], 0, drow, reader.tail)
            end
            reader.tail = reader.tail + drow
        end
        self.tail = math.min(self.tail, reader.tail)
    end
    self.rand_map = self.perm_gen(self.tail) -- generate shuffled index
    collectgarbage("collect")
    return self.tail >= self.batch_size
end

function FrmBuffer:get_data()
    local batch_size = self.batch_size
    if self.head >= self.tail then -- buffer is empty
        local t = os.clock()
        if (not self:saturate()) and (not self.consume) then
            return nil -- the remaining data cannot build a batch
        end
        if self.tail == self.head then
            return nil -- nothing left
        end
        nerv.info("%.3fs to fill the buffer", os.clock() - t)
    end
    if self.head + batch_size > self.tail and (not self.consume) then
        return nil -- the remaining data cannot build a batch
    end
    actual_batch_size = math.min(batch_size, self.tail - self.head)
    local res = {seq_length = table.vector(gconf.batch_size, 1),
                new_seq = {},
                data = {}}
    for i, reader in ipairs(self.readers) do
        for id, buff in pairs(reader.buffs) do
            local batch = self.output_mat_type(actual_batch_size, buff.width)
            if self.randomize then
                self.copy_rows_from_by_idx(batch, buff.data, self.rand_map, self.head)
            else
                self.copy_from(batch, buff.data, self.head, self.head + actual_batch_size)
            end
            res.data[id] = {batch}
        end
    end
    self.head = self.head + actual_batch_size
    return res
end