summaryrefslogtreecommitdiff
path: root/kaldi_io/src/tools/openfst/include/fst/reweight.h
diff options
context:
space:
mode:
Diffstat (limited to 'kaldi_io/src/tools/openfst/include/fst/reweight.h')
-rw-r--r--kaldi_io/src/tools/openfst/include/fst/reweight.h146
1 files changed, 146 insertions, 0 deletions
diff --git a/kaldi_io/src/tools/openfst/include/fst/reweight.h b/kaldi_io/src/tools/openfst/include/fst/reweight.h
new file mode 100644
index 0000000..c051c2a
--- /dev/null
+++ b/kaldi_io/src/tools/openfst/include/fst/reweight.h
@@ -0,0 +1,146 @@
+// reweight.h
+
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+//
+// Copyright 2005-2010 Google, Inc.
+// Author: allauzen@google.com (Cyril Allauzen)
+//
+// \file
+// Function to reweight an FST.
+
+#ifndef FST_LIB_REWEIGHT_H__
+#define FST_LIB_REWEIGHT_H__
+
+#include <vector>
+using std::vector;
+
+#include <fst/mutable-fst.h>
+
+
+namespace fst {
+
+enum ReweightType { REWEIGHT_TO_INITIAL, REWEIGHT_TO_FINAL };
+
+// Reweight FST according to the potentials defined by the POTENTIAL
+// vector in the direction defined by TYPE. Weight needs to be left
+// distributive when reweighting towards the initial state and right
+// distributive when reweighting towards the final states.
+//
+// An arc of weight w, with an origin state of potential p and
+// destination state of potential q, is reweighted by p\wq when
+// reweighting towards the initial state and by pw/q when reweighting
+// towards the final states.
+template <class Arc>
+void Reweight(MutableFst<Arc> *fst,
+ const vector<typename Arc::Weight> &potential,
+ ReweightType type) {
+ typedef typename Arc::Weight Weight;
+
+ if (fst->NumStates() == 0)
+ return;
+
+ if (type == REWEIGHT_TO_FINAL && !(Weight::Properties() & kRightSemiring)) {
+ FSTERROR() << "Reweight: Reweighting to the final states requires "
+ << "Weight to be right distributive: "
+ << Weight::Type();
+ fst->SetProperties(kError, kError);
+ return;
+ }
+
+ if (type == REWEIGHT_TO_INITIAL && !(Weight::Properties() & kLeftSemiring)) {
+ FSTERROR() << "Reweight: Reweighting to the initial state requires "
+ << "Weight to be left distributive: "
+ << Weight::Type();
+ fst->SetProperties(kError, kError);
+ return;
+ }
+
+ StateIterator< MutableFst<Arc> > sit(*fst);
+ for (; !sit.Done(); sit.Next()) {
+ typename Arc::StateId state = sit.Value();
+ if (state == potential.size())
+ break;
+ typename Arc::Weight weight = potential[state];
+ if (weight != Weight::Zero()) {
+ for (MutableArcIterator< MutableFst<Arc> > ait(fst, state);
+ !ait.Done();
+ ait.Next()) {
+ Arc arc = ait.Value();
+ if (arc.nextstate >= potential.size())
+ continue;
+ typename Arc::Weight nextweight = potential[arc.nextstate];
+ if (nextweight == Weight::Zero())
+ continue;
+ if (type == REWEIGHT_TO_INITIAL)
+ arc.weight = Divide(Times(arc.weight, nextweight), weight,
+ DIVIDE_LEFT);
+ if (type == REWEIGHT_TO_FINAL)
+ arc.weight = Divide(Times(weight, arc.weight), nextweight,
+ DIVIDE_RIGHT);
+ ait.SetValue(arc);
+ }
+ if (type == REWEIGHT_TO_INITIAL)
+ fst->SetFinal(state, Divide(fst->Final(state), weight, DIVIDE_LEFT));
+ }
+ if (type == REWEIGHT_TO_FINAL)
+ fst->SetFinal(state, Times(weight, fst->Final(state)));
+ }
+
+ // This handles elements past the end of the potentials array.
+ for (; !sit.Done(); sit.Next()) {
+ typename Arc::StateId state = sit.Value();
+ if (type == REWEIGHT_TO_FINAL)
+ fst->SetFinal(state, Times(Weight::Zero(), fst->Final(state)));
+ }
+
+ typename Arc::Weight startweight = fst->Start() < potential.size() ?
+ potential[fst->Start()] : Weight::Zero();
+ if ((startweight != Weight::One()) && (startweight != Weight::Zero())) {
+ if (fst->Properties(kInitialAcyclic, true) & kInitialAcyclic) {
+ typename Arc::StateId state = fst->Start();
+ for (MutableArcIterator< MutableFst<Arc> > ait(fst, state);
+ !ait.Done();
+ ait.Next()) {
+ Arc arc = ait.Value();
+ if (type == REWEIGHT_TO_INITIAL)
+ arc.weight = Times(startweight, arc.weight);
+ else
+ arc.weight = Times(
+ Divide(Weight::One(), startweight, DIVIDE_RIGHT),
+ arc.weight);
+ ait.SetValue(arc);
+ }
+ if (type == REWEIGHT_TO_INITIAL)
+ fst->SetFinal(state, Times(startweight, fst->Final(state)));
+ else
+ fst->SetFinal(state, Times(Divide(Weight::One(), startweight,
+ DIVIDE_RIGHT),
+ fst->Final(state)));
+ } else {
+ typename Arc::StateId state = fst->AddState();
+ Weight w = type == REWEIGHT_TO_INITIAL ? startweight :
+ Divide(Weight::One(), startweight, DIVIDE_RIGHT);
+ Arc arc(0, 0, w, fst->Start());
+ fst->AddArc(state, arc);
+ fst->SetStart(state);
+ }
+ }
+
+ fst->SetProperties(ReweightProperties(
+ fst->Properties(kFstProperties, false)),
+ kFstProperties);
+}
+
+} // namespace fst
+
+#endif // FST_LIB_REWEIGHT_H_