aboutsummaryrefslogblamecommitdiff
path: root/nerv/lib/matrix/generic/cukernel.cu
blob: cf9d213a51f8f6a708dd1bd8124d7540d9ad9fcb (plain) (tree)
1
2
3
4
5
6
7





                            
                   















                                                                       
                                                                                                          





                                                                 
                        




                      
































                                                                            






                                                                  

                                                                                             













                                                                     































































































































































                                                                                       







                                                             











                                                                                         
                           
                                                                                                                  
                                                                                                                        



                                                  
      


                                                                                     
      
                                                                                                   
                                                                                                       
                                                                                                                     
 






                                                                                                                                        
      


                                                                                     
      



                                                                                                                     
      
 























                                                                           








                                                                                


































                                                                               
                                                                              
                                                           
                                                                                 


                                                  
      




                                                                           




                                                       

 






                                                                                                 
      


                                                                              
      


                                          













                                                                                     
 














































                                                                      
                                                                         
                               
                                                                                                                  

                                
                                                                                                                        


              
                                                                                                           



                                                             

                                                                                     


                                 






















                                                                      






























































































































































































                                                                                                 









                                                             








                                                                   
    
                           
                                                                                                                              


                                                             
                                                                              
                                                                           
                                                                         


                                                           









                                                                                                                              
      
 




































                                                                           












                                                                                  









                                                                     

































                                                                                                       









                                                                

                                                                  
                                                                           



                                                                      
                                             
                                                                        

                                 










                                                                                                           

      
#ifdef NERV_GENERIC_CUKERNEL
#include <assert.h>
#include <stdio.h>
#include "../matrix.h"
#include "cuda.h"
#include "float.h"
#include "curand.h"
#define CUDA_THREADS_N 16
#define CUDA_THREADS_NN ((CUDA_THREADS_N) * (CUDA_THREADS_N))
#define CEIL_DIV(a, b) (((a) + (b) - 1) / (b))
__global__ void cudak_(log_elem)(const MATRIX_ELEM *a, MATRIX_ELEM *b, 
                                int nrow, int ncol, int stride) {
    int j = blockIdx.x * blockDim.x + threadIdx.x;
    int i = blockIdx.y * blockDim.y + threadIdx.y;
    long idx;
    MATRIX_ELEM tmp;
    if (i >= nrow || j >= ncol) return;
    idx = j + i * stride;
    tmp = a[idx];
    if(tmp < FLT_MIN) tmp = FLT_MIN;
    b[idx] = log(tmp);
}

__global__ void cudak_(thres_mask)(MATRIX_ELEM *a, MATRIX_ELEM *b, double thres, double low, double high, 
                                int nrow, int ncol, int stride) {
    int j = blockIdx.x * blockDim.x + threadIdx.x;
    int i = blockIdx.y * blockDim.y + threadIdx.y;
    long idx;
    if (i >= nrow || j >= ncol) return;
    idx = j + i * stride;
    if (b[idx] < thres) 
        a[idx] = low;
    else
        a[idx] = high;
}

__global__ void cudak_(mul_elem)(const MATRIX_ELEM *a, const MATRIX_ELEM *b,
                                MATRIX_ELEM *c, 
                                int nrow, int ncol, int stride) {
    int j = blockIdx.x * blockDim.x + threadIdx.x;
    int i = blockIdx.y * blockDim.y + threadIdx.y;
    long idx;
    if (i >= nrow || j >= ncol) return;
    idx = j + i * stride;
    c[idx] = a[idx] * b[idx];
}

__global__ void cudak_(sigmoid)(const MATRIX_ELEM *a, MATRIX_ELEM *b,
                        int nrow, int ncol, int stride) {
    int j = blockIdx.x * blockDim.x + threadIdx.x;
    int i = blockIdx.y * blockDim.y + threadIdx.y;
    long idx;
    if (i >= nrow || j >= ncol) return;
    idx = j + i * stride;
    b[idx] = 1.0 / (1.0 + exp(-a[idx]));
}

__global__ void cudak_(sigmoid_grad)(const MATRIX_ELEM *output,
                                    const MATRIX_ELEM *err,
                                    MATRIX_ELEM *nerr,
                                    int nrow, int ncol, int stride) {
    int j = blockIdx.x * blockDim.x + threadIdx.x;
    int i = blockIdx.y * blockDim.y + threadIdx.y;
    long idx;
    if (i >= nrow || j >= ncol) return;
    idx = j + i * stride;
    nerr[idx] = output[idx] * (1.0 - output[idx]) * err[idx];
}

__global__ void cudak_(tanh)(const MATRIX_ELEM *a, MATRIX_ELEM *b,
                        int nrow, int ncol, int stride) {
    int j = blockIdx.x * blockDim.x + threadIdx.x;
    int i = blockIdx.y * blockDim.y + threadIdx.y;
    long idx;
    if (i >= nrow || j >= ncol) return;
    idx = j + i * stride;
    //b[idx] = (exp(a[idx]) - exp(-a[idx])) / (exp(a[idx]) + exp(-a[idx])); //could cause nan
    b[idx] = tanh(a[idx]);
}

__global__ void cudak_(tanh_grad)(const MATRIX_ELEM *output,
                                    const MATRIX_ELEM *err,
                                    MATRIX_ELEM *nerr,
                                    int nrow, int ncol, int stride) {
    int j = blockIdx.x * blockDim.x + threadIdx.x;
    int i = blockIdx.y * blockDim.y + threadIdx.y;
    long idx;
    if (i >= nrow || j >= ncol) return;
    idx = j + i * stride;
    nerr[idx] = (1.0 - output[idx] * output[idx]) * err[idx];
}

__global__ void cudak_(softmax_final)(const MATRIX_ELEM *a, MATRIX_ELEM *b,
                        const MATRIX_ELEM *max, const MATRIX_ELEM *deno,
                        int nrow, int ncol, int stride, int mstride) {
    int j = blockIdx.x * blockDim.x + threadIdx.x;
    int i = blockIdx.y * blockDim.y + threadIdx.y;
    long idx;
    if (i >= nrow || j >= ncol) return;
    idx = j + i * stride;
    b[idx] = exp(a[idx] - max[0 + i * mstride]) / deno[0 + i * mstride];
}

