aboutsummaryrefslogtreecommitdiff
path: root/nerv/matrix/init.lua
diff options
context:
space:
mode:
Diffstat (limited to 'nerv/matrix/init.lua')
-rw-r--r--nerv/matrix/init.lua18
1 files changed, 17 insertions, 1 deletions
diff --git a/nerv/matrix/init.lua b/nerv/matrix/init.lua
index cf85004..722c780 100644
--- a/nerv/matrix/init.lua
+++ b/nerv/matrix/init.lua
@@ -40,7 +40,8 @@ end
--- Assign each element in a matrix using the value returned by a callback `gen`.
-- @param gen the callback used to generated the values in the matrix, to which
-- the indices of row and column will be passed (e.g., `gen(i, j)`)
-function nerv.Matrix:generate(gen)
+
+function nerv.Matrix:_generate(gen)
if (self:dim() == 2) then
for i = 0, self:nrow() - 1 do
local row = self[i]
@@ -55,6 +56,21 @@ function nerv.Matrix:generate(gen)
end
end
+function nerv.Matrix:generate(gen)
+ local tmp
+ if nerv.is_type(self, 'nerv.CuMatrixFloat') then
+ tmp = nerv.MMatrixFloat(self:nrow(), self:ncol())
+ elseif nerv.is_type(self, 'nerv.CuMatrixDouble') then
+ tmp = nerv.MMatrixDouble(self:nrow(), self:ncol())
+ else
+ tmp = self
+ end
+ tmp:_generate(gen)
+ if nerv.is_type(self, 'nerv.CuMatrix') then
+ self:copy_fromh(tmp)
+ end
+end
+
--- Create a fresh new matrix of the same matrix type (as `self`).
-- @param nrow optional, the number of rows in the created matrix if specified,
-- otherwise `self:nrow()` will be used