summaryrefslogtreecommitdiff
path: root/kaldi_io/src/tools/openfst/include/fst/prune.h
diff options
context:
space:
mode:
Diffstat (limited to 'kaldi_io/src/tools/openfst/include/fst/prune.h')
-rw-r--r--kaldi_io/src/tools/openfst/include/fst/prune.h339
1 files changed, 339 insertions, 0 deletions
diff --git a/kaldi_io/src/tools/openfst/include/fst/prune.h b/kaldi_io/src/tools/openfst/include/fst/prune.h
new file mode 100644
index 0000000..5ea5b4d
--- /dev/null
+++ b/kaldi_io/src/tools/openfst/include/fst/prune.h
@@ -0,0 +1,339 @@
+// prune.h
+
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+//
+// Copyright 2005-2010 Google, Inc.
+// Author: allauzen@google.com (Cyril Allauzen)
+//
+// \file
+// Functions implementing pruning.
+
+#ifndef FST_LIB_PRUNE_H__
+#define FST_LIB_PRUNE_H__
+
+#include <vector>
+using std::vector;
+
+#include <fst/arcfilter.h>
+#include <fst/heap.h>
+#include <fst/shortest-distance.h>
+
+
+namespace fst {
+
+template <class A, class ArcFilter>
+class PruneOptions {
+ public:
+ typedef typename A::Weight Weight;
+ typedef typename A::StateId StateId;
+
+ // Pruning weight threshold.
+ Weight weight_threshold;
+ // Pruning state threshold.
+ StateId state_threshold;
+ // Arc filter.
+ ArcFilter filter;
+ // If non-zero, passes in pre-computed shortest distance to final states.
+ const vector<Weight> *distance;
+ // Determines the degree of convergence required when computing shortest
+ // distances.
+ float delta;
+
+ explicit PruneOptions(const Weight& w, StateId s, ArcFilter f,
+ vector<Weight> *d = 0, float e = kDelta)
+ : weight_threshold(w),
+ state_threshold(s),
+ filter(f),
+ distance(d),
+ delta(e) {}
+ private:
+ PruneOptions(); // disallow
+};
+
+
+template <class S, class W>
+class PruneCompare {
+ public:
+ typedef S StateId;
+ typedef W Weight;
+
+ PruneCompare(const vector<Weight> &idistance,
+ const vector<Weight> &fdistance)
+ : idistance_(idistance), fdistance_(fdistance) {}
+
+ bool operator()(const StateId x, const StateId y) const {
+ Weight wx = Times(x < idistance_.size() ? idistance_[x] : Weight::Zero(),
+ x < fdistance_.size() ? fdistance_[x] : Weight::Zero());
+ Weight wy = Times(y < idistance_.size() ? idistance_[y] : Weight::Zero(),
+ y < fdistance_.size() ? fdistance_[y] : Weight::Zero());
+ return less_(wx, wy);
+ }
+
+ private:
+ const vector<Weight> &idistance_;
+ const vector<Weight> &fdistance_;
+ NaturalLess<Weight> less_;
+};
+
+
+
+// Pruning algorithm: this version modifies its input and it takes an
+// options class as an argment. Delete states and arcs in 'fst' that
+// do not belong to a successful path whose weight is no more than
+// the weight of the shortest path Times() 'opts.weight_threshold'.
+// When 'opts.state_threshold != kNoStateId', the resulting transducer
+// will restricted further to have at most 'opts.state_threshold'
+// states. Weights need to be commutative and have the path
+// property. The weight 'w' of any cycle needs to be bounded, i.e.,
+// 'Plus(w, W::One()) = One()'.
+template <class Arc, class ArcFilter>
+void Prune(MutableFst<Arc> *fst,
+ const PruneOptions<Arc, ArcFilter> &opts) {
+ typedef typename Arc::Weight Weight;
+ typedef typename Arc::StateId StateId;
+
+ if ((Weight::Properties() & (kPath | kCommutative))
+ != (kPath | kCommutative)) {
+ FSTERROR() << "Prune: Weight needs to have the path property and"
+ << " be commutative: "
+ << Weight::Type();
+ fst->SetProperties(kError, kError);
+ return;
+ }
+ StateId ns = fst->NumStates();
+ if (ns == 0) return;
+ vector<Weight> idistance(ns, Weight::Zero());
+ vector<Weight> tmp;
+ if (!opts.distance) {
+ tmp.reserve(ns);
+ ShortestDistance(*fst, &tmp, true, opts.delta);
+ }
+ const vector<Weight> *fdistance = opts.distance ? opts.distance : &tmp;
+
+ if ((opts.state_threshold == 0) ||
+ (fdistance->size() <= fst->Start()) ||
+ ((*fdistance)[fst->Start()] == Weight::Zero())) {
+ fst->DeleteStates();
+ return;
+ }
+ PruneCompare<StateId, Weight> compare(idistance, *fdistance);
+ Heap< StateId, PruneCompare<StateId, Weight>, false> heap(compare);
+ vector<bool> visited(ns, false);
+ vector<size_t> enqueued(ns, kNoKey);
+ vector<StateId> dead;
+ dead.push_back(fst->AddState());
+ NaturalLess<Weight> less;
+ Weight limit = Times((*fdistance)[fst->Start()], opts.weight_threshold);
+
+ StateId num_visited = 0;
+ StateId s = fst->Start();
+ if (!less(limit, (*fdistance)[s])) {
+ idistance[s] = Weight::One();
+ enqueued[s] = heap.Insert(s);
+ ++num_visited;
+ }
+
+ while (!heap.Empty()) {
+ s = heap.Top();
+ heap.Pop();
+ enqueued[s] = kNoKey;
+ visited[s] = true;
+ if (less(limit, Times(idistance[s], fst->Final(s))))
+ fst->SetFinal(s, Weight::Zero());
+ for (MutableArcIterator< MutableFst<Arc> > ait(fst, s);
+ !ait.Done();
+ ait.Next()) {
+ Arc arc = ait.Value();
+ if (!opts.filter(arc)) continue;
+ Weight weight = Times(Times(idistance[s], arc.weight),
+ arc.nextstate < fdistance->size()
+ ? (*fdistance)[arc.nextstate]
+ : Weight::Zero());
+ if (less(limit, weight)) {
+ arc.nextstate = dead[0];
+ ait.SetValue(arc);
+ continue;
+ }
+ if (less(Times(idistance[s], arc.weight), idistance[arc.nextstate]))
+ idistance[arc.nextstate] = Times(idistance[s], arc.weight);
+ if (visited[arc.nextstate]) continue;
+ if ((opts.state_threshold != kNoStateId) &&
+ (num_visited >= opts.state_threshold))
+ continue;
+ if (enqueued[arc.nextstate] == kNoKey) {
+ enqueued[arc.nextstate] = heap.Insert(arc.nextstate);
+ ++num_visited;
+ } else {
+ heap.Update(enqueued[arc.nextstate], arc.nextstate);
+ }
+ }
+ }
+ for (size_t i = 0; i < visited.size(); ++i)
+ if (!visited[i]) dead.push_back(i);
+ fst->DeleteStates(dead);
+}
+
+
+// Pruning algorithm: this version modifies its input and simply takes
+// the pruning threshold as an argument. Delete states and arcs in
+// 'fst' that do not belong to a successful path whose weight is no
+// more than the weight of the shortest path Times()
+// 'weight_threshold'. When 'state_threshold != kNoStateId', the
+// resulting transducer will be restricted further to have at most
+// 'opts.state_threshold' states. Weights need to be commutative and
+// have the path property. The weight 'w' of any cycle needs to be
+// bounded, i.e., 'Plus(w, W::One()) = One()'.
+template <class Arc>
+void Prune(MutableFst<Arc> *fst,
+ typename Arc::Weight weight_threshold,
+ typename Arc::StateId state_threshold = kNoStateId,
+ double delta = kDelta) {
+ PruneOptions<Arc, AnyArcFilter<Arc> > opts(weight_threshold, state_threshold,
+ AnyArcFilter<Arc>(), 0, delta);
+ Prune(fst, opts);
+}
+
+
+// Pruning algorithm: this version writes the pruned input Fst to an
+// output MutableFst and it takes an options class as an argument.
+// 'ofst' contains states and arcs that belong to a successful path in
+// 'ifst' whose weight is no more than the weight of the shortest path
+// Times() 'opts.weight_threshold'. When 'opts.state_threshold !=
+// kNoStateId', 'ofst' will be restricted further to have at most
+// 'opts.state_threshold' states. Weights need to be commutative and
+// have the path property. The weight 'w' of any cycle needs to be
+// bounded, i.e., 'Plus(w, W::One()) = One()'.
+template <class Arc, class ArcFilter>
+void Prune(const Fst<Arc> &ifst,
+ MutableFst<Arc> *ofst,
+ const PruneOptions<Arc, ArcFilter> &opts) {
+ typedef typename Arc::Weight Weight;
+ typedef typename Arc::StateId StateId;
+
+ if ((Weight::Properties() & (kPath | kCommutative))
+ != (kPath | kCommutative)) {
+ FSTERROR() << "Prune: Weight needs to have the path property and"
+ << " be commutative: "
+ << Weight::Type();
+ ofst->SetProperties(kError, kError);
+ return;
+ }
+ ofst->DeleteStates();
+ ofst->SetInputSymbols(ifst.InputSymbols());
+ ofst->SetOutputSymbols(ifst.OutputSymbols());
+ if (ifst.Start() == kNoStateId)
+ return;
+ NaturalLess<Weight> less;
+ if (less(opts.weight_threshold, Weight::One()) ||
+ (opts.state_threshold == 0))
+ return;
+ vector<Weight> idistance;
+ vector<Weight> tmp;
+ if (!opts.distance)
+ ShortestDistance(ifst, &tmp, true, opts.delta);
+ const vector<Weight> *fdistance = opts.distance ? opts.distance : &tmp;
+
+ if ((fdistance->size() <= ifst.Start()) ||
+ ((*fdistance)[ifst.Start()] == Weight::Zero())) {
+ return;
+ }
+ PruneCompare<StateId, Weight> compare(idistance, *fdistance);
+ Heap< StateId, PruneCompare<StateId, Weight>, false> heap(compare);
+ vector<StateId> copy;
+ vector<size_t> enqueued;
+ vector<bool> visited;
+
+ StateId s = ifst.Start();
+ Weight limit = Times(s < fdistance->size() ? (*fdistance)[s] : Weight::Zero(),
+ opts.weight_threshold);
+ while (copy.size() <= s)
+ copy.push_back(kNoStateId);
+ copy[s] = ofst->AddState();
+ ofst->SetStart(copy[s]);
+ while (idistance.size() <= s)
+ idistance.push_back(Weight::Zero());
+ idistance[s] = Weight::One();
+ while (enqueued.size() <= s) {
+ enqueued.push_back(kNoKey);
+ visited.push_back(false);
+ }
+ enqueued[s] = heap.Insert(s);
+
+ while (!heap.Empty()) {
+ s = heap.Top();
+ heap.Pop();
+ enqueued[s] = kNoKey;
+ visited[s] = true;
+ if (!less(limit, Times(idistance[s], ifst.Final(s))))
+ ofst->SetFinal(copy[s], ifst.Final(s));
+ for (ArcIterator< Fst<Arc> > ait(ifst, s);
+ !ait.Done();
+ ait.Next()) {
+ const Arc &arc = ait.Value();
+ if (!opts.filter(arc)) continue;
+ Weight weight = Times(Times(idistance[s], arc.weight),
+ arc.nextstate < fdistance->size()
+ ? (*fdistance)[arc.nextstate]
+ : Weight::Zero());
+ if (less(limit, weight)) continue;
+ if ((opts.state_threshold != kNoStateId) &&
+ (ofst->NumStates() >= opts.state_threshold))
+ continue;
+ while (idistance.size() <= arc.nextstate)
+ idistance.push_back(Weight::Zero());
+ if (less(Times(idistance[s], arc.weight),
+ idistance[arc.nextstate]))
+ idistance[arc.nextstate] = Times(idistance[s], arc.weight);
+ while (copy.size() <= arc.nextstate)
+ copy.push_back(kNoStateId);
+ if (copy[arc.nextstate] == kNoStateId)
+ copy[arc.nextstate] = ofst->AddState();
+ ofst->AddArc(copy[s], Arc(arc.ilabel, arc.olabel, arc.weight,
+ copy[arc.nextstate]));
+ while (enqueued.size() <= arc.nextstate) {
+ enqueued.push_back(kNoKey);
+ visited.push_back(false);
+ }
+ if (visited[arc.nextstate]) continue;
+ if (enqueued[arc.nextstate] == kNoKey)
+ enqueued[arc.nextstate] = heap.Insert(arc.nextstate);
+ else
+ heap.Update(enqueued[arc.nextstate], arc.nextstate);
+ }
+ }
+}
+
+
+// Pruning algorithm: this version writes the pruned input Fst to an
+// output MutableFst and simply takes the pruning threshold as an
+// argument. 'ofst' contains states and arcs that belong to a
+// successful path in 'ifst' whose weight is no more than
+// the weight of the shortest path Times() 'weight_threshold'. When
+// 'state_threshold != kNoStateId', 'ofst' will be restricted further
+// to have at most 'opts.state_threshold' states. Weights need to be
+// commutative and have the path property. The weight 'w' of any cycle
+// needs to be bounded, i.e., 'Plus(w, W::One()) = W::One()'.
+template <class Arc>
+void Prune(const Fst<Arc> &ifst,
+ MutableFst<Arc> *ofst,
+ typename Arc::Weight weight_threshold,
+ typename Arc::StateId state_threshold = kNoStateId,
+ float delta = kDelta) {
+ PruneOptions<Arc, AnyArcFilter<Arc> > opts(weight_threshold, state_threshold,
+ AnyArcFilter<Arc>(), 0, delta);
+ Prune(ifst, ofst, opts);
+}
+
+} // namespace fst
+
+#endif // FST_LIB_PRUNE_H_