__global__ void cudak_(block_reduce_rowsum)(const MATRIX_ELEM *input,
                                            MATRIX_ELEM *output,
                                            const int istride, const int ostride,
                                            const int n) {
    extern __shared__ MATRIX_ELEM cudak_(arr)[];
    int j = blockIdx.x * blockDim.x + threadIdx.x;
    cudak_(arr)[threadIdx.x] = j < n ? input[j + istride * blockIdx.y] : 0;
    __syncthreads();
    for (int offset = blockDim.x >> 1;  offset; offset >>= 1)
    {
        if (threadIdx.x < offset)
            cudak_(arr)[threadIdx.x] += cudak_(arr)[threadIdx.x + offset];
        __syncthreads();
    }
    if (threadIdx.x == 0)
        output[blockIdx.x + ostride * blockIdx.y] = cudak_(arr)[0];
}

__global__ void cudak_(block_reduce_colsum)(const MATRIX_ELEM *input,
                                MATRIX_ELEM *output,
                                const int istride, const int ostride,
                                const int n) {
    extern __shared__ MATRIX_ELEM cudak_(arr)[];
    int i = blockIdx.y * blockDim.y + threadIdx.y;
    cudak_(arr)[threadIdx.y] = i < n ? input[blockIdx.x + istride * i] : 0;
    __syncthreads();
    for (int offset = blockDim.y >> 1;  offset; offset >>= 1)
    {
        if (threadIdx.y < offset)
            cudak_(arr)[threadIdx.y] += cudak_(arr)[threadIdx.y + offset];
        __syncthreads();
    }
    if (threadIdx.y == 0)
        output[blockIdx.x + ostride * blockIdx.y] = cudak_(arr)[0];
}

__global__ void cudak_(block_reduce_colsame)(const MATRIX_ELEM *input,
                                            const MATRIX_ELEM *ref_input,
                                            MATRIX_ELEM *output,
                                            const int istride, const int ostride,
                                            const int n) {
    extern __shared__ MATRIX_ELEM cudak_(arr)[];
    int i = blockIdx.y * blockDim.y + threadIdx.y;
    cudak_(arr)[threadIdx.y] = (i < n && input[blockIdx.x + istride * i] == \
                                        ref_input[blockIdx.x + istride * i]) ? 1.0 : 0;
    __syncthreads();
    for (int offset = blockDim.y >> 1;  offset; offset >>= 1)
    {
        if (threadIdx.y < offset)
            cudak_(arr)[threadIdx.y] += cudak_(arr)[threadIdx.y + offset];
        __syncthreads();
    }
    if (threadIdx.y == 0)
        output[blockIdx.x + ostride * blockIdx.y] = cudak_(arr)[0];
}

__global__ void cudak_(block_reduce_softmax_rowsum)(const MATRIX_ELEM *input,
                                        MATRIX_ELEM *output,
                                        const MATRIX_ELEM *max,
                                        const int istride, const int ostride,
                                        const int mstride, const int n) {
    extern __shared__ MATRIX_ELEM cudak_(arr)[];
    int j = blockIdx.x * blockDim.x + threadIdx.x;
    cudak_(arr)[threadIdx.x] = j < n ? exp(input[j + istride * blockIdx.y] - \
                                    max[0 + mstride * blockIdx.y]) : 0;
    __syncthreads();
    for (int offset = blockDim.x >> 1;  offset; offset >>= 1)
    {
        if (threadIdx.x < offset)
            cudak_(arr)[threadIdx.x] += cudak_(arr)[threadIdx.x + offset];
        __syncthreads();
    }
    if (threadIdx.x == 0)
        output[blockIdx.x + ostride * blockIdx.y] = cudak_(arr)[0];
}

__global__ void cudak_(block_reduce_rowmax)(const MATRIX_ELEM *input,
                                            MATRIX_ELEM *output,
                                            const int istride, const int ostride,
                                            const int n) {
    extern __shared__ MATRIX_ELEM cudak_(arr)[];
    int j = blockIdx.x * blockDim.x + threadIdx.x;
    cudak_(arr)[threadIdx.x] = j < n ? input[j + istride * blockIdx.y] : -FLT_MAX;
    __syncthreads();
    for (int offset = blockDim.x >> 1;  offset; offset >>= 1)
    {
        if (threadIdx.x < offset)
        {
            MATRIX_ELEM l = cudak_(arr)[threadIdx.x],
                        r = cudak_(arr)[threadIdx.x + offset];
            if (r > l)
                cudak_(arr)[threadIdx.x] = r;
        }
        __syncthreads();
    }
    if (threadIdx.x == 0)
        output[blockIdx.x + ostride * blockIdx.y] = cudak_(arr)[0];
}

__global__ void cudak_(block_reduce_rowmax_idx)(const MATRIX_ELEM *input,
                                                const MATRIX_ELEM *idx_input,
                                                MATRIX_ELEM *output,
                                                MATRIX_ELEM *idx_output,
                                                const int istride, const int ostride,
                                                const int n) {
    extern __shared__ MATRIX_ELEM cudak_(arr)[];
    MATRIX_ELEM *arr_val = cudak_(arr);
    MATRIX_ELEM *arr_idx = arr_val + blockDim.x;
    int j = blockIdx.x * blockDim.x + threadIdx.x;
    arr_val[threadIdx.x] = j < n ? input[j + istride * blockIdx.y] : -FLT_MAX;
    arr_idx[threadIdx.x] = j < n ? idx_input[j + istride * blockIdx.y] : 0;
    __syncthreads();
    for (int offset = blockDim.x >> 1;  offset; offset >>= 1)
    {
        if (threadIdx.x < offset)
        {
            MATRIX_ELEM l = arr_val[threadIdx.x],
                        r = arr_val[threadIdx.x + offset];
            if (r > l)
            {
                arr_val[threadIdx.x] = r;
                arr_idx[threadIdx.x] = arr_idx[threadIdx.x + offset];
            }
        }
        __syncthreads();
    }
    if (threadIdx.x == 0)
    {
        output[blockIdx.x + ostride * blockIdx.y] = arr_val[0];
        idx_output[blockIdx.x + ostride * blockIdx.y] = arr_idx[0];
    }
}

