aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorDeterminant <ted.sybil@gmail.com>2015-06-05 10:58:57 +0800
committerDeterminant <ted.sybil@gmail.com>2015-06-05 10:58:57 +0800
commitdf737041e4a9f3f55978cc74db9a9cea27fa9fa0 (patch)
treed656820be286550bc548f7c5ed4b1dcfecf3691c
parentea6f2990f99dd9ded6a0e74d75a3ec84900a2518 (diff)
add profiling; add ce accurarcy; several other changes
-rw-r--r--Makefile2
-rw-r--r--common.c55
-rw-r--r--common.h25
-rw-r--r--examples/test_nn_lib.lua15
-rw-r--r--io/chunk_file.c4
-rw-r--r--io/sgd_buffer.lua14
-rw-r--r--layer/softmax_ce.lua4
-rw-r--r--matrix/cuda_helper.h26
-rw-r--r--matrix/cukernel.h2
-rw-r--r--matrix/cumatrix.c34
-rw-r--r--matrix/generic/cukernel.cu138
-rw-r--r--matrix/generic/cumatrix.c97
-rw-r--r--matrix/generic/mmatrix.c4
-rw-r--r--matrix/init.c5
-rw-r--r--matrix/mmatrix.c29
m---------speech0
16 files changed, 416 insertions, 38 deletions
diff --git a/Makefile b/Makefile
index f0d319f..5c6fa7b 100644
--- a/Makefile
+++ b/Makefile
@@ -16,7 +16,7 @@ CUDA_BASE := /usr/local/cuda-6.5
CUDA_INCLUDE := -I $(CUDA_BASE)/include/
INCLUDE += $(CUDA_INCLUDE)
LDFLAGS := -L$(CUDA_BASE)/lib64/ -Wl,-rpath=$(CUDA_BASE)/lib64/ -lcudart -lcublas
-CFLAGS := -Wall -Wextra
+CFLAGS := -Wall -Wextra -O2
OBJ_DIR := $(BUILD_DIR)/objs
LUA_DIR := $(BUILD_DIR)/lua
LIB_DIR := $(BUILD_DIR)/lib
diff --git a/common.c b/common.c
index c60c1ec..355d7ff 100644
--- a/common.c
+++ b/common.c
@@ -1,5 +1,3 @@
-#ifndef NERV_COMMON_H
-#define NERV_COMMON_H
#include "common.h"
#include <stdarg.h>
int nerv_error(lua_State *L, const char *err_mesg_fmt, ...) {
@@ -24,4 +22,55 @@ void luaN_append_methods(lua_State *L, const luaL_Reg *mlist) {
lua_setfield(L, -2, mlist->name);
}
}
-#endif
+
+HashMap *hashmap_create(size_t size, HashKey_t hfunc, HashMapCmp_t cmp) {
+ HashMap *res = (HashMap *)malloc(sizeof(HashMap));
+ res->bucket = calloc(size, sizeof(HashNode));
+ res->cmp = cmp;
+ res->hfunc = hfunc;
+ res->size = size;
+ return res;
+}
+
+void *hashmap_getval(HashMap *h, const char *key) {
+ size_t idx = h->hfunc(key) % h->size;
+ HashNode *ptr;
+ for (ptr = h->bucket[idx]; ptr; ptr = ptr->next)
+ {
+ if (!h->cmp(ptr->key, key))
+ return ptr->val;
+ }
+ return NULL;
+}
+
+void hashmap_setval(HashMap *h, const char *key, void *val) {
+ size_t idx = h->hfunc(key) % h->size;
+ HashNode *ptr = malloc(sizeof(HashNode));
+ ptr->next = h->bucket[idx];
+ h->bucket[idx] = ptr;
+ ptr->key = key;
+ ptr->val = val;
+}
+
+void hashmap_clear(HashMap *h) {
+ size_t i;
+ for (i = 0; i < h->size; i++)
+ {
+ HashNode *ptr, *nptr;
+ for (ptr = h->bucket[i]; ptr; ptr = nptr)
+ {
+ nptr = ptr->next;
+ free(ptr->val);
+ free(ptr);
+ }
+ h->bucket[i] = NULL;
+ }
+}
+
+size_t bkdr_hash(const char *key) {
+ unsigned int seed = 131;
+ unsigned int res = 0;
+ while (*key)
+ res = res * seed + *key++;
+ return res;
+}
diff --git a/common.h b/common.h
index 51e90ee..8be19b0 100644
--- a/common.h
+++ b/common.h
@@ -1,3 +1,5 @@
+#ifndef NERV_COMMON_H
+#define NERV_COMMON_H
#include "lua.h"
#include "lauxlib.h"
#include "lualib.h"
@@ -5,6 +7,29 @@
#include <stdio.h>
#include <stdlib.h>
+typedef struct HashNode {
+ const char *key;
+ void *val;
+ struct HashNode *next;
+} HashNode;
+
+typedef int (*HashMapCmp_t)(const char *a, const char *b);
+typedef size_t (*HashKey_t)(const char *key);
+
+typedef struct HashMap {
+ HashNode **bucket;
+ HashMapCmp_t cmp;
+ HashKey_t hfunc;
+ size_t size;
+} HashMap;
+
+HashMap *hashmap_create(size_t size, HashKey_t hfunc, HashMapCmp_t cmp);
+void *hashmap_getval(HashMap *h, const char *key);
+void hashmap_setval(HashMap *h, const char *key, void *val);
+
+size_t bkdr_hash(const char *key);
+
int nerv_error(lua_State *L, const char *err_mesg_fmt, ...);
int nerv_error_method_not_implemented(lua_State *L);
void luaN_append_methods(lua_State *L, const luaL_Reg *mlist);
+#endif
diff --git a/examples/test_nn_lib.lua b/examples/test_nn_lib.lua
index 9600917..04fd7d6 100644
--- a/examples/test_nn_lib.lua
+++ b/examples/test_nn_lib.lua
@@ -116,7 +116,8 @@ tnet_reader = nerv.TNetReader(gconf,
buffer = nerv.SGDBuffer(gconf,
{
- buffer_size = 8192,
+ buffer_size = 81920,
+ -- randomize = true,
readers = {
{ reader = tnet_reader,
data = {main_scp = 429, ref = 1}}
@@ -126,10 +127,11 @@ buffer = nerv.SGDBuffer(gconf,
sm = sublayer_repo:get_layer("softmax_ce0")
main = layer_repo:get_layer("main")
main:init(gconf.batch_size)
-cnt = 0
+gconf.cnt = 0
for data in buffer.get_data, buffer do
- if cnt == 1000 then break end
- cnt = cnt + 1
+ if gconf.cnt == 1000 then break end
+ gconf.cnt = gconf.cnt + 1
+
input = {data.main_scp, data.ref}
output = {}
err_input = {}
@@ -140,7 +142,10 @@ for data in buffer.get_data, buffer do
main:update(err_input, input, output)
nerv.utils.printf("cross entropy: %.8f\n", sm.total_ce)
- nerv.utils.printf("frames: %.8f\n", sm.total_frames)
+ nerv.utils.printf("correct: %d\n", sm.total_correct)
+ nerv.utils.printf("frames: %d\n", sm.total_frames)
nerv.utils.printf("err/frm: %.8f\n", sm.total_ce / sm.total_frames)
+ nerv.utils.printf("accuracy: %.8f\n", sm.total_correct / sm.total_frames)
collectgarbage("collect")
end
+nerv.Matrix.print_profile()
diff --git a/io/chunk_file.c b/io/chunk_file.c
index ce346c5..4e987b7 100644
--- a/io/chunk_file.c
+++ b/io/chunk_file.c
@@ -268,7 +268,7 @@ int nerv_chunk_file_handle_destroy(lua_State *L) {
return 0;
}
-static int nerv_chunk_destroy(lua_State *L) {
+static int nerv_chunk_info_destroy(lua_State *L) {
ChunkInfo *pci = luaT_checkudata(L, 1, nerv_chunk_info_tname);
free(pci);
return 0;
@@ -298,7 +298,7 @@ void nerv_chunk_file_init(lua_State *L) {
luaT_newmetatable(L, nerv_chunk_file_handle_tname, NULL,
NULL, nerv_chunk_file_handle_destroy, NULL);
luaT_newmetatable(L, nerv_chunk_info_tname, NULL,
- NULL, nerv_chunk_destroy, NULL);
+ NULL, nerv_chunk_info_destroy, NULL);
luaT_newmetatable(L, nerv_chunk_data_tname, NULL,
NULL, nerv_chunk_data_destroy, NULL);
}
diff --git a/io/sgd_buffer.lua b/io/sgd_buffer.lua
index dadcf67..bf72744 100644
--- a/io/sgd_buffer.lua
+++ b/io/sgd_buffer.lua
@@ -4,6 +4,10 @@ function SGDBuffer:__init(global_conf, buffer_conf)
self.gconf = global_conf
self.buffer_size = math.floor(buffer_conf.buffer_size /
global_conf.batch_size) * global_conf.batch_size
+ self.randomize = buffer_conf.randomize
+ if self.randomize == nil then
+ self.randomize = false
+ end
self.head = 0
self.tail = 0
self.readers = {}
@@ -35,7 +39,9 @@ function SGDBuffer:saturate()
nerv.error("buffer size is too small to contain leftovers")
end
buff.data:copy_from(buff.leftover, 0, lrow)
+ buff.leftover = nil
end
+ nerv.utils.printf("leftover: %d\n", lrow)
reader.tail = lrow
reader.has_leftover = false
end
@@ -73,6 +79,8 @@ function SGDBuffer:saturate()
end
self.tail = math.min(self.tail, reader.tail)
end
+ self.rand_map = nerv.MMatrixInt.perm_gen(self.tail) -- generate shuffled index
+ collectgarbage("collect")
return self.tail >= self.gconf.batch_size
end
@@ -90,7 +98,11 @@ function SGDBuffer:get_data()
for i, reader in ipairs(self.readers) do
for id, buff in pairs(reader.buffs) do
local batch = self.gconf.cumat_type(batch_size, buff.width)
- batch:copy_fromh(buff.data, self.head, self.head + batch_size)
+ if self.randomize then
+ batch:copy_rows_fromh_by_idx(buff.data, self.rand_map, self.head)
+ else
+ batch:copy_fromh(buff.data, self.head, self.head + batch_size)
+ end
res[id] = batch
end
end
diff --git a/layer/softmax_ce.lua b/layer/softmax_ce.lua
index cf98c45..cd57010 100644
--- a/layer/softmax_ce.lua
+++ b/layer/softmax_ce.lua
@@ -17,6 +17,7 @@ function SoftmaxCELayer:init()
nerv.error("mismatching dimensions of previous network output and labels")
end
self.total_ce = 0.0
+ self.total_correct = 0
self.total_frames = 0
end
@@ -27,7 +28,7 @@ end
function SoftmaxCELayer:propagate(input, output)
local soutput = input[1]:create() -- temporary value for calc softmax
self.soutput = soutput
- soutput:softmax(input[1])
+ local classified = soutput:softmax(input[1])
local ce = soutput:create()
ce:log_elem(soutput)
local label = input[2]
@@ -38,6 +39,7 @@ function SoftmaxCELayer:propagate(input, output)
-- add total ce
self.total_ce = self.total_ce - ce:rowsum():colsum()[0]
self.total_frames = self.total_frames + soutput:nrow()
+ self.total_correct = self.total_correct + classified:colsame(input[2])[0]
end
function SoftmaxCELayer:back_propagate(next_bp_err, bp_err, input, output)
diff --git a/matrix/cuda_helper.h b/matrix/cuda_helper.h
index cedc643..5e5f2ad 100644
--- a/matrix/cuda_helper.h
+++ b/matrix/cuda_helper.h
@@ -1,17 +1,23 @@
#ifndef NERV_CUDA_HELPER_H
#define NERV_CUDA_HELPER_H
+#include "cuda.h"
+#include "cuda_runtime.h"
+#include "driver_types.h"
+#include "cublas_v2.h"
#define CUBLAS_SAFE_CALL(call) \
do { \
cublasStatus_t err = (call); \
if (err != CUBLAS_STATUS_SUCCESS) \
- nerv_error(L, "cumatrix cublas error: %s", cublasGetErrorString(err)); \
+ nerv_error(L, "cumatrix cublas error: %s at %s:%d", \
+ cublasGetErrorString(err), __FILE__, __LINE__); \
} while (0)
#define CUDA_SAFE_CALL(call) \
do { \
cudaError_t err = (call); \
if (err != cudaSuccess) \
- nerv_error(L, "cumatrix CUDA error: %s", cudaGetErrorString(err)); \
+ nerv_error(L, "cumatrix CUDA error: %s at %s:%d", \
+ cudaGetErrorString(err), __FILE__, __LINE__); \
} while (0)
#define CUDA_SAFE_SYNC_CALL(call) \
@@ -52,4 +58,20 @@ static const char *cublasGetErrorString(cublasStatus_t err) {
}
return "<unknown>";
}
+
+#define PROFILE_START \
+ do { \
+ cudaEvent_t start, stop; \
+ cudaEventCreate(&start); \
+ cudaEventCreate(&stop); \
+ cudaEventRecord(start, 0);
+#define PROFILE_STOP \
+ cudaEventRecord(stop, 0); \
+ cudaEventSynchronize(stop); \
+ float milliseconds = 0; \
+ cudaEventElapsedTime(&milliseconds, start, stop); \
+ accu_profile(__func__, milliseconds / 1000); \
+ } while (0);
+
+#define PROFILE_END
#endif
diff --git a/matrix/cukernel.h b/matrix/cukernel.h
index 7d2168e..23398c8 100644
--- a/matrix/cukernel.h
+++ b/matrix/cukernel.h
@@ -5,7 +5,9 @@ void cudak_(cuda_sigmoid)(const Matrix *a, Matrix *b);
void cudak_(cuda_sigmoid_grad)(const Matrix *output, const Matrix *err, Matrix *nerr);
void cudak_(cuda_rowsum)(const Matrix *a, Matrix *b);
void cudak_(cuda_rowmax)(const Matrix *a, Matrix *b);
+void cudak_(cuda_rowmax_idx)(const Matrix *a, Matrix *b, Matrix *idx);
void cudak_(cuda_colsum)(const Matrix *a, Matrix *b);
+void cudak_(cuda_colsame)(const Matrix *a, const Matrix *ref, Matrix *b);
void cudak_(cuda_softmax_denominator)(const Matrix *a, const Matrix *max, Matrix *b);
void cudak_(cuda_softmax_final)(const Matrix *a, const Matrix *max, const Matrix *deno, Matrix *b);
void cudak_(cuda_add_row)(const Matrix *a, Matrix *b, double beta);
diff --git a/matrix/cumatrix.c b/matrix/cumatrix.c
index 51a3681..4ebc5ff 100644
--- a/matrix/cumatrix.c
+++ b/matrix/cumatrix.c
@@ -1,4 +1,38 @@
#define NERV_GENERIC_CUMATRIX
+#include "../common.h"
+#include "cuda_helper.h"
+static cublasHandle_t cublas_handle;
+static HashMap *profile;
+
+int print_profile(lua_State *L) {
+ size_t i;
+ fprintf(stderr, "*** [nerv cumatrix profile] **\n");
+ for (i = 0; i < profile->size; i++)
+ {
+ HashNode *ptr;
+ for (ptr = profile->bucket[i]; ptr; ptr = ptr->next)
+ {
+ fprintf(stderr, "%s:\t%.6f\n", ptr->key, *(float *)ptr->val);
+ }
+ }
+ return 0;
+}
+
+int clear_profile(lua_State *L) {
+ hashmap_clear(profile);
+ return 0;
+}
+
+void accu_profile(const char *name, float delta) {
+ float *val = hashmap_getval(profile, name);
+ if (!val)
+ {
+ val = malloc(sizeof(float));
+ *val = 0;
+ hashmap_setval(profile, name, val);
+ }
+ *val += delta;
+}
#define MATRIX_USE_FLOAT
#define cuda_matrix_(NAME) cuda_matrix_float_##NAME
diff --git a/matrix/generic/cukernel.cu b/matrix/generic/cukernel.cu
index 05a1e78..fdab356 100644
--- a/matrix/generic/cukernel.cu
+++ b/matrix/generic/cukernel.cu
@@ -3,6 +3,7 @@
#include <stdio.h>
#include "matrix.h"
#include "cuda.h"
+#include "float.h"
#define CUDA_THREADS_N 16
#define CUDA_THREADS_NN ((CUDA_THREADS_N) * (CUDA_THREADS_N))
#define CEIL_DIV(a, b) (((a) + (b) - 1) / (b))
@@ -11,9 +12,12 @@ __global__ void cudak_(log_elem)(const MATRIX_ELEM *a, MATRIX_ELEM *b,
int j = blockIdx.x * blockDim.x + threadIdx.x;
int i = blockIdx.y * blockDim.y + threadIdx.y;
long idx;
+ MATRIX_ELEM tmp;
if (i >= nrow || j >= ncol) return;
idx = j + i * stride;
- b[idx] = log(a[idx]);
+ tmp = a[idx];
+ if(tmp < FLT_MIN) tmp = FLT_MIN;
+ b[idx] = log(tmp);
}
__global__ void cudak_(mul_elem)(const MATRIX_ELEM *a, const MATRIX_ELEM *b,
@@ -61,9 +65,9 @@ __global__ void cudak_(softmax_final)(const MATRIX_ELEM *a, MATRIX_ELEM *b,
}
__global__ void cudak_(block_reduce_rowsum)(const MATRIX_ELEM *input,
- MATRIX_ELEM *output,
- const int istride, const int ostride,
- const int n) {
+ MATRIX_ELEM *output,
+ const int istride, const int ostride,
+ const int n) {
extern __shared__ MATRIX_ELEM cudak_(arr)[];
int j = blockIdx.x * blockDim.x + threadIdx.x;
cudak_(arr)[threadIdx.x] = j < n ? input[j + istride * blockIdx.y] : 0;
@@ -96,6 +100,26 @@ __global__ void cudak_(block_reduce_colsum)(const MATRIX_ELEM *input,
output[blockIdx.x + ostride * blockIdx.y] = cudak_(arr)[0];
}
+__global__ void cudak_(block_reduce_colsame)(const MATRIX_ELEM *input,
+ const MATRIX_ELEM *ref_input,
+ MATRIX_ELEM *output,
+ const int istride, const int ostride,
+ const int n) {
+ extern __shared__ MATRIX_ELEM cudak_(arr)[];
+ int i = blockIdx.y * blockDim.y + threadIdx.y;
+ cudak_(arr)[threadIdx.y] = (i < n && input[blockIdx.x + istride * i] == \
+ ref_input[blockIdx.x + istride * i]) ? 1.0 : 0;
+ __syncthreads();
+ for (int offset = blockDim.y >> 1; offset; offset >>= 1)
+ {
+ if (threadIdx.y < offset)
+ cudak_(arr)[threadIdx.y] += cudak_(arr)[threadIdx.y + offset];
+ __syncthreads();
+ }
+ if (threadIdx.y == 0)
+ output[blockIdx.x + ostride * blockIdx.y] = cudak_(arr)[0];
+}
+
__global__ void cudak_(block_reduce_softmax_rowsum)(const MATRIX_ELEM *input,
MATRIX_ELEM *output,
const MATRIX_ELEM *max,
@@ -117,9 +141,9 @@ __global__ void cudak_(block_reduce_softmax_rowsum)(const MATRIX_ELEM *input,
}
__global__ void cudak_(block_reduce_rowmax)(const MATRIX_ELEM *input,
- MATRIX_ELEM *output,
- const int istride, const int ostride,
- const int n) {
+ MATRIX_ELEM *output,
+ const int istride, const int ostride,
+ const int n) {
extern __shared__ MATRIX_ELEM cudak_(arr)[];
int j = blockIdx.x * blockDim.x + threadIdx.x;
cudak_(arr)[threadIdx.x] = j < n ? input[j + istride * blockIdx.y] : 0;
@@ -129,8 +153,9 @@ __global__ void cudak_(block_reduce_rowmax)(const MATRIX_ELEM *input,
if (threadIdx.x < offset)
{
MATRIX_ELEM l = cudak_(arr)[threadIdx.x],
- r = cudak_(arr)[threadIdx.x + offset];
- if (r > l) cudak_(arr)[threadIdx.x] = r;
+ r = cudak_(arr)[threadIdx.x + offset];
+ if (r > l)
+ cudak_(arr)[threadIdx.x] = r;
}
__syncthreads();
}
@@ -138,6 +163,40 @@ __global__ void cudak_(block_reduce_rowmax)(const MATRIX_ELEM *input,
output[blockIdx.x + ostride * blockIdx.y] = cudak_(arr)[0];
}
+__global__ void cudak_(block_reduce_rowmax_idx)(const MATRIX_ELEM *input,
+ const MATRIX_ELEM *idx_input,
+ MATRIX_ELEM *output,
+ MATRIX_ELEM *idx_output,
+ const int istride, const int ostride,
+ const int n) {
+ extern __shared__ MATRIX_ELEM cudak_(arr)[];
+ MATRIX_ELEM *arr_val = cudak_(arr);
+ MATRIX_ELEM *arr_idx = arr_val + blockDim.x;
+ int j = blockIdx.x * blockDim.x + threadIdx.x;
+ arr_val[threadIdx.x] = j < n ? input[j + istride * blockIdx.y] : 0;
+ arr_idx[threadIdx.x] = j < n ? idx_input[j + istride * blockIdx.y] : 0;
+ __syncthreads();
+ for (int offset = blockDim.x >> 1; offset; offset >>= 1)
+ {
+ if (threadIdx.x < offset)
+ {
+ MATRIX_ELEM l = arr_val[threadIdx.x],
+ r = arr_val[threadIdx.x + offset];
+ if (r > l)
+ {
+ arr_val[threadIdx.x] = r;
+ arr_idx[threadIdx.x] = arr_idx[threadIdx.x + offset];
+ }
+ }
+ __syncthreads();
+ }
+ if (threadIdx.x == 0)
+ {
+ output[blockIdx.x + ostride * blockIdx.y] = arr_val[0];
+ idx_output[blockIdx.x + ostride * blockIdx.y] = arr_idx[0];
+ }
+}
+
__global__ void cudak_(add_row)(const MATRIX_ELEM *a, MATRIX_ELEM *b,
int nrow, int ncol, int stride, double beta) {
int j = blockIdx.x * blockDim.x + threadIdx.x;
@@ -196,6 +255,14 @@ __global__ void cudak_(decompress)(const MATRIX_ELEM *a, MATRIX_ELEM *b,
b[lrintf(a[j + i * stride_a]) + i * stride_b] = 1.0;
}
+__global__ void cudak_(gen_col_idx)(MATRIX_ELEM *b,
+ int nrow, int ncol, int stride) {
+ int j = blockIdx.x * blockDim.x + threadIdx.x;
+ int i = blockIdx.y * blockDim.y + threadIdx.y;
+ if (i >= nrow || j >= ncol) return;
+ b[j + i * stride] = j;
+}
+
extern "C" {
#include "../cukernel.h"
void cudak_(cuda_log_elem)(const Matrix *a, Matrix *b) {
@@ -261,10 +328,32 @@ extern "C" {
cudaFree(res);
}
+ void cudak_(cuda_colsame)(const Matrix *a, const Matrix *ref, Matrix *b) {
+ dim3 block(1, CUDA_THREADS_NN);
+ int nrow = a->nrow;
+ int blocks_per_col = CEIL_DIV(nrow, block.y);
+ dim3 grid(a->ncol, blocks_per_col);
+ MATRIX_ELEM *res;
+ size_t stride;
+ cudaMallocPitch(&res, &stride, a->ncol * sizeof(MATRIX_ELEM), blocks_per_col);
+ cudak_(block_reduce_colsame)<<<grid, block, block.y * sizeof(MATRIX_ELEM)>>> \
+ (MATRIX_ELEM_PTR(a), MATRIX_ELEM_PTR(ref), res,
+ a->stride / sizeof(MATRIX_ELEM), stride / sizeof(MATRIX_ELEM),
+ nrow);
+ nrow = blocks_per_col;
+ assert((unsigned long)nrow <= block.y);
+ grid.y = 1;
+ cudak_(block_reduce_colsum)<<<grid, block, block.y * sizeof(MATRIX_ELEM)>>> \
+ (res, MATRIX_ELEM_PTR(b),
+ stride / sizeof(MATRIX_ELEM), b->stride / sizeof(MATRIX_ELEM),
+ nrow);
+ cudaFree(res);
+ }
+
void cudak_(cuda_colsum)(const Matrix *a, Matrix *b) {
dim3 block(1, CUDA_THREADS_NN);
int nrow = a->nrow;
- int blocks_per_col = CEIL_DIV(nrow, block.x);
+ int blocks_per_col = CEIL_DIV(nrow, block.y);
dim3 grid(a->ncol, blocks_per_col);
MATRIX_ELEM *res;
size_t stride;
@@ -344,6 +433,35 @@ extern "C" {
cudaFree(res);
}
+ void cudak_(cuda_rowmax_idx)(const Matrix *a, Matrix *b, Matrix *b_idx) {
+ dim3 block(CUDA_THREADS_NN, 1);
+ int ncol = a->ncol;
+ int blocks_per_row = CEIL_DIV(ncol, block.x);
+ dim3 grid(blocks_per_row, a->nrow);
+ MATRIX_ELEM *a_idx, *res, *res_idx;
+ size_t stride;
+ cudaMallocPitch(&a_idx, &stride, a->stride, a->nrow);
+ cudak_(gen_col_idx)<<<grid, block>>>(a_idx, a->nrow, ncol, stride / sizeof(MATRIX_ELEM));
+ cudaMallocPitch(&res, &stride, blocks_per_row * sizeof(MATRIX_ELEM), a->nrow);
+ cudaMallocPitch(&res_idx, &stride, blocks_per_row * sizeof(MATRIX_ELEM), a->nrow);
+ cudak_(block_reduce_rowmax_idx)<<<grid, block,
+ 2 * block.x * sizeof(MATRIX_ELEM)>>> \
+ (MATRIX_ELEM_PTR(a), a_idx, res, res_idx,
+ a->stride / sizeof(MATRIX_ELEM), stride / sizeof(MATRIX_ELEM),
+ ncol);
+ cudaFree(a_idx);
+ ncol = blocks_per_row;
+ assert((unsigned long)ncol <= block.x);
+ grid.x = 1;
+ cudak_(block_reduce_rowmax_idx)<<<grid, block,
+ 2 * block.x * sizeof(MATRIX_ELEM)>>> \
+ (res, res_idx, MATRIX_ELEM_PTR(b), MATRIX_ELEM_PTR(b_idx),
+ stride / sizeof(MATRIX_ELEM), b->stride / sizeof(MATRIX_ELEM),
+ ncol);
+ cudaFree(res);
+ cudaFree(res_idx);
+ }
+
/* in-place calc */
void cudak_(cuda_add_row)(const Matrix *a, Matrix *b, double beta) {
dim3 threadsPerBlock(CUDA_THREADS_N, CUDA_THREADS_N);
diff --git a/matrix/generic/cumatrix.c b/matrix/generic/cumatrix.c
index 373fc42..8e7d34f 100644
--- a/matrix/generic/cumatrix.c
+++ b/matrix/generic/cumatrix.c
@@ -11,15 +11,11 @@
#define MATRIX_BASE_TNAME nerv_matrix_cuda_tname
#define NERV_GENERIC_MATRIX
#define NERV_GENERIC_CUKERNEL
+#define PROFILE_HASHMAP_SIZE 123457
#include "../../common.h"
#include "../cukernel.h"
-#include "cuda.h"
-#include "cuda_runtime.h"
-#include "driver_types.h"
-#include "cublas_v2.h"
#include "../cuda_helper.h"
-
-static cublasHandle_t cublas_handle;
+#include <string.h>
Matrix *nerv_matrix_(new_)(lua_State *L, long nrow, long ncol);
void nerv_matrix_(data_free)(lua_State *L, Matrix *self);
@@ -27,6 +23,7 @@ void nerv_matrix_(data_free)(lua_State *L, Matrix *self);
static void nerv_matrix_(add_)(lua_State *L, const Matrix *a, const Matrix *b,
const Matrix *c,
MATRIX_ELEM alpha, MATRIX_ELEM beta) {
+ PROFILE_START
CUBLAS_SAFE_CALL(
NERV_CUBLAS_(geam)(cublas_handle, CUBLAS_OP_N, CUBLAS_OP_N,
a->ncol, a->nrow,
@@ -35,6 +32,7 @@ static void nerv_matrix_(add_)(lua_State *L, const Matrix *a, const Matrix *b,
&beta,
MATRIX_ELEM_PTR(b), b->stride / sizeof(MATRIX_ELEM),
MATRIX_ELEM_PTR(c), c->stride / sizeof(MATRIX_ELEM)));
+ PROFILE_STOP
}
static int nerv_matrix_(add)(lua_State *L) {
@@ -75,6 +73,7 @@ static int nerv_matrix_(mul)(lua_State *L) {
nerv_error(L, "Wrong dimension of multipliers");
/* MATRIX_ELEM alpha = 1.0f, beta = 0.0f; */
/* Because matrix in Nerv is row-major, here b comes first */
+ PROFILE_START
CUBLAS_SAFE_CALL(
NERV_CUBLAS_(gemm)(cublas_handle, tb, ta,
bn, am, bm,
@@ -83,6 +82,7 @@ static int nerv_matrix_(mul)(lua_State *L) {
MATRIX_ELEM_PTR(a), a->stride / sizeof(MATRIX_ELEM),
&beta,
MATRIX_ELEM_PTR(c), c->stride / sizeof(MATRIX_ELEM)));
+ PROFILE_STOP
return 0;
}
@@ -97,7 +97,9 @@ static int nerv_matrix_(sigmoid)(lua_State *L) {
Matrix *a = luaT_checkudata(L, 1, nerv_matrix_(tname));
Matrix *b = luaT_checkudata(L, 2, nerv_matrix_(tname));
CHECK_SAME_DIMENSION(a, b);
+ PROFILE_START
cudak_(cuda_sigmoid)(b, a);
+ PROFILE_STOP
return 0;
}
@@ -107,30 +109,38 @@ static int nerv_matrix_(sigmoid_grad)(lua_State *L) {
Matrix *output = luaT_checkudata(L, 3, nerv_matrix_(tname));
CHECK_SAME_DIMENSION(nerr, err);
CHECK_SAME_DIMENSION(nerr, output);
+ PROFILE_START
cudak_(cuda_sigmoid_grad)(output, err, nerr);
+ PROFILE_STOP
return 0;
}
static int nerv_matrix_(softmax)(lua_State *L) {
Matrix *a = luaT_checkudata(L, 2, nerv_matrix_(tname));
Matrix *b = luaT_checkudata(L, 1, nerv_matrix_(tname));
- Matrix *max;
+ Matrix *max, *max_idx;
Matrix *dno;
CHECK_SAME_DIMENSION(a, b);
max = nerv_matrix_(new_)(L, a->nrow, 1);
+ max_idx = nerv_matrix_(new_)(L, a->nrow, 1);
dno = nerv_matrix_(new_)(L, a->nrow, 1);
- cudak_(cuda_rowmax)(a, max);
+ PROFILE_START
+ cudak_(cuda_rowmax_idx)(a, max, max_idx);
cudak_(cuda_softmax_denominator)(a, max, dno);
cudak_(cuda_softmax_final)(a, max, dno, b);
+ PROFILE_STOP
nerv_matrix_(data_free)(L, max);
nerv_matrix_(data_free)(L, dno);
- return 0;
+ luaT_pushudata(L, max_idx, nerv_matrix_(tname));
+ return 1;
}
static int nerv_matrix_(rowsum)(lua_State *L) {
Matrix *a = luaT_checkudata(L, 1, nerv_matrix_(tname));
Matrix *b = nerv_matrix_(new_)(L, a->nrow, 1);
+ PROFILE_START
cudak_(cuda_rowsum)(a, b);
+ PROFILE_STOP
luaT_pushudata(L, b, nerv_matrix_(tname));
return 1;
}
@@ -138,7 +148,21 @@ static int nerv_matrix_(rowsum)(lua_State *L) {
static int nerv_matrix_(colsum)(lua_State *L) {
Matrix *a = luaT_checkudata(L, 1, nerv_matrix_(tname));
Matrix *b = nerv_matrix_(new_)(L, 1, a->ncol);
+ PROFILE_START
cudak_(cuda_colsum)(a, b);
+ PROFILE_STOP
+ luaT_pushudata(L, b, nerv_matrix_(tname));
+ return 1;
+}
+
+static int nerv_matrix_(colsame)(lua_State *L) {
+ Matrix *a = luaT_checkudata(L, 1, nerv_matrix_(tname));
+ Matrix *ref = luaT_checkudata(L, 2, nerv_matrix_(tname));
+ Matrix *b = nerv_matrix_(new_)(L, 1, a->ncol);
+ CHECK_SAME_DIMENSION(a, ref);
+ PROFILE_START
+ cudak_(cuda_colsame)(a, ref, b);
+ PROFILE_STOP
luaT_pushudata(L, b, nerv_matrix_(tname));
return 1;
}
@@ -146,11 +170,24 @@ static int nerv_matrix_(colsum)(lua_State *L) {
static int nerv_matrix_(rowmax)(lua_State *L) {
Matrix *a = luaT_checkudata(L, 1, nerv_matrix_(tname));
Matrix *b = nerv_matrix_(new_)(L, a->nrow, 1);
+ PROFILE_START
cudak_(cuda_rowmax)(a, b);
+ PROFILE_STOP
luaT_pushudata(L, b, nerv_matrix_(tname));
return 1;
}
+static int nerv_matrix_(rowmax_idx)(lua_State *L) {
+ Matrix *a = luaT_checkudata(L, 1, nerv_matrix_(tname));
+ Matrix *b = nerv_matrix_(new_)(L, a->nrow, 1);
+ Matrix *idx = nerv_matrix_(new_)(L, a->nrow, 1);
+ PROFILE_START
+ cudak_(cuda_rowmax_idx)(a, b, idx);
+ PROFILE_STOP
+ luaT_pushudata(L, b, nerv_matrix_(tname));
+ luaT_pushudata(L, idx, nerv_matrix_(tname));
+ return 2;
+}
static int nerv_matrix_(add_row)(lua_State *L) {
Matrix *a = luaT_checkudata(L, 2, nerv_matrix_(tname));
@@ -160,14 +197,18 @@ static int nerv_matrix_(add_row)(lua_State *L) {
nerv_error(L, "the number of columns is not the same");
if (a->nrow != 1)
nerv_error(L, "a row vector is expected");
+ PROFILE_START
cudak_(cuda_add_row)(a, b, beta);
+ PROFILE_STOP
return 0;
}
static int nerv_matrix_(fill)(lua_State *L) {
Matrix *self = luaT_checkudata(L, 1, nerv_matrix_(tname));
double val = luaL_checknumber(L, 2);
+ PROFILE_START
cudak_(cuda_fill)(self, val);
+ PROFILE_STOP
return 0;
}
@@ -183,11 +224,13 @@ static int nerv_matrix_(copy_fromd)(lua_State *L) {
nerv_error(L, "invalid copy interval");
if (a->ncol != b->ncol)
nerv_error(L, "matrices should be of the same dimension");
+ PROFILE_START
CUDA_SAFE_SYNC_CALL(
cudaMemcpy2D(MATRIX_ROW_PTR(a, a_begin), a->stride,
MATRIX_ROW_PTR(b, b_begin), b->stride,
sizeof(MATRIX_ELEM) * b->ncol, b_end - b_begin,
cudaMemcpyDeviceToDevice));
+ PROFILE_STOP
return 0;
}
@@ -204,11 +247,13 @@ static int nerv_matrix_(copy_fromh)(lua_State *L) {
nerv_error(L, "invalid copy interval");
if (a->ncol != b->ncol)
nerv_error(L, "matrices should be of the same dimension");
+ PROFILE_START
CUDA_SAFE_SYNC_CALL(
cudaMemcpy2D(MATRIX_ROW_PTR(a, a_begin), a->stride,
MATRIX_ROW_PTR(b, b_begin), b->stride,
sizeof(MATRIX_ELEM) * b->ncol, b_end - b_begin,
cudaMemcpyHostToDevice));
+ PROFILE_STOP
return 0;
}
@@ -224,11 +269,13 @@ static int nerv_matrix_(copy_toh)(lua_State *L) {
nerv_error(L, "invalid copy interval");
if (b->ncol != a->ncol)
nerv_error(L, "matrices should be of the same dimension");
+ PROFILE_START
CUDA_SAFE_SYNC_CALL(
cudaMemcpy2D(MATRIX_ROW_PTR(b, b_begin), b->stride,
MATRIX_ROW_PTR(a, a_begin), a->stride,
sizeof(MATRIX_ELEM) * a->ncol, a_end - a_begin,
cudaMemcpyDeviceToHost));
+ PROFILE_STOP
return 0;
}
@@ -237,6 +284,7 @@ static int nerv_matrix_(trans)(lua_State *L) {
Matrix *b = nerv_matrix_(new_)(L, a->ncol, a->nrow);
MATRIX_ELEM alpha = 1, beta = 0;
/* FIXME: possible memory leak when lua error is raised */
+ PROFILE_START
CUBLAS_SAFE_CALL(
NERV_CUBLAS_(geam)(cublas_handle, CUBLAS_OP_T, CUBLAS_OP_T,
a->nrow, a->ncol,
@@ -245,6 +293,7 @@ static int nerv_matrix_(trans)(lua_State *L) {
&beta,
MATRIX_ELEM_PTR(a), a->stride / sizeof(MATRIX_ELEM),
MATRIX_ELEM_PTR(b), b->stride / sizeof(MATRIX_ELEM)));
+ PROFILE_STOP
luaT_pushudata(L, b, nerv_matrix_(tname));
return 1;
}
@@ -255,7 +304,9 @@ static int nerv_matrix_(mul_elem)(lua_State *L) {
Matrix *c = luaT_checkudata(L, 1, nerv_matrix_(tname));
CHECK_SAME_DIMENSION(a, b);
CHECK_SAME_DIMENSION(a, c);
+ PROFILE_START
cudak_(cuda_mul_elem)(a, b, c);
+ PROFILE_STOP
return 0;
}
@@ -263,7 +314,9 @@ static int nerv_matrix_(log_elem)(lua_State *L) {
Matrix *a = luaT_checkudata(L, 2, nerv_matrix_(tname));
Matrix *b = luaT_checkudata(L, 1, nerv_matrix_(tname));
CHECK_SAME_DIMENSION(a, b);
+ PROFILE_START
cudak_(cuda_log_elem)(a, b);
+ PROFILE_STOP
return 0;
}
@@ -274,8 +327,10 @@ static int nerv_matrix_(decompress)(lua_State *L) {
if (a->ncol != 1)
nerv_error(L, "the compressed matrix must be a column vector");
b = nerv_matrix_(new_)(L, a->nrow, orig_col);
+ PROFILE_START
cudak_(cuda_fill)(b, 0.0);
cudak_(cuda_decompress)(a, b);
+ PROFILE_STOP
luaT_pushudata(L, b, nerv_matrix_(tname));
return 1;
}
@@ -285,21 +340,25 @@ static int nerv_matrix_(copy_rows_fromh_by_idx)(lua_State *L) {
Matrix *a = luaT_checkudata(L, 1, nerv_matrix_(tname));
Matrix *b = luaT_checkudata(L, 2, MATRIX_CUMATRIX_HOST_TNAME);
Matrix *idx = luaT_checkudata