aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--nerv/Makefile2
-rw-r--r--nerv/layer/init.lua1
-rw-r--r--nerv/layer/relu.lua33
-rw-r--r--nerv/lib/matrix/generic/cukernel.cu44
-rw-r--r--nerv/lib/matrix/generic/cumatrix.c19
-rw-r--r--nerv/lib/matrix/generic/cumatrix.h5
-rw-r--r--nerv/lib/matrix/generic/mmatrix.c39
-rw-r--r--nerv/lib/matrix/generic/mmatrix.h10
-rw-r--r--nerv/matrix/generic/cumatrix.c2
-rw-r--r--nerv/matrix/generic/matrix.c23
-rw-r--r--nerv/matrix/generic/mmatrix.c2
-rw-r--r--nerv/test/cumatrix_func.out12
-rw-r--r--nerv/test/matrix_func.lua3
-rw-r--r--nerv/test/mmatrix_func.out12
14 files changed, 206 insertions, 1 deletions
diff --git a/nerv/Makefile b/nerv/Makefile
index f74a92f..0d9934a 100644
--- a/nerv/Makefile
+++ b/nerv/Makefile
@@ -40,7 +40,7 @@ OBJS := $(CORE_OBJS) $(NERV_OBJS) $(LUAT_OBJS)
LIBS := $(INST_LIBDIR)/libnerv.so $(LIB_PATH)/libnervcore.so $(LIB_PATH)/libluaT.so
LUA_LIBS := matrix/init.lua io/init.lua init.lua \
layer/init.lua layer/affine.lua layer/sigmoid.lua layer/tanh.lua layer/softmax_ce.lua layer/softmax.lua \
- layer/lstmp.lua layer/projection.lua \
+ layer/lstmp.lua layer/projection.lua layer/relu.lua\
layer/window.lua layer/bias.lua layer/combiner.lua layer/mse.lua \
layer/elem_mul.lua layer/lstm.lua layer/lstm_gate.lua layer/dropout.lua layer/gru.lua \
layer/graph.lua layer/rnn.lua layer/duplicate.lua layer/identity.lua \
diff --git a/nerv/layer/init.lua b/nerv/layer/init.lua
index 7521b7a..d175d02 100644
--- a/nerv/layer/init.lua
+++ b/nerv/layer/init.lua
@@ -280,6 +280,7 @@ nerv.include('duplicate.lua')
nerv.include('identity.lua')
nerv.include('projection.lua')
nerv.include('lstmp.lua')
+nerv.include('relu.lua')
-- The following lines are for backward compatibility, and will be removed in
-- the future. The use of these names are deprecated.
diff --git a/nerv/layer/relu.lua b/nerv/layer/relu.lua
new file mode 100644
index 0000000..b7951e7
--- /dev/null
+++ b/nerv/layer/relu.lua
@@ -0,0 +1,33 @@
+local ReluLayer = nerv.class('nerv.ReluLayer', 'nerv.Layer')
+
+function ReluLayer:__init(id, global_conf, layer_conf)
+ nerv.Layer.__init(self, id, global_conf, layer_conf)
+ self:check_dim_len(1, 1)
+end
+
+function ReluLayer:bind_params()
+end
+
+function ReluLayer:init()
+ if self.dim_in[1] ~= self.dim_out[1] then
+ nerv.error('mismatching dimensions of input and output')
+ end
+end
+
+function ReluLayer:batch_resize(batch_size)
+end
+
+function ReluLayer:update()
+end
+
+function ReluLayer:propagate(input, output)
+ output[1]:relu(input[1])
+end
+
+function ReluLayer:back_propagate(bp_err, next_bp_err, input, output)
+ next_bp_err[1]:relu_grad(bp_err[1], output[1])
+end
+
+function ReluLayer:get_params()
+ return nerv.ParamRepo({}, self.loc_type)
+end
diff --git a/nerv/lib/matrix/generic/cukernel.cu b/nerv/lib/matrix/generic/cukernel.cu
index cf9d213..82bea14 100644
--- a/nerv/lib/matrix/generic/cukernel.cu
+++ b/nerv/lib/matrix/generic/cukernel.cu
@@ -90,6 +90,27 @@ __global__ void cudak_(tanh_grad)(const MATRIX_ELEM *output,
nerr[idx] = (1.0 - output[idx] * output[idx]) * err[idx];
}
+__global__ void cudak_(relu)(const MATRIX_ELEM *a, MATRIX_ELEM *b,
+ int nrow, int ncol, int stride) {
+ int j = blockIdx.x * blockDim.x + threadIdx.x;
+ int i = blockIdx.y * blockDim.y + threadIdx.y;
+ long idx;
+ if (i >= nrow || j >= ncol) return;
+ idx = j + i * stride;
+ b[idx] = a[idx] > 0 ? a[idx] : 0;
+}
+
+__global__ void cudak_(relu_grad)(const MATRIX_ELEM *output,
+ const MATRIX_ELEM *err,
+ MATRIX_ELEM *nerr,
+ int nrow, int ncol, int stride) {
+ int j = blockIdx.x * blockDim.x + threadIdx.x;
+ int i = blockIdx.y * blockDim.y + threadIdx.y;
+ long idx;
+ if (i >= nrow || j >= ncol) return;
+ idx = j + i * stride;
+ nerr[idx] = (output[idx] > 0 ? 1 : 0) * err[idx];
+}
__global__ void cudak_(softmax_final)(const MATRIX_ELEM *a, MATRIX_ELEM *b,
const MATRIX_ELEM *max, const MATRIX_ELEM *deno,
int nrow, int ncol, int stride, int mstride) {
@@ -510,6 +531,29 @@ extern "C" {
cudaStreamSynchronize(0);
}
+ void cudak_(cuda_relu)(const Matrix *a, Matrix *b) {
+ dim3 threadsPerBlock(CUDA_THREADS_N, CUDA_THREADS_N);
+ dim3 numBlocks(CEIL_DIV(b->ncol, threadsPerBlock.x),
+ CEIL_DIV(b->nrow, threadsPerBlock.y));
+ cudak_(relu)<<<numBlocks, threadsPerBlock>>> \
+ (MATRIX_ELEM_PTR(a), MATRIX_ELEM_PTR(b), b->nrow, b->ncol,
+ b->stride / sizeof(MATRIX_ELEM));
+ cudaStreamSynchronize(0);
+ }
+
+ void cudak_(cuda_relu_grad)(const Matrix *output,
+ const Matrix *err, Matrix *nerr) {
+ dim3 threadsPerBlock(CUDA_THREADS_N, CUDA_THREADS_N);
+ dim3 numBlocks(CEIL_DIV(nerr->ncol, threadsPerBlock.x),
+ CEIL_DIV(nerr->nrow, threadsPerBlock.y));
+ cudak_(relu_grad)<<<numBlocks, threadsPerBlock>>> \
+ (MATRIX_ELEM_PTR(output), MATRIX_ELEM_PTR(err),
+ MATRIX_ELEM_PTR(nerr),
+ nerr->nrow, nerr->ncol,
+ nerr->stride / sizeof(MATRIX_ELEM));
+ cudaStreamSynchronize(0);
+ }
+
void cudak_(cuda_rowsum)(const Matrix *a, Matrix *b) {
dim3 block(CUDA_THREADS_NN, 1);
int ncol = a->ncol;
diff --git a/nerv/lib/matrix/generic/cumatrix.c b/nerv/lib/matrix/generic/cumatrix.c
index bc5f285..432222a 100644
--- a/nerv/lib/matrix/generic/cumatrix.c
+++ b/nerv/lib/matrix/generic/cumatrix.c
@@ -117,6 +117,25 @@ void nerv_matrix_(tanh_grad)(Matrix *nerr, const Matrix *err, const Matrix *outp
NERV_SET_STATUS(status, NERV_NORMAL, 0);
}
+void nerv_matrix_(relu)(Matrix *a, const Matrix *b,
+ CuContext *context, Status *status) {
+ CHECK_SAME_DIMENSION(a, b, status);
+ PROFILE_START
+ cudak_(cuda_relu)(b, a);
+ PROFILE_STOP
+ NERV_SET_STATUS(status, NERV_NORMAL, 0);
+}
+
+void nerv_matrix_(relu_grad)(Matrix *nerr, const Matrix *err, const Matrix *output,
+ CuContext *context, Status *status) {
+ CHECK_SAME_DIMENSION(nerr, err, status);
+ CHECK_SAME_DIMENSION(nerr, output, status);
+ PROFILE_START
+ cudak_(cuda_relu_grad)(output, err, nerr);
+ PROFILE_STOP
+ NERV_SET_STATUS(status, NERV_NORMAL, 0);
+}
+
Matrix *nerv_matrix_(softmax)(Matrix *b, const Matrix *a,
CuContext *context, Status *status) {
Matrix *max, *max_idx;
diff --git a/nerv/lib/matrix/generic/cumatrix.h b/nerv/lib/matrix/generic/cumatrix.h
index 79bfc76..459513b 100644
--- a/nerv/lib/matrix/generic/cumatrix.h
+++ b/nerv/lib/matrix/generic/cumatrix.h
@@ -17,6 +17,11 @@ void nerv_matrix_(tanh)(Matrix *a, const Matrix *b,
void nerv_matrix_(tanh_grad)(Matrix *nerr, const Matrix *err,
const Matrix *output,
CuContext *context, Status *status);
+void nerv_matrix_(relu)(Matrix *a, const Matrix *b,
+ CuContext *context, Status *status);
+void nerv_matrix_(relu_grad)(Matrix *nerr, const Matrix *err,
+ const Matrix *output,
+ CuContext *context, Status *status);
Matrix *nerv_matrix_(softmax)(Matrix *b, const Matrix *a,
CuContext *context, Status *status);
diff --git a/nerv/lib/matrix/generic/mmatrix.c b/nerv/lib/matrix/generic/mmatrix.c
index ccfb2ce..e76d4fb 100644
--- a/nerv/lib/matrix/generic/mmatrix.c
+++ b/nerv/lib/matrix/generic/mmatrix.c
@@ -460,6 +460,45 @@ void nerv_matrix_(tanh_grad)(Matrix *nerr, const Matrix *err,
NERV_SET_STATUS(status, NERV_NORMAL, 0);
}
+void nerv_matrix_(relu)(Matrix *b, const Matrix *a,
+ MContext *context, Status *status) {
+ CHECK_SAME_DIMENSION(a, b, status);
+ int i, j;
+ size_t astride = a->stride, bstride = b->stride;
+ const MATRIX_ELEM *arow = MATRIX_ELEM_PTR(a);
+ MATRIX_ELEM *brow = MATRIX_ELEM_PTR(b);
+ for (i = 0; i < b->nrow; i++)
+ {
+ for (j = 0; j < b->ncol; j++)
+ brow[j] = arow[j] > 0 ? arow[j] : 0;
+ arow = MATRIX_NEXT_ROW_PTR(arow, astride);
+ brow = MATRIX_NEXT_ROW_PTR(brow, bstride);
+ }
+ NERV_SET_STATUS(status, NERV_NORMAL, 0);
+}
+
+void nerv_matrix_(relu_grad)(Matrix *nerr, const Matrix *err,
+ const Matrix *output,
+ MContext *context, Status *status) {
+ CHECK_SAME_DIMENSION(nerr, err, status);
+ CHECK_SAME_DIMENSION(nerr, output, status);
+ int i, j;
+ size_t nerr_stride = nerr->stride,
+ err_stride = err->stride,
+ out_stride = output->stride;
+ MATRIX_ELEM *nerr_row = MATRIX_ELEM_PTR(nerr);
+ const MATRIX_ELEM *err_row = MATRIX_ELEM_PTR(err),
+ *out_row = MATRIX_ELEM_PTR(output);
+ for (i = 0; i < nerr->nrow; i++)
+ {
+ for (j = 0; j < nerr->ncol; j++)
+ nerr_row[j] = (out_row[j] > 0 ? 1 : 0) * err_row[j];
+ nerr_row = MATRIX_NEXT_ROW_PTR(nerr_row, nerr_stride);
+ err_row = MATRIX_NEXT_ROW_PTR(err_row, err_stride);
+ out_row = MATRIX_NEXT_ROW_PTR(out_row, out_stride);
+ }
+ NERV_SET_STATUS(status, NERV_NORMAL, 0);
+}
void nerv_matrix_(expand_frm)(Matrix *a, const Matrix *b,
int cont, MContext *context, Status *status) {
if (a->nrow != b->nrow)
diff --git a/nerv/lib/matrix/generic/mmatrix.h b/nerv/lib/matrix/generic/mmatrix.h
index 41c39f6..7f494d6 100644
--- a/nerv/lib/matrix/generic/mmatrix.h
+++ b/nerv/lib/matrix/generic/mmatrix.h
@@ -13,6 +13,16 @@ void nerv_matrix_(sigmoid)(Matrix *a, const Matrix *b,
void nerv_matrix_(sigmoid_grad)(Matrix *nerr, const Matrix *err,
const Matrix *output,
MContext *context, Status *status);
+void nerv_matrix_(tanh)(Matrix *a, const Matrix *b,
+ MContext *context, Status *status);
+void nerv_matrix_(tanh_grad)(Matrix *nerr, const Matrix *err,
+ const Matrix *output,
+ MContext *context, Status *status);
+void nerv_matrix_(relu)(Matrix *a, const Matrix *b,
+ MContext *context, Status *status);
+void nerv_matrix_(relu_grad)(Matrix *nerr, const Matrix *err,
+ const Matrix *output,
+ MContext *context, Status *status);
Matrix *nerv_matrix_(softmax)(Matrix *b, const Matrix *a,
MContext *context, Status *status);
diff --git a/nerv/matrix/generic/cumatrix.c b/nerv/matrix/generic/cumatrix.c
index 540afd2..3481ede 100644
--- a/nerv/matrix/generic/cumatrix.c
+++ b/nerv/matrix/generic/cumatrix.c
@@ -230,6 +230,8 @@ static const luaL_Reg nerv_matrix_(extra_methods)[] = {
{"sigmoid_grad", nerv_matrix_(lua_sigmoid_grad)},
{"tanh", nerv_matrix_(lua_tanh)},
{"tanh_grad", nerv_matrix_(lua_tanh_grad)},
+ {"relu", nerv_matrix_(lua_relu)},
+ {"relu_grad", nerv_matrix_(lua_relu_grad)},
{"rand_uniform", nerv_matrix_(lua_rand_uniform)},
{"softmax", nerv_matrix_(lua_softmax)},
{"mul_elem", nerv_matrix_(lua_mul_elem)},
diff --git a/nerv/matrix/generic/matrix.c b/nerv/matrix/generic/matrix.c
index 800408d..9f31b4b 100644
--- a/nerv/matrix/generic/matrix.c
+++ b/nerv/matrix/generic/matrix.c
@@ -430,4 +430,27 @@ static int nerv_matrix_(lua_tanh_grad)(lua_State *L) {
return 0;
}
+static int nerv_matrix_(lua_relu)(lua_State *L) {
+ Status status;
+ MATRIX_CONTEXT *context;
+ MATRIX_GET_CONTEXT(L, 3);
+ Matrix *a = luaT_checkudata(L, 1, nerv_matrix_(tname));
+ Matrix *b = luaT_checkudata(L, 2, nerv_matrix_(tname));
+ nerv_matrix_(relu)(a, b, context, &status);
+ NERV_LUA_CHECK_STATUS(L, status);
+ return 0;
+}
+
+static int nerv_matrix_(lua_relu_grad)(lua_State *L) {
+ Status status;
+ MATRIX_CONTEXT *context;
+ MATRIX_GET_CONTEXT(L, 4);
+ Matrix *nerr = luaT_checkudata(L, 1, nerv_matrix_(tname));
+ Matrix *err = luaT_checkudata(L, 2, nerv_matrix_(tname));
+ Matrix *output = luaT_checkudata(L, 3, nerv_matrix_(tname));
+ nerv_matrix_(relu_grad)(nerr, err, output, context, &status);
+ NERV_LUA_CHECK_STATUS(L, status);
+ return 0;
+}
+
#endif
diff --git a/nerv/matrix/generic/mmatrix.c b/nerv/matrix/generic/mmatrix.c
index c03aee4..530888b 100644
--- a/nerv/matrix/generic/mmatrix.c
+++ b/nerv/matrix/generic/mmatrix.c
@@ -122,6 +122,8 @@ static const luaL_Reg nerv_matrix_(extra_methods)[] = {
{"sigmoid_grad", nerv_matrix_(lua_sigmoid_grad)},
{"tanh", nerv_matrix_(lua_tanh)},
{"tanh_grad", nerv_matrix_(lua_tanh_grad)},
+ {"relu", nerv_matrix_(lua_relu)},
+ {"relu_grad", nerv_matrix_(lua_relu_grad)},
{"softmax", nerv_matrix_(lua_softmax)},
{"mul_elem", nerv_matrix_(lua_mul_elem)},
{"log_elem", nerv_matrix_(lua_log_elem)},
diff --git a/nerv/test/cumatrix_func.out b/nerv/test/cumatrix_func.out
index 44e9015..651476f 100644
--- a/nerv/test/cumatrix_func.out
+++ b/nerv/test/cumatrix_func.out
@@ -179,6 +179,10 @@
0.76159418 0.96402758 0.99505478 0.99932933
0.96402758 0.99505478 0.99932933 0.99990922
[nerv.CuMatrixFloat 3 x 4]
+0.00000000 0.00000000 0.00000000 0.01483566
+0.00000000 0.00000000 0.01483566 1.00201201
+0.00000000 0.01483566 1.00201201 2.00027227
+[nerv.CuMatrixFloat 3 x 4]
0.00000000 1.00000000 2.00000000 3.00000000 4.00000000 5.00000000 6.00000000 7.00000000 8.00000000 9.00000000 10.00000000 11.00000000 12.00000000 13.00000000 14.00000000 15.00000000 16.00000000 17.00000000 18.00000000 19.00000000 20.00000000 21.00000000 22.00000000 23.00000000 24.00000000 25.00000000 26.00000000 27.00000000 28.00000000 29.00000000 30.00000000 31.00000000 32.00000000 33.00000000 34.00000000 35.00000000 36.00000000 37.00000000 38.00000000 39.00000000
1.00000000 2.00000000 3.00000000 4.00000000 5.00000000 6.00000000 7.00000000 8.00000000 9.00000000 10.00000000 11.00000000 12.00000000 13.00000000 14.00000000 15.00000000 16.00000000 17.00000000 18.00000000 19.00000000 20.00000000 21.00000000 22.00000000 23.00000000 24.00000000 25.00000000 26.00000000 27.00000000 28.00000000 29.00000000 30.00000000 31.00000000 32.00000000 33.00000000 34.00000000 35.00000000 36.00000000 37.00000000 38.00000000 39.00000000 40.00000000
2.00000000 3.00000000 4.00000000 5.00000000 6.00000000 7.00000000 8.00000000 9.00000000 10.00000000 11.00000000 12.00000000 13.00000000 14.00000000 15.00000000 16.00000000 17.00000000 18.00000000 19.00000000 20.00000000 21.00000000 22.00000000 23.00000000 24.00000000 25.00000000 26.00000000 27.00000000 28.00000000 29.00000000 30.00000000 31.00000000 32.00000000 33.00000000 34.00000000 35.00000000 36.00000000 37.00000000 38.00000000 39.00000000 40.00000000 41.00000000
@@ -1530,6 +1534,10 @@
0.76159418 0.96402758 0.99505478 0.99932933
0.96402758 0.99505478 0.99932933 0.99990922
[nerv.CuMatrixFloat 3 x 4]
+0.00000000 0.00000000 0.00000000 0.01483566
+0.00000000 0.00000000 0.01483566 1.00201201
+0.00000000 0.01483566 1.00201201 2.00027227
+[nerv.CuMatrixFloat 3 x 4]
0.00000000 1.00000000 2.00000000 3.00000000 4.00000000 5.00000000 6.00000000 7.00000000 8.00000000 9.00000000
1.00000000 2.00000000 3.00000000 4.00000000 5.00000000 6.00000000 7.00000000 8.00000000 9.00000000 10.00000000
2.00000000 3.00000000 4.00000000 5.00000000 6.00000000 7.00000000 8.00000000 9.00000000 10.00000000 11.00000000
@@ -1961,3 +1969,7 @@
0.76159418 0.96402758 0.99505478 0.99932933
0.96402758 0.99505478 0.99932933 0.99990922
[nerv.CuMatrixFloat 3 x 4]
+0.00000000 0.00000000 0.00000000 0.01483566
+0.00000000 0.00000000 0.01483566 1.00201201
+0.00000000 0.01483566 1.00201201 2.00027227
+[nerv.CuMatrixFloat 3 x 4]
diff --git a/nerv/test/matrix_func.lua b/nerv/test/matrix_func.lua
index 817d463..90bb27f 100644
--- a/nerv/test/matrix_func.lua
+++ b/nerv/test/matrix_func.lua
@@ -164,6 +164,9 @@ function _test_all_shape(mat_type, m, n, k, fill)
local c = a:create()
c:tanh(a)
print(c)
+ a:add(a, c, 1.0, -3.0)
+ c:relu(a)
+ print(c)
end
function test_all(mat_type)
_test_all_shape(mat_type, 3, 4, 2, _pattern_fill)
diff --git a/nerv/test/mmatrix_func.out b/nerv/test/mmatrix_func.out
index 721ee21..6244de4 100644
--- a/nerv/test/mmatrix_func.out
+++ b/nerv/test/mmatrix_func.out
@@ -179,6 +179,10 @@
0.76159418 0.96402758 0.99505478 0.99932933
0.96402758 0.99505478 0.99932933 0.99990922
[nerv.MMatrixFloat 3 x 4]
+0.00000000 0.00000000 0.00000000 0.01483560
+0.00000000 0.00000000 0.01483560 1.00201201
+0.00000000 0.01483560 1.00201201 2.00027227
+[nerv.MMatrixFloat 3 x 4]
0.00000000 1.00000000 2.00000000 3.00000000 4.00000000 5.00000000 6.00000000 7.00000000 8.00000000 9.00000000 10.00000000 11.00000000 12.00000000 13.00000000 14.00000000 15.00000000 16.00000000 17.00000000 18.00000000 19.00000000 20.00000000 21.00000000 22.00000000 23.00000000 24.00000000 25.00000000 26.00000000 27.00000000 28.00000000 29.00000000 30.00000000 31.00000000 32.00000000 33.00000000 34.00000000 35.00000000 36.00000000 37.00000000 38.00000000 39.00000000
1.00000000 2.00000000 3.00000000 4.00000000 5.00000000 6.00000000 7.00000000 8.00000000 9.00000000 10.00000000 11.00000000 12.00000000 13.00000000 14.00000000 15.00000000 16.00000000 17.00000000 18.00000000 19.00000000 20.00000000 21.00000000 22.00000000 23.00000000 24.00000000 25.00000000 26.00000000 27.00000000 28.00000000 29.00000000 30.00000000 31.00000000 32.00000000 33.00000000 34.00000000 35.00000000 36.00000000 37.00000000 38.00000000 39.00000000 40.00000000
2.00000000 3.00000000 4.00000000 5.00000000 6.00000000 7.00000000 8.00000000 9.00000000 10.00000000 11.00000000 12.00000000 13.00000000 14.00000000 15.00000000 16.00000000 17.00000000 18.00000000 19.00000000 20.00000000 21.00000000 22.00000000 23.00000000 24.00000000 25.00000000 26.00000000 27.00000000 28.00000000 29.00000000 30.00000000 31.00000000 32.00000000 33.00000000 34.00000000 35.00000000 36.00000000 37.00000000 38.00000000 39.00000000 40.00000000 41.00000000
@@ -1530,6 +1534,10 @@
0.76159418 0.96402758 0.99505478 0.99932933
0.96402758 0.99505478 0.99932933 0.99990922
[nerv.MMatrixFloat 3 x 4]
+0.00000000 0.00000000 0.00000000 0.01483560
+0.00000000 0.00000000 0.01483560 1.00201201
+0.00000000 0.01483560 1.00201201 2.00027227
+[nerv.MMatrixFloat 3 x 4]
0.00000000 1.00000000 2.00000000 3.00000000 4.00000000 5.00000000 6.00000000 7.00000000 8.00000000 9.00000000
1.00000000 2.00000000 3.00000000 4.00000000 5.00000000 6.00000000 7.00000000 8.00000000 9.00000000 10.00000000
2.00000000 3.00000000 4.00000000 5.00000000 6.00000000 7.00000000 8.00000000 9.00000000 10.00000000 11.00000000
@@ -1961,3 +1969,7 @@
0.76159418 0.96402758 0.99505478 0.99932933
0.96402758 0.99505478 0.99932933 0.99990922
[nerv.MMatrixFloat 3 x 4]
+0.00000000 0.00000000 0.00000000 0.01483560
+0.00000000 0.00000000 0.01483560 1.00201201
+0.00000000 0.01483560 1.00201201 2.00027227
+[nerv.MMatrixFloat 3 x 4]