diff options
Diffstat (limited to 'kaldi_io/src/tools/openfst/include/fst/randgen.h')
-rw-r--r-- | kaldi_io/src/tools/openfst/include/fst/randgen.h | 712 |
1 files changed, 0 insertions, 712 deletions
diff --git a/kaldi_io/src/tools/openfst/include/fst/randgen.h b/kaldi_io/src/tools/openfst/include/fst/randgen.h deleted file mode 100644 index 82ddffa..0000000 --- a/kaldi_io/src/tools/openfst/include/fst/randgen.h +++ /dev/null @@ -1,712 +0,0 @@ -// randgen.h - -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// -// Copyright 2005-2010 Google, Inc. -// Author: [email protected] (Michael Riley) -// -// \file -// Classes and functions to generate random paths through an FST. - -#ifndef FST_LIB_RANDGEN_H__ -#define FST_LIB_RANDGEN_H__ - -#include <cmath> -#include <cstdlib> -#include <ctime> -#include <map> - -#include <fst/accumulator.h> -#include <fst/cache.h> -#include <fst/dfs-visit.h> -#include <fst/mutable-fst.h> - -namespace fst { - -// -// ARC SELECTORS - these function objects are used to select a random -// transition to take from an FST's state. They should return a number -// N s.t. 0 <= N <= NumArcs(). If N < NumArcs(), then the N-th -// transition is selected. If N == NumArcs(), then the final weight at -// that state is selected (i.e., the 'super-final' transition is selected). -// It can be assumed these will not be called unless either there -// are transitions leaving the state and/or the state is final. -// - -// Randomly selects a transition using the uniform distribution. -template <class A> -struct UniformArcSelector { - typedef typename A::StateId StateId; - typedef typename A::Weight Weight; - - UniformArcSelector(int seed = time(0)) { srand(seed); } - - size_t operator()(const Fst<A> &fst, StateId s) const { - double r = rand()/(RAND_MAX + 1.0); - size_t n = fst.NumArcs(s); - if (fst.Final(s) != Weight::Zero()) - ++n; - return static_cast<size_t>(r * n); - } -}; - - -// Randomly selects a transition w.r.t. the weights treated as negative -// log probabilities after normalizing for the total weight leaving -// the state. Weight::zero transitions are disregarded. -// Assumes Weight::Value() accesses the floating point -// representation of the weight. -template <class A> -class LogProbArcSelector { - public: - typedef typename A::StateId StateId; - typedef typename A::Weight Weight; - - LogProbArcSelector(int seed = time(0)) { srand(seed); } - - size_t operator()(const Fst<A> &fst, StateId s) const { - // Find total weight leaving state - double sum = 0.0; - for (ArcIterator< Fst<A> > aiter(fst, s); !aiter.Done(); - aiter.Next()) { - const A &arc = aiter.Value(); - sum += exp(-to_log_weight_(arc.weight).Value()); - } - sum += exp(-to_log_weight_(fst.Final(s)).Value()); - - double r = rand()/(RAND_MAX + 1.0); - double p = 0.0; - int n = 0; - for (ArcIterator< Fst<A> > aiter(fst, s); !aiter.Done(); - aiter.Next(), ++n) { - const A &arc = aiter.Value(); - p += exp(-to_log_weight_(arc.weight).Value()); - if (p > r * sum) return n; - } - return n; - } - - private: - WeightConvert<Weight, Log64Weight> to_log_weight_; -}; - -// Convenience definitions -typedef LogProbArcSelector<StdArc> StdArcSelector; -typedef LogProbArcSelector<LogArc> LogArcSelector; - - -// Same as LogProbArcSelector but use CacheLogAccumulator to cache -// the cummulative weight computations. -template <class A> -class FastLogProbArcSelector : public LogProbArcSelector<A> { - public: - typedef typename A::StateId StateId; - typedef typename A::Weight Weight; - using LogProbArcSelector<A>::operator(); - - FastLogProbArcSelector(int seed = time(0)) - : LogProbArcSelector<A>(seed), - seed_(seed) {} - - size_t operator()(const Fst<A> &fst, StateId s, - CacheLogAccumulator<A> *accumulator) const { - accumulator->SetState(s); - ArcIterator< Fst<A> > aiter(fst, s); - // Find total weight leaving state - double sum = to_log_weight_(accumulator->Sum(fst.Final(s), &aiter, 0, - fst.NumArcs(s))).Value(); - double r = -log(rand()/(RAND_MAX + 1.0)); - return accumulator->LowerBound(r + sum, &aiter); - } - - int Seed() const { return seed_; } - private: - int seed_; - WeightConvert<Weight, Log64Weight> to_log_weight_; -}; - -// Random path state info maintained by RandGenFst and passed to samplers. -template <typename A> -struct RandState { - typedef typename A::StateId StateId; - - StateId state_id; // current input FST state - size_t nsamples; // # of samples to be sampled at this state - size_t length; // length of path to this random state - size_t select; // previous sample arc selection - const RandState<A> *parent; // previous random state on this path - - RandState(StateId s, size_t n, size_t l, size_t k, const RandState<A> *p) - : state_id(s), nsamples(n), length(l), select(k), parent(p) {} - - RandState() - : state_id(kNoStateId), nsamples(0), length(0), select(0), parent(0) {} -}; - -// This class, given an arc selector, samples, with raplacement, -// multiple random transitions from an FST's state. This is a generic -// version with a straight-forward use of the arc selector. -// Specializations may be defined for arc selectors for greater -// efficiency or special behavior. -template <class A, class S> -class ArcSampler { - public: - typedef typename A::StateId StateId; - typedef typename A::Weight Weight; - - // The 'max_length' may be interpreted (including ignored) by a - // sampler as it chooses. This generic version interprets this literally. - ArcSampler(const Fst<A> &fst, const S &arc_selector, - int max_length = INT_MAX) - : fst_(fst), - arc_selector_(arc_selector), - max_length_(max_length) {} - - // Allow updating Fst argument; pass only if changed. - ArcSampler(const ArcSampler<A, S> &sampler, const Fst<A> *fst = 0) - : fst_(fst ? *fst : sampler.fst_), - arc_selector_(sampler.arc_selector_), - max_length_(sampler.max_length_) { - Reset(); - } - - // Samples 'rstate.nsamples' from state 'state_id'. The 'rstate.length' is - // the length of the path to 'rstate'. Returns true if samples were - // collected. No samples may be collected if either there are no (including - // 'super-final') transitions leaving that state or if the - // 'max_length' has been deemed reached. Use the iterator members to - // read the samples. The samples will be in their original order. - bool Sample(const RandState<A> &rstate) { - sample_map_.clear(); - if ((fst_.NumArcs(rstate.state_id) == 0 && - fst_.Final(rstate.state_id) == Weight::Zero()) || - rstate.length == max_length_) { - Reset(); - return false; - } - - for (size_t i = 0; i < rstate.nsamples; ++i) - ++sample_map_[arc_selector_(fst_, rstate.state_id)]; - Reset(); - return true; - } - - // More samples? - bool Done() const { return sample_iter_ == sample_map_.end(); } - - // Gets the next sample. - void Next() { ++sample_iter_; } - - // Returns a pair (N, K) where 0 <= N <= NumArcs(s) and 0 < K <= nsamples. - // If N < NumArcs(s), then the N-th transition is specified. - // If N == NumArcs(s), then the final weight at that state is - // specified (i.e., the 'super-final' transition is specified). - // For the specified transition, K repetitions have been sampled. - pair<size_t, size_t> Value() const { return *sample_iter_; } - - void Reset() { sample_iter_ = sample_map_.begin(); } - - bool Error() const { return false; } - - private: - const Fst<A> &fst_; - const S &arc_selector_; - int max_length_; - - // Stores (N, K) as described for Value(). - map<size_t, size_t> sample_map_; - map<size_t, size_t>::const_iterator sample_iter_; - - // disallow - ArcSampler<A, S> & operator=(const ArcSampler<A, S> &s); -}; - - -// Specialization for FastLogProbArcSelector. -template <class A> -class ArcSampler<A, FastLogProbArcSelector<A> > { - public: - typedef FastLogProbArcSelector<A> S; - typedef typename A::StateId StateId; - typedef typename A::Weight Weight; - typedef CacheLogAccumulator<A> C; - - ArcSampler(const Fst<A> &fst, const S &arc_selector, int max_length = INT_MAX) - : fst_(fst), - arc_selector_(arc_selector), - max_length_(max_length), - accumulator_(new C()) { - accumulator_->Init(fst); - } - - ArcSampler(const ArcSampler<A, S> &sampler, const Fst<A> *fst = 0) - : fst_(fst ? *fst : sampler.fst_), - arc_selector_(sampler.arc_selector_), - max_length_(sampler.max_length_) { - if (fst) { - accumulator_ = new C(); - accumulator_->Init(*fst); - } else { // shallow copy - accumulator_ = new C(*sampler.accumulator_); - } - } - - ~ArcSampler() { - delete accumulator_; - } - - bool Sample(const RandState<A> &rstate) { - sample_map_.clear(); - if ((fst_.NumArcs(rstate.state_id) == 0 && - fst_.Final(rstate.state_id) == Weight::Zero()) || - rstate.length == max_length_) { - Reset(); - return false; - } - - for (size_t i = 0; i < rstate.nsamples; ++i) - ++sample_map_[arc_selector_(fst_, rstate.state_id, accumulator_)]; - Reset(); - return true; - } - - bool Done() const { return sample_iter_ == sample_map_.end(); } - void Next() { ++sample_iter_; } - pair<size_t, size_t> Value() const { return *sample_iter_; } - void Reset() { sample_iter_ = sample_map_.begin(); } - - bool Error() const { return accumulator_->Error(); } - - private: - const Fst<A> &fst_; - const S &arc_selector_; - int max_length_; - - // Stores (N, K) as described for Value(). - map<size_t, size_t> sample_map_; - map<size_t, size_t>::const_iterator sample_iter_; - C *accumulator_; - - // disallow - ArcSampler<A, S> & operator=(const ArcSampler<A, S> &s); -}; - - -// Options for random path generation with RandGenFst. The template argument -// is an arc sampler, typically class 'ArcSampler' above. Ownership of -// the sampler is taken by RandGenFst. -template <class S> -struct RandGenFstOptions : public CacheOptions { - S *arc_sampler; // How to sample transitions at a state - size_t npath; // # of paths to generate - bool weighted; // Output tree weighted by path count; o.w. - // output unweighted DAG - bool remove_total_weight; // Remove total weight when output is weighted. - - RandGenFstOptions(const CacheOptions &copts, S *samp, - size_t n = 1, bool w = true, bool rw = false) - : CacheOptions(copts), - arc_sampler(samp), - npath(n), - weighted(w), - remove_total_weight(rw) {} -}; - - -// Implementation of RandGenFst. -template <class A, class B, class S> -class RandGenFstImpl : public CacheImpl<B> { - public: - using FstImpl<B>::SetType; - using FstImpl<B>::SetProperties; - using FstImpl<B>::SetInputSymbols; - using FstImpl<B>::SetOutputSymbols; - - using CacheBaseImpl< CacheState<B> >::AddArc; - using CacheBaseImpl< CacheState<B> >::HasArcs; - using CacheBaseImpl< CacheState<B> >::HasFinal; - using CacheBaseImpl< CacheState<B> >::HasStart; - using CacheBaseImpl< CacheState<B> >::SetArcs; - using CacheBaseImpl< CacheState<B> >::SetFinal; - using CacheBaseImpl< CacheState<B> >::SetStart; - - typedef B Arc; - typedef typename A::Label Label; - typedef typename A::Weight Weight; - typedef typename A::StateId StateId; - - RandGenFstImpl(const Fst<A> &fst, const RandGenFstOptions<S> &opts) - : CacheImpl<B>(opts), - fst_(fst.Copy()), - arc_sampler_(opts.arc_sampler), - npath_(opts.npath), - weighted_(opts.weighted), - remove_total_weight_(opts.remove_total_weight), - superfinal_(kNoLabel) { - SetType("randgen"); - - uint64 props = fst.Properties(kFstProperties, false); - SetProperties(RandGenProperties(props, weighted_), kCopyProperties); - - SetInputSymbols(fst.InputSymbols()); - SetOutputSymbols(fst.OutputSymbols()); - } - - RandGenFstImpl(const RandGenFstImpl &impl) - : CacheImpl<B>(impl), - fst_(impl.fst_->Copy(true)), - arc_sampler_(new S(*impl.arc_sampler_, fst_)), - npath_(impl.npath_), - weighted_(impl.weighted_), - superfinal_(kNoLabel) { - SetType("randgen"); - SetProperties(impl.Properties(), kCopyProperties); - SetInputSymbols(impl.InputSymbols()); - SetOutputSymbols(impl.OutputSymbols()); - } - - ~RandGenFstImpl() { - for (int i = 0; i < state_table_.size(); ++i) - delete state_table_[i]; - delete fst_; - delete arc_sampler_; - } - - StateId Start() { - if (!HasStart()) { - StateId s = fst_->Start(); - if (s == kNoStateId) - return kNoStateId; - StateId start = state_table_.size(); - SetStart(start); - RandState<A> *rstate = new RandState<A>(s, npath_, 0, 0, 0); - state_table_.push_back(rstate); - } - return CacheImpl<B>::Start(); - } - - Weight Final(StateId s) { - if (!HasFinal(s)) { - Expand(s); - } - return CacheImpl<B>::Final(s); - } - - size_t NumArcs(StateId s) { - if (!HasArcs(s)) { - Expand(s); - } - return CacheImpl<B>::NumArcs(s); - } - - size_t NumInputEpsilons(StateId s) { - if (!HasArcs(s)) - Expand(s); - return CacheImpl<B>::NumInputEpsilons(s); - } - - size_t NumOutputEpsilons(StateId s) { - if (!HasArcs(s)) - Expand(s); - return CacheImpl<B>::NumOutputEpsilons(s); - } - - uint64 Properties() const { return Properties(kFstProperties); } - - // Set error if found; return FST impl properties. - uint64 Properties(uint64 mask) const { - if ((mask & kError) && - (fst_->Properties(kError, false) || arc_sampler_->Error())) { - SetProperties(kError, kError); - } - return FstImpl<Arc>::Properties(mask); - } - - void InitArcIterator(StateId s, ArcIteratorData<B> *data) { - if (!HasArcs(s)) - Expand(s); - CacheImpl<B>::InitArcIterator(s, data); - } - - // Computes the outgoing transitions from a state, creating new destination - // states as needed. - void Expand(StateId s) { - if (s == superfinal_) { - SetFinal(s, Weight::One()); - SetArcs(s); - return; - } - - SetFinal(s, Weight::Zero()); - const RandState<A> &rstate = *state_table_[s]; - arc_sampler_->Sample(rstate); - ArcIterator< Fst<A> > aiter(*fst_, rstate.state_id); - size_t narcs = fst_->NumArcs(rstate.state_id); - for (;!arc_sampler_->Done(); arc_sampler_->Next()) { - const pair<size_t, size_t> &sample_pair = arc_sampler_->Value(); - size_t pos = sample_pair.first; - size_t count = sample_pair.second; - double prob = static_cast<double>(count)/rstate.nsamples; - if (pos < narcs) { // regular transition - aiter.Seek(sample_pair.first); - const A &aarc = aiter.Value(); - Weight weight = weighted_ ? to_weight_(-log(prob)) : Weight::One(); - B barc(aarc.ilabel, aarc.olabel, weight, state_table_.size()); - AddArc(s, barc); - RandState<A> *nrstate = - new RandState<A>(aarc.nextstate, count, rstate.length + 1, - pos, &rstate); - state_table_.push_back(nrstate); - } else { // super-final transition - if (weighted_) { - Weight weight = remove_total_weight_ ? - to_weight_(-log(prob)) : to_weight_(-log(prob * npath_)); - SetFinal(s, weight); - } else { - if (superfinal_ == kNoLabel) { - superfinal_ = state_table_.size(); - RandState<A> *nrstate = new RandState<A>(kNoStateId, 0, 0, 0, 0); - state_table_.push_back(nrstate); - } - for (size_t n = 0; n < count; ++n) { - B barc(0, 0, Weight::One(), superfinal_); - AddArc(s, barc); - } - } - } - } - SetArcs(s); - } - - private: - Fst<A> *fst_; - S *arc_sampler_; - size_t npath_; - vector<RandState<A> *> state_table_; - bool weighted_; - bool remove_total_weight_; - StateId superfinal_; - WeightConvert<Log64Weight, Weight> to_weight_; - - void operator=(const RandGenFstImpl<A, B, S> &); // disallow -}; - - -// Fst class to randomly generate paths through an FST; details controlled -// by RandGenOptionsFst. Output format is a tree weighted by the -// path count. -template <class A, class B, class S> -class RandGenFst : public ImplToFst< RandGenFstImpl<A, B, S> > { - public: - friend class ArcIterator< RandGenFst<A, B, S> >; - friend class StateIterator< RandGenFst<A, B, S> >; - typedef B Arc; - typedef S Sampler; - typedef typename A::Label Label; - typedef typename A::Weight Weight; - typedef typename A::StateId StateId; - typedef CacheState<B> State; - typedef RandGenFstImpl<A, B, S> Impl; - - RandGenFst(const Fst<A> &fst, const RandGenFstOptions<S> &opts) - : ImplToFst<Impl>(new Impl(fst, opts)) {} - - // See Fst<>::Copy() for doc. - RandGenFst(const RandGenFst<A, B, S> &fst, bool safe = false) - : ImplToFst<Impl>(fst, safe) {} - - // Get a copy of this RandGenFst. See Fst<>::Copy() for further doc. - virtual RandGenFst<A, B, S> *Copy(bool safe = false) const { - return new RandGenFst<A, B, S>(*this, safe); - } - - virtual inline void InitStateIterator(StateIteratorData<B> *data) const; - - virtual void InitArcIterator(StateId s, ArcIteratorData<B> *data) const { - GetImpl()->InitArcIterator(s, data); - } - - private: - // Makes visible to friends. - Impl *GetImpl() const { return ImplToFst<Impl>::GetImpl(); } - - void operator=(const RandGenFst<A, B, S> &fst); // Disallow -}; - - - -// Specialization for RandGenFst. -template <class A, class B, class S> -class StateIterator< RandGenFst<A, B, S> > - : public CacheStateIterator< RandGenFst<A, B, S> > { - public: - explicit StateIterator(const RandGenFst<A, B, S> &fst) - : CacheStateIterator< RandGenFst<A, B, S> >(fst, fst.GetImpl()) {} - - private: - DISALLOW_COPY_AND_ASSIGN(StateIterator); -}; - - -// Specialization for RandGenFst. -template <class A, class B, class S> -class ArcIterator< RandGenFst<A, B, S> > - : public CacheArcIterator< RandGenFst<A, B, S> > { - public: - typedef typename A::StateId StateId; - - ArcIterator(const RandGenFst<A, B, S> &fst, StateId s) - : CacheArcIterator< RandGenFst<A, B, S> >(fst.GetImpl(), s) { - if (!fst.GetImpl()->HasArcs(s)) - fst.GetImpl()->Expand(s); - } - - private: - DISALLOW_COPY_AND_ASSIGN(ArcIterator); -}; - - -template <class A, class B, class S> inline -void RandGenFst<A, B, S>::InitStateIterator(StateIteratorData<B> *data) const -{ - data->base = new StateIterator< RandGenFst<A, B, S> >(*this); -} - -// Options for random path generation. -template <class S> -struct RandGenOptions { - const S &arc_selector; // How an arc is selected at a state - int max_length; // Maximum path length - size_t npath; // # of paths to generate - bool weighted; // Output is tree weighted by path count; o.w. - // output unweighted union of paths. - bool remove_total_weight; // Remove total weight when output is weighted. - - RandGenOptions(const S &sel, int len = INT_MAX, size_t n = 1, - bool w = false, bool rw = false) - : arc_selector(sel), - max_length(len), - npath(n), - weighted(w), - remove_total_weight(rw) {} -}; - - -template <class IArc, class OArc> -class RandGenVisitor { - public: - typedef typename IArc::Weight Weight; - typedef typename IArc::StateId StateId; - - RandGenVisitor(MutableFst<OArc> *ofst) : ofst_(ofst) {} - - void InitVisit(const Fst<IArc> &ifst) { - ifst_ = &ifst; - - ofst_->DeleteStates(); - ofst_->SetInputSymbols(ifst.InputSymbols()); - ofst_->SetOutputSymbols(ifst.OutputSymbols()); - if (ifst.Properties(kError, false)) - ofst_->SetProperties(kError, kError); - path_.clear(); - } - - bool InitState(StateId s, StateId root) { return true; } - - bool TreeArc(StateId s, const IArc &arc) { - if (ifst_->Final(arc.nextstate) == Weight::Zero()) { - path_.push_back(arc); - } else { - OutputPath(); - } - return true; - } - - bool BackArc(StateId s, const IArc &arc) { - FSTERROR() << "RandGenVisitor: cyclic input"; - ofst_->SetProperties(kError, kError); - return false; - } - - bool ForwardOrCrossArc(StateId s, const IArc &arc) { - OutputPath(); - return true; - } - - void FinishState(StateId s, StateId p, const IArc *) { - if (p != kNoStateId && ifst_->Final(s) == Weight::Zero()) - path_.pop_back(); - } - - void FinishVisit() {} - - private: - void OutputPath() { - if (ofst_->Start() == kNoStateId) { - StateId start = ofst_->AddState(); - ofst_->SetStart(start); - } - - StateId src = ofst_->Start(); - for (size_t i = 0; i < path_.size(); ++i) { - StateId dest = ofst_->AddState(); - OArc arc(path_[i].ilabel, path_[i].olabel, Weight::One(), dest); - ofst_->AddArc(src, arc); - src = dest; - } - ofst_->SetFinal(src, Weight::One()); - } - - const Fst<IArc> *ifst_; - MutableFst<OArc> *ofst_; - vector<OArc> path_; - - DISALLOW_COPY_AND_ASSIGN(RandGenVisitor); -}; - - -// Randomly generate paths through an FST; details controlled by -// RandGenOptions. -template<class IArc, class OArc, class Selector> -void RandGen(const Fst<IArc> &ifst, MutableFst<OArc> *ofst, - const RandGenOptions<Selector> &opts) { - typedef ArcSampler<IArc, Selector> Sampler; - typedef RandGenFst<IArc, OArc, Sampler> RandFst; - typedef typename OArc::StateId StateId; - typedef typename OArc::Weight Weight; - - Sampler* arc_sampler = new Sampler(ifst, opts.arc_selector, opts.max_length); - RandGenFstOptions<Sampler> fopts(CacheOptions(true, 0), arc_sampler, - opts.npath, opts.weighted, - opts.remove_total_weight); - RandFst rfst(ifst, fopts); - if (opts.weighted) { - *ofst = rfst; - } else { - RandGenVisitor<IArc, OArc> rand_visitor(ofst); - DfsVisit(rfst, &rand_visitor); - } -} - -// Randomly generate a path through an FST with the uniform distribution -// over the transitions. -template<class IArc, class OArc> -void RandGen(const Fst<IArc> &ifst, MutableFst<OArc> *ofst) { - UniformArcSelector<IArc> uniform_selector; - RandGenOptions< UniformArcSelector<IArc> > opts(uniform_selector); - RandGen(ifst, ofst, opts); -} - -} // namespace fst - -#endif // FST_LIB_RANDGEN_H__ |