// factor-weight.h
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// Copyright 2005-2010 Google, Inc.
// Author: allauzen@google.com (Cyril Allauzen)
//
// \file
// Classes to factor weights in an FST.
#ifndef FST_LIB_FACTOR_WEIGHT_H__
#define FST_LIB_FACTOR_WEIGHT_H__
#include <algorithm>
#include <tr1/unordered_map>
using std::tr1::unordered_map;
using std::tr1::unordered_multimap;
#include <string>
#include <utility>
using std::pair; using std::make_pair;
#include <vector>
using std::vector;
#include <fst/cache.h>
#include <fst/test-properties.h>
namespace fst {
const uint32 kFactorFinalWeights = 0x00000001;
const uint32 kFactorArcWeights = 0x00000002;
template <class Arc>
struct FactorWeightOptions : CacheOptions {
typedef typename Arc::Label Label;
float delta;
uint32 mode; // factor arc weights and/or final weights
Label final_ilabel; // input label of arc created when factoring final w's
Label final_olabel; // output label of arc created when factoring final w's
FactorWeightOptions(const CacheOptions &opts, float d,
uint32 m = kFactorArcWeights | kFactorFinalWeights,
Label il = 0, Label ol = 0)
: CacheOptions(opts), delta(d), mode(m), final_ilabel(il),
final_olabel(ol) {}
explicit FactorWeightOptions(
float d, uint32 m = kFactorArcWeights | kFactorFinalWeights,
Label il = 0, Label ol = 0)
: delta(d), mode(m), final_ilabel(il), final_olabel(ol) {}
FactorWeightOptions(uint32 m = kFactorArcWeights | kFactorFinalWeights,
Label il = 0, Label ol = 0)
: delta(kDelta), mode(m), final_ilabel(il), final_olabel(ol) {}
};
// A factor iterator takes as argument a weight w and returns a
// sequence of pairs of weights (xi,yi) such that the sum of the
// products xi times yi is equal to w. If w is fully factored,
// the iterator should return nothing.
//
// template <class W>
// class FactorIterator {
// public:
// FactorIterator(W w);
// bool Done() const;
// void Next();
// pair<W, W> Value() const;
// void Reset();
// }
// Factor trivially.
template <class W>
class IdentityFactor {
public:
IdentityFactor(const W &w) {}
bool Done() const { return true; }
void Next() {}
pair<W, W> Value() const { return make_pair(W::One(), W::One()); } // unused
void Reset() {}
};
// Factor a StringWeight w as 'ab' where 'a' is a label.
template <typename L, StringType S = STRING_LEFT>
class StringFactor {
public:
StringFactor(const StringWeight<L, S> &w)
: weight_(w), done_(w.Size() <= 1) {}
bool Done() const { return done_; }
void Next() { done_ = true; }
pair< StringWeight<L, S>, StringWeight<L, S> > Value() const {
StringWeightIterator<L, S> iter(weight_);
StringWeight<L, S> w1(iter.Value());
StringWeight<L, S> w2;
for (iter.Next(); !iter.Done(); iter.Next())
w2.PushBack(iter.Value());
return make_pair(w1, w2);
}
void Reset() { done_ = weight_.Size() <= 1; }
private:
StringWeight<L, S> weight_;
bool done_;
};
// Factor a GallicWeight using StringFactor.
template <class L, class W, StringType S = STRING_LEFT>
class GallicFactor {
public:
GallicFactor(const GallicWeight<L, W, S> &w)
: weight_(w), done_(w.Value1().Size() <= 1) {}
bool Done() const { return done_; }
void Next() { done_ = true; }
pair< GallicWeight<L, W, S>, GallicWeight<L, W, S> > Value() const {
StringFactor<L, S> iter(weight_.Value1());
GallicWeight<L, W, S> w1(iter.Value().first, weight_.Value2());
GallicWeight<L, W, S> w2(iter.Value().second, W::One());
return make_pair(w1, w2);
}
void Reset() { done_ = weight_.Value1().Size() <= 1; }
private:
GallicWeight<L, W, S> weight_;
bool done_;
};
// Implementation class for FactorWeight
template <class A, class F>
class FactorWeightFstImpl
: public CacheImpl<A> {
public:
using FstImpl<A>::SetType;
using FstImpl<A>::SetProperties;
using FstImpl<A>::SetInputSymbols;
using FstImpl<A>::SetOutputSymbols;
using CacheBaseImpl< CacheState<A> >::PushArc;
using CacheBaseImpl< CacheState<A> >::HasStart;
using CacheBaseImpl< CacheState<A> >::HasFinal;
using CacheBaseImpl< CacheState<A> >::HasArcs;
using CacheBaseImpl< CacheState<A> >::SetArcs;
using CacheBaseImpl< CacheState<A> >::SetFinal;
using CacheBaseImpl< CacheState<A> >::SetStart;
typedef A Arc;
typedef typename A::Label Label;
typedef typename A::Weight Weight;
typedef typename A::StateId StateId;
typedef F FactorIterator;
struct Element {
Element() {}
Element(StateId s, Weight w) : state(s), weight(w) {}
StateId state; // Input state Id
Weight weight; // Residual weight
};
FactorWeightFstImpl(const Fst<A> &fst, const FactorWeightOptions<A> &opts)
: CacheImpl<A>(opts),
fst_(fst.Copy()),
delta_(opts.delta),
mode_(opts.mode),
final_ilabel_(opts.final_ilabel),
final_olabel_(opts.final_olabel) {
SetType("factor_weight");
uint64 props = fst.Properties(kFstProperties, false);
SetProperties(FactorWeightProperties(props), kCopyProperties);
SetInputSymbols(fst.InputSymbols());
SetOutputSymbols(fst.OutputSymbols());
if (mode_ == 0)
LOG(WARNING) << "FactorWeightFst: factor mode is set to 0: "
<< "factoring neither arc weights nor final weights.";
}
FactorWeightFstImpl(const FactorWeightFstImpl<A, F> &impl)
: CacheImpl<A>(impl),
fst_(impl.fst_->Copy(true)),
delta_(impl.delta_),
mode_(impl.mode_),
final_ilabel_(impl.final_ilabel_),
final_olabel_(impl.final_olabel_) {
SetType("factor_weight");
SetProperties(impl.Properties(), kCopyProperties);
SetInputSymbols(impl.InputSymbols());
SetOutputSymbols(impl.OutputSymbols());
}
~FactorWeightFstImpl() {
delete fst_;
}
StateId Start() {
if (!HasStart()) {
StateId s = fst_->Start