summaryrefslogtreecommitdiff
path: root/kaldi_io/src/tools/openfst/include/fst/minimize.h
blob: 6e9dd3da5769a77ab05a378c26d189646a98e2f2 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
// minimize.h
// minimize.h

// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// Copyright 2005-2010 Google, Inc.
// Author: [email protected] (Johan Schalkwyk)
//
// \file Functions and classes to minimize a finite state acceptor
//

#ifndef FST_LIB_MINIMIZE_H__
#define FST_LIB_MINIMIZE_H__

#include <cmath>

#include <algorithm>
#include <map>
#include <queue>
#include <vector>
using std::vector;

#include <fst/arcsort.h>
#include <fst/connect.h>
#include <fst/dfs-visit.h>
#include <fst/encode.h>
#include <fst/factor-weight.h>
#include <fst/fst.h>
#include <fst/mutable-fst.h>
#include <fst/partition.h>
#include <fst/push.h>
#include <fst/queue.h>
#include <fst/reverse.h>
#include <fst/state-map.h>


namespace fst {

// comparator for creating partition based on sorting on
// - states
// - final weight
// - out degree,
// -  (input label, output label, weight, destination_block)
template <class A>
class StateComparator {
 public:
  typedef typename A::StateId StateId;
  typedef typename A::Weight Weight;

  static const uint32 kCompareFinal     = 0x00000001;
  static const uint32 kCompareOutDegree = 0x00000002;
  static const uint32 kCompareArcs      = 0x00000004;
  static const uint32 kCompareAll       = 0x00000007;

  StateComparator(const Fst<A>& fst,
                  const Partition<typename A::StateId>& partition,
                  uint32 flags = kCompareAll)
      : fst_(fst), partition_(partition), flags_(flags) {}

  // compare state x with state y based on sort criteria
  bool operator()(const StateId x, const StateId y) const {
    // check for final state equivalence
    if (flags_ & kCompareFinal) {
      const size_t xfinal = fst_.Final(x).Hash();
      const size_t yfinal = fst_.Final(y).Hash();
      if      (xfinal < yfinal) return true;
      else if (xfinal > yfinal) return false;
    }

    if (flags_ & kCompareOutDegree) {
      // check for # arcs
      if (fst_.NumArcs(x) < fst_.NumArcs(y)) return true;
      if (fst_.NumArcs(x) > fst_.NumArcs(y)) return false;

      if (flags_ & kCompareArcs) {
        // # arcs are equal, check for arc match
        for (ArcIterator<Fst<A> > aiter1(fst_, x), aiter2(fst_, y);
             !aiter1.Done() && !aiter2.Done(); aiter1.Next(), aiter2.Next()) {
          const A& arc1 = aiter1.Value();
          const A& arc2 = aiter2.Value();
          if (arc1.ilabel < arc2.ilabel) return true;
          if (arc1.ilabel > arc2.ilabel) return false;

          if (partition_.class_id(arc1.nextstate) <
              partition_.class_id(arc2.nextstate)) return true;
          if (partition_.class_id(arc1.nextstate) >
              partition_.class_id(arc2.nextstate)) return false;
        }
      }
    }

    return false;
  }

 private:
  const Fst<A>& fst_;
  const Partition<typename A::StateId>& partition_;
  const uint32 flags_;
};

template <class A> const uint32 StateComparator<A>::kCompareFinal;
template <class A> const uint32 StateComparator<A>::kCompareOutDegree;
template <class A> const uint32 StateComparator<A>::kCompareArcs;
template <class A> const uint32 StateComparator<A>::kCompareAll;


// Computes equivalence classes for cyclic Fsts. For cyclic minimization
// we use the classic HopCroft minimization algorithm, which is of
//
//   O(E)log(N),
//
// where E is the number of edges in the machine and N is number of states.
//
// The following paper describes the original algorithm
//  An N Log N algorithm for minimizing states in a finite automaton
//  by John HopCroft, January 1971
//
template <class A, class Queue>
class CyclicMinimizer {
 public:
  typedef typename A::Label Label;
  typedef typename A::StateId StateId;
  typedef typename A::StateId ClassId;
  typedef typename A::Weight Weight;
  typedef ReverseArc<A> RevA;

