#ifdef NERV_GENERIC_CUKERNEL #include #include #include "../matrix.h" #include "cuda.h" #include "float.h" #include "curand.h" #define CUDA_THREADS_N 16 #define CUDA_THREADS_NN ((CUDA_THREADS_N) * (CUDA_THREADS_N)) #define CEIL_DIV(a, b) (((a) + (b) - 1) / (b)) __global__ void cudak_(log_elem)(const MATRIX_ELEM *a, MATRIX_ELEM *b, int nrow, int ncol, int stride) { int j = blockIdx.x * blockDim.x + threadIdx.x; int i = blockIdx.y * blockDim.y + threadIdx.y; long idx; MATRIX_ELEM tmp; if (i >= nrow || j >= ncol) return; idx = j + i * stride; tmp = a[idx]; if(tmp < FLT_MIN) tmp = FLT_MIN; b[idx] = log(tmp); } __global__ void cudak_(thres_mask)(MATRIX_ELEM *a, MATRIX_ELEM *b, double thres, double low, double high, int nrow, int ncol, int stride) { int j = blockIdx.x * blockDim.x + threadIdx.x; int i = blockIdx.y * blockDim.y + threadIdx.y; long idx; if (i >= nrow || j >= ncol) return; idx = j + i * stride; if (b[idx] < thres) a[idx] = low; else a[idx] = high; } __global__ void cudak_(mul_elem)(const MATRIX_ELEM *a, const MATRIX_ELEM *b, MATRIX_ELEM *c, int nrow, int ncol, int stride) { int j = blockIdx.x * blockDim.x + threadIdx.x; int i = blockIdx.y * blockDim.y + threadIdx.y; long idx; if (i >= nrow || j >= ncol) return; idx = j + i * stride; c[idx] = a[idx] * b[idx]; } __global__ void cudak_(sigmoid)(const MATRIX_ELEM *a, MATRIX_ELEM *b, int nrow, int ncol, int stride) { int j = blockIdx.x * blockDim.x + threadIdx.x; int i = blockIdx.y * blockDim.y + threadIdx.y; long idx; if (i >= nrow || j >= ncol) return; idx = j + i * stride; b[idx] = 1.0 / (1.0 + exp(-a[idx])); } __global__ void cudak_(sigmoid_grad)(const MATRIX_ELEM *output, const MATRIX_ELEM *err, MATRIX_ELEM *nerr, int nrow, int ncol, int stride) { int j = blockIdx.x * blockDim.x + threadIdx.x; int i = blockIdx.y * blockDim.y + threadIdx.y; long idx; if (i >= nrow || j >= ncol) return; idx = j + i * stride; nerr[idx] = output[idx] * (1.0 - output[idx]) * err[idx]; } __global__ void cudak_(tanh)(const MATRIX_ELEM *a, MATRIX_ELEM *b, int nrow, int ncol, int stride) { int j = blockIdx.x * blockDim.x + threadIdx.x; int i = blockIdx.y * blockDim.y + threadIdx.y; long idx; if (i >= nrow || j >= ncol) return; idx = j + i * stride; //b[idx] = (exp(a[idx]) - exp(-a[idx])) / (exp(a[idx]) + exp(-a[idx])); //could cause nan b[idx] = tanh(a[idx]); } __global__ void cudak_(tanh_grad)(const MATRIX_ELEM *output, const MATRIX_ELEM *err, MATRIX_ELEM *nerr, int nrow, int ncol, int stride) { int j = blockIdx.x * blockDim.x + threadIdx.x; int i = blockIdx.y * blockDim.y + threadIdx.y; long idx; if (i >= nrow || j >= ncol) return; idx = j + i * stride; nerr[idx] = (1.0 - output[idx] * output[idx]) * err[idx]; } __global__ void cudak_(relu)(const MATRIX_ELEM *a, MATRIX_ELEM *b, int nrow, int ncol, int stride) { int j = blockIdx.x * blockDim.x + threadIdx.x; int i = blockIdx.y * blockDim.y + threadIdx.y; long idx; if (i >= nrow || j >= ncol) return; idx = j + i * stride; b[idx] = a[idx] > 0 ? a[idx] : 0; } __global__ void cudak_(relu_grad)(const MATRIX_ELEM *output, const MATRIX_ELEM *err, MATRIX_ELEM *nerr, int nrow, int ncol, int stride) { int j = blockIdx.x * blockDim.x + threadIdx.x; int i = blockIdx.y * blockDim.y + threadIdx.y; long idx; if (i >= nrow || j >= ncol) return; idx = j + i * stride; nerr[idx] = (output[idx] > 0 ? 1 : 0) * err[idx]; } __global__ void cudak_(softmax_final)(const MATRIX_ELEM *a, MATRIX_ELEM *b, const MATRIX_ELEM *max, const MATRIX_ELEM *deno, int nrow, int ncol, int stride, int mstride) { int j = blockIdx.x * blockDim.x + threadIdx.x; int i = blockIdx.y * blockDim.y + threadIdx.y; long idx; if (i >= nrow || j >= ncol) return; idx = j + i * stride; b[idx] = exp(a[idx] - max[0 + i * mstride]) / deno[0 + i * mstride]; } __global__ void cudak_(block_reduce_rowsum)(const MATRIX_ELEM *input, MATRIX_ELEM *output, const int istride, const int ostride, const int n) { extern __shared__ MATRIX_ELEM cudak_(arr)[]; int j = blockIdx.x * blockDim.x + threadIdx.x; cudak_(arr)[threadIdx.x] = j < n ? input[j + istride * blockIdx.y] : 0; __syncthreads(); for (int offset = blockDim.x >> 1; offset; offset >>= 1) { if (threadIdx.x < offset) cudak_(arr)[threadIdx.x] += cudak_(arr)[threadIdx.x + offset]; __syncthreads(); } if (threadIdx.x == 0) output[blockIdx.x + ostride * blockIdx.y] = cudak_(arr)[0]; } __global__ void cudak_(block_reduce_colsum)(const MATRIX_ELEM *input, MATRIX_ELEM *output, const int istride, const int ostride, const int n) { extern __shared__ MATRIX_ELEM cudak_(arr)[]; int i = blockIdx.y * blockDim.y + threadIdx.y; cudak_(arr)[threadIdx.y] = i < n ? input[blockIdx.x + istride * i] : 0; __syncthreads(); for (int offset = blockDim.y >> 1; offset; offset >>= 1) { if (threadIdx.y < offset) cudak_(arr)[threadIdx.y] += cudak_(arr)[threadIdx.y + offset]; __syncthreads(); } if (threadIdx.y == 0) output[blockIdx.x + ostride * blockIdx.y] = cudak_(arr)[0]; } __global__ void cudak_(block_reduce_colsame)(const MATRIX_ELEM *input, const MATRIX_ELEM *ref_input, MATRIX_ELEM *output, const int istride, const int ostride, const int n) { extern __shared__ MATRIX_ELEM cudak_(arr)[]; int i = blockIdx.y * blockDim.y + threadIdx.y; cudak_(arr)[threadIdx.y] = (i < n && input[blockIdx.x + istride * i] == \ ref_input[blockIdx.x + istride * i]) ? 1.0 : 0; __syncthreads(); for (int offset = blockDim.y >> 1; offset; offset >>= 1) { if (threadIdx.y < offset) cudak_(arr)[threadIdx.y] += cudak_(arr)[threadIdx.y + offset]; __syncthreads(); } if (threadIdx.y == 0) output[blockIdx.x + ostride * blockIdx.y] = cudak_(arr)[0]; } __global__ void cudak_(block_reduce_softmax_rowsum)(const MATRIX_ELEM *input, MATRIX_ELEM *output, const MATRIX_ELEM *max, const int istride, const int ostride, const int mstride, const int n) { extern __shared__ MATRIX_ELEM cudak_(arr)[]; int j = blockIdx.x * blockDim.x + threadIdx.x; cudak_(arr)[threadIdx.x] = j < n ? exp(input[j + istride * blockIdx.y] - \ max[0 + mstride * blockIdx.y]) : 0; __syncthreads(); for (int offset = blockDim.x >> 1; offset; offset >>= 1) { if (threadIdx.x < offset) cudak_(arr)[threadIdx.x] += cudak_(arr)[threadIdx.x + offset]; __syncthreads(); } if (threadIdx.x == 0) output[blockIdx.x + ostride * blockIdx.y] = cudak_(arr)[0]; } __global__ void cudak_(block_reduce_rowmax)(const MATRIX_ELEM *input, MATRIX_ELEM *output, const int istride, const int ostride, const int n) { extern __shared__ MATRIX_ELEM cudak_(arr)[]; int j = blockIdx.x * blockDim.x + threadIdx.x; cudak_(arr)[threadIdx.x] = j < n ? input[j + istride * blockIdx.y] : -FLT_MAX; __syncthreads(); for (int offset = blockDim.x >> 1; offset; offset >>= 1) { if (threadIdx.x < offset) { MATRIX_ELEM l = cudak_(arr)[threadIdx.x], r = cudak_(arr)[threadIdx.x + offset]; if (r > l) cudak_(arr)[threadIdx.x] = r; } __syncthreads(); } if (threadIdx.x == 0) output[blockIdx.x + ostride * blockIdx.y] = cudak_(arr)[0]; } __global__ void cudak_(block_reduce_rowmax_idx)(const MATRIX_ELEM *input, const MATRIX_ELEM *idx_input, MATRIX_ELEM *output, MATRIX_ELEM *idx_output, const int istride, const int ostride, const int n) { extern __shared__ MATRIX_ELEM cudak_(arr)[]; MATRIX_ELEM *arr_val = cudak_(arr); MATRIX_ELEM *arr_idx = arr_val + blockDim.x; int j = blockIdx.x * blockDim.x + threadIdx.x; arr_val[threadIdx.x] = j < n ? input[j + istride * blockIdx.y] : -FLT_MAX; arr_idx[threadIdx.x] = j < n ? idx_input[j + istride * blockIdx.y] : 0; __syncthreads(); for (int offset = blockDim.x >> 1; offset; offset >>= 1) { if (threadIdx.x < offset) { MATRIX_ELEM l = arr_val[threadIdx.x], r = arr_val[threadIdx.x + offset]; if (r > l) { arr_val[threadIdx.x] = r; arr_idx[threadIdx.x] = arr_idx[threadIdx.x + offset]; } } __syncthreads(); } if (threadIdx.x == 0) { output[blockIdx.x + ostride * blockIdx.y] = arr_val[0]; idx_output[blockIdx.x + ostride * blockIdx.y] = arr_idx[0]; } } __global__ void cudak_(add_row)(const MATRIX_ELEM *a, MATRIX_ELEM *b, int nrow, int ncol, int stride, double beta) { int j = blockIdx.x * blockDim.x + threadIdx.x; int i = blockIdx.y * blockDim.y + threadIdx.y; if (i >= nrow || j >= ncol) return; b[j + i * stride] += beta * a[j]; } __global__ void cudak_(add_col)(const MATRIX_ELEM *a, MATRIX_ELEM *b, int nrow, int ncol, int astride, int bstride, double beta) { int j = blockIdx.x * blockDim.x + threadIdx.x; int i = blockIdx.y * blockDim.y + threadIdx.y; if (i >= nrow || j >= ncol) return; b[j + i * bstride] += beta * a[i * astride]; } __global__ void cudak_(fill)(MATRIX_ELEM *a, int nrow, int ncol, int stride, double val) { int j = blockIdx.x * blockDim.x + threadIdx.x; int i = blockIdx.y * blockDim.y + threadIdx.y; if (i >= nrow || j >= ncol) return; a[j + i * stride] = val; } __global__ void cudak_(diagonalize)(MATRIX_ELEM *a, int nrow, int ncol, int stride) { int j = blockIdx.x * blockDim.x + threadIdx.x; int i = blockIdx.y * blockDim.y + threadIdx.y; if (i >= nrow || j >= ncol || i == j) return; a[j + i * stride] = 0; } __global__ void cudak_(clip)(MATRIX_ELEM *a, int nrow, int ncol, int stride, double val_1, double val_2) { int j = blockIdx.x * blockDim.x + threadIdx.x; int i = blockIdx.y * blockDim.y + threadIdx.y; if (i >= nrow || j >= ncol) return; if (a[j + i * stride] > val_2) a[j + i * stride] = val_2; else if (a[j + i * stride] < val_1) a[j + i * stride] = val_1; } #ifdef __NERV_FUTURE_CUDA_7 __global__ void cudak_(update_select_rows_by_rowidx)(MATRIX_ELEM *c, const MATRIX_ELEM *a, const MATRIX_ELEM *idx, int nrow_a, int ncol_a, int nrow_c, int stride_c, int stride_a, double alpha, double beta) { int j = blockIdx.x * blockDim.x + threadIdx.x; int i = blockIdx.y * blockDim.y + threadIdx.y; if (i >= nrow_a || j >= ncol_a) return; int i_c = lrintf(idx[i]); /* if (i_c < 0 || i_c >= nrow_c) { printf("ERROR inside kernel update_select_rows, i_c(%d) out of range!", i_c); } */ //critical: i_c could conflict among threads(same index in the idx array), so atomicAdd is used //c[j + i_c * stride_c] = c[j + i_c * stride_c] * (1 - beta * alpha) + a[j + i * stride_a] * alpha; atomicAdd_nvidia(c + j + i_c * stride_c, c[j + i_c * stride_c] * (- beta * alpha) + a[j + i * stride_a] * alpha); } __global__ void cudak_(update_select_rows_by_colidx)(MATRIX_ELEM *c, const MATRIX_ELEM *a, const MATRIX_ELEM *idx, int nrow_a, int ncol_a, int nrow_c, int stride_c, int stride_a, int stride_idx, double alpha, double beta) { int j = blockIdx.x * blockDim.x + threadIdx.x; int i = blockIdx.y * blockDim.y + threadIdx.y; if (i >= nrow_a || j >= ncol_a) return; int i_c = lrintf(idx[stride_idx * i]); /* if (i_c < 0 || i_c >= nrow_c) { printf("ERROR inside kernel update_select_rows, i_c(%d) out of range!", i_c); } */ //critical: i_c could conflict among threads(same index in the idx array), so atomicAdd is used //c[j + i_c * stride_c] = c[j + i_c * stride_c] * (1 - beta * alpha) + a[j + i * stride_a] * alpha; atomicAdd_nvidia(c + j + i_c * stride_c, c[j + i_c * stride_c] * (- beta * alpha) + a[j + i * stride_a] * alpha); } #endif __global__ void cudak_(expand_frm)(const MATRIX_ELEM *a, MATRIX_ELEM *b, int nrow, int ncol, int enrow, int encol, int stride, int estride, int context) { int j = blockIdx.x * blockDim.x + threadIdx.x; int i = blockIdx.y * blockDim.y + threadIdx.y; int ridx; if (i >= enrow || j >= encol) return; ridx = i + j / ncol - context; if (ridx < 0) ridx = 0; else if (ridx >= nrow) ridx = nrow - 1; b[j + i * estride] = a[j % ncol + ridx * stride]; } __global__ void cudak_(rearrange_frm)(const MATRIX_ELEM *a, MATRIX_ELEM *b, int nrow, int ncol, int stride, int step, int orig_dim) { int j = blockIdx.x * blockDim.x + threadIdx.x; int i = blockIdx.y * blockDim.y + threadIdx.y; if (i >= nrow || j >= ncol) return; b[j + i * stride] = a[j / step + (j % step) * orig_dim + i * stride]; } __global__ void cudak_(set_values_by_mask)(const MATRIX_ELEM *a, MATRIX_ELEM *b, int nrow, int ncol, int astride, int bstride, double val) { int j = blockIdx.x * blockDim.x + threadIdx.x; int i = blockIdx.y * blockDim.y + threadIdx.y; if (i >= nrow || j >= ncol || a[i * astride] != 0.0) return; b[j + i * bstride] = val; } __global__ void cudak_(scale_rows_by_col)(const MATRIX_ELEM *a, MATRIX_ELEM *b, int nrow, int ncol, int astride, int bstride) { int j = blockIdx.x * blockDim.x + threadIdx.x; int i = blockIdx.y * blockDim.y + threadIdx.y; if (i >= nrow || j >= ncol) return; b[j + i * bstride] *= a[i * astride]; } __global__ void cudak_(scale_rows_by_row)(const MATRIX_ELEM *a, MATRIX_ELEM *b, int nrow, int ncol, int stride) { int j = blockIdx.x * blockDim.x + threadIdx.x; int i = blockIdx.y * blockDim.y + threadIdx.y; if (i >= nrow || j >= ncol) return; b[j + i * stride] *= a[j]; } __global__ void cudak_(decompress)(const MATRIX_ELEM *a, MATRIX_ELEM *b, int nrow, int ncol, int stride_a, int stride_b) { int j = blockIdx.x * blockDim.x + threadIdx.x; int i = blockIdx.y * blockDim.y + threadIdx.y; if (i >= nrow || j >= ncol) return; b[lrintf(a[j + i * stride_a]) + i * stride_b] = 1.0; } __global__ void cudak_(gen_col_idx)(MATRIX_ELEM *b, int nrow, int ncol, int stride) { int j = blockIdx.x * blockDim.x + threadIdx.x; int i = blockIdx.y * blockDim.y + threadIdx.y; if (i >= nrow || j >= ncol) return; b[j + i * stride] = j; } __global__ void cudak_(copy_rows_by_idx)(const MATRIX_ELEM *a, MATRIX_ELEM *b, const MATRIX_ELEM *idx, int nrow, int ncol, int a_nrow, int stride) { int j = blockIdx.x * blockDim.x + threadIdx.x; int i = blockIdx.y * blockDim.y + threadIdx.y; if (i >= nrow || j >= ncol) return; /* int k = lrintf(idx[i]); if (k < 0 || k >= a_nrow) { printf("error in kernel copy_rows_by_idx k(%d) out of range\n", k); } b[j + i * stride] = a[j + k * stride]; */ /* NOTE: in most cases it is guaranteed * the idx is within the range, checking * would bring some overhead. */ b[j + i * stride] = a[j + lrintf(idx[i]) * stride]; } __global__ void cudak_(copy_rows_by_colidx)(const MATRIX_ELEM *a, MATRIX_ELEM *b, const MATRIX_ELEM *idx, int nrow, int ncol, int a_nrow, int stride, int idx_stride) { int j = blockIdx.x * blockDim.x + threadIdx.x; int i = blockIdx.y * blockDim.y + threadIdx.y; if (i >= nrow || j >= ncol) return; int k = lrintf(idx[i * idx_stride]); /* if (k < 0 || k >= a_nrow) { printf("error in kernel copy_rows_by_colidx k(%d) out of range\n", k); } */ b[j + i * stride] = a[j + k * stride]; } __global__ void cudak_(prefixsum_row_reduce)(const MATRIX_ELEM *a, MATRIX_ELEM *b, int nrow, int ncol, int stride_a, int stride_b, int offset) { int j = blockIdx.x * blockDim.x + threadIdx.x; int i = blockIdx.y * blockDim.y + threadIdx.y; long idx_a, idx_b; if (i >= nrow || j >= ncol) return; idx_b = j + i * stride_b; idx_a = j + i * stride_a; //b[idx] = 1.0 / (1.0 + exp(-a[idx])); if (j >= offset) b[idx_b] = a[idx_a] + a[idx_a - offset]; else b[idx_b] = a[idx_a]; } extern "C" { #include "../cukernel.h" void cudak_(cuda_log_elem)(const Matrix *a, Matrix *b) { dim3 threadsPerBlock(CUDA_THREADS_N, CUDA_THREADS_N); dim3 numBlocks(CEIL_DIV(b->ncol, threadsPerBlock.x), CEIL_DIV(b->nrow, threadsPerBlock.y)); cudak_(log_elem)<<>> \ (MATRIX_ELEM_PTR(a), MATRIX_ELEM_PTR(b), b->nrow, b->ncol, b->stride / sizeof(MATRIX_ELEM)); cudaStreamSynchronize(0); } void cudak_(cuda_mul_elem)(const Matrix *a, const Matrix *b, Matrix *c) { dim3 threadsPerBlock(CUDA_THREADS_N, CUDA_THREADS_N); dim3 numBlocks(CEIL_DIV(b->ncol, threadsPerBlock.x), CEIL_DIV(b->nrow, threadsPerBlock.y)); cudak_(mul_elem)<<>> \ (MATRIX_ELEM_PTR(a), MATRIX_ELEM_PTR(b), MATRIX_ELEM_PTR(c), b->nrow, b->ncol, b->stride / sizeof(MATRIX_ELEM)); cudaStreamSynchronize(0); } void cudak_(cuda_sigmoid)(const Matrix *a, Matrix *b) { dim3 threadsPerBlock(CUDA_THREADS_N, CUDA_THREADS_N); dim3 numBlocks(CEIL_DIV(b->ncol, threadsPerBlock.x), CEIL_DIV(b->nrow, threadsPerBlock.y)); cudak_(sigmoid)<<>> \ (MATRIX_ELEM_PTR(a), MATRIX_ELEM_PTR(b), b->nrow, b->ncol, b->stride / sizeof(MATRIX_ELEM)); cudaStreamSynchronize(0); } void cudak_(cuda_sigmoid_grad)(const Matrix *output, const Matrix *err, Matrix *nerr) { dim3 threadsPerBlock(CUDA_THREADS_N, CUDA_THREADS_N); dim3 numBlocks(CEIL_DIV(nerr->ncol, threadsPerBlock.x), CEIL_DIV(nerr->nrow, threadsPerBlock.y)); cudak_(sigmoid_grad)<<>> \ (MATRIX_ELEM_PTR(output), MATRIX_ELEM_PTR(err), MATRIX_ELEM_PTR(nerr), nerr->nrow, nerr->ncol, nerr->stride / sizeof(MATRIX_ELEM)); cudaStreamSynchronize(0); } void cudak_(cuda_rand_uniform)(const Matrix *a, CuContext *context) { #ifdef MATRIX_USE_FLOAT curandGenerateUniform(context->curand_gen, MATRIX_ELEM_PTR(a), a->nrow * a->stride / sizeof(MATRIX_ELEM)); #endif #ifdef MATRIX_USE_DOUBLE curandGenerateUniformDouble(context->curand_gen, MATRIX_ELEM_PTR(a), a->nrow * a->stride / sizeof(MATRIX_ELEM)); #endif } void cudak_(cuda_thres_mask)(const Matrix *a, const Matrix *b, double thres, double low, double high) { dim3 threadsPerBlock(CUDA_THREADS_N, CUDA_THREADS_N); dim3 numBlocks(CEIL_DIV(a->ncol, threadsPerBlock.x), CEIL_DIV(a->nrow, threadsPerBlock.y)); cudak_(thres_mask)<<>> \ (MATRIX_ELEM_PTR(a), MATRIX_ELEM_PTR(b), thres, low, high, a->nrow, a->ncol, a->stride / sizeof(MATRIX_ELEM)); cudaStreamSynchronize(0); } void cudak_(cuda_tanh)(const Matrix *a, Matrix *b) { dim3 threadsPerBlock(CUDA_THREADS_N, CUDA_THREADS_N); dim3 numBlocks(CEIL_DIV(b->ncol, threadsPerBlock.x), CEIL_DIV(b->nrow, threadsPerBlock.y)); cudak_(tanh)<<>> \ (MATRIX_ELEM_PTR(a), MATRIX_ELEM_PTR(b), b->nrow, b->ncol, b->stride / sizeof(MATRIX_ELEM)); cudaStreamSynchronize(0); } void cudak_(cuda_tanh_grad)(const Matrix *output, const Matrix *err, Matrix *nerr) { dim3 threadsPerBlock(CUDA_THREADS_N, CUDA_THREADS_N); dim3 numBlocks(CEIL_DIV(nerr->ncol, threadsPerBlock.x), CEIL_DIV(nerr->nrow, threadsPerBlock.y)); cudak_(tanh_grad)<<>> \ (MATRIX_ELEM_PTR(output), MATRIX_ELEM_PTR(err), MATRIX_ELEM_PTR(nerr), nerr->nrow, nerr->ncol, nerr->stride / sizeof(MATRIX_ELEM)); cudaStreamSynchronize(0); } void cudak_(cuda_relu)(const Matrix *a, Matrix *b) { dim3 threadsPerBlock(CUDA_THREADS_N, CUDA_THREADS_N); dim3 numBlocks(CEIL_DIV(b->ncol, threadsPerBlock.x), CEIL_DIV(b->nrow, threadsPerBlock.y)); cudak_(relu)<<>> \ (MATRIX_ELEM_PTR(a), MATRIX_ELEM_PTR(b), b->nrow, b->ncol, b->stride / sizeof(MATRIX_ELEM)); cudaStreamSynchronize(0); } void cudak_(cuda_relu_grad)(const Matrix *output, const Matrix *err, Matrix *nerr) { dim3 threadsPerBlock(CUDA_THREADS_N, CUDA_THREADS_N); dim3 numBlocks(CEIL_DIV(nerr->ncol, threadsPerBlock.x), CEIL_DIV(nerr->nrow, threadsPerBlock.y)); cudak_(relu_grad)<<>> \ (MATRIX_ELEM_PTR(output), MATRIX_ELEM_PTR(err), MATRIX_ELEM_PTR(nerr), nerr->nrow, nerr->ncol, nerr->stride / sizeof(MATRIX_ELEM)); cudaStreamSynchronize(0); } void cudak_(cuda_rowsum)(const Matrix *a, Matrix *b) { dim3 block(CUDA_THREADS_NN, 1); int ncol = a->ncol; int blocks_per_row = CEIL_DIV(ncol, block.x); dim3 grid(blocks_per_row, a->nrow); MATRIX_ELEM *res; size_t stride; cudaMallocPitch(&res, &stride, blocks_per_row * sizeof(MATRIX_ELEM), a->nrow); cudak_(block_reduce_rowsum)<<>> \ (MATRIX_ELEM_PTR(a), res, a->stride / sizeof(MATRIX_ELEM), stride / sizeof(MATRIX_ELEM), ncol); ncol = blocks_per_row; assert((unsigned long)ncol <= block.x); grid.x = 1; cudaStreamSynchronize(0); cudak_(block_reduce_rowsum)<<>> \ (res, MATRIX_ELEM_PTR(b), stride / sizeof(MATRIX_ELEM), b->stride / sizeof(MATRIX_ELEM), ncol); cudaStreamSynchronize(0); cudaFree(res); } void cudak_(cuda_colsame)(const Matrix *a, const Matrix *ref, Matrix *b) { dim3 block(1, CUDA_THREADS_NN); int nrow = a->nrow; int blocks_per_col = CEIL_DIV(nrow, block.y); dim3 grid(a->ncol, blocks_per_col); MATRIX_ELEM *res; size_t stride; cudaMallocPitch(&res, &stride, a->ncol * sizeof(MATRIX_ELEM), blocks_per_col); cudak_(block_reduce_colsame)<<>> \ (MATRIX_ELEM_PTR(a), MATRIX_ELEM_PTR(ref), res, a->stride / sizeof(MATRIX_ELEM), stride / sizeof(MATRIX_ELEM), nrow); nrow = blocks_per_col; assert((unsigned long)nrow <= block.y); grid.y = 1; cudaStreamSynchronize(0); cudak_(block_reduce_colsum)<<>> \ (res, MATRIX_ELEM_PTR(b), stride / sizeof(MATRIX_ELEM), b->stride / sizeof(MATRIX_ELEM), nrow); cudaStreamSynchronize(0); cudaFree(res); } void cudak_(cuda_colsum)(const Matrix *a, Matrix *b) { dim3 block(1, CUDA_THREADS_NN); int nrow = a->nrow; int blocks_per_col = CEIL_DIV(nrow, block.y); dim3 grid(a->ncol, blocks_per_col); MATRIX_ELEM *res; size_t stride; cudaMallocPitch(&res, &stride, a->ncol * sizeof(MATRIX_ELEM), blocks_per_col); cudak_(block_reduce_colsum)<<>> \ (MATRIX_ELEM_PTR(a), res, a->stride / sizeof(MATRIX_ELEM), stride / sizeof(MATRIX_ELEM), nrow); nrow = blocks_per_col; assert((unsigned long)nrow <= block.y); grid.y = 1; cudaStreamSynchronize(0); cudak_(block_reduce_colsum)<<>> \ (res, MATRIX_ELEM_PTR(b), stride / sizeof(MATRIX_ELEM), b->stride / sizeof(MATRIX_ELEM), nrow); cudaStreamSynchronize(0); cudaFree(res); } void cudak_(cuda_softmax_final)(const Matrix *a, const Matrix *max, const Matrix *deno, Matrix *b) { dim3 threadsPerBlock(CUDA_THREADS_N, CUDA_THREADS_N); dim3 numBlocks(CEIL_DIV(b->ncol, threadsPerBlock.x), CEIL_DIV(b->nrow, threadsPerBlock.y)); cudak_(softmax_final)<<>> \ (MATRIX_ELEM_PTR(a), MATRIX_ELEM_PTR(b), MATRIX_ELEM_PTR(max), MATRIX_ELEM_PTR(deno), b->nrow, b->ncol, b->stride / sizeof(MATRIX_ELEM), max->stride / sizeof(MATRIX_ELEM)); cudaStreamSynchronize(0); } void cudak_(cuda_softmax_denominator)(const Matrix *a, const Matrix *max, Matrix *b) { dim3 block(CUDA_THREADS_NN, 1); int ncol = a->ncol; int blocks_per_row = CEIL_DIV(ncol, block.x); dim3 grid(blocks_per_row, a->nrow); MATRIX_ELEM *res; size_t stride; assert(max->ncol == 1); cudaMallocPitch(&res, &stride, blocks_per_row * sizeof(MATRIX_ELEM), a->nrow); cudak_(block_reduce_softmax_rowsum) \ <<>> \ (MATRIX_ELEM_PTR(a), res, MATRIX_ELEM_PTR(max), a->stride / sizeof(MATRIX_ELEM), stride / sizeof(MATRIX_ELEM), max->stride / sizeof(MATRIX_ELEM), ncol); ncol = blocks_per_row; assert((unsigned long)ncol <= block.x); grid.x = 1; cudaStreamSynchronize(0); cudak_(block_reduce_rowsum) \ <<>> \ (res, MATRIX_ELEM_PTR(b), stride / sizeof(MATRIX_ELEM), b->stride / sizeof(MATRIX_ELEM), ncol); cudaStreamSynchronize(0); cudaFree(res); } void cudak_(cuda_rowmax)(const Matrix *a, Matrix *b) { dim3 block(CUDA_THREADS_NN, 1); int ncol = a->ncol; int blocks_per_row = CEIL_DIV(ncol, block.x); dim3 grid(blocks_per_row, a->nrow); MATRIX_ELEM *res; size_t stride; cudaMallocPitch(&res, &stride, blocks_per_row * sizeof(MATRIX_ELEM), a->nrow); cudak_(block_reduce_rowmax)<<>> \ (MATRIX_ELEM_PTR(a), res, a->stride / sizeof(MATRIX_ELEM), stride / sizeof(MATRIX_ELEM), ncol); ncol = blocks_per_row; assert((unsigned long)ncol <= block.x); grid.x = 1; cudaStreamSynchronize(0); cudak_(block_reduce_rowmax)<<>> \ (res, MATRIX_ELEM_PTR(b), stride / sizeof(MATRIX_ELEM), b->stride / sizeof(MATRIX_ELEM), ncol); cudaStreamSynchronize(0); cudaFree(res); } void cudak_(cuda_rowmax_idx)(const Matrix *a, Matrix *b, Matrix *b_idx) { dim3 block(CUDA_THREADS_NN, 1); int ncol = a->ncol; int blocks_per_row = CEIL_DIV(ncol, block.x); dim3 grid(blocks_per_row, a->nrow); MATRIX_ELEM *a_idx, *res, *res_idx; size_t stride; cudaMallocPitch(&a_idx, &stride, a->stride, a->nrow); cudak_(gen_col_idx)<<>>(a_idx, a->nrow, ncol, stride / sizeof(MATRIX_ELEM)); cudaMallocPitch(&res, &stride, blocks_per_row * sizeof(MATRIX_ELEM), a->nrow); cudaMallocPitch(&res_idx, &stride, blocks_per_row * sizeof(MATRIX_ELEM), a->nrow); cudaStreamSynchronize(0); cudak_(block_reduce_rowmax_idx)<<>> \ (MATRIX_ELEM_PTR(a), a_idx, res, res_idx, a->stride / sizeof(MATRIX_ELEM), stride / sizeof(MATRIX_ELEM), ncol); ncol = blocks_per_row; assert((unsigned long)ncol <= block.x); grid.x = 1; cudaStreamSynchronize(0); cudak_(block_reduce_rowmax_idx)<<>> \ (res, res_idx, MATRIX_ELEM_PTR(b), MATRIX_ELEM_PTR(b_idx), stride / sizeof(MATRIX_ELEM), b->stride / sizeof(MATRIX_ELEM), ncol); cudaStreamSynchronize(0); cudaFree(a_idx); cudaFree(res); cudaFree(res_idx); } /* in-place calc */ void cudak_(cuda_add_row)(const Matrix *a, Matrix *b, double beta) { dim3 threadsPerBlock(CUDA_THREADS_N, CUDA_THREADS_N); dim3 numBlocks(CEIL_DIV(b->ncol, threadsPerBlock.x), CEIL_DIV(b->nrow, threadsPerBlock.y)); cudak_(add_row)<<>> \ (MATRIX_ELEM_PTR(a), MATRIX_ELEM_PTR(b), b->nrow, b->ncol, b->stride / sizeof(MATRIX_ELEM), beta); cudaStreamSynchronize(0); } void cudak_(cuda_add_col)(const Matrix *a, Matrix *b, double beta) { dim3 threadsPerBlock(CUDA_THREADS_N, CUDA_THREADS_N); dim3 numBlocks(CEIL_DIV(b->ncol, threadsPerBlock.x), CEIL_DIV(b->nrow, threadsPerBlock.y)); cudak_(add_col)<<>> \ (MATRIX_ELEM_PTR(a), MATRIX_ELEM_PTR(b), b->nrow, b->ncol, a->stride / sizeof(MATRIX_ELEM), b->stride / sizeof(MATRIX_ELEM), beta); cudaStreamSynchronize(0); } void cudak_(cuda_fill)(Matrix *a, double val) { dim3 threadsPerBlock(CUDA_THREADS_N, CUDA_THREADS_N); dim3 numBlocks(CEIL_DIV(a->ncol, threadsPerBlock.x), CEIL_DIV(a->nrow, threadsPerBlock.y)); cudak_(fill)<<>> \ (MATRIX_ELEM_PTR(a), a->nrow, a->ncol, a->stride / sizeof(MATRIX_ELEM), val); cudaStreamSynchronize(0); } void cudak_(cuda_diagonalize)(Matrix *a) { dim3 threadsPerBlock(CUDA_THREADS_N, CUDA_THREADS_N); dim3 numBlocks(CEIL_DIV(a->ncol, threadsPerBlock.x), CEIL_DIV(a->nrow, threadsPerBlock.y)); cudak_(diagonalize)<<>> \ (MATRIX_ELEM_PTR(a), a->nrow, a->ncol, a->stride / sizeof(MATRIX_ELEM)); cudaStreamSynchronize(0); } void cudak_(cuda_clip)(Matrix *a, double val_1, double val_2) { dim3 threadsPerBlock(CUDA_THREADS_N, CUDA_THREADS_N); dim3 numBlocks(CEIL_DIV(a->ncol, threadsPerBlock.x), CEIL_DIV(a->nrow, threadsPerBlock.y)); cudak_(clip)<<>> \ (MATRIX_ELEM_PTR(a), a->nrow, a->ncol, a->stride / sizeof(MATRIX_ELEM), val_1, val_2); cudaStreamSynchronize(0); } #ifdef __NERV_FUTURE_CUDA_7 void cudak_(cuda_update_select_rows_by_rowidx)(Matrix *c, const Matrix *a, const Matrix *idx, double alpha, double beta) { dim3 threadsPerBlock(CUDA_THREADS_N, CUDA_THREADS_N); dim3 numBlocks(CEIL_DIV(a->ncol, threadsPerBlock.x), CEIL_DIV(a->nrow, threadsPerBlock.y)); cudak_(update_select_rows_by_rowidx)<<>> \ (MATRIX_ELEM_PTR(c), MATRIX_ELEM_PTR(a), MATRIX_ELEM_PTR(idx), a->nrow, a->ncol, c->nrow, c->stride / sizeof(MATRIX_ELEM), a->stride / sizeof(MATRIX_ELEM), alpha, beta); cudaStreamSynchronize(0); } void cudak_(cuda_update_select_rows_by_colidx)(Matrix *c, const Matrix *a, const Matrix *idx, double alpha, double beta) { dim3 threadsPerBlock(CUDA_THREADS_N, CUDA_THREADS_N); dim3 numBlocks(CEIL_DIV(a->ncol, threadsPerBlock.x), CEIL_DIV(a->nrow, threadsPerBlock.y)); cudak_(update_select_rows_by_colidx)<<>> \ (MATRIX_ELEM_PTR(c), MATRIX_ELEM_PTR(a), MATRIX_ELEM_PTR(idx), a->nrow, a->ncol, c->nrow, c->stride / sizeof(MATRIX_ELEM), a->stride / sizeof(MATRIX_ELEM), idx->stride / sizeof(MATRIX_ELEM), alpha, beta); cudaStreamSynchronize(0); } #endif void cudak_(cuda_expand_frm)(const Matrix *a, Matrix *b, int context) { dim3 threadsPerBlock(CUDA_THREADS_N, CUDA_THREADS_N); dim3 numBlocks(CEIL_DIV(b->ncol, threadsPerBlock.x), CEIL_DIV(b->nrow, threadsPerBlock.y)); cudak_(expand_frm)<<>> \ (MATRIX_ELEM_PTR(a), MATRIX_ELEM_PTR(b), a->nrow, a->ncol, b->nrow, b->ncol, a->stride / sizeof(MATRIX_ELEM), b->stride / sizeof(MATRIX_ELEM), context); cudaStreamSynchronize(0); } void cudak_(cuda_rearrange_frm)(const Matrix *a, Matrix *b, int step) { dim3 threadsPerBlock(CUDA_THREADS_N, CUDA_THREADS_N); dim3 numBlocks(CEIL_DIV(b->ncol, threadsPerBlock.x), CEIL_DIV(b->nrow, threadsPerBlock.y)); cudak_(rearrange_frm)<<>> \ (MATRIX_ELEM_PTR(a), MATRIX_ELEM_PTR(b), b->nrow, b->ncol, b->stride / sizeof(MATRIX_ELEM), step, b->ncol / step); cudaStreamSynchronize(0); } void cudak_(cuda_scale_rows_by_col)(const Matrix *a, Matrix *b) { dim3 threadsPerBlock(CUDA_THREADS_N, CUDA_THREADS_N); dim3 numBlocks(CEIL_DIV(b->ncol, threadsPerBlock.x), CEIL_DIV(b->nrow, threadsPerBlock.y)); cudak_(scale_rows_by_col)<<>> \ (MATRIX_ELEM_PTR(a), MATRIX_ELEM_PTR(b), b->nrow, b->ncol, a->stride / sizeof(MATRIX_ELEM), b->stride / sizeof(MATRIX_ELEM)); cudaStreamSynchronize(0); } void cudak_(cuda_set_values_by_mask)(const Matrix *a, Matrix *b, double val) { dim3 threadsPerBlock(CUDA_THREADS_N, CUDA_THREADS_N); dim3 numBlocks(CEIL_DIV(b->ncol, threadsPerBlock.x), CEIL_DIV(b->nrow, threadsPerBlock.y)); cudak_(set_values_by_mask)<<>> \ (MATRIX_ELEM_PTR(a), MATRIX_ELEM_PTR(b), b->nrow, b->ncol, a->stride / sizeof(MATRIX_ELEM), b->stride / sizeof(MATRIX_ELEM), val); cudaStreamSynchronize(0); } void cudak_(cuda_scale_rows_by_row)(const Matrix *a, Matrix *b) { dim3 threadsPerBlock(CUDA_THREADS_N, CUDA_THREADS_N); dim3 numBlocks(CEIL_DIV(b->ncol, threadsPerBlock.x), CEIL_DIV(b->nrow, threadsPerBlock.y)); cudak_(scale_rows_by_row)<<>> \ (MATRIX_ELEM_PTR(a), MATRIX_ELEM_PTR(b), b->nrow, b->ncol, b->stride / sizeof(MATRIX_ELEM)); cudaStreamSynchronize(0); } void cudak_(cuda_prefixsum_row)(const Matrix *a, Matrix *b) { dim3 threadsPerBlock(CUDA_THREADS_N, CUDA_THREADS_N); dim3 numBlocks(CEIL_DIV(b->ncol, threadsPerBlock.x), CEIL_DIV(b->nrow, threadsPerBlock.y)); MATRIX_ELEM *tmp[2]; size_t tmp_stride[2]; cudaMallocPitch(tmp, tmp_stride + 0, a->ncol * sizeof(MATRIX_ELEM), a->nrow); cudaMallocPitch(tmp + 1, tmp_stride + 1, a->ncol * sizeof(MATRIX_ELEM), a->nrow); int offset = 1; cudak_(prefixsum_row_reduce)<<>> \ (MATRIX_ELEM_PTR(a), tmp[0], b->nrow, b->ncol, a->stride / sizeof(MATRIX_ELEM), tmp_stride[0] / sizeof(MATRIX_ELEM), offset); int pin = 0, pout = 1; for (offset = 2;offset <= a->ncol / 2;offset *= 2) { cudak_(prefixsum_row_reduce)<<>> \ (tmp[pin], tmp[pout], b->nrow, b->ncol, tmp_stride[pin] / sizeof(MATRIX_ELEM), tmp_stride[pout] / sizeof(MATRIX_ELEM), offset); pin = 1 - pin; pout = 1 - pout; } cudak_(prefixsum_row_reduce)<<>> \ (tmp[pin], MATRIX_ELEM_PTR(b), b->nrow, b->ncol, tmp_stride[pin] / sizeof(MATRIX_ELEM), b->stride / sizeof(MATRIX_ELEM), offset); cudaFree(tmp[0]); cudaFree(tmp[1]); cudaStreamSynchronize(0); } void cudak_(cuda_decompress)(const Matrix *a, Matrix *b) { dim3 threadsPerBlock(1, CUDA_THREADS_NN); dim3 numBlocks(1, CEIL_DIV(a->nrow, threadsPerBlock.y)); cudak_(decompress)<<>> \ (MATRIX_ELEM_PTR(a), MATRIX_ELEM_PTR(b), a->nrow, a->ncol, a->stride / sizeof(MATRIX_ELEM), b->stride / sizeof(MATRIX_ELEM)); cudaStreamSynchronize(0); } void cudak_(cuda_copy_rows_by_idx)(const Matrix *a, Matrix *b, const Matrix *idx, int idx_begin) { dim3 threadsPerBlock(CUDA_THREADS_NN, 1); dim3 numBlocks(CEIL_DIV(b->ncol, threadsPerBlock.x), b->nrow); cudak_(copy_rows_by_idx)<<>> \ (MATRIX_ELEM_PTR(a), MATRIX_ELEM_PTR(b), MATRIX_ELEM_PTR(idx) + idx_begin, b->nrow, b->ncol, a->nrow, b->stride / sizeof(MATRIX_ELEM)); cudaStreamSynchronize(0); } void cudak_(cuda_copy_rows_by_colidx)(const Matrix *a, Matrix *b, const Matrix *idx, int idx_begin) { dim3 threadsPerBlock(CUDA_THREADS_NN, 1); dim3 numBlocks(CEIL_DIV(b->ncol, threadsPerBlock.x), b->nrow); cudak_(copy_rows_by_colidx)<<>> \ (MATRIX_ELEM_PTR(a), MATRIX_ELEM_PTR(b), MATRIX_ELEM_PTR(idx) + idx_begin, b->nrow, b->ncol, a->nrow, b->stride / sizeof(MATRIX_ELEM), idx->stride / sizeof(MATRIX_ELEM)); cudaStreamSynchronize(0); } } #endif