__global__ void cudak_(add_row)(const MATRIX_ELEM *a, MATRIX_ELEM *b,
                                int nrow, int ncol, int stride, double beta) {
    int j = blockIdx.x * blockDim.x + threadIdx.x;
    int i = blockIdx.y * blockDim.y + threadIdx.y;
    if (i >= nrow || j >= ncol) return;
    b[j + i * stride] += beta * a[j];
}

__global__ void cudak_(fill)(MATRIX_ELEM *a,
                            int nrow, int ncol, int stride, double val) {
    int j = blockIdx.x * blockDim.x + threadIdx.x;
    int i = blockIdx.y * blockDim.y + threadIdx.y;
    if (i >= nrow || j >= ncol) return;
    a[j + i * stride] = val;
}

__global__ void cudak_(diagonalize)(MATRIX_ELEM *a,
                            int nrow, int ncol, int stride) {
    int j = blockIdx.x * blockDim.x + threadIdx.x;
    int i = blockIdx.y * blockDim.y + threadIdx.y;
    if (i >= nrow || j >= ncol || i == j) return;
    a[j + i * stride] = 0;
}

__global__ void cudak_(clip)(MATRIX_ELEM *a,
                            int nrow, int ncol, int stride, double val_1, double val_2) {
    int j = blockIdx.x * blockDim.x + threadIdx.x;
    int i = blockIdx.y * blockDim.y + threadIdx.y;
    if (i >= nrow || j >= ncol) return;
    if (a[j + i * stride] > val_2)
        a[j + i * stride] = val_2;
    else
        if (a[j + i * stride] < val_1)
            a[j + i * stride] = val_1;
}

#ifdef __NERV_FUTURE_CUDA_7
__global__ void cudak_(update_select_rows_by_rowidx)(MATRIX_ELEM *c, const MATRIX_ELEM *a, const MATRIX_ELEM *idx,
                            int nrow_a, int ncol_a, int nrow_c, int stride_c, int stride_a, double alpha, double beta) {
    int j = blockIdx.x * blockDim.x + threadIdx.x;
    int i = blockIdx.y * blockDim.y + threadIdx.y;
    if (i >= nrow_a || j >= ncol_a) return;
    int i_c = lrintf(idx[i]);
    /*
    if (i_c < 0 || i_c >= nrow_c) {
        printf("ERROR inside kernel update_select_rows, i_c(%d) out of range!", i_c);
    }
    */
    //critical: i_c could conflict among threads(same index in the idx array), so atomicAdd is used
    //c[j + i_c * stride_c] = c[j + i_c * stride_c] * (1 - beta * alpha) + a[j + i * stride_a] * alpha;
    atomicAdd_nvidia(c + j + i_c * stride_c, c[j + i_c * stride_c] * (- beta * alpha) + a[j + i * stride_a] * alpha);
}

__global__ void cudak_(update_select_rows_by_colidx)(MATRIX_ELEM *c, const MATRIX_ELEM *a, const MATRIX_ELEM *idx,
                            int nrow_a, int ncol_a, int nrow_c, int stride_c, int stride_a, int stride_idx, double alpha, double beta) {
    int j = blockIdx.x * blockDim.x + threadIdx.x;
    int i = blockIdx.y * blockDim.y + threadIdx.y;
    if (i >= nrow_a || j >= ncol_a) return;
    int i_c = lrintf(idx[stride_idx * i]);
    /*
    if (i_c < 0 || i_c >= nrow_c) {
        printf("ERROR inside kernel update_select_rows, i_c(%d) out of range!", i_c);
    }
    */
    //critical: i_c could conflict among threads(same index in the idx array), so atomicAdd is used
    //c[j + i_c * stride_c] = c[j + i_c * stride_c] * (1 - beta * alpha) + a[j + i * stride_a] * alpha; 
    atomicAdd_nvidia(c + j + i_c * stride_c, c[j + i_c * stride_c] * (- beta * alpha) + a[j + i * stride_a] * alpha);
}
#endif

__global__ void cudak_(expand_frm)(const MATRIX_ELEM *a, MATRIX_ELEM *b,
                                    int nrow, int ncol,
                                    int enrow, int encol,
                                    int stride, int estride,
                                    int context) {
    int j = blockIdx.x * blockDim.x + threadIdx.x;
    int i = blockIdx.y * blockDim.y + threadIdx.y;
    int ridx;
    if (i >= enrow || j >= encol) return;
    ridx = i + j / ncol - context;
    if (ridx < 0) ridx = 0;
    else if (ridx >= nrow) ridx = nrow - 1;
    b[j + i * estride] = a[j % ncol + ridx * stride];
}

__global__ void cudak_(rearrange_frm)(const MATRIX_ELEM *a, MATRIX_ELEM *b,
                                    int nrow, int ncol,
                                    int stride, int step, int orig_dim) {
    int j = blockIdx.x * blockDim.x + threadIdx.x;
    int i = blockIdx.y * blockDim.y + threadIdx.y;
    if (i >= nrow || j >= ncol) return;
    b[j + i * stride] = a[j / step + (j % step) * orig_dim + i * stride];
}

__global__ void cudak_(set_values_by_mask)(const MATRIX_ELEM *a, MATRIX_ELEM *b,
                                        int nrow, int ncol,
                                        int astride, int bstride, double val) {
    int j = blockIdx.x * blockDim.x + threadIdx.x;
    int i = blockIdx.y * blockDim.y + threadIdx.y;
    if (i >= nrow || j >= ncol || a[i * astride] != 0.0) return;
    b[j + i * bstride] = val;
}

__global__ void cudak_(scale_rows_by_col)(const MATRIX_ELEM *a, MATRIX_ELEM *b,
                                        int nrow, int ncol,
                                        int astride, int bstride) {
    int j = blockIdx.x * blockDim.x + threadIdx.x;
    int i = blockIdx.y * blockDim.y + threadIdx.y;
    if (i >= nrow || j >= ncol) return;
    b[j + i * bstride] *= a[i * astride];
}

