diff options
Diffstat (limited to 'kaldi_io/src/tools/openfst/include/fst/float-weight.h')
-rw-r--r-- | kaldi_io/src/tools/openfst/include/fst/float-weight.h | 601 |
1 files changed, 601 insertions, 0 deletions
diff --git a/kaldi_io/src/tools/openfst/include/fst/float-weight.h b/kaldi_io/src/tools/openfst/include/fst/float-weight.h new file mode 100644 index 0000000..eb22638 --- /dev/null +++ b/kaldi_io/src/tools/openfst/include/fst/float-weight.h @@ -0,0 +1,601 @@ +// float-weight.h + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Copyright 2005-2010 Google, Inc. +// Author: [email protected] (Michael Riley) +// +// \file +// Float weight set and associated semiring operation definitions. +// + +#ifndef FST_LIB_FLOAT_WEIGHT_H__ +#define FST_LIB_FLOAT_WEIGHT_H__ + +#include <limits> +#include <climits> +#include <sstream> +#include <string> + +#include <fst/util.h> +#include <fst/weight.h> + + +namespace fst { + +// numeric limits class +template <class T> +class FloatLimits { + public: + static const T PosInfinity() { + static const T pos_infinity = numeric_limits<T>::infinity(); + return pos_infinity; + } + + static const T NegInfinity() { + static const T neg_infinity = -PosInfinity(); + return neg_infinity; + } + + static const T NumberBad() { + static const T number_bad = numeric_limits<T>::quiet_NaN(); + return number_bad; + } + +}; + +// weight class to be templated on floating-points types +template <class T = float> +class FloatWeightTpl { + public: + FloatWeightTpl() {} + + FloatWeightTpl(T f) : value_(f) {} + + FloatWeightTpl(const FloatWeightTpl<T> &w) : value_(w.value_) {} + + FloatWeightTpl<T> &operator=(const FloatWeightTpl<T> &w) { + value_ = w.value_; + return *this; + } + + istream &Read(istream &strm) { + return ReadType(strm, &value_); + } + + ostream &Write(ostream &strm) const { + return WriteType(strm, value_); + } + + size_t Hash() const { + union { + T f; + size_t s; + } u; + u.s = 0; + u.f = value_; + return u.s; + } + + const T &Value() const { return value_; } + + protected: + void SetValue(const T &f) { value_ = f; } + + inline static string GetPrecisionString() { + int64 size = sizeof(T); + if (size == sizeof(float)) return ""; + size *= CHAR_BIT; + + string result; + Int64ToStr(size, &result); + return result; + } + + private: + T value_; +}; + +// Single-precision float weight +typedef FloatWeightTpl<float> FloatWeight; + +template <class T> +inline bool operator==(const FloatWeightTpl<T> &w1, + const FloatWeightTpl<T> &w2) { + // Volatile qualifier thwarts over-aggressive compiler optimizations + // that lead to problems esp. with NaturalLess(). + volatile T v1 = w1.Value(); + volatile T v2 = w2.Value(); + return v1 == v2; +} + +inline bool operator==(const FloatWeightTpl<double> &w1, + const FloatWeightTpl<double> &w2) { + return operator==<double>(w1, w2); +} + +inline bool operator==(const FloatWeightTpl<float> &w1, + const FloatWeightTpl<float> &w2) { + return operator==<float>(w1, w2); +} + +template <class T> +inline bool operator!=(const FloatWeightTpl<T> &w1, + const FloatWeightTpl<T> &w2) { + return !(w1 == w2); +} + +inline bool operator!=(const FloatWeightTpl<double> &w1, + const FloatWeightTpl<double> &w2) { + return operator!=<double>(w1, w2); +} + +inline bool operator!=(const FloatWeightTpl<float> &w1, + const FloatWeightTpl<float> &w2) { + return operator!=<float>(w1, w2); +} + +template <class T> +inline bool ApproxEqual(const FloatWeightTpl<T> &w1, + const FloatWeightTpl<T> &w2, + float delta = kDelta) { + return w1.Value() <= w2.Value() + delta && w2.Value() <= w1.Value() + delta; +} + +template <class T> +inline ostream &operator<<(ostream &strm, const FloatWeightTpl<T> &w) { + if (w.Value() == FloatLimits<T>::PosInfinity()) + return strm << "Infinity"; + else if (w.Value() == FloatLimits<T>::NegInfinity()) + return strm << "-Infinity"; + else if (w.Value() != w.Value()) // Fails for NaN + return strm << "BadNumber"; + else + return strm << w.Value(); +} + +template <class T> +inline istream &operator>>(istream &strm, FloatWeightTpl<T> &w) { + string s; + strm >> s; + if (s == "Infinity") { + w = FloatWeightTpl<T>(FloatLimits<T>::PosInfinity()); + } else if (s == "-Infinity") { + w = FloatWeightTpl<T>(FloatLimits<T>::NegInfinity()); + } else { + char *p; + T f = strtod(s.c_str(), &p); + if (p < s.c_str() + s.size()) + strm.clear(std::ios::badbit); + else + w = FloatWeightTpl<T>(f); + } + return strm; +} + + +// Tropical semiring: (min, +, inf, 0) +template <class T> +class TropicalWeightTpl : public FloatWeightTpl<T> { + public: + using FloatWeightTpl<T>::Value; + + typedef TropicalWeightTpl<T> ReverseWeight; + + TropicalWeightTpl() : FloatWeightTpl<T>() {} + + TropicalWeightTpl(T f) : FloatWeightTpl<T>(f) {} + + TropicalWeightTpl(const TropicalWeightTpl<T> &w) : FloatWeightTpl<T>(w) {} + + static const TropicalWeightTpl<T> Zero() { + return TropicalWeightTpl<T>(FloatLimits<T>::PosInfinity()); } + + static const TropicalWeightTpl<T> One() { + return TropicalWeightTpl<T>(0.0F); } + + static const TropicalWeightTpl<T> NoWeight() { + return TropicalWeightTpl<T>(FloatLimits<T>::NumberBad()); } + + static const string &Type() { + static const string type = "tropical" + + FloatWeightTpl<T>::GetPrecisionString(); + return type; + } + + bool Member() const { + // First part fails for IEEE NaN + return Value() == Value() && Value() != FloatLimits<T>::NegInfinity(); + } + + TropicalWeightTpl<T> Quantize(float delta = kDelta) const { + if (Value() == FloatLimits<T>::NegInfinity() || + Value() == FloatLimits<T>::PosInfinity() || + Value() != Value()) + return *this; + else + return TropicalWeightTpl<T>(floor(Value()/delta + 0.5F) * delta); + } + + TropicalWeightTpl<T> Reverse() const { return *this; } + + static uint64 Properties() { + return kLeftSemiring | kRightSemiring | kCommutative | + kPath | kIdempotent; + } +}; + +// Single precision tropical weight +typedef TropicalWeightTpl<float> TropicalWeight; + +template <class T> +inline TropicalWeightTpl<T> Plus(const TropicalWeightTpl<T> &w1, + const TropicalWeightTpl<T> &w2) { + if (!w1.Member() || !w2.Member()) + return TropicalWeightTpl<T>::NoWeight(); + return w1.Value() < w2.Value() ? w1 : w2; +} + +inline TropicalWeightTpl<float> Plus(const TropicalWeightTpl<float> &w1, + const TropicalWeightTpl<float> &w2) { + return Plus<float>(w1, w2); +} + +inline TropicalWeightTpl<double> Plus(const TropicalWeightTpl<double> &w1, + const TropicalWeightTpl<double> &w2) { + return Plus<double>(w1, w2); +} + +template <class T> +inline TropicalWeightTpl<T> Times(const TropicalWeightTpl<T> &w1, + const TropicalWeightTpl<T> &w2) { + if (!w1.Member() || !w2.Member()) + return TropicalWeightTpl<T>::NoWeight(); + T f1 = w1.Value(), f2 = w2.Value(); + if (f1 == FloatLimits<T>::PosInfinity()) + return w1; + else if (f2 == FloatLimits<T>::PosInfinity()) + return w2; + else + return TropicalWeightTpl<T>(f1 + f2); +} + +inline TropicalWeightTpl<float> Times(const TropicalWeightTpl<float> &w1, + const TropicalWeightTpl<float> &w2) { + return Times<float>(w1, w2); +} + +inline TropicalWeightTpl<double> Times(const TropicalWeightTpl<double> &w1, + const TropicalWeightTpl<double> &w2) { + return Times<double>(w1, w2); +} + +template <class T> +inline TropicalWeightTpl<T> Divide(const TropicalWeightTpl<T> &w1, + const TropicalWeightTpl<T> &w2, + DivideType typ = DIVIDE_ANY) { + if (!w1.Member() || !w2.Member()) + return TropicalWeightTpl<T>::NoWeight(); + T f1 = w1.Value(), f2 = w2.Value(); + if (f2 == FloatLimits<T>::PosInfinity()) + return FloatLimits<T>::NumberBad(); + else if (f1 == FloatLimits<T>::PosInfinity()) + return FloatLimits<T>::PosInfinity(); + else + return TropicalWeightTpl<T>(f1 - f2); +} + +inline TropicalWeightTpl<float> Divide(const TropicalWeightTpl<float> &w1, + const TropicalWeightTpl<float> &w2, + DivideType typ = DIVIDE_ANY) { + return Divide<float>(w1, w2, typ); +} + +inline TropicalWeightTpl<double> Divide(const TropicalWeightTpl<double> &w1, + const TropicalWeightTpl<double> &w2, + DivideType typ = DIVIDE_ANY) { + return Divide<double>(w1, w2, typ); +} + + +// Log semiring: (log(e^-x + e^y), +, inf, 0) +template <class T> +class LogWeightTpl : public FloatWeightTpl<T> { + public: + using FloatWeightTpl<T>::Value; + + typedef LogWeightTpl ReverseWeight; + + LogWeightTpl() : FloatWeightTpl<T>() {} + + LogWeightTpl(T f) : FloatWeightTpl<T>(f) {} + + LogWeightTpl(const LogWeightTpl<T> &w) : FloatWeightTpl<T>(w) {} + + static const LogWeightTpl<T> Zero() { + return LogWeightTpl<T>(FloatLimits<T>::PosInfinity()); + } + + static const LogWeightTpl<T> One() { + return LogWeightTpl<T>(0.0F); + } + + static const LogWeightTpl<T> NoWeight() { + return LogWeightTpl<T>(FloatLimits<T>::NumberBad()); } + + static const string &Type() { + static const string type = "log" + FloatWeightTpl<T>::GetPrecisionString(); + return type; + } + + bool Member() const { + // First part fails for IEEE NaN + return Value() == Value() && Value() != FloatLimits<T>::NegInfinity(); + } + + LogWeightTpl<T> Quantize(float delta = kDelta) const { + if (Value() == FloatLimits<T>::NegInfinity() || + Value() == FloatLimits<T>::PosInfinity() || + Value() != Value()) + return *this; + else + return LogWeightTpl<T>(floor(Value()/delta + 0.5F) * delta); + } + + LogWeightTpl<T> Reverse() const { return *this; } + + static uint64 Properties() { + return kLeftSemiring | kRightSemiring | kCommutative; + } +}; + +// Single-precision log weight +typedef LogWeightTpl<float> LogWeight; +// Double-precision log weight +typedef LogWeightTpl<double> Log64Weight; + +template <class T> +inline T LogExp(T x) { return log(1.0F + exp(-x)); } + +template <class T> +inline LogWeightTpl<T> Plus(const LogWeightTpl<T> &w1, + const LogWeightTpl<T> &w2) { + T f1 = w1.Value(), f2 = w2.Value(); + if (f1 == FloatLimits<T>::PosInfinity()) + return w2; + else if (f2 == FloatLimits<T>::PosInfinity()) + return w1; + else if (f1 > f2) + return LogWeightTpl<T>(f2 - LogExp(f1 - f2)); + else + return LogWeightTpl<T>(f1 - LogExp(f2 - f1)); +} + +inline LogWeightTpl<float> Plus(const LogWeightTpl<float> &w1, + const LogWeightTpl<float> &w2) { + return Plus<float>(w1, w2); +} + +inline LogWeightTpl<double> Plus(const LogWeightTpl<double> &w1, + const LogWeightTpl<double> &w2) { + return Plus<double>(w1, w2); +} + +template <class T> +inline LogWeightTpl<T> Times(const LogWeightTpl<T> &w1, + const LogWeightTpl<T> &w2) { + if (!w1.Member() || !w2.Member()) + return LogWeightTpl<T>::NoWeight(); + T f1 = w1.Value(), f2 = w2.Value(); + if (f1 == FloatLimits<T>::PosInfinity()) + return w1; + else if (f2 == FloatLimits<T>::PosInfinity()) + return w2; + else + return LogWeightTpl<T>(f1 + f2); +} + +inline LogWeightTpl<float> Times(const LogWeightTpl<float> &w1, + const LogWeightTpl<float> &w2) { + return Times<float>(w1, w2); +} + +inline LogWeightTpl<double> Times(const LogWeightTpl<double> &w1, + const LogWeightTpl<double> &w2) { + return Times<double>(w1, w2); +} + +template <class T> +inline LogWeightTpl<T> Divide(const LogWeightTpl<T> &w1, + const LogWeightTpl<T> &w2, + DivideType typ = DIVIDE_ANY) { + if (!w1.Member() || !w2.Member()) + return LogWeightTpl<T>::NoWeight(); + T f1 = w1.Value(), f2 = w2.Value(); + if (f2 == FloatLimits<T>::PosInfinity()) + return FloatLimits<T>::NumberBad(); + else if (f1 == FloatLimits<T>::PosInfinity()) + return FloatLimits<T>::PosInfinity(); + else + return LogWeightTpl<T>(f1 - f2); +} + +inline LogWeightTpl<float> Divide(const LogWeightTpl<float> &w1, + const LogWeightTpl<float> &w2, + DivideType typ = DIVIDE_ANY) { + return Divide<float>(w1, w2, typ); +} + +inline LogWeightTpl<double> Divide(const LogWeightTpl<double> &w1, + const LogWeightTpl<double> &w2, + DivideType typ = DIVIDE_ANY) { + return Divide<double>(w1, w2, typ); +} + +// MinMax semiring: (min, max, inf, -inf) +template <class T> +class MinMaxWeightTpl : public FloatWeightTpl<T> { + public: + using FloatWeightTpl<T>::Value; + + typedef MinMaxWeightTpl<T> ReverseWeight; + + MinMaxWeightTpl() : FloatWeightTpl<T>() {} + + MinMaxWeightTpl(T f) : FloatWeightTpl<T>(f) {} + + MinMaxWeightTpl(const MinMaxWeightTpl<T> &w) : FloatWeightTpl<T>(w) {} + + static const MinMaxWeightTpl<T> Zero() { + return MinMaxWeightTpl<T>(FloatLimits<T>::PosInfinity()); + } + + static const MinMaxWeightTpl<T> One() { + return MinMaxWeightTpl<T>(FloatLimits<T>::NegInfinity()); + } + + static const MinMaxWeightTpl<T> NoWeight() { + return MinMaxWeightTpl<T>(FloatLimits<T>::NumberBad()); } + + static const string &Type() { + static const string type = "minmax" + + FloatWeightTpl<T>::GetPrecisionString(); + return type; + } + + bool Member() const { + // Fails for IEEE NaN + return Value() == Value(); + } + + MinMaxWeightTpl<T> Quantize(float delta = kDelta) const { + // If one of infinities, or a NaN + if (Value() == FloatLimits<T>::NegInfinity() || + Value() == FloatLimits<T>::PosInfinity() || + Value() != Value()) + return *this; + else + return MinMaxWeightTpl<T>(floor(Value()/delta + 0.5F) * delta); + } + + MinMaxWeightTpl<T> Reverse() const { return *this; } + + static uint64 Properties() { + return kLeftSemiring | kRightSemiring | kCommutative | kIdempotent | kPath; + } +}; + +// Single-precision min-max weight +typedef MinMaxWeightTpl<float> MinMaxWeight; + +// Min +template <class T> +inline MinMaxWeightTpl<T> Plus( + const MinMaxWeightTpl<T> &w1, const MinMaxWeightTpl<T> &w2) { + if (!w1.Member() || !w2.Member()) + return MinMaxWeightTpl<T>::NoWeight(); + return w1.Value() < w2.Value() ? w1 : w2; +} + +inline MinMaxWeightTpl<float> Plus( + const MinMaxWeightTpl<float> &w1, const MinMaxWeightTpl<float> &w2) { + return Plus<float>(w1, w2); +} + +inline MinMaxWeightTpl<double> Plus( + const MinMaxWeightTpl<double> &w1, const MinMaxWeightTpl<double> &w2) { + return Plus<double>(w1, w2); +} + +// Max +template <class T> +inline MinMaxWeightTpl<T> Times( + const MinMaxWeightTpl<T> &w1, const MinMaxWeightTpl<T> &w2) { + if (!w1.Member() || !w2.Member()) + return MinMaxWeightTpl<T>::NoWeight(); + return w1.Value() >= w2.Value() ? w1 : w2; +} + +inline MinMaxWeightTpl<float> Times( + const MinMaxWeightTpl<float> &w1, const MinMaxWeightTpl<float> &w2) { + return Times<float>(w1, w2); +} + +inline MinMaxWeightTpl<double> Times( + const MinMaxWeightTpl<double> &w1, const MinMaxWeightTpl<double> &w2) { + return Times<double>(w1, w2); +} + +// Defined only for special cases +template <class T> +inline MinMaxWeightTpl<T> Divide(const MinMaxWeightTpl<T> &w1, + const MinMaxWeightTpl<T> &w2, + DivideType typ = DIVIDE_ANY) { + if (!w1.Member() || !w2.Member()) + return MinMaxWeightTpl<T>::NoWeight(); + // min(w1, x) = w2, w1 >= w2 => min(w1, x) = w2, x = w2 + return w1.Value() >= w2.Value() ? w1 : FloatLimits<T>::NumberBad(); +} + +inline MinMaxWeightTpl<float> Divide(const MinMaxWeightTpl<float> &w1, + const MinMaxWeightTpl<float> &w2, + DivideType typ = DIVIDE_ANY) { + return Divide<float>(w1, w2, typ); +} + +inline MinMaxWeightTpl<double> Divide(const MinMaxWeightTpl<double> &w1, + const MinMaxWeightTpl<double> &w2, + DivideType typ = DIVIDE_ANY) { + return Divide<double>(w1, w2, typ); +} + +// +// WEIGHT CONVERTER SPECIALIZATIONS. +// + +// Convert to tropical +template <> +struct WeightConvert<LogWeight, TropicalWeight> { + TropicalWeight operator()(LogWeight w) const { return w.Value(); } +}; + +template <> +struct WeightConvert<Log64Weight, TropicalWeight> { + TropicalWeight operator()(Log64Weight w) const { return w.Value(); } +}; + +// Convert to log +template <> +struct WeightConvert<TropicalWeight, LogWeight> { + LogWeight operator()(TropicalWeight w) const { return w.Value(); } +}; + +template <> +struct WeightConvert<Log64Weight, LogWeight> { + LogWeight operator()(Log64Weight w) const { return w.Value(); } +}; + +// Convert to log64 +template <> +struct WeightConvert<TropicalWeight, Log64Weight> { + Log64Weight operator()(TropicalWeight w) const { return w.Value(); } +}; + +template <> +struct WeightConvert<LogWeight, Log64Weight> { + Log64Weight operator()(LogWeight w) const { return w.Value(); } +}; + +} // namespace fst + +#endif // FST_LIB_FLOAT_WEIGHT_H__ |