diff options
Diffstat (limited to 'kaldi_io/src/tools/openfst/include/fst/replace-util.h')
-rw-r--r-- | kaldi_io/src/tools/openfst/include/fst/replace-util.h | 550 |
1 files changed, 0 insertions, 550 deletions
diff --git a/kaldi_io/src/tools/openfst/include/fst/replace-util.h b/kaldi_io/src/tools/openfst/include/fst/replace-util.h deleted file mode 100644 index d58cb15..0000000 --- a/kaldi_io/src/tools/openfst/include/fst/replace-util.h +++ /dev/null @@ -1,550 +0,0 @@ -// replace-util.h - - -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// -// Copyright 2005-2010 Google, Inc. -// Author: [email protected] (Michael Riley) -// - -// \file -// Utility classes for the recursive replacement of Fsts (RTNs). - -#ifndef FST_LIB_REPLACE_UTIL_H__ -#define FST_LIB_REPLACE_UTIL_H__ - -#include <vector> -using std::vector; -#include <tr1/unordered_map> -using std::tr1::unordered_map; -using std::tr1::unordered_multimap; -#include <tr1/unordered_set> -using std::tr1::unordered_set; -using std::tr1::unordered_multiset; -#include <map> - -#include <fst/connect.h> -#include <fst/mutable-fst.h> -#include <fst/topsort.h> - - -namespace fst { - -template <class Arc> -void Replace(const vector<pair<typename Arc::Label, const Fst<Arc>* > >&, - MutableFst<Arc> *, typename Arc::Label, bool); - - -// Utility class for the recursive replacement of Fsts (RTNs). The -// user provides a set of Label, Fst pairs at construction. These are -// used by methods for testing cyclic dependencies and connectedness -// and doing RTN connection and specific Fst replacement by label or -// for various optimization properties. The modified results can be -// obtained with the GetFstPairs() or GetMutableFstPairs() methods. -template <class Arc> -class ReplaceUtil { - public: - typedef typename Arc::Label Label; - typedef typename Arc::Weight Weight; - typedef typename Arc::StateId StateId; - - typedef pair<Label, const Fst<Arc>*> FstPair; - typedef pair<Label, MutableFst<Arc>*> MutableFstPair; - typedef unordered_map<Label, Label> NonTerminalHash; - - // Constructs from mutable Fsts; Fst ownership given to ReplaceUtil. - ReplaceUtil(const vector<MutableFstPair> &fst_pairs, - Label root_label, bool epsilon_on_replace = false); - - // Constructs from Fsts; Fst ownership retained by caller. - ReplaceUtil(const vector<FstPair> &fst_pairs, - Label root_label, bool epsilon_on_replace = false); - - // Constructs from ReplaceFst internals; ownership retained by caller. - ReplaceUtil(const vector<const Fst<Arc> *> &fst_array, - const NonTerminalHash &nonterminal_hash, Label root_fst, - bool epsilon_on_replace = false); - - ~ReplaceUtil() { - for (Label i = 0; i < fst_array_.size(); ++i) - delete fst_array_[i]; - } - - // True if the non-terminal dependencies are cyclic. Cyclic - // dependencies will result in an unexpandable replace fst. - bool CyclicDependencies() const { - GetDependencies(false); - return depprops_ & kCyclic; - } - - // Returns true if no useless Fsts, states or transitions. - bool Connected() const { - GetDependencies(false); - uint64 props = kAccessible | kCoAccessible; - for (Label i = 0; i < fst_array_.size(); ++i) { - if (!fst_array_[i]) - continue; - if (fst_array_[i]->Properties(props, true) != props || !depaccess_[i]) - return false; - } - return true; - } - - // Removes useless Fsts, states and transitions. - void Connect(); - - // Replaces Fsts specified by labels. - // Does nothing if there are cyclic dependencies. - void ReplaceLabels(const vector<Label> &labels); - - // Replaces Fsts that have at most 'nstates' states, 'narcs' arcs and - // 'nnonterm' non-terminals (updating in reverse dependency order). - // Does nothing if there are cyclic dependencies. - void ReplaceBySize(size_t nstates, size_t narcs, size_t nnonterms); - - // Replaces singleton Fsts. - // Does nothing if there are cyclic dependencies. - void ReplaceTrivial() { ReplaceBySize(2, 1, 1); } - - // Replaces non-terminals that have at most 'ninstances' instances - // (updating in dependency order). - // Does nothing if there are cyclic dependencies. - void ReplaceByInstances(size_t ninstances); - - // Replaces non-terminals that have only one instance. - // Does nothing if there are cyclic dependencies. - void ReplaceUnique() { ReplaceByInstances(1); } - - // Returns Label, Fst pairs; Fst ownership retained by ReplaceUtil. - void GetFstPairs(vector<FstPair> *fst_pairs); - - // Returns Label, MutableFst pairs; Fst ownership given to caller. - void GetMutableFstPairs(vector<MutableFstPair> *mutable_fst_pairs); - - private: - // Per Fst statistics - struct ReplaceStats { - StateId nstates; // # of states - StateId nfinal; // # of final states - size_t narcs; // # of arcs - Label nnonterms; // # of non-terminals in Fst - size_t nref; // # of non-terminal instances referring to this Fst - - // # of times that ith Fst references this Fst - map<Label, size_t> inref; - // # of times that this Fst references the ith Fst - map<Label, size_t> outref; - - ReplaceStats() - : nstates(0), - nfinal(0), - narcs(0), - nnonterms(0), - nref(0) {} - }; - - // Check Mutable Fsts exist o.w. create them. - void CheckMutableFsts(); - - // Computes the dependency graph of the replace Fsts. - // If 'stats' is true, dependency statistics computed as well. - void GetDependencies(bool stats) const; - - void ClearDependencies() const { - depfst_.DeleteStates(); - stats_.clear(); - depprops_ = 0; - have_stats_ = false; - } - - // Get topological order of dependencies. Returns false with cyclic input. - bool GetTopOrder(const Fst<Arc> &fst, vector<Label> *toporder) const; - - // Update statistics assuming that jth Fst will be replaced. - void UpdateStats(Label j); - - Label root_label_; // root non-terminal - Label root_fst_; // root Fst ID - bool epsilon_on_replace_; // see Replace() - vector<const Fst<Arc> *> fst_array_; // Fst per ID - vector<MutableFst<Arc> *> mutable_fst_array_; // MutableFst per ID - vector<Label> nonterminal_array_; // Fst ID to non-terminal - NonTerminalHash nonterminal_hash_; // non-terminal to Fst ID - mutable VectorFst<Arc> depfst_; // Fst ID dependencies - mutable vector<bool> depaccess_; // Fst ID accessibility - mutable uint64 depprops_; // dependency Fst props - mutable bool have_stats_; // have dependency statistics - mutable vector<ReplaceStats> stats_; // Per Fst statistics - DISALLOW_COPY_AND_ASSIGN(ReplaceUtil); -}; - -template <class Arc> -ReplaceUtil<Arc>::ReplaceUtil( - const vector<MutableFstPair> &fst_pairs, - Label root_label, bool epsilon_on_replace) - : root_label_(root_label), - epsilon_on_replace_(epsilon_on_replace), - depprops_(0), - have_stats_(false) { - fst_array_.push_back(0); - mutable_fst_array_.push_back(0); - nonterminal_array_.push_back(kNoLabel); - for (Label i = 0; i < fst_pairs.size(); ++i) { - Label label = fst_pairs[i].first; - MutableFst<Arc> *fst = fst_pairs[i].second; - nonterminal_hash_[label] = fst_array_.size(); - nonterminal_array_.push_back(label); - fst_array_.push_back(fst); - mutable_fst_array_.push_back(fst); - } - root_fst_ = nonterminal_hash_[root_label_]; - if (!root_fst_) - FSTERROR() << "ReplaceUtil: no root FST for label: " << root_label_; -} - -template <class Arc> -ReplaceUtil<Arc>::ReplaceUtil( - const vector<FstPair> &fst_pairs, - Label root_label, bool epsilon_on_replace) - : root_label_(root_label), - epsilon_on_replace_(epsilon_on_replace), - depprops_(0), - have_stats_(false) { - fst_array_.push_back(0); - nonterminal_array_.push_back(kNoLabel); - for (Label i = 0; i < fst_pairs.size(); ++i) { - Label label = fst_pairs[i].first; - const Fst<Arc> *fst = fst_pairs[i].second; - nonterminal_hash_[label] = fst_array_.size(); - nonterminal_array_.push_back(label); - fst_array_.push_back(fst->Copy()); - } - root_fst_ = nonterminal_hash_[root_label]; - if (!root_fst_) - FSTERROR() << "ReplaceUtil: no root FST for label: " << root_label_; -} - -template <class Arc> -ReplaceUtil<Arc>::ReplaceUtil( - const vector<const Fst<Arc> *> &fst_array, - const NonTerminalHash &nonterminal_hash, Label root_fst, - bool epsilon_on_replace) - : root_fst_(root_fst), - epsilon_on_replace_(epsilon_on_replace), - nonterminal_array_(fst_array.size()), - nonterminal_hash_(nonterminal_hash), - depprops_(0), - have_stats_(false) { - fst_array_.push_back(0); - for (Label i = 1; i < fst_array.size(); ++i) - fst_array_.push_back(fst_array[i]->Copy()); - for (typename NonTerminalHash::const_iterator it = - nonterminal_hash.begin(); it != nonterminal_hash.end(); ++it) - nonterminal_array_[it->second] = it->first; - root_label_ = nonterminal_array_[root_fst_]; -} - -template <class Arc> -void ReplaceUtil<Arc>::GetDependencies(bool stats) const { - if (depfst_.NumStates() > 0) { - if (stats && !have_stats_) - ClearDependencies(); - else - return; - } - - have_stats_ = stats; - if (have_stats_) - stats_.reserve(fst_array_.size()); - - for (Label i = 0; i < fst_array_.size(); ++i) { - depfst_.AddState(); - depfst_.SetFinal(i, Weight::One()); - if (have_stats_) - stats_.push_back(ReplaceStats()); - } - depfst_.SetStart(root_fst_); - - // An arc from each state (representing the fst) to the - // state representing the fst being replaced - for (Label i = 0; i < fst_array_.size(); ++i) { - const Fst<Arc> *ifst = fst_array_[i]; - if (!ifst) - continue; - for (StateIterator<Fst<Arc> > siter(*ifst); !siter.Done(); siter.Next()) { - StateId s = siter.Value(); - if (have_stats_) { - ++stats_[i].nstates; - if (ifst->Final(s) != Weight::Zero()) - ++stats_[i].nfinal; - } - for (ArcIterator<Fst<Arc> > aiter(*ifst, s); - !aiter.Done(); aiter.Next()) { - if (have_stats_) - ++stats_[i].narcs; - const Arc& arc = aiter.Value(); - - typename NonTerminalHash::const_iterator it = - nonterminal_hash_.find(arc.olabel); - if (it != nonterminal_hash_.end()) { - Label j = it->second; - depfst_.AddArc(i, Arc(arc.olabel, arc.olabel, Weight::One(), j)); - if (have_stats_) { - ++stats_[i].nnonterms; - ++stats_[j].nref; - ++stats_[j].inref[i]; - ++stats_[i].outref[j]; - } - } - } - } - } - - // Gets accessibility info - SccVisitor<Arc> scc_visitor(0, &depaccess_, 0, &depprops_); - DfsVisit(depfst_, &scc_visitor); -} - -template <class Arc> -void ReplaceUtil<Arc>::UpdateStats(Label j) { - if (!have_stats_) { - FSTERROR() << "ReplaceUtil::UpdateStats: stats not available"; - return; - } - - if (j == root_fst_) // can't replace root - return; - - typedef typename map<Label, size_t>::iterator Iter; - for (Iter in = stats_[j].inref.begin(); - in != stats_[j].inref.end(); - ++in) { - Label i = in->first; - size_t ni = in->second; - stats_[i].nstates += stats_[j].nstates * ni; - stats_[i].narcs += (stats_[j].narcs + 1) * ni; // narcs - 1 + 2 (eps) - stats_[i].nnonterms += (stats_[j].nnonterms - 1) * ni; - stats_[i].outref.erase(stats_[i].outref.find(j)); - for (Iter out = stats_[j].outref.begin(); - out != stats_[j].outref.end(); - ++out) { - Label k = out->first; - size_t nk = out->second; - stats_[i].outref[k] += ni * nk; - } - } - - for (Iter out = stats_[j].outref.begin(); - out != stats_[j].outref.end(); - ++out) { - Label k = out->first; - size_t nk = out->second; - stats_[k].nref -= nk; - stats_[k].inref.erase(stats_[k].inref.find(j)); - for (Iter in = stats_[j].inref.begin(); - in != stats_[j].inref.end(); - ++in) { - Label i = in->first; - size_t ni = in->second; - stats_[k].inref[i] += ni * nk; - stats_[k].nref += ni * nk; - } - } -} - -template <class Arc> -void ReplaceUtil<Arc>::CheckMutableFsts() { - if (mutable_fst_array_.size() == 0) { - for (Label i = 0; i < fst_array_.size(); ++i) { - if (!fst_array_[i]) { - mutable_fst_array_.push_back(0); - } else { - mutable_fst_array_.push_back(new VectorFst<Arc>(*fst_array_[i])); - delete fst_array_[i]; - fst_array_[i] = mutable_fst_array_[i]; - } - } - } -} - -template <class Arc> -void ReplaceUtil<Arc>::Connect() { - CheckMutableFsts(); - uint64 props = kAccessible | kCoAccessible; - for (Label i = 0; i < mutable_fst_array_.size(); ++i) { - if (!mutable_fst_array_[i]) - continue; - if (mutable_fst_array_[i]->Properties(props, false) != props) - fst::Connect(mutable_fst_array_[i]); - } - GetDependencies(false); - for (Label i = 0; i < mutable_fst_array_.size(); ++i) { - MutableFst<Arc> *fst = mutable_fst_array_[i]; - if (fst && !depaccess_[i]) { - delete fst; - fst_array_[i] = 0; - mutable_fst_array_[i] = 0; - } - } - ClearDependencies(); -} - -template <class Arc> -bool ReplaceUtil<Arc>::GetTopOrder(const Fst<Arc> &fst, - vector<Label> *toporder) const { - // Finds topological order of dependencies. - vector<StateId> order; - bool acyclic = false; - - TopOrderVisitor<Arc> top_order_visitor(&order, &acyclic); - DfsVisit(fst, &top_order_visitor); - if (!acyclic) { - LOG(WARNING) << "ReplaceUtil::GetTopOrder: Cyclical label dependencies"; - return false; - } - - toporder->resize(order.size()); - for (Label i = 0; i < order.size(); ++i) - (*toporder)[order[i]] = i; - - return true; -} - -template <class Arc> -void ReplaceUtil<Arc>::ReplaceLabels(const vector<Label> &labels) { - CheckMutableFsts(); - unordered_set<Label> label_set; - for (Label i = 0; i < labels.size(); ++i) - if (labels[i] != root_label_) // can't replace root - label_set.insert(labels[i]); - - // Finds Fst dependencies restricted to the labels requested. - GetDependencies(false); - VectorFst<Arc> pfst(depfst_); - for (StateId i = 0; i < pfst.NumStates(); ++i) { - vector<Arc> arcs; - for (ArcIterator< VectorFst<Arc> > aiter(pfst, i); - !aiter.Done(); aiter.Next()) { - const Arc &arc = aiter.Value(); - Label label = nonterminal_array_[arc.nextstate]; - if (label_set.count(label) > 0) - arcs.push_back(arc); - } - pfst.DeleteArcs(i); - for (size_t j = 0; j < arcs.size(); ++j) - pfst.AddArc(i, arcs[j]); - } - - vector<Label> toporder; - if (!GetTopOrder(pfst, &toporder)) { - ClearDependencies(); - return; - } - - // Visits Fsts in reverse topological order of dependencies and - // performs replacements. - for (Label o = toporder.size() - 1; o >= 0; --o) { - vector<FstPair> fst_pairs; - StateId s = toporder[o]; - for (ArcIterator< VectorFst<Arc> > aiter(pfst, s); - !aiter.Done(); aiter.Next()) { - const Arc &arc = aiter.Value(); - Label label = nonterminal_array_[arc.nextstate]; - const Fst<Arc> *fst = fst_array_[arc.nextstate]; - fst_pairs.push_back(make_pair(label, fst)); - } - if (fst_pairs.empty()) - continue; - Label label = nonterminal_array_[s]; - const Fst<Arc> *fst = fst_array_[s]; - fst_pairs.push_back(make_pair(label, fst)); - - Replace(fst_pairs, mutable_fst_array_[s], label, epsilon_on_replace_); - } - ClearDependencies(); -} - -template <class Arc> -void ReplaceUtil<Arc>::ReplaceBySize(size_t nstates, size_t narcs, - size_t nnonterms) { - vector<Label> labels; - GetDependencies(true); - - vector<Label> toporder; - if (!GetTopOrder(depfst_, &toporder)) { - ClearDependencies(); - return; - } - - for (Label o = toporder.size() - 1; o >= 0; --o) { - Label j = toporder[o]; - if (stats_[j].nstates <= nstates && - stats_[j].narcs <= narcs && - stats_[j].nnonterms <= nnonterms) { - labels.push_back(nonterminal_array_[j]); - UpdateStats(j); - } - } - ReplaceLabels(labels); -} - -template <class Arc> -void ReplaceUtil<Arc>::ReplaceByInstances(size_t ninstances) { - vector<Label> labels; - GetDependencies(true); - - vector<Label> toporder; - if (!GetTopOrder(depfst_, &toporder)) { - ClearDependencies(); - return; - } - for (Label o = 0; o < toporder.size(); ++o) { - Label j = toporder[o]; - if (stats_[j].nref <= ninstances) { - labels.push_back(nonterminal_array_[j]); - UpdateStats(j); - } - } - ReplaceLabels(labels); -} - -template <class Arc> -void ReplaceUtil<Arc>::GetFstPairs(vector<FstPair> *fst_pairs) { - CheckMutableFsts(); - fst_pairs->clear(); - for (Label i = 0; i < fst_array_.size(); ++i) { - Label label = nonterminal_array_[i]; - const Fst<Arc> *fst = fst_array_[i]; - if (!fst) - continue; - fst_pairs->push_back(make_pair(label, fst)); - } -} - -template <class Arc> -void ReplaceUtil<Arc>::GetMutableFstPairs( - vector<MutableFstPair> *mutable_fst_pairs) { - CheckMutableFsts(); - mutable_fst_pairs->clear(); - for (Label i = 0; i < mutable_fst_array_.size(); ++i) { - Label label = nonterminal_array_[i]; - MutableFst<Arc> *fst = mutable_fst_array_[i]; - if (!fst) - continue; - mutable_fst_pairs->push_back(make_pair(label, fst->Copy())); - } -} - -} // namespace fst - -#endif // FST_LIB_REPLACE_UTIL_H__ |