summaryrefslogblamecommitdiff
path: root/kaldi_io/src/tools/openfst/include/fst/extensions/far/sttable.h
blob: 3ce0a4b67098867322c9e645c75e4f91be28162b (plain) (tree)


















































































































































































































































































































































































                                                                              
// sttable.h

// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// Copyright 2005-2010 Google, Inc.
// Author: allauzen@google.com (Cyril Allauzen)
//
// \file
// A generic string-to-type table file format
//
// This is not meant as a generalization of SSTable. This is more of
// a simple replacement for SSTable in order to provide an open-source
// implementation of the FAR format for the external version of the
// FST Library.

#ifndef FST_EXTENSIONS_FAR_STTABLE_H_
#define FST_EXTENSIONS_FAR_STTABLE_H_

#include <algorithm>
#include <iostream>
#include <fstream>
#include <sstream>
#include <fst/util.h>

namespace fst {

static const int32 kSTTableMagicNumber = 2125656924;
static const int32 kSTTableFileVersion = 1;

// String-to-type table writing class for object of type 'T' using functor 'W'
// to write an object of type 'T' from a stream. 'W' must conform to the
// following interface:
//
//   struct Writer {
//     void operator()(ostream &, const T &) const;
//   };
//
template <class T, class W>
class STTableWriter {
 public:
  typedef T EntryType;
  typedef W EntryWriter;

  explicit STTableWriter(const string &filename)
      : stream_(filename.c_str(), ofstream::out | ofstream::binary),
        error_(false) {
    WriteType(stream_, kSTTableMagicNumber);
    WriteType(stream_, kSTTableFileVersion);
    if (!stream_) {
      FSTERROR() << "STTableWriter::STTableWriter: error writing to file: "
                 << filename;
      error_=true;
    }
  }

  static STTableWriter<T, W> *Create(const string &filename) {
    if (filename.empty()) {
      LOG(ERROR) << "STTableWriter: writing to standard out unsupported.";
      return 0;
    }
    return new STTableWriter<T, W>(filename);
  }

  void Add(const string &key, const T &t) {
    if (key == "") {
      FSTERROR() << "STTableWriter::Add: key empty: " << key;
      error_ = true;
    } else if (key < last_key_) {
      FSTERROR() << "STTableWriter::Add: key disorder: " << key;
      error_ = true;
    }
    if (error_) return;
    last_key_ = key;
    positions_.push_back(stream_.tellp());
    WriteType(stream_, key);
    entry_writer_(stream_, t);
  }

  bool Error() const { return error_; }

  ~STTableWriter() {
    WriteType(stream_, positions_);
    WriteType(stream_, static_cast<int64>(positions_.size()));
  }

 private:
  EntryWriter entry_writer_;  // Write functor for 'EntryType'
  ofstream stream_;           // Output stream
  vector<int64> positions_;   // Position in file of each key-entry pair
  string last_key_;           // Last key
  bool error_;

  DISALLOW_COPY_AND_ASSIGN(STTableWriter);
};


// String-to-type table reading class for object of type 'T' using functor 'R'
// to read an object of type 'T' form a stream. 'R' must conform to the
// following interface:
//
//   struct Reader {
//     T *operator()(istream &) const;
//   };
//
template <class T, class R>
class STTableReader {
 public:
  typedef T EntryType;
  typedef R EntryReader;

  explicit STTableReader(const vector<string> &filenames)
      : sources_(filenames), entry_(0), error_(false) {
    compare_ = new Compare(&keys_);
    keys_.resize(filenames.size());
    streams_.resize(filenames.size(), 0);
    positions_.resize(filenames.size());
    for (size_t i = 0; i < filenames.size(); ++i) {
      streams_[i] = new ifstream(
          filenames[i].c_str(), ifstream::in | ifstream::binary);
      int32 magic_number = 0, file_version = 0;
      ReadType(*streams_[i], &magic_number);
      ReadType(*streams_[i], &file_version);
      if (magic_number != kSTTableMagicNumber) {
        FSTERROR() << "STTableReader::STTableReader: wrong file type: "
                   << filenames[i];
        error_ = true;
        return;
      }
      if (file_version != kSTTableFileVersion) {
        FSTERROR() << "STTableReader::STTableReader: wrong file version: "
                   << filenames[i];
        error_ = true;
        return;
      }
      int64 num_entries;
      streams_[i]->seekg(-static_cast<int>(sizeof(int64)), ios_base::end);
      ReadType(*streams_[i], &num_entries);
      streams_[i]->seekg(-static_cast<int>(sizeof(int64)) *
                         (num_entries + 1), ios_base::end);
      positions_[i].resize(num_entries);
      for (size_t j = 0; (j < num_entries) && (*streams_[i]); ++j)
        ReadType(*streams_[i], &(positions_[i][j]));
      streams_[i]->seekg(positions_[i][0]);
      if (!*streams_[i]) {
        FSTERROR() << "STTableReader::STTableReader: error reading file: "
                   << filenames[i];
        error_ = true;
        return;
      }

    }
    MakeHeap();
  }

  ~STTableReader() {
    for (size_t i = 0; i < streams_.size(); ++i)
      delete streams_[i];
    delete compare_;
    if (entry_)
      delete entry_;
  }

  static STTableReader<T, R> *Open(const string &filename) {
    if (filename.empty()) {
      LOG(ERROR) << "STTableReader: reading from standard in not supported";
      return 0;
    }
    vector<string> filenames;
    filenames.push_back(filename);
    return new STTableReader<T, R>(filenames);
  }

  static STTableReader<T, R> *Open(const vector<string> &filenames) {
    return new STTableReader<T, R>(filenames);
  }

  void Reset() {
    if (error_) return;
    for (size_t i = 0; i < streams_.size(); ++i)
      streams_[i]->seekg(positions_[i].front());
    MakeHeap();
  }

  bool Find(const string &key) {
    if (error_) return false;
    for (size_t i = 0; i < streams_.size(); ++i)
      LowerBound(i, key);
    MakeHeap();
    return keys_[current_] == key;
  }

  bool Done() const { return error_ || heap_.empty(); }

  void Next() {
    if (error_) return;
    if (streams_[current_]->tellg() <= positions_[current_].back()) {
      ReadType(*(streams_[current_]), &(keys_[current_]));
      if (!*streams_[current_]) {
        FSTERROR() << "STTableReader: error reading file: "
                   << sources_[current_];
        error_ = true;
        return;
      }
      push_heap(heap_.begin(), heap_.end(), *compare_);
    } else {
      heap_.pop_back();
    }
    if (!heap_.empty())
      PopHeap();
  }

  const string &GetKey() const {
    return keys_[current_];
  }

  const EntryType &GetEntry() const {
    return *entry_;
  }

  bool Error() const { return error_; }

 private:
  // Comparison functor used to compare stream IDs in the heap
  struct Compare {
    Compare(const vector<string> *keys) : keys_(keys) {}

    bool operator()(size_t i, size_t j) const {
      return (*keys_)[i] > (*keys_)[j];
    };

   private:
    const vector<string> *keys_;
  };

  // Position the stream with ID 'id' at the position corresponding
  // to the lower bound for key 'find_key'
  void LowerBound(size_t id, const