aboutsummaryrefslogblamecommitdiff
path: root/nerv/matrix/generic/cumatrix.c
blob: cb559019721a91225f1eb0c9719919434dc65a82 (plain) (tree)
1
2
3
4
5
6
7
8
9
10
                            
                                               

                                                                                   
                                            
                                                

                             

                                            
                                              
 

                                                   


                                                                


                                                           




                                                              











                                                           


















                                                                
                                              

                                                        
                                                           
                                                                        



                                                              

                                                                     


             

                                                     
                                                           
                                                                        



                                                              

                                                                   


             
                                                       
                  
                                                           





                                                                     
                                     

             
 
                                                

                                                                   
                                                           
                                                                        
                                                                            

                                                                  

                                                                      


             





                                                                   

                                                                        



                                     











                                                                           
                           
                                                                         

                                                                        





                                                                   














                                                                                


                                     
      
 
































                                                                      
                                                       






                                                 
                       

                                                 

                                                



                                             
                                     


                                                     

                                               
                                                     


                                             
                                                 
                                                                         
                                                                         
                                                                        
                                                                               



                                                               
                           

                                                                                     
      




                                                        


                     

 
      
#ifdef NERV_GENERIC_CUMATRIX
#include "../../lib/matrix/generic/elem_type.h"
#define MATRIX_DATA_WRITE(L, data, idx, val) cuda_matrix_(write)(L, data, idx, val)
#define MATRIX_DATA_READ(L, data, idx) cuda_matrix_(read)(L, data, idx)
#define MATRIX_INIT(L) cuda_matrix_(init)(L)
#define MATRIX_BASE_TNAME nerv_matrix_cuda_tname
#define NERV_GENERIC_MATRIX
#define NERV_GENERIC_CUKERNEL
#include "../../lib/common.h"
#include "../../lib/matrix/generic/matrix.h"
#include "../../lib/matrix/generic/cumatrix.h"

#define BLAS_OP_N CUBLAS_OP_N
static int nerv_matrix_(lua_get_blas_op)(char ch) {
    return (ch == 'T' || ch == 't') ? CUBLAS_OP_T : CUBLAS_OP_N;
}

static int nerv_matrix_(lua_thres_mask)(lua_State *L) {
    Status status;
    Matrix *a = luaT_checkudata(L, 1, nerv_matrix_(tname));
    Matrix *b = luaT_checkudata(L, 2, nerv_matrix_(tname));
    MATRIX_ELEM thres = luaL_checknumber(L, 3);
    MATRIX_ELEM low = luaL_checknumber(L, 4);
    MATRIX_ELEM high = luaL_checknumber(L, 5);
    nerv_matrix_(thres_mask)(a, b, thres, low, high, &status);
    NERV_LUA_CHECK_STATUS(L, status);
    return 0;
}

static int nerv_matrix_(lua_rand_uniform)(lua_State *L) {
    Status status;
    Matrix *a = luaT_checkudata(L, 1, nerv_matrix_(tname));
    nerv_matrix_(rand_uniform)(a, &status);
    NERV_LUA_CHECK_STATUS(L, status);
    return 0;
}

static int nerv_matrix_(lua_tanh)(lua_State *L) {
    Status status;
    Matrix *a = luaT_checkudata(L, 1, nerv_matrix_(tname));
    Matrix *b = luaT_checkudata(L, 2, nerv_matrix_(tname));
    nerv_matrix_(tanh)(a, b, &status);
    NERV_LUA_CHECK_STATUS(L, status);
    return 0;
}

static int nerv_matrix_(lua_tanh_grad)(lua_State *L) {
    Status status;
    Matrix *nerr = luaT_checkudata(L, 1, nerv_matrix_(tname));
    Matrix *err = luaT_checkudata(L, 2, nerv_matrix_(tname));
    Matrix *output = luaT_checkudata(L, 3, nerv_matrix_(tname));
    nerv_matrix_(tanh_grad)(nerr, err, output, &status);
    NERV_LUA_CHECK_STATUS(L, status);
    return 0;
}

extern const char *MATRIX_CUMATRIX_HOST_TNAME;
static int nerv_matrix_(lua_copy_fromh)(lua_State *L) { 
    Status status;
    Matrix *a = luaT_checkudata(L, 1, nerv_matrix_(tname));
    const Matrix *b = luaT_checkudata(L, 2, MATRIX_CUMATRIX_HOST_TNAME);
    int nargs = lua_gettop(L);
    int b_begin = nargs > 2 ? luaL_checkinteger(L, 3) : 0;
    int b_end = nargs > 3 ? luaL_checkinteger(L, 4) : b->nrow;
    int a_begin = nargs > 4 ? luaL_checkinteger(L, 5) : 0;
    nerv_matrix_(copy_fromh)(a, b, a_begin, b_begin, b_end, &status);
    NERV_LUA_CHECK_STATUS(L, status);
    return 0;
}

