#ifndef NERV_FASTNN_MODELSYNC_H #define NERV_FASTNN_MODELSYNC_H #define STRLEN 1024 #include "../threads/lib/THThread.h" #include "matrix/matrix.h" #include "stdlib.h" #include "stdbool.h" typedef struct NnetParallelOptions_ { int num_threads; int merge_size; int num_merge; int num_procs; int threadid; int myid; int thread_level; char merge_func[STRLEN]; char log_file[STRLEN]; } NnetParallelOptions; typedef struct ModelSync_ { THMutex *model_mutex; THMutex *state_mutex; bool initialized_; int dim_; int pos_; float *data_; float *free_data_; int refcount; int threadcount; }ModelSync; ModelSync *ModelSync_new(void); ModelSync *ModelSync_newWithId(long id); int ModelSync_free(ModelSync *self); long ModelSync_id(ModelSync *self); int ModelSync_lockmodel(ModelSync *self); int ModelSync_unlockmodel(ModelSync *self); int ModelSync_lockstate(ModelSync *self); int ModelSync_unlockstate(ModelSync *self); int ModelSync_initBuffer(ModelSync *self); int ModelSync_weightfromd(ModelSync *self, Matrix *dm); int ModelSync_weighttod(ModelSync *self, Matrix *dm); int ModelSync_threadcount(ModelSync *self); void ModelSync_syncinc(ModelSync *self); void ModelSync_syncdec(ModelSync *self); typedef struct Xent_ { size_t frames_; size_t correct_; double loss_; double entropy_; int refcount; } Xent; Xent* Xent_new(); Xent* Xent_newWithId(long id); Xent* Xent_newWithParm(size_t frames_, size_t correct_, double loss_, double entropy_); long Xent_id(Xent *xent); Xent* Xent_add(Xent *a, Xent *b); void Xent_free(Xent *xent); typedef struct Mse_ { size_t frames_; double loss_; int refcount; } Mse; Mse* Mse_new(); Mse* Mse_newWithId(long id); Mse* Mse_newWithParm(size_t frames_, double loss_); long Mse_id(Mse *mse); Mse* Mse_add(Mse *a, Mse *b); void Mse_free(Mse *mse); typedef struct NnetUpdateState_ { int num_utter; int num_nolabel; int num_other_error; long total_frames; Xent xent; Mse mse; } NnetUpdateState; typedef struct GlobalOption_ { int batch_size; float lrate; bool bp; char tr_scp[STRLEN]; char cv_scp[STRLEN]; char transf[STRLEN]; char network[STRLEN]; int refcount; }GlobalOption; GlobalOption* GlobalOption_new(); GlobalOption* GlobalOption_newWithParm(int batch_size, float lrate, bool bp, const char *tr_scp, const char *cv_scp, const char *transf, const char *network); GlobalOption* GlobalOption_newWithId(long id); long GlobalOption_id(GlobalOption *option); void GlobalOption_free(GlobalOption *option); #endif