diff options
Diffstat (limited to 'kaldi_io/src/tools/openfst/include/fst/symbol-table.h')
-rw-r--r-- | kaldi_io/src/tools/openfst/include/fst/symbol-table.h | 537 |
1 files changed, 537 insertions, 0 deletions
diff --git a/kaldi_io/src/tools/openfst/include/fst/symbol-table.h b/kaldi_io/src/tools/openfst/include/fst/symbol-table.h new file mode 100644 index 0000000..6eb6c2d --- /dev/null +++ b/kaldi_io/src/tools/openfst/include/fst/symbol-table.h @@ -0,0 +1,537 @@ + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Copyright 2005-2010 Google, Inc. +// All Rights Reserved. +// +// Author : Johan Schalkwyk +// +// \file +// Classes to provide symbol-to-integer and integer-to-symbol mappings. + +#ifndef FST_LIB_SYMBOL_TABLE_H__ +#define FST_LIB_SYMBOL_TABLE_H__ + +#include <cstring> +#include <string> +#include <utility> +using std::pair; using std::make_pair; +#include <vector> +using std::vector; + + +#include <fst/compat.h> +#include <iostream> +#include <fstream> +#include <sstream> + + +#include <map> + +DECLARE_bool(fst_compat_symbols); + +namespace fst { + +// WARNING: Reading via symbol table read options should +// not be used. This is a temporary work around for +// reading symbol ranges of previously stored symbol sets. +struct SymbolTableReadOptions { + SymbolTableReadOptions() { } + + SymbolTableReadOptions(vector<pair<int64, int64> > string_hash_ranges_, + const string& source_) + : string_hash_ranges(string_hash_ranges_), + source(source_) { } + + vector<pair<int64, int64> > string_hash_ranges; + string source; +}; + +struct SymbolTableTextOptions { + SymbolTableTextOptions(); + + bool allow_negative; + string fst_field_separator; +}; + +class SymbolTableImpl { + public: + SymbolTableImpl(const string &name) + : name_(name), + available_key_(0), + dense_key_limit_(0), + check_sum_finalized_(false) {} + + explicit SymbolTableImpl(const SymbolTableImpl& impl) + : name_(impl.name_), + available_key_(0), + dense_key_limit_(0), + check_sum_finalized_(false) { + for (size_t i = 0; i < impl.symbols_.size(); ++i) { + AddSymbol(impl.symbols_[i], impl.Find(impl.symbols_[i])); + } + } + + ~SymbolTableImpl() { + for (size_t i = 0; i < symbols_.size(); ++i) + delete[] symbols_[i]; + } + + // TODO(johans): Add flag to specify whether the symbol + // should be indexed as string or int or both. + int64 AddSymbol(const string& symbol, int64 key); + + int64 AddSymbol(const string& symbol) { + int64 key = Find(symbol); + return (key == -1) ? AddSymbol(symbol, available_key_++) : key; + } + + static SymbolTableImpl* ReadText( + istream &strm, const string &name, + const SymbolTableTextOptions &opts = SymbolTableTextOptions()); + + static SymbolTableImpl* Read(istream &strm, + const SymbolTableReadOptions& opts); + + bool Write(ostream &strm) const; + + // + // Return the string associated with the key. If the key is out of + // range (<0, >max), return an empty string. + string Find(int64 key) const { + if (key >=0 && key < dense_key_limit_) + return string(symbols_[key]); + + map<int64, const char*>::const_iterator it = + key_map_.find(key); + if (it == key_map_.end()) { + return ""; + } + return string(it->second); + } + + // + // Return the key associated with the symbol. If the symbol + // does not exists, return SymbolTable::kNoSymbol. + int64 Find(const string& symbol) const { + return Find(symbol.c_str()); + } + + // + // Return the key associated with the symbol. If the symbol + // does not exists, return SymbolTable::kNoSymbol. + int64 Find(const char* symbol) const { + map<const char *, int64, StrCmp>::const_iterator it = + symbol_map_.find(symbol); + if (it == symbol_map_.end()) { + return -1; + } + return it->second; + } + + int64 GetNthKey(ssize_t pos) const { + if ((pos < 0) || (pos >= symbols_.size())) return -1; + else return Find(symbols_[pos]); + } + + const string& Name() const { return name_; } + + int IncrRefCount() const { + return ref_count_.Incr(); + } + int DecrRefCount() const { + return ref_count_.Decr(); + } + int RefCount() const { + return ref_count_.count(); + } + + string CheckSum() const { + MaybeRecomputeCheckSum(); + return check_sum_string_; + } + + string LabeledCheckSum() const { + MaybeRecomputeCheckSum(); + return labeled_check_sum_string_; + } + + int64 AvailableKey() const { + return available_key_; + } + + size_t NumSymbols() const { + return symbols_.size(); + } + + private: + // Recomputes the checksums (both of them) if we've had changes since the last + // computation (i.e., if check_sum_finalized_ is false). + // Takes ~2.5 microseconds (dbg) or ~230 nanoseconds (opt) on a 2.67GHz Xeon + // if the checksum is up-to-date (requiring no recomputation). + void MaybeRecomputeCheckSum() const; + + struct StrCmp { + bool operator()(const char *s1, const char *s2) const { + return strcmp(s1, s2) < 0; + } + }; + + string name_; + int64 available_key_; + int64 dense_key_limit_; + vector<const char *> symbols_; + map<int64, const char*> key_map_; + map<const char *, int64, StrCmp> symbol_map_; + + mutable RefCounter ref_count_; + mutable bool check_sum_finalized_; + mutable string check_sum_string_; + mutable string labeled_check_sum_string_; + mutable Mutex check_sum_mutex_; +}; + +// +// \class SymbolTable +// \brief Symbol (string) to int and reverse mapping +// +// The SymbolTable implements the mappings of labels to strings and reverse. +// SymbolTables are used to describe the alphabet of the input and output +// labels for arcs in a Finite State Transducer. +// +// SymbolTables are reference counted and can therefore be shared across +// multiple machines. For example a language model grammar G, with a +// SymbolTable for the words in the language model can share this symbol +// table with the lexical representation L o G. +// +class SymbolTable { + public: + static const int64 kNoSymbol = -1; + + // Construct symbol table with an unspecified name. + SymbolTable() : impl_(new SymbolTableImpl("<unspecified>")) {} + + // Construct symbol table with a unique name. + SymbolTable(const string& name) : impl_(new SymbolTableImpl(name)) {} + + // Create a reference counted copy. + SymbolTable(const SymbolTable& table) : impl_(table.impl_) { + impl_->IncrRefCount(); + } + + // Derefence implentation object. When reference count hits 0, delete + // implementation. + virtual ~SymbolTable() { + if (!impl_->DecrRefCount()) delete impl_; + } + + // Copys the implemenation from one symbol table to another. + void operator=(const SymbolTable &st) { + if (impl_ != st.impl_) { + st.impl_->IncrRefCount(); + if (!impl_->DecrRefCount()) delete impl_; + impl_ = st.impl_; + } + } + + // Read an ascii representation of the symbol table from an istream. Pass a + // name to give the resulting SymbolTable. + static SymbolTable* ReadText( + istream &strm, const string& name, + const SymbolTableTextOptions &opts = SymbolTableTextOptions()) { + SymbolTableImpl* impl = SymbolTableImpl::ReadText(strm, name, opts); + if (!impl) + return 0; + else + return new SymbolTable(impl); + } + + // read an ascii representation of the symbol table + static SymbolTable* ReadText(const string& filename, + const SymbolTableTextOptions &opts = SymbolTableTextOptions()) { + ifstream strm(filename.c_str(), ifstream::in); + if (!strm) { + LOG(ERROR) << "SymbolTable::ReadText: Can't open file " << filename; + return 0; + } + return ReadText(strm, filename, opts); + } + + + // WARNING: Reading via symbol table read options should + // not be used. This is a temporary work around. + static SymbolTable* Read(istream &strm, + const SymbolTableReadOptions& opts) { + SymbolTableImpl* impl = SymbolTableImpl::Read(strm, opts); + if (!impl) + return 0; + else + return new SymbolTable(impl); + } + + // read a binary dump of the symbol table from a stream + static SymbolTable* Read(istream &strm, const string& source) { + SymbolTableReadOptions opts; + opts.source = source; + return Read(strm, opts); + } + + // read a binary dump of the symbol table + static SymbolTable* Read(const string& filename) { + ifstream strm(filename.c_str(), ifstream::in | ifstream::binary); + if (!strm) { + LOG(ERROR) << "SymbolTable::Read: Can't open file " << filename; + return 0; + } + return Read(strm, filename); + } + + //-------------------------------------------------------- + // Derivable Interface (final) + //-------------------------------------------------------- + // create a reference counted copy + virtual SymbolTable* Copy() const { + return new SymbolTable(*this); + } + + // Add a symbol with given key to table. A symbol table also + // keeps track of the last available key (highest key value in + // the symbol table). + virtual int64 AddSymbol(const string& symbol, int64 key) { + MutateCheck(); + return impl_->AddSymbol(symbol, key); + } + + // Add a symbol to the table. The associated value key is automatically + // assigned by the symbol table. + virtual int64 AddSymbol(const string& symbol) { + MutateCheck(); + return impl_->AddSymbol(symbol); + } + + // Add another symbol table to this table. All key values will be offset + // by the current available key (highest key value in the symbol table). + // Note string symbols with the same key value with still have the same + // key value after the symbol table has been merged, but a different + // value. Adding symbol tables do not result in changes in the base table. + virtual void AddTable(const SymbolTable& table); + + // return the name of the symbol table + virtual const string& Name() const { + return impl_->Name(); + } + + // Return the label-agnostic MD5 check-sum for this table. All new symbols + // added to the table will result in an updated checksum. + // DEPRECATED. + virtual string CheckSum() const { + return impl_->CheckSum(); + } + + // Same as CheckSum(), but this returns an label-dependent version. + virtual string LabeledCheckSum() const { + return impl_->LabeledCheckSum(); + } + + virtual bool Write(ostream &strm) const { + return impl_->Write(strm); + } + + bool Write(const string& filename) const { + ofstream strm(filename.c_str(), ofstream::out | ofstream::binary); + if (!strm) { + LOG(ERROR) << "SymbolTable::Write: Can't open file " << filename; + return false; + } + return Write(strm); + } + + // Dump an ascii text representation of the symbol table via a stream + virtual bool WriteText( + ostream &strm, + const SymbolTableTextOptions &opts = SymbolTableTextOptions()) const; + + // Dump an ascii text representation of the symbol table + bool WriteText(const string& filename) const { + ofstream strm(filename.c_str()); + if (!strm) { + LOG(ERROR) << "SymbolTable::WriteText: Can't open file " << filename; + return false; + } + return WriteText(strm); + } + + // Return the string associated with the key. If the key is out of + // range (<0, >max), log error and return an empty string. + virtual string Find(int64 key) const { + return impl_->Find(key); + } + + // Return the key associated with the symbol. If the symbol + // does not exists, log error and return SymbolTable::kNoSymbol + virtual int64 Find(const string& symbol) const { + return impl_->Find(symbol); + } + + // Return the key associated with the symbol. If the symbol + // does not exists, log error and return SymbolTable::kNoSymbol + virtual int64 Find(const char* symbol) const { + return impl_->Find(symbol); + } + + // Return the current available key (i.e highest key number+1) in + // the symbol table + virtual int64 AvailableKey(void) const { + return impl_->AvailableKey(); + } + + // Return the current number of symbols in table (not necessarily + // equal to AvailableKey()) + virtual size_t NumSymbols(void) const { + return impl_->NumSymbols(); + } + + virtual int64 GetNthKey(ssize_t pos) const { + return impl_->GetNthKey(pos); + } + + private: + explicit SymbolTable(SymbolTableImpl* impl) : impl_(impl) {} + + void MutateCheck() { + // Copy on write + if (impl_->RefCount() > 1) { + impl_->DecrRefCount(); + impl_ = new SymbolTableImpl(*impl_); + } + } + + const SymbolTableImpl* Impl() const { + return impl_; + } + + private: + SymbolTableImpl* impl_; +}; + + +// +// \class SymbolTableIterator +// \brief Iterator class for symbols in a symbol table +class SymbolTableIterator { + public: + SymbolTableIterator(const SymbolTable& table) + : table_(table), + pos_(0), + nsymbols_(table.NumSymbols()), + key_(table.GetNthKey(0)) { } + + ~SymbolTableIterator() { } + + // is iterator done + bool Done(void) { + return (pos_ == nsymbols_); + } + + // return the Value() of the current symbol (int64 key) + int64 Value(void) { + return key_; + } + + // return the string of the current symbol + string Symbol(void) { + return table_.Find(key_); + } + + // advance iterator forward + void Next(void) { + ++pos_; + if (pos_ < nsymbols_) key_ = table_.GetNthKey(pos_); + } + + // reset iterator + void Reset(void) { + pos_ = 0; + key_ = table_.GetNthKey(0); + } + + private: + const SymbolTable& table_; + ssize_t pos_; + size_t nsymbols_; + int64 key_; +}; + + +// Tests compatibilty between two sets of symbol tables +inline bool CompatSymbols(const SymbolTable *syms1, const SymbolTable *syms2, + bool warning = true) { + if (!FLAGS_fst_compat_symbols) { + return true; + } else if (!syms1 && !syms2) { + return true; + } else if (syms1 && !syms2) { + if (warning) + LOG(WARNING) << + "CompatSymbols: first symbol table present but second missing"; + return false; + } else if (!syms1 && syms2) { + if (warning) + LOG(WARNING) << + "CompatSymbols: second symbol table present but first missing"; + return false; + } else if (syms1->LabeledCheckSum() != syms2->LabeledCheckSum()) { + if (warning) + LOG(WARNING) << "CompatSymbols: Symbol table check sums do not match"; + return false; + } else { + return true; + } +} + + +// Relabels a symbol table as specified by the input vector of pairs +// (old label, new label). The new symbol table only retains symbols +// for which a relabeling is *explicitely* specified. +// TODO(allauzen): consider adding options to allow for some form +// of implicit identity relabeling. +template <class Label> +SymbolTable *RelabelSymbolTable(const SymbolTable *table, + const vector<pair<Label, Label> > &pairs) { + SymbolTable *new_table = new SymbolTable( + table->Name().empty() ? string() : + (string("relabeled_") + table->Name())); + + for (size_t i = 0; i < pairs.size(); ++i) + new_table->AddSymbol(table->Find(pairs[i].first), pairs[i].second); + + return new_table; +} + +// Symbol Table Serialization +inline void SymbolTableToString(const SymbolTable *table, string *result) { + ostringstream ostrm; + table->Write(ostrm); + *result = ostrm.str(); +} + +inline SymbolTable *StringToSymbolTable(const string &s) { + istringstream istrm(s); + return SymbolTable::Read(istrm, SymbolTableReadOptions()); +} + + + +} // namespace fst + +#endif // FST_LIB_SYMBOL_TABLE_H__ |