--- Implements a fraction of matrix operations (methods) in Lua, while -- others are implemented in C extension. -- @author Ted Yin --- The base class for all matrices. -- @type nerv.Matrix --- Convert the matrix object to a string. function nerv.Matrix:__tostring__() local ncol = self:ncol() local nrow = self:nrow() local dim = self:dim() local strt = {} local fmt if self.fmt then fmt = self.fmt else fmt = "%.8f " end if (dim == 2) then for row = 0, nrow - 1 do local rp = self[row] for col = 0, ncol - 1 do table.insert(strt, string.format(fmt, rp[col])) end table.insert(strt, "\n") end else for col = 0, ncol - 1 do table.insert(strt, string.format(fmt, self[col])) end table.insert(strt, "\n") end table.insert(strt, string.format( "[%s %d x %d]", self.__typename, nrow, ncol)) return table.concat(strt) end function nerv.Matrix:_generate(gen) if (self:dim() == 2) then for i = 0, self:nrow() - 1 do local row = self[i] for j = 0, self:ncol() - 1 do row[j] = gen(i, j) end end else for j = 0, self:ncol() - 1 do self[j] = gen(j) end end end --- Assign each element in a matrix using the value returned by a callback `gen`. -- @param gen the callback used to generated the values in the matrix, to which -- the indices of row and column will be passed (e.g., `gen(i, j)`) function nerv.Matrix:generate(gen) local tmp if nerv.is_type(self, 'nerv.CuMatrixFloat') then tmp = nerv.MMatrixFloat(self:nrow(), self:ncol()) elseif nerv.is_type(self, 'nerv.CuMatrixDouble') then tmp = nerv.MMatrixDouble(self:nrow(), self:ncol()) else tmp = self end tmp:_generate(gen) if nerv.is_type(self, 'nerv.CuMatrix') then self:copy_fromh(tmp) end end --- Create a fresh new matrix of the same matrix type (as `self`). -- @param nrow optional, the number of rows in the created matrix if specified, -- otherwise `self:nrow()` will be used -- @param ncol optional, the number of columns in the created matrix if specified, -- otherwise `self:ncol()` will be used function nerv.Matrix:create(nrow, ncol) return self.__constructor(nrow or self:nrow(), ncol or self:ncol()) end nerv.MMatrixInt.fmt = "%d " --- Operator overloading of `+`. function nerv.Matrix:__add__(b) c = self:create() c:add(self, b, 1.0, 1.0) return c end --- Operator overloading of `-`. function nerv.Matrix:__sub__(b) c = self:create() c:add(self, b, 1.0, -1.0) return c end --- Operator overloading of `*`. function nerv.Matrix:__mul__(b) c = nerv.get_type(self.__typename)(self:nrow(), b:ncol()) c:mul(self, b, 1.0, 0.0, 'N', 'N') return c end --- A wrapper function for `copy_from`. function nerv.Matrix:copy_to(b, ...) b:copy_from(self, ...) end --- The base class for all device (in-GPU) matrices. -- @type nerv.CuMatrix --- A wrapper function for `copy_fromd`. nerv.CuMatrix.copy_tod = nerv.Matrix.copy_to --- CUDA float matrices. -- @type nerv.CuMatrixFloat --- Create a CUDA matrix copy of the host matrix (in memory). -- @param mat the host matrix function nerv.CuMatrixFloat.new_from_host(mat) local res = nerv.CuMatrixFloat(mat:nrow(), mat:ncol()) res:copy_fromh(mat) return res end --- Create a host matrix copy of the CUDA matrix. function nerv.CuMatrixFloat:new_to_host() local res = nerv.MMatrixFloat(self:nrow(), self:ncol()) self:copy_toh(res) return res end --- CUDA double matrices. -- @type nerv.CuMatrixDouble --- Create a CUDA matrix copy of the host matrix (in memory). -- @param mat the host matrix function nerv.CuMatrixDouble.new_from_host(mat) local res = nerv.CuMatrixDouble(mat:nrow(), mat:ncol()) res:copy_fromh(mat) return res end --- Create a host matrix copy of the CUDA matrix. function nerv.CuMatrixDouble:new_to_host() local res = nerv.MMatrixDouble(self:nrow(), self:ncol()) self:copy_toh(res) return res end --- The base class for all host (in-memory) matrices. -- @type nerv.MMatrix --- A wrapper function for `copy_fromh`. nerv.MMatrix.copy_toh = nerv.Matrix.copy_to --- A wrapper function for `nerv.CuMatrix` copy. function nerv.MMatrix:copy_fromd(b, ...) b:copy_toh(self, ...) end --- A wrapper function for `nerv.CuMatrix` copy. function nerv.MMatrix:copy_tod(b, ...) b:copy_fromh(self, ...) end