  CyclicMinimizer(const ExpandedFst<A>& fst):
      // tell the Partition data-member to expect multiple repeated
      // calls to SplitOn with the same element if we are non-deterministic.
      P_(fst.Properties(kIDeterministic, true) == 0) {
    if(fst.Properties(kIDeterministic, true) == 0)
      CHECK(Weight::Properties() & kIdempotent); // this minimization
    // algorithm for non-deterministic FSTs can only work with idempotent
    // semirings.
    Initialize(fst);
    Compute(fst);
  }

  ~CyclicMinimizer() {
    delete aiter_queue_;
  }

  const Partition<StateId>& partition() const {
    return P_;
  }

  // helper classes
 private:
  typedef ArcIterator<Fst<RevA> > ArcIter;
  class ArcIterCompare {
   public:
    ArcIterCompare(const Partition<StateId>& partition)
        : partition_(partition) {}

    ArcIterCompare(const ArcIterCompare& comp)
        : partition_(comp.partition_) {}

    // compare two iterators based on there input labels, and proto state
    // (partition class Ids)
    bool operator()(const ArcIter* x, const ArcIter* y) const {
      const RevA& xarc = x->Value();
      const RevA& yarc = y->Value();
      return (xarc.ilabel > yarc.ilabel);
    }

   private:
    const Partition<StateId>& partition_;
  };

  typedef priority_queue<ArcIter*, vector<ArcIter*>, ArcIterCompare>
  ArcIterQueue;

  // helper methods
 private:
  // prepartitions the space into equivalence classes with
  //   same final weight
  //   same # arcs per state
  //   same outgoing arcs
  void PrePartition(const Fst<A>& fst) {
    VLOG(5) << "PrePartition";

    typedef map<StateId, StateId, StateComparator<A> > EquivalenceMap;
    StateComparator<A> comp(fst, P_, StateComparator<A>::kCompareFinal);
    EquivalenceMap equiv_map(comp);

    StateIterator<Fst<A> > siter(fst);
    StateId class_id = P_.AddClass();
    P_.Add(siter.Value(), class_id);
    equiv_map[siter.Value()] = class_id;
    L_.Enqueue(class_id);
    for (siter.Next(); !siter.Done(); siter.Next()) {
      StateId  s = siter.Value();
      typename EquivalenceMap::const_iterator it = equiv_map.find(s);
      if (it == equiv_map.end()) {
        class_id = P_.AddClass();
        P_.Add(s, class_id);
        equiv_map[s] = class_id;
        L_.Enqueue(class_id);
      } else {
        P_.Add(s, it->second);
        equiv_map[s] = it->second;
      }
    }

    VLOG(5) << "Initial Partition: " << P_.num_classes();
  }

  // - Create inverse transition Tr_ = rev(fst)
  // - loop over states in fst and split on final, creating two blocks
  //   in the partition corresponding to final, non-final
  void Initialize(const Fst<A>& fst) {
    // construct Tr
    Reverse(fst, &Tr_);
    ILabelCompare<RevA> ilabel_comp;
    ArcSort(&Tr_, ilabel_comp);

    // initial split (F, S - F)
    P_.Initialize(Tr_.NumStates() - 1);

    // prep partition
    PrePartition(fst);

    // allocate arc iterator queue
    ArcIterCompare comp(P_);
    aiter_queue_ = new ArcIterQueue(comp);
  }

