aboutsummaryrefslogtreecommitdiff
path: root/nerv/matrix/init.lua
diff options
context:
space:
mode:
authorTed Yin <Determinant@users.noreply.github.com>2016-03-12 13:17:38 +0800
committerTed Yin <Determinant@users.noreply.github.com>2016-03-12 13:17:38 +0800
commit2b03555ea53a47e87d03a79feb866c868d424f01 (patch)
tree63cd01ee70d056d3a159a1e7d9aa4ea6f414d213 /nerv/matrix/init.lua
parente8b1007d99691c08dd1b71f5733eb3cd2827dc64 (diff)
parent2660af7f6a9ac243a8ad38bf3375ef0fd292bf52 (diff)
Merge pull request #31 from liuq901/master
modfiy param generate & rewrite LSTM layer
Diffstat (limited to 'nerv/matrix/init.lua')
-rw-r--r--nerv/matrix/init.lua18
1 files changed, 17 insertions, 1 deletions
diff --git a/nerv/matrix/init.lua b/nerv/matrix/init.lua
index cf85004..722c780 100644
--- a/nerv/matrix/init.lua
+++ b/nerv/matrix/init.lua
@@ -40,7 +40,8 @@ end
--- Assign each element in a matrix using the value returned by a callback `gen`.
-- @param gen the callback used to generated the values in the matrix, to which
-- the indices of row and column will be passed (e.g., `gen(i, j)`)
-function nerv.Matrix:generate(gen)
+
+function nerv.Matrix:_generate(gen)
if (self:dim() == 2) then
for i = 0, self:nrow() - 1 do
local row = self[i]
@@ -55,6 +56,21 @@ function nerv.Matrix:generate(gen)
end
end
+function nerv.Matrix:generate(gen)
+ local tmp
+ if nerv.is_type(self, 'nerv.CuMatrixFloat') then
+ tmp = nerv.MMatrixFloat(self:nrow(), self:ncol())
+ elseif nerv.is_type(self, 'nerv.CuMatrixDouble') then
+ tmp = nerv.MMatrixDouble(self:nrow(), self:ncol())
+ else
+ tmp = self
+ end
+ tmp:_generate(gen)
+ if nerv.is_type(self, 'nerv.CuMatrix') then
+ self:copy_fromh(tmp)
+ end
+end
+
--- Create a fresh new matrix of the same matrix type (as `self`).
-- @param nrow optional, the number of rows in the created matrix if specified,
-- otherwise `self:nrow()` will be used