aboutsummaryrefslogblamecommitdiff
path: root/nerv/lib/matrix/generic/cumatrix.c
blob: 772b78d0986bd938a7057eb5a3ecca1f176b4327 (plain) (tree)
1
2
3
4
5
6
7
8
9








                                                                                    
                         


















                                                                       
                                            

























                                                                     
                                            






                                                                        
                                            








                                                                       
                                            






                                                                           
                                        

                                                       
                                        




                                                   
                                        











                                                  
                                            




                                                         
                                        



                              
                                            




                                                         
                                        



                              
                                            





                                                           
                                        




                                             
                                            




                                                         
                                        



                              
                                            





                                                                  
                                        

                                                    
                                        







                                           
                                            










                                                                   
                                            





                                                                   
                                            

















                                                                
                                            

















                                                                
                                            

















                                                                
                                            




                                                               
                                        












                                                                       
                                            









                                                                        
                                            






                                                                         
                                            









                                                                                 
                                        




                                  
                                            































                                                                               
                                            











                                                               
                                            










                                                                             
                                            










                                                                
                                            










                                                                
                                            



                                                                  
                                            








                                                                              
                                            



                   
#ifdef NERV_GENERIC_CUMATRIX
#include "matrix.h"
#include "elem_type.h"
#define MATRIX_DATA_FREE(ptr, status) cuda_matrix_(free)(ptr, status)
#define MATRIX_DATA_ALLOC(dptr, stride, width, height, status) \
                            cuda_matrix_(alloc)(dptr, stride, width, height, status)

#define NERV_GENERIC_MATRIX
#define NERV_GENERIC_CUKERNEL
#include "../../common.h"
#include "../cukernel.h"
#include "../cuda_helper.h"

void nerv_matrix_(add)(Matrix *c, const Matrix *a, const Matrix *b,
                            MATRIX_ELEM alpha, MATRIX_ELEM beta,
                            Status *status) {
    CHECK_SAME_DIMENSION(a, b, status);
    CHECK_SAME_DIMENSION(a, c, status);
    PROFILE_START
    CUBLAS_SAFE_SYNC_CALL(
            NERV_CUBLAS_(geam)(cublas_handle, CUBLAS_OP_N, CUBLAS_OP_N,
                a->ncol, a->nrow,
                &alpha,
                MATRIX_ELEM_PTR(a), a->stride / sizeof(MATRIX_ELEM),
                &beta,
                MATRIX_ELEM_PTR(b), b->stride / sizeof(MATRIX_ELEM),
                MATRIX_ELEM_PTR(c), c->stride / sizeof(MATRIX_ELEM)),
            status);
    PROFILE_STOP
    NERV_SET_STATUS(status, NERV_NORMAL,</