aboutsummaryrefslogblamecommitdiff
path: root/matrix/generic/cukernel.cu
blob: 1d8b983d0a8b98a420aea25ac42971e692636a9c (plain) (tree)
1
2
3
4
5
6
7
8





                            
                                                             
                                              




















                                                                            









                                                                     











                                                                     










                                                                           
                                                                     
















                                                                           


















                                                                             


















                                                                              
                                                                     




















                                                                           







                                                                              







                                                                         































                                                                           
 

                        
                                                            
                                                             








                                                                
                                                             







                                                                
                                                           
                                                             






                                                                      

                                                                      
                                                             








                                                               
                                                          






                                                                                      
                                                                                     





                                                                           
                                                                                     





                                                                           





















                                                                                      

                                                                       
                                                             


















                                                                                          

                                                              






                                                                           

                                                              





                                                                           
                                                          






                                                                                      
                                                                                     





                                                                           
                                                                                     




                                                                           


                                                                        
                                                             





                                                                      

                                                   
                                                             





                                                            































                                                                           

      
#ifdef NERV_GENERIC_CUKERNEL
#include <assert.h>
#include <stdio.h>
#include "matrix.h"
#include "cuda.h"
#define CUDA_THREADS_N 16
#define CUDA_THREADS_NN ((CUDA_THREADS_N) * (CUDA_THREADS_N))
#define CEIL_DIV(a, b) (((a) + (b) - 1) / (b))
__global__ void cudak_(log_elem)(const MATRIX_ELEM *a, MATRIX_ELEM *b, 
                                int nrow, int ncol, int stride) {
    int j = blockIdx.x * blockDim.x + threadIdx.x;
    int i = blockIdx.y * blockDim.y + threadIdx.y;
    long idx;
    if (i >= nrow || j >= ncol) return;
    idx = j + i * stride;
    b[idx] = log(a[idx]);
}

__global__ void cudak_(mul_elem)(const MATRIX_ELEM *a, const MATRIX_ELEM *b,
                                MATRIX_ELEM *c, 
                                int nrow, int ncol, int stride) {
    int j = blockIdx.x * blockDim.x + threadIdx.x;
    int i = blockIdx.y * blockDim.y + threadIdx.y;
    long idx;
    if (i >= nrow || j >= ncol) return;
    idx = j + i * stride;
    c[idx] = a[idx] * b[idx];
}

__global__ void cudak_(sigmoid)(const MATRIX_ELEM *a, MATRIX_ELEM *b,
                        int nrow, int ncol, int stride) {
    int j = blockIdx.x * blockDim.x + threadIdx.x;
    int i = blockIdx.y * blockDim.y + threadIdx.y;
    long idx;
    if (i >= nrow || j >= ncol) return;
    idx = j + i * stride;
    b[idx] = 1.0 / (1.0 + exp(-a[idx]));
}

__global__ void cudak_(sigmoid_grad)(const MATRIX_ELEM *output,
                                    const MATRIX_ELEM *err,
                                    MATRIX_ELEM *nerr,
                                    int nrow, int ncol, int stride) {
    int j = blockIdx.x * blockDim.x + threadIdx.x;
    int i = blockIdx.y * blockDim.y + threadIdx.y;
    long idx;
    if (i >= nrow || j >= ncol) return;
    idx = j + i * stride;
    nerr[idx] = output[idx] * (1.0 - output[idx]) * err