__global__ void cudak_(scale_rows_by_row)(const MATRIX_ELEM *a, MATRIX_ELEM *b,
                                        int nrow, int ncol,
                                        int stride) {
    int j = blockIdx.x * blockDim.x + threadIdx.x;
    int i = blockIdx.y * blockDim.y + threadIdx.y;
    if (i >= nrow || j >= ncol) return;
    b[j + i * stride] *= a[j];
}

__global__ void cudak_(decompress)(const MATRIX_ELEM *a, MATRIX_ELEM *b,
                                    int nrow, int ncol,
                                    int stride_a, int stride_b) {
    int j = blockIdx.x * blockDim.x + threadIdx.x;
    int i = blockIdx.y * blockDim.y + threadIdx.y;
    if (i >= nrow || j >= ncol) return;
    b[lrintf(a[j + i * stride_a]) + i * stride_b] = 1.0;
}

__global__ void cudak_(gen_col_idx)(MATRIX_ELEM *b,
                                    int nrow, int ncol, int stride) {
    int j = blockIdx.x * blockDim.x + threadIdx.x;
    int i = blockIdx.y * blockDim.y + threadIdx.y;
    if (i >= nrow || j >= ncol) return;
    b[j + i * stride] = j;
}

__global__ void cudak_(copy_rows_by_idx)(const MATRIX_ELEM *a, MATRIX_ELEM *b,
                                    const MATRIX_ELEM *idx,
                                    int nrow, int ncol, int a_nrow, int stride) {
    int j = blockIdx.x * blockDim.x + threadIdx.x;
    int i = blockIdx.y * blockDim.y + threadIdx.y;
    if (i >= nrow || j >= ncol) return;
    /*
    int k = lrintf(idx[i]);
    if (k < 0 || k >= a_nrow) {
        printf("error in kernel copy_rows_by_idx k(%d) out of range\n", k);
    }
    b[j + i * stride] = a[j + k * stride];
    */
    /* NOTE: in most cases it is guaranteed
     * the idx is within the range, checking
     * would bring some overhead. */
    b[j + i * stride] = a[j + lrintf(idx[i]) * stride];
}

__global__ void cudak_(copy_rows_by_colidx)(const MATRIX_ELEM *a, MATRIX_ELEM *b,
                                    const MATRIX_ELEM *idx,
                                    int nrow, int ncol, int a_nrow, int stride, int idx_stride) {
    int j = blockIdx.x * blockDim.x + threadIdx.x;
    int i = blockIdx.y * blockDim.y + threadIdx.y;
    if (i >= nrow || j >= ncol) return;
    int k = lrintf(idx[i * idx_stride]);
    /*
    if (k < 0 || k >= a_nrow) {
        printf("error in kernel copy_rows_by_colidx k(%d) out of range\n", k);
    }
    */
    b[j + i * stride] = a[j + k * stride];
}

__global__ void cudak_(prefixsum_row_reduce)(const MATRIX_ELEM *a, MATRIX_ELEM *b,
                        int nrow, int ncol, int stride_a, int stride_b, int offset) {
    int j = blockIdx.x * blockDim.x + threadIdx.x;
    int i = blockIdx.y * blockDim.y + threadIdx.y;
    long idx_a, idx_b;
    if (i >= nrow || j >= ncol) return;
    idx_b = j + i * stride_b;
    idx_a = j + i * stride_a;
    //b[idx] = 1.0 / (1.0 + exp(-a[idx]));
    if (j >= offset) 
        b[idx_b] = a[idx_a] + a[idx_a - offset];
    else
        b[idx_b] = a[idx_a];
}