  // partition all classes with destination C
  void Split(ClassId C) {
    // Prep priority queue. Open arc iterator for each state in C, and
    // insert into priority queue.
    for (PartitionIterator<StateId> siter(P_, C);
         !siter.Done(); siter.Next()) {
      StateId s = siter.Value();
      if (Tr_.NumArcs(s + 1))
        aiter_queue_->push(new ArcIterator<Fst<RevA> >(Tr_, s + 1));
    }

    // Now pop arc iterator from queue, split entering equivalence class
    // re-insert updated iterator into queue.
    Label prev_label = -1;
    while (!aiter_queue_->empty()) {
      ArcIterator<Fst<RevA> >* aiter = aiter_queue_->top();
      aiter_queue_->pop();
      if (aiter->Done()) {
        delete aiter;
        continue;
     }

      const RevA& arc = aiter->Value();
      StateId from_state = aiter->Value().nextstate - 1;
      Label   from_label = arc.ilabel;
      if (prev_label != from_label)
        P_.FinalizeSplit(&L_);

      StateId from_class = P_.class_id(from_state);
      if (P_.class_size(from_class) > 1)
        P_.SplitOn(from_state);

      prev_label = from_label;
      aiter->Next();
      if (aiter->Done())
        delete aiter;
      else
        aiter_queue_->push(aiter);
    }
    P_.FinalizeSplit(&L_);
  }

  // Main loop for hopcroft minimization.
  void Compute(const Fst<A>& fst) {
    // process active classes (FIFO, or FILO)
    while (!L_.Empty()) {
      ClassId C = L_.Head();
      L_.Dequeue();

      // split on C, all labels in C
      Split(C);
    }
  }

  // helper data
 private:
  // Partioning of states into equivalence classes
  Partition<StateId> P_;

  // L = set of active classes to be processed in partition P
  Queue L_;

  // reverse transition function
  VectorFst<RevA> Tr_;

  // Priority queue of open arc iterators for all states in the 'splitter'
  // equivalence class
  ArcIterQueue* aiter_queue_;
};


// Computes equivalence classes for acyclic Fsts. The implementation details
// for this algorithms is documented by the following paper.
//
// Minimization of acyclic deterministic automata in linear time
//  Dominque Revuz
//
// Complexity O(|E|)
//
template <class A>
class AcyclicMinimizer {
 public:
  typedef typename A::Label Label;
  typedef typename A::StateId StateId;
  typedef typename A::StateId ClassId;
  typedef typename A::Weight Weight;

  AcyclicMinimizer(const ExpandedFst<A>& fst):
      // tell the Partition data-member to expect multiple repeated
      // calls to SplitOn with the same element if we are non-deterministic.
      partition_(fst.Properties(kIDeterministic, true) == 0) {
    if(fst.Properties(kIDeterministic, true) == 0)
      CHECK(Weight::Properties() & kIdempotent); // minimization for
    // non-deterministic FSTs can only work with idempotent semirings.
    Initialize(fst);
    Refine(fst);
  }

  const Partition<StateId>& partition() {
    return partition_;
  }

  // helper classes
 private:
  // DFS visitor to compute the height (distance) to final state.
  class HeightVisitor {
   public:
    HeightVisitor() : max_height_(0), num_states_(0) { }

    // invoked before dfs visit
    void InitVisit(const Fst<A>& fst) {}

    // invoked when state is discovered (2nd arg is DFS tree root)
    bool InitState(StateId s, StateId root) {
      // extend height array and initialize height (distance) to 0
      for (size_t i = height_.size(); i <= s; ++i)
        height_.push_back(-1);

      if (s >= num_states_) num_states_ = s + 1;
      return true;
    }

    // invoked when tree arc examined (to undiscoverted state)
    bool TreeArc(StateId s, const A& arc) {
      return true;
    }

    // invoked when back arc examined (to unfinished state)
    bool BackArc(StateId s, const A& arc) {
      return true;
    }

    // invoked when forward or cross arc examined (to finished state)
    bool ForwardOrCrossArc(StateId s, const A& arc) {
      if (height_[arc.nextstate] + 1 > height_[s])
        height_[s] = height_[arc.nextstate] + 1;
      return true;
    }

