summaryrefslogtreecommitdiff
path: root/tnet_io/KaldiLib/MathAux.h
blob: c08e836faf8b4d97d3b20ac64ee9114468f5d618 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
#ifndef TNet_MathAux_h
#define TNet_MathAux_h

#include <cmath>


#if !defined(SQR)
# define SQR(x) ((x) * (x))
#endif


#if !defined(LOG_0)
# define LOG_0 (-1.0e10)
#endif

#if !defined(LOG_MIN)
# define LOG_MIN   (0.5 * LOG_0)
#endif


#ifndef DBL_EPSILON
#define DBL_EPSILON 2.2204460492503131e-16
#endif


#ifndef M_PI
#  define M_PI 3.1415926535897932384626433832795
#endif

#define M_LOG_2PI 1.8378770664093454835606594728112


#if DOUBLEPRECISION
#  define FLOAT double
#  define EPSILON DBL_EPSILON
#  define FLOAT_FMT "%lg"
#  define swapFLOAT swap8
#  define _ABS  fabs
#  define _COS  cos
#  define _EXP  exp
#  define _LOG  log
#  define _SQRT sqrt
#else
#  define FLOAT float
#  define EPSILON FLT_EPSILON
#  define FLOAT_FMT "%g"
#  define swapFLOAT swap4
#  define _ABS  fabsf
#  define _COS  cosf
#  define _EXP  expf
#  define _LOG  logf
#  define _SQRT sqrtf
#endif

namespace TNet
{
  inline float frand(){ // random between 0 and 1.
	return (float(rand()) + 1.0f) / (float(RAND_MAX)+2.0f);
  }
  inline float gauss_rand(){
	return _SQRT( -2.0f * _LOG(frand()) ) * _COS(2.0f*float(M_PI)*frand());
  }
  
  static const double gMinLogDiff = log(DBL_EPSILON);
  
  //***************************************************************************
  //***************************************************************************
  inline double
  LogAdd(double x, double y)
  {
    double diff;
  
    if (x < y) {
      diff = x - y;
      x = y;
    } else {
      diff = y - x;
    }
  
    double res;
    if (x >= LOG_MIN) {
      if (diff >= gMinLogDiff) {
        res = x + log(1.0 + exp(diff));
      } else {
        res = x;
      }
    } else {
      res = LOG_0;
    }
    return res;
  } 


  //***************************************************************************
  //***************************************************************************
  inline double
  LogSub(double x, double y) // returns exp(x) - exp(y).  Throws exception if y>=x.
  {

    if(y >= x){
      if(y==x)  return LOG_0;
      else throw std::runtime_error("LogSub: cannot subtract a larger from a smaller number.");
    }

    double diff = y - x;  // Will be negative.
    
    double res = x + log(1.0 - exp(diff));

    if(res != res) // test for res==NaN.. could happen if diff ~0.0, so 1.0-exp(diff) == 0.0 to machine precision.
      res = LOG_0;
    return res;
  } 

} // namespace TNet


#endif