From 96a32415ab43377cf1575bd3f4f2980f58028209 Mon Sep 17 00:00:00 2001 From: Determinant Date: Fri, 14 Aug 2015 11:51:42 +0800 Subject: add implementation for kaldi io (by ymz) --- kaldi_io/src/tools/openfst/include/fst/relabel.h | 528 +++++++++++++++++++++++ 1 file changed, 528 insertions(+) create mode 100644 kaldi_io/src/tools/openfst/include/fst/relabel.h (limited to 'kaldi_io/src/tools/openfst/include/fst/relabel.h') diff --git a/kaldi_io/src/tools/openfst/include/fst/relabel.h b/kaldi_io/src/tools/openfst/include/fst/relabel.h new file mode 100644 index 0000000..dc675b6 --- /dev/null +++ b/kaldi_io/src/tools/openfst/include/fst/relabel.h @@ -0,0 +1,528 @@ +// relabel.h + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Copyright 2005-2010 Google, Inc. +// Author: johans@google.com (Johan Schalkwyk) +// +// \file +// Functions and classes to relabel an Fst (either on input or output) +// +#ifndef FST_LIB_RELABEL_H__ +#define FST_LIB_RELABEL_H__ + +#include +using std::tr1::unordered_map; +using std::tr1::unordered_multimap; +#include +#include +using std::pair; using std::make_pair; +#include +using std::vector; + +#include +#include + + +#include +using std::tr1::unordered_map; +using std::tr1::unordered_multimap; + +namespace fst { + +// +// Relabels either the input labels or output labels. The old to +// new labels are specified using a vector of pair. +// Any label associations not specified are assumed to be identity +// mapping. +// +// \param fst input fst, must be mutable +// \param ipairs vector of input label pairs indicating old to new mapping +// \param opairs vector of output label pairs indicating old to new mapping +// +template +void Relabel( + MutableFst *fst, + const vector >& ipairs, + const vector >& opairs) { + typedef typename A::StateId StateId; + typedef typename A::Label Label; + + uint64 props = fst->Properties(kFstProperties, false); + + // construct label to label hash. + unordered_map input_map; + for (size_t i = 0; i < ipairs.size(); ++i) { + input_map[ipairs[i].first] = ipairs[i].second; + } + + unordered_map output_map; + for (size_t i = 0; i < opairs.size(); ++i) { + output_map[opairs[i].first] = opairs[i].second; + } + + for (StateIterator > siter(*fst); + !siter.Done(); siter.Next()) { + StateId s = siter.Value(); + for (MutableArcIterator > aiter(fst, s); + !aiter.Done(); aiter.Next()) { + A arc = aiter.Value(); + + // relabel input + // only relabel if relabel pair defined + typename unordered_map::iterator it = + input_map.find(arc.ilabel); + if (it != input_map.end()) { + if (it->second == kNoLabel) { + FSTERROR() << "Input symbol id " << arc.ilabel + << " missing from target vocabulary"; + fst->SetProperties(kError, kError); + return; + } + arc.ilabel = it->second; + } + + // relabel output + it = output_map.find(arc.olabel); + if (it != output_map.end()) { + if (it->second == kNoLabel) { + FSTERROR() << "Output symbol id " << arc.olabel + << " missing from target vocabulary"; + fst->SetProperties(kError, kError); + return; + } + arc.olabel = it->second; + } + + aiter.SetValue(arc); + } + } + + fst->SetProperties(RelabelProperties(props), kFstProperties); +} + +// +// Relabels either the input labels or output labels. The old to +// new labels mappings are specified using an input Symbol set. +// Any label associations not specified are assumed to be identity +// mapping. +// +// \param fst input fst, must be mutable +// \param new_isymbols symbol set indicating new mapping of input symbols +// \param new_osymbols symbol set indicating new mapping of output symbols +// +template +void Relabel(MutableFst *fst, + const SymbolTable* new_isymbols, + const SymbolTable* new_osymbols) { + Relabel(fst, + fst->InputSymbols(), new_isymbols, true, + fst->OutputSymbols(), new_osymbols, true); +} + +template +void Relabel(MutableFst *fst, + const SymbolTable* old_isymbols, + const SymbolTable* new_isymbols, + bool attach_new_isymbols, + const SymbolTable* old_osymbols, + const SymbolTable* new_osymbols, + bool attach_new_osymbols) { + typedef typename A::StateId StateId; + typedef typename A::Label Label; + + vector > ipairs; + if (old_isymbols && new_isymbols) { + for (SymbolTableIterator syms_iter(*old_isymbols); !syms_iter.Done(); + syms_iter.Next()) { + string isymbol = syms_iter.Symbol(); + int isymbol_val = syms_iter.Value(); + int new_isymbol_val = new_isymbols->Find(isymbol); + ipairs.push_back(make_pair(isymbol_val, new_isymbol_val)); + } + if (attach_new_isymbols) + fst->SetInputSymbols(new_isymbols); + } + + vector > opairs; + if (old_osymbols && new_osymbols) { + for (SymbolTableIterator syms_iter(*old_osymbols); !syms_iter.Done(); + syms_iter.Next()) { + string osymbol = syms_iter.Symbol(); + int osymbol_val = syms_iter.Value(); + int new_osymbol_val = new_osymbols->Find(osymbol); + opairs.push_back(make_pair(osymbol_val, new_osymbol_val)); + } + if (attach_new_osymbols) + fst->SetOutputSymbols(new_osymbols); + } + + // call relabel using vector of relabel pairs. + Relabel(fst, ipairs, opairs); +} + + +typedef CacheOptions RelabelFstOptions; + +template class RelabelFst; + +// +// \class RelabelFstImpl +// \brief Implementation for delayed relabeling +// +// Relabels an FST from one symbol set to another. Relabeling +// can either be on input or output space. RelabelFst implements +// a delayed version of the relabel. Arcs are relabeled on the fly +// and not cached. I.e each request is recomputed. +// +template +class RelabelFstImpl : public CacheImpl { + friend class StateIterator< RelabelFst >; + public: + using FstImpl::SetType; + using FstImpl::SetProperties; + using FstImpl::WriteHeader; + using FstImpl::SetInputSymbols; + using FstImpl::SetOutputSymbols; + + using CacheImpl::PushArc; + using CacheImpl::HasArcs; + using CacheImpl::HasFinal; + using CacheImpl::HasStart; + using CacheImpl::SetArcs; + using CacheImpl::SetFinal; + using CacheImpl::SetStart; + + typedef A Arc; + typedef typename A::Label Label; + typedef typename A::Weight Weight; + typedef typename A::StateId StateId; + typedef CacheState State; + + RelabelFstImpl(const Fst& fst, + const vector >& ipairs, + const vector >& opairs, + const RelabelFstOptions &opts) + : CacheImpl(opts), fst_(fst.Copy()), + relabel_input_(false), relabel_output_(false) { + uint64 props = fst.Properties(kCopyProperties, false); + SetProperties(RelabelProperties(props)); + SetType("relabel"); + + // create input label map + if (ipairs.size() > 0) { + for (size_t i = 0; i < ipairs.size(); ++i) { + input_map_[ipairs[i].first] = ipairs[i].second; + } + relabel_input_ = true; + } + + // create output label map + if (opairs.size() > 0) { + for (size_t i = 0; i < opairs.size(); ++i) { + output_map_[opairs[i].first] = opairs[i].second; + } + relabel_output_ = true; + } + } + + RelabelFstImpl(const Fst& fst, + const SymbolTable* old_isymbols, + const SymbolTable* new_isymbols, + const SymbolTable* old_osymbols, + const SymbolTable* new_osymbols, + const RelabelFstOptions &opts) + : CacheImpl(opts), fst_(fst.Copy()), + relabel_input_(false), relabel_output_(false) { + SetType("relabel"); + + uint64 props = fst.Properties(kCopyProperties, false); + SetProperties(RelabelProperties(props)); + SetInputSymbols(old_isymbols); + SetOutputSymbols(old_osymbols); + + if (old_isymbols && new_isymbols && + old_isymbols->LabeledCheckSum() != new_isymbols->LabeledCheckSum()) { + for (SymbolTableIterator syms_iter(*old_isymbols); !syms_iter.Done(); + syms_iter.Next()) { + input_map_[syms_iter.Value()] = new_isymbols->Find(syms_iter.Symbol()); + } + SetInputSymbols(new_isymbols); + relabel_input_ = true; + } + + if (old_osymbols && new_osymbols && + old_osymbols->LabeledCheckSum() != new_osymbols->LabeledCheckSum()) { + for (SymbolTableIterator syms_iter(*old_osymbols); !syms_iter.Done(); + syms_iter.Next()) { + output_map_[syms_iter.Value()] = + new_osymbols->Find(syms_iter.Symbol()); + } + SetOutputSymbols(new_osymbols); + relabel_output_ = true; + } + } + + RelabelFstImpl(const RelabelFstImpl& impl) + : CacheImpl(impl), + fst_(impl.fst_->Copy(true)), + input_map_(impl.input_map_), + output_map_(impl.output_map_), + relabel_input_(impl.relabel_input_), + relabel_output_(impl.relabel_output_) { + SetType("relabel"); + SetProperties(impl.Properties(), kCopyProperties); + SetInputSymbols(impl.InputSymbols()); + SetOutputSymbols(impl.OutputSymbols()); + } + + ~RelabelFstImpl() { delete fst_; } + + StateId Start() { + if (!HasStart()) { + StateId s = fst_->Start(); + SetStart(s); + } + return CacheImpl::Start(); + } + + Weight Final(StateId s) { + if (!HasFinal(s)) { + SetFinal(s, fst_->Final(s)); + } + return CacheImpl::Final(s); + } + + size_t NumArcs(StateId s) { + if (!HasArcs(s)) { + Expand(s); + } + return CacheImpl::NumArcs(s); + } + + size_t NumInputEpsilons(StateId s) { + if (!HasArcs(s)) { + Expand(s); + } + return CacheImpl::NumInputEpsilons(s); + } + + size_t NumOutputEpsilons(StateId s) { + if (!HasArcs(s)) { + Expand(s); + } + return CacheImpl::NumOutputEpsilons(s); + } + + uint64 Properties() const { return Properties(kFstProperties); } + + // Set error if found; return FST impl properties. + uint64 Properties(uint64 mask) const { + if ((mask & kError) && fst_->Properties(kError, false)) + SetProperties(kError, kError); + return FstImpl::Properties(mask); + } + + void InitArcIterator(StateId s, ArcIteratorData* data) { + if (!HasArcs(s)) { + Expand(s); + } + CacheImpl::InitArcIterator(s, data); + } + + void Expand(StateId s) { + for (ArcIterator > aiter(*fst_, s); !aiter.Done(); aiter.Next()) { + A arc = aiter.Value(); + + // relabel input + if (relabel_input_) { + typename unordered_map::iterator it = + input_map_.find(arc.ilabel); + if (it != input_map_.end()) { arc.ilabel = it->second; } + } + + // relabel output + if (relabel_output_) { + typename unordered_map::iterator it = + output_map_.find(arc.olabel); + if (it != output_map_.end()) { arc.olabel = it->second; } + } + + PushArc(s, arc); + } + SetArcs(s); + } + + + private: + const Fst *fst_; + + unordered_map input_map_; + unordered_map output_map_; + bool relabel_input_; + bool relabel_output_; + + void operator=(const RelabelFstImpl &); // disallow +}; + + +// +// \class RelabelFst +// \brief Delayed implementation of arc relabeling +// +// This class attaches interface to implementation and handles +// reference counting, delegating most methods to ImplToFst. +template +class RelabelFst : public ImplToFst< RelabelFstImpl > { + public: + friend class ArcIterator< RelabelFst >; + friend class StateIterator< RelabelFst >; + + typedef A Arc; + typedef typename A::Label Label; + typedef typename A::Weight Weight; + typedef typename A::StateId StateId; + typedef CacheState State; + typedef RelabelFstImpl Impl; + + RelabelFst(const Fst& fst, + const vector >& ipairs, + const vector >& opairs) + : ImplToFst(new Impl(fst, ipairs, opairs, RelabelFstOptions())) {} + + RelabelFst(const Fst& fst, + const vector >& ipairs, + const vector >& opairs, + const RelabelFstOptions &opts) + : ImplToFst(new Impl(fst, ipairs, opairs, opts)) {} + + RelabelFst(const Fst& fst, + const SymbolTable* new_isymbols, + const SymbolTable* new_osymbols) + : ImplToFst(new Impl(fst, fst.InputSymbols(), new_isymbols, + fst.OutputSymbols(), new_osymbols, + RelabelFstOptions())) {} + + RelabelFst(const Fst& fst, + const SymbolTable* new_isymbols, + const SymbolTable* new_osymbols, + const RelabelFstOptions &opts) + : ImplToFst(new Impl(fst, fst.InputSymbols(), new_isymbols, + fst.OutputSymbols(), new_osymbols, opts)) {} + + RelabelFst(const Fst& fst, + const SymbolTable* old_isymbols, + const SymbolTable* new_isymbols, + const SymbolTable* old_osymbols, + const SymbolTable* new_osymbols) + : ImplToFst(new Impl(fst, old_isymbols, new_isymbols, old_osymbols, + new_osymbols, RelabelFstOptions())) {} + + RelabelFst(const Fst& fst, + const SymbolTable* old_isymbols, + const SymbolTable* new_isymbols, + const SymbolTable* old_osymbols, + const SymbolTable* new_osymbols, + const RelabelFstOptions &opts) + : ImplToFst(new Impl(fst, old_isymbols, new_isymbols, old_osymbols, + new_osymbols, opts)) {} + + // See Fst<>::Copy() for doc. + RelabelFst(const RelabelFst &fst, bool safe = false) + : ImplToFst(fst, safe) {} + + // Get a copy of this RelabelFst. See Fst<>::Copy() for further doc. + virtual RelabelFst *Copy(bool safe = false) const { + return new RelabelFst(*this, safe); + } + + virtual void InitStateIterator(StateIteratorData *data) const; + + virtual void InitArcIterator(StateId s, ArcIteratorData *data) const { + return GetImpl()->InitArcIterator(s, data); + } + + private: + // Makes visible to friends. + Impl *GetImpl() const { return ImplToFst::GetImpl(); } + + void operator=(const RelabelFst &fst); // disallow +}; + +// Specialization for RelabelFst. +template +class StateIterator< RelabelFst > : public StateIteratorBase { + public: + typedef typename A::StateId StateId; + + explicit StateIterator(const RelabelFst &fst) + : impl_(fst.GetImpl()), siter_(*impl_->fst_), s_(0) {} + + bool Done() const { return siter_.Done(); } + + StateId Value() const { return s_; } + + void Next() { + if (!siter_.Done()) { + ++s_; + siter_.Next(); + } + } + + void Reset() { + s_ = 0; + siter_.Reset(); + } + + private: + bool Done_() const { return Done(); } + StateId Value_() const { return Value(); } + void Next_() { Next(); } + void Reset_() { Reset(); } + + const RelabelFstImpl *impl_; + StateIterator< Fst > siter_; + StateId s_; + + DISALLOW_COPY_AND_ASSIGN(StateIterator); +}; + + +// Specialization for RelabelFst. +template +class ArcIterator< RelabelFst > + : public CacheArcIterator< RelabelFst > { + public: + typedef typename A::StateId StateId; + + ArcIterator(const RelabelFst &fst, StateId s) + : CacheArcIterator< RelabelFst >(fst.GetImpl(), s) { + if (!fst.GetImpl()->HasArcs(s)) + fst.GetImpl()->Expand(s); + } + + private: + DISALLOW_COPY_AND_ASSIGN(ArcIterator); +}; + +template inline +void RelabelFst::InitStateIterator(StateIteratorData *data) const { + data->base = new StateIterator< RelabelFst >(*this); +} + +// Useful alias when using StdArc. +typedef RelabelFst StdRelabelFst; + +} // namespace fst + +#endif // FST_LIB_RELABEL_H__ -- cgit v1.2.3