aboutsummaryrefslogblamecommitdiff
path: root/nerv/matrix/generic/cumatrix.c
blob: f8b80382554c429fc4c6c349e823136aff8bf54f (plain) (tree)
1
2
3
4
5
6
7
8
9
10
                            
                                               

                                                                                   
                                            
                                                

                             

                                            
                                              
 

                                                   


                                                                








                                                           


                                                           




                                                              











                                                           


















                                                                
                                              

                                                        
                                                           
                                                                        



                                                              

                                                                     


             

                                                     
                                                           
                                                                        



                                                              

                                                                   


             
                                                       
                  
                                                           





                                                                     
                                     

             
 
                                                

                                                                   
                                                           
                                                                        
                                                                            

                                                                  

                                                                      


             





                                                                   

                                                                        



                                     











                                                                           
                           
                                                                         

                                                                        





                                                                   














                                                                                


                                     
      
 
































                                                                      
                                                       






                                                 
                       

                                                 

                                                



                                             
                                     


                                                     

                                               
                                                     


                                             
                                                 
                                                                         
                                                                         
                                                                        
                                                                               



                                                               
                                                       
                                                   
                           

                                                                                     
      




                                                        


                     

 
      
#ifdef NERV_GENERIC_CUMATRIX
#include "../../lib/matrix/generic/elem_type.h"
#define MATRIX_DATA_WRITE(L, data, idx, val) cuda_matrix_(write)(L, data, idx, val)
#define MATRIX_DATA_READ(L, data, idx) cuda_matrix_(read)(L, data, idx)
#define MATRIX_INIT(L) cuda_matrix_(init)(L)
#define MATRIX_BASE_TNAME nerv_matrix_cuda_tname
#define NERV_GENERIC_MATRIX
#define NERV_GENERIC_CUKERNEL
#include "../../lib/common.h"
#include "../../lib/matrix/generic/matrix.h"
#include "../../lib/matrix/generic/cumatrix.h"

#define BLAS_OP_N CUBLAS_OP_N
static int nerv_matrix_(lua_get_blas_op)(char ch) {
    return (ch == 'T' || ch == 't') ? CUBLAS_OP_T : CUBLAS_OP_N;
}

static int nerv_matrix_(lua_prefixsum_row)(lua_State *L) {
    Status status;
    Matrix *a = luaT_checkudata(L, 1, nerv_matrix_(tname));
    Matrix *b = luaT_checkudata(L, 2, nerv_matrix_(tname));
    nerv_matrix_(prefixsum_row)(a, b, &status);
    NERV_LUA_CHECK_STATUS(L, status);
    return 0;
}

static int nerv_matrix_(lua_thres_mask)(lua_State *L) {
    Status status;
    Matrix *a = luaT_checkudata(L, 1, nerv_matrix_(tname));
    Matrix *b = luaT_checkudata(L, 2, nerv_matrix_(tname));
    MATRIX_ELEM thres = luaL_checknumber(L, 3);
    MATRIX_ELEM low = luaL_checknumber(L, 4);
    MATRIX_ELEM high = luaL_checknumber(L, 5);
    nerv_matrix_(thres_mask)(a, b, thres, low, high, &status);
    NERV_LUA_CHECK_STATUS(L, status);
    return 0;
}

static int nerv_matrix_(lua_rand_uniform)(lua_State *L) {
    Status