static int nerv_matrix_(lua_copy_toh)(lua_State *L) {
    Status status;
    Matrix *a = luaT_checkudata(L, 1, nerv_matrix_(tname));
    const Matrix *b = luaT_checkudata(L, 2, MATRIX_CUMATRIX_HOST_TNAME);
    int nargs = lua_gettop(L);
    int a_begin = nargs > 2 ? luaL_checkinteger(L, 3) : 0;
    int a_end = nargs > 3 ? luaL_checkinteger(L, 4) : a->nrow;
    int b_begin = nargs > 4 ? luaL_checkinteger(L, 5) : 0;
    nerv_matrix_(copy_toh)(a, b, a_begin, a_end, b_begin, &status);
    NERV_LUA_CHECK_STATUS(L, status);
    return 0;
}

static int nerv_matrix_(lua_copy_fromd)(lua_State *L) {
    Status status;
    Matrix *a = luaT_checkudata(L, 1, nerv_matrix_(tname));
    const Matrix *b = luaT_checkudata(L, 2, nerv_matrix_(tname));
    int nargs = lua_gettop(L);
    int b_begin = nargs > 2 ? luaL_checkinteger(L, 3) : 0;
    int b_end = nargs > 3 ? luaL_checkinteger(L, 4) : b->nrow;
    int a_begin = nargs > 4 ? luaL_checkinteger(L, 5) : 0;
    nerv_matrix_(copy_fromd)(a, b, a_begin, b_begin, b_end, &status);
    NERV_LUA_CHECK_STATUS(L, status);
    return 0;
}

extern const char *nerv_matrix_host_float_tname;
static int nerv_matrix_(lua_copy_rows_fromh_by_idx)(lua_State *L) {
    Status status;
    Matrix *a = luaT_checkudata(L, 1, nerv_matrix_(tname));
    const Matrix *b = luaT_checkudata(L, 2, MATRIX_CUMATRIX_HOST_TNAME);
    const Matrix *idx = luaT_checkudata(L, 3, nerv_matrix_host_float_tname);
    long nrow = a->nrow;
    int b_begin = lua_gettop(L) > 3 ? luaL_checkinteger(L, 4) : 0;
    nerv_matrix_(copy_rows_fromh_by_idx)(a, b, idx, b_begin, &status);
    NERV_LUA_CHECK_STATUS(L, status);
    return 0;
}

static int nerv_matrix_(lua_copy_rows_fromd_by_idx)(lua_State *L) {
    Status status;
    Matrix *a = luaT_checkudata(L, 1, nerv_matrix_(tname));
    const Matrix *b = luaT_checkudata(L, 2, nerv_matrix_(tname));
    const Matrix *idx = luaT_checkudata(L, 3, nerv_matrix_(tname));
    long nrow = a->nrow;
    int idx_begin = lua_gettop(L) > 3 ? luaL_checkinteger(L, 4) : 0;
    nerv_matrix_(copy_rows_fromd_by_idx)(a, b, idx, idx_begin, &status);
    NERV_LUA_CHECK_STATUS(L, status);
    return 0;
}

static int nerv_matrix_(lua_copy_rows_fromd_by_colidx)(lua_State *L) {
    Status status;
    Matrix *a = luaT_checkudata(L, 1, nerv_matrix_(tname));
    const Matrix *b = luaT_checkudata(L, 2, nerv_matrix_(tname));
    const Matrix *idx = luaT_checkudata(L, 3, nerv_matrix_(tname));
    long nrow = a->nrow;
    int idx_begin = lua_gettop(L) > 3 ? luaL_checkinteger(L, 4) : 0;
    nerv_matrix_(copy_rows_fromd_by_colidx)(a, b, idx, idx_begin, &status);
    NERV_LUA_CHECK_STATUS(L, status);
    return 0;
}

#ifdef __NERV_FUTURE_CUDA_7
static int nerv_matrix_(lua_update_select_rows_by_rowidx)(lua_State *L) {
    /* update c's select rows,
     * i.e. c[idx[i]] = c[idx[i]] * (1 - beta * alpha) + a[i] * alpha */
    Status status;
    Matrix *c = luaT_checkudata(L, 1, nerv_matrix_(tname));
    const Matrix *a = luaT_checkudata(L, 2, nerv_matrix_(tname));
    const Matrix *idx = luaT_checkudata(L, 3, nerv_matrix_(tname));
    MATRIX_ELEM alpha = luaL_checknumber(L, 4);
    MATRIX_ELEM beta = luaL_checknumber(L, 5);
    nerv_matrix_(update_select_rows_by_rowidx)(c, a, idx, alpha, beta, &status);
    NERV_LUA_CHECK_STATUS(L, status);
    return 0;
}

