// randgen.h
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// Copyright 2005-2010 Google, Inc.
// Author: riley@google.com (Michael Riley)
//
// \file
// Classes and functions to generate random paths through an FST.
#ifndef FST_LIB_RANDGEN_H__
#define FST_LIB_RANDGEN_H__
#include <cmath>
#include <cstdlib>
#include <ctime>
#include <map>
#include <fst/accumulator.h>
#include <fst/cache.h>
#include <fst/dfs-visit.h>
#include <fst/mutable-fst.h>
namespace fst {
//
// ARC SELECTORS - these function objects are used to select a random
// transition to take from an FST's state. They should return a number
// N s.t. 0 <= N <= NumArcs(). If N < NumArcs(), then the N-th
// transition is selected. If N == NumArcs(), then the final weight at
// that state is selected (i.e., the 'super-final' transition is selected).
// It can be assumed these will not be called unless either there
// are transitions leaving the state and/or the state is final.
//
// Randomly selects a transition using the uniform distribution.
template <class A>
struct UniformArcSelector {
typedef typename A::StateId StateId;
typedef typename A::Weight Weight;
UniformArcSelector(int seed = time(0)) { srand(seed); }
size_t operator()(const Fst<A> &fst, StateId s) const {
double r = rand()/(RAND_MAX + 1.0);
size_t n = fst.NumArcs(s);
if (fst.Final(s) != Weight::Zero())
++n;
return static_cast<size_t>(r * n);
}
};
// Randomly selects a transition w.r.t. the weights treated as negative
// log probabilities after normalizing for the total weight leaving
// the state. Weight::zero transitions are disregarded.
// Assumes Weight::Value() accesses the floating point
// representation of the weight.
template <class A>
class LogProbArcSelector {
public:
typedef typename A::StateId StateId;
typedef typename A::Weight Weight;
LogProbArcSelector(int seed = time(0)) { srand(seed); }
size_t operator()(const Fst<A> &fst, StateId s) const {
// Find total weight leaving state
double sum = 0.0;
for (ArcIterator< Fst<A> > aiter(fst, s); !aiter.Done();
aiter.Next()) {
const A &arc = aiter.Value();
sum += exp(-to_log_weight_(arc.weight).Value());
}
sum += exp(-to_log_weight_(fst.Final(s)).Value());
double r = rand()/(RAND_MAX + 1.0);
double p = 0.0;
int n = 0;
for (ArcIterator< Fst<A> > aiter(fst, s); !aiter.Done();
aiter.Next(), ++n) {
const A &arc = aiter.Value();
p += exp(-to_log_weight_(arc.weight).Value());
if (p > r * sum) return n;
}
return n;
}
private:
WeightConvert<Weight, Log64Weight> to_log_weight_;
};
// Convenience definitions
typedef LogProbArcSelector<StdArc> StdArcSelector;
typedef LogProbArcSelector<LogArc> LogArcSelector;
// Same as LogProbArcSelector but use CacheLogAccumulator to cache
// the cummulative weight computations.
template <class A>
class FastLogProbArcSelector : public LogProbArcSelector<A> {
public:
typedef typename A::StateId StateId;
typedef typename A::Weight Weight;
using LogProbArcSelector<A>::operator();
FastLogProbArcSelector(int seed = time(0))
: LogProbArcSelector<A>(seed),
seed_(seed) {}
size_t operator()(const Fst<A> &fst, StateId s,
CacheLogAccumulator<A> *accumulator) const {
accumulator->SetState(s);
ArcIterator< Fst<A> > aiter(fst, s);
// Find total weight leaving state
double sum = to_log_weight_(accumulator->Sum(fst.Final(s), &aiter, 0,
fst.NumArcs(s))).Value();
double r = -log(rand()/(RAND_MAX + 1.0));
return accumulator->LowerBound(r + sum, &aiter);
}
int Seed() const { return seed_; }
private:
int seed_;
WeightConvert<Weight, Log64Weight> to_log_weight_;
};
// Random path state info maintained by RandGenFst and passed to samplers.
template <typename A>
struct RandState {
typedef typename A::StateId StateId;
StateId state_id; // current input FST state
size_t nsamples; // # of samples to be sampled at this state
size_t length; // length of path to this random state
size_t select; // previous sample arc selection
const RandState<A> *parent; // previous random state on this path
RandState(StateId s, size_t n, size_t l, size_t k, const RandState<A> *p)
: state_id(s), nsamples(n), length(l), select(k), parent(p) {}
RandState()
: state_id(kNoStateId), nsamples(0), length(0), select(0), parent(0) {}
};
// This class, given an arc selector, samples, with raplacement,
// multiple random transitions from an FST's state. This is a generic
// version with a straight-forward use of the arc selector.
// Specializations may be defined for arc selectors for greater
// efficiency or special behavior.
template <class A, class S>
class ArcSampler {
public:
typedef typename A::StateId StateId;
typedef typename A::Weight Weight;
// The 'max_length' may be interpreted (including ignored) by a
// sampler as it chooses. This generic version interprets this literally.
ArcSampler(const Fst<A> &fst, const S &arc_selector,
int max_length = INT_MAX)
: fst_(fst),
arc_selector_(arc_selector),
max_length_(max_length) {}
// Allow updating Fst argument; pass only if changed.
ArcSampler(const ArcSampler<A, S> &sampler, const Fst<A> *fst = 0)
: fst_(fst ? *fst : sampler.fst_),
arc_selector_(sampler.arc_selector_),
max_length_(sampler.max_length_) {
Reset();
}
// Samples 'rstate.nsamples' from state 'state_id'. The 'rstate.length' is
// the length of the path to 'rstate'. Retu