diff options
Diffstat (limited to 'kaldi_io/src/tools/openfst/include/fst/factor-weight.h')
-rw-r--r-- | kaldi_io/src/tools/openfst/include/fst/factor-weight.h | 475 |
1 files changed, 0 insertions, 475 deletions
diff --git a/kaldi_io/src/tools/openfst/include/fst/factor-weight.h b/kaldi_io/src/tools/openfst/include/fst/factor-weight.h deleted file mode 100644 index 685155c..0000000 --- a/kaldi_io/src/tools/openfst/include/fst/factor-weight.h +++ /dev/null @@ -1,475 +0,0 @@ -// factor-weight.h - -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// -// Copyright 2005-2010 Google, Inc. -// Author: [email protected] (Cyril Allauzen) -// -// \file -// Classes to factor weights in an FST. - -#ifndef FST_LIB_FACTOR_WEIGHT_H__ -#define FST_LIB_FACTOR_WEIGHT_H__ - -#include <algorithm> -#include <tr1/unordered_map> -using std::tr1::unordered_map; -using std::tr1::unordered_multimap; -#include <string> -#include <utility> -using std::pair; using std::make_pair; -#include <vector> -using std::vector; - -#include <fst/cache.h> -#include <fst/test-properties.h> - - -namespace fst { - -const uint32 kFactorFinalWeights = 0x00000001; -const uint32 kFactorArcWeights = 0x00000002; - -template <class Arc> -struct FactorWeightOptions : CacheOptions { - typedef typename Arc::Label Label; - float delta; - uint32 mode; // factor arc weights and/or final weights - Label final_ilabel; // input label of arc created when factoring final w's - Label final_olabel; // output label of arc created when factoring final w's - - FactorWeightOptions(const CacheOptions &opts, float d, - uint32 m = kFactorArcWeights | kFactorFinalWeights, - Label il = 0, Label ol = 0) - : CacheOptions(opts), delta(d), mode(m), final_ilabel(il), - final_olabel(ol) {} - - explicit FactorWeightOptions( - float d, uint32 m = kFactorArcWeights | kFactorFinalWeights, - Label il = 0, Label ol = 0) - : delta(d), mode(m), final_ilabel(il), final_olabel(ol) {} - - FactorWeightOptions(uint32 m = kFactorArcWeights | kFactorFinalWeights, - Label il = 0, Label ol = 0) - : delta(kDelta), mode(m), final_ilabel(il), final_olabel(ol) {} -}; - - -// A factor iterator takes as argument a weight w and returns a -// sequence of pairs of weights (xi,yi) such that the sum of the -// products xi times yi is equal to w. If w is fully factored, -// the iterator should return nothing. -// -// template <class W> -// class FactorIterator { -// public: -// FactorIterator(W w); -// bool Done() const; -// void Next(); -// pair<W, W> Value() const; -// void Reset(); -// } - - -// Factor trivially. -template <class W> -class IdentityFactor { - public: - IdentityFactor(const W &w) {} - bool Done() const { return true; } - void Next() {} - pair<W, W> Value() const { return make_pair(W::One(), W::One()); } // unused - void Reset() {} -}; - - -// Factor a StringWeight w as 'ab' where 'a' is a label. -template <typename L, StringType S = STRING_LEFT> -class StringFactor { - public: - StringFactor(const StringWeight<L, S> &w) - : weight_(w), done_(w.Size() <= 1) {} - - bool Done() const { return done_; } - - void Next() { done_ = true; } - - pair< StringWeight<L, S>, StringWeight<L, S> > Value() const { - StringWeightIterator<L, S> iter(weight_); - StringWeight<L, S> w1(iter.Value()); - StringWeight<L, S> w2; - for (iter.Next(); !iter.Done(); iter.Next()) - w2.PushBack(iter.Value()); - return make_pair(w1, w2); - } - - void Reset() { done_ = weight_.Size() <= 1; } - - private: - StringWeight<L, S> weight_; - bool done_; -}; - - -// Factor a GallicWeight using StringFactor. -template <class L, class W, StringType S = STRING_LEFT> -class GallicFactor { - public: - GallicFactor(const GallicWeight<L, W, S> &w) - : weight_(w), done_(w.Value1().Size() <= 1) {} - - bool Done() const { return done_; } - - void Next() { done_ = true; } - - pair< GallicWeight<L, W, S>, GallicWeight<L, W, S> > Value() const { - StringFactor<L, S> iter(weight_.Value1()); - GallicWeight<L, W, S> w1(iter.Value().first, weight_.Value2()); - GallicWeight<L, W, S> w2(iter.Value().second, W::One()); - return make_pair(w1, w2); - } - - void Reset() { done_ = weight_.Value1().Size() <= 1; } - - private: - GallicWeight<L, W, S> weight_; - bool done_; -}; - - -// Implementation class for FactorWeight -template <class A, class F> -class FactorWeightFstImpl - : public CacheImpl<A> { - public: - using FstImpl<A>::SetType; - using FstImpl<A>::SetProperties; - using FstImpl<A>::SetInputSymbols; - using FstImpl<A>::SetOutputSymbols; - - using CacheBaseImpl< CacheState<A> >::PushArc; - using CacheBaseImpl< CacheState<A> >::HasStart; - using CacheBaseImpl< CacheState<A> >::HasFinal; - using CacheBaseImpl< CacheState<A> >::HasArcs; - using CacheBaseImpl< CacheState<A> >::SetArcs; - using CacheBaseImpl< CacheState<A> >::SetFinal; - using CacheBaseImpl< CacheState<A> >::SetStart; - - typedef A Arc; - typedef typename A::Label Label; - typedef typename A::Weight Weight; - typedef typename A::StateId StateId; - typedef F FactorIterator; - - struct Element { - Element() {} - - Element(StateId s, Weight w) : state(s), weight(w) {} - - StateId state; // Input state Id - Weight weight; // Residual weight - }; - - FactorWeightFstImpl(const Fst<A> &fst, const FactorWeightOptions<A> &opts) - : CacheImpl<A>(opts), - fst_(fst.Copy()), - delta_(opts.delta), - mode_(opts.mode), - final_ilabel_(opts.final_ilabel), - final_olabel_(opts.final_olabel) { - SetType("factor_weight"); - uint64 props = fst.Properties(kFstProperties, false); - SetProperties(FactorWeightProperties(props), kCopyProperties); - - SetInputSymbols(fst.InputSymbols()); - SetOutputSymbols(fst.OutputSymbols()); - - if (mode_ == 0) - LOG(WARNING) << "FactorWeightFst: factor mode is set to 0: " - << "factoring neither arc weights nor final weights."; - } - - FactorWeightFstImpl(const FactorWeightFstImpl<A, F> &impl) - : CacheImpl<A>(impl), - fst_(impl.fst_->Copy(true)), - delta_(impl.delta_), - mode_(impl.mode_), - final_ilabel_(impl.final_ilabel_), - final_olabel_(impl.final_olabel_) { - SetType("factor_weight"); - SetProperties(impl.Properties(), kCopyProperties); - SetInputSymbols(impl.InputSymbols()); - SetOutputSymbols(impl.OutputSymbols()); - } - - ~FactorWeightFstImpl() { - delete fst_; - } - - StateId Start() { - if (!HasStart()) { - StateId s = fst_->Start(); - if (s == kNoStateId) - return kNoStateId; - StateId start = FindState(Element(fst_->Start(), Weight::One())); - SetStart(start); - } - return CacheImpl<A>::Start(); - } - - Weight Final(StateId s) { - if (!HasFinal(s)) { - const Element &e = elements_[s]; - // TODO: fix so cast is unnecessary - Weight w = e.state == kNoStateId - ? e.weight - : (Weight) Times(e.weight, fst_->Final(e.state)); - FactorIterator f(w); - if (!(mode_ & kFactorFinalWeights) || f.Done()) - SetFinal(s, w); - else - SetFinal(s, Weight::Zero()); - } - return CacheImpl<A>::Final(s); - } - - size_t NumArcs(StateId s) { - if (!HasArcs(s)) - Expand(s); - return CacheImpl<A>::NumArcs(s); - } - - size_t NumInputEpsilons(StateId s) { - if (!HasArcs(s)) - Expand(s); - return CacheImpl<A>::NumInputEpsilons(s); - } - - size_t NumOutputEpsilons(StateId s) { - if (!HasArcs(s)) - Expand(s); - return CacheImpl<A>::NumOutputEpsilons(s); - } - - uint64 Properties() const { return Properties(kFstProperties); } - - // Set error if found; return FST impl properties. - uint64 Properties(uint64 mask) const { - if ((mask & kError) && fst_->Properties(kError, false)) - SetProperties(kError, kError); - return FstImpl<Arc>::Properties(mask); - } - - void InitArcIterator(StateId s, ArcIteratorData<A> *data) { - if (!HasArcs(s)) - Expand(s); - CacheImpl<A>::InitArcIterator(s, data); - } - - - // Find state corresponding to an element. Create new state - // if element not found. - StateId FindState(const Element &e) { - if (!(mode_ & kFactorArcWeights) && e.weight == Weight::One()) { - while (unfactored_.size() <= e.state) - unfactored_.push_back(kNoStateId); - if (unfactored_[e.state] == kNoStateId) { - unfactored_[e.state] = elements_.size(); - elements_.push_back(e); - } - return unfactored_[e.state]; - } else { - typename ElementMap::iterator eit = element_map_.find(e); - if (eit != element_map_.end()) { - return (*eit).second; - } else { - StateId s = elements_.size(); - elements_.push_back(e); - element_map_.insert(pair<const Element, StateId>(e, s)); - return s; - } - } - } - - // Computes the outgoing transitions from a state, creating new destination - // states as needed. - void Expand(StateId s) { - Element e = elements_[s]; - if (e.state != kNoStateId) { - for (ArcIterator< Fst<A> > ait(*fst_, e.state); - !ait.Done(); - ait.Next()) { - const A &arc = ait.Value(); - Weight w = Times(e.weight, arc.weight); - FactorIterator fit(w); - if (!(mode_ & kFactorArcWeights) || fit.Done()) { - StateId d = FindState(Element(arc.nextstate, Weight::One())); - PushArc(s, Arc(arc.ilabel, arc.olabel, w, d)); - } else { - for (; !fit.Done(); fit.Next()) { - const pair<Weight, Weight> &p = fit.Value(); - StateId d = FindState(Element(arc.nextstate, - p.second.Quantize(delta_))); - PushArc(s, Arc(arc.ilabel, arc.olabel, p.first, d)); - } - } - } - } - - if ((mode_ & kFactorFinalWeights) && - ((e.state == kNoStateId) || - (fst_->Final(e.state) != Weight::Zero()))) { - Weight w = e.state == kNoStateId - ? e.weight - : Times(e.weight, fst_->Final(e.state)); - for (FactorIterator fit(w); - !fit.Done(); - fit.Next()) { - const pair<Weight, Weight> &p = fit.Value(); - StateId d = FindState(Element(kNoStateId, - p.second.Quantize(delta_))); - PushArc(s, Arc(final_ilabel_, final_olabel_, p.first, d)); - } - } - SetArcs(s); - } - - private: - static const size_t kPrime = 7853; - - // Equality function for Elements, assume weights have been quantized. - class ElementEqual { - public: - bool operator()(const Element &x, const Element &y) const { - return x.state == y.state && x.weight == y.weight; - } - }; - - // Hash function for Elements to Fst states. - class ElementKey { - public: - size_t operator()(const Element &x) const { - return static_cast<size_t>(x.state * kPrime + x.weight.Hash()); - } - private: - }; - - typedef unordered_map<Element, StateId, ElementKey, ElementEqual> ElementMap; - - const Fst<A> *fst_; - float delta_; - uint32 mode_; // factoring arc and/or final weights - Label final_ilabel_; // ilabel of arc created when factoring final w's - Label final_olabel_; // olabel of arc created when factoring final w's - vector<Element> elements_; // mapping Fst state to Elements - ElementMap element_map_; // mapping Elements to Fst state - // mapping between old/new 'StateId' for states that do not need to - // be factored when 'mode_' is '0' or 'kFactorFinalWeights' - vector<StateId> unfactored_; - - void operator=(const FactorWeightFstImpl<A, F> &); // disallow -}; - -template <class A, class F> const size_t FactorWeightFstImpl<A, F>::kPrime; - - -// FactorWeightFst takes as template parameter a FactorIterator as -// defined above. The result of weight factoring is a transducer -// equivalent to the input whose path weights have been factored -// according to the FactorIterator. States and transitions will be -// added as necessary. The algorithm is a generalization to arbitrary -// weights of the second step of the input epsilon-normalization -// algorithm due to Mohri, "Generic epsilon-removal and input -// epsilon-normalization algorithms for weighted transducers", -// International Journal of Computer Science 13(1): 129-143 (2002). -// -// This class attaches interface to implementation and handles -// reference counting, delegating most methods to ImplToFst. -template <class A, class F> -class FactorWeightFst : public ImplToFst< FactorWeightFstImpl<A, F> > { - public: - friend class ArcIterator< FactorWeightFst<A, F> >; - friend class StateIterator< FactorWeightFst<A, F> >; - - typedef A Arc; - typedef typename A::Weight Weight; - typedef typename A::StateId StateId; - typedef CacheState<A> State; - typedef FactorWeightFstImpl<A, F> Impl; - - FactorWeightFst(const Fst<A> &fst) - : ImplToFst<Impl>(new Impl(fst, FactorWeightOptions<A>())) {} - - FactorWeightFst(const Fst<A> &fst, const FactorWeightOptions<A> &opts) - : ImplToFst<Impl>(new Impl(fst, opts)) {} - - // See Fst<>::Copy() for doc. - FactorWeightFst(const FactorWeightFst<A, F> &fst, bool copy) - : ImplToFst<Impl>(fst, copy) {} - - // Get a copy of this FactorWeightFst. See Fst<>::Copy() for further doc. - virtual FactorWeightFst<A, F> *Copy(bool copy = false) const { - return new FactorWeightFst<A, F>(*this, copy); - } - - virtual inline void InitStateIterator(StateIteratorData<A> *data) const; - - virtual void InitArcIterator(StateId s, ArcIteratorData<A> *data) const { - GetImpl()->InitArcIterator(s, data); - } - - private: - // Makes visible to friends. - Impl *GetImpl() const { return ImplToFst<Impl>::GetImpl(); } - - void operator=(const FactorWeightFst<A, F> &fst); // Disallow -}; - - -// Specialization for FactorWeightFst. -template<class A, class F> -class StateIterator< FactorWeightFst<A, F> > - : public CacheStateIterator< FactorWeightFst<A, F> > { - public: - explicit StateIterator(const FactorWeightFst<A, F> &fst) - : CacheStateIterator< FactorWeightFst<A, F> >(fst, fst.GetImpl()) {} -}; - - -// Specialization for FactorWeightFst. -template <class A, class F> -class ArcIterator< FactorWeightFst<A, F> > - : public CacheArcIterator< FactorWeightFst<A, F> > { - public: - typedef typename A::StateId StateId; - - ArcIterator(const FactorWeightFst<A, F> &fst, StateId s) - : CacheArcIterator< FactorWeightFst<A, F> >(fst.GetImpl(), s) { - if (!fst.GetImpl()->HasArcs(s)) - fst.GetImpl()->Expand(s); - } - - private: - DISALLOW_COPY_AND_ASSIGN(ArcIterator); -}; - -template <class A, class F> inline -void FactorWeightFst<A, F>::InitStateIterator(StateIteratorData<A> *data) const -{ - data->base = new StateIterator< FactorWeightFst<A, F> >(*this); -} - - -} // namespace fst - -#endif // FST_LIB_FACTOR_WEIGHT_H__ |