summaryrefslogblamecommitdiff
path: root/kaldi_io/src/tools/openfst/include/fst/interval-set.h
blob: 58cad449b8fc303b1f00b6c297cdc1089da12fe5 (plain) (tree)




























































































































































































































































































































































































                                                                                 
// interval-set.h

// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// Copyright 2005-2010 Google, Inc.
// Author: riley@google.com (Michael Riley)
//
// \file
// Class to represent and operate on sets of intervals.

#ifndef FST_LIB_INTERVAL_SET_H__
#define FST_LIB_INTERVAL_SET_H__

#include <iostream>
#include <vector>
using std::vector;


#include <fst/util.h>


namespace fst {

// Stores and operates on a set of half-open integral intervals [a,b)
// of signed integers of type T.
template <typename T>
class IntervalSet {
 public:
  struct Interval {
    T begin_;
    T end_;

    Interval() : begin_(-1), end_(-1) {}

    Interval(T b, T e) : begin_(b), end_(e) {}

    bool operator<(const Interval &i) const {
      return begin_ < i.begin_ || (begin_ == i.begin_ && end_ > i.end_);
    }

    bool operator==(const Interval &i) const {
      return begin_ == i.begin_ && end_ == i.end_;
    }

    bool operator!=(const Interval &i) const {
      return begin_ != i.begin_ || end_ != i.end_;
    }

    istream &Read(istream &strm) {
      T n;
      ReadType(strm, &n);
      begin_ = n;
      ReadType(strm, &n);
      end_ = n;
      return strm;
    }

    ostream &Write(ostream &strm) const {
      T n = begin_;
      WriteType(strm, n);
      n = end_;
      WriteType(strm, n);
      return strm;
    }
  };

  IntervalSet() : count_(-1) {}

  // Returns the interval set as a vector.
  vector<Interval> *Intervals() { return &intervals_; }

  const vector<Interval> *Intervals() const { return &intervals_; }

  bool Empty() const { return intervals_.empty(); }

  T Size() const { return intervals_.size(); }

  // Number of points in the intervals (undefined if not normalized).
  T Count() const { return count_; }

  void Clear() {
    intervals_.clear();
    count_ = 0;
  }

  // Adds an interval set to the set. The result may not be normalized.
  void Union(const IntervalSet<T> &iset) {
    const vector<Interval> *intervals = iset.Intervals();
    for (typename vector<Interval>::const_iterator it = intervals->begin();
         it != intervals->end(); ++it)
      intervals_.push_back(*it);
  }

  // Requires intervals be normalized.
  bool Member(T value) const {
    Interval interval(value, value);
    typename vector<Interval>::const_iterator lb =
        lower_bound(intervals_.begin(), intervals_.end(), interval);
    if (lb == intervals_.begin())
      return false;
    return (--lb)->end_ > value;
  }

  // Requires intervals be normalized.
  bool operator==(const IntervalSet<T>& iset) const {
    return *(iset.Intervals()) == intervals_;
  }

  // Requires intervals be normalized.
  bool operator!=(const IntervalSet<T>& iset) const {
    return *(iset.Intervals()) != intervals_;
  }

  bool Singleton() const {
    return intervals_.size() == 1 &&
        intervals_[0].begin_ + 1 == intervals_[0].end_;
  }


  // Sorts; collapses overlapping and adjacent interals; sets count.
  void Normalize();

  // Intersects an interval set with the set. Requires intervals be
  // normalized. The result is normalized.
  void Intersect(const IntervalSet<T> &iset, IntervalSet<T> *oset) const;

  // Complements the set w.r.t [0, maxval). Requires intervals be
  // normalized. The result is normalized.
  void Complement(T maxval, IntervalSet<T> *oset) const;

  // Subtract an interval set from the set. Requires intervals be
  // normalized. The result is normalized.
  void Difference(const IntervalSet<T> &iset, IntervalSet<T> *oset) const;

  // Determines if an interval set overlaps with the set. Requires
  // intervals be normalized.
  bool Overlaps(const IntervalSet<T> &iset) const;

  // Determines if an interval set overlaps with the set but neither
  // is contained in the other. Requires intervals be normalized.
  bool StrictlyOverlaps(const IntervalSet<T> &iset) const;

  // Determines if an interval set is contained within the set. Requires
  // intervals be normalized.
  bool Contains(const IntervalSet<T> &iset) const;

  istream &Read(istream &strm) {
    ReadType(strm, &intervals_);
    return ReadType(strm, &count_);
  }

  ostream &Write(ostream &strm) const {
    WriteType(strm, intervals_);
    return WriteType(strm, count_);
  }

 private:
  vector<Interval> intervals_;
  T count_;
};

// Sorts; collapses overlapping and adjacent interavls; sets count.
template <typename T>
void IntervalSet<T>::Normalize() {
  sort(intervals_.begin(), intervals_.end());

  count_ = 0;
  T size = 0;
  for (T i = 0; i < intervals_.size(); ++i) {
    Interval &inti = intervals_[i];
    if (inti.begin_ == inti.end_)
      continue;
    for (T j = i + 1; j < intervals_.size(); ++j) {
      Interval &intj = intervals_[j];
      if (intj.begin_ > inti.end_)
        break;
      if (intj.end_ > inti.end_)
        inti.end_ = intj.end_;
      ++i;
    }
    count_ += inti.end_ - inti.begin_;
    intervals_[size++] = inti;
  }
  intervals_.resize(size);
}

// Intersects an interval set with the set. Requires intervals be normalized.
// The result is normalized.
template <typename T>
void IntervalSet<T>::Intersect(const IntervalSet<T> &iset,
                               IntervalSet<T> *oset) const {
  const vector<Interval> *iintervals = iset.Intervals();
  vector<Interval> *ointervals = oset->Intervals();
  typename vector<Interval>::const_iterator it1 = intervals_.begin();
  typename vector<Interval>::const_iterator it2 = iintervals->begin();

  ointervals->clear();
  oset->count_ = 0;

  while (it1 != intervals_.end() && it2 != iintervals->end()) {
    if (it1->end_ <= it2->begin_) {
      ++it1;
    } else if (it2->end_ <= it1->begin_) {
      ++it2;
    } else {
      Interval interval;
      interval.begin_ = max(it1->begin_, it2->begin_);
      interval.end_ = min(it1->end_, it2->end_);
      ointervals->push_back(interval);
      oset->count_ += interval.end_ - interval.begin_;
      if (it1->end_ < it2->end_)
        ++it1;
      else
        ++it2;
    }
  }
}

// Complements the set w.r.t [0, maxval). Requires intervals be normalized.
// The result is normalized.
template <typename T>
void IntervalSet<T>::Complement(T maxval, IntervalSet<T> *oset) const {
  vector<Interval> *ointervals = oset->Intervals();
  ointervals->clear();
  oset->count_ = 0;

  Interval interval;
  interval.begin_ = 0;
  for (typename vector<Interval>::