From 96a32415ab43377cf1575bd3f4f2980f58028209 Mon Sep 17 00:00:00 2001 From: Determinant Date: Fri, 14 Aug 2015 11:51:42 +0800 Subject: add implementation for kaldi io (by ymz) --- kaldi_io/src/tools/openfst/include/fst/encode.h | 599 ++++++++++++++++++++++++ 1 file changed, 599 insertions(+) create mode 100644 kaldi_io/src/tools/openfst/include/fst/encode.h (limited to 'kaldi_io/src/tools/openfst/include/fst/encode.h') diff --git a/kaldi_io/src/tools/openfst/include/fst/encode.h b/kaldi_io/src/tools/openfst/include/fst/encode.h new file mode 100644 index 0000000..08b84cb --- /dev/null +++ b/kaldi_io/src/tools/openfst/include/fst/encode.h @@ -0,0 +1,599 @@ +// encode.h + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Copyright 2005-2010 Google, Inc. +// Author: johans@google.com (Johan Schalkwyk) +// +// \file +// Class to encode and decoder an fst. + +#ifndef FST_LIB_ENCODE_H__ +#define FST_LIB_ENCODE_H__ + +#include +#include +using std::tr1::unordered_map; +using std::tr1::unordered_multimap; +#include +#include +using std::vector; + +#include +#include + + +namespace fst { + +static const uint32 kEncodeLabels = 0x0001; +static const uint32 kEncodeWeights = 0x0002; +static const uint32 kEncodeFlags = 0x0003; // All non-internal flags + +static const uint32 kEncodeHasISymbols = 0x0004; // For internal use +static const uint32 kEncodeHasOSymbols = 0x0008; // For internal use + +enum EncodeType { ENCODE = 1, DECODE = 2 }; + +// Identifies stream data as an encode table (and its endianity) +static const int32 kEncodeMagicNumber = 2129983209; + + +// The following class encapsulates implementation details for the +// encoding and decoding of label/weight tuples used for encoding +// and decoding of Fsts. The EncodeTable is bidirectional. I.E it +// stores both the Tuple of encode labels and weights to a unique +// label, and the reverse. +template class EncodeTable { + public: + typedef typename A::Label Label; + typedef typename A::Weight Weight; + + // Encoded data consists of arc input/output labels and arc weight + struct Tuple { + Tuple() {} + Tuple(Label ilabel_, Label olabel_, Weight weight_) + : ilabel(ilabel_), olabel(olabel_), weight(weight_) {} + Tuple(const Tuple& tuple) + : ilabel(tuple.ilabel), olabel(tuple.olabel), weight(tuple.weight) {} + + Label ilabel; + Label olabel; + Weight weight; + }; + + // Comparison object for hashing EncodeTable Tuple(s). + class TupleEqual { + public: + bool operator()(const Tuple* x, const Tuple* y) const { + return (x->ilabel == y->ilabel && + x->olabel == y->olabel && + x->weight == y->weight); + } + }; + + // Hash function for EncodeTabe Tuples. Based on the encode flags + // we either hash the labels, weights or combination of them. + class TupleKey { + public: + TupleKey() + : encode_flags_(kEncodeLabels | kEncodeWeights) {} + + TupleKey(const TupleKey& key) + : encode_flags_(key.encode_flags_) {} + + explicit TupleKey(uint32 encode_flags) + : encode_flags_(encode_flags) {} + + size_t operator()(const Tuple* x) const { + size_t hash = x->ilabel; + const int lshift = 5; + const int rshift = CHAR_BIT * sizeof(size_t) - 5; + if (encode_flags_ & kEncodeLabels) + hash = hash << lshift ^ hash >> rshift ^ x->olabel; + if (encode_flags_ & kEncodeWeights) + hash = hash << lshift ^ hash >> rshift ^ x->weight.Hash(); + return hash; + } + + private: + int32 encode_flags_; + }; + + typedef unordered_map EncodeHash; + + explicit EncodeTable(uint32 encode_flags) + : flags_(encode_flags), + encode_hash_(1024, TupleKey(encode_flags)), + isymbols_(0), osymbols_(0) {} + + ~EncodeTable() { + for (size_t i = 0; i < encode_tuples_.size(); ++i) { + delete encode_tuples_[i]; + } + delete isymbols_; + delete osymbols_; + } + + // Given an arc encode either input/ouptut labels or input/costs or both + Label Encode(const A &arc) { + const Tuple tuple(arc.ilabel, + flags_ & kEncodeLabels ? arc.olabel : 0, + flags_ & kEncodeWeights ? arc.weight : Weight::One()); + typename EncodeHash::const_iterator it = encode_hash_.find(&tuple); + if (it == encode_hash_.end()) { + encode_tuples_.push_back(new Tuple(tuple)); + encode_hash_[encode_tuples_.back()] = encode_tuples_.size(); + return encode_tuples_.size(); + } else { + return it->second; + } + } + + // Given an arc, look up its encoded label. Returns kNoLabel if not found. + Label GetLabel(const A &arc) const { + const Tuple tuple(arc.ilabel, + flags_ & kEncodeLabels ? arc.olabel : 0, + flags_ & kEncodeWeights ? arc.weight : Weight::One()); + typename EncodeHash::const_iterator it = encode_hash_.find(&tuple); + if (it == encode_hash_.end()) { + return kNoLabel; + } else { + return it->second; + } + } + + // Given an encode arc Label decode back to input/output labels and costs + const Tuple* Decode(Label key) const { + if (key < 1 || key > encode_tuples_.size()) { + LOG(ERROR) << "EncodeTable::Decode: unknown decode key: " << key; + return 0; + } + return encode_tuples_[key - 1]; + } + + size_t Size() const { return encode_tuples_.size(); } + + bool Write(ostream &strm, const string &source) const; + + static EncodeTable *Read(istream &strm, const string &source); + + const uint32 flags() const { return flags_ & kEncodeFlags; } + + int RefCount() const { return ref_count_.count(); } + int IncrRefCount() { return ref_count_.Incr(); } + int DecrRefCount() { return ref_count_.Decr(); } + + + SymbolTable *InputSymbols() const { return isymbols_; } + + SymbolTable *OutputSymbols() const { return osymbols_; } + + void SetInputSymbols(const SymbolTable* syms) { + if (isymbols_) delete isymbols_; + if (syms) { + isymbols_ = syms->Copy(); + flags_ |= kEncodeHasISymbols; + } else { + isymbols_ = 0; + flags_ &= ~kEncodeHasISymbols; + } + } + + void SetOutputSymbols(const SymbolTable* syms) { + if (osymbols_) delete osymbols_; + if (syms) { + osymbols_ = syms->Copy(); + flags_ |= kEncodeHasOSymbols; + } else { + osymbols_ = 0; + flags_ &= ~kEncodeHasOSymbols; + } + } + + private: + uint32 flags_; + vector encode_tuples_; + EncodeHash encode_hash_; + RefCounter ref_count_; + SymbolTable *isymbols_; // Pre-encoded ilabel symbol table + SymbolTable *osymbols_; // Pre-encoded olabel symbol table + + DISALLOW_COPY_AND_ASSIGN(EncodeTable); +}; + +template inline +bool EncodeTable::Write(ostream &strm, const string &source) const { + WriteType(strm, kEncodeMagicNumber); + WriteType(strm, flags_); + int64 size = encode_tuples_.size(); + WriteType(strm, size); + for (size_t i = 0; i < size; ++i) { + const Tuple* tuple = encode_tuples_[i]; + WriteType(strm, tuple->ilabel); + WriteType(strm, tuple->olabel); + tuple->weight.Write(strm); + } + + if (flags_ & kEncodeHasISymbols) + isymbols_->Write(strm); + + if (flags_ & kEncodeHasOSymbols) + osymbols_->Write(strm); + + strm.flush(); + if (!strm) { + LOG(ERROR) << "EncodeTable::Write: write failed: " << source; + return false; + } + return true; +} + +template inline +EncodeTable *EncodeTable::Read(istream &strm, const string &source) { + int32 magic_number = 0; + ReadType(strm, &magic_number); + if (magic_number != kEncodeMagicNumber) { + LOG(ERROR) << "EncodeTable::Read: Bad encode table header: " << source; + return 0; + } + uint32 flags; + ReadType(strm, &flags); + EncodeTable *table = new EncodeTable(flags); + + int64 size; + ReadType(strm, &size); + if (!strm) { + LOG(ERROR) << "EncodeTable::Read: read failed: " << source; + return 0; + } + + for (size_t i = 0; i < size; ++i) { + Tuple* tuple = new Tuple(); + ReadType(strm, &tuple->ilabel); + ReadType(strm, &tuple->olabel); + tuple->weight.Read(strm); + if (!strm) { + LOG(ERROR) << "EncodeTable::Read: read failed: " << source; + return 0; + } + table->encode_tuples_.push_back(tuple); + table->encode_hash_[table->encode_tuples_.back()] = + table->encode_tuples_.size(); + } + + if (flags & kEncodeHasISymbols) + table->isymbols_ = SymbolTable::Read(strm, source); + + if (flags & kEncodeHasOSymbols) + table->osymbols_ = SymbolTable::Read(strm, source); + + return table; +} + + +// A mapper to encode/decode weighted transducers. Encoding of an +// Fst is useful for performing classical determinization or minimization +// on a weighted transducer by treating it as an unweighted acceptor over +// encoded labels. +// +// The Encode mapper stores the encoding in a local hash table (EncodeTable) +// This table is shared (and reference counted) between the encoder and +// decoder. A decoder has read only access to the EncodeTable. +// +// The EncodeMapper allows on the fly encoding of the machine. As the +// EncodeTable is generated the same table may by used to decode the machine +// on the fly. For example in the following sequence of operations +// +// Encode -> Determinize -> Decode +// +// we will use the encoding table generated during the encode step in the +// decode, even though the encoding is not complete. +// +template class EncodeMapper { + typedef typename A::Weight Weight; + typedef typename A::Label Label; + public: + EncodeMapper(uint32 flags, EncodeType type) + : flags_(flags), + type_(type), + table_(new EncodeTable(flags)), + error_(false) {} + + EncodeMapper(const EncodeMapper& mapper) + : flags_(mapper.flags_), + type_(mapper.type_), + table_(mapper.table_), + error_(false) { + table_->IncrRefCount(); + } + + // Copy constructor but setting the type, typically to DECODE + EncodeMapper(const EncodeMapper& mapper, EncodeType type) + : flags_(mapper.flags_), + type_(type), + table_(mapper.table_), + error_(mapper.error_) { + table_->IncrRefCount(); + } + + ~EncodeMapper() { + if (!table_->DecrRefCount()) delete table_; + } + + A operator()(const A &arc); + + MapFinalAction FinalAction() const { + return (type_ == ENCODE && (flags_ & kEncodeWeights)) ? + MAP_REQUIRE_SUPERFINAL : MAP_NO_SUPERFINAL; + } + + MapSymbolsAction InputSymbolsAction() const { return MAP_CLEAR_SYMBOLS; } + + MapSymbolsAction OutputSymbolsAction() const { return MAP_CLEAR_SYMBOLS;} + + uint64 Properties(uint64 inprops) { + uint64 outprops = inprops; + if (error_) outprops |= kError; + + uint64 mask = kFstProperties; + if (flags_ & kEncodeLabels) + mask &= kILabelInvariantProperties & kOLabelInvariantProperties; + if (flags_ & kEncodeWeights) + mask &= kILabelInvariantProperties & kWeightInvariantProperties & + (type_ == ENCODE ? kAddSuperFinalProperties : + kRmSuperFinalProperties); + + return outprops & mask; + } + + const uint32 flags() const { return flags_; } + const EncodeType type() const { return type_; } + const EncodeTable &table() const { return *table_; } + + bool Write(ostream &strm, const string& source) { + return table_->Write(strm, source); + } + + bool Write(const string& filename) { + ofstream strm(filename.c_str(), ofstream::out | ofstream::binary); + if (!strm) { + LOG(ERROR) << "EncodeMap: Can't open file: " << filename; + return false; + } + return Write(strm, filename); + } + + static EncodeMapper *Read(istream &strm, + const string& source, + EncodeType type = ENCODE) { + EncodeTable *table = EncodeTable::Read(strm, source); + return table ? new EncodeMapper(table->flags(), type, table) : 0; + } + + static EncodeMapper *Read(const string& filename, + EncodeType type = ENCODE) { + ifstream strm(filename.c_str(), ifstream::in | ifstream::binary); + if (!strm) { + LOG(ERROR) << "EncodeMap: Can't open file: " << filename; + return NULL; + } + return Read(strm, filename, type); + } + + SymbolTable *InputSymbols() const { return table_->InputSymbols(); } + + SymbolTable *OutputSymbols() const { return table_->OutputSymbols(); } + + void SetInputSymbols(const SymbolTable* syms) { + table_->SetInputSymbols(syms); + } + + void SetOutputSymbols(const SymbolTable* syms) { + table_->SetOutputSymbols(syms); + } + + private: + uint32 flags_; + EncodeType type_; + EncodeTable* table_; + bool error_; + + explicit EncodeMapper(uint32 flags, EncodeType type, EncodeTable *table) + : flags_(flags), type_(type), table_(table) {} + void operator=(const EncodeMapper &); // Disallow. +}; + +template inline +A EncodeMapper::operator()(const A &arc) { + if (type_ == ENCODE) { // labels and/or weights to single label + if ((arc.nextstate == kNoStateId && !(flags_ & kEncodeWeights)) || + (arc.nextstate == kNoStateId && (flags_ & kEncodeWeights) && + arc.weight == Weight::Zero())) { + return arc; + } else { + Label label = table_->Encode(arc); + return A(label, + flags_ & kEncodeLabels ? label : arc.olabel, + flags_ & kEncodeWeights ? Weight::One() : arc.weight, + arc.nextstate); + } + } else { // type_ == DECODE + if (arc.nextstate == kNoStateId) { + return arc; + } else { + if (arc.ilabel == 0) return arc; + if (flags_ & kEncodeLabels && arc.ilabel != arc.olabel) { + FSTERROR() << "EncodeMapper: Label-encoded arc has different " + "input and output labels"; + error_ = true; + } + if (flags_ & kEncodeWeights && arc.weight != Weight::One()) { + FSTERROR() << + "EncodeMapper: Weight-encoded arc has non-trivial weight"; + error_ = true; + } + const typename EncodeTable::Tuple* tuple = table_->Decode(arc.ilabel); + if (!tuple) { + FSTERROR() << "EncodeMapper: decode failed"; + error_ = true; + return A(kNoLabel, kNoLabel, Weight::NoWeight(), arc.nextstate); + } else { + return A(tuple->ilabel, + flags_ & kEncodeLabels ? tuple->olabel : arc.olabel, + flags_ & kEncodeWeights ? tuple->weight : arc.weight, + arc.nextstate); + } + } + } +} + + +// Complexity: O(nstates + narcs) +template inline +void Encode(MutableFst *fst, EncodeMapper* mapper) { + mapper->SetInputSymbols(fst->InputSymbols()); + mapper->SetOutputSymbols(fst->OutputSymbols()); + ArcMap(fst, mapper); +} + +template inline +void Decode(MutableFst* fst, const EncodeMapper& mapper) { + ArcMap(fst, EncodeMapper(mapper, DECODE)); + RmFinalEpsilon(fst); + fst->SetInputSymbols(mapper.InputSymbols()); + fst->SetOutputSymbols(mapper.OutputSymbols()); +} + + +// On the fly label and/or weight encoding of input Fst +// +// Complexity: +// - Constructor: O(1) +// - Traversal: O(nstates_visited + narcs_visited), assuming constant +// time to visit an input state or arc. +template +class EncodeFst : public ArcMapFst > { + public: + typedef A Arc; + typedef EncodeMapper C; + typedef ArcMapFstImpl< A, A, EncodeMapper > Impl; + using ImplToFst::GetImpl; + + EncodeFst(const Fst &fst, EncodeMapper* encoder) + : ArcMapFst(fst, encoder, ArcMapFstOptions()) { + encoder->SetInputSymbols(fst.InputSymbols()); + encoder->SetOutputSymbols(fst.OutputSymbols()); + } + + EncodeFst(const Fst &fst, const EncodeMapper& encoder) + : ArcMapFst(fst, encoder, ArcMapFstOptions()) {} + + // See Fst<>::Copy() for doc. + EncodeFst(const EncodeFst &fst, bool copy = false) + : ArcMapFst(fst, copy) {} + + // Get a copy of this EncodeFst. See Fst<>::Copy() for further doc. + virtual EncodeFst *Copy(bool safe = false) const { + if (safe) { + FSTERROR() << "EncodeFst::Copy(true): not allowed."; + GetImpl()->SetProperties(kError, kError); + } + return new EncodeFst(*this); + } +}; + + +// On the fly label and/or weight encoding of input Fst +// +// Complexity: +// - Constructor: O(1) +// - Traversal: O(nstates_visited + narcs_visited), assuming constant +// time to visit an input state or arc. +template +class DecodeFst : public ArcMapFst > { + public: + typedef A Arc; + typedef EncodeMapper C; + typedef ArcMapFstImpl< A, A, EncodeMapper > Impl; + using ImplToFst::GetImpl; + + DecodeFst(const Fst &fst, const EncodeMapper& encoder) + : ArcMapFst(fst, + EncodeMapper(encoder, DECODE), + ArcMapFstOptions()) { + GetImpl()->SetInputSymbols(encoder.InputSymbols()); + GetImpl()->SetOutputSymbols(encoder.OutputSymbols()); + } + + // See Fst<>::Copy() for doc. + DecodeFst(const DecodeFst &fst, bool safe = false) + : ArcMapFst(fst, safe) {} + + // Get a copy of this DecodeFst. See Fst<>::Copy() for further doc. + virtual DecodeFst *Copy(bool safe = false) const { + return new DecodeFst(*this, safe); + } +}; + + +// Specialization for EncodeFst. +template +class StateIterator< EncodeFst > + : public StateIterator< ArcMapFst > > { + public: + explicit StateIterator(const EncodeFst &fst) + : StateIterator< ArcMapFst > >(fst) {} +}; + + +// Specialization for EncodeFst. +template +class ArcIterator< EncodeFst > + : public ArcIterator< ArcMapFst > > { + public: + ArcIterator(const EncodeFst &fst, typename A::StateId s) + : ArcIterator< ArcMapFst > >(fst, s) {} +}; + + +// Specialization for DecodeFst. +template +class StateIterator< DecodeFst > + : public StateIterator< ArcMapFst > > { + public: + explicit StateIterator(const DecodeFst &fst) + : StateIterator< ArcMapFst > >(fst) {} +}; + + +// Specialization for DecodeFst. +template +class ArcIterator< DecodeFst > + : public ArcIterator< ArcMapFst > > { + public: + ArcIterator(const DecodeFst &fst, typename A::StateId s) + : ArcIterator< ArcMapFst > >(fst, s) {} +}; + + +// Useful aliases when using StdArc. +typedef EncodeFst StdEncodeFst; + +typedef DecodeFst StdDecodeFst; + +} // namespace fst + +#endif // FST_LIB_ENCODE_H__ -- cgit v1.2.3