aboutsummaryrefslogtreecommitdiff
path: root/nerv/matrix/init.lua
diff options
context:
space:
mode:
Diffstat (limited to 'nerv/matrix/init.lua')
-rw-r--r--nerv/matrix/init.lua44
1 files changed, 35 insertions, 9 deletions
diff --git a/nerv/matrix/init.lua b/nerv/matrix/init.lua
index da76e1b..722c780 100644
--- a/nerv/matrix/init.lua
+++ b/nerv/matrix/init.lua
@@ -40,7 +40,8 @@ end
--- Assign each element in a matrix using the value returned by a callback `gen`.
-- @param gen the callback used to generated the values in the matrix, to which
-- the indices of row and column will be passed (e.g., `gen(i, j)`)
-function nerv.Matrix:generate(gen)
+
+function nerv.Matrix:_generate(gen)
if (self:dim() == 2) then
for i = 0, self:nrow() - 1 do
local row = self[i]
@@ -55,6 +56,21 @@ function nerv.Matrix:generate(gen)
end
end
+function nerv.Matrix:generate(gen)
+ local tmp
+ if nerv.is_type(self, 'nerv.CuMatrixFloat') then
+ tmp = nerv.MMatrixFloat(self:nrow(), self:ncol())
+ elseif nerv.is_type(self, 'nerv.CuMatrixDouble') then
+ tmp = nerv.MMatrixDouble(self:nrow(), self:ncol())
+ else
+ tmp = self
+ end
+ tmp:_generate(gen)
+ if nerv.is_type(self, 'nerv.CuMatrix') then
+ self:copy_fromh(tmp)
+ end
+end
+
--- Create a fresh new matrix of the same matrix type (as `self`).
-- @param nrow optional, the number of rows in the created matrix if specified,
-- otherwise `self:nrow()` will be used
@@ -87,6 +103,17 @@ function nerv.Matrix:__mul__(b)
return c
end
+--- A wrapper function for `copy_from`
+function nerv.Matrix:copy_to(b, ...)
+ b:copy_from(self, ...)
+end
+
+--- The base class for all device (in-GPU) matrices
+-- @type nerv.CuMatrix
+
+--- A wrapper function for `copy_fromd`
+nerv.CuMatrix.copy_tod = nerv.Matrix.copy_to
+
--- CUDA float matrices
-- @type nerv.CuMatrixFloat
@@ -127,15 +154,14 @@ end
-- @type nerv.MMatrix
--- A wrapper function for `copy_fromh`
-function nerv.MMatrix:copy_toh(b, ...)
- b:copy_fromh(self, ...)
-end
+nerv.MMatrix.copy_toh = nerv.Matrix.copy_to
---- Print profiling info of host matrices
-function nerv.MMatrix.print_profile()
- nerv.info("mmatrix profile not available")
+--- A wrapper function for `nerv.CuMatrix` copy
+function nerv.MMatrix:copy_fromd(b, ...)
+ b:copy_toh(self, ...)
end
---- Clear profiling info of host matrices
-function nerv.MMatrix.clear_profile()
+--- A wrapper function for `nerv.CuMatrix` copy
+function nerv.MMatrix:copy_tod(b, ...)
+ b:copy_fromh(self, ...)
end