aboutsummaryrefslogtreecommitdiff
path: root/matrix
diff options
context:
space:
mode:
Diffstat (limited to 'matrix')
-rw-r--r--matrix/init.lua13
1 files changed, 13 insertions, 0 deletions
diff --git a/matrix/init.lua b/matrix/init.lua
index 8f626dc..08080a9 100644
--- a/matrix/init.lua
+++ b/matrix/init.lua
@@ -38,3 +38,16 @@ function nerv.CuMatrix:__mul__(b)
c:mul(self, b, 'N', 'N')
return c
end
+
+function nerv.CuMatrixFloat.new_from_host(mat)
+ local res = nerv.CuMatrixFloat(mat:nrow(), mat:ncol())
+ res:copy_from(mat)
+ print(res)
+ return res
+end
+
+function nerv.CuMatrixFloat:new_to_host()
+ local res = nerv.MMatrixFloat(self:nrow(), self:ncol())
+ self:copy_to(res)
+ return res
+end