// shortest-path.h
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// Copyright 2005-2010 Google, Inc.
// Author: allauzen@google.com (Cyril Allauzen)
//
// \file
// Functions to find shortest paths in an FST.
#ifndef FST_LIB_SHORTEST_PATH_H__
#define FST_LIB_SHORTEST_PATH_H__
#include <functional>
#include <utility>
using std::pair; using std::make_pair;
#include <vector>
using std::vector;
#include <fst/cache.h>
#include <fst/determinize.h>
#include <fst/queue.h>
#include <fst/shortest-distance.h>
#include <fst/test-properties.h>
namespace fst {
template <class Arc, class Queue, class ArcFilter>
struct ShortestPathOptions
: public ShortestDistanceOptions<Arc, Queue, ArcFilter> {
typedef typename Arc::StateId StateId;
typedef typename Arc::Weight Weight;
size_t nshortest; // return n-shortest paths
bool unique; // only return paths with distinct input strings
bool has_distance; // distance vector already contains the
// shortest distance from the initial state
bool first_path; // Single shortest path stops after finding the first
// path to a final state. That path is the shortest path
// only when using the ShortestFirstQueue and
// only when all the weights in the FST are between
// One() and Zero() according to NaturalLess.
Weight weight_threshold; // pruning weight threshold.
StateId state_threshold; // pruning state threshold.
ShortestPathOptions(Queue *q, ArcFilter filt, size_t n = 1, bool u = false,
bool hasdist = false, float d = kDelta,
bool fp = false, Weight w = Weight::Zero(),
StateId s = kNoStateId)
: ShortestDistanceOptions<Arc, Queue, ArcFilter>(q, filt, kNoStateId, d),
nshortest(n), unique(u), has_distance(hasdist), first_path(fp),
weight_threshold(w), state_threshold(s) {}
};
// Shortest-path algorithm: normally not called directly; prefer
// 'ShortestPath' below with n=1. 'ofst' contains the shortest path in
// 'ifst'. 'distance' returns the shortest distances from the source
// state to each state in 'ifst'. 'opts' is used to specify options
// such as the queue discipline, the arc filter and delta.
//
// The shortest path is the lowest weight path w.r.t. the natural
// semiring order.
//
// The weights need to be right distributive and have the path (kPath)
// property.
template<class Arc, class Queue, class ArcFilter>
void SingleShortestPath(const Fst<Arc> &ifst,
MutableFst<Arc> *ofst,
vector<typename Arc::Weight> *distance,
ShortestPathOptions<Arc, Queue, ArcFilter> &opts) {
typedef typename Arc::StateId StateId;
typedef typename Arc::Weight Weight;
ofst->DeleteStates();
ofst->SetInputSymbols(ifst.InputSymbols());
ofst->SetOutputSymbols(ifst.OutputSymbols());
if (ifst.Start() == kNoStateId) {
if (ifst.Properties(kError, false)) ofst->SetProperties(kError, kError);
return;
}
vector<bool> enqueued;
vector<StateId> parent;
vector<Arc> arc_parent;
Queue *state_queue = opts.state_queue;
StateId source = opts.source == kNoStateId ? ifst.Start() : opts.source;
Weight f_distance = Weight::Zero();
StateId f_parent = kNoStateId;
distance->clear();
state_queue->Clear();
if (opts.nshortest != 1) {
FSTERROR() << "SingleShortestPath: for nshortest > 1, use ShortestPath"
<< " instead";
ofst->SetProperties(kError, kError);
return;
}
if (opts.weight_threshold != Weight::Zero() ||
opts.state_threshold != kNoStateId) {
FSTERROR() <<
"SingleShortestPath: weight and state thresholds not applicable";
ofst->SetProperties(kError, kError);
return;
}
if ((Weight::Properties() & (kPath | kRightSemiring))
!= (kPath | kRightSemiring)) {
FSTERROR() << "SingleShortestPath: Weight needs to have the path"
<< " property and be right distributive: " << Weight::Type();
ofst->SetProperties(kError, kError);
return;
}
while (distance->size() < source) {
distance->push_back(Weight::Zero());
enqueued.push_back(false);
parent.push_back(kNoStateId);
arc_parent.push_back(Arc(kNoLabel, kNoLabel, Weight::Zero(), kNoStateId));
}
distance->push_back(Weight::One());
parent.push_back(kNoStateId);
arc_parent.push_back(Arc(kNoLabel, kNoLabel, Weight::Zero(), kNoStateId));
state_queue->Enqueue(source);
enqueued.push_back(true);
while (!state_queue->Empty()) {
StateId s = state_queue->Head();
state_queue->Dequeue();
enqueued[s] = false;
Weight sd = (*distance)[s];
if (ifst.Final(s) != Weight::Zero()) {
Weight w = Times(sd, ifst.Final(s));
if (f_distance != Plus(f_distance, w)) {
f_distance = Plus(f_distance, w);
f_parent = s;
}
if (!f_distance.Member()) {
ofst->SetProperties(kError, kError);
return;
}
if (opts.first_path)
break;
}
for (ArcIterator< Fst<Arc> > aiter(ifst, s);
!aiter.Done();
aiter.Next()) {
const Arc &arc = aiter.Value();
while (distance->size() <= arc.nextstate) {
distance->push_back(Weight::Zero());
enqueued.push_back(false);
parent.push_back(kNoStateId);
arc_parent.push_back(Arc(kNoLabel, kNoLabel, Weight::Zero(),
kNoStateId));
}
Weight &nd = (*distance)[arc.nextstate];
Weight w = Times(sd, arc.weight);
if (nd != Plus(nd, w)) {
nd = Plus(nd, w);
if (!nd.Member()) {
ofst->SetProperties(kError, kError);
return;
}
parent[arc.nextstate] = s;
arc_parent[arc.nextstate] = arc;
if (!enqueued[arc.nextstate]) {
state_queue->Enqueue(arc.nextstate);
enqueued[arc.nextstate] = true;
} else {
state_queue->Update(arc.nextstate);
}
}
}
}
StateId s_p = kNoStateId, d_p = kNoStateId;
for (StateId s = f_parent, d = kNoStateId;
s != kNoStateId;
d = s, s = parent[s]) {
d_p = s_p;
s_p = ofst->AddState();
if (d == kNoStateId) {
ofst->SetFinal(s_p, ifst.Final(f_parent));
} else {
arc_parent[d].nextstate = d_p;
ofst->AddArc(s_p, arc_parent[d]);
}
}
ofst->SetStart(s_p);
if (ifst.Properties(kError, false)) ofst->