aboutsummaryrefslogtreecommitdiff
path: root/nerv/matrix
diff options
context:
space:
mode:
authorQi Liu <liuq901@163.com>2016-03-11 20:11:00 +0800
committerQi Liu <liuq901@163.com>2016-03-11 20:11:00 +0800
commite2a9af061db485d4388902d738c9d8be3f94ab34 (patch)
tree468d6c6afa0801f6a6bf794b3674f8814b8827f7 /nerv/matrix
parent2f46a5e2b37a054f482f76f4ac3d26b144cf988f (diff)
add recipe and fix bugs
Diffstat (limited to 'nerv/matrix')
-rw-r--r--nerv/matrix/init.lua18
1 files changed, 17 insertions, 1 deletions
diff --git a/nerv/matrix/init.lua b/nerv/matrix/init.lua
index cf85004..722c780 100644
--- a/nerv/matrix/init.lua
+++ b/nerv/matrix/init.lua
@@ -40,7 +40,8 @@ end
--- Assign each element in a matrix using the value returned by a callback `gen`.
-- @param gen the callback used to generated the values in the matrix, to which
-- the indices of row and column will be passed (e.g., `gen(i, j)`)
-function nerv.Matrix:generate(gen)
+
+function nerv.Matrix:_generate(gen)
if (self:dim() == 2) then
for i = 0, self:nrow() - 1 do
local row = self[i]
@@ -55,6 +56,21 @@ function nerv.Matrix:generate(gen)
end
end
+function nerv.Matrix:generate(gen)
+ local tmp
+ if nerv.is_type(self, 'nerv.CuMatrixFloat') then
+ tmp = nerv.MMatrixFloat(self:nrow(), self:ncol())
+ elseif nerv.is_type(self, 'nerv.CuMatrixDouble') then
+ tmp = nerv.MMatrixDouble(self:nrow(), self:ncol())
+ else
+ tmp = self
+ end
+ tmp:_generate(gen)
+ if nerv.is_type(self, 'nerv.CuMatrix') then
+ self:copy_fromh(tmp)
+ end
+end
+
--- Create a fresh new matrix of the same matrix type (as `self`).
-- @param nrow optional, the number of rows in the created matrix if specified,
-- otherwise `self:nrow()` will be used