diff options
author | Determinant <ted.sybil@gmail.com> | 2015-08-14 11:51:42 +0800 |
---|---|---|
committer | Determinant <ted.sybil@gmail.com> | 2015-08-14 11:51:42 +0800 |
commit | 96a32415ab43377cf1575bd3f4f2980f58028209 (patch) | |
tree | 30a2d92d73e8f40ac87b79f6f56e227bfc4eea6e /kaldi_io/src/tools/openfst/include/fst/replace.h | |
parent | c177a7549bd90670af4b29fa813ddea32cfe0f78 (diff) |
add implementation for kaldi io (by ymz)
Diffstat (limited to 'kaldi_io/src/tools/openfst/include/fst/replace.h')
-rw-r--r-- | kaldi_io/src/tools/openfst/include/fst/replace.h | 1453 |
1 files changed, 1453 insertions, 0 deletions
diff --git a/kaldi_io/src/tools/openfst/include/fst/replace.h b/kaldi_io/src/tools/openfst/include/fst/replace.h new file mode 100644 index 0000000..ef5f6cc --- /dev/null +++ b/kaldi_io/src/tools/openfst/include/fst/replace.h @@ -0,0 +1,1453 @@ +// replace.h + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Copyright 2005-2010 Google, Inc. +// Author: johans@google.com (Johan Schalkwyk) +// +// \file +// Functions and classes for the recursive replacement of Fsts. +// + +#ifndef FST_LIB_REPLACE_H__ +#define FST_LIB_REPLACE_H__ + +#include <tr1/unordered_map> +using std::tr1::unordered_map; +using std::tr1::unordered_multimap; +#include <set> +#include <string> +#include <utility> +using std::pair; using std::make_pair; +#include <vector> +using std::vector; + +#include <fst/cache.h> +#include <fst/expanded-fst.h> +#include <fst/fst.h> +#include <fst/matcher.h> +#include <fst/replace-util.h> +#include <fst/state-table.h> +#include <fst/test-properties.h> + +namespace fst { + +// +// REPLACE STATE TUPLES AND TABLES +// +// The replace state table has the form +// +// template <class A, class P> +// class ReplaceStateTable { +// public: +// typedef A Arc; +// typedef P PrefixId; +// typedef typename A::StateId StateId; +// typedef ReplaceStateTuple<StateId, PrefixId> StateTuple; +// typedef typename A::Label Label; +// +// // Required constuctor +// ReplaceStateTable(const vector<pair<Label, const Fst<A>*> > &fst_tuples, +// Label root); +// +// // Required copy constructor that does not copy state +// ReplaceStateTable(const ReplaceStateTable<A,P> &table); +// +// // Lookup state ID by tuple. If it doesn't exist, then add it. +// StateId FindState(const StateTuple &tuple); +// +// // Lookup state tuple by ID. +// const StateTuple &Tuple(StateId id) const; +// }; + + +// \struct ReplaceStateTuple +// \brief Tuple of information that uniquely defines a state in replace +template <class S, class P> +struct ReplaceStateTuple { + typedef S StateId; + typedef P PrefixId; + + ReplaceStateTuple() + : prefix_id(-1), fst_id(kNoStateId), fst_state(kNoStateId) {} + + ReplaceStateTuple(PrefixId p, StateId f, StateId s) + : prefix_id(p), fst_id(f), fst_state(s) {} + + PrefixId prefix_id; // index in prefix table + StateId fst_id; // current fst being walked + StateId fst_state; // current state in fst being walked, not to be + // confused with the state_id of the combined fst +}; + + +// Equality of replace state tuples. +template <class S, class P> +inline bool operator==(const ReplaceStateTuple<S, P>& x, + const ReplaceStateTuple<S, P>& y) { + return x.prefix_id == y.prefix_id && + x.fst_id == y.fst_id && + x.fst_state == y.fst_state; +} + + +// \class ReplaceRootSelector +// Functor returning true for tuples corresponding to states in the root FST +template <class S, class P> +class ReplaceRootSelector { + public: + bool operator()(const ReplaceStateTuple<S, P> &tuple) const { + return tuple.prefix_id == 0; + } +}; + + +// \class ReplaceFingerprint +// Fingerprint for general replace state tuples. +template <class S, class P> +class ReplaceFingerprint { + public: + ReplaceFingerprint(const vector<uint64> *size_array) + : cumulative_size_array_(size_array) {} + + uint64 operator()(const ReplaceStateTuple<S, P> &tuple) const { + return tuple.prefix_id * (cumulative_size_array_->back()) + + cumulative_size_array_->at(tuple.fst_id - 1) + + tuple.fst_state; + } + + private: + const vector<uint64> *cumulative_size_array_; +}; + + +// \class ReplaceFstStateFingerprint +// Useful when the fst_state uniquely define the tuple. +template <class S, class P> +class ReplaceFstStateFingerprint { + public: + uint64 operator()(const ReplaceStateTuple<S, P>& tuple) const { + return tuple.fst_state; + } +}; + + +// \class ReplaceHash +// A generic hash function for replace state tuples. +template <typename S, typename P> +class ReplaceHash { + public: + size_t operator()(const ReplaceStateTuple<S, P>& t) const { + return t.prefix_id + t.fst_id * kPrime0 + t.fst_state * kPrime1; + } + private: + static const size_t kPrime0; + static const size_t kPrime1; +}; + +template <typename S, typename P> +const size_t ReplaceHash<S, P>::kPrime0 = 7853; + +template <typename S, typename P> +const size_t ReplaceHash<S, P>::kPrime1 = 7867; + +template <class A, class T> class ReplaceFstMatcher; + + +// \class VectorHashReplaceStateTable +// A two-level state table for replace. +// Warning: calls CountStates to compute the number of states of each +// component Fst. +template <class A, class P = ssize_t> +class VectorHashReplaceStateTable { + public: + typedef A Arc; + typedef typename A::StateId StateId; + typedef typename A::Label Label; + typedef P PrefixId; + typedef ReplaceStateTuple<StateId, P> StateTuple; + typedef VectorHashStateTable<ReplaceStateTuple<StateId, P>, + ReplaceRootSelector<StateId, P>, + ReplaceFstStateFingerprint<StateId, P>, + ReplaceFingerprint<StateId, P> > StateTable; + + VectorHashReplaceStateTable( + const vector<pair<Label, const Fst<A>*> > &fst_tuples, + Label root) : root_size_(0) { + cumulative_size_array_.push_back(0); + for (size_t i = 0; i < fst_tuples.size(); ++i) { + if (fst_tuples[i].first == root) { + root_size_ = CountStates(*(fst_tuples[i].second)); + cumulative_size_array_.push_back(cumulative_size_array_.back()); + } else { + cumulative_size_array_.push_back(cumulative_size_array_.back() + + CountStates(*(fst_tuples[i].second))); + } + } + state_table_ = new StateTable( + new ReplaceRootSelector<StateId, P>, + new ReplaceFstStateFingerprint<StateId, P>, + new ReplaceFingerprint<StateId, P>(&cumulative_size_array_), + root_size_, + root_size_ + cumulative_size_array_.back()); + } + + VectorHashReplaceStateTable(const VectorHashReplaceStateTable<A, P> &table) + : root_size_(table.root_size_), + cumulative_size_array_(table.cumulative_size_array_) { + state_table_ = new StateTable( + new ReplaceRootSelector<StateId, P>, + new ReplaceFstStateFingerprint<StateId, P>, + new ReplaceFingerprint<StateId, P>(&cumulative_size_array_), + root_size_, + root_size_ + cumulative_size_array_.back()); + } + + ~VectorHashReplaceStateTable() { + delete state_table_; + } + + StateId FindState(const StateTuple &tuple) { + return state_table_->FindState(tuple); + } + + const StateTuple &Tuple(StateId id) const { + return state_table_->Tuple(id); + } + + private: + StateId root_size_; + vector<uint64> cumulative_size_array_; + StateTable *state_table_; +}; + + +// \class DefaultReplaceStateTable +// Default replace state table +template <class A, class P = ssize_t> +class DefaultReplaceStateTable : public CompactHashStateTable< + ReplaceStateTuple<typename A::StateId, P>, + ReplaceHash<typename A::StateId, P> > { + public: + typedef A Arc; + typedef typename A::StateId StateId; + typedef typename A::Label Label; + typedef P PrefixId; + typedef ReplaceStateTuple<StateId, P> StateTuple; + typedef CompactHashStateTable<StateTuple, + ReplaceHash<StateId, PrefixId> > StateTable; + + using StateTable::FindState; + using StateTable::Tuple; + + DefaultReplaceStateTable( + const vector<pair<Label, const Fst<A>*> > &fst_tuples, + Label root) {} + + DefaultReplaceStateTable(const DefaultReplaceStateTable<A, P> &table) + : StateTable() {} +}; + +// +// REPLACE FST CLASS +// + +// By default ReplaceFst will copy the input label of the 'replace arc'. +// For acceptors we do not want this behaviour. Instead we need to +// create an epsilon arc when recursing into the appropriate Fst. +// The 'epsilon_on_replace' option can be used to toggle this behaviour. +template <class A, class T = DefaultReplaceStateTable<A> > +struct ReplaceFstOptions : CacheOptions { + int64 root; // root rule for expansion + bool epsilon_on_replace; + bool take_ownership; // take ownership of input Fst(s) + T* state_table; + + ReplaceFstOptions(const CacheOptions &opts, int64 r) + : CacheOptions(opts), + root(r), + epsilon_on_replace(false), + take_ownership(false), + state_table(0) {} + explicit ReplaceFstOptions(int64 r) + : root(r), + epsilon_on_replace(false), + take_ownership(false), + state_table(0) {} + ReplaceFstOptions(int64 r, bool epsilon_replace_arc) + : root(r), + epsilon_on_replace(epsilon_replace_arc), + take_ownership(false), + state_table(0) {} + ReplaceFstOptions() + : root(kNoLabel), + epsilon_on_replace(false), + take_ownership(false), + state_table(0) {} +}; + + +// \class ReplaceFstImpl +// \brief Implementation class for replace class Fst +// +// The replace implementation class supports a dynamic +// expansion of a recursive transition network represented as Fst +// with dynamic replacable arcs. +// +template <class A, class T> +class ReplaceFstImpl : public CacheImpl<A> { + friend class ReplaceFstMatcher<A, T>; + + public: + using FstImpl<A>::SetType; + using FstImpl<A>::SetProperties; + using FstImpl<A>::WriteHeader; + using FstImpl<A>::SetInputSymbols; + using FstImpl<A>::SetOutputSymbols; + using FstImpl<A>::InputSymbols; + using FstImpl<A>::OutputSymbols; + + using CacheImpl<A>::PushArc; + using CacheImpl<A>::HasArcs; + using CacheImpl<A>::HasFinal; + using CacheImpl<A>::HasStart; + using CacheImpl<A>::SetArcs; + using CacheImpl<A>::SetFinal; + using CacheImpl<A>::SetStart; + + typedef typename A::Label Label; + typedef typename A::Weight Weight; + typedef typename A::StateId StateId; + typedef CacheState<A> State; + typedef A Arc; + typedef unordered_map<Label, Label> NonTerminalHash; + + typedef T StateTable; + typedef typename T::PrefixId PrefixId; + typedef ReplaceStateTuple<StateId, PrefixId> StateTuple; + + // constructor for replace class implementation. + // \param fst_tuples array of label/fst tuples, one for each non-terminal + ReplaceFstImpl(const vector< pair<Label, const Fst<A>* > >& fst_tuples, + const ReplaceFstOptions<A, T> &opts) + : CacheImpl<A>(opts), + epsilon_on_replace_(opts.epsilon_on_replace), + state_table_(opts.state_table ? opts.state_table : + new StateTable(fst_tuples, opts.root)) { + + SetType("replace"); + + if (fst_tuples.size() > 0) { + SetInputSymbols(fst_tuples[0].second->InputSymbols()); + SetOutputSymbols(fst_tuples[0].second->OutputSymbols()); + } + + bool all_negative = true; // all nonterminals are negative? + bool dense_range = true; // all nonterminals are positive + // and form a dense range containing 1? + for (size_t i = 0; i < fst_tuples.size(); ++i) { + Label nonterminal = fst_tuples[i].first; + if (nonterminal >= 0) + all_negative = false; + if (nonterminal > fst_tuples.size() || nonterminal <= 0) + dense_range = false; + } + + vector<uint64> inprops; + bool all_ilabel_sorted = true; + bool all_olabel_sorted = true; + bool all_non_empty = true; + fst_array_.push_back(0); + for (size_t i = 0; i < fst_tuples.size(); ++i) { + Label label = fst_tuples[i].first; + const Fst<A> *fst = fst_tuples[i].second; + nonterminal_hash_[label] = fst_array_.size(); + nonterminal_set_.insert(label); + fst_array_.push_back(opts.take_ownership ? fst : fst->Copy()); + if (fst->Start() == kNoStateId) + all_non_empty = false; + if(!fst->Properties(kILabelSorted, false)) + all_ilabel_sorted = false; + if(!fst->Properties(kOLabelSorted, false)) + all_olabel_sorted = false; + inprops.push_back(fst->Properties(kCopyProperties, false)); + if (i) { + if (!CompatSymbols(InputSymbols(), fst->InputSymbols())) { + FSTERROR() << "ReplaceFstImpl: input symbols of Fst " << i + << " does not match input symbols of base Fst (0'th fst)"; + SetProperties(kError, kError); + } + if (!CompatSymbols(OutputSymbols(), fst->OutputSymbols())) { + FSTERROR() << "ReplaceFstImpl: output symbols of Fst " << i + << " does not match output symbols of base Fst " + << "(0'th fst)"; + SetProperties(kError, kError); + } + } + } + Label nonterminal = nonterminal_hash_[opts.root]; + if ((nonterminal == 0) && (fst_array_.size() > 1)) { + FSTERROR() << "ReplaceFstImpl: no Fst corresponding to root label '" + << opts.root << "' in the input tuple vector"; + SetProperties(kError, kError); + } + root_ = (nonterminal > 0) ? nonterminal : 1; + + SetProperties(ReplaceProperties(inprops, root_ - 1, epsilon_on_replace_, + all_non_empty)); + // We assume that all terminals are positive. The resulting + // ReplaceFst is known to be kILabelSorted when all sub-FSTs are + // kILabelSorted and one of the 3 following conditions is satisfied: + // 1. 'epsilon_on_replace' is false, or + // 2. all non-terminals are negative, or + // 3. all non-terninals are positive and form a dense range containing 1. + if (all_ilabel_sorted && + (!epsilon_on_replace_ || all_negative || dense_range)) + SetProperties(kILabelSorted, kILabelSorted); + // Similarly, the resulting ReplaceFst is known to be + // kOLabelSorted when all sub-FSTs are kOLabelSorted and one of + // the 2 following conditions is satisfied: + // 1. all non-terminals are negative, or + // 2. all non-terninals are positive and form a dense range containing 1. + if (all_olabel_sorted && (all_negative || dense_range)) + SetProperties(kOLabelSorted, kOLabelSorted); + + // Enable optional caching as long as sorted and all non empty. + if (Properties(kILabelSorted | kOLabelSorted) && all_non_empty) + always_cache_ = false; + else + always_cache_ = true; + VLOG(2) << "ReplaceFstImpl::ReplaceFstImpl: always_cache = " + << (always_cache_ ? "true" : "false"); + } + + ReplaceFstImpl(const ReplaceFstImpl& impl) + : CacheImpl<A>(impl), + epsilon_on_replace_(impl.epsilon_on_replace_), + always_cache_(impl.always_cache_), + state_table_(new StateTable(*(impl.state_table_))), + nonterminal_set_(impl.nonterminal_set_), + nonterminal_hash_(impl.nonterminal_hash_), + root_(impl.root_) { + SetType("replace"); + SetProperties(impl.Properties(), kCopyProperties); + SetInputSymbols(impl.InputSymbols()); + SetOutputSymbols(impl.OutputSymbols()); + fst_array_.reserve(impl.fst_array_.size()); + fst_array_.push_back(0); + for (size_t i = 1; i < impl.fst_array_.size(); ++i) { + fst_array_.push_back(impl.fst_array_[i]->Copy(true)); + } + } + + ~ReplaceFstImpl() { + VLOG(2) << "~ReplaceFstImpl: gc = " + << (CacheImpl<A>::GetCacheGc() ? "true" : "false") + << ", gc_size = " << CacheImpl<A>::GetCacheSize() + << ", gc_limit = " << CacheImpl<A>::GetCacheLimit(); + + delete state_table_; + for (size_t i = 1; i < fst_array_.size(); ++i) { + delete fst_array_[i]; + } + } + + // Computes the dependency graph of the replace class and returns + // true if the dependencies are cyclic. Cyclic dependencies will result + // in an un-expandable replace fst. + bool CyclicDependencies() const { + ReplaceUtil<A> replace_util(fst_array_, nonterminal_hash_, root_); + return replace_util.CyclicDependencies(); + } + + // Return or compute start state of replace fst + StateId Start() { + if (!HasStart()) { + if (fst_array_.size() == 1) { // no fsts defined for replace + SetStart(kNoStateId); + return kNoStateId; + } else { + const Fst<A>* fst = fst_array_[root_]; + StateId fst_start = fst->Start(); + if (fst_start == kNoStateId) // root Fst is empty + return kNoStateId; + + PrefixId prefix = GetPrefixId(StackPrefix()); + StateId start = state_table_->FindState( + StateTuple(prefix, root_, fst_start)); + SetStart(start); + return start; + } + } else { + return CacheImpl<A>::Start(); + } + } + + // return final weight of state (kInfWeight means state is not final) + Weight Final(StateId s) { + if (!HasFinal(s)) { + const StateTuple& tuple = state_table_->Tuple(s); + const StackPrefix& stack = stackprefix_array_[tuple.prefix_id]; + const Fst<A>* fst = fst_array_[tuple.fst_id]; + StateId fst_state = tuple.fst_state; + + if (fst->Final(fst_state) != Weight::Zero() && stack.Depth() == 0) + SetFinal(s, fst->Final(fst_state)); + else + SetFinal(s, Weight::Zero()); + } + return CacheImpl<A>::Final(s); + } + + size_t NumArcs(StateId s) { + if (HasArcs(s)) { // If state cached, use the cached value. + return CacheImpl<A>::NumArcs(s); + } else if (always_cache_) { // If always caching, expand and cache state. + Expand(s); + return CacheImpl<A>::NumArcs(s); + } else { // Otherwise compute the number of arcs without expanding. + StateTuple tuple = state_table_->Tuple(s); + if (tuple.fst_state == kNoStateId) + return 0; + + const Fst<A>* fst = fst_array_[tuple.fst_id]; + size_t num_arcs = fst->NumArcs(tuple.fst_state); + if (ComputeFinalArc(tuple, 0)) + num_arcs++; + + return num_arcs; + } + } + + // Returns whether a given label is a non terminal + bool IsNonTerminal(Label l) const { + // TODO(allauzen): be smarter and take advantage of + // all_dense or all_negative. + // Use also in ComputeArc, this would require changes to replace + // so that recursing into an empty fst lead to a non co-accessible + // state instead of deleting the arc as done currently. + // Current use correct, since i/olabel sorted iff all_non_empty. + typename NonTerminalHash::const_iterator it = + nonterminal_hash_.find(l); + return it != nonterminal_hash_.end(); + } + + size_t NumInputEpsilons(StateId s) { + if (HasArcs(s)) { + // If state cached, use the cached value. + return CacheImpl<A>::NumInputEpsilons(s); + } else if (always_cache_ || !Properties(kILabelSorted)) { + // If always caching or if the number of input epsilons is too expensive + // to compute without caching (i.e. not ilabel sorted), + // then expand and cache state. + Expand(s); + return CacheImpl<A>::NumInputEpsilons(s); + } else { + // Otherwise, compute the number of input epsilons without caching. + StateTuple tuple = state_table_->Tuple(s); + if (tuple.fst_state == kNoStateId) + return 0; + const Fst<A>* fst = fst_array_[tuple.fst_id]; + size_t num = 0; + if (!epsilon_on_replace_) { + // If epsilon_on_replace is false, all input epsilon arcs + // are also input epsilons arcs in the underlying machine. + fst->NumInputEpsilons(tuple.fst_state); + } else { + // Otherwise, one need to consider that all non-terminal arcs + // in the underlying machine also become input epsilon arc. + ArcIterator<Fst<A> > aiter(*fst, tuple.fst_state); + for (; !aiter.Done() && + ((aiter.Value().ilabel == 0) || + IsNonTerminal(aiter.Value().olabel)); + aiter.Next()) + ++num; + } + if (ComputeFinalArc(tuple, 0)) + num++; + return num; + } + } + + size_t NumOutputEpsilons(StateId s) { + if (HasArcs(s)) { + // If state cached, use the cached value. + return CacheImpl<A>::NumOutputEpsilons(s); + } else if(always_cache_ || !Properties(kOLabelSorted)) { + // If always caching or if the number of output epsilons is too expensive + // to compute without caching (i.e. not olabel sorted), + // then expand and cache state. + Expand(s); + return CacheImpl<A>::NumOutputEpsilons(s); + } else { + // Otherwise, compute the number of output epsilons without caching. + StateTuple tuple = state_table_->Tuple(s); + if (tuple.fst_state == kNoStateId) + return 0; + const Fst<A>* fst = fst_array_[tuple.fst_id]; + size_t num = 0; + ArcIterator<Fst<A> > aiter(*fst, tuple.fst_state); + for (; !aiter.Done() && + ((aiter.Value().olabel == 0) || + IsNonTerminal(aiter.Value().olabel)); + aiter.Next()) + ++num; + if (ComputeFinalArc(tuple, 0)) + num++; + return num; + } + } + + uint64 Properties() const { return Properties(kFstProperties); } + + // Set error if found; return FST impl properties. + uint64 Properties(uint64 mask) const { + if (mask & kError) { + for (size_t i = 1; i < fst_array_.size(); ++i) { + if (fst_array_[i]->Properties(kError, false)) + SetProperties(kError, kError); + } + } + return FstImpl<Arc>::Properties(mask); + } + + // return the base arc iterator, if arcs have not been computed yet, + // extend/recurse for new arcs. + void InitArcIterator(StateId s, ArcIteratorData<A> *data) { + if (!HasArcs(s)) + Expand(s); + CacheImpl<A>::InitArcIterator(s, data); + // TODO(allauzen): Set behaviour of generic iterator + // Warning: ArcIterator<ReplaceFst<A> >::InitCache() + // relies on current behaviour. + } + + + // Extend current state (walk arcs one level deep) + void Expand(StateId s) { + StateTuple tuple = state_table_->Tuple(s); + + // If local fst is empty + if (tuple.fst_state == kNoStateId) { + SetArcs(s); + return; + } + + ArcIterator< Fst<A> > aiter( + *(fst_array_[tuple.fst_id]), tuple.fst_state); + Arc arc; + + // Create a final arc when needed + if (ComputeFinalArc(tuple, &arc)) + PushArc(s, arc); + + // Expand all arcs leaving the state + for (;!aiter.Done(); aiter.Next()) { + if (ComputeArc(tuple, aiter.Value(), &arc)) + PushArc(s, arc); + } + + SetArcs(s); + } + + void Expand(StateId s, const StateTuple &tuple, + const ArcIteratorData<A> &data) { + // If local fst is empty + if (tuple.fst_state == kNoStateId) { + SetArcs(s); + return; + } + + ArcIterator< Fst<A> > aiter(data); + Arc arc; + + // Create a final arc when needed + if (ComputeFinalArc(tuple, &arc)) + AddArc(s, arc); + + // Expand all arcs leaving the state + for (; !aiter.Done(); aiter.Next()) { + if (ComputeArc(tuple, aiter.Value(), &arc)) + AddArc(s, arc); + } + + SetArcs(s); + } + + // If arcp == 0, only returns if a final arc is required, does not + // actually compute it. + bool ComputeFinalArc(const StateTuple &tuple, A* arcp, + uint32 flags = kArcValueFlags) { + const Fst<A>* fst = fst_array_[tuple.fst_id]; + StateId fst_state = tuple.fst_state; + if (fst_state == kNoStateId) + return false; + + // if state is final, pop up stack + const StackPrefix& stack = stackprefix_array_[tuple.prefix_id]; + if (fst->Final(fst_state) != Weight::Zero() && stack.Depth()) { + if (arcp) { + arcp->ilabel = 0; + arcp->olabel = 0; + if (flags & kArcNextStateValue) { + PrefixId prefix_id = PopPrefix(stack); + const PrefixTuple& top = stack.Top(); + arcp->nextstate = state_table_->FindState( + StateTuple(prefix_id, top.fst_id, top.nextstate)); + } + if (flags & kArcWeightValue) + arcp->weight = fst->Final(fst_state); + } + return true; + } else { + return false; + } + } + + // Compute the arc in the replace fst corresponding to a given + // in the underlying machine. Returns false if the underlying arc + // corresponds to no arc in the replace. + bool ComputeArc(const StateTuple &tuple, const A &arc, A* arcp, + uint32 flags = kArcValueFlags) { + if (!epsilon_on_replace_ && + (flags == (flags & (kArcILabelValue | kArcWeightValue)))) { + *arcp = arc; + return true; + } + + if (arc.olabel == 0) { // expand local fst + StateId nextstate = flags & kArcNextStateValue + ? state_table_->FindState( + StateTuple(tuple.prefix_id, tuple.fst_id, arc.nextstate)) + : kNoStateId; + *arcp = A(arc.ilabel, arc.olabel, arc.weight, nextstate); + } else { + // check for non terminal + typename NonTerminalHash::const_iterator it = + nonterminal_hash_.find(arc.olabel); + if (it != nonterminal_hash_.end()) { // recurse into non terminal + Label nonterminal = it->second; + const Fst<A>* nt_fst = fst_array_[nonterminal]; + PrefixId nt_prefix = PushPrefix(stackprefix_array_[tuple.prefix_id], + tuple.fst_id, arc.nextstate); + + // if start state is valid replace, else arc is implicitly + // deleted + StateId nt_start = nt_fst->Start(); + if (nt_start != kNoStateId) { + StateId nt_nextstate = flags & kArcNextStateValue + ? state_table_->FindState( + StateTuple(nt_prefix, nonterminal, nt_start)) + : kNoStateId; + Label ilabel = (epsilon_on_replace_) ? 0 : arc.ilabel; + *arcp = A(ilabel, 0, arc.weight, nt_nextstate); + } else { + return false; + } + } else { + StateId nextstate = flags & kArcNextStateValue + ? state_table_->FindState( + StateTuple(tuple.prefix_id, tuple.fst_id, arc.nextstate)) + : kNoStateId; + *arcp = A(arc.ilabel, arc.olabel, arc.weight, nextstate); + } + } + return true; + } + + // Returns the arc iterator flags supported by this Fst. + uint32 ArcIteratorFlags() const { + uint32 flags = kArcValueFlags; + if (!always_cache_) + flags |= kArcNoCache; + return flags; + } + + T* GetStateTable() const { + return state_table_; + } + + const Fst<A>* GetFst(Label fst_id) const { + return fst_array_[fst_id]; + } + + bool EpsilonOnReplace() const { return epsilon_on_replace_; } + + // private helper classes + private: + static const size_t kPrime0; + + // \class PrefixTuple + // \brief Tuple of fst_id and destination state (entry in stack prefix) + struct PrefixTuple { + PrefixTuple(Label f, StateId s) : fst_id(f), nextstate(s) {} + + Label fst_id; + StateId nextstate; + }; + + // \class StackPrefix + // \brief Container for stack prefix. + class StackPrefix { + public: + StackPrefix() {} + + // copy constructor + StackPrefix(const StackPrefix& x) : + prefix_(x.prefix_) { + } + + void Push(StateId fst_id, StateId nextstate) { + prefix_.push_back(PrefixTuple(fst_id, nextstate)); + } + + void Pop() { + prefix_.pop_back(); + } + + const PrefixTuple& Top() const { + return prefix_[prefix_.size()-1]; + } + + size_t Depth() const { + return prefix_.size(); + } + + public: + vector<PrefixTuple> prefix_; + }; + + + // \class StackPrefixEqual + // \brief Compare two stack prefix classes for equality + class StackPrefixEqual { + public: + bool operator()(const StackPrefix& x, const StackPrefix& y) const { + if (x.prefix_.size() != y.prefix_.size()) return false; + for (size_t i = 0; i < x.prefix_.size(); ++i) { + if (x.prefix_[i].fst_id != y.prefix_[i].fst_id || + x.prefix_[i].nextstate != y.prefix_[i].nextstate) return false; + } + return true; + } + }; + + // + // \class StackPrefixKey + // \brief Hash function for stack prefix to prefix id + class StackPrefixKey { + public: + size_t operator()(const StackPrefix& x) const { + size_t sum = 0; + for (size_t i = 0; i < x.prefix_.size(); ++i) { + sum += x.prefix_[i].fst_id + x.prefix_[i].nextstate*kPrime0; + } + return sum; + } + }; + + typedef unordered_map<StackPrefix, PrefixId, StackPrefixKey, StackPrefixEqual> + StackPrefixHash; + + // private methods + private: + // hash stack prefix (return unique index into stackprefix array) + PrefixId GetPrefixId(const StackPrefix& prefix) { + typename StackPrefixHash::iterator it = prefix_hash_.find(prefix); + if (it == prefix_hash_.end()) { + PrefixId prefix_id = stackprefix_array_.size(); + stackprefix_array_.push_back(prefix); + prefix_hash_[prefix] = prefix_id; + return prefix_id; + } else { + return it->second; + } + } + + // prefix id after a stack pop + PrefixId PopPrefix(StackPrefix prefix) { + prefix.Pop(); + return GetPrefixId(prefix); + } + + // prefix id after a stack push + PrefixId PushPrefix(StackPrefix prefix, Label fst_id, StateId nextstate) { + prefix.Push(fst_id, nextstate); + return GetPrefixId(prefix); + } + + + // private data + private: + // runtime options + bool epsilon_on_replace_; + bool always_cache_; // Optionally caching arc iterator disabled when true + + // state table + StateTable *state_table_; + + // cross index of unique stack prefix + // could potentially have one copy of prefix array + StackPrefixHash prefix_hash_; + vector<StackPrefix> stackprefix_array_; + + set<Label> nonterminal_set_; + NonTerminalHash nonterminal_hash_; + vector<const Fst<A>*> fst_array_; + Label root_; + + void operator=(const ReplaceFstImpl<A, T> &); // disallow +}; + + +template <class A, class T> +const size_t ReplaceFstImpl<A, T>::kPrime0 = 7853; + +// +// \class ReplaceFst +// \brief Recursivively replaces arcs in the root Fst with other Fsts. +// This version is a delayed Fst. +// +// ReplaceFst supports dynamic replacement of arcs in one Fst with +// another Fst. This replacement is recursive. ReplaceFst can be used +// to support a variety of delayed constructions such as recursive +// transition networks, union, or closure. It is constructed with an +// array of Fst(s). One Fst represents the root (or topology) +// machine. The root Fst refers to other Fsts by recursively replacing +// arcs labeled as non-terminals with the matching non-terminal +// Fst. Currently the ReplaceFst uses the output symbols of the arcs +// to determine whether the arc is a non-terminal arc or not. A +// non-terminal can be any label that is not a non-zero terminal label +// in the output alphabet. +// +// Note that the constructor uses a vector of pair<>. These correspond +// to the tuple of non-terminal Label and corresponding Fst. For example +// to implement the closure operation we need 2 Fsts. The first root +// Fst is a single Arc on the start State that self loops, it references +// the particular machine for which we are performing the closure operation. +// +// The ReplaceFst class supports an optionally caching arc iterator: +// ArcIterator< ReplaceFst<A> > +// The ReplaceFst need to be built such that it is known to be ilabel +// or olabel sorted (see usage below). +// +// Observe that Matcher<Fst<A> > will use the optionally caching arc +// iterator when available (Fst is ilabel sorted and matching on the +// input, or Fst is olabel sorted and matching on the output). +// In order to obtain the most efficient behaviour, it is recommended +// to set 'epsilon_on_replace' to false (this means constructing acceptors +// as transducers with epsilons on the input side of nonterminal arcs) +// and matching on the input side. +// +// This class attaches interface to implementation and handles +// reference counting, delegating most methods to ImplToFst. +template <class A, class T = DefaultReplaceStateTable<A> > +class ReplaceFst : public ImplToFst< ReplaceFstImpl<A, T> > { + public: + friend class ArcIterator< ReplaceFst<A, T> >; + friend class StateIterator< ReplaceFst<A, T> >; + friend class ReplaceFstMatcher<A, T>; + + typedef A Arc; + typedef typename A::Label Label; + typedef typename A::Weight Weight; + typedef typename A::StateId StateId; + typedef CacheState<A> State; + typedef ReplaceFstImpl<A, |