#ifndef NERV_LUA_MATRIX_H #define NERV_LUA_MATRIX_H #include "../lib/luaT/luaT.h" #define _MATRIX_GET_CONTEXT(L, p, tname, ctname) \ do { \ if (lua_gettop(L) < p) \ { \ luaT_pushmetatable(L, tname); \ lua_getfield(L, -1, "_default_context"); \ context = luaT_checkudata(L, -1, ctname); \ lua_pop(L, 2); \ } \ else \ { \ context = luaT_checkudata(L, p, ctname); \ } \ } while (0) extern const char *nerv_cuda_context_tname; extern const char *nerv_host_context_tname; extern const char *nerv_matrix_host_tname; #define MATRIX_GET_CONTEXT(L, p) _MATRIX_GET_CONTEXT(L, p, nerv_matrix_(tname), MATRIX_CONTEXT_TNAME) #define MMATRIX_GET_CONTEXT(L, p) _MATRIX_GET_CONTEXT(L, p, nerv_matrix_host_tname, nerv_host_context_tname) #endif