diff options
Diffstat (limited to 'kaldi_io/src/tools/openfst/include/fst/relabel.h')
-rw-r--r-- | kaldi_io/src/tools/openfst/include/fst/relabel.h | 528 |
1 files changed, 0 insertions, 528 deletions
diff --git a/kaldi_io/src/tools/openfst/include/fst/relabel.h b/kaldi_io/src/tools/openfst/include/fst/relabel.h deleted file mode 100644 index dc675b6..0000000 --- a/kaldi_io/src/tools/openfst/include/fst/relabel.h +++ /dev/null @@ -1,528 +0,0 @@ -// relabel.h - -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// -// Copyright 2005-2010 Google, Inc. -// Author: [email protected] (Johan Schalkwyk) -// -// \file -// Functions and classes to relabel an Fst (either on input or output) -// -#ifndef FST_LIB_RELABEL_H__ -#define FST_LIB_RELABEL_H__ - -#include <tr1/unordered_map> -using std::tr1::unordered_map; -using std::tr1::unordered_multimap; -#include <string> -#include <utility> -using std::pair; using std::make_pair; -#include <vector> -using std::vector; - -#include <fst/cache.h> -#include <fst/test-properties.h> - - -#include <tr1/unordered_map> -using std::tr1::unordered_map; -using std::tr1::unordered_multimap; - -namespace fst { - -// -// Relabels either the input labels or output labels. The old to -// new labels are specified using a vector of pair<Label,Label>. -// Any label associations not specified are assumed to be identity -// mapping. -// -// \param fst input fst, must be mutable -// \param ipairs vector of input label pairs indicating old to new mapping -// \param opairs vector of output label pairs indicating old to new mapping -// -template <class A> -void Relabel( - MutableFst<A> *fst, - const vector<pair<typename A::Label, typename A::Label> >& ipairs, - const vector<pair<typename A::Label, typename A::Label> >& opairs) { - typedef typename A::StateId StateId; - typedef typename A::Label Label; - - uint64 props = fst->Properties(kFstProperties, false); - - // construct label to label hash. - unordered_map<Label, Label> input_map; - for (size_t i = 0; i < ipairs.size(); ++i) { - input_map[ipairs[i].first] = ipairs[i].second; - } - - unordered_map<Label, Label> output_map; - for (size_t i = 0; i < opairs.size(); ++i) { - output_map[opairs[i].first] = opairs[i].second; - } - - for (StateIterator<MutableFst<A> > siter(*fst); - !siter.Done(); siter.Next()) { - StateId s = siter.Value(); - for (MutableArcIterator<MutableFst<A> > aiter(fst, s); - !aiter.Done(); aiter.Next()) { - A arc = aiter.Value(); - - // relabel input - // only relabel if relabel pair defined - typename unordered_map<Label, Label>::iterator it = - input_map.find(arc.ilabel); - if (it != input_map.end()) { - if (it->second == kNoLabel) { - FSTERROR() << "Input symbol id " << arc.ilabel - << " missing from target vocabulary"; - fst->SetProperties(kError, kError); - return; - } - arc.ilabel = it->second; - } - - // relabel output - it = output_map.find(arc.olabel); - if (it != output_map.end()) { - if (it->second == kNoLabel) { - FSTERROR() << "Output symbol id " << arc.olabel - << " missing from target vocabulary"; - fst->SetProperties(kError, kError); - return; - } - arc.olabel = it->second; - } - - aiter.SetValue(arc); - } - } - - fst->SetProperties(RelabelProperties(props), kFstProperties); -} - -// -// Relabels either the input labels or output labels. The old to -// new labels mappings are specified using an input Symbol set. -// Any label associations not specified are assumed to be identity -// mapping. -// -// \param fst input fst, must be mutable -// \param new_isymbols symbol set indicating new mapping of input symbols -// \param new_osymbols symbol set indicating new mapping of output symbols -// -template<class A> -void Relabel(MutableFst<A> *fst, - const SymbolTable* new_isymbols, - const SymbolTable* new_osymbols) { - Relabel(fst, - fst->InputSymbols(), new_isymbols, true, - fst->OutputSymbols(), new_osymbols, true); -} - -template<class A> -void Relabel(MutableFst<A> *fst, - const SymbolTable* old_isymbols, - const SymbolTable* new_isymbols, - bool attach_new_isymbols, - const SymbolTable* old_osymbols, - const SymbolTable* new_osymbols, - bool attach_new_osymbols) { - typedef typename A::StateId StateId; - typedef typename A::Label Label; - - vector<pair<Label, Label> > ipairs; - if (old_isymbols && new_isymbols) { - for (SymbolTableIterator syms_iter(*old_isymbols); !syms_iter.Done(); - syms_iter.Next()) { - string isymbol = syms_iter.Symbol(); - int isymbol_val = syms_iter.Value(); - int new_isymbol_val = new_isymbols->Find(isymbol); - ipairs.push_back(make_pair(isymbol_val, new_isymbol_val)); - } - if (attach_new_isymbols) - fst->SetInputSymbols(new_isymbols); - } - - vector<pair<Label, Label> > opairs; - if (old_osymbols && new_osymbols) { - for (SymbolTableIterator syms_iter(*old_osymbols); !syms_iter.Done(); - syms_iter.Next()) { - string osymbol = syms_iter.Symbol(); - int osymbol_val = syms_iter.Value(); - int new_osymbol_val = new_osymbols->Find(osymbol); - opairs.push_back(make_pair(osymbol_val, new_osymbol_val)); - } - if (attach_new_osymbols) - fst->SetOutputSymbols(new_osymbols); - } - - // call relabel using vector of relabel pairs. - Relabel(fst, ipairs, opairs); -} - - -typedef CacheOptions RelabelFstOptions; - -template <class A> class RelabelFst; - -// -// \class RelabelFstImpl -// \brief Implementation for delayed relabeling -// -// Relabels an FST from one symbol set to another. Relabeling -// can either be on input or output space. RelabelFst implements -// a delayed version of the relabel. Arcs are relabeled on the fly -// and not cached. I.e each request is recomputed. -// -template<class A> -class RelabelFstImpl : public CacheImpl<A> { - friend class StateIterator< RelabelFst<A> >; - public: - using FstImpl<A>::SetType; - using FstImpl<A>::SetProperties; - using FstImpl<A>::WriteHeader; - using FstImpl<A>::SetInputSymbols; - using FstImpl<A>::SetOutputSymbols; - - using CacheImpl<A>::PushArc; - using CacheImpl<A>::HasArcs; - using CacheImpl<A>::HasFinal; - using CacheImpl<A>::HasStart; - using CacheImpl<A>::SetArcs; - using CacheImpl<A>::SetFinal; - using CacheImpl<A>::SetStart; - - typedef A Arc; - typedef typename A::Label Label; - typedef typename A::Weight Weight; - typedef typename A::StateId StateId; - typedef CacheState<A> State; - - RelabelFstImpl(const Fst<A>& fst, - const vector<pair<Label, Label> >& ipairs, - const vector<pair<Label, Label> >& opairs, - const RelabelFstOptions &opts) - : CacheImpl<A>(opts), fst_(fst.Copy()), - relabel_input_(false), relabel_output_(false) { - uint64 props = fst.Properties(kCopyProperties, false); - SetProperties(RelabelProperties(props)); - SetType("relabel"); - - // create input label map - if (ipairs.size() > 0) { - for (size_t i = 0; i < ipairs.size(); ++i) { - input_map_[ipairs[i].first] = ipairs[i].second; - } - relabel_input_ = true; - } - - // create output label map - if (opairs.size() > 0) { - for (size_t i = 0; i < opairs.size(); ++i) { - output_map_[opairs[i].first] = opairs[i].second; - } - relabel_output_ = true; - } - } - - RelabelFstImpl(const Fst<A>& fst, - const SymbolTable* old_isymbols, - const SymbolTable* new_isymbols, - const SymbolTable* old_osymbols, - const SymbolTable* new_osymbols, - const RelabelFstOptions &opts) - : CacheImpl<A>(opts), fst_(fst.Copy()), - relabel_input_(false), relabel_output_(false) { - SetType("relabel"); - - uint64 props = fst.Properties(kCopyProperties, false); - SetProperties(RelabelProperties(props)); - SetInputSymbols(old_isymbols); - SetOutputSymbols(old_osymbols); - - if (old_isymbols && new_isymbols && - old_isymbols->LabeledCheckSum() != new_isymbols->LabeledCheckSum()) { - for (SymbolTableIterator syms_iter(*old_isymbols); !syms_iter.Done(); - syms_iter.Next()) { - input_map_[syms_iter.Value()] = new_isymbols->Find(syms_iter.Symbol()); - } - SetInputSymbols(new_isymbols); - relabel_input_ = true; - } - - if (old_osymbols && new_osymbols && - old_osymbols->LabeledCheckSum() != new_osymbols->LabeledCheckSum()) { - for (SymbolTableIterator syms_iter(*old_osymbols); !syms_iter.Done(); - syms_iter.Next()) { - output_map_[syms_iter.Value()] = - new_osymbols->Find(syms_iter.Symbol()); - } - SetOutputSymbols(new_osymbols); - relabel_output_ = true; - } - } - - RelabelFstImpl(const RelabelFstImpl<A>& impl) - : CacheImpl<A>(impl), - fst_(impl.fst_->Copy(true)), - input_map_(impl.input_map_), - output_map_(impl.output_map_), - relabel_input_(impl.relabel_input_), - relabel_output_(impl.relabel_output_) { - SetType("relabel"); - SetProperties(impl.Properties(), kCopyProperties); - SetInputSymbols(impl.InputSymbols()); - SetOutputSymbols(impl.OutputSymbols()); - } - - ~RelabelFstImpl() { delete fst_; } - - StateId Start() { - if (!HasStart()) { - StateId s = fst_->Start(); - SetStart(s); - } - return CacheImpl<A>::Start(); - } - - Weight Final(StateId s) { - if (!HasFinal(s)) { - SetFinal(s, fst_->Final(s)); - } - return CacheImpl<A>::Final(s); - } - - size_t NumArcs(StateId s) { - if (!HasArcs(s)) { - Expand(s); - } - return CacheImpl<A>::NumArcs(s); - } - - size_t NumInputEpsilons(StateId s) { - if (!HasArcs(s)) { - Expand(s); - } - return CacheImpl<A>::NumInputEpsilons(s); - } - - size_t NumOutputEpsilons(StateId s) { - if (!HasArcs(s)) { - Expand(s); - } - return CacheImpl<A>::NumOutputEpsilons(s); - } - - uint64 Properties() const { return Properties(kFstProperties); } - - // Set error if found; return FST impl properties. - uint64 Properties(uint64 mask) const { - if ((mask & kError) && fst_->Properties(kError, false)) - SetProperties(kError, kError); - return FstImpl<Arc>::Properties(mask); - } - - void InitArcIterator(StateId s, ArcIteratorData<A>* data) { - if (!HasArcs(s)) { - Expand(s); - } - CacheImpl<A>::InitArcIterator(s, data); - } - - void Expand(StateId s) { - for (ArcIterator<Fst<A> > aiter(*fst_, s); !aiter.Done(); aiter.Next()) { - A arc = aiter.Value(); - - // relabel input - if (relabel_input_) { - typename unordered_map<Label, Label>::iterator it = - input_map_.find(arc.ilabel); - if (it != input_map_.end()) { arc.ilabel = it->second; } - } - - // relabel output - if (relabel_output_) { - typename unordered_map<Label, Label>::iterator it = - output_map_.find(arc.olabel); - if (it != output_map_.end()) { arc.olabel = it->second; } - } - - PushArc(s, arc); - } - SetArcs(s); - } - - - private: - const Fst<A> *fst_; - - unordered_map<Label, Label> input_map_; - unordered_map<Label, Label> output_map_; - bool relabel_input_; - bool relabel_output_; - - void operator=(const RelabelFstImpl<A> &); // disallow -}; - - -// -// \class RelabelFst -// \brief Delayed implementation of arc relabeling -// -// This class attaches interface to implementation and handles -// reference counting, delegating most methods to ImplToFst. -template <class A> -class RelabelFst : public ImplToFst< RelabelFstImpl<A> > { - public: - friend class ArcIterator< RelabelFst<A> >; - friend class StateIterator< RelabelFst<A> >; - - typedef A Arc; - typedef typename A::Label Label; - typedef typename A::Weight Weight; - typedef typename A::StateId StateId; - typedef CacheState<A> State; - typedef RelabelFstImpl<A> Impl; - - RelabelFst(const Fst<A>& fst, - const vector<pair<Label, Label> >& ipairs, - const vector<pair<Label, Label> >& opairs) - : ImplToFst<Impl>(new Impl(fst, ipairs, opairs, RelabelFstOptions())) {} - - RelabelFst(const Fst<A>& fst, - const vector<pair<Label, Label> >& ipairs, - const vector<pair<Label, Label> >& opairs, - const RelabelFstOptions &opts) - : ImplToFst<Impl>(new Impl(fst, ipairs, opairs, opts)) {} - - RelabelFst(const Fst<A>& fst, - const SymbolTable* new_isymbols, - const SymbolTable* new_osymbols) - : ImplToFst<Impl>(new Impl(fst, fst.InputSymbols(), new_isymbols, - fst.OutputSymbols(), new_osymbols, - RelabelFstOptions())) {} - - RelabelFst(const Fst<A>& fst, - const SymbolTable* new_isymbols, - const SymbolTable* new_osymbols, - const RelabelFstOptions &opts) - : ImplToFst<Impl>(new Impl(fst, fst.InputSymbols(), new_isymbols, - fst.OutputSymbols(), new_osymbols, opts)) {} - - RelabelFst(const Fst<A>& fst, - const SymbolTable* old_isymbols, - const SymbolTable* new_isymbols, - const SymbolTable* old_osymbols, - const SymbolTable* new_osymbols) - : ImplToFst<Impl>(new Impl(fst, old_isymbols, new_isymbols, old_osymbols, - new_osymbols, RelabelFstOptions())) {} - - RelabelFst(const Fst<A>& fst, - const SymbolTable* old_isymbols, - const SymbolTable* new_isymbols, - const SymbolTable* old_osymbols, - const SymbolTable* new_osymbols, - const RelabelFstOptions &opts) - : ImplToFst<Impl>(new Impl(fst, old_isymbols, new_isymbols, old_osymbols, - new_osymbols, opts)) {} - - // See Fst<>::Copy() for doc. - RelabelFst(const RelabelFst<A> &fst, bool safe = false) - : ImplToFst<Impl>(fst, safe) {} - - // Get a copy of this RelabelFst. See Fst<>::Copy() for further doc. - virtual RelabelFst<A> *Copy(bool safe = false) const { - return new RelabelFst<A>(*this, safe); - } - - virtual void InitStateIterator(StateIteratorData<A> *data) const; - - virtual void InitArcIterator(StateId s, ArcIteratorData<A> *data) const { - return GetImpl()->InitArcIterator(s, data); - } - - private: - // Makes visible to friends. - Impl *GetImpl() const { return ImplToFst<Impl>::GetImpl(); } - - void operator=(const RelabelFst<A> &fst); // disallow -}; - -// Specialization for RelabelFst. -template<class A> -class StateIterator< RelabelFst<A> > : public StateIteratorBase<A> { - public: - typedef typename A::StateId StateId; - - explicit StateIterator(const RelabelFst<A> &fst) - : impl_(fst.GetImpl()), siter_(*impl_->fst_), s_(0) {} - - bool Done() const { return siter_.Done(); } - - StateId Value() const { return s_; } - - void Next() { - if (!siter_.Done()) { - ++s_; - siter_.Next(); - } - } - - void Reset() { - s_ = 0; - siter_.Reset(); - } - - private: - bool Done_() const { return Done(); } - StateId Value_() const { return Value(); } - void Next_() { Next(); } - void Reset_() { Reset(); } - - const RelabelFstImpl<A> *impl_; - StateIterator< Fst<A> > siter_; - StateId s_; - - DISALLOW_COPY_AND_ASSIGN(StateIterator); -}; - - -// Specialization for RelabelFst. -template <class A> -class ArcIterator< RelabelFst<A> > - : public CacheArcIterator< RelabelFst<A> > { - public: - typedef typename A::StateId StateId; - - ArcIterator(const RelabelFst<A> &fst, StateId s) - : CacheArcIterator< RelabelFst<A> >(fst.GetImpl(), s) { - if (!fst.GetImpl()->HasArcs(s)) - fst.GetImpl()->Expand(s); - } - - private: - DISALLOW_COPY_AND_ASSIGN(ArcIterator); -}; - -template <class A> inline -void RelabelFst<A>::InitStateIterator(StateIteratorData<A> *data) const { - data->base = new StateIterator< RelabelFst<A> >(*this); -} - -// Useful alias when using StdArc. -typedef RelabelFst<StdArc> StdRelabelFst; - -} // namespace fst - -#endif // FST_LIB_RELABEL_H__ |