summaryrefslogtreecommitdiff
path: root/kaldi_io/src/tools/openfst/include/fst/prune.h
diff options
context:
space:
mode:
Diffstat (limited to 'kaldi_io/src/tools/openfst/include/fst/prune.h')
-rw-r--r--kaldi_io/src/tools/openfst/include/fst/prune.h339
1 files changed, 0 insertions, 339 deletions
diff --git a/kaldi_io/src/tools/openfst/include/fst/prune.h b/kaldi_io/src/tools/openfst/include/fst/prune.h
deleted file mode 100644
index 5ea5b4d..0000000
--- a/kaldi_io/src/tools/openfst/include/fst/prune.h
+++ /dev/null
@@ -1,339 +0,0 @@
-// prune.h
-
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-//
-// Copyright 2005-2010 Google, Inc.
-// Author: [email protected] (Cyril Allauzen)
-//
-// \file
-// Functions implementing pruning.
-
-#ifndef FST_LIB_PRUNE_H__
-#define FST_LIB_PRUNE_H__
-
-#include <vector>
-using std::vector;
-
-#include <fst/arcfilter.h>
-#include <fst/heap.h>
-#include <fst/shortest-distance.h>
-
-
-namespace fst {
-
-template <class A, class ArcFilter>
-class PruneOptions {
- public:
- typedef typename A::Weight Weight;
- typedef typename A::StateId StateId;
-
- // Pruning weight threshold.
- Weight weight_threshold;
- // Pruning state threshold.
- StateId state_threshold;
- // Arc filter.
- ArcFilter filter;
- // If non-zero, passes in pre-computed shortest distance to final states.
- const vector<Weight> *distance;
- // Determines the degree of convergence required when computing shortest
- // distances.
- float delta;
-
- explicit PruneOptions(const Weight& w, StateId s, ArcFilter f,
- vector<Weight> *d = 0, float e = kDelta)
- : weight_threshold(w),
- state_threshold(s),
- filter(f),
- distance(d),
- delta(e) {}
- private:
- PruneOptions(); // disallow
-};
-
-
-template <class S, class W>
-class PruneCompare {
- public:
- typedef S StateId;
- typedef W Weight;
-
- PruneCompare(const vector<Weight> &idistance,
- const vector<Weight> &fdistance)
- : idistance_(idistance), fdistance_(fdistance) {}
-
- bool operator()(const StateId x, const StateId y) const {
- Weight wx = Times(x < idistance_.size() ? idistance_[x] : Weight::Zero(),
- x < fdistance_.size() ? fdistance_[x] : Weight::Zero());
- Weight wy = Times(y < idistance_.size() ? idistance_[y] : Weight::Zero(),
- y < fdistance_.size() ? fdistance_[y] : Weight::Zero());
- return less_(wx, wy);
- }
-
- private:
- const vector<Weight> &idistance_;
- const vector<Weight> &fdistance_;
- NaturalLess<Weight> less_;
-};
-
-
-
-// Pruning algorithm: this version modifies its input and it takes an
-// options class as an argment. Delete states and arcs in 'fst' that
-// do not belong to a successful path whose weight is no more than
-// the weight of the shortest path Times() 'opts.weight_threshold'.
-// When 'opts.state_threshold != kNoStateId', the resulting transducer
-// will restricted further to have at most 'opts.state_threshold'
-// states. Weights need to be commutative and have the path
-// property. The weight 'w' of any cycle needs to be bounded, i.e.,
-// 'Plus(w, W::One()) = One()'.
-template <class Arc, class ArcFilter>
-void Prune(MutableFst<Arc> *fst,
- const PruneOptions<Arc, ArcFilter> &opts) {
- typedef typename Arc::Weight Weight;
- typedef typename Arc::StateId StateId;
-
- if ((Weight::Properties() & (kPath | kCommutative))
- != (kPath | kCommutative)) {
- FSTERROR() << "Prune: Weight needs to have the path property and"
- << " be commutative: "
- << Weight::Type();
- fst->SetProperties(kError, kError);
- return;
- }
- StateId ns = fst->NumStates();
- if (ns == 0) return;
- vector<Weight> idistance(ns, Weight::Zero());
- vector<Weight> tmp;
- if (!opts.distance) {
- tmp.reserve(ns);
- ShortestDistance(*fst, &tmp, true, opts.delta);
- }
- const vector<Weight> *fdistance = opts.distance ? opts.distance : &tmp;
-
- if ((opts.state_threshold == 0) ||
- (fdistance->size() <= fst->Start()) ||
- ((*fdistance)[fst->Start()] == Weight::Zero())) {
- fst->DeleteStates();
- return;
- }
- PruneCompare<StateId, Weight> compare(idistance, *fdistance);
- Heap< StateId, PruneCompare<StateId, Weight>, false> heap(compare);
- vector<bool> visited(ns, false);
- vector<size_t> enqueued(ns, kNoKey);
- vector<StateId> dead;
- dead.push_back(fst->AddState());
- NaturalLess<Weight> less;
- Weight limit = Times((*fdistance)[fst->Start()], opts.weight_threshold);
-
- StateId num_visited = 0;
- StateId s = fst->Start();
- if (!less(limit, (*fdistance)[s])) {
- idistance[s] = Weight::One();
- enqueued[s] = heap.Insert(s);
- ++num_visited;
- }
-
- while (!heap.Empty()) {
- s = heap.Top();
- heap.Pop();
- enqueued[s] = kNoKey;
- visited[s] = true;
- if (less(limit, Times(idistance[s], fst->Final(s))))
- fst->SetFinal(s, Weight::Zero());
- for (MutableArcIterator< MutableFst<Arc> > ait(fst, s);
- !ait.Done();
- ait.Next()) {
- Arc arc = ait.Value();
- if (!opts.filter(arc)) continue;
- Weight weight = Times(Times(idistance[s], arc.weight),
- arc.nextstate < fdistance->size()
- ? (*fdistance)[arc.nextstate]
- : Weight::Zero());
- if (less(limit, weight)) {
- arc.nextstate = dead[0];
- ait.SetValue(arc);
- continue;
- }
- if (less(Times(idistance[s], arc.weight), idistance[arc.nextstate]))
- idistance[arc.nextstate] = Times(idistance[s], arc.weight);
- if (visited[arc.nextstate]) continue;
- if ((opts.state_threshold != kNoStateId) &&
- (num_visited >= opts.state_threshold))
- continue;
- if (enqueued[arc.nextstate] == kNoKey) {
- enqueued[arc.nextstate] = heap.Insert(arc.nextstate);
- ++num_visited;
- } else {
- heap.Update(enqueued[arc.nextstate], arc.nextstate);
- }
- }
- }
- for (size_t i = 0; i < visited.size(); ++i)
- if (!visited[i]) dead.push_back(i);
- fst->DeleteStates(dead);
-}
-
-
-// Pruning algorithm: this version modifies its input and simply takes
-// the pruning threshold as an argument. Delete states and arcs in
-// 'fst' that do not belong to a successful path whose weight is no
-// more than the weight of the shortest path Times()
-// 'weight_threshold'. When 'state_threshold != kNoStateId', the
-// resulting transducer will be restricted further to have at most
-// 'opts.state_threshold' states. Weights need to be commutative and
-// have the path property. The weight 'w' of any cycle needs to be
-// bounded, i.e., 'Plus(w, W::One()) = One()'.
-template <class Arc>
-void Prune(MutableFst<Arc> *fst,
- typename Arc::Weight weight_threshold,
- typename Arc::StateId state_threshold = kNoStateId,
- double delta = kDelta) {
- PruneOptions<Arc, AnyArcFilter<Arc> > opts(weight_threshold, state_threshold,
- AnyArcFilter<Arc>(), 0, delta);
- Prune(fst, opts);
-}
-
-
-// Pruning algorithm: this version writes the pruned input Fst to an
-// output MutableFst and it takes an options class as an argument.
-// 'ofst' contains states and arcs that belong to a successful path in
-// 'ifst' whose weight is no more than the weight of the shortest path
-// Times() 'opts.weight_threshold'. When 'opts.state_threshold !=
-// kNoStateId', 'ofst' will be restricted further to have at most
-// 'opts.state_threshold' states. Weights need to be commutative and
-// have the path property. The weight 'w' of any cycle needs to be
-// bounded, i.e., 'Plus(w, W::One()) = One()'.
-template <class Arc, class ArcFilter>
-void Prune(const Fst<Arc> &ifst,
- MutableFst<Arc> *ofst,
- const PruneOptions<Arc, ArcFilter> &opts) {
- typedef typename Arc::Weight Weight;
- typedef typename Arc::StateId StateId;
-
- if ((Weight::Properties() & (kPath | kCommutative))
- != (kPath | kCommutative)) {
- FSTERROR() << "Prune: Weight needs to have the path property and"
- << " be commutative: "
- << Weight::Type();
- ofst->SetProperties(kError, kError);
- return;
- }
- ofst->DeleteStates();
- ofst->SetInputSymbols(ifst.InputSymbols());
- ofst->SetOutputSymbols(ifst.OutputSymbols());
- if (ifst.Start() == kNoStateId)
- return;
- NaturalLess<Weight> less;
- if (less(opts.weight_threshold, Weight::One()) ||
- (opts.state_threshold == 0))
- return;
- vector<Weight> idistance;
- vector<Weight> tmp;
- if (!opts.distance)
- ShortestDistance(ifst, &tmp, true, opts.delta);
- const vector<Weight> *fdistance = opts.distance ? opts.distance : &tmp;
-
- if ((fdistance->size() <= ifst.Start()) ||
- ((*fdistance)[ifst.Start()] == Weight::Zero())) {
- return;
- }
- PruneCompare<StateId, Weight> compare(idistance, *fdistance);
- Heap< StateId, PruneCompare<StateId, Weight>, false> heap(compare);
- vector<StateId> copy;
- vector<size_t> enqueued;
- vector<bool> visited;
-
- StateId s = ifst.Start();
- Weight limit = Times(s < fdistance->size() ? (*fdistance)[s] : Weight::Zero(),
- opts.weight_threshold);
- while (copy.size() <= s)
- copy.push_back(kNoStateId);
- copy[s] = ofst->AddState();
- ofst->SetStart(copy[s]);
- while (idistance.size() <= s)
- idistance.push_back(Weight::Zero());
- idistance[s] = Weight::One();
- while (enqueued.size() <= s) {
- enqueued.push_back(kNoKey);
- visited.push_back(false);
- }
- enqueued[s] = heap.Insert(s);
-
- while (!heap.Empty()) {
- s = heap.Top();
- heap.Pop();
- enqueued[s] = kNoKey;
- visited[s] = true;
- if (!less(limit, Times(idistance[s], ifst.Final(s))))
- ofst->SetFinal(copy[s], ifst.Final(s));
- for (ArcIterator< Fst<Arc> > ait(ifst, s);
- !ait.Done();
- ait.Next()) {
- const Arc &arc = ait.Value();
- if (!opts.filter(arc)) continue;
- Weight weight = Times(Times(idistance[s], arc.weight),
- arc.nextstate < fdistance->size()
- ? (*fdistance)[arc.nextstate]
- : Weight::Zero());
- if (less(limit, weight)) continue;
- if ((opts.state_threshold != kNoStateId) &&
- (ofst->NumStates() >= opts.state_threshold))
- continue;
- while (idistance.size() <= arc.nextstate)
- idistance.push_back(Weight::Zero());
- if (less(Times(idistance[s], arc.weight),
- idistance[arc.nextstate]))
- idistance[arc.nextstate] = Times(idistance[s], arc.weight);
- while (copy.size() <= arc.nextstate)
- copy.push_back(kNoStateId);
- if (copy[arc.nextstate] == kNoStateId)
- copy[arc.nextstate] = ofst->AddState();
- ofst->AddArc(copy[s], Arc(arc.ilabel, arc.olabel, arc.weight,
- copy[arc.nextstate]));
- while (enqueued.size() <= arc.nextstate) {
- enqueued.push_back(kNoKey);
- visited.push_back(false);
- }
- if (visited[arc.nextstate]) continue;
- if (enqueued[arc.nextstate] == kNoKey)
- enqueued[arc.nextstate] = heap.Insert(arc.nextstate);
- else
- heap.Update(enqueued[arc.nextstate], arc.nextstate);
- }
- }
-}
-
-
-// Pruning algorithm: this version writes the pruned input Fst to an
-// output MutableFst and simply takes the pruning threshold as an
-// argument. 'ofst' contains states and arcs that belong to a
-// successful path in 'ifst' whose weight is no more than
-// the weight of the shortest path Times() 'weight_threshold'. When
-// 'state_threshold != kNoStateId', 'ofst' will be restricted further
-// to have at most 'opts.state_threshold' states. Weights need to be
-// commutative and have the path property. The weight 'w' of any cycle
-// needs to be bounded, i.e., 'Plus(w, W::One()) = W::One()'.
-template <class Arc>
-void Prune(const Fst<Arc> &ifst,
- MutableFst<Arc> *ofst,
- typename Arc::Weight weight_threshold,
- typename Arc::StateId state_threshold = kNoStateId,
- float delta = kDelta) {
- PruneOptions<Arc, AnyArcFilter<Arc> > opts(weight_threshold, state_threshold,
- AnyArcFilter<Arc>(), 0, delta);
- Prune(ifst, ofst, opts);
-}
-
-} // namespace fst
-
-#endif // FST_LIB_PRUNE_H_