diff options
author | Determinant <[email protected]> | 2015-08-14 11:51:42 +0800 |
---|---|---|
committer | Determinant <[email protected]> | 2015-08-14 11:51:42 +0800 |
commit | 96a32415ab43377cf1575bd3f4f2980f58028209 (patch) | |
tree | 30a2d92d73e8f40ac87b79f6f56e227bfc4eea6e /kaldi_io/src/tools/openfst/include/fst/extensions | |
parent | c177a7549bd90670af4b29fa813ddea32cfe0f78 (diff) |
add implementation for kaldi io (by ymz)
Diffstat (limited to 'kaldi_io/src/tools/openfst/include/fst/extensions')
15 files changed, 3614 insertions, 0 deletions
diff --git a/kaldi_io/src/tools/openfst/include/fst/extensions/far/compile-strings.h b/kaldi_io/src/tools/openfst/include/fst/extensions/far/compile-strings.h new file mode 100644 index 0000000..ca247db --- /dev/null +++ b/kaldi_io/src/tools/openfst/include/fst/extensions/far/compile-strings.h @@ -0,0 +1,304 @@ + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Copyright 2005-2010 Google, Inc. +// Authors: [email protected] (Cyril Allauzen) +// [email protected] (Terry Tai) +// [email protected] (Jake Ratkiewicz) + + +#ifndef FST_EXTENSIONS_FAR_COMPILE_STRINGS_H_ +#define FST_EXTENSIONS_FAR_COMPILE_STRINGS_H_ + +#include <libgen.h> +#include <string> +#include <vector> +using std::vector; + +#include <fst/extensions/far/far.h> +#include <fst/string.h> + +namespace fst { + +// Construct a reader that provides FSTs from a file (stream) either on a +// line-by-line basis or on a per-stream basis. Note that the freshly +// constructed reader is already set to the first input. +// +// Sample Usage: +// for (StringReader<Arc> reader(...); !reader.Done(); reader.Next()) { +// Fst *fst = reader.GetVectorFst(); +// } +template <class A> +class StringReader { + public: + typedef A Arc; + typedef typename A::Label Label; + typedef typename A::Weight Weight; + typedef typename StringCompiler<A>::TokenType TokenType; + + enum EntryType { LINE = 1, FILE = 2 }; + + StringReader(istream &istrm, + const string &source, + EntryType entry_type, + TokenType token_type, + bool allow_negative_labels, + const SymbolTable *syms = 0, + Label unknown_label = kNoStateId) + : nline_(0), strm_(istrm), source_(source), entry_type_(entry_type), + token_type_(token_type), symbols_(syms), done_(false), + compiler_(token_type, syms, unknown_label, allow_negative_labels) { + Next(); // Initialize the reader to the first input. + } + + bool Done() { + return done_; + } + + void Next() { + VLOG(1) << "Processing source " << source_ << " at line " << nline_; + if (!strm_) { // We're done if we have no more input. + done_ = true; + return; + } + if (entry_type_ == LINE) { + getline(strm_, content_); + ++nline_; + } else { + content_.clear(); + string line; + while (getline(strm_, line)) { + ++nline_; + content_.append(line); + content_.append("\n"); + } + } + if (!strm_ && content_.empty()) // We're also done if we read off all the + done_ = true; // whitespace at the end of a file. + } + + VectorFst<A> *GetVectorFst(bool keep_symbols = false) { + VectorFst<A> *fst = new VectorFst<A>; + if (keep_symbols) { + fst->SetInputSymbols(symbols_); + fst->SetOutputSymbols(symbols_); + } + if (compiler_(content_, fst)) { + return fst; + } else { + delete fst; + return NULL; + } + } + + CompactFst<A, StringCompactor<A> > *GetCompactFst(bool keep_symbols = false) { + CompactFst<A, StringCompactor<A> > *fst; + if (keep_symbols) { + VectorFst<A> tmp; + tmp.SetInputSymbols(symbols_); + tmp.SetOutputSymbols(symbols_); + fst = new CompactFst<A, StringCompactor<A> >(tmp); + } else { + fst = new CompactFst<A, StringCompactor<A> >; + } + if (compiler_(content_, fst)) { + return fst; + } else { + delete fst; + return NULL; + } + } + + private: + size_t nline_; + istream &strm_; + string source_; + EntryType entry_type_; + TokenType token_type_; + const SymbolTable *symbols_; + bool done_; + StringCompiler<A> compiler_; + string content_; // The actual content of the input stream's next FST. + + DISALLOW_COPY_AND_ASSIGN(StringReader); +}; + +// Compute the minimal length required to encode each line number as a decimal +// number. +int KeySize(const char *filename); + +template <class Arc> +void FarCompileStrings(const vector<string> &in_fnames, + const string &out_fname, + const string &fst_type, + const FarType &far_type, + int32 generate_keys, + FarEntryType fet, + FarTokenType tt, + const string &symbols_fname, + const string &unknown_symbol, + bool keep_symbols, + bool initial_symbols, + bool allow_negative_labels, + bool file_list_input, + const string &key_prefix, + const string &key_suffix) { + typename StringReader<Arc>::EntryType entry_type; + if (fet == FET_LINE) { + entry_type = StringReader<Arc>::LINE; + } else if (fet == FET_FILE) { + entry_type = StringReader<Arc>::FILE; + } else { + FSTERROR() << "FarCompileStrings: unknown entry type"; + return; + } + + typename StringCompiler<Arc>::TokenType token_type; + if (tt == FTT_SYMBOL) { + token_type = StringCompiler<Arc>::SYMBOL; + } else if (tt == FTT_BYTE) { + token_type = StringCompiler<Arc>::BYTE; + } else if (tt == FTT_UTF8) { + token_type = StringCompiler<Arc>::UTF8; + } else { + FSTERROR() << "FarCompileStrings: unknown token type"; + return; + } + + bool compact; + if (fst_type.empty() || (fst_type == "vector")) { + compact = false; + } else if (fst_type == "compact") { + compact = true; + } else { + FSTERROR() << "FarCompileStrings: unknown fst type: " + << fst_type; + return; + } + + const SymbolTable *syms = 0; + typename Arc::Label unknown_label = kNoLabel; + if (!symbols_fname.empty()) { + SymbolTableTextOptions opts; + opts.allow_negative = allow_negative_labels; + syms = SymbolTable::ReadText(symbols_fname, opts); + if (!syms) { + FSTERROR() << "FarCompileStrings: error reading symbol table: " + << symbols_fname; + return; + } + if (!unknown_symbol.empty()) { + unknown_label = syms->Find(unknown_symbol); + if (unknown_label == kNoLabel) { + FSTERROR() << "FarCompileStrings: unknown label \"" << unknown_label + << "\" missing from symbol table: " << symbols_fname; + return; + } + } + } + + FarWriter<Arc> *far_writer = + FarWriter<Arc>::Create(out_fname, far_type); + if (!far_writer) return; + + vector<string> inputs; + if (file_list_input) { + for (int i = 1; i < in_fnames.size(); ++i) { + istream *istrm = in_fnames.empty() ? &cin : + new ifstream(in_fnames[i].c_str()); + string str; + while (getline(*istrm, str)) + inputs.push_back(str); + if (!in_fnames.empty()) + delete istrm; + } + } else { + inputs = in_fnames; + } + + for (int i = 0, n = 0; i < inputs.size(); ++i) { + if (generate_keys == 0 && inputs[i].empty()) { + FSTERROR() << "FarCompileStrings: read from a file instead of stdin or" + << " set the --generate_keys flags."; + delete far_writer; + delete syms; + return; + } + int key_size = generate_keys ? generate_keys : + (entry_type == StringReader<Arc>::FILE ? 1 : + KeySize(inputs[i].c_str())); + istream *istrm = inputs[i].empty() ? &cin : + new ifstream(inputs[i].c_str()); + + bool keep_syms = keep_symbols; + for (StringReader<Arc> reader( + *istrm, inputs[i].empty() ? "stdin" : inputs[i], + entry_type, token_type, allow_negative_labels, + syms, unknown_label); + !reader.Done(); + reader.Next()) { + ++n; + const Fst<Arc> *fst; + if (compact) + fst = reader.GetCompactFst(keep_syms); + else + fst = reader.GetVectorFst(keep_syms); + if (initial_symbols) + keep_syms = false; + if (!fst) { + FSTERROR() << "FarCompileStrings: compiling string number " << n + << " in file " << inputs[i] << " failed with token_type = " + << (tt == FTT_BYTE ? "byte" : + (tt == FTT_UTF8 ? "utf8" : + (tt == FTT_SYMBOL ? "symbol" : "unknown"))) + << " and entry_type = " + << (fet == FET_LINE ? "line" : + (fet == FET_FILE ? "file" : "unknown")); + delete far_writer; + delete syms; + if (!inputs[i].empty()) delete istrm; + return; + } + ostringstream keybuf; + keybuf.width(key_size); + keybuf.fill('0'); + keybuf << n; + string key; + if (generate_keys > 0) { + key = keybuf.str(); + } else { + char* filename = new char[inputs[i].size() + 1]; + strcpy(filename, inputs[i].c_str()); + key = basename(filename); + if (entry_type != StringReader<Arc>::FILE) { + key += "-"; + key += keybuf.str(); + } + delete[] filename; + } + far_writer->Add(key_prefix + key + key_suffix, *fst); + delete fst; + } + if (generate_keys == 0) + n = 0; + if (!inputs[i].empty()) + delete istrm; + } + + delete far_writer; +} + +} // namespace fst + + +#endif // FST_EXTENSIONS_FAR_COMPILE_STRINGS_H_ diff --git a/kaldi_io/src/tools/openfst/include/fst/extensions/far/create.h b/kaldi_io/src/tools/openfst/include/fst/extensions/far/create.h new file mode 100644 index 0000000..edb31e7 --- /dev/null +++ b/kaldi_io/src/tools/openfst/include/fst/extensions/far/create.h @@ -0,0 +1,87 @@ +// create-main.h + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Copyright 2005-2010 Google, Inc. +// Author: [email protected] (Michael Riley) +// Modified: [email protected] (Jake Ratkiewicz) to use new dispatch +// +// \file +// Creates a finite-state archive from component FSTs. Includes +// helper function for farcreate.cc that templates the main on the arc +// type to support multiple and extensible arc types. +// + +#ifndef FST_EXTENSIONS_FAR_CREATE_H__ +#define FST_EXTENSIONS_FAR_CREATE_H__ + +#include <libgen.h> +#include <string> +#include <vector> +using std::vector; + +#include <fst/extensions/far/far.h> + +namespace fst { + +template <class Arc> +void FarCreate(const vector<string> &in_fnames, + const string &out_fname, + const int32 generate_keys, + const bool file_list_input, + const FarType &far_type, + const string &key_prefix, + const string &key_suffix) { + FarWriter<Arc> *far_writer = + FarWriter<Arc>::Create(out_fname, far_type); + if (!far_writer) return; + + vector<string> inputs; + if (file_list_input) { + for (int i = 1; i < in_fnames.size(); ++i) { + ifstream istrm(in_fnames[i].c_str()); + string str; + while (getline(istrm, str)) + inputs.push_back(str); + } + } else { + inputs = in_fnames; + } + + for (int i = 0; i < inputs.size(); ++i) { + Fst<Arc> *ifst = Fst<Arc>::Read(inputs[i]); + if (!ifst) return; + string key; + if (generate_keys > 0) { + ostringstream keybuf; + keybuf.width(generate_keys); + keybuf.fill('0'); + keybuf << i + 1; + key = keybuf.str(); + } else { + char* filename = new char[inputs[i].size() + 1]; + strcpy(filename, inputs[i].c_str()); + key = basename(filename); + delete[] filename; + } + + far_writer->Add(key_prefix + key + key_suffix, *ifst); + delete ifst; + } + + delete far_writer; +} + +} // namespace fst + +#endif // FST_EXTENSIONS_FAR_CREATE_H__ diff --git a/kaldi_io/src/tools/openfst/include/fst/extensions/far/equal.h b/kaldi_io/src/tools/openfst/include/fst/extensions/far/equal.h new file mode 100644 index 0000000..be82e2d --- /dev/null +++ b/kaldi_io/src/tools/openfst/include/fst/extensions/far/equal.h @@ -0,0 +1,99 @@ + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Copyright 2005-2010 Google, Inc. +// Author: [email protected] (Cyril Allauzen) + +#ifndef FST_EXTENSIONS_FAR_EQUAL_H_ +#define FST_EXTENSIONS_FAR_EQUAL_H_ + +#include <string> + +#include <fst/extensions/far/far.h> +#include <fst/equal.h> + +namespace fst { + +template <class Arc> +bool FarEqual(const string &filename1, + const string &filename2, + float delta = kDelta, + const string &begin_key = string(), + const string &end_key = string()) { + + FarReader<Arc> *reader1 = FarReader<Arc>::Open(filename1); + FarReader<Arc> *reader2 = FarReader<Arc>::Open(filename2); + if (!reader1 || !reader2) { + delete reader1; + delete reader2; + VLOG(1) << "FarEqual: cannot open input Far file(s)"; + return false; + } + + if (!begin_key.empty()) { + bool find_begin1 = reader1->Find(begin_key); + bool find_begin2 = reader2->Find(begin_key); + if (!find_begin1 || !find_begin2) { + bool ret = !find_begin1 && !find_begin2; + if (!ret) { + VLOG(1) << "FarEqual: key \"" << begin_key << "\" missing from " + << (find_begin1 ? "second" : "first") << " archive."; + } + delete reader1; + delete reader2; + return ret; + } + } + + for(; !reader1->Done() && !reader2->Done(); + reader1->Next(), reader2->Next()) { + const string key1 = reader1->GetKey(); + const string key2 = reader2->GetKey(); + if (!end_key.empty() && end_key < key1 && end_key < key2) { + delete reader1; + delete reader2; + return true; + } + if (key1 != key2) { + VLOG(1) << "FarEqual: mismatched keys \"" + << key1 << "\" <> \"" << key2 << "\"."; + delete reader1; + delete reader2; + return false; + } + if (!Equal(reader1->GetFst(), reader2->GetFst(), delta)) { + VLOG(1) << "FarEqual: Fsts for key \"" << key1 << "\" are not equal."; + delete reader1; + delete reader2; + return false; + } + } + + if (!reader1->Done() || !reader2->Done()) { + VLOG(1) << "FarEqual: key \"" + << (reader1->Done() ? reader2->GetKey() : reader1->GetKey()) + << "\" missing form " << (reader2->Done() ? "first" : "second") + << " archive."; + delete reader1; + delete reader2; + return false; + } + + delete reader1; + delete reader2; + return true; +} + +} // namespace fst + +#endif // FST_EXTENSIONS_FAR_EQUAL_H_ diff --git a/kaldi_io/src/tools/openfst/include/fst/extensions/far/extract.h b/kaldi_io/src/tools/openfst/include/fst/extensions/far/extract.h new file mode 100644 index 0000000..95866de --- /dev/null +++ b/kaldi_io/src/tools/openfst/include/fst/extensions/far/extract.h @@ -0,0 +1,140 @@ +// extract-main.h + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Copyright 2005-2010 Google, Inc. +// Author: [email protected] (Michael Riley) +// Modified: [email protected] (Jake Ratkiewicz) to use the new arc-dispatch + +// \file +// Extracts component FSTs from an finite-state archive. +// + +#ifndef FST_EXTENSIONS_FAR_EXTRACT_H__ +#define FST_EXTENSIONS_FAR_EXTRACT_H__ + +#include <string> +#include <vector> +using std::vector; + +#include <fst/extensions/far/far.h> + +namespace fst { + +template<class Arc> +inline void FarWriteFst(const Fst<Arc>* fst, string key, + string* okey, int* nrep, + const int32 &generate_filenames, int i, + const string &filename_prefix, + const string &filename_suffix) { + if (key == *okey) + ++*nrep; + else + *nrep = 0; + + *okey = key; + + string ofilename; + if (generate_filenames) { + ostringstream tmp; + tmp.width(generate_filenames); + tmp.fill('0'); + tmp << i; + ofilename = tmp.str(); + } else { + if (*nrep > 0) { + ostringstream tmp; + tmp << '.' << nrep; + key.append(tmp.str().data(), tmp.str().size()); + } + ofilename = key; + } + fst->Write(filename_prefix + ofilename + filename_suffix); +} + +template<class Arc> +void FarExtract(const vector<string> &ifilenames, + const int32 &generate_filenames, + const string &keys, + const string &key_separator, + const string &range_delimiter, + const string &filename_prefix, + const string &filename_suffix) { + FarReader<Arc> *far_reader = FarReader<Arc>::Open(ifilenames); + if (!far_reader) return; + + string okey; + int nrep = 0; + + vector<char *> key_vector; + // User has specified a set of fsts to extract, where some of the "fsts" could + // be ranges. + if (!keys.empty()) { + char *keys_cstr = new char[keys.size()+1]; + strcpy(keys_cstr, keys.c_str()); + SplitToVector(keys_cstr, key_separator.c_str(), &key_vector, true); + int i = 0; + for (int k = 0; k < key_vector.size(); ++k, ++i) { + string key = string(key_vector[k]); + char *key_cstr = new char[key.size()+1]; + strcpy(key_cstr, key.c_str()); + vector<char *> range_vector; + SplitToVector(key_cstr, range_delimiter.c_str(), &range_vector, false); + if (range_vector.size() == 1) { // Not a range + if (!far_reader->Find(key)) { + LOG(ERROR) << "FarExtract: Cannot find key: " << key; + return; + } + const Fst<Arc> &fst = far_reader->GetFst(); + FarWriteFst(&fst, key, &okey, &nrep, generate_filenames, i, + filename_prefix, filename_suffix); + } else if (range_vector.size() == 2) { // A legal range + string begin_key = string(range_vector[0]); + string end_key = string(range_vector[1]); + if (begin_key.empty() || end_key.empty()) { + LOG(ERROR) << "FarExtract: Illegal range specification: " << key; + return; + } + if (!far_reader->Find(begin_key)) { + LOG(ERROR) << "FarExtract: Cannot find key: " << begin_key; + return; + } + for ( ; !far_reader->Done(); far_reader->Next(), ++i) { + string ikey = far_reader->GetKey(); + if (end_key < ikey) break; + const Fst<Arc> &fst = far_reader->GetFst(); + FarWriteFst(&fst, ikey, &okey, &nrep, generate_filenames, i, + filename_prefix, filename_suffix); + } + } else { + LOG(ERROR) << "FarExtract: Illegal range specification: " << key; + return; + } + delete key_cstr; + } + delete keys_cstr; + return; + } + // Nothing specified: extract everything. + for (int i = 1; !far_reader->Done(); far_reader->Next(), ++i) { + string key = far_reader->GetKey(); + const Fst<Arc> &fst = far_reader->GetFst(); + FarWriteFst(&fst, key, &okey, &nrep, generate_filenames, i, + filename_prefix, filename_suffix); + } + return; +} + +} // namespace fst + +#endif // FST_EXTENSIONS_FAR_EXTRACT_H__ diff --git a/kaldi_io/src/tools/openfst/include/fst/extensions/far/far.h b/kaldi_io/src/tools/openfst/include/fst/extensions/far/far.h new file mode 100644 index 0000000..acce76e --- /dev/null +++ b/kaldi_io/src/tools/openfst/include/fst/extensions/far/far.h @@ -0,0 +1,532 @@ +// far.h + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Copyright 2005-2010 Google, Inc. +// Author: [email protected] (Michael Riley) +// +// \file +// Finite-State Transducer (FST) archive classes. +// + +#ifndef FST_EXTENSIONS_FAR_FAR_H__ +#define FST_EXTENSIONS_FAR_FAR_H__ + +#include <fst/extensions/far/stlist.h> +#include <fst/extensions/far/sttable.h> +#include <fst/fst.h> +#include <fst/vector-fst.h> + +namespace fst { + +enum FarEntryType { FET_LINE, FET_FILE }; +enum FarTokenType { FTT_SYMBOL, FTT_BYTE, FTT_UTF8 }; + +inline bool IsFst(const string &filename) { + ifstream strm(filename.c_str()); + if (!strm) + return false; + return IsFstHeader(strm, filename); +} + +// FST archive header class +class FarHeader { + public: + const string &FarType() const { return fartype_; } + const string &ArcType() const { return arctype_; } + + bool Read(const string &filename) { + FstHeader fsthdr; + if (filename.empty()) { + // Header reading unsupported on stdin. Assumes STList and StdArc. + fartype_ = "stlist"; + arctype_ = "standard"; + return true; + } else if (IsSTTable(filename)) { // Check if STTable + ReadSTTableHeader(filename, &fsthdr); + fartype_ = "sttable"; + arctype_ = fsthdr.ArcType().empty() ? "unknown" : fsthdr.ArcType(); + return true; + } else if (IsSTList(filename)) { // Check if STList + ReadSTListHeader(filename, &fsthdr); + fartype_ = "sttable"; + arctype_ = fsthdr.ArcType().empty() ? "unknown" : fsthdr.ArcType(); + return true; + } else if (IsFst(filename)) { // Check if Fst + ifstream istrm(filename.c_str()); + fsthdr.Read(istrm, filename); + fartype_ = "fst"; + arctype_ = fsthdr.ArcType().empty() ? "unknown" : fsthdr.ArcType(); + return true; + } + return false; + } + + private: + string fartype_; + string arctype_; +}; + +enum FarType { + FAR_DEFAULT = 0, + FAR_STTABLE = 1, + FAR_STLIST = 2, + FAR_FST = 3, +}; + +// This class creates an archive of FSTs. +template <class A> +class FarWriter { + public: + typedef A Arc; + + // Creates a new (empty) FST archive; returns NULL on error. + static FarWriter *Create(const string &filename, FarType type = FAR_DEFAULT); + + // Adds an FST to the end of an archive. Keys must be non-empty and + // in lexicographic order. FSTs must have a suitable write method. + virtual void Add(const string &key, const Fst<A> &fst) = 0; + + virtual FarType Type() const = 0; + + virtual bool Error() const = 0; + + virtual ~FarWriter() {} + + protected: + FarWriter() {} + + private: + DISALLOW_COPY_AND_ASSIGN(FarWriter); +}; + + +// This class iterates through an existing archive of FSTs. +template <class A> +class FarReader { + public: + typedef A Arc; + + // Opens an existing FST archive in a single file; returns NULL on error. + // Sets current position to the beginning of the achive. + static FarReader *Open(const string &filename); + + // Opens an existing FST archive in multiple files; returns NULL on error. + // Sets current position to the beginning of the achive. + static FarReader *Open(const vector<string> &filenames); + + // Resets current posision to beginning of archive. + virtual void Reset() = 0; + + // Sets current position to first entry >= key. Returns true if a match. + virtual bool Find(const string &key) = 0; + + // Current position at end of archive? + virtual bool Done() const = 0; + + // Move current position to next FST. + virtual void Next() = 0; + + // Returns key at the current position. This reference is invalidated if + // the current position in the archive is changed. + virtual const string &GetKey() const = 0; + + // Returns FST at the current position. This reference is invalidated if + // the current position in the archive is changed. + virtual const Fst<A> &GetFst() const = 0; + + virtual FarType Type() const = 0; + + virtual bool Error() const = 0; + + virtual ~FarReader() {} + + protected: + FarReader() {} + + private: + DISALLOW_COPY_AND_ASSIGN(FarReader); +}; + + +template <class A> +class FstWriter { + public: + void operator()(ostream &strm, const Fst<A> &fst) const { + fst.Write(strm, FstWriteOptions()); + } +}; + + +template <class A> +class STTableFarWriter : public FarWriter<A> { + public: + typedef A Arc; + + static STTableFarWriter *Create(const string &filename) { + STTableWriter<Fst<A>, FstWriter<A> > *writer = + STTableWriter<Fst<A>, FstWriter<A> >::Create(filename); + return new STTableFarWriter(writer); + } + + void Add(const string &key, const Fst<A> &fst) { writer_->Add(key, fst); } + + FarType Type() const { return FAR_STTABLE; } + + bool Error() const { return writer_->Error(); } + + ~STTableFarWriter() { delete writer_; } + + private: + explicit STTableFarWriter(STTableWriter<Fst<A>, FstWriter<A> > *writer) + : writer_(writer) {} + + private: + STTableWriter<Fst<A>, FstWriter<A> > *writer_; + + DISALLOW_COPY_AND_ASSIGN(STTableFarWriter); +}; + + +template <class A> +class STListFarWriter : public FarWriter<A> { + public: + typedef A Arc; + + static STListFarWriter *Create(const string &filename) { + STListWriter<Fst<A>, FstWriter<A> > *writer = + STListWriter<Fst<A>, FstWriter<A> >::Create(filename); + return new STListFarWriter(writer); + } + + void Add(const string &key, const Fst<A> &fst) { writer_->Add(key, fst); } + + FarType Type() const { return FAR_STLIST; } + + bool Error() const { return writer_->Error(); } + + ~STListFarWriter() { delete writer_; } + + private: + explicit STListFarWriter(STListWriter<Fst<A>, FstWriter<A> > *writer) + : writer_(writer) {} + + private: + STListWriter<Fst<A>, FstWriter<A> > *writer_; + + DISALLOW_COPY_AND_ASSIGN(STListFarWriter); +}; + + +template <class A> +class FstFarWriter : public FarWriter<A> { + public: + typedef A Arc; + + explicit FstFarWriter(const string &filename) + : filename_(filename), error_(false), written_(false) {} + + static FstFarWriter *Create(const string &filename) { + return new FstFarWriter(filename); + } + + void Add(const string &key, const Fst<A> &fst) { + if (written_) { + LOG(WARNING) << "FstFarWriter::Add: only one Fst supported," + << " subsequent entries discarded."; + } else { + error_ = !fst.Write(filename_); + written_ = true; + } + } + + FarType Type() const { return FAR_FST; } + + bool Error() const { return error_; } + + ~FstFarWriter() {} + + private: + string filename_; + bool error_; + bool written_; + + DISALLOW_COPY_AND_ASSIGN(FstFarWriter); +}; + + +template <class A> +FarWriter<A> *FarWriter<A>::Create(const string &filename, FarType type) { + switch(type) { + case FAR_DEFAULT: + if (filename.empty()) + return STListFarWriter<A>::Create(filename); + case FAR_STTABLE: + return STTableFarWriter<A>::Create(filename); + break; + case FAR_STLIST: + return STListFarWriter<A>::Create(filename); + break; + case FAR_FST: + return FstFarWriter<A>::Create(filename); + break; + default: + LOG(ERROR) << "FarWriter::Create: unknown far type"; + return 0; + } +} + + +template <class A> +class FstReader { + public: + Fst<A> *operator()(istream &strm) const { + return Fst<A>::Read(strm, FstReadOptions()); + } +}; + + +template <class A> +class STTableFarReader : public FarReader<A> { + public: + typedef A Arc; + + static STTableFarReader *Open(const string &filename) { + STTableReader<Fst<A>, FstReader<A> > *reader = + STTableReader<Fst<A>, FstReader<A> >::Open(filename); + // TODO: error check + return new STTableFarReader(reader); + } + + static STTableFarReader *Open(const vector<string> &filenames) { + STTableReader<Fst<A>, FstReader<A> > *reader = + STTableReader<Fst<A>, FstReader<A> >::Open(filenames); + // TODO: error check + return new STTableFarReader(reader); + } + + void Reset() { reader_->Reset(); } + + bool Find(const string &key) { return reader_->Find(key); } + + bool Done() const { return reader_->Done(); } + + void Next() { return reader_->Next(); } + + const string &GetKey() const { return reader_->GetKey(); } + + const Fst<A> &GetFst() const { return reader_->GetEntry(); } + + FarType Type() const { return FAR_STTABLE; } + + bool Error() const { return reader_->Error(); } + + ~STTableFarReader() { delete reader_; } + + private: + explicit STTableFarReader(STTableReader<Fst<A>, FstReader<A> > *reader) + : reader_(reader) {} + + private: + STTableReader<Fst<A>, FstReader<A> > *reader_; + + DISALLOW_COPY_AND_ASSIGN(STTableFarReader); +}; + + +template <class A> +class STListFarReader : public FarReader<A> { + public: + typedef A Arc; + + static STListFarReader *Open(const string &filename) { + STListReader<Fst<A>, FstReader<A> > *reader = + STListReader<Fst<A>, FstReader<A> >::Open(filename); + // TODO: error check + return new STListFarReader(reader); + } + + static STListFarReader *Open(const vector<string> &filenames) { + STListReader<Fst<A>, FstReader<A> > *reader = + STListReader<Fst<A>, FstReader<A> >::Open(filenames); + // TODO: error check + return new STListFarReader(reader); + } + + void Reset() { reader_->Reset(); } + + bool Find(const string &key) { return reader_->Find(key); } + + bool Done() const { return reader_->Done(); } + + void Next() { return reader_->Next(); } + + const string &GetKey() const { return reader_->GetKey(); } + + const Fst<A> &GetFst() const { return reader_->GetEntry(); } + + FarType Type() const { return FAR_STLIST; } + + bool Error() const { return reader_->Error(); } + + ~STListFarReader() { delete reader_; } + + private: + explicit STListFarReader(STListReader<Fst<A>, FstReader<A> > *reader) + : reader_(reader) {} + + private: + STListReader<Fst<A>, FstReader<A> > *reader_; + + DISALLOW_COPY_AND_ASSIGN(STListFarReader); +}; + +template <class A> +class FstFarReader : public FarReader<A> { + public: + typedef A Arc; + + static FstFarReader *Open(const string &filename) { + vector<string> filenames; + filenames.push_back(filename); + return new FstFarReader<A>(filenames); + } + + static FstFarReader *Open(const vector<string> &filenames) { + return new FstFarReader<A>(filenames); + } + + FstFarReader(const vector<string> &filenames) + : keys_(filenames), has_stdin_(false), pos_(0), fst_(0), error_(false) { + sort(keys_.begin(), keys_.end()); + streams_.resize(keys_.size(), 0); + for (size_t i = 0; i < keys_.size(); ++i) { + if (keys_[i].empty()) { + if (!has_stdin_) { + streams_[i] = &cin; + //sources_[i] = "stdin"; + has_stdin_ = true; + } else { + FSTERROR() << "FstFarReader::FstFarReader: stdin should only " + << "appear once in the input file list."; + error_ = true; + return; + } + } else { + streams_[i] = new ifstream( + keys_[i].c_str(), ifstream::in | ifstream::binary); + } + } + if (pos_ >= keys_.size()) return; + ReadFst(); + } + + void Reset() { + if (has_stdin_) { + FSTERROR() << "FstFarReader::Reset: operation not supported on stdin"; + error_ = true; + return; + } + pos_ = 0; + ReadFst(); + } + + bool Find(const string &key) { + if (has_stdin_) { + FSTERROR() << "FstFarReader::Find: operation not supported on stdin"; + error_ = true; + return false; + } + pos_ = 0;//TODO + ReadFst(); + return true; + } + + bool Done() const { return error_ || pos_ >= keys_.size(); } + + void Next() { + ++pos_; + ReadFst(); + } + + const string &GetKey() const { + return keys_[pos_]; + } + + const Fst<A> &GetFst() const { + return *fst_; + } + + FarType Type() const { return FAR_FST; } + + bool Error() const { return error_; } + + ~FstFarReader() { + if (fst_) delete fst_; + for (size_t i = 0; i < keys_.size(); ++i) + delete streams_[i]; + } + + private: + void ReadFst() { + if (fst_) delete fst_; + if (pos_ >= keys_.size()) return; + streams_[pos_]->seekg(0); + fst_ = Fst<A>::Read(*streams_[pos_], FstReadOptions()); + if (!fst_) { + FSTERROR() << "FstFarReader: error reading Fst from: " << keys_[pos_]; + error_ = true; + } + } + + private: + vector<string> keys_; + vector<istream*> streams_; + bool has_stdin_; + size_t pos_; + mutable Fst<A> *fst_; + mutable bool error_; + + DISALLOW_COPY_AND_ASSIGN(FstFarReader); +}; + +template <class A> +FarReader<A> *FarReader<A>::Open(const string &filename) { + if (filename.empty()) + return STListFarReader<A>::Open(filename); + else if (IsSTTable(filename)) + return STTableFarReader<A>::Open(filename); + else if (IsSTList(filename)) + return STListFarReader<A>::Open(filename); + else if (IsFst(filename)) + return FstFarReader<A>::Open(filename); + return 0; +} + + +template <class A> +FarReader<A> *FarReader<A>::Open(const vector<string> &filenames) { + if (!filenames.empty() && filenames[0].empty()) + return STListFarReader<A>::Open(filenames); + else if (!filenames.empty() && IsSTTable(filenames[0])) + return STTableFarReader<A>::Open(filenames); + else if (!filenames.empty() && IsSTList(filenames[0])) + return STListFarReader<A>::Open(filenames); + else if (!filenames.empty() && IsFst(filenames[0])) + return FstFarReader<A>::Open(filenames); + return 0; +} + +} // namespace fst + +#endif // FST_EXTENSIONS_FAR_FAR_H__ diff --git a/kaldi_io/src/tools/openfst/include/fst/extensions/far/farlib.h b/kaldi_io/src/tools/openfst/include/fst/extensions/far/farlib.h new file mode 100644 index 0000000..91ba224 --- /dev/null +++ b/kaldi_io/src/tools/openfst/include/fst/extensions/far/farlib.h @@ -0,0 +1,31 @@ + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Copyright 2005-2010 Google, Inc. +// Author: [email protected] (Jake Ratkiewicz) + +// A finite-state archive (FAR) is used to store an indexable collection of +// FSTs in a single file. Utilities are provided to create FARs from FSTs, +// to iterate over FARs, and to extract specific FSTs from FARs. + +#ifndef FST_EXTENSIONS_FAR_FARLIB_H_ +#define FST_EXTENSIONS_FAR_FARLIB_H_ + +#include <fst/extensions/far/far.h> +#include <fst/extensions/far/compile-strings.h> +#include <fst/extensions/far/create.h> +#include <fst/extensions/far/extract.h> +#include <fst/extensions/far/info.h> +#include <fst/extensions/far/print-strings.h> + +#endif // FST_EXTENSIONS_FAR_FARLIB_H_ diff --git a/kaldi_io/src/tools/openfst/include/fst/extensions/far/farscript.h b/kaldi_io/src/tools/openfst/include/fst/extensions/far/farscript.h new file mode 100644 index 0000000..cfd9167 --- /dev/null +++ b/kaldi_io/src/tools/openfst/include/fst/extensions/far/farscript.h @@ -0,0 +1,273 @@ + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Copyright 2005-2010 Google, Inc. +// Author: [email protected] (Jake Ratkiewicz) + +// Convenience file for including all of the FAR operations, +// or registering them for new arc types. + +#ifndef FST_EXTENSIONS_FAR_FARSCRIPT_H_ +#define FST_EXTENSIONS_FAR_FARSCRIPT_H_ + +#include <vector> +using std::vector; +#include <string> + +#include <fst/script/arg-packs.h> +#include <fst/extensions/far/compile-strings.h> +#include <fst/extensions/far/create.h> +#include <fst/extensions/far/equal.h> +#include <fst/extensions/far/extract.h> +#include <fst/extensions/far/info.h> +#include <fst/extensions/far/print-strings.h> +#include <fst/extensions/far/far.h> + +#include <fst/types.h> + +namespace fst { +namespace script { + +// Note: it is safe to pass these strings as references because +// this struct is only used to pass them deeper in the call graph. +// Be sure you understand why this is so before using this struct +// for anything else! +struct FarCompileStringsArgs { + const vector<string> &in_fnames; + const string &out_fname; + const string &fst_type; + const FarType &far_type; + const int32 generate_keys; + const FarEntryType fet; + const FarTokenType tt; + const string &symbols_fname; + const string &unknown_symbol; + const bool keep_symbols; + const bool initial_symbols; + const bool allow_negative_labels; + const bool file_list_input; + const string &key_prefix; + const string &key_suffix; + + FarCompileStringsArgs(const vector<string> &in_fnames, + const string &out_fname, + const string &fst_type, + const FarType &far_type, + int32 generate_keys, + FarEntryType fet, + FarTokenType tt, + const string &symbols_fname, + const string &unknown_symbol, + bool keep_symbols, + bool initial_symbols, + bool allow_negative_labels, + bool file_list_input, + const string &key_prefix, + const string &key_suffix) : + in_fnames(in_fnames), out_fname(out_fname), fst_type(fst_type), + far_type(far_type), generate_keys(generate_keys), fet(fet), + tt(tt), symbols_fname(symbols_fname), unknown_symbol(unknown_symbol), + keep_symbols(keep_symbols), initial_symbols(initial_symbols), + allow_negative_labels(allow_negative_labels), + file_list_input(file_list_input), key_prefix(key_prefix), + key_suffix(key_suffix) { } +}; + +template <class Arc> +void FarCompileStrings(FarCompileStringsArgs *args) { + fst::FarCompileStrings<Arc>( + args->in_fnames, args->out_fname, args->fst_type, args->far_type, + args->generate_keys, args->fet, args->tt, args->symbols_fname, + args->unknown_symbol, args->keep_symbols, args->initial_symbols, + args->allow_negative_labels, args->file_list_input, + args->key_prefix, args->key_suffix); +} + +void FarCompileStrings( + const vector<string> &in_fnames, + const string &out_fname, + const string &arc_type, + const string &fst_type, + const FarType &far_type, + int32 generate_keys, + FarEntryType fet, + FarTokenType tt, + const string &symbols_fname, + const string &unknown_symbol, + bool keep_symbols, + bool initial_symbols, + bool allow_negative_labels, + bool file_list_input, + const string &key_prefix, + const string &key_suffix); + + +// Note: it is safe to pass these strings as references because +// this struct is only used to pass them deeper in the call graph. +// Be sure you understand why this is so before using this struct +// for anything else! +struct FarCreateArgs { + const vector<string> &in_fnames; + const string &out_fname; + const int32 generate_keys; + const bool file_list_input; + const FarType &far_type; + const string &key_prefix; + const string &key_suffix; + + FarCreateArgs( + const vector<string> &in_fnames, const string &out_fname, + const int32 generate_keys, const bool file_list_input, + const FarType &far_type, const string &key_prefix, + const string &key_suffix) + : in_fnames(in_fnames), out_fname(out_fname), + generate_keys(generate_keys), file_list_input(file_list_input), + far_type(far_type), key_prefix(key_prefix), key_suffix(key_suffix) { } +}; + +template<class Arc> +void FarCreate(FarCreateArgs *args) { + fst::FarCreate<Arc>(args->in_fnames, args->out_fname, args->generate_keys, + args->file_list_input, args->far_type, + args->key_prefix, args->key_suffix); +} + +void FarCreate(const vector<string> &in_fnames, + const string &out_fname, + const string &arc_type, + const int32 generate_keys, + const bool file_list_input, + const FarType &far_type, + const string &key_prefix, + const string &key_suffix); + + +typedef args::Package<const string &, const string &, float, + const string &, const string &> FarEqualInnerArgs; +typedef args::WithReturnValue<bool, FarEqualInnerArgs> FarEqualArgs; + +template <class Arc> +void FarEqual(FarEqualArgs *args) { + args->retval = fst::FarEqual<Arc>( + args->args.arg1, args->args.arg2, args->args.arg3, + args->args.arg4, args->args.arg5); +} + +bool FarEqual(const string &filename1, + const string &filename2, + const string &arc_type, + float delta = kDelta, + const string &begin_key = string(), + const string &end_key = string()); + + +typedef args::Package<const vector<string> &, int32, + const string&, const string&, const string&, + const string&, const string&> FarExtractArgs; + +template<class Arc> +void FarExtract(FarExtractArgs *args) { + fst::FarExtract<Arc>( + args->arg1, args->arg2, args->arg3, args->arg4, args->arg5, args->arg6, + args->arg7); +} + +void FarExtract(const vector<string> &ifilenames, + const string &arc_type, + int32 generate_filenames, + const string &keys, + const string &key_separator, + const string &range_delimiter, + const string &filename_prefix, + const string &filename_suffix); + +typedef args::Package<const vector<string> &, const string &, + const string &, const bool> FarInfoArgs; + +template <class Arc> +void FarInfo(FarInfoArgs *args) { + fst::FarInfo<Arc>(args->arg1, args->arg2, args->arg3, args->arg4); +} + +void FarInfo(const vector<string> &filenames, + const string &arc_type, + const string &begin_key, + const string &end_key, + const bool list_fsts); + +struct FarPrintStringsArgs { + const vector<string> &ifilenames; + const FarEntryType entry_type; + const FarTokenType token_type; + const string &begin_key; + const string &end_key; + const bool print_key; + const bool print_weight; + const string &symbols_fname; + const bool initial_symbols; + const int32 generate_filenames; + const string &filename_prefix; + const string &filename_suffix; + + FarPrintStringsArgs( + const vector<string> &ifilenames, const FarEntryType entry_type, + const FarTokenType token_type, const string &begin_key, + const string &end_key, const bool print_key, const bool print_weight, + const string &symbols_fname, const bool initial_symbols, + const int32 generate_filenames, + const string &filename_prefix, const string &filename_suffix) : + ifilenames(ifilenames), entry_type(entry_type), token_type(token_type), + begin_key(begin_key), end_key(end_key), + print_key(print_key), print_weight(print_weight), + symbols_fname(symbols_fname), initial_symbols(initial_symbols), + generate_filenames(generate_filenames), filename_prefix(filename_prefix), + filename_suffix(filename_suffix) { } +}; + +template <class Arc> +void FarPrintStrings(FarPrintStringsArgs *args) { + fst::FarPrintStrings<Arc>( + args->ifilenames, args->entry_type, args->token_type, + args->begin_key, args->end_key, args->print_key, args->print_weight, + args->symbols_fname, args->initial_symbols, args->generate_filenames, + args->filename_prefix, args->filename_suffix); +} + + +void FarPrintStrings(const vector<string> &ifilenames, + const string &arc_type, + const FarEntryType entry_type, + const FarTokenType token_type, + const string &begin_key, + const string &end_key, + const bool print_key, + const bool print_weight, + const string &symbols_fname, + const bool initial_symbols, + const int32 generate_filenames, + const string &filename_prefix, + const string &filename_suffix); + +} // namespace script +} // namespace fst + + +#define REGISTER_FST_FAR_OPERATIONS(ArcType) \ + REGISTER_FST_OPERATION(FarCompileStrings, ArcType, FarCompileStringsArgs); \ + REGISTER_FST_OPERATION(FarCreate, ArcType, FarCreateArgs); \ + REGISTER_FST_OPERATION(FarEqual, ArcType, FarEqualArgs); \ + REGISTER_FST_OPERATION(FarExtract, ArcType, FarExtractArgs); \ + REGISTER_FST_OPERATION(FarInfo, ArcType, FarInfoArgs); \ + REGISTER_FST_OPERATION(FarPrintStrings, ArcType, FarPrintStringsArgs) + +#endif // FST_EXTENSIONS_FAR_FARSCRIPT_H_ diff --git a/kaldi_io/src/tools/openfst/include/fst/extensions/far/info.h b/kaldi_io/src/tools/openfst/include/fst/extensions/far/info.h new file mode 100644 index 0000000..100fe68 --- /dev/null +++ b/kaldi_io/src/tools/openfst/include/fst/extensions/far/info.h @@ -0,0 +1,128 @@ + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Copyright 2005-2010 Google, Inc. +// Author: [email protected] (Cyril Allauzen) +// Modified: [email protected] (Jake Ratkiewicz) + +#ifndef FST_EXTENSIONS_FAR_INFO_H_ +#define FST_EXTENSIONS_FAR_INFO_H_ + +#include <iomanip> +#include <set> +#include <string> +#include <vector> +using std::vector; + +#include <fst/extensions/far/far.h> +#include <fst/extensions/far/main.h> // For FarTypeToString + +namespace fst { + +template <class Arc> +void CountStatesAndArcs(const Fst<Arc> &fst, size_t *nstate, size_t *narc) { + StateIterator<Fst<Arc> > siter(fst); + for (; !siter.Done(); siter.Next(), ++(*nstate)) { + ArcIterator<Fst<Arc> > aiter(fst, siter.Value()); + for (; !aiter.Done(); aiter.Next(), ++(*narc)) {} + } +} + +struct KeyInfo { + string key; + string type; + size_t nstate; + size_t narc; + + KeyInfo(string k, string t, int64 ns = 0, int64 na = 0) + : key(k), type(t), nstate(ns), narc(na) {} +}; + +template <class Arc> +void FarInfo(const vector<string> &filenames, const string &begin_key, + const string &end_key, const bool list_fsts) { + FarReader<Arc> *far_reader = FarReader<Arc>::Open(filenames); + if (!far_reader) return; + + if (!begin_key.empty()) + far_reader->Find(begin_key); + + vector<KeyInfo> *infos = list_fsts ? new vector<KeyInfo>() : 0; + size_t nfst = 0, nstate = 0, narc = 0; + set<string> fst_types; + for (; !far_reader->Done(); far_reader->Next()) { + string key = far_reader->GetKey(); + if (!end_key.empty() && end_key < key) + break; + ++nfst; + const Fst<Arc> &fst = far_reader->GetFst(); + fst_types.insert(fst.Type()); + if (infos) { + KeyInfo info(key, fst.Type()); + CountStatesAndArcs(fst, &info.nstate, &info.narc); + nstate += info.nstate; + nstate += info.narc; + infos->push_back(info); + } else { + CountStatesAndArcs(fst, &nstate, &narc); + } + } + + if (!infos) { + cout << std::left << setw(50) << "far type" + << FarTypeToString(far_reader->Type()) << endl; + cout << std::left << setw(50) << "arc type" << Arc::Type() << endl; + cout << std::left << setw(50) << "fst type"; + for (set<string>::const_iterator iter = fst_types.begin(); + iter != fst_types.end(); + ++iter) { + if (iter != fst_types.begin()) + cout << ","; + cout << *iter; + } + cout << endl; + cout << std::left << setw(50) << "# of FSTs" << nfst << endl; + cout << std::left << setw(50) << "total # of states" << nstate << endl; + cout << std::left << setw(50) << "total # of arcs" << narc << endl; + + } else { + int wkey = 10, wtype = 10, wnstate = 16, wnarc = 16; + for (size_t i = 0; i < infos->size(); ++i) { + const KeyInfo &info = (*infos)[i]; + if (info.key.size() + 2 > wkey) + wkey = info.key.size() + 2; + if (info.type.size() + 2 > wtype) + wtype = info.type.size() + 2; + if (ceil(log10(info.nstate)) + 2 > wnstate) + wnstate = ceil(log10(info.nstate)) + 2; + if (ceil(log10(info.narc)) + 2 > wnarc) + wnarc = ceil(log10(info.narc)) + 2; + } + + cout << std::left << setw(wkey) << "key" << setw(wtype) << "type" + << std::right << setw(wnstate) << "# of states" + << setw(wnarc) << "# of arcs" << endl; + + for (size_t i = 0; i < infos->size(); ++i) { + const KeyInfo &info = (*infos)[i]; + cout << std::left << setw(wkey) << info.key << setw(wtype) << info.type + << std::right << setw(wnstate) << info.nstate + << setw(wnarc) << info.narc << endl; + } + } +} + +} // namespace fst + + +#endif // FST_EXTENSIONS_FAR_INFO_H_ diff --git a/kaldi_io/src/tools/openfst/include/fst/extensions/far/main.h b/kaldi_io/src/tools/openfst/include/fst/extensions/far/main.h new file mode 100644 index 0000000..00ccfef --- /dev/null +++ b/kaldi_io/src/tools/openfst/include/fst/extensions/far/main.h @@ -0,0 +1,43 @@ +// main.h + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Copyright 2005-2010 Google, Inc. +// Author: [email protected] (Michael Riley) +// +// \file +// Classes and functions for registering and invoking Far main +// functions that support multiple and extensible arc types. + +#ifndef FST_EXTENSIONS_FAR_MAIN_H__ +#define FST_EXTENSIONS_FAR_MAIN_H__ + +#include <fst/extensions/far/far.h> + +namespace fst { + +FarEntryType StringToFarEntryType(const string &s); +FarTokenType StringToFarTokenType(const string &s); + +// Return the 'FarType' value corresponding to a far type name. +FarType FarTypeFromString(const string &str); + +// Return the textual name corresponding to a 'FarType;. +string FarTypeToString(FarType type); + +string LoadArcTypeFromFar(const string& far_fname); +string LoadArcTypeFromFst(const string& far_fname); + +} // namespace fst + +#endif // FST_EXTENSIONS_FAR_MAIN_H__ diff --git a/kaldi_io/src/tools/openfst/include/fst/extensions/far/print-strings.h b/kaldi_io/src/tools/openfst/include/fst/extensions/far/print-strings.h new file mode 100644 index 0000000..dcc7351 --- /dev/null +++ b/kaldi_io/src/tools/openfst/include/fst/extensions/far/print-strings.h @@ -0,0 +1,138 @@ +// printstrings-main.h + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Copyright 2005-2010 Google, Inc. +// Author: [email protected] (Cyril Allauzen) +// Modified by: [email protected] (Jake Ratkiewicz) +// +// \file +// Output as strings the string FSTs in a finite-state archive. + +#ifndef FST_EXTENSIONS_FAR_PRINT_STRINGS_H__ +#define FST_EXTENSIONS_FAR_PRINT_STRINGS_H__ + +#include <string> +#include <vector> +using std::vector; + +#include <fst/extensions/far/far.h> +#include <fst/shortest-distance.h> +#include <fst/string.h> + +DECLARE_string(far_field_separator); + +namespace fst { + +template <class Arc> +void FarPrintStrings( + const vector<string> &ifilenames, const FarEntryType entry_type, + const FarTokenType far_token_type, const string &begin_key, + const string &end_key, const bool print_key, const bool print_weight, + const string &symbols_fname, const bool initial_symbols, + const int32 generate_filenames, + const string &filename_prefix, const string &filename_suffix) { + + typename StringPrinter<Arc>::TokenType token_type; + if (far_token_type == FTT_SYMBOL) { + token_type = StringPrinter<Arc>::SYMBOL; + } else if (far_token_type == FTT_BYTE) { + token_type = StringPrinter<Arc>::BYTE; + } else if (far_token_type == FTT_UTF8) { + token_type = StringPrinter<Arc>::UTF8; + } else { + FSTERROR() << "FarPrintStrings: unknown token type"; + return; + } + + const SymbolTable *syms = 0; + if (!symbols_fname.empty()) { + // allow negative flag? + SymbolTableTextOptions opts; + opts.allow_negative = true; + syms = SymbolTable::ReadText(symbols_fname, opts); + if (!syms) { + FSTERROR() << "FarPrintStrings: error reading symbol table: " + << symbols_fname; + return; + } + } + + FarReader<Arc> *far_reader = FarReader<Arc>::Open(ifilenames); + if (!far_reader) return; + + if (!begin_key.empty()) + far_reader->Find(begin_key); + + string okey; + int nrep = 0; + for (int i = 1; !far_reader->Done(); far_reader->Next(), ++i) { + string key = far_reader->GetKey(); + if (!end_key.empty() && end_key < key) + break; + if (okey == key) + ++nrep; + else + nrep = 0; + okey = key; + + const Fst<Arc> &fst = far_reader->GetFst(); + if (i == 1 && initial_symbols && syms == 0 && fst.InputSymbols() != 0) + syms = fst.InputSymbols()->Copy(); + string str; + VLOG(2) << "Handling key: " << key; + StringPrinter<Arc> string_printer( + token_type, syms ? syms : fst.InputSymbols()); + string_printer(fst, &str); + + if (entry_type == FET_LINE) { + if (print_key) + cout << key << FLAGS_far_field_separator[0]; + cout << str; + if (print_weight) + cout << FLAGS_far_field_separator[0] << ShortestDistance(fst); + cout << endl; + } else if (entry_type == FET_FILE) { + stringstream sstrm; + if (generate_filenames) { + sstrm.fill('0'); + sstrm << std::right << setw(generate_filenames) << i; + } else { + sstrm << key; + if (nrep > 0) + sstrm << "." << nrep; + } + + string filename; + filename = filename_prefix + sstrm.str() + filename_suffix; + + ofstream ostrm(filename.c_str()); + if (!ostrm) { + FSTERROR() << "FarPrintStrings: Can't open file:" << filename; + delete syms; + delete far_reader; + return; + } + ostrm << str; + if (token_type == StringPrinter<Arc>::SYMBOL) + ostrm << "\n"; + } + } + delete syms; +} + + + +} // namespace fst + +#endif // FST_EXTENSIONS_FAR_PRINT_STRINGS_H__ diff --git a/kaldi_io/src/tools/openfst/include/fst/extensions/far/stlist.h b/kaldi_io/src/tools/openfst/include/fst/extensions/far/stlist.h new file mode 100644 index 0000000..ff3d98b --- /dev/null +++ b/kaldi_io/src/tools/openfst/include/fst/extensions/far/stlist.h @@ -0,0 +1,305 @@ + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Copyright 2005-2010 Google, Inc. +// Author: [email protected] (Cyril Allauzen) +// +// \file +// A generic (string,type) list file format. +// +// This is a stripped-down version of STTable that does +// not support the Find() operation but that does support +// reading/writting from standard in/out. + +#ifndef FST_EXTENSIONS_FAR_STLIST_H_ +#define FST_EXTENSIONS_FAR_STLIST_H_ + +#include <iostream> +#include <fstream> +#include <sstream> +#include <fst/util.h> + +#include <algorithm> +#include <functional> +#include <queue> +#include <string> +#include <utility> +using std::pair; using std::make_pair; +#include <vector> +using std::vector; + +namespace fst { + +static const int32 kSTListMagicNumber = 5656924; +static const int32 kSTListFileVersion = 1; + +// String-type list writing class for object of type 'T' using functor 'W' +// to write an object of type 'T' from a stream. 'W' must conform to the +// following interface: +// +// struct Writer { +// void operator()(ostream &, const T &) const; +// }; +// +template <class T, class W> +class STListWriter { + public: + typedef T EntryType; + typedef W EntryWriter; + + explicit STListWriter(const string filename) + : stream_( + filename.empty() ? &cout : + new ofstream(filename.c_str(), ofstream::out | ofstream::binary)), + error_(false) { + WriteType(*stream_, kSTListMagicNumber); + WriteType(*stream_, kSTListFileVersion); + if (!stream_) { + FSTERROR() << "STListWriter::STListWriter: error writing to file: " + << filename; + error_ = true; + } + } + + static STListWriter<T, W> *Create(const string &filename) { + return new STListWriter<T, W>(filename); + } + + void Add(const string &key, const T &t) { + if (key == "") { + FSTERROR() << "STListWriter::Add: key empty: " << key; + error_ = true; + } else if (key < last_key_) { + FSTERROR() << "STListWriter::Add: key disorder: " << key; + error_ = true; + } + if (error_) return; + last_key_ = key; + WriteType(*stream_, key); + entry_writer_(*stream_, t); + } + + bool Error() const { return error_; } + + ~STListWriter() { + WriteType(*stream_, string()); + if (stream_ != &cout) + delete stream_; + } + + private: + EntryWriter entry_writer_; // Write functor for 'EntryType' + ostream *stream_; // Output stream + string last_key_; // Last key + bool error_; + + DISALLOW_COPY_AND_ASSIGN(STListWriter); +}; + + +// String-type list reading class for object of type 'T' using functor 'R' +// to read an object of type 'T' form a stream. 'R' must conform to the +// following interface: +// +// struct Reader { +// T *operator()(istream &) const; +// }; +// +template <class T, class R> +class STListReader { + public: + typedef T EntryType; + typedef R EntryReader; + + explicit STListReader(const vector<string> &filenames) + : sources_(filenames), entry_(0), error_(false) { + streams_.resize(filenames.size(), 0); + bool has_stdin = false; + for (size_t i = 0; i < filenames.size(); ++i) { + if (filenames[i].empty()) { + if (!has_stdin) { + streams_[i] = &cin; + sources_[i] = "stdin"; + has_stdin = true; + } else { + FSTERROR() << "STListReader::STListReader: stdin should only " + << "appear once in the input file list."; + error_ = true; + return; + } + } else { + streams_[i] = new ifstream( + filenames[i].c_str(), ifstream::in | ifstream::binary); + } + int32 magic_number = 0, file_version = 0; + ReadType(*streams_[i], &magic_number); + ReadType(*streams_[i], &file_version); + if (magic_number != kSTListMagicNumber) { + FSTERROR() << "STListReader::STListReader: wrong file type: " + << filenames[i]; + error_ = true; + return; + } + if (file_version != kSTListFileVersion) { + FSTERROR() << "STListReader::STListReader: wrong file version: " + << filenames[i]; + error_ = true; + return; + } + string key; + ReadType(*streams_[i], &key); + if (!key.empty()) + heap_.push(make_pair(key, i)); + if (!*streams_[i]) { + FSTERROR() << "STListReader: error reading file: " << sources_[i]; + error_ = true; + return; + } + } + if (heap_.empty()) return; + size_t current = heap_.top().second; + entry_ = entry_reader_(*streams_[current]); + if (!entry_ || !*streams_[current]) { + FSTERROR() << "STListReader: error reading entry for key: " + << heap_.top().first << ", file: " << sources_[current]; + error_ = true; + } + } + + ~STListReader() { + for (size_t i = 0; i < streams_.size(); ++i) { + if (streams_[i] != &cin) + delete streams_[i]; + } + if (entry_) + delete entry_; + } + + static STListReader<T, R> *Open(const string &filename) { + vector<string> filenames; + filenames.push_back(filename); + return new STListReader<T, R>(filenames); + } + + static STListReader<T, R> *Open(const vector<string> &filenames) { + return new STListReader<T, R>(filenames); + } + + void Reset() { + FSTERROR() + << "STListReader::Reset: stlist does not support reset operation"; + error_ = true; + } + + bool Find(const string &key) { + FSTERROR() + << "STListReader::Find: stlist does not support find operation"; + error_ = true; + return false; + } + + bool Done() const { + return error_ || heap_.empty(); + } + + void Next() { + if (error_) return; + size_t current = heap_.top().second; + string key; + heap_.pop(); + ReadType(*(streams_[current]), &key); + if (!*streams_[current]) { + FSTERROR() << "STListReader: error reading file: " + << sources_[current]; + error_ = true; + return; + } + if (!key.empty()) + heap_.push(make_pair(key, current)); + + if(!heap_.empty()) { + current = heap_.top().second; + if (entry_) + delete entry_; + entry_ = entry_reader_(*streams_[current]); + if (!entry_ || !*streams_[current]) { + FSTERROR() << "STListReader: error reading entry for key: " + << heap_.top().first << ", file: " << sources_[current]; + error_ = true; + } + } + } + + const string &GetKey() const { + return heap_.top().first; + } + + const EntryType &GetEntry() const { + return *entry_; + } + + bool Error() const { return error_; } + + private: + EntryReader entry_reader_; // Read functor for 'EntryType' + vector<istream*> streams_; // Input streams + vector<string> sources_; // and corresponding file names + priority_queue< + pair<string, size_t>, vector<pair<string, size_t> >, + greater<pair<string, size_t> > > heap_; // (Key, stream id) heap + mutable EntryType *entry_; // Pointer to the currently read entry + bool error_; + + DISALLOW_COPY_AND_ASSIGN(STListReader); +}; + + +// String-type list header reading function template on the entry header +// type 'H' having a member function: +// Read(istream &strm, const string &filename); +// Checks that 'filename' is an STList and call the H::Read() on the last +// entry in the STList. +// Does not support reading from stdin. +template <class H> +bool ReadSTListHeader(const string &filename, H *header) { + if (filename.empty()) { + LOG(ERROR) << "ReadSTListHeader: reading header not supported on stdin"; + return false; + } + ifstream strm(filename.c_str(), ifstream::in | ifstream::binary); + int32 magic_number = 0, file_version = 0; + ReadType(strm, &magic_number); + ReadType(strm, &file_version); + if (magic_number != kSTListMagicNumber) { + LOG(ERROR) << "ReadSTListHeader: wrong file type: " << filename; + return false; + } + if (file_version != kSTListFileVersion) { + LOG(ERROR) << "ReadSTListHeader: wrong file version: " << filename; + return false; + } + string key; + ReadType(strm, &key); + header->Read(strm, filename + ":" + key); + if (!strm) { + LOG(ERROR) << "ReadSTListHeader: error reading file: " << filename; + return false; + } + return true; +} + +bool IsSTList(const string &filename); + +} // namespace fst + +#endif // FST_EXTENSIONS_FAR_STLIST_H_ diff --git a/kaldi_io/src/tools/openfst/include/fst/extensions/far/sttable.h b/kaldi_io/src/tools/openfst/include/fst/extensions/far/sttable.h new file mode 100644 index 0000000..3ce0a4b --- /dev/null +++ b/kaldi_io/src/tools/openfst/include/fst/extensions/far/sttable.h @@ -0,0 +1,371 @@ +// sttable.h + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Copyright 2005-2010 Google, Inc. +// Author: [email protected] (Cyril Allauzen) +// +// \file +// A generic string-to-type table file format +// +// This is not meant as a generalization of SSTable. This is more of +// a simple replacement for SSTable in order to provide an open-source +// implementation of the FAR format for the external version of the +// FST Library. + +#ifndef FST_EXTENSIONS_FAR_STTABLE_H_ +#define FST_EXTENSIONS_FAR_STTABLE_H_ + +#include <algorithm> +#include <iostream> +#include <fstream> +#include <sstream> +#include <fst/util.h> + +namespace fst { + +static const int32 kSTTableMagicNumber = 2125656924; +static const int32 kSTTableFileVersion = 1; + +// String-to-type table writing class for object of type 'T' using functor 'W' +// to write an object of type 'T' from a stream. 'W' must conform to the +// following interface: +// +// struct Writer { +// void operator()(ostream &, const T &) const; +// }; +// +template <class T, class W> +class STTableWriter { + public: + typedef T EntryType; + typedef W EntryWriter; + + explicit STTableWriter(const string &filename) + : stream_(filename.c_str(), ofstream::out | ofstream::binary), + error_(false) { + WriteType(stream_, kSTTableMagicNumber); + WriteType(stream_, kSTTableFileVersion); + if (!stream_) { + FSTERROR() << "STTableWriter::STTableWriter: error writing to file: " + << filename; + error_=true; + } + } + + static STTableWriter<T, W> *Create(const string &filename) { + if (filename.empty()) { + LOG(ERROR) << "STTableWriter: writing to standard out unsupported."; + return 0; + } + return new STTableWriter<T, W>(filename); + } + + void Add(const string &key, const T &t) { + if (key == "") { + FSTERROR() << "STTableWriter::Add: key empty: " << key; + error_ = true; + } else if (key < last_key_) { + FSTERROR() << "STTableWriter::Add: key disorder: " << key; + error_ = true; + } + if (error_) return; + last_key_ = key; + positions_.push_back(stream_.tellp()); + WriteType(stream_, key); + entry_writer_(stream_, t); + } + + bool Error() const { return error_; } + + ~STTableWriter() { + WriteType(stream_, positions_); + WriteType(stream_, static_cast<int64>(positions_.size())); + } + + private: + EntryWriter entry_writer_; // Write functor for 'EntryType' + ofstream stream_; // Output stream + vector<int64> positions_; // Position in file of each key-entry pair + string last_key_; // Last key + bool error_; + + DISALLOW_COPY_AND_ASSIGN(STTableWriter); +}; + + +// String-to-type table reading class for object of type 'T' using functor 'R' +// to read an object of type 'T' form a stream. 'R' must conform to the +// following interface: +// +// struct Reader { +// T *operator()(istream &) const; +// }; +// +template <class T, class R> +class STTableReader { + public: + typedef T EntryType; + typedef R EntryReader; + + explicit STTableReader(const vector<string> &filenames) + : sources_(filenames), entry_(0), error_(false) { + compare_ = new Compare(&keys_); + keys_.resize(filenames.size()); + streams_.resize(filenames.size(), 0); + positions_.resize(filenames.size()); + for (size_t i = 0; i < filenames.size(); ++i) { + streams_[i] = new ifstream( + filenames[i].c_str(), ifstream::in | ifstream::binary); + int32 magic_number = 0, file_version = 0; + ReadType(*streams_[i], &magic_number); + ReadType(*streams_[i], &file_version); + if (magic_number != kSTTableMagicNumber) { + FSTERROR() << "STTableReader::STTableReader: wrong file type: " + << filenames[i]; + error_ = true; + return; + } + if (file_version != kSTTableFileVersion) { + FSTERROR() << "STTableReader::STTableReader: wrong file version: " + << filenames[i]; + error_ = true; + return; + } + int64 num_entries; + streams_[i]->seekg(-static_cast<int>(sizeof(int64)), ios_base::end); + ReadType(*streams_[i], &num_entries); + streams_[i]->seekg(-static_cast<int>(sizeof(int64)) * + (num_entries + 1), ios_base::end); + positions_[i].resize(num_entries); + for (size_t j = 0; (j < num_entries) && (*streams_[i]); ++j) + ReadType(*streams_[i], &(positions_[i][j])); + streams_[i]->seekg(positions_[i][0]); + if (!*streams_[i]) { + FSTERROR() << "STTableReader::STTableReader: error reading file: " + << filenames[i]; + error_ = true; + return; + } + + } + MakeHeap(); + } + + ~STTableReader() { + for (size_t i = 0; i < streams_.size(); ++i) + delete streams_[i]; + delete compare_; + if (entry_) + delete entry_; + } + + static STTableReader<T, R> *Open(const string &filename) { + if (filename.empty()) { + LOG(ERROR) << "STTableReader: reading from standard in not supported"; + return 0; + } + vector<string> filenames; + filenames.push_back(filename); + return new STTableReader<T, R>(filenames); + } + + static STTableReader<T, R> *Open(const vector<string> &filenames) { + return new STTableReader<T, R>(filenames); + } + + void Reset() { + if (error_) return; + for (size_t i = 0; i < streams_.size(); ++i) + streams_[i]->seekg(positions_[i].front()); + MakeHeap(); + } + + bool Find(const string &key) { + if (error_) return false; + for (size_t i = 0; i < streams_.size(); ++i) + LowerBound(i, key); + MakeHeap(); + return keys_[current_] == key; + } + + bool Done() const { return error_ || heap_.empty(); } + + void Next() { + if (error_) return; + if (streams_[current_]->tellg() <= positions_[current_].back()) { + ReadType(*(streams_[current_]), &(keys_[current_])); + if (!*streams_[current_]) { + FSTERROR() << "STTableReader: error reading file: " + << sources_[current_]; + error_ = true; + return; + } + push_heap(heap_.begin(), heap_.end(), *compare_); + } else { + heap_.pop_back(); + } + if (!heap_.empty()) + PopHeap(); + } + + const string &GetKey() const { + return keys_[current_]; + } + + const EntryType &GetEntry() const { + return *entry_; + } + + bool Error() const { return error_; } + + private: + // Comparison functor used to compare stream IDs in the heap + struct Compare { + Compare(const vector<string> *keys) : keys_(keys) {} + + bool operator()(size_t i, size_t j) const { + return (*keys_)[i] > (*keys_)[j]; + }; + + private: + const vector<string> *keys_; + }; + + // Position the stream with ID 'id' at the position corresponding + // to the lower bound for key 'find_key' + void LowerBound(size_t id, const string &find_key) { + ifstream *strm = streams_[id]; + const vector<int64> &positions = positions_[id]; + size_t low = 0, high = positions.size() - 1; + + while (low < high) { + size_t mid = (low + high)/2; + strm->seekg(positions[mid]); + string key; + ReadType(*strm, &key); + if (key > find_key) { + high = mid; + } else if (key < find_key) { + low = mid + 1; + } else { + for (size_t i = mid; i > low; --i) { + strm->seekg(positions[i - 1]); + ReadType(*strm, &key); + if (key != find_key) { + strm->seekg(positions[i]); + return; + } + } + strm->seekg(positions[low]); + return; + } + } + strm->seekg(positions[low]); + } + + // Add all streams to the heap + void MakeHeap() { + heap_.clear(); + for (size_t i = 0; i < streams_.size(); ++i) { + ReadType(*streams_[i], &(keys_[i])); + if (!*streams_[i]) { + FSTERROR() << "STTableReader: error reading file: " << sources_[i]; + error_ = true; + return; + } + heap_.push_back(i); + } + make_heap(heap_.begin(), heap_.end(), *compare_); + PopHeap(); + } + + // Position the stream with the lowest key at the top + // of the heap, set 'current_' to the ID of that stream + // and read the current entry from that stream + void PopHeap() { + pop_heap(heap_.begin(), heap_.end(), *compare_); + current_ = heap_.back(); + if (entry_) + delete entry_; + entry_ = entry_reader_(*streams_[current_]); + if (!entry_) + error_ = true; + if (!*streams_[current_]) { + FSTERROR() << "STTableReader: error reading entry for key: " + << keys_[current_] << ", file: " << sources_[current_]; + error_ = true; + } + } + + + EntryReader entry_reader_; // Read functor for 'EntryType' + vector<ifstream*> streams_; // Input streams + vector<string> sources_; // and corresponding file names + vector<vector<int64> > positions_; // Index of positions for each stream + vector<string> keys_; // Lowest unread key for each stream + vector<int64> heap_; // Heap containing ID of streams with unread keys + int64 current_; // Id of current stream to be read + Compare *compare_; // Functor comparing stream IDs for the heap + mutable EntryType *entry_; // Pointer to the currently read entry + bool error_; + + DISALLOW_COPY_AND_ASSIGN(STTableReader); +}; + + +// String-to-type table header reading function template on the entry header +// type 'H' having a member function: +// Read(istream &strm, const string &filename); +// Checks that 'filename' is an STTable and call the H::Read() on the last +// entry in the STTable. +template <class H> +bool ReadSTTableHeader(const string &filename, H *header) { + ifstream strm(filename.c_str(), ifstream::in | ifstream::binary); + int32 magic_number = 0, file_version = 0; + ReadType(strm, &magic_number); + ReadType(strm, &file_version); + if (magic_number != kSTTableMagicNumber) { + LOG(ERROR) << "ReadSTTableHeader: wrong file type: " << filename; + return false; + } + if (file_version != kSTTableFileVersion) { + LOG(ERROR) << "ReadSTTableHeader: wrong file version: " << filename; + return false; + } + int64 i = -1; + strm.seekg(-static_cast<int>(sizeof(int64)), ios_base::end); + ReadType(strm, &i); // Read number of entries + if (!strm) { + LOG(ERROR) << "ReadSTTableHeader: error reading file: " << filename; + return false; + } + if (i == 0) return true; // No entry header to read + strm.seekg(-2 * static_cast<int>(sizeof(int64)), ios_base::end); + ReadType(strm, &i); // Read position for last entry in file + strm.seekg(i); + string key; + ReadType(strm, &key); + header->Read(strm, filename + ":" + key); + if (!strm) { + LOG(ERROR) << "ReadSTTableHeader: error reading file: " << filename; + return false; + } + return true; +} + +bool IsSTTable(const string &filename); + +} // namespace fst + +#endif // FST_EXTENSIONS_FAR_STTABLE_H_ diff --git a/kaldi_io/src/tools/openfst/include/fst/extensions/ngram/bitmap-index.h b/kaldi_io/src/tools/openfst/include/fst/extensions/ngram/bitmap-index.h new file mode 100644 index 0000000..f5a5ba7 --- /dev/null +++ b/kaldi_io/src/tools/openfst/include/fst/extensions/ngram/bitmap-index.h @@ -0,0 +1,183 @@ + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Copyright 2005-2010 Google, Inc. +// Author: [email protected] (Jeffrey Sorensen) + +#ifndef FST_EXTENSIONS_NGRAM_BITMAP_INDEX_H_ +#define FST_EXTENSIONS_NGRAM_BITMAP_INDEX_H_ + +#include <vector> +using std::vector; + +#include <fst/compat.h> + +// This class is a bitstring storage class with an index that allows +// seeking to the Nth set or clear bit in time O(Log(N)) where N is +// the length of the bit vector. In addition, it allows counting set or +// clear bits over ranges in constant time. +// +// This is accomplished by maintaining an "secondary" index of limited +// size in bits that maintains a running count of the number of bits set +// in each block of bitmap data. A block is defined as the number of +// uint64 values that can fit in the secondary index before an overflow +// occurs. +// +// To handle overflows, a "primary" index containing a running count of +// bits set in each block is created using the type uint64. + +namespace fst { + +class BitmapIndex { + public: + static size_t StorageSize(size_t size) { + return ((size + kStorageBlockMask) >> kStorageLogBitSize); + } + + BitmapIndex() : bits_(NULL), size_(0) { } + + bool Get(size_t index) const { + return (bits_[index >> kStorageLogBitSize] & + (kOne << (index & kStorageBlockMask))) != 0; + } + + static void Set(uint64* bits, size_t index) { + bits[index >> kStorageLogBitSize] |= (kOne << (index & kStorageBlockMask)); + } + + static void Clear(uint64* bits, size_t index) { + bits[index >> kStorageLogBitSize] &= ~(kOne << (index & kStorageBlockMask)); + } + + size_t Bits() const { + return size_; + } + + size_t ArraySize() const { + return StorageSize(size_); + } + + // Returns the number of one bits in the bitmap + size_t GetOnesCount() const { + return primary_index_[primary_index_size() - 1]; + } + + // Returns the number of one bits in positions 0 to limit - 1. + // REQUIRES: limit <= Bits() + size_t Rank1(size_t end) const; + + // Returns the number of one bits in the range start to end - 1. + // REQUIRES: limit <= Bits() + size_t GetOnesCountInRange(size_t start, size_t end) const { + return Rank1(end) - Rank1(start); + } + + // Returns the number of zero bits in positions 0 to limit - 1. + // REQUIRES: limit <= Bits() + size_t Rank0(size_t end) const { + return end - Rank1(end); + } + + // Returns the number of zero bits in the range start to end - 1. + // REQUIRES: limit <= Bits() + size_t GetZeroesCountInRange(size_t start, size_t end) const { + return end - start - GetOnesCountInRange(start, end); + } + + // Return true if any bit between begin inclusive and end exclusive + // is set. 0 <= begin <= end <= Bits() is required. + // + bool TestRange(size_t start, size_t end) const { + return Rank1(end) > Rank1(start); + } + + // Returns the offset to the nth set bit (zero based) + // or Bits() if index >= number of ones + size_t Select1(size_t bit_index) const; + + // Returns the offset to the nth clear bit (zero based) + // or Bits() if index > number of + size_t Select0(size_t bit_index) const; + + // Rebuilds from index for the associated Bitmap, should be called + // whenever changes have been made to the Bitmap or else behavior + // of the indexed bitmap methods will be undefined. + void BuildIndex(const uint64 *bits, size_t size); + + // the secondary index accumulates counts until it can possibly overflow + // this constant computes the number of uint64 units that can fit into + // units the size of uint16. + static const uint64 kOne = 1; + static const uint32 kStorageBitSize = 64; + static const uint32 kStorageLogBitSize = 6; + static const uint32 kSecondaryBlockSize = ((1 << 16) - 1) + >> kStorageLogBitSize; + + private: + static const uint32 kStorageBlockMask = kStorageBitSize - 1; + + // returns, from the index, the count of ones up to array_index + size_t get_index_ones_count(size_t array_index) const; + + // because the indexes, both primary and secondary, contain a running + // count of the population of one bits contained in [0,i), there is + // no reason to have an element in the zeroth position as this value would + // necessarily be zero. (The bits are indexed in a zero based way.) Thus + // we don't store the 0th element in either index. Both of the following + // functions, if greater than 0, must be decremented by one before retreiving + // the value from the corresponding array. + // returns the 1 + the block that contains the bitindex in question + // the inverted version works the same but looks for zeros using an inverted + // view of the index + size_t find_primary_block(size_t bit_index) const; + + size_t find_inverted_primary_block(size_t bit_index) const; + + // similarly, the secondary index (which resets its count to zero at + // the end of every kSecondaryBlockSize entries) does not store the element + // at 0. Note that the rem_bit_index parameter is the number of bits + // within the secondary block, after the bits accounted for by the primary + // block have been removed (i.e. the remaining bits) And, because we + // reset to zero with each new block, there is no need to store those + // actual zeros. + // returns 1 + the secondary block that contains the bitindex in question + size_t find_secondary_block(size_t block, size_t rem_bit_index) const; + + size_t find_inverted_secondary_block(size_t block, size_t rem_bit_index) + const; + + // We create a primary index based upon the number of secondary index + // blocks. The primary index uses fields wide enough to accomodate any + // index of the bitarray so cannot overflow + // The primary index is the actual running + // count of one bits set for all blocks (and, thus, all uint64s). + size_t primary_index_size() const { + return (ArraySize() + kSecondaryBlockSize - 1) / kSecondaryBlockSize; + } + + const uint64* bits_; + size_t size_; + + // The primary index contains the running popcount of all blocks + // which means the nth value contains the popcounts of + // [0,n*kSecondaryBlockSize], however, the 0th element is omitted. + vector<uint32> primary_index_; + // The secondary index contains the running popcount of the associated + // bitmap. It is the same length (in units of uint16) as the + // bitmap's map is in units of uint64s. + vector<uint16> secondary_index_; +}; + +} // end namespace fst + +#endif // FST_EXTENSIONS_NGRAM_BITMAP_INDEX_H_ diff --git a/kaldi_io/src/tools/openfst/include/fst/extensions/ngram/ngram-fst.h b/kaldi_io/src/tools/openfst/include/fst/extensions/ngram/ngram-fst.h new file mode 100644 index 0000000..d113fb3 --- /dev/null +++ b/kaldi_io/src/tools/openfst/include/fst/extensions/ngram/ngram-fst.h @@ -0,0 +1,934 @@ + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Copyright 2005-2010 Google, Inc. +// Author: [email protected] (Jeffrey Sorensen) +// +#ifndef FST_EXTENSIONS_NGRAM_NGRAM_FST_H_ +#define FST_EXTENSIONS_NGRAM_NGRAM_FST_H_ + +#include <stddef.h> +#include <string.h> +#include <algorithm> +#include <string> +#include <vector> +using std::vector; + +#include <fst/compat.h> +#include <fst/fstlib.h> +#include <fst/mapped-file.h> +#include <fst/extensions/ngram/bitmap-index.h> + +// NgramFst implements a n-gram language model based upon the LOUDS data +// structure. Please refer to "Unary Data Strucutres for Language Models" +// http://research.google.com/pubs/archive/37218.pdf + +namespace fst { +template <class A> class NGramFst; +template <class A> class NGramFstMatcher; + +// Instance data containing mutable state for bookkeeping repeated access to +// the same state. +template <class A> +struct NGramFstInst { + typedef typename A::Label Label; + typedef typename A::StateId StateId; + typedef typename A::Weight Weight; + StateId state_; + size_t num_futures_; + size_t offset_; + size_t node_; + StateId node_state_; + vector<Label> context_; + StateId context_state_; + NGramFstInst() + : state_(kNoStateId), node_state_(kNoStateId), + context_state_(kNoStateId) { } +}; + +// Implementation class for LOUDS based NgramFst interface +template <class A> +class NGramFstImpl : public FstImpl<A> { + using FstImpl<A>::SetInputSymbols; + using FstImpl<A>::SetOutputSymbols; + using FstImpl<A>::SetType; + using FstImpl<A>::WriteHeader; + + friend class ArcIterator<NGramFst<A> >; + friend class NGramFstMatcher<A>; + + public: + using FstImpl<A>::InputSymbols; + using FstImpl<A>::SetProperties; + using FstImpl<A>::Properties; + + typedef A Arc; + typedef typename A::Label Label; + typedef typename A::StateId StateId; + typedef typename A::Weight Weight; + + NGramFstImpl() : data_region_(0), data_(0), owned_(false) { + SetType("ngram"); + SetInputSymbols(NULL); + SetOutputSymbols(NULL); + SetProperties(kStaticProperties); + } + + NGramFstImpl(const Fst<A> &fst, vector<StateId>* order_out); + + ~NGramFstImpl() { + if (owned_) { + delete [] data_; + } + delete data_region_; + } + + static NGramFstImpl<A>* Read(istream &strm, // NOLINT + const FstReadOptions &opts) { + NGramFstImpl<A>* impl = new NGramFstImpl(); + FstHeader hdr; + if (!impl->ReadHeader(strm, opts, kMinFileVersion, &hdr)) return 0; + uint64 num_states, num_futures, num_final; + const size_t offset = sizeof(num_states) + sizeof(num_futures) + + sizeof(num_final); + // Peek at num_states and num_futures to see how much more needs to be read. + strm.read(reinterpret_cast<char *>(&num_states), sizeof(num_states)); + strm.read(reinterpret_cast<char *>(&num_futures), sizeof(num_futures)); + strm.read(reinterpret_cast<char *>(&num_final), sizeof(num_final)); + size_t size = Storage(num_states, num_futures, num_final); + MappedFile *data_region = MappedFile::Allocate(size); + char *data = reinterpret_cast<char *>(data_region->mutable_data()); + // Copy num_states, num_futures and num_final back into data. + memcpy(data, reinterpret_cast<char *>(&num_states), sizeof(num_states)); + memcpy(data + sizeof(num_states), reinterpret_cast<char *>(&num_futures), + sizeof(num_futures)); + memcpy(data + sizeof(num_states) + sizeof(num_futures), + reinterpret_cast<char *>(&num_final), sizeof(num_final)); + strm.read(data + offset, size - offset); + if (!strm) { + delete impl; + return NULL; + } + impl->Init(data, false, data_region); + return impl; + } + + bool Write(ostream &strm, // NOLINT + const FstWriteOptions &opts) const { + FstHeader hdr; + hdr.SetStart(Start()); + hdr.SetNumStates(num_states_); + WriteHeader(strm, opts, kFileVersion, &hdr); + strm.write(data_, Storage(num_states_, num_futures_, num_final_)); + return strm; + } + + StateId Start() const { + return 1; + } + + Weight Final(StateId state) const { + if (final_index_.Get(state)) { + return final_probs_[final_index_.Rank1(state)]; + } else { + return Weight::Zero(); + } + } + + size_t NumArcs(StateId state, NGramFstInst<A> *inst = NULL) const { + if (inst == NULL) { + const size_t next_zero = future_index_.Select0(state + 1); + const size_t this_zero = future_index_.Select0(state); + return next_zero - this_zero - 1; + } + SetInstFuture(state, inst); + return inst->num_futures_ + ((state == 0) ? 0 : 1); + } + + size_t NumInputEpsilons(StateId state) const { + // State 0 has no parent, thus no backoff. + if (state == 0) return 0; + return 1; + } + + size_t NumOutputEpsilons(StateId state) const { + return NumInputEpsilons(state); + } + + StateId NumStates() const { + return num_states_; + } + + void InitStateIterator(StateIteratorData<A>* data) const { + data->base = 0; + data->nstates = num_states_; + } + + static size_t Storage(uint64 num_states, uint64 num_futures, + uint64 num_final) { + uint64 b64; + Weight weight; + Label label; + size_t offset = sizeof(num_states) + sizeof(num_futures) + + sizeof(num_final); + offset += sizeof(b64) * ( + BitmapIndex::StorageSize(num_states * 2 + 1) + + BitmapIndex::StorageSize(num_futures + num_states + 1) + + BitmapIndex::StorageSize(num_states)); + offset += (num_states + 1) * sizeof(label) + num_futures * sizeof(label); + // Pad for alignemnt, see + // http://en.wikipedia.org/wiki/Data_structure_alignment#Computing_padding + offset = (offset + sizeof(weight) - 1) & ~(sizeof(weight) - 1); + offset += (num_states + 1) * sizeof(weight) + num_final * sizeof(weight) + + (num_futures + 1) * sizeof(weight); + return offset; + } + + void SetInstFuture(StateId state, NGramFstInst<A> *inst) const { + if (inst->state_ != state) { + inst->state_ = state; + const size_t next_zero = future_index_.Select0(state + 1); + const size_t this_zero = future_index_.Select0(state); + inst->num_futures_ = next_zero - this_zero - 1; + inst->offset_ = future_index_.Rank1(future_index_.Select0(state) + 1); + } + } + + void SetInstNode(NGramFstInst<A> *inst) const { + if (inst->node_state_ != inst->state_) { + inst->node_state_ = inst->state_; + inst->node_ = context_index_.Select1(inst->state_); + } + } + + void SetInstContext(NGramFstInst<A> *inst) const { + SetInstNode(inst); + if (inst->context_state_ != inst->state_) { + inst->context_state_ = inst->state_; + inst->context_.clear(); + size_t node = inst->node_; + while (node != 0) { + inst->context_.push_back(context_words_[context_index_.Rank1(node)]); + node = context_index_.Select1(context_index_.Rank0(node) - 1); + } + } + } + + // Access to the underlying representation + const char* GetData(size_t* data_size) const { + *data_size = Storage(num_states_, num_futures_, num_final_); + return data_; + } + + void Init(const char* data, bool owned, MappedFile *file = 0); + + const vector<Label> &GetContext(StateId s, NGramFstInst<A> *inst) const { + SetInstFuture(s, inst); + SetInstContext(inst); + return inst->context_; + } + + private: + StateId Transition(const vector<Label> &context, Label future) const; + + // Properties always true for this Fst class. + static const uint64 kStaticProperties = kAcceptor | kIDeterministic | + kODeterministic | kEpsilons | kIEpsilons | kOEpsilons | kILabelSorted | + kOLabelSorted | kWeighted | kCyclic | kInitialAcyclic | kNotTopSorted | + kAccessible | kCoAccessible | kNotString | kExpanded; + // Current file format version. + static const int kFileVersion = 4; + // Minimum file format version supported. + static const int kMinFileVersion = 4; + + MappedFile *data_region_; + const char* data_; + bool owned_; // True if we own data_ + uint64 num_states_, num_futures_, num_final_; + size_t root_num_children_; + const Label *root_children_; + size_t root_first_child_; + // borrowed references + const uint64 *context_, *future_, *final_; + const Label *context_words_, *future_words_; + const Weight *backoff_, *final_probs_, *future_probs_; + BitmapIndex context_index_; + BitmapIndex future_index_; + BitmapIndex final_index_; + + void operator=(const NGramFstImpl<A> &); // Disallow +}; + +template<typename A> +NGramFstImpl<A>::NGramFstImpl(const Fst<A> &fst, vector<StateId>* order_out) + : data_region_(0), data_(0), owned_(false) { + typedef A Arc; + typedef typename Arc::Label Label; + typedef typename Arc::Weight Weight; + typedef typename Arc::StateId StateId; + SetType("ngram"); + SetInputSymbols(fst.InputSymbols()); + SetOutputSymbols(fst.OutputSymbols()); + SetProperties(kStaticProperties); + + // Check basic requirements for an OpenGRM language model Fst. + int64 props = kAcceptor | kIDeterministic | kIEpsilons | kILabelSorted; + if (fst.Properties(props, true) != props) { + FSTERROR() << "NGramFst only accepts OpenGRM langauge models as input"; + SetProperties(kError, kError); + return; + } + + int64 num_states = CountStates(fst); + Label* context = new Label[num_states]; + + // Find the unigram state by starting from the start state, following + // epsilons. + StateId unigram = fst.Start(); + while (1) { + if (unigram == kNoStateId) { + FSTERROR() << "Could not identify unigram state."; + SetProperties(kError, kError); + return; + } + ArcIterator<Fst<A> > aiter(fst, unigram); + if (aiter.Done()) { + LOG(WARNING) << "Unigram state " << unigram << " has no arcs."; + break; + } + if (aiter.Value().ilabel != 0) break; + unigram = aiter.Value().nextstate; + } + + // Each state's context is determined by the subtree it is under from the + // unigram state. + queue<pair<StateId, Label> > label_queue; + vector<bool> visited(num_states); + // Force an epsilon link to the start state. + label_queue.push(make_pair(fst.Start(), 0)); + for (ArcIterator<Fst<A> > aiter(fst, unigram); + !aiter.Done(); aiter.Next()) { + label_queue.push(make_pair(aiter.Value().nextstate, aiter.Value().ilabel)); + } + // investigate states in breadth first fashion to assign context words. + while (!label_queue.empty()) { + pair<StateId, Label> &now = label_queue.front(); + if (!visited[now.first]) { + context[now.first] = now.second; + visited[now.first] = true; + for (ArcIterator<Fst<A> > aiter(fst, now.first); + !aiter.Done(); aiter.Next()) { + const Arc &arc = aiter.Value(); + if (arc.ilabel != 0) { + label_queue.push(make_pair(arc.nextstate, now.second)); + } + } + } + label_queue.pop(); + } + visited.clear(); + + // The arc from the start state should be assigned an epsilon to put it + // in front of the all other labels (which makes Start state 1 after + // unigram which is state 0). + context[fst.Start()] = 0; + + // Build the tree of contexts fst by reversing the epsilon arcs from fst. + VectorFst<Arc> context_fst; + uint64 num_final = 0; + for (int i = 0; i < num_states; ++i) { + if (fst.Final(i) != Weight::Zero()) { + ++num_final; + } + context_fst.SetFinal(context_fst.AddState(), fst.Final(i)); + } + context_fst.SetStart(unigram); + context_fst.SetInputSymbols(fst.InputSymbols()); + context_fst.SetOutputSymbols(fst.OutputSymbols()); + int64 num_context_arcs = 0; + int64 num_futures = 0; + for (StateIterator<Fst<A> > siter(fst); !siter.Done(); siter.Next()) { + const StateId &state = siter.Value(); + num_futures += fst.NumArcs(state) - fst.NumInputEpsilons(state); + ArcIterator<Fst<A> > aiter(fst, state); + if (!aiter.Done()) { + const Arc &arc = aiter.Value(); + // this arc goes from state to arc.nextstate, so create an arc from + // arc.nextstate to state to reverse it. + if (arc.ilabel == 0) { + context_fst.AddArc(arc.nextstate, Arc(context[state], context[state], + arc.weight, state)); + num_context_arcs++; + } + } + } + if (num_context_arcs != context_fst.NumStates() - 1) { + FSTERROR() << "Number of contexts arcs != number of states - 1"; + SetProperties(kError, kError); + return; + } + if (context_fst.NumStates() != num_states) { + FSTERROR() << "Number of contexts != number of states"; + SetProperties(kError, kError); + return; + } + int64 context_props = context_fst.Properties(kIDeterministic | + kILabelSorted, true); + if (!(context_props & kIDeterministic)) { + FSTERROR() << "Input fst is not structured properly"; + SetProperties(kError, kError); + return; + } + if (!(context_props & kILabelSorted)) { + ArcSort(&context_fst, ILabelCompare<Arc>()); + } + + delete [] context; + + uint64 b64; + Weight weight; + Label label = kNoLabel; + const size_t storage = Storage(num_states, num_futures, num_final); + MappedFile *data_region = MappedFile::Allocate(storage); + char *data = reinterpret_cast<char *>(data_region->mutable_data()); + memset(data, 0, storage); + size_t offset = 0; + memcpy(data + offset, reinterpret_cast<char *>(&num_states), + sizeof(num_states)); + offset += sizeof(num_states); + memcpy(data + offset, reinterpret_cast<char *>(&num_futures), + sizeof(num_futures)); + offset += sizeof(num_futures); + memcpy(data + offset, reinterpret_cast<char *>(&num_final), + sizeof(num_final)); + offset += sizeof(num_final); + uint64* context_bits = reinterpret_cast<uint64*>(data + offset); + offset += BitmapIndex::StorageSize(num_states * 2 + 1) * sizeof(b64); + uint64* future_bits = reinterpret_cast<uint64*>(data + offset); + offset += + BitmapIndex::StorageSize(num_futures + num_states + 1) * sizeof(b64); + uint64* final_bits = reinterpret_cast<uint64*>(data + offset); + offset += BitmapIndex::StorageSize(num_states) * sizeof(b64); + Label* context_words = reinterpret_cast<Label*>(data + offset); + offset += (num_states + 1) * sizeof(label); + Label* future_words = reinterpret_cast<Label*>(data + offset); + offset += num_futures * sizeof(label); + offset = (offset + sizeof(weight) - 1) & ~(sizeof(weight) - 1); + Weight* backoff = reinterpret_cast<Weight*>(data + offset); + offset += (num_states + 1) * sizeof(weight); + Weight* final_probs = reinterpret_cast<Weight*>(data + offset); + offset += num_final * sizeof(weight); + Weight* future_probs = reinterpret_cast<Weight*>(data + offset); + int64 context_arc = 0, future_arc = 0, context_bit = 0, future_bit = 0, + final_bit = 0; + + // pseudo-root bits + BitmapIndex::Set(context_bits, context_bit++); + ++context_bit; + context_words[context_arc] = label; + backoff[context_arc] = Weight::Zero(); + context_arc++; + + ++future_bit; + if (order_out) { + order_out->clear(); + order_out->resize(num_states); + } + + queue<StateId> context_q; + context_q.push(context_fst.Start()); + StateId state_number = 0; + while (!context_q.empty()) { + const StateId &state = context_q.front(); + if (order_out) { + (*order_out)[state] = state_number; + } + + const Weight &final = context_fst.Final(state); + if (final != Weight::Zero()) { + BitmapIndex::Set(final_bits, state_number); + final_probs[final_bit] = final; + ++final_bit; + } + + for (ArcIterator<VectorFst<A> > aiter(context_fst, state); + !aiter.Done(); aiter.Next()) { + const Arc &arc = aiter.Value(); + context_words[context_arc] = arc.ilabel; + backoff[context_arc] = arc.weight; + ++context_arc; + BitmapIndex::Set(context_bits, context_bit++); + context_q.push(arc.nextstate); + } + ++context_bit; + + for (ArcIterator<Fst<A> > aiter(fst, state); !aiter.Done(); aiter.Next()) { + const Arc &arc = aiter.Value(); + if (arc.ilabel != 0) { + future_words[future_arc] = arc.ilabel; + future_probs[future_arc] = arc.weight; + ++future_arc; + BitmapIndex::Set(future_bits, future_bit++); + } + } + ++future_bit; + ++state_number; + context_q.pop(); + } + + if ((state_number != num_states) || + (context_bit != num_states * 2 + 1) || + (context_arc != num_states) || + (future_arc != num_futures) || + (future_bit != num_futures + num_states + 1) || + (final_bit != num_final)) { + FSTERROR() << "Structure problems detected during construction"; + SetProperties(kError, kError); + return; + } + + Init(data, false, data_region); +} + +template<typename A> +inline void NGramFstImpl<A>::Init(const char* data, bool owned, + MappedFile *data_region) { + if (owned_) { + delete [] data_; + } + delete data_region_; + data_region_ = data_region; + owned_ = owned; + data_ = data; + size_t offset = 0; + num_states_ = *(reinterpret_cast<const uint64*>(data_ + offset)); + offset += sizeof(num_states_); + num_futures_ = *(reinterpret_cast<const uint64*>(data_ + offset)); + offset += sizeof(num_futures_); + num_final_ = *(reinterpret_cast<const uint64*>(data_ + offset)); + offset += sizeof(num_final_); + uint64 bits; + size_t context_bits = num_states_ * 2 + 1; + size_t future_bits = num_futures_ + num_states_ + 1; + context_ = reinterpret_cast<const uint64*>(data_ + offset); + offset += BitmapIndex::StorageSize(context_bits) * sizeof(bits); + future_ = reinterpret_cast<const uint64*>(data_ + offset); + offset += BitmapIndex::StorageSize(future_bits) * sizeof(bits); + final_ = reinterpret_cast<const uint64*>(data_ + offset); + offset += BitmapIndex::StorageSize(num_states_) * sizeof(bits); + context_words_ = reinterpret_cast<const Label*>(data_ + offset); + offset += (num_states_ + 1) * sizeof(*context_words_); + future_words_ = reinterpret_cast<const Label*>(data_ + offset); + offset += num_futures_ * sizeof(*future_words_); + offset = (offset + sizeof(*backoff_) - 1) & ~(sizeof(*backoff_) - 1); + backoff_ = reinterpret_cast<const Weight*>(data_ + offset); + offset += (num_states_ + 1) * sizeof(*backoff_); + final_probs_ = reinterpret_cast<const Weight*>(data_ + offset); + offset += num_final_ * sizeof(*final_probs_); + future_probs_ = reinterpret_cast<const Weight*>(data_ + offset); + + context_index_.BuildIndex(context_, context_bits); + future_index_.BuildIndex(future_, future_bits); + final_index_.BuildIndex(final_, num_states_); + + const size_t node_rank = context_index_.Rank1(0); + root_first_child_ = context_index_.Select0(node_rank) + 1; + if (context_index_.Get(root_first_child_) == false) { + FSTERROR() << "Missing unigrams"; + SetProperties(kError, kError); + return; + } + const size_t last_child = context_index_.Select0(node_rank + 1) - 1; + root_num_children_ = last_child - root_first_child_ + 1; + root_children_ = context_words_ + context_index_.Rank1(root_first_child_); +} + +template<typename A> +inline typename A::StateId NGramFstImpl<A>::Transition( + const vector<Label> &context, Label future) const { + size_t num_children = root_num_children_; + const Label *children = root_children_; + const Label *loc = lower_bound(children, children + num_children, future); + if (loc == children + num_children || *loc != future) { + return context_index_.Rank1(0); + } + size_t node = root_first_child_ + loc - children; + size_t node_rank = context_index_.Rank1(node); + size_t first_child = context_index_.Select0(node_rank) + 1; + if (context_index_.Get(first_child) == false) { + return context_index_.Rank1(node); + } + size_t last_child = context_index_.Select0(node_rank + 1) - 1; + num_children = last_child - first_child + 1; + for (int word = context.size() - 1; word >= 0; --word) { + children = context_words_ + context_index_.Rank1(first_child); + loc = lower_bound(children, children + last_child - first_child + 1, + context[word]); + if (loc == children + last_child - first_child + 1 || + *loc != context[word]) { + break; + } + node = first_child + loc - children; + node_rank = context_index_.Rank1(node); + first_child = context_index_.Select0(node_rank) + 1; + if (context_index_.Get(first_child) == false) break; + last_child = context_index_.Select0(node_rank + 1) - 1; + } + return context_index_.Rank1(node); +} + +/*****************************************************************************/ +template<class A> +class NGramFst : public ImplToExpandedFst<NGramFstImpl<A> > { + friend class ArcIterator<NGramFst<A> >; + friend class NGramFstMatcher<A>; + + public: + typedef A Arc; + typedef typename A::StateId StateId; + typedef typename A::Label Label; + typedef typename A::Weight Weight; + typedef NGramFstImpl<A> Impl; + + explicit NGramFst(const Fst<A> &dst) + : ImplToExpandedFst<Impl>(new Impl(dst, NULL)) {} + + NGramFst(const Fst<A> &fst, vector<StateId>* order_out) + : ImplToExpandedFst<Impl>(new Impl(fst, order_out)) {} + + // Because the NGramFstImpl is a const stateless data structure, there + // is never a need to do anything beside copy the reference. + NGramFst(const NGramFst<A> &fst, bool safe = false) + : ImplToExpandedFst<Impl>(fst, false) {} + + NGramFst() : ImplToExpandedFst<Impl>(new Impl()) {} + + // Non-standard constructor to initialize NGramFst directly from data. + NGramFst(const char* data, bool owned) : ImplToExpandedFst<Impl>(new Impl()) { + GetImpl()->Init(data, owned, NULL); + } + + // Get method that gets the data associated with Init(). + const char* GetData(size_t* data_size) const { + return GetImpl()->GetData(data_size); + } + + const vector<Label> GetContext(StateId s) const { + return GetImpl()->GetContext(s, &inst_); + } + + virtual size_t NumArcs(StateId s) const { + return GetImpl()->NumArcs(s, &inst_); + } + + virtual NGramFst<A>* Copy(bool safe = false) const { + return new NGramFst(*this, safe); + } + + static NGramFst<A>* Read(istream &strm, const FstReadOptions &opts) { + Impl* impl = Impl::Read(strm, opts); + return impl ? new NGramFst<A>(impl) : 0; + } + + static NGramFst<A>* Read(const string &filename) { + if (!filename.empty()) { + ifstream strm(filename.c_str(), ifstream::in | ifstream::binary); + if (!strm) { + LOG(ERROR) << "NGramFst::Read: Can't open file: " << filename; + return 0; + } + return Read(strm, FstReadOptions(filename)); + } else { + return Read(cin, FstReadOptions("standard input")); + } + } + + virtual bool Write(ostream &strm, const FstWriteOptions &opts) const { + return GetImpl()->Write(strm, opts); + } + + virtual bool Write(const string &filename) const { + return Fst<A>::WriteFile(filename); + } + + virtual inline void InitStateIterator(StateIteratorData<A>* data) const { + GetImpl()->InitStateIterator(data); + } + + virtual inline void InitArcIterator( + StateId s, ArcIteratorData<A>* data) const; + + virtual MatcherBase<A>* InitMatcher(MatchType match_type) const { + return new NGramFstMatcher<A>(*this, match_type); + } + + private: + explicit NGramFst(Impl* impl) : ImplToExpandedFst<Impl>(impl) {} + + Impl* GetImpl() const { + return + ImplToExpandedFst<Impl, ExpandedFst<A> >::GetImpl(); + } + + void SetImpl(Impl* impl, bool own_impl = true) { + ImplToExpandedFst<Impl, Fst<A> >::SetImpl(impl, own_impl); + } + + mutable NGramFstInst<A> inst_; +}; + +template <class A> inline void +NGramFst<A>::InitArcIterator(StateId s, ArcIteratorData<A>* data) const { + GetImpl()->SetInstFuture(s, &inst_); + GetImpl()->SetInstNode(&inst_); + data->base = new ArcIterator<NGramFst<A> >(*this, s); +} + +/*****************************************************************************/ +template <class A> +class NGramFstMatcher : public MatcherBase<A> { + public: + typedef A Arc; + typedef typename A::Label Label; + typedef typename A::StateId StateId; + typedef typename A::Weight Weight; + + NGramFstMatcher(const NGramFst<A> &fst, MatchType match_type) + : fst_(fst), inst_(fst.inst_), match_type_(match_type), + current_loop_(false), + loop_(kNoLabel, 0, A::Weight::One(), kNoStateId) { + if (match_type_ == MATCH_OUTPUT) { + swap(loop_.ilabel, loop_.olabel); + } + } + + NGramFstMatcher(const NGramFstMatcher<A> &matcher, bool safe = false) + : fst_(matcher.fst_), inst_(matcher.inst_), + match_type_(matcher.match_type_), current_loop_(false), + loop_(kNoLabel, 0, A::Weight::One(), kNoStateId) { + if (match_type_ == MATCH_OUTPUT) { + swap(loop_.ilabel, loop_.olabel); + } + } + + virtual NGramFstMatcher<A>* Copy(bool safe = false) const { + return new NGramFstMatcher<A>(*this, safe); + } + + virtual MatchType Type(bool test) const { + return match_type_; + } + + virtual const Fst<A> &GetFst() const { + return fst_; + } + + virtual uint64 Properties(uint64 props) const { + return props; + } + + private: + virtual void SetState_(StateId s) { + fst_.GetImpl()->SetInstFuture(s, &inst_); + current_loop_ = false; + } + + virtual bool Find_(Label label) { + const Label nolabel = kNoLabel; + done_ = true; + if (label == 0 || label == nolabel) { + if (label == 0) { + current_loop_ = true; + loop_.nextstate = inst_.state_; + } + // The unigram state has no epsilon arc. + if (inst_.state_ != 0) { + arc_.ilabel = arc_.olabel = 0; + fst_.GetImpl()->SetInstNode(&inst_); + arc_.nextstate = fst_.GetImpl()->context_index_.Rank1( + fst_.GetImpl()->context_index_.Select1( + fst_.GetImpl()->context_index_.Rank0(inst_.node_) - 1)); + arc_.weight = fst_.GetImpl()->backoff_[inst_.state_]; + done_ = false; + } + } else { + const Label *start = fst_.GetImpl()->future_words_ + inst_.offset_; + const Label *end = start + inst_.num_futures_; + const Label* search = lower_bound(start, end, label); + if (search != end && *search == label) { + size_t state = search - start; + arc_.ilabel = arc_.olabel = label; + arc_.weight = fst_.GetImpl()->future_probs_[inst_.offset_ + state]; + fst_.GetImpl()->SetInstContext(&inst_); + arc_.nextstate = fst_.GetImpl()->Transition(inst_.context_, label); + done_ = false; + } + } + return !Done_(); + } + + virtual bool Done_() const { + return !current_loop_ && done_; + } + + virtual const Arc& Value_() const { + return (current_loop_) ? loop_ : arc_; + } + + virtual void Next_() { + if (current_loop_) { + current_loop_ = false; + } else { + done_ = true; + } + } + + const NGramFst<A>& fst_; + NGramFstInst<A> inst_; + MatchType match_type_; // Supplied by caller + bool done_; + Arc arc_; + bool current_loop_; // Current arc is the implicit loop + Arc loop_; +}; + +/*****************************************************************************/ +template<class A> +class ArcIterator<NGramFst<A> > : public ArcIteratorBase<A> { + public: + typedef A Arc; + typedef typename A::Label Label; + typedef typename A::StateId StateId; + typedef typename A::Weight Weight; + + ArcIterator(const NGramFst<A> &fst, StateId state) + : lazy_(~0), impl_(fst.GetImpl()), i_(0), flags_(kArcValueFlags) { + inst_ = fst.inst_; + impl_->SetInstFuture(state, &inst_); + impl_->SetInstNode(&inst_); + } + + bool Done() const { + return i_ >= ((inst_.node_ == 0) ? inst_.num_futures_ : + inst_.num_futures_ + 1); + } + + const Arc &Value() const { + bool eps = (inst_.node_ != 0 && i_ == 0); + StateId state = (inst_.node_ == 0) ? i_ : i_ - 1; + if (flags_ & lazy_ & (kArcILabelValue | kArcOLabelValue)) { + arc_.ilabel = + arc_.olabel = eps ? 0 : impl_->future_words_[inst_.offset_ + state]; + lazy_ &= ~(kArcILabelValue | kArcOLabelValue); + } + if (flags_ & lazy_ & kArcNextStateValue) { + if (eps) { + arc_.nextstate = impl_->context_index_.Rank1( + impl_->context_index_.Select1( + impl_->context_index_.Rank0(inst_.node_) - 1)); + } else { + if (lazy_ & kArcNextStateValue) { + impl_->SetInstContext(&inst_); // first time only. + } + arc_.nextstate = + impl_->Transition(inst_.context_, + impl_->future_words_[inst_.offset_ + state]); + } + lazy_ &= ~kArcNextStateValue; + } + if (flags_ & lazy_ & kArcWeightValue) { + arc_.weight = eps ? impl_->backoff_[inst_.state_] : + impl_->future_probs_[inst_.offset_ + state]; + lazy_ &= ~kArcWeightValue; + } + return arc_; + } + + void Next() { + ++i_; + lazy_ = ~0; + } + + size_t Position() const { return i_; } + + void Reset() { + i_ = 0; + lazy_ = ~0; + } + + void Seek(size_t a) { + if (i_ != a) { + i_ = a; + lazy_ = ~0; + } + } + + uint32 Flags() const { + return flags_; + } + + void SetFlags(uint32 f, uint32 m) { + flags_ &= ~m; + flags_ |= (f & kArcValueFlags); + } + + private: + virtual bool Done_() const { return Done(); } + virtual const Arc& Value_() const { return Value(); } + virtual void Next_() { Next(); } + virtual size_t Position_() const { return Position(); } + virtual void Reset_() { Reset(); } + virtual void Seek_(size_t a) { Seek(a); } + uint32 Flags_() const { return Flags(); } + void SetFlags_(uint32 f, uint32 m) { SetFlags(f, m); } + + mutable Arc arc_; + mutable uint32 lazy_; + const NGramFstImpl<A> *impl_; + mutable NGramFstInst<A> inst_; + + size_t i_; + uint32 flags_; + + DISALLOW_COPY_AND_ASSIGN(ArcIterator); +}; + +/*****************************************************************************/ +// Specialization for NGramFst; see generic version in fst.h +// for sample usage (but use the ProdLmFst type!). This version +// should inline. +template <class A> +class StateIterator<NGramFst<A> > : public StateIteratorBase<A> { + public: + typedef typename A::StateId StateId; + + explicit StateIterator(const NGramFst<A> &fst) + : s_(0), num_states_(fst.NumStates()) { } + + bool Done() const { return s_ >= num_states_; } + StateId Value() const { return s_; } + void Next() { ++s_; } + void Reset() { s_ = 0; } + + private: + virtual bool Done_() const { return Done(); } + virtual StateId Value_() const { return Value(); } + virtual void Next_() { Next(); } + virtual void Reset_() { Reset(); } + + StateId s_, num_states_; + + DISALLOW_COPY_AND_ASSIGN(StateIterator); +}; +} // namespace fst +#endif // FST_EXTENSIONS_NGRAM_NGRAM_FST_H_ diff --git a/kaldi_io/src/tools/openfst/include/fst/extensions/ngram/nthbit.h b/kaldi_io/src/tools/openfst/include/fst/extensions/ngram/nthbit.h new file mode 100644 index 0000000..d4a9a5a --- /dev/null +++ b/kaldi_io/src/tools/openfst/include/fst/extensions/ngram/nthbit.h @@ -0,0 +1,46 @@ + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Copyright 2005-2010 Google, Inc. +// Author: [email protected] (Jeffrey Sorensen) +// [email protected] (Doug Rohde) + +#ifndef FST_EXTENSIONS_NGRAM_NTHBIT_H_ +#define FST_EXTENSIONS_NGRAM_NTHBIT_H_ + +#include <fst/types.h> + +extern uint32 nth_bit_bit_offset[]; + +inline uint32 nth_bit(uint64 v, uint32 r) { + uint32 shift = 0; + uint32 c = __builtin_popcount(v & 0xffffffff); + uint32 mask = -(r > c); + r -= c & mask; + shift += (32 & mask); + + c = __builtin_popcount((v >> shift) & 0xffff); + mask = -(r > c); + r -= c & mask; + shift += (16 & mask); + + c = __builtin_popcount((v >> shift) & 0xff); + mask = -(r > c); + r -= c & mask; + shift += (8 & mask); + + return shift + ((nth_bit_bit_offset[(v >> shift) & 0xff] >> + ((r - 1) << 2)) & 0xf); +} + +#endif // FST_EXTENSIONS_NGRAM_NTHBIT_H_ |