aboutsummaryrefslogtreecommitdiff
path: root/nerv/matrix/matrix.h
blob: 788f596aa5afda1120c1d95434e120e83212d79b (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
#ifndef NERV_LUA_MATRIX_H
#define NERV_LUA_MATRIX_H
#include "../lib/luaT/luaT.h"
#define _MATRIX_GET_CONTEXT(L, p, tname, ctname) \
    do { \
        if (lua_gettop(L) < p) \
        { \
            luaT_pushmetatable(L, tname); \
            lua_getfield(L, -1, "_default_context"); \
            context = luaT_checkudata(L, -1, ctname); \
            lua_pop(L, 2); \
        } \
        else \
        { \
            context = luaT_checkudata(L, p, ctname); \
        } \
    } while (0)

extern const char *nerv_cuda_context_tname;
extern const char *nerv_host_context_tname;
extern const char *nerv_matrix_host_tname;
#define MATRIX_GET_CONTEXT(L, p) _MATRIX_GET_CONTEXT(L, p, nerv_matrix_(tname), MATRIX_CONTEXT_TNAME)
#define MMATRIX_GET_CONTEXT(L, p) _MATRIX_GET_CONTEXT(L, p, nerv_matrix_host_tname, nerv_host_context_tname)
#endif