diff options
Diffstat (limited to 'kaldi_io/src/tools/openfst/include/fst/shortest-distance.h')
-rw-r--r-- | kaldi_io/src/tools/openfst/include/fst/shortest-distance.h | 348 |
1 files changed, 348 insertions, 0 deletions
diff --git a/kaldi_io/src/tools/openfst/include/fst/shortest-distance.h b/kaldi_io/src/tools/openfst/include/fst/shortest-distance.h new file mode 100644 index 0000000..ec47a14 --- /dev/null +++ b/kaldi_io/src/tools/openfst/include/fst/shortest-distance.h @@ -0,0 +1,348 @@ +// shortest-distance.h + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Copyright 2005-2010 Google, Inc. +// Author: [email protected] (Cyril Allauzen) +// +// \file +// Functions and classes to find shortest distance in an FST. + +#ifndef FST_LIB_SHORTEST_DISTANCE_H__ +#define FST_LIB_SHORTEST_DISTANCE_H__ + +#include <deque> +using std::deque; +#include <vector> +using std::vector; + +#include <fst/arcfilter.h> +#include <fst/cache.h> +#include <fst/queue.h> +#include <fst/reverse.h> +#include <fst/test-properties.h> + + +namespace fst { + +template <class Arc, class Queue, class ArcFilter> +struct ShortestDistanceOptions { + typedef typename Arc::StateId StateId; + + Queue *state_queue; // Queue discipline used; owned by caller + ArcFilter arc_filter; // Arc filter (e.g., limit to only epsilon graph) + StateId source; // If kNoStateId, use the Fst's initial state + float delta; // Determines the degree of convergence required + bool first_path; // For a semiring with the path property (o.w. + // undefined), compute the shortest-distances along + // along the first path to a final state found + // by the algorithm. That path is the shortest-path + // only if the FST has a unique final state (or all + // the final states have the same final weight), the + // queue discipline is shortest-first and all the + // weights in the FST are between One() and Zero() + // according to NaturalLess. + + ShortestDistanceOptions(Queue *q, ArcFilter filt, StateId src = kNoStateId, + float d = kDelta) + : state_queue(q), arc_filter(filt), source(src), delta(d), + first_path(false) {} +}; + + +// Computation state of the shortest-distance algorithm. Reusable +// information is maintained across calls to member function +// ShortestDistance(source) when 'retain' is true for improved +// efficiency when calling multiple times from different source states +// (e.g., in epsilon removal). Contrary to usual conventions, 'fst' +// may not be freed before this class. Vector 'distance' should not be +// modified by the user between these calls. +// The Error() method returns true if an error was encountered. +template<class Arc, class Queue, class ArcFilter> +class ShortestDistanceState { + public: + typedef typename Arc::StateId StateId; + typedef typename Arc::Weight Weight; + + ShortestDistanceState( + const Fst<Arc> &fst, + vector<Weight> *distance, + const ShortestDistanceOptions<Arc, Queue, ArcFilter> &opts, + bool retain) + : fst_(fst), distance_(distance), state_queue_(opts.state_queue), + arc_filter_(opts.arc_filter), delta_(opts.delta), + first_path_(opts.first_path), retain_(retain), source_id_(0), + error_(false) { + distance_->clear(); + } + + ~ShortestDistanceState() {} + + void ShortestDistance(StateId source); + + bool Error() const { return error_; } + + private: + const Fst<Arc> &fst_; + vector<Weight> *distance_; + Queue *state_queue_; + ArcFilter arc_filter_; + float delta_; + bool first_path_; + bool retain_; // Retain and reuse information across calls + + vector<Weight> rdistance_; // Relaxation distance. + vector<bool> enqueued_; // Is state enqueued? + vector<StateId> sources_; // Source ID for ith state in 'distance_', + // 'rdistance_', and 'enqueued_' if retained. + StateId source_id_; // Unique ID characterizing each call to SD + + bool error_; +}; + +// Compute the shortest distance. If 'source' is kNoStateId, use +// the initial state of the Fst. +template <class Arc, class Queue, class ArcFilter> +void ShortestDistanceState<Arc, Queue, ArcFilter>::ShortestDistance( + StateId source) { + if (fst_.Start() == kNoStateId) { + if (fst_.Properties(kError, false)) error_ = true; + return; + } + + if (!(Weight::Properties() & kRightSemiring)) { + FSTERROR() << "ShortestDistance: Weight needs to be right distributive: " + << Weight::Type(); + error_ = true; + return; + } + + if (first_path_ && !(Weight::Properties() & kPath)) { + FSTERROR() << "ShortestDistance: first_path option disallowed when " + << "Weight does not have the path property: " + << Weight::Type(); + error_ = true; + return; + } + + state_queue_->Clear(); + + if (!retain_) { + distance_->clear(); + rdistance_.clear(); + enqueued_.clear(); + } + + if (source == kNoStateId) + source = fst_.Start(); + + while (distance_->size() <= source) { + distance_->push_back(Weight::Zero()); + rdistance_.push_back(Weight::Zero()); + enqueued_.push_back(false); + } + if (retain_) { + while (sources_.size() <= source) + sources_.push_back(kNoStateId); + sources_[source] = source_id_; + } + (*distance_)[source] = Weight::One(); + rdistance_[source] = Weight::One(); + enqueued_[source] = true; + + state_queue_->Enqueue(source); + + while (!state_queue_->Empty()) { + StateId s = state_queue_->Head(); + state_queue_->Dequeue(); + while (distance_->size() <= s) { + distance_->push_back(Weight::Zero()); + rdistance_.push_back(Weight::Zero()); + enqueued_.push_back(false); + } + if (first_path_ && (fst_.Final(s) != Weight::Zero())) + break; + enqueued_[s] = false; + Weight r = rdistance_[s]; + rdistance_[s] = Weight::Zero(); + for (ArcIterator< Fst<Arc> > aiter(fst_, s); + !aiter.Done(); + aiter.Next()) { + const Arc &arc = aiter.Value(); + if (!arc_filter_(arc)) + continue; + while (distance_->size() <= arc.nextstate) { + distance_->push_back(Weight::Zero()); + rdistance_.push_back(Weight::Zero()); + enqueued_.push_back(false); + } + if (retain_) { + while (sources_.size() <= arc.nextstate) + sources_.push_back(kNoStateId); + if (sources_[arc.nextstate] != source_id_) { + (*distance_)[arc.nextstate] = Weight::Zero(); + rdistance_[arc.nextstate] = Weight::Zero(); + enqueued_[arc.nextstate] = false; + sources_[arc.nextstate] = source_id_; + } + } + Weight &nd = (*distance_)[arc.nextstate]; + Weight &nr = rdistance_[arc.nextstate]; + Weight w = Times(r, arc.weight); + if (!ApproxEqual(nd, Plus(nd, w), delta_)) { + nd = Plus(nd, w); + nr = Plus(nr, w); + if (!nd.Member() || !nr.Member()) { + error_ = true; + return; + } + if (!enqueued_[arc.nextstate]) { + state_queue_->Enqueue(arc.nextstate); + enqueued_[arc.nextstate] = true; + } else { + state_queue_->Update(arc.nextstate); + } + } + } + } + ++source_id_; + if (fst_.Properties(kError, false)) error_ = true; +} + + +// Shortest-distance algorithm: this version allows fine control +// via the options argument. See below for a simpler interface. +// +// This computes the shortest distance from the 'opts.source' state to +// each visited state S and stores the value in the 'distance' vector. +// An unvisited state S has distance Zero(), which will be stored in +// the 'distance' vector if S is less than the maximum visited state. +// The state queue discipline, arc filter, and convergence delta are +// taken in the options argument. +// The 'distance' vector will contain a unique element for which +// Member() is false if an error was encountered. +// +// The weights must must be right distributive and k-closed (i.e., 1 + +// x + x^2 + ... + x^(k +1) = 1 + x + x^2 + ... + x^k). +// +// The algorithm is from Mohri, "Semiring Framweork and Algorithms for +// Shortest-Distance Problems", Journal of Automata, Languages and +// Combinatorics 7(3):321-350, 2002. The complexity of algorithm +// depends on the properties of the semiring and the queue discipline +// used. Refer to the paper for more details. +template<class Arc, class Queue, class ArcFilter> +void ShortestDistance( + const Fst<Arc> &fst, + vector<typename Arc::Weight> *distance, + const ShortestDistanceOptions<Arc, Queue, ArcFilter> &opts) { + + ShortestDistanceState<Arc, Queue, ArcFilter> + sd_state(fst, distance, opts, false); + sd_state.ShortestDistance(opts.source); + if (sd_state.Error()) { + distance->clear(); + distance->resize(1, Arc::Weight::NoWeight()); + } +} + +// Shortest-distance algorithm: simplified interface. See above for a +// version that allows finer control. +// +// If 'reverse' is false, this computes the shortest distance from the +// initial state to each state S and stores the value in the +// 'distance' vector. If 'reverse' is true, this computes the shortest +// distance from each state to the final states. An unvisited state S +// has distance Zero(), which will be stored in the 'distance' vector +// if S is less than the maximum visited state. The state queue +// discipline is automatically-selected. +// The 'distance' vector will contain a unique element for which +// Member() is false if an error was encountered. +// +// The weights must must be right (left) distributive if reverse is +// false (true) and k-closed (i.e., 1 + x + x^2 + ... + x^(k +1) = 1 + +// x + x^2 + ... + x^k). +// +// The algorithm is from Mohri, "Semiring Framweork and Algorithms for +// Shortest-Distance Problems", Journal of Automata, Languages and +// Combinatorics 7(3):321-350, 2002. The complexity of algorithm +// depends on the properties of the semiring and the queue discipline +// used. Refer to the paper for more details. +template<class Arc> +void ShortestDistance(const Fst<Arc> &fst, + vector<typename Arc::Weight> *distance, + bool reverse = false, + float delta = kDelta) { + typedef typename Arc::StateId StateId; + typedef typename Arc::Weight Weight; + + if (!reverse) { + AnyArcFilter<Arc> arc_filter; + AutoQueue<StateId> state_queue(fst, distance, arc_filter); + ShortestDistanceOptions< Arc, AutoQueue<StateId>, AnyArcFilter<Arc> > + opts(&state_queue, arc_filter); + opts.delta = delta; + ShortestDistance(fst, distance, opts); + } else { + typedef ReverseArc<Arc> ReverseArc; + typedef typename ReverseArc::Weight ReverseWeight; + AnyArcFilter<ReverseArc> rarc_filter; + VectorFst<ReverseArc> rfst; + Reverse(fst, &rfst); + vector<ReverseWeight> rdistance; + AutoQueue<StateId> state_queue(rfst, &rdistance, rarc_filter); + ShortestDistanceOptions< ReverseArc, AutoQueue<StateId>, + AnyArcFilter<ReverseArc> > + ropts(&state_queue, rarc_filter); + ropts.delta = delta; + ShortestDistance(rfst, &rdistance, ropts); + distance->clear(); + if (rdistance.size() == 1 && !rdistance[0].Member()) { + distance->resize(1, Arc::Weight::NoWeight()); + return; + } + while (distance->size() < rdistance.size() - 1) + distance->push_back(rdistance[distance->size() + 1].Reverse()); + } +} + + +// Return the sum of the weight of all successful paths in an FST, i.e., +// the shortest-distance from the initial state to the final states. +// Returns a weight such that Member() is false if an error was encountered. +template <class Arc> +typename Arc::Weight ShortestDistance(const Fst<Arc> &fst, float delta = kDelta) { + typedef typename Arc::Weight Weight; + typedef typename Arc::StateId StateId; + vector<Weight> distance; + if (Weight::Properties() & kRightSemiring) { + ShortestDistance(fst, &distance, false, delta); + if (distance.size() == 1 && !distance[0].Member()) + return Arc::Weight::NoWeight(); + Weight sum = Weight::Zero(); + for (StateId s = 0; s < distance.size(); ++s) + sum = Plus(sum, Times(distance[s], fst.Final(s))); + return sum; + } else { + ShortestDistance(fst, &distance, true, delta); + StateId s = fst.Start(); + if (distance.size() == 1 && !distance[0].Member()) + return Arc::Weight::NoWeight(); + return s != kNoStateId && s < distance.size() ? + distance[s] : Weight::Zero(); + } +} + + +} // namespace fst + +#endif // FST_LIB_SHORTEST_DISTANCE_H__ |