summaryrefslogtreecommitdiff
path: root/kaldi_io/src/tools/openfst/include/fst/replace-util.h
diff options
context:
space:
mode:
Diffstat (limited to 'kaldi_io/src/tools/openfst/include/fst/replace-util.h')
-rw-r--r--kaldi_io/src/tools/openfst/include/fst/replace-util.h550
1 files changed, 0 insertions, 550 deletions
diff --git a/kaldi_io/src/tools/openfst/include/fst/replace-util.h b/kaldi_io/src/tools/openfst/include/fst/replace-util.h
deleted file mode 100644
index d58cb15..0000000
--- a/kaldi_io/src/tools/openfst/include/fst/replace-util.h
+++ /dev/null
@@ -1,550 +0,0 @@
-// replace-util.h
-
-
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-//
-// Copyright 2005-2010 Google, Inc.
-// Author: riley@google.com (Michael Riley)
-//
-
-// \file
-// Utility classes for the recursive replacement of Fsts (RTNs).
-
-#ifndef FST_LIB_REPLACE_UTIL_H__
-#define FST_LIB_REPLACE_UTIL_H__
-
-#include <vector>
-using std::vector;
-#include <tr1/unordered_map>
-using std::tr1::unordered_map;
-using std::tr1::unordered_multimap;
-#include <tr1/unordered_set>
-using std::tr1::unordered_set;
-using std::tr1::unordered_multiset;
-#include <map>
-
-#include <fst/connect.h>
-#include <fst/mutable-fst.h>
-#include <fst/topsort.h>
-
-
-namespace fst {
-
-template <class Arc>
-void Replace(const vector<pair<typename Arc::Label, const Fst<Arc>* > >&,
- MutableFst<Arc> *, typename Arc::Label, bool);
-
-
-// Utility class for the recursive replacement of Fsts (RTNs). The
-// user provides a set of Label, Fst pairs at construction. These are
-// used by methods for testing cyclic dependencies and connectedness
-// and doing RTN connection and specific Fst replacement by label or
-// for various optimization properties. The modified results can be
-// obtained with the GetFstPairs() or GetMutableFstPairs() methods.
-template <class Arc>
-class ReplaceUtil {
- public:
- typedef typename Arc::Label Label;
- typedef typename Arc::Weight Weight;
- typedef typename Arc::StateId StateId;
-
- typedef pair<Label, const Fst<Arc>*> FstPair;
- typedef pair<Label, MutableFst<Arc>*> MutableFstPair;
- typedef unordered_map<Label, Label> NonTerminalHash;
-
- // Constructs from mutable Fsts; Fst ownership given to ReplaceUtil.
- ReplaceUtil(const vector<MutableFstPair> &fst_pairs,
- Label root_label, bool epsilon_on_replace = false);
-
- // Constructs from Fsts; Fst ownership retained by caller.
- ReplaceUtil(const vector<FstPair> &fst_pairs,
- Label root_label, bool epsilon_on_replace = false);
-
- // Constructs from ReplaceFst internals; ownership retained by caller.
- ReplaceUtil(const vector<const Fst<Arc> *> &fst_array,
- const NonTerminalHash &nonterminal_hash, Label root_fst,
- bool epsilon_on_replace = false);
-
- ~ReplaceUtil() {
- for (Label i = 0; i < fst_array_.size(); ++i)
- delete fst_array_[i];
- }
-
- // True if the non-terminal dependencies are cyclic. Cyclic
- // dependencies will result in an unexpandable replace fst.
- bool CyclicDependencies() const {
- GetDependencies(false);
- return depprops_ & kCyclic;
- }
-
- // Returns true if no useless Fsts, states or transitions.
- bool Connected() const {
- GetDependencies(false);
- uint64 props = kAccessible | kCoAccessible;
- for (Label i = 0; i < fst_array_.size(); ++i) {
- if (!fst_array_[i])
- continue;
- if (fst_array_[i]->Properties(props, true) != props || !depaccess_[i])
- return false;
- }
- return true;
- }
-
- // Removes useless Fsts, states and transitions.
- void Connect();
-
- // Replaces Fsts specified by labels.
- // Does nothing if there are cyclic dependencies.
- void ReplaceLabels(const vector<Label> &labels);
-
- // Replaces Fsts that have at most 'nstates' states, 'narcs' arcs and
- // 'nnonterm' non-terminals (updating in reverse dependency order).
- // Does nothing if there are cyclic dependencies.
- void ReplaceBySize(size_t nstates, size_t narcs, size_t nnonterms);
-
- // Replaces singleton Fsts.
- // Does nothing if there are cyclic dependencies.
- void ReplaceTrivial() { ReplaceBySize(2, 1, 1); }
-
- // Replaces non-terminals that have at most 'ninstances' instances
- // (updating in dependency order).
- // Does nothing if there are cyclic dependencies.
- void ReplaceByInstances(size_t ninstances);
-
- // Replaces non-terminals that have only one instance.
- // Does nothing if there are cyclic dependencies.
- void ReplaceUnique() { ReplaceByInstances(1); }
-
- // Returns Label, Fst pairs; Fst ownership retained by ReplaceUtil.
- void GetFstPairs(vector<FstPair> *fst_pairs);
-
- // Returns Label, MutableFst pairs; Fst ownership given to caller.
- void GetMutableFstPairs(vector<MutableFstPair> *mutable_fst_pairs);
-
- private:
- // Per Fst statistics
- struct ReplaceStats {
- StateId nstates; // # of states
- StateId nfinal; // # of final states
- size_t narcs; // # of arcs
- Label nnonterms; // # of non-terminals in Fst
- size_t nref; // # of non-terminal instances referring to this Fst
-
- // # of times that ith Fst references this Fst
- map<Label, size_t> inref;
- // # of times that this Fst references the ith Fst
- map<Label, size_t> outref;
-
- ReplaceStats()
- : nstates(0),
- nfinal(0),
- narcs(0),
- nnonterms(0),
- nref(0) {}
- };
-
- // Check Mutable Fsts exist o.w. create them.
- void CheckMutableFsts();
-
- // Computes the dependency graph of the replace Fsts.
- // If 'stats' is true, dependency statistics computed as well.
- void GetDependencies(bool stats) const;
-
- void ClearDependencies() const {
- depfst_.DeleteStates();
- stats_.clear();
- depprops_ = 0;
- have_stats_ = false;
- }
-
- // Get topological order of dependencies. Returns false with cyclic input.
- bool GetTopOrder(const Fst<Arc> &fst, vector<Label> *toporder) const;
-
- // Update statistics assuming that jth Fst will be replaced.
- void UpdateStats(Label j);
-
- Label root_label_; // root non-terminal
- Label root_fst_; // root Fst ID
- bool epsilon_on_replace_; // see Replace()
- vector<const Fst<Arc> *> fst_array_; // Fst per ID
- vector<MutableFst<Arc> *> mutable_fst_array_; // MutableFst per ID
- vector<Label> nonterminal_array_; // Fst ID to non-terminal
- NonTerminalHash nonterminal_hash_; // non-terminal to Fst ID
- mutable VectorFst<Arc> depfst_; // Fst ID dependencies
- mutable vector<bool> depaccess_; // Fst ID accessibility
- mutable uint64 depprops_; // dependency Fst props
- mutable bool have_stats_; // have dependency statistics
- mutable vector<ReplaceStats> stats_; // Per Fst statistics
- DISALLOW_COPY_AND_ASSIGN(ReplaceUtil);
-};
-
-template <class Arc>
-ReplaceUtil<Arc>::ReplaceUtil(
- const vector<MutableFstPair> &fst_pairs,
- Label root_label, bool epsilon_on_replace)
- : root_label_(root_label),
- epsilon_on_replace_(epsilon_on_replace),
- depprops_(0),
- have_stats_(false) {
- fst_array_.push_back(0);
- mutable_fst_array_.push_back(0);
- nonterminal_array_.push_back(kNoLabel);
- for (Label i = 0; i < fst_pairs.size(); ++i) {
- Label label = fst_pairs[i].first;
- MutableFst<Arc> *fst = fst_pairs[i].second;
- nonterminal_hash_[label] = fst_array_.size();
- nonterminal_array_.push_back(label);
- fst_array_.push_back(fst);
- mutable_fst_array_.push_back(fst);
- }
- root_fst_ = nonterminal_hash_[root_label_];
- if (!root_fst_)
- FSTERROR() << "ReplaceUtil: no root FST for label: " << root_label_;
-}
-
-template <class Arc>
-ReplaceUtil<Arc>::ReplaceUtil(
- const vector<FstPair> &fst_pairs,
- Label root_label, bool epsilon_on_replace)
- : root_label_(root_label),
- epsilon_on_replace_(epsilon_on_replace),
- depprops_(0),
- have_stats_(false) {
- fst_array_.push_back(0);
- nonterminal_array_.push_back(kNoLabel);
- for (Label i = 0; i < fst_pairs.size(); ++i) {
- Label label = fst_pairs[i].first;
- const Fst<Arc> *fst = fst_pairs[i].second;
- nonterminal_hash_[label] = fst_array_.size();
- nonterminal_array_.push_back(label);
- fst_array_.push_back(fst->Copy());
- }
- root_fst_ = nonterminal_hash_[root_label];
- if (!root_fst_)
- FSTERROR() << "ReplaceUtil: no root FST for label: " << root_label_;
-}
-
-template <class Arc>
-ReplaceUtil<Arc>::ReplaceUtil(
- const vector<const Fst<Arc> *> &fst_array,
- const NonTerminalHash &nonterminal_hash, Label root_fst,
- bool epsilon_on_replace)
- : root_fst_(root_fst),
- epsilon_on_replace_(epsilon_on_replace),
- nonterminal_array_(fst_array.size()),
- nonterminal_hash_(nonterminal_hash),
- depprops_(0),
- have_stats_(false) {
- fst_array_.push_back(0);
- for (Label i = 1; i < fst_array.size(); ++i)
- fst_array_.push_back(fst_array[i]->Copy());
- for (typename NonTerminalHash::const_iterator it =
- nonterminal_hash.begin(); it != nonterminal_hash.end(); ++it)
- nonterminal_array_[it->second] = it->first;
- root_label_ = nonterminal_array_[root_fst_];
-}
-
-template <class Arc>
-void ReplaceUtil<Arc>::GetDependencies(bool stats) const {
- if (depfst_.NumStates() > 0) {
- if (stats && !have_stats_)
- ClearDependencies();
- else
- return;
- }
-
- have_stats_ = stats;
- if (have_stats_)
- stats_.reserve(fst_array_.size());
-
- for (Label i = 0; i < fst_array_.size(); ++i) {
- depfst_.AddState();
- depfst_.SetFinal(i, Weight::One());
- if (have_stats_)
- stats_.push_back(ReplaceStats());
- }
- depfst_.SetStart(root_fst_);
-
- // An arc from each state (representing the fst) to the
- // state representing the fst being replaced
- for (Label i = 0; i < fst_array_.size(); ++i) {
- const Fst<Arc> *ifst = fst_array_[i];
- if (!ifst)
- continue;
- for (StateIterator<Fst<Arc> > siter(*ifst); !siter.Done(); siter.Next()) {
- StateId s = siter.Value();
- if (have_stats_) {
- ++stats_[i].nstates;
- if (ifst->Final(s) != Weight::Zero())
- ++stats_[i].nfinal;
- }
- for (ArcIterator<Fst<Arc> > aiter(*ifst, s);
- !aiter.Done(); aiter.Next()) {
- if (have_stats_)
- ++stats_[i].narcs;
- const Arc& arc = aiter.Value();
-
- typename NonTerminalHash::const_iterator it =
- nonterminal_hash_.find(arc.olabel);
- if (it != nonterminal_hash_.end()) {
- Label j = it->second;
- depfst_.AddArc(i, Arc(arc.olabel, arc.olabel, Weight::One(), j));
- if (have_stats_) {
- ++stats_[i].nnonterms;
- ++stats_[j].nref;
- ++stats_[j].inref[i];
- ++stats_[i].outref[j];
- }
- }
- }
- }
- }
-
- // Gets accessibility info
- SccVisitor<Arc> scc_visitor(0, &depaccess_, 0, &depprops_);
- DfsVisit(depfst_, &scc_visitor);
-}
-
-template <class Arc>
-void ReplaceUtil<Arc>::UpdateStats(Label j) {
- if (!have_stats_) {
- FSTERROR() << "ReplaceUtil::UpdateStats: stats not available";
- return;
- }
-
- if (j == root_fst_) // can't replace root
- return;
-
- typedef typename map<Label, size_t>::iterator Iter;
- for (Iter in = stats_[j].inref.begin();
- in != stats_[j].inref.end();
- ++in) {
- Label i = in->first;
- size_t ni = in->second;
- stats_[i].nstates += stats_[j].nstates * ni;
- stats_[i].narcs += (stats_[j].narcs + 1) * ni; // narcs - 1 + 2 (eps)
- stats_[i].nnonterms += (stats_[j].nnonterms - 1) * ni;
- stats_[i].outref.erase(stats_[i].outref.find(j));
- for (Iter out = stats_[j].outref.begin();
- out != stats_[j].outref.end();
- ++out) {
- Label k = out->first;
- size_t nk = out->second;
- stats_[i].outref[k] += ni * nk;
- }
- }
-
- for (Iter out = stats_[j].outref.begin();
- out != stats_[j].outref.end();
- ++out) {
- Label k = out->first;
- size_t nk = out->second;
- stats_[k].nref -= nk;
- stats_[k].inref.erase(stats_[k].inref.find(j));
- for (Iter in = stats_[j].inref.begin();
- in != stats_[j].inref.end();
- ++in) {
- Label i = in->first;
- size_t ni = in->second;
- stats_[k].inref[i] += ni * nk;
- stats_[k].nref += ni * nk;
- }
- }
-}
-
-template <class Arc>
-void ReplaceUtil<Arc>::CheckMutableFsts() {
- if (mutable_fst_array_.size() == 0) {
- for (Label i = 0; i < fst_array_.size(); ++i) {
- if (!fst_array_[i]) {
- mutable_fst_array_.push_back(0);
- } else {
- mutable_fst_array_.push_back(new VectorFst<Arc>(*fst_array_[i]));
- delete fst_array_[i];
- fst_array_[i] = mutable_fst_array_[i];
- }
- }
- }
-}
-
-template <class Arc>
-void ReplaceUtil<Arc>::Connect() {
- CheckMutableFsts();
- uint64 props = kAccessible | kCoAccessible;
- for (Label i = 0; i < mutable_fst_array_.size(); ++i) {
- if (!mutable_fst_array_[i])
- continue;
- if (mutable_fst_array_[i]->Properties(props, false) != props)
- fst::Connect(mutable_fst_array_[i]);
- }
- GetDependencies(false);
- for (Label i = 0; i < mutable_fst_array_.size(); ++i) {
- MutableFst<Arc> *fst = mutable_fst_array_[i];
- if (fst && !depaccess_[i]) {
- delete fst;
- fst_array_[i] = 0;
- mutable_fst_array_[i] = 0;
- }
- }
- ClearDependencies();
-}
-
-template <class Arc>
-bool ReplaceUtil<Arc>::GetTopOrder(const Fst<Arc> &fst,
- vector<Label> *toporder) const {
- // Finds topological order of dependencies.
- vector<StateId> order;
- bool acyclic = false;
-
- TopOrderVisitor<Arc> top_order_visitor(&order, &acyclic);
- DfsVisit(fst, &top_order_visitor);
- if (!acyclic) {
- LOG(WARNING) << "ReplaceUtil::GetTopOrder: Cyclical label dependencies";
- return false;
- }
-
- toporder->resize(order.size());
- for (Label i = 0; i < order.size(); ++i)
- (*toporder)[order[i]] = i;
-
- return true;
-}
-
-template <class Arc>
-void ReplaceUtil<Arc>::ReplaceLabels(const vector<Label> &labels) {
- CheckMutableFsts();
- unordered_set<Label> label_set;
- for (Label i = 0; i < labels.size(); ++i)
- if (labels[i] != root_label_) // can't replace root
- label_set.insert(labels[i]);
-
- // Finds Fst dependencies restricted to the labels requested.
- GetDependencies(false);
- VectorFst<Arc> pfst(depfst_);
- for (StateId i = 0; i < pfst.NumStates(); ++i) {
- vector<Arc> arcs;
- for (ArcIterator< VectorFst<Arc> > aiter(pfst, i);
- !aiter.Done(); aiter.Next()) {
- const Arc &arc = aiter.Value();
- Label label = nonterminal_array_[arc.nextstate];
- if (label_set.count(label) > 0)
- arcs.push_back(arc);
- }
- pfst.DeleteArcs(i);
- for (size_t j = 0; j < arcs.size(); ++j)
- pfst.AddArc(i, arcs[j]);
- }
-
- vector<Label> toporder;
- if (!GetTopOrder(pfst, &toporder)) {
- ClearDependencies();
- return;
- }
-
- // Visits Fsts in reverse topological order of dependencies and
- // performs replacements.
- for (Label o = toporder.size() - 1; o >= 0; --o) {
- vector<FstPair> fst_pairs;
- StateId s = toporder[o];
- for (ArcIterator< VectorFst<Arc> > aiter(pfst, s);
- !aiter.Done(); aiter.Next()) {
- const Arc &arc = aiter.Value();
- Label label = nonterminal_array_[arc.nextstate];
- const Fst<Arc> *fst = fst_array_[arc.nextstate];
- fst_pairs.push_back(make_pair(label, fst));
- }
- if (fst_pairs.empty())
- continue;
- Label label = nonterminal_array_[s];
- const Fst<Arc> *fst = fst_array_[s];
- fst_pairs.push_back(make_pair(label, fst));
-
- Replace(fst_pairs, mutable_fst_array_[s], label, epsilon_on_replace_);
- }
- ClearDependencies();
-}
-
-template <class Arc>
-void ReplaceUtil<Arc>::ReplaceBySize(size_t nstates, size_t narcs,
- size_t nnonterms) {
- vector<Label> labels;
- GetDependencies(true);
-
- vector<Label> toporder;
- if (!GetTopOrder(depfst_, &toporder)) {
- ClearDependencies();
- return;
- }
-
- for (Label o = toporder.size() - 1; o >= 0; --o) {
- Label j = toporder[o];
- if (stats_[j].nstates <= nstates &&
- stats_[j].narcs <= narcs &&
- stats_[j].nnonterms <= nnonterms) {
- labels.push_back(nonterminal_array_[j]);
- UpdateStats(j);
- }
- }
- ReplaceLabels(labels);
-}
-
-template <class Arc>
-void ReplaceUtil<Arc>::ReplaceByInstances(size_t ninstances) {
- vector<Label> labels;
- GetDependencies(true);
-
- vector<Label> toporder;
- if (!GetTopOrder(depfst_, &toporder)) {
- ClearDependencies();
- return;
- }
- for (Label o = 0; o < toporder.size(); ++o) {
- Label j = toporder[o];
- if (stats_[j].nref <= ninstances) {
- labels.push_back(nonterminal_array_[j]);
- UpdateStats(j);
- }
- }
- ReplaceLabels(labels);
-}
-
-template <class Arc>
-void ReplaceUtil<Arc>::GetFstPairs(vector<FstPair> *fst_pairs) {
- CheckMutableFsts();
- fst_pairs->clear();
- for (Label i = 0; i < fst_array_.size(); ++i) {
- Label label = nonterminal_array_[i];
- const Fst<Arc> *fst = fst_array_[i];
- if (!fst)
- continue;
- fst_pairs->push_back(make_pair(label, fst));
- }
-}
-
-template <class Arc>
-void ReplaceUtil<Arc>::GetMutableFstPairs(
- vector<MutableFstPair> *mutable_fst_pairs) {
- CheckMutableFsts();
- mutable_fst_pairs->clear();
- for (Label i = 0; i < mutable_fst_array_.size(); ++i) {
- Label label = nonterminal_array_[i];
- MutableFst<Arc> *fst = mutable_fst_array_[i];
- if (!fst)
- continue;
- mutable_fst_pairs->push_back(make_pair(label, fst->Copy()));
- }
-}
-
-} // namespace fst
-
-#endif // FST_LIB_REPLACE_UTIL_H__