From 96a32415ab43377cf1575bd3f4f2980f58028209 Mon Sep 17 00:00:00 2001 From: Determinant Date: Fri, 14 Aug 2015 11:51:42 +0800 Subject: add implementation for kaldi io (by ymz) --- .../tools/openfst/include/fst/extensions/far/far.h | 532 +++++++++++++++++++++ 1 file changed, 532 insertions(+) create mode 100644 kaldi_io/src/tools/openfst/include/fst/extensions/far/far.h (limited to 'kaldi_io/src/tools/openfst/include/fst/extensions/far/far.h') diff --git a/kaldi_io/src/tools/openfst/include/fst/extensions/far/far.h b/kaldi_io/src/tools/openfst/include/fst/extensions/far/far.h new file mode 100644 index 0000000..acce76e --- /dev/null +++ b/kaldi_io/src/tools/openfst/include/fst/extensions/far/far.h @@ -0,0 +1,532 @@ +// far.h + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Copyright 2005-2010 Google, Inc. +// Author: riley@google.com (Michael Riley) +// +// \file +// Finite-State Transducer (FST) archive classes. +// + +#ifndef FST_EXTENSIONS_FAR_FAR_H__ +#define FST_EXTENSIONS_FAR_FAR_H__ + +#include +#include +#include +#include + +namespace fst { + +enum FarEntryType { FET_LINE, FET_FILE }; +enum FarTokenType { FTT_SYMBOL, FTT_BYTE, FTT_UTF8 }; + +inline bool IsFst(const string &filename) { + ifstream strm(filename.c_str()); + if (!strm) + return false; + return IsFstHeader(strm, filename); +} + +// FST archive header class +class FarHeader { + public: + const string &FarType() const { return fartype_; } + const string &ArcType() const { return arctype_; } + + bool Read(const string &filename) { + FstHeader fsthdr; + if (filename.empty()) { + // Header reading unsupported on stdin. Assumes STList and StdArc. + fartype_ = "stlist"; + arctype_ = "standard"; + return true; + } else if (IsSTTable(filename)) { // Check if STTable + ReadSTTableHeader(filename, &fsthdr); + fartype_ = "sttable"; + arctype_ = fsthdr.ArcType().empty() ? "unknown" : fsthdr.ArcType(); + return true; + } else if (IsSTList(filename)) { // Check if STList + ReadSTListHeader(filename, &fsthdr); + fartype_ = "sttable"; + arctype_ = fsthdr.ArcType().empty() ? "unknown" : fsthdr.ArcType(); + return true; + } else if (IsFst(filename)) { // Check if Fst + ifstream istrm(filename.c_str()); + fsthdr.Read(istrm, filename); + fartype_ = "fst"; + arctype_ = fsthdr.ArcType().empty() ? "unknown" : fsthdr.ArcType(); + return true; + } + return false; + } + + private: + string fartype_; + string arctype_; +}; + +enum FarType { + FAR_DEFAULT = 0, + FAR_STTABLE = 1, + FAR_STLIST = 2, + FAR_FST = 3, +}; + +// This class creates an archive of FSTs. +template +class FarWriter { + public: + typedef A Arc; + + // Creates a new (empty) FST archive; returns NULL on error. + static FarWriter *Create(const string &filename, FarType type = FAR_DEFAULT); + + // Adds an FST to the end of an archive. Keys must be non-empty and + // in lexicographic order. FSTs must have a suitable write method. + virtual void Add(const string &key, const Fst &fst) = 0; + + virtual FarType Type() const = 0; + + virtual bool Error() const = 0; + + virtual ~FarWriter() {} + + protected: + FarWriter() {} + + private: + DISALLOW_COPY_AND_ASSIGN(FarWriter); +}; + + +// This class iterates through an existing archive of FSTs. +template +class FarReader { + public: + typedef A Arc; + + // Opens an existing FST archive in a single file; returns NULL on error. + // Sets current position to the beginning of the achive. + static FarReader *Open(const string &filename); + + // Opens an existing FST archive in multiple files; returns NULL on error. + // Sets current position to the beginning of the achive. + static FarReader *Open(const vector &filenames); + + // Resets current posision to beginning of archive. + virtual void Reset() = 0; + + // Sets current position to first entry >= key. Returns true if a match. + virtual bool Find(const string &key) = 0; + + // Current position at end of archive? + virtual bool Done() const = 0; + + // Move current position to next FST. + virtual void Next() = 0; + + // Returns key at the current position. This reference is invalidated if + // the current position in the archive is changed. + virtual const string &GetKey() const = 0; + + // Returns FST at the current position. This reference is invalidated if + // the current position in the archive is changed. + virtual const Fst &GetFst() const = 0; + + virtual FarType Type() const = 0; + + virtual bool Error() const = 0; + + virtual ~FarReader() {} + + protected: + FarReader() {} + + private: + DISALLOW_COPY_AND_ASSIGN(FarReader); +}; + + +template +class FstWriter { + public: + void operator()(ostream &strm, const Fst &fst) const { + fst.Write(strm, FstWriteOptions()); + } +}; + + +template +class STTableFarWriter : public FarWriter { + public: + typedef A Arc; + + static STTableFarWriter *Create(const string &filename) { + STTableWriter, FstWriter > *writer = + STTableWriter, FstWriter >::Create(filename); + return new STTableFarWriter(writer); + } + + void Add(const string &key, const Fst &fst) { writer_->Add(key, fst); } + + FarType Type() const { return FAR_STTABLE; } + + bool Error() const { return writer_->Error(); } + + ~STTableFarWriter() { delete writer_; } + + private: + explicit STTableFarWriter(STTableWriter, FstWriter > *writer) + : writer_(writer) {} + + private: + STTableWriter, FstWriter > *writer_; + + DISALLOW_COPY_AND_ASSIGN(STTableFarWriter); +}; + + +template +class STListFarWriter : public FarWriter { + public: + typedef A Arc; + + static STListFarWriter *Create(const string &filename) { + STListWriter, FstWriter > *writer = + STListWriter, FstWriter >::Create(filename); + return new STListFarWriter(writer); + } + + void Add(const string &key, const Fst &fst) { writer_->Add(key, fst); } + + FarType Type() const { return FAR_STLIST; } + + bool Error() const { return writer_->Error(); } + + ~STListFarWriter() { delete writer_; } + + private: + explicit STListFarWriter(STListWriter, FstWriter > *writer) + : writer_(writer) {} + + private: + STListWriter, FstWriter > *writer_; + + DISALLOW_COPY_AND_ASSIGN(STListFarWriter); +}; + + +template +class FstFarWriter : public FarWriter { + public: + typedef A Arc; + + explicit FstFarWriter(const string &filename) + : filename_(filename), error_(false), written_(false) {} + + static FstFarWriter *Create(const string &filename) { + return new FstFarWriter(filename); + } + + void Add(const string &key, const Fst &fst) { + if (written_) { + LOG(WARNING) << "FstFarWriter::Add: only one Fst supported," + << " subsequent entries discarded."; + } else { + error_ = !fst.Write(filename_); + written_ = true; + } + } + + FarType Type() const { return FAR_FST; } + + bool Error() const { return error_; } + + ~FstFarWriter() {} + + private: + string filename_; + bool error_; + bool written_; + + DISALLOW_COPY_AND_ASSIGN(FstFarWriter); +}; + + +template +FarWriter *FarWriter::Create(const string &filename, FarType type) { + switch(type) { + case FAR_DEFAULT: + if (filename.empty()) + return STListFarWriter::Create(filename); + case FAR_STTABLE: + return STTableFarWriter::Create(filename); + break; + case FAR_STLIST: + return STListFarWriter::Create(filename); + break; + case FAR_FST: + return FstFarWriter::Create(filename); + break; + default: + LOG(ERROR) << "FarWriter::Create: unknown far type"; + return 0; + } +} + + +template +class FstReader { + public: + Fst *operator()(istream &strm) const { + return Fst::Read(strm, FstReadOptions()); + } +}; + + +template +class STTableFarReader : public FarReader { + public: + typedef A Arc; + + static STTableFarReader *Open(const string &filename) { + STTableReader, FstReader > *reader = + STTableReader, FstReader >::Open(filename); + // TODO: error check + return new STTableFarReader(reader); + } + + static STTableFarReader *Open(const vector &filenames) { + STTableReader, FstReader > *reader = + STTableReader, FstReader >::Open(filenames); + // TODO: error check + return new STTableFarReader(reader); + } + + void Reset() { reader_->Reset(); } + + bool Find(const string &key) { return reader_->Find(key); } + + bool Done() const { return reader_->Done(); } + + void Next() { return reader_->Next(); } + + const string &GetKey() const { return reader_->GetKey(); } + + const Fst &GetFst() const { return reader_->GetEntry(); } + + FarType Type() const { return FAR_STTABLE; } + + bool Error() const { return reader_->Error(); } + + ~STTableFarReader() { delete reader_; } + + private: + explicit STTableFarReader(STTableReader, FstReader > *reader) + : reader_(reader) {} + + private: + STTableReader, FstReader > *reader_; + + DISALLOW_COPY_AND_ASSIGN(STTableFarReader); +}; + + +template +class STListFarReader : public FarReader { + public: + typedef A Arc; + + static STListFarReader *Open(const string &filename) { + STListReader, FstReader > *reader = + STListReader, FstReader >::Open(filename); + // TODO: error check + return new STListFarReader(reader); + } + + static STListFarReader *Open(const vector &filenames) { + STListReader, FstReader > *reader = + STListReader, FstReader >::Open(filenames); + // TODO: error check + return new STListFarReader(reader); + } + + void Reset() { reader_->Reset(); } + + bool Find(const string &key) { return reader_->Find(key); } + + bool Done() const { return reader_->Done(); } + + void Next() { return reader_->Next(); } + + const string &GetKey() const { return reader_->GetKey(); } + + const Fst &GetFst() const { return reader_->GetEntry(); } + + FarType Type() const { return FAR_STLIST; } + + bool Error() const { return reader_->Error(); } + + ~STListFarReader() { delete reader_; } + + private: + explicit STListFarReader(STListReader, FstReader > *reader) + : reader_(reader) {} + + private: + STListReader, FstReader > *reader_; + + DISALLOW_COPY_AND_ASSIGN(STListFarReader); +}; + +template +class FstFarReader : public FarReader { + public: + typedef A Arc; + + static FstFarReader *Open(const string &filename) { + vector filenames; + filenames.push_back(filename); + return new FstFarReader(filenames); + } + + static FstFarReader *Open(const vector &filenames) { + return new FstFarReader(filenames); + } + + FstFarReader(const vector &filenames) + : keys_(filenames), has_stdin_(false), pos_(0), fst_(0), error_(false) { + sort(keys_.begin(), keys_.end()); + streams_.resize(keys_.size(), 0); + for (size_t i = 0; i < keys_.size(); ++i) { + if (keys_[i].empty()) { + if (!has_stdin_) { + streams_[i] = &cin; + //sources_[i] = "stdin"; + has_stdin_ = true; + } else { + FSTERROR() << "FstFarReader::FstFarReader: stdin should only " + << "appear once in the input file list."; + error_ = true; + return; + } + } else { + streams_[i] = new ifstream( + keys_[i].c_str(), ifstream::in | ifstream::binary); + } + } + if (pos_ >= keys_.size()) return; + ReadFst(); + } + + void Reset() { + if (has_stdin_) { + FSTERROR() << "FstFarReader::Reset: operation not supported on stdin"; + error_ = true; + return; + } + pos_ = 0; + ReadFst(); + } + + bool Find(const string &key) { + if (has_stdin_) { + FSTERROR() << "FstFarReader::Find: operation not supported on stdin"; + error_ = true; + return false; + } + pos_ = 0;//TODO + ReadFst(); + return true; + } + + bool Done() const { return error_ || pos_ >= keys_.size(); } + + void Next() { + ++pos_; + ReadFst(); + } + + const string &GetKey() const { + return keys_[pos_]; + } + + const Fst &GetFst() const { + return *fst_; + } + + FarType Type() const { return FAR_FST; } + + bool Error() const { return error_; } + + ~FstFarReader() { + if (fst_) delete fst_; + for (size_t i = 0; i < keys_.size(); ++i) + delete streams_[i]; + } + + private: + void ReadFst() { + if (fst_) delete fst_; + if (pos_ >= keys_.size()) return; + streams_[pos_]->seekg(0); + fst_ = Fst::Read(*streams_[pos_], FstReadOptions()); + if (!fst_) { + FSTERROR() << "FstFarReader: error reading Fst from: " << keys_[pos_]; + error_ = true; + } + } + + private: + vector keys_; + vector streams_; + bool has_stdin_; + size_t pos_; + mutable Fst *fst_; + mutable bool error_; + + DISALLOW_COPY_AND_ASSIGN(FstFarReader); +}; + +template +FarReader *FarReader::Open(const string &filename) { + if (filename.empty()) + return STListFarReader::Open(filename); + else if (IsSTTable(filename)) + return STTableFarReader::Open(filename); + else if (IsSTList(filename)) + return STListFarReader::Open(filename); + else if (IsFst(filename)) + return FstFarReader::Open(filename); + return 0; +} + + +template +FarReader *FarReader::Open(const vector &filenames) { + if (!filenames.empty() && filenames[0].empty()) + return STListFarReader::Open(filenames); + else if (!filenames.empty() && IsSTTable(filenames[0])) + return STTableFarReader::Open(filenames); + else if (!filenames.empty() && IsSTList(filenames[0])) + return STListFarReader::Open(filenames); + else if (!filenames.empty() && IsFst(filenames[0])) + return FstFarReader::Open(filenames); + return 0; +} + +} // namespace fst + +#endif // FST_EXTENSIONS_FAR_FAR_H__ -- cgit v1.2.3