summaryrefslogblamecommitdiff
path: root/kaldi_io/src/tools/openfst/include/fst/interval-set.h
blob: 58cad449b8fc303b1f00b6c297cdc1089da12fe5 (plain) (tree)




























































































































































































































































































































































































                                                                                 
// interval-set.h

// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// Copyright 2005-2010 Google, Inc.
// Author: [email protected] (Michael Riley)
//
// \file
// Class to represent and operate on sets of intervals.

#ifndef FST_LIB_INTERVAL_SET_H__
#define FST_LIB_INTERVAL_SET_H__

#include <iostream>
#include <vector>
using std::vector;


#include <fst/util.h>


namespace fst {

// Stores and operates on a set of half-open integral intervals [a,b)
// of signed integers of type T.
template <typename T>
class IntervalSet {
 public:
  struct Interval {
    T begin_;
    T end_;

    Interval() : begin_(-1), end_(-1) {}

    Interval(T b, T e) : begin_(b), end_(e) {}

    bool operator<(const Interval &i) const {
      return begin_ < i.begin_ || (begin_ == i.begin_ && end_ > i.end_);
    }

    bool operator==(const Interval &i) const {
      return begin_ == i.begin_ && end_ == i.end_;
    }

    bool operator!=(const Interval &i) const {
      return begin_ != i.begin_ || end_ != i.end_;
    }

    istream &Read(istream &strm) {
      T n;
      ReadType(strm, &n);
      begin_ = n;
      ReadType(strm, &n);
      end_ = n;
      return strm;
    }

    ostream &Write(ostream &strm) const {
      T n = begin_;
      WriteType(strm, n);
      n = end_;
      WriteType(strm, n);
      return strm;
    }
  };

  IntervalSet() : count_(-1) {}

  // Returns the interval set as a vector.
  vector<Interval> *Intervals() { return &intervals_; }

  const vector<Interval> *Intervals() const { return &intervals_; }

  bool Empty() const { return intervals_.empty(); }

  T Size() const { return intervals_.size(); }

  // Number of points in the intervals (undefined if not normalized).
  T Count() const { return count_; }

  void Clear() {
    intervals_.clear();
    count_ = 0;
  }

  // Adds an interval set to the set. The result may not be normalized.
  void Union(const IntervalSet<T> &iset) {
    const vector<Interval> *intervals = iset.Intervals();
    for (typename vector<Interval>::const_iterator it = intervals->begin();
         it != intervals->end(); ++it)
      intervals_.push_back(*it);
  }

  // Requires intervals be normalized.
  bool Member(T value) const {
    Interval interval(value, value);
    typename vector<Interval>::const_iterator lb =
        lower_bound(intervals_.begin(), intervals_.end(), interval);
    if (lb == intervals_.begin())
      return false;
    return (--lb)->end_ > value;
  }

  // Requires intervals be normalized.
  bool operator==(const IntervalSet<T>& iset) const {
    return *(iset.Intervals()) == intervals_;
  }

  // Requires intervals be normalized.
  bool operator!=(const IntervalSet<T>& iset) const {
    return *(iset.Intervals()) != intervals_;
  }

  bool Singleton() const {
    return intervals_.size() == 1 &&
        intervals_[0].begin_ + 1 == intervals_[0].end_;
  }


  // Sorts; collapses overlapping and adjacent interals; sets count.
  void Normalize();

  // Intersects an interval set with the set. Requires intervals be
  // normalized. The result is normalized.
  void Intersect(const IntervalSet<T> &iset, IntervalSet<T> *oset) const;

  // Complements the set w.r.t [0, maxval). Requires intervals be
  // normalized. The result is normalized.
  void Complement(T maxval, IntervalSet<T> *oset) const;

  // Subtract an interval set from the set. Requires intervals be
  // normalized. The result is normalized.
  void Difference(const IntervalSet<T> &iset, IntervalSet<T> *oset) const;

