// lookahead-matcher.h // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // // Copyright 2005-2010 Google, Inc. // Author: riley@google.com (Michael Riley) // // \file // Classes to add lookahead to FST matchers, useful e.g. for improving // composition efficiency with certain inputs. #ifndef FST_LIB_LOOKAHEAD_MATCHER_H__ #define FST_LIB_LOOKAHEAD_MATCHER_H__ #include #include #include #include #include DECLARE_string(save_relabel_ipairs); DECLARE_string(save_relabel_opairs); namespace fst { // LOOKAHEAD MATCHERS - these have the interface of Matchers (see // matcher.h) and these additional methods: // // template // class LookAheadMatcher { // public: // typedef F FST; // typedef F::Arc Arc; // typedef typename Arc::StateId StateId; // typedef typename Arc::Label Label; // typedef typename Arc::Weight Weight; // // // Required constructors. // LookAheadMatcher(const F &fst, MatchType match_type); // // If safe=true, the copy is thread-safe (except the lookahead Fst is // // preserved). See Fst<>::Cop() for further doc. // LookAheadMatcher(const LookAheadMatcher &matcher, bool safe = false); // // Below are methods for looking ahead for a match to a label and // more generally, to a rational set. Each returns false if there is // definitely not a match and returns true if there possibly is a // match. // // LABEL LOOKAHEAD: Can 'label' be read from the current matcher state // // after possibly following epsilon transitions? // bool LookAheadLabel(Label label) const; // // // RATIONAL LOOKAHEAD: The next methods allow looking ahead for an // // arbitrary rational set of strings, specified by an FST and a state // // from which to begin the matching. If the lookahead FST is a // // transducer, this looks on the side different from the matcher // // 'match_type' (cf. composition). // // // Are there paths P from 's' in the lookahead FST that can be read from // // the cur. matcher state? // bool LookAheadFst(const Fst& fst, StateId s); // // // Gives an estimate of the combined weight of the paths P in the // // lookahead and matcher FSTs for the last call to LookAheadFst. // // A trivial implementation returns Weight::One(). Non-trivial // // implementations are useful for weight-pushing in composition. // Weight LookAheadWeight() const; // // // Is there is a single non-epsilon arc found in the lookahead FST // // that begins P (after possibly following any epsilons) in the last // // call LookAheadFst? If so, return true and copy it to '*arc', o.w. // // return false. A trivial implementation returns false. Non-trivial // // implementations are useful for label-pushing in composition. // bool LookAheadPrefix(Arc *arc); // // // Optionally pre-specifies the lookahead FST that will be passed // // to LookAheadFst() for possible precomputation. If copy is true, // // then 'fst' is a copy of the FST used in the previous call to // // this method (useful to avoid unnecessary updates). // void InitLookAheadFst(const Fst& fst, bool copy = false); // // }; // // LOOK-AHEAD FLAGS (see also kMatcherFlags in matcher.h): // // Matcher is a lookahead matcher when 'match_type' is MATCH_INPUT. const uint32 kInputLookAheadMatcher = 0x00000010; // Matcher is a lookahead matcher when 'match_type' is MATCH_OUTPUT. const uint32 kOutputLookAheadMatcher = 0x00000020; // A non-trivial implementation of LookAheadWeight() method defined and // should be used? const uint32 kLookAheadWeight = 0x00000040; // A non-trivial implementation of LookAheadPrefix() method defined and // should be used? const uint32 kLookAheadPrefix = 0x00000080; // Look-ahead of matcher FST non-epsilon arcs? const uint32 kLookAheadNonEpsilons = 0x00000100; // Look-ahead of matcher FST epsilon arcs? const uint32 kLookAheadEpsilons = 0x00000200; // Ignore epsilon paths for the lookahead prefix? Note this gives // correct results in composition only with an appropriate composition // filter since it depends on the filter blocking the ignored paths. const uint32 kLookAheadNonEpsilonPrefix = 0x00000400; // For LabelLookAheadMatcher, save relabeling data to file const uint32 kLookAheadKeepRelabelData = 0x00000800; // Flags used for lookahead matchers. const uint32 kLookAheadFlags = 0x00000ff0; // LookAhead Matcher interface, templated on the Arc definition; used // for lookahead matcher specializations that are returned by the // InitMatcher() Fst method. template class LookAheadMatcherBase : public MatcherBase { public: typedef A Arc; typedef typename A::StateId StateId; typedef typename A::Label Label; typedef typename A::Weight Weight; LookAheadMatcherBase() : weight_(Weight::One()), prefix_arc_(kNoLabel, kNoLabel, Weight::One(), kNoStateId) {} virtual ~LookAheadMatcherBase() {} bool LookAheadLabel(Label label) const { return LookAheadLabel_(label); } bool LookAheadFst(const Fst &fst, StateId s) { return LookAheadFst_(fst, s); } Weight LookAheadWeight() const { return weight_; } bool LookAheadPrefix(Arc *arc) const { if (prefix_arc_.nextstate != kNoStateId) { *arc = prefix_arc_; return true; } else { return false; } } virtual void InitLookAheadFst(const Fst& fst, bool copy = false) = 0; protected: void SetLookAheadWeight(const Weight &w) { weight_ = w; } void SetLookAheadPrefix(const Arc &arc) { prefix_arc_ = arc; } void ClearLookAheadPrefix() { prefix_arc_.nextstate = kNoStateId; } private: virtual bool LookAheadLabel_(Label label) const = 0; virtual bool LookAheadFst_(const Fst &fst, StateId s) = 0; // This must set l.a. weight and // prefix if non-trivial. Weight weight_; // Look-ahead weight Arc prefix_arc_; // Look-ahead prefix arc }; // Don't really lookahead, just declare future looks good regardless. template class TrivialLookAheadMatcher : public LookAheadMatcherBase { public: typedef typename M::FST FST; typedef typename M::Arc Arc; typedef typename Arc::StateId StateId; typedef typename Arc::Label Label; typedef typename Arc::Weight Weight; TrivialLookAheadMatcher(const FST &fst, MatchType match_type) : matcher_(fst, match_type) {} TrivialLookAheadMatcher(const TrivialLookAheadMatcher &lmatcher, bool safe = false) : matcher_(lmatcher.matcher_, safe) {} // General matcher methods TrivialLookAheadMatcher *Copy(bool safe = false) const { return new TrivialLookAheadMatcher(*this, safe); } MatchType Type(bool test) const { return matcher_.Type(test); } void SetState(StateId s) { return matcher_.SetState(s); } bool Find(Label label) { return matcher_.Find(label); } bool Done() const { return matcher_.Done(); } const Arc& Value() const { return matcher_.Value(); } void Next() { matcher_.Next(); } virtual const FST &GetFst() const { return matcher_.GetFst(); } uint64 Properties(uint64 props) const { return matcher_.Properties(props); } uint32 Flags() const { return matcher_.Flags() | kInputLookAheadMatcher | kOutputLookAheadMatcher; } // Look-ahead methods. bool LookAheadLabel(Label label) const { return true; } bool LookAheadFst(const Fst &fst, StateId s) {return true; } Weight LookAheadWeight() const { return Weight::One(); } bool LookAheadPrefix(Arc *arc) const { return false; } void InitLookAheadFst(const Fst& fst, bool copy = false) {} private: // This allows base class virtual access to non-virtual derived- // class members of the same name. It makes the derived class more // efficient to use but unsafe to further derive. virtual void SetState_(StateId s) { SetState(s); } virtual bool Find_(Label label) { return Find(label); } virtual bool Done_() const { return Done(); } virtual const Arc& Value_() const { return Value(); } virtual void Next_() { Next(); } bool LookAheadLabel_(Label l) const { return LookAheadLabel(l); } bool LookAheadFst_(const Fst &fst, StateId s) { return LookAheadFst(fst, s); } Weight LookAheadWeight_() const { return LookAheadWeight(); } bool LookAheadPrefix_(Arc *arc) const { return LookAheadPrefix(arc); } M matcher_; }; // Look-ahead of one transition. Template argument F accepts flags to // control behavior. template class ArcLookAheadMatcher : public LookAheadMatcherBase { public: typedef typename M::FST FST; typedef typename M::Arc Arc; typedef typename Arc::StateId StateId; typedef typename Arc::Label Label; typedef typename Arc::Weight Weight; typedef NullAddOn MatcherData; using LookAheadMatcherBase::LookAheadWeight; using LookAheadMatcherBase::SetLookAheadPrefix; using LookAheadMatcherBase::SetLookAheadWeight; using LookAheadMatcherBase::ClearLookAheadPrefix; ArcLookAheadMatcher(const FST &fst, MatchType match_type, MatcherData *data = 0) : matcher_(fst, match_type), fst_(matcher_.GetFst()), lfst_(0), s_(kNoStateId) {} ArcLookAheadMatcher(const ArcLookAheadMatcher &lmatcher, bool safe = false) : matcher_(lmatcher.matcher_, safe), fst_(matcher_.GetFst()), lfst_(lmatcher.lfst_), s_(kNoStateId) {} // General matcher methods ArcLookAheadMatcher *Copy(bool safe = false) const { return new ArcLookAheadMatcher(*this, safe); } MatchType Type(bool test) const { return matcher_.Type(test); } void SetState(StateId s) { s_ = s; matcher_.SetState(s); } bool Find(Label label) { return matcher_.Find(label); } bool Done() const { return matcher_.Done(); } const Arc& Value() const { return matcher_.Value(); } void Next() { matcher_.Next(); } const FST &GetFst() const { return fst_; } uint64 Properties(uint64 props) const { return matcher_.Properties(props); } uint32 Flags() const { return matcher_.Flags() | kInputLookAheadMatcher | kOutputLookAheadMatcher | F; } // Writable matcher methods MatcherData *GetData() const { return 0; } // Look-ahead methods. bool LookAheadLabel(Label label) const { return matcher_.Find(label); } // Checks if there is a matching (possibly super-final) transition // at (s_, s). bool LookAheadFst(const Fst &fst, StateId s); void InitLookAheadFst(const Fst& fst, bool copy = false) { lfst_ = &fst; } private: // This allows base class virtual access to non-virtual derived- // class members of the same name. It makes the derived class more // efficient to use but unsafe to further derive. virtual void SetState_(StateId s) { SetState(s); } virtual bool Find_(Label label) { return Find(label); } virtual bool Done_() const { return Done(); } virtual const Arc& Value_() const { return Value(); } virtual void Next_() { Next(); } bool LookAheadLabel_(Label l) const { return LookAheadLabel(l); } bool LookAheadFst_(const Fst &fst, StateId s) { return LookAheadFst(fst, s); } mutable M matcher_; const FST &fst_; // Matcher FST const Fst *lfst_; // Look-ahead FST StateId s_; // Matcher state }; template bool ArcLookAheadMatcher::LookAheadFst(const Fst &fst, StateId s) { if (&fst != lfst_) InitLookAheadFst(fst); bool ret = false; ssize_t nprefix = 0; if (F & kLookAheadWeight) SetLookAheadWeight(Weight::Zero()); if (F & kLookAheadPrefix) ClearLookAheadPrefix(); if (fst_.Final(s_) != Weight::Zero() && lfst_->Final(s) != Weight::Zero()) { if (!(F & (kLookAheadWeight | kLookAheadPrefix))) return true; ++nprefix; if (F & kLookAheadWeight) SetLookAheadWeight(Plus(LookAheadWeight(), Times(fst_.Final(s_), lfst_->Final(s)))); ret = true; } if (matcher_.Find(kNoLabel)) { if (!(F & (kLookAheadWeight | kLookAheadPrefix))) return true; ++nprefix; if (F & kLookAheadWeight) for (; !matcher_.Done(); matcher_.Next()) SetLookAheadWeight(Plus(LookAheadWeight(), matcher_.Value().weight)); ret = true; } for (ArcIterator< Fst > aiter(*lfst_, s); !aiter.Done(); aiter.Next()) { const Arc &arc = aiter.Value(); Label label = kNoLabel; switch (matcher_.Type(false)) { case MATCH_INPUT: label = arc.olabel; break; case MATCH_OUTPUT: label = arc.ilabel; break; default: FSTERROR() << "ArcLookAheadMatcher::LookAheadFst: bad match type"; return true; } if (label == 0) { if (!(F & (kLookAheadWeight | kLookAheadPrefix))) return true; if (!(F & kLookAheadNonEpsilonPrefix)) ++nprefix; if (F & kLookAheadWeight) SetLookAheadWeight(Plus(LookAheadWeight(), arc.weight)); ret = true; } else if (matcher_.Find(label)) { if (!(F & (kLookAheadWeight | kLookAheadPrefix))) return true; for (; !matcher_.Done(); matcher_.Next()) { ++nprefix; if (F & kLookAheadWeight) SetLookAheadWeight(Plus(LookAheadWeight(), Times(arc.weight, matcher_.Value().weight))); if ((F & kLookAheadPrefix) && nprefix == 1) SetLookAheadPrefix(arc); } ret = true; } } if (F & kLookAheadPrefix) { if (nprefix == 1) SetLookAheadWeight(Weight::One()); // Avoids double counting. else ClearLookAheadPrefix(); } return ret; } // Template argument F accepts flags to control behavior. // It must include precisely one of KInputLookAheadMatcher or // KOutputLookAheadMatcher. template > class LabelLookAheadMatcher : public LookAheadMatcherBase { public: typedef typename M::FST FST; typedef typename M::Arc Arc; typedef typename Arc::StateId StateId; typedef typename Arc::Label Label; typedef typename Arc::Weight Weight; typedef LabelReachableData *fst, const L &mfst, bool relabel_input) { typename L::Impl *impl = mfst.GetImpl(); D *data = impl->GetAddOn(); LabelReachable reachable(data->First() ? data->First() : data->Second()); reachable.Relabel(fst, relabel_input); } // Returns relabeling pairs (cf. relabel.h::Relabel()). // Class L should be a label-lookahead Fst. // If 'avoid_collisions' is true, extra pairs are added to // ensure no collisions when relabeling automata that have // labels unseen here. template static void RelabelPairs(const L &mfst, vector > *pairs, bool avoid_collisions = false) { typename L::Impl *impl = mfst.GetImpl(); D *data = impl->GetAddOn(); LabelReachable reachable(data->First() ? data->First() : data->Second()); reachable.RelabelPairs(pairs, avoid_collisions); } }; template template inline LabelLookAheadRelabeler::LabelLookAheadRelabeler(I **impl) { Fst &fst = (*impl)->GetFst(); D *data = (*impl)->GetAddOn(); const string name = (*impl)->Type(); bool is_mutable = fst.Properties(kMutable, false); MutableFst *mfst = 0; if (is_mutable) { mfst = static_cast *>(&fst); } else { mfst = new VectorFst(fst); data->IncrRefCount(); delete *impl; } if (data->First()) { // reach_input LabelReachable reachable(data->First()); reachable.Relabel(mfst, true); if (!FLAGS_save_relabel_ipairs.empty()) { vector > pairs; reachable.RelabelPairs(&pairs, true); WriteLabelPairs(FLAGS_save_relabel_ipairs, pairs); } } else { LabelReachable reachable(data->Second()); reachable.Relabel(mfst, false); if (!FLAGS_save_relabel_opairs.empty()) { vector > pairs; reachable.RelabelPairs(&pairs, true); WriteLabelPairs(FLAGS_save_relabel_opairs, pairs); } } if (!is_mutable) { *impl = new I(*mfst, name); (*impl)->SetAddOn(data); delete mfst; data->DecrRefCount(); } } // Generic lookahead matcher, templated on the FST definition // - a wrapper around pointer to specific one. template class LookAheadMatcher { public: typedef F FST; typedef typename F::Arc Arc; typedef typename Arc::StateId StateId; typedef typename Arc::Label Label; typedef typename Arc::Weight Weight; typedef LookAheadMatcherBase LBase; LookAheadMatcher(const F &fst, MatchType match_type) { base_ = fst.InitMatcher(match_type); if (!base_) base_ = new SortedMatcher(fst, match_type); lookahead_ = false; } LookAheadMatcher(const LookAheadMatcher &matcher, bool safe = false) { base_ = matcher.base_->Copy(safe); lookahead_ = matcher.lookahead_; } ~LookAheadMatcher() { delete base_; } // General matcher methods LookAheadMatcher *Copy(bool safe = false) const { return new LookAheadMatcher(*this, safe); } MatchType Type(bool test) const { return base_->Type(test); } void SetState(StateId s) { base_->SetState(s); } bool Find(Label label) { return base_->Find(label); } bool Done() const { return base_->Done(); } const Arc& Value() const { return base_->Value(); } void Next() { base_->Next(); } const F &GetFst() const { return static_cast(base_->GetFst()); } uint64 Properties(uint64 props) const { return base_->Properties(props); } uint32 Flags() const { return base_->Flags(); } // Look-ahead methods bool LookAheadLabel(Label label) const { if (LookAheadCheck()) { LBase *lbase = static_cast(base_); return lbase->LookAheadLabel(label); } else { return true; } } bool LookAheadFst(const Fst &fst, StateId s) { if (LookAheadCheck()) { LBase *lbase = static_cast(base_); return lbase->LookAheadFst(fst, s); } else { return true; } } Weight LookAheadWeight() const { if (LookAheadCheck()) { LBase *lbase = static_cast(base_); return lbase->LookAheadWeight(); } else { return Weight::One(); } } bool LookAheadPrefix(Arc *arc) const { if (LookAheadCheck()) { LBase *lbase = static_cast(base_); return lbase->LookAheadPrefix(arc); } else { return false; } } void InitLookAheadFst(const Fst& fst, bool copy = false) { if (LookAheadCheck()) { LBase *lbase = static_cast(base_); lbase->InitLookAheadFst(fst, copy); } } private: bool LookAheadCheck() const { if (!lookahead_) { lookahead_ = base_->Flags() & (kInputLookAheadMatcher | kOutputLookAheadMatcher); if (!lookahead_) { FSTERROR() << "LookAheadMatcher: No look-ahead matcher defined"; } } return lookahead_; } MatcherBase *base_; mutable bool lookahead_; void operator=(const LookAheadMatcher &); // disallow }; } // namespace fst #endif // FST_LIB_LOOKAHEAD_MATCHER_H__