From 96a32415ab43377cf1575bd3f4f2980f58028209 Mon Sep 17 00:00:00 2001 From: Determinant Date: Fri, 14 Aug 2015 11:51:42 +0800 Subject: add implementation for kaldi io (by ymz) --- kaldi_io/src/tools/openfst/include/fst/queue.h | 938 +++++++++++++++++++++++++ 1 file changed, 938 insertions(+) create mode 100644 kaldi_io/src/tools/openfst/include/fst/queue.h (limited to 'kaldi_io/src/tools/openfst/include/fst/queue.h') diff --git a/kaldi_io/src/tools/openfst/include/fst/queue.h b/kaldi_io/src/tools/openfst/include/fst/queue.h new file mode 100644 index 0000000..95a082d --- /dev/null +++ b/kaldi_io/src/tools/openfst/include/fst/queue.h @@ -0,0 +1,938 @@ +// queue.h + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Copyright 2005-2010 Google, Inc. +// Author: allauzen@google.com (Cyril Allauzen) +// +// \file +// Functions and classes for various Fst state queues with +// a unified interface. + +#ifndef FST_LIB_QUEUE_H__ +#define FST_LIB_QUEUE_H__ + +#include +using std::deque; +#include +using std::vector; + +#include +#include +#include +#include + + +namespace fst { + +// template +// class Queue { +// public: +// typedef typename S StateId; +// +// // Ctr: may need args (e.g., Fst, comparator) for some queues +// Queue(...); +// // Returns the head of the queue +// StateId Head() const; +// // Inserts a state +// void Enqueue(StateId s); +// // Removes the head of the queue +// void Dequeue(); +// // Updates ordering of state s when weight changes, if necessary +// void Update(StateId s); +// // Does the queue contain no elements? +// bool Empty() const; +// // Remove all states from queue +// void Clear(); +// }; + +// State queue types. +enum QueueType { + TRIVIAL_QUEUE = 0, // Single state queue + FIFO_QUEUE = 1, // First-in, first-out queue + LIFO_QUEUE = 2, // Last-in, first-out queue + SHORTEST_FIRST_QUEUE = 3, // Shortest-first queue + TOP_ORDER_QUEUE = 4, // Topologically-ordered queue + STATE_ORDER_QUEUE = 5, // State-ID ordered queue + SCC_QUEUE = 6, // Component graph top-ordered meta-queue + AUTO_QUEUE = 7, // Auto-selected queue + OTHER_QUEUE = 8 + }; + + +// QueueBase, templated on the StateId, is the base class shared by the +// queues considered by AutoQueue. +template +class QueueBase { + public: + typedef S StateId; + + QueueBase(QueueType type) : queue_type_(type), error_(false) {} + virtual ~QueueBase() {} + StateId Head() const { return Head_(); } + void Enqueue(StateId s) { Enqueue_(s); } + void Dequeue() { Dequeue_(); } + void Update(StateId s) { Update_(s); } + bool Empty() const { return Empty_(); } + void Clear() { Clear_(); } + QueueType Type() { return queue_type_; } + bool Error() const { return error_; } + void SetError(bool error) { error_ = error; } + + private: + // This allows base-class virtual access to non-virtual derived- + // class members of the same name. It makes the derived class more + // efficient to use but unsafe to further derive. + virtual StateId Head_() const = 0; + virtual void Enqueue_(StateId s) = 0; + virtual void Dequeue_() = 0; + virtual void Update_(StateId s) = 0; + virtual bool Empty_() const = 0; + virtual void Clear_() = 0; + + QueueType queue_type_; + bool error_; +}; + + +// Trivial queue discipline, templated on the StateId. You may enqueue +// at most one state at a time. It is used for strongly connected components +// with only one state and no self loops. +template +class TrivialQueue : public QueueBase { +public: + typedef S StateId; + + TrivialQueue() : QueueBase(TRIVIAL_QUEUE), front_(kNoStateId) {} + StateId Head() const { return front_; } + void Enqueue(StateId s) { front_ = s; } + void Dequeue() { front_ = kNoStateId; } + void Update(StateId s) {} + bool Empty() const { return front_ == kNoStateId; } + void Clear() { front_ = kNoStateId; } + + +private: + // This allows base-class virtual access to non-virtual derived- + // class members of the same name. It makes the derived class more + // efficient to use but unsafe to further derive. + virtual StateId Head_() const { return Head(); } + virtual void Enqueue_(StateId s) { Enqueue(s); } + virtual void Dequeue_() { Dequeue(); } + virtual void Update_(StateId s) { Update(s); } + virtual bool Empty_() const { return Empty(); } + virtual void Clear_() { return Clear(); } + + StateId front_; +}; + + +// First-in, first-out queue discipline, templated on the StateId. +template +class FifoQueue : public QueueBase, public deque { + public: + using deque::back; + using deque::push_front; + using deque::pop_back; + using deque::empty; + using deque::clear; + + typedef S StateId; + + FifoQueue() : QueueBase(FIFO_QUEUE) {} + StateId Head() const { return back(); } + void Enqueue(StateId s) { push_front(s); } + void Dequeue() { pop_back(); } + void Update(StateId s) {} + bool Empty() const { return empty(); } + void Clear() { clear(); } + + private: + // This allows base-class virtual access to non-virtual derived- + // class members of the same name. It makes the derived class more + // efficient to use but unsafe to further derive. + virtual StateId Head_() const { return Head(); } + virtual void Enqueue_(StateId s) { Enqueue(s); } + virtual void Dequeue_() { Dequeue(); } + virtual void Update_(StateId s) { Update(s); } + virtual bool Empty_() const { return Empty(); } + virtual void Clear_() { return Clear(); } +}; + + +// Last-in, first-out queue discipline, templated on the StateId. +template +class LifoQueue : public QueueBase, public deque { + public: + using deque::front; + using deque::push_front; + using deque::pop_front; + using deque::empty; + using deque::clear; + + typedef S StateId; + + LifoQueue() : QueueBase(LIFO_QUEUE) {} + StateId Head() const { return front(); } + void Enqueue(StateId s) { push_front(s); } + void Dequeue() { pop_front(); } + void Update(StateId s) {} + bool Empty() const { return empty(); } + void Clear() { clear(); } + + private: + // This allows base-class virtual access to non-virtual derived- + // class members of the same name. It makes the derived class more + // efficient to use but unsafe to further derive. + virtual StateId Head_() const { return Head(); } + virtual void Enqueue_(StateId s) { Enqueue(s); } + virtual void Dequeue_() { Dequeue(); } + virtual void Update_(StateId s) { Update(s); } + virtual bool Empty_() const { return Empty(); } + virtual void Clear_() { return Clear(); } +}; + + +// Shortest-first queue discipline, templated on the StateId and +// comparison function object. Comparison function object COMP is +// used to compare two StateIds. If a (single) state's order changes, +// it can be reordered in the queue with a call to Update(). +// If 'update == false', call to Update() does not reorder the queue. +template +class ShortestFirstQueue : public QueueBase { + public: + typedef S StateId; + typedef C Compare; + + ShortestFirstQueue(C comp) + : QueueBase(SHORTEST_FIRST_QUEUE), heap_(comp) {} + + StateId Head() const { return heap_.Top(); } + + void Enqueue(StateId s) { + if (update) { + for (StateId i = key_.size(); i <= s; ++i) + key_.push_back(kNoKey); + key_[s] = heap_.Insert(s); + } else { + heap_.Insert(s); + } + } + + void Dequeue() { + if (update) + key_[heap_.Pop()] = kNoKey; + else + heap_.Pop(); + } + + void Update(StateId s) { + if (!update) + return; + if (s >= key_.size() || key_[s] == kNoKey) { + Enqueue(s); + } else { + heap_.Update(key_[s], s); + } + } + + bool Empty() const { return heap_.Empty(); } + + void Clear() { + heap_.Clear(); + if (update) key_.clear(); + } + + private: + Heap heap_; + vector key_; + + // This allows base-class virtual access to non-virtual derived- + // class members of the same name. It makes the derived class more + // efficient to use but unsafe to further derive. + virtual StateId Head_() const { return Head(); } + virtual void Enqueue_(StateId s) { Enqueue(s); } + virtual void Dequeue_() { Dequeue(); } + virtual void Update_(StateId s) { Update(s); } + virtual bool Empty_() const { return Empty(); } + virtual void Clear_() { return Clear(); } +}; + + +// Given a vector that maps from states to weights and a Less +// comparison function object between weights, this class defines a +// comparison function object between states. +template +class StateWeightCompare { + public: + typedef L Less; + typedef typename L::Weight Weight; + typedef S StateId; + + StateWeightCompare(const vector& weights, const L &less) + : weights_(weights), less_(less) {} + + bool operator()(const S x, const S y) const { + return less_(weights_[x], weights_[y]); + } + + private: + const vector& weights_; + L less_; +}; + + +// Shortest-first queue discipline, templated on the StateId and Weight, is +// specialized to use the weight's natural order for the comparison function. +template +class NaturalShortestFirstQueue : + public ShortestFirstQueue > > { + public: + typedef StateWeightCompare > C; + + NaturalShortestFirstQueue(const vector &distance) : + ShortestFirstQueue(C(distance, less_)) {} + + private: + NaturalLess less_; +}; + +// Topological-order queue discipline, templated on the StateId. +// States are ordered in the queue topologically. The FST must be acyclic. +template +class TopOrderQueue : public QueueBase { + public: + typedef S StateId; + + // This constructor computes the top. order. It accepts an arc filter + // to limit the transitions considered in that computation (e.g., only + // the epsilon graph). + template + TopOrderQueue(const Fst &fst, ArcFilter filter) + : QueueBase(TOP_ORDER_QUEUE), front_(0), back_(kNoStateId), + order_(0), state_(0) { + bool acyclic; + TopOrderVisitor top_order_visitor(&order_, &acyclic); + DfsVisit(fst, &top_order_visitor, filter); + if (!acyclic) { + FSTERROR() << "TopOrderQueue: fst is not acyclic."; + QueueBase::SetError(true); + } + state_.resize(order_.size(), kNoStateId); + } + + // This constructor is passed the top. order, useful when we know it + // beforehand. + TopOrderQueue(const vector &order) + : QueueBase(TOP_ORDER_QUEUE), front_(0), back_(kNoStateId), + order_(order), state_(order.size(), kNoStateId) {} + + StateId Head() const { return state_[front_]; } + + void Enqueue(StateId s) { + if (front_ > back_) front_ = back_ = order_[s]; + else if (order_[s] > back_) back_ = order_[s]; + else if (order_[s] < front_) front_ = order_[s]; + state_[order_[s]] = s; + } + + void Dequeue() { + state_[front_] = kNoStateId; + while ((front_ <= back_) && (state_[front_] == kNoStateId)) ++front_; + } + + void Update(StateId s) {} + + bool Empty() const { return front_ > back_; } + + void Clear() { + for (StateId i = front_; i <= back_; ++i) state_[i] = kNoStateId; + back_ = kNoStateId; + front_ = 0; + } + + private: + StateId front_; + StateId back_; + vector order_; + vector state_; + + // This allows base-class virtual access to non-virtual derived- + // class members of the same name. It makes the derived class more + // efficient to use but unsafe to further derive. + virtual StateId Head_() const { return Head(); } + virtual void Enqueue_(StateId s) { Enqueue(s); } + virtual void Dequeue_() { Dequeue(); } + virtual void Update_(StateId s) { Update(s); } + virtual bool Empty_() const { return Empty(); } + virtual void Clear_() { return Clear(); } +}; + + +// State order queue discipline, templated on the StateId. +// States are ordered in the queue by state Id. +template +class StateOrderQueue : public QueueBase { +public: + typedef S StateId; + + StateOrderQueue() + : QueueBase(STATE_ORDER_QUEUE), front_(0), back_(kNoStateId) {} + + StateId Head() const { return front_; } + + void Enqueue(StateId s) { + if (front_ > back_) front_ = back_ = s; + else if (s > back_) back_ = s; + else if (s < front_) front_ = s; + while (enqueued_.size() <= s) enqueued_.push_back(false); + enqueued_[s] = true; + } + + void Dequeue() { + enqueued_[front_] = false; + while ((front_ <= back_) && (enqueued_[front_] == false)) ++front_; + } + + void Update(StateId s) {} + + bool Empty() const { return front_ > back_; } + + void Clear() { + for (StateId i = front_; i <= back_; ++i) enqueued_[i] = false; + front_ = 0; + back_ = kNoStateId; + } + +private: + StateId front_; + StateId back_; + vector enqueued_; + + // This allows base-class virtual access to non-virtual derived- + // class members of the same name. It makes the derived class more + // efficient to use but unsafe to further derive. + virtual StateId Head_() const { return Head(); } + virtual void Enqueue_(StateId s) { Enqueue(s); } + virtual void Dequeue_() { Dequeue(); } + virtual void Update_(StateId s) { Update(s); } + virtual bool Empty_() const { return Empty(); } + virtual void Clear_() { return Clear(); } + +}; + + +// SCC topological-order meta-queue discipline, templated on the StateId S +// and a queue Q, which is used inside each SCC. It visits the SCC's +// of an FST in topological order. Its constructor is passed the queues to +// to use within an SCC. +template +class SccQueue : public QueueBase { + public: + typedef S StateId; + typedef Q Queue; + + // Constructor takes a vector specifying the SCC number per state + // and a vector giving the queue to use per SCC number. + SccQueue(const vector &scc, vector *queue) + : QueueBase(SCC_QUEUE), queue_(queue), scc_(scc), front_(0), + back_(kNoStateId) {} + + StateId Head() const { + while ((front_ <= back_) && + (((*queue_)[front_] && (*queue_)[front_]->Empty()) + || (((*queue_)[front_] == 0) && + ((front_ >= trivial_queue_.size()) + || (trivial_queue_[front_] == kNoStateId))))) + ++front_; + if ((*queue_)[front_]) + return (*queue_)[front_]->Head(); + else + return trivial_queue_[front_]; + } + + void Enqueue(StateId s) { + if (front_ > back_) front_ = back_ = scc_[s]; + else if (scc_[s] > back_) back_ = scc_[s]; + else if (scc_[s] < front_) front_ = scc_[s]; + if ((*queue_)[scc_[s]]) { + (*queue_)[scc_[s]]->Enqueue(s); + } else { + while (trivial_queue_.size() <= scc_[s]) + trivial_queue_.push_back(kNoStateId); + trivial_queue_[scc_[s]] = s; + } + } + + void Dequeue() { + if ((*queue_)[front_]) + (*queue_)[front_]->Dequeue(); + else if (front_ < trivial_queue_.size()) + trivial_queue_[front_] = kNoStateId; + } + + void Update(StateId s) { + if ((*queue_)[scc_[s]]) + (*queue_)[scc_[s]]->Update(s); + } + + bool Empty() const { + if (front_ < back_) // Queue scc # back_ not empty unless back_==front_ + return false; + else if (front_ > back_) + return true; + else if ((*queue_)[front_]) + return (*queue_)[front_]->Empty(); + else + return (front_ >= trivial_queue_.size()) + || (trivial_queue_[front_] == kNoStateId); + } + + void Clear() { + for (StateId i = front_; i <= back_; ++i) + if ((*queue_)[i]) + (*queue_)[i]->Clear(); + else if (i < trivial_queue_.size()) + trivial_queue_[i] = kNoStateId; + front_ = 0; + back_ = kNoStateId; + } + +private: + vector *queue_; + const vector &scc_; + mutable StateId front_; + StateId back_; + vector trivial_queue_; + + // This allows base-class virtual access to non-virtual derived- + // class members of the same name. It makes the derived class more + // efficient to use but unsafe to further derive. + virtual StateId Head_() const { return Head(); } + virtual void Enqueue_(StateId s) { Enqueue(s); } + virtual void Dequeue_() { Dequeue(); } + virtual void Update_(StateId s) { Update(s); } + virtual bool Empty_() const { return Empty(); } + virtual void Clear_() { return Clear(); } + + DISALLOW_COPY_AND_ASSIGN(SccQueue); +}; + + +// Automatic queue discipline, templated on the StateId. It selects a +// queue discipline for a given FST based on its properties. +template +class AutoQueue : public QueueBase { +public: + typedef S StateId; + + // This constructor takes a state distance vector that, if non-null and if + // the Weight type has the path property, will entertain the + // shortest-first queue using the natural order w.r.t to the distance. + template + AutoQueue(const Fst &fst, const vector *distance, + ArcFilter filter) : QueueBase(AUTO_QUEUE) { + typedef typename Arc::Weight Weight; + typedef StateWeightCompare< StateId, NaturalLess > Compare; + + // First check if the FST is known to have these properties. + uint64 props = fst.Properties(kAcyclic | kCyclic | + kTopSorted | kUnweighted, false); + if ((props & kTopSorted) || fst.Start() == kNoStateId) { + queue_ = new StateOrderQueue(); + VLOG(2) << "AutoQueue: using state-order discipline"; + } else if (props & kAcyclic) { + queue_ = new TopOrderQueue(fst, filter); + VLOG(2) << "AutoQueue: using top-order discipline"; + } else if ((props & kUnweighted) && (Weight::Properties() & kIdempotent)) { + queue_ = new LifoQueue(); + VLOG(2) << "AutoQueue: using LIFO discipline"; + } else { + uint64 properties; + // Decompose into strongly-connected components. + SccVisitor scc_visitor(&scc_, 0, 0, &properties); + DfsVisit(fst, &scc_visitor, filter); + StateId nscc = *max_element(scc_.begin(), scc_.end()) + 1; + vector queue_types(nscc); + NaturalLess *less = 0; + Compare *comp = 0; + if (distance && (Weight::Properties() & kPath)) { + less = new NaturalLess; + comp = new Compare(*distance, *less); + } + // Find the queue type to use per SCC. + bool unweighted; + bool all_trivial; + SccQueueType(fst, scc_, &queue_types, filter, less, &all_trivial, + &unweighted); + // If unweighted and semiring is idempotent, use lifo queue. + if (unweighted) { + queue_ = new LifoQueue(); + VLOG(2) << "AutoQueue: using LIFO discipline"; + delete comp; + delete less; + return; + } + // If all the scc are trivial, FST is acyclic and the scc# gives + // the topological order. + if (all_trivial) { + queue_ = new TopOrderQueue(scc_); + VLOG(2) << "AutoQueue: using top-order discipline"; + delete comp; + delete less; + return; + } + VLOG(2) << "AutoQueue: using SCC meta-discipline"; + queues_.resize(nscc); + for (StateId i = 0; i < nscc; ++i) { + switch(queue_types[i]) { + case TRIVIAL_QUEUE: + queues_[i] = 0; + VLOG(3) << "AutoQueue: SCC #" << i + << ": using trivial discipline"; + break; + case SHORTEST_FIRST_QUEUE: + queues_[i] = new ShortestFirstQueue(*comp); + VLOG(3) << "AutoQueue: SCC #" << i << + ": using shortest-first discipline"; + break; + case LIFO_QUEUE: + queues_[i] = new LifoQueue(); + VLOG(3) << "AutoQueue: SCC #" << i + << ": using LIFO disciplle"; + break; + case FIFO_QUEUE: + default: + queues_[i] = new FifoQueue(); + VLOG(3) << "AutoQueue: SCC #" << i + << ": using FIFO disciplle"; + break; + } + } + queue_ = new SccQueue< StateId, QueueBase >(scc_, &queues_); + delete comp; + delete less; + } + } + + ~AutoQueue() { + for (StateId i = 0; i < queues_.size(); ++i) + delete queues_[i]; + delete queue_; + } + + StateId Head() const { return queue_->Head(); } + + void Enqueue(StateId s) { queue_->Enqueue(s); } + + void Dequeue() { queue_->Dequeue(); } + + void Update(StateId s) { queue_->Update(s); } + + bool Empty() const { return queue_->Empty(); } + + void Clear() { queue_->Clear(); } + + + private: + QueueBase *queue_; + vector< QueueBase* > queues_; + vector scc_; + + template + static void SccQueueType(const Fst &fst, + const vector &scc, + vector *queue_types, + ArcFilter filter, Less *less, + bool *all_trivial, bool *unweighted); + + // This allows base-class virtual access to non-virtual derived- + // class members of the same name. It makes the derived class more + // efficient to use but unsafe to further derive. + virtual StateId Head_() const { return Head(); } + + virtual void Enqueue_(StateId s) { Enqueue(s); } + + virtual void Dequeue_() { Dequeue(); } + + virtual void Update_(StateId s) { Update(s); } + + virtual bool Empty_() const { return Empty(); } + + virtual void Clear_() { return Clear(); } + + DISALLOW_COPY_AND_ASSIGN(AutoQueue); +}; + + +// Examines the states in an Fst's strongly connected components and +// determines which type of queue to use per SCC. Stores result in +// vector QUEUE_TYPES, which is assumed to have length equal to the +// number of SCCs. An arc filter is used to limit the transitions +// considered (e.g., only the epsilon graph). ALL_TRIVIAL is set +// to true if every queue is the trivial queue. UNWEIGHTED is set to +// true if the semiring is idempotent and all the arc weights are equal to +// Zero() or One(). +template +template +void AutoQueue::SccQueueType(const Fst &fst, + const vector &scc, + vector *queue_type, + ArcFilter filter, Less *less, + bool *all_trivial, bool *unweighted) { + typedef A Arc; + typedef typename A::StateId StateId; + typedef typename A::Weight Weight; + + *all_trivial = true; + *unweighted = true; + + for (StateId i = 0; i < queue_type->size(); ++i) + (*queue_type)[i] = TRIVIAL_QUEUE; + + for (StateIterator< Fst > sit(fst); !sit.Done(); sit.Next()) { + StateId state = sit.Value(); + for (ArcIterator< Fst > ait(fst, state); + !ait.Done(); + ait.Next()) { + const Arc &arc = ait.Value(); + if (!filter(arc)) continue; + if (scc[state] == scc[arc.nextstate]) { + QueueType &type = (*queue_type)[scc[state]]; + if (!less || ((*less)(arc.weight, Weight::One()))) + type = FIFO_QUEUE; + else if ((type == TRIVIAL_QUEUE) || (type == LIFO_QUEUE)) { + if (!(Weight::Properties() & kIdempotent) || + (arc.weight != Weight::Zero() && arc.weight != Weight::One())) + type = SHORTEST_FIRST_QUEUE; + else + type = LIFO_QUEUE; + } + if (type != TRIVIAL_QUEUE) *all_trivial = false; + } + if (!(Weight::Properties() & kIdempotent) || + (arc.weight != Weight::Zero() && arc.weight != Weight::One())) + *unweighted = false; + } + } +} + + +// An A* estimate is a function object that maps from a state ID to a +// an estimate of the shortest distance to the final states. +// The trivial A* estimate is always One(). +template +struct TrivialAStarEstimate { + W operator()(S s) const { return W::One(); } +}; + + +// Given a vector that maps from states to weights representing the +// shortest distance from the initial state, a Less comparison +// function object between weights, and an estimate E of the +// shortest distance to the final states, this class defines a +// comparison function object between states. +template +class AStarWeightCompare { + public: + typedef L Less; + typedef typename L::Weight Weight; + typedef S StateId; + + AStarWeightCompare(const vector& weights, const L &less, + const E &estimate) + : weights_(weights), less_(less), estimate_(estimate) {} + + bool operator()(const S x, const S y) const { + Weight wx = Times(weights_[x], estimate_(x)); + Weight wy = Times(weights_[y], estimate_(y)); + return less_(wx, wy); + } + + private: + const vector& weights_; + L less_; + const E &estimate_; +}; + + +// A* queue discipline, templated on the StateId, Weight and an +// estimate E of the shortest distance to the final states, is specialized +// to use the weight's natural order for the comparison function. +template +class NaturalAStarQueue : + public ShortestFirstQueue, E> > { + public: + typedef AStarWeightCompare, E> C; + + NaturalAStarQueue(const vector &distance, const E &estimate) : + ShortestFirstQueue(C(distance, less_, estimate)) {} + + private: + NaturalLess less_; +}; + + +// A state equivalence class is a function object that +// maps from a state ID to an equivalence class (state) ID. +// The trivial equivalence class maps a state to itself. +template +struct TrivialStateEquivClass { + S operator()(S s) const { return s; } +}; + + +// Distance-based pruning queue discipline: Enqueues a state 's' +// only when its shortest distance (so far), as specified by +// 'distance', is less than (as specified by 'comp') the shortest +// distance Times() the 'threshold' to any state in the same +// equivalence class, as specified by the function object +// 'class_func'. The underlying queue discipline is specified by +// 'queue'. The ownership of 'queue' is given to this class. +template +class PruneQueue : public QueueBase { + public: + typedef typename Q::StateId StateId; + typedef typename L::Weight Weight; + + PruneQueue(const vector &distance, Q *queue, L comp, + const C &class_func, Weight threshold) + : QueueBase(OTHER_QUEUE), + distance_(distance), + queue_(queue), + less_(comp), + class_func_(class_func), + threshold_(threshold) {} + + ~PruneQueue() { delete queue_; } + + StateId Head() const { return queue_->Head(); } + + void Enqueue(StateId s) { + StateId c = class_func_(s); + if (c >= class_distance_.size()) + class_distance_.resize(c + 1, Weight::Zero()); + if (less_(distance_[s], class_distance_[c])) + class_distance_[c] = distance_[s]; + + // Enqueue only if below threshold limit + Weight limit = Times(class_distance_[c], threshold_); + if (less_(distance_[s], limit)) + queue_->Enqueue(s); + } + + void Dequeue() { queue_->Dequeue(); } + + void Update(StateId s) { + StateId c = class_func_(s); + if (less_(distance_[s], class_distance_[c])) + class_distance_[c] = distance_[s]; + queue_->Update(s); + } + + bool Empty() const { return queue_->Empty(); } + void Clear() { queue_->Clear(); } + + private: + // This allows base-class virtual access to non-virtual derived- + // class members of the same name. It makes the derived class more + // efficient to use but unsafe to further derive. + virtual StateId Head_() const { return Head(); } + virtual void Enqueue_(StateId s) { Enqueue(s); } + virtual void Dequeue_() { Dequeue(); } + virtual void Update_(StateId s) { Update(s); } + virtual bool Empty_() const { return Empty(); } + virtual void Clear_() { return Clear(); } + + const vector &distance_; // shortest distance to state + Q *queue_; + L less_; + const C &class_func_; // eqv. class function object + Weight threshold_; // pruning weight threshold + vector class_distance_; // shortest distance to class + + DISALLOW_COPY_AND_ASSIGN(PruneQueue); +}; + + +// Pruning queue discipline (see above) using the weight's natural +// order for the comparison function. The ownership of 'queue' is +// given to this class. +template +class NaturalPruneQueue : + public PruneQueue, C> { + public: + typedef typename Q::StateId StateId; + typedef W Weight; + + NaturalPruneQueue(const vector &distance, Q *queue, + const C &class_func_, Weight threshold) : + PruneQueue, C>(distance, queue, less_, + class_func_, threshold) {} + + private: + NaturalLess less_; +}; + + +// Filter-based pruning queue discipline: Enqueues a state 's' only +// if allowed by the filter, specified by the function object 'state_filter'. +// The underlying queue discipline is specified by 'queue'. The ownership +// of 'queue' is given to this class. +template +class FilterQueue : public QueueBase { + public: + typedef typename Q::StateId StateId; + + FilterQueue(Q *queue, const F &state_filter) + : QueueBase(OTHER_QUEUE), + queue_(queue), + state_filter_(state_filter) {} + + ~FilterQueue() { delete queue_; } + + StateId Head() const { return queue_->Head(); } + + // Enqueues only if allowed by state filter. + void Enqueue(StateId s) { + if (state_filter_(s)) { + queue_->Enqueue(s); + } + } + + void Dequeue() { queue_->Dequeue(); } + + void Update(StateId s) {} + bool Empty() const { return queue_->Empty(); } + void Clear() { queue_->Clear(); } + + private: + // This allows base-class virtual access to non-virtual derived- + // class members of the same name. It makes the derived class more + // efficient to use but unsafe to further derive. + virtual StateId Head_() const { return Head(); } + virtual void Enqueue_(StateId s) { Enqueue(s); } + virtual void Dequeue_() { Dequeue(); } + virtual void Update_(StateId s) { Update(s); } + virtual bool Empty_() const { return Empty(); } + virtual void Clear_() { return Clear(); } + + Q *queue_; + const F &state_filter_; // Filter to prune states + + DISALLOW_COPY_AND_ASSIGN(FilterQueue); +}; + +} // namespace fst + +#endif -- cgit v1.2.3