summaryrefslogblamecommitdiff
path: root/kaldi_io/src/tools/openfst/include/fst/script/rmepsilon.h
blob: 62fed0373849fcf46b5e26e5dc4ca161f975e9d3 (plain) (tree)


















































































































































































































                                                                               
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// Copyright 2005-2010 Google, Inc.
// Author: jpr@google.com (Jake Ratkiewicz)

#ifndef FST_SCRIPT_RMEPSILON_H_
#define FST_SCRIPT_RMEPSILON_H_

#include <vector>
using std::vector;

#include <fst/script/arg-packs.h>
#include <fst/script/fst-class.h>
#include <fst/script/weight-class.h>
#include <fst/script/shortest-distance.h>  // for ShortestDistanceOptions
#include <fst/rmepsilon.h>
#include <fst/queue.h>

// the following is necessary, or SWIG complains mightily about
// shortestdistanceoptions not being defined before being used as a base.
#ifdef SWIG
%include "nlp/fst/script/shortest-distance.h"
#endif


namespace fst {
namespace script {

//
// OPTIONS
//

struct RmEpsilonOptions : public fst::script::ShortestDistanceOptions {
  bool connect;
  WeightClass weight_threshold;
  int64 state_threshold;

  RmEpsilonOptions(QueueType qt = AUTO_QUEUE, float d = kDelta, bool c = true,
                   WeightClass w = fst::script::WeightClass::Zero(),
                   int64 n = kNoStateId)
      : ShortestDistanceOptions(qt, EPSILON_ARC_FILTER,
                                kNoStateId, d),
        connect(c), weight_threshold(w), state_threshold(n) { }
};


//
// TEMPLATES
//

// this function takes care of transforming a script-land RmEpsilonOptions
// into a lib-land RmEpsilonOptions
template<class Arc>
void RmEpsilonHelper(MutableFst<Arc> *fst,
                     vector<typename Arc::Weight> *distance,
                     const RmEpsilonOptions &opts) {
  typedef typename Arc::StateId StateId;
  typedef typename Arc::Weight Weight;

  typename Arc::Weight weight_thresh =
      *(opts.weight_threshold.GetWeight<Weight>());

  switch (opts.queue_type) {
    case AUTO_QUEUE: {
      AutoQueue<StateId> queue(*fst, distance, EpsilonArcFilter<Arc>());
      fst::RmEpsilonOptions<Arc, AutoQueue<StateId> > ropts(
          &queue, opts.delta, opts.connect, weight_thresh,
          opts.state_threshold);
      RmEpsilon(fst, distance, ropts);
      break;
    }
    case FIFO_QUEUE: {
      FifoQueue<StateId> queue;
      fst::RmEpsilonOptions<Arc, FifoQueue<StateId> > ropts(
          &queue, opts.delta, opts.connect, weight_thresh,
          opts.state_threshold);
      RmEpsilon(fst, distance, ropts);
      break;
    }
    case LIFO_QUEUE: {
      LifoQueue<StateId> queue;
      fst::RmEpsilonOptions<Arc, LifoQueue<StateId> > ropts(
          &queue, opts.delta, opts.connect, weight_thresh,
          opts.state_threshold);
      RmEpsilon(fst, distance, ropts);
      break;
    }
    case SHORTEST_FIRST_QUEUE: {
      NaturalShortestFirstQueue<StateId, Weight>  queue(*distance);
      fst::RmEpsilonOptions<Arc, NaturalShortestFirstQueue<StateId,
                                                               Weight> > ropts(
          &queue, opts.delta, opts.connect, weight_thresh,
          opts.state_threshold);
      RmEpsilon(fst, distance, ropts);
      break;
    }
    case STATE_ORDER_QUEUE: {
      StateOrderQueue<StateId> queue;
      fst::RmEpsilonOptions<Arc, StateOrderQueue<StateId> > ropts(
          &queue, opts.delta, opts.connect, weight_thresh,
          opts.state_threshold);
      RmEpsilon(fst, distance, ropts);
      break;
    }
    case TOP_ORDER_QUEUE: {
      TopOrderQueue<StateId> queue(*fst, EpsilonArcFilter<Arc>());
      fst::RmEpsilonOptions<Arc, TopOrderQueue<StateId> > ropts(
          &queue, opts.delta, opts.connect, weight_thresh,
          opts.state_threshold);
      RmEpsilon(fst, distance, ropts);
      break;
    }
    default:
      FSTERROR() << "Unknown or unsupported queue type: " << opts.queue_type;
      fst->SetProperties(kError, kError);
  }
}

// 1
typedef args::Package<const FstClass &, MutableFstClass *,
                      bool, const RmEpsilonOptions &> RmEpsilonArgs1;

template<class Arc>
void RmEpsilon(RmEpsilonArgs1 *args) {
  const Fst<Arc> &ifst = *(args->arg1.GetFst<Arc>());
  MutableFst<Arc> *ofst = args->arg2->GetMutableFst<Arc>();
  vector<typename Arc::Weight> distance;
  bool reverse = args->arg3;

  if (reverse) {
    VectorFst<Arc> rfst;
    Reverse(ifst, &rfst);
    RmEpsilonHelper(&rfst, &distance, args->arg4);
    Reverse(rfst, ofst);
  } else {
    *ofst = ifst;
  }
  RmEpsilonHelper(ofst, &distance, args->arg4);
}

// 2
typedef args::Package<MutableFstClass *, bool,
                      const WeightClass, int64,
                      float> RmEpsilonArgs2;

template<class Arc>
void RmEpsilon(RmEpsilonArgs2 *args) {
  MutableFst<Arc> *fst = args->arg1->GetMutableFst<Arc>();
  typename Arc::Weight w = *(args->arg3.GetWeight<typename Arc::Weight>());

  RmEpsilon(fst, args->arg2, w, args->arg4, args->arg5);
}

// 3
typedef args::Package<MutableFstClass *, vector<WeightClass> *,
                      const RmEpsilonOptions &> RmEpsilonArgs3;

template<class Arc>
void RmEpsilon(RmEpsilonArgs3 *args) {
  MutableFst<Arc> *fst = args->arg1->GetMutableFst<Arc>();
  const RmEpsilonOptions &opts = args->arg3;

  vector<typename Arc::Weight> weights;

  RmEpsilonHelper(fst, &weights, opts);

  // Copy the weights back
  args->arg2->resize(weights.size());
  for (unsigned i = 0; i < weights.size(); ++i) {
    (*args->arg2)[i] = WeightClass(weights[i]);
  }
}

//
// PROTOTYPES
//

// 1
void RmEpsilon(const FstClass &ifst, MutableFstClass *ofst,
               bool reverse = false,
               const RmEpsilonOptions& opts =
                 fst::script::RmEpsilonOptions());

// 2
void RmEpsilon(MutableFstClass *arc, bool connect = true,
               const WeightClass &weight_threshold =
                 fst::script::WeightClass::Zero(),
               int64 state_threshold = fst::kNoStateId,
               float delta = fst::kDelta);

// 3
void RmEpsilon(MutableFstClass *fst, vector<WeightClass> *distance,
               const RmEpsilonOptions &opts);


}  // namespace script
}  // namespace fst


#endif  // FST_SCRIPT_RMEPSILON_H_