aboutsummaryrefslogtreecommitdiff
path: root/fastnn/lib/ModelSync.h
diff options
context:
space:
mode:
Diffstat (limited to 'fastnn/lib/ModelSync.h')
-rw-r--r--fastnn/lib/ModelSync.h119
1 files changed, 119 insertions, 0 deletions
diff --git a/fastnn/lib/ModelSync.h b/fastnn/lib/ModelSync.h
new file mode 100644
index 0000000..71216a0
--- /dev/null
+++ b/fastnn/lib/ModelSync.h
@@ -0,0 +1,119 @@
+
+#ifndef NERV_FASTNN_MODELSYNC_H
+#define NERV_FASTNN_MODELSYNC_H
+
+#define STRLEN 1024
+
+#include "../threads/lib/THThread.h"
+#include "matrix/matrix.h"
+#include "stdlib.h"
+#include "stdbool.h"
+
+typedef struct NnetParallelOptions_
+{
+ int num_threads;
+ int merge_size;
+ int num_merge;
+ int num_procs;
+ int threadid;
+ int myid;
+ int thread_level;
+ char merge_func[STRLEN];
+ char log_file[STRLEN];
+} NnetParallelOptions;
+
+
+typedef struct ModelSync_
+{
+ THMutex *model_mutex;
+ THMutex *state_mutex;
+ bool initialized_;
+ int dim_;
+ int pos_;
+ float *data_;
+ float *free_data_;
+ int refcount;
+ int threadcount;
+}ModelSync;
+
+ModelSync *ModelSync_new(void);
+ModelSync *ModelSync_newWithId(long id);
+int ModelSync_free(ModelSync *self);
+long ModelSync_id(ModelSync *self);
+int ModelSync_lockmodel(ModelSync *self);
+int ModelSync_unlockmodel(ModelSync *self);
+int ModelSync_lockstate(ModelSync *self);
+int ModelSync_unlockstate(ModelSync *self);
+int ModelSync_initBuffer(ModelSync *self);
+int ModelSync_weightfromd(ModelSync *self, Matrix *dm);
+int ModelSync_weighttod(ModelSync *self, Matrix *dm);
+int ModelSync_threadcount(ModelSync *self);
+void ModelSync_syncinc(ModelSync *self);
+void ModelSync_syncdec(ModelSync *self);
+
+typedef struct Xent_
+{
+ size_t frames_;
+ size_t correct_;
+ double loss_;
+ double entropy_;
+ int refcount;
+} Xent;
+
+Xent* Xent_new();
+Xent* Xent_newWithId(long id);
+Xent* Xent_newWithParm(size_t frames_, size_t correct_, double loss_, double entropy_);
+long Xent_id(Xent *xent);
+Xent* Xent_add(Xent *a, Xent *b);
+void Xent_free(Xent *xent);
+
+typedef struct Mse_
+{
+ size_t frames_;
+ double loss_;
+ int refcount;
+} Mse;
+
+Mse* Mse_new();
+Mse* Mse_newWithId(long id);
+Mse* Mse_newWithParm(size_t frames_, double loss_);
+long Mse_id(Mse *mse);
+Mse* Mse_add(Mse *a, Mse *b);
+void Mse_free(Mse *mse);
+
+typedef struct NnetUpdateState_
+{
+ int num_utter;
+ int num_nolabel;
+ int num_other_error;
+ long total_frames;
+ Xent xent;
+ Mse mse;
+} NnetUpdateState;
+
+typedef struct GlobalOption_
+{
+ int batch_size;
+ float lrate;
+ bool bp;
+ char tr_scp[STRLEN];
+ char cv_scp[STRLEN];
+ char transf[STRLEN];
+ char network[STRLEN];
+ int refcount;
+}GlobalOption;
+
+
+GlobalOption* GlobalOption_new();
+GlobalOption* GlobalOption_newWithParm(int batch_size, float lrate, bool bp, const char *tr_scp, const char *cv_scp, const char *transf, const char *network);
+GlobalOption* GlobalOption_newWithId(long id);
+long GlobalOption_id(GlobalOption *option);
+void GlobalOption_free(GlobalOption *option);
+
+
+
+
+#endif
+
+
+