From 96a32415ab43377cf1575bd3f4f2980f58028209 Mon Sep 17 00:00:00 2001 From: Determinant Date: Fri, 14 Aug 2015 11:51:42 +0800 Subject: add implementation for kaldi io (by ymz) --- kaldi_io/src/tools/openfst/include/fst/prune.h | 339 +++++++++++++++++++++++++ 1 file changed, 339 insertions(+) create mode 100644 kaldi_io/src/tools/openfst/include/fst/prune.h (limited to 'kaldi_io/src/tools/openfst/include/fst/prune.h') diff --git a/kaldi_io/src/tools/openfst/include/fst/prune.h b/kaldi_io/src/tools/openfst/include/fst/prune.h new file mode 100644 index 0000000..5ea5b4d --- /dev/null +++ b/kaldi_io/src/tools/openfst/include/fst/prune.h @@ -0,0 +1,339 @@ +// prune.h + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Copyright 2005-2010 Google, Inc. +// Author: allauzen@google.com (Cyril Allauzen) +// +// \file +// Functions implementing pruning. + +#ifndef FST_LIB_PRUNE_H__ +#define FST_LIB_PRUNE_H__ + +#include +using std::vector; + +#include +#include +#include + + +namespace fst { + +template +class PruneOptions { + public: + typedef typename A::Weight Weight; + typedef typename A::StateId StateId; + + // Pruning weight threshold. + Weight weight_threshold; + // Pruning state threshold. + StateId state_threshold; + // Arc filter. + ArcFilter filter; + // If non-zero, passes in pre-computed shortest distance to final states. + const vector *distance; + // Determines the degree of convergence required when computing shortest + // distances. + float delta; + + explicit PruneOptions(const Weight& w, StateId s, ArcFilter f, + vector *d = 0, float e = kDelta) + : weight_threshold(w), + state_threshold(s), + filter(f), + distance(d), + delta(e) {} + private: + PruneOptions(); // disallow +}; + + +template +class PruneCompare { + public: + typedef S StateId; + typedef W Weight; + + PruneCompare(const vector &idistance, + const vector &fdistance) + : idistance_(idistance), fdistance_(fdistance) {} + + bool operator()(const StateId x, const StateId y) const { + Weight wx = Times(x < idistance_.size() ? idistance_[x] : Weight::Zero(), + x < fdistance_.size() ? fdistance_[x] : Weight::Zero()); + Weight wy = Times(y < idistance_.size() ? idistance_[y] : Weight::Zero(), + y < fdistance_.size() ? fdistance_[y] : Weight::Zero()); + return less_(wx, wy); + } + + private: + const vector &idistance_; + const vector &fdistance_; + NaturalLess less_; +}; + + + +// Pruning algorithm: this version modifies its input and it takes an +// options class as an argment. Delete states and arcs in 'fst' that +// do not belong to a successful path whose weight is no more than +// the weight of the shortest path Times() 'opts.weight_threshold'. +// When 'opts.state_threshold != kNoStateId', the resulting transducer +// will restricted further to have at most 'opts.state_threshold' +// states. Weights need to be commutative and have the path +// property. The weight 'w' of any cycle needs to be bounded, i.e., +// 'Plus(w, W::One()) = One()'. +template +void Prune(MutableFst *fst, + const PruneOptions &opts) { + typedef typename Arc::Weight Weight; + typedef typename Arc::StateId StateId; + + if ((Weight::Properties() & (kPath | kCommutative)) + != (kPath | kCommutative)) { + FSTERROR() << "Prune: Weight needs to have the path property and" + << " be commutative: " + << Weight::Type(); + fst->SetProperties(kError, kError); + return; + } + StateId ns = fst->NumStates(); + if (ns == 0) return; + vector idistance(ns, Weight::Zero()); + vector tmp; + if (!opts.distance) { + tmp.reserve(ns); + ShortestDistance(*fst, &tmp, true, opts.delta); + } + const vector *fdistance = opts.distance ? opts.distance : &tmp; + + if ((opts.state_threshold == 0) || + (fdistance->size() <= fst->Start()) || + ((*fdistance)[fst->Start()] == Weight::Zero())) { + fst->DeleteStates(); + return; + } + PruneCompare compare(idistance, *fdistance); + Heap< StateId, PruneCompare, false> heap(compare); + vector visited(ns, false); + vector enqueued(ns, kNoKey); + vector dead; + dead.push_back(fst->AddState()); + NaturalLess less; + Weight limit = Times((*fdistance)[fst->Start()], opts.weight_threshold); + + StateId num_visited = 0; + StateId s = fst->Start(); + if (!less(limit, (*fdistance)[s])) { + idistance[s] = Weight::One(); + enqueued[s] = heap.Insert(s); + ++num_visited; + } + + while (!heap.Empty()) { + s = heap.Top(); + heap.Pop(); + enqueued[s] = kNoKey; + visited[s] = true; + if (less(limit, Times(idistance[s], fst->Final(s)))) + fst->SetFinal(s, Weight::Zero()); + for (MutableArcIterator< MutableFst > ait(fst, s); + !ait.Done(); + ait.Next()) { + Arc arc = ait.Value(); + if (!opts.filter(arc)) continue; + Weight weight = Times(Times(idistance[s], arc.weight), + arc.nextstate < fdistance->size() + ? (*fdistance)[arc.nextstate] + : Weight::Zero()); + if (less(limit, weight)) { + arc.nextstate = dead[0]; + ait.SetValue(arc); + continue; + } + if (less(Times(idistance[s], arc.weight), idistance[arc.nextstate])) + idistance[arc.nextstate] = Times(idistance[s], arc.weight); + if (visited[arc.nextstate]) continue; + if ((opts.state_threshold != kNoStateId) && + (num_visited >= opts.state_threshold)) + continue; + if (enqueued[arc.nextstate] == kNoKey) { + enqueued[arc.nextstate] = heap.Insert(arc.nextstate); + ++num_visited; + } else { + heap.Update(enqueued[arc.nextstate], arc.nextstate); + } + } + } + for (size_t i = 0; i < visited.size(); ++i) + if (!visited[i]) dead.push_back(i); + fst->DeleteStates(dead); +} + + +// Pruning algorithm: this version modifies its input and simply takes +// the pruning threshold as an argument. Delete states and arcs in +// 'fst' that do not belong to a successful path whose weight is no +// more than the weight of the shortest path Times() +// 'weight_threshold'. When 'state_threshold != kNoStateId', the +// resulting transducer will be restricted further to have at most +// 'opts.state_threshold' states. Weights need to be commutative and +// have the path property. The weight 'w' of any cycle needs to be +// bounded, i.e., 'Plus(w, W::One()) = One()'. +template +void Prune(MutableFst *fst, + typename Arc::Weight weight_threshold, + typename Arc::StateId state_threshold = kNoStateId, + double delta = kDelta) { + PruneOptions > opts(weight_threshold, state_threshold, + AnyArcFilter(), 0, delta); + Prune(fst, opts); +} + + +// Pruning algorithm: this version writes the pruned input Fst to an +// output MutableFst and it takes an options class as an argument. +// 'ofst' contains states and arcs that belong to a successful path in +// 'ifst' whose weight is no more than the weight of the shortest path +// Times() 'opts.weight_threshold'. When 'opts.state_threshold != +// kNoStateId', 'ofst' will be restricted further to have at most +// 'opts.state_threshold' states. Weights need to be commutative and +// have the path property. The weight 'w' of any cycle needs to be +// bounded, i.e., 'Plus(w, W::One()) = One()'. +template +void Prune(const Fst &ifst, + MutableFst *ofst, + const PruneOptions &opts) { + typedef typename Arc::Weight Weight; + typedef typename Arc::StateId StateId; + + if ((Weight::Properties() & (kPath | kCommutative)) + != (kPath | kCommutative)) { + FSTERROR() << "Prune: Weight needs to have the path property and" + << " be commutative: " + << Weight::Type(); + ofst->SetProperties(kError, kError); + return; + } + ofst->DeleteStates(); + ofst->SetInputSymbols(ifst.InputSymbols()); + ofst->SetOutputSymbols(ifst.OutputSymbols()); + if (ifst.Start() == kNoStateId) + return; + NaturalLess less; + if (less(opts.weight_threshold, Weight::One()) || + (opts.state_threshold == 0)) + return; + vector idistance; + vector tmp; + if (!opts.distance) + ShortestDistance(ifst, &tmp, true, opts.delta); + const vector *fdistance = opts.distance ? opts.distance : &tmp; + + if ((fdistance->size() <= ifst.Start()) || + ((*fdistance)[ifst.Start()] == Weight::Zero())) { + return; + } + PruneCompare compare(idistance, *fdistance); + Heap< StateId, PruneCompare, false> heap(compare); + vector copy; + vector enqueued; + vector visited; + + StateId s = ifst.Start(); + Weight limit = Times(s < fdistance->size() ? (*fdistance)[s] : Weight::Zero(), + opts.weight_threshold); + while (copy.size() <= s) + copy.push_back(kNoStateId); + copy[s] = ofst->AddState(); + ofst->SetStart(copy[s]); + while (idistance.size() <= s) + idistance.push_back(Weight::Zero()); + idistance[s] = Weight::One(); + while (enqueued.size() <= s) { + enqueued.push_back(kNoKey); + visited.push_back(false); + } + enqueued[s] = heap.Insert(s); + + while (!heap.Empty()) { + s = heap.Top(); + heap.Pop(); + enqueued[s] = kNoKey; + visited[s] = true; + if (!less(limit, Times(idistance[s], ifst.Final(s)))) + ofst->SetFinal(copy[s], ifst.Final(s)); + for (ArcIterator< Fst > ait(ifst, s); + !ait.Done(); + ait.Next()) { + const Arc &arc = ait.Value(); + if (!opts.filter(arc)) continue; + Weight weight = Times(Times(idistance[s], arc.weight), + arc.nextstate < fdistance->size() + ? (*fdistance)[arc.nextstate] + : Weight::Zero()); + if (less(limit, weight)) continue; + if ((opts.state_threshold != kNoStateId) && + (ofst->NumStates() >= opts.state_threshold)) + continue; + while (idistance.size() <= arc.nextstate) + idistance.push_back(Weight::Zero()); + if (less(Times(idistance[s], arc.weight), + idistance[arc.nextstate])) + idistance[arc.nextstate] = Times(idistance[s], arc.weight); + while (copy.size() <= arc.nextstate) + copy.push_back(kNoStateId); + if (copy[arc.nextstate] == kNoStateId) + copy[arc.nextstate] = ofst->AddState(); + ofst->AddArc(copy[s], Arc(arc.ilabel, arc.olabel, arc.weight, + copy[arc.nextstate])); + while (enqueued.size() <= arc.nextstate) { + enqueued.push_back(kNoKey); + visited.push_back(false); + } + if (visited[arc.nextstate]) continue; + if (enqueued[arc.nextstate] == kNoKey) + enqueued[arc.nextstate] = heap.Insert(arc.nextstate); + else + heap.Update(enqueued[arc.nextstate], arc.nextstate); + } + } +} + + +// Pruning algorithm: this version writes the pruned input Fst to an +// output MutableFst and simply takes the pruning threshold as an +// argument. 'ofst' contains states and arcs that belong to a +// successful path in 'ifst' whose weight is no more than +// the weight of the shortest path Times() 'weight_threshold'. When +// 'state_threshold != kNoStateId', 'ofst' will be restricted further +// to have at most 'opts.state_threshold' states. Weights need to be +// commutative and have the path property. The weight 'w' of any cycle +// needs to be bounded, i.e., 'Plus(w, W::One()) = W::One()'. +template +void Prune(const Fst &ifst, + MutableFst *ofst, + typename Arc::Weight weight_threshold, + typename Arc::StateId state_threshold = kNoStateId, + float delta = kDelta) { + PruneOptions > opts(weight_threshold, state_threshold, + AnyArcFilter(), 0, delta); + Prune(ifst, ofst, opts); +} + +} // namespace fst + +#endif // FST_LIB_PRUNE_H_ -- cgit v1.2.3