aboutsummaryrefslogblamecommitdiff
path: root/fastnn/lib/ModelSync.h
blob: 71216a0351ec9320bd639b20f000108e38a5cd7a (plain) (tree)






















































































































                                                                                                                                                               
#ifndef NERV_FASTNN_MODELSYNC_H
#define NERV_FASTNN_MODELSYNC_H

#define STRLEN  1024

#include "../threads/lib/THThread.h"
#include "matrix/matrix.h"
#include "stdlib.h"
#include "stdbool.h"

typedef struct NnetParallelOptions_
{
	int num_threads;
	int merge_size;
	int num_merge;
	int num_procs;
	int threadid;
	int myid;
	int thread_level;
	char merge_func[STRLEN];
	char log_file[STRLEN];
} NnetParallelOptions;


typedef struct ModelSync_
{
	THMutex *model_mutex;
	THMutex *state_mutex;
	bool	initialized_;
	int	dim_;
	int 	pos_;
	float	*data_;
	float	*free_data_;
	int	refcount;
	int	threadcount;
}ModelSync;

ModelSync *ModelSync_new(void);
ModelSync *ModelSync_newWithId(long id);
int  ModelSync_free(ModelSync *self);
long ModelSync_id(ModelSync *self);	
int  ModelSync_lockmodel(ModelSync *self);
int  ModelSync_unlockmodel(ModelSync *self);
int  ModelSync_lockstate(ModelSync *self);
int  ModelSync_unlockstate(ModelSync *self);
int  ModelSync_initBuffer(ModelSync *self);
int  ModelSync_weightfromd(ModelSync *self, Matrix *dm);
int  ModelSync_weighttod(ModelSync *self, Matrix *dm);
int ModelSync_threadcount(ModelSync *self);
void ModelSync_syncinc(ModelSync *self);
void ModelSync_syncdec(ModelSync *self);

typedef struct Xent_
{
	size_t frames_;
  	size_t correct_;
  	double loss_;
  	double entropy_;
	int    refcount;
} Xent;

Xent* Xent_new();
Xent* Xent_newWithId(long id);
Xent* Xent_newWithParm(size_t frames_, size_t correct_, double loss_, double entropy_);
long Xent_id(Xent *xent);
Xent* Xent_add(Xent *a, Xent *b);
void Xent_free(Xent *xent);

typedef struct Mse_
{
  size_t frames_;
  double loss_;
  int	 refcount;
} Mse;

Mse* Mse_new();
Mse* Mse_newWithId(long id);
Mse* Mse_newWithParm(size_t frames_, double loss_);
long Mse_id(Mse *mse);
Mse* Mse_add(Mse *a, Mse *b);
void Mse_free(Mse *mse);

typedef struct NnetUpdateState_
{
	int num_utter;
	int num_nolabel;
	int num_other_error;
	long total_frames;
	Xent xent;
	Mse  mse;	
} NnetUpdateState;

typedef struct GlobalOption_
{
	int 	batch_size;
	float	lrate;
	bool	bp;
	char	tr_scp[STRLEN];
	char	cv_scp[STRLEN];
	char	transf[STRLEN];
	char	network[STRLEN];
	int	refcount;
}GlobalOption;


GlobalOption* GlobalOption_new();
GlobalOption* GlobalOption_newWithParm(int batch_size, float lrate, bool bp, const char *tr_scp, const char *cv_scp, const char *transf, const char *network); 
GlobalOption* GlobalOption_newWithId(long id);
long GlobalOption_id(GlobalOption *option);
void GlobalOption_free(GlobalOption *option);




#endif