From 96a32415ab43377cf1575bd3f4f2980f58028209 Mon Sep 17 00:00:00 2001 From: Determinant Date: Fri, 14 Aug 2015 11:51:42 +0800 Subject: add implementation for kaldi io (by ymz) --- kaldi_io/src/tools/openfst/include/fst/matcher.h | 1205 ++++++++++++++++++++++ 1 file changed, 1205 insertions(+) create mode 100644 kaldi_io/src/tools/openfst/include/fst/matcher.h (limited to 'kaldi_io/src/tools/openfst/include/fst/matcher.h') diff --git a/kaldi_io/src/tools/openfst/include/fst/matcher.h b/kaldi_io/src/tools/openfst/include/fst/matcher.h new file mode 100644 index 0000000..89ed9be --- /dev/null +++ b/kaldi_io/src/tools/openfst/include/fst/matcher.h @@ -0,0 +1,1205 @@ +// matcher.h + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Copyright 2005-2010 Google, Inc. +// Author: riley@google.com (Michael Riley) +// +// \file +// Classes to allow matching labels leaving FST states. + +#ifndef FST_LIB_MATCHER_H__ +#define FST_LIB_MATCHER_H__ + +#include +#include + +#include // for all internal FST accessors + + +namespace fst { + +// MATCHERS - these can find and iterate through requested labels at +// FST states. In the simplest form, these are just some associative +// map or search keyed on labels. More generally, they may +// implement matching special labels that represent sets of labels +// such as 'sigma' (all), 'rho' (rest), or 'phi' (fail). +// The Matcher interface is: +// +// template +// class Matcher { +// public: +// typedef F FST; +// typedef F::Arc Arc; +// typedef typename Arc::StateId StateId; +// typedef typename Arc::Label Label; +// typedef typename Arc::Weight Weight; +// +// // Required constructors. +// Matcher(const F &fst, MatchType type); +// // If safe=true, the copy is thread-safe. See Fst<>::Copy() +// // for further doc. +// Matcher(const Matcher &matcher, bool safe = false); +// +// // If safe=true, the copy is thread-safe. See Fst<>::Copy() +// // for further doc. +// Matcher *Copy(bool safe = false) const; +// +// // Returns the match type that can be provided (depending on +// // compatibility of the input FST). It is either +// // the requested match type, MATCH_NONE, or MATCH_UNKNOWN. +// // If 'test' is false, a constant time test is performed, but +// // MATCH_UNKNOWN may be returned. If 'test' is true, +// // a definite answer is returned, but may involve more costly +// // computation (e.g., visiting the Fst). +// MatchType Type(bool test) const; +// // Specifies the current state. +// void SetState(StateId s); +// +// // This finds matches to a label at the current state. +// // Returns true if a match found. kNoLabel matches any +// // 'non-consuming' transitions, e.g., epsilon transitions, +// // which do not require a matching symbol. +// bool Find(Label label); +// // These iterate through any matches found: +// bool Done() const; // No more matches. +// const A& Value() const; // Current arc (when !Done) +// void Next(); // Advance to next arc (when !Done) +// // Initially and after SetState() the iterator methods +// // have undefined behavior until Find() is called. +// +// // Return matcher FST. +// const F& GetFst() const; +// // This specifies the known Fst properties as viewed from this +// // matcher. It takes as argument the input Fst's known properties. +// uint64 Properties(uint64 props) const; +// }; + +// +// MATCHER FLAGS (see also kLookAheadFlags in lookahead-matcher.h) +// +// Matcher prefers being used as the matching side in composition. +const uint32 kPreferMatch = 0x00000001; + +// Matcher needs to be used as the matching side in composition. +const uint32 kRequireMatch = 0x00000002; + +// Flags used for basic matchers (see also lookahead.h). +const uint32 kMatcherFlags = kPreferMatch | kRequireMatch; + +// Matcher interface, templated on the Arc definition; used +// for matcher specializations that are returned by the +// InitMatcher Fst method. +template +class MatcherBase { + public: + typedef A Arc; + typedef typename A::StateId StateId; + typedef typename A::Label Label; + typedef typename A::Weight Weight; + + virtual ~MatcherBase() {} + + virtual MatcherBase *Copy(bool safe = false) const = 0; + virtual MatchType Type(bool test) const = 0; + void SetState(StateId s) { SetState_(s); } + bool Find(Label label) { return Find_(label); } + bool Done() const { return Done_(); } + const A& Value() const { return Value_(); } + void Next() { Next_(); } + virtual const Fst &GetFst() const = 0; + virtual uint64 Properties(uint64 props) const = 0; + virtual uint32 Flags() const { return 0; } + private: + virtual void SetState_(StateId s) = 0; + virtual bool Find_(Label label) = 0; + virtual bool Done_() const = 0; + virtual const A& Value_() const = 0; + virtual void Next_() = 0; +}; + + +// A matcher that expects sorted labels on the side to be matched. +// If match_type == MATCH_INPUT, epsilons match the implicit self loop +// Arc(kNoLabel, 0, Weight::One(), current_state) as well as any +// actual epsilon transitions. If match_type == MATCH_OUTPUT, then +// Arc(0, kNoLabel, Weight::One(), current_state) is instead matched. +template +class SortedMatcher : public MatcherBase { + public: + typedef F FST; + typedef typename F::Arc Arc; + typedef typename Arc::StateId StateId; + typedef typename Arc::Label Label; + typedef typename Arc::Weight Weight; + + // Labels >= binary_label will be searched for by binary search, + // o.w. linear search is used. + SortedMatcher(const F &fst, MatchType match_type, + Label binary_label = 1) + : fst_(fst.Copy()), + s_(kNoStateId), + aiter_(0), + match_type_(match_type), + binary_label_(binary_label), + match_label_(kNoLabel), + narcs_(0), + loop_(kNoLabel, 0, Weight::One(), kNoStateId), + error_(false) { + switch(match_type_) { + case MATCH_INPUT: + case MATCH_NONE: + break; + case MATCH_OUTPUT: + swap(loop_.ilabel, loop_.olabel); + break; + default: + FSTERROR() << "SortedMatcher: bad match type"; + match_type_ = MATCH_NONE; + error_ = true; + } + } + + SortedMatcher(const SortedMatcher &matcher, bool safe = false) + : fst_(matcher.fst_->Copy(safe)), + s_(kNoStateId), + aiter_(0), + match_type_(matcher.match_type_), + binary_label_(matcher.binary_label_), + match_label_(kNoLabel), + narcs_(0), + loop_(matcher.loop_), + error_(matcher.error_) {} + + virtual ~SortedMatcher() { + if (aiter_) + delete aiter_; + delete fst_; + } + + virtual SortedMatcher *Copy(bool safe = false) const { + return new SortedMatcher(*this, safe); + } + + virtual MatchType Type(bool test) const { + if (match_type_ == MATCH_NONE) + return match_type_; + + uint64 true_prop = match_type_ == MATCH_INPUT ? + kILabelSorted : kOLabelSorted; + uint64 false_prop = match_type_ == MATCH_INPUT ? + kNotILabelSorted : kNotOLabelSorted; + uint64 props = fst_->Properties(true_prop | false_prop, test); + + if (props & true_prop) + return match_type_; + else if (props & false_prop) + return MATCH_NONE; + else + return MATCH_UNKNOWN; + } + + void SetState(StateId s) { + if (s_ == s) + return; + s_ = s; + if (match_type_ == MATCH_NONE) { + FSTERROR() << "SortedMatcher: bad match type"; + error_ = true; + } + if (aiter_) + delete aiter_; + aiter_ = new ArcIterator(*fst_, s); + aiter_->SetFlags(kArcNoCache, kArcNoCache); + narcs_ = internal::NumArcs(*fst_, s); + loop_.nextstate = s; + } + + bool Find(Label match_label) { + exact_match_ = true; + if (error_) { + current_loop_ = false; + match_label_ = kNoLabel; + return false; + } + current_loop_ = match_label == 0; + match_label_ = match_label == kNoLabel ? 0 : match_label; + if (Search()) { + return true; + } else { + return current_loop_; + } + } + + // Positions matcher to the first position where inserting + // match_label would maintain the sort order. + void LowerBound(Label match_label) { + exact_match_ = false; + current_loop_ = false; + if (error_) { + match_label_ = kNoLabel; + return; + } + match_label_ = match_label; + Search(); + } + + // After Find(), returns false if no more exact matches. + // After LowerBound(), returns false if no more arcs. + bool Done() const { + if (current_loop_) + return false; + if (aiter_->Done()) + return true; + if (!exact_match_) + return false; + aiter_->SetFlags( + match_type_ == MATCH_INPUT ? kArcILabelValue : kArcOLabelValue, + kArcValueFlags); + Label label = match_type_ == MATCH_INPUT ? + aiter_->Value().ilabel : aiter_->Value().olabel; + return label != match_label_; + } + + const Arc& Value() const { + if (current_loop_) { + return loop_; + } + aiter_->SetFlags(kArcValueFlags, kArcValueFlags); + return aiter_->Value(); + } + + void Next() { + if (current_loop_) + current_loop_ = false; + else + aiter_->Next(); + } + + virtual const F &GetFst() const { return *fst_; } + + virtual uint64 Properties(uint64 inprops) const { + uint64 outprops = inprops; + if (error_) outprops |= kError; + return outprops; + } + + size_t Position() const { return aiter_ ? aiter_->Position() : 0; } + + private: + virtual void SetState_(StateId s) { SetState(s); } + virtual bool Find_(Label label) { return Find(label); } + virtual bool Done_() const { return Done(); } + virtual const Arc& Value_() const { return Value(); } + virtual void Next_() { Next(); } + + bool Search(); + + const F *fst_; + StateId s_; // Current state + ArcIterator *aiter_; // Iterator for current state + MatchType match_type_; // Type of match to perform + Label binary_label_; // Least label for binary search + Label match_label_; // Current label to be matched + size_t narcs_; // Current state arc count + Arc loop_; // For non-consuming symbols + bool current_loop_; // Current arc is the implicit loop + bool exact_match_; // Exact match or lower bound? + bool error_; // Error encountered + + void operator=(const SortedMatcher &); // Disallow +}; + +// Returns true iff match to match_label_. Positions arc iterator at +// lower bound regardless. +template inline +bool SortedMatcher::Search() { + aiter_->SetFlags( + match_type_ == MATCH_INPUT ? kArcILabelValue : kArcOLabelValue, + kArcValueFlags); + if (match_label_ >= binary_label_) { + // Binary search for match. + size_t low = 0; + size_t high = narcs_; + while (low < high) { + size_t mid = (low + high) / 2; + aiter_->Seek(mid); + Label label = match_type_ == MATCH_INPUT ? + aiter_->Value().ilabel : aiter_->Value().olabel; + if (label > match_label_) { + high = mid; + } else if (label < match_label_) { + low = mid + 1; + } else { + // find first matching label (when non-determinism) + for (size_t i = mid; i > low; --i) { + aiter_->Seek(i - 1); + label = match_type_ == MATCH_INPUT ? aiter_->Value().ilabel : + aiter_->Value().olabel; + if (label != match_label_) { + aiter_->Seek(i); + return true; + } + } + return true; + } + } + aiter_->Seek(low); + return false; + } else { + // Linear search for match. + for (aiter_->Reset(); !aiter_->Done(); aiter_->Next()) { + Label label = match_type_ == MATCH_INPUT ? + aiter_->Value().ilabel : aiter_->Value().olabel; + if (label == match_label_) { + return true; + } + if (label > match_label_) + break; + } + return false; + } +} + + +// Specifies whether during matching we rewrite both the input and output sides. +enum MatcherRewriteMode { + MATCHER_REWRITE_AUTO = 0, // Rewrites both sides iff acceptor. + MATCHER_REWRITE_ALWAYS, + MATCHER_REWRITE_NEVER +}; + + +// For any requested label that doesn't match at a state, this matcher +// considers all transitions that match the label 'rho_label' (rho = +// 'rest'). Each such rho transition found is returned with the +// rho_label rewritten as the requested label (both sides if an +// acceptor, or if 'rewrite_both' is true and both input and output +// labels of the found transition are 'rho_label'). If 'rho_label' is +// kNoLabel, this special matching is not done. RhoMatcher is +// templated itself on a matcher, which is used to perform the +// underlying matching. By default, the underlying matcher is +// constructed by RhoMatcher. The user can instead pass in this +// object; in that case, RhoMatcher takes its ownership. +template +class RhoMatcher : public MatcherBase { + public: + typedef typename M::FST FST; + typedef typename M::Arc Arc; + typedef typename Arc::StateId StateId; + typedef typename Arc::Label Label; + typedef typename Arc::Weight Weight; + + RhoMatcher(const FST &fst, + MatchType match_type, + Label rho_label = kNoLabel, + MatcherRewriteMode rewrite_mode = MATCHER_REWRITE_AUTO, + M *matcher = 0) + : matcher_(matcher ? matcher : new M(fst, match_type)), + match_type_(match_type), + rho_label_(rho_label), + error_(false) { + if (match_type == MATCH_BOTH) { + FSTERROR() << "RhoMatcher: bad match type"; + match_type_ = MATCH_NONE; + error_ = true; + } + if (rho_label == 0) { + FSTERROR() << "RhoMatcher: 0 cannot be used as rho_label"; + rho_label_ = kNoLabel; + error_ = true; + } + + if (rewrite_mode == MATCHER_REWRITE_AUTO) + rewrite_both_ = fst.Properties(kAcceptor, true); + else if (rewrite_mode == MATCHER_REWRITE_ALWAYS) + rewrite_both_ = true; + else + rewrite_both_ = false; + } + + RhoMatcher(const RhoMatcher &matcher, bool safe = false) + : matcher_(new M(*matcher.matcher_, safe)), + match_type_(matcher.match_type_), + rho_label_(matcher.rho_label_), + rewrite_both_(matcher.rewrite_both_), + error_(matcher.error_) {} + + virtual ~RhoMatcher() { + delete matcher_; + } + + virtual RhoMatcher *Copy(bool safe = false) const { + return new RhoMatcher(*this, safe); + } + + virtual MatchType Type(bool test) const { return matcher_->Type(test); } + + void SetState(StateId s) { + matcher_->SetState(s); + has_rho_ = rho_label_ != kNoLabel; + } + + bool Find(Label match_label) { + if (match_label == rho_label_ && rho_label_ != kNoLabel) { + FSTERROR() << "RhoMatcher::Find: bad label (rho)"; + error_ = true; + return false; + } + if (matcher_->Find(match_label)) { + rho_match_ = kNoLabel; + return true; + } else if (has_rho_ && match_label != 0 && match_label != kNoLabel && + (has_rho_ = matcher_->Find(rho_label_))) { + rho_match_ = match_label; + return true; + } else { + return false; + } + } + + bool Done() const { return matcher_->Done(); } + + const Arc& Value() const { + if (rho_match_ == kNoLabel) { + return matcher_->Value(); + } else { + rho_arc_ = matcher_->Value(); + if (rewrite_both_) { + if (rho_arc_.ilabel == rho_label_) + rho_arc_.ilabel = rho_match_; + if (rho_arc_.olabel == rho_label_) + rho_arc_.olabel = rho_match_; + } else if (match_type_ == MATCH_INPUT) { + rho_arc_.ilabel = rho_match_; + } else { + rho_arc_.olabel = rho_match_; + } + return rho_arc_; + } + } + + void Next() { matcher_->Next(); } + + virtual const FST &GetFst() const { return matcher_->GetFst(); } + + virtual uint64 Properties(uint64 props) const; + + virtual uint32 Flags() const { + if (rho_label_ == kNoLabel || match_type_ == MATCH_NONE) + return matcher_->Flags(); + return matcher_->Flags() | kRequireMatch; + } + + private: + virtual void SetState_(StateId s) { SetState(s); } + virtual bool Find_(Label label) { return Find(label); } + virtual bool Done_() const { return Done(); } + virtual const Arc& Value_() const { return Value(); } + virtual void Next_() { Next(); } + + M *matcher_; + MatchType match_type_; // Type of match requested + Label rho_label_; // Label that represents the rho transition + bool rewrite_both_; // Rewrite both sides when both are 'rho_label_' + bool has_rho_; // Are there possibly rhos at the current state? + Label rho_match_; // Current label that matches rho transition + mutable Arc rho_arc_; // Arc to return when rho match + bool error_; // Error encountered + + void operator=(const RhoMatcher &); // Disallow +}; + +template inline +uint64 RhoMatcher::Properties(uint64 inprops) const { + uint64 outprops = matcher_->Properties(inprops); + if (error_) outprops |= kError; + + if (match_type_ == MATCH_NONE) { + return outprops; + } else if (match_type_ == MATCH_INPUT) { + if (rewrite_both_) { + return outprops & ~(kODeterministic | kNonODeterministic | kString | + kILabelSorted | kNotILabelSorted | + kOLabelSorted | kNotOLabelSorted); + } else { + return outprops & ~(kODeterministic | kAcceptor | kString | + kILabelSorted | kNotILabelSorted); + } + } else if (match_type_ == MATCH_OUTPUT) { + if (rewrite_both_) { + return outprops & ~(kIDeterministic | kNonIDeterministic | kString | + kILabelSorted | kNotILabelSorted | + kOLabelSorted | kNotOLabelSorted); + } else { + return outprops & ~(kIDeterministic | kAcceptor | kString | + kOLabelSorted | kNotOLabelSorted); + } + } else { + // Shouldn't ever get here. + FSTERROR() << "RhoMatcher:: bad match type: " << match_type_; + return 0; + } +} + + +// For any requested label, this matcher considers all transitions +// that match the label 'sigma_label' (sigma = "any"), and this in +// additions to transitions with the requested label. Each such sigma +// transition found is returned with the sigma_label rewritten as the +// requested label (both sides if an acceptor, or if 'rewrite_both' is +// true and both input and output labels of the found transition are +// 'sigma_label'). If 'sigma_label' is kNoLabel, this special +// matching is not done. SigmaMatcher is templated itself on a +// matcher, which is used to perform the underlying matching. By +// default, the underlying matcher is constructed by SigmaMatcher. +// The user can instead pass in this object; in that case, +// SigmaMatcher takes its ownership. +template +class SigmaMatcher : public MatcherBase { + public: + typedef typename M::FST FST; + typedef typename M::Arc Arc; + typedef typename Arc::StateId StateId; + typedef typename Arc::Label Label; + typedef typename Arc::Weight Weight; + + SigmaMatcher(const FST &fst, + MatchType match_type, + Label sigma_label = kNoLabel, + MatcherRewriteMode rewrite_mode = MATCHER_REWRITE_AUTO, + M *matcher = 0) + : matcher_(matcher ? matcher : new M(fst, match_type)), + match_type_(match_type), + sigma_label_(sigma_label), + error_(false) { + if (match_type == MATCH_BOTH) { + FSTERROR() << "SigmaMatcher: bad match type"; + match_type_ = MATCH_NONE; + error_ = true; + } + if (sigma_label == 0) { + FSTERROR() << "SigmaMatcher: 0 cannot be used as sigma_label"; + sigma_label_ = kNoLabel; + error_ = true; + } + + if (rewrite_mode == MATCHER_REWRITE_AUTO) + rewrite_both_ = fst.Properties(kAcceptor, true); + else if (rewrite_mode == MATCHER_REWRITE_ALWAYS) + rewrite_both_ = true; + else + rewrite_both_ = false; + } + + SigmaMatcher(const SigmaMatcher &matcher, bool safe = false) + : matcher_(new M(*matcher.matcher_, safe)), + match_type_(matcher.match_type_), + sigma_label_(matcher.sigma_label_), + rewrite_both_(matcher.rewrite_both_), + error_(matcher.error_) {} + + virtual ~SigmaMatcher() { + delete matcher_; + } + + virtual SigmaMatcher *Copy(bool safe = false) const { + return new SigmaMatcher(*this, safe); + } + + virtual MatchType Type(bool test) const { return matcher_->Type(test); } + + void SetState(StateId s) { + matcher_->SetState(s); + has_sigma_ = + sigma_label_ != kNoLabel ? matcher_->Find(sigma_label_) : false; + } + + bool Find(Label match_label) { + match_label_ = match_label; + if (match_label == sigma_label_ && sigma_label_ != kNoLabel) { + FSTERROR() << "SigmaMatcher::Find: bad label (sigma)"; + error_ = true; + return false; + } + if (matcher_->Find(match_label)) { + sigma_match_ = kNoLabel; + return true; + } else if (has_sigma_ && match_label != 0 && match_label != kNoLabel && + matcher_->Find(sigma_label_)) { + sigma_match_ = match_label; + return true; + } else { + return false; + } + } + + bool Done() const { + return matcher_->Done(); + } + + const Arc& Value() const { + if (sigma_match_ == kNoLabel) { + return matcher_->Value(); + } else { + sigma_arc_ = matcher_->Value(); + if (rewrite_both_) { + if (sigma_arc_.ilabel == sigma_label_) + sigma_arc_.ilabel = sigma_match_; + if (sigma_arc_.olabel == sigma_label_) + sigma_arc_.olabel = sigma_match_; + } else if (match_type_ == MATCH_INPUT) { + sigma_arc_.ilabel = sigma_match_; + } else { + sigma_arc_.olabel = sigma_match_; + } + return sigma_arc_; + } + } + + void Next() { + matcher_->Next(); + if (matcher_->Done() && has_sigma_ && (sigma_match_ == kNoLabel) && + (match_label_ > 0)) { + matcher_->Find(sigma_label_); + sigma_match_ = match_label_; + } + } + + virtual const FST &GetFst() const { return matcher_->GetFst(); } + + virtual uint64 Properties(uint64 props) const; + + virtual uint32 Flags() const { + if (sigma_label_ == kNoLabel || match_type_ == MATCH_NONE) + return matcher_->Flags(); + // kRequireMatch temporarily disabled until issues + // in //speech/gaudi/annotation/util/denorm are resolved. + // return matcher_->Flags() | kRequireMatch; + return matcher_->Flags(); + } + +private: + virtual void SetState_(StateId s) { SetState(s); } + virtual bool Find_(Label label) { return Find(label); } + virtual bool Done_() const { return Done(); } + virtual const Arc& Value_() const { return Value(); } + virtual void Next_() { Next(); } + + M *matcher_; + MatchType match_type_; // Type of match requested + Label sigma_label_; // Label that represents the sigma transition + bool rewrite_both_; // Rewrite both sides when both are 'sigma_label_' + bool has_sigma_; // Are there sigmas at the current state? + Label sigma_match_; // Current label that matches sigma transition + mutable Arc sigma_arc_; // Arc to return when sigma match + Label match_label_; // Label being matched + bool error_; // Error encountered + + void operator=(const SigmaMatcher &); // disallow +}; + +template inline +uint64 SigmaMatcher::Properties(uint64 inprops) const { + uint64 outprops = matcher_->Properties(inprops); + if (error_) outprops |= kError; + + if (match_type_ == MATCH_NONE) { + return outprops; + } else if (rewrite_both_) { + return outprops & ~(kIDeterministic | kNonIDeterministic | + kODeterministic | kNonODeterministic | + kILabelSorted | kNotILabelSorted | + kOLabelSorted | kNotOLabelSorted | + kString); + } else if (match_type_ == MATCH_INPUT) { + return outprops & ~(kIDeterministic | kNonIDeterministic | + kODeterministic | kNonODeterministic | + kILabelSorted | kNotILabelSorted | + kString | kAcceptor); + } else if (match_type_ == MATCH_OUTPUT) { + return outprops & ~(kIDeterministic | kNonIDeterministic | + kODeterministic | kNonODeterministic | + kOLabelSorted | kNotOLabelSorted | + kString | kAcceptor); + } else { + // Shouldn't ever get here. + FSTERROR() << "SigmaMatcher:: bad match type: " << match_type_; + return 0; + } +} + + +// For any requested label that doesn't match at a state, this matcher +// considers the *unique* transition that matches the label 'phi_label' +// (phi = 'fail'), and recursively looks for a match at its +// destination. When 'phi_loop' is true, if no match is found but a +// phi self-loop is found, then the phi transition found is returned +// with the phi_label rewritten as the requested label (both sides if +// an acceptor, or if 'rewrite_both' is true and both input and output +// labels of the found transition are 'phi_label'). If 'phi_label' is +// kNoLabel, this special matching is not done. PhiMatcher is +// templated itself on a matcher, which is used to perform the +// underlying matching. By default, the underlying matcher is +// constructed by PhiMatcher. The user can instead pass in this +// object; in that case, PhiMatcher takes its ownership. +// Warning: phi non-determinism not supported (for simplicity). +template +class PhiMatcher : public MatcherBase { + public: + typedef typename M::FST FST; + typedef typename M::Arc Arc; + typedef typename Arc::StateId StateId; + typedef typename Arc::Label Label; + typedef typename Arc::Weight Weight; + + PhiMatcher(const FST &fst, + MatchType match_type, + Label phi_label = kNoLabel, + bool phi_loop = true, + MatcherRewriteMode rewrite_mode = MATCHER_REWRITE_AUTO, + M *matcher = 0) + : matcher_(matcher ? matcher : new M(fst, match_type)), + match_type_(match_type), + phi_label_(phi_label), + state_(kNoStateId), + phi_loop_(phi_loop), + error_(false) { + if (match_type == MATCH_BOTH) { + FSTERROR() << "PhiMatcher: bad match type"; + match_type_ = MATCH_NONE; + error_ = true; + } + + if (rewrite_mode == MATCHER_REWRITE_AUTO) + rewrite_both_ = fst.Properties(kAcceptor, true); + else if (rewrite_mode == MATCHER_REWRITE_ALWAYS) + rewrite_both_ = true; + else + rewrite_both_ = false; + } + + PhiMatcher(const PhiMatcher &matcher, bool safe = false) + : matcher_(new M(*matcher.matcher_, safe)), + match_type_(matcher.match_type_), + phi_label_(matcher.phi_label_), + rewrite_both_(matcher.rewrite_both_), + state_(kNoStateId), + phi_loop_(matcher.phi_loop_), + error_(matcher.error_) {} + + virtual ~PhiMatcher() { + delete matcher_; + } + + virtual PhiMatcher *Copy(bool safe = false) const { + return new PhiMatcher(*this, safe); + } + + virtual MatchType Type(bool test) const { return matcher_->Type(test); } + + void SetState(StateId s) { + matcher_->SetState(s); + state_ = s; + has_phi_ = phi_label_ != kNoLabel; + } + + bool Find(Label match_label); + + bool Done() const { return matcher_->Done(); } + + const Arc& Value() const { + if ((phi_match_ == kNoLabel) && (phi_weight_ == Weight::One())) { + return matcher_->Value(); + } else if (phi_match_ == 0) { // Virtual epsilon loop + phi_arc_ = Arc(kNoLabel, 0, Weight::One(), state_); + if (match_type_ == MATCH_OUTPUT) + swap(phi_arc_.ilabel, phi_arc_.olabel); + return phi_arc_; + } else { + phi_arc_ = matcher_->Value(); + phi_arc_.weight = Times(phi_weight_, phi_arc_.weight); + if (phi_match_ != kNoLabel) { // Phi loop match + if (rewrite_both_) { + if (phi_arc_.ilabel == phi_label_) + phi_arc_.ilabel = phi_match_; + if (phi_arc_.olabel == phi_label_) + phi_arc_.olabel = phi_match_; + } else if (match_type_ == MATCH_INPUT) { + phi_arc_.ilabel = phi_match_; + } else { + phi_arc_.olabel = phi_match_; + } + } + return phi_arc_; + } + } + + void Next() { matcher_->Next(); } + + virtual const FST &GetFst() const { return matcher_->GetFst(); } + + virtual uint64 Properties(uint64 props) const; + + virtual uint32 Flags() const { + if (phi_label_ == kNoLabel || match_type_ == MATCH_NONE) + return matcher_->Flags(); + return matcher_->Flags() | kRequireMatch; + } + +private: + virtual void SetState_(StateId s) { SetState(s); } + virtual bool Find_(Label label) { return Find(label); } + virtual bool Done_() const { return Done(); } + virtual const Arc& Value_() const { return Value(); } + virtual void Next_() { Next(); } + + M *matcher_; + MatchType match_type_; // Type of match requested + Label phi_label_; // Label that represents the phi transition + bool rewrite_both_; // Rewrite both sides when both are 'phi_label_' + bool has_phi_; // Are there possibly phis at the current state? + Label phi_match_; // Current label that matches phi loop + mutable Arc phi_arc_; // Arc to return + StateId state_; // State where looking for matches + Weight phi_weight_; // Product of the weights of phi transitions taken + bool phi_loop_; // When true, phi self-loop are allowed and treated + // as rho (required for Aho-Corasick) + bool error_; // Error encountered + + void operator=(const PhiMatcher &); // disallow +}; + +template inline +bool PhiMatcher::Find(Label match_label) { + if (match_label == phi_label_ && phi_label_ != kNoLabel && phi_label_ != 0) { + FSTERROR() << "PhiMatcher::Find: bad label (phi): " << phi_label_; + error_ = true; + return false; + } + matcher_->SetState(state_); + phi_match_ = kNoLabel; + phi_weight_ = Weight::One(); + if (phi_label_ == 0) { // When 'phi_label_ == 0', + if (match_label == kNoLabel) // there are no more true epsilon arcs, + return false; + if (match_label == 0) { // but virtual eps loop need to be returned + if (!matcher_->Find(kNoLabel)) { + return matcher_->Find(0); + } else { + phi_match_ = 0; + return true; + } + } + } + if (!has_phi_ || match_label == 0 || match_label == kNoLabel) + return matcher_->Find(match_label); + StateId state = state_; + while (!matcher_->Find(match_label)) { + // Look for phi transition (if phi_label_ == 0, we need to look + // for -1 to avoid getting the virtual self-loop) + if (!matcher_->Find(phi_label_ == 0 ? -1 : phi_label_)) + return false; + if (phi_loop_ && matcher_->Value().nextstate == state) { + phi_match_ = match_label; + return true; + } + phi_weight_ = Times(phi_weight_, matcher_->Value().weight); + state = matcher_->Value().nextstate; + matcher_->Next(); + if (!matcher_->Done()) { + FSTERROR() << "PhiMatcher: phi non-determinism not supported"; + error_ = true; + } + matcher_->SetState(state); + } + return true; +} + +template inline +uint64 PhiMatcher::Properties(uint64 inprops) const { + uint64 outprops = matcher_->Properties(inprops); + if (error_) outprops |= kError; + + if (match_type_ == MATCH_NONE) { + return outprops; + } else if (match_type_ == MATCH_INPUT) { + if (phi_label_ == 0) { + outprops &= ~kEpsilons | ~kIEpsilons | ~kOEpsilons; + outprops |= kNoEpsilons | kNoIEpsilons; + } + if (rewrite_both_) { + return outprops & ~(kODeterministic | kNonODeterministic | kString | + kILabelSorted | kNotILabelSorted | + kOLabelSorted | kNotOLabelSorted); + } else { + return outprops & ~(kODeterministic | kAcceptor | kString | + kILabelSorted | kNotILabelSorted | + kOLabelSorted | kNotOLabelSorted); + } + } else if (match_type_ == MATCH_OUTPUT) { + if (phi_label_ == 0) { + outprops &= ~kEpsilons | ~kIEpsilons | ~kOEpsilons; + outprops |= kNoEpsilons | kNoOEpsilons; + } + if (rewrite_both_) { + return outprops & ~(kIDeterministic | kNonIDeterministic | kString | + kILabelSorted | kNotILabelSorted | + kOLabelSorted | kNotOLabelSorted); + } else { + return outprops & ~(kIDeterministic | kAcceptor | kString | + kILabelSorted | kNotILabelSorted | + kOLabelSorted | kNotOLabelSorted); + } + } else { + // Shouldn't ever get here. + FSTERROR() << "PhiMatcher:: bad match type: " << match_type_; + return 0; + } +} + + +// +// MULTI-EPS MATCHER FLAGS +// + +// Return multi-epsilon arcs for Find(kNoLabel). +const uint32 kMultiEpsList = 0x00000001; + +// Return a kNolabel loop for Find(multi_eps). +const uint32 kMultiEpsLoop = 0x00000002; + +// MultiEpsMatcher: allows treating multiple non-0 labels as +// non-consuming labels in addition to 0 that is always +// non-consuming. Precise behavior controlled by 'flags' argument. By +// default, the underlying matcher is constructed by +// MultiEpsMatcher. The user can instead pass in this object; in that +// case, MultiEpsMatcher takes its ownership iff 'own_matcher' is +// true. +template +class MultiEpsMatcher { + public: + typedef typename M::FST FST; + typedef typename M::Arc Arc; + typedef typename Arc::StateId StateId; + typedef typename Arc::Label Label; + typedef typename Arc::Weight Weight; + + MultiEpsMatcher(const FST &fst, MatchType match_type, + uint32 flags = (kMultiEpsLoop | kMultiEpsList), + M *matcher = 0, bool own_matcher = true) + : matcher_(matcher ? matcher : new M(fst, match_type)), + flags_(flags), + own_matcher_(matcher ? own_matcher : true) { + if (match_type == MATCH_INPUT) { + loop_.ilabel = kNoLabel; + loop_.olabel = 0; + } else { + loop_.ilabel = 0; + loop_.olabel = kNoLabel; + } + loop_.weight = Weight::One(); + loop_.nextstate = kNoStateId; + } + + MultiEpsMatcher(const MultiEpsMatcher &matcher, bool safe = false) + : matcher_(new M(*matcher.matcher_, safe)), + flags_(matcher.flags_), + own_matcher_(true), + multi_eps_labels_(matcher.multi_eps_labels_), + loop_(matcher.loop_) { + loop_.nextstate = kNoStateId; + } + + ~MultiEpsMatcher() { + if (own_matcher_) + delete matcher_; + } + + MultiEpsMatcher *Copy(bool safe = false) const { + return new MultiEpsMatcher(*this, safe); + } + + MatchType Type(bool test) const { return matcher_->Type(test); } + + void SetState(StateId s) { + matcher_->SetState(s); + loop_.nextstate = s; + } + + bool Find(Label match_label); + + bool Done() const { + return done_; + } + + const Arc& Value() const { + return current_loop_ ? loop_ : matcher_->Value(); + } + + void Next() { + if (!current_loop_) { + matcher_->Next(); + done_ = matcher_->Done(); + if (done_ && multi_eps_iter_ != multi_eps_labels_.End()) { + ++multi_eps_iter_; + while ((multi_eps_iter_ != multi_eps_labels_.End()) && + !matcher_->Find(*multi_eps_iter_)) + ++multi_eps_iter_; + if (multi_eps_iter_ != multi_eps_labels_.End()) + done_ = false; + else + done_ = !matcher_->Find(kNoLabel); + + } + } else { + done_ = true; + } + } + + const FST &GetFst() const { return matcher_->GetFst(); } + + uint64 Properties(uint64 props) const { return matcher_->Properties(props); } + + uint32 Flags() const { return matcher_->Flags(); } + + void AddMultiEpsLabel(Label label) { + if (label == 0) { + FSTERROR() << "MultiEpsMatcher: Bad multi-eps label: 0"; + } else { + multi_eps_labels_.Insert(label); + } + } + + void RemoveMultiEpsLabel(Label label) { + if (label == 0) { + FSTERROR() << "MultiEpsMatcher: Bad multi-eps label: 0"; + } else { + multi_eps_labels_.Erase(label); + } + } + + void ClearMultiEpsLabels() { + multi_eps_labels_.Clear(); + } + +private: + M *matcher_; + uint32 flags_; + bool own_matcher_; // Does this class delete the matcher? + + // Multi-eps label set + CompactSet multi_eps_labels_; + typename CompactSet::const_iterator multi_eps_iter_; + + bool current_loop_; // Current arc is the implicit loop + mutable Arc loop_; // For non-consuming symbols + bool done_; // Matching done + + void operator=(const MultiEpsMatcher &); // Disallow +}; + +template inline +bool MultiEpsMatcher::Find(Label match_label) { + multi_eps_iter_ = multi_eps_labels_.End(); + current_loop_ = false; + bool ret; + if (match_label == 0) { + ret = matcher_->Find(0); + } else if (match_label == kNoLabel) { + if (flags_ & kMultiEpsList) { + // return all non-consuming arcs (incl. epsilon) + multi_eps_iter_ = multi_eps_labels_.Begin(); + while ((multi_eps_iter_ != multi_eps_labels_.End()) && + !matcher_->Find(*multi_eps_iter_)) + ++multi_eps_iter_; + if (multi_eps_iter_ != multi_eps_labels_.End()) + ret = true; + else + ret = matcher_->Find(kNoLabel); + } else { + // return all epsilon arcs + ret = matcher_->Find(kNoLabel); + } + } else if ((flags_ & kMultiEpsLoop) && + multi_eps_labels_.Find(match_label) != multi_eps_labels_.End()) { + // return 'implicit' loop + current_loop_ = true; + ret = true; + } else { + ret = matcher_->Find(match_label); + } + done_ = !ret; + return ret; +} + + +// Generic matcher, templated on the FST definition +// - a wrapper around pointer to specific one. +// Here is a typical use: \code +// Matcher matcher(fst, MATCH_INPUT); +// matcher.SetState(state); +// if (matcher.Find(label)) +// for (; !matcher.Done(); matcher.Next()) { +// StdArc &arc = matcher.Value(); +// ... +// } \endcode +template +class Matcher { + public: + typedef F FST; + typedef typename F::Arc Arc; + typedef typename Arc::StateId StateId; + typedef typename Arc::Label Label; + typedef typename Arc::Weight Weight; + + Matcher(const F &fst, MatchType match_type) { + base_ = fst.InitMatcher(match_type); + if (!base_) + base_ = new SortedMatcher(fst, match_type); + } + + Matcher(const Matcher &matcher, bool safe = false) { + base_ = matcher.base_->Copy(safe); + } + + // Takes ownership of the provided matcher + Matcher(MatcherBase* base_matcher) { base_ = base_matcher; } + + ~Matcher() { delete base_; } + + Matcher *Copy(bool safe = false) const { + return new Matcher(*this, safe); + } + + MatchType Type(bool test) const { return base_->Type(test); } + void SetState(StateId s) { base_->SetState(s); } + bool Find(Label label) { return base_->Find(label); } + bool Done() const { return base_->Done(); } + const Arc& Value() const { return base_->Value(); } + void Next() { base_->Next(); } + const F &GetFst() const { return static_cast(base_->GetFst()); } + uint64 Properties(uint64 props) const { return base_->Properties(props); } + uint32 Flags() const { return base_->Flags() & kMatcherFlags; } + + private: + MatcherBase *base_; + + void operator=(const Matcher &); // disallow +}; + +} // namespace fst + + + +#endif // FST_LIB_MATCHER_H__ -- cgit v1.2.3