// factor-weight.h // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // // Copyright 2005-2010 Google, Inc. // Author: allauzen@google.com (Cyril Allauzen) // // \file // Classes to factor weights in an FST. #ifndef FST_LIB_FACTOR_WEIGHT_H__ #define FST_LIB_FACTOR_WEIGHT_H__ #include #include using std::tr1::unordered_map; using std::tr1::unordered_multimap; #include #include using std::pair; using std::make_pair; #include using std::vector; #include #include namespace fst { const uint32 kFactorFinalWeights = 0x00000001; const uint32 kFactorArcWeights = 0x00000002; template struct FactorWeightOptions : CacheOptions { typedef typename Arc::Label Label; float delta; uint32 mode; // factor arc weights and/or final weights Label final_ilabel; // input label of arc created when factoring final w's Label final_olabel; // output label of arc created when factoring final w's FactorWeightOptions(const CacheOptions &opts, float d, uint32 m = kFactorArcWeights | kFactorFinalWeights, Label il = 0, Label ol = 0) : CacheOptions(opts), delta(d), mode(m), final_ilabel(il), final_olabel(ol) {} explicit FactorWeightOptions( float d, uint32 m = kFactorArcWeights | kFactorFinalWeights, Label il = 0, Label ol = 0) : delta(d), mode(m), final_ilabel(il), final_olabel(ol) {} FactorWeightOptions(uint32 m = kFactorArcWeights | kFactorFinalWeights, Label il = 0, Label ol = 0) : delta(kDelta), mode(m), final_ilabel(il), final_olabel(ol) {} }; // A factor iterator takes as argument a weight w and returns a // sequence of pairs of weights (xi,yi) such that the sum of the // products xi times yi is equal to w. If w is fully factored, // the iterator should return nothing. // // template // class FactorIterator { // public: // FactorIterator(W w); // bool Done() const; // void Next(); // pair Value() const; // void Reset(); // } // Factor trivially. template class IdentityFactor { public: IdentityFactor(const W &w) {} bool Done() const { return true; } void Next() {} pair Value() const { return make_pair(W::One(), W::One()); } // unused void Reset() {} }; // Factor a StringWeight w as 'ab' where 'a' is a label. template class StringFactor { public: StringFactor(const StringWeight &w) : weight_(w), done_(w.Size() <= 1) {} bool Done() const { return done_; } void Next() { done_ = true; } pair< StringWeight, StringWeight > Value() const { StringWeightIterator iter(weight_); StringWeight w1(iter.Value()); StringWeight w2; for (iter.Next(); !iter.Done(); iter.Next()) w2.PushBack(iter.Value()); return make_pair(w1, w2); } void Reset() { done_ = weight_.Size() <= 1; } private: StringWeight weight_; bool done_; }; // Factor a GallicWeight using StringFactor. template class GallicFactor { public: GallicFactor(const GallicWeight &w) : weight_(w), done_(w.Value1().Size() <= 1) {} bool Done() const { return done_; } void Next() { done_ = true; } pair< GallicWeight, GallicWeight > Value() const { StringFactor iter(weight_.Value1()); GallicWeight w1(iter.Value().first, weight_.Value2()); GallicWeight w2(iter.Value().second, W::One()); return make_pair(w1, w2); } void Reset() { done_ = weight_.Value1().Size() <= 1; } private: GallicWeight weight_; bool done_; }; // Implementation class for FactorWeight template class FactorWeightFstImpl : public CacheImpl { public: using FstImpl::SetType; using FstImpl::SetProperties; using FstImpl::SetInputSymbols; using FstImpl::SetOutputSymbols; using CacheBaseImpl< CacheState >::PushArc; using CacheBaseImpl< CacheState >::HasStart; using CacheBaseImpl< CacheState >::HasFinal; using CacheBaseImpl< CacheState >::HasArcs; using CacheBaseImpl< CacheState >::SetArcs; using CacheBaseImpl< CacheState >::SetFinal; using CacheBaseImpl< CacheState >::SetStart; typedef A Arc; typedef typename A::Label Label; typedef typename A::Weight Weight; typedef typename A::StateId StateId; typedef F FactorIterator; struct Element { Element() {} Element(StateId s, Weight w) : state(s), weight(w) {} StateId state; // Input state Id Weight weight; // Residual weight }; FactorWeightFstImpl(const Fst &fst, const FactorWeightOptions &opts) : CacheImpl(opts), fst_(fst.Copy()), delta_(opts.delta), mode_(opts.mode), final_ilabel_(opts.final_ilabel), final_olabel_(opts.final_olabel) { SetType("factor_weight"); uint64 props = fst.Properties(kFstProperties, false); SetProperties(FactorWeightProperties(props), kCopyProperties); SetInputSymbols(fst.InputSymbols()); SetOutputSymbols(fst.OutputSymbols()); if (mode_ == 0) LOG(WARNING) << "FactorWeightFst: factor mode is set to 0: " << "factoring neither arc weights nor final weights."; } FactorWeightFstImpl(const FactorWeightFstImpl &impl) : CacheImpl(impl), fst_(impl.fst_->Copy(true)), delta_(impl.delta_), mode_(impl.mode_), final_ilabel_(impl.final_ilabel_), final_olabel_(impl.final_olabel_) { SetType("factor_weight"); SetProperties(impl.Properties(), kCopyProperties); SetInputSymbols(impl.InputSymbols()); SetOutputSymbols(impl.OutputSymbols()); } ~FactorWeightFstImpl() { delete fst_; } StateId Start() { if (!HasStart()) { StateId s = fst_->Start(); if (s == kNoStateId) return kNoStateId; StateId start = FindState(Element(fst_->Start(), Weight::One())); SetStart(start); } return CacheImpl::Start(); } Weight Final(StateId s) { if (!HasFinal(s)) { const Element &e = elements_[s]; // TODO: fix so cast is unnecessary Weight w = e.state == kNoStateId ? e.weight : (Weight) Times(e.weight, fst_->Final(e.state)); FactorIterator f(w); if (!(mode_ & kFactorFinalWeights) || f.Done()) SetFinal(s, w); else SetFinal(s, Weight::Zero()); } return CacheImpl::Final(s); } size_t NumArcs(StateId s) { if (!HasArcs(s)) Expand(s); return CacheImpl::NumArcs(s); } size_t NumInputEpsilons(StateId s) { if (!HasArcs(s)) Expand(s); return CacheImpl::NumInputEpsilons(s); } size_t NumOutputEpsilons(StateId s) { if (!HasArcs(s)) Expand(s); return CacheImpl::NumOutputEpsilons(s); } uint64 Properties() const { return Properties(kFstProperties); } // Set error if found; return FST impl properties. uint64 Properties(uint64 mask) const { if ((mask & kError) && fst_->Properties(kError, false)) SetProperties(kError, kError); return FstImpl::Properties(mask); } void InitArcIterator(StateId s, ArcIteratorData *data) { if (!HasArcs(s)) Expand(s); CacheImpl::InitArcIterator(s, data); } // Find state corresponding to an element. Create new state // if element not found. StateId FindState(const Element &e) { if (!(mode_ & kFactorArcWeights) && e.weight == Weight::One()) { while (unfactored_.size() <= e.state) unfactored_.push_back(kNoStateId); if (unfactored_[e.state] == kNoStateId) { unfactored_[e.state] = elements_.size(); elements_.push_back(e); } return unfactored_[e.state]; } else { typename ElementMap::iterator eit = element_map_.find(e); if (eit != element_map_.end()) { return (*eit).second; } else { StateId s = elements_.size(); elements_.push_back(e); element_map_.insert(pair(e, s)); return s; } } } // Computes the outgoing transitions from a state, creating new destination // states as needed. void Expand(StateId s) { Element e = elements_[s]; if (e.state != kNoStateId) { for (ArcIterator< Fst > ait(*fst_, e.state); !ait.Done(); ait.Next()) { const A &arc = ait.Value(); Weight w = Times(e.weight, arc.weight); FactorIterator fit(w); if (!(mode_ & kFactorArcWeights) || fit.Done()) { StateId d = FindState(Element(arc.nextstate, Weight::One())); PushArc(s, Arc(arc.ilabel, arc.olabel, w, d)); } else { for (; !fit.Done(); fit.Next()) { const pair &p = fit.Value(); StateId d = FindState(Element(arc.nextstate, p.second.Quantize(delta_))); PushArc(s, Arc(arc.ilabel, arc.olabel, p.first, d)); } } } } if ((mode_ & kFactorFinalWeights) && ((e.state == kNoStateId) || (fst_->Final(e.state) != Weight::Zero()))) { Weight w = e.state == kNoStateId ? e.weight : Times(e.weight, fst_->Final(e.state)); for (FactorIterator fit(w); !fit.Done(); fit.Next()) { const pair &p = fit.Value(); StateId d = FindState(Element(kNoStateId, p.second.Quantize(delta_))); PushArc(s, Arc(final_ilabel_, final_olabel_, p.first, d)); } } SetArcs(s); } private: static const size_t kPrime = 7853; // Equality function for Elements, assume weights have been quantized. class ElementEqual { public: bool operator()(const Element &x, const Element &y) const { return x.state == y.state && x.weight == y.weight; } }; // Hash function for Elements to Fst states. class ElementKey { public: size_t operator()(const Element &x) const { return static_cast(x.state * kPrime + x.weight.Hash()); } private: }; typedef unordered_map ElementMap; const Fst *fst_; float delta_; uint32 mode_; // factoring arc and/or final weights Label final_ilabel_; // ilabel of arc created when factoring final w's Label final_olabel_; // olabel of arc created when factoring final w's vector elements_; // mapping Fst state to Elements ElementMap element_map_; // mapping Elements to Fst state // mapping between old/new 'StateId' for states that do not need to // be factored when 'mode_' is '0' or 'kFactorFinalWeights' vector unfactored_; void operator=(const FactorWeightFstImpl &); // disallow }; template const size_t FactorWeightFstImpl::kPrime; // FactorWeightFst takes as template parameter a FactorIterator as // defined above. The result of weight factoring is a transducer // equivalent to the input whose path weights have been factored // according to the FactorIterator. States and transitions will be // added as necessary. The algorithm is a generalization to arbitrary // weights of the second step of the input epsilon-normalization // algorithm due to Mohri, "Generic epsilon-removal and input // epsilon-normalization algorithms for weighted transducers", // International Journal of Computer Science 13(1): 129-143 (2002). // // This class attaches interface to implementation and handles // reference counting, delegating most methods to ImplToFst. template class FactorWeightFst : public ImplToFst< FactorWeightFstImpl > { public: friend class ArcIterator< FactorWeightFst >; friend class StateIterator< FactorWeightFst >; typedef A Arc; typedef typename A::Weight Weight; typedef typename A::StateId StateId; typedef CacheState State; typedef FactorWeightFstImpl Impl; FactorWeightFst(const Fst &fst) : ImplToFst(new Impl(fst, FactorWeightOptions())) {} FactorWeightFst(const Fst &fst, const FactorWeightOptions &opts) : ImplToFst(new Impl(fst, opts)) {} // See Fst<>::Copy() for doc. FactorWeightFst(const FactorWeightFst &fst, bool copy) : ImplToFst(fst, copy) {} // Get a copy of this FactorWeightFst. See Fst<>::Copy() for further doc. virtual FactorWeightFst *Copy(bool copy = false) const { return new FactorWeightFst(*this, copy); } virtual inline void InitStateIterator(StateIteratorData *data) const; virtual void InitArcIterator(StateId s, ArcIteratorData *data) const { GetImpl()->InitArcIterator(s, data); } private: // Makes visible to friends. Impl *GetImpl() const { return ImplToFst::GetImpl(); } void operator=(const FactorWeightFst &fst); // Disallow }; // Specialization for FactorWeightFst. template class StateIterator< FactorWeightFst > : public CacheStateIterator< FactorWeightFst > { public: explicit StateIterator(const FactorWeightFst &fst) : CacheStateIterator< FactorWeightFst >(fst, fst.GetImpl()) {} }; // Specialization for FactorWeightFst. template class ArcIterator< FactorWeightFst > : public CacheArcIterator< FactorWeightFst > { public: typedef typename A::StateId StateId; ArcIterator(const FactorWeightFst &fst, StateId s) : CacheArcIterator< FactorWeightFst >(fst.GetImpl(), s) { if (!fst.GetImpl()->HasArcs(s)) fst.GetImpl()->Expand(s); } private: DISALLOW_COPY_AND_ASSIGN(ArcIterator); }; template inline void FactorWeightFst::InitStateIterator(StateIteratorData *data) const { data->base = new StateIterator< FactorWeightFst >(*this); } } // namespace fst #endif // FST_LIB_FACTOR_WEIGHT_H__