aboutsummaryrefslogblamecommitdiff
path: root/nerv/test/matrix_func.lua
blob: 817d46373d391beccf585a6b5d6106507816b3f4 (plain) (tree)

































































































































                                                     
                                      






























                                    



                            





                                                        
function _pattern_fill(mat_type, m, n)
    local a = mat_type(m, n)
    for i = 0, m - 1 do
        row = a[i]
        for j = 0, n - 1 do
            row[j] = i + j
        end
    end
    return a
end

function _test_all_shape(mat_type, m, n, k, fill)
    local a = fill(mat_type, m, n)
    local b = fill(mat_type, m, n)
    local c = fill(mat_type, m, n)
    for i = 0, m - 1 do
        for j = 0, n - 1 do
            a[i][j] = i + j
        end
    end
    -- test sigmoid
    b:sigmoid(a)
    print(a)
    print(b)
    -- test add
    c:add(a, b, 1.0, 2.0)
    print(c)
    c:add(a, b, 1.0, -2.0)
    print(c)
    c:add(a, b, 1.0, 0.0)
    print(c)
    c:add(a, b, 0.0, 1.0)
    print(c)
    -- test mul
    a = fill(mat_type, m, k)
    b = fill(mat_type, k, n)
    print(a)
    print(b)
    c:mul(a, b, 1.0, 0.0, 'N', 'N')
    print(c)
    c:mul(a, b, 1.0, 0.5, 'N', 'N')
    print(c)
    c = mat_type(n, m)
    c:mul(b, a, 1.0, 0.0, 'T', 'T')
    print(c)
    -- test colsum
    print(c:colsum())
    -- test colsame
    print(c:colsame(c))
    local d = c:create()
    d:copy_from(c)
    for j = 0, m - 1 do
        d[math.min(j, c:nrow() - 1)][j] = -1
    end
    print(c:colsame(d))
    -- test rowsum
    print(c:rowsum())
    -- test rowmax
    print(c:rowmax())
    d:copy_from(c)
    for i = 0, n - 1 do
        d[i][0] = 9999
    end
    print(d:rowmax())
    -- test fill
    d:fill(0)
    for i = 0, n - 1 do
        d[i][math.min(i, d:ncol() - 1)] = 1.0
    end
    print(d)
    -- test rowmax_idx
    x, y = d:rowmax_idx()
    print(x)
    print(y)
    -- test trans
    print(c)
    c:mul(b:trans(), a:trans(), 1.0, 0.0, 'N', 'N')
    print(c)
    -- test decompress
    c = mat_type(n, 1)
    for i = 0, n - 1 do
        c[i][0] = i
    end
    local e = c:decompress(n)
    print(e)
    -- test copy_from
    d = mat_type(n, n)
    d:copy_from(e)
    print(d)
    -- test add_row
    a = mat_type(1, n)
    for i = 0, n - 1 do
        a[0][i] = i
    end
    d:add_row(a, 0.5)
    print(d)
    -- test clip
    e:copy_from(d)
    e:clip(1, 2)
    print(e)
    -- test sigmoid_grad
    a = fill(mat_type, m, n)
    b = fill(mat_type, m, n)
    c = a:create()
    c:sigmoid_grad(a, b)
    print(c)
    -- test softmax
    for i = 0, m - 1 do
        a[i][math.min(i, a:ncol() - 1)] = a[i][0] * n
    end
    print(a)
    e = c:softmax(a)
    print(c)
    print(e)
    -- test mul_elem
    a:mul_elem(c, c)
    print(a)
    -- test log_elem
    a:log_elem(a)
    print(a)
    -- test copy_rows_from_by_idx
    local idx = mat_type(1, n)
    a = fill(mat_type, n, m)
    b = mat_type(n, m)
    for i = 0, n - 1 do
        idx[0][i] = n - 1 - i
    end
    print(a)
    b:copy_rows_from_by_idx(a, idx)
    b = mat_type(2, m)
    b:copy_rows_from_by_idx(a, idx, 2)
    print(a)
    print(b)
    -- test expand_frm
    a = mat_type(m, n)
    for i = 0, m - 1 do
        for j = 0, n - 1 do
            a[i][j] = i
        end
    end
    c = mat_type(m, n * (2 * k + 1))
    c:expand_frm(a, k)
    print(a)
    print(c)
    -- test rearrange_frm
    a = c:create()
    a:rearrange_frm(c, n)
    print(a)
    -- test scale_rows_by_row
    a = mat_type(n, m)
    a:fill(2)
    b = fill(mat_type, 1, m)
    c = a:create()
    print(a)
    a:scale_rows_by_row(b)
    print(a)
    -- test scale_rows_by_col
    a = fill(mat_type, m, n)
    b = fill(mat_type, m, 1)
    print(a)
    a:scale_rows_by_col(b)
    print(a)
    a = fill(mat_type, 3, 4)
    local c = a:create()
    c:tanh(a)
    print(c)
end
function test_all(mat_type)
    _test_all_shape(mat_type, 3, 4, 2, _pattern_fill)
    _test_all_shape(mat_type, 30, 40, 20, _pattern_fill)
    _test_all_shape(mat_type, 10, 10, 10, _pattern_fill)
end