diff options
Diffstat (limited to 'kaldi_io/src/tools/openfst/include/fst/lookahead-matcher.h')
-rw-r--r-- | kaldi_io/src/tools/openfst/include/fst/lookahead-matcher.h | 812 |
1 files changed, 812 insertions, 0 deletions
diff --git a/kaldi_io/src/tools/openfst/include/fst/lookahead-matcher.h b/kaldi_io/src/tools/openfst/include/fst/lookahead-matcher.h new file mode 100644 index 0000000..f927d65 --- /dev/null +++ b/kaldi_io/src/tools/openfst/include/fst/lookahead-matcher.h @@ -0,0 +1,812 @@ +// lookahead-matcher.h + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Copyright 2005-2010 Google, Inc. +// Author: [email protected] (Michael Riley) +// +// \file +// Classes to add lookahead to FST matchers, useful e.g. for improving +// composition efficiency with certain inputs. + +#ifndef FST_LIB_LOOKAHEAD_MATCHER_H__ +#define FST_LIB_LOOKAHEAD_MATCHER_H__ + +#include <fst/add-on.h> +#include <fst/const-fst.h> +#include <fst/fst.h> +#include <fst/label-reachable.h> +#include <fst/matcher.h> + + +DECLARE_string(save_relabel_ipairs); +DECLARE_string(save_relabel_opairs); + +namespace fst { + +// LOOKAHEAD MATCHERS - these have the interface of Matchers (see +// matcher.h) and these additional methods: +// +// template <class F> +// class LookAheadMatcher { +// public: +// typedef F FST; +// typedef F::Arc Arc; +// typedef typename Arc::StateId StateId; +// typedef typename Arc::Label Label; +// typedef typename Arc::Weight Weight; +// +// // Required constructors. +// LookAheadMatcher(const F &fst, MatchType match_type); +// // If safe=true, the copy is thread-safe (except the lookahead Fst is +// // preserved). See Fst<>::Cop() for further doc. +// LookAheadMatcher(const LookAheadMatcher &matcher, bool safe = false); +// +// Below are methods for looking ahead for a match to a label and +// more generally, to a rational set. Each returns false if there is +// definitely not a match and returns true if there possibly is a +// match. + +// // LABEL LOOKAHEAD: Can 'label' be read from the current matcher state +// // after possibly following epsilon transitions? +// bool LookAheadLabel(Label label) const; +// +// // RATIONAL LOOKAHEAD: The next methods allow looking ahead for an +// // arbitrary rational set of strings, specified by an FST and a state +// // from which to begin the matching. If the lookahead FST is a +// // transducer, this looks on the side different from the matcher +// // 'match_type' (cf. composition). +// +// // Are there paths P from 's' in the lookahead FST that can be read from +// // the cur. matcher state? +// bool LookAheadFst(const Fst<Arc>& fst, StateId s); +// +// // Gives an estimate of the combined weight of the paths P in the +// // lookahead and matcher FSTs for the last call to LookAheadFst. +// // A trivial implementation returns Weight::One(). Non-trivial +// // implementations are useful for weight-pushing in composition. +// Weight LookAheadWeight() const; +// +// // Is there is a single non-epsilon arc found in the lookahead FST +// // that begins P (after possibly following any epsilons) in the last +// // call LookAheadFst? If so, return true and copy it to '*arc', o.w. +// // return false. A trivial implementation returns false. Non-trivial +// // implementations are useful for label-pushing in composition. +// bool LookAheadPrefix(Arc *arc); +// +// // Optionally pre-specifies the lookahead FST that will be passed +// // to LookAheadFst() for possible precomputation. If copy is true, +// // then 'fst' is a copy of the FST used in the previous call to +// // this method (useful to avoid unnecessary updates). +// void InitLookAheadFst(const Fst<Arc>& fst, bool copy = false); +// +// }; + +// +// LOOK-AHEAD FLAGS (see also kMatcherFlags in matcher.h): +// +// Matcher is a lookahead matcher when 'match_type' is MATCH_INPUT. +const uint32 kInputLookAheadMatcher = 0x00000010; + +// Matcher is a lookahead matcher when 'match_type' is MATCH_OUTPUT. +const uint32 kOutputLookAheadMatcher = 0x00000020; + +// A non-trivial implementation of LookAheadWeight() method defined and +// should be used? +const uint32 kLookAheadWeight = 0x00000040; + +// A non-trivial implementation of LookAheadPrefix() method defined and +// should be used? +const uint32 kLookAheadPrefix = 0x00000080; + +// Look-ahead of matcher FST non-epsilon arcs? +const uint32 kLookAheadNonEpsilons = 0x00000100; + +// Look-ahead of matcher FST epsilon arcs? +const uint32 kLookAheadEpsilons = 0x00000200; + +// Ignore epsilon paths for the lookahead prefix? Note this gives +// correct results in composition only with an appropriate composition +// filter since it depends on the filter blocking the ignored paths. +const uint32 kLookAheadNonEpsilonPrefix = 0x00000400; + +// For LabelLookAheadMatcher, save relabeling data to file +const uint32 kLookAheadKeepRelabelData = 0x00000800; + +// Flags used for lookahead matchers. +const uint32 kLookAheadFlags = 0x00000ff0; + +// LookAhead Matcher interface, templated on the Arc definition; used +// for lookahead matcher specializations that are returned by the +// InitMatcher() Fst method. +template <class A> +class LookAheadMatcherBase : public MatcherBase<A> { + public: + typedef A Arc; + typedef typename A::StateId StateId; + typedef typename A::Label Label; + typedef typename A::Weight Weight; + + LookAheadMatcherBase() + : weight_(Weight::One()), + prefix_arc_(kNoLabel, kNoLabel, Weight::One(), kNoStateId) {} + + virtual ~LookAheadMatcherBase() {} + + bool LookAheadLabel(Label label) const { return LookAheadLabel_(label); } + + bool LookAheadFst(const Fst<Arc> &fst, StateId s) { + return LookAheadFst_(fst, s); + } + + Weight LookAheadWeight() const { return weight_; } + + bool LookAheadPrefix(Arc *arc) const { + if (prefix_arc_.nextstate != kNoStateId) { + *arc = prefix_arc_; + return true; + } else { + return false; + } + } + + virtual void InitLookAheadFst(const Fst<Arc>& fst, bool copy = false) = 0; + + protected: + void SetLookAheadWeight(const Weight &w) { weight_ = w; } + + void SetLookAheadPrefix(const Arc &arc) { prefix_arc_ = arc; } + + void ClearLookAheadPrefix() { prefix_arc_.nextstate = kNoStateId; } + + private: + virtual bool LookAheadLabel_(Label label) const = 0; + virtual bool LookAheadFst_(const Fst<Arc> &fst, + StateId s) = 0; // This must set l.a. weight and + // prefix if non-trivial. + Weight weight_; // Look-ahead weight + Arc prefix_arc_; // Look-ahead prefix arc +}; + + +// Don't really lookahead, just declare future looks good regardless. +template <class M> +class TrivialLookAheadMatcher + : public LookAheadMatcherBase<typename M::FST::Arc> { + public: + typedef typename M::FST FST; + typedef typename M::Arc Arc; + typedef typename Arc::StateId StateId; + typedef typename Arc::Label Label; + typedef typename Arc::Weight Weight; + + TrivialLookAheadMatcher(const FST &fst, MatchType match_type) + : matcher_(fst, match_type) {} + + TrivialLookAheadMatcher(const TrivialLookAheadMatcher<M> &lmatcher, + bool safe = false) + : matcher_(lmatcher.matcher_, safe) {} + + // General matcher methods + TrivialLookAheadMatcher<M> *Copy(bool safe = false) const { + return new TrivialLookAheadMatcher<M>(*this, safe); + } + + MatchType Type(bool test) const { return matcher_.Type(test); } + void SetState(StateId s) { return matcher_.SetState(s); } + bool Find(Label label) { return matcher_.Find(label); } + bool Done() const { return matcher_.Done(); } + const Arc& Value() const { return matcher_.Value(); } + void Next() { matcher_.Next(); } + virtual const FST &GetFst() const { return matcher_.GetFst(); } + uint64 Properties(uint64 props) const { return matcher_.Properties(props); } + uint32 Flags() const { + return matcher_.Flags() | kInputLookAheadMatcher | kOutputLookAheadMatcher; + } + + // Look-ahead methods. + bool LookAheadLabel(Label label) const { return true; } + bool LookAheadFst(const Fst<Arc> &fst, StateId s) {return true; } + Weight LookAheadWeight() const { return Weight::One(); } + bool LookAheadPrefix(Arc *arc) const { return false; } + void InitLookAheadFst(const Fst<Arc>& fst, bool copy = false) {} + + private: + // This allows base class virtual access to non-virtual derived- + // class members of the same name. It makes the derived class more + // efficient to use but unsafe to further derive. + virtual void SetState_(StateId s) { SetState(s); } + virtual bool Find_(Label label) { return Find(label); } + virtual bool Done_() const { return Done(); } + virtual const Arc& Value_() const { return Value(); } + virtual void Next_() { Next(); } + + bool LookAheadLabel_(Label l) const { return LookAheadLabel(l); } + + bool LookAheadFst_(const Fst<Arc> &fst, StateId s) { + return LookAheadFst(fst, s); + } + + Weight LookAheadWeight_() const { return LookAheadWeight(); } + bool LookAheadPrefix_(Arc *arc) const { return LookAheadPrefix(arc); } + + M matcher_; +}; + +// Look-ahead of one transition. Template argument F accepts flags to +// control behavior. +template <class M, uint32 F = kLookAheadNonEpsilons | kLookAheadEpsilons | + kLookAheadWeight | kLookAheadPrefix> +class ArcLookAheadMatcher + : public LookAheadMatcherBase<typename M::FST::Arc> { + public: + typedef typename M::FST FST; + typedef typename M::Arc Arc; + typedef typename Arc::StateId StateId; + typedef typename Arc::Label Label; + typedef typename Arc::Weight Weight; + typedef NullAddOn MatcherData; + + using LookAheadMatcherBase<Arc>::LookAheadWeight; + using LookAheadMatcherBase<Arc>::SetLookAheadPrefix; + using LookAheadMatcherBase<Arc>::SetLookAheadWeight; + using LookAheadMatcherBase<Arc>::ClearLookAheadPrefix; + + ArcLookAheadMatcher(const FST &fst, MatchType match_type, + MatcherData *data = 0) + : matcher_(fst, match_type), + fst_(matcher_.GetFst()), + lfst_(0), + s_(kNoStateId) {} + + ArcLookAheadMatcher(const ArcLookAheadMatcher<M, F> &lmatcher, + bool safe = false) + : matcher_(lmatcher.matcher_, safe), + fst_(matcher_.GetFst()), + lfst_(lmatcher.lfst_), + s_(kNoStateId) {} + + // General matcher methods + ArcLookAheadMatcher<M, F> *Copy(bool safe = false) const { + return new ArcLookAheadMatcher<M, F>(*this, safe); + } + + MatchType Type(bool test) const { return matcher_.Type(test); } + + void SetState(StateId s) { + s_ = s; + matcher_.SetState(s); + } + + bool Find(Label label) { return matcher_.Find(label); } + bool Done() const { return matcher_.Done(); } + const Arc& Value() const { return matcher_.Value(); } + void Next() { matcher_.Next(); } + const FST &GetFst() const { return fst_; } + uint64 Properties(uint64 props) const { return matcher_.Properties(props); } + uint32 Flags() const { + return matcher_.Flags() | kInputLookAheadMatcher | + kOutputLookAheadMatcher | F; + } + + // Writable matcher methods + MatcherData *GetData() const { return 0; } + + // Look-ahead methods. + bool LookAheadLabel(Label label) const { return matcher_.Find(label); } + + // Checks if there is a matching (possibly super-final) transition + // at (s_, s). + bool LookAheadFst(const Fst<Arc> &fst, StateId s); + + void InitLookAheadFst(const Fst<Arc>& fst, bool copy = false) { + lfst_ = &fst; + } + + private: + // This allows base class virtual access to non-virtual derived- + // class members of the same name. It makes the derived class more + // efficient to use but unsafe to further derive. + virtual void SetState_(StateId s) { SetState(s); } + virtual bool Find_(Label label) { return Find(label); } + virtual bool Done_() const { return Done(); } + virtual const Arc& Value_() const { return Value(); } + virtual void Next_() { Next(); } + + bool LookAheadLabel_(Label l) const { return LookAheadLabel(l); } + bool LookAheadFst_(const Fst<Arc> &fst, StateId s) { + return LookAheadFst(fst, s); + } + + mutable M matcher_; + const FST &fst_; // Matcher FST + const Fst<Arc> *lfst_; // Look-ahead FST + StateId s_; // Matcher state +}; + +template <class M, uint32 F> +bool ArcLookAheadMatcher<M, F>::LookAheadFst(const Fst<Arc> &fst, StateId s) { + if (&fst != lfst_) + InitLookAheadFst(fst); + + bool ret = false; + ssize_t nprefix = 0; + if (F & kLookAheadWeight) + SetLookAheadWeight(Weight::Zero()); + if (F & kLookAheadPrefix) + ClearLookAheadPrefix(); + if (fst_.Final(s_) != Weight::Zero() && + lfst_->Final(s) != Weight::Zero()) { + if (!(F & (kLookAheadWeight | kLookAheadPrefix))) + return true; + ++nprefix; + if (F & kLookAheadWeight) + SetLookAheadWeight(Plus(LookAheadWeight(), + Times(fst_.Final(s_), lfst_->Final(s)))); + ret = true; + } + if (matcher_.Find(kNoLabel)) { + if (!(F & (kLookAheadWeight | kLookAheadPrefix))) + return true; + ++nprefix; + if (F & kLookAheadWeight) + for (; !matcher_.Done(); matcher_.Next()) + SetLookAheadWeight(Plus(LookAheadWeight(), matcher_.Value().weight)); + ret = true; + } + for (ArcIterator< Fst<Arc> > aiter(*lfst_, s); + !aiter.Done(); + aiter.Next()) { + const Arc &arc = aiter.Value(); + Label label = kNoLabel; + switch (matcher_.Type(false)) { + case MATCH_INPUT: + label = arc.olabel; + break; + case MATCH_OUTPUT: + label = arc.ilabel; + break; + default: + FSTERROR() << "ArcLookAheadMatcher::LookAheadFst: bad match type"; + return true; + } + if (label == 0) { + if (!(F & (kLookAheadWeight | kLookAheadPrefix))) + return true; + if (!(F & kLookAheadNonEpsilonPrefix)) + ++nprefix; + if (F & kLookAheadWeight) + SetLookAheadWeight(Plus(LookAheadWeight(), arc.weight)); + ret = true; + } else if (matcher_.Find(label)) { + if (!(F & (kLookAheadWeight | kLookAheadPrefix))) + return true; + for (; !matcher_.Done(); matcher_.Next()) { + ++nprefix; + if (F & kLookAheadWeight) + SetLookAheadWeight(Plus(LookAheadWeight(), + Times(arc.weight, + matcher_.Value().weight))); + if ((F & kLookAheadPrefix) && nprefix == 1) + SetLookAheadPrefix(arc); + } + ret = true; + } + } + if (F & kLookAheadPrefix) { + if (nprefix == 1) + SetLookAheadWeight(Weight::One()); // Avoids double counting. + else + ClearLookAheadPrefix(); + } + return ret; +} + + +// Template argument F accepts flags to control behavior. +// It must include precisely one of KInputLookAheadMatcher or +// KOutputLookAheadMatcher. +template <class M, uint32 F = kLookAheadEpsilons | kLookAheadWeight | + kLookAheadPrefix | kLookAheadNonEpsilonPrefix | + kLookAheadKeepRelabelData, + class S = DefaultAccumulator<typename M::Arc> > +class LabelLookAheadMatcher + : public LookAheadMatcherBase<typename M::FST::Arc> { + public: + typedef typename M::FST FST; + typedef typename M::Arc Arc; + typedef typename Arc::StateId StateId; + typedef typename Arc::Label Label; + typedef typename Arc::Weight Weight; + typedef LabelReachableData<Label> MatcherData; + + using LookAheadMatcherBase<Arc>::LookAheadWeight; + using LookAheadMatcherBase<Arc>::SetLookAheadPrefix; + using LookAheadMatcherBase<Arc>::SetLookAheadWeight; + using LookAheadMatcherBase<Arc>::ClearLookAheadPrefix; + + LabelLookAheadMatcher(const FST &fst, MatchType match_type, + MatcherData *data = 0, S *s = 0) + : matcher_(fst, match_type), + lfst_(0), + label_reachable_(0), + s_(kNoStateId), + error_(false) { + if (!(F & (kInputLookAheadMatcher | kOutputLookAheadMatcher))) { + FSTERROR() << "LabelLookaheadMatcher: bad matcher flags: " << F; + error_ = true; + } + bool reach_input = match_type == MATCH_INPUT; + if (data) { + if (reach_input == data->ReachInput()) + label_reachable_ = new LabelReachable<Arc, S>(data, s); + } else if ((reach_input && (F & kInputLookAheadMatcher)) || + (!reach_input && (F & kOutputLookAheadMatcher))) { + label_reachable_ = new LabelReachable<Arc, S>( + fst, reach_input, s, F & kLookAheadKeepRelabelData); + } + } + + LabelLookAheadMatcher(const LabelLookAheadMatcher<M, F, S> &lmatcher, + bool safe = false) + : matcher_(lmatcher.matcher_, safe), + lfst_(lmatcher.lfst_), + label_reachable_( + lmatcher.label_reachable_ ? + new LabelReachable<Arc, S>(*lmatcher.label_reachable_) : 0), + s_(kNoStateId), + error_(lmatcher.error_) {} + + ~LabelLookAheadMatcher() { + delete label_reachable_; + } + + // General matcher methods + LabelLookAheadMatcher<M, F, S> *Copy(bool safe = false) const { + return new LabelLookAheadMatcher<M, F, S>(*this, safe); + } + + MatchType Type(bool test) const { return matcher_.Type(test); } + + void SetState(StateId s) { + if (s_ == s) + return; + s_ = s; + match_set_state_ = false; + reach_set_state_ = false; + } + + bool Find(Label label) { + if (!match_set_state_) { + matcher_.SetState(s_); + match_set_state_ = true; + } + return matcher_.Find(label); + } + + bool Done() const { return matcher_.Done(); } + const Arc& Value() const { return matcher_.Value(); } + void Next() { matcher_.Next(); } + const FST &GetFst() const { return matcher_.GetFst(); } + + uint64 Properties(uint64 inprops) const { + uint64 outprops = matcher_.Properties(inprops); + if (error_ || (label_reachable_ && label_reachable_->Error())) + outprops |= kError; + return outprops; + } + + uint32 Flags() const { + if (label_reachable_ && label_reachable_->GetData()->ReachInput()) + return matcher_.Flags() | F | kInputLookAheadMatcher; + else if (label_reachable_ && !label_reachable_->GetData()->ReachInput()) + return matcher_.Flags() | F | kOutputLookAheadMatcher; + else + return matcher_.Flags(); + } + + // Writable matcher methods + MatcherData *GetData() const { + return label_reachable_ ? label_reachable_->GetData() : 0; + }; + + // Look-ahead methods. + bool LookAheadLabel(Label label) const { + if (label == 0) + return true; + + if (label_reachable_) { + if (!reach_set_state_) { + label_reachable_->SetState(s_); + reach_set_state_ = true; + } + return label_reachable_->Reach(label); + } else { + return true; + } + } + + // Checks if there is a matching (possibly super-final) transition + // at (s_, s). + template <class L> + bool LookAheadFst(const L &fst, StateId s); + + void InitLookAheadFst(const Fst<Arc>& fst, bool copy = false) { + lfst_ = &fst; + if (label_reachable_) + label_reachable_->ReachInit(fst, copy); + } + + template <class L> + void InitLookAheadFst(const L& fst, bool copy = false) { + lfst_ = static_cast<const Fst<Arc> *>(&fst); + if (label_reachable_) + label_reachable_->ReachInit(fst, copy); + } + + private: + // This allows base class virtual access to non-virtual derived- + // class members of the same name. It makes the derived class more + // efficient to use but unsafe to further derive. + virtual void SetState_(StateId s) { SetState(s); } + virtual bool Find_(Label label) { return Find(label); } + virtual bool Done_() const { return Done(); } + virtual const Arc& Value_() const { return Value(); } + virtual void Next_() { Next(); } + + bool LookAheadLabel_(Label l) const { return LookAheadLabel(l); } + bool LookAheadFst_(const Fst<Arc> &fst, StateId s) { + return LookAheadFst(fst, s); + } + + mutable M matcher_; + const Fst<Arc> *lfst_; // Look-ahead FST + LabelReachable<Arc, S> *label_reachable_; // Label reachability info + StateId s_; // Matcher state + bool match_set_state_; // matcher_.SetState called? + mutable bool reach_set_state_; // reachable_.SetState called? + bool error_; +}; + +template <class M, uint32 F, class S> +template <class L> inline +bool LabelLookAheadMatcher<M, F, S>::LookAheadFst(const L &fst, StateId s) { + if (static_cast<const Fst<Arc> *>(&fst) != lfst_) + InitLookAheadFst(fst); + + SetLookAheadWeight(Weight::One()); + ClearLookAheadPrefix(); + + if (!label_reachable_) + return true; + + label_reachable_->SetState(s_, s); + reach_set_state_ = true; + + bool compute_weight = F & kLookAheadWeight; + bool compute_prefix = F & kLookAheadPrefix; + + bool reach_input = Type(false) == MATCH_OUTPUT; + ArcIterator<L> aiter(fst, s); + bool reach_arc = label_reachable_->Reach(&aiter, 0, + internal::NumArcs(*lfst_, s), + reach_input, compute_weight); + Weight lfinal = internal::Final(*lfst_, s); + bool reach_final = lfinal != Weight::Zero() && label_reachable_->ReachFinal(); + if (reach_arc) { + ssize_t begin = label_reachable_->ReachBegin(); + ssize_t end = label_reachable_->ReachEnd(); + if (compute_prefix && end - begin == 1 && !reach_final) { + aiter.Seek(begin); + SetLookAheadPrefix(aiter.Value()); + compute_weight = false; + } else if (compute_weight) { + SetLookAheadWeight(label_reachable_->ReachWeight()); + } + } + if (reach_final && compute_weight) + SetLookAheadWeight(reach_arc ? + Plus(LookAheadWeight(), lfinal) : lfinal); + + return reach_arc || reach_final; +} + + +// Label-lookahead relabeling class. +template <class A> +class LabelLookAheadRelabeler { + public: + typedef typename A::Label Label; + typedef LabelReachableData<Label> MatcherData; + typedef AddOnPair<MatcherData, MatcherData> D; + + // Relabels matcher Fst - initialization function object. + template <typename I> + LabelLookAheadRelabeler(I **impl); + + // Relabels arbitrary Fst. Class L should be a label-lookahead Fst. + template <class L> + static void Relabel(MutableFst<A> *fst, const L &mfst, + bool relabel_input) { + typename L::Impl *impl = mfst.GetImpl(); + D *data = impl->GetAddOn(); + LabelReachable<A> reachable(data->First() ? + data->First() : data->Second()); + reachable.Relabel(fst, relabel_input); + } + + // Returns relabeling pairs (cf. relabel.h::Relabel()). + // Class L should be a label-lookahead Fst. + // If 'avoid_collisions' is true, extra pairs are added to + // ensure no collisions when relabeling automata that have + // labels unseen here. + template <class L> + static void RelabelPairs(const L &mfst, vector<pair<Label, Label> > *pairs, + bool avoid_collisions = false) { + typename L::Impl *impl = mfst.GetImpl(); + D *data = impl->GetAddOn(); + LabelReachable<A> reachable(data->First() ? + data->First() : data->Second()); + reachable.RelabelPairs(pairs, avoid_collisions); + } +}; + +template <class A> +template <typename I> inline +LabelLookAheadRelabeler<A>::LabelLookAheadRelabeler(I **impl) { + Fst<A> &fst = (*impl)->GetFst(); + D *data = (*impl)->GetAddOn(); + const string name = (*impl)->Type(); + bool is_mutable = fst.Properties(kMutable, false); + MutableFst<A> *mfst = 0; + if (is_mutable) { + mfst = static_cast<MutableFst<A> *>(&fst); + } else { + mfst = new VectorFst<A>(fst); + data->IncrRefCount(); + delete *impl; + } + if (data->First()) { // reach_input + LabelReachable<A> reachable(data->First()); + reachable.Relabel(mfst, true); + if (!FLAGS_save_relabel_ipairs.empty()) { + vector<pair<Label, Label> > pairs; + reachable.RelabelPairs(&pairs, true); + WriteLabelPairs(FLAGS_save_relabel_ipairs, pairs); + } + } else { + LabelReachable<A> reachable(data->Second()); + reachable.Relabel(mfst, false); + if (!FLAGS_save_relabel_opairs.empty()) { + vector<pair<Label, Label> > pairs; + reachable.RelabelPairs(&pairs, true); + WriteLabelPairs(FLAGS_save_relabel_opairs, pairs); + } + } + if (!is_mutable) { + *impl = new I(*mfst, name); + (*impl)->SetAddOn(data); + delete mfst; + data->DecrRefCount(); + } +} + + +// Generic lookahead matcher, templated on the FST definition +// - a wrapper around pointer to specific one. +template <class F> +class LookAheadMatcher { + public: + typedef F FST; + typedef typename F::Arc Arc; + typedef typename Arc::StateId StateId; + typedef typename Arc::Label Label; + typedef typename Arc::Weight Weight; + typedef LookAheadMatcherBase<Arc> LBase; + + LookAheadMatcher(const F &fst, MatchType match_type) { + base_ = fst.InitMatcher(match_type); + if (!base_) + base_ = new SortedMatcher<F>(fst, match_type); + lookahead_ = false; + } + + LookAheadMatcher(const LookAheadMatcher<F> &matcher, bool safe = false) { + base_ = matcher.base_->Copy(safe); + lookahead_ = matcher.lookahead_; + } + + ~LookAheadMatcher() { delete base_; } + + // General matcher methods + LookAheadMatcher<F> *Copy(bool safe = false) const { + return new LookAheadMatcher<F>(*this, safe); + } + + MatchType Type(bool test) const { return base_->Type(test); } + void SetState(StateId s) { base_->SetState(s); } + bool Find(Label label) { return base_->Find(label); } + bool Done() const { return base_->Done(); } + const Arc& Value() const { return base_->Value(); } + void Next() { base_->Next(); } + const F &GetFst() const { return static_cast<const F &>(base_->GetFst()); } + + uint64 Properties(uint64 props) const { return base_->Properties(props); } + + uint32 Flags() const { return base_->Flags(); } + + // Look-ahead methods + bool LookAheadLabel(Label label) const { + if (LookAheadCheck()) { + LBase *lbase = static_cast<LBase *>(base_); + return lbase->LookAheadLabel(label); + } else { + return true; + } + } + + bool LookAheadFst(const Fst<Arc> &fst, StateId s) { + if (LookAheadCheck()) { + LBase *lbase = static_cast<LBase *>(base_); + return lbase->LookAheadFst(fst, s); + } else { + return true; + } + } + + Weight LookAheadWeight() const { + if (LookAheadCheck()) { + LBase *lbase = static_cast<LBase *>(base_); + return lbase->LookAheadWeight(); + } else { + return Weight::One(); + } + } + + bool LookAheadPrefix(Arc *arc) const { + if (LookAheadCheck()) { + LBase *lbase = static_cast<LBase *>(base_); + return lbase->LookAheadPrefix(arc); + } else { + return false; + } + } + + void InitLookAheadFst(const Fst<Arc>& fst, bool copy = false) { + if (LookAheadCheck()) { + LBase *lbase = static_cast<LBase *>(base_); + lbase->InitLookAheadFst(fst, copy); + } + } + + private: + bool LookAheadCheck() const { + if (!lookahead_) { + lookahead_ = base_->Flags() & + (kInputLookAheadMatcher | kOutputLookAheadMatcher); + if (!lookahead_) { + FSTERROR() << "LookAheadMatcher: No look-ahead matcher defined"; + } + } + return lookahead_; + } + + MatcherBase<Arc> *base_; + mutable bool lookahead_; + + void operator=(const LookAheadMatcher<Arc> &); // disallow +}; + +} // namespace fst + +#endif // FST_LIB_LOOKAHEAD_MATCHER_H__ |