From 96a32415ab43377cf1575bd3f4f2980f58028209 Mon Sep 17 00:00:00 2001 From: Determinant Date: Fri, 14 Aug 2015 11:51:42 +0800 Subject: add implementation for kaldi io (by ymz) --- kaldi_io/src/tools/openfst/include/fst/push.h | 175 ++++++++++++++++++++++++++ 1 file changed, 175 insertions(+) create mode 100644 kaldi_io/src/tools/openfst/include/fst/push.h (limited to 'kaldi_io/src/tools/openfst/include/fst/push.h') diff --git a/kaldi_io/src/tools/openfst/include/fst/push.h b/kaldi_io/src/tools/openfst/include/fst/push.h new file mode 100644 index 0000000..1f7a8fa --- /dev/null +++ b/kaldi_io/src/tools/openfst/include/fst/push.h @@ -0,0 +1,175 @@ +// push.h + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Copyright 2005-2010 Google, Inc. +// Author: allauzen@google.com (Cyril Allauzen) +// +// \file +// Class to reweight/push an FST. + +#ifndef FST_LIB_PUSH_H__ +#define FST_LIB_PUSH_H__ + +#include +using std::vector; + +#include +#include +#include +#include +#include + + +namespace fst { + +// Private helper functions for Push +namespace internal { + +// Compute the total weight (sum of the weights of all accepting paths) from +// the output of ShortestDistance. 'distance' is the shortest distance from the +// initial state when 'reverse == false' and to the final states when +// 'reverse == true'. +template +typename Arc::Weight ComputeTotalWeight( + const Fst &fst, + const vector &distance, + bool reverse) { + if (reverse) + return fst.Start() < distance.size() ? + distance[fst.Start()] : Arc::Weight::Zero(); + + typename Arc::Weight sum = Arc::Weight::Zero(); + for (typename Arc::StateId s = 0; s < distance.size(); ++s) + sum = Plus(sum, Times(distance[s], fst.Final(s))); + return sum; +} + +// Divide the weight of every accepting path by 'w'. The weight 'w' is +// divided at the final states if 'at_final == true' and at the +// initial state otherwise. +template +void RemoveWeight(MutableFst *fst, typename Arc::Weight w, bool at_final) { + if ((w == Arc::Weight::One()) || (w == Arc::Weight::Zero())) + return; + + if (at_final) { + // Remove 'w' from the final states + for (StateIterator< MutableFst > sit(*fst); + !sit.Done(); + sit.Next()) + fst->SetFinal(sit.Value(), + Divide(fst->Final(sit.Value()), w, DIVIDE_RIGHT)); + } else { // at_final == false + // Remove 'w' from the initial state + typename Arc::StateId start = fst->Start(); + for (MutableArcIterator > ait(fst, start); + !ait.Done(); + ait.Next()) { + Arc arc = ait.Value(); + arc.weight = Divide(arc.weight, w, DIVIDE_LEFT); + ait.SetValue(arc); + } + fst->SetFinal(start, Divide(fst->Final(start), w, DIVIDE_LEFT)); + } +} +} // namespace internal + +// Pushes the weights in FST in the direction defined by TYPE. If +// pushing towards the initial state, the sum of the weight of the +// outgoing transitions and final weight at a non-initial state is +// equal to One() in the resulting machine. If pushing towards the +// final state, the same property holds on the reverse machine. +// +// Weight needs to be left distributive when pushing towards the +// initial state and right distributive when pushing towards the final +// states. +template +void Push(MutableFst *fst, + ReweightType type, + float delta = kDelta, + bool remove_total_weight = false) { + vector distance; + ShortestDistance(*fst, &distance, type == REWEIGHT_TO_INITIAL, delta); + typename Arc::Weight total_weight = Arc::Weight::One(); + if (remove_total_weight) + total_weight = internal::ComputeTotalWeight(*fst, distance, + type == REWEIGHT_TO_INITIAL); + Reweight(fst, distance, type); + if (remove_total_weight) + internal::RemoveWeight(fst, total_weight, type == REWEIGHT_TO_FINAL); +} + +const uint32 kPushWeights = 0x0001; +const uint32 kPushLabels = 0x0002; +const uint32 kPushRemoveTotalWeight = 0x0004; +const uint32 kPushRemoveCommonAffix = 0x0008; + +// OFST obtained from IFST by pushing weights and/or labels according +// to PTYPE in the direction defined by RTYPE. Weight needs to be +// left distributive when pushing weights towards the initial state +// and right distributive when pushing weights towards the final +// states. +template +void Push(const Fst &ifst, + MutableFst *ofst, + uint32 ptype, + float delta = kDelta) { + + if ((ptype & (kPushWeights | kPushLabels)) == kPushWeights) { + *ofst = ifst; + Push(ofst, rtype, delta, ptype & kPushRemoveTotalWeight); + } else if (ptype & kPushLabels) { + const StringType stype = rtype == REWEIGHT_TO_INITIAL + ? STRING_LEFT + : STRING_RIGHT; + vector::Weight> gdistance; + VectorFst > gfst; + ArcMap(ifst, &gfst, ToGallicMapper()); + if (ptype & kPushWeights ) { + ShortestDistance(gfst, &gdistance, rtype == REWEIGHT_TO_INITIAL, delta); + } else { + ArcMapFst > + uwfst(ifst, RmWeightMapper()); + ArcMapFst, ToGallicMapper > + guwfst(uwfst, ToGallicMapper()); + ShortestDistance(guwfst, &gdistance, rtype == REWEIGHT_TO_INITIAL, delta); + } + typename GallicArc::Weight total_weight = + GallicArc::Weight::One(); + if (ptype & (kPushRemoveTotalWeight | kPushRemoveCommonAffix)) { + total_weight = internal::ComputeTotalWeight( + gfst, gdistance, rtype == REWEIGHT_TO_INITIAL); + total_weight = typename GallicArc::Weight( + ptype & kPushRemoveCommonAffix ? total_weight.Value1() + : StringWeight::One(), + ptype & kPushRemoveTotalWeight ? total_weight.Value2() + : Arc::Weight::One()); + } + Reweight(&gfst, gdistance, rtype); + if (ptype & (kPushRemoveTotalWeight | kPushRemoveCommonAffix)) + internal::RemoveWeight(&gfst, total_weight, rtype == REWEIGHT_TO_FINAL); + FactorWeightFst< GallicArc, GallicFactor > fwfst(gfst); + ArcMap(fwfst, ofst, FromGallicMapper()); + ofst->SetOutputSymbols(ifst.OutputSymbols()); + } else { + LOG(WARNING) << "Push: pushing type is set to 0: " + << "pushing neither labels nor weights."; + *ofst = ifst; + } +} + +} // namespace fst + +#endif /* FST_LIB_PUSH_H_ */ -- cgit v1.2.3