summaryrefslogtreecommitdiff
path: root/kaldi_io/src/tools/openfst/include/fst/script/shortest-distance.h
diff options
context:
space:
mode:
Diffstat (limited to 'kaldi_io/src/tools/openfst/include/fst/script/shortest-distance.h')
-rw-r--r--kaldi_io/src/tools/openfst/include/fst/script/shortest-distance.h250
1 files changed, 250 insertions, 0 deletions
diff --git a/kaldi_io/src/tools/openfst/include/fst/script/shortest-distance.h b/kaldi_io/src/tools/openfst/include/fst/script/shortest-distance.h
new file mode 100644
index 0000000..5fc2976
--- /dev/null
+++ b/kaldi_io/src/tools/openfst/include/fst/script/shortest-distance.h
@@ -0,0 +1,250 @@
+
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+//
+// Copyright 2005-2010 Google, Inc.
+// Author: jpr@google.com (Jake Ratkiewicz)
+
+#ifndef FST_SCRIPT_SHORTEST_DISTANCE_H_
+#define FST_SCRIPT_SHORTEST_DISTANCE_H_
+
+#include <vector>
+using std::vector;
+
+#include <fst/script/arg-packs.h>
+#include <fst/script/fst-class.h>
+#include <fst/script/weight-class.h>
+#include <fst/script/prune.h> // for ArcFilterType
+#include <fst/queue.h> // for QueueType
+#include <fst/shortest-distance.h>
+
+namespace fst {
+namespace script {
+
+enum ArcFilterType { ANY_ARC_FILTER, EPSILON_ARC_FILTER,
+ INPUT_EPSILON_ARC_FILTER, OUTPUT_EPSILON_ARC_FILTER };
+
+// See nlp/fst/lib/shortest-distance.h for the template options class
+// that this one shadows
+struct ShortestDistanceOptions {
+ const QueueType queue_type;
+ const ArcFilterType arc_filter_type;
+ const int64 source;
+ const float delta;
+ const bool first_path;
+
+ ShortestDistanceOptions(QueueType qt, ArcFilterType aft, int64 s,
+ float d)
+ : queue_type(qt), arc_filter_type(aft), source(s), delta(d),
+ first_path(false) { }
+};
+
+
+
+// 1
+typedef args::Package<const FstClass &, vector<WeightClass> *,
+ const ShortestDistanceOptions &> ShortestDistanceArgs1;
+
+template<class Queue, class Arc, class ArcFilter>
+struct QueueConstructor {
+ // template<class Arc, class ArcFilter>
+ static Queue *Construct(const Fst<Arc> &,
+ const vector<typename Arc::Weight> *) {
+ return new Queue();
+ }
+};
+
+// Specializations to deal with AutoQueue, NaturalShortestFirstQueue,
+// and TopOrderQueue's different constructors
+template<class Arc, class ArcFilter>
+struct QueueConstructor<AutoQueue<typename Arc::StateId>, Arc, ArcFilter> {
+ // template<class Arc, class ArcFilter>
+ static AutoQueue<typename Arc::StateId> *Construct(
+ const Fst<Arc> &fst,
+ const vector<typename Arc::Weight> *distance) {
+ return new AutoQueue<typename Arc::StateId>(fst, distance, ArcFilter());
+ }
+};
+
+template<class Arc, class ArcFilter>
+struct QueueConstructor<NaturalShortestFirstQueue<typename Arc::StateId,
+ typename Arc::Weight>,
+ Arc, ArcFilter> {
+ // template<class Arc, class ArcFilter>
+ static NaturalShortestFirstQueue<typename Arc::StateId, typename Arc::Weight>
+ *Construct(const Fst<Arc> &fst,
+ const vector<typename Arc::Weight> *distance) {
+ return new NaturalShortestFirstQueue<typename Arc::StateId,
+ typename Arc::Weight>(*distance);
+ }
+};
+
+template<class Arc, class ArcFilter>
+struct QueueConstructor<TopOrderQueue<typename Arc::StateId>, Arc, ArcFilter> {
+ // template<class Arc, class ArcFilter>
+ static TopOrderQueue<typename Arc::StateId> *Construct(
+ const Fst<Arc> &fst, const vector<typename Arc::Weight> *weights) {
+ return new TopOrderQueue<typename Arc::StateId>(fst, ArcFilter());
+ }
+};
+
+
+template<class Arc, class Queue>
+void ShortestDistanceHelper(ShortestDistanceArgs1 *args) {
+ const Fst<Arc> &fst = *(args->arg1.GetFst<Arc>());
+ const ShortestDistanceOptions &opts = args->arg3;
+
+ vector<typename Arc::Weight> weights;
+
+ switch (opts.arc_filter_type) {
+ case ANY_ARC_FILTER: {
+ Queue *queue =
+ QueueConstructor<Queue, Arc, AnyArcFilter<Arc> >::Construct(
+ fst, &weights);
+ fst::ShortestDistanceOptions<Arc, Queue, AnyArcFilter<Arc> > sdopts(
+ queue, AnyArcFilter<Arc>(), opts.source, opts.delta);
+ ShortestDistance(fst, &weights, sdopts);
+ delete queue;
+ break;
+ }
+ case EPSILON_ARC_FILTER: {
+ Queue *queue =
+ QueueConstructor<Queue, Arc, AnyArcFilter<Arc> >::Construct(
+ fst, &weights);
+ fst::ShortestDistanceOptions<Arc, Queue,
+ EpsilonArcFilter<Arc> > sdopts(
+ queue, EpsilonArcFilter<Arc>(), opts.source, opts.delta);
+ ShortestDistance(fst, &weights, sdopts);
+ delete queue;
+ break;
+ }
+ case INPUT_EPSILON_ARC_FILTER: {
+ Queue *queue =
+ QueueConstructor<Queue, Arc, InputEpsilonArcFilter<Arc> >::Construct(
+ fst, &weights);
+ fst::ShortestDistanceOptions<Arc, Queue,
+ InputEpsilonArcFilter<Arc> > sdopts(
+ queue, InputEpsilonArcFilter<Arc>(), opts.source, opts.delta);
+ ShortestDistance(fst, &weights, sdopts);
+ delete queue;
+ break;
+ }
+ case OUTPUT_EPSILON_ARC_FILTER: {
+ Queue *queue =
+ QueueConstructor<Queue, Arc,
+ OutputEpsilonArcFilter<Arc> >::Construct(
+ fst, &weights);
+ fst::ShortestDistanceOptions<Arc, Queue,
+ OutputEpsilonArcFilter<Arc> > sdopts(
+ queue, OutputEpsilonArcFilter<Arc>(), opts.source, opts.delta);
+ ShortestDistance(fst, &weights, sdopts);
+ delete queue;
+ break;
+ }
+ }
+
+ // Copy the weights back
+ args->arg2->resize(weights.size());
+ for (unsigned i = 0; i < weights.size(); ++i) {
+ (*args->arg2)[i] = WeightClass(weights[i]);
+ }
+}
+
+template<class Arc>
+void ShortestDistance(ShortestDistanceArgs1 *args) {
+ const ShortestDistanceOptions &opts = args->arg3;
+ typedef typename Arc::StateId StateId;
+ typedef typename Arc::Weight Weight;
+
+ // Must consider (opts.queue_type x opts.filter_type) options
+ switch (opts.queue_type) {
+ default:
+ FSTERROR() << "Unknown queue type." << opts.queue_type;
+
+ case AUTO_QUEUE:
+ ShortestDistanceHelper<Arc, AutoQueue<StateId> >(args);
+ return;
+
+ case FIFO_QUEUE:
+ ShortestDistanceHelper<Arc, FifoQueue<StateId> >(args);
+ return;
+
+ case LIFO_QUEUE:
+ ShortestDistanceHelper<Arc, LifoQueue<StateId> >(args);
+ return;
+
+ case SHORTEST_FIRST_QUEUE:
+ ShortestDistanceHelper<Arc,
+ NaturalShortestFirstQueue<StateId, Weight> >(args);
+ return;
+
+ case STATE_ORDER_QUEUE:
+ ShortestDistanceHelper<Arc, StateOrderQueue<StateId> >(args);
+ return;
+
+ case TOP_ORDER_QUEUE:
+ ShortestDistanceHelper<Arc, TopOrderQueue<StateId> >(args);
+ return;
+ }
+}
+
+// 2
+typedef args::Package<const FstClass&, vector<WeightClass>*,
+ bool, double> ShortestDistanceArgs2;
+
+template<class Arc>
+void ShortestDistance(ShortestDistanceArgs2 *args) {
+ const Fst<Arc> &fst = *(args->arg1.GetFst<Arc>());
+ vector<typename Arc::Weight> distance;
+
+ ShortestDistance(fst, &distance, args->arg3, args->arg4);
+
+ // convert the typed weights back into weightclass
+ vector<WeightClass> *retval = args->arg2;
+ retval->resize(distance.size());
+
+ for (unsigned i = 0; i < distance.size(); ++i) {
+ (*retval)[i] = WeightClass(distance[i]);
+ }
+}
+
+// 3
+typedef args::WithReturnValue<WeightClass,
+ const FstClass &> ShortestDistanceArgs3;
+
+template<class Arc>
+void ShortestDistance(ShortestDistanceArgs3 *args) {
+ const Fst<Arc> &fst = *(args->args.GetFst<Arc>());
+
+ args->retval = WeightClass(ShortestDistance(fst));
+}
+
+
+// 1
+void ShortestDistance(const FstClass &fst, vector<WeightClass> *distance,
+ const ShortestDistanceOptions &opts);
+
+// 2
+void ShortestDistance(const FstClass &ifst, vector<WeightClass> *distance,
+ bool reverse = false, double delta = fst::kDelta);
+
+#ifndef SWIG
+// 3
+WeightClass ShortestDistance(const FstClass &ifst);
+#endif
+
+} // namespace script
+} // namespace fst
+
+
+
+#endif // FST_SCRIPT_SHORTEST_DISTANCE_H_