// randgen.h // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // // Copyright 2005-2010 Google, Inc. // Author: riley@google.com (Michael Riley) // // \file // Classes and functions to generate random paths through an FST. #ifndef FST_LIB_RANDGEN_H__ #define FST_LIB_RANDGEN_H__ #include #include #include #include #include #include #include #include namespace fst { // // ARC SELECTORS - these function objects are used to select a random // transition to take from an FST's state. They should return a number // N s.t. 0 <= N <= NumArcs(). If N < NumArcs(), then the N-th // transition is selected. If N == NumArcs(), then the final weight at // that state is selected (i.e., the 'super-final' transition is selected). // It can be assumed these will not be called unless either there // are transitions leaving the state and/or the state is final. // // Randomly selects a transition using the uniform distribution. template struct UniformArcSelector { typedef typename A::StateId StateId; typedef typename A::Weight Weight; UniformArcSelector(int seed = time(0)) { srand(seed); } size_t operator()(const Fst &fst, StateId s) const { double r = rand()/(RAND_MAX + 1.0); size_t n = fst.NumArcs(s); if (fst.Final(s) != Weight::Zero()) ++n; return static_cast(r * n); } }; // Randomly selects a transition w.r.t. the weights treated as negative // log probabilities after normalizing for the total weight leaving // the state. Weight::zero transitions are disregarded. // Assumes Weight::Value() accesses the floating point // representation of the weight. template class LogProbArcSelector { public: typedef typename A::StateId StateId; typedef typename A::Weight Weight; LogProbArcSelector(int seed = time(0)) { srand(seed); } size_t operator()(const Fst &fst, StateId s) const { // Find total weight leaving state double sum = 0.0; for (ArcIterator< Fst > aiter(fst, s); !aiter.Done(); aiter.Next()) { const A &arc = aiter.Value(); sum += exp(-to_log_weight_(arc.weight).Value()); } sum += exp(-to_log_weight_(fst.Final(s)).Value()); double r = rand()/(RAND_MAX + 1.0); double p = 0.0; int n = 0; for (ArcIterator< Fst > aiter(fst, s); !aiter.Done(); aiter.Next(), ++n) { const A &arc = aiter.Value(); p += exp(-to_log_weight_(arc.weight).Value()); if (p > r * sum) return n; } return n; } private: WeightConvert to_log_weight_; }; // Convenience definitions typedef LogProbArcSelector StdArcSelector; typedef LogProbArcSelector LogArcSelector; // Same as LogProbArcSelector but use CacheLogAccumulator to cache // the cummulative weight computations. template class FastLogProbArcSelector : public LogProbArcSelector { public: typedef typename A::StateId StateId; typedef typename A::Weight Weight; using LogProbArcSelector::operator(); FastLogProbArcSelector(int seed = time(0)) : LogProbArcSelector(seed), seed_(seed) {} size_t operator()(const Fst &fst, StateId s, CacheLogAccumulator *accumulator) const { accumulator->SetState(s); ArcIterator< Fst > aiter(fst, s); // Find total weight leaving state double sum = to_log_weight_(accumulator->Sum(fst.Final(s), &aiter, 0, fst.NumArcs(s))).Value(); double r = -log(rand()/(RAND_MAX + 1.0)); return accumulator->LowerBound(r + sum, &aiter); } int Seed() const { return seed_; } private: int seed_; WeightConvert to_log_weight_; }; // Random path state info maintained by RandGenFst and passed to samplers. template struct RandState { typedef typename A::StateId StateId; StateId state_id; // current input FST state size_t nsamples; // # of samples to be sampled at this state size_t length; // length of path to this random state size_t select; // previous sample arc selection const RandState *parent; // previous random state on this path RandState(StateId s, size_t n, size_t l, size_t k, const RandState *p) : state_id(s), nsamples(n), length(l), select(k), parent(p) {} RandState() : state_id(kNoStateId), nsamples(0), length(0), select(0), parent(0) {} }; // This class, given an arc selector, samples, with raplacement, // multiple random transitions from an FST's state. This is a generic // version with a straight-forward use of the arc selector. // Specializations may be defined for arc selectors for greater // efficiency or special behavior. template class ArcSampler { public: typedef typename A::StateId StateId; typedef typename A::Weight Weight; // The 'max_length' may be interpreted (including ignored) by a // sampler as it chooses. This generic version interprets this literally. ArcSampler(const Fst &fst, const S &arc_selector, int max_length = INT_MAX) : fst_(fst), arc_selector_(arc_selector), max_length_(max_length) {} // Allow updating Fst argument; pass only if changed. ArcSampler(const ArcSampler &sampler, const Fst *fst = 0) : fst_(fst ? *fst : sampler.fst_), arc_selector_(sampler.arc_selector_), max_length_(sampler.max_length_) { Reset(); } // Samples 'rstate.nsamples' from state 'state_id'. The 'rstate.length' is // the length of the path to 'rstate'. Returns true if samples were // collected. No samples may be collected if either there are no (including // 'super-final') transitions leaving that state or if the // 'max_length' has been deemed reached. Use the iterator members to // read the samples. The samples will be in their original order. bool Sample(const RandState &rstate) { sample_map_.clear(); if ((fst_.NumArcs(rstate.state_id) == 0 && fst_.Final(rstate.state_id) == Weight::Zero()) || rstate.length == max_length_) { Reset(); return false; } for (size_t i = 0; i < rstate.nsamples; ++i) ++sample_map_[arc_selector_(fst_, rstate.state_id)]; Reset(); return true; } // More samples? bool Done() const { return sample_iter_ == sample_map_.end(); } // Gets the next sample. void Next() { ++sample_iter_; } // Returns a pair (N, K) where 0 <= N <= NumArcs(s) and 0 < K <= nsamples. // If N < NumArcs(s), then the N-th transition is specified. // If N == NumArcs(s), then the final weight at that state is // specified (i.e., the 'super-final' transition is specified). // For the specified transition, K repetitions have been sampled. pair Value() const { return *sample_iter_; } void Reset() { sample_iter_ = sample_map_.begin(); } bool Error() const { return false; } private: const Fst &fst_; const S &arc_selector_; int max_length_; // Stores (N, K) as described for Value(). map sample_map_; map::const_iterator sample_iter_; // disallow ArcSampler & operator=(const ArcSampler &s); }; // Specialization for FastLogProbArcSelector. template class ArcSampler > { public: typedef FastLogProbArcSelector S; typedef typename A::StateId StateId; typedef typename A::Weight Weight; typedef CacheLogAccumulator C; ArcSampler(const Fst &fst, const S &arc_selector, int max_length = INT_MAX) : fst_(fst), arc_selector_(arc_selector), max_length_(max_length), accumulator_(new C()) { accumulator_->Init(fst); } ArcSampler(const ArcSampler &sampler, const Fst *fst = 0) : fst_(fst ? *fst : sampler.fst_), arc_selector_(sampler.arc_selector_), max_length_(sampler.max_length_) { if (fst) { accumulator_ = new C(); accumulator_->Init(*fst); } else { // shallow copy accumulator_ = new C(*sampler.accumulator_); } } ~ArcSampler() { delete accumulator_; } bool Sample(const RandState &rstate) { sample_map_.clear(); if ((fst_.NumArcs(rstate.state_id) == 0 && fst_.Final(rstate.state_id) == Weight::Zero()) || rstate.length == max_length_) { Reset(); return false; } for (size_t i = 0; i < rstate.nsamples; ++i) ++sample_map_[arc_selector_(fst_, rstate.state_id, accumulator_)]; Reset(); return true; } bool Done() const { return sample_iter_ == sample_map_.end(); } void Next() { ++sample_iter_; } pair Value() const { return *sample_iter_; } void Reset() { sample_iter_ = sample_map_.begin(); } bool Error() const { return accumulator_->Error(); } private: const Fst &fst_; const S &arc_selector_; int max_length_; // Stores (N, K) as described for Value(). map sample_map_; map::const_iterator sample_iter_; C *accumulator_; // disallow ArcSampler & operator=(const ArcSampler &s); }; // Options for random path generation with RandGenFst. The template argument // is an arc sampler, typically class 'ArcSampler' above. Ownership of // the sampler is taken by RandGenFst. template struct RandGenFstOptions : public CacheOptions { S *arc_sampler; // How to sample transitions at a state size_t npath; // # of paths to generate bool weighted; // Output tree weighted by path count; o.w. // output unweighted DAG bool remove_total_weight; // Remove total weight when output is weighted. RandGenFstOptions(const CacheOptions &copts, S *samp, size_t n = 1, bool w = true, bool rw = false) : CacheOptions(copts), arc_sampler(samp), npath(n), weighted(w), remove_total_weight(rw) {} }; // Implementation of RandGenFst. template class RandGenFstImpl : public CacheImpl { public: using FstImpl::SetType; using FstImpl::SetProperties; using FstImpl::SetInputSymbols; using FstImpl::SetOutputSymbols; using CacheBaseImpl< CacheState >::AddArc; using CacheBaseImpl< CacheState >::HasArcs; using CacheBaseImpl< CacheState >::HasFinal; using CacheBaseImpl< CacheState >::HasStart; using CacheBaseImpl< CacheState >::SetArcs; using CacheBaseImpl< CacheState >::SetFinal; using CacheBaseImpl< CacheState >::SetStart; typedef B Arc; typedef typename A::Label Label; typedef typename A::Weight Weight; typedef typename A::StateId StateId; RandGenFstImpl(const Fst &fst, const RandGenFstOptions &opts) : CacheImpl(opts), fst_(fst.Copy()), arc_sampler_(opts.arc_sampler), npath_(opts.npath), weighted_(opts.weighted), remove_total_weight_(opts.remove_total_weight), superfinal_(kNoLabel) { SetType("randgen"); uint64 props = fst.Properties(kFstProperties, false); SetProperties(RandGenProperties(props, weighted_), kCopyProperties); SetInputSymbols(fst.InputSymbols()); SetOutputSymbols(fst.OutputSymbols()); } RandGenFstImpl(const RandGenFstImpl &impl) : CacheImpl(impl), fst_(impl.fst_->Copy(true)), arc_sampler_(new S(*impl.arc_sampler_, fst_)), npath_(impl.npath_), weighted_(impl.weighted_), superfinal_(kNoLabel) { SetType("randgen"); SetProperties(impl.Properties(), kCopyProperties); SetInputSymbols(impl.InputSymbols()); SetOutputSymbols(impl.OutputSymbols()); } ~RandGenFstImpl() { for (int i = 0; i < state_table_.size(); ++i) delete state_table_[i]; delete fst_; delete arc_sampler_; } StateId Start() { if (!HasStart()) { StateId s = fst_->Start(); if (s == kNoStateId) return kNoStateId; StateId start = state_table_.size(); SetStart(start); RandState *rstate = new RandState(s, npath_, 0, 0, 0); state_table_.push_back(rstate); } return CacheImpl::Start(); } Weight Final(StateId s) { if (!HasFinal(s)) { Expand(s); } return CacheImpl::Final(s); } size_t NumArcs(StateId s) { if (!HasArcs(s)) { Expand(s); } return CacheImpl::NumArcs(s); } size_t NumInputEpsilons(StateId s) { if (!HasArcs(s)) Expand(s); return CacheImpl::NumInputEpsilons(s); } size_t NumOutputEpsilons(StateId s) { if (!HasArcs(s)) Expand(s); return CacheImpl::NumOutputEpsilons(s); } uint64 Properties() const { return Properties(kFstProperties); } // Set error if found; return FST impl properties. uint64 Properties(uint64 mask) const { if ((mask & kError) && (fst_->Properties(kError, false) || arc_sampler_->Error())) { SetProperties(kError, kError); } return FstImpl::Properties(mask); } void InitArcIterator(StateId s, ArcIteratorData *data) { if (!HasArcs(s)) Expand(s); CacheImpl::InitArcIterator(s, data); } // Computes the outgoing transitions from a state, creating new destination // states as needed. void Expand(StateId s) { if (s == superfinal_) { SetFinal(s, Weight::One()); SetArcs(s); return; } SetFinal(s, Weight::Zero()); const RandState &rstate = *state_table_[s]; arc_sampler_->Sample(rstate); ArcIterator< Fst > aiter(*fst_, rstate.state_id); size_t narcs = fst_->NumArcs(rstate.state_id); for (;!arc_sampler_->Done(); arc_sampler_->Next()) { const pair &sample_pair = arc_sampler_->Value(); size_t pos = sample_pair.first; size_t count = sample_pair.second; double prob = static_cast(count)/rstate.nsamples; if (pos < narcs) { // regular transition aiter.Seek(sample_pair.first); const A &aarc = aiter.Value(); Weight weight = weighted_ ? to_weight_(-log(prob)) : Weight::One(); B barc(aarc.ilabel, aarc.olabel, weight, state_table_.size()); AddArc(s, barc); RandState *nrstate = new RandState(aarc.nextstate, count, rstate.length + 1, pos, &rstate); state_table_.push_back(nrstate); } else { // super-final transition if (weighted_) { Weight weight = remove_total_weight_ ? to_weight_(-log(prob)) : to_weight_(-log(prob * npath_)); SetFinal(s, weight); } else { if (superfinal_ == kNoLabel) { superfinal_ = state_table_.size(); RandState *nrstate = new RandState(kNoStateId, 0, 0, 0, 0); state_table_.push_back(nrstate); } for (size_t n = 0; n < count; ++n) { B barc(0, 0, Weight::One(), superfinal_); AddArc(s, barc); } } } } SetArcs(s); } private: Fst *fst_; S *arc_sampler_; size_t npath_; vector *> state_table_; bool weighted_; bool remove_total_weight_; StateId superfinal_; WeightConvert to_weight_; void operator=(const RandGenFstImpl &); // disallow }; // Fst class to randomly generate paths through an FST; details controlled // by RandGenOptionsFst. Output format is a tree weighted by the // path count. template class RandGenFst : public ImplToFst< RandGenFstImpl > { public: friend class ArcIterator< RandGenFst >; friend class StateIterator< RandGenFst >; typedef B Arc; typedef S Sampler; typedef typename A::Label Label; typedef typename A::Weight Weight; typedef typename A::StateId StateId; typedef CacheState State; typedef RandGenFstImpl Impl; RandGenFst(const Fst &fst, const RandGenFstOptions &opts) : ImplToFst(new Impl(fst, opts)) {} // See Fst<>::Copy() for doc. RandGenFst(const RandGenFst &fst, bool safe = false) : ImplToFst(fst, safe) {} // Get a copy of this RandGenFst. See Fst<>::Copy() for further doc. virtual RandGenFst *Copy(bool safe = false) const { return new RandGenFst(*this, safe); } virtual inline void InitStateIterator(StateIteratorData *data) const; virtual void InitArcIterator(StateId s, ArcIteratorData *data) const { GetImpl()->InitArcIterator(s, data); } private: // Makes visible to friends. Impl *GetImpl() const { return ImplToFst::GetImpl(); } void operator=(const RandGenFst &fst); // Disallow }; // Specialization for RandGenFst. template class StateIterator< RandGenFst > : public CacheStateIterator< RandGenFst > { public: explicit StateIterator(const RandGenFst &fst) : CacheStateIterator< RandGenFst >(fst, fst.GetImpl()) {} private: DISALLOW_COPY_AND_ASSIGN(StateIterator); }; // Specialization for RandGenFst. template class ArcIterator< RandGenFst > : public CacheArcIterator< RandGenFst > { public: typedef typename A::StateId StateId; ArcIterator(const RandGenFst &fst, StateId s) : CacheArcIterator< RandGenFst >(fst.GetImpl(), s) { if (!fst.GetImpl()->HasArcs(s)) fst.GetImpl()->Expand(s); } private: DISALLOW_COPY_AND_ASSIGN(ArcIterator); }; template inline void RandGenFst::InitStateIterator(StateIteratorData *data) const { data->base = new StateIterator< RandGenFst >(*this); } // Options for random path generation. template struct RandGenOptions { const S &arc_selector; // How an arc is selected at a state int max_length; // Maximum path length size_t npath; // # of paths to generate bool weighted; // Output is tree weighted by path count; o.w. // output unweighted union of paths. bool remove_total_weight; // Remove total weight when output is weighted. RandGenOptions(const S &sel, int len = INT_MAX, size_t n = 1, bool w = false, bool rw = false) : arc_selector(sel), max_length(len), npath(n), weighted(w), remove_total_weight(rw) {} }; template class RandGenVisitor { public: typedef typename IArc::Weight Weight; typedef typename IArc::StateId StateId; RandGenVisitor(MutableFst *ofst) : ofst_(ofst) {} void InitVisit(const Fst &ifst) { ifst_ = &ifst; ofst_->DeleteStates(); ofst_->SetInputSymbols(ifst.InputSymbols()); ofst_->SetOutputSymbols(ifst.OutputSymbols()); if (ifst.Properties(kError, false)) ofst_->SetProperties(kError, kError); path_.clear(); } bool InitState(StateId s, StateId root) { return true; } bool TreeArc(StateId s, const IArc &arc) { if (ifst_->Final(arc.nextstate) == Weight::Zero()) { path_.push_back(arc); } else { OutputPath(); } return true; } bool BackArc(StateId s, const IArc &arc) { FSTERROR() << "RandGenVisitor: cyclic input"; ofst_->SetProperties(kError, kError); return false; } bool ForwardOrCrossArc(StateId s, const IArc &arc) { OutputPath(); return true; } void FinishState(StateId s, StateId p, const IArc *) { if (p != kNoStateId && ifst_->Final(s) == Weight::Zero()) path_.pop_back(); } void FinishVisit() {} private: void OutputPath() { if (ofst_->Start() == kNoStateId) { StateId start = ofst_->AddState(); ofst_->SetStart(start); } StateId src = ofst_->Start(); for (size_t i = 0; i < path_.size(); ++i) { StateId dest = ofst_->AddState(); OArc arc(path_[i].ilabel, path_[i].olabel, Weight::One(), dest); ofst_->AddArc(src, arc); src = dest; } ofst_->SetFinal(src, Weight::One()); } const Fst *ifst_; MutableFst *ofst_; vector path_; DISALLOW_COPY_AND_ASSIGN(RandGenVisitor); }; // Randomly generate paths through an FST; details controlled by // RandGenOptions. template void RandGen(const Fst &ifst, MutableFst *ofst, const RandGenOptions &opts) { typedef ArcSampler Sampler; typedef RandGenFst RandFst; typedef typename OArc::StateId StateId; typedef typename OArc::Weight Weight; Sampler* arc_sampler = new Sampler(ifst, opts.arc_selector, opts.max_length); RandGenFstOptions fopts(CacheOptions(true, 0), arc_sampler, opts.npath, opts.weighted, opts.remove_total_weight); RandFst rfst(ifst, fopts); if (opts.weighted) { *ofst = rfst; } else { RandGenVisitor rand_visitor(ofst); DfsVisit(rfst, &rand_visitor); } } // Randomly generate a path through an FST with the uniform distribution // over the transitions. template void RandGen(const Fst &ifst, MutableFst *ofst) { UniformArcSelector uniform_selector; RandGenOptions< UniformArcSelector > opts(uniform_selector); RandGen(ifst, ofst, opts); } } // namespace fst #endif // FST_LIB_RANDGEN_H__