diff options
author | Ted Yin <[email protected]> | 2015-08-14 17:42:26 +0800 |
---|---|---|
committer | Ted Yin <[email protected]> | 2015-08-14 17:42:26 +0800 |
commit | c3cffb58b9921d78753336421b52b9ffdaa5515c (patch) | |
tree | bfea20e97c200cf734021e3756d749c892e658a4 /kaldi_io/src/tools/openfst/include/fst/accumulator.h | |
parent | 10cce5f6a5c9e2f8e00d5a2a4d87c9cb7c26bf4c (diff) | |
parent | dfdd17afc2e984ec6c32ea01290f5c76309a456a (diff) |
Merge pull request #2 from yimmon/master
remove needless files
Diffstat (limited to 'kaldi_io/src/tools/openfst/include/fst/accumulator.h')
-rw-r--r-- | kaldi_io/src/tools/openfst/include/fst/accumulator.h | 745 |
1 files changed, 0 insertions, 745 deletions
diff --git a/kaldi_io/src/tools/openfst/include/fst/accumulator.h b/kaldi_io/src/tools/openfst/include/fst/accumulator.h deleted file mode 100644 index 81d1847..0000000 --- a/kaldi_io/src/tools/openfst/include/fst/accumulator.h +++ /dev/null @@ -1,745 +0,0 @@ -// accumulator.h - -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// -// Copyright 2005-2010 Google, Inc. -// Author: [email protected] (Michael Riley) -// -// \file -// Classes to accumulate arc weights. Useful for weight lookahead. - -#ifndef FST_LIB_ACCUMULATOR_H__ -#define FST_LIB_ACCUMULATOR_H__ - -#include <algorithm> -#include <functional> -#include <tr1/unordered_map> -using std::tr1::unordered_map; -using std::tr1::unordered_multimap; -#include <vector> -using std::vector; - -#include <fst/arcfilter.h> -#include <fst/arcsort.h> -#include <fst/dfs-visit.h> -#include <fst/expanded-fst.h> -#include <fst/replace.h> - -namespace fst { - -// This class accumulates arc weights using the semiring Plus(). -template <class A> -class DefaultAccumulator { - public: - typedef A Arc; - typedef typename A::StateId StateId; - typedef typename A::Weight Weight; - - DefaultAccumulator() {} - - DefaultAccumulator(const DefaultAccumulator<A> &acc) {} - - void Init(const Fst<A>& fst, bool copy = false) {} - - void SetState(StateId) {} - - Weight Sum(Weight w, Weight v) { - return Plus(w, v); - } - - template <class ArcIterator> - Weight Sum(Weight w, ArcIterator *aiter, ssize_t begin, - ssize_t end) { - Weight sum = w; - aiter->Seek(begin); - for (ssize_t pos = begin; pos < end; aiter->Next(), ++pos) - sum = Plus(sum, aiter->Value().weight); - return sum; - } - - bool Error() const { return false; } - - private: - void operator=(const DefaultAccumulator<A> &); // Disallow -}; - - -// This class accumulates arc weights using the log semiring Plus() -// assuming an arc weight has a WeightConvert specialization to -// and from log64 weights. -template <class A> -class LogAccumulator { - public: - typedef A Arc; - typedef typename A::StateId StateId; - typedef typename A::Weight Weight; - - LogAccumulator() {} - - LogAccumulator(const LogAccumulator<A> &acc) {} - - void Init(const Fst<A>& fst, bool copy = false) {} - - void SetState(StateId) {} - - Weight Sum(Weight w, Weight v) { - return LogPlus(w, v); - } - - template <class ArcIterator> - Weight Sum(Weight w, ArcIterator *aiter, ssize_t begin, - ssize_t end) { - Weight sum = w; - aiter->Seek(begin); - for (ssize_t pos = begin; pos < end; aiter->Next(), ++pos) - sum = LogPlus(sum, aiter->Value().weight); - return sum; - } - - bool Error() const { return false; } - - private: - double LogPosExp(double x) { return log(1.0F + exp(-x)); } - - Weight LogPlus(Weight w, Weight v) { - double f1 = to_log_weight_(w).Value(); - double f2 = to_log_weight_(v).Value(); - if (f1 > f2) - return to_weight_(f2 - LogPosExp(f1 - f2)); - else - return to_weight_(f1 - LogPosExp(f2 - f1)); - } - - WeightConvert<Weight, Log64Weight> to_log_weight_; - WeightConvert<Log64Weight, Weight> to_weight_; - - void operator=(const LogAccumulator<A> &); // Disallow -}; - - -// Stores shareable data for fast log accumulator copies. -class FastLogAccumulatorData { - public: - FastLogAccumulatorData() {} - - vector<double> *Weights() { return &weights_; } - vector<ssize_t> *WeightPositions() { return &weight_positions_; } - double *WeightEnd() { return &(weights_[weights_.size() - 1]); }; - int RefCount() const { return ref_count_.count(); } - int IncrRefCount() { return ref_count_.Incr(); } - int DecrRefCount() { return ref_count_.Decr(); } - - private: - // Cummulative weight per state for all states s.t. # of arcs > - // arc_limit_ with arcs in order. Special first element per state - // being Log64Weight::Zero(); - vector<double> weights_; - // Maps from state to corresponding beginning weight position in - // weights_. Position -1 means no pre-computed weights for that - // state. - vector<ssize_t> weight_positions_; - RefCounter ref_count_; // Reference count. - - DISALLOW_COPY_AND_ASSIGN(FastLogAccumulatorData); -}; - - -// This class accumulates arc weights using the log semiring Plus() -// assuming an arc weight has a WeightConvert specialization to and -// from log64 weights. The member function Init(fst) has to be called -// to setup pre-computed weight information. -template <class A> -class FastLogAccumulator { - public: - typedef A Arc; - typedef typename A::StateId StateId; - typedef typename A::Weight Weight; - - explicit FastLogAccumulator(ssize_t arc_limit = 20, ssize_t arc_period = 10) - : arc_limit_(arc_limit), - arc_period_(arc_period), - data_(new FastLogAccumulatorData()), - error_(false) {} - - FastLogAccumulator(const FastLogAccumulator<A> &acc) - : arc_limit_(acc.arc_limit_), - arc_period_(acc.arc_period_), - data_(acc.data_), - error_(acc.error_) { - data_->IncrRefCount(); - } - - ~FastLogAccumulator() { - if (!data_->DecrRefCount()) - delete data_; - } - - void SetState(StateId s) { - vector<double> &weights = *data_->Weights(); - vector<ssize_t> &weight_positions = *data_->WeightPositions(); - - if (weight_positions.size() <= s) { - FSTERROR() << "FastLogAccumulator::SetState: invalid state id."; - error_ = true; - return; - } - - ssize_t pos = weight_positions[s]; - if (pos >= 0) - state_weights_ = &(weights[pos]); - else - state_weights_ = 0; - } - - Weight Sum(Weight w, Weight v) { - return LogPlus(w, v); - } - - template <class ArcIterator> - Weight Sum(Weight w, ArcIterator *aiter, ssize_t begin, - ssize_t end) { - if (error_) return Weight::NoWeight(); - Weight sum = w; - // Finds begin and end of pre-stored weights - ssize_t index_begin = -1, index_end = -1; - ssize_t stored_begin = end, stored_end = end; - if (state_weights_ != 0) { - index_begin = begin > 0 ? (begin - 1)/ arc_period_ + 1 : 0; - index_end = end / arc_period_; - stored_begin = index_begin * arc_period_; - stored_end = index_end * arc_period_; - } - // Computes sum before pre-stored weights - if (begin < stored_begin) { - ssize_t pos_end = min(stored_begin, end); - aiter->Seek(begin); - for (ssize_t pos = begin; pos < pos_end; aiter->Next(), ++pos) - sum = LogPlus(sum, aiter->Value().weight); - } - // Computes sum between pre-stored weights - if (stored_begin < stored_end) { - sum = LogPlus(sum, LogMinus(state_weights_[index_end], - state_weights_[index_begin])); - } - // Computes sum after pre-stored weights - if (stored_end < end) { - ssize_t pos_start = max(stored_begin, stored_end); - aiter->Seek(pos_start); - for (ssize_t pos = pos_start; pos < end; aiter->Next(), ++pos) - sum = LogPlus(sum, aiter->Value().weight); - } - return sum; - } - - template <class F> - void Init(const F &fst, bool copy = false) { - if (copy) - return; - vector<double> &weights = *data_->Weights(); - vector<ssize_t> &weight_positions = *data_->WeightPositions(); - if (!weights.empty() || arc_limit_ < arc_period_) { - FSTERROR() << "FastLogAccumulator: initialization error."; - error_ = true; - return; - } - weight_positions.reserve(CountStates(fst)); - - ssize_t weight_position = 0; - for(StateIterator<F> siter(fst); !siter.Done(); siter.Next()) { - StateId s = siter.Value(); - if (fst.NumArcs(s) >= arc_limit_) { - double sum = FloatLimits<double>::PosInfinity(); - weight_positions.push_back(weight_position); - weights.push_back(sum); - ++weight_position; - ssize_t narcs = 0; - for(ArcIterator<F> aiter(fst, s); !aiter.Done(); aiter.Next()) { - const A &arc = aiter.Value(); - sum = LogPlus(sum, arc.weight); - // Stores cumulative weight distribution per arc_period_. - if (++narcs % arc_period_ == 0) { - weights.push_back(sum); - ++weight_position; - } - } - } else { - weight_positions.push_back(-1); - } - } - } - - bool Error() const { return error_; } - - private: - double LogPosExp(double x) { - return x == FloatLimits<double>::PosInfinity() ? - 0.0 : log(1.0F + exp(-x)); - } - - double LogMinusExp(double x) { - return x == FloatLimits<double>::PosInfinity() ? - 0.0 : log(1.0F - exp(-x)); - } - - Weight LogPlus(Weight w, Weight v) { - double f1 = to_log_weight_(w).Value(); - double f2 = to_log_weight_(v).Value(); - if (f1 > f2) - return to_weight_(f2 - LogPosExp(f1 - f2)); - else - return to_weight_(f1 - LogPosExp(f2 - f1)); - } - - double LogPlus(double f1, Weight v) { - double f2 = to_log_weight_(v).Value(); - if (f1 == FloatLimits<double>::PosInfinity()) - return f2; - else if (f1 > f2) - return f2 - LogPosExp(f1 - f2); - else - return f1 - LogPosExp(f2 - f1); - } - - Weight LogMinus(double f1, double f2) { - if (f1 >= f2) { - FSTERROR() << "FastLogAcumulator::LogMinus: f1 >= f2 with f1 = " << f1 - << " and f2 = " << f2; - error_ = true; - return Weight::NoWeight(); - } - if (f2 == FloatLimits<double>::PosInfinity()) - return to_weight_(f1); - else - return to_weight_(f1 - LogMinusExp(f2 - f1)); - } - - WeightConvert<Weight, Log64Weight> to_log_weight_; - WeightConvert<Log64Weight, Weight> to_weight_; - - ssize_t arc_limit_; // Minimum # of arcs to pre-compute state - ssize_t arc_period_; // Save cumulative weights per 'arc_period_'. - bool init_; // Cumulative weights initialized? - FastLogAccumulatorData *data_; - double *state_weights_; - bool error_; - - void operator=(const FastLogAccumulator<A> &); // Disallow -}; - - -// Stores shareable data for cache log accumulator copies. -// All copies share the same cache. -template <class A> -class CacheLogAccumulatorData { - public: - typedef A Arc; - typedef typename A::StateId StateId; - typedef typename A::Weight Weight; - - CacheLogAccumulatorData(bool gc, size_t gc_limit) - : cache_gc_(gc), cache_limit_(gc_limit), cache_size_(0) {} - - ~CacheLogAccumulatorData() { - for(typename unordered_map<StateId, CacheState>::iterator it = cache_.begin(); - it != cache_.end(); - ++it) - delete it->second.weights; - } - - bool CacheDisabled() const { return cache_gc_ && cache_limit_ == 0; } - - vector<double> *GetWeights(StateId s) { - typename unordered_map<StateId, CacheState>::iterator it = cache_.find(s); - if (it != cache_.end()) { - it->second.recent = true; - return it->second.weights; - } else { - return 0; - } - } - - void AddWeights(StateId s, vector<double> *weights) { - if (cache_gc_ && cache_size_ >= cache_limit_) - GC(false); - cache_.insert(make_pair(s, CacheState(weights, true))); - if (cache_gc_) - cache_size_ += weights->capacity() * sizeof(double); - } - - int RefCount() const { return ref_count_.count(); } - int IncrRefCount() { return ref_count_.Incr(); } - int DecrRefCount() { return ref_count_.Decr(); } - - private: - // Cached information for a given state. - struct CacheState { - vector<double>* weights; // Accumulated weights for this state. - bool recent; // Has this state been accessed since last GC? - - CacheState(vector<double> *w, bool r) : weights(w), recent(r) {} - }; - - // Garbage collect: Delete from cache states that have not been - // accessed since the last GC ('free_recent = false') until - // 'cache_size_' is 2/3 of 'cache_limit_'. If it does not free enough - // memory, start deleting recently accessed states. - void GC(bool free_recent) { - size_t cache_target = (2 * cache_limit_)/3 + 1; - typename unordered_map<StateId, CacheState>::iterator it = cache_.begin(); - while (it != cache_.end() && cache_size_ > cache_target) { - CacheState &cs = it->second; - if (free_recent || !cs.recent) { - cache_size_ -= cs.weights->capacity() * sizeof(double); - delete cs.weights; - cache_.erase(it++); - } else { - cs.recent = false; - ++it; - } - } - if (!free_recent && cache_size_ > cache_target) - GC(true); - } - - unordered_map<StateId, CacheState> cache_; // Cache - bool cache_gc_; // Enable garbage collection - size_t cache_limit_; // # of bytes cached - size_t cache_size_; // # of bytes allowed before GC - RefCounter ref_count_; - - DISALLOW_COPY_AND_ASSIGN(CacheLogAccumulatorData); -}; - -// This class accumulates arc weights using the log semiring Plus() -// has a WeightConvert specialization to and from log64 weights. It -// is similar to the FastLogAccumator. However here, the accumulated -// weights are pre-computed and stored only for the states that are -// visited. The member function Init(fst) has to be called to setup -// this accumulator. -template <class A> -class CacheLogAccumulator { - public: - typedef A Arc; - typedef typename A::StateId StateId; - typedef typename A::Weight Weight; - - explicit CacheLogAccumulator(ssize_t arc_limit = 10, bool gc = false, - size_t gc_limit = 10 * 1024 * 1024) - : arc_limit_(arc_limit), fst_(0), data_( - new CacheLogAccumulatorData<A>(gc, gc_limit)), s_(kNoStateId), - error_(false) {} - - CacheLogAccumulator(const CacheLogAccumulator<A> &acc) - : arc_limit_(acc.arc_limit_), fst_(acc.fst_ ? acc.fst_->Copy() : 0), - data_(acc.data_), s_(kNoStateId), error_(acc.error_) { - data_->IncrRefCount(); - } - - ~CacheLogAccumulator() { - if (fst_) - delete fst_; - if (!data_->DecrRefCount()) - delete data_; - } - - // Arg 'arc_limit' specifies minimum # of arcs to pre-compute state. - void Init(const Fst<A> &fst, bool copy = false) { - if (copy) { - delete fst_; - } else if (fst_) { - FSTERROR() << "CacheLogAccumulator: initialization error."; - error_ = true; - return; - } - fst_ = fst.Copy(); - } - - void SetState(StateId s, int depth = 0) { - if (s == s_) - return; - s_ = s; - - if (data_->CacheDisabled() || error_) { - weights_ = 0; - return; - } - - if (!fst_) { - FSTERROR() << "CacheLogAccumulator::SetState: incorrectly initialized."; - error_ = true; - weights_ = 0; - return; - } - - weights_ = data_->GetWeights(s); - if ((weights_ == 0) && (fst_->NumArcs(s) >= arc_limit_)) { - weights_ = new vector<double>; - weights_->reserve(fst_->NumArcs(s) + 1); - weights_->push_back(FloatLimits<double>::PosInfinity()); - data_->AddWeights(s, weights_); - } - } - - Weight Sum(Weight w, Weight v) { - return LogPlus(w, v); - } - - template <class Iterator> - Weight Sum(Weight w, Iterator *aiter, ssize_t begin, - ssize_t end) { - if (weights_ == 0) { - Weight sum = w; - aiter->Seek(begin); - for (ssize_t pos = begin; pos < end; aiter->Next(), ++pos) - sum = LogPlus(sum, aiter->Value().weight); - return sum; - } else { - if (weights_->size() <= end) - for (aiter->Seek(weights_->size() - 1); - weights_->size() <= end; - aiter->Next()) - weights_->push_back(LogPlus(weights_->back(), - aiter->Value().weight)); - return LogPlus(w, LogMinus((*weights_)[end], (*weights_)[begin])); - } - } - - template <class Iterator> - size_t LowerBound(double w, Iterator *aiter) { - if (weights_ != 0) { - return lower_bound(weights_->begin() + 1, - weights_->end(), - w, - std::greater<double>()) - - weights_->begin() - 1; - } else { - size_t n = 0; - double x = FloatLimits<double>::PosInfinity(); - for(aiter->Reset(); !aiter->Done(); aiter->Next(), ++n) { - x = LogPlus(x, aiter->Value().weight); - if (x < w) break; - } - return n; - } - } - - bool Error() const { return error_; } - - private: - double LogPosExp(double x) { - return x == FloatLimits<double>::PosInfinity() ? - 0.0 : log(1.0F + exp(-x)); - } - - double LogMinusExp(double x) { - return x == FloatLimits<double>::PosInfinity() ? - 0.0 : log(1.0F - exp(-x)); - } - - Weight LogPlus(Weight w, Weight v) { - double f1 = to_log_weight_(w).Value(); - double f2 = to_log_weight_(v).Value(); - if (f1 > f2) - return to_weight_(f2 - LogPosExp(f1 - f2)); - else - return to_weight_(f1 - LogPosExp(f2 - f1)); - } - - double LogPlus(double f1, Weight v) { - double f2 = to_log_weight_(v).Value(); - if (f1 == FloatLimits<double>::PosInfinity()) - return f2; - else if (f1 > f2) - return f2 - LogPosExp(f1 - f2); - else - return f1 - LogPosExp(f2 - f1); - } - - Weight LogMinus(double f1, double f2) { - if (f1 >= f2) { - FSTERROR() << "CacheLogAcumulator::LogMinus: f1 >= f2 with f1 = " << f1 - << " and f2 = " << f2; - error_ = true; - return Weight::NoWeight(); - } - if (f2 == FloatLimits<double>::PosInfinity()) - return to_weight_(f1); - else - return to_weight_(f1 - LogMinusExp(f2 - f1)); - } - - WeightConvert<Weight, Log64Weight> to_log_weight_; - WeightConvert<Log64Weight, Weight> to_weight_; - - ssize_t arc_limit_; // Minimum # of arcs to cache a state - vector<double> *weights_; // Accumulated weights for cur. state - const Fst<A>* fst_; // Input fst - CacheLogAccumulatorData<A> *data_; // Cache data - StateId s_; // Current state - bool error_; - - void operator=(const CacheLogAccumulator<A> &); // Disallow -}; - - -// Stores shareable data for replace accumulator copies. -template <class Accumulator, class T> -class ReplaceAccumulatorData { - public: - typedef typename Accumulator::Arc Arc; - typedef typename Arc::StateId StateId; - typedef typename Arc::Label Label; - typedef T StateTable; - typedef typename T::StateTuple StateTuple; - - ReplaceAccumulatorData() : state_table_(0) {} - - ReplaceAccumulatorData(const vector<Accumulator*> &accumulators) - : state_table_(0), accumulators_(accumulators) {} - - ~ReplaceAccumulatorData() { - for (size_t i = 0; i < fst_array_.size(); ++i) - delete fst_array_[i]; - for (size_t i = 0; i < accumulators_.size(); ++i) - delete accumulators_[i]; - } - - void Init(const vector<pair<Label, const Fst<Arc>*> > &fst_tuples, - const StateTable *state_table) { - state_table_ = state_table; - accumulators_.resize(fst_tuples.size()); - for (size_t i = 0; i < accumulators_.size(); ++i) { - if (!accumulators_[i]) - accumulators_[i] = new Accumulator; - accumulators_[i]->Init(*(fst_tuples[i].second)); - fst_array_.push_back(fst_tuples[i].second->Copy()); - } - } - - const StateTuple &GetTuple(StateId s) const { - return state_table_->Tuple(s); - } - - Accumulator *GetAccumulator(size_t i) { return accumulators_[i]; } - - const Fst<Arc> *GetFst(size_t i) const { return fst_array_[i]; } - - int RefCount() const { return ref_count_.count(); } - int IncrRefCount() { return ref_count_.Incr(); } - int DecrRefCount() { return ref_count_.Decr(); } - - private: - const T * state_table_; - vector<Accumulator*> accumulators_; - vector<const Fst<Arc>*> fst_array_; - RefCounter ref_count_; - - DISALLOW_COPY_AND_ASSIGN(ReplaceAccumulatorData); -}; - -// This class accumulates weights in a ReplaceFst. The 'Init' method -// takes as input the argument used to build the ReplaceFst and the -// ReplaceFst state table. It uses accumulators of type 'Accumulator' -// in the underlying FSTs. -template <class Accumulator, - class T = DefaultReplaceStateTable<typename Accumulator::Arc> > -class ReplaceAccumulator { - public: - typedef typename Accumulator::Arc Arc; - typedef typename Arc::StateId StateId; - typedef typename Arc::Label Label; - typedef typename Arc::Weight Weight; - typedef T StateTable; - typedef typename T::StateTuple StateTuple; - - ReplaceAccumulator() - : init_(false), data_(new ReplaceAccumulatorData<Accumulator, T>()), - error_(false) {} - - ReplaceAccumulator(const vector<Accumulator*> &accumulators) - : init_(false), - data_(new ReplaceAccumulatorData<Accumulator, T>(accumulators)), - error_(false) {} - - ReplaceAccumulator(const ReplaceAccumulator<Accumulator, T> &acc) - : init_(acc.init_), data_(acc.data_), error_(acc.error_) { - if (!init_) - FSTERROR() << "ReplaceAccumulator: can't copy unintialized accumulator"; - data_->IncrRefCount(); - } - - ~ReplaceAccumulator() { - if (!data_->DecrRefCount()) - delete data_; - } - - // Does not take ownership of the state table, the state table - // is own by the ReplaceFst - void Init(const vector<pair<Label, const Fst<Arc>*> > &fst_tuples, - const StateTable *state_table) { - init_ = true; - data_->Init(fst_tuples, state_table); - } - - void SetState(StateId s) { - if (!init_) { - FSTERROR() << "ReplaceAccumulator::SetState: incorrectly initialized."; - error_ = true; - return; - } - StateTuple tuple = data_->GetTuple(s); - fst_id_ = tuple.fst_id - 1; // Replace FST ID is 1-based - data_->GetAccumulator(fst_id_)->SetState(tuple.fst_state); - if ((tuple.prefix_id != 0) && - (data_->GetFst(fst_id_)->Final(tuple.fst_state) != Weight::Zero())) { - offset_ = 1; - offset_weight_ = data_->GetFst(fst_id_)->Final(tuple.fst_state); - } else { - offset_ = 0; - offset_weight_ = Weight::Zero(); - } - } - - Weight Sum(Weight w, Weight v) { - if (error_) return Weight::NoWeight(); - return data_->GetAccumulator(fst_id_)->Sum(w, v); - } - - template <class ArcIterator> - Weight Sum(Weight w, ArcIterator *aiter, ssize_t begin, - ssize_t end) { - if (error_) return Weight::NoWeight(); - Weight sum = begin == end ? Weight::Zero() - : data_->GetAccumulator(fst_id_)->Sum( - w, aiter, begin ? begin - offset_ : 0, end - offset_); - if (begin == 0 && end != 0 && offset_ > 0) - sum = Sum(offset_weight_, sum); - return sum; - } - - bool Error() const { return error_; } - - private: - bool init_; - ReplaceAccumulatorData<Accumulator, T> *data_; - Label fst_id_; - size_t offset_; - Weight offset_weight_; - bool error_; - - void operator=(const ReplaceAccumulator<Accumulator, T> &); // Disallow -}; - -} // namespace fst - -#endif // FST_LIB_ACCUMULATOR_H__ |