extern "C" {
#include "../cukernel.h"
    void cudak_(cuda_log_elem)(const Matrix *a, Matrix *b) {
        dim3 threadsPerBlock(CUDA_THREADS_N, CUDA_THREADS_N);
        dim3 numBlocks(CEIL_DIV(b->ncol, threadsPerBlock.x),
                CEIL_DIV(b->nrow, threadsPerBlock.y));
        cudak_(log_elem)<<<numBlocks, threadsPerBlock>>> \
            (MATRIX_ELEM_PTR(a), MATRIX_ELEM_PTR(b),
             b->nrow, b->ncol, b->stride / sizeof(MATRIX_ELEM));
        cudaStreamSynchronize(0);
    }

    void cudak_(cuda_mul_elem)(const Matrix *a, const Matrix *b,
                                Matrix *c) {
        dim3 threadsPerBlock(CUDA_THREADS_N, CUDA_THREADS_N);
        dim3 numBlocks(CEIL_DIV(b->ncol, threadsPerBlock.x),
                CEIL_DIV(b->nrow, threadsPerBlock.y));
        cudak_(mul_elem)<<<numBlocks, threadsPerBlock>>> \
            (MATRIX_ELEM_PTR(a), MATRIX_ELEM_PTR(b),
             MATRIX_ELEM_PTR(c),
             b->nrow, b->ncol, b->stride / sizeof(MATRIX_ELEM));
        cudaStreamSynchronize(0);
    }

    void cudak_(cuda_sigmoid)(const Matrix *a, Matrix *b) {
        dim3 threadsPerBlock(CUDA_THREADS_N, CUDA_THREADS_N);
        dim3 numBlocks(CEIL_DIV(b->ncol, threadsPerBlock.x),
                CEIL_DIV(b->nrow, threadsPerBlock.y));
        cudak_(sigmoid)<<<numBlocks, threadsPerBlock>>> \
            (MATRIX_ELEM_PTR(a), MATRIX_ELEM_PTR(b), b->nrow, b->ncol,
            b->stride / sizeof(MATRIX_ELEM));
        cudaStreamSynchronize(0);
    }

    void cudak_(cuda_sigmoid_grad)(const Matrix *output,
                                    const Matrix *err, Matrix *nerr) {
        dim3 threadsPerBlock(CUDA_THREADS_N, CUDA_THREADS_N);
        dim3 numBlocks(CEIL_DIV(nerr->ncol, threadsPerBlock.x),
                CEIL_DIV(nerr->nrow, threadsPerBlock.y));
        cudak_(sigmoid_grad)<<<numBlocks, threadsPerBlock>>> \
            (MATRIX_ELEM_PTR(output), MATRIX_ELEM_PTR(err),
             MATRIX_ELEM_PTR(nerr),
             nerr->nrow, nerr->ncol,
             nerr->stride / sizeof(MATRIX_ELEM));
        cudaStreamSynchronize(0);
    }

    void cudak_(cuda_rand_uniform)(const Matrix *a, CuContext *context) {
        #ifdef MATRIX_USE_FLOAT
        curandGenerateUniform(context->curand_gen, MATRIX_ELEM_PTR(a), a->nrow * a->stride / sizeof(MATRIX_ELEM));
        #endif
        #ifdef MATRIX_USE_DOUBLE
        curandGenerateUniformDouble(context->curand_gen, MATRIX_ELEM_PTR(a), a->nrow * a->stride / sizeof(MATRIX_ELEM));
        #endif
    }

    void cudak_(cuda_thres_mask)(const Matrix *a, const Matrix *b, double thres, double low, double high) {
        dim3 threadsPerBlock(CUDA_THREADS_N, CUDA_THREADS_N);
        dim3 numBlocks(CEIL_DIV(a->ncol, threadsPerBlock.x),
                CEIL_DIV(a->nrow, threadsPerBlock.y));
        cudak_(thres_mask)<<<numBlocks, threadsPerBlock>>> \
            (MATRIX_ELEM_PTR(a), MATRIX_ELEM_PTR(b),
                thres, low, high, a->nrow, a->ncol, a->stride / sizeof(MATRIX_ELEM));
        cudaStreamSynchronize(0);
    }

    void cudak_(cuda_tanh)(const Matrix *a, Matrix *b) {
        dim3 threadsPerBlock(CUDA_THREADS_N, CUDA_THREADS_N);
        dim3 numBlocks(CEIL_DIV(b->ncol, threadsPerBlock.x),
                CEIL_DIV(b->nrow, threadsPerBlock.y));
        cudak_(tanh)<<<numBlocks, threadsPerBlock>>> \
            (MATRIX_ELEM_PTR(a), MATRIX_ELEM_PTR(b), b->nrow, b->ncol,
            b->stride / sizeof(MATRIX_ELEM));
        cudaStreamSynchronize(0);
    }

    void cudak_(cuda_tanh_grad)(const Matrix *output,
                                    const Matrix *err, Matrix *nerr) {
        dim3 threadsPerBlock(CUDA_THREADS_N, CUDA_THREADS_N);
        dim3 numBlocks(CEIL_DIV(nerr->ncol, threadsPerBlock.x),
                CEIL_DIV(nerr->nrow, threadsPerBlock.y));
        cudak_(tanh_grad)<<<numBlocks, threadsPerBlock>>> \
            (MATRIX_ELEM_PTR(output), MATRIX_ELEM_PTR(err),
             MATRIX_ELEM_PTR(nerr),
             nerr->nrow, nerr->ncol,
             nerr->stride / sizeof(MATRIX_ELEM));
        cudaStreamSynchronize(0);
    }

    void cudak_(cuda_rowsum)(const Matrix *a, Matrix *b) {
        dim3 block(CUDA_THREADS_NN, 1);
        int ncol = a->ncol;
        int blocks_per_row = CEIL_DIV(ncol, block.x);
        dim3 grid(blocks_per_row, a->nrow);
        MATRIX_ELEM *res;
        size_t stride;
        cudaMallocPitch(&res, &stride, blocks_per_row * sizeof(MATRIX_ELEM), a->nrow);
        cudak_(block_reduce_rowsum)<<<grid, block, block.x * sizeof(MATRIX_ELEM)>>> \
            (MATRIX_ELEM_PTR(a), res,
             a->stride / sizeof(MATRIX_ELEM), stride / sizeof(MATRIX_ELEM),
             ncol);
        ncol = blocks_per_row;
        assert((unsigned long)ncol <= block.x);
        grid.x = 1;
        cudaStreamSynchronize(0);
        cudak_(block_reduce_rowsum)<<<grid, block, block.x * sizeof(MATRIX_ELEM)>>> \
            (res, MATRIX_ELEM_PTR(b),
             stride / sizeof(MATRIX_ELEM), b->stride / sizeof(MATRIX_ELEM),
             ncol);
        cudaStreamSynchronize(0);
        cudaFree(res);
    }

    void cudak_(cuda_colsame)(const Matrix *a, const Matrix *ref, Matrix *b) {
        dim3 block(1, CUDA_THREADS_NN);
        int nrow = a->nrow;
        int blocks_per_col = CEIL_DIV(nrow, block.y);
        dim3 grid(a->ncol, blocks_per_col);
        MATRIX_ELEM *res;
        size_t stride;
        cudaMallocPitch(&res, &stride, a->ncol * sizeof(MATRIX_ELEM), blocks_per_col);
        cudak_(block_reduce_colsame)<<<grid, block, block.y * sizeof(MATRIX_ELEM)>>> \
            (MATRIX_ELEM_PTR(a), MATRIX_ELEM_PTR(ref), res,
             a->stride / sizeof(MATRIX_ELEM), stride / sizeof(MATRIX_ELEM),
             nrow);
        nrow = blocks_per_col;
        assert((unsigned long)nrow <= block.y);
        grid.y = 1;
        cudaStreamSynchronize(0);
        cudak_(block_reduce_colsum)<<<grid, block, block.y * sizeof(MATRIX_ELEM)>>> \
            (res, MATRIX_ELEM_PTR(b),
             stride / sizeof(MATRIX_ELEM), b->stride / sizeof(MATRIX_ELEM),
             nrow);
        cudaStreamSynchronize(0);
        cudaFree(res);
    }

    void cudak_(cuda_colsum)(const Matrix *a, Matrix *b) {
        dim3 block(1, CUDA_THREADS_NN);
        int nrow = a->nrow;
        int blocks_per_col = CEIL_DIV(nrow, block.y);
        dim3 grid(a->ncol, blocks_per_col);
        MATRIX_ELEM *res;
        size_t stride;
        cudaMallocPitch(&res, &stride, a->ncol * sizeof(MATRIX_ELEM), blocks_per_col);
        cudak_(block_reduce_colsum)<<<grid, block, block.y * sizeof(MATRIX_ELEM)>>> \
            (MATRIX_ELEM_PTR(a), res,
             a->stride / sizeof(MATRIX_ELEM), stride / sizeof(MATRIX_ELEM),
             nrow);
        nrow = blocks_per_col;
        assert((unsigned long)nrow <= block.y);
        grid.y = 1;
        cudaStreamSynchronize(0);
        cudak_(block_reduce_colsum)<<<grid, block, block.y * sizeof(MATRIX_ELEM)>>> \
            (res, MATRIX_ELEM_PTR(b),
             stride / sizeof(MATRIX_ELEM), b->stride / sizeof(MATRIX_ELEM),
             nrow);
        cudaStreamSynchronize(0);
        cudaFree(res);
    }

    void cudak_(cuda_softmax_final)(const Matrix *a, const Matrix *max,
                            const Matrix *deno, Matrix *b) {
        dim3 threadsPerBlock(CUDA_THREADS_N, CUDA_THREADS_N);
        dim3 numBlocks(CEIL_DIV(b->ncol, threadsPerBlock.x),
                CEIL_DIV(b->nrow, threadsPerBlock.y));
        cudak_(softmax_final)<<<numBlocks, threadsPerBlock>>> \
                (MATRIX_ELEM_PTR(a), MATRIX_ELEM_PTR(b),
                MATRIX_ELEM_PTR(max), MATRIX_ELEM_PTR(deno),
                b->nrow, b->ncol,
                b->stride / sizeof(MATRIX_ELEM),
                max->stride / sizeof(MATRIX_ELEM));
        cudaStreamSynchronize(0);
    }

    void cudak_(cuda_softmax_denominator)(const Matrix *a, const Matrix *max, Matrix *b) {
        dim3 block(CUDA_THREADS_NN, 1);
        int ncol = a->ncol;
        int blocks_per_row = CEIL_DIV(ncol, block.x);
        dim3 grid(blocks_per_row, a->nrow);
        MATRIX_ELEM *res;
        size_t stride;
        assert(max->ncol == 1);
        cudaMallocPitch(&res, &stride, blocks_per_row * sizeof(MATRIX_ELEM), a->nrow);
        cudak_(block_reduce_softmax_rowsum) \
            <<<grid, block, block.x * sizeof(MATRIX_ELEM)>>> \
            (MATRIX_ELEM_PTR(a), res, MATRIX_ELEM_PTR(max),
             a->stride / sizeof(MATRIX_ELEM), stride / sizeof(MATRIX_ELEM),
             max->stride / sizeof(MATRIX_ELEM),
             ncol);
        ncol = blocks_per_row;
        assert((unsigned long)ncol <= block.x);
        grid.x = 1;
        cudaStreamSynchronize(0);
        cudak_(block_reduce_rowsum) \
            <<<grid, block, block.x * sizeof(MATRIX_ELEM)>>> \
            (res, MATRIX_ELEM_PTR(b),
             stride / sizeof(MATRIX_ELEM), b->stride / sizeof(MATRIX_ELEM),
             ncol);
        cudaStreamSynchronize(0);
        cudaFree(res);
    }

    void cudak_(cuda_rowmax)(const Matrix *a, Matrix *b) {
        dim3 block(CUDA_THREADS_NN, 1);
        int ncol = a->ncol;
        int blocks_per_row = CEIL_DIV(ncol, block.x);
        dim3 grid(blocks_per_row, a->nrow);
        MATRIX_ELEM *res;
        size_t stride;
        cudaMallocPitch(&res, &stride, blocks_per_row * sizeof(MATRIX_ELEM), a->nrow);
        cudak_(block_reduce_rowmax)<<<grid, block, block.x * sizeof(MATRIX_ELEM)>>> \
            (MATRIX_ELEM_PTR(a), res,
             a->stride / sizeof(MATRIX_ELEM), stride / sizeof(MATRIX_ELEM),
             ncol);
        ncol = blocks_per_row;
        assert((unsigned long)ncol <= block.x);
        grid.x = 1;
        cudaStreamSynchronize(0);
        cudak_(block_reduce_rowmax)<<<grid, block, block.x * sizeof(MATRIX_ELEM)>>> \
            (res, MATRIX_ELEM_PTR(b),
             stride / sizeof(MATRIX_ELEM), b->stride / sizeof(MATRIX_ELEM),
             ncol);
        cudaStreamSynchronize(0);
        cudaFree(res);
    }

    void cudak_(cuda_rowmax_idx)(const Matrix *a, Matrix *b, Matrix *b_idx) {
        dim3 block(CUDA_THREADS_NN, 1);
        int ncol = a->ncol;
        int blocks_per_row = CEIL_DIV(ncol, block.x);
        dim3 grid(blocks_per_row, a->nrow);
        MATRIX_ELEM *a_idx, *res, *res_idx;
        size_t stride;
        cudaMallocPitch(&a_idx, &stride, a->stride, a->nrow);
        cudak_(gen_col_idx)<<<grid, block>>>(a_idx, a->nrow, ncol, stride / sizeof(MATRIX_ELEM));
        cudaMallocPitch(&res, &stride, blocks_per_row * sizeof(MATRIX_ELEM), a->nrow);
        cudaMallocPitch(&res_idx, &stride, blocks_per_row * sizeof(MATRIX_ELEM), a->nrow);
        cudaStreamSynchronize(0);
        cudak_(block_reduce_rowmax_idx)<<<grid, block,
                                        2 * block.x * sizeof(MATRIX_ELEM)>>> \
            (MATRIX_ELEM_PTR(a), a_idx, res, res_idx,
             a->stride / sizeof(MATRIX_ELEM), stride / sizeof(MATRIX_ELEM),
             ncol);
        ncol = blocks_per_row;
        assert((unsigned long)ncol <= block.x);
        grid.x = 1;
        cudaStreamSynchronize(0);
        cudak_(block_reduce_rowmax_idx)<<<grid, block,
                                        2 * block.x * sizeof(MATRIX_ELEM)>>> \
            (res, res_idx, MATRIX_ELEM_PTR(b), MATRIX_ELEM_PTR(b_idx),
             stride / sizeof(MATRIX_ELEM), b->stride / sizeof(MATRIX_ELEM),
             ncol);
        cudaStreamSynchronize(0);
        cudaFree(a_idx);
        cudaFree(res);
        cudaFree(res_idx);
    }

    /* in-place calc */
    void cudak_(cuda_add_row)(const Matrix *a, Matrix *b, double beta) {
        dim3 threadsPerBlock(CUDA_THREADS_N, CUDA_THREADS_N);
        dim3 numBlocks(CEIL_DIV(b->ncol, threadsPerBlock.x),
                CEIL_DIV(b->nrow, threadsPerBlock.y));
        cudak_(add_row)<<<numBlocks, threadsPerBlock>>> \
            (MATRIX_ELEM_PTR(a), MATRIX_ELEM_PTR(b), b->nrow, b->ncol,
            b->stride / sizeof(MATRIX_ELEM), beta);
        cudaStreamSynchronize(0);
    }

    void cudak_(cuda_fill)(Matrix *a, double val) {
        dim3 threadsPerBlock(CUDA_THREADS_N, CUDA_THREADS_N);
        dim3 numBlocks(CEIL_DIV(a->ncol, threadsPerBlock.x),
                CEIL_DIV(a->nrow, threadsPerBlock.y));
        cudak_(fill)<<<numBlocks, threadsPerBlock>>> \
            (MATRIX_ELEM_PTR(a), a->nrow, a->ncol,
            a->stride / sizeof(MATRIX_ELEM), val);
        cudaStreamSynchronize(0);
    }

    void cudak_(cuda_diagonalize)(Matrix *a) {
        dim3 threadsPerBlock(CUDA_THREADS_N, CUDA_THREADS_N);
        dim3 numBlocks(CEIL_DIV(a->ncol, threadsPerBlock.x),
                CEIL_DIV(a->nrow, threadsPerBlock.y));
        cudak_(diagonalize)<<<numBlocks, threadsPerBlock>>> \
            (MATRIX_ELEM_PTR(a), a->nrow, a->ncol,
            a->stride / sizeof(MATRIX_ELEM));
        cudaStreamSynchronize(0);
    }

    void cudak_(cuda_clip)(Matrix *a, double val_1, double val_2) {
        dim3 threadsPerBlock(CUDA_THREADS_N, CUDA_THREADS_N);
        dim3 numBlocks(CEIL_DIV(a->ncol, threadsPerBlock.x),
                CEIL_DIV(a->nrow, threadsPerBlock.y));
        cudak_(clip)<<<numBlocks, threadsPerBlock>>> \
            (MATRIX_ELEM_PTR(a), a->nrow, a->ncol,
            a->stride / sizeof(MATRIX_ELEM), val_1, val_2);
        cudaStreamSynchronize(0);
    }
    
#ifdef __NERV_FUTURE_CUDA_7
    void cudak_(cuda_update_select_rows_by_rowidx)(Matrix *c, const Matrix *a, const Matrix *idx, double alpha, double beta) {
        dim3 threadsPerBlock(CUDA_THREADS_N, CUDA_THREADS_N);
        dim3 numBlocks(CEIL_DIV(a->ncol, threadsPerBlock.x),
                CEIL_DIV(a->nrow, threadsPerBlock.y));
        cudak_(update_select_rows_by_rowidx)<<<numBlocks, threadsPerBlock>>> \
            (MATRIX_ELEM_PTR(c), MATRIX_ELEM_PTR(a), MATRIX_ELEM_PTR(idx), 
             a->nrow, a->ncol, c->nrow, c->stride / sizeof(MATRIX_ELEM), 
             a->stride / sizeof(MATRIX_ELEM), alpha, beta);
        cudaStreamSynchronize(0);
    }
    void cudak_(cuda_update_select_rows_by_colidx)(Matrix *c, const Matrix *a, const Matrix *idx, double alpha, double beta) {
        dim3 threadsPerBlock(CUDA_THREADS_N, CUDA_THREADS_N);
        dim3 numBlocks(CEIL_DIV(a->ncol, threadsPerBlock.x),
                CEIL_DIV(a->nrow, threadsPerBlock.y));
        cudak_(update_select_rows_by_colidx)<<<numBlocks, threadsPerBlock>>> \
            (MATRIX_ELEM_PTR(c), MATRIX_ELEM_PTR(a), MATRIX_ELEM_PTR(idx), 
             a->nrow, a->ncol, c->nrow, c->stride / sizeof(MATRIX_ELEM), 
             a->stride / sizeof(MATRIX_ELEM), idx->stride / sizeof(MATRIX_ELEM), alpha, beta);
        cudaStreamSynchronize(0);
    }
#endif

    void cudak_(cuda_expand_frm)(const Matrix *a, Matrix *b, int context) {
        dim3 threadsPerBlock(CUDA_THREADS_N, CUDA_THREADS_N);
        dim3 numBlocks(CEIL_DIV(b->ncol, threadsPerBlock.x),
                CEIL_DIV(b->nrow, threadsPerBlock.y));
        cudak_(expand_frm)<<<numBlocks, threadsPerBlock>>> \
            (MATRIX_ELEM_PTR(a), MATRIX_ELEM_PTR(b),
             a->nrow, a->ncol,
             b->nrow, b->ncol,
             a->stride / sizeof(MATRIX_ELEM),
             b->stride / sizeof(MATRIX_ELEM),
             context);
        cudaStreamSynchronize(0);
    }

    void cudak_(cuda_rearrange_frm)(const Matrix *a, Matrix *b, int step) {
        dim3 threadsPerBlock(CUDA_THREADS_N, CUDA_THREADS_N);
        dim3 numBlocks(CEIL_DIV(b->ncol, threadsPerBlock.x),
                CEIL_DIV(b->nrow, threadsPerBlock.y));
        cudak_(rearrange_frm)<<<numBlocks, threadsPerBlock>>> \
            (MATRIX_ELEM_PTR(a), MATRIX_ELEM_PTR(b),
             b->nrow, b->ncol, b->stride / sizeof(MATRIX_ELEM),
             step, b->ncol / step);
        cudaStreamSynchronize(0);
    }

    void cudak_(cuda_scale_rows_by_col)(const Matrix *a, Matrix *b) {
        dim3 threadsPerBlock(CUDA_THREADS_N, CUDA_THREADS_N);
        dim3 numBlocks(CEIL_DIV(b->ncol, threadsPerBlock.x),
                CEIL_DIV(b->nrow, threadsPerBlock.y));
        cudak_(scale_rows_by_col)<<<numBlocks, threadsPerBlock>>> \
            (MATRIX_ELEM_PTR(a), MATRIX_ELEM_PTR(b),
             b->nrow, b->ncol,
             a->stride / sizeof(MATRIX_ELEM),
             b->stride / sizeof(MATRIX_ELEM));
        cudaStreamSynchronize(0);
    }

    void cudak_(cuda_set_values_by_mask)(const Matrix *a, Matrix *b, double val) {
        dim3 threadsPerBlock(CUDA_THREADS_N, CUDA_THREADS_N);
        dim3 numBlocks(CEIL_DIV(b->ncol, threadsPerBlock.x),
                CEIL_DIV(b->nrow, threadsPerBlock.y));
        cudak_(set_values_by_mask)<<<numBlocks, threadsPerBlock>>> \
            (MATRIX_ELEM_PTR(a), MATRIX_ELEM_PTR(b),
            b->nrow, b->ncol,
            a->stride / sizeof(MATRIX_ELEM),
            b->stride / sizeof(MATRIX_ELEM),
            val);
        cudaStreamSynchronize(0);
    }

    void cudak_(cuda_scale_rows_by_row)(const Matrix *a, Matrix *b) {
        dim3 threadsPerBlock(CUDA_THREADS_N, CUDA_THREADS_N);
        dim3 numBlocks(CEIL_DIV(b->ncol, threadsPerBlock.x),
                CEIL_DIV(b->nrow, threadsPerBlock.y));
        cudak_(scale_rows_by_row)<<<numBlocks, threadsPerBlock>>> \
            (MATRIX_ELEM_PTR(a), MATRIX_ELEM_PTR(b),
             b->nrow, b->ncol, b->stride / sizeof(MATRIX_ELEM));
        cudaStreamSynchronize(0);
    }

    void cudak_(cuda_prefixsum_row)(const Matrix *a, Matrix *b) {
        dim3 threadsPerBlock(CUDA_THREADS_N, CUDA_THREADS_N);
        dim3 numBlocks(CEIL_DIV(b->ncol, threadsPerBlock.x),
                CEIL_DIV(b->nrow, threadsPerBlock.y));
        
        MATRIX_ELEM *tmp[2];
        size_t tmp_stride[2];
        cudaMallocPitch(tmp, tmp_stride + 0, a->ncol * sizeof(MATRIX_ELEM), a->nrow);
        cudaMallocPitch(tmp + 1, tmp_stride + 1, a->ncol * sizeof(MATRIX_ELEM), a->nrow);
        
        int offset = 1;
        cudak_(prefixsum_row_reduce)<<<numBlocks, threadsPerBlock>>> \
            (MATRIX_ELEM_PTR(a), tmp[0], b->nrow, b->ncol,
            a->stride / sizeof(MATRIX_ELEM), tmp_stride[0] / sizeof(MATRIX_ELEM), offset);
        int pin = 0, pout = 1;

        for (offset = 2;offset <= a->ncol / 2;offset *= 2) {
            cudak_(prefixsum_row_reduce)<<<numBlocks, threadsPerBlock>>> \
                (tmp[pin], tmp[pout], b->nrow, b->ncol,
                tmp_stride[pin] / sizeof(MATRIX_ELEM), tmp_stride[pout] / sizeof(MATRIX_ELEM), offset);
            pin = 1 - pin; 
            pout = 1 - pout;
        }

        cudak_(prefixsum_row_reduce)<<<numBlocks, threadsPerBlock>>> \
            (tmp[pin], MATRIX_ELEM_PTR(b), b->nrow, b->ncol,
            tmp_stride[pin] / sizeof(MATRIX_ELEM), b->stride / sizeof(MATRIX_ELEM), offset);
        
        cudaFree(tmp[0]);
        cudaFree(tmp[1]);
        
        cudaStreamSynchronize(0);
    }

    void cudak_(cuda_decompress)(const Matrix *a, Matrix *b) {
        dim3 threadsPerBlock(1, CUDA_THREADS_NN);
        dim3 numBlocks(1, CEIL_DIV(a->nrow, threadsPerBlock.y));
        cudak_(decompress)<<<numBlocks, threadsPerBlock>>> \
            (MATRIX_ELEM_PTR(a), MATRIX_ELEM_PTR(b),
             a->nrow, a->ncol,
             a->stride / sizeof(MATRIX_ELEM),
             b->stride / sizeof(MATRIX_ELEM));
        cudaStreamSynchronize(0);
    }

    void cudak_(cuda_copy_rows_by_idx)(const Matrix *a, Matrix *b,
                                        const Matrix *idx, int idx_begin) {
        dim3 threadsPerBlock(CUDA_THREADS_NN, 1);
        dim3 numBlocks(CEIL_DIV(b->ncol, threadsPerBlock.x), b->nrow);
        cudak_(copy_rows_by_idx)<<<numBlocks, threadsPerBlock>>> \
            (MATRIX_ELEM_PTR(a), MATRIX_ELEM_PTR(b),
            MATRIX_ELEM_PTR(idx) + idx_begin,
            b->nrow, b->ncol, a->nrow, b->stride / sizeof(MATRIX_ELEM));
        cudaStreamSynchronize(0);
    }

    void cudak_(cuda_copy_rows_by_colidx)(const Matrix *a, Matrix *b,
                                        const Matrix *idx, int idx_begin) {
        dim3 threadsPerBlock(CUDA_THREADS_NN, 1);
        dim3 numBlocks(CEIL_DIV(b->ncol, threadsPerBlock.x), b->nrow);
        cudak_(copy_rows_by_colidx)<<<numBlocks, threadsPerBlock>>> \
            (MATRIX_ELEM_PTR(a), MATRIX_ELEM_PTR(b),
            MATRIX_ELEM_PTR(idx) + idx_begin,
            b->nrow, b->ncol, a->nrow, b->stride / sizeof(MATRIX_ELEM), idx->stride / sizeof(MATRIX_ELEM));
        cudaStreamSynchronize(0);
    }
}
#endif