#include <stdio.h>
#include <stdlib.h>
#include <lua.h>
#include <lualib.h>
#include <luaT/luaT.h>
#include "ModelSync.h"
#include "../threads/lib/luaTHRD.h"
const char *fastnn_model_sync_tname = "fastnn.CModelSync";
const char *fastnn_xent_tname = "fastnn.CXent";
const char *fastnn_mse_tname = "fastnn.CMse";
const char *fastnn_global_option_tname = "fastnn.CGlobalOption";
static int model_sync_new(lua_State *L)
{
ModelSync *model_sync = NULL;
if(lua_gettop(L) == 0)
{
model_sync = ModelSync_new();
}
else if(lua_gettop(L) == 1)
{
long id = luaL_checklong(L, 1);
model_sync = ModelSync_newWithId(id);
}
else
luaL_error(L, "modelsync: modelsync new invalid arguments");
if(!model_sync)
luaL_error(L, "modelsync: modelsync new failed");
luaTHRD_pushudata(L, model_sync, fastnn_model_sync_tname);
return 1;
}
static int model_sync_tostring(lua_State *L)
{
char str[STRLEN];
ModelSync *model_sync = luaTHRD_checkudata(L, 1, fastnn_model_sync_tname);
snprintf(str, STRLEN, "fastnn.modelsync <%lx>", ModelSync_id(model_sync));
lua_pushstring(L, str);
return 1;
}
static int model_sync_id(lua_State *L)
{
ModelSync *model_sync = luaTHRD_checkudata(L, 1, fastnn_model_sync_tname);
lua_pushinteger(L, ModelSync_id(model_sync));
return 1;
}
static int model_sync_lockmodel(lua_State *L)
{
ModelSync *model_sync = luaTHRD_checkudata(L, 1, fastnn_model_sync_tname);
if (ModelSync_lockmodel(model_sync))
luaL_error(L, "modelsync: model lock failed");
return 0;
}
static int model_sync_unlockmodel(lua_State *L)
{
ModelSync *model_sync = luaTHRD_checkudata(L, 1, fastnn_model_sync_tname);
if (ModelSync_unlockmodel(model_sync))
luaL_error(L, "modelsync: model unlock failed");
return 0;
}
static int model_sync_lockstate(lua_State *L)
{
ModelSync *model_sync = luaTHRD_checkudata(L, 1, fastnn_model_sync_tname);
if (ModelSync_lockstate(model_sync))
luaL_error(L, "modelsync: state lock failed");
return 0;
}
static int model_sync_unlockstate(lua_State *L)
{
ModelSync *model_sync = luaTHRD_checkudata(L, 1, fastnn_model_sync_tname);
if (ModelSync_unlockstate(model_sync))
luaL_error(L, "modelsync: state unlock failed");
return 0;
}
static int model_sync_free(lua_State *L)
{
ModelSync *model_sync = luaTHRD_checkudata(L, 1, fastnn_model_sync_tname);
ModelSync_free(model_sync);
return 0;
}
static int model_sync_initbuffer(lua_State *L)
{
ModelSync *model_sync = luaTHRD_checkudata(L, 1, fastnn_model_sync_tname);
model_sync->dim_ = luaL_checkinteger(L, 2);
ModelSync_initBuffer(model_sync);
return 0;
}
static int model_sync_weightfromd(lua_State *L)
{
ModelSync *model_sync = luaTHRD_checkudata(L, 1, fastnn_model_sync_tname);
Matrix *dm = luaT_checkudata(L, 2, "nerv.CuMatrixFloat");
ModelSync_weightfromd(model_sync, dm);
return 0;
}
static int model_sync_weighttod(lua_State *L)
{
ModelSync *model_sync = luaTHRD_checkudata(L, 1, fastnn_model_sync_tname);
Matrix *dm = luaT_checkudata(L, 2, "nerv.CuMatrixFloat");
ModelSync_weighttod(model_sync, dm);
return 0;
}
static int model_sync_initialized(lua_State *L)
{
ModelSync *model_sync = luaTHRD_checkudata(L, 1, fastnn_model_sync_tname);
lua_pushboolean(L, model_sync->initialized_);
return 1;
}
static int model_sync_setpos(lua_State *L)
{
ModelSync *model_sync = luaTHRD_checkudata(L, 1, fastnn_model_sync_tname);
int pos = luaL_checkinteger(L, 2);
model_sync->pos_ = pos;
return 0;
}
static int model_sync_threadcount(lua_State *L)
{
ModelSync *model_sync = luaTHRD_checkudata(L, 1, fastnn_model_sync_tname);
lua_pushinteger(L, ModelSync_threadcount(model_sync));
return 1;
}
static int model_sync_syncinc(lua_State *L)
{
ModelSync *model_sync = luaTHRD_checkudata(L, 1, fastnn_model_sync_tname);
ModelSync_syncinc(model_sync);
return 0;
}
static int model_sync_syncdec(lua_State *L)
{
ModelSync *model_sync = luaTHRD_checkudata(L, 1, fastnn_model_sync_tname);
ModelSync_syncdec(model_sync);
return 0;
}
//////////////////////////////////////////
static int xent_new(lua_State *L)
{
Xent *xent = NULL;
if(lua_gettop(L) == 0)
{
xent = Xent_new();
}
else if(lua_gettop(L) == 1)
{
long id = luaL_checklong(L, 1);
xent = Xent_newWithId(id);
}
else if(lua_gettop(L) == 4)
{
size_t frames_, correct_;
double loss_, entropy_ ;
frames_ = luaL_checkinteger(L, 1);
correct_ = luaL_checkinteger(L, 2);
loss_ = luaL_checknumber(L, 3);
entropy_ = luaL_checknumber(L, 4);
xent = Xent_newWithParm(frames_, correct_, loss_, entropy_);
}
else
luaL_error(L, "xent: xent new invalid arguments");
if(!xent)
luaL_error(L, "xent: xent new failed");
luaTHRD_pushudata(L, xent, fastnn_xent_tname);
return 1;
}
static int xent_tostring(lua_State *L)
{
char str[STRLEN];
Xent *xent = luaTHRD_checkudata(L, 1, fastnn_xent_tname);
snprintf(str, STRLEN, "fastnn.xent <%lx>", Xent_id(xent));
lua_pushstring(L, str);
return 1;
}
static int xent_totalframes(lua_State *L)
{
Xent *xent = luaTHRD_checkudata(L, 1, fastnn_xent_tname);
lua_pushinteger(L, xent->frames_);
return 1;
}
static int xent_correct(lua_State *L)
{
Xent *xent = luaTHRD_checkudata(L, 1, fastnn_xent_tname);
lua_pushinteger(L, xent->correct_);
return 1;
}
static int xent_loss(lua_State *L)
{
Xent *xent = luaTHRD_checkudata(L, 1, fastnn_xent_tname);
lua_pushnumber(L, xent->loss_);
return 1;
}
static int xent_entropy(lua_State *L)
{
Xent *xent = luaTHRD_checkudata(L, 1, fastnn_xent_tname);
lua_pushnumber(L, xent->entropy_);
return 1;
}
static int xent_id(lua_State *L)
{
Xent *xent = luaTHRD_checkudata(L, 1, fastnn_xent_tname);
lua_pushinteger(L, Xent_id(xent));
return 1;
}
static int xent_free(lua_State *L)
{
Xent *xent = luaTHRD_checkudata(L, 1, fastnn_xent_tname);
Xent_free(xent);
return 0;
}
static int xent_add(lua_State *L)
{
Xent *a = luaTHRD_checkudata(L, 1, fastnn_xent_tname);
Xent *b = luaTHRD_checkudata(L, 2, fastnn_xent_tname);
Xent_add(a, b);
return 0;
}
/////////////////////////////////////////////
static int mse_new(lua_State *L)
{
Mse *mse = NULL;
if(lua_gettop(L) == 0)
{
mse = Mse_new();
}
else if(lua_gettop(L) == 1)
{
long id = luaL_checklong(L, 1);
mse = Mse_newWithId(id);
}
else if(lua_gettop(L) == 2)
{
size_t frames_;
double loss_;
frames_ = luaL_checkinteger(L, 1);
loss_ = luaL_checknumber(L, 2);
mse = Mse_newWithParm(frames_, loss_);
}
else
luaL_error(L, "mse: mse new invalid arguments");
if(!mse)
luaL_error(L, "mse: mse new failed");
luaTHRD_pushudata(L, mse, fastnn_mse_tname);
return 1;
}
static int mse_tostring(lua_State *L)
{
char str[STRLEN];
Mse *mse = luaTHRD_checkudata(L, 1, fastnn_mse_tname);
snprintf(str, STRLEN, "fastnn.mse <%lx>", Mse_id(mse));
lua_pushstring(L, str);
return 1;
}
static int mse_totalframes(lua_State *L)
{
Mse *mse = luaTHRD_checkudata(L, 1, fastnn_mse_tname);
lua_pushinteger(L, mse->frames_);
return 1;
}
static int mse_loss(lua_State *L)
{
Mse *mse = luaTHRD_checkudata(L, 1, fastnn_mse_tname);
lua_pushnumber(L, mse->loss_);
return 1;
}
static int mse_id(lua_State *L)
{
Mse *mse = luaTHRD_checkudata(L, 1, fastnn_mse_tname);
lua_pushinteger(L, Mse_id(mse));
return 1;
}
static int mse_free(lua_State *L)
{
Mse *mse = luaTHRD_checkudata(L, 1, fastnn_mse_tname);
Mse_free(mse);
return 0;
}
static int mse_add(lua_State *L)
{
Mse *a = luaTHRD_checkudata(L, 1, fastnn_mse_tname);
Mse *b = luaTHRD_checkudata(L, 2, fastnn_mse_tname);
Mse_add(a, b);
return 0;
}
/////////////////////////////////////////////
static int global_option_new(lua_State *L)
{
GlobalOption *global_option = NULL;
if(lua_gettop(L) == 0)
{
global_option = GlobalOption_new();
}
else if(lua_gettop(L) == 1)
{
long id = luaL_checklong(L, 1);
global_option = GlobalOption_newWithId(id);
}
else if(lua_gettop(L) > 3)
{
int batch_size = luaL_checkinteger(L, 1);
float lrate = luaL_checknumber(L, 2);
bool bp = lua_toboolean(L, 3);
const char *tr_scp = lua_tostring(L, 4);
const char *cv_scp = lua_tostring(L, 5);
const char *transf = lua_tostring(L, 6);
const char *network = lua_tostring(L, 7);
global_option = GlobalOption_newWithParm(batch_size, lrate, bp, tr_scp, cv_scp, transf, network);
}
else
luaL_error(L, "global_option: global_option new invalid arguments");
if(!global_option)
luaL_error(L, "global_option: global_option new failed");
luaTHRD_pushudata(L, global_option, fastnn_global_option_tname);
return 1;
}
static int global_option_tostring(lua_State *L)
{
char str[STRLEN];
GlobalOption *global_option = luaTHRD_checkudata(L, 1, fastnn_global_option_tname);
snprintf(str, STRLEN, "fastnn.global_option <%lx>", GlobalOption_id(global_option));
lua_pushstring(L, str);
return 1;
}
static int global_option_id(lua_State *L)
{
GlobalOption *global_option = luaTHRD_checkudata(L, 1, fastnn_global_option_tname);
lua_pushinteger(L, GlobalOption_id(global_option));
return 1;
}
static int global_option_free(lua_State *L)
{
GlobalOption *global_option = luaTHRD_checkudata(L, 1, fastnn_global_option_tname);
GlobalOption_free(global_option);
return 0;
}
static int global_option_batch_size(lua_State *L)
{
GlobalOption *global_option = luaTHRD_checkudata(L, 1, fastnn_global_option_tname);
lua_pushinteger(L, global_option->batch_size);
return 1;
}
static int global_option_lrate(lua_State *L)
{
GlobalOption *global_option = luaTHRD_checkudata(L, 1, fastnn_global_option_tname);
lua_pushnumber(L, global_option->lrate);
return 1;
}
static int global_option_bp(lua_State *L)
{
GlobalOption *global_option = luaTHRD_checkudata(L, 1, fastnn_global_option_tname);
lua_pushboolean(L, global_option->bp);
return 1;
}
static int global_option_tr_scp(lua_State *L)
{
GlobalOption *global_option = luaTHRD_checkudata(L, 1, fastnn_global_option_tname);
lua_pushstring(L, global_option->tr_scp);
return 1;
}
static int global_option_cv_scp(lua_State *L)
{
GlobalOption *global_option = luaTHRD_checkudata(L, 1, fastnn_global_option_tname);
lua_pushstring(L, global_option->cv_scp);
return 1;
}
static int global_option_transf(lua_State *L)
{
GlobalOption *global_option = luaTHRD_checkudata(L, 1, fastnn_global_option_tname);
lua_pushstring(L, global_option->transf);
return 1;
}
static int global_option_network(lua_State *L)
{
GlobalOption *global_option = luaTHRD_checkudata(L, 1, fastnn_global_option_tname);
lua_pushstring(L, global_option->network);
return 1;
}
//////////////////////////////////////////////
static const struct luaL_Reg model_sync__ [] = {
{"new", model_sync_new},
{"__tostring", model_sync_tostring},
{"id", model_sync_id},
{"lockmodel", model_sync_lockmodel},
{"unlockmodel", model_sync_unlockmodel},
{"lockstate", model_sync_lockstate},
{"unlockstate", model_sync_unlockstate},
{"initbuffer", model_sync_initbuffer},
{"setpos", model_sync_setpos},
{"initialized", model_sync_initialized},
{"weightfromd", model_sync_weightfromd},
{"weighttod", model_sync_weighttod},
{"threadcount", model_sync_threadcount},
{"syncinc", model_sync_syncinc},
{"syncdec", model_sync_syncdec},
{"threadcount", model_sync_threadcount},
{"free", model_sync_free},
{NULL, NULL}
};
static const struct luaL_Reg xent__ [] = {
{"new", xent_new},
{"__tostring", xent_tostring},
{"id", xent_id},
{"totalframes", xent_totalframes},
{"correct", xent_correct},
{"loss", xent_loss},
{"entropy", xent_entropy},
{"add", xent_add},
{"free", xent_free},
{NULL, NULL}
};
static const struct luaL_Reg mse__ [] = {
{"new", mse_new},
{"__tostring", mse_tostring},
{"id", mse_id},
{"totalframes", mse_totalframes},
{"loss", mse_loss},
{"add", mse_add},
{"free", mse_free},
{NULL, NULL}
};
static const struct luaL_Reg global_option__ [] = {
{"new", global_option_new},
{"__tostring", global_option_tostring},
{"id", global_option_id},
{"batch_size", global_option_batch_size},
{"lrate", global_option_lrate},
{"bp", global_option_bp},
{"tr_scp", global_option_tr_scp},
{"cv_scp", global_option_cv_scp},
{"transf", global_option_transf},
{"network", global_option_network},
{"free", global_option_free},
{NULL, NULL}
};
void fastnn_init_modelsync(lua_State *L)
{
luaT_newmetatable(L, fastnn_model_sync_tname, NULL, model_sync_new, NULL, NULL);
luaL_register(L, NULL, model_sync__);
lua_pop(L, 1);
luaT_newmetatable(L, fastnn_xent_tname, NULL, xent_new, xent_free, NULL);
luaL_register(L, NULL, xent__);
lua_pop(L, 1);
luaT_newmetatable(L, fastnn_mse_tname, NULL, mse_new, mse_free, NULL);
luaL_register(L, NULL, mse__);
lua_pop(L, 1);
luaT_newmetatable(L, fastnn_global_option_tname, NULL, global_option_new, global_option_free, NULL);
luaL_register(L, NULL, global_option__);
lua_pop(L, 1);
/*
printf("%s %lx\n", model_sync__[13].name, model_sync__[13].func);
if(!luaL_newmetatable(L, fastnn_model_sync_tname))
luaL_error(L, "fastnn: fastnn.modelsync type already exists");
luaL_setfuncs(L, model_sync__, 0);
lua_pushstring(L, "__index");
lua_pushvalue(L, -2);
lua_rawset(L, -3);
lua_pop(L, 1);
printf("%s %lx\n", model_sync__[13].name, model_sync__[13].func);
lua_pushstring(L, "modelsync");
luaTHRD_pushctortable(L, model_sync_new, fastnn_model_sync_tname);
lua_rawset(L, -3);
printf("%s %lx\n", model_sync__[13].name, model_sync__[13].func);
*/
}