diff options
Diffstat (limited to 'fastnn/lib')
-rw-r--r-- | fastnn/lib/ModelSync.c | 305 | ||||
-rw-r--r-- | fastnn/lib/ModelSync.h | 119 | ||||
-rw-r--r-- | fastnn/lib/modelsync.c | 532 | ||||
-rw-r--r-- | fastnn/lib/modelsync.lua | 107 |
4 files changed, 1063 insertions, 0 deletions
diff --git a/fastnn/lib/ModelSync.c b/fastnn/lib/ModelSync.c new file mode 100644 index 0000000..bd511ea --- /dev/null +++ b/fastnn/lib/ModelSync.c @@ -0,0 +1,305 @@ + +#include "ModelSync.h" +#include "../../nerv/lib/matrix/cuda_helper.h" +#include "../../nerv/lib/matrix/generic/elem_type.h" +#include "common.h" +#include <string.h> + + +ModelSync* ModelSync_new(void) +{ + ModelSync *self = (ModelSync*)malloc(sizeof(ModelSync)); + if (NULL != self) + { + self->model_mutex = THMutex_new(); + self->state_mutex = THMutex_new(); + self->initialized_ = false; + self->dim_ = 0; + self->pos_ = 0; + self->data_ = NULL; + self->free_data_ = NULL; + self->data_ = NULL; + self->refcount = 1; + self->threadcount = 0; + } + return self; +} + +ModelSync* ModelSync_newWithId(long id) +{ + ModelSync *self = (ModelSync*)id; + __sync_fetch_and_add(&self->refcount, 1); + return self; +} + +long ModelSync_id(ModelSync *self) +{ + return (long)(self); +} + +int ModelSync_lockmodel(ModelSync *self) +{ + if(THMutex_lock(self->model_mutex)) + return 1; + return 0; +} + +int ModelSync_unlockmodel(ModelSync *self) +{ + if(THMutex_unlock(self->model_mutex)) + return 1; + return 0; + +} +int ModelSync_lockstate(ModelSync *self) +{ + if(THMutex_lock(self->state_mutex)) + return 1; + return 0; +} + +int ModelSync_unlockstate(ModelSync *self) +{ + if(THMutex_unlock(self->state_mutex)) + return 1; + return 0; +} + +int ModelSync_free(ModelSync *self) +{ + if (NULL != self && __sync_fetch_and_add(&self->refcount, -1) == 1) + { + free(self->model_mutex); + free(self->state_mutex); + Status status; + CUDA_SAFE_SYNC_CALL(cudaFreeHost(self->free_data_), &status); + free(self); + } +} + +int ModelSync_initBuffer(ModelSync *self) +{ + if (NULL != self) + { + void *free_data = NULL, *data = NULL; + size_t size = self->dim_ * sizeof(float)+16; + Status status; + CUDA_SAFE_SYNC_CALL(cudaHostAlloc((void**) &free_data, size, cudaHostAllocPortable), &status); + NERV_SET_STATUS(&status, NERV_NORMAL, 0); + + data = (free_data ? (void *)( (((unsigned long)*(&free_data)) + 15) & ~0xFUL ) : NULL) ; + if (NULL != data) + { + self->data_ = (float*)(data); + self->free_data_ = (float*)(free_data); + } + return 0; + } + return 1; +} + +int ModelSync_weightfromd(ModelSync *self, Matrix *dm) +{ + + if (NULL != self && NULL != dm) + { + void *host_data_ = (void*)self->data_; + size_t width = dm->ncol * sizeof(float); + size_t src_pitch = dm->stride; + size_t dst_pitch = src_pitch; + Status status; + + CUDA_SAFE_SYNC_CALL(cudaMemcpy2D(host_data_+self->pos_, dst_pitch, dm->data.f, src_pitch, width, dm->nrow, cudaMemcpyDeviceToHost), &status); + NERV_SET_STATUS(&status, NERV_NORMAL, 0); + self->pos_ += dm->nrow * dm->stride; + return 0; + } + return 1; + +} + + +int ModelSync_weighttod(ModelSync *self, Matrix *dm) +{ + + if (NULL != self && NULL != dm) + { + void *host_data_ = (void*)self->data_; + size_t width = dm->ncol * sizeof(float); + size_t dst_pitch = dm->stride; + size_t src_pitch = dst_pitch; + Status status; + + CUDA_SAFE_SYNC_CALL(cudaMemcpy2D(dm->data.f, dst_pitch, host_data_+self->pos_, src_pitch, width, dm->nrow, cudaMemcpyHostToDevice), &status); + NERV_SET_STATUS(&status, NERV_NORMAL, 0); + + self->pos_ += dm->nrow * dm->stride; + self->initialized_ = true; + return 0; + } + return 1; +} + +void ModelSync_syncinc(ModelSync *self) +{ + __sync_fetch_and_add(&self->threadcount, 1); +} + +void ModelSync_syncdec(ModelSync *self) +{ + __sync_fetch_and_add(&self->threadcount, -1); +} + +int ModelSync_threadcount(ModelSync *self) +{ + return self->threadcount; +} + +///////////////////////////////// + +Xent* Xent_new() +{ + Xent *xent = (Xent*)malloc(sizeof(Xent)); + memset(xent, 0, sizeof(Xent)); + xent->refcount = 1; + return xent; +} + +Xent* Xent_newWithId(long id) +{ + Xent *xent = (Xent*)id; + __sync_fetch_and_add(&xent->refcount, 1); + return xent; +} + +Xent* Xent_newWithParm(size_t frames_, size_t correct_, double loss_, double entropy_) +{ + Xent *xent = (Xent*)malloc(sizeof(Xent)); + xent->frames_ = frames_; + xent->correct_ = correct_; + xent->loss_ = loss_; + xent->entropy_ = entropy_; + xent->refcount = 1; + return xent; +} + +long Xent_id(Xent *xent) +{ + return (long)(xent); +} + +Xent* Xent_add(Xent *a, Xent *b) +{ + a->frames_ += b->frames_; + a->correct_ += b->correct_; + a->loss_ += b->loss_; + a->entropy_ += b->entropy_; + return a; +} + +void Xent_free(Xent *xent) +{ + if (NULL != xent && __sync_fetch_and_add(&xent->refcount, -1) == 1) + { + free(xent); + xent = NULL; + } +} + + +////////////////////////////////// + +Mse* Mse_new() +{ + Mse *mse = (Mse*)malloc(sizeof(Mse)); + memset(mse, 0, sizeof(Mse)); + mse->refcount = 1; + return mse; +} + +Mse* Mse_newWithId(long id) +{ + Mse *mse = (Mse*)id; + __sync_fetch_and_add(&mse->refcount, 1); + return mse; +} + +Mse* Mse_newWithParm(size_t frames_, double loss_) +{ + Mse *mse = (Mse*)malloc(sizeof(Mse)); + mse->frames_ = frames_; + mse->loss_ = loss_; + mse->refcount = 1; + return mse; +} + + +long Mse_id(Mse *mse) +{ + return (long)(mse); +} + +Mse* Mse_add(Mse *a, Mse *b) +{ + a->frames_ += b->frames_; + a->loss_ += b->loss_; + return a; +} + +void Mse_free(Mse *mse) +{ + if (NULL != mse && __sync_fetch_and_add(&mse->refcount, -1) == 1) + { + free(mse); + mse = NULL; + } +} + +////////////////////////////////// + +GlobalOption* GlobalOption_new() +{ + GlobalOption *option = (GlobalOption*)malloc(sizeof(GlobalOption)); + option->refcount = 1; + return option; +} + +GlobalOption* GlobalOption_newWithParm(int batch_size, float lrate, bool bp,const char *tr_scp, const char *cv_scp, const char *transf, const char *network) +{ + GlobalOption *option = (GlobalOption*)malloc(sizeof(GlobalOption)); + option->batch_size = batch_size; + option->lrate = lrate; + option->bp = bp; + strncpy(option->tr_scp, tr_scp, strlen(tr_scp)+1); + strncpy(option->cv_scp, cv_scp, strlen(cv_scp)+1); + strncpy(option->transf, transf, strlen(transf)+1); + strncpy(option->network, network, strlen(network)+1); + option->refcount = 1; + + return option; +} + +GlobalOption* GlobalOption_newWithId(long id) +{ + GlobalOption *option = (GlobalOption*)id; + __sync_fetch_and_add(&option->refcount, 1); + return option; +} + + + +long GlobalOption_id(GlobalOption *option) +{ + return (long)(option); +} + +void GlobalOption_free(GlobalOption *option) +{ + if (NULL != option && __sync_fetch_and_add(&option->refcount, -1) == 1) + { + free(option); + option = NULL; + } +} + + diff --git a/fastnn/lib/ModelSync.h b/fastnn/lib/ModelSync.h new file mode 100644 index 0000000..71216a0 --- /dev/null +++ b/fastnn/lib/ModelSync.h @@ -0,0 +1,119 @@ + +#ifndef NERV_FASTNN_MODELSYNC_H +#define NERV_FASTNN_MODELSYNC_H + +#define STRLEN 1024 + +#include "../threads/lib/THThread.h" +#include "matrix/matrix.h" +#include "stdlib.h" +#include "stdbool.h" + +typedef struct NnetParallelOptions_ +{ + int num_threads; + int merge_size; + int num_merge; + int num_procs; + int threadid; + int myid; + int thread_level; + char merge_func[STRLEN]; + char log_file[STRLEN]; +} NnetParallelOptions; + + +typedef struct ModelSync_ +{ + THMutex *model_mutex; + THMutex *state_mutex; + bool initialized_; + int dim_; + int pos_; + float *data_; + float *free_data_; + int refcount; + int threadcount; +}ModelSync; + +ModelSync *ModelSync_new(void); +ModelSync *ModelSync_newWithId(long id); +int ModelSync_free(ModelSync *self); +long ModelSync_id(ModelSync *self); +int ModelSync_lockmodel(ModelSync *self); +int ModelSync_unlockmodel(ModelSync *self); +int ModelSync_lockstate(ModelSync *self); +int ModelSync_unlockstate(ModelSync *self); +int ModelSync_initBuffer(ModelSync *self); +int ModelSync_weightfromd(ModelSync *self, Matrix *dm); +int ModelSync_weighttod(ModelSync *self, Matrix *dm); +int ModelSync_threadcount(ModelSync *self); +void ModelSync_syncinc(ModelSync *self); +void ModelSync_syncdec(ModelSync *self); + +typedef struct Xent_ +{ + size_t frames_; + size_t correct_; + double loss_; + double entropy_; + int refcount; +} Xent; + +Xent* Xent_new(); +Xent* Xent_newWithId(long id); +Xent* Xent_newWithParm(size_t frames_, size_t correct_, double loss_, double entropy_); +long Xent_id(Xent *xent); +Xent* Xent_add(Xent *a, Xent *b); +void Xent_free(Xent *xent); + +typedef struct Mse_ +{ + size_t frames_; + double loss_; + int refcount; +} Mse; + +Mse* Mse_new(); +Mse* Mse_newWithId(long id); +Mse* Mse_newWithParm(size_t frames_, double loss_); +long Mse_id(Mse *mse); +Mse* Mse_add(Mse *a, Mse *b); +void Mse_free(Mse *mse); + +typedef struct NnetUpdateState_ +{ + int num_utter; + int num_nolabel; + int num_other_error; + long total_frames; + Xent xent; + Mse mse; +} NnetUpdateState; + +typedef struct GlobalOption_ +{ + int batch_size; + float lrate; + bool bp; + char tr_scp[STRLEN]; + char cv_scp[STRLEN]; + char transf[STRLEN]; + char network[STRLEN]; + int refcount; +}GlobalOption; + + +GlobalOption* GlobalOption_new(); +GlobalOption* GlobalOption_newWithParm(int batch_size, float lrate, bool bp, const char *tr_scp, const char *cv_scp, const char *transf, const char *network); +GlobalOption* GlobalOption_newWithId(long id); +long GlobalOption_id(GlobalOption *option); +void GlobalOption_free(GlobalOption *option); + + + + +#endif + + + diff --git a/fastnn/lib/modelsync.c b/fastnn/lib/modelsync.c new file mode 100644 index 0000000..2b52752 --- /dev/null +++ b/fastnn/lib/modelsync.c @@ -0,0 +1,532 @@ + +#include <stdio.h> +#include <stdlib.h> + +#include <lua.h> +#include <lualib.h> +#include <luaT/luaT.h> + + + +#include "ModelSync.h" +#include "../threads/lib/luaTHRD.h" + +const char *fastnn_model_sync_tname = "fastnn.CModelSync"; +const char *fastnn_xent_tname = "fastnn.CXent"; +const char *fastnn_mse_tname = "fastnn.CMse"; +const char *fastnn_global_option_tname = "fastnn.CGlobalOption"; + +static int model_sync_new(lua_State *L) +{ + ModelSync *model_sync = NULL; + if(lua_gettop(L) == 0) + { + model_sync = ModelSync_new(); + } + else if(lua_gettop(L) == 1) + { + long id = luaL_checklong(L, 1); + model_sync = ModelSync_newWithId(id); + } + else + luaL_error(L, "modelsync: modelsync new invalid arguments"); + if(!model_sync) + luaL_error(L, "modelsync: modelsync new failed"); + + luaTHRD_pushudata(L, model_sync, fastnn_model_sync_tname); + return 1; +} + + +static int model_sync_tostring(lua_State *L) +{ + char str[STRLEN]; + ModelSync *model_sync = luaTHRD_checkudata(L, 1, fastnn_model_sync_tname); + snprintf(str, STRLEN, "fastnn.modelsync <%lx>", ModelSync_id(model_sync)); + lua_pushstring(L, str); + return 1; +} + +static int model_sync_id(lua_State *L) +{ + ModelSync *model_sync = luaTHRD_checkudata(L, 1, fastnn_model_sync_tname); + lua_pushinteger(L, ModelSync_id(model_sync)); + return 1; +} + +static int model_sync_lockmodel(lua_State *L) +{ + ModelSync *model_sync = luaTHRD_checkudata(L, 1, fastnn_model_sync_tname); + if (ModelSync_lockmodel(model_sync)) + luaL_error(L, "modelsync: model lock failed"); + return 0; +} + +static int model_sync_unlockmodel(lua_State *L) +{ + ModelSync *model_sync = luaTHRD_checkudata(L, 1, fastnn_model_sync_tname); + if (ModelSync_unlockmodel(model_sync)) + luaL_error(L, "modelsync: model unlock failed"); + return 0; +} + +static int model_sync_lockstate(lua_State *L) +{ + ModelSync *model_sync = luaTHRD_checkudata(L, 1, fastnn_model_sync_tname); + if (ModelSync_lockstate(model_sync)) + luaL_error(L, "modelsync: state lock failed"); + return 0; +} + +static int model_sync_unlockstate(lua_State *L) +{ + ModelSync *model_sync = luaTHRD_checkudata(L, 1, fastnn_model_sync_tname); + if (ModelSync_unlockstate(model_sync)) + luaL_error(L, "modelsync: state unlock failed"); + return 0; +} + +static int model_sync_free(lua_State *L) +{ + ModelSync *model_sync = luaTHRD_checkudata(L, 1, fastnn_model_sync_tname); + ModelSync_free(model_sync); + return 0; +} + +static int model_sync_initbuffer(lua_State *L) +{ + ModelSync *model_sync = luaTHRD_checkudata(L, 1, fastnn_model_sync_tname); + model_sync->dim_ = luaL_checkinteger(L, 2); + ModelSync_initBuffer(model_sync); + return 0; +} + +static int model_sync_weightfromd(lua_State *L) +{ + ModelSync *model_sync = luaTHRD_checkudata(L, 1, fastnn_model_sync_tname); + Matrix *dm = luaT_checkudata(L, 2, "nerv.CuMatrixFloat"); + ModelSync_weightfromd(model_sync, dm); + return 0; +} + +static int model_sync_weighttod(lua_State *L) +{ + ModelSync *model_sync = luaTHRD_checkudata(L, 1, fastnn_model_sync_tname); + Matrix *dm = luaT_checkudata(L, 2, "nerv.CuMatrixFloat"); + ModelSync_weighttod(model_sync, dm); + return 0; +} + +static int model_sync_initialized(lua_State *L) +{ + ModelSync *model_sync = luaTHRD_checkudata(L, 1, fastnn_model_sync_tname); + lua_pushboolean(L, model_sync->initialized_); + return 1; +} + +static int model_sync_setpos(lua_State *L) +{ + ModelSync *model_sync = luaTHRD_checkudata(L, 1, fastnn_model_sync_tname); + int pos = luaL_checkinteger(L, 2); + model_sync->pos_ = pos; + return 0; +} + +static int model_sync_threadcount(lua_State *L) +{ + ModelSync *model_sync = luaTHRD_checkudata(L, 1, fastnn_model_sync_tname); + lua_pushinteger(L, ModelSync_threadcount(model_sync)); + return 1; +} + +static int model_sync_syncinc(lua_State *L) +{ + ModelSync *model_sync = luaTHRD_checkudata(L, 1, fastnn_model_sync_tname); + ModelSync_syncinc(model_sync); + return 0; +} + +static int model_sync_syncdec(lua_State *L) +{ + ModelSync *model_sync = luaTHRD_checkudata(L, 1, fastnn_model_sync_tname); + ModelSync_syncdec(model_sync); + return 0; +} +////////////////////////////////////////// + +static int xent_new(lua_State *L) +{ + Xent *xent = NULL; + if(lua_gettop(L) == 0) + { + xent = Xent_new(); + } + else if(lua_gettop(L) == 1) + { + long id = luaL_checklong(L, 1); + xent = Xent_newWithId(id); + } + else if(lua_gettop(L) == 4) + { + size_t frames_, correct_; + double loss_, entropy_ ; + frames_ = luaL_checkinteger(L, 1); + correct_ = luaL_checkinteger(L, 2); + loss_ = luaL_checknumber(L, 3); + entropy_ = luaL_checknumber(L, 4); + xent = Xent_newWithParm(frames_, correct_, loss_, entropy_); + } + else + luaL_error(L, "xent: xent new invalid arguments"); + if(!xent) + luaL_error(L, "xent: xent new failed"); + + luaTHRD_pushudata(L, xent, fastnn_xent_tname); + return 1; +} + +static int xent_tostring(lua_State *L) +{ + char str[STRLEN]; + Xent *xent = luaTHRD_checkudata(L, 1, fastnn_xent_tname); + snprintf(str, STRLEN, "fastnn.xent <%lx>", Xent_id(xent)); + lua_pushstring(L, str); + return 1; +} + +static int xent_totalframes(lua_State *L) +{ + Xent *xent = luaTHRD_checkudata(L, 1, fastnn_xent_tname); + lua_pushinteger(L, xent->frames_); + return 1; +} + +static int xent_correct(lua_State *L) +{ + Xent *xent = luaTHRD_checkudata(L, 1, fastnn_xent_tname); + lua_pushinteger(L, xent->correct_); + return 1; +} + +static int xent_loss(lua_State *L) +{ + Xent *xent = luaTHRD_checkudata(L, 1, fastnn_xent_tname); + lua_pushnumber(L, xent->loss_); + return 1; +} + +static int xent_entropy(lua_State *L) +{ + Xent *xent = luaTHRD_checkudata(L, 1, fastnn_xent_tname); + lua_pushnumber(L, xent->entropy_); + return 1; +} + +static int xent_id(lua_State *L) +{ + Xent *xent = luaTHRD_checkudata(L, 1, fastnn_xent_tname); + lua_pushinteger(L, Xent_id(xent)); + return 1; +} + +static int xent_free(lua_State *L) +{ + Xent *xent = luaTHRD_checkudata(L, 1, fastnn_xent_tname); + Xent_free(xent); + return 0; +} + +static int xent_add(lua_State *L) +{ + Xent *a = luaTHRD_checkudata(L, 1, fastnn_xent_tname); + Xent *b = luaTHRD_checkudata(L, 2, fastnn_xent_tname); + Xent_add(a, b); + return 0; +} + +///////////////////////////////////////////// + +static int mse_new(lua_State *L) +{ + Mse *mse = NULL; + if(lua_gettop(L) == 0) + { + mse = Mse_new(); + } + else if(lua_gettop(L) == 1) + { + long id = luaL_checklong(L, 1); + mse = Mse_newWithId(id); + } + else if(lua_gettop(L) == 2) + { + size_t frames_; + double loss_; + frames_ = luaL_checkinteger(L, 1); + loss_ = luaL_checknumber(L, 2); + mse = Mse_newWithParm(frames_, loss_); + } + else + luaL_error(L, "mse: mse new invalid arguments"); + if(!mse) + luaL_error(L, "mse: mse new failed"); + + luaTHRD_pushudata(L, mse, fastnn_mse_tname); + return 1; + +} + +static int mse_tostring(lua_State *L) +{ + char str[STRLEN]; + Mse *mse = luaTHRD_checkudata(L, 1, fastnn_mse_tname); + snprintf(str, STRLEN, "fastnn.mse <%lx>", Mse_id(mse)); + lua_pushstring(L, str); + return 1; +} + +static int mse_totalframes(lua_State *L) +{ + Mse *mse = luaTHRD_checkudata(L, 1, fastnn_mse_tname); + lua_pushinteger(L, mse->frames_); + return 1; +} + +static int mse_loss(lua_State *L) +{ + Mse *mse = luaTHRD_checkudata(L, 1, fastnn_mse_tname); + lua_pushnumber(L, mse->loss_); + return 1; +} + +static int mse_id(lua_State *L) +{ + Mse *mse = luaTHRD_checkudata(L, 1, fastnn_mse_tname); + lua_pushinteger(L, Mse_id(mse)); + return 1; +} + +static int mse_free(lua_State *L) +{ + Mse *mse = luaTHRD_checkudata(L, 1, fastnn_mse_tname); + Mse_free(mse); + return 0; +} + +static int mse_add(lua_State *L) +{ + Mse *a = luaTHRD_checkudata(L, 1, fastnn_mse_tname); + Mse *b = luaTHRD_checkudata(L, 2, fastnn_mse_tname); + Mse_add(a, b); + return 0; +} +///////////////////////////////////////////// + +static int global_option_new(lua_State *L) +{ + GlobalOption *global_option = NULL; + if(lua_gettop(L) == 0) + { + global_option = GlobalOption_new(); + } + else if(lua_gettop(L) == 1) + { + long id = luaL_checklong(L, 1); + global_option = GlobalOption_newWithId(id); + } + else if(lua_gettop(L) > 3) + { + int batch_size = luaL_checkinteger(L, 1); + float lrate = luaL_checknumber(L, 2); + bool bp = lua_toboolean(L, 3); + const char *tr_scp = lua_tostring(L, 4); + const char *cv_scp = lua_tostring(L, 5); + const char *transf = lua_tostring(L, 6); + const char *network = lua_tostring(L, 7); + global_option = GlobalOption_newWithParm(batch_size, lrate, bp, tr_scp, cv_scp, transf, network); + } + else + luaL_error(L, "global_option: global_option new invalid arguments"); + if(!global_option) + luaL_error(L, "global_option: global_option new failed"); + + luaTHRD_pushudata(L, global_option, fastnn_global_option_tname); + return 1; + +} + +static int global_option_tostring(lua_State *L) +{ + char str[STRLEN]; + GlobalOption *global_option = luaTHRD_checkudata(L, 1, fastnn_global_option_tname); + snprintf(str, STRLEN, "fastnn.global_option <%lx>", GlobalOption_id(global_option)); + lua_pushstring(L, str); + return 1; +} + +static int global_option_id(lua_State *L) +{ + GlobalOption *global_option = luaTHRD_checkudata(L, 1, fastnn_global_option_tname); + lua_pushinteger(L, GlobalOption_id(global_option)); + return 1; +} + +static int global_option_free(lua_State *L) +{ + GlobalOption *global_option = luaTHRD_checkudata(L, 1, fastnn_global_option_tname); + GlobalOption_free(global_option); + return 0; +} + +static int global_option_batch_size(lua_State *L) +{ + GlobalOption *global_option = luaTHRD_checkudata(L, 1, fastnn_global_option_tname); + lua_pushinteger(L, global_option->batch_size); + return 1; +} + +static int global_option_lrate(lua_State *L) +{ + GlobalOption *global_option = luaTHRD_checkudata(L, 1, fastnn_global_option_tname); + lua_pushnumber(L, global_option->lrate); + return 1; +} + +static int global_option_bp(lua_State *L) +{ + GlobalOption *global_option = luaTHRD_checkudata(L, 1, fastnn_global_option_tname); + lua_pushboolean(L, global_option->bp); + return 1; +} + +static int global_option_tr_scp(lua_State *L) +{ + GlobalOption *global_option = luaTHRD_checkudata(L, 1, fastnn_global_option_tname); + lua_pushstring(L, global_option->tr_scp); + return 1; +} + +static int global_option_cv_scp(lua_State *L) +{ + GlobalOption *global_option = luaTHRD_checkudata(L, 1, fastnn_global_option_tname); + lua_pushstring(L, global_option->cv_scp); + return 1; +} + +static int global_option_transf(lua_State *L) +{ + GlobalOption *global_option = luaTHRD_checkudata(L, 1, fastnn_global_option_tname); + lua_pushstring(L, global_option->transf); + return 1; +} + +static int global_option_network(lua_State *L) +{ + GlobalOption *global_option = luaTHRD_checkudata(L, 1, fastnn_global_option_tname); + lua_pushstring(L, global_option->network); + return 1; +} + + +////////////////////////////////////////////// + +static const struct luaL_Reg model_sync__ [] = { + {"new", model_sync_new}, + {"__tostring", model_sync_tostring}, + {"id", model_sync_id}, + {"lockmodel", model_sync_lockmodel}, + {"unlockmodel", model_sync_unlockmodel}, + {"lockstate", model_sync_lockstate}, + {"unlockstate", model_sync_unlockstate}, + {"initbuffer", model_sync_initbuffer}, + {"setpos", model_sync_setpos}, + {"initialized", model_sync_initialized}, + {"weightfromd", model_sync_weightfromd}, + {"weighttod", model_sync_weighttod}, + {"threadcount", model_sync_threadcount}, + {"syncinc", model_sync_syncinc}, + {"syncdec", model_sync_syncdec}, + {"threadcount", model_sync_threadcount}, + {"free", model_sync_free}, + {NULL, NULL} +}; + + +static const struct luaL_Reg xent__ [] = { + {"new", xent_new}, + {"__tostring", xent_tostring}, + {"id", xent_id}, + {"totalframes", xent_totalframes}, + {"correct", xent_correct}, + {"loss", xent_loss}, + {"entropy", xent_entropy}, + {"add", xent_add}, + {"free", xent_free}, + {NULL, NULL} +}; + +static const struct luaL_Reg mse__ [] = { + {"new", mse_new}, + {"__tostring", mse_tostring}, + {"id", mse_id}, + {"totalframes", mse_totalframes}, + {"loss", mse_loss}, + {"add", mse_add}, + {"free", mse_free}, + {NULL, NULL} +}; + + +static const struct luaL_Reg global_option__ [] = { + {"new", global_option_new}, + {"__tostring", global_option_tostring}, + {"id", global_option_id}, + {"batch_size", global_option_batch_size}, + {"lrate", global_option_lrate}, + {"bp", global_option_bp}, + {"tr_scp", global_option_tr_scp}, + {"cv_scp", global_option_cv_scp}, + {"transf", global_option_transf}, + {"network", global_option_network}, + {"free", global_option_free}, + {NULL, NULL} +}; + +void fastnn_init_modelsync(lua_State *L) +{ + luaT_newmetatable(L, fastnn_model_sync_tname, NULL, model_sync_new, NULL, NULL); + luaL_register(L, NULL, model_sync__); + lua_pop(L, 1); + + luaT_newmetatable(L, fastnn_xent_tname, NULL, xent_new, xent_free, NULL); + luaL_register(L, NULL, xent__); + lua_pop(L, 1); + + luaT_newmetatable(L, fastnn_mse_tname, NULL, mse_new, mse_free, NULL); + luaL_register(L, NULL, mse__); + lua_pop(L, 1); + + luaT_newmetatable(L, fastnn_global_option_tname, NULL, global_option_new, global_option_free, NULL); + luaL_register(L, NULL, global_option__); + lua_pop(L, 1); + /* + printf("%s %lx\n", model_sync__[13].name, model_sync__[13].func); + + if(!luaL_newmetatable(L, fastnn_model_sync_tname)) + luaL_error(L, "fastnn: fastnn.modelsync type already exists"); + luaL_setfuncs(L, model_sync__, 0); + lua_pushstring(L, "__index"); + lua_pushvalue(L, -2); + lua_rawset(L, -3); + lua_pop(L, 1); + + printf("%s %lx\n", model_sync__[13].name, model_sync__[13].func); + + lua_pushstring(L, "modelsync"); + luaTHRD_pushctortable(L, model_sync_new, fastnn_model_sync_tname); + lua_rawset(L, -3); + printf("%s %lx\n", model_sync__[13].name, model_sync__[13].func); + */ +} + + diff --git a/fastnn/lib/modelsync.lua b/fastnn/lib/modelsync.lua new file mode 100644 index 0000000..a247562 --- /dev/null +++ b/fastnn/lib/modelsync.lua @@ -0,0 +1,107 @@ + +local C = require 'libfastnn' +local T = require 'libthreads' + +local ModelSync = nerv.class("fastnn.ModelSync") + +fastnn.CModelSync = C.CModelSync +fastnn.Thread = T.Thread + + +function ModelSync:__init(shareid) + self.modelsync = fastnn.CModelSync(shareid) +-- print(self.modelsync.initbuffer) + --print(self.modelsync.setpos) + --print(self.modelsync.initialized) + --print(self.modelsync.weightfromd) +-- print(self.modelsync.weighttod) +-- print(self.modelsync.aaaa) +-- print(self.modelsync.bbbb) +-- print(self.modelsync.cccc) +end + +function ModelSync:GetDim(nnet) + + local repo = nnet:get_params() + local params = repo.params + local dim = 0 + for pid, ref in pairs(params) do + if nerv.is_type(ref.trans, "nerv.Matrix") then + dim = dim + ref.trans:nrow() * ref.trans:nstride() + end + end + + return dim +end + + +function ModelSync:Initialize(nnet) + + self:LockModel() + + if not self.modelsync:initialized() then + dim = self:GetDim(nnet) + self.modelsync:initbuffer(dim) + self:WeightFromD(nnet) + end + + self:UnLockModel() +end + +function ModelSync:WeightFromD(nnet) + local repo = nnet:get_params() + local params = repo.params + self.modelsync:setpos(0) + for pid, ref in pairs(params) do + if nerv.is_type(ref.trans, "nerv.Matrix") then + self.modelsync:weightfromd(ref.trans) + end + end +end + +function ModelSync:WeightToD(nnet) + local repo = nnet:get_params() + local params = repo.params + self.modelsync:setpos(0) + for pid, ref in pairs(params) do + if nerv.is_type(ref.trans, "nerv.Matrix") then + self.modelsync:weighttod(ref.trans) + end + end +end + +function ModelSync:LockState() + self.modelsync:lockstate() +end + +function ModelSync:UnLockState() + self.modelsync:unlockstate() +end + + +function ModelSync:LockModel() + self.modelsync:lockmodel() +end + + +function ModelSync:UnLockModel() + self.modelsync:unlockmodel() +end + +function ModelSync:Id() + return self.modelsync:id() +end + +function ModelSync:ThreadCount() + return self.modelsync:threadcount() +end + +function ModelSync:SyncInc() + return self.modelsync:syncinc() +end + +function ModelSync:SyncDec() + return self.modelsync:syncdec() +end + + |