    // invoked when state finished (parent is kNoStateId for tree root)
    void FinishState(StateId s, StateId parent, const A* parent_arc) {
      if (height_[s] == -1) height_[s] = 0;
      StateId h = height_[s] +  1;
      if (parent >= 0) {
        if (h > height_[parent]) height_[parent] = h;
        if (h > max_height_)     max_height_ = h;
      }
    }

    // invoked after DFS visit
    void FinishVisit() {}

    size_t max_height() const { return max_height_; }

    const vector<StateId>& height() const { return height_; }

    const size_t num_states() const { return num_states_; }

   private:
    vector<StateId> height_;
    size_t max_height_;
    size_t num_states_;
  };

  // helper methods
 private:
  // cluster states according to height (distance to final state)
  void Initialize(const Fst<A>& fst) {
    // compute height (distance to final state)
    HeightVisitor hvisitor;
    DfsVisit(fst, &hvisitor);

    // create initial partition based on height
    partition_.Initialize(hvisitor.num_states());
    partition_.AllocateClasses(hvisitor.max_height() + 1);
    const vector<StateId>& hstates = hvisitor.height();
    for (size_t s = 0; s < hstates.size(); ++s)
      partition_.Add(s, hstates[s]);
  }

  // refine states based on arc sort (out degree, arc equivalence)
  void Refine(const Fst<A>& fst) {
    typedef map<StateId, StateId, StateComparator<A> > EquivalenceMap;
    StateComparator<A> comp(fst, partition_);

    // start with tail (height = 0)
    size_t height = partition_.num_classes();
    for (size_t h = 0; h < height; ++h) {
      EquivalenceMap equiv_classes(comp);

      // sort states within equivalence class
      PartitionIterator<StateId> siter(partition_, h);
      equiv_classes[siter.Value()] = h;
      for (siter.Next(); !siter.Done(); siter.Next()) {
        const StateId s = siter.Value();
        typename EquivalenceMap::const_iterator it = equiv_classes.find(s);
        if (it == equiv_classes.end())
          equiv_classes[s] = partition_.AddClass();
        else
          equiv_classes[s] = it->second;
      }

      // create refined partition
      for (siter.Reset(); !siter.Done();) {
        const StateId s = siter.Value();
        const StateId old_class = partition_.class_id(s);
        const StateId new_class = equiv_classes[s];

        // a move operation can invalidate the iterator, so
        // we first update the iterator to the next element
        // before we move the current element out of the list
        siter.Next();
        if (old_class != new_class)
          partition_.Move(s, new_class);
      }
    }
  }

