// rmepsilon.h
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// Copyright 2005-2010 Google, Inc.
// Author: allauzen@google.com (Cyril Allauzen)
//
// \file
// Functions and classes that implemement epsilon-removal.
#ifndef FST_LIB_RMEPSILON_H__
#define FST_LIB_RMEPSILON_H__
#include <tr1/unordered_map>
using std::tr1::unordered_map;
using std::tr1::unordered_multimap;
#include <fst/slist.h>
#include <stack>
#include <string>
#include <utility>
using std::pair; using std::make_pair;
#include <vector>
using std::vector;
#include <fst/arcfilter.h>
#include <fst/cache.h>
#include <fst/connect.h>
#include <fst/factor-weight.h>
#include <fst/invert.h>
#include <fst/prune.h>
#include <fst/queue.h>
#include <fst/shortest-distance.h>
#include <fst/topsort.h>
namespace fst {
template <class Arc, class Queue>
class RmEpsilonOptions
: public ShortestDistanceOptions<Arc, Queue, EpsilonArcFilter<Arc> > {
public:
typedef typename Arc::StateId StateId;
typedef typename Arc::Weight Weight;
bool connect; // Connect output
Weight weight_threshold; // Pruning weight threshold.
StateId state_threshold; // Pruning state threshold.
explicit RmEpsilonOptions(Queue *q, float d = kDelta, bool c = true,
Weight w = Weight::Zero(),
StateId n = kNoStateId)
: ShortestDistanceOptions< Arc, Queue, EpsilonArcFilter<Arc> >(
q, EpsilonArcFilter<Arc>(), kNoStateId, d),
connect(c), weight_threshold(w), state_threshold(n) {}
private:
RmEpsilonOptions(); // disallow
};
// Computation state of the epsilon-removal algorithm.
template <class Arc, class Queue>
class RmEpsilonState {
public:
typedef typename Arc::Label Label;
typedef typename Arc::StateId StateId;
typedef typename Arc::Weight Weight;
RmEpsilonState(const Fst<Arc> &fst,
vector<Weight> *distance,
const RmEpsilonOptions<Arc, Queue> &opts)
: fst_(fst), distance_(distance), sd_state_(fst_, distance, opts, true),
expand_id_(0) {}
// Compute arcs and final weight for state 's'
void Expand(StateId s);
// Returns arcs of expanded state.
vector<Arc> &Arcs() { return arcs_; }
// Returns final weight of expanded state.
const Weight &Final() const { return final_; }
// Return true if an error has occured.
bool Error() const { return sd_state_.Error(); }
private:
static const size_t kPrime0 = 7853;
static const size_t kPrime1 = 7867;
struct Element {
Label ilabel;
Label olabel;
StateId nextstate;
Element() {}
Element(Label i, Label o, StateId s)
: ilabel(i), olabel(o), nextstate(s) {}
};
class ElementKey {
public:
size_t operator()(const Element& e) const {
return static_cast<size_t>(e.nextstate +
e.ilabel * kPrime0 +
e.olabel * kPrime1);
}
private:
};
class ElementEqual {
public:
bool operator()(const Element &e1, const Element &e2) const {
return (e1.ilabel == e2.ilabel) && (e1.olabel == e2.olabel)
&& (e1.nextstate == e2.nextstate);
}
};
typedef unordered_map<Element, pair<StateId, size_t>,
ElementKey, ElementEqual> ElementMap;
const Fst<Arc> &fst_;
// Distance from state being expanded in epsilon-closure.
vector<Weight> *distance_;
// Shortest distance algorithm computation state.
ShortestDistanceState<Arc, Queue, EpsilonArcFilter<Arc> > sd_state_;
// Maps an element 'e' to a pair 'p' corresponding to a position
// in the arcs vector of the state being expanded. 'e' corresponds
// to the position 'p.second' in the 'arcs_' vector if 'p.first' is
// equal to the state being expanded.
ElementMap element_map_;
EpsilonArcFilter<Arc> eps_filter_;
stack<StateId> eps_queue_; // Queue used to visit the epsilon-closure
vector<bool> visited_; // '[i] = true' if state 'i' has been visited
slist<StateId> visited_states_; // List of visited states
vector<Arc> arcs_; // Arcs of state being expanded
Weight final_; // Final weight of state being expanded
StateId expand_id_; // Unique ID for each call to Expand
DISALLOW_COPY_AND_ASSIGN(RmEpsilonState);
};
template <class Arc, class Queue>
const size_t RmEpsilonState<Arc, Queue>::kPrime0;
template <class Arc, class Queue>
const size_t RmEpsilonState<Arc, Queue>::kPrime1;
template <class Arc, class Queue>
void RmEpsilonState<Arc,Queue>::Expand(typename Arc::StateId source) {
final_ = Weight::Zero();
arcs_.clear();
sd_state_.ShortestDistance(source);
if (sd_state_.Error())
return;
eps_queue_.push(source);
while (!eps_queue_.empty()) {
StateId state = eps_queue_.top();
eps_queue_.pop();
while (visited_.size() <= state) visited_.push_back(false);
if (visited_[state]) continue;
visited_[state] = true;
visited_states_.push_front(state);
for (ArcIterator< Fst<Arc> > ait(fst_, state);
!ait.Done();
ait.Next()) {
Arc arc = ait.Value();
arc.weight = Times((*distance_)[state], arc.weight);
if (eps_filter_(arc)) {
while (visited_.size() <= arc.nextstate)
visited_.push_back(false);
if (!visited_[arc.nextstate])
eps_queue_.push(arc.nextstate);
} else {
Element element(arc.ilabel, arc.olabel, arc.nextstate);
typename ElementMap::iterator it = element_map_.find(element);
if (it == element_map_.end()) {
element_map_.insert(
pair<Element, pair<StateId, size_t> >
(element, pair<StateId, size_t>(expand_id_, arcs_.size())));
arcs_.push_back(arc);
} else {
if (((*it).second).first == expand_id_) {
Weight &w =