aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorDeterminant <[email protected]>2015-05-19 15:01:38 +0800
committerDeterminant <[email protected]>2015-05-19 15:01:38 +0800
commite9b8855c894daa4e6749acfe891f68b3ed8ed481 (patch)
tree5a3ea5e89bd475dc4312d379ffc7bf9121862dbb
parent9b6606504241f27a9d42b96f535bf5f2c2918161 (diff)
add double precision matrix implementation
-rw-r--r--Makefile2
-rw-r--r--cumatrix_example.lua19
-rw-r--r--matrix/cukernel.cu196
-rw-r--r--matrix/cukernel.h13
-rw-r--r--matrix/cumatrix.c163
-rw-r--r--matrix/generic/cukernel.cu184
-rw-r--r--matrix/generic/cumatrix.c143
-rw-r--r--matrix/generic/elem_type.h11
-rw-r--r--matrix/generic/matrix.c83
-rw-r--r--matrix/generic/matrix.h1
-rw-r--r--matrix/generic/mmatrix.c (renamed from matrix/matrix.c)25
-rw-r--r--matrix/init.c13
-rw-r--r--matrix/init.lua22
-rw-r--r--matrix/mmatrix.c5
-rw-r--r--matrix_example.lua16
-rw-r--r--mmatrix_example.lua9
16 files changed, 482 insertions, 423 deletions
diff --git a/Makefile b/Makefile
index 9f6413e..927dfaf 100644
--- a/Makefile
+++ b/Makefile
@@ -1,5 +1,5 @@
.PHONY: all clean luajit
-OBJS := oop_example.o nerv.o luaT.o common.o matrix/matrix.o matrix/cumatrix.o matrix/init.o matrix/cukernel.o
+OBJS := oop_example.o nerv.o luaT.o common.o matrix/mmatrix.o matrix/cumatrix.o matrix/init.o matrix/cukernel.o
LIBS := libnerv.so
LUA_LIBS := matrix/init.lua nerv.lua
INCLUDE := -I build/luajit-2.0/include/luajit-2.0/ -DLUA_USE_APICHECK
diff --git a/cumatrix_example.lua b/cumatrix_example.lua
index c2f2139..f26d7f4 100644
--- a/cumatrix_example.lua
+++ b/cumatrix_example.lua
@@ -1,15 +1,16 @@
m = 10
n = 10
-t = nerv.FloatCuMatrix(m, n)
--- print(t)
-a = t[1]
+fm = nerv.FloatCuMatrix(m, n)
+dm = nerv.DoubleCuMatrix(m, n)
for i = 0, m - 1 do
for j = 0, n - 1 do
--- t[i][j] = i + j
- t[i][j] = math.random(10)
+ -- local t = math.random(10)
+ t = i / (j + 1)
+ fm[i][j] = t
+ dm[i][j] = t
end
end
-print(t)
-print(t:colsum())
-print(t:colmax())
-print(t:softmax())
+print(fm)
+print(fm:softmax())
+print(dm)
+print(dm:softmax())
diff --git a/matrix/cukernel.cu b/matrix/cukernel.cu
index ee6d871..1f97b41 100644
--- a/matrix/cukernel.cu
+++ b/matrix/cukernel.cu
@@ -1,177 +1,19 @@
-#include <assert.h>
-#include <stdio.h>
-#include "generic/matrix.h"
-#include "cuda.h"
-#define CUDA_THREADS_N 16
-#define CUDA_THREADS_NN (16 * 16)
-#define CEIL_DIV(a, b) (((a) + (b) - 1) / (b))
-__global__ void sigmoid(const float *a, float *b,
- int nrow, int ncol, int stride) {
- int j = blockIdx.x * blockDim.x + threadIdx.x;
- int i = blockIdx.y * blockDim.y + threadIdx.y;
- long idx;
- if (i >= nrow || j >= ncol) return;
- idx = j + i * stride;
- b[idx] = 1.0 / (1.0 + exp(-a[idx]));
-}
-
-__global__ void softmax_final(const float *a, float *b,
- const float *max, const float *deno,
- int nrow, int ncol, int stride, int mstride) {
- int j = blockIdx.x * blockDim.x + threadIdx.x;
- int i = blockIdx.y * blockDim.y + threadIdx.y;
- long idx;
- if (i >= nrow || j >= ncol) return;
- idx = j + i * stride;
- b[idx] = exp(a[idx] - max[0 + i * mstride]) / deno[0 + i * mstride];
-}
-
-__global__ void block_reduce_sum(const float *input, float *output,
- const int istride, const int ostride,
- const int n) {
- extern __shared__ float arr[];
- int j = blockIdx.x * blockDim.x + threadIdx.x;
- arr[threadIdx.x] = j < n ? input[j + istride * blockIdx.y] : 0;
- __syncthreads();
- for (int offset = blockDim.x >> 1; offset; offset >>= 1)
- {
- if (threadIdx.x < offset)
- arr[threadIdx.x] += arr[threadIdx.x + offset];
- __syncthreads();
- }
- if (threadIdx.x == 0)
- output[blockIdx.x + ostride * blockIdx.y] = arr[0];
-}
-
-__global__ void block_reduce_softmax_sum(const float *input, float *output,
- const float *max,
- const int istride, const int ostride,
- const int mstride, const int n) {
- extern __shared__ float arr[];
- int j = blockIdx.x * blockDim.x + threadIdx.x;
- arr[threadIdx.x] = j < n ? exp(input[j + istride * blockIdx.y] - \
- max[0 + mstride * blockIdx.y]) : 0;
- __syncthreads();
- for (int offset = blockDim.x >> 1; offset; offset >>= 1)
- {
- if (threadIdx.x < offset)
- arr[threadIdx.x] += arr[threadIdx.x + offset];
- __syncthreads();
- }
- if (threadIdx.x == 0)
- output[blockIdx.x + ostride * blockIdx.y] = arr[0];
-}
-
-__global__ void block_reduce_max(const float *input, float *output,
- const int istride, const int ostride,
- const int n) {
- extern __shared__ float arr[];
- int j = blockIdx.x * blockDim.x + threadIdx.x;
- arr[threadIdx.x] = j < n ? input[j + istride * blockIdx.y] : 0;
- __syncthreads();
- for (int offset = blockDim.x >> 1; offset; offset >>= 1)
- {
- if (threadIdx.x < offset)
- {
- float l = arr[threadIdx.x],
- r = arr[threadIdx.x + offset];
- if (r > l) arr[threadIdx.x] = r;
- }
- __syncthreads();
- }
- if (threadIdx.x == 0)
- output[blockIdx.x + ostride * blockIdx.y] = arr[0];
-}
-
-extern "C" {
-#include "cukernel.h"
- void cuda_sigmoid(const Matrix *a, Matrix *b) {
- dim3 threadsPerBlock(CUDA_THREADS_N,
- CUDA_THREADS_N);
- dim3 numBlocks(CEIL_DIV(b->ncol, threadsPerBlock.x),
- CEIL_DIV(b->nrow, threadsPerBlock.y));
- sigmoid<<<numBlocks, threadsPerBlock>>>(a->data.f, b->data.f, b->nrow, b->ncol,
- b->stride / sizeof(float));
- }
-
- void cuda_colsum(const Matrix *a, Matrix *b) {
- dim3 block(CUDA_THREADS_NN, 1);
- int ncol = a->ncol;
- int blocks_per_row = CEIL_DIV(ncol, block.x);
- dim3 grid(blocks_per_row, a->nrow);
- float *res;
- size_t stride;
- cudaMallocPitch(&res, &stride, blocks_per_row * sizeof(float), a->nrow);
- block_reduce_sum<<<grid, block, block.x * sizeof(float)>>> \
- (a->data.f, res,
- a->stride / sizeof(float), stride / sizeof(float),
- ncol);
- ncol = blocks_per_row;
- assert((unsigned long)ncol <= block.x);
- grid.x = 1;
- block_reduce_sum<<<grid, block, block.x * sizeof(float)>>> \
- (res, b->data.f,
- stride / sizeof(float), b->stride / sizeof(float),
- ncol);
- cudaFree(res);
- }
-
- void cuda_softmax_final(const Matrix *a, const Matrix *max,
- const Matrix *deno, Matrix *b) {
- dim3 threadsPerBlock(CUDA_THREADS_N,
- CUDA_THREADS_N);
- dim3 numBlocks(CEIL_DIV(b->ncol, threadsPerBlock.x),
- CEIL_DIV(b->nrow, threadsPerBlock.y));
- softmax_final<<<numBlocks, threadsPerBlock>>>(a->data.f, b->data.f,
- max->data.f, deno->data.f,
- b->nrow, b->ncol,
- b->stride / sizeof(float),
- max->stride / sizeof(float));
- }
-
- void cuda_softmax_denominator(const Matrix *a, const Matrix *max, Matrix *b) {
- dim3 block(CUDA_THREADS_NN, 1);
- int ncol = a->ncol;
- int blocks_per_row = CEIL_DIV(ncol, block.x);
- dim3 grid(blocks_per_row, a->nrow);
- float *res;
- size_t stride;
- assert(max->ncol == 1);
- cudaMallocPitch(&res, &stride, blocks_per_row * sizeof(float), a->nrow);
- block_reduce_softmax_sum<<<grid, block, block.x * sizeof(float)>>> \
- (a->data.f, res, max->data.f,
- a->stride / sizeof(float), stride / sizeof(float),
- max->stride / sizeof(float),
- ncol);
- ncol = blocks_per_row;
- assert((unsigned long)ncol <= block.x);
- grid.x = 1;
- block_reduce_sum<<<grid, block, block.x * sizeof(float)>>> \
- (res, b->data.f,
- stride / sizeof(float), b->stride / sizeof(float),
- ncol);
- cudaFree(res);
- }
-
- void cuda_colmax(const Matrix *a, Matrix *b) {
- dim3 block(CUDA_THREADS_NN, 1);
- int ncol = a->ncol;
- int blocks_per_row = CEIL_DIV(ncol, block.x);
- dim3 grid(blocks_per_row, a->nrow);
- float *res;
- size_t stride;
- cudaMallocPitch(&res, &stride, blocks_per_row * sizeof(float), a->nrow);
- block_reduce_max<<<grid, block, block.x * sizeof(float)>>> \
- (a->data.f, res,
- a->stride / sizeof(float), stride / sizeof(float),
- ncol);
- ncol = blocks_per_row;
- assert((unsigned long)ncol <= block.x);
- grid.x = 1;
- block_reduce_max<<<grid, block, block.x * sizeof(float)>>> \
- (res, b->data.f,
- stride / sizeof(float), b->stride / sizeof(float),
- ncol);
- cudaFree(res);
- }
-}
+#define NERV_GENERIC_CUKERNEL
+
+#define cudak_(NAME) cudak_float_ ## NAME
+#define MATRIX_USE_FLOAT
+#include "generic/elem_type.h"
+#include "generic/cukernel.cu"
+#undef cudak_
+#undef MATRIX_USE_FLOAT
+#undef MATRIX_ELEM
+#undef MATRIX_ELEM_PTR
+
+#define cudak_(NAME) cudak_double_ ## NAME
+#define MATRIX_USE_DOUBLE
+#include "generic/elem_type.h"
+#include "generic/cukernel.cu"
+#undef cudak_
+#undef MATRIX_USE_DOUBLE
+#undef MATRIX_ELEM
+#undef MATRIX_ELEM_PTR
diff --git a/matrix/cukernel.h b/matrix/cukernel.h
index 9c13558..ea81e5a 100644
--- a/matrix/cukernel.h
+++ b/matrix/cukernel.h
@@ -1,8 +1,7 @@
-#ifndef NERV_CUKERNEL_H
-#define NERV_CUKERNEL_H
-void cuda_sigmoid(const Matrix *a, Matrix *b);
-void cuda_colsum(const Matrix *a, Matrix *b);
-void cuda_colmax(const Matrix *a, Matrix *b);
-void cuda_softmax_denominator(const Matrix *a, const Matrix *max, Matrix *b);
-void cuda_softmax_final(const Matrix *a, const Matrix *max, const Matrix *deno, Matrix *b);
+#ifdef NERV_GENERIC_CUKERNEL
+void cudak_(cuda_sigmoid)(const Matrix *a, Matrix *b);
+void cudak_(cuda_colsum)(const Matrix *a, Matrix *b);
+void cudak_(cuda_colmax)(const Matrix *a, Matrix *b);
+void cudak_(cuda_softmax_denominator)(const Matrix *a, const Matrix *max, Matrix *b);
+void cudak_(cuda_softmax_final)(const Matrix *a, const Matrix *max, const Matrix *deno, Matrix *b);
#endif
diff --git a/matrix/cumatrix.c b/matrix/cumatrix.c
index aa10571..90a6703 100644
--- a/matrix/cumatrix.c
+++ b/matrix/cumatrix.c
@@ -1,139 +1,24 @@
-#define MATRIX_DATA_FREE(ptr) cuda_float_array_free(ptr)
-#define MATRIX_DATA_ALLOC(dptr, stride, width, height) cuda_float_array_alloc(dptr, stride, width, height)
-#define MATRIX_DATA_WRITE(data, idx, val) cuda_float_array_write(data, idx, val)
-#define MATRIX_DATA_READ(data, idx) cuda_float_array_read(data, idx)
-#define MATRIX_INIT(L) cuda_float_init(L)
-#define NERV_GENERIC_MATRIX
-#define nerv_float_matrix_(NAME) nerv_float_matrix_cuda_ ## NAME
-#include "../common.h"
-#include "generic/matrix.h"
-#include "cukernel.h"
-#include "cuda.h"
-#include "cuda_runtime.h"
-#include "driver_types.h"
-#include "cublas_v2.h"
-
-const char *nerv_float_matrix_(tname) = "nerv.FloatCuMatrix";
-static cublasHandle_t cublas_handle;
-
-Matrix *nerv_float_matrix_(new_)(long nrow, long ncol);
-static int nerv_float_matrix_(add)(lua_State *L) {
- Matrix *a = luaT_checkudata(L, 1, nerv_float_matrix_(tname));
- Matrix *b = luaT_checkudata(L, 2, nerv_float_matrix_(tname));
- Matrix *c;
- long nrow, ncol;
- if (!(a->nrow == b->nrow && a->ncol == b->ncol))
- nerv_error(L, "Matrices should be of the same dimension");
- nrow = a->nrow;
- ncol = a->ncol;
- c = nerv_float_matrix_(new_)(nrow, ncol);
- float alpha = 1.0f, beta = 1.0f;
- cublasSgeam(cublas_handle, CUBLAS_OP_N, CUBLAS_OP_N,
- ncol, nrow,
- &alpha,
- a->data.f, a->stride / sizeof(float),
- &beta,
- b->data.f, b->stride / sizeof(float),
- c->data.f, c->stride / sizeof(float));
- luaT_pushudata(L, c, nerv_float_matrix_(tname));
- return 1;
-}
-
-static int nerv_float_matrix_(mul)(lua_State *L) {
- Matrix *a = luaT_checkudata(L, 1, nerv_float_matrix_(tname));
- Matrix *b = luaT_checkudata(L, 2, nerv_float_matrix_(tname));
- Matrix *c;
- if (a->ncol != b->nrow)
- nerv_error(L, "Wrong dimension of multipliers");
- c = nerv_float_matrix_(new_)(a->nrow, b->ncol);
- float alpha = 1.0f, beta = 0.0f;
- cublasSgemm(cublas_handle, CUBLAS_OP_N, CUBLAS_OP_N,
- b->ncol, a->nrow, b->nrow,
- &alpha,
- b->data.f, b->stride / sizeof(float),
- a->data.f, a->stride / sizeof(float),
- &beta,
- c->data.f, c->stride / sizeof(float));
- luaT_pushudata(L, c, nerv_float_matrix_(tname));
- return 1;
-}
-
-static int nerv_float_matrix_(sigmoid)(lua_State *L) {
- Matrix *a = luaT_checkudata(L, 1, nerv_float_matrix_(tname));
- Matrix *b = nerv_float_matrix_(new_)(a->nrow, a->ncol);
- cuda_sigmoid(a, b);
- luaT_pushudata(L, b, nerv_float_matrix_(tname));
- return 1;
-}
-
-static int nerv_float_matrix_(softmax)(lua_State *L) {
- Matrix *a = luaT_checkudata(L, 1, nerv_float_matrix_(tname));
- Matrix *max = nerv_float_matrix_(new_)(a->nrow, 1);
- Matrix *dno = nerv_float_matrix_(new_)(a->nrow, 1);
- Matrix *b = nerv_float_matrix_(new_)(a->nrow, a->ncol);
- cuda_colmax(a, max);
- cuda_softmax_denominator(a, max, dno);
- cuda_softmax_final(a, max, dno, b);
- luaT_pushudata(L, b, nerv_float_matrix_(tname));
- return 1;
-}
-
-static int nerv_float_matrix_(colsum)(lua_State *L) {
- Matrix *a = luaT_checkudata(L, 1, nerv_float_matrix_(tname));
- Matrix *b = nerv_float_matrix_(new_)(a->nrow, 1);
- cuda_colsum(a, b);
- luaT_pushudata(L, b, nerv_float_matrix_(tname));
- return 1;
-}
-
-static int nerv_float_matrix_(colmax)(lua_State *L) {
- Matrix *a = luaT_checkudata(L, 1, nerv_float_matrix_(tname));
- Matrix *b = nerv_float_matrix_(new_)(a->nrow, 1);
- cuda_colmax(a, b);
- luaT_pushudata(L, b, nerv_float_matrix_(tname));
- return 1;
-}
-
-static const luaL_Reg nerv_float_matrix_(extra_methods)[] = {
- {"__add__", nerv_float_matrix_(add)},
- {"__mul__", nerv_float_matrix_(mul)},
- {"sigmoid", nerv_float_matrix_(sigmoid)},
- {"softmax", nerv_float_matrix_(softmax)},
- {"colsum", nerv_float_matrix_(colsum)},
- {"colmax", nerv_float_matrix_(colmax)},
- {NULL, NULL}
-};
-
-static void cuda_float_init(lua_State *L) {
- luaN_append_methods(L, nerv_float_matrix_(extra_methods));
- cublasCreate(&cublas_handle);
-}
-
-static void cuda_float_array_free(float *ptr) {
- cudaFree(ptr);
-}
-
-static void cuda_float_array_alloc(float **dptr, size_t *stride,
- long width, long height) {
- cudaMallocPitch((void **)dptr, stride, width, height);
-}
-
-static float cuda_float_array_read(float *data, int idx) {
- float res;
- cudaMemcpy(&res, data + idx, sizeof(float), cudaMemcpyDeviceToHost);
- return res;
-}
-
-static void cuda_float_array_write(float *data, int idx, float val) {
- cudaMemcpy(data + idx, &val, sizeof(float), cudaMemcpyHostToDevice);
-}
-
-int nerv_float_matrix_(get_elem)(lua_State *L) {
- return nerv_error_method_not_implemented(L);
-}
-
-int nerv_float_matrix_(set_elem)(lua_State *L) {
- return nerv_error_method_not_implemented(L);
-}
-
-#include "generic/matrix.c"
+#define NERV_GENERIC_CUMATRIX
+
+#define MATRIX_USE_FLOAT
+#define cuda_matrix_(NAME) cuda_matrix_float_ ## NAME
+#define nerv_matrix_(NAME) nerv_matrix_float_cuda_ ## NAME
+#define cudak_(NAME) cudak_float_ ## NAME
+#define NERV_CUBLAS_(NAME) cublasS##NAME
+const char *nerv_matrix_(tname) = "nerv.FloatCuMatrix";
+#include "generic/cumatrix.c"
+#undef NERV_CUBLAS_
+#undef cudak_
+#undef nerv_matrix_
+#undef cuda_matrix_
+#undef MATRIX_USE_FLOAT
+#undef MATRIX_ELEM
+#undef MATRIX_ELEM_PTR
+
+#define MATRIX_USE_DOUBLE
+#define cuda_matrix_(NAME) cuda_matrix_double_ ## NAME
+#define nerv_matrix_(NAME) nerv_matrix_double_cuda_ ## NAME
+#define cudak_(NAME) cudak_double_ ## NAME
+#define NERV_CUBLAS_(NAME) cublasD##NAME
+const char *nerv_matrix_(tname) = "nerv.DoubleCuMatrix";
+#include "generic/cumatrix.c"
diff --git a/matrix/generic/cukernel.cu b/matrix/generic/cukernel.cu
new file mode 100644
index 0000000..a37ccf4
--- /dev/null
+++ b/matrix/generic/cukernel.cu
@@ -0,0 +1,184 @@
+#ifdef NERV_GENERIC_CUKERNEL
+#include <assert.h>
+#include <stdio.h>
+#include "matrix.h"
+#include "cuda.h"
+#define CUDA_THREADS_N 16
+#define CUDA_THREADS_NN (16 * 16)
+#define CEIL_DIV(a, b) (((a) + (b) - 1) / (b))
+__global__ void cudak_(sigmoid)(const MATRIX_ELEM *a, MATRIX_ELEM *b,
+ int nrow, int ncol, int stride) {
+ int j = blockIdx.x * blockDim.x + threadIdx.x;
+ int i = blockIdx.y * blockDim.y + threadIdx.y;
+ long idx;
+ if (i >= nrow || j >= ncol) return;
+ idx = j + i * stride;
+ b[idx] = 1.0 / (1.0 + exp(-a[idx]));
+}
+
+__global__ void cudak_(softmax_final)(const MATRIX_ELEM *a, MATRIX_ELEM *b,
+ const MATRIX_ELEM *max, const MATRIX_ELEM *deno,
+ int nrow, int ncol, int stride, int mstride) {
+ int j = blockIdx.x * blockDim.x + threadIdx.x;
+ int i = blockIdx.y * blockDim.y + threadIdx.y;
+ long idx;
+ if (i >= nrow || j >= ncol) return;
+ idx = j + i * stride;
+ b[idx] = exp(a[idx] - max[0 + i * mstride]) / deno[0 + i * mstride];
+}
+
+__global__ void cudak_(block_reduce_sum)(const MATRIX_ELEM *input,
+ MATRIX_ELEM *output,
+ const int istride, const int ostride,
+ const int n) {
+ extern __shared__ MATRIX_ELEM cudak_(arr)[];
+ int j = blockIdx.x * blockDim.x + threadIdx.x;
+ cudak_(arr)[threadIdx.x] = j < n ? input[j + istride * blockIdx.y] : 0;
+ __syncthreads();
+ for (int offset = blockDim.x >> 1; offset; offset >>= 1)
+ {
+ if (threadIdx.x < offset)
+ cudak_(arr)[threadIdx.x] += cudak_(arr)[threadIdx.x + offset];
+ __syncthreads();
+ }
+ if (threadIdx.x == 0)
+ output[blockIdx.x + ostride * blockIdx.y] = cudak_(arr)[0];
+}
+
+__global__ void cudak_(block_reduce_softmax_sum)(const MATRIX_ELEM *input,
+ MATRIX_ELEM *output,
+ const MATRIX_ELEM *max,
+ const int istride, const int ostride,
+ const int mstride, const int n) {
+ extern __shared__ MATRIX_ELEM cudak_(arr)[];
+ int j = blockIdx.x * blockDim.x + threadIdx.x;
+ cudak_(arr)[threadIdx.x] = j < n ? exp(input[j + istride * blockIdx.y] - \
+ max[0 + mstride * blockIdx.y]) : 0;
+ __syncthreads();
+ for (int offset = blockDim.x >> 1; offset; offset >>= 1)
+ {
+ if (threadIdx.x < offset)
+ cudak_(arr)[threadIdx.x] += cudak_(arr)[threadIdx.x + offset];
+ __syncthreads();
+ }
+ if (threadIdx.x == 0)
+ output[blockIdx.x + ostride * blockIdx.y] = cudak_(arr)[0];
+}
+
+__global__ void cudak_(block_reduce_max)(const MATRIX_ELEM *input,
+ MATRIX_ELEM *output,
+ const int istride, const int ostride,
+ const int n) {
+ extern __shared__ MATRIX_ELEM cudak_(arr)[];
+ int j = blockIdx.x * blockDim.x + threadIdx.x;
+ cudak_(arr)[threadIdx.x] = j < n ? input[j + istride * blockIdx.y] : 0;
+ __syncthreads();
+ for (int offset = blockDim.x >> 1; offset; offset >>= 1)
+ {
+ if (threadIdx.x < offset)
+ {
+ MATRIX_ELEM l = cudak_(arr)[threadIdx.x],
+ r = cudak_(arr)[threadIdx.x + offset];
+ if (r > l) cudak_(arr)[threadIdx.x] = r;
+ }
+ __syncthreads();
+ }
+ if (threadIdx.x == 0)
+ output[blockIdx.x + ostride * blockIdx.y] = cudak_(arr)[0];
+}
+
+extern "C" {
+#include "../cukernel.h"
+ void cudak_(cuda_sigmoid)(const Matrix *a, Matrix *b) {
+ dim3 threadsPerBlock(CUDA_THREADS_N,
+ CUDA_THREADS_N);
+ dim3 numBlocks(CEIL_DIV(b->ncol, threadsPerBlock.x),
+ CEIL_DIV(b->nrow, threadsPerBlock.y));
+ cudak_(sigmoid)<<<numBlocks, threadsPerBlock>>> \
+ (MATRIX_ELEM_PTR(a), MATRIX_ELEM_PTR(b), b->nrow, b->ncol,
+ b->stride / sizeof(MATRIX_ELEM));
+ }
+
+ void cudak_(cuda_colsum)(const Matrix *a, Matrix *b) {
+ dim3 block(CUDA_THREADS_NN, 1);
+ int ncol = a->ncol;
+ int blocks_per_row = CEIL_DIV(ncol, block.x);
+ dim3 grid(blocks_per_row, a->nrow);
+ MATRIX_ELEM *res;
+ size_t stride;
+ cudaMallocPitch(&res, &stride, blocks_per_row * sizeof(MATRIX_ELEM), a->nrow);
+ cudak_(block_reduce_sum)<<<grid, block, block.x * sizeof(MATRIX_ELEM)>>> \
+ (MATRIX_ELEM_PTR(a), res,
+ a->stride / sizeof(MATRIX_ELEM), stride / sizeof(MATRIX_ELEM),
+ ncol);
+ ncol = blocks_per_row;
+ assert((unsigned long)ncol <= block.x);
+ grid.x = 1;
+ cudak_(block_reduce_sum)<<<grid, block, block.x * sizeof(MATRIX_ELEM)>>> \
+ (res, MATRIX_ELEM_PTR(b),
+ stride / sizeof(MATRIX_ELEM), b->stride / sizeof(MATRIX_ELEM),
+ ncol);
+ cudaFree(res);
+ }
+
+ void cudak_(cuda_softmax_final)(const Matrix *a, const Matrix *max,
+ const Matrix *deno, Matrix *b) {
+ dim3 threadsPerBlock(CUDA_THREADS_N,
+ CUDA_THREADS_N);
+ dim3 numBlocks(CEIL_DIV(b->ncol, threadsPerBlock.x),
+ CEIL_DIV(b->nrow, threadsPerBlock.y));
+ cudak_(softmax_final)<<<numBlocks, threadsPerBlock>>> \
+ (MATRIX_ELEM_PTR(a), MATRIX_ELEM_PTR(b),
+ MATRIX_ELEM_PTR(max), MATRIX_ELEM_PTR(deno),
+ b->nrow, b->ncol,
+ b->stride / sizeof(MATRIX_ELEM),
+ max->stride / sizeof(MATRIX_ELEM));
+ }
+
+ void cudak_(cuda_softmax_denominator)(const Matrix *a, const Matrix *max, Matrix *b) {
+ dim3 block(CUDA_THREADS_NN, 1);
+ int ncol = a->ncol;
+ int blocks_per_row = CEIL_DIV(ncol, block.x);
+ dim3 grid(blocks_per_row, a->nrow);
+ MATRIX_ELEM *res;
+ size_t stride;
+ assert(max->ncol == 1);
+ cudaMallocPitch(&res, &stride, blocks_per_row * sizeof(MATRIX_ELEM), a->nrow);
+ cudak_(block_reduce_softmax_sum)<<<grid, block, block.x * sizeof(MATRIX_ELEM)>>> \
+ (MATRIX_ELEM_PTR(a), res, MATRIX_ELEM_PTR(max),
+ a->stride / sizeof(MATRIX_ELEM), stride / sizeof(MATRIX_ELEM),
+ max->stride / sizeof(MATRIX_ELEM),
+ ncol);
+ ncol = blocks_per_row;
+ assert((unsigned long)ncol <= block.x);
+ grid.x = 1;
+ cudak_(block_reduce_sum)<<<grid, block, block.x * sizeof(MATRIX_ELEM)>>> \
+ (res, MATRIX_ELEM_PTR(b),
+ stride / sizeof(MATRIX_ELEM), b->stride / sizeof(MATRIX_ELEM),
+ ncol);
+ cudaFree(res);
+ }
+
+ void cudak_(cuda_colmax)(const Matrix *a, Matrix *b) {
+ dim3 block(CUDA_THREADS_NN, 1);
+ int ncol = a->ncol;
+ int blocks_per_row = CEIL_DIV(ncol, block.x);
+ dim3 grid(blocks_per_row, a->nrow);
+ MATRIX_ELEM *res;
+ size_t stride;
+ cudaMallocPitch(&res, &stride, blocks_per_row * sizeof(MATRIX_ELEM), a->nrow);
+ cudak_(block_reduce_max)<<<grid, block, block.x * sizeof(MATRIX_ELEM)>>> \
+ (MATRIX_ELEM_PTR(a), res,
+ a->stride / sizeof(MATRIX_ELEM), stride / sizeof(MATRIX_ELEM),
+ ncol);
+ ncol = blocks_per_row;
+ assert((unsigned long)ncol <= block.x);
+ grid.x = 1;
+ cudak_(block_reduce_max)<<<grid, block, block.x * sizeof(MATRIX_ELEM)>>> \
+ (res, MATRIX_ELEM_PTR(b),
+ stride / sizeof(MATRIX_ELEM), b->stride / sizeof(MATRIX_ELEM),
+ ncol);
+ cudaFree(res);
+ }
+}
+#endif
diff --git a/matrix/generic/cumatrix.c b/matrix/generic/cumatrix.c
new file mode 100644
index 0000000..f0ef99d
--- /dev/null
+++ b/matrix/generic/cumatrix.c
@@ -0,0 +1,143 @@
+#ifdef NERV_GENERIC_CUMATRIX
+#include "matrix.h"
+#include "elem_type.h"
+
+#define MATRIX_DATA_FREE(ptr) cuda_matrix_(free)(ptr)
+#define MATRIX_DATA_ALLOC(dptr, stride, width, height) \
+ cuda_matrix_(alloc)(dptr, stride, width, height)
+#define MATRIX_DATA_WRITE(data, idx, val) cuda_matrix_(write)(data, idx, val)
+#define MATRIX_DATA_READ(data, idx) cuda_matrix_(read)(data, idx)
+#define MATRIX_INIT(L) cuda_matrix_(init)(L)
+#define NERV_GENERIC_MATRIX
+#define NERV_GENERIC_CUKERNEL
+#include "../../common.h"
+#include "../cukernel.h"
+#include "cuda.h"
+#include "cuda_runtime.h"
+#include "driver_types.h"
+#include "cublas_v2.h"
+
+static cublasHandle_t cublas_handle;
+
+Matrix *nerv_matrix_(new_)(long nrow, long ncol);
+static int nerv_matrix_(add)(lua_State *L) {
+ Matrix *a = luaT_checkudata(L, 1, nerv_matrix_(tname));
+ Matrix *b = luaT_checkudata(L, 2, nerv_matrix_(tname));
+ Matrix *c;
+ long nrow, ncol;
+ if (!(a->nrow == b->nrow && a->ncol == b->ncol))
+ nerv_error(L, "Matrices should be of the same dimension");
+ nrow = a->nrow;
+ ncol = a->ncol;
+ c = nerv_matrix_(new_)(nrow, ncol);
+ MATRIX_ELEM alpha = 1.0f, beta = 1.0f;
+ NERV_CUBLAS_(geam)(cublas_handle, CUBLAS_OP_N, CUBLAS_OP_N,
+ ncol, nrow,
+ &alpha,
+ MATRIX_ELEM_PTR(a), a->stride / sizeof(MATRIX_ELEM),
+ &beta,
+ MATRIX_ELEM_PTR(b), b->stride / sizeof(MATRIX_ELEM),
+ MATRIX_ELEM_PTR(c), c->stride / sizeof(MATRIX_ELEM));
+ luaT_pushudata(L, c, nerv_matrix_(tname));
+ return 1;
+}
+
+static int nerv_matrix_(mul)(lua_State *L) {
+ Matrix *a = luaT_checkudata(L, 1, nerv_matrix_(tname));
+ Matrix *b = luaT_checkudata(L, 2, nerv_matrix_(tname));
+ Matrix *c;
+ if (a->ncol != b->nrow)
+ nerv_error(L, "Wrong dimension of multipliers");
+ c = nerv_matrix_(new_)(a->nrow, b->ncol);
+ MATRIX_ELEM alpha = 1.0f, beta = 0.0f;
+ NERV_CUBLAS_(gemm)(cublas_handle, CUBLAS_OP_N, CUBLAS_OP_N,
+ b->ncol, a->nrow, b->nrow,
+ &alpha,
+ MATRIX_ELEM_PTR(b), b->stride / sizeof(MATRIX_ELEM),
+ MATRIX_ELEM_PTR(a), a->stride / sizeof(MATRIX_ELEM),
+ &beta,
+ MATRIX_ELEM_PTR(c), c->stride / sizeof(MATRIX_ELEM));
+ luaT_pushudata(L, c, nerv_matrix_(tname));
+ return 1;
+}
+
+static int nerv_matrix_(sigmoid)(lua_State *L) {
+ Matrix *a = luaT_checkudata(L, 1, nerv_matrix_(tname));
+ Matrix *b = nerv_matrix_(new_)(a->nrow, a->ncol);
+ cudak_(cuda_sigmoid)(a, b);
+ luaT_pushudata(L, b, nerv_matrix_(tname));
+ return 1;
+}
+
+static int nerv_matrix_(softmax)(lua_State *L) {
+ Matrix *a = luaT_checkudata(L, 1, nerv_matrix_(tname));
+ Matrix *max = nerv_matrix_(new_)(a->nrow, 1);
+ Matrix *dno = nerv_matrix_(new_)(a->nrow, 1);
+ Matrix *b = nerv_matrix_(new_)(a->nrow, a->ncol);
+ cudak_(cuda_colmax)(a, max);
+ cudak_(cuda_softmax_denominator)(a, max, dno);
+ cudak_(cuda_softmax_final)(a, max, dno, b);
+ luaT_pushudata(L, b, nerv_matrix_(tname));
+ return 1;
+}
+
+static int nerv_matrix_(colsum)(lua_State *L) {
+ Matrix *a = luaT_checkudata(L, 1, nerv_matrix_(tname));
+ Matrix *b = nerv_matrix_(new_)(a->nrow, 1);
+ cudak_(cuda_colsum)(a, b);
+ luaT_pushudata(L, b, nerv_matrix_(tname));
+ return 1;
+}
+
+static int nerv_matrix_(colmax)(lua_State *L) {
+ Matrix *a = luaT_checkudata(L, 1, nerv_matrix_(tname));
+ Matrix *b = nerv_matrix_(new_)(a->nrow, 1);
+ cudak_(cuda_colmax)(a, b);
+ luaT_pushudata(L, b, nerv_matrix_(tname));
+ return 1;
+}
+
+static const luaL_Reg nerv_matrix_(extra_methods)[] = {
+ {"__add__", nerv_matrix_(add)},
+ {"__mul__", nerv_matrix_(mul)},
+ {"sigmoid", nerv_matrix_(sigmoid)},
+ {"softmax", nerv_matrix_(softmax)},
+ {"colsum", nerv_matrix_(colsum)},
+ {"colmax", nerv_matrix_(colmax)},
+ {NULL, NULL}
+};
+
+static void cuda_matrix_(init)(lua_State *L) {
+ luaN_append_methods(L, nerv_matrix_(extra_methods));
+ cublasCreate(&cublas_handle);
+}
+
+static void cuda_matrix_(free)(MATRIX_ELEM *ptr) {
+ cudaFree(ptr);
+}
+
+static void cuda_matrix_(alloc)(MATRIX_ELEM **dptr, size_t *stride,
+ long width, long height) {
+ cudaMallocPitch((void **)dptr, stride, width, height);
+}
+
+static MATRIX_ELEM cuda_matrix_(read)(MATRIX_ELEM *data, int idx) {
+ MATRIX_ELEM res;
+ cudaMemcpy(&res, data + idx, sizeof(MATRIX_ELEM), cudaMemcpyDeviceToHost);
+ return res;
+}
+
+static void cuda_matrix_(write)(MATRIX_ELEM *data, int idx, MATRIX_ELEM val) {
+ cudaMemcpy(data + idx, &val, sizeof(MATRIX_ELEM), cudaMemcpyHostToDevice);
+}
+
+int nerv_matrix_(get_elem)(lua_State *L) {
+ return nerv_error_method_not_implemented(L);
+}
+
+int nerv_matrix_(set_elem)(lua_State *L) {
+ return nerv_error_method_not_implemented(L);
+}
+
+#include "matrix.c"
+#endif
diff --git a/matrix/generic/elem_type.h b/matrix/generic/elem_type.h
new file mode 100644
index 0000000..8f80306
--- /dev/null
+++ b/matrix/generic/elem_type.h
@@ -0,0 +1,11 @@
+#ifdef MATRIX_USE_FLOAT
+
+#define MATRIX_ELEM float
+#define MATRIX_ELEM_PTR(self) ((self)->data.f)
+
+#elif defined(MATRIX_USE_DOUBLE)
+
+#define MATRIX_ELEM double
+#define MATRIX_ELEM_PTR(self) ((self)->data.d)
+
+#endif
diff --git a/matrix/generic/matrix.c b/matrix/generic/matrix.c
index 9ced397..f0f81a9 100644
--- a/matrix/generic/matrix.c
+++ b/matrix/generic/matrix.c
@@ -3,59 +3,61 @@
#include "matrix.h"
extern const char *nerv_matrix_tname;
-extern const char *nerv_float_matrix_(tname);
+extern const char *nerv_matrix_(tname);
-void nerv_float_matrix_(data_free)(Matrix *self) {
+void nerv_matrix_(data_free)(Matrix *self) {
if (--(*self->data_ref) == 0)
- MATRIX_DATA_FREE(self->data.f);
+ MATRIX_DATA_FREE(MATRIX_ELEM_PTR(self));
}
-void nerv_float_matrix_(data_retain)(Matrix *self) {
+void nerv_matrix_(data_retain)(Matrix *self) {
(*self->data_ref)++;
}
-Matrix *nerv_float_matrix_(new_)(long nrow, long ncol) {
+Matrix *nerv_matrix_(new_)(long nrow, long ncol) {
Matrix *self = (Matrix *)malloc(sizeof(Matrix));
self->nrow = nrow;
self->ncol = ncol;
self->nmax = self->nrow * self->ncol;
- MATRIX_DATA_ALLOC(&self->data.f, &self->stride, sizeof(float) * self->ncol, self->nrow);
+ MATRIX_DATA_ALLOC(&MATRIX_ELEM_PTR(self), &self->stride,
+ sizeof(MATRIX_ELEM) * self->ncol, self->nrow);
self->data_ref = (long *)malloc(sizeof(long));
*self->data_ref = 0;
- nerv_float_matrix_(data_retain)(self);
+ nerv_matrix_(data_retain)(self);
return self;
}
-int nerv_float_matrix_(new)(lua_State *L) {
- luaT_pushudata(L, nerv_float_matrix_(new_)(luaL_checkinteger(L, 1),
+int nerv_matrix_(new)(lua_State *L) {
+ luaT_pushudata(L, nerv_matrix_(new_)(luaL_checkinteger(L, 1),
luaL_checkinteger(L, 2)),
- nerv_float_matrix_(tname));
+ nerv_matrix_(tname));
return 1;
}
-int nerv_float_matrix_(destroy)(lua_State *L) {
- Matrix *self = luaT_checkudata(L, 1, nerv_float_matrix_(tname));
- nerv_float_matrix_(data_free)(self);
+int nerv_matrix_(destroy)(lua_State *L) {
+ Matrix *self = luaT_checkudata(L, 1, nerv_matrix_(tname));
+ nerv_matrix_(data_free)(self);
return 0;
}
-int nerv_float_matrix_(get_elem)(lua_State *L);
-int nerv_float_matrix_(set_elem)(lua_State *L);
+int nerv_matrix_(get_elem)(lua_State *L);
+int nerv_matrix_(set_elem)(lua_State *L);
-static Matrix *nerv_float_matrix_(getrow)(Matrix *self, int row) {
+static Matrix *nerv_matrix_(getrow)(Matrix *self, int row) {
Matrix *prow = (Matrix *)malloc(sizeof(Matrix));
prow->ncol = self->ncol;
prow->nrow = 1;
prow->stride = self->stride;
prow->nmax = prow->ncol;
- prow->data.f = (float *)((char *)self->data.f + row * self->stride);
+ MATRIX_ELEM_PTR(prow) = \
+ (MATRIX_ELEM *)((char *)MATRIX_ELEM_PTR(self) + row * self->stride);
prow->data_ref = self->data_ref;
- nerv_float_matrix_(data_retain)(self);
+ nerv_matrix_(data_retain)(self);
return prow;
}
-static int nerv_float_matrix_(newindex)(lua_State *L) {
- Matrix *self = luaT_checkudata(L, 1, nerv_float_matrix_(tname));
+static int nerv_matrix_(newindex)(lua_State *L) {
+ Matrix *self = luaT_checkudata(L, 1, nerv_matrix_(tname));
if (lua_isnumber(L, 2))
{
int idx = luaL_checkinteger(L, 2);
@@ -63,7 +65,8 @@ static int nerv_float_matrix_(newindex)(lua_State *L) {
{
if (idx < 0 || idx >= self->ncol)
nerv_error(L, "index must be within range [0, %d)", self->ncol);
- MATRIX_DATA_WRITE(self->data.f, idx, luaL_checknumber(L, 3));
+ MATRIX_DATA_WRITE(MATRIX_ELEM_PTR(self), idx,
+ luaL_checknumber(L, 3));
}
else
nerv_error(L, "cannot assign a scalar to row vector");
@@ -78,8 +81,8 @@ static int nerv_float_matrix_(newindex)(lua_State *L) {
}
-static int nerv_float_matrix_(index)(lua_State *L) {
- Matrix *self = luaT_checkudata(L, 1, nerv_float_matrix_(tname));
+static int nerv_matrix_(index)(lua_State *L) {
+ Matrix *self = luaT_checkudata(L, 1, nerv_matrix_(tname));
if (lua_isnumber(L, 2))
{
int idx = luaL_checkinteger(L, 2);
@@ -87,13 +90,13 @@ static int nerv_float_matrix_(index)(lua_State *L) {
{
if (idx < 0 || idx >= self->ncol)
nerv_error(L, "index must be within range [0, %d)", self->ncol);
- lua_pushnumber(L, MATRIX_DATA_READ(self->data.f, idx));
+ lua_pushnumber(L, MATRIX_DATA_READ(MATRIX_ELEM_PTR(self), idx));
}
else
{
if (idx < 0 || idx >= self->nrow)
nerv_error(L, "index must be within range [0, %d)", self->nrow);
- luaT_pushudata(L, nerv_float_matrix_(getrow)(self, idx), nerv_float_matrix_(tname));
+ luaT_pushudata(L, nerv_matrix_(getrow)(self, idx), nerv_matrix_(tname));
}
lua_pushboolean(L, 1);
return 2;
@@ -105,33 +108,33 @@ static int nerv_float_matrix_(index)(lua_State *L) {
}
}
-static int nerv_float_matrix_(ncol)(lua_State *L) {
- Matrix *self = luaT_checkudata(L, 1, nerv_float_matrix_(tname));
+static int nerv_matrix_(ncol)(lua_State *L) {
+ Matrix *self = luaT_checkudata(L, 1, nerv_matrix_(tname));
lua_pushinteger(L, self->ncol);
return 1;
}
-static int nerv_float_matrix_(nrow)(lua_State *L) {
- Matrix *self = luaT_checkudata(L, 1, nerv_float_matrix_(tname));
+static int nerv_matrix_(nrow)(lua_State *L) {
+ Matrix *self = luaT_checkudata(L, 1, nerv_matrix_(tname));
lua_pushinteger(L, self->nrow);
return 1;
}
-static const luaL_Reg nerv_float_matrix_(methods)[] = {
- {"get_elem", nerv_float_matrix_(get_elem)},
- {"set_elem", nerv_float_matrix_(set_elem)},
- {"ncol", nerv_float_matrix_(ncol)},
- {"nrow", nerv_float_matrix_(nrow)},
- {"__index__", nerv_float_matrix_(index)},
- {"__newindex__", nerv_float_matrix_(newindex)},
+static const luaL_Reg nerv_matrix_(methods)[] = {
+ {"get_elem", nerv_matrix_(get_elem)},
+ {"set_elem", nerv_matrix_(set_elem)},
+ {"ncol", nerv_matrix_(ncol)},
+ {"nrow", nerv_matrix_(nrow)},
+ {"__index__", nerv_matrix_(index)},
+ {"__newindex__", nerv_matrix_(newindex)},
{NULL, NULL}
};
-void nerv_float_matrix_(init)(lua_State *L) {
- luaT_newmetatable(L, nerv_float_matrix_(tname), nerv_matrix_tname,
- nerv_float_matrix_(new), nerv_float_matrix_(destroy), NULL);
- luaL_register(L, NULL, nerv_float_matrix_(methods));
+void nerv_matrix_(init)(lua_State *L) {
+ luaT_newmetatable(L, nerv_matrix_(tname), nerv_matrix_tname,
+ nerv_matrix_(new), nerv_matrix_(destroy), NULL);
+ luaL_register(L, NULL, nerv_matrix_(methods));
#ifdef MATRIX_INIT
MATRIX_INIT(L);
#endif
diff --git a/matrix/generic/matrix.h b/matrix/generic/matrix.h
index 264859b..276ca5c 100644
--- a/matrix/generic/matrix.h
+++ b/matrix/generic/matrix.h
@@ -1,6 +1,7 @@
#ifndef NERV_GENERIC_MATRIX_H
#define NERV_GENERIC_MATRIX_H
+#include <stddef.h>
typedef struct Matrix {
size_t stride; /* size of a row */
long ncol, nrow, nmax; /* dimension of the matrix */
diff --git a/matrix/matrix.c b/matrix/generic/mmatrix.c
index b392f56..ac71c3d 100644
--- a/matrix/matrix.c
+++ b/matrix/generic/mmatrix.c
@@ -1,23 +1,25 @@
+#ifdef NERV_GENERIC_MMATRIX
+#include "matrix.h"
+#include "elem_type.h"
#define MATRIX_DATA_FREE(ptr) free(ptr)
-#define MATRIX_DATA_ALLOC(dptr, stride, width, height) host_float_array_alloc(dptr, stride, width, height)
+#define MATRIX_DATA_ALLOC(dptr, stride, width, height) \
+ host_matrix_(alloc)(dptr, stride, width, height)
#define MATRIX_DATA_STRIDE(ncol) (sizeof(float) * (ncol))
#define MATRIX_DATA_WRITE(data, idx, val) (data[idx] = val)
#define MATRIX_DATA_READ(data, idx) (data[idx])
#define NERV_GENERIC_MATRIX
-#define nerv_float_matrix_(NAME) nerv_float_matrix_host_ ## NAME
-#include "../common.h"
-#include "generic/matrix.h"
+#include "../../common.h"
-const char *nerv_float_matrix_(tname) = "nerv.FloatMatrix";
+const char *nerv_matrix_(tname) = "nerv.FloatMMatrix";
-static void host_float_array_alloc(float **dptr, size_t *stride,
+static void host_matrix_(alloc)(float **dptr, size_t *stride,
long width, long height) {
*dptr = (float *)malloc(width * height);
*stride = width;
}
-int nerv_float_matrix_(get_elem)(lua_State *L) {
- Matrix *self = luaT_checkudata(L, 1, nerv_float_matrix_(tname));
+int nerv_matrix_(get_elem)(lua_State *L) {
+ Matrix *self = luaT_checkudata(L, 1, nerv_matrix_(tname));
int idx = luaL_checkinteger(L, 2);
if (idx < 0 || idx >= self->nmax)
nerv_error(L, "index must be within range [0, %d)", self->nmax);
@@ -25,8 +27,8 @@ int nerv_float_matrix_(get_elem)(lua_State *L) {
return 1;
}
-int nerv_float_matrix_(set_elem)(lua_State *L) {
- Matrix *self = luaT_checkudata(L, 1, nerv_float_matrix_(tname));
+int nerv_matrix_(set_elem)(lua_State *L) {
+ Matrix *self = luaT_checkudata(L, 1, nerv_matrix_(tname));
int idx = luaL_checkinteger(L, 2);
float v = luaL_checknumber(L, 3);
if (idx < 0 || idx >= self->nmax)
@@ -35,4 +37,5 @@ int nerv_float_matrix_(set_elem)(lua_State *L) {
return 0;
}
-#include "generic/matrix.c"
+#include "matrix.c"
+#endif
diff --git a/matrix/init.c b/matrix/init.c
index e723f55..fb1c287 100644
--- a/matrix/init.c
+++ b/matrix/init.c
@@ -2,8 +2,11 @@
#include "generic/matrix.h"
const char *nerv_matrix_tname = "nerv.Matrix";
-void nerv_float_matrix_host_init(lua_State *L);
-void nerv_float_matrix_cuda_init(lua_State *L);
+void nerv_matrix_float_host_init(lua_State *L);
+void nerv_matrix_float_cuda_init(lua_State *L);
+void nerv_matrix_double_host_init(lua_State *L);
+void nerv_matrix_double_cuda_init(lua_State *L);
+
static const luaL_Reg matrix_methods[] = {
{"__tostring__", nerv_error_method_not_implemented },
{"__add__", nerv_error_method_not_implemented },
@@ -17,6 +20,8 @@ void nerv_matrix_init(lua_State *L) {
luaT_newmetatable(L, nerv_matrix_tname, NULL, NULL, NULL, NULL);
luaL_register(L, NULL, matrix_methods);
lua_pop(L, 1);
- nerv_float_matrix_host_init(L);
- nerv_float_matrix_cuda_init(L);
+ nerv_matrix_float_host_init(L);
+ nerv_matrix_float_cuda_init(L);
+/* nerv_matrix_double_host_init(L); */
+ nerv_matrix_double_cuda_init(L);
}
diff --git a/matrix/init.lua b/matrix/init.lua
index 59b8384..d6aab73 100644
--- a/matrix/init.lua
+++ b/matrix/init.lua
@@ -1,4 +1,4 @@
-function nerv.FloatCuMatrix:__tostring__()
+function nerv.Matrix:__tostring__()
local ncol = self:ncol()
local nrow = self:nrow()
local strt = {}
@@ -12,27 +12,11 @@ function nerv.FloatCuMatrix:__tostring__()
for row = 0, nrow - 1 do
local rp = self[row]
for col = 0, ncol - 1 do
- table.insert(strt, string.format("%f ", rp[col]))
+ table.insert(strt, string.format("%.10f ", rp[col]))
end
table.insert(strt, "\n")
end
end
- table.insert(strt, string.format("[Float Matrix %d x %d]", nrow, ncol))
- return table.concat(strt)
-end
-
-function nerv.FloatMatrix:__tostring__()
- local ncol = self:ncol()
- local nrow = self:nrow()
- local i = 0
- local strt = {}
- for row = 0, nrow - 1 do
- for col = 0, ncol - 1 do
- table.insert(strt, string.format("%f ", self:get_elem(i)))
- i = i + 1
- end
- table.insert(strt, "\n")
- end
- table.insert(strt, string.format("[Float Matrix %d x %d]", nrow, ncol))
+ table.insert(strt, string.format("[Matrix %d x %d]", nrow, ncol))
return table.concat(strt)
end
diff --git a/matrix/mmatrix.c b/matrix/mmatrix.c
new file mode 100644
index 0000000..f616d51
--- /dev/null
+++ b/matrix/mmatrix.c
@@ -0,0 +1,5 @@
+#define NERV_GENERIC_MMATRIX
+#define MATRIX_USE_FLOAT
+#define host_matrix_(NAME) host_matrix_float_ ## NAME
+#define nerv_matrix_(NAME) nerv_matrix_float_host_ ## NAME
+#include "generic/mmatrix.c"
diff --git a/matrix_example.lua b/matrix_example.lua
deleted file mode 100644
index 361e126..0000000
--- a/matrix_example.lua
+++ /dev/null
@@ -1,16 +0,0 @@
-t = nerv.FloatMatrix(2, 3)
-print(t:get_elem(1))
-t:set_elem(1, 3.23432);
-print(t:get_elem(1))
-print(t)
-t = nerv.FloatMatrix(10, 20)
-t:set_elem(1, 3.34);
-print(t)
-a = t[1]
-for i = 0, 9 do
- for j = 0, 19 do
- t[i][j] = i + j
- end
-end
-print(t)
-print(a)
diff --git a/mmatrix_example.lua b/mmatrix_example.lua
new file mode 100644
index 0000000..5b34779
--- /dev/null
+++ b/mmatrix_example.lua
@@ -0,0 +1,9 @@
+t = nerv.FloatMMatrix(5, 10)
+a = t[1]
+for i = 0, 4 do
+ for j = 0, 9 do
+ t[i][j] = i + j
+ end
+end
+print(t)
+print(a)