summaryrefslogblamecommitdiff
path: root/kaldi_io/src/tools/openfst/include/fst/script/shortest-distance.h
blob: 5fc2976b7b51be58a744ec4ac10fea6b2ad040c3 (plain) (tree)

























































































































































































































































                                                                               
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// Copyright 2005-2010 Google, Inc.
// Author: jpr@google.com (Jake Ratkiewicz)

#ifndef FST_SCRIPT_SHORTEST_DISTANCE_H_
#define FST_SCRIPT_SHORTEST_DISTANCE_H_

#include <vector>
using std::vector;

#include <fst/script/arg-packs.h>
#include <fst/script/fst-class.h>
#include <fst/script/weight-class.h>
#include <fst/script/prune.h>  // for ArcFilterType
#include <fst/queue.h>  // for QueueType
#include <fst/shortest-distance.h>

namespace fst {
namespace script {

enum ArcFilterType { ANY_ARC_FILTER, EPSILON_ARC_FILTER,
                     INPUT_EPSILON_ARC_FILTER, OUTPUT_EPSILON_ARC_FILTER };

// See nlp/fst/lib/shortest-distance.h for the template options class
// that this one shadows
struct ShortestDistanceOptions {
  const QueueType queue_type;
  const ArcFilterType arc_filter_type;
  const int64 source;
  const float delta;
  const bool first_path;

