From 96a32415ab43377cf1575bd3f4f2980f58028209 Mon Sep 17 00:00:00 2001 From: Determinant Date: Fri, 14 Aug 2015 11:51:42 +0800 Subject: add implementation for kaldi io (by ymz) --- .../src/tools/openfst/include/fst/float-weight.h | 601 +++++++++++++++++++++ 1 file changed, 601 insertions(+) create mode 100644 kaldi_io/src/tools/openfst/include/fst/float-weight.h (limited to 'kaldi_io/src/tools/openfst/include/fst/float-weight.h') diff --git a/kaldi_io/src/tools/openfst/include/fst/float-weight.h b/kaldi_io/src/tools/openfst/include/fst/float-weight.h new file mode 100644 index 0000000..eb22638 --- /dev/null +++ b/kaldi_io/src/tools/openfst/include/fst/float-weight.h @@ -0,0 +1,601 @@ +// float-weight.h + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Copyright 2005-2010 Google, Inc. +// Author: riley@google.com (Michael Riley) +// +// \file +// Float weight set and associated semiring operation definitions. +// + +#ifndef FST_LIB_FLOAT_WEIGHT_H__ +#define FST_LIB_FLOAT_WEIGHT_H__ + +#include +#include +#include +#include + +#include +#include + + +namespace fst { + +// numeric limits class +template +class FloatLimits { + public: + static const T PosInfinity() { + static const T pos_infinity = numeric_limits::infinity(); + return pos_infinity; + } + + static const T NegInfinity() { + static const T neg_infinity = -PosInfinity(); + return neg_infinity; + } + + static const T NumberBad() { + static const T number_bad = numeric_limits::quiet_NaN(); + return number_bad; + } + +}; + +// weight class to be templated on floating-points types +template +class FloatWeightTpl { + public: + FloatWeightTpl() {} + + FloatWeightTpl(T f) : value_(f) {} + + FloatWeightTpl(const FloatWeightTpl &w) : value_(w.value_) {} + + FloatWeightTpl &operator=(const FloatWeightTpl &w) { + value_ = w.value_; + return *this; + } + + istream &Read(istream &strm) { + return ReadType(strm, &value_); + } + + ostream &Write(ostream &strm) const { + return WriteType(strm, value_); + } + + size_t Hash() const { + union { + T f; + size_t s; + } u; + u.s = 0; + u.f = value_; + return u.s; + } + + const T &Value() const { return value_; } + + protected: + void SetValue(const T &f) { value_ = f; } + + inline static string GetPrecisionString() { + int64 size = sizeof(T); + if (size == sizeof(float)) return ""; + size *= CHAR_BIT; + + string result; + Int64ToStr(size, &result); + return result; + } + + private: + T value_; +}; + +// Single-precision float weight +typedef FloatWeightTpl FloatWeight; + +template +inline bool operator==(const FloatWeightTpl &w1, + const FloatWeightTpl &w2) { + // Volatile qualifier thwarts over-aggressive compiler optimizations + // that lead to problems esp. with NaturalLess(). + volatile T v1 = w1.Value(); + volatile T v2 = w2.Value(); + return v1 == v2; +} + +inline bool operator==(const FloatWeightTpl &w1, + const FloatWeightTpl &w2) { + return operator==(w1, w2); +} + +inline bool operator==(const FloatWeightTpl &w1, + const FloatWeightTpl &w2) { + return operator==(w1, w2); +} + +template +inline bool operator!=(const FloatWeightTpl &w1, + const FloatWeightTpl &w2) { + return !(w1 == w2); +} + +inline bool operator!=(const FloatWeightTpl &w1, + const FloatWeightTpl &w2) { + return operator!=(w1, w2); +} + +inline bool operator!=(const FloatWeightTpl &w1, + const FloatWeightTpl &w2) { + return operator!=(w1, w2); +} + +template +inline bool ApproxEqual(const FloatWeightTpl &w1, + const FloatWeightTpl &w2, + float delta = kDelta) { + return w1.Value() <= w2.Value() + delta && w2.Value() <= w1.Value() + delta; +} + +template +inline ostream &operator<<(ostream &strm, const FloatWeightTpl &w) { + if (w.Value() == FloatLimits::PosInfinity()) + return strm << "Infinity"; + else if (w.Value() == FloatLimits::NegInfinity()) + return strm << "-Infinity"; + else if (w.Value() != w.Value()) // Fails for NaN + return strm << "BadNumber"; + else + return strm << w.Value(); +} + +template +inline istream &operator>>(istream &strm, FloatWeightTpl &w) { + string s; + strm >> s; + if (s == "Infinity") { + w = FloatWeightTpl(FloatLimits::PosInfinity()); + } else if (s == "-Infinity") { + w = FloatWeightTpl(FloatLimits::NegInfinity()); + } else { + char *p; + T f = strtod(s.c_str(), &p); + if (p < s.c_str() + s.size()) + strm.clear(std::ios::badbit); + else + w = FloatWeightTpl(f); + } + return strm; +} + + +// Tropical semiring: (min, +, inf, 0) +template +class TropicalWeightTpl : public FloatWeightTpl { + public: + using FloatWeightTpl::Value; + + typedef TropicalWeightTpl ReverseWeight; + + TropicalWeightTpl() : FloatWeightTpl() {} + + TropicalWeightTpl(T f) : FloatWeightTpl(f) {} + + TropicalWeightTpl(const TropicalWeightTpl &w) : FloatWeightTpl(w) {} + + static const TropicalWeightTpl Zero() { + return TropicalWeightTpl(FloatLimits::PosInfinity()); } + + static const TropicalWeightTpl One() { + return TropicalWeightTpl(0.0F); } + + static const TropicalWeightTpl NoWeight() { + return TropicalWeightTpl(FloatLimits::NumberBad()); } + + static const string &Type() { + static const string type = "tropical" + + FloatWeightTpl::GetPrecisionString(); + return type; + } + + bool Member() const { + // First part fails for IEEE NaN + return Value() == Value() && Value() != FloatLimits::NegInfinity(); + } + + TropicalWeightTpl Quantize(float delta = kDelta) const { + if (Value() == FloatLimits::NegInfinity() || + Value() == FloatLimits::PosInfinity() || + Value() != Value()) + return *this; + else + return TropicalWeightTpl(floor(Value()/delta + 0.5F) * delta); + } + + TropicalWeightTpl Reverse() const { return *this; } + + static uint64 Properties() { + return kLeftSemiring | kRightSemiring | kCommutative | + kPath | kIdempotent; + } +}; + +// Single precision tropical weight +typedef TropicalWeightTpl TropicalWeight; + +template +inline TropicalWeightTpl Plus(const TropicalWeightTpl &w1, + const TropicalWeightTpl &w2) { + if (!w1.Member() || !w2.Member()) + return TropicalWeightTpl::NoWeight(); + return w1.Value() < w2.Value() ? w1 : w2; +} + +inline TropicalWeightTpl Plus(const TropicalWeightTpl &w1, + const TropicalWeightTpl &w2) { + return Plus(w1, w2); +} + +inline TropicalWeightTpl Plus(const TropicalWeightTpl &w1, + const TropicalWeightTpl &w2) { + return Plus(w1, w2); +} + +template +inline TropicalWeightTpl Times(const TropicalWeightTpl &w1, + const TropicalWeightTpl &w2) { + if (!w1.Member() || !w2.Member()) + return TropicalWeightTpl::NoWeight(); + T f1 = w1.Value(), f2 = w2.Value(); + if (f1 == FloatLimits::PosInfinity()) + return w1; + else if (f2 == FloatLimits::PosInfinity()) + return w2; + else + return TropicalWeightTpl(f1 + f2); +} + +inline TropicalWeightTpl Times(const TropicalWeightTpl &w1, + const TropicalWeightTpl &w2) { + return Times(w1, w2); +} + +inline TropicalWeightTpl Times(const TropicalWeightTpl &w1, + const TropicalWeightTpl &w2) { + return Times(w1, w2); +} + +template +inline TropicalWeightTpl Divide(const TropicalWeightTpl &w1, + const TropicalWeightTpl &w2, + DivideType typ = DIVIDE_ANY) { + if (!w1.Member() || !w2.Member()) + return TropicalWeightTpl::NoWeight(); + T f1 = w1.Value(), f2 = w2.Value(); + if (f2 == FloatLimits::PosInfinity()) + return FloatLimits::NumberBad(); + else if (f1 == FloatLimits::PosInfinity()) + return FloatLimits::PosInfinity(); + else + return TropicalWeightTpl(f1 - f2); +} + +inline TropicalWeightTpl Divide(const TropicalWeightTpl &w1, + const TropicalWeightTpl &w2, + DivideType typ = DIVIDE_ANY) { + return Divide(w1, w2, typ); +} + +inline TropicalWeightTpl Divide(const TropicalWeightTpl &w1, + const TropicalWeightTpl &w2, + DivideType typ = DIVIDE_ANY) { + return Divide(w1, w2, typ); +} + + +// Log semiring: (log(e^-x + e^y), +, inf, 0) +template +class LogWeightTpl : public FloatWeightTpl { + public: + using FloatWeightTpl::Value; + + typedef LogWeightTpl ReverseWeight; + + LogWeightTpl() : FloatWeightTpl() {} + + LogWeightTpl(T f) : FloatWeightTpl(f) {} + + LogWeightTpl(const LogWeightTpl &w) : FloatWeightTpl(w) {} + + static const LogWeightTpl Zero() { + return LogWeightTpl(FloatLimits::PosInfinity()); + } + + static const LogWeightTpl One() { + return LogWeightTpl(0.0F); + } + + static const LogWeightTpl NoWeight() { + return LogWeightTpl(FloatLimits::NumberBad()); } + + static const string &Type() { + static const string type = "log" + FloatWeightTpl::GetPrecisionString(); + return type; + } + + bool Member() const { + // First part fails for IEEE NaN + return Value() == Value() && Value() != FloatLimits::NegInfinity(); + } + + LogWeightTpl Quantize(float delta = kDelta) const { + if (Value() == FloatLimits::NegInfinity() || + Value() == FloatLimits::PosInfinity() || + Value() != Value()) + return *this; + else + return LogWeightTpl(floor(Value()/delta + 0.5F) * delta); + } + + LogWeightTpl Reverse() const { return *this; } + + static uint64 Properties() { + return kLeftSemiring | kRightSemiring | kCommutative; + } +}; + +// Single-precision log weight +typedef LogWeightTpl LogWeight; +// Double-precision log weight +typedef LogWeightTpl Log64Weight; + +template +inline T LogExp(T x) { return log(1.0F + exp(-x)); } + +template +inline LogWeightTpl Plus(const LogWeightTpl &w1, + const LogWeightTpl &w2) { + T f1 = w1.Value(), f2 = w2.Value(); + if (f1 == FloatLimits::PosInfinity()) + return w2; + else if (f2 == FloatLimits::PosInfinity()) + return w1; + else if (f1 > f2) + return LogWeightTpl(f2 - LogExp(f1 - f2)); + else + return LogWeightTpl(f1 - LogExp(f2 - f1)); +} + +inline LogWeightTpl Plus(const LogWeightTpl &w1, + const LogWeightTpl &w2) { + return Plus(w1, w2); +} + +inline LogWeightTpl Plus(const LogWeightTpl &w1, + const LogWeightTpl &w2) { + return Plus(w1, w2); +} + +template +inline LogWeightTpl Times(const LogWeightTpl &w1, + const LogWeightTpl &w2) { + if (!w1.Member() || !w2.Member()) + return LogWeightTpl::NoWeight(); + T f1 = w1.Value(), f2 = w2.Value(); + if (f1 == FloatLimits::PosInfinity()) + return w1; + else if (f2 == FloatLimits::PosInfinity()) + return w2; + else + return LogWeightTpl(f1 + f2); +} + +inline LogWeightTpl Times(const LogWeightTpl &w1, + const LogWeightTpl &w2) { + return Times(w1, w2); +} + +inline LogWeightTpl Times(const LogWeightTpl &w1, + const LogWeightTpl &w2) { + return Times(w1, w2); +} + +template +inline LogWeightTpl Divide(const LogWeightTpl &w1, + const LogWeightTpl &w2, + DivideType typ = DIVIDE_ANY) { + if (!w1.Member() || !w2.Member()) + return LogWeightTpl::NoWeight(); + T f1 = w1.Value(), f2 = w2.Value(); + if (f2 == FloatLimits::PosInfinity()) + return FloatLimits::NumberBad(); + else if (f1 == FloatLimits::PosInfinity()) + return FloatLimits::PosInfinity(); + else + return LogWeightTpl(f1 - f2); +} + +inline LogWeightTpl Divide(const LogWeightTpl &w1, + const LogWeightTpl &w2, + DivideType typ = DIVIDE_ANY) { + return Divide(w1, w2, typ); +} + +inline LogWeightTpl Divide(const LogWeightTpl &w1, + const LogWeightTpl &w2, + DivideType typ = DIVIDE_ANY) { + return Divide(w1, w2, typ); +} + +// MinMax semiring: (min, max, inf, -inf) +template +class MinMaxWeightTpl : public FloatWeightTpl { + public: + using FloatWeightTpl::Value; + + typedef MinMaxWeightTpl ReverseWeight; + + MinMaxWeightTpl() : FloatWeightTpl() {} + + MinMaxWeightTpl(T f) : FloatWeightTpl(f) {} + + MinMaxWeightTpl(const MinMaxWeightTpl &w) : FloatWeightTpl(w) {} + + static const MinMaxWeightTpl Zero() { + return MinMaxWeightTpl(FloatLimits::PosInfinity()); + } + + static const MinMaxWeightTpl One() { + return MinMaxWeightTpl(FloatLimits::NegInfinity()); + } + + static const MinMaxWeightTpl NoWeight() { + return MinMaxWeightTpl(FloatLimits::NumberBad()); } + + static const string &Type() { + static const string type = "minmax" + + FloatWeightTpl::GetPrecisionString(); + return type; + } + + bool Member() const { + // Fails for IEEE NaN + return Value() == Value(); + } + + MinMaxWeightTpl Quantize(float delta = kDelta) const { + // If one of infinities, or a NaN + if (Value() == FloatLimits::NegInfinity() || + Value() == FloatLimits::PosInfinity() || + Value() != Value()) + return *this; + else + return MinMaxWeightTpl(floor(Value()/delta + 0.5F) * delta); + } + + MinMaxWeightTpl Reverse() const { return *this; } + + static uint64 Properties() { + return kLeftSemiring | kRightSemiring | kCommutative | kIdempotent | kPath; + } +}; + +// Single-precision min-max weight +typedef MinMaxWeightTpl MinMaxWeight; + +// Min +template +inline MinMaxWeightTpl Plus( + const MinMaxWeightTpl &w1, const MinMaxWeightTpl &w2) { + if (!w1.Member() || !w2.Member()) + return MinMaxWeightTpl::NoWeight(); + return w1.Value() < w2.Value() ? w1 : w2; +} + +inline MinMaxWeightTpl Plus( + const MinMaxWeightTpl &w1, const MinMaxWeightTpl &w2) { + return Plus(w1, w2); +} + +inline MinMaxWeightTpl Plus( + const MinMaxWeightTpl &w1, const MinMaxWeightTpl &w2) { + return Plus(w1, w2); +} + +// Max +template +inline MinMaxWeightTpl Times( + const MinMaxWeightTpl &w1, const MinMaxWeightTpl &w2) { + if (!w1.Member() || !w2.Member()) + return MinMaxWeightTpl::NoWeight(); + return w1.Value() >= w2.Value() ? w1 : w2; +} + +inline MinMaxWeightTpl Times( + const MinMaxWeightTpl &w1, const MinMaxWeightTpl &w2) { + return Times(w1, w2); +} + +inline MinMaxWeightTpl Times( + const MinMaxWeightTpl &w1, const MinMaxWeightTpl &w2) { + return Times(w1, w2); +} + +// Defined only for special cases +template +inline MinMaxWeightTpl Divide(const MinMaxWeightTpl &w1, + const MinMaxWeightTpl &w2, + DivideType typ = DIVIDE_ANY) { + if (!w1.Member() || !w2.Member()) + return MinMaxWeightTpl::NoWeight(); + // min(w1, x) = w2, w1 >= w2 => min(w1, x) = w2, x = w2 + return w1.Value() >= w2.Value() ? w1 : FloatLimits::NumberBad(); +} + +inline MinMaxWeightTpl Divide(const MinMaxWeightTpl &w1, + const MinMaxWeightTpl &w2, + DivideType typ = DIVIDE_ANY) { + return Divide(w1, w2, typ); +} + +inline MinMaxWeightTpl Divide(const MinMaxWeightTpl &w1, + const MinMaxWeightTpl &w2, + DivideType typ = DIVIDE_ANY) { + return Divide(w1, w2, typ); +} + +// +// WEIGHT CONVERTER SPECIALIZATIONS. +// + +// Convert to tropical +template <> +struct WeightConvert { + TropicalWeight operator()(LogWeight w) const { return w.Value(); } +}; + +template <> +struct WeightConvert { + TropicalWeight operator()(Log64Weight w) const { return w.Value(); } +}; + +// Convert to log +template <> +struct WeightConvert { + LogWeight operator()(TropicalWeight w) const { return w.Value(); } +}; + +template <> +struct WeightConvert { + LogWeight operator()(Log64Weight w) const { return w.Value(); } +}; + +// Convert to log64 +template <> +struct WeightConvert { + Log64Weight operator()(TropicalWeight w) const { return w.Value(); } +}; + +template <> +struct WeightConvert { + Log64Weight operator()(LogWeight w) const { return w.Value(); } +}; + +} // namespace fst + +#endif // FST_LIB_FLOAT_WEIGHT_H__ -- cgit v1.2.3