summaryrefslogtreecommitdiff
path: root/kaldi_io/src/tools/openfst/include/fst/replace.h
diff options
context:
space:
mode:
Diffstat (limited to 'kaldi_io/src/tools/openfst/include/fst/replace.h')
-rw-r--r--kaldi_io/src/tools/openfst/include/fst/replace.h1453
1 files changed, 1453 insertions, 0 deletions
diff --git a/kaldi_io/src/tools/openfst/include/fst/replace.h b/kaldi_io/src/tools/openfst/include/fst/replace.h
new file mode 100644
index 0000000..ef5f6cc
--- /dev/null
+++ b/kaldi_io/src/tools/openfst/include/fst/replace.h
@@ -0,0 +1,1453 @@
+// replace.h
+
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+//
+// Copyright 2005-2010 Google, Inc.
+// Author: johans@google.com (Johan Schalkwyk)
+//
+// \file
+// Functions and classes for the recursive replacement of Fsts.
+//
+
+#ifndef FST_LIB_REPLACE_H__
+#define FST_LIB_REPLACE_H__
+
+#include <tr1/unordered_map>
+using std::tr1::unordered_map;
+using std::tr1::unordered_multimap;
+#include <set>
+#include <string>
+#include <utility>
+using std::pair; using std::make_pair;
+#include <vector>
+using std::vector;
+
+#include <fst/cache.h>
+#include <fst/expanded-fst.h>
+#include <fst/fst.h>
+#include <fst/matcher.h>
+#include <fst/replace-util.h>
+#include <fst/state-table.h>
+#include <fst/test-properties.h>
+
+namespace fst {
+
+//
+// REPLACE STATE TUPLES AND TABLES
+//
+// The replace state table has the form
+//
+// template <class A, class P>
+// class ReplaceStateTable {
+// public:
+// typedef A Arc;
+// typedef P PrefixId;
+// typedef typename A::StateId StateId;
+// typedef ReplaceStateTuple<StateId, PrefixId> StateTuple;
+// typedef typename A::Label Label;
+//
+// // Required constuctor
+// ReplaceStateTable(const vector<pair<Label, const Fst<A>*> > &fst_tuples,
+// Label root);
+//
+// // Required copy constructor that does not copy state
+// ReplaceStateTable(const ReplaceStateTable<A,P> &table);
+//
+// // Lookup state ID by tuple. If it doesn't exist, then add it.
+// StateId FindState(const StateTuple &tuple);
+//
+// // Lookup state tuple by ID.
+// const StateTuple &Tuple(StateId id) const;
+// };
+
+
+// \struct ReplaceStateTuple
+// \brief Tuple of information that uniquely defines a state in replace
+template <class S, class P>
+struct ReplaceStateTuple {
+ typedef S StateId;
+ typedef P PrefixId;
+
+ ReplaceStateTuple()
+ : prefix_id(-1), fst_id(kNoStateId), fst_state(kNoStateId) {}
+
+ ReplaceStateTuple(PrefixId p, StateId f, StateId s)
+ : prefix_id(p), fst_id(f), fst_state(s) {}
+
+ PrefixId prefix_id; // index in prefix table
+ StateId fst_id; // current fst being walked
+ StateId fst_state; // current state in fst being walked, not to be
+ // confused with the state_id of the combined fst
+};
+
+
+// Equality of replace state tuples.
+template <class S, class P>
+inline bool operator==(const ReplaceStateTuple<S, P>& x,
+ const ReplaceStateTuple<S, P>& y) {
+ return x.prefix_id == y.prefix_id &&
+ x.fst_id == y.fst_id &&
+ x.fst_state == y.fst_state;
+}
+
+
+// \class ReplaceRootSelector
+// Functor returning true for tuples corresponding to states in the root FST
+template <class S, class P>
+class ReplaceRootSelector {
+ public:
+ bool operator()(const ReplaceStateTuple<S, P> &tuple) const {
+ return tuple.prefix_id == 0;
+ }
+};
+
+
+// \class ReplaceFingerprint
+// Fingerprint for general replace state tuples.
+template <class S, class P>
+class ReplaceFingerprint {
+ public:
+ ReplaceFingerprint(const vector<uint64> *size_array)
+ : cumulative_size_array_(size_array) {}
+
+ uint64 operator()(const ReplaceStateTuple<S, P> &tuple) const {
+ return tuple.prefix_id * (cumulative_size_array_->back()) +
+ cumulative_size_array_->at(tuple.fst_id - 1) +
+ tuple.fst_state;
+ }
+
+ private:
+ const vector<uint64> *cumulative_size_array_;
+};
+
+
+// \class ReplaceFstStateFingerprint
+// Useful when the fst_state uniquely define the tuple.
+template <class S, class P>
+class ReplaceFstStateFingerprint {
+ public:
+ uint64 operator()(const ReplaceStateTuple<S, P>& tuple) const {
+ return tuple.fst_state;
+ }
+};
+
+
+// \class ReplaceHash
+// A generic hash function for replace state tuples.
+template <typename S, typename P>
+class ReplaceHash {
+ public:
+ size_t operator()(const ReplaceStateTuple<S, P>& t) const {
+ return t.prefix_id + t.fst_id * kPrime0 + t.fst_state * kPrime1;
+ }
+ private:
+ static const size_t kPrime0;
+ static const size_t kPrime1;
+};
+
+template <typename S, typename P>
+const size_t ReplaceHash<S, P>::kPrime0 = 7853;
+
+template <typename S, typename P>
+const size_t ReplaceHash<S, P>::kPrime1 = 7867;
+
+template <class A, class T> class ReplaceFstMatcher;
+
+
+// \class VectorHashReplaceStateTable
+// A two-level state table for replace.
+// Warning: calls CountStates to compute the number of states of each
+// component Fst.
+template <class A, class P = ssize_t>
+class VectorHashReplaceStateTable {
+ public:
+ typedef A Arc;
+ typedef typename A::StateId StateId;
+ typedef typename A::Label Label;
+ typedef P PrefixId;
+ typedef ReplaceStateTuple<StateId, P> StateTuple;
+ typedef VectorHashStateTable<ReplaceStateTuple<StateId, P>,
+ ReplaceRootSelector<StateId, P>,
+ ReplaceFstStateFingerprint<StateId, P>,
+ ReplaceFingerprint<StateId, P> > StateTable;
+
+ VectorHashReplaceStateTable(
+ const vector<pair<Label, const Fst<A>*> > &fst_tuples,
+ Label root) : root_size_(0) {
+ cumulative_size_array_.push_back(0);
+ for (size_t i = 0; i < fst_tuples.size(); ++i) {
+ if (fst_tuples[i].first == root) {
+ root_size_ = CountStates(*(fst_tuples[i].second));
+ cumulative_size_array_.push_back(cumulative_size_array_.back());
+ } else {
+ cumulative_size_array_.push_back(cumulative_size_array_.back() +
+ CountStates(*(fst_tuples[i].second)));
+ }
+ }
+ state_table_ = new StateTable(
+ new ReplaceRootSelector<StateId, P>,
+ new ReplaceFstStateFingerprint<StateId, P>,
+ new ReplaceFingerprint<StateId, P>(&cumulative_size_array_),
+ root_size_,
+ root_size_ + cumulative_size_array_.back());
+ }
+
+ VectorHashReplaceStateTable(const VectorHashReplaceStateTable<A, P> &table)
+ : root_size_(table.root_size_),
+ cumulative_size_array_(table.cumulative_size_array_) {
+ state_table_ = new StateTable(
+ new ReplaceRootSelector<StateId, P>,
+ new ReplaceFstStateFingerprint<StateId, P>,
+ new ReplaceFingerprint<StateId, P>(&cumulative_size_array_),
+ root_size_,
+ root_size_ + cumulative_size_array_.back());
+ }
+
+ ~VectorHashReplaceStateTable() {
+ delete state_table_;
+ }
+
+ StateId FindState(const StateTuple &tuple) {
+ return state_table_->FindState(tuple);
+ }
+
+ const StateTuple &Tuple(StateId id) const {
+ return state_table_->Tuple(id);
+ }
+
+ private:
+ StateId root_size_;
+ vector<uint64> cumulative_size_array_;
+ StateTable *state_table_;
+};
+
+
+// \class DefaultReplaceStateTable
+// Default replace state table
+template <class A, class P = ssize_t>
+class DefaultReplaceStateTable : public CompactHashStateTable<
+ ReplaceStateTuple<typename A::StateId, P>,
+ ReplaceHash<typename A::StateId, P> > {
+ public:
+ typedef A Arc;
+ typedef typename A::StateId StateId;
+ typedef typename A::Label Label;
+ typedef P PrefixId;
+ typedef ReplaceStateTuple<StateId, P> StateTuple;
+ typedef CompactHashStateTable<StateTuple,
+ ReplaceHash<StateId, PrefixId> > StateTable;
+
+ using StateTable::FindState;
+ using StateTable::Tuple;
+
+ DefaultReplaceStateTable(
+ const vector<pair<Label, const Fst<A>*> > &fst_tuples,
+ Label root) {}
+
+ DefaultReplaceStateTable(const DefaultReplaceStateTable<A, P> &table)
+ : StateTable() {}
+};
+
+//
+// REPLACE FST CLASS
+//
+
+// By default ReplaceFst will copy the input label of the 'replace arc'.
+// For acceptors we do not want this behaviour. Instead we need to
+// create an epsilon arc when recursing into the appropriate Fst.
+// The 'epsilon_on_replace' option can be used to toggle this behaviour.
+template <class A, class T = DefaultReplaceStateTable<A> >
+struct ReplaceFstOptions : CacheOptions {
+ int64 root; // root rule for expansion
+ bool epsilon_on_replace;
+ bool take_ownership; // take ownership of input Fst(s)
+ T* state_table;
+
+ ReplaceFstOptions(const CacheOptions &opts, int64 r)
+ : CacheOptions(opts),
+ root(r),
+ epsilon_on_replace(false),
+ take_ownership(false),
+ state_table(0) {}
+ explicit ReplaceFstOptions(int64 r)
+ : root(r),
+ epsilon_on_replace(false),
+ take_ownership(false),
+ state_table(0) {}
+ ReplaceFstOptions(int64 r, bool epsilon_replace_arc)
+ : root(r),
+ epsilon_on_replace(epsilon_replace_arc),
+ take_ownership(false),
+ state_table(0) {}
+ ReplaceFstOptions()
+ : root(kNoLabel),
+ epsilon_on_replace(false),
+ take_ownership(false),
+ state_table(0) {}
+};
+
+
+// \class ReplaceFstImpl
+// \brief Implementation class for replace class Fst
+//
+// The replace implementation class supports a dynamic
+// expansion of a recursive transition network represented as Fst
+// with dynamic replacable arcs.
+//
+template <class A, class T>
+class ReplaceFstImpl : public CacheImpl<A> {
+ friend class ReplaceFstMatcher<A, T>;
+
+ public:
+ using FstImpl<A>::SetType;
+ using FstImpl<A>::SetProperties;
+ using FstImpl<A>::WriteHeader;
+ using FstImpl<A>::SetInputSymbols;
+ using FstImpl<A>::SetOutputSymbols;
+ using FstImpl<A>::InputSymbols;
+ using FstImpl<A>::OutputSymbols;
+
+ using CacheImpl<A>::PushArc;
+ using CacheImpl<A>::HasArcs;
+ using CacheImpl<A>::HasFinal;
+ using CacheImpl<A>::HasStart;
+ using CacheImpl<A>::SetArcs;
+ using CacheImpl<A>::SetFinal;
+ using CacheImpl<A>::SetStart;
+
+ typedef typename A::Label Label;
+ typedef typename A::Weight Weight;
+ typedef typename A::StateId StateId;
+ typedef CacheState<A> State;
+ typedef A Arc;
+ typedef unordered_map<Label, Label> NonTerminalHash;
+
+ typedef T StateTable;
+ typedef typename T::PrefixId PrefixId;
+ typedef ReplaceStateTuple<StateId, PrefixId> StateTuple;
+
+ // constructor for replace class implementation.
+ // \param fst_tuples array of label/fst tuples, one for each non-terminal
+ ReplaceFstImpl(const vector< pair<Label, const Fst<A>* > >& fst_tuples,
+ const ReplaceFstOptions<A, T> &opts)
+ : CacheImpl<A>(opts),
+ epsilon_on_replace_(opts.epsilon_on_replace),
+ state_table_(opts.state_table ? opts.state_table :
+ new StateTable(fst_tuples, opts.root)) {
+
+ SetType("replace");
+
+ if (fst_tuples.size() > 0) {
+ SetInputSymbols(fst_tuples[0].second->InputSymbols());
+ SetOutputSymbols(fst_tuples[0].second->OutputSymbols());
+ }
+
+ bool all_negative = true; // all nonterminals are negative?
+ bool dense_range = true; // all nonterminals are positive
+ // and form a dense range containing 1?
+ for (size_t i = 0; i < fst_tuples.size(); ++i) {
+ Label nonterminal = fst_tuples[i].first;
+ if (nonterminal >= 0)
+ all_negative = false;
+ if (nonterminal > fst_tuples.size() || nonterminal <= 0)
+ dense_range = false;
+ }
+
+ vector<uint64> inprops;
+ bool all_ilabel_sorted = true;
+ bool all_olabel_sorted = true;
+ bool all_non_empty = true;
+ fst_array_.push_back(0);
+ for (size_t i = 0; i < fst_tuples.size(); ++i) {
+ Label label = fst_tuples[i].first;
+ const Fst<A> *fst = fst_tuples[i].second;
+ nonterminal_hash_[label] = fst_array_.size();
+ nonterminal_set_.insert(label);
+ fst_array_.push_back(opts.take_ownership ? fst : fst->Copy());
+ if (fst->Start() == kNoStateId)
+ all_non_empty = false;
+ if(!fst->Properties(kILabelSorted, false))
+ all_ilabel_sorted = false;
+ if(!fst->Properties(kOLabelSorted, false))
+ all_olabel_sorted = false;
+ inprops.push_back(fst->Properties(kCopyProperties, false));
+ if (i) {
+ if (!CompatSymbols(InputSymbols(), fst->InputSymbols())) {
+ FSTERROR() << "ReplaceFstImpl: input symbols of Fst " << i
+ << " does not match input symbols of base Fst (0'th fst)";
+ SetProperties(kError, kError);
+ }
+ if (!CompatSymbols(OutputSymbols(), fst->OutputSymbols())) {
+ FSTERROR() << "ReplaceFstImpl: output symbols of Fst " << i
+ << " does not match output symbols of base Fst "
+ << "(0'th fst)";
+ SetProperties(kError, kError);
+ }
+ }
+ }
+ Label nonterminal = nonterminal_hash_[opts.root];
+ if ((nonterminal == 0) && (fst_array_.size() > 1)) {
+ FSTERROR() << "ReplaceFstImpl: no Fst corresponding to root label '"
+ << opts.root << "' in the input tuple vector";
+ SetProperties(kError, kError);
+ }
+ root_ = (nonterminal > 0) ? nonterminal : 1;
+
+ SetProperties(ReplaceProperties(inprops, root_ - 1, epsilon_on_replace_,
+ all_non_empty));
+ // We assume that all terminals are positive. The resulting
+ // ReplaceFst is known to be kILabelSorted when all sub-FSTs are
+ // kILabelSorted and one of the 3 following conditions is satisfied:
+ // 1. 'epsilon_on_replace' is false, or
+ // 2. all non-terminals are negative, or
+ // 3. all non-terninals are positive and form a dense range containing 1.
+ if (all_ilabel_sorted &&
+ (!epsilon_on_replace_ || all_negative || dense_range))
+ SetProperties(kILabelSorted, kILabelSorted);
+ // Similarly, the resulting ReplaceFst is known to be
+ // kOLabelSorted when all sub-FSTs are kOLabelSorted and one of
+ // the 2 following conditions is satisfied:
+ // 1. all non-terminals are negative, or
+ // 2. all non-terninals are positive and form a dense range containing 1.
+ if (all_olabel_sorted && (all_negative || dense_range))
+ SetProperties(kOLabelSorted, kOLabelSorted);
+
+ // Enable optional caching as long as sorted and all non empty.
+ if (Properties(kILabelSorted | kOLabelSorted) && all_non_empty)
+ always_cache_ = false;
+ else
+ always_cache_ = true;
+ VLOG(2) << "ReplaceFstImpl::ReplaceFstImpl: always_cache = "
+ << (always_cache_ ? "true" : "false");
+ }
+
+ ReplaceFstImpl(const ReplaceFstImpl& impl)
+ : CacheImpl<A>(impl),
+ epsilon_on_replace_(impl.epsilon_on_replace_),
+ always_cache_(impl.always_cache_),
+ state_table_(new StateTable(*(impl.state_table_))),
+ nonterminal_set_(impl.nonterminal_set_),
+ nonterminal_hash_(impl.nonterminal_hash_),
+ root_(impl.root_) {
+ SetType("replace");
+ SetProperties(impl.Properties(), kCopyProperties);
+ SetInputSymbols(impl.InputSymbols());
+ SetOutputSymbols(impl.OutputSymbols());
+ fst_array_.reserve(impl.fst_array_.size());
+ fst_array_.push_back(0);
+ for (size_t i = 1; i < impl.fst_array_.size(); ++i) {
+ fst_array_.push_back(impl.fst_array_[i]->Copy(true));
+ }
+ }
+
+ ~ReplaceFstImpl() {
+ VLOG(2) << "~ReplaceFstImpl: gc = "
+ << (CacheImpl<A>::GetCacheGc() ? "true" : "false")
+ << ", gc_size = " << CacheImpl<A>::GetCacheSize()
+ << ", gc_limit = " << CacheImpl<A>::GetCacheLimit();
+
+ delete state_table_;
+ for (size_t i = 1; i < fst_array_.size(); ++i) {
+ delete fst_array_[i];
+ }
+ }
+
+ // Computes the dependency graph of the replace class and returns
+ // true if the dependencies are cyclic. Cyclic dependencies will result
+ // in an un-expandable replace fst.
+ bool CyclicDependencies() const {
+ ReplaceUtil<A> replace_util(fst_array_, nonterminal_hash_, root_);
+ return replace_util.CyclicDependencies();
+ }
+
+ // Return or compute start state of replace fst
+ StateId Start() {
+ if (!HasStart()) {
+ if (fst_array_.size() == 1) { // no fsts defined for replace
+ SetStart(kNoStateId);
+ return kNoStateId;
+ } else {
+ const Fst<A>* fst = fst_array_[root_];
+ StateId fst_start = fst->Start();
+ if (fst_start == kNoStateId) // root Fst is empty
+ return kNoStateId;
+
+ PrefixId prefix = GetPrefixId(StackPrefix());
+ StateId start = state_table_->FindState(
+ StateTuple(prefix, root_, fst_start));
+ SetStart(start);
+ return start;
+ }
+ } else {
+ return CacheImpl<A>::Start();
+ }
+ }
+
+ // return final weight of state (kInfWeight means state is not final)
+ Weight Final(StateId s) {
+ if (!HasFinal(s)) {
+ const StateTuple& tuple = state_table_->Tuple(s);
+ const StackPrefix& stack = stackprefix_array_[tuple.prefix_id];
+ const Fst<A>* fst = fst_array_[tuple.fst_id];
+ StateId fst_state = tuple.fst_state;
+
+ if (fst->Final(fst_state) != Weight::Zero() && stack.Depth() == 0)
+ SetFinal(s, fst->Final(fst_state));
+ else
+ SetFinal(s, Weight::Zero());
+ }
+ return CacheImpl<A>::Final(s);
+ }
+
+ size_t NumArcs(StateId s) {
+ if (HasArcs(s)) { // If state cached, use the cached value.
+ return CacheImpl<A>::NumArcs(s);
+ } else if (always_cache_) { // If always caching, expand and cache state.
+ Expand(s);
+ return CacheImpl<A>::NumArcs(s);
+ } else { // Otherwise compute the number of arcs without expanding.
+ StateTuple tuple = state_table_->Tuple(s);
+ if (tuple.fst_state == kNoStateId)
+ return 0;
+
+ const Fst<A>* fst = fst_array_[tuple.fst_id];
+ size_t num_arcs = fst->NumArcs(tuple.fst_state);
+ if (ComputeFinalArc(tuple, 0))
+ num_arcs++;
+
+ return num_arcs;
+ }
+ }
+
+ // Returns whether a given label is a non terminal
+ bool IsNonTerminal(Label l) const {
+ // TODO(allauzen): be smarter and take advantage of
+ // all_dense or all_negative.
+ // Use also in ComputeArc, this would require changes to replace
+ // so that recursing into an empty fst lead to a non co-accessible
+ // state instead of deleting the arc as done currently.
+ // Current use correct, since i/olabel sorted iff all_non_empty.
+ typename NonTerminalHash::const_iterator it =
+ nonterminal_hash_.find(l);
+ return it != nonterminal_hash_.end();
+ }
+
+ size_t NumInputEpsilons(StateId s) {
+ if (HasArcs(s)) {
+ // If state cached, use the cached value.
+ return CacheImpl<A>::NumInputEpsilons(s);
+ } else if (always_cache_ || !Properties(kILabelSorted)) {
+ // If always caching or if the number of input epsilons is too expensive
+ // to compute without caching (i.e. not ilabel sorted),
+ // then expand and cache state.
+ Expand(s);
+ return CacheImpl<A>::NumInputEpsilons(s);
+ } else {
+ // Otherwise, compute the number of input epsilons without caching.
+ StateTuple tuple = state_table_->Tuple(s);
+ if (tuple.fst_state == kNoStateId)
+ return 0;
+ const Fst<A>* fst = fst_array_[tuple.fst_id];
+ size_t num = 0;
+ if (!epsilon_on_replace_) {
+ // If epsilon_on_replace is false, all input epsilon arcs
+ // are also input epsilons arcs in the underlying machine.
+ fst->NumInputEpsilons(tuple.fst_state);
+ } else {
+ // Otherwise, one need to consider that all non-terminal arcs
+ // in the underlying machine also become input epsilon arc.
+ ArcIterator<Fst<A> > aiter(*fst, tuple.fst_state);
+ for (; !aiter.Done() &&
+ ((aiter.Value().ilabel == 0) ||
+ IsNonTerminal(aiter.Value().olabel));
+ aiter.Next())
+ ++num;
+ }
+ if (ComputeFinalArc(tuple, 0))
+ num++;
+ return num;
+ }
+ }
+
+ size_t NumOutputEpsilons(StateId s) {
+ if (HasArcs(s)) {
+ // If state cached, use the cached value.
+ return CacheImpl<A>::NumOutputEpsilons(s);
+ } else if(always_cache_ || !Properties(kOLabelSorted)) {
+ // If always caching or if the number of output epsilons is too expensive
+ // to compute without caching (i.e. not olabel sorted),
+ // then expand and cache state.
+ Expand(s);
+ return CacheImpl<A>::NumOutputEpsilons(s);
+ } else {
+ // Otherwise, compute the number of output epsilons without caching.
+ StateTuple tuple = state_table_->Tuple(s);
+ if (tuple.fst_state == kNoStateId)
+ return 0;
+ const Fst<A>* fst = fst_array_[tuple.fst_id];
+ size_t num = 0;
+ ArcIterator<Fst<A> > aiter(*fst, tuple.fst_state);
+ for (; !aiter.Done() &&
+ ((aiter.Value().olabel == 0) ||
+ IsNonTerminal(aiter.Value().olabel));
+ aiter.Next())
+ ++num;
+ if (ComputeFinalArc(tuple, 0))
+ num++;
+ return num;
+ }
+ }
+
+ uint64 Properties() const { return Properties(kFstProperties); }
+
+ // Set error if found; return FST impl properties.
+ uint64 Properties(uint64 mask) const {
+ if (mask & kError) {
+ for (size_t i = 1; i < fst_array_.size(); ++i) {
+ if (fst_array_[i]->Properties(kError, false))
+ SetProperties(kError, kError);
+ }
+ }
+ return FstImpl<Arc>::Properties(mask);
+ }
+
+ // return the base arc iterator, if arcs have not been computed yet,
+ // extend/recurse for new arcs.
+ void InitArcIterator(StateId s, ArcIteratorData<A> *data) {
+ if (!HasArcs(s))
+ Expand(s);
+ CacheImpl<A>::InitArcIterator(s, data);
+ // TODO(allauzen): Set behaviour of generic iterator
+ // Warning: ArcIterator<ReplaceFst<A> >::InitCache()
+ // relies on current behaviour.
+ }
+
+
+ // Extend current state (walk arcs one level deep)
+ void Expand(StateId s) {
+ StateTuple tuple = state_table_->Tuple(s);
+
+ // If local fst is empty
+ if (tuple.fst_state == kNoStateId) {
+ SetArcs(s);
+ return;
+ }
+
+ ArcIterator< Fst<A> > aiter(
+ *(fst_array_[tuple.fst_id]), tuple.fst_state);
+ Arc arc;
+
+ // Create a final arc when needed
+ if (ComputeFinalArc(tuple, &arc))
+ PushArc(s, arc);
+
+ // Expand all arcs leaving the state
+ for (;!aiter.Done(); aiter.Next()) {
+ if (ComputeArc(tuple, aiter.Value(), &arc))
+ PushArc(s, arc);
+ }
+
+ SetArcs(s);
+ }
+
+ void Expand(StateId s, const StateTuple &tuple,
+ const ArcIteratorData<A> &data) {
+ // If local fst is empty
+ if (tuple.fst_state == kNoStateId) {
+ SetArcs(s);
+ return;
+ }
+
+ ArcIterator< Fst<A> > aiter(data);
+ Arc arc;
+
+ // Create a final arc when needed
+ if (ComputeFinalArc(tuple, &arc))
+ AddArc(s, arc);
+
+ // Expand all arcs leaving the state
+ for (; !aiter.Done(); aiter.Next()) {
+ if (ComputeArc(tuple, aiter.Value(), &arc))
+ AddArc(s, arc);
+ }
+
+ SetArcs(s);
+ }
+
+ // If arcp == 0, only returns if a final arc is required, does not
+ // actually compute it.
+ bool ComputeFinalArc(const StateTuple &tuple, A* arcp,
+ uint32 flags = kArcValueFlags) {
+ const Fst<A>* fst = fst_array_[tuple.fst_id];
+ StateId fst_state = tuple.fst_state;
+ if (fst_state == kNoStateId)
+ return false;
+
+ // if state is final, pop up stack
+ const StackPrefix& stack = stackprefix_array_[tuple.prefix_id];
+ if (fst->Final(fst_state) != Weight::Zero() && stack.Depth()) {
+ if (arcp) {
+ arcp->ilabel = 0;
+ arcp->olabel = 0;
+ if (flags & kArcNextStateValue) {
+ PrefixId prefix_id = PopPrefix(stack);
+ const PrefixTuple& top = stack.Top();
+ arcp->nextstate = state_table_->FindState(
+ StateTuple(prefix_id, top.fst_id, top.nextstate));
+ }
+ if (flags & kArcWeightValue)
+ arcp->weight = fst->Final(fst_state);
+ }
+ return true;
+ } else {
+ return false;
+ }
+ }
+
+ // Compute the arc in the replace fst corresponding to a given
+ // in the underlying machine. Returns false if the underlying arc
+ // corresponds to no arc in the replace.
+ bool ComputeArc(const StateTuple &tuple, const A &arc, A* arcp,
+ uint32 flags = kArcValueFlags) {
+ if (!epsilon_on_replace_ &&
+ (flags == (flags & (kArcILabelValue | kArcWeightValue)))) {
+ *arcp = arc;
+ return true;
+ }
+
+ if (arc.olabel == 0) { // expand local fst
+ StateId nextstate = flags & kArcNextStateValue
+ ? state_table_->FindState(
+ StateTuple(tuple.prefix_id, tuple.fst_id, arc.nextstate))
+ : kNoStateId;
+ *arcp = A(arc.ilabel, arc.olabel, arc.weight, nextstate);
+ } else {
+ // check for non terminal
+ typename NonTerminalHash::const_iterator it =
+ nonterminal_hash_.find(arc.olabel);
+ if (it != nonterminal_hash_.end()) { // recurse into non terminal
+ Label nonterminal = it->second;
+ const Fst<A>* nt_fst = fst_array_[nonterminal];
+ PrefixId nt_prefix = PushPrefix(stackprefix_array_[tuple.prefix_id],
+ tuple.fst_id, arc.nextstate);
+
+ // if start state is valid replace, else arc is implicitly
+ // deleted
+ StateId nt_start = nt_fst->Start();
+ if (nt_start != kNoStateId) {
+ StateId nt_nextstate = flags & kArcNextStateValue
+ ? state_table_->FindState(
+ StateTuple(nt_prefix, nonterminal, nt_start))
+ : kNoStateId;
+ Label ilabel = (epsilon_on_replace_) ? 0 : arc.ilabel;
+ *arcp = A(ilabel, 0, arc.weight, nt_nextstate);
+ } else {
+ return false;
+ }
+ } else {
+ StateId nextstate = flags & kArcNextStateValue
+ ? state_table_->FindState(
+ StateTuple(tuple.prefix_id, tuple.fst_id, arc.nextstate))
+ : kNoStateId;
+ *arcp = A(arc.ilabel, arc.olabel, arc.weight, nextstate);
+ }
+ }
+ return true;
+ }
+
+ // Returns the arc iterator flags supported by this Fst.
+ uint32 ArcIteratorFlags() const {
+ uint32 flags = kArcValueFlags;
+ if (!always_cache_)
+ flags |= kArcNoCache;
+ return flags;
+ }
+
+ T* GetStateTable() const {
+ return state_table_;
+ }
+
+ const Fst<A>* GetFst(Label fst_id) const {
+ return fst_array_[fst_id];
+ }
+
+ bool EpsilonOnReplace() const { return epsilon_on_replace_; }
+
+ // private helper classes
+ private:
+ static const size_t kPrime0;
+
+ // \class PrefixTuple
+ // \brief Tuple of fst_id and destination state (entry in stack prefix)
+ struct PrefixTuple {
+ PrefixTuple(Label f, StateId s) : fst_id(f), nextstate(s) {}
+
+ Label fst_id;
+ StateId nextstate;
+ };
+
+ // \class StackPrefix
+ // \brief Container for stack prefix.
+ class StackPrefix {
+ public:
+ StackPrefix() {}
+
+ // copy constructor
+ StackPrefix(const StackPrefix& x) :
+ prefix_(x.prefix_) {
+ }
+
+ void Push(StateId fst_id, StateId nextstate) {
+ prefix_.push_back(PrefixTuple(fst_id, nextstate));
+ }
+
+ void Pop() {
+ prefix_.pop_back();
+ }
+
+ const PrefixTuple& Top() const {
+ return prefix_[prefix_.size()-1];
+ }
+
+ size_t Depth() const {
+ return prefix_.size();
+ }
+
+ public:
+ vector<PrefixTuple> prefix_;
+ };
+
+
+ // \class StackPrefixEqual
+ // \brief Compare two stack prefix classes for equality
+ class StackPrefixEqual {
+ public:
+ bool operator()(const StackPrefix& x, const StackPrefix& y) const {
+ if (x.prefix_.size() != y.prefix_.size()) return false;
+ for (size_t i = 0; i < x.prefix_.size(); ++i) {
+ if (x.prefix_[i].fst_id != y.prefix_[i].fst_id ||
+ x.prefix_[i].nextstate != y.prefix_[i].nextstate) return false;
+ }
+ return true;
+ }
+ };
+
+ //
+ // \class StackPrefixKey
+ // \brief Hash function for stack prefix to prefix id
+ class StackPrefixKey {
+ public:
+ size_t operator()(const StackPrefix& x) const {
+ size_t sum = 0;
+ for (size_t i = 0; i < x.prefix_.size(); ++i) {
+ sum += x.prefix_[i].fst_id + x.prefix_[i].nextstate*kPrime0;
+ }
+ return sum;
+ }
+ };
+
+ typedef unordered_map<StackPrefix, PrefixId, StackPrefixKey, StackPrefixEqual>
+ StackPrefixHash;
+
+ // private methods
+ private:
+ // hash stack prefix (return unique index into stackprefix array)
+ PrefixId GetPrefixId(const StackPrefix& prefix) {
+ typename StackPrefixHash::iterator it = prefix_hash_.find(prefix);
+ if (it == prefix_hash_.end()) {
+ PrefixId prefix_id = stackprefix_array_.size();
+ stackprefix_array_.push_back(prefix);
+ prefix_hash_[prefix] = prefix_id;
+ return prefix_id;
+ } else {
+ return it->second;
+ }
+ }
+
+ // prefix id after a stack pop
+ PrefixId PopPrefix(StackPrefix prefix) {
+ prefix.Pop();
+ return GetPrefixId(prefix);
+ }
+
+ // prefix id after a stack push
+ PrefixId PushPrefix(StackPrefix prefix, Label fst_id, StateId nextstate) {
+ prefix.Push(fst_id, nextstate);
+ return GetPrefixId(prefix);
+ }
+
+
+ // private data
+ private:
+ // runtime options
+ bool epsilon_on_replace_;
+ bool always_cache_; // Optionally caching arc iterator disabled when true
+
+ // state table
+ StateTable *state_table_;
+
+ // cross index of unique stack prefix
+ // could potentially have one copy of prefix array
+ StackPrefixHash prefix_hash_;
+ vector<StackPrefix> stackprefix_array_;
+
+ set<Label> nonterminal_set_;
+ NonTerminalHash nonterminal_hash_;
+ vector<const Fst<A>*> fst_array_;
+ Label root_;
+
+ void operator=(const ReplaceFstImpl<A, T> &); // disallow
+};
+
+
+template <class A, class T>
+const size_t ReplaceFstImpl<A, T>::kPrime0 = 7853;
+
+//
+// \class ReplaceFst
+// \brief Recursivively replaces arcs in the root Fst with other Fsts.
+// This version is a delayed Fst.
+//
+// ReplaceFst supports dynamic replacement of arcs in one Fst with
+// another Fst. This replacement is recursive. ReplaceFst can be used
+// to support a variety of delayed constructions such as recursive
+// transition networks, union, or closure. It is constructed with an
+// array of Fst(s). One Fst represents the root (or topology)
+// machine. The root Fst refers to other Fsts by recursively replacing
+// arcs labeled as non-terminals with the matching non-terminal
+// Fst. Currently the ReplaceFst uses the output symbols of the arcs
+// to determine whether the arc is a non-terminal arc or not. A
+// non-terminal can be any label that is not a non-zero terminal label
+// in the output alphabet.
+//
+// Note that the constructor uses a vector of pair<>. These correspond
+// to the tuple of non-terminal Label and corresponding Fst. For example
+// to implement the closure operation we need 2 Fsts. The first root
+// Fst is a single Arc on the start State that self loops, it references
+// the particular machine for which we are performing the closure operation.
+//
+// The ReplaceFst class supports an optionally caching arc iterator:
+// ArcIterator< ReplaceFst<A> >
+// The ReplaceFst need to be built such that it is known to be ilabel
+// or olabel sorted (see usage below).
+//
+// Observe that Matcher<Fst<A> > will use the optionally caching arc
+// iterator when available (Fst is ilabel sorted and matching on the
+// input, or Fst is olabel sorted and matching on the output).
+// In order to obtain the most efficient behaviour, it is recommended
+// to set 'epsilon_on_replace' to false (this means constructing acceptors
+// as transducers with epsilons on the input side of nonterminal arcs)
+// and matching on the input side.
+//
+// This class attaches interface to implementation and handles
+// reference counting, delegating most methods to ImplToFst.
+template <class A, class T = DefaultReplaceStateTable<A> >
+class ReplaceFst : public ImplToFst< ReplaceFstImpl<A, T> > {
+ public:
+ friend class ArcIterator< ReplaceFst<A, T> >;
+ friend class StateIterator< ReplaceFst<A, T> >;
+ friend class ReplaceFstMatcher<A, T>;
+
+ typedef A Arc;
+ typedef typename A::Label Label;
+ typedef typename A::Weight Weight;
+ typedef typename A::StateId StateId;
+ typedef CacheState<A> State;
+ typedef ReplaceFstImpl<A, T> Impl;
+
+ using ImplToFst<Impl>::Properties;
+
+ ReplaceFst(const vector<pair<Label, const Fst<A>* > >& fst_array,