diff options
author | Determinant <[email protected]> | 2015-08-14 11:51:42 +0800 |
---|---|---|
committer | Determinant <[email protected]> | 2015-08-14 11:51:42 +0800 |
commit | 96a32415ab43377cf1575bd3f4f2980f58028209 (patch) | |
tree | 30a2d92d73e8f40ac87b79f6f56e227bfc4eea6e /kaldi_io/src/tools/openfst/include/fst/factor-weight.h | |
parent | c177a7549bd90670af4b29fa813ddea32cfe0f78 (diff) |
add implementation for kaldi io (by ymz)
Diffstat (limited to 'kaldi_io/src/tools/openfst/include/fst/factor-weight.h')
-rw-r--r-- | kaldi_io/src/tools/openfst/include/fst/factor-weight.h | 475 |
1 files changed, 475 insertions, 0 deletions
diff --git a/kaldi_io/src/tools/openfst/include/fst/factor-weight.h b/kaldi_io/src/tools/openfst/include/fst/factor-weight.h new file mode 100644 index 0000000..685155c --- /dev/null +++ b/kaldi_io/src/tools/openfst/include/fst/factor-weight.h @@ -0,0 +1,475 @@ +// factor-weight.h + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Copyright 2005-2010 Google, Inc. +// Author: [email protected] (Cyril Allauzen) +// +// \file +// Classes to factor weights in an FST. + +#ifndef FST_LIB_FACTOR_WEIGHT_H__ +#define FST_LIB_FACTOR_WEIGHT_H__ + +#include <algorithm> +#include <tr1/unordered_map> +using std::tr1::unordered_map; +using std::tr1::unordered_multimap; +#include <string> +#include <utility> +using std::pair; using std::make_pair; +#include <vector> +using std::vector; + +#include <fst/cache.h> +#include <fst/test-properties.h> + + +namespace fst { + +const uint32 kFactorFinalWeights = 0x00000001; +const uint32 kFactorArcWeights = 0x00000002; + +template <class Arc> +struct FactorWeightOptions : CacheOptions { + typedef typename Arc::Label Label; + float delta; + uint32 mode; // factor arc weights and/or final weights + Label final_ilabel; // input label of arc created when factoring final w's + Label final_olabel; // output label of arc created when factoring final w's + + FactorWeightOptions(const CacheOptions &opts, float d, + uint32 m = kFactorArcWeights | kFactorFinalWeights, + Label il = 0, Label ol = 0) + : CacheOptions(opts), delta(d), mode(m), final_ilabel(il), + final_olabel(ol) {} + + explicit FactorWeightOptions( + float d, uint32 m = kFactorArcWeights | kFactorFinalWeights, + Label il = 0, Label ol = 0) + : delta(d), mode(m), final_ilabel(il), final_olabel(ol) {} + + FactorWeightOptions(uint32 m = kFactorArcWeights | kFactorFinalWeights, + Label il = 0, Label ol = 0) + : delta(kDelta), mode(m), final_ilabel(il), final_olabel(ol) {} +}; + + +// A factor iterator takes as argument a weight w and returns a +// sequence of pairs of weights (xi,yi) such that the sum of the +// products xi times yi is equal to w. If w is fully factored, +// the iterator should return nothing. +// +// template <class W> +// class FactorIterator { +// public: +// FactorIterator(W w); +// bool Done() const; +// void Next(); +// pair<W, W> Value() const; +// void Reset(); +// } + + +// Factor trivially. +template <class W> +class IdentityFactor { + public: + IdentityFactor(const W &w) {} + bool Done() const { return true; } + void Next() {} + pair<W, W> Value() const { return make_pair(W::One(), W::One()); } // unused + void Reset() {} +}; + + +// Factor a StringWeight w as 'ab' where 'a' is a label. +template <typename L, StringType S = STRING_LEFT> +class StringFactor { + public: + StringFactor(const StringWeight<L, S> &w) + : weight_(w), done_(w.Size() <= 1) {} + + bool Done() const { return done_; } + + void Next() { done_ = true; } + + pair< StringWeight<L, S>, StringWeight<L, S> > Value() const { + StringWeightIterator<L, S> iter(weight_); + StringWeight<L, S> w1(iter.Value()); + StringWeight<L, S> w2; + for (iter.Next(); !iter.Done(); iter.Next()) + w2.PushBack(iter.Value()); + return make_pair(w1, w2); + } + + void Reset() { done_ = weight_.Size() <= 1; } + + private: + StringWeight<L, S> weight_; + bool done_; +}; + + +// Factor a GallicWeight using StringFactor. +template <class L, class W, StringType S = STRING_LEFT> +class GallicFactor { + public: + GallicFactor(const GallicWeight<L, W, S> &w) + : weight_(w), done_(w.Value1().Size() <= 1) {} + + bool Done() const { return done_; } + + void Next() { done_ = true; } + + pair< GallicWeight<L, W, S>, GallicWeight<L, W, S> > Value() const { + StringFactor<L, S> iter(weight_.Value1()); + GallicWeight<L, W, S> w1(iter.Value().first, weight_.Value2()); + GallicWeight<L, W, S> w2(iter.Value().second, W::One()); + return make_pair(w1, w2); + } + + void Reset() { done_ = weight_.Value1().Size() <= 1; } + + private: + GallicWeight<L, W, S> weight_; + bool done_; +}; + + +// Implementation class for FactorWeight +template <class A, class F> +class FactorWeightFstImpl + : public CacheImpl<A> { + public: + using FstImpl<A>::SetType; + using FstImpl<A>::SetProperties; + using FstImpl<A>::SetInputSymbols; + using FstImpl<A>::SetOutputSymbols; + + using CacheBaseImpl< CacheState<A> >::PushArc; + using CacheBaseImpl< CacheState<A> >::HasStart; + using CacheBaseImpl< CacheState<A> >::HasFinal; + using CacheBaseImpl< CacheState<A> >::HasArcs; + using CacheBaseImpl< CacheState<A> >::SetArcs; + using CacheBaseImpl< CacheState<A> >::SetFinal; + using CacheBaseImpl< CacheState<A> >::SetStart; + + typedef A Arc; + typedef typename A::Label Label; + typedef typename A::Weight Weight; + typedef typename A::StateId StateId; + typedef F FactorIterator; + + struct Element { + Element() {} + + Element(StateId s, Weight w) : state(s), weight(w) {} + + StateId state; // Input state Id + Weight weight; // Residual weight + }; + + FactorWeightFstImpl(const Fst<A> &fst, const FactorWeightOptions<A> &opts) + : CacheImpl<A>(opts), + fst_(fst.Copy()), + delta_(opts.delta), + mode_(opts.mode), + final_ilabel_(opts.final_ilabel), + final_olabel_(opts.final_olabel) { + SetType("factor_weight"); + uint64 props = fst.Properties(kFstProperties, false); + SetProperties(FactorWeightProperties(props), kCopyProperties); + + SetInputSymbols(fst.InputSymbols()); + SetOutputSymbols(fst.OutputSymbols()); + + if (mode_ == 0) + LOG(WARNING) << "FactorWeightFst: factor mode is set to 0: " + << "factoring neither arc weights nor final weights."; + } + + FactorWeightFstImpl(const FactorWeightFstImpl<A, F> &impl) + : CacheImpl<A>(impl), + fst_(impl.fst_->Copy(true)), + delta_(impl.delta_), + mode_(impl.mode_), + final_ilabel_(impl.final_ilabel_), + final_olabel_(impl.final_olabel_) { + SetType("factor_weight"); + SetProperties(impl.Properties(), kCopyProperties); + SetInputSymbols(impl.InputSymbols()); + SetOutputSymbols(impl.OutputSymbols()); + } + + ~FactorWeightFstImpl() { + delete fst_; + } + + StateId Start() { + if (!HasStart()) { + StateId s = fst_->Start(); + if (s == kNoStateId) + return kNoStateId; + StateId start = FindState(Element(fst_->Start(), Weight::One())); + SetStart(start); + } + return CacheImpl<A>::Start(); + } + + Weight Final(StateId s) { + if (!HasFinal(s)) { + const Element &e = elements_[s]; + // TODO: fix so cast is unnecessary + Weight w = e.state == kNoStateId + ? e.weight + : (Weight) Times(e.weight, fst_->Final(e.state)); + FactorIterator f(w); + if (!(mode_ & kFactorFinalWeights) || f.Done()) + SetFinal(s, w); + else + SetFinal(s, Weight::Zero()); + } + return CacheImpl<A>::Final(s); + } + + size_t NumArcs(StateId s) { + if (!HasArcs(s)) + Expand(s); + return CacheImpl<A>::NumArcs(s); + } + + size_t NumInputEpsilons(StateId s) { + if (!HasArcs(s)) + Expand(s); + return CacheImpl<A>::NumInputEpsilons(s); + } + + size_t NumOutputEpsilons(StateId s) { + if (!HasArcs(s)) + Expand(s); + return CacheImpl<A>::NumOutputEpsilons(s); + } + + uint64 Properties() const { return Properties(kFstProperties); } + + // Set error if found; return FST impl properties. + uint64 Properties(uint64 mask) const { + if ((mask & kError) && fst_->Properties(kError, false)) + SetProperties(kError, kError); + return FstImpl<Arc>::Properties(mask); + } + + void InitArcIterator(StateId s, ArcIteratorData<A> *data) { + if (!HasArcs(s)) + Expand(s); + CacheImpl<A>::InitArcIterator(s, data); + } + + + // Find state corresponding to an element. Create new state + // if element not found. + StateId FindState(const Element &e) { + if (!(mode_ & kFactorArcWeights) && e.weight == Weight::One()) { + while (unfactored_.size() <= e.state) + unfactored_.push_back(kNoStateId); + if (unfactored_[e.state] == kNoStateId) { + unfactored_[e.state] = elements_.size(); + elements_.push_back(e); + } + return unfactored_[e.state]; + } else { + typename ElementMap::iterator eit = element_map_.find(e); + if (eit != element_map_.end()) { + return (*eit).second; + } else { + StateId s = elements_.size(); + elements_.push_back(e); + element_map_.insert(pair<const Element, StateId>(e, s)); + return s; + } + } + } + + // Computes the outgoing transitions from a state, creating new destination + // states as needed. + void Expand(StateId s) { + Element e = elements_[s]; + if (e.state != kNoStateId) { + for (ArcIterator< Fst<A> > ait(*fst_, e.state); + !ait.Done(); + ait.Next()) { + const A &arc = ait.Value(); + Weight w = Times(e.weight, arc.weight); + FactorIterator fit(w); + if (!(mode_ & kFactorArcWeights) || fit.Done()) { + StateId d = FindState(Element(arc.nextstate, Weight::One())); + PushArc(s, Arc(arc.ilabel, arc.olabel, w, d)); + } else { + for (; !fit.Done(); fit.Next()) { + const pair<Weight, Weight> &p = fit.Value(); + StateId d = FindState(Element(arc.nextstate, + p.second.Quantize(delta_))); + PushArc(s, Arc(arc.ilabel, arc.olabel, p.first, d)); + } + } + } + } + + if ((mode_ & kFactorFinalWeights) && + ((e.state == kNoStateId) || + (fst_->Final(e.state) != Weight::Zero()))) { + Weight w = e.state == kNoStateId + ? e.weight + : Times(e.weight, fst_->Final(e.state)); + for (FactorIterator fit(w); + !fit.Done(); + fit.Next()) { + const pair<Weight, Weight> &p = fit.Value(); + StateId d = FindState(Element(kNoStateId, + p.second.Quantize(delta_))); + PushArc(s, Arc(final_ilabel_, final_olabel_, p.first, d)); + } + } + SetArcs(s); + } + + private: + static const size_t kPrime = 7853; + + // Equality function for Elements, assume weights have been quantized. + class ElementEqual { + public: + bool operator()(const Element &x, const Element &y) const { + return x.state == y.state && x.weight == y.weight; + } + }; + + // Hash function for Elements to Fst states. + class ElementKey { + public: + size_t operator()(const Element &x) const { + return static_cast<size_t>(x.state * kPrime + x.weight.Hash()); + } + private: + }; + + typedef unordered_map<Element, StateId, ElementKey, ElementEqual> ElementMap; + + const Fst<A> *fst_; + float delta_; + uint32 mode_; // factoring arc and/or final weights + Label final_ilabel_; // ilabel of arc created when factoring final w's + Label final_olabel_; // olabel of arc created when factoring final w's + vector<Element> elements_; // mapping Fst state to Elements + ElementMap element_map_; // mapping Elements to Fst state + // mapping between old/new 'StateId' for states that do not need to + // be factored when 'mode_' is '0' or 'kFactorFinalWeights' + vector<StateId> unfactored_; + + void operator=(const FactorWeightFstImpl<A, F> &); // disallow +}; + +template <class A, class F> const size_t FactorWeightFstImpl<A, F>::kPrime; + + +// FactorWeightFst takes as template parameter a FactorIterator as +// defined above. The result of weight factoring is a transducer +// equivalent to the input whose path weights have been factored +// according to the FactorIterator. States and transitions will be +// added as necessary. The algorithm is a generalization to arbitrary +// weights of the second step of the input epsilon-normalization +// algorithm due to Mohri, "Generic epsilon-removal and input +// epsilon-normalization algorithms for weighted transducers", +// International Journal of Computer Science 13(1): 129-143 (2002). +// +// This class attaches interface to implementation and handles +// reference counting, delegating most methods to ImplToFst. +template <class A, class F> +class FactorWeightFst : public ImplToFst< FactorWeightFstImpl<A, F> > { + public: + friend class ArcIterator< FactorWeightFst<A, F> >; + friend class StateIterator< FactorWeightFst<A, F> >; + + typedef A Arc; + typedef typename A::Weight Weight; + typedef typename A::StateId StateId; + typedef CacheState<A> State; + typedef FactorWeightFstImpl<A, F> Impl; + + FactorWeightFst(const Fst<A> &fst) + : ImplToFst<Impl>(new Impl(fst, FactorWeightOptions<A>())) {} + + FactorWeightFst(const Fst<A> &fst, const FactorWeightOptions<A> &opts) + : ImplToFst<Impl>(new Impl(fst, opts)) {} + + // See Fst<>::Copy() for doc. + FactorWeightFst(const FactorWeightFst<A, F> &fst, bool copy) + : ImplToFst<Impl>(fst, copy) {} + + // Get a copy of this FactorWeightFst. See Fst<>::Copy() for further doc. + virtual FactorWeightFst<A, F> *Copy(bool copy = false) const { + return new FactorWeightFst<A, F>(*this, copy); + } + + virtual inline void InitStateIterator(StateIteratorData<A> *data) const; + + virtual void InitArcIterator(StateId s, ArcIteratorData<A> *data) const { + GetImpl()->InitArcIterator(s, data); + } + + private: + // Makes visible to friends. + Impl *GetImpl() const { return ImplToFst<Impl>::GetImpl(); } + + void operator=(const FactorWeightFst<A, F> &fst); // Disallow +}; + + +// Specialization for FactorWeightFst. +template<class A, class F> +class StateIterator< FactorWeightFst<A, F> > + : public CacheStateIterator< FactorWeightFst<A, F> > { + public: + explicit StateIterator(const FactorWeightFst<A, F> &fst) + : CacheStateIterator< FactorWeightFst<A, F> >(fst, fst.GetImpl()) {} +}; + + +// Specialization for FactorWeightFst. +template <class A, class F> +class ArcIterator< FactorWeightFst<A, F> > + : public CacheArcIterator< FactorWeightFst<A, F> > { + public: + typedef typename A::StateId StateId; + + ArcIterator(const FactorWeightFst<A, F> &fst, StateId s) + : CacheArcIterator< FactorWeightFst<A, F> >(fst.GetImpl(), s) { + if (!fst.GetImpl()->HasArcs(s)) + fst.GetImpl()->Expand(s); + } + + private: + DISALLOW_COPY_AND_ASSIGN(ArcIterator); +}; + +template <class A, class F> inline +void FactorWeightFst<A, F>::InitStateIterator(StateIteratorData<A> *data) const +{ + data->base = new StateIterator< FactorWeightFst<A, F> >(*this); +} + + +} // namespace fst + +#endif // FST_LIB_FACTOR_WEIGHT_H__ |