// accumulator.h // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // // Copyright 2005-2010 Google, Inc. // Author: riley@google.com (Michael Riley) // // \file // Classes to accumulate arc weights. Useful for weight lookahead. #ifndef FST_LIB_ACCUMULATOR_H__ #define FST_LIB_ACCUMULATOR_H__ #include #include #include using std::tr1::unordered_map; using std::tr1::unordered_multimap; #include using std::vector; #include #include #include #include #include namespace fst { // This class accumulates arc weights using the semiring Plus(). template class DefaultAccumulator { public: typedef A Arc; typedef typename A::StateId StateId; typedef typename A::Weight Weight; DefaultAccumulator() {} DefaultAccumulator(const DefaultAccumulator &acc) {} void Init(const Fst& fst, bool copy = false) {} void SetState(StateId) {} Weight Sum(Weight w, Weight v) { return Plus(w, v); } template Weight Sum(Weight w, ArcIterator *aiter, ssize_t begin, ssize_t end) { Weight sum = w; aiter->Seek(begin); for (ssize_t pos = begin; pos < end; aiter->Next(), ++pos) sum = Plus(sum, aiter->Value().weight); return sum; } bool Error() const { return false; } private: void operator=(const DefaultAccumulator &); // Disallow }; // This class accumulates arc weights using the log semiring Plus() // assuming an arc weight has a WeightConvert specialization to // and from log64 weights. template class LogAccumulator { public: typedef A Arc; typedef typename A::StateId StateId; typedef typename A::Weight Weight; LogAccumulator() {} LogAccumulator(const LogAccumulator &acc) {} void Init(const Fst& fst, bool copy = false) {} void SetState(StateId) {} Weight Sum(Weight w, Weight v) { return LogPlus(w, v); } template Weight Sum(Weight w, ArcIterator *aiter, ssize_t begin, ssize_t end) { Weight sum = w; aiter->Seek(begin); for (ssize_t pos = begin; pos < end; aiter->Next(), ++pos) sum = LogPlus(sum, aiter->Value().weight); return sum; } bool Error() const { return false; } private: double LogPosExp(double x) { return log(1.0F + exp(-x)); } Weight LogPlus(Weight w, Weight v) { double f1 = to_log_weight_(w).Value(); double f2 = to_log_weight_(v).Value(); if (f1 > f2) return to_weight_(f2 - LogPosExp(f1 - f2)); else return to_weight_(f1 - LogPosExp(f2 - f1)); } WeightConvert to_log_weight_; WeightConvert to_weight_; void operator=(const LogAccumulator &); // Disallow }; // Stores shareable data for fast log accumulator copies. class FastLogAccumulatorData { public: FastLogAccumulatorData() {} vector *Weights() { return &weights_; } vector *WeightPositions() { return &weight_positions_; } double *WeightEnd() { return &(weights_[weights_.size() - 1]); }; int RefCount() const { return ref_count_.count(); } int IncrRefCount() { return ref_count_.Incr(); } int DecrRefCount() { return ref_count_.Decr(); } private: // Cummulative weight per state for all states s.t. # of arcs > // arc_limit_ with arcs in order. Special first element per state // being Log64Weight::Zero(); vector weights_; // Maps from state to corresponding beginning weight position in // weights_. Position -1 means no pre-computed weights for that // state. vector weight_positions_; RefCounter ref_count_; // Reference count. DISALLOW_COPY_AND_ASSIGN(FastLogAccumulatorData); }; // This class accumulates arc weights using the log semiring Plus() // assuming an arc weight has a WeightConvert specialization to and // from log64 weights. The member function Init(fst) has to be called // to setup pre-computed weight information. template class FastLogAccumulator { public: typedef A Arc; typedef typename A::StateId StateId; typedef typename A::Weight Weight; explicit FastLogAccumulator(ssize_t arc_limit = 20, ssize_t arc_period = 10) : arc_limit_(arc_limit), arc_period_(arc_period), data_(new FastLogAccumulatorData()), error_(false) {} FastLogAccumulator(const FastLogAccumulator &acc) : arc_limit_(acc.arc_limit_), arc_period_(acc.arc_period_), data_(acc.data_), error_(acc.error_) { data_->IncrRefCount(); } ~FastLogAccumulator() { if (!data_->DecrRefCount()) delete data_; } void SetState(StateId s) { vector &weights = *data_->Weights(); vector &weight_positions = *data_->WeightPositions(); if (weight_positions.size() <= s) { FSTERROR() << "FastLogAccumulator::SetState: invalid state id."; error_ = true; return; } ssize_t pos = weight_positions[s]; if (pos >= 0) state_weights_ = &(weights[pos]); else state_weights_ = 0; } Weight Sum(Weight w, Weight v) { return LogPlus(w, v); } template Weight Sum(Weight w, ArcIterator *aiter, ssize_t begin, ssize_t end) { if (error_) return Weight::NoWeight(); Weight sum = w; // Finds begin and end of pre-stored weights ssize_t index_begin = -1, index_end = -1; ssize_t stored_begin = end, stored_end = end; if (state_weights_ != 0) { index_begin = begin > 0 ? (begin - 1)/ arc_period_ + 1 : 0; index_end = end / arc_period_; stored_begin = index_begin * arc_period_; stored_end = index_end * arc_period_; } // Computes sum before pre-stored weights if (begin < stored_begin) { ssize_t pos_end = min(stored_begin, end); aiter->Seek(begin); for (ssize_t pos = begin; pos < pos_end; aiter->Next(), ++pos) sum = LogPlus(sum, aiter->Value().weight); } // Computes sum between pre-stored weights if (stored_begin < stored_end) { sum = LogPlus(sum, LogMinus(state_weights_[index_end], state_weights_[index_begin])); } // Computes sum after pre-stored weights if (stored_end < end) { ssize_t pos_start = max(stored_begin, stored_end); aiter->Seek(pos_start); for (ssize_t pos = pos_start; pos < end; aiter->Next(), ++pos) sum = LogPlus(sum, aiter->Value().weight); } return sum; } template void Init(const F &fst, bool copy = false) { if (copy) return; vector &weights = *data_->Weights(); vector &weight_positions = *data_->WeightPositions(); if (!weights.empty() || arc_limit_ < arc_period_) { FSTERROR() << "FastLogAccumulator: initialization error."; error_ = true; return; } weight_positions.reserve(CountStates(fst)); ssize_t weight_position = 0; for(StateIterator siter(fst); !siter.Done(); siter.Next()) { StateId s = siter.Value(); if (fst.NumArcs(s) >= arc_limit_) { double sum = FloatLimits::PosInfinity(); weight_positions.push_back(weight_position); weights.push_back(sum); ++weight_position; ssize_t narcs = 0; for(ArcIterator aiter(fst, s); !aiter.Done(); aiter.Next()) { const A &arc = aiter.Value(); sum = LogPlus(sum, arc.weight); // Stores cumulative weight distribution per arc_period_. if (++narcs % arc_period_ == 0) { weights.push_back(sum); ++weight_position; } } } else { weight_positions.push_back(-1); } } } bool Error() const { return error_; } private: double LogPosExp(double x) { return x == FloatLimits::PosInfinity() ? 0.0 : log(1.0F + exp(-x)); } double LogMinusExp(double x) { return x == FloatLimits::PosInfinity() ? 0.0 : log(1.0F - exp(-x)); } Weight LogPlus(Weight w, Weight v) { double f1 = to_log_weight_(w).Value(); double f2 = to_log_weight_(v).Value(); if (f1 > f2) return to_weight_(f2 - LogPosExp(f1 - f2)); else return to_weight_(f1 - LogPosExp(f2 - f1)); } double LogPlus(double f1, Weight v) { double f2 = to_log_weight_(v).Value(); if (f1 == FloatLimits::PosInfinity()) return f2; else if (f1 > f2) return f2 - LogPosExp(f1 - f2); else return f1 - LogPosExp(f2 - f1); } Weight LogMinus(double f1, double f2) { if (f1 >= f2) { FSTERROR() << "FastLogAcumulator::LogMinus: f1 >= f2 with f1 = " << f1 << " and f2 = " << f2; error_ = true; return Weight::NoWeight(); } if (f2 == FloatLimits::PosInfinity()) return to_weight_(f1); else return to_weight_(f1 - LogMinusExp(f2 - f1)); } WeightConvert to_log_weight_; WeightConvert to_weight_; ssize_t arc_limit_; // Minimum # of arcs to pre-compute state ssize_t arc_period_; // Save cumulative weights per 'arc_period_'. bool init_; // Cumulative weights initialized? FastLogAccumulatorData *data_; double *state_weights_; bool error_; void operator=(const FastLogAccumulator &); // Disallow }; // Stores shareable data for cache log accumulator copies. // All copies share the same cache. template class CacheLogAccumulatorData { public: typedef A Arc; typedef typename A::StateId StateId; typedef typename A::Weight Weight; CacheLogAccumulatorData(bool gc, size_t gc_limit) : cache_gc_(gc), cache_limit_(gc_limit), cache_size_(0) {} ~CacheLogAccumulatorData() { for(typename unordered_map::iterator it = cache_.begin(); it != cache_.end(); ++it) delete it->second.weights; } bool CacheDisabled() const { return cache_gc_ && cache_limit_ == 0; } vector *GetWeights(StateId s) { typename unordered_map::iterator it = cache_.find(s); if (it != cache_.end()) { it->second.recent = true; return it->second.weights; } else { return 0; } } void AddWeights(StateId s, vector *weights) { if (cache_gc_ && cache_size_ >= cache_limit_) GC(false); cache_.insert(make_pair(s, CacheState(weights, true))); if (cache_gc_) cache_size_ += weights->capacity() * sizeof(double); } int RefCount() const { return ref_count_.count(); } int IncrRefCount() { return ref_count_.Incr(); } int DecrRefCount() { return ref_count_.Decr(); } private: // Cached information for a given state. struct CacheState { vector* weights; // Accumulated weights for this state. bool recent; // Has this state been accessed since last GC? CacheState(vector *w, bool r) : weights(w), recent(r) {} }; // Garbage collect: Delete from cache states that have not been // accessed since the last GC ('free_recent = false') until // 'cache_size_' is 2/3 of 'cache_limit_'. If it does not free enough // memory, start deleting recently accessed states. void GC(bool free_recent) { size_t cache_target = (2 * cache_limit_)/3 + 1; typename unordered_map::iterator it = cache_.begin(); while (it != cache_.end() && cache_size_ > cache_target) { CacheState &cs = it->second; if (free_recent || !cs.recent) { cache_size_ -= cs.weights->capacity() * sizeof(double); delete cs.weights; cache_.erase(it++); } else { cs.recent = false; ++it; } } if (!free_recent && cache_size_ > cache_target) GC(true); } unordered_map cache_; // Cache bool cache_gc_; // Enable garbage collection size_t cache_limit_; // # of bytes cached size_t cache_size_; // # of bytes allowed before GC RefCounter ref_count_; DISALLOW_COPY_AND_ASSIGN(CacheLogAccumulatorData); }; // This class accumulates arc weights using the log semiring Plus() // has a WeightConvert specialization to and from log64 weights. It // is similar to the FastLogAccumator. However here, the accumulated // weights are pre-computed and stored only for the states that are // visited. The member function Init(fst) has to be called to setup // this accumulator. template class CacheLogAccumulator { public: typedef A Arc; typedef typename A::StateId StateId; typedef typename A::Weight Weight; explicit CacheLogAccumulator(ssize_t arc_limit = 10, bool gc = false, size_t gc_limit = 10 * 1024 * 1024) : arc_limit_(arc_limit), fst_(0), data_( new CacheLogAccumulatorData(gc, gc_limit)), s_(kNoStateId), error_(false) {} CacheLogAccumulator(const CacheLogAccumulator &acc) : arc_limit_(acc.arc_limit_), fst_(acc.fst_ ? acc.fst_->Copy() : 0), data_(acc.data_), s_(kNoStateId), error_(acc.error_) { data_->IncrRefCount(); } ~CacheLogAccumulator() { if (fst_) delete fst_; if (!data_->DecrRefCount()) delete data_; } // Arg 'arc_limit' specifies minimum # of arcs to pre-compute state. void Init(const Fst &fst, bool copy = false) { if (copy) { delete fst_; } else if (fst_) { FSTERROR() << "CacheLogAccumulator: initialization error."; error_ = true; return; } fst_ = fst.Copy(); } void SetState(StateId s, int depth = 0) { if (s == s_) return; s_ = s; if (data_->CacheDisabled() || error_) { weights_ = 0; return; } if (!fst_) { FSTERROR() << "CacheLogAccumulator::SetState: incorrectly initialized."; error_ = true; weights_ = 0; return; } weights_ = data_->GetWeights(s); if ((weights_ == 0) && (fst_->NumArcs(s) >= arc_limit_)) { weights_ = new vector; weights_->reserve(fst_->NumArcs(s) + 1); weights_->push_back(FloatLimits::PosInfinity()); data_->AddWeights(s, weights_); } } Weight Sum(Weight w, Weight v) { return LogPlus(w, v); } template Weight Sum(Weight w, Iterator *aiter, ssize_t begin, ssize_t end) { if (weights_ == 0) { Weight sum = w; aiter->Seek(begin); for (ssize_t pos = begin; pos < end; aiter->Next(), ++pos) sum = LogPlus(sum, aiter->Value().weight); return sum; } else { if (weights_->size() <= end) for (aiter->Seek(weights_->size() - 1); weights_->size() <= end; aiter->Next()) weights_->push_back(LogPlus(weights_->back(), aiter->Value().weight)); return LogPlus(w, LogMinus((*weights_)[end], (*weights_)[begin])); } } template size_t LowerBound(double w, Iterator *aiter) { if (weights_ != 0) { return lower_bound(weights_->begin() + 1, weights_->end(), w, std::greater()) - weights_->begin() - 1; } else { size_t n = 0; double x = FloatLimits::PosInfinity(); for(aiter->Reset(); !aiter->Done(); aiter->Next(), ++n) { x = LogPlus(x, aiter->Value().weight); if (x < w) break; } return n; } } bool Error() const { return error_; } private: double LogPosExp(double x) { return x == FloatLimits::PosInfinity() ? 0.0 : log(1.0F + exp(-x)); } double LogMinusExp(double x) { return x == FloatLimits::PosInfinity() ? 0.0 : log(1.0F - exp(-x)); } Weight LogPlus(Weight w, Weight v) { double f1 = to_log_weight_(w).Value(); double f2 = to_log_weight_(v).Value(); if (f1 > f2) return to_weight_(f2 - LogPosExp(f1 - f2)); else return to_weight_(f1 - LogPosExp(f2 - f1)); } double LogPlus(double f1, Weight v) { double f2 = to_log_weight_(v).Value(); if (f1 == FloatLimits::PosInfinity()) return f2; else if (f1 > f2) return f2 - LogPosExp(f1 - f2); else return f1 - LogPosExp(f2 - f1); } Weight LogMinus(double f1, double f2) { if (f1 >= f2) { FSTERROR() << "CacheLogAcumulator::LogMinus: f1 >= f2 with f1 = " << f1 << " and f2 = " << f2; error_ = true; return Weight::NoWeight(); } if (f2 == FloatLimits::PosInfinity()) return to_weight_(f1); else return to_weight_(f1 - LogMinusExp(f2 - f1)); } WeightConvert to_log_weight_; WeightConvert to_weight_; ssize_t arc_limit_; // Minimum # of arcs to cache a state vector *weights_; // Accumulated weights for cur. state const Fst* fst_; // Input fst CacheLogAccumulatorData *data_; // Cache data StateId s_; // Current state bool error_; void operator=(const CacheLogAccumulator &); // Disallow }; // Stores shareable data for replace accumulator copies. template class ReplaceAccumulatorData { public: typedef typename Accumulator::Arc Arc; typedef typename Arc::StateId StateId; typedef typename Arc::Label Label; typedef T StateTable; typedef typename T::StateTuple StateTuple; ReplaceAccumulatorData() : state_table_(0) {} ReplaceAccumulatorData(const vector &accumulators) : state_table_(0), accumulators_(accumulators) {} ~ReplaceAccumulatorData() { for (size_t i = 0; i < fst_array_.size(); ++i) delete fst_array_[i]; for (size_t i = 0; i < accumulators_.size(); ++i) delete accumulators_[i]; } void Init(const vector*> > &fst_tuples, const StateTable *state_table) { state_table_ = state_table; accumulators_.resize(fst_tuples.size()); for (size_t i = 0; i < accumulators_.size(); ++i) { if (!accumulators_[i]) accumulators_[i] = new Accumulator; accumulators_[i]->Init(*(fst_tuples[i].second)); fst_array_.push_back(fst_tuples[i].second->Copy()); } } const StateTuple &GetTuple(StateId s) const { return state_table_->Tuple(s); } Accumulator *GetAccumulator(size_t i) { return accumulators_[i]; } const Fst *GetFst(size_t i) const { return fst_array_[i]; } int RefCount() const { return ref_count_.count(); } int IncrRefCount() { return ref_count_.Incr(); } int DecrRefCount() { return ref_count_.Decr(); } private: const T * state_table_; vector accumulators_; vector*> fst_array_; RefCounter ref_count_; DISALLOW_COPY_AND_ASSIGN(ReplaceAccumulatorData); }; // This class accumulates weights in a ReplaceFst. The 'Init' method // takes as input the argument used to build the ReplaceFst and the // ReplaceFst state table. It uses accumulators of type 'Accumulator' // in the underlying FSTs. template > class ReplaceAccumulator { public: typedef typename Accumulator::Arc Arc; typedef typename Arc::StateId StateId; typedef typename Arc::Label Label; typedef typename Arc::Weight Weight; typedef T StateTable; typedef typename T::StateTuple StateTuple; ReplaceAccumulator() : init_(false), data_(new ReplaceAccumulatorData()), error_(false) {} ReplaceAccumulator(const vector &accumulators) : init_(false), data_(new ReplaceAccumulatorData(accumulators)), error_(false) {} ReplaceAccumulator(const ReplaceAccumulator &acc) : init_(acc.init_), data_(acc.data_), error_(acc.error_) { if (!init_) FSTERROR() << "ReplaceAccumulator: can't copy unintialized accumulator"; data_->IncrRefCount(); } ~ReplaceAccumulator() { if (!data_->DecrRefCount()) delete data_; } // Does not take ownership of the state table, the state table // is own by the ReplaceFst void Init(const vector*> > &fst_tuples, const StateTable *state_table) { init_ = true; data_->Init(fst_tuples, state_table); } void SetState(StateId s) { if (!init_) { FSTERROR() << "ReplaceAccumulator::SetState: incorrectly initialized."; error_ = true; return; } StateTuple tuple = data_->GetTuple(s); fst_id_ = tuple.fst_id - 1; // Replace FST ID is 1-based data_->GetAccumulator(fst_id_)->SetState(tuple.fst_state); if ((tuple.prefix_id != 0) && (data_->GetFst(fst_id_)->Final(tuple.fst_state) != Weight::Zero())) { offset_ = 1; offset_weight_ = data_->GetFst(fst_id_)->Final(tuple.fst_state); } else { offset_ = 0; offset_weight_ = Weight::Zero(); } } Weight Sum(Weight w, Weight v) { if (error_) return Weight::NoWeight(); return data_->GetAccumulator(fst_id_)->Sum(w, v); } template Weight Sum(Weight w, ArcIterator *aiter, ssize_t begin, ssize_t end) { if (error_) return Weight::NoWeight(); Weight sum = begin == end ? Weight::Zero() : data_->GetAccumulator(fst_id_)->Sum( w, aiter, begin ? begin - offset_ : 0, end - offset_); if (begin == 0 && end != 0 && offset_ > 0) sum = Sum(offset_weight_, sum); return sum; } bool Error() const { return error_; } private: bool init_; ReplaceAccumulatorData *data_; Label fst_id_; size_t offset_; Weight offset_weight_; bool error_; void operator=(const ReplaceAccumulator &); // Disallow }; } // namespace fst #endif // FST_LIB_ACCUMULATOR_H__