static int nerv_matrix_(lua_update_select_rows_by_colidx)(lua_State *L) {
    /* update c's select rows,
     * i.e. c[idx[i]] = c[idx[i]] * (1 - beta * alpha) + a[i] * alpha */
    Status status;
    Matrix *c = luaT_checkudata(L, 1, nerv_matrix_(tname));
    const Matrix *a = luaT_checkudata(L, 2, nerv_matrix_(tname));
    const Matrix *idx = luaT_checkudata(L, 3, nerv_matrix_(tname));
    MATRIX_ELEM alpha = luaL_checknumber(L, 4);
    MATRIX_ELEM beta = luaL_checknumber(L, 5);
    nerv_matrix_(update_select_rows_by_colidx)(c, a, idx, alpha, beta, &status);
    NERV_LUA_CHECK_STATUS(L, status);
    return 0;
}
#endif

int nerv_matrix_(lua_get_elem)(lua_State *L) {
    return nerv_error_method_not_implemented(L);
}

int nerv_matrix_(lua_set_elem)(lua_State *L) {
    return nerv_error_method_not_implemented(L);
}

static MATRIX_ELEM cuda_matrix_(read)(lua_State *L, MATRIX_ELEM *data,
                                    int idx) {
    cudaError_t err;
    MATRIX_ELEM res;
    err = cudaMemcpy(&res, data + idx,
                sizeof(MATRIX_ELEM), cudaMemcpyDeviceToHost);
    if (err != cudaSuccess)
        nerv_error(L, "cuda error: error while reading element");
    cudaDeviceSynchronize();
    return res;
}

static void cuda_matrix_(write)(lua_State *L, MATRIX_ELEM *data,
                                int idx, MATRIX_ELEM val) {
    cudaError_t err;
    err = cudaMemcpy(data + idx, &val,
                sizeof(MATRIX_ELEM), cudaMemcpyHostToDevice);
    if (err != cudaSuccess)
        nerv_error(L, "cuda error: error while writing element");
    cudaDeviceSynchronize();
}

static void cuda_matrix_(init)(lua_State *L);
#include "matrix.c"

static const luaL_Reg nerv_matrix_(extra_methods)[] = {
    {"colsum", nerv_matrix_(lua_colsum)},
    {"colsame", nerv_matrix_(lua_colsame)},
    {"rowsum", nerv_matrix_(lua_rowsum)},
    {"rowmax", nerv_matrix_(lua_rowmax)},
    {"rowmax_idx", nerv_matrix_(lua_rowmax_idx)},
    {"trans", nerv_matrix_(lua_trans)},
    {"decompress", nerv_matrix_(lua_decompress)},
    /* in-place calc */
    {"copy_fromh", nerv_matrix_(lua_copy_fromh)},
    {"copy_fromd", nerv_matrix_(lua_copy_fromd)},
    /* alias for copy_fromd */
    {"copy_from", nerv_matrix_(lua_copy_fromd)},
    {"copy_toh", nerv_matrix_(lua_copy_toh)},
    {"add", nerv_matrix_(lua_add)},
    {"mul", nerv_matrix_(lua_mul)},
    {"add_row", nerv_matrix_(lua_add_row)},
    {"clip", nerv_matrix_(lua_clip)},
    {"fill", nerv_matrix_(lua_fill)},
    {"sigmoid", nerv_matrix_(lua_sigmoid)},
    {"sigmoid_grad", nerv_matrix_(lua_sigmoid_grad)},
    {"tanh", nerv_matrix_(lua_tanh)},
    {"tanh_grad", nerv_matrix_(lua_tanh_grad)},
    {"rand_uniform", nerv_matrix_(lua_rand_uniform)},
    {"softmax", nerv_matrix_(lua_softmax)},
    {"mul_elem", nerv_matrix_(lua_mul_elem)},
    {"log_elem", nerv_matrix_(lua_log_elem)},
    {"thres_mask", nerv_matrix_(lua_thres_mask)},
    {"copy_rows_fromh_by_idx", nerv_matrix_(lua_copy_rows_fromh_by_idx)},
    {"copy_rows_fromd_by_idx", nerv_matrix_(lua_copy_rows_fromd_by_idx)},
    {"copy_rows_from_by_idx", nerv_matrix_(lua_copy_rows_fromd_by_idx)},
    {"copy_rows_fromd_by_colidx", nerv_matrix_(lua_copy_rows_fromd_by_colidx)},
    {"expand_frm", nerv_matrix_(lua_expand_frm)},
    {"rearrange_frm", nerv_matrix_(lua_rearrange_frm)},
    {"scale_rows_by_row", nerv_matrix_(lua_scale_rows_by_row)},
    {"scale_rows_by_col", nerv_matrix_(lua_scale_rows_by_col)},
#ifdef __NERV_FUTURE_CUDA_7
    {"update_select_rows_by_rowidx", nerv_matrix_(lua_update_select_rows_by_rowidx)},
    {"update_select_rows_by_colidx", nerv_matrix_(lua_update_select_rows_by_colidx)},
#endif
    {NULL, NULL}
};

static void cuda_matrix_(init)(lua_State *L) {
    luaN_append_methods(L, nerv_matrix_(extra_methods));
#ifdef CUMATRIX_INIT
    CUMATRIX_INIT(L);
#endif
}

#endif