diff options
Diffstat (limited to 'kaldi_io/src/tools/openfst/include/fst/signed-log-weight.h')
-rw-r--r-- | kaldi_io/src/tools/openfst/include/fst/signed-log-weight.h | 367 |
1 files changed, 367 insertions, 0 deletions
diff --git a/kaldi_io/src/tools/openfst/include/fst/signed-log-weight.h b/kaldi_io/src/tools/openfst/include/fst/signed-log-weight.h new file mode 100644 index 0000000..61adefb --- /dev/null +++ b/kaldi_io/src/tools/openfst/include/fst/signed-log-weight.h @@ -0,0 +1,367 @@ + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Copyright 2005-2010 Google, Inc. +// Author: [email protected] (Kasturi Rangan Raghavan) +// \file +// LogWeight along with sign information that represents the value X in the +// linear domain as <sign(X), -ln(|X|)> +// The sign is a TropicalWeight: +// positive, TropicalWeight.Value() > 0.0, recommended value 1.0 +// negative, TropicalWeight.Value() <= 0.0, recommended value -1.0 + +#ifndef FST_LIB_SIGNED_LOG_WEIGHT_H_ +#define FST_LIB_SIGNED_LOG_WEIGHT_H_ + +#include <fst/float-weight.h> +#include <fst/pair-weight.h> + + +namespace fst { +template <class T> +class SignedLogWeightTpl + : public PairWeight<TropicalWeight, LogWeightTpl<T> > { + public: + typedef TropicalWeight X1; + typedef LogWeightTpl<T> X2; + using PairWeight<X1, X2>::Value1; + using PairWeight<X1, X2>::Value2; + + using PairWeight<X1, X2>::Reverse; + using PairWeight<X1, X2>::Quantize; + using PairWeight<X1, X2>::Member; + + typedef SignedLogWeightTpl<T> ReverseWeight; + + SignedLogWeightTpl() : PairWeight<X1, X2>() {} + + SignedLogWeightTpl(const SignedLogWeightTpl<T>& w) + : PairWeight<X1, X2> (w) { } + + SignedLogWeightTpl(const PairWeight<X1, X2>& w) + : PairWeight<X1, X2> (w) { } + + SignedLogWeightTpl(const X1& x1, const X2& x2) + : PairWeight<X1, X2>(x1, x2) { } + + static const SignedLogWeightTpl<T> &Zero() { + static const SignedLogWeightTpl<T> zero(X1(1.0), X2::Zero()); + return zero; + } + + static const SignedLogWeightTpl<T> &One() { + static const SignedLogWeightTpl<T> one(X1(1.0), X2::One()); + return one; + } + + static const SignedLogWeightTpl<T> &NoWeight() { + static const SignedLogWeightTpl<T> no_weight(X1(1.0), X2::NoWeight()); + return no_weight; + } + + static const string &Type() { + static const string type = "signed_log_" + X1::Type() + "_" + X2::Type(); + return type; + } + + ProductWeight<X1, X2> Quantize(float delta = kDelta) const { + return PairWeight<X1, X2>::Quantize(); + } + + ReverseWeight Reverse() const { + return PairWeight<X1, X2>::Reverse(); + } + + bool Member() const { + return PairWeight<X1, X2>::Member(); + } + + static uint64 Properties() { + // not idempotent nor path + return kLeftSemiring | kRightSemiring | kCommutative; + } + + size_t Hash() const { + size_t h1; + if (Value2() == X2::Zero() || Value1().Value() > 0.0) + h1 = TropicalWeight(1.0).Hash(); + else + h1 = TropicalWeight(-1.0).Hash(); + size_t h2 = Value2().Hash(); + const int lshift = 5; + const int rshift = CHAR_BIT * sizeof(size_t) - 5; + return h1 << lshift ^ h1 >> rshift ^ h2; + } +}; + +template <class T> +inline SignedLogWeightTpl<T> Plus(const SignedLogWeightTpl<T> &w1, + const SignedLogWeightTpl<T> &w2) { + if (!w1.Member() || !w2.Member()) + return SignedLogWeightTpl<T>::NoWeight(); + bool s1 = w1.Value1().Value() > 0.0; + bool s2 = w2.Value1().Value() > 0.0; + T f1 = w1.Value2().Value(); + T f2 = w2.Value2().Value(); + if (f1 == FloatLimits<T>::PosInfinity()) + return w2; + else if (f2 == FloatLimits<T>::PosInfinity()) + return w1; + else if (f1 == f2) { + if (s1 == s2) + return SignedLogWeightTpl<T>(w1.Value1(), (f2 - log(2.0F))); + else + return SignedLogWeightTpl<T>::Zero(); + } else if (f1 > f2) { + if (s1 == s2) { + return SignedLogWeightTpl<T>( + w1.Value1(), (f2 - log(1.0F + exp(f2 - f1)))); + } else { + return SignedLogWeightTpl<T>( + w2.Value1(), (f2 - log(1.0F - exp(f2 - f1)))); + } + } else { + if (s2 == s1) { + return SignedLogWeightTpl<T>( + w2.Value1(), (f1 - log(1.0F + exp(f1 - f2)))); + } else { + return SignedLogWeightTpl<T>( + w1.Value1(), (f1 - log(1.0F - exp(f1 - f2)))); + } + } +} + +template <class T> +inline SignedLogWeightTpl<T> Minus(const SignedLogWeightTpl<T> &w1, + const SignedLogWeightTpl<T> &w2) { + SignedLogWeightTpl<T> minus_w2(-w2.Value1().Value(), w2.Value2()); + return Plus(w1, minus_w2); +} + +template <class T> +inline SignedLogWeightTpl<T> Times(const SignedLogWeightTpl<T> &w1, + const SignedLogWeightTpl<T> &w2) { + if (!w1.Member() || !w2.Member()) + return SignedLogWeightTpl<T>::NoWeight(); + bool s1 = w1.Value1().Value() > 0.0; + bool s2 = w2.Value1().Value() > 0.0; + T f1 = w1.Value2().Value(); + T f2 = w2.Value2().Value(); + if (s1 == s2) + return SignedLogWeightTpl<T>(TropicalWeight(1.0), (f1 + f2)); + else + return SignedLogWeightTpl<T>(TropicalWeight(-1.0), (f1 + f2)); +} + +template <class T> +inline SignedLogWeightTpl<T> Divide(const SignedLogWeightTpl<T> &w1, + const SignedLogWeightTpl<T> &w2, + DivideType typ = DIVIDE_ANY) { + if (!w1.Member() || !w2.Member()) + return SignedLogWeightTpl<T>::NoWeight(); + bool s1 = w1.Value1().Value() > 0.0; + bool s2 = w2.Value1().Value() > 0.0; + T f1 = w1.Value2().Value(); + T f2 = w2.Value2().Value(); + if (f2 == FloatLimits<T>::PosInfinity()) + return SignedLogWeightTpl<T>(TropicalWeight(1.0), + FloatLimits<T>::NumberBad()); + else if (f1 == FloatLimits<T>::PosInfinity()) + return SignedLogWeightTpl<T>(TropicalWeight(1.0), + FloatLimits<T>::PosInfinity()); + else if (s1 == s2) + return SignedLogWeightTpl<T>(TropicalWeight(1.0), (f1 - f2)); + else + return SignedLogWeightTpl<T>(TropicalWeight(-1.0), (f1 - f2)); +} + +template <class T> +inline bool ApproxEqual(const SignedLogWeightTpl<T> &w1, + const SignedLogWeightTpl<T> &w2, + float delta = kDelta) { + bool s1 = w1.Value1().Value() > 0.0; + bool s2 = w2.Value1().Value() > 0.0; + if (s1 == s2) { + return ApproxEqual(w1.Value2(), w2.Value2(), delta); + } else { + return w1.Value2() == LogWeightTpl<T>::Zero() + && w2.Value2() == LogWeightTpl<T>::Zero(); + } +} + +template <class T> +inline bool operator==(const SignedLogWeightTpl<T> &w1, + const SignedLogWeightTpl<T> &w2) { + bool s1 = w1.Value1().Value() > 0.0; + bool s2 = w2.Value1().Value() > 0.0; + if (s1 == s2) + return w1.Value2() == w2.Value2(); + else + return (w1.Value2() == LogWeightTpl<T>::Zero()) && + (w2.Value2() == LogWeightTpl<T>::Zero()); +} + + +// Single-precision signed-log weight +typedef SignedLogWeightTpl<float> SignedLogWeight; +// Double-precision signed-log weight +typedef SignedLogWeightTpl<double> SignedLog64Weight; + +// +// WEIGHT CONVERTER SPECIALIZATIONS. +// + +template <class W1, class W2> +bool SignedLogConvertCheck(W1 w) { + if (w.Value1().Value() < 0.0) { + FSTERROR() << "WeightConvert: can't convert weight from \"" + << W1::Type() << "\" to \"" << W2::Type(); + return false; + } + return true; +} + +// Convert to tropical +template <> +struct WeightConvert<SignedLogWeight, TropicalWeight> { + TropicalWeight operator()(SignedLogWeight w) const { + if (!SignedLogConvertCheck<SignedLogWeight, TropicalWeight>(w)) + return TropicalWeight::NoWeight(); + return w.Value2().Value(); + } +}; + +template <> +struct WeightConvert<SignedLog64Weight, TropicalWeight> { + TropicalWeight operator()(SignedLog64Weight w) const { + if (!SignedLogConvertCheck<SignedLog64Weight, TropicalWeight>(w)) + return TropicalWeight::NoWeight(); + return w.Value2().Value(); + } +}; + +// Convert to log +template <> +struct WeightConvert<SignedLogWeight, LogWeight> { + LogWeight operator()(SignedLogWeight w) const { + if (!SignedLogConvertCheck<SignedLogWeight, LogWeight>(w)) + return LogWeight::NoWeight(); + return w.Value2().Value(); + } +}; + +template <> +struct WeightConvert<SignedLog64Weight, LogWeight> { + LogWeight operator()(SignedLog64Weight w) const { + if (!SignedLogConvertCheck<SignedLog64Weight, LogWeight>(w)) + return LogWeight::NoWeight(); + return w.Value2().Value(); + } +}; + +// Convert to log64 +template <> +struct WeightConvert<SignedLogWeight, Log64Weight> { + Log64Weight operator()(SignedLogWeight w) const { + if (!SignedLogConvertCheck<SignedLogWeight, Log64Weight>(w)) + return Log64Weight::NoWeight(); + return w.Value2().Value(); + } +}; + +template <> +struct WeightConvert<SignedLog64Weight, Log64Weight> { + Log64Weight operator()(SignedLog64Weight w) const { + if (!SignedLogConvertCheck<SignedLog64Weight, Log64Weight>(w)) + return Log64Weight::NoWeight(); + return w.Value2().Value(); + } +}; + +// Convert to signed log +template <> +struct WeightConvert<TropicalWeight, SignedLogWeight> { + SignedLogWeight operator()(TropicalWeight w) const { + TropicalWeight x1 = 1.0; + LogWeight x2 = w.Value(); + return SignedLogWeight(x1, x2); + } +}; + +template <> +struct WeightConvert<LogWeight, SignedLogWeight> { + SignedLogWeight operator()(LogWeight w) const { + TropicalWeight x1 = 1.0; + LogWeight x2 = w.Value(); + return SignedLogWeight(x1, x2); + } +}; + +template <> +struct WeightConvert<Log64Weight, SignedLogWeight> { + SignedLogWeight operator()(Log64Weight w) const { + TropicalWeight x1 = 1.0; + LogWeight x2 = w.Value(); + return SignedLogWeight(x1, x2); + } +}; + +template <> +struct WeightConvert<SignedLog64Weight, SignedLogWeight> { + SignedLogWeight operator()(SignedLog64Weight w) const { + TropicalWeight x1 = w.Value1(); + LogWeight x2 = w.Value2().Value(); + return SignedLogWeight(x1, x2); + } +}; + +// Convert to signed log64 +template <> +struct WeightConvert<TropicalWeight, SignedLog64Weight> { + SignedLog64Weight operator()(TropicalWeight w) const { + TropicalWeight x1 = 1.0; + Log64Weight x2 = w.Value(); + return SignedLog64Weight(x1, x2); + } +}; + +template <> +struct WeightConvert<LogWeight, SignedLog64Weight> { + SignedLog64Weight operator()(LogWeight w) const { + TropicalWeight x1 = 1.0; + Log64Weight x2 = w.Value(); + return SignedLog64Weight(x1, x2); + } +}; + +template <> +struct WeightConvert<Log64Weight, SignedLog64Weight> { + SignedLog64Weight operator()(Log64Weight w) const { + TropicalWeight x1 = 1.0; + Log64Weight x2 = w.Value(); + return SignedLog64Weight(x1, x2); + } +}; + +template <> +struct WeightConvert<SignedLogWeight, SignedLog64Weight> { + SignedLog64Weight operator()(SignedLogWeight w) const { + TropicalWeight x1 = w.Value1(); + Log64Weight x2 = w.Value2().Value(); + return SignedLog64Weight(x1, x2); + } +}; + +} // namespace fst + +#endif // FST_LIB_SIGNED_LOG_WEIGHT_H_ |