  // Determines if an interval set overlaps with the set. Requires
  // intervals be normalized.
  bool Overlaps(const IntervalSet<T> &iset) const;

  // Determines if an interval set overlaps with the set but neither
  // is contained in the other. Requires intervals be normalized.
  bool StrictlyOverlaps(const IntervalSet<T> &iset) const;

  // Determines if an interval set is contained within the set. Requires
  // intervals be normalized.
  bool Contains(const IntervalSet<T> &iset) const;

  istream &Read(istream &strm) {
    ReadType(strm, &intervals_);
    return ReadType(strm, &count_);
  }

  ostream &Write(ostream &strm) const {
    WriteType(strm, intervals_);
    return WriteType(strm, count_);
  }

 private:
  vector<Interval> intervals_;
  T count_;
};

// Sorts; collapses overlapping and adjacent interavls; sets count.
template <typename T>
void IntervalSet<T>::Normalize() {
  sort(intervals_.begin(), intervals_.end());

  count_ = 0;
  T size = 0;
  for (T i = 0; i < intervals_.size(); ++i) {
    Interval &inti = intervals_[i];
    if (inti.begin_ == inti.end_)
      continue;
    for (T j = i + 1; j < intervals_.size(); ++j) {
      Interval &intj = intervals_[j];
      if (intj.begin_ > inti.end_)
        break;
      if (intj.end_ > inti.end_)
        inti.end_ = intj.end_;
      ++i;
    }
    count_ += inti.end_ - inti.begin_;
    intervals_[size++] = inti;
  }
  intervals_.resize(size);
}

// Intersects an interval set with the set. Requires intervals be normalized.
// The result is normalized.
template <typename T>
void IntervalSet<T>::Intersect(const IntervalSet<T> &iset,
                               IntervalSet<T> *oset) const {
  const vector<Interval> *iintervals = iset.Intervals();
  vector<Interval> *ointervals = oset->Intervals();
  typename vector<Interval>::const_iterator it1 = intervals_.begin();
  typename vector<Interval>::const_iterator it2 = iintervals->begin();

  ointervals->clear();
  oset->count_ = 0;

  while (it1 != intervals_.end() && it2 != iintervals->end()) {
    if (it1->end_ <= it2->begin_) {
      ++it1;
    } else if (it2->end_ <= it1->begin_) {
      ++it2;
    } else {
      Interval interval;
      interval.begin_ = max(it1->begin_, it2->begin_);
      interval.end_ = min(it1->end_, it2->end_);
      ointervals->push_back(interval);
      oset->count_ += interval.end_ - interval.begin_;
      if (it1->end_ < it2->end_)
        ++it1;
      else
        ++it2;
    }
  }
}

// Complements the set w.r.t [0, maxval). Requires intervals be normalized.
// The result is normalized.
template <typename T>
void IntervalSet<T>::Complement(T maxval, IntervalSet<T> *oset) const {
  vector<Interval> *ointervals = oset->Intervals();
  ointervals->clear();
  oset->count_ = 0;

  Interval interval;
  interval.begin_ = 0;
  for (typename vector<Interval>::const_iterator it = intervals_.begin();
       it != intervals_.end();
       ++it) {
    interval.end_ = min(it->begin_, maxval);
    if (interval.begin_ < interval.end_) {
      ointervals->push_back(interval);
      oset->count_ += interval.end_ - interval.begin_;
    }
    interval.begin_ = it->end_;
  }
  interval.end_ = maxval;
  if (interval.begin_ < interval.end_) {
    ointervals->push_back(interval);
    oset->count_ += interval.end_ - interval.begin_;
  }
}

// Subtract an interval set from the set. Requires intervals be normalized.
// The result is normalized.
template <typename T>
void IntervalSet<T>::Difference(const IntervalSet<T> &iset,
                                IntervalSet<T> *oset) const {
  if (intervals_.empty()) {
    oset->Intervals()->clear();
    oset->count_ = 0;
  } else {
    IntervalSet<T> cset;
    iset.Complement(intervals_.back().end_, &cset);
    Intersect(cset, oset);
  }
}

// Determines if an interval set overlaps with the set. Requires
// intervals be normalized.
template <typename T>
bool IntervalSet<T>::Overlaps(const IntervalSet<T> &iset) const {
  const vector<Interval> *intervals = iset.Intervals();
  typename vector<Interval>::const_iterator it1 = intervals_.begin();
  typename vector<Interval>::const_iterator it2 = intervals->begin();

  while (it1 != intervals_.end() && it2 != intervals->end()) {
    if (it1->end_ <= it2->begin_) {
      ++it1;
    } else if (it2->end_ <= it1->begin_) {
      ++it2;
    } else {
      return true;
    }
  }
  return false;
}

// Determines if an interval set overlaps with the set but neither
// is contained in the other. Requires intervals be normalized.
template <typename T>
bool IntervalSet<T>::StrictlyOverlaps(const IntervalSet<T> &iset) const {
  const vector<Interval> *intervals = iset.Intervals();
  typename vector<Interval>::const_iterator it1 = intervals_.begin();
  typename vector<Interval>::const_iterator it2 = intervals->begin();
  bool only1 = false;   // point in intervals_ but not intervals
  bool only2 = false;   // point in intervals but not intervals_
  bool overlap = false; // point in both intervals_ and intervals

  while (it1 != intervals_.end() && it2 != intervals->end()) {
    if (it1->end_ <= it2->begin_) {  // no overlap - it1 first
      only1 = true;
      ++it1;
    } else if (it2->end_ <= it1->begin_) {  // no overlap - it2 first
      only2 = true;
      ++it2;
    } else if (it2->begin_ == it1->begin_ && it2->end_ == it1->end_) {  // equals
      overlap = true;
      ++it1;
      ++it2;
    } else if (it2->begin_ <= it1->begin_ && it2->end_ >= it1->end_) {  // 1 c 2
      only2 = true;
      overlap = true;
      ++it1;
    } else if (it1->begin_ <= it2->begin_ && it1->end_ >= it2->end_) {  // 2 c 1
      only1 = true;
      overlap = true;
      ++it2;
    } else {  // strict overlap
      only1 = true;
      only2 = true;
      overlap = true;
    }
    if (only1 == true && only2 == true && overlap == true)
      return true;
  }
  if (it1 != intervals_.end())
    only1 = true;
  if (it2 != intervals->end())
    only2 = true;

  return only1 == true && only2 == true && overlap == true;
}

// Determines if an interval set is contained within the set. Requires
// intervals be normalized.
template <typename T>
bool IntervalSet<T>::Contains(const IntervalSet<T> &iset) const {
  if (iset.Count() > Count())
    return false;

  const vector<Interval> *intervals = iset.Intervals();
  typename vector<Interval>::const_iterator it1 = intervals_.begin();
  typename vector<Interval>::const_iterator it2 = intervals->begin();

  while (it1 != intervals_.end() && it2 != intervals->end()) {
    if (it1->end_ <= it2->begin_) {  // no overlap - it1 first
      ++it1;
    } else if (it2->begin_ < it1->begin_ || it2->end_ > it1->end_) {  // no C
      return false;
    } else if (it2->end_ == it1->end_) {
      ++it1;
      ++it2;
    } else {
      ++it2;
    }
  }
  return it2 == intervals->end();
}

template <typename T>
ostream &operator<<(ostream &strm, const IntervalSet<T> &s)  {
  typedef typename IntervalSet<T>::Interval Interval;
  const vector<Interval> *intervals = s.Intervals();
  strm << "{";
  for (typename vector<Interval>::const_iterator it = intervals->begin();
       it != intervals->end();
       ++it) {
    if (it != intervals->begin())
      strm << ",";
    strm << "[" << it->begin_ << "," << it->end_ << ")";
  }
  strm << "}";
  return strm;
}

}  // namespace fst

#endif  // FST_LIB_INTERVAL_SET_H__