// accumulator.h
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// Copyright 2005-2010 Google, Inc.
// Author: riley@google.com (Michael Riley)
//
// \file
// Classes to accumulate arc weights. Useful for weight lookahead.
#ifndef FST_LIB_ACCUMULATOR_H__
#define FST_LIB_ACCUMULATOR_H__
#include <algorithm>
#include <functional>
#include <tr1/unordered_map>
using std::tr1::unordered_map;
using std::tr1::unordered_multimap;
#include <vector>
using std::vector;
#include <fst/arcfilter.h>
#include <fst/arcsort.h>
#include <fst/dfs-visit.h>
#include <fst/expanded-fst.h>
#include <fst/replace.h>
namespace fst {
// This class accumulates arc weights using the semiring Plus().
template <class A>
class DefaultAccumulator {
public:
typedef A Arc;
typedef typename A::StateId StateId;
typedef typename A::Weight Weight;
DefaultAccumulator() {}
DefaultAccumulator(const DefaultAccumulator<A> &acc) {}
void Init(const Fst<A>& fst, bool copy = false) {}
void SetState(StateId) {}
Weight Sum(Weight w, Weight v) {
return Plus(w, v);
}
template <class ArcIterator>
Weight Sum(Weight w, ArcIterator *aiter, ssize_t begin,
ssize_t end) {
Weight sum = w;
aiter->Seek(begin);
for (ssize_t pos = begin; pos < end; aiter->Next(), ++pos)
sum = Plus(sum, aiter->Value().weight);
return sum;
}
bool Error() const { return false; }
private:
void operator=(const DefaultAccumulator<A> &); // Disallow
};
// This class accumulates arc weights using the log semiring Plus()
// assuming an arc weight has a WeightConvert specialization to
// and from log64 weights.
template <class A>
class LogAccumulator {
public:
typedef A Arc;
typedef typename A::StateId StateId;
typedef typename A::Weight Weight;
LogAccumulator() {}
LogAccumulator(const LogAccumulator<A> &acc) {}
void Init(const Fst<A>& fst, bool copy = false) {}
void SetState(StateId) {}
Weight Sum(Weight w, Weight v) {
return LogPlus(w, v);
}
template <class ArcIterator>
Weight Sum(Weight w, ArcIterator *aiter, ssize_t begin,
ssize_t end) {
Weight sum = w;
aiter->Seek(begin);
for (ssize_t pos = begin; pos < end; aiter->Next(), ++pos)
sum = LogPlus(sum, aiter->Value().weight);
return sum;
}
bool Error() const { return false; }
private:
double LogPosExp(double x) { return log(1.0F + exp(-x)); }
Weight LogPlus(Weight w, Weight v) {
double f1 = to_log_weight_(w).Value();
double f2 = to_log_weight_(v).Value();
if (f1 > f2)
return to_weight_(f2 - LogPosExp(f1 - f2));
else
return to_weight_(f1 - LogPosExp(f2 - f1));
}
WeightConvert<Weight, Log64Weight> to_log_weight_;
WeightConvert<Log64Weight, Weight> to_weight_;
void operator=(const LogAccumulator<A> &); // Disallow
};
// Stores shareable data for fast log accumulator copies.
class FastLogAccumulatorData {
public:
FastLogAccumulatorData() {}
vector<double> *Weights() { return &weights_; }
vector<ssize_t> *WeightPositions() { return &weight_positions_; }
double *WeightEnd() { return &(weights_[weights_.size() - 1]); };
int RefCount() const { return ref_count_.count(); }
int IncrRefCount() { return ref_count_.Incr(); }
int DecrRefCount() { return ref_count_.Decr(); }
private:
// Cummulative weight per state for all states s.t. # of arcs >
// arc_limit_ with arcs in order. Special first element per state
// being Log64Weight::Zero();
vector<double> weights_;
// Maps from state to corresponding beginning weight position in
// weights_. Position -1 means no pre-computed weights for that
// state.
vector<ssize_t> weight_positions_;
RefCounter ref_count_; // Reference count.
DISALLOW_COPY_AND_ASSIGN(FastLogAccumulatorData);
};
// This class accumulates arc weights using the log semiring Plus()
// assuming an arc weight has a WeightConvert specialization to and
// from log64 weights. The member function Init(fst) has to be called
// to setup pre-computed weight information.
template <class A>
class FastLogAccumulator {
public:
typedef A Arc;
typedef typename A::StateId StateId;
typedef typename A::Weight Weight;
explicit FastLogAccumulator(ssize_t arc_limit = 20, ssize_t arc_period = 10)
: arc_limit_(arc_limit),
arc_period_(arc_period),
data_(new FastLogAccumulatorData()),
error_(false) {}
FastLogAccumulator(const FastLogAccumulator<A> &acc)
: arc_limit_(acc.arc_limit_),
arc_period_(acc.arc_period_),
data_(acc.data_),
error_(acc.error_) {
data_->IncrRefCount();
}
~FastLogAccumulator() {
if (!data_->DecrRefCount())
delete data_;
}
void SetState(StateId s) {
vector<double> &weights = *data_->Weights();
vector<ssize_t> &weight_positions = *data_->WeightPositions();