summaryrefslogtreecommitdiff
path: root/kaldi_io/src/tools/openfst/include/fst/accumulator.h
diff options
context:
space:
mode:
Diffstat (limited to 'kaldi_io/src/tools/openfst/include/fst/accumulator.h')
-rw-r--r--kaldi_io/src/tools/openfst/include/fst/accumulator.h745
1 files changed, 745 insertions, 0 deletions
diff --git a/kaldi_io/src/tools/openfst/include/fst/accumulator.h b/kaldi_io/src/tools/openfst/include/fst/accumulator.h
new file mode 100644
index 0000000..81d1847
--- /dev/null
+++ b/kaldi_io/src/tools/openfst/include/fst/accumulator.h
@@ -0,0 +1,745 @@
+// accumulator.h
+
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+//
+// Copyright 2005-2010 Google, Inc.
+// Author: riley@google.com (Michael Riley)
+//
+// \file
+// Classes to accumulate arc weights. Useful for weight lookahead.
+
+#ifndef FST_LIB_ACCUMULATOR_H__
+#define FST_LIB_ACCUMULATOR_H__
+
+#include <algorithm>
+#include <functional>
+#include <tr1/unordered_map>
+using std::tr1::unordered_map;
+using std::tr1::unordered_multimap;
+#include <vector>
+using std::vector;
+
+#include <fst/arcfilter.h>
+#include <fst/arcsort.h>
+#include <fst/dfs-visit.h>
+#include <fst/expanded-fst.h>
+#include <fst/replace.h>
+
+namespace fst {
+
+// This class accumulates arc weights using the semiring Plus().
+template <class A>
+class DefaultAccumulator {
+ public:
+ typedef A Arc;
+ typedef typename A::StateId StateId;
+ typedef typename A::Weight Weight;
+
+ DefaultAccumulator() {}
+
+ DefaultAccumulator(const DefaultAccumulator<A> &acc) {}
+
+ void Init(const Fst<A>& fst, bool copy = false) {}
+
+ void SetState(StateId) {}
+
+ Weight Sum(Weight w, Weight v) {
+ return Plus(w, v);
+ }
+
+ template <class ArcIterator>
+ Weight Sum(Weight w, ArcIterator *aiter, ssize_t begin,
+ ssize_t end) {
+ Weight sum = w;
+ aiter->Seek(begin);
+ for (ssize_t pos = begin; pos < end; aiter->Next(), ++pos)
+ sum = Plus(sum, aiter->Value().weight);
+ return sum;
+ }
+
+ bool Error() const { return false; }
+
+ private:
+ void operator=(const DefaultAccumulator<A> &); // Disallow
+};
+
+
+// This class accumulates arc weights using the log semiring Plus()
+// assuming an arc weight has a WeightConvert specialization to
+// and from log64 weights.
+template <class A>
+class LogAccumulator {
+ public:
+ typedef A Arc;
+ typedef typename A::StateId StateId;
+ typedef typename A::Weight Weight;
+
+ LogAccumulator() {}
+
+ LogAccumulator(const LogAccumulator<A> &acc) {}
+
+ void Init(const Fst<A>& fst, bool copy = false) {}
+
+ void SetState(StateId) {}
+
+ Weight Sum(Weight w, Weight v) {
+ return LogPlus(w, v);
+ }
+
+ template <class ArcIterator>
+ Weight Sum(Weight w, ArcIterator *aiter, ssize_t begin,
+ ssize_t end) {
+ Weight sum = w;
+ aiter->Seek(begin);
+ for (ssize_t pos = begin; pos < end; aiter->Next(), ++pos)
+ sum = LogPlus(sum, aiter->Value().weight);
+ return sum;
+ }
+
+ bool Error() const { return false; }
+
+ private:
+ double LogPosExp(double x) { return log(1.0F + exp(-x)); }
+
+ Weight LogPlus(Weight w, Weight v) {
+ double f1 = to_log_weight_(w).Value();
+ double f2 = to_log_weight_(v).Value();
+ if (f1 > f2)
+ return to_weight_(f2 - LogPosExp(f1 - f2));
+ else
+ return to_weight_(f1 - LogPosExp(f2 - f1));
+ }
+
+ WeightConvert<Weight, Log64Weight> to_log_weight_;
+ WeightConvert<Log64Weight, Weight> to_weight_;
+
+ void operator=(const LogAccumulator<A> &); // Disallow
+};
+
+
+// Stores shareable data for fast log accumulator copies.
+class FastLogAccumulatorData {
+ public:
+ FastLogAccumulatorData() {}
+
+ vector<double> *Weights() { return &weights_; }
+ vector<ssize_t> *WeightPositions() { return &weight_positions_; }
+ double *WeightEnd() { return &(weights_[weights_.size() - 1]); };
+ int RefCount() const { return ref_count_.count(); }
+ int IncrRefCount() { return ref_count_.Incr(); }
+ int DecrRefCount() { return ref_count_.Decr(); }
+
+ private:
+ // Cummulative weight per state for all states s.t. # of arcs >
+ // arc_limit_ with arcs in order. Special first element per state
+ // being Log64Weight::Zero();
+ vector<double> weights_;
+ // Maps from state to corresponding beginning weight position in
+ // weights_. Position -1 means no pre-computed weights for that
+ // state.
+ vector<ssize_t> weight_positions_;
+ RefCounter ref_count_; // Reference count.
+
+ DISALLOW_COPY_AND_ASSIGN(FastLogAccumulatorData);
+};
+
+
+// This class accumulates arc weights using the log semiring Plus()
+// assuming an arc weight has a WeightConvert specialization to and
+// from log64 weights. The member function Init(fst) has to be called
+// to setup pre-computed weight information.
+template <class A>
+class FastLogAccumulator {
+ public:
+ typedef A Arc;
+ typedef typename A::StateId StateId;
+ typedef typename A::Weight Weight;
+
+ explicit FastLogAccumulator(ssize_t arc_limit = 20, ssize_t arc_period = 10)
+ : arc_limit_(arc_limit),
+ arc_period_(arc_period),
+ data_(new FastLogAccumulatorData()),
+ error_(false) {}
+
+ FastLogAccumulator(const FastLogAccumulator<A> &acc)
+ : arc_limit_(acc.arc_limit_),
+ arc_period_(acc.arc_period_),
+ data_(acc.data_),
+ error_(acc.error_) {
+ data_->IncrRefCount();
+ }
+
+ ~FastLogAccumulator() {
+ if (!data_->DecrRefCount())
+ delete data_;
+ }
+
+ void SetState(StateId s) {
+ vector<double> &weights = *data_->Weights();
+ vector<ssize_t> &weight_positions = *data_->WeightPositions();
+
+ if (weight_positions.size() <= s) {
+ FSTERROR() << "FastLogAccumulator::SetState: invalid state id.";
+ error_ = true;
+ return;
+ }
+
+ ssize_t pos = weight_positions[s];
+ if (pos >= 0)
+ state_weights_ = &(weights[pos]);
+ else
+ state_weights_ = 0;
+ }
+
+ Weight Sum(Weight w, Weight v) {
+ return LogPlus(w, v);
+ }
+
+ template <class ArcIterator>
+ Weight Sum(Weight w, ArcIterator *aiter, ssize_t begin,
+ ssize_t end) {
+ if (error_) return Weight::NoWeight();
+ Weight sum = w;
+ // Finds begin and end of pre-stored weights
+ ssize_t index_begin = -1, index_end = -1;
+ ssize_t stored_begin = end, stored_end = end;
+ if (state_weights_ != 0) {
+ index_begin = begin > 0 ? (begin - 1)/ arc_period_ + 1 : 0;
+ index_end = end / arc_period_;
+ stored_begin = index_begin * arc_period_;
+ stored_end = index_end * arc_period_;
+ }
+ // Computes sum before pre-stored weights
+ if (begin < stored_begin) {
+ ssize_t pos_end = min(stored_begin, end);
+ aiter->Seek(begin);
+ for (ssize_t pos = begin; pos < pos_end; aiter->Next(), ++pos)
+ sum = LogPlus(sum, aiter->Value().weight);
+ }
+ // Computes sum between pre-stored weights
+ if (stored_begin < stored_end) {
+ sum = LogPlus(sum, LogMinus(state_weights_[index_end],
+ state_weights_[index_begin]));
+ }
+ // Computes sum after pre-stored weights
+ if (stored_end < end) {
+ ssize_t pos_start = max(stored_begin, stored_end);
+ aiter->Seek(pos_start);
+ for (ssize_t pos = pos_start; pos < end; aiter->Next(), ++pos)
+ sum = LogPlus(sum, aiter->Value().weight);
+ }
+ return sum;
+ }
+
+ template <class F>
+ void Init(const F &fst, bool copy = false) {
+ if (copy)
+ return;
+ vector<double> &weights = *data_->Weights();
+ vector<ssize_t> &weight_positions = *data_->WeightPositions();
+ if (!weights.empty() || arc_limit_ < arc_period_) {
+ FSTERROR() << "FastLogAccumulator: initialization error.";
+ error_ = true;
+ return;
+ }
+ weight_positions.reserve(CountStates(fst));
+
+ ssize_t weight_position = 0;
+ for(StateIterator<F> siter(fst); !siter.Done(); siter.Next()) {
+ StateId s = siter.Value();
+ if (fst.NumArcs(s) >= arc_limit_) {
+ double sum = FloatLimits<double>::PosInfinity();
+ weight_positions.push_back(weight_position);
+ weights.push_back(sum);
+ ++weight_position;
+ ssize_t narcs = 0;
+ for(ArcIterator<F> aiter(fst, s); !aiter.Done(); aiter.Next()) {
+ const A &arc = aiter.Value();
+ sum = LogPlus(sum, arc.weight);
+ // Stores cumulative weight distribution per arc_period_.
+ if (++narcs % arc_period_ == 0) {
+ weights.push_back(sum);
+ ++weight_position;
+ }
+ }
+ } else {
+ weight_positions.push_back(-1);
+ }
+ }
+ }
+
+ bool Error() const { return error_; }
+
+ private:
+ double LogPosExp(double x) {
+ return x == FloatLimits<double>::PosInfinity() ?
+ 0.0 : log(1.0F + exp(-x));
+ }
+
+ double LogMinusExp(double x) {
+ return x == FloatLimits<double>::PosInfinity() ?
+ 0.0 : log(1.0F - exp(-x));
+ }
+
+ Weight LogPlus(Weight w, Weight v) {
+ double f1 = to_log_weight_(w).Value();
+ double f2 = to_log_weight_(v).Value();
+ if (f1 > f2)
+ return to_weight_(f2 - LogPosExp(f1 - f2));
+ else
+ return to_weight_(f1 - LogPosExp(f2 - f1));
+ }
+
+ double LogPlus(double f1, Weight v) {
+ double f2 = to_log_weight_(v).Value();
+ if (f1 == FloatLimits<double>::PosInfinity())
+ return f2;
+ else if (f1 > f2)
+ return f2 - LogPosExp(f1 - f2);
+ else
+ return f1 - LogPosExp(f2 - f1);
+ }
+
+ Weight LogMinus(double f1, double f2) {
+ if (f1 >= f2) {
+ FSTERROR() << "FastLogAcumulator::LogMinus: f1 >= f2 with f1 = " << f1
+ << " and f2 = " << f2;
+ error_ = true;
+ return Weight::NoWeight();
+ }
+ if (f2 == FloatLimits<double>::PosInfinity())
+ return to_weight_(f1);
+ else
+ return to_weight_(f1 - LogMinusExp(f2 - f1));
+ }
+
+ WeightConvert<Weight, Log64Weight> to_log_weight_;
+ WeightConvert<Log64Weight, Weight> to_weight_;
+
+ ssize_t arc_limit_; // Minimum # of arcs to pre-compute state
+ ssize_t arc_period_; // Save cumulative weights per 'arc_period_'.
+ bool init_; // Cumulative weights initialized?
+ FastLogAccumulatorData *data_;
+ double *state_weights_;
+ bool error_;
+
+ void operator=(const FastLogAccumulator<A> &); // Disallow
+};
+
+
+// Stores shareable data for cache log accumulator copies.
+// All copies share the same cache.
+template <class A>
+class CacheLogAccumulatorData {
+ public:
+ typedef A Arc;
+ typedef typename A::StateId StateId;
+ typedef typename A::Weight Weight;
+
+ CacheLogAccumulatorData(bool gc, size_t gc_limit)
+ : cache_gc_(gc), cache_limit_(gc_limit), cache_size_(0) {}
+
+ ~CacheLogAccumulatorData() {
+ for(typename unordered_map<StateId, CacheState>::iterator it = cache_.begin();
+ it != cache_.end();
+ ++it)
+ delete it->second.weights;
+ }
+
+ bool CacheDisabled() const { return cache_gc_ && cache_limit_ == 0; }
+
+ vector<double> *GetWeights(StateId s) {
+ typename unordered_map<StateId, CacheState>::iterator it = cache_.find(s);
+ if (it != cache_.end()) {
+ it->second.recent = true;
+ return it->second.weights;
+ } else {
+ return 0;
+ }
+ }
+
+ void AddWeights(StateId s, vector<double> *weights) {
+ if (cache_gc_ && cache_size_ >= cache_limit_)
+ GC(false);
+ cache_.insert(make_pair(s, CacheState(weights, true)));
+ if (cache_gc_)
+ cache_size_ += weights->capacity() * sizeof(double);
+ }
+
+ int RefCount() const { return ref_count_.count(); }
+ int IncrRefCount() { return ref_count_.Incr(); }
+ int DecrRefCount() { return ref_count_.Decr(); }
+
+ private:
+ // Cached information for a given state.
+ struct CacheState {
+ vector<double>* weights; // Accumulated weights for this state.
+ bool recent; // Has this state been accessed since last GC?
+
+ CacheState(vector<double> *w, bool r) : weights(w), recent(r) {}
+ };
+
+ // Garbage collect: Delete from cache states that have not been
+ // accessed since the last GC ('free_recent = false') until
+ // 'cache_size_' is 2/3 of 'cache_limit_'. If it does not free enough
+ // memory, start deleting recently accessed states.
+ void GC(bool free_recent) {
+ size_t cache_target = (2 * cache_limit_)/3 + 1;
+ typename unordered_map<StateId, CacheState>::iterator it = cache_.begin();
+ while (it != cache_.end() && cache_size_ > cache_target) {
+ CacheState &cs = it->second;
+ if (free_recent || !cs.recent) {
+ cache_size_ -= cs.weights->capacity() * sizeof(double);
+ delete cs.weights;
+ cache_.erase(it++);
+ } else {
+ cs.recent = false;
+ ++it;
+ }
+ }
+ if (!free_recent && cache_size_ > cache_target)
+ GC(true);
+ }
+
+ unordered_map<StateId, CacheState> cache_; // Cache
+ bool cache_gc_; // Enable garbage collection
+ size_t cache_limit_; // # of bytes cached
+ size_t cache_size_; // # of bytes allowed before GC
+ RefCounter ref_count_;
+
+ DISALLOW_COPY_AND_ASSIGN(CacheLogAccumulatorData);
+};
+
+// This class accumulates arc weights using the log semiring Plus()
+// has a WeightConvert specialization to and from log64 weights. It
+// is similar to the FastLogAccumator. However here, the accumulated
+// weights are pre-computed and stored only for the states that are
+// visited. The member function Init(fst) has to be called to setup
+// this accumulator.
+template <class A>
+class CacheLogAccumulator {
+ public:
+ typedef A Arc;
+ typedef typename A::StateId StateId;
+ typedef typename A::Weight Weight;
+
+ explicit CacheLogAccumulator(ssize_t arc_limit = 10, bool gc = false,
+ size_t gc_limit = 10 * 1024 * 1024)
+ : arc_limit_(arc_limit), fst_(0), data_(
+ new CacheLogAccumulatorData<A>(gc, gc_limit)), s_(kNoStateId),
+ error_(false) {}
+
+ CacheLogAccumulator(const CacheLogAccumulator<A> &acc)
+ : arc_limit_(acc.arc_limit_), fst_(acc.fst_ ? acc.fst_->Copy() : 0),
+ data_(acc.data_), s_(kNoStateId), error_(acc.error_) {
+ data_->IncrRefCount();
+ }
+
+ ~CacheLogAccumulator() {
+ if (fst_)
+ delete fst_;
+ if (!data_->DecrRefCount())
+ delete data_;
+ }
+
+ // Arg 'arc_limit' specifies minimum # of arcs to pre-compute state.
+ void Init(const Fst<A> &fst, bool copy = false) {
+ if (copy) {
+ delete fst_;
+ } else if (fst_) {
+ FSTERROR() << "CacheLogAccumulator: initialization error.";
+ error_ = true;
+ return;
+ }
+ fst_ = fst.Copy();
+ }
+
+ void SetState(StateId s, int depth = 0) {
+ if (s == s_)
+ return;
+ s_ = s;
+
+ if (data_->CacheDisabled() || error_) {
+ weights_ = 0;
+ return;
+ }
+
+ if (!fst_) {
+ FSTERROR() << "CacheLogAccumulator::SetState: incorrectly initialized.";
+ error_ = true;
+ weights_ = 0;
+ return;
+ }
+
+ weights_ = data_->GetWeights(s);
+ if ((weights_ == 0) && (fst_->NumArcs(s) >= arc_limit_)) {
+ weights_ = new vector<double>;
+ weights_->reserve(fst_->NumArcs(s) + 1);
+ weights_->push_back(FloatLimits<double>::PosInfinity());
+ data_->AddWeights(s, weights_);
+ }
+ }
+
+ Weight Sum(Weight w, Weight v) {
+ return LogPlus(w, v);
+ }
+
+ template <class Iterator>
+ Weight Sum(Weight w, Iterator *aiter, ssize_t begin,
+ ssize_t end) {
+ if (weights_ == 0) {
+ Weight sum = w;
+ aiter->Seek(begin);
+ for (ssize_t pos = begin; pos < end; aiter->Next(), ++pos)
+ sum = LogPlus(sum, aiter->Value().weight);
+ return sum;
+ } else {
+ if (weights_->size() <= end)
+ for (aiter->Seek(weights_->size() - 1);
+ weights_->size() <= end;
+ aiter->Next())
+ weights_->push_back(LogPlus(weights_->back(),
+ aiter->Value().weight));
+ return LogPlus(w, LogMinus((*weights_)[end], (*weights_)[begin]));
+ }
+ }
+
+ template <class Iterator>
+ size_t LowerBound(double w, Iterator *aiter) {
+ if (weights_ != 0) {
+ return lower_bound(weights_->begin() + 1,
+ weights_->end(),
+ w,
+ std::greater<double>())
+ - weights_->begin() - 1;
+ } else {
+ size_t n = 0;
+ double x = FloatLimits<double>::PosInfinity();
+ for(aiter->Reset(); !aiter->Done(); aiter->Next(), ++n) {
+ x = LogPlus(x, aiter->Value().weight);
+ if (x < w) break;
+ }
+ return n;
+ }
+ }
+
+ bool Error() const { return error_; }
+
+ private:
+ double LogPosExp(double x) {
+ return x == FloatLimits<double>::PosInfinity() ?
+ 0.0 : log(1.0F + exp(-x));
+ }
+
+ double LogMinusExp(double x) {
+ return x == FloatLimits<double>::PosInfinity() ?
+ 0.0 : log(1.0F - exp(-x));
+ }
+
+ Weight LogPlus(Weight w, Weight v) {
+ double f1 = to_log_weight_(w).Value();
+ double f2 = to_log_weight_(v).Value();
+ if (f1 > f2)
+ return to_weight_(f2 - LogPosExp(f1 - f2));
+ else
+ return to_weight_(f1 - LogPosExp(f2 - f1));
+ }
+
+ double LogPlus(double f1, Weight v) {
+ double f2 = to_log_weight_(v).Value();
+ if (f1 == FloatLimits<double>::PosInfinity())
+ return f2;
+ else if (f1 > f2)
+ return f2 - LogPosExp(f1 - f2);
+ else
+ return f1 - LogPosExp(f2 - f1);
+ }
+
+ Weight LogMinus(double f1, double f2) {
+ if (f1 >= f2) {
+ FSTERROR() << "CacheLogAcumulator::LogMinus: f1 >= f2 with f1 = " << f1
+ << " and f2 = " << f2;
+ error_ = true;
+ return Weight::NoWeight();
+ }
+ if (f2 == FloatLimits<double>::PosInfinity())
+ return to_weight_(f1);
+ else
+ return to_weight_(f1 - LogMinusExp(f2 - f1));
+ }
+
+ WeightConvert<Weight, Log64Weight> to_log_weight_;
+ WeightConvert<Log64Weight, Weight> to_weight_;
+
+ ssize_t arc_limit_; // Minimum # of arcs to cache a state
+ vector<double> *weights_; // Accumulated weights for cur. state
+ const Fst<A>* fst_; // Input fst
+ CacheLogAccumulatorData<A> *data_; // Cache data
+ StateId s_; // Current state
+ bool error_;
+
+ void operator=(const CacheLogAccumulator<A> &); // Disallow
+};
+
+
+// Stores shareable data for replace accumulator copies.
+template <class Accumulator, class T>
+class ReplaceAccumulatorData {
+ public:
+ typedef typename Accumulator::Arc Arc;
+ typedef typename Arc::StateId StateId;
+ typedef typename Arc::Label Label;
+ typedef T StateTable;
+ typedef typename T::StateTuple StateTuple;
+
+ ReplaceAccumulatorData() : state_table_(0) {}
+
+ ReplaceAccumulatorData(const vector<Accumulator*> &accumulators)
+ : state_table_(0), accumulators_(accumulators) {}
+
+ ~ReplaceAccumulatorData() {
+ for (size_t i = 0; i < fst_array_.size(); ++i)
+ delete fst_array_[i];
+ for (size_t i = 0; i < accumulators_.size(); ++i)
+ delete accumulators_[i];
+ }
+
+ void Init(const vector<pair<Label, const Fst<Arc>*> > &fst_tuples,
+ const StateTable *state_table) {
+ state_table_ = state_table;
+ accumulators_.resize(fst_tuples.size());
+ for (size_t i = 0; i < accumulators_.size(); ++i) {
+ if (!accumulators_[i])
+ accumulators_[i] = new Accumulator;
+ accumulators_[i]->Init(*(fst_tuples[i].second));
+ fst_array_.push_back(fst_tuples[i].second->Copy());
+ }
+ }
+
+ const StateTuple &GetTuple(StateId s) const {
+ return state_table_->Tuple(s);
+ }
+
+ Accumulator *GetAccumulator(size_t i) { return accumulators_[i]; }
+
+ const Fst<Arc> *GetFst(size_t i) const { return fst_array_[i]; }
+
+ int RefCount() const { return ref_count_.count(); }
+ int IncrRefCount() { return ref_count_.Incr(); }
+ int DecrRefCount() { return ref_count_.Decr(); }
+
+ private:
+ const T * state_table_;
+ vector<Accumulator*> accumulators_;
+ vector<const Fst<Arc>*> fst_array_;
+ RefCounter ref_count_;
+
+ DISALLOW_COPY_AND_ASSIGN(ReplaceAccumulatorData);
+};
+
+// This class accumulates weights in a ReplaceFst. The 'Init' method
+// takes as input the argument used to build the ReplaceFst and the
+// ReplaceFst state table. It uses accumulators of type 'Accumulator'
+// in the underlying FSTs.
+template <class Accumulator,
+ class T = DefaultReplaceStateTable<typename Accumulator::Arc> >
+class ReplaceAccumulator {
+ public:
+ typedef typename Accumulator::Arc Arc;
+ typedef typename Arc::StateId StateId;
+ typedef typename Arc::Label Label;
+ typedef typename Arc::Weight Weight;
+ typedef T StateTable;
+ typedef typename T::StateTuple StateTuple;
+
+ ReplaceAccumulator()
+ : init_(false), data_(new ReplaceAccumulatorData<Accumulator, T>()),
+ error_(false) {}
+
+ ReplaceAccumulator(const vector<Accumulator*> &accumulators)
+ : init_(false),
+ data_(new ReplaceAccumulatorData<Accumulator, T>(accumulators)),
+ error_(false) {}
+
+ ReplaceAccumulator(const ReplaceAccumulator<Accumulator, T> &acc)
+ : init_(acc.init_), data_(acc.data_), error_(acc.error_) {
+ if (!init_)
+ FSTERROR() << "ReplaceAccumulator: can't copy unintialized accumulator";
+ data_->IncrRefCount();
+ }
+
+ ~ReplaceAccumulator() {
+ if (!data_->DecrRefCount())
+ delete data_;
+ }
+
+ // Does not take ownership of the state table, the state table
+ // is own by the ReplaceFst
+ void Init(const vector<pair<Label, const Fst<Arc>*> > &fst_tuples,
+ const StateTable *state_table) {
+ init_ = true;
+ data_->Init(fst_tuples, state_table);
+ }
+
+ void SetState(StateId s) {
+ if (!init_) {
+ FSTERROR() << "ReplaceAccumulator::SetState: incorrectly initialized.";
+ error_ = true;
+ return;
+ }
+ StateTuple tuple = data_->GetTuple(s);
+ fst_id_ = tuple.fst_id - 1; // Replace FST ID is 1-based
+ data_->GetAccumulator(fst_id_)->SetState(tuple.fst_state);
+ if ((tuple.prefix_id != 0) &&
+ (data_->GetFst(fst_id_)->Final(tuple.fst_state) != Weight::Zero())) {
+ offset_ = 1;
+ offset_weight_ = data_->GetFst(fst_id_)->Final(tuple.fst_state);
+ } else {
+ offset_ = 0;
+ offset_weight_ = Weight::Zero();
+ }
+ }
+
+ Weight Sum(Weight w, Weight v) {
+ if (error_) return Weight::NoWeight();
+ return data_->GetAccumulator(fst_id_)->Sum(w, v);
+ }
+
+ template <class ArcIterator>
+ Weight Sum(Weight w, ArcIterator *aiter, ssize_t begin,
+ ssize_t end) {
+ if (error_) return Weight::NoWeight();
+ Weight sum = begin == end ? Weight::Zero()
+ : data_->GetAccumulator(fst_id_)->Sum(
+ w, aiter, begin ? begin - offset_ : 0, end - offset_);
+ if (begin == 0 && end != 0 && offset_ > 0)
+ sum = Sum(offset_weight_, sum);
+ return sum;
+ }
+
+ bool Error() const { return error_; }
+
+ private:
+ bool init_;
+ ReplaceAccumulatorData<Accumulator, T> *data_;
+ Label fst_id_;
+ size_t offset_;
+ Weight offset_weight_;
+ bool error_;
+
+ void operator=(const ReplaceAccumulator<Accumulator, T> &); // Disallow
+};
+
+} // namespace fst
+
+#endif // FST_LIB_ACCUMULATOR_H__