// label_reachable.h
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// Copyright 2005-2010 Google, Inc.
// Author: riley@google.com (Michael Riley)
//
// \file
// Class to determine if a non-epsilon label can be read as the
// first non-epsilon symbol along some path from a given state.
#ifndef FST_LIB_LABEL_REACHABLE_H__
#define FST_LIB_LABEL_REACHABLE_H__
#include <tr1/unordered_map>
using std::tr1::unordered_map;
using std::tr1::unordered_multimap;
#include <vector>
using std::vector;
#include <fst/accumulator.h>
#include <fst/arcsort.h>
#include <fst/interval-set.h>
#include <fst/state-reachable.h>
#include <fst/vector-fst.h>
namespace fst {
// Stores shareable data for label reachable class copies.
template <typename L>
class LabelReachableData {
public:
typedef L Label;
typedef typename IntervalSet<L>::Interval Interval;
explicit LabelReachableData(bool reach_input, bool keep_relabel_data = true)
: reach_input_(reach_input),
keep_relabel_data_(keep_relabel_data),
have_relabel_data_(true),
final_label_(kNoLabel) {}
~LabelReachableData() {}
bool ReachInput() const { return reach_input_; }
vector< IntervalSet<L> > *IntervalSets() { return &isets_; }
unordered_map<L, L> *Label2Index() {
if (!have_relabel_data_)
FSTERROR() << "LabelReachableData: no relabeling data";
return &label2index_;
}
Label FinalLabel() {
if (final_label_ == kNoLabel)
final_label_ = label2index_[kNoLabel];
return final_label_;
}
static LabelReachableData<L> *Read(istream &istrm) {
LabelReachableData<L> *data = new LabelReachableData<L>();
ReadType(istrm, &data->reach_input_);
ReadType(istrm, &data->keep_relabel_data_);
data->have_relabel_data_ = data->keep_relabel_data_;
if (data->keep_relabel_data_)
ReadType(istrm, &data->label2index_);
ReadType(istrm, &data->final_label_);
ReadType(istrm, &data->isets_);
return data;
}
bool Write(ostream &ostrm) {
WriteType(ostrm, reach_input_);
WriteType(ostrm, keep_relabel_data_);
if (keep_relabel_data_)
WriteType(ostrm, label2index_);
WriteType(ostrm, FinalLabel());
WriteType(ostrm, isets_);
return true;
}
int RefCount() const { return ref_count_.count(); }
int IncrRefCount() { return ref_count_.Incr(); }
int DecrRefCount() { return ref_count_.Decr(); }
private:
LabelReachableData() {}
bool reach_input_; // Input or output labels considered?
bool keep_relabel_data_; // Save label2index_ to file?
bool have_relabel_data_; // Using label2index_?
Label final_label_; // Final label
RefCounter ref_count_; // Reference count.
unordered_map<L, L> label2index_; // Finds index for a label.
vector<IntervalSet <L> > isets_; // Interval sets per state.
DISALLOW_COPY_AND_ASSIGN(LabelReachableData);
};
// Tests reachability of labels from a given state. If reach_input =
// true, then input labels are considered, o.w. output labels are
// considered. To test for reachability from a state s, first do
// SetState(s). Then a label l can be reached from state s of FST f
// iff Reach(r) is true where r = Relabel(l). The relabeling is
// required to ensure a compact representation of the reachable
// labels.
// The whole FST can be relabeled instead with Relabel(&f,
// reach_input) so that the test Reach(r) applies directly to the
// labels of the transformed FST f. The relabeled FST will also be
// sorted appropriately for composition.
//
// Reachablity of a final state from state s (via an epsilon path)
// can be tested with ReachFinal();
//
// Reachability can also be tested on the set of labels specified by
// an arc iterator, useful for FST composition. In particular,
// Reach(aiter, ...) is true if labels on the input (output) side of
// the transitions of the arc iterator, when iter_input is true
// (false), can be reached from the state s. The iterator labels must
// have already been relabeled.
//
// With the arc iterator test of reachability, the begin position, end
// position and accumulated arc weight of the matches can be
// returned. The optional template argument controls how reachable arc
// weights are accumulated. The default uses the semiring
// Plus(). Alternative ones can be used to distribute the weights in
// composition in various ways.
template <class A, class S = DefaultAccumulator<A> >
class LabelReachable {
public:
typedef A Arc;
typedef typename A::StateId StateId;
typedef typename A::Label Label;
typedef typename A::Weight Weight;
typedef typename IntervalSet<Label>::Interval Interval;
LabelReachable(const Fst<A> &fst, bool reach_input, S *s = 0,
bool keep_relabel_data = true)
: fst_(new VectorFst<Arc>(fst)),
s_(kNoStateId),
data_(new LabelReachableData<Label>(reach_input, keep_relabel_data)),
accumulator_(s ? s : new S()),
ncalls_(0),
nintervals_(0),
error_(false) {
StateId ins = fst_->NumStates();
TransformFst();
FindIntervals(ins);
delete fst_;
}
explicit LabelReachable(LabelReachableData<Label> *data, S *s = 0)
: fst_(0),
s_(kNoStateId),
data_(data),
accumulator_(s ? s : new S()),
ncalls_(0),
nintervals_(0),
error_(false) {
data_->IncrRefCount();
}
LabelReachable(const LabelReachable<A, S> &reachable) :
fst_(0),
s_(kNoStateId),
data_(reachable.data_),
accumulator_(new S(*reachable.accumulator_)),
ncalls_(0),
nintervals_(0),
error_(reachable.error_) {
data_->IncrRefCount();
}
~LabelReachable() {
if (!data_->DecrRefCount())
delete data_;
delete accumulator_;
if (ncalls_ > 0) {
VLOG(2) << "# of calls: " << ncalls_;
VLOG(2) << "# of intervals/call: " << (nintervals_ / ncalls_);
}
}
// Relabels w.r.t labels that give compact label sets.
Label Relabel(Label label) {
if (label == 0 || error_)
return label;
unordered_map<Label, Label> &label2index = *data_->Label2Index();
Label &relabel = label2index[label];
if (!relabel) // Add new label
relabel = label2index.size() + 1;
return relabel;
}
// Relabels Fst w.r.t to labels that give compact label sets.
void Relabel(MutableFst<Arc> *fst, bool relabel_input) {
for (StateIterator< MutableFst<Arc> > siter(*fst);
!siter.Done(); siter.Next()) {
StateId s = siter.Value();
for (MutableArcIterator< MutableFst<Arc> > aiter(fst, s);
!aiter.Done();
aiter.Next()) {
Arc arc = aiter.Value();
if (relabel_input)
arc.ilabel = Relabel(arc.ilabel);
else
arc.olabel = Relabel(arc.olabel);
aiter.SetValue(arc);
}
}
if (relabel_input) {
ArcSort(fst, ILabelCompare<Arc>());
fst->SetInputSymbols(0);
} else {
<