diff options
Diffstat (limited to 'kaldi_io/src/tools/openfst/include/fst/randgen.h')
-rw-r--r-- | kaldi_io/src/tools/openfst/include/fst/randgen.h | 712 |
1 files changed, 712 insertions, 0 deletions
diff --git a/kaldi_io/src/tools/openfst/include/fst/randgen.h b/kaldi_io/src/tools/openfst/include/fst/randgen.h new file mode 100644 index 0000000..82ddffa --- /dev/null +++ b/kaldi_io/src/tools/openfst/include/fst/randgen.h @@ -0,0 +1,712 @@ +// randgen.h + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Copyright 2005-2010 Google, Inc. +// Author: [email protected] (Michael Riley) +// +// \file +// Classes and functions to generate random paths through an FST. + +#ifndef FST_LIB_RANDGEN_H__ +#define FST_LIB_RANDGEN_H__ + +#include <cmath> +#include <cstdlib> +#include <ctime> +#include <map> + +#include <fst/accumulator.h> +#include <fst/cache.h> +#include <fst/dfs-visit.h> +#include <fst/mutable-fst.h> + +namespace fst { + +// +// ARC SELECTORS - these function objects are used to select a random +// transition to take from an FST's state. They should return a number +// N s.t. 0 <= N <= NumArcs(). If N < NumArcs(), then the N-th +// transition is selected. If N == NumArcs(), then the final weight at +// that state is selected (i.e., the 'super-final' transition is selected). +// It can be assumed these will not be called unless either there +// are transitions leaving the state and/or the state is final. +// + +// Randomly selects a transition using the uniform distribution. +template <class A> +struct UniformArcSelector { + typedef typename A::StateId StateId; + typedef typename A::Weight Weight; + + UniformArcSelector(int seed = time(0)) { srand(seed); } + + size_t operator()(const Fst<A> &fst, StateId s) const { + double r = rand()/(RAND_MAX + 1.0); + size_t n = fst.NumArcs(s); + if (fst.Final(s) != Weight::Zero()) + ++n; + return static_cast<size_t>(r * n); + } +}; + + +// Randomly selects a transition w.r.t. the weights treated as negative +// log probabilities after normalizing for the total weight leaving +// the state. Weight::zero transitions are disregarded. +// Assumes Weight::Value() accesses the floating point +// representation of the weight. +template <class A> +class LogProbArcSelector { + public: + typedef typename A::StateId StateId; + typedef typename A::Weight Weight; + + LogProbArcSelector(int seed = time(0)) { srand(seed); } + + size_t operator()(const Fst<A> &fst, StateId s) const { + // Find total weight leaving state + double sum = 0.0; + for (ArcIterator< Fst<A> > aiter(fst, s); !aiter.Done(); + aiter.Next()) { + const A &arc = aiter.Value(); + sum += exp(-to_log_weight_(arc.weight).Value()); + } + sum += exp(-to_log_weight_(fst.Final(s)).Value()); + + double r = rand()/(RAND_MAX + 1.0); + double p = 0.0; + int n = 0; + for (ArcIterator< Fst<A> > aiter(fst, s); !aiter.Done(); + aiter.Next(), ++n) { + const A &arc = aiter.Value(); + p += exp(-to_log_weight_(arc.weight).Value()); + if (p > r * sum) return n; + } + return n; + } + + private: + WeightConvert<Weight, Log64Weight> to_log_weight_; +}; + +// Convenience definitions +typedef LogProbArcSelector<StdArc> StdArcSelector; +typedef LogProbArcSelector<LogArc> LogArcSelector; + + +// Same as LogProbArcSelector but use CacheLogAccumulator to cache +// the cummulative weight computations. +template <class A> +class FastLogProbArcSelector : public LogProbArcSelector<A> { + public: + typedef typename A::StateId StateId; + typedef typename A::Weight Weight; + using LogProbArcSelector<A>::operator(); + + FastLogProbArcSelector(int seed = time(0)) + : LogProbArcSelector<A>(seed), + seed_(seed) {} + + size_t operator()(const Fst<A> &fst, StateId s, + CacheLogAccumulator<A> *accumulator) const { + accumulator->SetState(s); + ArcIterator< Fst<A> > aiter(fst, s); + // Find total weight leaving state + double sum = to_log_weight_(accumulator->Sum(fst.Final(s), &aiter, 0, + fst.NumArcs(s))).Value(); + double r = -log(rand()/(RAND_MAX + 1.0)); + return accumulator->LowerBound(r + sum, &aiter); + } + + int Seed() const { return seed_; } + private: + int seed_; + WeightConvert<Weight, Log64Weight> to_log_weight_; +}; + +// Random path state info maintained by RandGenFst and passed to samplers. +template <typename A> +struct RandState { + typedef typename A::StateId StateId; + + StateId state_id; // current input FST state + size_t nsamples; // # of samples to be sampled at this state + size_t length; // length of path to this random state + size_t select; // previous sample arc selection + const RandState<A> *parent; // previous random state on this path + + RandState(StateId s, size_t n, size_t l, size_t k, const RandState<A> *p) + : state_id(s), nsamples(n), length(l), select(k), parent(p) {} + + RandState() + : state_id(kNoStateId), nsamples(0), length(0), select(0), parent(0) {} +}; + +// This class, given an arc selector, samples, with raplacement, +// multiple random transitions from an FST's state. This is a generic +// version with a straight-forward use of the arc selector. +// Specializations may be defined for arc selectors for greater +// efficiency or special behavior. +template <class A, class S> +class ArcSampler { + public: + typedef typename A::StateId StateId; + typedef typename A::Weight Weight; + + // The 'max_length' may be interpreted (including ignored) by a + // sampler as it chooses. This generic version interprets this literally. + ArcSampler(const Fst<A> &fst, const S &arc_selector, + int max_length = INT_MAX) + : fst_(fst), + arc_selector_(arc_selector), + max_length_(max_length) {} + + // Allow updating Fst argument; pass only if changed. + ArcSampler(const ArcSampler<A, S> &sampler, const Fst<A> *fst = 0) + : fst_(fst ? *fst : sampler.fst_), + arc_selector_(sampler.arc_selector_), + max_length_(sampler.max_length_) { + Reset(); + } + + // Samples 'rstate.nsamples' from state 'state_id'. The 'rstate.length' is + // the length of the path to 'rstate'. Returns true if samples were + // collected. No samples may be collected if either there are no (including + // 'super-final') transitions leaving that state or if the + // 'max_length' has been deemed reached. Use the iterator members to + // read the samples. The samples will be in their original order. + bool Sample(const RandState<A> &rstate) { + sample_map_.clear(); + if ((fst_.NumArcs(rstate.state_id) == 0 && + fst_.Final(rstate.state_id) == Weight::Zero()) || + rstate.length == max_length_) { + Reset(); + return false; + } + + for (size_t i = 0; i < rstate.nsamples; ++i) + ++sample_map_[arc_selector_(fst_, rstate.state_id)]; + Reset(); + return true; + } + + // More samples? + bool Done() const { return sample_iter_ == sample_map_.end(); } + + // Gets the next sample. + void Next() { ++sample_iter_; } + + // Returns a pair (N, K) where 0 <= N <= NumArcs(s) and 0 < K <= nsamples. + // If N < NumArcs(s), then the N-th transition is specified. + // If N == NumArcs(s), then the final weight at that state is + // specified (i.e., the 'super-final' transition is specified). + // For the specified transition, K repetitions have been sampled. + pair<size_t, size_t> Value() const { return *sample_iter_; } + + void Reset() { sample_iter_ = sample_map_.begin(); } + + bool Error() const { return false; } + + private: + const Fst<A> &fst_; + const S &arc_selector_; + int max_length_; + + // Stores (N, K) as described for Value(). + map<size_t, size_t> sample_map_; + map<size_t, size_t>::const_iterator sample_iter_; + + // disallow + ArcSampler<A, S> & operator=(const ArcSampler<A, S> &s); +}; + + +// Specialization for FastLogProbArcSelector. +template <class A> +class ArcSampler<A, FastLogProbArcSelector<A> > { + public: + typedef FastLogProbArcSelector<A> S; + typedef typename A::StateId StateId; + typedef typename A::Weight Weight; + typedef CacheLogAccumulator<A> C; + + ArcSampler(const Fst<A> &fst, const S &arc_selector, int max_length = INT_MAX) + : fst_(fst), + arc_selector_(arc_selector), + max_length_(max_length), + accumulator_(new C()) { + accumulator_->Init(fst); + } + + ArcSampler(const ArcSampler<A, S> &sampler, const Fst<A> *fst = 0) + : fst_(fst ? *fst : sampler.fst_), + arc_selector_(sampler.arc_selector_), + max_length_(sampler.max_length_) { + if (fst) { + accumulator_ = new C(); + accumulator_->Init(*fst); + } else { // shallow copy + accumulator_ = new C(*sampler.accumulator_); + } + } + + ~ArcSampler() { + delete accumulator_; + } + + bool Sample(const RandState<A> &rstate) { + sample_map_.clear(); + if ((fst_.NumArcs(rstate.state_id) == 0 && + fst_.Final(rstate.state_id) == Weight::Zero()) || + rstate.length == max_length_) { + Reset(); + return false; + } + + for (size_t i = 0; i < rstate.nsamples; ++i) + ++sample_map_[arc_selector_(fst_, rstate.state_id, accumulator_)]; + Reset(); + return true; + } + + bool Done() const { return sample_iter_ == sample_map_.end(); } + void Next() { ++sample_iter_; } + pair<size_t, size_t> Value() const { return *sample_iter_; } + void Reset() { sample_iter_ = sample_map_.begin(); } + + bool Error() const { return accumulator_->Error(); } + + private: + const Fst<A> &fst_; + const S &arc_selector_; + int max_length_; + + // Stores (N, K) as described for Value(). + map<size_t, size_t> sample_map_; + map<size_t, size_t>::const_iterator sample_iter_; + C *accumulator_; + + // disallow + ArcSampler<A, S> & operator=(const ArcSampler<A, S> &s); +}; + + +// Options for random path generation with RandGenFst. The template argument +// is an arc sampler, typically class 'ArcSampler' above. Ownership of +// the sampler is taken by RandGenFst. +template <class S> +struct RandGenFstOptions : public CacheOptions { + S *arc_sampler; // How to sample transitions at a state + size_t npath; // # of paths to generate + bool weighted; // Output tree weighted by path count; o.w. + // output unweighted DAG + bool remove_total_weight; // Remove total weight when output is weighted. + + RandGenFstOptions(const CacheOptions &copts, S *samp, + size_t n = 1, bool w = true, bool rw = false) + : CacheOptions(copts), + arc_sampler(samp), + npath(n), + weighted(w), + remove_total_weight(rw) {} +}; + + +// Implementation of RandGenFst. +template <class A, class B, class S> +class RandGenFstImpl : public CacheImpl<B> { + public: + using FstImpl<B>::SetType; + using FstImpl<B>::SetProperties; + using FstImpl<B>::SetInputSymbols; + using FstImpl<B>::SetOutputSymbols; + + using CacheBaseImpl< CacheState<B> >::AddArc; + using CacheBaseImpl< CacheState<B> >::HasArcs; + using CacheBaseImpl< CacheState<B> >::HasFinal; + using CacheBaseImpl< CacheState<B> >::HasStart; + using CacheBaseImpl< CacheState<B> >::SetArcs; + using CacheBaseImpl< CacheState<B> >::SetFinal; + using CacheBaseImpl< CacheState<B> >::SetStart; + + typedef B Arc; + typedef typename A::Label Label; + typedef typename A::Weight Weight; + typedef typename A::StateId StateId; + + RandGenFstImpl(const Fst<A> &fst, const RandGenFstOptions<S> &opts) + : CacheImpl<B>(opts), + fst_(fst.Copy()), + arc_sampler_(opts.arc_sampler), + npath_(opts.npath), + weighted_(opts.weighted), + remove_total_weight_(opts.remove_total_weight), + superfinal_(kNoLabel) { + SetType("randgen"); + + uint64 props = fst.Properties(kFstProperties, false); + SetProperties(RandGenProperties(props, weighted_), kCopyProperties); + + SetInputSymbols(fst.InputSymbols()); + SetOutputSymbols(fst.OutputSymbols()); + } + + RandGenFstImpl(const RandGenFstImpl &impl) + : CacheImpl<B>(impl), + fst_(impl.fst_->Copy(true)), + arc_sampler_(new S(*impl.arc_sampler_, fst_)), + npath_(impl.npath_), + weighted_(impl.weighted_), + superfinal_(kNoLabel) { + SetType("randgen"); + SetProperties(impl.Properties(), kCopyProperties); + SetInputSymbols(impl.InputSymbols()); + SetOutputSymbols(impl.OutputSymbols()); + } + + ~RandGenFstImpl() { + for (int i = 0; i < state_table_.size(); ++i) + delete state_table_[i]; + delete fst_; + delete arc_sampler_; + } + + StateId Start() { + if (!HasStart()) { + StateId s = fst_->Start(); + if (s == kNoStateId) + return kNoStateId; + StateId start = state_table_.size(); + SetStart(start); + RandState<A> *rstate = new RandState<A>(s, npath_, 0, 0, 0); + state_table_.push_back(rstate); + } + return CacheImpl<B>::Start(); + } + + Weight Final(StateId s) { + if (!HasFinal(s)) { + Expand(s); + } + return CacheImpl<B>::Final(s); + } + + size_t NumArcs(StateId s) { + if (!HasArcs(s)) { + Expand(s); + } + return CacheImpl<B>::NumArcs(s); + } + + size_t NumInputEpsilons(StateId s) { + if (!HasArcs(s)) + Expand(s); + return CacheImpl<B>::NumInputEpsilons(s); + } + + size_t NumOutputEpsilons(StateId s) { + if (!HasArcs(s)) + Expand(s); + return CacheImpl<B>::NumOutputEpsilons(s); + } + + uint64 Properties() const { return Properties(kFstProperties); } + + // Set error if found; return FST impl properties. + uint64 Properties(uint64 mask) const { + if ((mask & kError) && + (fst_->Properties(kError, false) || arc_sampler_->Error())) { + SetProperties(kError, kError); + } + return FstImpl<Arc>::Properties(mask); + } + + void InitArcIterator(StateId s, ArcIteratorData<B> *data) { + if (!HasArcs(s)) + Expand(s); + CacheImpl<B>::InitArcIterator(s, data); + } + + // Computes the outgoing transitions from a state, creating new destination + // states as needed. + void Expand(StateId s) { + if (s == superfinal_) { + SetFinal(s, Weight::One()); + SetArcs(s); + return; + } + + SetFinal(s, Weight::Zero()); + const RandState<A> &rstate = *state_table_[s]; + arc_sampler_->Sample(rstate); + ArcIterator< Fst<A> > aiter(*fst_, rstate.state_id); + size_t narcs = fst_->NumArcs(rstate.state_id); + for (;!arc_sampler_->Done(); arc_sampler_->Next()) { + const pair<size_t, size_t> &sample_pair = arc_sampler_->Value(); + size_t pos = sample_pair.first; + size_t count = sample_pair.second; + double prob = static_cast<double>(count)/rstate.nsamples; + if (pos < narcs) { // regular transition + aiter.Seek(sample_pair.first); + const A &aarc = aiter.Value(); + Weight weight = weighted_ ? to_weight_(-log(prob)) : Weight::One(); + B barc(aarc.ilabel, aarc.olabel, weight, state_table_.size()); + AddArc(s, barc); + RandState<A> *nrstate = + new RandState<A>(aarc.nextstate, count, rstate.length + 1, + pos, &rstate); + state_table_.push_back(nrstate); + } else { // super-final transition + if (weighted_) { + Weight weight = remove_total_weight_ ? + to_weight_(-log(prob)) : to_weight_(-log(prob * npath_)); + SetFinal(s, weight); + } else { + if (superfinal_ == kNoLabel) { + superfinal_ = state_table_.size(); + RandState<A> *nrstate = new RandState<A>(kNoStateId, 0, 0, 0, 0); + state_table_.push_back(nrstate); + } + for (size_t n = 0; n < count; ++n) { + B barc(0, 0, Weight::One(), superfinal_); + AddArc(s, barc); + } + } + } + } + SetArcs(s); + } + + private: + Fst<A> *fst_; + S *arc_sampler_; + size_t npath_; + vector<RandState<A> *> state_table_; + bool weighted_; + bool remove_total_weight_; + StateId superfinal_; + WeightConvert<Log64Weight, Weight> to_weight_; + + void operator=(const RandGenFstImpl<A, B, S> &); // disallow +}; + + +// Fst class to randomly generate paths through an FST; details controlled +// by RandGenOptionsFst. Output format is a tree weighted by the +// path count. +template <class A, class B, class S> +class RandGenFst : public ImplToFst< RandGenFstImpl<A, B, S> > { + public: + friend class ArcIterator< RandGenFst<A, B, S> >; + friend class StateIterator< RandGenFst<A, B, S> >; + typedef B Arc; + typedef S Sampler; + typedef typename A::Label Label; + typedef typename A::Weight Weight; + typedef typename A::StateId StateId; + typedef CacheState<B> State; + typedef RandGenFstImpl<A, B, S> Impl; + + RandGenFst(const Fst<A> &fst, const RandGenFstOptions<S> &opts) + : ImplToFst<Impl>(new Impl(fst, opts)) {} + + // See Fst<>::Copy() for doc. + RandGenFst(const RandGenFst<A, B, S> &fst, bool safe = false) + : ImplToFst<Impl>(fst, safe) {} + + // Get a copy of this RandGenFst. See Fst<>::Copy() for further doc. + virtual RandGenFst<A, B, S> *Copy(bool safe = false) const { + return new RandGenFst<A, B, S>(*this, safe); + } + + virtual inline void InitStateIterator(StateIteratorData<B> *data) const; + + virtual void InitArcIterator(StateId s, ArcIteratorData<B> *data) const { + GetImpl()->InitArcIterator(s, data); + } + + private: + // Makes visible to friends. + Impl *GetImpl() const { return ImplToFst<Impl>::GetImpl(); } + + void operator=(const RandGenFst<A, B, S> &fst); // Disallow +}; + + + +// Specialization for RandGenFst. +template <class A, class B, class S> +class StateIterator< RandGenFst<A, B, S> > + : public CacheStateIterator< RandGenFst<A, B, S> > { + public: + explicit StateIterator(const RandGenFst<A, B, S> &fst) + : CacheStateIterator< RandGenFst<A, B, S> >(fst, fst.GetImpl()) {} + + private: + DISALLOW_COPY_AND_ASSIGN(StateIterator); +}; + + +// Specialization for RandGenFst. +template <class A, class B, class S> +class ArcIterator< RandGenFst<A, B, S> > + : public CacheArcIterator< RandGenFst<A, B, S> > { + public: + typedef typename A::StateId StateId; + + ArcIterator(const RandGenFst<A, B, S> &fst, StateId s) + : CacheArcIterator< RandGenFst<A, B, S> >(fst.GetImpl(), s) { + if (!fst.GetImpl()->HasArcs(s)) + fst.GetImpl()->Expand(s); + } + + private: + DISALLOW_COPY_AND_ASSIGN(ArcIterator); +}; + + +template <class A, class B, class S> inline +void RandGenFst<A, B, S>::InitStateIterator(StateIteratorData<B> *data) const +{ + data->base = new StateIterator< RandGenFst<A, B, S> >(*this); +} + +// Options for random path generation. +template <class S> +struct RandGenOptions { + const S &arc_selector; // How an arc is selected at a state + int max_length; // Maximum path length + size_t npath; // # of paths to generate + bool weighted; // Output is tree weighted by path count; o.w. + // output unweighted union of paths. + bool remove_total_weight; // Remove total weight when output is weighted. + + RandGenOptions(const S &sel, int len = INT_MAX, size_t n = 1, + bool w = false, bool rw = false) + : arc_selector(sel), + max_length(len), + npath(n), + weighted(w), + remove_total_weight(rw) {} +}; + + +template <class IArc, class OArc> +class RandGenVisitor { + public: + typedef typename IArc::Weight Weight; + typedef typename IArc::StateId StateId; + + RandGenVisitor(MutableFst<OArc> *ofst) : ofst_(ofst) {} + + void InitVisit(const Fst<IArc> &ifst) { + ifst_ = &ifst; + + ofst_->DeleteStates(); + ofst_->SetInputSymbols(ifst.InputSymbols()); + ofst_->SetOutputSymbols(ifst.OutputSymbols()); + if (ifst.Properties(kError, false)) + ofst_->SetProperties(kError, kError); + path_.clear(); + } + + bool InitState(StateId s, StateId root) { return true; } + + bool TreeArc(StateId s, const IArc &arc) { + if (ifst_->Final(arc.nextstate) == Weight::Zero()) { + path_.push_back(arc); + } else { + OutputPath(); + } + return true; + } + + bool BackArc(StateId s, const IArc &arc) { + FSTERROR() << "RandGenVisitor: cyclic input"; + ofst_->SetProperties(kError, kError); + return false; + } + + bool ForwardOrCrossArc(StateId s, const IArc &arc) { + OutputPath(); + return true; + } + + void FinishState(StateId s, StateId p, const IArc *) { + if (p != kNoStateId && ifst_->Final(s) == Weight::Zero()) + path_.pop_back(); + } + + void FinishVisit() {} + + private: + void OutputPath() { + if (ofst_->Start() == kNoStateId) { + StateId start = ofst_->AddState(); + ofst_->SetStart(start); + } + + StateId src = ofst_->Start(); + for (size_t i = 0; i < path_.size(); ++i) { + StateId dest = ofst_->AddState(); + OArc arc(path_[i].ilabel, path_[i].olabel, Weight::One(), dest); + ofst_->AddArc(src, arc); + src = dest; + } + ofst_->SetFinal(src, Weight::One()); + } + + const Fst<IArc> *ifst_; + MutableFst<OArc> *ofst_; + vector<OArc> path_; + + DISALLOW_COPY_AND_ASSIGN(RandGenVisitor); +}; + + +// Randomly generate paths through an FST; details controlled by +// RandGenOptions. +template<class IArc, class OArc, class Selector> +void RandGen(const Fst<IArc> &ifst, MutableFst<OArc> *ofst, + const RandGenOptions<Selector> &opts) { + typedef ArcSampler<IArc, Selector> Sampler; + typedef RandGenFst<IArc, OArc, Sampler> RandFst; + typedef typename OArc::StateId StateId; + typedef typename OArc::Weight Weight; + + Sampler* arc_sampler = new Sampler(ifst, opts.arc_selector, opts.max_length); + RandGenFstOptions<Sampler> fopts(CacheOptions(true, 0), arc_sampler, + opts.npath, opts.weighted, + opts.remove_total_weight); + RandFst rfst(ifst, fopts); + if (opts.weighted) { + *ofst = rfst; + } else { + RandGenVisitor<IArc, OArc> rand_visitor(ofst); + DfsVisit(rfst, &rand_visitor); + } +} + +// Randomly generate a path through an FST with the uniform distribution +// over the transitions. +template<class IArc, class OArc> +void RandGen(const Fst<IArc> &ifst, MutableFst<OArc> *ofst) { + UniformArcSelector<IArc> uniform_selector; + RandGenOptions< UniformArcSelector<IArc> > opts(uniform_selector); + RandGen(ifst, ofst, opts); +} + +} // namespace fst + +#endif // FST_LIB_RANDGEN_H__ |