summaryrefslogtreecommitdiff
path: root/kaldi_io/src/tools/openfst/include/fst/replace.h
diff options
context:
space:
mode:
Diffstat (limited to 'kaldi_io/src/tools/openfst/include/fst/replace.h')
-rw-r--r--kaldi_io/src/tools/openfst/include/fst/replace.h1453
1 files changed, 0 insertions, 1453 deletions
diff --git a/kaldi_io/src/tools/openfst/include/fst/replace.h b/kaldi_io/src/tools/openfst/include/fst/replace.h
deleted file mode 100644
index ef5f6cc..0000000
--- a/kaldi_io/src/tools/openfst/include/fst/replace.h
+++ /dev/null
@@ -1,1453 +0,0 @@
-// replace.h
-
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-//
-// Copyright 2005-2010 Google, Inc.
-// Author: [email protected] (Johan Schalkwyk)
-//
-// \file
-// Functions and classes for the recursive replacement of Fsts.
-//
-
-#ifndef FST_LIB_REPLACE_H__
-#define FST_LIB_REPLACE_H__
-
-#include <tr1/unordered_map>
-using std::tr1::unordered_map;
-using std::tr1::unordered_multimap;
-#include <set>
-#include <string>
-#include <utility>
-using std::pair; using std::make_pair;
-#include <vector>
-using std::vector;
-
-#include <fst/cache.h>
-#include <fst/expanded-fst.h>
-#include <fst/fst.h>
-#include <fst/matcher.h>
-#include <fst/replace-util.h>
-#include <fst/state-table.h>
-#include <fst/test-properties.h>
-
-namespace fst {
-
-//
-// REPLACE STATE TUPLES AND TABLES
-//
-// The replace state table has the form
-//
-// template <class A, class P>
-// class ReplaceStateTable {
-// public:
-// typedef A Arc;
-// typedef P PrefixId;
-// typedef typename A::StateId StateId;
-// typedef ReplaceStateTuple<StateId, PrefixId> StateTuple;
-// typedef typename A::Label Label;
-//
-// // Required constuctor
-// ReplaceStateTable(const vector<pair<Label, const Fst<A>*> > &fst_tuples,
-// Label root);
-//
-// // Required copy constructor that does not copy state
-// ReplaceStateTable(const ReplaceStateTable<A,P> &table);
-//
-// // Lookup state ID by tuple. If it doesn't exist, then add it.
-// StateId FindState(const StateTuple &tuple);
-//
-// // Lookup state tuple by ID.
-// const StateTuple &Tuple(StateId id) const;
-// };
-
-
-// \struct ReplaceStateTuple
-// \brief Tuple of information that uniquely defines a state in replace
-template <class S, class P>
-struct ReplaceStateTuple {
- typedef S StateId;
- typedef P PrefixId;
-
- ReplaceStateTuple()
- : prefix_id(-1), fst_id(kNoStateId), fst_state(kNoStateId) {}
-
- ReplaceStateTuple(PrefixId p, StateId f, StateId s)
- : prefix_id(p), fst_id(f), fst_state(s) {}
-
- PrefixId prefix_id; // index in prefix table
- StateId fst_id; // current fst being walked
- StateId fst_state; // current state in fst being walked, not to be
- // confused with the state_id of the combined fst
-};
-
-
-// Equality of replace state tuples.
-template <class S, class P>
-inline bool operator==(const ReplaceStateTuple<S, P>& x,
- const ReplaceStateTuple<S, P>& y) {
- return x.prefix_id == y.prefix_id &&
- x.fst_id == y.fst_id &&
- x.fst_state == y.fst_state;
-}
-
-
-// \class ReplaceRootSelector
-// Functor returning true for tuples corresponding to states in the root FST
-template <class S, class P>
-class ReplaceRootSelector {
- public:
- bool operator()(const ReplaceStateTuple<S, P> &tuple) const {
- return tuple.prefix_id == 0;
- }
-};
-
-
-// \class ReplaceFingerprint
-// Fingerprint for general replace state tuples.
-template <class S, class P>
-class ReplaceFingerprint {
- public:
- ReplaceFingerprint(const vector<uint64> *size_array)
- : cumulative_size_array_(size_array) {}
-
- uint64 operator()(const ReplaceStateTuple<S, P> &tuple) const {
- return tuple.prefix_id * (cumulative_size_array_->back()) +
- cumulative_size_array_->at(tuple.fst_id - 1) +
- tuple.fst_state;
- }
-
- private:
- const vector<uint64> *cumulative_size_array_;
-};
-
-
-// \class ReplaceFstStateFingerprint
-// Useful when the fst_state uniquely define the tuple.
-template <class S, class P>
-class ReplaceFstStateFingerprint {
- public:
- uint64 operator()(const ReplaceStateTuple<S, P>& tuple) const {
- return tuple.fst_state;
- }
-};
-
-
-// \class ReplaceHash
-// A generic hash function for replace state tuples.
-template <typename S, typename P>
-class ReplaceHash {
- public:
- size_t operator()(const ReplaceStateTuple<S, P>& t) const {
- return t.prefix_id + t.fst_id * kPrime0 + t.fst_state * kPrime1;
- }
- private:
- static const size_t kPrime0;
- static const size_t kPrime1;
-};
-
-template <typename S, typename P>
-const size_t ReplaceHash<S, P>::kPrime0 = 7853;
-
-template <typename S, typename P>
-const size_t ReplaceHash<S, P>::kPrime1 = 7867;
-
-template <class A, class T> class ReplaceFstMatcher;
-
-
-// \class VectorHashReplaceStateTable
-// A two-level state table for replace.
-// Warning: calls CountStates to compute the number of states of each
-// component Fst.
-template <class A, class P = ssize_t>
-class VectorHashReplaceStateTable {
- public:
- typedef A Arc;
- typedef typename A::StateId StateId;
- typedef typename A::Label Label;
- typedef P PrefixId;
- typedef ReplaceStateTuple<StateId, P> StateTuple;
- typedef VectorHashStateTable<ReplaceStateTuple<StateId, P>,
- ReplaceRootSelector<StateId, P>,
- ReplaceFstStateFingerprint<StateId, P>,
- ReplaceFingerprint<StateId, P> > StateTable;
-
- VectorHashReplaceStateTable(
- const vector<pair<Label, const Fst<A>*> > &fst_tuples,
- Label root) : root_size_(0) {
- cumulative_size_array_.push_back(0);
- for (size_t i = 0; i < fst_tuples.size(); ++i) {
- if (fst_tuples[i].first == root) {
- root_size_ = CountStates(*(fst_tuples[i].second));
- cumulative_size_array_.push_back(cumulative_size_array_.back());
- } else {
- cumulative_size_array_.push_back(cumulative_size_array_.back() +
- CountStates(*(fst_tuples[i].second)));
- }
- }
- state_table_ = new StateTable(
- new ReplaceRootSelector<StateId, P>,
- new ReplaceFstStateFingerprint<StateId, P>,
- new ReplaceFingerprint<StateId, P>(&cumulative_size_array_),
- root_size_,
- root_size_ + cumulative_size_array_.back());
- }
-
- VectorHashReplaceStateTable(const VectorHashReplaceStateTable<A, P> &table)
- : root_size_(table.root_size_),
- cumulative_size_array_(table.cumulative_size_array_) {
- state_table_ = new StateTable(
- new ReplaceRootSelector<StateId, P>,
- new ReplaceFstStateFingerprint<StateId, P>,
- new ReplaceFingerprint<StateId, P>(&cumulative_size_array_),
- root_size_,
- root_size_ + cumulative_size_array_.back());
- }
-
- ~VectorHashReplaceStateTable() {
- delete state_table_;
- }
-
- StateId FindState(const StateTuple &tuple) {
- return state_table_->FindState(tuple);
- }
-
- const StateTuple &Tuple(StateId id) const {
- return state_table_->Tuple(id);
- }
-
- private:
- StateId root_size_;
- vector<uint64> cumulative_size_array_;
- StateTable *state_table_;
-};
-
-
-// \class DefaultReplaceStateTable
-// Default replace state table
-template <class A, class P = ssize_t>
-class DefaultReplaceStateTable : public CompactHashStateTable<
- ReplaceStateTuple<typename A::StateId, P>,
- ReplaceHash<typename A::StateId, P> > {
- public:
- typedef A Arc;
- typedef typename A::StateId StateId;
- typedef typename A::Label Label;
- typedef P PrefixId;
- typedef ReplaceStateTuple<StateId, P> StateTuple;
- typedef CompactHashStateTable<StateTuple,
- ReplaceHash<StateId, PrefixId> > StateTable;
-
- using StateTable::FindState;
- using StateTable::Tuple;
-
- DefaultReplaceStateTable(
- const vector<pair<Label, const Fst<A>*> > &fst_tuples,
- Label root) {}
-
- DefaultReplaceStateTable(const DefaultReplaceStateTable<A, P> &table)
- : StateTable() {}
-};
-
-//
-// REPLACE FST CLASS
-//
-
-// By default ReplaceFst will copy the input label of the 'replace arc'.
-// For acceptors we do not want this behaviour. Instead we need to
-// create an epsilon arc when recursing into the appropriate Fst.
-// The 'epsilon_on_replace' option can be used to toggle this behaviour.
-template <class A, class T = DefaultReplaceStateTable<A> >
-struct ReplaceFstOptions : CacheOptions {
- int64 root; // root rule for expansion
- bool epsilon_on_replace;
- bool take_ownership; // take ownership of input Fst(s)
- T* state_table;
-
- ReplaceFstOptions(const CacheOptions &opts, int64 r)
- : CacheOptions(opts),
- root(r),
- epsilon_on_replace(false),
- take_ownership(false),
- state_table(0) {}
- explicit ReplaceFstOptions(int64 r)
- : root(r),
- epsilon_on_replace(false),
- take_ownership(false),
- state_table(0) {}
- ReplaceFstOptions(int64 r, bool epsilon_replace_arc)
- : root(r),
- epsilon_on_replace(epsilon_replace_arc),
- take_ownership(false),
- state_table(0) {}
- ReplaceFstOptions()
- : root(kNoLabel),
- epsilon_on_replace(false),
- take_ownership(false),
- state_table(0) {}
-};
-
-
-// \class ReplaceFstImpl
-// \brief Implementation class for replace class Fst
-//
-// The replace implementation class supports a dynamic
-// expansion of a recursive transition network represented as Fst
-// with dynamic replacable arcs.
-//
-template <class A, class T>
-class ReplaceFstImpl : public CacheImpl<A> {
- friend class ReplaceFstMatcher<A, T>;
-
- public:
- using FstImpl<A>::SetType;
- using FstImpl<A>::SetProperties;
- using FstImpl<A>::WriteHeader;
- using FstImpl<A>::SetInputSymbols;
- using FstImpl<A>::SetOutputSymbols;
- using FstImpl<A>::InputSymbols;
- using FstImpl<A>::OutputSymbols;
-
- using CacheImpl<A>::PushArc;
- using CacheImpl<A>::HasArcs;
- using CacheImpl<A>::HasFinal;
- using CacheImpl<A>::HasStart;
- using CacheImpl<A>::SetArcs;
- using CacheImpl<A>::SetFinal;
- using CacheImpl<A>::SetStart;
-
- typedef typename A::Label Label;
- typedef typename A::Weight Weight;
- typedef typename A::StateId StateId;
- typedef CacheState<A> State;
- typedef A Arc;
- typedef unordered_map<Label, Label> NonTerminalHash;
-
- typedef T StateTable;
- typedef typename T::PrefixId PrefixId;
- typedef ReplaceStateTuple<StateId, PrefixId> StateTuple;
-
- // constructor for replace class implementation.
- // \param fst_tuples array of label/fst tuples, one for each non-terminal
- ReplaceFstImpl(const vector< pair<Label, const Fst<A>* > >& fst_tuples,
- const ReplaceFstOptions<A, T> &opts)
- : CacheImpl<A>(opts),
- epsilon_on_replace_(opts.epsilon_on_replace),
- state_table_(opts.state_table ? opts.state_table :
- new StateTable(fst_tuples, opts.root)) {
-
- SetType("replace");
-
- if (fst_tuples.size() > 0) {
- SetInputSymbols(fst_tuples[0].second->InputSymbols());
- SetOutputSymbols(fst_tuples[0].second->OutputSymbols());
- }
-
- bool all_negative = true; // all nonterminals are negative?
- bool dense_range = true; // all nonterminals are positive
- // and form a dense range containing 1?
- for (size_t i = 0; i < fst_tuples.size(); ++i) {
- Label nonterminal = fst_tuples[i].first;
- if (nonterminal >= 0)
- all_negative = false;
- if (nonterminal > fst_tuples.size() || nonterminal <= 0)
- dense_range = false;
- }
-
- vector<uint64> inprops;
- bool all_ilabel_sorted = true;
- bool all_olabel_sorted = true;
- bool all_non_empty = true;
- fst_array_.push_back(0);
- for (size_t i = 0; i < fst_tuples.size(); ++i) {
- Label label = fst_tuples[i].first;
- const Fst<A> *fst = fst_tuples[i].second;
- nonterminal_hash_[label] = fst_array_.size();
- nonterminal_set_.insert(label);
- fst_array_.push_back(opts.take_ownership ? fst : fst->Copy());
- if (fst->Start() == kNoStateId)
- all_non_empty = false;
- if(!fst->Properties(kILabelSorted, false))
- all_ilabel_sorted = false;
- if(!fst->Properties(kOLabelSorted, false))
- all_olabel_sorted = false;
- inprops.push_back(fst->Properties(kCopyProperties, false));
- if (i) {
- if (!CompatSymbols(InputSymbols(), fst->InputSymbols())) {
- FSTERROR() << "ReplaceFstImpl: input symbols of Fst " << i
- << " does not match input symbols of base Fst (0'th fst)";
- SetProperties(kError, kError);
- }
- if (!CompatSymbols(OutputSymbols(), fst->OutputSymbols())) {
- FSTERROR() << "ReplaceFstImpl: output symbols of Fst " << i
- << " does not match output symbols of base Fst "
- << "(0'th fst)";
- SetProperties(kError, kError);
- }
- }
- }
- Label nonterminal = nonterminal_hash_[opts.root];
- if ((nonterminal == 0) && (fst_array_.size() > 1)) {
- FSTERROR() << "ReplaceFstImpl: no Fst corresponding to root label '"
- << opts.root << "' in the input tuple vector";
- SetProperties(kError, kError);
- }
- root_ = (nonterminal > 0) ? nonterminal : 1;
-
- SetProperties(ReplaceProperties(inprops, root_ - 1, epsilon_on_replace_,
- all_non_empty));
- // We assume that all terminals are positive. The resulting
- // ReplaceFst is known to be kILabelSorted when all sub-FSTs are
- // kILabelSorted and one of the 3 following conditions is satisfied:
- // 1. 'epsilon_on_replace' is false, or
- // 2. all non-terminals are negative, or
- // 3. all non-terninals are positive and form a dense range containing 1.
- if (all_ilabel_sorted &&
- (!epsilon_on_replace_ || all_negative || dense_range))
- SetProperties(kILabelSorted, kILabelSorted);
- // Similarly, the resulting ReplaceFst is known to be
- // kOLabelSorted when all sub-FSTs are kOLabelSorted and one of
- // the 2 following conditions is satisfied:
- // 1. all non-terminals are negative, or
- // 2. all non-terninals are positive and form a dense range containing 1.
- if (all_olabel_sorted && (all_negative || dense_range))
- SetProperties(kOLabelSorted, kOLabelSorted);
-
- // Enable optional caching as long as sorted and all non empty.
- if (Properties(kILabelSorted | kOLabelSorted) && all_non_empty)
- always_cache_ = false;
- else
- always_cache_ = true;
- VLOG(2) << "ReplaceFstImpl::ReplaceFstImpl: always_cache = "
- << (always_cache_ ? "true" : "false");
- }
-
- ReplaceFstImpl(const ReplaceFstImpl& impl)
- : CacheImpl<A>(impl),
- epsilon_on_replace_(impl.epsilon_on_replace_),
- always_cache_(impl.always_cache_),
- state_table_(new StateTable(*(impl.state_table_))),
- nonterminal_set_(impl.nonterminal_set_),
- nonterminal_hash_(impl.nonterminal_hash_),
- root_(impl.root_) {
- SetType("replace");
- SetProperties(impl.Properties(), kCopyProperties);
- SetInputSymbols(impl.InputSymbols());
- SetOutputSymbols(impl.OutputSymbols());
- fst_array_.reserve(impl.fst_array_.size());
- fst_array_.push_back(0);
- for (size_t i = 1; i < impl.fst_array_.size(); ++i) {
- fst_array_.push_back(impl.fst_array_[i]->Copy(true));
- }
- }
-
- ~ReplaceFstImpl() {
- VLOG(2) << "~ReplaceFstImpl: gc = "
- << (CacheImpl<A>::GetCacheGc() ? "true" : "false")
- << ", gc_size = " << CacheImpl<A>::GetCacheSize()
- << ", gc_limit = " << CacheImpl<A>::GetCacheLimit();
-
- delete state_table_;
- for (size_t i = 1; i < fst_array_.size(); ++i) {
- delete fst_array_[i];
- }
- }
-
- // Computes the dependency graph of the replace class and returns
- // true if the dependencies are cyclic. Cyclic dependencies will result
- // in an un-expandable replace fst.
- bool CyclicDependencies() const {
- ReplaceUtil<A> replace_util(fst_array_, nonterminal_hash_, root_);
- return replace_util.CyclicDependencies();
- }
-
- // Return or compute start state of replace fst
- StateId Start() {
- if (!HasStart()) {
- if (fst_array_.size() == 1) { // no fsts defined for replace
- SetStart(kNoStateId);
- return kNoStateId;
- } else {
- const Fst<A>* fst = fst_array_[root_];
- StateId fst_start = fst->Start();
- if (fst_start == kNoStateId) // root Fst is empty
- return kNoStateId;
-
- PrefixId prefix = GetPrefixId(StackPrefix());
- StateId start = state_table_->FindState(
- StateTuple(prefix, root_, fst_start));
- SetStart(start);
- return start;
- }
- } else {
- return CacheImpl<A>::Start();
- }
- }
-
- // return final weight of state (kInfWeight means state is not final)
- Weight Final(StateId s) {
- if (!HasFinal(s)) {
- const StateTuple& tuple = state_table_->Tuple(s);
- const StackPrefix& stack = stackprefix_array_[tuple.prefix_id];
- const Fst<A>* fst = fst_array_[tuple.fst_id];
- StateId fst_state = tuple.fst_state;
-
- if (fst->Final(fst_state) != Weight::Zero() && stack.Depth() == 0)
- SetFinal(s, fst->Final(fst_state));
- else
- SetFinal(s, Weight::Zero());
- }
- return CacheImpl<A>::Final(s);
- }
-
- size_t NumArcs(StateId s) {
- if (HasArcs(s)) { // If state cached, use the cached value.
- return CacheImpl<A>::NumArcs(s);
- } else if (always_cache_) { // If always caching, expand and cache state.
- Expand(s);
- return CacheImpl<A>::NumArcs(s);
- } else { // Otherwise compute the number of arcs without expanding.
- StateTuple tuple = state_table_->Tuple(s);
- if (tuple.fst_state == kNoStateId)
- return 0;
-
- const Fst<A>* fst = fst_array_[tuple.fst_id];
- size_t num_arcs = fst->NumArcs(tuple.fst_state);
- if (ComputeFinalArc(tuple, 0))
- num_arcs++;
-
- return num_arcs;
- }
- }
-
- // Returns whether a given label is a non terminal
- bool IsNonTerminal(Label l) const {
- // TODO(allauzen): be smarter and take advantage of
- // all_dense or all_negative.
- // Use also in ComputeArc, this would require changes to replace
- // so that recursing into an empty fst lead to a non co-accessible
- // state instead of deleting the arc as done currently.
- // Current use correct, since i/olabel sorted iff all_non_empty.
- typename NonTerminalHash::const_iterator it =
- nonterminal_hash_.find(l);
- return it != nonterminal_hash_.end();
- }
-
- size_t NumInputEpsilons(StateId s) {
- if (HasArcs(s)) {
- // If state cached, use the cached value.
- return CacheImpl<A>::NumInputEpsilons(s);
- } else if (always_cache_ || !Properties(kILabelSorted)) {
- // If always caching or if the number of input epsilons is too expensive
- // to compute without caching (i.e. not ilabel sorted),
- // then expand and cache state.
- Expand(s);
- return CacheImpl<A>::NumInputEpsilons(s);
- } else {
- // Otherwise, compute the number of input epsilons without caching.
- StateTuple tuple = state_table_->Tuple(s);
- if (tuple.fst_state == kNoStateId)
- return 0;
- const Fst<A>* fst = fst_array_[tuple.fst_id];
- size_t num = 0;
- if (!epsilon_on_replace_) {
- // If epsilon_on_replace is false, all input epsilon arcs
- // are also input epsilons arcs in the underlying machine.
- fst->NumInputEpsilons(tuple.fst_state);
- } else {
- // Otherwise, one need to consider that all non-terminal arcs
- // in the underlying machine also become input epsilon arc.
- ArcIterator<Fst<A> > aiter(*fst, tuple.fst_state);
- for (; !aiter.Done() &&
- ((aiter.Value().ilabel == 0) ||
- IsNonTerminal(aiter.Value().olabel));
- aiter.Next())
- ++num;
- }
- if (ComputeFinalArc(tuple, 0))
- num++;
- return num;
- }
- }
-
- size_t NumOutputEpsilons(StateId s) {
- if (HasArcs(s)) {
- // If state cached, use the cached value.
- return CacheImpl<A>::NumOutputEpsilons(s);
- } else if(always_cache_ || !Properties(kOLabelSorted)) {
- // If always caching or if the number of output epsilons is too expensive
- // to compute without caching (i.e. not olabel sorted),
- // then expand and cache state.
- Expand(s);
- return CacheImpl<A>::NumOutputEpsilons(s);
- } else {
- // Otherwise, compute the number of output epsilons without caching.
- StateTuple tuple = state_table_->Tuple(s);
- if (tuple.fst_state == kNoStateId)
- return 0;
- const Fst<A>* fst = fst_array_[tuple.fst_id];
- size_t num = 0;
- ArcIterator<Fst<A> > aiter(*fst, tuple.fst_state);
- for (; !aiter.Done() &&
- ((aiter.Value().olabel == 0) ||
- IsNonTerminal(aiter.Value().olabel));
- aiter.Next())
- ++num;
- if (ComputeFinalArc(tuple, 0))
- num++;
- return num;
- }
- }
-
- uint64 Properties() const { return Properties(kFstProperties); }
-
- // Set error if found; return FST impl properties.
- uint64 Properties(uint64 mask) const {
- if (mask & kError) {
- for (size_t i = 1; i < fst_array_.size(); ++i) {
- if (fst_array_[i]->Properties(kError, false))
- SetProperties(kError, kError);
- }
- }
- return FstImpl<Arc>::Properties(mask);
- }
-
- // return the base arc iterator, if arcs have not been computed yet,
- // extend/recurse for new arcs.
- void InitArcIterator(StateId s, ArcIteratorData<A> *data) {
- if (!HasArcs(s))
- Expand(s);
- CacheImpl<A>::InitArcIterator(s, data);
- // TODO(allauzen): Set behaviour of generic iterator
- // Warning: ArcIterator<ReplaceFst<A> >::InitCache()
- // relies on current behaviour.
- }
-
-
- // Extend current state (walk arcs one level deep)
- void Expand(StateId s) {
- StateTuple tuple = state_table_->Tuple(s);
-
- // If local fst is empty
- if (tuple.fst_state == kNoStateId) {
- SetArcs(s);
- return;
- }
-
- ArcIterator< Fst<A> > aiter(
- *(fst_array_[tuple.fst_id]), tuple.fst_state);
- Arc arc;
-
- // Create a final arc when needed
- if (ComputeFinalArc(tuple, &arc))
- PushArc(s, arc);
-
- // Expand all arcs leaving the state
- for (;!aiter.Done(); aiter.Next()) {
- if (ComputeArc(tuple, aiter.Value(), &arc))
- PushArc(s, arc);
- }
-
- SetArcs(s);
- }
-
- void Expand(StateId s, const StateTuple &tuple,
- const ArcIteratorData<A> &data) {
- // If local fst is empty
- if (tuple.fst_state == kNoStateId) {
- SetArcs(s);
- return;
- }
-
- ArcIterator< Fst<A> > aiter(data);
- Arc arc;
-
- // Create a final arc when needed
- if (ComputeFinalArc(tuple, &arc))
- AddArc(s, arc);
-
- // Expand all arcs leaving the state
- for (; !aiter.Done(); aiter.Next()) {
- if (ComputeArc(tuple, aiter.Value(), &arc))
- AddArc(s, arc);
- }
-
- SetArcs(s);
- }
-
- // If arcp == 0, only returns if a final arc is required, does not
- // actually compute it.
- bool ComputeFinalArc(const StateTuple &tuple, A* arcp,
- uint32 flags = kArcValueFlags) {
- const Fst<A>* fst = fst_array_[tuple.fst_id];
- StateId fst_state = tuple.fst_state;
- if (fst_state == kNoStateId)
- return false;
-
- // if state is final, pop up stack
- const StackPrefix& stack = stackprefix_array_[tuple.prefix_id];
- if (fst->Final(fst_state) != Weight::Zero() && stack.Depth()) {
- if (arcp) {
- arcp->ilabel = 0;
- arcp->olabel = 0;
- if (flags & kArcNextStateValue) {
- PrefixId prefix_id = PopPrefix(stack);
- const PrefixTuple& top = stack.Top();
- arcp->nextstate = state_table_->FindState(
- StateTuple(prefix_id, top.fst_id, top.nextstate));
- }
- if (flags & kArcWeightValue)
- arcp->weight = fst->Final(fst_state);
- }
- return true;
- } else {
- return false;
- }
- }
-
- // Compute the arc in the replace fst corresponding to a given
- // in the underlying machine. Returns false if the underlying arc
- // corresponds to no arc in the replace.
- bool ComputeArc(const StateTuple &tuple, const A &arc, A* arcp,
- uint32 flags = kArcValueFlags) {
- if (!epsilon_on_replace_ &&
- (flags == (flags & (kArcILabelValue | kArcWeightValue)))) {
- *arcp = arc;
- return true;
- }
-
- if (arc.olabel == 0) { // expand local fst
- StateId nextstate = flags & kArcNextStateValue
- ? state_table_->FindState(
- StateTuple(tuple.prefix_id, tuple.fst_id, arc.nextstate))
- : kNoStateId;
- *arcp = A(arc.ilabel, arc.olabel, arc.weight, nextstate);
- } else {
- // check for non terminal
- typename NonTerminalHash::const_iterator it =
- nonterminal_hash_.find(arc.olabel);
- if (it != nonterminal_hash_.end()) { // recurse into non terminal
- Label nonterminal = it->second;
- const Fst<A>* nt_fst = fst_array_[nonterminal];
- PrefixId nt_prefix = PushPrefix(stackprefix_array_[tuple.prefix_id],
- tuple.fst_id, arc.nextstate);
-
- // if start state is valid replace, else arc is implicitly
- // deleted
- StateId nt_start = nt_fst->Start();
- if (nt_start != kNoStateId) {
- StateId nt_nextstate = flags & kArcNextStateValue
- ? state_table_->FindState(
- StateTuple(nt_prefix, nonterminal, nt_start))
- : kNoStateId;
- Label ilabel = (epsilon_on_replace_) ? 0 : arc.ilabel;
- *arcp = A(ilabel, 0, arc.weight, nt_nextstate);
- } else {
- return false;
- }
- } else {
- StateId nextstate = flags & kArcNextStateValue
- ? state_table_->FindState(
- StateTuple(tuple.prefix_id, tuple.fst_id, arc.nextstate))
- : kNoStateId;
- *arcp = A(arc.ilabel, arc.olabel, arc.weight, nextstate);
- }
- }
- return true;
- }
-
- // Returns the arc iterator flags supported by this Fst.
- uint32 ArcIteratorFlags() const {
- uint32 flags = kArcValueFlags;
- if (!always_cache_)
- flags |= kArcNoCache;
- return flags;
- }
-
- T* GetStateTable() const {
- return state_table_;
- }
-
- const Fst<A>* GetFst(Label fst_id) const {
- return fst_array_[fst_id];
- }
-
- bool EpsilonOnReplace() const { return epsilon_on_replace_; }
-
- // private helper classes
- private:
- static const size_t kPrime0;
-
- // \class PrefixTuple
- // \brief Tuple of fst_id and destination state (entry in stack prefix)
- struct PrefixTuple {
- PrefixTuple(Label f, StateId s) : fst_id(f), nextstate(s) {}
-
- Label fst_id;
- StateId nextstate;
- };
-
- // \class StackPrefix
- // \brief Container for stack prefix.
- class StackPrefix {
- public:
- StackPrefix() {}
-
- // copy constructor
- StackPrefix(const StackPrefix& x) :
- prefix_(x.prefix_) {
- }
-
- void Push(StateId fst_id, StateId nextstate) {
- prefix_.push_back(PrefixTuple(fst_id, nextstate));
- }
-
- void Pop() {
- prefix_.pop_back();
- }
-
- const PrefixTuple& Top() const {
- return prefix_[prefix_.size()-1];
- }
-
- size_t Depth() const {
- return prefix_.size();
- }
-
- public:
- vector<PrefixTuple> prefix_;
- };
-
-
- // \class StackPrefixEqual
- // \brief Compare two stack prefix classes for equality
- class StackPrefixEqual {
- public:
- bool operator()(const StackPrefix& x, const StackPrefix& y) const {
- if (x.prefix_.size() != y.prefix_.size()) return false;
- for (size_t i = 0; i < x.prefix_.size(); ++i) {
- if (x.prefix_[i].fst_id != y.prefix_[i].fst_id ||
- x.prefix_[i].nextstate != y.prefix_[i].nextstate) return false;
- }
- return true;
- }
- };
-
- //
- // \class StackPrefixKey
- // \brief Hash function for stack prefix to prefix id
- class StackPrefixKey {
- public:
- size_t operator()(const StackPrefix& x) const {
- size_t sum = 0;
- for (size_t i = 0; i < x.prefix_.size(); ++i) {
- sum += x.prefix_[i].fst_id + x.prefix_[i].nextstate*kPrime0;
- }
- return sum;
- }
- };
-
- typedef unordered_map<StackPrefix, PrefixId, StackPrefixKey, StackPrefixEqual>
- StackPrefixHash;
-
- // private methods
- private:
- // hash stack prefix (return unique index into stackprefix array)
- PrefixId GetPrefixId(const StackPrefix& prefix) {
- typename StackPrefixHash::iterator it = prefix_hash_.find(prefix);
- if (it == prefix_hash_.end()) {
- PrefixId prefix_id = stackprefix_array_.size();
- stackprefix_array_.push_back(prefix);
- prefix_hash_[prefix] = prefix_id;
- return prefix_id;
- } else {
- return it->second;
- }
- }
-
- // prefix id after a stack pop
- PrefixId PopPrefix(StackPrefix prefix) {
- prefix.Pop();
- return GetPrefixId(prefix);
- }
-
- // prefix id after a stack push
- PrefixId PushPrefix(StackPrefix prefix, Label fst_id, StateId nextstate) {
- prefix.Push(fst_id, nextstate);
- return GetPrefixId(prefix);
- }
-
-
- // private data
- private:
- // runtime options
- bool epsilon_on_replace_;
- bool always_cache_; // Optionally caching arc iterator disabled when true
-
- // state table
- StateTable *state_table_;
-
- // cross index of unique stack prefix
- // could potentially have one copy of prefix array
- StackPrefixHash prefix_hash_;
- vector<StackPrefix> stackprefix_array_;
-
- set<Label> nonterminal_set_;
- NonTerminalHash nonterminal_hash_;
- vector<const Fst<A>*> fst_array_;
- Label root_;
-
- void operator=(const ReplaceFstImpl<A, T> &); // disallow
-};
-
-
-template <class A, class T>
-const size_t ReplaceFstImpl<A, T>::kPrime0 = 7853;
-
-//
-// \class ReplaceFst
-// \brief Recursivively replaces arcs in the root Fst with other Fsts.
-// This version is a delayed Fst.
-//
-// ReplaceFst supports dynamic replacement of arcs in one Fst with
-// another Fst. This replacement is recursive. ReplaceFst can be used
-// to support a variety of delayed constructions such as recursive
-// transition networks, union, or closure. It is constructed with an
-// array of Fst(s). One Fst represents the root (or topology)
-// machine. The root Fst refers to other Fsts by recursively replacing
-// arcs labeled as non-terminals with the matching non-terminal
-// Fst. Currently the ReplaceFst uses the output symbols of the arcs
-// to determine whether the arc is a non-terminal arc or not. A
-// non-terminal can be any label that is not a non-zero terminal label
-// in the output alphabet.
-//
-// Note that the constructor uses a vector of pair<>. These correspond
-// to the tuple of non-terminal Label and corresponding Fst. For example
-// to implement the closure operation we need 2 Fsts. The first root
-// Fst is a single Arc on the start State that self loops, it references
-// the particular machine for which we are performing the closure operation.
-//
-// The ReplaceFst class supports an optionally caching arc iterator:
-// ArcIterator< ReplaceFst<A> >
-// The ReplaceFst need to be built such that it is known to be ilabel
-// or olabel sorted (see usage below).
-//
-// Observe that Matcher<Fst<A> > will use the optionally caching arc
-// iterator when available (Fst is ilabel sorted and matching on the
-// input, or Fst is olabel sorted and matching on the output).
-// In order to obtain the most efficient behaviour, it is recommended
-// to set 'epsilon_on_replace' to false (this means constructing acceptors
-// as transducers with epsilons on the input side of nonterminal arcs)
-// and matching on the input side.
-//
-// This class attaches interface to implementation and handles
-// reference counting, delegating most methods to ImplToFst.
-template <class A, class T = DefaultReplaceStateTable<A> >
-class ReplaceFst : public ImplToFst< ReplaceFstImpl<A, T> > {
- public:
- friend class ArcIterator< ReplaceFst<A, T> >;
- friend class StateIterator< ReplaceFst<A, T> >;
- friend class ReplaceFstMatcher<A, T>;
-
- typedef A Arc;
- typedef typename A::Label Label;
- typedef typename A::Weight Weight;
- typedef typename A::StateId StateId;
- typedef CacheState<A> State;
- typedef ReplaceFstImpl<A, T> Impl;
-
- using ImplToFst<Impl>::Properties;
-
- ReplaceFst(const vector<pair<Label, const Fst<A>* > >& fst_array,
- Label root)
- : ImplToFst<Impl>(new Impl(fst_array, ReplaceFstOptions<A, T>(root))) {}
-
- ReplaceFst(const vector<pair<Label, const Fst<A>* > >& fst_array,
- const ReplaceFstOptions<A, T> &opts)
- : ImplToFst<Impl>(new Impl(fst_array, opts)) {}
-
- // See Fst<>::Copy() for doc.
- ReplaceFst(const ReplaceFst<A, T>& fst, bool safe = false)
- : ImplToFst<Impl>(fst, safe) {}
-
- // Get a copy of this ReplaceFst. See Fst<>::Copy() for further doc.
- virtual ReplaceFst<A, T> *Copy(bool safe = false) const {
- return new ReplaceFst<A, T>(*this, safe);
- }
-
- virtual inline void InitStateIterator(StateIteratorData<A> *data) const;
-
- virtual void InitArcIterator(StateId s, ArcIteratorData<A> *data) const {
- GetImpl()->InitArcIterator(s, data);
- }
-
- virtual MatcherBase<A> *InitMatcher(MatchType match_type) const {
- if ((GetImpl()->ArcIteratorFlags() & kArcNoCache) &&
- ((match_type == MATCH_INPUT && Properties(kILabelSorted, false)) ||
- (match_type == MATCH_OUTPUT && Properties(kOLabelSorted, false)))) {
- return new ReplaceFstMatcher<A, T>(*this, match_type);
- }
- else {
- VLOG(2) << "Not using replace matcher";
- return 0;
- }
- }
-
- bool CyclicDependencies() const {
- return GetImpl()->CyclicDependencies();
- }
-
- private:
- // Makes visible to friends.
- Impl *GetImpl() const { return ImplToFst<Impl>::GetImpl(); }
-
- void operator=(const ReplaceFst<A> &fst); // disallow
-};
-
-
-// Specialization for ReplaceFst.
-template<class A, class T>
-class StateIterator< ReplaceFst<A, T> >
- : public CacheStateIterator< ReplaceFst<A, T> > {
- public:
- explicit StateIterator(const ReplaceFst<A, T> &fst)
- : CacheStateIterator< ReplaceFst<A, T> >(fst, fst.GetImpl()) {}
-
- private:
- DISALLOW_COPY_AND_ASSIGN(StateIterator);
-};
-
-
-// Specialization for ReplaceFst.
-// Implements optional caching. It can be used as follows:
-//
-// ReplaceFst<A> replace;
-// ArcIterator< ReplaceFst<A> > aiter(replace, s);
-// // Note: ArcIterator< Fst<A> > is always a caching arc iterator.
-// aiter.SetFlags(kArcNoCache, kArcNoCache);
-// // Use the arc iterator, no arc will be cached, no state will be expanded.
-// // The varied 'kArcValueFlags' can be used to decide which part
-// // of arc values needs to be computed.
-// aiter.SetFlags(kArcILabelValue, kArcValueFlags);
-// // Only want the ilabel for this arc
-// aiter.Value(); // Does not compute the destination state.
-// aiter.Next();
-// aiter.SetFlags(kArcNextStateValue, kArcNextStateValue);
-// // Want both ilabel and nextstate for that arc
-// aiter.Value(); // Does compute the destination state and inserts it
-// // in the replace state table.
-// // No Arc has been cached at that point.
-//
-template <class A, class T>
-class ArcIterator< ReplaceFst<A, T> > {
- public:
- typedef A Arc;
- typedef typename A::StateId StateId;
-
- ArcIterator(const ReplaceFst<A, T> &fst, StateId s)
- : fst_(fst), state_(s), pos_(0), offset_(0), flags_(0), arcs_(0),
- data_flags_(0), final_flags_(0) {
- cache_data_.ref_count = 0;
- local_data_.ref_count = 0;
-
- // If FST does not support optional caching, force caching.
- if(!(fst_.GetImpl()->ArcIteratorFlags() & kArcNoCache) &&
- !(fst_.GetImpl()->HasArcs(state_)))
- fst_.GetImpl()->Expand(state_);
-
- // If state is already cached, use cached arcs array.
- if (fst_.GetImpl()->HasArcs(state_)) {
- (fst_.GetImpl())->template CacheImpl<A>::InitArcIterator(state_,
- &cache_data_);
- num_arcs_ = cache_data_.narcs;
- arcs_ = cache_data_.arcs; // 'arcs_' is a ptr to the cached arcs.
- data_flags_ = kArcValueFlags; // All the arc member values are valid.
- } else { // Otherwise delay decision until Value() is called.
- tuple_ = fst_.GetImpl()->GetStateTable()->Tuple(state_);
- if (tuple_.fst_state == kNoStateId) {
- num_arcs_ = 0;
- } else {
- // The decision to cache or not to cache has been defered
- // until Value() or SetFlags() is called. However, the arc
- // iterator is set up now to be ready for non-caching in order
- // to keep the Value() method simple and efficient.
- const Fst<A>* fst = fst_.GetImpl()->GetFst(tuple_.fst_id);
- fst->InitArcIterator(tuple_.fst_state, &local_data_);
- // 'arcs_' is a pointer to the arcs in the underlying machine.
- arcs_ = local_data_.arcs;
- // Compute the final arc (but not its destination state)
- // if a final arc is required.
- bool has_final_arc = fst_.GetImpl()->ComputeFinalArc(
- tuple_,
- &final_arc_,
- kArcValueFlags & ~kArcNextStateValue);
- // Set the arc value flags that hold for 'final_arc_'.
- final_flags_ = kArcValueFlags & ~kArcNextStateValue;
- // Compute the number of arcs.
- num_arcs_ = local_data_.narcs;
- if (has_final_arc)
- ++num_arcs_;
- // Set the offset between the underlying arc positions and
- // the positions in the arc iterator.
- offset_ = num_arcs_ - local_data_.narcs;
- // Defers the decision to cache or not until Value() or
- // SetFlags() is called.
- data_flags_ = 0;
- }
- }
- }
-
- ~ArcIterator() {
- if (cache_data_.ref_count)
- --(*cache_data_.ref_count);
- if (local_data_.ref_count)
- --(*local_data_.ref_count);
- }
-
- void ExpandAndCache() const {
- // TODO(allauzen): revisit this
- // fst_.GetImpl()->Expand(state_, tuple_, local_data_);
- // (fst_.GetImpl())->CacheImpl<A>*>::InitArcIterator(state_,
- // &cache_data_);
- //
- fst_.InitArcIterator(state_, &cache_data_); // Expand and cache state.
- arcs_ = cache_data_.arcs; // 'arcs_' is a pointer to the cached arcs.
- data_flags_ = kArcValueFlags; // All the arc member values are valid.
- offset_ = 0; // No offset
-
- }
-
- void Init() {
- if (flags_ & kArcNoCache) { // If caching is disabled
- // 'arcs_' is a pointer to the arcs in the underlying machine.
- arcs_ = local_data_.arcs;
- // Set the arcs value flags that hold for 'arcs_'.
- data_flags_ = kArcWeightValue;
- if (!fst_.GetImpl()->EpsilonOnReplace())
- data_flags_ |= kArcILabelValue;
- // Set the offset between the underlying arc positions and
- // the positions in the arc iterator.
- offset_ = num_arcs_ - local_data_.narcs;
- } else { // Otherwise, expand and cache
- ExpandAndCache();
- }
- }
-
- bool Done() const { return pos_ >= num_arcs_; }
-
- const A& Value() const {
- // If 'data_flags_' was set to 0, non-caching was not requested
- if (!data_flags_) {
- // TODO(allauzen): revisit this.
- if (flags_ & kArcNoCache) {
- // Should never happen.
- FSTERROR() << "ReplaceFst: inconsistent arc iterator flags";
- }
- ExpandAndCache(); // Expand and cache.
- }
-
- if (pos_ - offset_ >= 0) { // The requested arc is not the 'final' arc.
- const A& arc = arcs_[pos_ - offset_];
- if ((data_flags_ & flags_) == (flags_ & kArcValueFlags)) {
- // If the value flags for 'arc' match the recquired value flags
- // then return 'arc'.
- return arc;
- } else {
- // Otherwise, compute the corresponding arc on-the-fly.
- fst_.GetImpl()->ComputeArc(tuple_, arc, &arc_, flags_ & kArcValueFlags);
- return arc_;
- }
- } else { // The requested arc is the 'final' arc.
- if ((final_flags_ & flags_) != (flags_ & kArcValueFlags)) {
- // If the arc value flags that hold for the final arc
- // do not match the requested value flags, then
- // 'final_arc_' needs to be updated.
- fst_.GetImpl()->ComputeFinalArc(tuple_, &final_arc_,
- flags_ & kArcValueFlags);
- final_flags_ = flags_ & kArcValueFlags;
- }
- return final_arc_;
- }
- }
-
- void Next() { ++pos_; }
-
- size_t Position() const { return pos_; }
-
- void Reset() { pos_ = 0; }
-
- void Seek(size_t pos) { pos_ = pos; }
-
- uint32 Flags() const { return flags_; }
-
- void SetFlags(uint32 f, uint32 mask) {
- // Update the flags taking into account what flags are supported
- // by the Fst.
- flags_ &= ~mask;
- flags_ |= (f & fst_.GetImpl()->ArcIteratorFlags());
- // If non-caching is not requested (and caching has not already
- // been performed), then flush 'data_flags_' to request caching
- // during the next call to Value().
- if (!(flags_ & kArcNoCache) && data_flags_ != kArcValueFlags) {
- if (!fst_.GetImpl()->HasArcs(state_))
- data_flags_ = 0;
- }
- // If 'data_flags_' has been flushed but non-caching is requested
- // before calling Value(), then set up the iterator for non-caching.
- if ((f & kArcNoCache) && (!data_flags_))
- Init();
- }
-
- private:
- const ReplaceFst<A, T> &fst_; // Reference to the FST
- StateId state_; // State in the FST
- mutable typename T::StateTuple tuple_; // Tuple corresponding to state_
-
- ssize_t pos_; // Current position
- mutable ssize_t offset_; // Offset between position in iterator and in arcs_
- ssize_t num_arcs_; // Number of arcs at state_
- uint32 flags_; // Behavorial flags for the arc iterator
- mutable Arc arc_; // Memory to temporarily store computed arcs
-
- mutable ArcIteratorData<Arc> cache_data_; // Arc iterator data in cache
- mutable ArcIteratorData<Arc> local_data_; // Arc iterator data in local fst
-
- mutable const A* arcs_; // Array of arcs
- mutable uint32 data_flags_; // Arc value flags valid for data in arcs_
- mutable Arc final_arc_; // Final arc (when required)
- mutable uint32 final_flags_; // Arc value flags valid for final_arc_
-
- DISALLOW_COPY_AND_ASSIGN(ArcIterator);
-};
-
-
-template <class A, class T>
-class ReplaceFstMatcher : public MatcherBase<A> {
- public:
- typedef A Arc;
- typedef typename A::StateId StateId;
- typedef typename A::Label Label;
- typedef MultiEpsMatcher<Matcher<Fst<A> > > LocalMatcher;
-
- ReplaceFstMatcher(const ReplaceFst<A, T> &fst, fst::MatchType match_type)
- : fst_(fst),
- impl_(fst_.GetImpl()),
- s_(fst::kNoStateId),
- match_type_(match_type),
- current_loop_(false),
- final_arc_(false),
- loop_(fst::kNoLabel, 0, A::Weight::One(), fst::kNoStateId) {
- if (match_type_ == fst::MATCH_OUTPUT)
- swap(loop_.ilabel, loop_.olabel);
- InitMatchers();
- }
-
- ReplaceFstMatcher(const ReplaceFstMatcher<A, T> &matcher, bool safe = false)
- : fst_(matcher.fst_),
- impl_(fst_.GetImpl()),
- s_(fst::kNoStateId),
- match_type_(matcher.match_type_),
- current_loop_(false),
- loop_(fst::kNoLabel, 0, A::Weight::One(), fst::kNoStateId) {
- if (match_type_ == fst::MATCH_OUTPUT)
- swap(loop_.ilabel, loop_.olabel);
- InitMatchers();
- }
-
- // Create a local matcher for each component Fst of replace.
- // LocalMatcher is a multi epsilon wrapper matcher. MultiEpsilonMatcher
- // is used to match each non-terminal arc, since these non-terminal
- // turn into epsilons on recursion.
- void InitMatchers() {
- const vector<const Fst<A>*>& fst_array = impl_->fst_array_;
- matcher_.resize(fst_array.size(), 0);
- for (size_t i = 0; i < fst_array.size(); ++i) {
- if (fst_array[i]) {
- matcher_[i] =
- new LocalMatcher(*fst_array[i], match_type_, kMultiEpsList);
-
- typename set<Label>::iterator it = impl_->nonterminal_set_.begin();
- for (; it != impl_->nonterminal_set_.end(); ++it) {
- matcher_[i]->AddMultiEpsLabel(*it);
- }
- }
- }
- }
-
- virtual ReplaceFstMatcher<A, T> *Copy(bool safe = false) const {
- return new ReplaceFstMatcher<A, T>(*this, safe);
- }
-
- virtual ~ReplaceFstMatcher() {
- for (size_t i = 0; i < matcher_.size(); ++i)
- delete matcher_[i];
- }
-
- virtual MatchType Type(bool test) const {
- if (match_type_ == MATCH_NONE)
- return match_type_;
-
- uint64 true_prop = match_type_ == MATCH_INPUT ?
- kILabelSorted : kOLabelSorted;
- uint64 false_prop = match_type_ == MATCH_INPUT ?
- kNotILabelSorted : kNotOLabelSorted;
- uint64 props = fst_.Properties(true_prop | false_prop, test);
-
- if (props & true_prop)
- return match_type_;
- else if (props & false_prop)
- return MATCH_NONE;
- else
- return MATCH_UNKNOWN;
- }
-
- virtual const Fst<A> &GetFst() const {
- return fst_;
- }
-
- virtual uint64 Properties(uint64 props) const {
- return props;
- }
-
- private:
- // Set the sate from which our matching happens.
- virtual void SetState_(StateId s) {
- if (s_ == s) return;
-
- s_ = s;
- tuple_ = impl_->GetStateTable()->Tuple(s_);
- if (tuple_.fst_state == kNoStateId) {
- done_ = true;
- return;
- }
- // Get current matcher. Used for non epsilon matching
- current_matcher_ = matcher_[tuple_.fst_id];
- current_matcher_->SetState(tuple_.fst_state);
- loop_.nextstate = s_;
-
- final_arc_ = false;
- }
-
- // Search for label, from previous set state. If label == 0, first
- // hallucinate and epsilon loop, else use the underlying matcher to
- // search for the label or epsilons.
- // - Note since the ReplaceFST recursion on non-terminal arcs causes
- // epsilon transitions to be created we use the MultiEpsilonMatcher
- // to search for possible matches of non terminals.
- // - If the component Fst reaches a final state we also need to add
- // the exiting final arc.
- virtual bool Find_(Label label) {
- bool found = false;
- label_ = label;
- if (label_ == 0 || label_ == kNoLabel) {
- // Compute loop directly, saving Replace::ComputeArc
- if (label_ == 0) {
- current_loop_ = true;
- found = true;
- }
- // Search for matching multi epsilons
- final_arc_ = impl_->ComputeFinalArc(tuple_, 0);
- found = current_matcher_->Find(kNoLabel) || final_arc_ || found;
- } else {
- // Search on sub machine directly using sub machine matcher.
- found = current_matcher_->Find(label_);
- }
- return found;
- }
-
- virtual bool Done_() const {
- return !current_loop_ && !final_arc_ && current_matcher_->Done();
- }
-
- virtual const Arc& Value_() const {
- if (current_loop_) {
- return loop_;
- }
- if (final_arc_) {
- impl_->ComputeFinalArc(tuple_, &arc_);
- return arc_;
- }
- const Arc& component_arc = current_matcher_->Value();
- impl_->ComputeArc(tuple_, component_arc, &arc_);
- return arc_;
- }
-
- virtual void Next_() {
- if (current_loop_) {
- current_loop_ = false;
- return;
- }
- if (final_arc_) {
- final_arc_ = false;
- return;
- }
- current_matcher_->Next();
- }
-
- const ReplaceFst<A, T>& fst_;
- ReplaceFstImpl<A, T> *impl_;
- LocalMatcher* current_matcher_;
- vector<LocalMatcher*> matcher_;
-
- StateId s_; // Current state
- Label label_; // Current label
-
- MatchType match_type_; // Supplied by caller
- mutable bool done_;
- mutable bool current_loop_; // Current arc is the implicit loop
- mutable bool final_arc_; // Current arc for exiting recursion
- mutable typename T::StateTuple tuple_; // Tuple corresponding to state_
- mutable Arc arc_;
- Arc loop_;
-};
-
-template <class A, class T> inline
-void ReplaceFst<A, T>::InitStateIterator(StateIteratorData<A> *data) const {
- data->base = new StateIterator< ReplaceFst<A, T> >(*this);
-}
-
-typedef ReplaceFst<StdArc> StdReplaceFst;
-
-
-// // Recursivively replaces arcs in the root Fst with other Fsts.
-// This version writes the result of replacement to an output MutableFst.
-//
-// Replace supports replacement of arcs in one Fst with another
-// Fst. This replacement is recursive. Replace takes an array of
-// Fst(s). One Fst represents the root (or topology) machine. The root
-// Fst refers to other Fsts by recursively replacing arcs labeled as
-// non-terminals with the matching non-terminal Fst. Currently Replace
-// uses the output symbols of the arcs to determine whether the arc is
-// a non-terminal arc or not. A non-terminal can be any label that is
-// not a non-zero terminal label in the output alphabet. Note that
-// input argument is a vector of pair<>. These correspond to the tuple
-// of non-terminal Label and corresponding Fst.
-template<class Arc>
-void Replace(const vector<pair<typename Arc::Label,
- const Fst<Arc>* > >& ifst_array,
- MutableFst<Arc> *ofst, typename Arc::Label root,
- bool epsilon_on_replace) {
- ReplaceFstOptions<Arc> opts(root, epsilon_on_replace);
- opts.gc_limit = 0; // Cache only the last state for fastest copy.
- *ofst = ReplaceFst<Arc>(ifst_array, opts);
-}
-
-template<class Arc>
-void Replace(const vector<pair<typename Arc::Label,
- const Fst<Arc>* > >& ifst_array,
- MutableFst<Arc> *ofst, typename Arc::Label root) {
- Replace(ifst_array, ofst, root, false);
-}
-
-} // namespace fst
-
-#endif // FST_LIB_REPLACE_H__