diff options
Diffstat (limited to 'fastnn/lib/ModelSync.h')
-rw-r--r-- | fastnn/lib/ModelSync.h | 119 |
1 files changed, 119 insertions, 0 deletions
diff --git a/fastnn/lib/ModelSync.h b/fastnn/lib/ModelSync.h new file mode 100644 index 0000000..71216a0 --- /dev/null +++ b/fastnn/lib/ModelSync.h @@ -0,0 +1,119 @@ + +#ifndef NERV_FASTNN_MODELSYNC_H +#define NERV_FASTNN_MODELSYNC_H + +#define STRLEN 1024 + +#include "../threads/lib/THThread.h" +#include "matrix/matrix.h" +#include "stdlib.h" +#include "stdbool.h" + +typedef struct NnetParallelOptions_ +{ + int num_threads; + int merge_size; + int num_merge; + int num_procs; + int threadid; + int myid; + int thread_level; + char merge_func[STRLEN]; + char log_file[STRLEN]; +} NnetParallelOptions; + + +typedef struct ModelSync_ +{ + THMutex *model_mutex; + THMutex *state_mutex; + bool initialized_; + int dim_; + int pos_; + float *data_; + float *free_data_; + int refcount; + int threadcount; +}ModelSync; + +ModelSync *ModelSync_new(void); +ModelSync *ModelSync_newWithId(long id); +int ModelSync_free(ModelSync *self); +long ModelSync_id(ModelSync *self); +int ModelSync_lockmodel(ModelSync *self); +int ModelSync_unlockmodel(ModelSync *self); +int ModelSync_lockstate(ModelSync *self); +int ModelSync_unlockstate(ModelSync *self); +int ModelSync_initBuffer(ModelSync *self); +int ModelSync_weightfromd(ModelSync *self, Matrix *dm); +int ModelSync_weighttod(ModelSync *self, Matrix *dm); +int ModelSync_threadcount(ModelSync *self); +void ModelSync_syncinc(ModelSync *self); +void ModelSync_syncdec(ModelSync *self); + +typedef struct Xent_ +{ + size_t frames_; + size_t correct_; + double loss_; + double entropy_; + int refcount; +} Xent; + +Xent* Xent_new(); +Xent* Xent_newWithId(long id); +Xent* Xent_newWithParm(size_t frames_, size_t correct_, double loss_, double entropy_); +long Xent_id(Xent *xent); +Xent* Xent_add(Xent *a, Xent *b); +void Xent_free(Xent *xent); + +typedef struct Mse_ +{ + size_t frames_; + double loss_; + int refcount; +} Mse; + +Mse* Mse_new(); +Mse* Mse_newWithId(long id); +Mse* Mse_newWithParm(size_t frames_, double loss_); +long Mse_id(Mse *mse); +Mse* Mse_add(Mse *a, Mse *b); +void Mse_free(Mse *mse); + +typedef struct NnetUpdateState_ +{ + int num_utter; + int num_nolabel; + int num_other_error; + long total_frames; + Xent xent; + Mse mse; +} NnetUpdateState; + +typedef struct GlobalOption_ +{ + int batch_size; + float lrate; + bool bp; + char tr_scp[STRLEN]; + char cv_scp[STRLEN]; + char transf[STRLEN]; + char network[STRLEN]; + int refcount; +}GlobalOption; + + +GlobalOption* GlobalOption_new(); +GlobalOption* GlobalOption_newWithParm(int batch_size, float lrate, bool bp, const char *tr_scp, const char *cv_scp, const char *transf, const char *network); +GlobalOption* GlobalOption_newWithId(long id); +long GlobalOption_id(GlobalOption *option); +void GlobalOption_free(GlobalOption *option); + + + + +#endif + + + |