 private:
  Partition<StateId> partition_;
};


// Given a partition and a mutable fst, merge states of Fst inplace
// (i.e. destructively). Merging works by taking the first state in
// a class of the partition to be the representative state for the class.
// Each arc is then reconnected to this state. All states in the class
// are merged by adding there arcs to the representative state.
template <class A>
void MergeStates(
    const Partition<typename A::StateId>& partition, MutableFst<A>* fst) {
  typedef typename A::StateId StateId;

  vector<StateId> state_map(partition.num_classes());
  for (size_t i = 0; i < partition.num_classes(); ++i) {
    PartitionIterator<StateId> siter(partition, i);
    state_map[i] = siter.Value();  // first state in partition;
  }

  // relabel destination states
  for (size_t c = 0; c < partition.num_classes(); ++c) {
    for (PartitionIterator<StateId> siter(partition, c);
         !siter.Done(); siter.Next()) {
      StateId s = siter.Value();
      for (MutableArcIterator<MutableFst<A> > aiter(fst, s);
           !aiter.Done(); aiter.Next()) {
        A arc = aiter.Value();
        arc.nextstate = state_map[partition.class_id(arc.nextstate)];

        if (s == state_map[c])  // first state just set destination
          aiter.SetValue(arc);
        else
          fst->AddArc(state_map[c], arc);
      }
    }
  }
  fst->SetStart(state_map[partition.class_id(fst->Start())]);

  Connect(fst);
}

template <class A>
void AcceptorMinimize(MutableFst<A>* fst) {
  typedef typename A::StateId StateId;
  if (!(fst->Properties(kAcceptor | kUnweighted, true))) {
    FSTERROR() << "FST is not an unweighted acceptor";
    fst->SetProperties(kError, kError);
    return;
  }

  // connect fst before minimization, handles disconnected states
  Connect(fst);
  if (fst->NumStates() == 0) return;

  if (fst->Properties(kAcyclic, true)) {
    // Acyclic minimization (revuz)
    VLOG(2) << "Acyclic Minimization";
    ArcSort(fst, ILabelCompare<A>());
    AcyclicMinimizer<A> minimizer(*fst);
    MergeStates(minimizer.partition(), fst);

  } else {
    // Cyclic minimizaton (hopcroft)
    VLOG(2) << "Cyclic Minimization";
    CyclicMinimizer<A, LifoQueue<StateId> > minimizer(*fst);
    MergeStates(minimizer.partition(), fst);
  }

  // Merge in appropriate semiring
  ArcUniqueMapper<A> mapper(*fst);
  StateMap(fst, mapper);
}


// In place minimization of deterministic weighted automata and transducers.
// For transducers, then the 'sfst' argument is not null, the algorithm
// produces a compact factorization of the minimal transducer.
//
// In the acyclic case, we use an algorithm from Dominique Revuz that
// is linear in the number of arcs (edges) in the machine.
//  Complexity = O(E)
//
// In the cyclic case, we use the classical hopcroft minimization.
//  Complexity = O(|E|log(|N|)
//
template <class A>
void Minimize(MutableFst<A>* fst,
              MutableFst<A>* sfst = 0,
              float delta = kDelta) {
  uint64 props = fst->Properties(kAcceptor | kWeighted | kUnweighted, true);

  if (!(props & kAcceptor)) {  // weighted transducer
    VectorFst< GallicArc<A, STRING_LEFT> > gfst;
    ArcMap(*fst, &gfst, ToGallicMapper<A, STRING_LEFT>());
    fst->DeleteStates();
    gfst.SetProperties(kAcceptor, kAcceptor);
    Push(&gfst, REWEIGHT_TO_INITIAL, delta);
    ArcMap(&gfst, QuantizeMapper< GallicArc<A, STRING_LEFT> >(delta));
    EncodeMapper< GallicArc<A, STRING_LEFT> >
      encoder(kEncodeLabels | kEncodeWeights, ENCODE);
    Encode(&gfst, &encoder);
    AcceptorMinimize(&gfst);
    Decode(&gfst, encoder);

    if (sfst == 0) {
      FactorWeightFst< GallicArc<A, STRING_LEFT>,
        GallicFactor<typename A::Label,
        typename A::Weight, STRING_LEFT> > fwfst(gfst);
      SymbolTable *osyms = fst->OutputSymbols() ?
          fst->OutputSymbols()->Copy() : 0;
      ArcMap(fwfst, fst, FromGallicMapper<A, STRING_LEFT>());
      fst->SetOutputSymbols(osyms);
      delete osyms;
    } else {
      sfst->SetOutputSymbols(fst->OutputSymbols());
      GallicToNewSymbolsMapper<A, STRING_LEFT> mapper(sfst);
      ArcMap(gfst, fst, &mapper);
      fst->SetOutputSymbols(sfst->InputSymbols());
    }
  } else if (props & kWeighted) {  // weighted acceptor
    Push(fst, REWEIGHT_TO_INITIAL, delta);
    ArcMap(fst, QuantizeMapper<A>(delta));
    EncodeMapper<A> encoder(kEncodeLabels | kEncodeWeights, ENCODE);
    Encode(fst, &encoder);
    AcceptorMinimize(fst);
    Decode(fst, encoder);
  } else {  // unweighted acceptor
    AcceptorMinimize(fst);
  }
}

}  // namespace fst

#endif  // FST_LIB_MINIMIZE_H__