// float-weight.h // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // // Copyright 2005-2010 Google, Inc. // Author: riley@google.com (Michael Riley) // // \file // Float weight set and associated semiring operation definitions. // #ifndef FST_LIB_FLOAT_WEIGHT_H__ #define FST_LIB_FLOAT_WEIGHT_H__ #include #include #include #include #include #include namespace fst { // numeric limits class template class FloatLimits { public: static const T PosInfinity() { static const T pos_infinity = numeric_limits::infinity(); return pos_infinity; } static const T NegInfinity() { static const T neg_infinity = -PosInfinity(); return neg_infinity; } static const T NumberBad() { static const T number_bad = numeric_limits::quiet_NaN(); return number_bad; } }; // weight class to be templated on floating-points types template class FloatWeightTpl { public: FloatWeightTpl() {} FloatWeightTpl(T f) : value_(f) {} FloatWeightTpl(const FloatWeightTpl &w) : value_(w.value_) {} FloatWeightTpl &operator=(const FloatWeightTpl &w) { value_ = w.value_; return *this; } istream &Read(istream &strm) { return ReadType(strm, &value_); } ostream &Write(ostream &strm) const { return WriteType(strm, value_); } size_t Hash() const { union { T f; size_t s; } u; u.s = 0; u.f = value_; return u.s; } const T &Value() const { return value_; } protected: void SetValue(const T &f) { value_ = f; } inline static string GetPrecisionString() { int64 size = sizeof(T); if (size == sizeof(float)) return ""; size *= CHAR_BIT; string result; Int64ToStr(size, &result); return result; } private: T value_; }; // Single-precision float weight typedef FloatWeightTpl FloatWeight; template inline bool operator==(const FloatWeightTpl &w1, const FloatWeightTpl &w2) { // Volatile qualifier thwarts over-aggressive compiler optimizations // that lead to problems esp. with NaturalLess(). volatile T v1 = w1.Value(); volatile T v2 = w2.Value(); return v1 == v2; } inline bool operator==(const FloatWeightTpl &w1, const FloatWeightTpl &w2) { return operator==(w1, w2); } inline bool operator==(const FloatWeightTpl &w1, const FloatWeightTpl &w2) { return operator==(w1, w2); } template inline bool operator!=(const FloatWeightTpl &w1, const FloatWeightTpl &w2) { return !(w1 == w2); } inline bool operator!=(const FloatWeightTpl &w1, const FloatWeightTpl &w2) { return operator!=(w1, w2); } inline bool operator!=(const FloatWeightTpl &w1, const FloatWeightTpl &w2) { return operator!=(w1, w2); } template inline bool ApproxEqual(const FloatWeightTpl &w1, const FloatWeightTpl &w2, float delta = kDelta) { return w1.Value() <= w2.Value() + delta && w2.Value() <= w1.Value() + delta; } template inline ostream &operator<<(ostream &strm, const FloatWeightTpl &w) { if (w.Value() == FloatLimits::PosInfinity()) return strm << "Infinity"; else if (w.Value() == FloatLimits::NegInfinity()) return strm << "-Infinity"; else if (w.Value() != w.Value()) // Fails for NaN return strm << "BadNumber"; else return strm << w.Value(); } template inline istream &operator>>(istream &strm, FloatWeightTpl &w) { string s; strm >> s; if (s == "Infinity") { w = FloatWeightTpl(FloatLimits::PosInfinity()); } else if (s == "-Infinity") { w = FloatWeightTpl(FloatLimits::NegInfinity()); } else { char *p; T f = strtod(s.c_str(), &p); if (p < s.c_str() + s.size()) strm.clear(std::ios::badbit); else w = FloatWeightTpl(f); } return strm; } // Tropical semiring: (min, +, inf, 0) template class TropicalWeightTpl : public FloatWeightTpl { public: using FloatWeightTpl::Value; typedef TropicalWeightTpl ReverseWeight; TropicalWeightTpl() : FloatWeightTpl() {} TropicalWeightTpl(T f) : FloatWeightTpl(f) {} TropicalWeightTpl(const TropicalWeightTpl &w) : FloatWeightTpl(w) {} static const TropicalWeightTpl Zero() { return TropicalWeightTpl(FloatLimits::PosInfinity()); } static const TropicalWeightTpl One() { return TropicalWeightTpl(0.0F); } static const TropicalWeightTpl NoWeight() { return TropicalWeightTpl(FloatLimits::NumberBad()); } static const string &Type() { static const string type = "tropical" + FloatWeightTpl::GetPrecisionString(); return type; } bool Member() const { // First part fails for IEEE NaN return Value() == Value() && Value() != FloatLimits::NegInfinity(); } TropicalWeightTpl Quantize(float delta = kDelta) const { if (Value() == FloatLimits::NegInfinity() || Value() == FloatLimits::PosInfinity() || Value() != Value()) return *this; else return TropicalWeightTpl(floor(Value()/delta + 0.5F) * delta); } TropicalWeightTpl Reverse() const { return *this; } static uint64 Properties() { return kLeftSemiring | kRightSemiring | kCommutative | kPath | kIdempotent; } }; // Single precision tropical weight typedef TropicalWeightTpl TropicalWeight; template inline TropicalWeightTpl Plus(const TropicalWeightTpl &w1, const TropicalWeightTpl &w2) { if (!w1.Member() || !w2.Member()) return TropicalWeightTpl::NoWeight(); return w1.Value() < w2.Value() ? w1 : w2; } inline TropicalWeightTpl Plus(const TropicalWeightTpl &w1, const TropicalWeightTpl &w2) { return Plus(w1, w2); } inline TropicalWeightTpl Plus(const TropicalWeightTpl &w1, const TropicalWeightTpl &w2) { return Plus(w1, w2); } template inline TropicalWeightTpl Times(const TropicalWeightTpl &w1, const TropicalWeightTpl &w2) { if (!w1.Member() || !w2.Member()) return TropicalWeightTpl::NoWeight(); T f1 = w1.Value(), f2 = w2.Value(); if (f1 == FloatLimits::PosInfinity()) return w1; else if (f2 == FloatLimits::PosInfinity()) return w2; else return TropicalWeightTpl(f1 + f2); } inline TropicalWeightTpl Times(const TropicalWeightTpl &w1, const TropicalWeightTpl &w2) { return Times(w1, w2); } inline TropicalWeightTpl Times(const TropicalWeightTpl &w1, const TropicalWeightTpl &w2) { return Times(w1, w2); } template inline TropicalWeightTpl Divide(const TropicalWeightTpl &w1, const TropicalWeightTpl &w2, DivideType typ = DIVIDE_ANY) { if (!w1.Member() || !w2.Member()) return TropicalWeightTpl::NoWeight(); T f1 = w1.Value(), f2 = w2.Value(); if (f2 == FloatLimits::PosInfinity()) return FloatLimits::NumberBad(); else if (f1 == FloatLimits::PosInfinity()) return FloatLimits::PosInfinity(); else return TropicalWeightTpl(f1 - f2); } inline TropicalWeightTpl Divide(const TropicalWeightTpl &w1, const TropicalWeightTpl &w2, DivideType typ = DIVIDE_ANY) { return Divide(w1, w2, typ); } inline TropicalWeightTpl Divide(const TropicalWeightTpl &w1, const TropicalWeightTpl &w2, DivideType typ = DIVIDE_ANY) { return Divide(w1, w2, typ); } // Log semiring: (log(e^-x + e^y), +, inf, 0) template class LogWeightTpl : public FloatWeightTpl { public: using FloatWeightTpl::Value; typedef LogWeightTpl ReverseWeight; LogWeightTpl() : FloatWeightTpl() {} LogWeightTpl(T f) : FloatWeightTpl(f) {} LogWeightTpl(const LogWeightTpl &w) : FloatWeightTpl(w) {} static const LogWeightTpl Zero() { return LogWeightTpl(FloatLimits::PosInfinity()); } static const LogWeightTpl One() { return LogWeightTpl(0.0F); } static const LogWeightTpl NoWeight() { return LogWeightTpl(FloatLimits::NumberBad()); } static const string &Type() { static const string type = "log" + FloatWeightTpl::GetPrecisionString(); return type; } bool Member() const { // First part fails for IEEE NaN return Value() == Value() && Value() != FloatLimits::NegInfinity(); } LogWeightTpl Quantize(float delta = kDelta) const { if (Value() == FloatLimits::NegInfinity() || Value() == FloatLimits::PosInfinity() || Value() != Value()) return *this; else return LogWeightTpl(floor(Value()/delta + 0.5F) * delta); } LogWeightTpl Reverse() const { return *this; } static uint64 Properties() { return kLeftSemiring | kRightSemiring | kCommutative; } }; // Single-precision log weight typedef LogWeightTpl LogWeight; // Double-precision log weight typedef LogWeightTpl Log64Weight; template inline T LogExp(T x) { return log(1.0F + exp(-x)); } template inline LogWeightTpl Plus(const LogWeightTpl &w1, const LogWeightTpl &w2) { T f1 = w1.Value(), f2 = w2.Value(); if (f1 == FloatLimits::PosInfinity()) return w2; else if (f2 == FloatLimits::PosInfinity()) return w1; else if (f1 > f2) return LogWeightTpl(f2 - LogExp(f1 - f2)); else return LogWeightTpl(f1 - LogExp(f2 - f1)); } inline LogWeightTpl Plus(const LogWeightTpl &w1, const LogWeightTpl &w2) { return Plus(w1, w2); } inline LogWeightTpl Plus(const LogWeightTpl &w1, const LogWeightTpl &w2) { return Plus(w1, w2); } template inline LogWeightTpl Times(const LogWeightTpl &w1, const LogWeightTpl &w2) { if (!w1.Member() || !w2.Member()) return LogWeightTpl::NoWeight(); T f1 = w1.Value(), f2 = w2.Value(); if (f1 == FloatLimits::PosInfinity()) return w1; else if (f2 == FloatLimits::PosInfinity()) return w2; else return LogWeightTpl(f1 + f2); } inline LogWeightTpl Times(const LogWeightTpl &w1, const LogWeightTpl &w2) { return Times(w1, w2); } inline LogWeightTpl Times(const LogWeightTpl &w1, const LogWeightTpl &w2) { return Times(w1, w2); } template inline LogWeightTpl Divide(const LogWeightTpl &w1, const LogWeightTpl &w2, DivideType typ = DIVIDE_ANY) { if (!w1.Member() || !w2.Member()) return LogWeightTpl::NoWeight(); T f1 = w1.Value(), f2 = w2.Value(); if (f2 == FloatLimits::PosInfinity()) return FloatLimits::NumberBad(); else if (f1 == FloatLimits::PosInfinity()) return FloatLimits::PosInfinity(); else return LogWeightTpl(f1 - f2); } inline LogWeightTpl Divide(const LogWeightTpl &w1, const LogWeightTpl &w2, DivideType typ = DIVIDE_ANY) { return Divide(w1, w2, typ); } inline LogWeightTpl Divide(const LogWeightTpl &w1, const LogWeightTpl &w2, DivideType typ = DIVIDE_ANY) { return Divide(w1, w2, typ); } // MinMax semiring: (min, max, inf, -inf) template class MinMaxWeightTpl : public FloatWeightTpl { public: using FloatWeightTpl::Value; typedef MinMaxWeightTpl ReverseWeight; MinMaxWeightTpl() : FloatWeightTpl() {} MinMaxWeightTpl(T f) : FloatWeightTpl(f) {} MinMaxWeightTpl(const MinMaxWeightTpl &w) : FloatWeightTpl(w) {} static const MinMaxWeightTpl Zero() { return MinMaxWeightTpl(FloatLimits::PosInfinity()); } static const MinMaxWeightTpl One() { return MinMaxWeightTpl(FloatLimits::NegInfinity()); } static const MinMaxWeightTpl NoWeight() { return MinMaxWeightTpl(FloatLimits::NumberBad()); } static const string &Type() { static const string type = "minmax" + FloatWeightTpl::GetPrecisionString(); return type; } bool Member() const { // Fails for IEEE NaN return Value() == Value(); } MinMaxWeightTpl Quantize(float delta = kDelta) const { // If one of infinities, or a NaN if (Value() == FloatLimits::NegInfinity() || Value() == FloatLimits::PosInfinity() || Value() != Value()) return *this; else return MinMaxWeightTpl(floor(Value()/delta + 0.5F) * delta); } MinMaxWeightTpl Reverse() const { return *this; } static uint64 Properties() { return kLeftSemiring | kRightSemiring | kCommutative | kIdempotent | kPath; } }; // Single-precision min-max weight typedef MinMaxWeightTpl MinMaxWeight; // Min template inline MinMaxWeightTpl Plus( const MinMaxWeightTpl &w1, const MinMaxWeightTpl &w2) { if (!w1.Member() || !w2.Member()) return MinMaxWeightTpl::NoWeight(); return w1.Value() < w2.Value() ? w1 : w2; } inline MinMaxWeightTpl Plus( const MinMaxWeightTpl &w1, const MinMaxWeightTpl &w2) { return Plus(w1, w2); } inline MinMaxWeightTpl Plus( const MinMaxWeightTpl &w1, const MinMaxWeightTpl &w2) { return Plus(w1, w2); } // Max template inline MinMaxWeightTpl Times( const MinMaxWeightTpl &w1, const MinMaxWeightTpl &w2) { if (!w1.Member() || !w2.Member()) return MinMaxWeightTpl::NoWeight(); return w1.Value() >= w2.Value() ? w1 : w2; } inline MinMaxWeightTpl Times( const MinMaxWeightTpl &w1, const MinMaxWeightTpl &w2) { return Times(w1, w2); } inline MinMaxWeightTpl Times( const MinMaxWeightTpl &w1, const MinMaxWeightTpl &w2) { return Times(w1, w2); } // Defined only for special cases template inline MinMaxWeightTpl Divide(const MinMaxWeightTpl &w1, const MinMaxWeightTpl &w2, DivideType typ = DIVIDE_ANY) { if (!w1.Member() || !w2.Member()) return MinMaxWeightTpl::NoWeight(); // min(w1, x) = w2, w1 >= w2 => min(w1, x) = w2, x = w2 return w1.Value() >= w2.Value() ? w1 : FloatLimits::NumberBad(); } inline MinMaxWeightTpl Divide(const MinMaxWeightTpl &w1, const MinMaxWeightTpl &w2, DivideType typ = DIVIDE_ANY) { return Divide(w1, w2, typ); } inline MinMaxWeightTpl Divide(const MinMaxWeightTpl &w1, const MinMaxWeightTpl &w2, DivideType typ = DIVIDE_ANY) { return Divide(w1, w2, typ); } // // WEIGHT CONVERTER SPECIALIZATIONS. // // Convert to tropical template <> struct WeightConvert { TropicalWeight operator()(LogWeight w) const { return w.Value(); } }; template <> struct WeightConvert { TropicalWeight operator()(Log64Weight w) const { return w.Value(); } }; // Convert to log template <> struct WeightConvert { LogWeight operator()(TropicalWeight w) const { return w.Value(); } }; template <> struct WeightConvert { LogWeight operator()(Log64Weight w) const { return w.Value(); } }; // Convert to log64 template <> struct WeightConvert { Log64Weight operator()(TropicalWeight w) const { return w.Value(); } }; template <> struct WeightConvert { Log64Weight operator()(LogWeight w) const { return w.Value(); } }; } // namespace fst #endif // FST_LIB_FLOAT_WEIGHT_H__