#include #include #include #include #include #include "ModelSync.h" #include "../threads/lib/luaTHRD.h" const char *fastnn_model_sync_tname = "fastnn.CModelSync"; const char *fastnn_xent_tname = "fastnn.CXent"; const char *fastnn_mse_tname = "fastnn.CMse"; const char *fastnn_global_option_tname = "fastnn.CGlobalOption"; static int model_sync_new(lua_State *L) { ModelSync *model_sync = NULL; if(lua_gettop(L) == 0) { model_sync = ModelSync_new(); } else if(lua_gettop(L) == 1) { long id = luaL_checklong(L, 1); model_sync = ModelSync_newWithId(id); } else luaL_error(L, "modelsync: modelsync new invalid arguments"); if(!model_sync) luaL_error(L, "modelsync: modelsync new failed"); luaTHRD_pushudata(L, model_sync, fastnn_model_sync_tname); return 1; } static int model_sync_tostring(lua_State *L) { char str[STRLEN]; ModelSync *model_sync = luaTHRD_checkudata(L, 1, fastnn_model_sync_tname); snprintf(str, STRLEN, "fastnn.modelsync <%lx>", ModelSync_id(model_sync)); lua_pushstring(L, str); return 1; } static int model_sync_id(lua_State *L) { ModelSync *model_sync = luaTHRD_checkudata(L, 1, fastnn_model_sync_tname); lua_pushinteger(L, ModelSync_id(model_sync)); return 1; } static int model_sync_lockmodel(lua_State *L) { ModelSync *model_sync = luaTHRD_checkudata(L, 1, fastnn_model_sync_tname); if (ModelSync_lockmodel(model_sync)) luaL_error(L, "modelsync: model lock failed"); return 0; } static int model_sync_unlockmodel(lua_State *L) { ModelSync *model_sync = luaTHRD_checkudata(L, 1, fastnn_model_sync_tname); if (ModelSync_unlockmodel(model_sync)) luaL_error(L, "modelsync: model unlock failed"); return 0; } static int model_sync_lockstate(lua_State *L) { ModelSync *model_sync = luaTHRD_checkudata(L, 1, fastnn_model_sync_tname); if (ModelSync_lockstate(model_sync)) luaL_error(L, "modelsync: state lock failed"); return 0; } static int model_sync_unlockstate(lua_State *L) { ModelSync *model_sync = luaTHRD_checkudata(L, 1, fastnn_model_sync_tname); if (ModelSync_unlockstate(model_sync)) luaL_error(L, "modelsync: state unlock failed"); return 0; } static int model_sync_free(lua_State *L) { ModelSync *model_sync = luaTHRD_checkudata(L, 1, fastnn_model_sync_tname); ModelSync_free(model_sync); return 0; } static int model_sync_initbuffer(lua_State *L) { ModelSync *model_sync = luaTHRD_checkudata(L, 1, fastnn_model_sync_tname); model_sync->dim_ = luaL_checkinteger(L, 2); ModelSync_initBuffer(model_sync); return 0; } static int model_sync_weightfromd(lua_State *L) { ModelSync *model_sync = luaTHRD_checkudata(L, 1, fastnn_model_sync_tname); Matrix *dm = luaT_checkudata(L, 2, "nerv.CuMatrixFloat"); ModelSync_weightfromd(model_sync, dm); return 0; } static int model_sync_weighttod(lua_State *L) { ModelSync *model_sync = luaTHRD_checkudata(L, 1, fastnn_model_sync_tname); Matrix *dm = luaT_checkudata(L, 2, "nerv.CuMatrixFloat"); ModelSync_weighttod(model_sync, dm); return 0; } static int model_sync_initialized(lua_State *L) { ModelSync *model_sync = luaTHRD_checkudata(L, 1, fastnn_model_sync_tname); lua_pushboolean(L, model_sync->initialized_); return 1; } static int model_sync_setpos(lua_State *L) { ModelSync *model_sync = luaTHRD_checkudata(L, 1, fastnn_model_sync_tname); int pos = luaL_checkinteger(L, 2); model_sync->pos_ = pos; return 0; } static int model_sync_threadcount(lua_State *L) { ModelSync *model_sync = luaTHRD_checkudata(L, 1, fastnn_model_sync_tname); lua_pushinteger(L, ModelSync_threadcount(model_sync)); return 1; } static int model_sync_syncinc(lua_State *L) { ModelSync *model_sync = luaTHRD_checkudata(L, 1, fastnn_model_sync_tname); ModelSync_syncinc(model_sync); return 0; } static int model_sync_syncdec(lua_State *L) { ModelSync *model_sync = luaTHRD_checkudata(L, 1, fastnn_model_sync_tname); ModelSync_syncdec(model_sync); return 0; } ////////////////////////////////////////// static int xent_new(lua_State *L) { Xent *xent = NULL; if(lua_gettop(L) == 0) { xent = Xent_new(); } else if(lua_gettop(L) == 1) { long id = luaL_checklong(L, 1); xent = Xent_newWithId(id); } else if(lua_gettop(L) == 4) { size_t frames_, correct_; double loss_, entropy_ ; frames_ = luaL_checkinteger(L, 1); correct_ = luaL_checkinteger(L, 2); loss_ = luaL_checknumber(L, 3); entropy_ = luaL_checknumber(L, 4); xent = Xent_newWithParm(frames_, correct_, loss_, entropy_); } else luaL_error(L, "xent: xent new invalid arguments"); if(!xent) luaL_error(L, "xent: xent new failed"); luaTHRD_pushudata(L, xent, fastnn_xent_tname); return 1; } static int xent_tostring(lua_State *L) { char str[STRLEN]; Xent *xent = luaTHRD_checkudata(L, 1, fastnn_xent_tname); snprintf(str, STRLEN, "fastnn.xent <%lx>", Xent_id(xent)); lua_pushstring(L, str); return 1; } static int xent_totalframes(lua_State *L) { Xent *xent = luaTHRD_checkudata(L, 1, fastnn_xent_tname); lua_pushinteger(L, xent->frames_); return 1; } static int xent_correct(lua_State *L) { Xent *xent = luaTHRD_checkudata(L, 1, fastnn_xent_tname); lua_pushinteger(L, xent->correct_); return 1; } static int xent_loss(lua_State *L) { Xent *xent = luaTHRD_checkudata(L, 1, fastnn_xent_tname); lua_pushnumber(L, xent->loss_); return 1; } static int xent_entropy(lua_State *L) { Xent *xent = luaTHRD_checkudata(L, 1, fastnn_xent_tname); lua_pushnumber(L, xent->entropy_); return 1; } static int xent_id(lua_State *L) { Xent *xent = luaTHRD_checkudata(L, 1, fastnn_xent_tname); lua_pushinteger(L, Xent_id(xent)); return 1; } static int xent_free(lua_State *L) { Xent *xent = luaTHRD_checkudata(L, 1, fastnn_xent_tname); Xent_free(xent); return 0; } static int xent_add(lua_State *L) { Xent *a = luaTHRD_checkudata(L, 1, fastnn_xent_tname); Xent *b = luaTHRD_checkudata(L, 2, fastnn_xent_tname); Xent_add(a, b); return 0; } ///////////////////////////////////////////// static int mse_new(lua_State *L) { Mse *mse = NULL; if(lua_gettop(L) == 0) { mse = Mse_new(); } else if(lua_gettop(L) == 1) { long id = luaL_checklong(L, 1); mse = Mse_newWithId(id); } else if(lua_gettop(L) == 2) { size_t frames_; double loss_; frames_ = luaL_checkinteger(L, 1); loss_ = luaL_checknumber(L, 2); mse = Mse_newWithParm(frames_, loss_); } else luaL_error(L, "mse: mse new invalid arguments"); if(!mse) luaL_error(L, "mse: mse new failed"); luaTHRD_pushudata(L, mse, fastnn_mse_tname); return 1; } static int mse_tostring(lua_State *L) { char str[STRLEN]; Mse *mse = luaTHRD_checkudata(L, 1, fastnn_mse_tname); snprintf(str, STRLEN, "fastnn.mse <%lx>", Mse_id(mse)); lua_pushstring(L, str); return 1; } static int mse_totalframes(lua_State *L) { Mse *mse = luaTHRD_checkudata(L, 1, fastnn_mse_tname); lua_pushinteger(L, mse->frames_); return 1; } static int mse_loss(lua_State *L) { Mse *mse = luaTHRD_checkudata(L, 1, fastnn_mse_tname); lua_pushnumber(L, mse->loss_); return 1; } static int mse_id(lua_State *L) { Mse *mse = luaTHRD_checkudata(L, 1, fastnn_mse_tname); lua_pushinteger(L, Mse_id(mse)); return 1; } static int mse_free(lua_State *L) { Mse *mse = luaTHRD_checkudata(L, 1, fastnn_mse_tname); Mse_free(mse); return 0; } static int mse_add(lua_State *L) { Mse *a = luaTHRD_checkudata(L, 1, fastnn_mse_tname); Mse *b = luaTHRD_checkudata(L, 2, fastnn_mse_tname); Mse_add(a, b); return 0; } ///////////////////////////////////////////// static int global_option_new(lua_State *L) { GlobalOption *global_option = NULL; if(lua_gettop(L) == 0) { global_option = GlobalOption_new(); } else if(lua_gettop(L) == 1) { long id = luaL_checklong(L, 1); global_option = GlobalOption_newWithId(id); } else if(lua_gettop(L) > 3) { int batch_size = luaL_checkinteger(L, 1); float lrate = luaL_checknumber(L, 2); bool bp = lua_toboolean(L, 3); const char *tr_scp = lua_tostring(L, 4); const char *cv_scp = lua_tostring(L, 5); const char *transf = lua_tostring(L, 6); const char *network = lua_tostring(L, 7); global_option = GlobalOption_newWithParm(batch_size, lrate, bp, tr_scp, cv_scp, transf, network); } else luaL_error(L, "global_option: global_option new invalid arguments"); if(!global_option) luaL_error(L, "global_option: global_option new failed"); luaTHRD_pushudata(L, global_option, fastnn_global_option_tname); return 1; } static int global_option_tostring(lua_State *L) { char str[STRLEN]; GlobalOption *global_option = luaTHRD_checkudata(L, 1, fastnn_global_option_tname); snprintf(str, STRLEN, "fastnn.global_option <%lx>", GlobalOption_id(global_option)); lua_pushstring(L, str); return 1; } static int global_option_id(lua_State *L) { GlobalOption *global_option = luaTHRD_checkudata(L, 1, fastnn_global_option_tname); lua_pushinteger(L, GlobalOption_id(global_option)); return 1; } static int global_option_free(lua_State *L) { GlobalOption *global_option = luaTHRD_checkudata(L, 1, fastnn_global_option_tname); GlobalOption_free(global_option); return 0; } static int global_option_batch_size(lua_State *L) { GlobalOption *global_option = luaTHRD_checkudata(L, 1, fastnn_global_option_tname); lua_pushinteger(L, global_option->batch_size); return 1; } static int global_option_lrate(lua_State *L) { GlobalOption *global_option = luaTHRD_checkudata(L, 1, fastnn_global_option_tname); lua_pushnumber(L, global_option->lrate); return 1; } static int global_option_bp(lua_State *L) { GlobalOption *global_option = luaTHRD_checkudata(L, 1, fastnn_global_option_tname); lua_pushboolean(L, global_option->bp); return 1; } static int global_option_tr_scp(lua_State *L) { GlobalOption *global_option = luaTHRD_checkudata(L, 1, fastnn_global_option_tname); lua_pushstring(L, global_option->tr_scp); return 1; } static int global_option_cv_scp(lua_State *L) { GlobalOption *global_option = luaTHRD_checkudata(L, 1, fastnn_global_option_tname); lua_pushstring(L, global_option->cv_scp); return 1; } static int global_option_transf(lua_State *L) { GlobalOption *global_option = luaTHRD_checkudata(L, 1, fastnn_global_option_tname); lua_pushstring(L, global_option->transf); return 1; } static int global_option_network(lua_State *L) { GlobalOption *global_option = luaTHRD_checkudata(L, 1, fastnn_global_option_tname); lua_pushstring(L, global_option->network); return 1; } ////////////////////////////////////////////// static const struct luaL_Reg model_sync__ [] = { {"new", model_sync_new}, {"__tostring", model_sync_tostring}, {"id", model_sync_id}, {"lockmodel", model_sync_lockmodel}, {"unlockmodel", model_sync_unlockmodel}, {"lockstate", model_sync_lockstate}, {"unlockstate", model_sync_unlockstate}, {"initbuffer", model_sync_initbuffer}, {"setpos", model_sync_setpos}, {"initialized", model_sync_initialized}, {"weightfromd", model_sync_weightfromd}, {"weighttod", model_sync_weighttod}, {"threadcount", model_sync_threadcount}, {"syncinc", model_sync_syncinc}, {"syncdec", model_sync_syncdec}, {"threadcount", model_sync_threadcount}, {"free", model_sync_free}, {NULL, NULL} }; static const struct luaL_Reg xent__ [] = { {"new", xent_new}, {"__tostring", xent_tostring}, {"id", xent_id}, {"totalframes", xent_totalframes}, {"correct", xent_correct}, {"loss", xent_loss}, {"entropy", xent_entropy}, {"add", xent_add}, {"free", xent_free}, {NULL, NULL} }; static const struct luaL_Reg mse__ [] = { {"new", mse_new}, {"__tostring", mse_tostring}, {"id", mse_id}, {"totalframes", mse_totalframes}, {"loss", mse_loss}, {"add", mse_add}, {"free", mse_free}, {NULL, NULL} }; static const struct luaL_Reg global_option__ [] = { {"new", global_option_new}, {"__tostring", global_option_tostring}, {"id", global_option_id}, {"batch_size", global_option_batch_size}, {"lrate", global_option_lrate}, {"bp", global_option_bp}, {"tr_scp", global_option_tr_scp}, {"cv_scp", global_option_cv_scp}, {"transf", global_option_transf}, {"network", global_option_network}, {"free", global_option_free}, {NULL, NULL} }; void fastnn_init_modelsync(lua_State *L) { luaT_newmetatable(L, fastnn_model_sync_tname, NULL, model_sync_new, NULL, NULL); luaL_register(L, NULL, model_sync__); lua_pop(L, 1); luaT_newmetatable(L, fastnn_xent_tname, NULL, xent_new, xent_free, NULL); luaL_register(L, NULL, xent__); lua_pop(L, 1); luaT_newmetatable(L, fastnn_mse_tname, NULL, mse_new, mse_free, NULL); luaL_register(L, NULL, mse__); lua_pop(L, 1); luaT_newmetatable(L, fastnn_global_option_tname, NULL, global_option_new, global_option_free, NULL); luaL_register(L, NULL, global_option__); lua_pop(L, 1); /* printf("%s %lx\n", model_sync__[13].name, model_sync__[13].func); if(!luaL_newmetatable(L, fastnn_model_sync_tname)) luaL_error(L, "fastnn: fastnn.modelsync type already exists"); luaL_setfuncs(L, model_sync__, 0); lua_pushstring(L, "__index"); lua_pushvalue(L, -2); lua_rawset(L, -3); lua_pop(L, 1); printf("%s %lx\n", model_sync__[13].name, model_sync__[13].func); lua_pushstring(L, "modelsync"); luaTHRD_pushctortable(L, model_sync_new, fastnn_model_sync_tname); lua_rawset(L, -3); printf("%s %lx\n", model_sync__[13].name, model_sync__[13].func); */ }