aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--Makefile5
-rw-r--r--examples/test_dnn_layers.lua2
-rw-r--r--examples/test_nn_lib.lua91
-rw-r--r--io/init.lua22
-rw-r--r--io/sgd_buffer.lua99
-rw-r--r--layer/affine.lua2
-rw-r--r--layer/init.lua24
-rw-r--r--layer/softmax_ce.lua20
-rw-r--r--matrix/cuda_helper.h2
-rw-r--r--matrix/cukernel.h1
-rw-r--r--matrix/generic/cukernel.cu19
-rw-r--r--matrix/generic/cumatrix.c87
-rw-r--r--matrix/generic/mmatrix.c20
-rw-r--r--matrix/init.lua3
-rw-r--r--nn/layer_dag.lua59
m---------speech0
16 files changed, 361 insertions, 95 deletions
diff --git a/Makefile b/Makefile
index 934235f..f0d319f 100644
--- a/Makefile
+++ b/Makefile
@@ -9,7 +9,8 @@ LUA_LIBS := matrix/init.lua io/init.lua nerv.lua \
pl/utils.lua pl/compat.lua \
layer/init.lua layer/affine.lua layer/sigmoid.lua layer/softmax_ce.lua \
layer/window.lua layer/bias.lua \
- nn/init.lua nn/layer_repo.lua nn/param_repo.lua nn/layer_dag.lua
+ nn/init.lua nn/layer_repo.lua nn/param_repo.lua nn/layer_dag.lua \
+ io/sgd_buffer.lua
INCLUDE := -I build/luajit-2.0/include/luajit-2.0/ -DLUA_USE_APICHECK
CUDA_BASE := /usr/local/cuda-6.5
CUDA_INCLUDE := -I $(CUDA_BASE)/include/
@@ -53,7 +54,7 @@ $(OBJ_DIR)/matrix/cukernel.o: matrix/generic/cukernel.cu
speech:
-mkdir -p build/objs/speech/tnet_io
- $(MAKE) -C speech/ BUILD_DIR=$(BUILD_DIR) LIB_DIR=$(LIB_DIR) OBJ_DIR=$(CURDIR)/build/objs/speech/
+ $(MAKE) -C speech/ BUILD_DIR=$(BUILD_DIR) LIB_DIR=$(LIB_DIR) OBJ_DIR=$(CURDIR)/build/objs/speech/ LUA_DIR=$(LUA_DIR)
clean:
-rm -rf $(OBJ_DIR)
diff --git a/examples/test_dnn_layers.lua b/examples/test_dnn_layers.lua
index 6e4d98d..f306807 100644
--- a/examples/test_dnn_layers.lua
+++ b/examples/test_dnn_layers.lua
@@ -3,7 +3,7 @@ require 'layer.sigmoid'
require 'layer.softmax_ce'
global_conf = {lrate = 0.8, wcost = 1e-6,
- momentum = 0.9, mat_type = nerv.CuMatrixFloat}
+ momentum = 0.9, cumat_type = nerv.CuMatrixFloat}
pf = nerv.ChunkFile("affine.param", "r")
ltp = pf:read_chunk("a", global_conf)
diff --git a/examples/test_nn_lib.lua b/examples/test_nn_lib.lua
index ec338fe..9600917 100644
--- a/examples/test_nn_lib.lua
+++ b/examples/test_nn_lib.lua
@@ -1,14 +1,24 @@
--- require 'layer.affine'
--- require 'layer.sigmoid'
--- require 'layer.softmax_ce'
-
+require 'speech.init'
gconf = {lrate = 0.8, wcost = 1e-6, momentum = 0.9,
- mat_type = nerv.CuMatrixFloat,
- batch_size = 10}
+ cumat_type = nerv.CuMatrixFloat,
+ mmat_type = nerv.MMatrixFloat,
+ batch_size = 256}
-param_repo = nerv.ParamRepo({"converted.nerv"})
+param_repo = nerv.ParamRepo({"converted.nerv", "global_transf.nerv"})
sublayer_repo = nerv.LayerRepo(
{
+ -- global transf
+ ["nerv.BiasLayer"] =
+ {
+ blayer1 = {{bias = "bias1"}, {dim_in = {429}, dim_out = {429}}},
+ blayer2 = {{bias = "bias2"}, {dim_in = {429}, dim_out = {429}}}
+ },
+ ["nerv.WindowLayer"] =
+ {
+ wlayer1 = {{window = "window1"}, {dim_in = {429}, dim_out = {429}}},
+ wlayer2 = {{window = "window2"}, {dim_in = {429}, dim_out = {429}}}
+ },
+ -- biased linearity
["nerv.AffineLayer"] =
{
affine0 = {{ltp = "affine0_ltp", bp = "affine0_bp"},
@@ -40,7 +50,7 @@ sublayer_repo = nerv.LayerRepo(
},
["nerv.SoftmaxCELayer"] =
{
- softmax_ce0 = {{}, {dim_in = {3001, 3001}, dim_out = {}}}
+ softmax_ce0 = {{}, {dim_in = {3001, 1}, dim_out = {}, compressed = true}}
}
}, param_repo, gconf)
@@ -48,8 +58,19 @@ layer_repo = nerv.LayerRepo(
{
["nerv.DAGLayer"] =
{
+ global_transf = {{}, {
+ dim_in = {429}, dim_out = {429},
+ sub_layers = sublayer_repo,
+ connections = {
+ ["<input>[1]"] = "blayer1[1]",
+ ["blayer1[1]"] = "wlayer1[1]",
+ ["wlayer1[1]"] = "blayer2[1]",
+ ["blayer2[1]"] = "wlayer2[1]",
+ ["wlayer2[1]"] = "<output>[1]"
+ }
+ }},
main = {{}, {
- dim_in = {429, 3001}, dim_out = {},
+ dim_in = {429, 1}, dim_out = {},
sub_layers = sublayer_repo,
connections = {
["<input>[1]"] = "affine0[1]",
@@ -74,24 +95,52 @@ layer_repo = nerv.LayerRepo(
}
}, param_repo, gconf)
-df = nerv.ChunkFile("input.param", "r")
-label = nerv.CuMatrixFloat(10, 3001)
-label:fill(0)
-for i = 0, 9 do
- label[i][i] = 1.0
-end
+tnet_reader = nerv.TNetReader(gconf,
+ {
+ id = "main_scp",
+ scp_file = "/slfs1/users/mfy43/swb_ivec/train_bp.scp",
+-- scp_file = "t.scp",
+ conf_file = "/slfs1/users/mfy43/swb_ivec/plp_0_d_a.conf",
+ frm_ext = 5,
+ mlfs = {
+ ref = {
+ file = "/slfs1/users/mfy43/swb_ivec/ref.mlf",
+ format = "map",
+ format_arg = "/slfs1/users/mfy43/swb_ivec/dict",
+ dir = "*/",
+ ext = "lab"
+ }
+ },
+ global_transf = layer_repo:get_layer("global_transf")
+ })
+
+buffer = nerv.SGDBuffer(gconf,
+ {
+ buffer_size = 8192,
+ readers = {
+ { reader = tnet_reader,
+ data = {main_scp = 429, ref = 1}}
+ }
+ })
-input = {df:read_chunk("input", gconf).trans, label}
-output = {}
-err_input = {}
-err_output = {input[1]:create()}
sm = sublayer_repo:get_layer("softmax_ce0")
main = layer_repo:get_layer("main")
-main:init()
-for i = 0, 3 do
+main:init(gconf.batch_size)
+cnt = 0
+for data in buffer.get_data, buffer do
+ if cnt == 1000 then break end
+ cnt = cnt + 1
+ input = {data.main_scp, data.ref}
+ output = {}
+ err_input = {}
+ err_output = {input[1]:create()}
+
main:propagate(input, output)
main:back_propagate(err_output, err_input, input, output)
main:update(err_input, input, output)
+
nerv.utils.printf("cross entropy: %.8f\n", sm.total_ce)
nerv.utils.printf("frames: %.8f\n", sm.total_frames)
+ nerv.utils.printf("err/frm: %.8f\n", sm.total_ce / sm.total_frames)
+ collectgarbage("collect")
end
diff --git a/io/init.lua b/io/init.lua
index 4a663a7..9bbd51a 100644
--- a/io/init.lua
+++ b/io/init.lua
@@ -28,3 +28,25 @@ function nerv.ChunkFile:read_chunk(id, global_conf)
chunk:read(self:get_chunkdata(id))
return chunk
end
+
+local DataReader = nerv.class("nerv.DataReader")
+
+function DataReader:__init(global_conf, reader_conf)
+ nerv.error_method_not_implemented()
+end
+
+function DataReader:get_data()
+ nerv.error_method_not_implemented()
+end
+
+local DataBuffer = nerv.class("nerv.DataBuffer")
+
+function DataBuffer:__init(global_conf, buffer_conf)
+ nerv.error_method_not_implemented()
+end
+
+function DataBuffer:get_batch()
+ nerv.error_method_not_implemented()
+end
+
+require 'io.sgd_buffer'
diff --git a/io/sgd_buffer.lua b/io/sgd_buffer.lua
new file mode 100644
index 0000000..dadcf67
--- /dev/null
+++ b/io/sgd_buffer.lua
@@ -0,0 +1,99 @@
+local SGDBuffer = nerv.class("nerv.SGDBuffer", "nerv.DataBuffer")
+
+function SGDBuffer:__init(global_conf, buffer_conf)
+ self.gconf = global_conf
+ self.buffer_size = math.floor(buffer_conf.buffer_size /
+ global_conf.batch_size) * global_conf.batch_size
+ self.head = 0
+ self.tail = 0
+ self.readers = {}
+ for i, reader_spec in ipairs(buffer_conf.readers) do
+ local buffs = {}
+ for id, width in pairs(reader_spec.data) do
+ buffs[id] = {data = global_conf.mmat_type(self.buffer_size, width),
+ leftover = {},
+ width = width}
+ end
+ table.insert(self.readers, {buffs = buffs,
+ reader = reader_spec.reader,
+ tail = 0,
+ has_leftover = false})
+ end
+end
+
+function SGDBuffer:saturate()
+ local buffer_size = self.buffer_size
+ self.head = 0
+ self.tail = buffer_size
+ for i, reader in ipairs(self.readers) do
+ reader.tail = 0
+ if reader.has_leftover then
+ local lrow
+ for id, buff in pairs(reader.buffs) do
+ lrow = buff.leftover:nrow()
+ if lrow > buffer_size then
+ nerv.error("buffer size is too small to contain leftovers")
+ end
+ buff.data:copy_from(buff.leftover, 0, lrow)
+ end
+ reader.tail = lrow
+ reader.has_leftover = false
+ end
+ while reader.tail < buffer_size do
+ local data = reader.reader:get_data()
+ if data == nil then
+ break
+ end
+ local drow = nil
+ for id, d in pairs(data) do
+ if drow == nil then
+ drow = d:nrow()
+ elseif d:nrow() ~= drow then
+ nerv.error("reader provides with inconsistent rows of data")
+ end
+ end
+ local remain = buffer_size - reader.tail
+ if drow > remain then
+ for id, buff in pairs(reader.buffs) do
+ local d = data[id]
+ if d == nil then
+ nerv.error("reader does not provide data for %s", id)
+ end
+ buff.leftover = self.gconf.mmat_type(drow - remain,
+ buff.width)
+ buff.leftover:copy_from(d, remain, drow)
+ end
+ drow = remain
+ reader.has_leftover = true
+ end
+ for id, buff in pairs(reader.buffs) do
+ buff.data:copy_from(data[id], 0, drow, reader.tail)
+ end
+ reader.tail = reader.tail + drow
+ end
+ self.tail = math.min(self.tail, reader.tail)
+ end
+ return self.tail >= self.gconf.batch_size
+end
+
+function SGDBuffer:get_data()
+ local batch_size = self.gconf.batch_size
+ if self.head >= self.tail then -- buffer is empty
+ if not self:saturate() then
+ return nil -- the remaining data cannot build a batch
+ end
+ end
+ if self.head + batch_size > self.tail then
+ return nil -- the remaining data cannot build a batch
+ end
+ local res = {}
+ for i, reader in ipairs(self.readers) do
+ for id, buff in pairs(reader.buffs) do
+ local batch = self.gconf.cumat_type(batch_size, buff.width)
+ batch:copy_fromh(buff.data, self.head, self.head + batch_size)
+ res[id] = batch
+ end
+ end
+ self.head = self.head + batch_size
+ return res
+end
diff --git a/layer/affine.lua b/layer/affine.lua
index 90a1d16..59a0e91 100644
--- a/layer/affine.lua
+++ b/layer/affine.lua
@@ -4,7 +4,7 @@ local BiasParam = nerv.class('nerv.BiasParam', 'nerv.MatrixParam')
local AffineLayer = nerv.class('nerv.AffineLayer', 'nerv.Layer')
function MatrixParam:read(pcdata)
- self.trans = self.gconf.mat_type.new_from_host(
+ self.trans = self.gconf.cumat_type.new_from_host(
nerv.MMatrixFloat.load(pcdata))
end
diff --git a/layer/init.lua b/layer/init.lua
index c8c691b..38bcd7f 100644
--- a/layer/init.lua
+++ b/layer/init.lua
@@ -2,50 +2,50 @@
local Param = nerv.class('nerv.Param')
-function nerv.Param:__init(id, global_conf)
+function Param:__init(id, global_conf)
self.id = id
self.gconf = global_conf
end
-function nerv.Param:get_info()
+function Param:get_info()
return self.info
end
-function nerv.Param:set_info(info)
+function Param:set_info(info)
self.info = info
end
-function nerv.Param:read(pfhandle)
+function Param:read(pfhandle)
nerv.error_method_not_implemented()
end
-function nerv.Param:write(pfhandle)
+function Param:write(pfhandle)
nerv.error_method_not_implemented()
end
local Layer = nerv.class('nerv.Layer')
-function nerv.Layer:__init(id, global_conf, ...)
+function Layer:__init(id, global_conf, layer_conf)
nerv.error_method_not_implemented()
end
-function nerv.Layer:init(id)
+function Layer:init(id)
nerv.error_method_not_implemented()
end
-function nerv.Layer:update(bp_err, input, output)
+function Layer:update(bp_err, input, output)
nerv.error_method_not_implemented()
end
-function nerv.Layer:propagate(input, output)
+function Layer:propagate(input, output)
nerv.error_method_not_implemented()
end
-function nerv.Layer:back_propagate(next_bp_err, bp_err, input, output)
+function Layer:back_propagate(next_bp_err, bp_err, input, output)
nerv.error_method_not_implemented()
end
-function nerv.Layer:check_dim_len(len_in, len_out)
+function Layer:check_dim_len(len_in, len_out)
local expected_in = #self.dim_in
local expected_out = #self.dim_out
if len_in > 0 and expected_in ~= len_in then
@@ -58,7 +58,7 @@ function nerv.Layer:check_dim_len(len_in, len_out)
end
end
-function nerv.Layer:get_dim()
+function Layer:get_dim()
return self.dim_in, self.dim_out
end
diff --git a/layer/softmax_ce.lua b/layer/softmax_ce.lua
index 09eb3a9..cf98c45 100644
--- a/layer/softmax_ce.lua
+++ b/layer/softmax_ce.lua
@@ -5,6 +5,10 @@ function SoftmaxCELayer:__init(id, global_conf, layer_conf)
self.gconf = global_conf
self.dim_in = layer_conf.dim_in
self.dim_out = layer_conf.dim_out
+ self.compressed = layer_conf.compressed
+ if self.compressed == nil then
+ self.compressed = false
+ end
self:check_dim_len(2, -1) -- two inputs: nn output and label
end
@@ -26,15 +30,21 @@ function SoftmaxCELayer:propagate(input, output)
soutput:softmax(input[1])
local ce = soutput:create()
ce:log_elem(soutput)
- ce:mul_elem(ce, input[2])
--- print(input[1][0])
--- print(soutput[1][0])
- -- add total ce
+ local label = input[2]
+ if self.compressed then
+ label = label:decompress(input[1]:ncol())
+ end
+ ce:mul_elem(ce, label)
+ -- add total ce
self.total_ce = self.total_ce - ce:rowsum():colsum()[0]
self.total_frames = self.total_frames + soutput:nrow()
end
function SoftmaxCELayer:back_propagate(next_bp_err, bp_err, input, output)
-- softmax output - label
- next_bp_err[1]:add(self.soutput, input[2], 1.0, -1.0)
+ local label = input[2]
+ if self.compressed then
+ label = label:decompress(input[1]:ncol())
+ end
+ next_bp_err[1]:add(self.soutput, label, 1.0, -1.0)
end
diff --git a/matrix/cuda_helper.h b/matrix/cuda_helper.h
index c0fa618..cedc643 100644
--- a/matrix/cuda_helper.h
+++ b/matrix/cuda_helper.h
@@ -23,7 +23,7 @@
#define CHECK_SAME_DIMENSION(a, b) \
do { \
if (!(a->nrow == b->nrow && a->ncol == b->ncol)) \
- nerv_error(L, "Matrices should be of the same dimension"); \
+ nerv_error(L, "matrices should be of the same dimension"); \
} while (0)
static const char *cublasGetErrorString(cublasStatus_t err) {
diff --git a/matrix/cukernel.h b/matrix/cukernel.h
index 178b7d3..7d2168e 100644
--- a/matrix/cukernel.h
+++ b/matrix/cukernel.h
@@ -13,4 +13,5 @@ void cudak_(cuda_fill)(Matrix *a, double val);
void cudak_(cuda_expand_frm)(const Matrix *a, Matrix *b, int context);
void cudak_(cuda_rearrange_frm)(const Matrix *a, Matrix *b, int step);
void cudak_(cuda_scale_row)(const Matrix *a, Matrix *b);
+void cudak_(cuda_decompress)(const Matrix *a, Matrix *b);
#endif
diff --git a/matrix/generic/cukernel.cu b/matrix/generic/cukernel.cu
index 1d8b983..05a1e78 100644
--- a/matrix/generic/cukernel.cu
+++ b/matrix/generic/cukernel.cu
@@ -187,6 +187,15 @@ __global__ void cudak_(scale_row)(const MATRIX_ELEM *a, MATRIX_ELEM *b,
b[j + i * stride] *= a[j];
}
+__global__ void cudak_(decompress)(const MATRIX_ELEM *a, MATRIX_ELEM *b,
+ int nrow, int ncol,
+ int stride_a, int stride_b) {
+ int j = blockIdx.x * blockDim.x + threadIdx.x;
+ int i = blockIdx.y * blockDim.y + threadIdx.y;
+ if (i >= nrow || j >= ncol) return;
+ b[lrintf(a[j + i * stride_a]) + i * stride_b] = 1.0;
+}
+
extern "C" {
#include "../cukernel.h"
void cudak_(cuda_log_elem)(const Matrix *a, Matrix *b) {
@@ -385,5 +394,15 @@ extern "C" {
(MATRIX_ELEM_PTR(a), MATRIX_ELEM_PTR(b),
b->nrow, b->ncol, b->stride / sizeof(MATRIX_ELEM));
}
+
+ void cudak_(cuda_decompress)(const Matrix *a, Matrix *b) {
+ dim3 threadsPerBlock(1, CUDA_THREADS_NN);
+ dim3 numBlocks(1, CEIL_DIV(a->nrow, threadsPerBlock.y));
+ cudak_(decompress)<<<numBlocks, threadsPerBlock>>> \
+ (MATRIX_ELEM_PTR(a), MATRIX_ELEM_PTR(b),
+ a->nrow, a->ncol,
+ a->stride / sizeof(MATRIX_ELEM),
+ b->stride / sizeof(MATRIX_ELEM));
+ }
}
#endif
diff --git a/matrix/generic/cumatrix.c b/matrix/generic/cumatrix.c
index 0df1bd7..373fc42 100644
--- a/matrix/generic/cumatrix.c
+++ b/matrix/generic/cumatrix.c
@@ -74,7 +74,8 @@ static int nerv_matrix_(mul)(lua_State *L) {
if (an != bm)
nerv_error(L, "Wrong dimension of multipliers");
/* MATRIX_ELEM alpha = 1.0f, beta = 0.0f; */
- CUBLAS_SAFE_CALL( //Because matrix in Nerv is row-major, here b comes first
+ /* Because matrix in Nerv is row-major, here b comes first */
+ CUBLAS_SAFE_CALL(
NERV_CUBLAS_(gemm)(cublas_handle, tb, ta,
bn, am, bm,
&alpha,
@@ -113,9 +114,11 @@ static int nerv_matrix_(sigmoid_grad)(lua_State *L) {
static int nerv_matrix_(softmax)(lua_State *L) {
Matrix *a = luaT_checkudata(L, 2, nerv_matrix_(tname));
Matrix *b = luaT_checkudata(L, 1, nerv_matrix_(tname));
- Matrix *max = nerv_matrix_(new_)(L, a->nrow, 1);
- Matrix *dno = nerv_matrix_(new_)(L, a->nrow, 1);
+ Matrix *max;
+ Matrix *dno;
CHECK_SAME_DIMENSION(a, b);
+ max = nerv_matrix_(new_)(L, a->nrow, 1);
+ dno = nerv_matrix_(new_)(L, a->nrow, 1);
cudak_(cuda_rowmax)(a, max);
cudak_(cuda_softmax_denominator)(a, max, dno);
cudak_(cuda_softmax_final)(a, max, dno, b);
@@ -168,26 +171,22 @@ static int nerv_matrix_(fill)(lua_State *L) {
return 0;
}
-static int nerv_matrix_(copy_fromd)(lua_State *L) {
+static int nerv_matrix_(copy_fromd)(lua_State *L) {
Matrix *a = luaT_checkudata(L, 1, nerv_matrix_(tname));
Matrix *b = luaT_checkudata(L, 2, nerv_matrix_(tname));
- CHECK_SAME_DIMENSION(a, b);
- CUDA_SAFE_SYNC_CALL(
- cudaMemcpy2D(MATRIX_ELEM_PTR(a), a->stride,
- MATRIX_ELEM_PTR(b), b->stride,
- sizeof(MATRIX_ELEM) * b->ncol, b->nrow,
- cudaMemcpyDeviceToDevice));
- return 0;
-}
-
-static int nerv_matrix_(copy_tod)(lua_State *L) {
- Matrix *a = luaT_checkudata(L, 1, nerv_matrix_(tname));
- Matrix *b = luaT_checkudata(L, 2, nerv_matrix_(tname));
- CHECK_SAME_DIMENSION(a, b);
+ int nargs = lua_gettop(L);
+ int b_begin = nargs > 2 ? luaL_checkinteger(L, 3) : 0;
+ int b_end = nargs > 3 ? luaL_checkinteger(L, 4) : b->nrow;
+ int a_begin = nargs > 4 ? luaL_checkinteger(L, 5) : 0;
+ if (!(0 <= b_begin && b_begin < b_end && b_end <= b->nrow &&
+ a_begin + b_end - b_begin <= a->nrow))
+ nerv_error(L, "invalid copy interval");
+ if (a->ncol != b->ncol)
+ nerv_error(L, "matrices should be of the same dimension");
CUDA_SAFE_SYNC_CALL(
- cudaMemcpy2D(MATRIX_ELEM_PTR(b), b->stride,
- MATRIX_ELEM_PTR(a), a->stride,
- sizeof(MATRIX_ELEM) * a->ncol, a->nrow,
+ cudaMemcpy2D(MATRIX_ROW_PTR(a, a_begin), a->stride,
+ MATRIX_ROW_PTR(b, b_begin), b->stride,
+ sizeof(MATRIX_ELEM) * b->ncol, b_end - b_begin,
cudaMemcpyDeviceToDevice));
return 0;
}
@@ -196,11 +195,19 @@ extern const char *MATRIX_CUMATRIX_HOST_TNAME;
static int nerv_matrix_(copy_fromh)(lua_State *L) {
Matrix *a = luaT_checkudata(L, 1, nerv_matrix_(tname));
Matrix *b = luaT_checkudata(L, 2, MATRIX_CUMATRIX_HOST_TNAME);
- CHECK_SAME_DIMENSION(a, b);
+ int nargs = lua_gettop(L);
+ int b_begin = nargs > 2 ? luaL_checkinteger(L, 3) : 0;
+ int b_end = nargs > 3 ? luaL_checkinteger(L, 4) : b->nrow;
+ int a_begin = nargs > 4 ? luaL_checkinteger(L, 5) : 0;
+ if (!(0 <= b_begin && b_begin < b_end && b_end <= b->nrow &&
+ a_begin + b_end - b_begin <= a->nrow))
+ nerv_error(L, "invalid copy interval");
+ if (a->ncol != b->ncol)
+ nerv_error(L, "matrices should be of the same dimension");
CUDA_SAFE_SYNC_CALL(
- cudaMemcpy2D(MATRIX_ELEM_PTR(a), a->stride,
- MATRIX_ELEM_PTR(b), b->stride,
- sizeof(MATRIX_ELEM) * b->ncol, b->nrow,
+ cudaMemcpy2D(MATRIX_ROW_PTR(a, a_begin), a->stride,
+ MATRIX_ROW_PTR(b, b_begin), b->stride,
+ sizeof(MATRIX_ELEM) * b->ncol, b_end - b_begin,
cudaMemcpyHostToDevice));
return 0;
}
@@ -208,11 +215,19 @@ static int nerv_matrix_(copy_fromh)(lua_State *L) {
static int nerv_matrix_(copy_toh)(lua_State *L) {
Matrix *a = luaT_checkudata(L, 1, nerv_matrix_(tname));
Matrix *b = luaT_checkudata(L, 2, MATRIX_CUMATRIX_HOST_TNAME);
- CHECK_SAME_DIMENSION(a, b);
+ int nargs = lua_gettop(L);
+ int a_begin = nargs > 2 ? luaL_checkinteger(L, 3) : 0;
+ int a_end = nargs > 3 ? luaL_checkinteger(L, 4) : a->nrow;
+ int b_begin = nargs > 4 ? luaL_checkinteger(L, 5) : 0;
+ if (!(0 <= a_begin && a_begin < a_end && a_end <= a->nrow &&
+ b_begin + a_end - a_begin <= b->nrow))
+ nerv_error(L, "invalid copy interval");
+ if (b->ncol != a->ncol)
+ nerv_error(L, "matrices should be of the same dimension");
CUDA_SAFE_SYNC_CALL(
- cudaMemcpy2D(MATRIX_ELEM_PTR(b), b->stride,
- MATRIX_ELEM_PTR(a), a->stride,
- sizeof(MATRIX_ELEM) * a->ncol, a->nrow,
+ cudaMemcpy2D(MATRIX_ROW_PTR(b, b_begin), b->stride,
+ MATRIX_ROW_PTR(a, a_begin), a->stride,
+ sizeof(MATRIX_ELEM) * a->ncol, a_end - a_begin,
cudaMemcpyDeviceToHost));
return 0;
}
@@ -221,6 +236,7 @@ static int nerv_matrix_(trans)(lua_State *L) {
Matrix *a = luaT_checkudata(L, 1, nerv_matrix_(tname));
Matrix *b = nerv_matrix_(new_)(L, a->ncol, a->nrow);
MATRIX_ELEM alpha = 1, beta = 0;
+ /* FIXME: possible memory leak when lua error is raised */
CUBLAS_SAFE_CALL(
NERV_CUBLAS_(geam)(cublas_handle, CUBLAS_OP_T, CUBLAS_OP_T,
a->nrow, a->ncol,
@@ -251,6 +267,19 @@ static int nerv_matrix_(log_elem)(lua_State *L) {
return 0;
}
+static int nerv_matrix_(decompress)(lua_State *L) {
+ Matrix *a = luaT_checkudata(L, 1, nerv_matrix_(tname));
+ Matrix *b;
+ int orig_col = luaL_checkinteger(L, 2);
+ if (a->ncol != 1)
+ nerv_error(L, "the compressed matrix must be a column vector");
+ b = nerv_matrix_(new_)(L, a->nrow, orig_col);
+ cudak_(cuda_fill)(b, 0.0);
+ cudak_(cuda_decompress)(a, b);
+ luaT_pushudata(L, b, nerv_matrix_(tname));
+ return 1;
+}
+
extern const char *nerv_matrix_host_int_tname;
static int nerv_matrix_(copy_rows_fromh_by_idx)(lua_State *L) {
Matrix *a = luaT_checkudata(L, 1, nerv_matrix_(tname));
@@ -322,11 +351,11 @@ static const luaL_Reg nerv_matrix_(extra_methods)[] = {
{"rowsum", nerv_matrix_(rowsum)},
{"rowmax", nerv_matrix_(rowmax)},
{"trans", nerv_matrix_(trans)},
+ {"decompress", nerv_matrix_(decompress)},
/* in-place calc */
{"copy_fromh", nerv_matrix_(copy_fromh)},
{"copy_fromd", nerv_matrix_(copy_fromd)},
{"copy_toh", nerv_matrix_(copy_toh)},
- {"copy_tod", nerv_matrix_(copy_tod)},
{"add", nerv_matrix_(add)},
{"mul", nerv_matrix_(mul)},
{"add_row", nerv_matrix_(add_row)},
diff --git a/matrix/generic/mmatrix.c b/matrix/generic/mmatrix.c
index 3a9ae79..4b722f3 100644
--- a/matrix/generic/mmatrix.c
+++ b/matrix/generic/mmatrix.c
@@ -11,6 +11,7 @@
#define NERV_GENERIC_MATRIX
#include "../../common.h"
#include "../../io/chunk_file.h"
+#include "string.h"
static void host_matrix_(alloc)(lua_State *L,
MATRIX_ELEM **dptr, size_t *stride,
@@ -96,10 +97,27 @@ int nerv_matrix_(save)(lua_State *L) {
return 0;
}
-
+static int nerv_matrix_(copy_from)(lua_State *L) {
+ Matrix *a = luaT_checkudata(L, 1, nerv_matrix_(tname));
+ Matrix *b = luaT_checkudata(L, 2, nerv_matrix_(tname));
+ int nargs = lua_gettop(L);
+ int b_begin = nargs > 2 ? luaL_checkinteger(L, 3) : 0;
+ int b_end = nargs > 3 ? luaL_checkinteger(L, 4) : b->nrow;
+ int a_begin = nargs > 4 ? luaL_checkinteger(L, 5) : 0;
+ if (!(0 <= b_begin && b_begin < b_end && b_end <= b->nrow &&
+ a_begin + b_end - b_begin <= a->nrow))
+ nerv_error(L, "invalid copy interval");
+ if (a->ncol != b->ncol)
+ nerv_error(L, "matrices should be of the same dimension");
+ memmove(MATRIX_ROW_PTR(a, a_begin),
+ MATRIX_ROW_PTR(b, b_begin),
+ sizeof(MATRIX_ELEM) * b->ncol * (b_end - b_begin));
+ return 0;
+}
static const luaL_Reg nerv_matrix_(extra_methods)[] = {
{"load", nerv_matrix_(load)},
{"save", nerv_matrix_(save)},
+ {"copy_from", nerv_matrix_(copy_from)},
{NULL, NULL}
};
diff --git a/matrix/init.lua b/matrix/init.lua
index f309f81..9637391 100644
--- a/matrix/init.lua
+++ b/matrix/init.lua
@@ -22,7 +22,8 @@ function nerv.Matrix:__tostring__()
table.insert(strt, "\n")
end
end
- table.insert(strt, string.format("[Matrix %d x %d]", nrow, ncol))
+ table.insert(strt, string.format(
+ "[%s %d x %d]", self.__typename, nrow, ncol))
return table.concat(strt)
end
diff --git a/nn/layer_dag.lua b/nn/layer_dag.lua
index 1ab18fa..4ee829e 100644
--- a/nn/layer_dag.lua
+++ b/nn/layer_dag.lua
@@ -44,6 +44,7 @@ function nerv.DAGLayer:__init(id, global_conf, layer_conf)
local outputs = {}
local dim_in = layer_conf.dim_in
local dim_out = layer_conf.dim_out
+ local parsed_conn = {}
for from, to in pairs(layer_conf.connections) do
local id_from, port_from = parse_id(from)
local id_to, port_to = parse_id(to)
@@ -76,32 +77,18 @@ function nerv.DAGLayer:__init(id, global_conf, layer_conf)
if output_dim[port_from] ~= input_dim[port_to] then
nerv.error("mismatching data dimension between %s and %s", from, to)
end
- local mid = global_conf.mat_type(global_conf.batch_size,
- output_dim[port_from])
- local err_mid = mid:create()
-
- ref_from.outputs[port_from] = mid
- ref_to.inputs[port_to] = mid
-
- ref_from.err_inputs[port_from] = err_mid
- ref_to.err_outputs[port_to] = err_mid
+ table.insert(parsed_conn,
+ {{ref_from, port_from}, {ref_to, port_to}})
table.insert(ref_from.next_layers, ref_to) -- add edge
ref_to.in_deg = ref_to.in_deg + 1 -- increase the in-degree of the target layer
end
end
- self.layers = layers
- self.inputs = inputs
- self.outputs = outputs
- self.dim_in = dim_in
- self.dim_out = dim_out
-end
-function nerv.DAGLayer:init(id) -- topology sort
local queue = {}
local l = 1
local r = 1
- for id, ref in pairs(self.layers) do
+ for id, ref in pairs(layers) do
if ref.in_deg == 0 then
table.insert(queue, ref)
nerv.utils.printf("adding source layer: %s\n", id)
@@ -126,20 +113,50 @@ function nerv.DAGLayer:init(id) -- topology sort
for i = 1, #queue do
nerv.utils.printf("queued layer: %s\n", queue[i].layer.id)
end
- self.queue = queue
- for id, ref in pairs(self.layers) do
+
+ for id, ref in pairs(layers) do
-- check wether the graph is connected
if ref.visited == false then
nerv.utils.printf("warning: layer %s is ignored\n", id)
end
+ end
+
+ self.layers = layers
+ self.inputs = inputs
+ self.outputs = outputs
+ self.dim_in = dim_in
+ self.dim_out = dim_out
+ self.parsed_conn = parsed_conn
+ self.queue = queue
+ self.gconf = global_conf
+end
+
+function nerv.DAGLayer:init(batch_size) -- topology sort
+ for i, conn in ipairs(self.parsed_conn) do
+ local _, output_dim
+ local ref_from, port_from, ref_to, port_to
+ ref_from, port_from = unpack(conn[1])
+ ref_to, port_to = unpack(conn[2])
+ _, output_dim = ref_from.layer:get_dim()
+ local mid = self.gconf.cumat_type(batch_size,
+ output_dim[port_from])
+ local err_mid = mid:create()
+
+ ref_from.outputs[port_from] = mid
+ ref_to.inputs[port_to] = mid
+
+ ref_from.err_inputs[port_from] = err_mid
+ ref_to.err_outputs[port_to] = err_mid
+ end
+ for id, ref in pairs(self.layers) do
for i = 1, ref.input_len do
if ref.inputs[i] == nil then
- nerv.error("dangling port %d of layer %s", i, id)
+ nerv.error("dangling input port %d of layer %s", i, id)
end
end
for i = 1, ref.output_len do
if ref.outputs[i] == nil then
- nerv.error("dangling port %d of layer %s", i, id)
+ nerv.error("dangling output port %d of layer %s", i, id)
end
end
-- initialize sub layers
diff --git a/speech b/speech
-Subproject 0c6ca6a17f06821cd5d612f489ca6cb68c2c4d5
+Subproject a753eca0121ac3ec81ed76bd719d3f1cb952268
1702 1703 1704 1705 1706 1707 1708 1709 1710 1711 1712 1713 1714 1715 1716 1717 1718 1719 1720 1721 1722 1723 1724 1725 1726 1727 1728 1729 1730 1731 1732 1733 1734 1735 1736 1737 1738 1739 1740 1741 1742 1743 1744 1745 1746 1747 1748 1749 1750 1751 1752 1753 1754 1755 1756 1757 1758 1759 1760 1761 1762 1763 1764 1765 1766 1767 1768 1769 1770 1771 1772 1773 1774 1775 1776 1777 1778 1779 1780 1781 1782 1783 1784 1785 1786 1787 1788 1789 1790 1791 1792 1793 1794 1795 1796 1797 1798 1799 1800 1801 1802 1803 1804 1805 1806 1807 1808 1809 1810 1811 1812 1813 1814 1815 1816 1817 1818 1819 1820 1821 1822 1823 1824 1825 1826 1827 1828 1829 1830 1831 1832 1833 1834 1835 1836 1837 1838 1839 1840 1841 1842 1843 1844 1845 1846 1847 1848 1849 1850 1851 1852 1853 1854 1855 1856 1857 1858 1859 1860 1861 1862 1863 1864 1865 1866 1867 1868 1869