diff options
author | Determinant <[email protected]> | 2015-08-14 11:51:42 +0800 |
---|---|---|
committer | Determinant <[email protected]> | 2015-08-14 11:51:42 +0800 |
commit | 96a32415ab43377cf1575bd3f4f2980f58028209 (patch) | |
tree | 30a2d92d73e8f40ac87b79f6f56e227bfc4eea6e /kaldi_io/src/tools/openfst/include/fst/rmepsilon.h | |
parent | c177a7549bd90670af4b29fa813ddea32cfe0f78 (diff) |
add implementation for kaldi io (by ymz)
Diffstat (limited to 'kaldi_io/src/tools/openfst/include/fst/rmepsilon.h')
-rw-r--r-- | kaldi_io/src/tools/openfst/include/fst/rmepsilon.h | 600 |
1 files changed, 600 insertions, 0 deletions
diff --git a/kaldi_io/src/tools/openfst/include/fst/rmepsilon.h b/kaldi_io/src/tools/openfst/include/fst/rmepsilon.h new file mode 100644 index 0000000..89b8178 --- /dev/null +++ b/kaldi_io/src/tools/openfst/include/fst/rmepsilon.h @@ -0,0 +1,600 @@ +// rmepsilon.h + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Copyright 2005-2010 Google, Inc. +// Author: [email protected] (Cyril Allauzen) +// +// \file +// Functions and classes that implemement epsilon-removal. + +#ifndef FST_LIB_RMEPSILON_H__ +#define FST_LIB_RMEPSILON_H__ + +#include <tr1/unordered_map> +using std::tr1::unordered_map; +using std::tr1::unordered_multimap; +#include <fst/slist.h> +#include <stack> +#include <string> +#include <utility> +using std::pair; using std::make_pair; +#include <vector> +using std::vector; + +#include <fst/arcfilter.h> +#include <fst/cache.h> +#include <fst/connect.h> +#include <fst/factor-weight.h> +#include <fst/invert.h> +#include <fst/prune.h> +#include <fst/queue.h> +#include <fst/shortest-distance.h> +#include <fst/topsort.h> + + +namespace fst { + +template <class Arc, class Queue> +class RmEpsilonOptions + : public ShortestDistanceOptions<Arc, Queue, EpsilonArcFilter<Arc> > { + public: + typedef typename Arc::StateId StateId; + typedef typename Arc::Weight Weight; + + bool connect; // Connect output + Weight weight_threshold; // Pruning weight threshold. + StateId state_threshold; // Pruning state threshold. + + explicit RmEpsilonOptions(Queue *q, float d = kDelta, bool c = true, + Weight w = Weight::Zero(), + StateId n = kNoStateId) + : ShortestDistanceOptions< Arc, Queue, EpsilonArcFilter<Arc> >( + q, EpsilonArcFilter<Arc>(), kNoStateId, d), + connect(c), weight_threshold(w), state_threshold(n) {} + private: + RmEpsilonOptions(); // disallow +}; + +// Computation state of the epsilon-removal algorithm. +template <class Arc, class Queue> +class RmEpsilonState { + public: + typedef typename Arc::Label Label; + typedef typename Arc::StateId StateId; + typedef typename Arc::Weight Weight; + + RmEpsilonState(const Fst<Arc> &fst, + vector<Weight> *distance, + const RmEpsilonOptions<Arc, Queue> &opts) + : fst_(fst), distance_(distance), sd_state_(fst_, distance, opts, true), + expand_id_(0) {} + + // Compute arcs and final weight for state 's' + void Expand(StateId s); + + // Returns arcs of expanded state. + vector<Arc> &Arcs() { return arcs_; } + + // Returns final weight of expanded state. + const Weight &Final() const { return final_; } + + // Return true if an error has occured. + bool Error() const { return sd_state_.Error(); } + + private: + static const size_t kPrime0 = 7853; + static const size_t kPrime1 = 7867; + + struct Element { + Label ilabel; + Label olabel; + StateId nextstate; + + Element() {} + + Element(Label i, Label o, StateId s) + : ilabel(i), olabel(o), nextstate(s) {} + }; + + class ElementKey { + public: + size_t operator()(const Element& e) const { + return static_cast<size_t>(e.nextstate + + e.ilabel * kPrime0 + + e.olabel * kPrime1); + } + + private: + }; + + class ElementEqual { + public: + bool operator()(const Element &e1, const Element &e2) const { + return (e1.ilabel == e2.ilabel) && (e1.olabel == e2.olabel) + && (e1.nextstate == e2.nextstate); + } + }; + + typedef unordered_map<Element, pair<StateId, size_t>, + ElementKey, ElementEqual> ElementMap; + + const Fst<Arc> &fst_; + // Distance from state being expanded in epsilon-closure. + vector<Weight> *distance_; + // Shortest distance algorithm computation state. + ShortestDistanceState<Arc, Queue, EpsilonArcFilter<Arc> > sd_state_; + // Maps an element 'e' to a pair 'p' corresponding to a position + // in the arcs vector of the state being expanded. 'e' corresponds + // to the position 'p.second' in the 'arcs_' vector if 'p.first' is + // equal to the state being expanded. + ElementMap element_map_; + EpsilonArcFilter<Arc> eps_filter_; + stack<StateId> eps_queue_; // Queue used to visit the epsilon-closure + vector<bool> visited_; // '[i] = true' if state 'i' has been visited + slist<StateId> visited_states_; // List of visited states + vector<Arc> arcs_; // Arcs of state being expanded + Weight final_; // Final weight of state being expanded + StateId expand_id_; // Unique ID for each call to Expand + + DISALLOW_COPY_AND_ASSIGN(RmEpsilonState); +}; + +template <class Arc, class Queue> +const size_t RmEpsilonState<Arc, Queue>::kPrime0; +template <class Arc, class Queue> +const size_t RmEpsilonState<Arc, Queue>::kPrime1; + + +template <class Arc, class Queue> +void RmEpsilonState<Arc,Queue>::Expand(typename Arc::StateId source) { + final_ = Weight::Zero(); + arcs_.clear(); + sd_state_.ShortestDistance(source); + if (sd_state_.Error()) + return; + eps_queue_.push(source); + + while (!eps_queue_.empty()) { + StateId state = eps_queue_.top(); + eps_queue_.pop(); + + while (visited_.size() <= state) visited_.push_back(false); + if (visited_[state]) continue; + visited_[state] = true; + visited_states_.push_front(state); + + for (ArcIterator< Fst<Arc> > ait(fst_, state); + !ait.Done(); + ait.Next()) { + Arc arc = ait.Value(); + arc.weight = Times((*distance_)[state], arc.weight); + + if (eps_filter_(arc)) { + while (visited_.size() <= arc.nextstate) + visited_.push_back(false); + if (!visited_[arc.nextstate]) + eps_queue_.push(arc.nextstate); + } else { + Element element(arc.ilabel, arc.olabel, arc.nextstate); + typename ElementMap::iterator it = element_map_.find(element); + if (it == element_map_.end()) { + element_map_.insert( + pair<Element, pair<StateId, size_t> > + (element, pair<StateId, size_t>(expand_id_, arcs_.size()))); + arcs_.push_back(arc); + } else { + if (((*it).second).first == expand_id_) { + Weight &w = arcs_[((*it).second).second].weight; + w = Plus(w, arc.weight); + } else { + ((*it).second).first = expand_id_; + ((*it).second).second = arcs_.size(); + arcs_.push_back(arc); + } + } + } + } + final_ = Plus(final_, Times((*distance_)[state], fst_.Final(state))); + } + + while (!visited_states_.empty()) { + visited_[visited_states_.front()] = false; + visited_states_.pop_front(); + } + ++expand_id_; +} + +// Removes epsilon-transitions (when both the input and output label +// are an epsilon) from a transducer. The result will be an equivalent +// FST that has no such epsilon transitions. This version modifies +// its input. It allows fine control via the options argument; see +// below for a simpler interface. +// +// The vector 'distance' will be used to hold the shortest distances +// during the epsilon-closure computation. The state queue discipline +// and convergence delta are taken in the options argument. +template <class Arc, class Queue> +void RmEpsilon(MutableFst<Arc> *fst, + vector<typename Arc::Weight> *distance, + const RmEpsilonOptions<Arc, Queue> &opts) { + typedef typename Arc::StateId StateId; + typedef typename Arc::Weight Weight; + typedef typename Arc::Label Label; + + if (fst->Start() == kNoStateId) { + return; + } + + // 'noneps_in[s]' will be set to true iff 's' admits a non-epsilon + // incoming transition or is the start state. + vector<bool> noneps_in(fst->NumStates(), false); + noneps_in[fst->Start()] = true; + for (StateId i = 0; i < fst->NumStates(); ++i) { + for (ArcIterator<Fst<Arc> > aiter(*fst, i); + !aiter.Done(); + aiter.Next()) { + if (aiter.Value().ilabel != 0 || aiter.Value().olabel != 0) + noneps_in[aiter.Value().nextstate] = true; + } + } + + // States sorted in topological order when (acyclic) or generic + // topological order (cyclic). + vector<StateId> states; + states.reserve(fst->NumStates()); + + if (fst->Properties(kTopSorted, false) & kTopSorted) { + for (StateId i = 0; i < fst->NumStates(); i++) + states.push_back(i); + } else if (fst->Properties(kAcyclic, false) & kAcyclic) { + vector<StateId> order; + bool acyclic; + TopOrderVisitor<Arc> top_order_visitor(&order, &acyclic); + DfsVisit(*fst, &top_order_visitor, EpsilonArcFilter<Arc>()); + // Sanity check: should be acyclic if property bit is set. + if(!acyclic) { + FSTERROR() << "RmEpsilon: inconsistent acyclic property bit"; + fst->SetProperties(kError, kError); + return; + } + states.resize(order.size()); + for (StateId i = 0; i < order.size(); i++) + states[order[i]] = i; + } else { + uint64 props; + vector<StateId> scc; + SccVisitor<Arc> scc_visitor(&scc, 0, 0, &props); + DfsVisit(*fst, &scc_visitor, EpsilonArcFilter<Arc>()); + vector<StateId> first(scc.size(), kNoStateId); + vector<StateId> next(scc.size(), kNoStateId); + for (StateId i = 0; i < scc.size(); i++) { + if (first[scc[i]] != kNoStateId) + next[i] = first[scc[i]]; + first[scc[i]] = i; + } + for (StateId i = 0; i < first.size(); i++) + for (StateId j = first[i]; j != kNoStateId; j = next[j]) + states.push_back(j); + } + + RmEpsilonState<Arc, Queue> + rmeps_state(*fst, distance, opts); + + while (!states.empty()) { + StateId state = states.back(); + states.pop_back(); + if (!noneps_in[state]) + continue; + rmeps_state.Expand(state); + fst->SetFinal(state, rmeps_state.Final()); + fst->DeleteArcs(state); + vector<Arc> &arcs = rmeps_state.Arcs(); + fst->ReserveArcs(state, arcs.size()); + while (!arcs.empty()) { + fst->AddArc(state, arcs.back()); + arcs.pop_back(); + } + } + + for (StateId s = 0; s < fst->NumStates(); ++s) { + if (!noneps_in[s]) + fst->DeleteArcs(s); + } + + if(rmeps_state.Error()) + fst->SetProperties(kError, kError); + fst->SetProperties( + RmEpsilonProperties(fst->Properties(kFstProperties, false)), + kFstProperties); + + if (opts.weight_threshold != Weight::Zero() || + opts.state_threshold != kNoStateId) + Prune(fst, opts.weight_threshold, opts.state_threshold); + if (opts.connect && (opts.weight_threshold == Weight::Zero() || + opts.state_threshold != kNoStateId)) + Connect(fst); +} + +// Removes epsilon-transitions (when both the input and output label +// are an epsilon) from a transducer. The result will be an equivalent +// FST that has no such epsilon transitions. This version modifies its +// input. It has a simplified interface; see above for a version that +// allows finer control. +// +// Complexity: +// - Time: +// - Unweighted: O(V2 + V E) +// - Acyclic: O(V2 + V E) +// - Tropical semiring: O(V2 log V + V E) +// - General: exponential +// - Space: O(V E) +// where V = # of states visited, E = # of arcs. +// +// References: +// - Mehryar Mohri. Generic Epsilon-Removal and Input +// Epsilon-Normalization Algorithms for Weighted Transducers, +// "International Journal of Computer Science", 13(1):129-143 (2002). +template <class Arc> +void RmEpsilon(MutableFst<Arc> *fst, + bool connect = true, + typename Arc::Weight weight_threshold = Arc::Weight::Zero(), + typename Arc::StateId state_threshold = kNoStateId, + float delta = kDelta) { + typedef typename Arc::StateId StateId; + typedef typename Arc::Weight Weight; + typedef typename Arc::Label Label; + + vector<Weight> distance; + AutoQueue<StateId> state_queue(*fst, &distance, EpsilonArcFilter<Arc>()); + RmEpsilonOptions<Arc, AutoQueue<StateId> > + opts(&state_queue, delta, connect, weight_threshold, state_threshold); + + RmEpsilon(fst, &distance, opts); +} + + +struct RmEpsilonFstOptions : CacheOptions { + float delta; + + RmEpsilonFstOptions(const CacheOptions &opts, float delta = kDelta) + : CacheOptions(opts), delta(delta) {} + + explicit RmEpsilonFstOptions(float delta = kDelta) : delta(delta) {} +}; + + +// Implementation of delayed RmEpsilonFst. +template <class A> +class RmEpsilonFstImpl : public CacheImpl<A> { + public: + using FstImpl<A>::SetType; + using FstImpl<A>::SetProperties; + using FstImpl<A>::SetInputSymbols; + using FstImpl<A>::SetOutputSymbols; + + using CacheBaseImpl< CacheState<A> >::PushArc; + using CacheBaseImpl< CacheState<A> >::HasArcs; + using CacheBaseImpl< CacheState<A> >::HasFinal; + using CacheBaseImpl< CacheState<A> >::HasStart; + using CacheBaseImpl< CacheState<A> >::SetArcs; + using CacheBaseImpl< CacheState<A> >::SetFinal; + using CacheBaseImpl< CacheState<A> >::SetStart; + + typedef typename A::Label Label; + typedef typename A::Weight Weight; + typedef typename A::StateId StateId; + typedef CacheState<A> State; + + RmEpsilonFstImpl(const Fst<A>& fst, const RmEpsilonFstOptions &opts) + : CacheImpl<A>(opts), + fst_(fst.Copy()), + delta_(opts.delta), + rmeps_state_( + *fst_, + &distance_, + RmEpsilonOptions<A, FifoQueue<StateId> >(&queue_, delta_, false)) { + SetType("rmepsilon"); + uint64 props = fst.Properties(kFstProperties, false); + SetProperties(RmEpsilonProperties(props, true), kCopyProperties); + SetInputSymbols(fst.InputSymbols()); + SetOutputSymbols(fst.OutputSymbols()); + } + + RmEpsilonFstImpl(const RmEpsilonFstImpl &impl) + : CacheImpl<A>(impl), + fst_(impl.fst_->Copy(true)), + delta_(impl.delta_), + rmeps_state_( + *fst_, + &distance_, + RmEpsilonOptions<A, FifoQueue<StateId> >(&queue_, delta_, false)) { + SetType("rmepsilon"); + SetProperties(impl.Properties(), kCopyProperties); + SetInputSymbols(impl.InputSymbols()); + SetOutputSymbols(impl.OutputSymbols()); + } + + ~RmEpsilonFstImpl() { + delete fst_; + } + + StateId Start() { + if (!HasStart()) { + SetStart(fst_->Start()); + } + return CacheImpl<A>::Start(); + } + + Weight Final(StateId s) { + if (!HasFinal(s)) { + Expand(s); + } + return CacheImpl<A>::Final(s); + } + + size_t NumArcs(StateId s) { + if (!HasArcs(s)) + Expand(s); + return CacheImpl<A>::NumArcs(s); + } + + size_t NumInputEpsilons(StateId s) { + if (!HasArcs(s)) + Expand(s); + return CacheImpl<A>::NumInputEpsilons(s); + } + + size_t NumOutputEpsilons(StateId s) { + if (!HasArcs(s)) + Expand(s); + return CacheImpl<A>::NumOutputEpsilons(s); + } + + uint64 Properties() const { return Properties(kFstProperties); } + + // Set error if found; return FST impl properties. + uint64 Properties(uint64 mask) const { + if ((mask & kError) && + (fst_->Properties(kError, false) || rmeps_state_.Error())) + SetProperties(kError, kError); + return FstImpl<A>::Properties(mask); + } + + void InitArcIterator(StateId s, ArcIteratorData<A> *data) { + if (!HasArcs(s)) + Expand(s); + CacheImpl<A>::InitArcIterator(s, data); + } + + void Expand(StateId s) { + rmeps_state_.Expand(s); + SetFinal(s, rmeps_state_.Final()); + vector<A> &arcs = rmeps_state_.Arcs(); + while (!arcs.empty()) { + PushArc(s, arcs.back()); + arcs.pop_back(); + } + SetArcs(s); + } + + private: + const Fst<A> *fst_; + float delta_; + vector<Weight> distance_; + FifoQueue<StateId> queue_; + RmEpsilonState<A, FifoQueue<StateId> > rmeps_state_; + + void operator=(const RmEpsilonFstImpl<A> &); // disallow +}; + + +// Removes epsilon-transitions (when both the input and output label +// are an epsilon) from a transducer. The result will be an equivalent +// FST that has no such epsilon transitions. This version is a +// delayed Fst. +// +// Complexity: +// - Time: +// - Unweighted: O(v^2 + v e) +// - General: exponential +// - Space: O(v e) +// where v = # of states visited, e = # of arcs visited. Constant time +// to visit an input state or arc is assumed and exclusive of caching. +// +// References: +// - Mehryar Mohri. Generic Epsilon-Removal and Input +// Epsilon-Normalization Algorithms for Weighted Transducers, +// "International Journal of Computer Science", 13(1):129-143 (2002). +// +// This class attaches interface to implementation and handles +// reference counting, delegating most methods to ImplToFst. +template <class A> +class RmEpsilonFst : public ImplToFst< RmEpsilonFstImpl<A> > { + public: + friend class ArcIterator< RmEpsilonFst<A> >; + friend class StateIterator< RmEpsilonFst<A> >; + + typedef A Arc; + typedef typename A::StateId StateId; + typedef CacheState<A> State; + typedef RmEpsilonFstImpl<A> Impl; + + RmEpsilonFst(const Fst<A> &fst) + : ImplToFst<Impl>(new Impl(fst, RmEpsilonFstOptions())) {} + + RmEpsilonFst(const Fst<A> &fst, const RmEpsilonFstOptions &opts) + : ImplToFst<Impl>(new Impl(fst, opts)) {} + + // See Fst<>::Copy() for doc. + RmEpsilonFst(const RmEpsilonFst<A> &fst, bool safe = false) + : ImplToFst<Impl>(fst, safe) {} + + // Get a copy of this RmEpsilonFst. See Fst<>::Copy() for further doc. + virtual RmEpsilonFst<A> *Copy(bool safe = false) const { + return new RmEpsilonFst<A>(*this, safe); + } + + virtual inline void InitStateIterator(StateIteratorData<A> *data) const; + + virtual void InitArcIterator(StateId s, ArcIteratorData<Arc> *data) const { + GetImpl()->InitArcIterator(s, data); + } + + private: + // Makes visible to friends. + Impl *GetImpl() const { return ImplToFst<Impl>::GetImpl(); } + + void operator=(const RmEpsilonFst<A> &fst); // disallow +}; + +// Specialization for RmEpsilonFst. +template<class A> +class StateIterator< RmEpsilonFst<A> > + : public CacheStateIterator< RmEpsilonFst<A> > { + public: + explicit StateIterator(const RmEpsilonFst<A> &fst) + : CacheStateIterator< RmEpsilonFst<A> >(fst, fst.GetImpl()) {} +}; + + +// Specialization for RmEpsilonFst. +template <class A> +class ArcIterator< RmEpsilonFst<A> > + : public CacheArcIterator< RmEpsilonFst<A> > { + public: + typedef typename A::StateId StateId; + + ArcIterator(const RmEpsilonFst<A> &fst, StateId s) + : CacheArcIterator< RmEpsilonFst<A> >(fst.GetImpl(), s) { + if (!fst.GetImpl()->HasArcs(s)) + fst.GetImpl()->Expand(s); + } + + private: + DISALLOW_COPY_AND_ASSIGN(ArcIterator); +}; + + +template <class A> inline +void RmEpsilonFst<A>::InitStateIterator(StateIteratorData<A> *data) const { + data->base = new StateIterator< RmEpsilonFst<A> >(*this); +} + + +// Useful alias when using StdArc. +typedef RmEpsilonFst<StdArc> StdRmEpsilonFst; + +} // namespace fst + +#endif // FST_LIB_RMEPSILON_H__ |