diff options
Diffstat (limited to 'kaldi_io/src/tools/openfst/include/fst/relabel.h')
-rw-r--r-- | kaldi_io/src/tools/openfst/include/fst/relabel.h | 528 |
1 files changed, 528 insertions, 0 deletions
diff --git a/kaldi_io/src/tools/openfst/include/fst/relabel.h b/kaldi_io/src/tools/openfst/include/fst/relabel.h new file mode 100644 index 0000000..dc675b6 --- /dev/null +++ b/kaldi_io/src/tools/openfst/include/fst/relabel.h @@ -0,0 +1,528 @@ +// relabel.h + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Copyright 2005-2010 Google, Inc. +// Author: [email protected] (Johan Schalkwyk) +// +// \file +// Functions and classes to relabel an Fst (either on input or output) +// +#ifndef FST_LIB_RELABEL_H__ +#define FST_LIB_RELABEL_H__ + +#include <tr1/unordered_map> +using std::tr1::unordered_map; +using std::tr1::unordered_multimap; +#include <string> +#include <utility> +using std::pair; using std::make_pair; +#include <vector> +using std::vector; + +#include <fst/cache.h> +#include <fst/test-properties.h> + + +#include <tr1/unordered_map> +using std::tr1::unordered_map; +using std::tr1::unordered_multimap; + +namespace fst { + +// +// Relabels either the input labels or output labels. The old to +// new labels are specified using a vector of pair<Label,Label>. +// Any label associations not specified are assumed to be identity +// mapping. +// +// \param fst input fst, must be mutable +// \param ipairs vector of input label pairs indicating old to new mapping +// \param opairs vector of output label pairs indicating old to new mapping +// +template <class A> +void Relabel( + MutableFst<A> *fst, + const vector<pair<typename A::Label, typename A::Label> >& ipairs, + const vector<pair<typename A::Label, typename A::Label> >& opairs) { + typedef typename A::StateId StateId; + typedef typename A::Label Label; + + uint64 props = fst->Properties(kFstProperties, false); + + // construct label to label hash. + unordered_map<Label, Label> input_map; + for (size_t i = 0; i < ipairs.size(); ++i) { + input_map[ipairs[i].first] = ipairs[i].second; + } + + unordered_map<Label, Label> output_map; + for (size_t i = 0; i < opairs.size(); ++i) { + output_map[opairs[i].first] = opairs[i].second; + } + + for (StateIterator<MutableFst<A> > siter(*fst); + !siter.Done(); siter.Next()) { + StateId s = siter.Value(); + for (MutableArcIterator<MutableFst<A> > aiter(fst, s); + !aiter.Done(); aiter.Next()) { + A arc = aiter.Value(); + + // relabel input + // only relabel if relabel pair defined + typename unordered_map<Label, Label>::iterator it = + input_map.find(arc.ilabel); + if (it != input_map.end()) { + if (it->second == kNoLabel) { + FSTERROR() << "Input symbol id " << arc.ilabel + << " missing from target vocabulary"; + fst->SetProperties(kError, kError); + return; + } + arc.ilabel = it->second; + } + + // relabel output + it = output_map.find(arc.olabel); + if (it != output_map.end()) { + if (it->second == kNoLabel) { + FSTERROR() << "Output symbol id " << arc.olabel + << " missing from target vocabulary"; + fst->SetProperties(kError, kError); + return; + } + arc.olabel = it->second; + } + + aiter.SetValue(arc); + } + } + + fst->SetProperties(RelabelProperties(props), kFstProperties); +} + +// +// Relabels either the input labels or output labels. The old to +// new labels mappings are specified using an input Symbol set. +// Any label associations not specified are assumed to be identity +// mapping. +// +// \param fst input fst, must be mutable +// \param new_isymbols symbol set indicating new mapping of input symbols +// \param new_osymbols symbol set indicating new mapping of output symbols +// +template<class A> +void Relabel(MutableFst<A> *fst, + const SymbolTable* new_isymbols, + const SymbolTable* new_osymbols) { + Relabel(fst, + fst->InputSymbols(), new_isymbols, true, + fst->OutputSymbols(), new_osymbols, true); +} + +template<class A> +void Relabel(MutableFst<A> *fst, + const SymbolTable* old_isymbols, + const SymbolTable* new_isymbols, + bool attach_new_isymbols, + const SymbolTable* old_osymbols, + const SymbolTable* new_osymbols, + bool attach_new_osymbols) { + typedef typename A::StateId StateId; + typedef typename A::Label Label; + + vector<pair<Label, Label> > ipairs; + if (old_isymbols && new_isymbols) { + for (SymbolTableIterator syms_iter(*old_isymbols); !syms_iter.Done(); + syms_iter.Next()) { + string isymbol = syms_iter.Symbol(); + int isymbol_val = syms_iter.Value(); + int new_isymbol_val = new_isymbols->Find(isymbol); + ipairs.push_back(make_pair(isymbol_val, new_isymbol_val)); + } + if (attach_new_isymbols) + fst->SetInputSymbols(new_isymbols); + } + + vector<pair<Label, Label> > opairs; + if (old_osymbols && new_osymbols) { + for (SymbolTableIterator syms_iter(*old_osymbols); !syms_iter.Done(); + syms_iter.Next()) { + string osymbol = syms_iter.Symbol(); + int osymbol_val = syms_iter.Value(); + int new_osymbol_val = new_osymbols->Find(osymbol); + opairs.push_back(make_pair(osymbol_val, new_osymbol_val)); + } + if (attach_new_osymbols) + fst->SetOutputSymbols(new_osymbols); + } + + // call relabel using vector of relabel pairs. + Relabel(fst, ipairs, opairs); +} + + +typedef CacheOptions RelabelFstOptions; + +template <class A> class RelabelFst; + +// +// \class RelabelFstImpl +// \brief Implementation for delayed relabeling +// +// Relabels an FST from one symbol set to another. Relabeling +// can either be on input or output space. RelabelFst implements +// a delayed version of the relabel. Arcs are relabeled on the fly +// and not cached. I.e each request is recomputed. +// +template<class A> +class RelabelFstImpl : public CacheImpl<A> { + friend class StateIterator< RelabelFst<A> >; + public: + using FstImpl<A>::SetType; + using FstImpl<A>::SetProperties; + using FstImpl<A>::WriteHeader; + using FstImpl<A>::SetInputSymbols; + using FstImpl<A>::SetOutputSymbols; + + using CacheImpl<A>::PushArc; + using CacheImpl<A>::HasArcs; + using CacheImpl<A>::HasFinal; + using CacheImpl<A>::HasStart; + using CacheImpl<A>::SetArcs; + using CacheImpl<A>::SetFinal; + using CacheImpl<A>::SetStart; + + typedef A Arc; + typedef typename A::Label Label; + typedef typename A::Weight Weight; + typedef typename A::StateId StateId; + typedef CacheState<A> State; + + RelabelFstImpl(const Fst<A>& fst, + const vector<pair<Label, Label> >& ipairs, + const vector<pair<Label, Label> >& opairs, + const RelabelFstOptions &opts) + : CacheImpl<A>(opts), fst_(fst.Copy()), + relabel_input_(false), relabel_output_(false) { + uint64 props = fst.Properties(kCopyProperties, false); + SetProperties(RelabelProperties(props)); + SetType("relabel"); + + // create input label map + if (ipairs.size() > 0) { + for (size_t i = 0; i < ipairs.size(); ++i) { + input_map_[ipairs[i].first] = ipairs[i].second; + } + relabel_input_ = true; + } + + // create output label map + if (opairs.size() > 0) { + for (size_t i = 0; i < opairs.size(); ++i) { + output_map_[opairs[i].first] = opairs[i].second; + } + relabel_output_ = true; + } + } + + RelabelFstImpl(const Fst<A>& fst, + const SymbolTable* old_isymbols, + const SymbolTable* new_isymbols, + const SymbolTable* old_osymbols, + const SymbolTable* new_osymbols, + const RelabelFstOptions &opts) + : CacheImpl<A>(opts), fst_(fst.Copy()), + relabel_input_(false), relabel_output_(false) { + SetType("relabel"); + + uint64 props = fst.Properties(kCopyProperties, false); + SetProperties(RelabelProperties(props)); + SetInputSymbols(old_isymbols); + SetOutputSymbols(old_osymbols); + + if (old_isymbols && new_isymbols && + old_isymbols->LabeledCheckSum() != new_isymbols->LabeledCheckSum()) { + for (SymbolTableIterator syms_iter(*old_isymbols); !syms_iter.Done(); + syms_iter.Next()) { + input_map_[syms_iter.Value()] = new_isymbols->Find(syms_iter.Symbol()); + } + SetInputSymbols(new_isymbols); + relabel_input_ = true; + } + + if (old_osymbols && new_osymbols && + old_osymbols->LabeledCheckSum() != new_osymbols->LabeledCheckSum()) { + for (SymbolTableIterator syms_iter(*old_osymbols); !syms_iter.Done(); + syms_iter.Next()) { + output_map_[syms_iter.Value()] = + new_osymbols->Find(syms_iter.Symbol()); + } + SetOutputSymbols(new_osymbols); + relabel_output_ = true; + } + } + + RelabelFstImpl(const RelabelFstImpl<A>& impl) + : CacheImpl<A>(impl), + fst_(impl.fst_->Copy(true)), + input_map_(impl.input_map_), + output_map_(impl.output_map_), + relabel_input_(impl.relabel_input_), + relabel_output_(impl.relabel_output_) { + SetType("relabel"); + SetProperties(impl.Properties(), kCopyProperties); + SetInputSymbols(impl.InputSymbols()); + SetOutputSymbols(impl.OutputSymbols()); + } + + ~RelabelFstImpl() { delete fst_; } + + StateId Start() { + if (!HasStart()) { + StateId s = fst_->Start(); + SetStart(s); + } + return CacheImpl<A>::Start(); + } + + Weight Final(StateId s) { + if (!HasFinal(s)) { + SetFinal(s, fst_->Final(s)); + } + return CacheImpl<A>::Final(s); + } + + size_t NumArcs(StateId s) { + if (!HasArcs(s)) { + Expand(s); + } + return CacheImpl<A>::NumArcs(s); + } + + size_t NumInputEpsilons(StateId s) { + if (!HasArcs(s)) { + Expand(s); + } + return CacheImpl<A>::NumInputEpsilons(s); + } + + size_t NumOutputEpsilons(StateId s) { + if (!HasArcs(s)) { + Expand(s); + } + return CacheImpl<A>::NumOutputEpsilons(s); + } + + uint64 Properties() const { return Properties(kFstProperties); } + + // Set error if found; return FST impl properties. + uint64 Properties(uint64 mask) const { + if ((mask & kError) && fst_->Properties(kError, false)) + SetProperties(kError, kError); + return FstImpl<Arc>::Properties(mask); + } + + void InitArcIterator(StateId s, ArcIteratorData<A>* data) { + if (!HasArcs(s)) { + Expand(s); + } + CacheImpl<A>::InitArcIterator(s, data); + } + + void Expand(StateId s) { + for (ArcIterator<Fst<A> > aiter(*fst_, s); !aiter.Done(); aiter.Next()) { + A arc = aiter.Value(); + + // relabel input + if (relabel_input_) { + typename unordered_map<Label, Label>::iterator it = + input_map_.find(arc.ilabel); + if (it != input_map_.end()) { arc.ilabel = it->second; } + } + + // relabel output + if (relabel_output_) { + typename unordered_map<Label, Label>::iterator it = + output_map_.find(arc.olabel); + if (it != output_map_.end()) { arc.olabel = it->second; } + } + + PushArc(s, arc); + } + SetArcs(s); + } + + + private: + const Fst<A> *fst_; + + unordered_map<Label, Label> input_map_; + unordered_map<Label, Label> output_map_; + bool relabel_input_; + bool relabel_output_; + + void operator=(const RelabelFstImpl<A> &); // disallow +}; + + +// +// \class RelabelFst +// \brief Delayed implementation of arc relabeling +// +// This class attaches interface to implementation and handles +// reference counting, delegating most methods to ImplToFst. +template <class A> +class RelabelFst : public ImplToFst< RelabelFstImpl<A> > { + public: + friend class ArcIterator< RelabelFst<A> >; + friend class StateIterator< RelabelFst<A> >; + + typedef A Arc; + typedef typename A::Label Label; + typedef typename A::Weight Weight; + typedef typename A::StateId StateId; + typedef CacheState<A> State; + typedef RelabelFstImpl<A> Impl; + + RelabelFst(const Fst<A>& fst, + const vector<pair<Label, Label> >& ipairs, + const vector<pair<Label, Label> >& opairs) + : ImplToFst<Impl>(new Impl(fst, ipairs, opairs, RelabelFstOptions())) {} + + RelabelFst(const Fst<A>& fst, + const vector<pair<Label, Label> >& ipairs, + const vector<pair<Label, Label> >& opairs, + const RelabelFstOptions &opts) + : ImplToFst<Impl>(new Impl(fst, ipairs, opairs, opts)) {} + + RelabelFst(const Fst<A>& fst, + const SymbolTable* new_isymbols, + const SymbolTable* new_osymbols) + : ImplToFst<Impl>(new Impl(fst, fst.InputSymbols(), new_isymbols, + fst.OutputSymbols(), new_osymbols, + RelabelFstOptions())) {} + + RelabelFst(const Fst<A>& fst, + const SymbolTable* new_isymbols, + const SymbolTable* new_osymbols, + const RelabelFstOptions &opts) + : ImplToFst<Impl>(new Impl(fst, fst.InputSymbols(), new_isymbols, + fst.OutputSymbols(), new_osymbols, opts)) {} + + RelabelFst(const Fst<A>& fst, + const SymbolTable* old_isymbols, + const SymbolTable* new_isymbols, + const SymbolTable* old_osymbols, + const SymbolTable* new_osymbols) + : ImplToFst<Impl>(new Impl(fst, old_isymbols, new_isymbols, old_osymbols, + new_osymbols, RelabelFstOptions())) {} + + RelabelFst(const Fst<A>& fst, + const SymbolTable* old_isymbols, + const SymbolTable* new_isymbols, + const SymbolTable* old_osymbols, + const SymbolTable* new_osymbols, + const RelabelFstOptions &opts) + : ImplToFst<Impl>(new Impl(fst, old_isymbols, new_isymbols, old_osymbols, + new_osymbols, opts)) {} + + // See Fst<>::Copy() for doc. + RelabelFst(const RelabelFst<A> &fst, bool safe = false) + : ImplToFst<Impl>(fst, safe) {} + + // Get a copy of this RelabelFst. See Fst<>::Copy() for further doc. + virtual RelabelFst<A> *Copy(bool safe = false) const { + return new RelabelFst<A>(*this, safe); + } + + virtual void InitStateIterator(StateIteratorData<A> *data) const; + + virtual void InitArcIterator(StateId s, ArcIteratorData<A> *data) const { + return GetImpl()->InitArcIterator(s, data); + } + + private: + // Makes visible to friends. + Impl *GetImpl() const { return ImplToFst<Impl>::GetImpl(); } + + void operator=(const RelabelFst<A> &fst); // disallow +}; + +// Specialization for RelabelFst. +template<class A> +class StateIterator< RelabelFst<A> > : public StateIteratorBase<A> { + public: + typedef typename A::StateId StateId; + + explicit StateIterator(const RelabelFst<A> &fst) + : impl_(fst.GetImpl()), siter_(*impl_->fst_), s_(0) {} + + bool Done() const { return siter_.Done(); } + + StateId Value() const { return s_; } + + void Next() { + if (!siter_.Done()) { + ++s_; + siter_.Next(); + } + } + + void Reset() { + s_ = 0; + siter_.Reset(); + } + + private: + bool Done_() const { return Done(); } + StateId Value_() const { return Value(); } + void Next_() { Next(); } + void Reset_() { Reset(); } + + const RelabelFstImpl<A> *impl_; + StateIterator< Fst<A> > siter_; + StateId s_; + + DISALLOW_COPY_AND_ASSIGN(StateIterator); +}; + + +// Specialization for RelabelFst. +template <class A> +class ArcIterator< RelabelFst<A> > + : public CacheArcIterator< RelabelFst<A> > { + public: + typedef typename A::StateId StateId; + + ArcIterator(const RelabelFst<A> &fst, StateId s) + : CacheArcIterator< RelabelFst<A> >(fst.GetImpl(), s) { + if (!fst.GetImpl()->HasArcs(s)) + fst.GetImpl()->Expand(s); + } + + private: + DISALLOW_COPY_AND_ASSIGN(ArcIterator); +}; + +template <class A> inline +void RelabelFst<A>::InitStateIterator(StateIteratorData<A> *data) const { + data->base = new StateIterator< RelabelFst<A> >(*this); +} + +// Useful alias when using StdArc. +typedef RelabelFst<StdArc> StdRelabelFst; + +} // namespace fst + +#endif // FST_LIB_RELABEL_H__ |