  ShortestDistanceOptions(QueueType qt, ArcFilterType aft, int64 s,
                          float d)
      : queue_type(qt), arc_filter_type(aft), source(s), delta(d),
        first_path(false) { }
};



// 1
typedef args::Package<const FstClass &, vector<WeightClass> *,
                      const ShortestDistanceOptions &> ShortestDistanceArgs1;

template<class Queue, class Arc, class ArcFilter>
struct QueueConstructor {
  //  template<class Arc, class ArcFilter>
  static Queue *Construct(const Fst<Arc> &,
                          const vector<typename Arc::Weight> *) {
    return new Queue();
  }
};

// Specializations to deal with AutoQueue, NaturalShortestFirstQueue,
// and TopOrderQueue's different constructors
template<class Arc, class ArcFilter>
struct QueueConstructor<AutoQueue<typename Arc::StateId>, Arc, ArcFilter> {
  //  template<class Arc, class ArcFilter>
  static AutoQueue<typename Arc::StateId> *Construct(
      const Fst<Arc> &fst,
      const vector<typename Arc::Weight> *distance) {
    return new AutoQueue<typename Arc::StateId>(fst, distance, ArcFilter());
  }
};

template<class Arc, class ArcFilter>
struct QueueConstructor<NaturalShortestFirstQueue<typename Arc::StateId,
                                                  typename Arc::Weight>,
                        Arc, ArcFilter> {
  //  template<class Arc, class ArcFilter>
  static NaturalShortestFirstQueue<typename Arc::StateId, typename Arc::Weight>
  *Construct(const Fst<Arc> &fst,
            const vector<typename Arc::Weight> *distance) {
    return new NaturalShortestFirstQueue<typename Arc::StateId,
                                         typename Arc::Weight>(*distance);
  }
};

template<class Arc, class ArcFilter>
struct QueueConstructor<TopOrderQueue<typename Arc::StateId>, Arc, ArcFilter> {
  //  template<class Arc, class ArcFilter>
  static TopOrderQueue<typename Arc::StateId> *Construct(
      const Fst<Arc> &fst, const vector<typename Arc::Weight> *weights) {
    return new TopOrderQueue<typename Arc::StateId>(fst, ArcFilter());
  }
};


template<class Arc, class Queue>
void ShortestDistanceHelper(ShortestDistanceArgs1 *args) {
  const Fst<Arc> &fst = *(args->arg1.GetFst<Arc>());
  const ShortestDistanceOptions &opts = args->arg3;

  vector<typename Arc::Weight> weights;

  switch (opts.arc_filter_type) {
    case ANY_ARC_FILTER: {
      Queue *queue =
          QueueConstructor<Queue, Arc, AnyArcFilter<Arc> >::Construct(
              fst, &weights);
      fst::ShortestDistanceOptions<Arc, Queue, AnyArcFilter<Arc> > sdopts(
          queue, AnyArcFilter<Arc>(), opts.source, opts.delta);
      ShortestDistance(fst, &weights, sdopts);
      delete queue;
      break;
    }
    case EPSILON_ARC_FILTER: {
      Queue *queue =
          QueueConstructor<Queue, Arc, AnyArcFilter<Arc> >::Construct(
              fst, &weights);
      fst::ShortestDistanceOptions<Arc, Queue,
          EpsilonArcFilter<Arc> > sdopts(
              queue, EpsilonArcFilter<Arc>(), opts.source, opts.delta);
      ShortestDistance(fst, &weights, sdopts);
      delete queue;
      break;
    }
    case INPUT_EPSILON_ARC_FILTER: {
      Queue *queue =
          QueueConstructor<Queue, Arc, InputEpsilonArcFilter<Arc> >::Construct(
              fst, &weights);
      fst::ShortestDistanceOptions<Arc, Queue,
          InputEpsilonArcFilter<Arc> > sdopts(
              queue, InputEpsilonArcFilter<Arc>(), opts.source, opts.delta);
      ShortestDistance(fst, &weights, sdopts);
      delete queue;
      break;
    }
    case OUTPUT_EPSILON_ARC_FILTER: {
      Queue *queue =
          QueueConstructor<Queue, Arc,
          OutputEpsilonArcFilter<Arc> >::Construct(
              fst, &weights);
      fst::ShortestDistanceOptions<Arc, Queue,
          OutputEpsilonArcFilter<Arc> > sdopts(
              queue, OutputEpsilonArcFilter<Arc>(), opts.source, opts.delta);
      ShortestDistance(fst, &weights, sdopts);
      delete queue;
      break;
    }
  }

  // Copy the weights back
  args->arg2->resize(weights.size());
  for (unsigned i = 0; i < weights.size(); ++i) {
    (*args->arg2)[i] = WeightClass(weights[i]);
  }
}

template<class Arc>
void ShortestDistance(ShortestDistanceArgs1 *args) {
  const ShortestDistanceOptions &opts = args->arg3;
  typedef typename Arc::StateId StateId;
  typedef typename Arc::Weight Weight;

  // Must consider (opts.queue_type x opts.filter_type) options
  switch (opts.queue_type) {
    default:
      FSTERROR() << "Unknown queue type." << opts.queue_type;

    case AUTO_QUEUE:
      ShortestDistanceHelper<Arc, AutoQueue<StateId> >(args);
      return;

    case FIFO_QUEUE:
       ShortestDistanceHelper<Arc, FifoQueue<StateId> >(args);
      return;

    case LIFO_QUEUE:
       ShortestDistanceHelper<Arc, LifoQueue<StateId> >(args);
      return;

    case SHORTEST_FIRST_QUEUE:
      ShortestDistanceHelper<Arc,
        NaturalShortestFirstQueue<StateId, Weight> >(args);
      return;

    case STATE_ORDER_QUEUE:
       ShortestDistanceHelper<Arc, StateOrderQueue<StateId> >(args);
      return;

    case TOP_ORDER_QUEUE:
       ShortestDistanceHelper<Arc, TopOrderQueue<StateId> >(args);
      return;
  }
}

// 2
typedef args::Package<const FstClass&, vector<WeightClass>*,
                      bool, double> ShortestDistanceArgs2;

template<class Arc>
void ShortestDistance(ShortestDistanceArgs2 *args) {
  const Fst<Arc> &fst = *(args->arg1.GetFst<Arc>());
  vector<typename Arc::Weight> distance;

  ShortestDistance(fst, &distance, args->arg3, args->arg4);

  // convert the typed weights back into weightclass
  vector<WeightClass> *retval = args->arg2;
  retval->resize(distance.size());

  for (unsigned i = 0; i < distance.size(); ++i) {
    (*retval)[i] = WeightClass(distance[i]);
  }
}

// 3
typedef args::WithReturnValue<WeightClass,
                              const FstClass &> ShortestDistanceArgs3;

template<class Arc>
void ShortestDistance(ShortestDistanceArgs3 *args) {
  const Fst<Arc> &fst = *(args->args.GetFst<Arc>());

  args->retval = WeightClass(ShortestDistance(fst));
}


// 1
void ShortestDistance(const FstClass &fst, vector<WeightClass> *distance,
                      const ShortestDistanceOptions &opts);

// 2
void ShortestDistance(const FstClass &ifst, vector<WeightClass> *distance,
                      bool reverse = false, double delta = fst::kDelta);

#ifndef SWIG
// 3
WeightClass ShortestDistance(const FstClass &ifst);
#endif

}  // namespace script
}  // namespace fst



#endif  // FST_SCRIPT_SHORTEST_DISTANCE_H_