From 96a32415ab43377cf1575bd3f4f2980f58028209 Mon Sep 17 00:00:00 2001 From: Determinant Date: Fri, 14 Aug 2015 11:51:42 +0800 Subject: add implementation for kaldi io (by ymz) --- .../tools/openfst/include/fst/lookahead-matcher.h | 812 +++++++++++++++++++++ 1 file changed, 812 insertions(+) create mode 100644 kaldi_io/src/tools/openfst/include/fst/lookahead-matcher.h (limited to 'kaldi_io/src/tools/openfst/include/fst/lookahead-matcher.h') diff --git a/kaldi_io/src/tools/openfst/include/fst/lookahead-matcher.h b/kaldi_io/src/tools/openfst/include/fst/lookahead-matcher.h new file mode 100644 index 0000000..f927d65 --- /dev/null +++ b/kaldi_io/src/tools/openfst/include/fst/lookahead-matcher.h @@ -0,0 +1,812 @@ +// lookahead-matcher.h + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Copyright 2005-2010 Google, Inc. +// Author: riley@google.com (Michael Riley) +// +// \file +// Classes to add lookahead to FST matchers, useful e.g. for improving +// composition efficiency with certain inputs. + +#ifndef FST_LIB_LOOKAHEAD_MATCHER_H__ +#define FST_LIB_LOOKAHEAD_MATCHER_H__ + +#include +#include +#include +#include +#include + + +DECLARE_string(save_relabel_ipairs); +DECLARE_string(save_relabel_opairs); + +namespace fst { + +// LOOKAHEAD MATCHERS - these have the interface of Matchers (see +// matcher.h) and these additional methods: +// +// template +// class LookAheadMatcher { +// public: +// typedef F FST; +// typedef F::Arc Arc; +// typedef typename Arc::StateId StateId; +// typedef typename Arc::Label Label; +// typedef typename Arc::Weight Weight; +// +// // Required constructors. +// LookAheadMatcher(const F &fst, MatchType match_type); +// // If safe=true, the copy is thread-safe (except the lookahead Fst is +// // preserved). See Fst<>::Cop() for further doc. +// LookAheadMatcher(const LookAheadMatcher &matcher, bool safe = false); +// +// Below are methods for looking ahead for a match to a label and +// more generally, to a rational set. Each returns false if there is +// definitely not a match and returns true if there possibly is a +// match. + +// // LABEL LOOKAHEAD: Can 'label' be read from the current matcher state +// // after possibly following epsilon transitions? +// bool LookAheadLabel(Label label) const; +// +// // RATIONAL LOOKAHEAD: The next methods allow looking ahead for an +// // arbitrary rational set of strings, specified by an FST and a state +// // from which to begin the matching. If the lookahead FST is a +// // transducer, this looks on the side different from the matcher +// // 'match_type' (cf. composition). +// +// // Are there paths P from 's' in the lookahead FST that can be read from +// // the cur. matcher state? +// bool LookAheadFst(const Fst& fst, StateId s); +// +// // Gives an estimate of the combined weight of the paths P in the +// // lookahead and matcher FSTs for the last call to LookAheadFst. +// // A trivial implementation returns Weight::One(). Non-trivial +// // implementations are useful for weight-pushing in composition. +// Weight LookAheadWeight() const; +// +// // Is there is a single non-epsilon arc found in the lookahead FST +// // that begins P (after possibly following any epsilons) in the last +// // call LookAheadFst? If so, return true and copy it to '*arc', o.w. +// // return false. A trivial implementation returns false. Non-trivial +// // implementations are useful for label-pushing in composition. +// bool LookAheadPrefix(Arc *arc); +// +// // Optionally pre-specifies the lookahead FST that will be passed +// // to LookAheadFst() for possible precomputation. If copy is true, +// // then 'fst' is a copy of the FST used in the previous call to +// // this method (useful to avoid unnecessary updates). +// void InitLookAheadFst(const Fst& fst, bool copy = false); +// +// }; + +// +// LOOK-AHEAD FLAGS (see also kMatcherFlags in matcher.h): +// +// Matcher is a lookahead matcher when 'match_type' is MATCH_INPUT. +const uint32 kInputLookAheadMatcher = 0x00000010; + +// Matcher is a lookahead matcher when 'match_type' is MATCH_OUTPUT. +const uint32 kOutputLookAheadMatcher = 0x00000020; + +// A non-trivial implementation of LookAheadWeight() method defined and +// should be used? +const uint32 kLookAheadWeight = 0x00000040; + +// A non-trivial implementation of LookAheadPrefix() method defined and +// should be used? +const uint32 kLookAheadPrefix = 0x00000080; + +// Look-ahead of matcher FST non-epsilon arcs? +const uint32 kLookAheadNonEpsilons = 0x00000100; + +// Look-ahead of matcher FST epsilon arcs? +const uint32 kLookAheadEpsilons = 0x00000200; + +// Ignore epsilon paths for the lookahead prefix? Note this gives +// correct results in composition only with an appropriate composition +// filter since it depends on the filter blocking the ignored paths. +const uint32 kLookAheadNonEpsilonPrefix = 0x00000400; + +// For LabelLookAheadMatcher, save relabeling data to file +const uint32 kLookAheadKeepRelabelData = 0x00000800; + +// Flags used for lookahead matchers. +const uint32 kLookAheadFlags = 0x00000ff0; + +// LookAhead Matcher interface, templated on the Arc definition; used +// for lookahead matcher specializations that are returned by the +// InitMatcher() Fst method. +template +class LookAheadMatcherBase : public MatcherBase { + public: + typedef A Arc; + typedef typename A::StateId StateId; + typedef typename A::Label Label; + typedef typename A::Weight Weight; + + LookAheadMatcherBase() + : weight_(Weight::One()), + prefix_arc_(kNoLabel, kNoLabel, Weight::One(), kNoStateId) {} + + virtual ~LookAheadMatcherBase() {} + + bool LookAheadLabel(Label label) const { return LookAheadLabel_(label); } + + bool LookAheadFst(const Fst &fst, StateId s) { + return LookAheadFst_(fst, s); + } + + Weight LookAheadWeight() const { return weight_; } + + bool LookAheadPrefix(Arc *arc) const { + if (prefix_arc_.nextstate != kNoStateId) { + *arc = prefix_arc_; + return true; + } else { + return false; + } + } + + virtual void InitLookAheadFst(const Fst& fst, bool copy = false) = 0; + + protected: + void SetLookAheadWeight(const Weight &w) { weight_ = w; } + + void SetLookAheadPrefix(const Arc &arc) { prefix_arc_ = arc; } + + void ClearLookAheadPrefix() { prefix_arc_.nextstate = kNoStateId; } + + private: + virtual bool LookAheadLabel_(Label label) const = 0; + virtual bool LookAheadFst_(const Fst &fst, + StateId s) = 0; // This must set l.a. weight and + // prefix if non-trivial. + Weight weight_; // Look-ahead weight + Arc prefix_arc_; // Look-ahead prefix arc +}; + + +// Don't really lookahead, just declare future looks good regardless. +template +class TrivialLookAheadMatcher + : public LookAheadMatcherBase { + public: + typedef typename M::FST FST; + typedef typename M::Arc Arc; + typedef typename Arc::StateId StateId; + typedef typename Arc::Label Label; + typedef typename Arc::Weight Weight; + + TrivialLookAheadMatcher(const FST &fst, MatchType match_type) + : matcher_(fst, match_type) {} + + TrivialLookAheadMatcher(const TrivialLookAheadMatcher &lmatcher, + bool safe = false) + : matcher_(lmatcher.matcher_, safe) {} + + // General matcher methods + TrivialLookAheadMatcher *Copy(bool safe = false) const { + return new TrivialLookAheadMatcher(*this, safe); + } + + MatchType Type(bool test) const { return matcher_.Type(test); } + void SetState(StateId s) { return matcher_.SetState(s); } + bool Find(Label label) { return matcher_.Find(label); } + bool Done() const { return matcher_.Done(); } + const Arc& Value() const { return matcher_.Value(); } + void Next() { matcher_.Next(); } + virtual const FST &GetFst() const { return matcher_.GetFst(); } + uint64 Properties(uint64 props) const { return matcher_.Properties(props); } + uint32 Flags() const { + return matcher_.Flags() | kInputLookAheadMatcher | kOutputLookAheadMatcher; + } + + // Look-ahead methods. + bool LookAheadLabel(Label label) const { return true; } + bool LookAheadFst(const Fst &fst, StateId s) {return true; } + Weight LookAheadWeight() const { return Weight::One(); } + bool LookAheadPrefix(Arc *arc) const { return false; } + void InitLookAheadFst(const Fst& fst, bool copy = false) {} + + private: + // This allows base class virtual access to non-virtual derived- + // class members of the same name. It makes the derived class more + // efficient to use but unsafe to further derive. + virtual void SetState_(StateId s) { SetState(s); } + virtual bool Find_(Label label) { return Find(label); } + virtual bool Done_() const { return Done(); } + virtual const Arc& Value_() const { return Value(); } + virtual void Next_() { Next(); } + + bool LookAheadLabel_(Label l) const { return LookAheadLabel(l); } + + bool LookAheadFst_(const Fst &fst, StateId s) { + return LookAheadFst(fst, s); + } + + Weight LookAheadWeight_() const { return LookAheadWeight(); } + bool LookAheadPrefix_(Arc *arc) const { return LookAheadPrefix(arc); } + + M matcher_; +}; + +// Look-ahead of one transition. Template argument F accepts flags to +// control behavior. +template +class ArcLookAheadMatcher + : public LookAheadMatcherBase { + public: + typedef typename M::FST FST; + typedef typename M::Arc Arc; + typedef typename Arc::StateId StateId; + typedef typename Arc::Label Label; + typedef typename Arc::Weight Weight; + typedef NullAddOn MatcherData; + + using LookAheadMatcherBase::LookAheadWeight; + using LookAheadMatcherBase::SetLookAheadPrefix; + using LookAheadMatcherBase::SetLookAheadWeight; + using LookAheadMatcherBase::ClearLookAheadPrefix; + + ArcLookAheadMatcher(const FST &fst, MatchType match_type, + MatcherData *data = 0) + : matcher_(fst, match_type), + fst_(matcher_.GetFst()), + lfst_(0), + s_(kNoStateId) {} + + ArcLookAheadMatcher(const ArcLookAheadMatcher &lmatcher, + bool safe = false) + : matcher_(lmatcher.matcher_, safe), + fst_(matcher_.GetFst()), + lfst_(lmatcher.lfst_), + s_(kNoStateId) {} + + // General matcher methods + ArcLookAheadMatcher *Copy(bool safe = false) const { + return new ArcLookAheadMatcher(*this, safe); + } + + MatchType Type(bool test) const { return matcher_.Type(test); } + + void SetState(StateId s) { + s_ = s; + matcher_.SetState(s); + } + + bool Find(Label label) { return matcher_.Find(label); } + bool Done() const { return matcher_.Done(); } + const Arc& Value() const { return matcher_.Value(); } + void Next() { matcher_.Next(); } + const FST &GetFst() const { return fst_; } + uint64 Properties(uint64 props) const { return matcher_.Properties(props); } + uint32 Flags() const { + return matcher_.Flags() | kInputLookAheadMatcher | + kOutputLookAheadMatcher | F; + } + + // Writable matcher methods + MatcherData *GetData() const { return 0; } + + // Look-ahead methods. + bool LookAheadLabel(Label label) const { return matcher_.Find(label); } + + // Checks if there is a matching (possibly super-final) transition + // at (s_, s). + bool LookAheadFst(const Fst &fst, StateId s); + + void InitLookAheadFst(const Fst& fst, bool copy = false) { + lfst_ = &fst; + } + + private: + // This allows base class virtual access to non-virtual derived- + // class members of the same name. It makes the derived class more + // efficient to use but unsafe to further derive. + virtual void SetState_(StateId s) { SetState(s); } + virtual bool Find_(Label label) { return Find(label); } + virtual bool Done_() const { return Done(); } + virtual const Arc& Value_() const { return Value(); } + virtual void Next_() { Next(); } + + bool LookAheadLabel_(Label l) const { return LookAheadLabel(l); } + bool LookAheadFst_(const Fst &fst, StateId s) { + return LookAheadFst(fst, s); + } + + mutable M matcher_; + const FST &fst_; // Matcher FST + const Fst *lfst_; // Look-ahead FST + StateId s_; // Matcher state +}; + +template +bool ArcLookAheadMatcher::LookAheadFst(const Fst &fst, StateId s) { + if (&fst != lfst_) + InitLookAheadFst(fst); + + bool ret = false; + ssize_t nprefix = 0; + if (F & kLookAheadWeight) + SetLookAheadWeight(Weight::Zero()); + if (F & kLookAheadPrefix) + ClearLookAheadPrefix(); + if (fst_.Final(s_) != Weight::Zero() && + lfst_->Final(s) != Weight::Zero()) { + if (!(F & (kLookAheadWeight | kLookAheadPrefix))) + return true; + ++nprefix; + if (F & kLookAheadWeight) + SetLookAheadWeight(Plus(LookAheadWeight(), + Times(fst_.Final(s_), lfst_->Final(s)))); + ret = true; + } + if (matcher_.Find(kNoLabel)) { + if (!(F & (kLookAheadWeight | kLookAheadPrefix))) + return true; + ++nprefix; + if (F & kLookAheadWeight) + for (; !matcher_.Done(); matcher_.Next()) + SetLookAheadWeight(Plus(LookAheadWeight(), matcher_.Value().weight)); + ret = true; + } + for (ArcIterator< Fst > aiter(*lfst_, s); + !aiter.Done(); + aiter.Next()) { + const Arc &arc = aiter.Value(); + Label label = kNoLabel; + switch (matcher_.Type(false)) { + case MATCH_INPUT: + label = arc.olabel; + break; + case MATCH_OUTPUT: + label = arc.ilabel; + break; + default: + FSTERROR() << "ArcLookAheadMatcher::LookAheadFst: bad match type"; + return true; + } + if (label == 0) { + if (!(F & (kLookAheadWeight | kLookAheadPrefix))) + return true; + if (!(F & kLookAheadNonEpsilonPrefix)) + ++nprefix; + if (F & kLookAheadWeight) + SetLookAheadWeight(Plus(LookAheadWeight(), arc.weight)); + ret = true; + } else if (matcher_.Find(label)) { + if (!(F & (kLookAheadWeight | kLookAheadPrefix))) + return true; + for (; !matcher_.Done(); matcher_.Next()) { + ++nprefix; + if (F & kLookAheadWeight) + SetLookAheadWeight(Plus(LookAheadWeight(), + Times(arc.weight, + matcher_.Value().weight))); + if ((F & kLookAheadPrefix) && nprefix == 1) + SetLookAheadPrefix(arc); + } + ret = true; + } + } + if (F & kLookAheadPrefix) { + if (nprefix == 1) + SetLookAheadWeight(Weight::One()); // Avoids double counting. + else + ClearLookAheadPrefix(); + } + return ret; +} + + +// Template argument F accepts flags to control behavior. +// It must include precisely one of KInputLookAheadMatcher or +// KOutputLookAheadMatcher. +template > +class LabelLookAheadMatcher + : public LookAheadMatcherBase { + public: + typedef typename M::FST FST; + typedef typename M::Arc Arc; + typedef typename Arc::StateId StateId; + typedef typename Arc::Label Label; + typedef typename Arc::Weight Weight; + typedef LabelReachableData *fst, const L &mfst, + bool relabel_input) { + typename L::Impl *impl = mfst.GetImpl(); + D *data = impl->GetAddOn(); + LabelReachable reachable(data->First() ? + data->First() : data->Second()); + reachable.Relabel(fst, relabel_input); + } + + // Returns relabeling pairs (cf. relabel.h::Relabel()). + // Class L should be a label-lookahead Fst. + // If 'avoid_collisions' is true, extra pairs are added to + // ensure no collisions when relabeling automata that have + // labels unseen here. + template + static void RelabelPairs(const L &mfst, vector > *pairs, + bool avoid_collisions = false) { + typename L::Impl *impl = mfst.GetImpl(); + D *data = impl->GetAddOn(); + LabelReachable reachable(data->First() ? + data->First() : data->Second()); + reachable.RelabelPairs(pairs, avoid_collisions); + } +}; + +template +template inline +LabelLookAheadRelabeler::LabelLookAheadRelabeler(I **impl) { + Fst &fst = (*impl)->GetFst(); + D *data = (*impl)->GetAddOn(); + const string name = (*impl)->Type(); + bool is_mutable = fst.Properties(kMutable, false); + MutableFst *mfst = 0; + if (is_mutable) { + mfst = static_cast *>(&fst); + } else { + mfst = new VectorFst(fst); + data->IncrRefCount(); + delete *impl; + } + if (data->First()) { // reach_input + LabelReachable reachable(data->First()); + reachable.Relabel(mfst, true); + if (!FLAGS_save_relabel_ipairs.empty()) { + vector > pairs; + reachable.RelabelPairs(&pairs, true); + WriteLabelPairs(FLAGS_save_relabel_ipairs, pairs); + } + } else { + LabelReachable reachable(data->Second()); + reachable.Relabel(mfst, false); + if (!FLAGS_save_relabel_opairs.empty()) { + vector > pairs; + reachable.RelabelPairs(&pairs, true); + WriteLabelPairs(FLAGS_save_relabel_opairs, pairs); + } + } + if (!is_mutable) { + *impl = new I(*mfst, name); + (*impl)->SetAddOn(data); + delete mfst; + data->DecrRefCount(); + } +} + + +// Generic lookahead matcher, templated on the FST definition +// - a wrapper around pointer to specific one. +template +class LookAheadMatcher { + public: + typedef F FST; + typedef typename F::Arc Arc; + typedef typename Arc::StateId StateId; + typedef typename Arc::Label Label; + typedef typename Arc::Weight Weight; + typedef LookAheadMatcherBase LBase; + + LookAheadMatcher(const F &fst, MatchType match_type) { + base_ = fst.InitMatcher(match_type); + if (!base_) + base_ = new SortedMatcher(fst, match_type); + lookahead_ = false; + } + + LookAheadMatcher(const LookAheadMatcher &matcher, bool safe = false) { + base_ = matcher.base_->Copy(safe); + lookahead_ = matcher.lookahead_; + } + + ~LookAheadMatcher() { delete base_; } + + // General matcher methods + LookAheadMatcher *Copy(bool safe = false) const { + return new LookAheadMatcher(*this, safe); + } + + MatchType Type(bool test) const { return base_->Type(test); } + void SetState(StateId s) { base_->SetState(s); } + bool Find(Label label) { return base_->Find(label); } + bool Done() const { return base_->Done(); } + const Arc& Value() const { return base_->Value(); } + void Next() { base_->Next(); } + const F &GetFst() const { return static_cast(base_->GetFst()); } + + uint64 Properties(uint64 props) const { return base_->Properties(props); } + + uint32 Flags() const { return base_->Flags(); } + + // Look-ahead methods + bool LookAheadLabel(Label label) const { + if (LookAheadCheck()) { + LBase *lbase = static_cast(base_); + return lbase->LookAheadLabel(label); + } else { + return true; + } + } + + bool LookAheadFst(const Fst &fst, StateId s) { + if (LookAheadCheck()) { + LBase *lbase = static_cast(base_); + return lbase->LookAheadFst(fst, s); + } else { + return true; + } + } + + Weight LookAheadWeight() const { + if (LookAheadCheck()) { + LBase *lbase = static_cast(base_); + return lbase->LookAheadWeight(); + } else { + return Weight::One(); + } + } + + bool LookAheadPrefix(Arc *arc) const { + if (LookAheadCheck()) { + LBase *lbase = static_cast(base_); + return lbase->LookAheadPrefix(arc); + } else { + return false; + } + } + + void InitLookAheadFst(const Fst& fst, bool copy = false) { + if (LookAheadCheck()) { + LBase *lbase = static_cast(base_); + lbase->InitLookAheadFst(fst, copy); + } + } + + private: + bool LookAheadCheck() const { + if (!lookahead_) { + lookahead_ = base_->Flags() & + (kInputLookAheadMatcher | kOutputLookAheadMatcher); + if (!lookahead_) { + FSTERROR() << "LookAheadMatcher: No look-ahead matcher defined"; + } + } + return lookahead_; + } + + MatcherBase *base_; + mutable bool lookahead_; + + void operator=(const LookAheadMatcher &); // disallow +}; + +} // namespace fst + +#endif // FST_LIB_LOOKAHEAD_MATCHER_H__ -- cgit v1.2.3