summaryrefslogtreecommitdiff
path: root/kaldi_io/src/tools/openfst/include/fst/symbol-table.h
diff options
context:
space:
mode:
authorDeterminant <[email protected]>2015-08-14 11:51:42 +0800
committerDeterminant <[email protected]>2015-08-14 11:51:42 +0800
commit96a32415ab43377cf1575bd3f4f2980f58028209 (patch)
tree30a2d92d73e8f40ac87b79f6f56e227bfc4eea6e /kaldi_io/src/tools/openfst/include/fst/symbol-table.h
parentc177a7549bd90670af4b29fa813ddea32cfe0f78 (diff)
add implementation for kaldi io (by ymz)
Diffstat (limited to 'kaldi_io/src/tools/openfst/include/fst/symbol-table.h')
-rw-r--r--kaldi_io/src/tools/openfst/include/fst/symbol-table.h537
1 files changed, 537 insertions, 0 deletions
diff --git a/kaldi_io/src/tools/openfst/include/fst/symbol-table.h b/kaldi_io/src/tools/openfst/include/fst/symbol-table.h
new file mode 100644
index 0000000..6eb6c2d
--- /dev/null
+++ b/kaldi_io/src/tools/openfst/include/fst/symbol-table.h
@@ -0,0 +1,537 @@
+
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+//
+// Copyright 2005-2010 Google, Inc.
+// All Rights Reserved.
+//
+// Author : Johan Schalkwyk
+//
+// \file
+// Classes to provide symbol-to-integer and integer-to-symbol mappings.
+
+#ifndef FST_LIB_SYMBOL_TABLE_H__
+#define FST_LIB_SYMBOL_TABLE_H__
+
+#include <cstring>
+#include <string>
+#include <utility>
+using std::pair; using std::make_pair;
+#include <vector>
+using std::vector;
+
+
+#include <fst/compat.h>
+#include <iostream>
+#include <fstream>
+#include <sstream>
+
+
+#include <map>
+
+DECLARE_bool(fst_compat_symbols);
+
+namespace fst {
+
+// WARNING: Reading via symbol table read options should
+// not be used. This is a temporary work around for
+// reading symbol ranges of previously stored symbol sets.
+struct SymbolTableReadOptions {
+ SymbolTableReadOptions() { }
+
+ SymbolTableReadOptions(vector<pair<int64, int64> > string_hash_ranges_,
+ const string& source_)
+ : string_hash_ranges(string_hash_ranges_),
+ source(source_) { }
+
+ vector<pair<int64, int64> > string_hash_ranges;
+ string source;
+};
+
+struct SymbolTableTextOptions {
+ SymbolTableTextOptions();
+
+ bool allow_negative;
+ string fst_field_separator;
+};
+
+class SymbolTableImpl {
+ public:
+ SymbolTableImpl(const string &name)
+ : name_(name),
+ available_key_(0),
+ dense_key_limit_(0),
+ check_sum_finalized_(false) {}
+
+ explicit SymbolTableImpl(const SymbolTableImpl& impl)
+ : name_(impl.name_),
+ available_key_(0),
+ dense_key_limit_(0),
+ check_sum_finalized_(false) {
+ for (size_t i = 0; i < impl.symbols_.size(); ++i) {
+ AddSymbol(impl.symbols_[i], impl.Find(impl.symbols_[i]));
+ }
+ }
+
+ ~SymbolTableImpl() {
+ for (size_t i = 0; i < symbols_.size(); ++i)
+ delete[] symbols_[i];
+ }
+
+ // TODO(johans): Add flag to specify whether the symbol
+ // should be indexed as string or int or both.
+ int64 AddSymbol(const string& symbol, int64 key);
+
+ int64 AddSymbol(const string& symbol) {
+ int64 key = Find(symbol);
+ return (key == -1) ? AddSymbol(symbol, available_key_++) : key;
+ }
+
+ static SymbolTableImpl* ReadText(
+ istream &strm, const string &name,
+ const SymbolTableTextOptions &opts = SymbolTableTextOptions());
+
+ static SymbolTableImpl* Read(istream &strm,
+ const SymbolTableReadOptions& opts);
+
+ bool Write(ostream &strm) const;
+
+ //
+ // Return the string associated with the key. If the key is out of
+ // range (<0, >max), return an empty string.
+ string Find(int64 key) const {
+ if (key >=0 && key < dense_key_limit_)
+ return string(symbols_[key]);
+
+ map<int64, const char*>::const_iterator it =
+ key_map_.find(key);
+ if (it == key_map_.end()) {
+ return "";
+ }
+ return string(it->second);
+ }
+
+ //
+ // Return the key associated with the symbol. If the symbol
+ // does not exists, return SymbolTable::kNoSymbol.
+ int64 Find(const string& symbol) const {
+ return Find(symbol.c_str());
+ }
+
+ //
+ // Return the key associated with the symbol. If the symbol
+ // does not exists, return SymbolTable::kNoSymbol.
+ int64 Find(const char* symbol) const {
+ map<const char *, int64, StrCmp>::const_iterator it =
+ symbol_map_.find(symbol);
+ if (it == symbol_map_.end()) {
+ return -1;
+ }
+ return it->second;
+ }
+
+ int64 GetNthKey(ssize_t pos) const {
+ if ((pos < 0) || (pos >= symbols_.size())) return -1;
+ else return Find(symbols_[pos]);
+ }
+
+ const string& Name() const { return name_; }
+
+ int IncrRefCount() const {
+ return ref_count_.Incr();
+ }
+ int DecrRefCount() const {
+ return ref_count_.Decr();
+ }
+ int RefCount() const {
+ return ref_count_.count();
+ }
+
+ string CheckSum() const {
+ MaybeRecomputeCheckSum();
+ return check_sum_string_;
+ }
+
+ string LabeledCheckSum() const {
+ MaybeRecomputeCheckSum();
+ return labeled_check_sum_string_;
+ }
+
+ int64 AvailableKey() const {
+ return available_key_;
+ }
+
+ size_t NumSymbols() const {
+ return symbols_.size();
+ }
+
+ private:
+ // Recomputes the checksums (both of them) if we've had changes since the last
+ // computation (i.e., if check_sum_finalized_ is false).
+ // Takes ~2.5 microseconds (dbg) or ~230 nanoseconds (opt) on a 2.67GHz Xeon
+ // if the checksum is up-to-date (requiring no recomputation).
+ void MaybeRecomputeCheckSum() const;
+
+ struct StrCmp {
+ bool operator()(const char *s1, const char *s2) const {
+ return strcmp(s1, s2) < 0;
+ }
+ };
+
+ string name_;
+ int64 available_key_;
+ int64 dense_key_limit_;
+ vector<const char *> symbols_;
+ map<int64, const char*> key_map_;
+ map<const char *, int64, StrCmp> symbol_map_;
+
+ mutable RefCounter ref_count_;
+ mutable bool check_sum_finalized_;
+ mutable string check_sum_string_;
+ mutable string labeled_check_sum_string_;
+ mutable Mutex check_sum_mutex_;
+};
+
+//
+// \class SymbolTable
+// \brief Symbol (string) to int and reverse mapping
+//
+// The SymbolTable implements the mappings of labels to strings and reverse.
+// SymbolTables are used to describe the alphabet of the input and output
+// labels for arcs in a Finite State Transducer.
+//
+// SymbolTables are reference counted and can therefore be shared across
+// multiple machines. For example a language model grammar G, with a
+// SymbolTable for the words in the language model can share this symbol
+// table with the lexical representation L o G.
+//
+class SymbolTable {
+ public:
+ static const int64 kNoSymbol = -1;
+
+ // Construct symbol table with an unspecified name.
+ SymbolTable() : impl_(new SymbolTableImpl("<unspecified>")) {}
+
+ // Construct symbol table with a unique name.
+ SymbolTable(const string& name) : impl_(new SymbolTableImpl(name)) {}
+
+ // Create a reference counted copy.
+ SymbolTable(const SymbolTable& table) : impl_(table.impl_) {
+ impl_->IncrRefCount();
+ }
+
+ // Derefence implentation object. When reference count hits 0, delete
+ // implementation.
+ virtual ~SymbolTable() {
+ if (!impl_->DecrRefCount()) delete impl_;
+ }
+
+ // Copys the implemenation from one symbol table to another.
+ void operator=(const SymbolTable &st) {
+ if (impl_ != st.impl_) {
+ st.impl_->IncrRefCount();
+ if (!impl_->DecrRefCount()) delete impl_;
+ impl_ = st.impl_;
+ }
+ }
+
+ // Read an ascii representation of the symbol table from an istream. Pass a
+ // name to give the resulting SymbolTable.
+ static SymbolTable* ReadText(
+ istream &strm, const string& name,
+ const SymbolTableTextOptions &opts = SymbolTableTextOptions()) {
+ SymbolTableImpl* impl = SymbolTableImpl::ReadText(strm, name, opts);
+ if (!impl)
+ return 0;
+ else
+ return new SymbolTable(impl);
+ }
+
+ // read an ascii representation of the symbol table
+ static SymbolTable* ReadText(const string& filename,
+ const SymbolTableTextOptions &opts = SymbolTableTextOptions()) {
+ ifstream strm(filename.c_str(), ifstream::in);
+ if (!strm) {
+ LOG(ERROR) << "SymbolTable::ReadText: Can't open file " << filename;
+ return 0;
+ }
+ return ReadText(strm, filename, opts);
+ }
+
+
+ // WARNING: Reading via symbol table read options should
+ // not be used. This is a temporary work around.
+ static SymbolTable* Read(istream &strm,
+ const SymbolTableReadOptions& opts) {
+ SymbolTableImpl* impl = SymbolTableImpl::Read(strm, opts);
+ if (!impl)
+ return 0;
+ else
+ return new SymbolTable(impl);
+ }
+
+ // read a binary dump of the symbol table from a stream
+ static SymbolTable* Read(istream &strm, const string& source) {
+ SymbolTableReadOptions opts;
+ opts.source = source;
+ return Read(strm, opts);
+ }
+
+ // read a binary dump of the symbol table
+ static SymbolTable* Read(const string& filename) {
+ ifstream strm(filename.c_str(), ifstream::in | ifstream::binary);
+ if (!strm) {
+ LOG(ERROR) << "SymbolTable::Read: Can't open file " << filename;
+ return 0;
+ }
+ return Read(strm, filename);
+ }
+
+ //--------------------------------------------------------
+ // Derivable Interface (final)
+ //--------------------------------------------------------
+ // create a reference counted copy
+ virtual SymbolTable* Copy() const {
+ return new SymbolTable(*this);
+ }
+
+ // Add a symbol with given key to table. A symbol table also
+ // keeps track of the last available key (highest key value in
+ // the symbol table).
+ virtual int64 AddSymbol(const string& symbol, int64 key) {
+ MutateCheck();
+ return impl_->AddSymbol(symbol, key);
+ }
+
+ // Add a symbol to the table. The associated value key is automatically
+ // assigned by the symbol table.
+ virtual int64 AddSymbol(const string& symbol) {
+ MutateCheck();
+ return impl_->AddSymbol(symbol);
+ }
+
+ // Add another symbol table to this table. All key values will be offset
+ // by the current available key (highest key value in the symbol table).
+ // Note string symbols with the same key value with still have the same
+ // key value after the symbol table has been merged, but a different
+ // value. Adding symbol tables do not result in changes in the base table.
+ virtual void AddTable(const SymbolTable& table);
+
+ // return the name of the symbol table
+ virtual const string& Name() const {
+ return impl_->Name();
+ }
+
+ // Return the label-agnostic MD5 check-sum for this table. All new symbols
+ // added to the table will result in an updated checksum.
+ // DEPRECATED.
+ virtual string CheckSum() const {
+ return impl_->CheckSum();
+ }
+
+ // Same as CheckSum(), but this returns an label-dependent version.
+ virtual string LabeledCheckSum() const {
+ return impl_->LabeledCheckSum();
+ }
+
+ virtual bool Write(ostream &strm) const {
+ return impl_->Write(strm);
+ }
+
+ bool Write(const string& filename) const {
+ ofstream strm(filename.c_str(), ofstream::out | ofstream::binary);
+ if (!strm) {
+ LOG(ERROR) << "SymbolTable::Write: Can't open file " << filename;
+ return false;
+ }
+ return Write(strm);
+ }
+
+ // Dump an ascii text representation of the symbol table via a stream
+ virtual bool WriteText(
+ ostream &strm,
+ const SymbolTableTextOptions &opts = SymbolTableTextOptions()) const;
+
+ // Dump an ascii text representation of the symbol table
+ bool WriteText(const string& filename) const {
+ ofstream strm(filename.c_str());
+ if (!strm) {
+ LOG(ERROR) << "SymbolTable::WriteText: Can't open file " << filename;
+ return false;
+ }
+ return WriteText(strm);
+ }
+
+ // Return the string associated with the key. If the key is out of
+ // range (<0, >max), log error and return an empty string.
+ virtual string Find(int64 key) const {
+ return impl_->Find(key);
+ }
+
+ // Return the key associated with the symbol. If the symbol
+ // does not exists, log error and return SymbolTable::kNoSymbol
+ virtual int64 Find(const string& symbol) const {
+ return impl_->Find(symbol);
+ }
+
+ // Return the key associated with the symbol. If the symbol
+ // does not exists, log error and return SymbolTable::kNoSymbol
+ virtual int64 Find(const char* symbol) const {
+ return impl_->Find(symbol);
+ }
+
+ // Return the current available key (i.e highest key number+1) in
+ // the symbol table
+ virtual int64 AvailableKey(void) const {
+ return impl_->AvailableKey();
+ }
+
+ // Return the current number of symbols in table (not necessarily
+ // equal to AvailableKey())
+ virtual size_t NumSymbols(void) const {
+ return impl_->NumSymbols();
+ }
+
+ virtual int64 GetNthKey(ssize_t pos) const {
+ return impl_->GetNthKey(pos);
+ }
+
+ private:
+ explicit SymbolTable(SymbolTableImpl* impl) : impl_(impl) {}
+
+ void MutateCheck() {
+ // Copy on write
+ if (impl_->RefCount() > 1) {
+ impl_->DecrRefCount();
+ impl_ = new SymbolTableImpl(*impl_);
+ }
+ }
+
+ const SymbolTableImpl* Impl() const {
+ return impl_;
+ }
+
+ private:
+ SymbolTableImpl* impl_;
+};
+
+
+//
+// \class SymbolTableIterator
+// \brief Iterator class for symbols in a symbol table
+class SymbolTableIterator {
+ public:
+ SymbolTableIterator(const SymbolTable& table)
+ : table_(table),
+ pos_(0),
+ nsymbols_(table.NumSymbols()),
+ key_(table.GetNthKey(0)) { }
+
+ ~SymbolTableIterator() { }
+
+ // is iterator done
+ bool Done(void) {
+ return (pos_ == nsymbols_);
+ }
+
+ // return the Value() of the current symbol (int64 key)
+ int64 Value(void) {
+ return key_;
+ }
+
+ // return the string of the current symbol
+ string Symbol(void) {
+ return table_.Find(key_);
+ }
+
+ // advance iterator forward
+ void Next(void) {
+ ++pos_;
+ if (pos_ < nsymbols_) key_ = table_.GetNthKey(pos_);
+ }
+
+ // reset iterator
+ void Reset(void) {
+ pos_ = 0;
+ key_ = table_.GetNthKey(0);
+ }
+
+ private:
+ const SymbolTable& table_;
+ ssize_t pos_;
+ size_t nsymbols_;
+ int64 key_;
+};
+
+
+// Tests compatibilty between two sets of symbol tables
+inline bool CompatSymbols(const SymbolTable *syms1, const SymbolTable *syms2,
+ bool warning = true) {
+ if (!FLAGS_fst_compat_symbols) {
+ return true;
+ } else if (!syms1 && !syms2) {
+ return true;
+ } else if (syms1 && !syms2) {
+ if (warning)
+ LOG(WARNING) <<
+ "CompatSymbols: first symbol table present but second missing";
+ return false;
+ } else if (!syms1 && syms2) {
+ if (warning)
+ LOG(WARNING) <<
+ "CompatSymbols: second symbol table present but first missing";
+ return false;
+ } else if (syms1->LabeledCheckSum() != syms2->LabeledCheckSum()) {
+ if (warning)
+ LOG(WARNING) << "CompatSymbols: Symbol table check sums do not match";
+ return false;
+ } else {
+ return true;
+ }
+}
+
+
+// Relabels a symbol table as specified by the input vector of pairs
+// (old label, new label). The new symbol table only retains symbols
+// for which a relabeling is *explicitely* specified.
+// TODO(allauzen): consider adding options to allow for some form
+// of implicit identity relabeling.
+template <class Label>
+SymbolTable *RelabelSymbolTable(const SymbolTable *table,
+ const vector<pair<Label, Label> > &pairs) {
+ SymbolTable *new_table = new SymbolTable(
+ table->Name().empty() ? string() :
+ (string("relabeled_") + table->Name()));
+
+ for (size_t i = 0; i < pairs.size(); ++i)
+ new_table->AddSymbol(table->Find(pairs[i].first), pairs[i].second);
+
+ return new_table;
+}
+
+// Symbol Table Serialization
+inline void SymbolTableToString(const SymbolTable *table, string *result) {
+ ostringstream ostrm;
+ table->Write(ostrm);
+ *result = ostrm.str();
+}
+
+inline SymbolTable *StringToSymbolTable(const string &s) {
+ istringstream istrm(s);
+ return SymbolTable::Read(istrm, SymbolTableReadOptions());
+}
+
+
+
+} // namespace fst
+
+#endif // FST_LIB_SYMBOL_TABLE_H__