aboutsummaryrefslogtreecommitdiff
path: root/fastnn/lib
diff options
context:
space:
mode:
Diffstat (limited to 'fastnn/lib')
-rw-r--r--fastnn/lib/ModelSync.c305
-rw-r--r--fastnn/lib/ModelSync.h119
-rw-r--r--fastnn/lib/modelsync.c532
-rw-r--r--fastnn/lib/modelsync.lua107
4 files changed, 1063 insertions, 0 deletions
diff --git a/fastnn/lib/ModelSync.c b/fastnn/lib/ModelSync.c
new file mode 100644
index 0000000..bd511ea
--- /dev/null
+++ b/fastnn/lib/ModelSync.c
@@ -0,0 +1,305 @@
+
+#include "ModelSync.h"
+#include "../../nerv/lib/matrix/cuda_helper.h"
+#include "../../nerv/lib/matrix/generic/elem_type.h"
+#include "common.h"
+#include <string.h>
+
+
+ModelSync* ModelSync_new(void)
+{
+ ModelSync *self = (ModelSync*)malloc(sizeof(ModelSync));
+ if (NULL != self)
+ {
+ self->model_mutex = THMutex_new();
+ self->state_mutex = THMutex_new();
+ self->initialized_ = false;
+ self->dim_ = 0;
+ self->pos_ = 0;
+ self->data_ = NULL;
+ self->free_data_ = NULL;
+ self->data_ = NULL;
+ self->refcount = 1;
+ self->threadcount = 0;
+ }
+ return self;
+}
+
+ModelSync* ModelSync_newWithId(long id)
+{
+ ModelSync *self = (ModelSync*)id;
+ __sync_fetch_and_add(&self->refcount, 1);
+ return self;
+}
+
+long ModelSync_id(ModelSync *self)
+{
+ return (long)(self);
+}
+
+int ModelSync_lockmodel(ModelSync *self)
+{
+ if(THMutex_lock(self->model_mutex))
+ return 1;
+ return 0;
+}
+
+int ModelSync_unlockmodel(ModelSync *self)
+{
+ if(THMutex_unlock(self->model_mutex))
+ return 1;
+ return 0;
+
+}
+int ModelSync_lockstate(ModelSync *self)
+{
+ if(THMutex_lock(self->state_mutex))
+ return 1;
+ return 0;
+}
+
+int ModelSync_unlockstate(ModelSync *self)
+{
+ if(THMutex_unlock(self->state_mutex))
+ return 1;
+ return 0;
+}
+
+int ModelSync_free(ModelSync *self)
+{
+ if (NULL != self && __sync_fetch_and_add(&self->refcount, -1) == 1)
+ {
+ free(self->model_mutex);
+ free(self->state_mutex);
+ Status status;
+ CUDA_SAFE_SYNC_CALL(cudaFreeHost(self->free_data_), &status);
+ free(self);
+ }
+}
+
+int ModelSync_initBuffer(ModelSync *self)
+{
+ if (NULL != self)
+ {
+ void *free_data = NULL, *data = NULL;
+ size_t size = self->dim_ * sizeof(float)+16;
+ Status status;
+ CUDA_SAFE_SYNC_CALL(cudaHostAlloc((void**) &free_data, size, cudaHostAllocPortable), &status);
+ NERV_SET_STATUS(&status, NERV_NORMAL, 0);
+
+ data = (free_data ? (void *)( (((unsigned long)*(&free_data)) + 15) & ~0xFUL ) : NULL) ;
+ if (NULL != data)
+ {
+ self->data_ = (float*)(data);
+ self->free_data_ = (float*)(free_data);
+ }
+ return 0;
+ }
+ return 1;
+}
+
+int ModelSync_weightfromd(ModelSync *self, Matrix *dm)
+{
+
+ if (NULL != self && NULL != dm)
+ {
+ void *host_data_ = (void*)self->data_;
+ size_t width = dm->ncol * sizeof(float);
+ size_t src_pitch = dm->stride;
+ size_t dst_pitch = src_pitch;
+ Status status;
+
+ CUDA_SAFE_SYNC_CALL(cudaMemcpy2D(host_data_+self->pos_, dst_pitch, dm->data.f, src_pitch, width, dm->nrow, cudaMemcpyDeviceToHost), &status);
+ NERV_SET_STATUS(&status, NERV_NORMAL, 0);
+ self->pos_ += dm->nrow * dm->stride;
+ return 0;
+ }
+ return 1;
+
+}
+
+
+int ModelSync_weighttod(ModelSync *self, Matrix *dm)
+{
+
+ if (NULL != self && NULL != dm)
+ {
+ void *host_data_ = (void*)self->data_;
+ size_t width = dm->ncol * sizeof(float);
+ size_t dst_pitch = dm->stride;
+ size_t src_pitch = dst_pitch;
+ Status status;
+
+ CUDA_SAFE_SYNC_CALL(cudaMemcpy2D(dm->data.f, dst_pitch, host_data_+self->pos_, src_pitch, width, dm->nrow, cudaMemcpyHostToDevice), &status);
+ NERV_SET_STATUS(&status, NERV_NORMAL, 0);
+
+ self->pos_ += dm->nrow * dm->stride;
+ self->initialized_ = true;
+ return 0;
+ }
+ return 1;
+}
+
+void ModelSync_syncinc(ModelSync *self)
+{
+ __sync_fetch_and_add(&self->threadcount, 1);
+}
+
+void ModelSync_syncdec(ModelSync *self)
+{
+ __sync_fetch_and_add(&self->threadcount, -1);
+}
+
+int ModelSync_threadcount(ModelSync *self)
+{
+ return self->threadcount;
+}
+
+/////////////////////////////////
+
+Xent* Xent_new()
+{
+ Xent *xent = (Xent*)malloc(sizeof(Xent));
+ memset(xent, 0, sizeof(Xent));
+ xent->refcount = 1;
+ return xent;
+}
+
+Xent* Xent_newWithId(long id)
+{
+ Xent *xent = (Xent*)id;
+ __sync_fetch_and_add(&xent->refcount, 1);
+ return xent;
+}
+
+Xent* Xent_newWithParm(size_t frames_, size_t correct_, double loss_, double entropy_)
+{
+ Xent *xent = (Xent*)malloc(sizeof(Xent));
+ xent->frames_ = frames_;
+ xent->correct_ = correct_;
+ xent->loss_ = loss_;
+ xent->entropy_ = entropy_;
+ xent->refcount = 1;
+ return xent;
+}
+
+long Xent_id(Xent *xent)
+{
+ return (long)(xent);
+}
+
+Xent* Xent_add(Xent *a, Xent *b)
+{
+ a->frames_ += b->frames_;
+ a->correct_ += b->correct_;
+ a->loss_ += b->loss_;
+ a->entropy_ += b->entropy_;
+ return a;
+}
+
+void Xent_free(Xent *xent)
+{
+ if (NULL != xent && __sync_fetch_and_add(&xent->refcount, -1) == 1)
+ {
+ free(xent);
+ xent = NULL;
+ }
+}
+
+
+//////////////////////////////////
+
+Mse* Mse_new()
+{
+ Mse *mse = (Mse*)malloc(sizeof(Mse));
+ memset(mse, 0, sizeof(Mse));
+ mse->refcount = 1;
+ return mse;
+}
+
+Mse* Mse_newWithId(long id)
+{
+ Mse *mse = (Mse*)id;
+ __sync_fetch_and_add(&mse->refcount, 1);
+ return mse;
+}
+
+Mse* Mse_newWithParm(size_t frames_, double loss_)
+{
+ Mse *mse = (Mse*)malloc(sizeof(Mse));
+ mse->frames_ = frames_;
+ mse->loss_ = loss_;
+ mse->refcount = 1;
+ return mse;
+}
+
+
+long Mse_id(Mse *mse)
+{
+ return (long)(mse);
+}
+
+Mse* Mse_add(Mse *a, Mse *b)
+{
+ a->frames_ += b->frames_;
+ a->loss_ += b->loss_;
+ return a;
+}
+
+void Mse_free(Mse *mse)
+{
+ if (NULL != mse && __sync_fetch_and_add(&mse->refcount, -1) == 1)
+ {
+ free(mse);
+ mse = NULL;
+ }
+}
+
+//////////////////////////////////
+
+GlobalOption* GlobalOption_new()
+{
+ GlobalOption *option = (GlobalOption*)malloc(sizeof(GlobalOption));
+ option->refcount = 1;
+ return option;
+}
+
+GlobalOption* GlobalOption_newWithParm(int batch_size, float lrate, bool bp,const char *tr_scp, const char *cv_scp, const char *transf, const char *network)
+{
+ GlobalOption *option = (GlobalOption*)malloc(sizeof(GlobalOption));
+ option->batch_size = batch_size;
+ option->lrate = lrate;
+ option->bp = bp;
+ strncpy(option->tr_scp, tr_scp, strlen(tr_scp)+1);
+ strncpy(option->cv_scp, cv_scp, strlen(cv_scp)+1);
+ strncpy(option->transf, transf, strlen(transf)+1);
+ strncpy(option->network, network, strlen(network)+1);
+ option->refcount = 1;
+
+ return option;
+}
+
+GlobalOption* GlobalOption_newWithId(long id)
+{
+ GlobalOption *option = (GlobalOption*)id;
+ __sync_fetch_and_add(&option->refcount, 1);
+ return option;
+}
+
+
+
+long GlobalOption_id(GlobalOption *option)
+{
+ return (long)(option);
+}
+
+void GlobalOption_free(GlobalOption *option)
+{
+ if (NULL != option && __sync_fetch_and_add(&option->refcount, -1) == 1)
+ {
+ free(option);
+ option = NULL;
+ }
+}
+
+
diff --git a/fastnn/lib/ModelSync.h b/fastnn/lib/ModelSync.h
new file mode 100644
index 0000000..71216a0
--- /dev/null
+++ b/fastnn/lib/ModelSync.h
@@ -0,0 +1,119 @@
+
+#ifndef NERV_FASTNN_MODELSYNC_H
+#define NERV_FASTNN_MODELSYNC_H
+
+#define STRLEN 1024
+
+#include "../threads/lib/THThread.h"
+#include "matrix/matrix.h"
+#include "stdlib.h"
+#include "stdbool.h"
+
+typedef struct NnetParallelOptions_
+{
+ int num_threads;
+ int merge_size;
+ int num_merge;
+ int num_procs;
+ int threadid;
+ int myid;
+ int thread_level;
+ char merge_func[STRLEN];
+ char log_file[STRLEN];
+} NnetParallelOptions;
+
+
+typedef struct ModelSync_
+{
+ THMutex *model_mutex;
+ THMutex *state_mutex;
+ bool initialized_;
+ int dim_;
+ int pos_;
+ float *data_;
+ float *free_data_;
+ int refcount;
+ int threadcount;
+}ModelSync;
+
+ModelSync *ModelSync_new(void);
+ModelSync *ModelSync_newWithId(long id);
+int ModelSync_free(ModelSync *self);
+long ModelSync_id(ModelSync *self);
+int ModelSync_lockmodel(ModelSync *self);
+int ModelSync_unlockmodel(ModelSync *self);
+int ModelSync_lockstate(ModelSync *self);
+int ModelSync_unlockstate(ModelSync *self);
+int ModelSync_initBuffer(ModelSync *self);
+int ModelSync_weightfromd(ModelSync *self, Matrix *dm);
+int ModelSync_weighttod(ModelSync *self, Matrix *dm);
+int ModelSync_threadcount(ModelSync *self);
+void ModelSync_syncinc(ModelSync *self);
+void ModelSync_syncdec(ModelSync *self);
+
+typedef struct Xent_
+{
+ size_t frames_;
+ size_t correct_;
+ double loss_;
+ double entropy_;
+ int refcount;
+} Xent;
+
+Xent* Xent_new();
+Xent* Xent_newWithId(long id);
+Xent* Xent_newWithParm(size_t frames_, size_t correct_, double loss_, double entropy_);
+long Xent_id(Xent *xent);
+Xent* Xent_add(Xent *a, Xent *b);
+void Xent_free(Xent *xent);
+
+typedef struct Mse_
+{
+ size_t frames_;
+ double loss_;
+ int refcount;
+} Mse;
+
+Mse* Mse_new();
+Mse* Mse_newWithId(long id);
+Mse* Mse_newWithParm(size_t frames_, double loss_);
+long Mse_id(Mse *mse);
+Mse* Mse_add(Mse *a, Mse *b);
+void Mse_free(Mse *mse);
+
+typedef struct NnetUpdateState_
+{
+ int num_utter;
+ int num_nolabel;
+ int num_other_error;
+ long total_frames;
+ Xent xent;
+ Mse mse;
+} NnetUpdateState;
+
+typedef struct GlobalOption_
+{
+ int batch_size;
+ float lrate;
+ bool bp;
+ char tr_scp[STRLEN];
+ char cv_scp[STRLEN];
+ char transf[STRLEN];
+ char network[STRLEN];
+ int refcount;
+}GlobalOption;
+
+
+GlobalOption* GlobalOption_new();
+GlobalOption* GlobalOption_newWithParm(int batch_size, float lrate, bool bp, const char *tr_scp, const char *cv_scp, const char *transf, const char *network);
+GlobalOption* GlobalOption_newWithId(long id);
+long GlobalOption_id(GlobalOption *option);
+void GlobalOption_free(GlobalOption *option);
+
+
+
+
+#endif
+
+
+
diff --git a/fastnn/lib/modelsync.c b/fastnn/lib/modelsync.c
new file mode 100644
index 0000000..2b52752
--- /dev/null
+++ b/fastnn/lib/modelsync.c
@@ -0,0 +1,532 @@
+
+#include <stdio.h>
+#include <stdlib.h>
+
+#include <lua.h>
+#include <lualib.h>
+#include <luaT/luaT.h>
+
+
+
+#include "ModelSync.h"
+#include "../threads/lib/luaTHRD.h"
+
+const char *fastnn_model_sync_tname = "fastnn.CModelSync";
+const char *fastnn_xent_tname = "fastnn.CXent";
+const char *fastnn_mse_tname = "fastnn.CMse";
+const char *fastnn_global_option_tname = "fastnn.CGlobalOption";
+
+static int model_sync_new(lua_State *L)
+{
+ ModelSync *model_sync = NULL;
+ if(lua_gettop(L) == 0)
+ {
+ model_sync = ModelSync_new();
+ }
+ else if(lua_gettop(L) == 1)
+ {
+ long id = luaL_checklong(L, 1);
+ model_sync = ModelSync_newWithId(id);
+ }
+ else
+ luaL_error(L, "modelsync: modelsync new invalid arguments");
+ if(!model_sync)
+ luaL_error(L, "modelsync: modelsync new failed");
+
+ luaTHRD_pushudata(L, model_sync, fastnn_model_sync_tname);
+ return 1;
+}
+
+
+static int model_sync_tostring(lua_State *L)
+{
+ char str[STRLEN];
+ ModelSync *model_sync = luaTHRD_checkudata(L, 1, fastnn_model_sync_tname);
+ snprintf(str, STRLEN, "fastnn.modelsync <%lx>", ModelSync_id(model_sync));
+ lua_pushstring(L, str);
+ return 1;
+}
+
+static int model_sync_id(lua_State *L)
+{
+ ModelSync *model_sync = luaTHRD_checkudata(L, 1, fastnn_model_sync_tname);
+ lua_pushinteger(L, ModelSync_id(model_sync));
+ return 1;
+}
+
+static int model_sync_lockmodel(lua_State *L)
+{
+ ModelSync *model_sync = luaTHRD_checkudata(L, 1, fastnn_model_sync_tname);
+ if (ModelSync_lockmodel(model_sync))
+ luaL_error(L, "modelsync: model lock failed");
+ return 0;
+}
+
+static int model_sync_unlockmodel(lua_State *L)
+{
+ ModelSync *model_sync = luaTHRD_checkudata(L, 1, fastnn_model_sync_tname);
+ if (ModelSync_unlockmodel(model_sync))
+ luaL_error(L, "modelsync: model unlock failed");
+ return 0;
+}
+
+static int model_sync_lockstate(lua_State *L)
+{
+ ModelSync *model_sync = luaTHRD_checkudata(L, 1, fastnn_model_sync_tname);
+ if (ModelSync_lockstate(model_sync))
+ luaL_error(L, "modelsync: state lock failed");
+ return 0;
+}
+
+static int model_sync_unlockstate(lua_State *L)
+{
+ ModelSync *model_sync = luaTHRD_checkudata(L, 1, fastnn_model_sync_tname);
+ if (ModelSync_unlockstate(model_sync))
+ luaL_error(L, "modelsync: state unlock failed");
+ return 0;
+}
+
+static int model_sync_free(lua_State *L)
+{
+ ModelSync *model_sync = luaTHRD_checkudata(L, 1, fastnn_model_sync_tname);
+ ModelSync_free(model_sync);
+ return 0;
+}
+
+static int model_sync_initbuffer(lua_State *L)
+{
+ ModelSync *model_sync = luaTHRD_checkudata(L, 1, fastnn_model_sync_tname);
+ model_sync->dim_ = luaL_checkinteger(L, 2);
+ ModelSync_initBuffer(model_sync);
+ return 0;
+}
+
+static int model_sync_weightfromd(lua_State *L)
+{
+ ModelSync *model_sync = luaTHRD_checkudata(L, 1, fastnn_model_sync_tname);
+ Matrix *dm = luaT_checkudata(L, 2, "nerv.CuMatrixFloat");
+ ModelSync_weightfromd(model_sync, dm);
+ return 0;
+}
+
+static int model_sync_weighttod(lua_State *L)
+{
+ ModelSync *model_sync = luaTHRD_checkudata(L, 1, fastnn_model_sync_tname);
+ Matrix *dm = luaT_checkudata(L, 2, "nerv.CuMatrixFloat");
+ ModelSync_weighttod(model_sync, dm);
+ return 0;
+}
+
+static int model_sync_initialized(lua_State *L)
+{
+ ModelSync *model_sync = luaTHRD_checkudata(L, 1, fastnn_model_sync_tname);
+ lua_pushboolean(L, model_sync->initialized_);
+ return 1;
+}
+
+static int model_sync_setpos(lua_State *L)
+{
+ ModelSync *model_sync = luaTHRD_checkudata(L, 1, fastnn_model_sync_tname);
+ int pos = luaL_checkinteger(L, 2);
+ model_sync->pos_ = pos;
+ return 0;
+}
+
+static int model_sync_threadcount(lua_State *L)
+{
+ ModelSync *model_sync = luaTHRD_checkudata(L, 1, fastnn_model_sync_tname);
+ lua_pushinteger(L, ModelSync_threadcount(model_sync));
+ return 1;
+}
+
+static int model_sync_syncinc(lua_State *L)
+{
+ ModelSync *model_sync = luaTHRD_checkudata(L, 1, fastnn_model_sync_tname);
+ ModelSync_syncinc(model_sync);
+ return 0;
+}
+
+static int model_sync_syncdec(lua_State *L)
+{
+ ModelSync *model_sync = luaTHRD_checkudata(L, 1, fastnn_model_sync_tname);
+ ModelSync_syncdec(model_sync);
+ return 0;
+}
+//////////////////////////////////////////
+
+static int xent_new(lua_State *L)
+{
+ Xent *xent = NULL;
+ if(lua_gettop(L) == 0)
+ {
+ xent = Xent_new();
+ }
+ else if(lua_gettop(L) == 1)
+ {
+ long id = luaL_checklong(L, 1);
+ xent = Xent_newWithId(id);
+ }
+ else if(lua_gettop(L) == 4)
+ {
+ size_t frames_, correct_;
+ double loss_, entropy_ ;
+ frames_ = luaL_checkinteger(L, 1);
+ correct_ = luaL_checkinteger(L, 2);
+ loss_ = luaL_checknumber(L, 3);
+ entropy_ = luaL_checknumber(L, 4);
+ xent = Xent_newWithParm(frames_, correct_, loss_, entropy_);
+ }
+ else
+ luaL_error(L, "xent: xent new invalid arguments");
+ if(!xent)
+ luaL_error(L, "xent: xent new failed");
+
+ luaTHRD_pushudata(L, xent, fastnn_xent_tname);
+ return 1;
+}
+
+static int xent_tostring(lua_State *L)
+{
+ char str[STRLEN];
+ Xent *xent = luaTHRD_checkudata(L, 1, fastnn_xent_tname);
+ snprintf(str, STRLEN, "fastnn.xent <%lx>", Xent_id(xent));
+ lua_pushstring(L, str);
+ return 1;
+}
+
+static int xent_totalframes(lua_State *L)
+{
+ Xent *xent = luaTHRD_checkudata(L, 1, fastnn_xent_tname);
+ lua_pushinteger(L, xent->frames_);
+ return 1;
+}
+
+static int xent_correct(lua_State *L)
+{
+ Xent *xent = luaTHRD_checkudata(L, 1, fastnn_xent_tname);
+ lua_pushinteger(L, xent->correct_);
+ return 1;
+}
+
+static int xent_loss(lua_State *L)
+{
+ Xent *xent = luaTHRD_checkudata(L, 1, fastnn_xent_tname);
+ lua_pushnumber(L, xent->loss_);
+ return 1;
+}
+
+static int xent_entropy(lua_State *L)
+{
+ Xent *xent = luaTHRD_checkudata(L, 1, fastnn_xent_tname);
+ lua_pushnumber(L, xent->entropy_);
+ return 1;
+}
+
+static int xent_id(lua_State *L)
+{
+ Xent *xent = luaTHRD_checkudata(L, 1, fastnn_xent_tname);
+ lua_pushinteger(L, Xent_id(xent));
+ return 1;
+}
+
+static int xent_free(lua_State *L)
+{
+ Xent *xent = luaTHRD_checkudata(L, 1, fastnn_xent_tname);
+ Xent_free(xent);
+ return 0;
+}
+
+static int xent_add(lua_State *L)
+{
+ Xent *a = luaTHRD_checkudata(L, 1, fastnn_xent_tname);
+ Xent *b = luaTHRD_checkudata(L, 2, fastnn_xent_tname);
+ Xent_add(a, b);
+ return 0;
+}
+
+/////////////////////////////////////////////
+
+static int mse_new(lua_State *L)
+{
+ Mse *mse = NULL;
+ if(lua_gettop(L) == 0)
+ {
+ mse = Mse_new();
+ }
+ else if(lua_gettop(L) == 1)
+ {
+ long id = luaL_checklong(L, 1);
+ mse = Mse_newWithId(id);
+ }
+ else if(lua_gettop(L) == 2)
+ {
+ size_t frames_;
+ double loss_;
+ frames_ = luaL_checkinteger(L, 1);
+ loss_ = luaL_checknumber(L, 2);
+ mse = Mse_newWithParm(frames_, loss_);
+ }
+ else
+ luaL_error(L, "mse: mse new invalid arguments");
+ if(!mse)
+ luaL_error(L, "mse: mse new failed");
+
+ luaTHRD_pushudata(L, mse, fastnn_mse_tname);
+ return 1;
+
+}
+
+static int mse_tostring(lua_State *L)
+{
+ char str[STRLEN];
+ Mse *mse = luaTHRD_checkudata(L, 1, fastnn_mse_tname);
+ snprintf(str, STRLEN, "fastnn.mse <%lx>", Mse_id(mse));
+ lua_pushstring(L, str);
+ return 1;
+}
+
+static int mse_totalframes(lua_State *L)
+{
+ Mse *mse = luaTHRD_checkudata(L, 1, fastnn_mse_tname);
+ lua_pushinteger(L, mse->frames_);
+ return 1;
+}
+
+static int mse_loss(lua_State *L)
+{
+ Mse *mse = luaTHRD_checkudata(L, 1, fastnn_mse_tname);
+ lua_pushnumber(L, mse->loss_);
+ return 1;
+}
+
+static int mse_id(lua_State *L)
+{
+ Mse *mse = luaTHRD_checkudata(L, 1, fastnn_mse_tname);
+ lua_pushinteger(L, Mse_id(mse));
+ return 1;
+}
+
+static int mse_free(lua_State *L)
+{
+ Mse *mse = luaTHRD_checkudata(L, 1, fastnn_mse_tname);
+ Mse_free(mse);
+ return 0;
+}
+
+static int mse_add(lua_State *L)
+{
+ Mse *a = luaTHRD_checkudata(L, 1, fastnn_mse_tname);
+ Mse *b = luaTHRD_checkudata(L, 2, fastnn_mse_tname);
+ Mse_add(a, b);
+ return 0;
+}
+/////////////////////////////////////////////
+
+static int global_option_new(lua_State *L)
+{
+ GlobalOption *global_option = NULL;
+ if(lua_gettop(L) == 0)
+ {
+ global_option = GlobalOption_new();
+ }
+ else if(lua_gettop(L) == 1)
+ {
+ long id = luaL_checklong(L, 1);
+ global_option = GlobalOption_newWithId(id);
+ }
+ else if(lua_gettop(L) > 3)
+ {
+ int batch_size = luaL_checkinteger(L, 1);
+ float lrate = luaL_checknumber(L, 2);
+ bool bp = lua_toboolean(L, 3);
+ const char *tr_scp = lua_tostring(L, 4);
+ const char *cv_scp = lua_tostring(L, 5);
+ const char *transf = lua_tostring(L, 6);
+ const char *network = lua_tostring(L, 7);
+ global_option = GlobalOption_newWithParm(batch_size, lrate, bp, tr_scp, cv_scp, transf, network);
+ }
+ else
+ luaL_error(L, "global_option: global_option new invalid arguments");
+ if(!global_option)
+ luaL_error(L, "global_option: global_option new failed");
+
+ luaTHRD_pushudata(L, global_option, fastnn_global_option_tname);
+ return 1;
+
+}
+
+static int global_option_tostring(lua_State *L)
+{
+ char str[STRLEN];
+ GlobalOption *global_option = luaTHRD_checkudata(L, 1, fastnn_global_option_tname);
+ snprintf(str, STRLEN, "fastnn.global_option <%lx>", GlobalOption_id(global_option));
+ lua_pushstring(L, str);
+ return 1;
+}
+
+static int global_option_id(lua_State *L)
+{
+ GlobalOption *global_option = luaTHRD_checkudata(L, 1, fastnn_global_option_tname);
+ lua_pushinteger(L, GlobalOption_id(global_option));
+ return 1;
+}
+
+static int global_option_free(lua_State *L)
+{
+ GlobalOption *global_option = luaTHRD_checkudata(L, 1, fastnn_global_option_tname);
+ GlobalOption_free(global_option);
+ return 0;
+}
+
+static int global_option_batch_size(lua_State *L)
+{
+ GlobalOption *global_option = luaTHRD_checkudata(L, 1, fastnn_global_option_tname);
+ lua_pushinteger(L, global_option->batch_size);
+ return 1;
+}
+
+static int global_option_lrate(lua_State *L)
+{
+ GlobalOption *global_option = luaTHRD_checkudata(L, 1, fastnn_global_option_tname);
+ lua_pushnumber(L, global_option->lrate);
+ return 1;
+}
+
+static int global_option_bp(lua_State *L)
+{
+ GlobalOption *global_option = luaTHRD_checkudata(L, 1, fastnn_global_option_tname);
+ lua_pushboolean(L, global_option->bp);
+ return 1;
+}
+
+static int global_option_tr_scp(lua_State *L)
+{
+ GlobalOption *global_option = luaTHRD_checkudata(L, 1, fastnn_global_option_tname);
+ lua_pushstring(L, global_option->tr_scp);
+ return 1;
+}
+
+static int global_option_cv_scp(lua_State *L)
+{
+ GlobalOption *global_option = luaTHRD_checkudata(L, 1, fastnn_global_option_tname);
+ lua_pushstring(L, global_option->cv_scp);
+ return 1;
+}
+
+static int global_option_transf(lua_State *L)
+{
+ GlobalOption *global_option = luaTHRD_checkudata(L, 1, fastnn_global_option_tname);
+ lua_pushstring(L, global_option->transf);
+ return 1;
+}
+
+static int global_option_network(lua_State *L)
+{
+ GlobalOption *global_option = luaTHRD_checkudata(L, 1, fastnn_global_option_tname);
+ lua_pushstring(L, global_option->network);
+ return 1;
+}
+
+
+//////////////////////////////////////////////
+
+static const struct luaL_Reg model_sync__ [] = {
+ {"new", model_sync_new},
+ {"__tostring", model_sync_tostring},
+ {"id", model_sync_id},
+ {"lockmodel", model_sync_lockmodel},
+ {"unlockmodel", model_sync_unlockmodel},
+ {"lockstate", model_sync_lockstate},
+ {"unlockstate", model_sync_unlockstate},
+ {"initbuffer", model_sync_initbuffer},
+ {"setpos", model_sync_setpos},
+ {"initialized", model_sync_initialized},
+ {"weightfromd", model_sync_weightfromd},
+ {"weighttod", model_sync_weighttod},
+ {"threadcount", model_sync_threadcount},
+ {"syncinc", model_sync_syncinc},
+ {"syncdec", model_sync_syncdec},
+ {"threadcount", model_sync_threadcount},
+ {"free", model_sync_free},
+ {NULL, NULL}
+};
+
+
+static const struct luaL_Reg xent__ [] = {
+ {"new", xent_new},
+ {"__tostring", xent_tostring},
+ {"id", xent_id},
+ {"totalframes", xent_totalframes},
+ {"correct", xent_correct},
+ {"loss", xent_loss},
+ {"entropy", xent_entropy},
+ {"add", xent_add},
+ {"free", xent_free},
+ {NULL, NULL}
+};
+
+static const struct luaL_Reg mse__ [] = {
+ {"new", mse_new},
+ {"__tostring", mse_tostring},
+ {"id", mse_id},
+ {"totalframes", mse_totalframes},
+ {"loss", mse_loss},
+ {"add", mse_add},
+ {"free", mse_free},
+ {NULL, NULL}
+};
+
+
+static const struct luaL_Reg global_option__ [] = {
+ {"new", global_option_new},
+ {"__tostring", global_option_tostring},
+ {"id", global_option_id},
+ {"batch_size", global_option_batch_size},
+ {"lrate", global_option_lrate},
+ {"bp", global_option_bp},
+ {"tr_scp", global_option_tr_scp},
+ {"cv_scp", global_option_cv_scp},
+ {"transf", global_option_transf},
+ {"network", global_option_network},
+ {"free", global_option_free},
+ {NULL, NULL}
+};
+
+void fastnn_init_modelsync(lua_State *L)
+{
+ luaT_newmetatable(L, fastnn_model_sync_tname, NULL, model_sync_new, NULL, NULL);
+ luaL_register(L, NULL, model_sync__);
+ lua_pop(L, 1);
+
+ luaT_newmetatable(L, fastnn_xent_tname, NULL, xent_new, xent_free, NULL);
+ luaL_register(L, NULL, xent__);
+ lua_pop(L, 1);
+
+ luaT_newmetatable(L, fastnn_mse_tname, NULL, mse_new, mse_free, NULL);
+ luaL_register(L, NULL, mse__);
+ lua_pop(L, 1);
+
+ luaT_newmetatable(L, fastnn_global_option_tname, NULL, global_option_new, global_option_free, NULL);
+ luaL_register(L, NULL, global_option__);
+ lua_pop(L, 1);
+ /*
+ printf("%s %lx\n", model_sync__[13].name, model_sync__[13].func);
+
+ if(!luaL_newmetatable(L, fastnn_model_sync_tname))
+ luaL_error(L, "fastnn: fastnn.modelsync type already exists");
+ luaL_setfuncs(L, model_sync__, 0);
+ lua_pushstring(L, "__index");
+ lua_pushvalue(L, -2);
+ lua_rawset(L, -3);
+ lua_pop(L, 1);
+
+ printf("%s %lx\n", model_sync__[13].name, model_sync__[13].func);
+
+ lua_pushstring(L, "modelsync");
+ luaTHRD_pushctortable(L, model_sync_new, fastnn_model_sync_tname);
+ lua_rawset(L, -3);
+ printf("%s %lx\n", model_sync__[13].name, model_sync__[13].func);
+ */
+}
+
+
diff --git a/fastnn/lib/modelsync.lua b/fastnn/lib/modelsync.lua
new file mode 100644
index 0000000..a247562
--- /dev/null
+++ b/fastnn/lib/modelsync.lua
@@ -0,0 +1,107 @@
+
+local C = require 'libfastnn'
+local T = require 'libthreads'
+
+local ModelSync = nerv.class("fastnn.ModelSync")
+
+fastnn.CModelSync = C.CModelSync
+fastnn.Thread = T.Thread
+
+
+function ModelSync:__init(shareid)
+ self.modelsync = fastnn.CModelSync(shareid)
+-- print(self.modelsync.initbuffer)
+ --print(self.modelsync.setpos)
+ --print(self.modelsync.initialized)
+ --print(self.modelsync.weightfromd)
+-- print(self.modelsync.weighttod)
+-- print(self.modelsync.aaaa)
+-- print(self.modelsync.bbbb)
+-- print(self.modelsync.cccc)
+end
+
+function ModelSync:GetDim(nnet)
+
+ local repo = nnet:get_params()
+ local params = repo.params
+ local dim = 0
+ for pid, ref in pairs(params) do
+ if nerv.is_type(ref.trans, "nerv.Matrix") then
+ dim = dim + ref.trans:nrow() * ref.trans:nstride()
+ end
+ end
+
+ return dim
+end
+
+
+function ModelSync:Initialize(nnet)
+
+ self:LockModel()
+
+ if not self.modelsync:initialized() then
+ dim = self:GetDim(nnet)
+ self.modelsync:initbuffer(dim)
+ self:WeightFromD(nnet)
+ end
+
+ self:UnLockModel()
+end
+
+function ModelSync:WeightFromD(nnet)
+ local repo = nnet:get_params()
+ local params = repo.params
+ self.modelsync:setpos(0)
+ for pid, ref in pairs(params) do
+ if nerv.is_type(ref.trans, "nerv.Matrix") then
+ self.modelsync:weightfromd(ref.trans)
+ end
+ end
+end
+
+function ModelSync:WeightToD(nnet)
+ local repo = nnet:get_params()
+ local params = repo.params
+ self.modelsync:setpos(0)
+ for pid, ref in pairs(params) do
+ if nerv.is_type(ref.trans, "nerv.Matrix") then
+ self.modelsync:weighttod(ref.trans)
+ end
+ end
+end
+
+function ModelSync:LockState()
+ self.modelsync:lockstate()
+end
+
+function ModelSync:UnLockState()
+ self.modelsync:unlockstate()
+end
+
+
+function ModelSync:LockModel()
+ self.modelsync:lockmodel()
+end
+
+
+function ModelSync:UnLockModel()
+ self.modelsync:unlockmodel()
+end
+
+function ModelSync:Id()
+ return self.modelsync:id()
+end
+
+function ModelSync:ThreadCount()
+ return self.modelsync:threadcount()
+end
+
+function ModelSync:SyncInc()
+ return self.modelsync:syncinc()
+end
+
+function ModelSync:SyncDec()
+ return self.modelsync:syncdec()
+end
+
+