diff options
Diffstat (limited to 'kaldi_io/src/tools/openfst/include/fst/prune.h')
-rw-r--r-- | kaldi_io/src/tools/openfst/include/fst/prune.h | 339 |
1 files changed, 0 insertions, 339 deletions
diff --git a/kaldi_io/src/tools/openfst/include/fst/prune.h b/kaldi_io/src/tools/openfst/include/fst/prune.h deleted file mode 100644 index 5ea5b4d..0000000 --- a/kaldi_io/src/tools/openfst/include/fst/prune.h +++ /dev/null @@ -1,339 +0,0 @@ -// prune.h - -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// -// Copyright 2005-2010 Google, Inc. -// Author: [email protected] (Cyril Allauzen) -// -// \file -// Functions implementing pruning. - -#ifndef FST_LIB_PRUNE_H__ -#define FST_LIB_PRUNE_H__ - -#include <vector> -using std::vector; - -#include <fst/arcfilter.h> -#include <fst/heap.h> -#include <fst/shortest-distance.h> - - -namespace fst { - -template <class A, class ArcFilter> -class PruneOptions { - public: - typedef typename A::Weight Weight; - typedef typename A::StateId StateId; - - // Pruning weight threshold. - Weight weight_threshold; - // Pruning state threshold. - StateId state_threshold; - // Arc filter. - ArcFilter filter; - // If non-zero, passes in pre-computed shortest distance to final states. - const vector<Weight> *distance; - // Determines the degree of convergence required when computing shortest - // distances. - float delta; - - explicit PruneOptions(const Weight& w, StateId s, ArcFilter f, - vector<Weight> *d = 0, float e = kDelta) - : weight_threshold(w), - state_threshold(s), - filter(f), - distance(d), - delta(e) {} - private: - PruneOptions(); // disallow -}; - - -template <class S, class W> -class PruneCompare { - public: - typedef S StateId; - typedef W Weight; - - PruneCompare(const vector<Weight> &idistance, - const vector<Weight> &fdistance) - : idistance_(idistance), fdistance_(fdistance) {} - - bool operator()(const StateId x, const StateId y) const { - Weight wx = Times(x < idistance_.size() ? idistance_[x] : Weight::Zero(), - x < fdistance_.size() ? fdistance_[x] : Weight::Zero()); - Weight wy = Times(y < idistance_.size() ? idistance_[y] : Weight::Zero(), - y < fdistance_.size() ? fdistance_[y] : Weight::Zero()); - return less_(wx, wy); - } - - private: - const vector<Weight> &idistance_; - const vector<Weight> &fdistance_; - NaturalLess<Weight> less_; -}; - - - -// Pruning algorithm: this version modifies its input and it takes an -// options class as an argment. Delete states and arcs in 'fst' that -// do not belong to a successful path whose weight is no more than -// the weight of the shortest path Times() 'opts.weight_threshold'. -// When 'opts.state_threshold != kNoStateId', the resulting transducer -// will restricted further to have at most 'opts.state_threshold' -// states. Weights need to be commutative and have the path -// property. The weight 'w' of any cycle needs to be bounded, i.e., -// 'Plus(w, W::One()) = One()'. -template <class Arc, class ArcFilter> -void Prune(MutableFst<Arc> *fst, - const PruneOptions<Arc, ArcFilter> &opts) { - typedef typename Arc::Weight Weight; - typedef typename Arc::StateId StateId; - - if ((Weight::Properties() & (kPath | kCommutative)) - != (kPath | kCommutative)) { - FSTERROR() << "Prune: Weight needs to have the path property and" - << " be commutative: " - << Weight::Type(); - fst->SetProperties(kError, kError); - return; - } - StateId ns = fst->NumStates(); - if (ns == 0) return; - vector<Weight> idistance(ns, Weight::Zero()); - vector<Weight> tmp; - if (!opts.distance) { - tmp.reserve(ns); - ShortestDistance(*fst, &tmp, true, opts.delta); - } - const vector<Weight> *fdistance = opts.distance ? opts.distance : &tmp; - - if ((opts.state_threshold == 0) || - (fdistance->size() <= fst->Start()) || - ((*fdistance)[fst->Start()] == Weight::Zero())) { - fst->DeleteStates(); - return; - } - PruneCompare<StateId, Weight> compare(idistance, *fdistance); - Heap< StateId, PruneCompare<StateId, Weight>, false> heap(compare); - vector<bool> visited(ns, false); - vector<size_t> enqueued(ns, kNoKey); - vector<StateId> dead; - dead.push_back(fst->AddState()); - NaturalLess<Weight> less; - Weight limit = Times((*fdistance)[fst->Start()], opts.weight_threshold); - - StateId num_visited = 0; - StateId s = fst->Start(); - if (!less(limit, (*fdistance)[s])) { - idistance[s] = Weight::One(); - enqueued[s] = heap.Insert(s); - ++num_visited; - } - - while (!heap.Empty()) { - s = heap.Top(); - heap.Pop(); - enqueued[s] = kNoKey; - visited[s] = true; - if (less(limit, Times(idistance[s], fst->Final(s)))) - fst->SetFinal(s, Weight::Zero()); - for (MutableArcIterator< MutableFst<Arc> > ait(fst, s); - !ait.Done(); - ait.Next()) { - Arc arc = ait.Value(); - if (!opts.filter(arc)) continue; - Weight weight = Times(Times(idistance[s], arc.weight), - arc.nextstate < fdistance->size() - ? (*fdistance)[arc.nextstate] - : Weight::Zero()); - if (less(limit, weight)) { - arc.nextstate = dead[0]; - ait.SetValue(arc); - continue; - } - if (less(Times(idistance[s], arc.weight), idistance[arc.nextstate])) - idistance[arc.nextstate] = Times(idistance[s], arc.weight); - if (visited[arc.nextstate]) continue; - if ((opts.state_threshold != kNoStateId) && - (num_visited >= opts.state_threshold)) - continue; - if (enqueued[arc.nextstate] == kNoKey) { - enqueued[arc.nextstate] = heap.Insert(arc.nextstate); - ++num_visited; - } else { - heap.Update(enqueued[arc.nextstate], arc.nextstate); - } - } - } - for (size_t i = 0; i < visited.size(); ++i) - if (!visited[i]) dead.push_back(i); - fst->DeleteStates(dead); -} - - -// Pruning algorithm: this version modifies its input and simply takes -// the pruning threshold as an argument. Delete states and arcs in -// 'fst' that do not belong to a successful path whose weight is no -// more than the weight of the shortest path Times() -// 'weight_threshold'. When 'state_threshold != kNoStateId', the -// resulting transducer will be restricted further to have at most -// 'opts.state_threshold' states. Weights need to be commutative and -// have the path property. The weight 'w' of any cycle needs to be -// bounded, i.e., 'Plus(w, W::One()) = One()'. -template <class Arc> -void Prune(MutableFst<Arc> *fst, - typename Arc::Weight weight_threshold, - typename Arc::StateId state_threshold = kNoStateId, - double delta = kDelta) { - PruneOptions<Arc, AnyArcFilter<Arc> > opts(weight_threshold, state_threshold, - AnyArcFilter<Arc>(), 0, delta); - Prune(fst, opts); -} - - -// Pruning algorithm: this version writes the pruned input Fst to an -// output MutableFst and it takes an options class as an argument. -// 'ofst' contains states and arcs that belong to a successful path in -// 'ifst' whose weight is no more than the weight of the shortest path -// Times() 'opts.weight_threshold'. When 'opts.state_threshold != -// kNoStateId', 'ofst' will be restricted further to have at most -// 'opts.state_threshold' states. Weights need to be commutative and -// have the path property. The weight 'w' of any cycle needs to be -// bounded, i.e., 'Plus(w, W::One()) = One()'. -template <class Arc, class ArcFilter> -void Prune(const Fst<Arc> &ifst, - MutableFst<Arc> *ofst, - const PruneOptions<Arc, ArcFilter> &opts) { - typedef typename Arc::Weight Weight; - typedef typename Arc::StateId StateId; - - if ((Weight::Properties() & (kPath | kCommutative)) - != (kPath | kCommutative)) { - FSTERROR() << "Prune: Weight needs to have the path property and" - << " be commutative: " - << Weight::Type(); - ofst->SetProperties(kError, kError); - return; - } - ofst->DeleteStates(); - ofst->SetInputSymbols(ifst.InputSymbols()); - ofst->SetOutputSymbols(ifst.OutputSymbols()); - if (ifst.Start() == kNoStateId) - return; - NaturalLess<Weight> less; - if (less(opts.weight_threshold, Weight::One()) || - (opts.state_threshold == 0)) - return; - vector<Weight> idistance; - vector<Weight> tmp; - if (!opts.distance) - ShortestDistance(ifst, &tmp, true, opts.delta); - const vector<Weight> *fdistance = opts.distance ? opts.distance : &tmp; - - if ((fdistance->size() <= ifst.Start()) || - ((*fdistance)[ifst.Start()] == Weight::Zero())) { - return; - } - PruneCompare<StateId, Weight> compare(idistance, *fdistance); - Heap< StateId, PruneCompare<StateId, Weight>, false> heap(compare); - vector<StateId> copy; - vector<size_t> enqueued; - vector<bool> visited; - - StateId s = ifst.Start(); - Weight limit = Times(s < fdistance->size() ? (*fdistance)[s] : Weight::Zero(), - opts.weight_threshold); - while (copy.size() <= s) - copy.push_back(kNoStateId); - copy[s] = ofst->AddState(); - ofst->SetStart(copy[s]); - while (idistance.size() <= s) - idistance.push_back(Weight::Zero()); - idistance[s] = Weight::One(); - while (enqueued.size() <= s) { - enqueued.push_back(kNoKey); - visited.push_back(false); - } - enqueued[s] = heap.Insert(s); - - while (!heap.Empty()) { - s = heap.Top(); - heap.Pop(); - enqueued[s] = kNoKey; - visited[s] = true; - if (!less(limit, Times(idistance[s], ifst.Final(s)))) - ofst->SetFinal(copy[s], ifst.Final(s)); - for (ArcIterator< Fst<Arc> > ait(ifst, s); - !ait.Done(); - ait.Next()) { - const Arc &arc = ait.Value(); - if (!opts.filter(arc)) continue; - Weight weight = Times(Times(idistance[s], arc.weight), - arc.nextstate < fdistance->size() - ? (*fdistance)[arc.nextstate] - : Weight::Zero()); - if (less(limit, weight)) continue; - if ((opts.state_threshold != kNoStateId) && - (ofst->NumStates() >= opts.state_threshold)) - continue; - while (idistance.size() <= arc.nextstate) - idistance.push_back(Weight::Zero()); - if (less(Times(idistance[s], arc.weight), - idistance[arc.nextstate])) - idistance[arc.nextstate] = Times(idistance[s], arc.weight); - while (copy.size() <= arc.nextstate) - copy.push_back(kNoStateId); - if (copy[arc.nextstate] == kNoStateId) - copy[arc.nextstate] = ofst->AddState(); - ofst->AddArc(copy[s], Arc(arc.ilabel, arc.olabel, arc.weight, - copy[arc.nextstate])); - while (enqueued.size() <= arc.nextstate) { - enqueued.push_back(kNoKey); - visited.push_back(false); - } - if (visited[arc.nextstate]) continue; - if (enqueued[arc.nextstate] == kNoKey) - enqueued[arc.nextstate] = heap.Insert(arc.nextstate); - else - heap.Update(enqueued[arc.nextstate], arc.nextstate); - } - } -} - - -// Pruning algorithm: this version writes the pruned input Fst to an -// output MutableFst and simply takes the pruning threshold as an -// argument. 'ofst' contains states and arcs that belong to a -// successful path in 'ifst' whose weight is no more than -// the weight of the shortest path Times() 'weight_threshold'. When -// 'state_threshold != kNoStateId', 'ofst' will be restricted further -// to have at most 'opts.state_threshold' states. Weights need to be -// commutative and have the path property. The weight 'w' of any cycle -// needs to be bounded, i.e., 'Plus(w, W::One()) = W::One()'. -template <class Arc> -void Prune(const Fst<Arc> &ifst, - MutableFst<Arc> *ofst, - typename Arc::Weight weight_threshold, - typename Arc::StateId state_threshold = kNoStateId, - float delta = kDelta) { - PruneOptions<Arc, AnyArcFilter<Arc> > opts(weight_threshold, state_threshold, - AnyArcFilter<Arc>(), 0, delta); - Prune(ifst, ofst, opts); -} - -} // namespace fst - -#endif // FST_LIB_PRUNE_H_ |