diff options
Diffstat (limited to 'kaldi_io/src/tools/openfst/include/fst/tuple-weight.h')
-rw-r--r-- | kaldi_io/src/tools/openfst/include/fst/tuple-weight.h | 332 |
1 files changed, 332 insertions, 0 deletions
diff --git a/kaldi_io/src/tools/openfst/include/fst/tuple-weight.h b/kaldi_io/src/tools/openfst/include/fst/tuple-weight.h new file mode 100644 index 0000000..184026c --- /dev/null +++ b/kaldi_io/src/tools/openfst/include/fst/tuple-weight.h @@ -0,0 +1,332 @@ +// tuple-weight.h + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Copyright 2005-2010 Google, Inc. +// Author: allauzen@google (Cyril Allauzen) +// +// \file +// Tuple weight set operation definitions. + +#ifndef FST_LIB_TUPLE_WEIGHT_H__ +#define FST_LIB_TUPLE_WEIGHT_H__ + +#include <string> +#include <vector> +using std::vector; + +#include <fst/weight.h> + + +DECLARE_string(fst_weight_parentheses); +DECLARE_string(fst_weight_separator); + +namespace fst { + +template<class W, unsigned int n> class TupleWeight; +template <class W, unsigned int n> +istream &operator>>(istream &strm, TupleWeight<W, n> &w); + +// n-tuple weight, element of the n-th catersian power of W +template <class W, unsigned int n> +class TupleWeight { + public: + typedef TupleWeight<typename W::ReverseWeight, n> ReverseWeight; + + TupleWeight() {} + + TupleWeight(const TupleWeight &w) { + for (size_t i = 0; i < n; ++i) + values_[i] = w.values_[i]; + } + + template <class Iterator> + TupleWeight(Iterator begin, Iterator end) { + for (Iterator iter = begin; iter != end; ++iter) + values_[iter - begin] = *iter; + } + + TupleWeight(const W &w) { + for (size_t i = 0; i < n; ++i) + values_[i] = w; + } + + static const TupleWeight<W, n> &Zero() { + static const TupleWeight<W, n> zero(W::Zero()); + return zero; + } + + static const TupleWeight<W, n> &One() { + static const TupleWeight<W, n> one(W::One()); + return one; + } + + static const TupleWeight<W, n> &NoWeight() { + static const TupleWeight<W, n> no_weight(W::NoWeight()); + return no_weight; + } + + static unsigned int Length() { + return n; + } + + istream &Read(istream &strm) { + for (size_t i = 0; i < n; ++i) + values_[i].Read(strm); + return strm; + } + + ostream &Write(ostream &strm) const { + for (size_t i = 0; i < n; ++i) + values_[i].Write(strm); + return strm; + } + + TupleWeight<W, n> &operator=(const TupleWeight<W, n> &w) { + for (size_t i = 0; i < n; ++i) + values_[i] = w.values_[i]; + return *this; + } + + bool Member() const { + bool member = true; + for (size_t i = 0; i < n; ++i) + member = member && values_[i].Member(); + return member; + } + + size_t Hash() const { + uint64 hash = 0; + for (size_t i = 0; i < n; ++i) + hash = 5 * hash + values_[i].Hash(); + return size_t(hash); + } + + TupleWeight<W, n> Quantize(float delta = kDelta) const { + TupleWeight<W, n> w; + for (size_t i = 0; i < n; ++i) + w.values_[i] = values_[i].Quantize(delta); + return w; + } + + ReverseWeight Reverse() const { + TupleWeight<W, n> w; + for (size_t i = 0; i < n; ++i) + w.values_[i] = values_[i].Reverse(); + return w; + } + + const W& Value(size_t i) const { return values_[i]; } + + void SetValue(size_t i, const W &w) { values_[i] = w; } + + protected: + // Reads TupleWeight when there are no parentheses around tuple terms + inline static istream &ReadNoParen(istream &strm, + TupleWeight<W, n> &w, + char separator) { + int c; + do { + c = strm.get(); + } while (isspace(c)); + + for (size_t i = 0; i < n - 1; ++i) { + string s; + if (i) + c = strm.get(); + while (c != separator) { + if (c == EOF) { + strm.clear(std::ios::badbit); + return strm; + } + s += c; + c = strm.get(); + } + // read (i+1)-th element + istringstream sstrm(s); + W r = W::Zero(); + sstrm >> r; + w.SetValue(i, r); + } + + // read n-th element + W r = W::Zero(); + strm >> r; + w.SetValue(n - 1, r); + + return strm; + } + + // Reads TupleWeight when there are parentheses around tuple terms + inline static istream &ReadWithParen(istream &strm, + TupleWeight<W, n> &w, + char separator, + char open_paren, + char close_paren) { + int c; + do { + c = strm.get(); + } while (isspace(c)); + + if (c != open_paren) { + FSTERROR() << " is fst_weight_parentheses flag set correcty? "; + strm.clear(std::ios::badbit); + return strm; + } + + for (size_t i = 0; i < n - 1; ++i) { + // read (i+1)-th element + stack<int> parens; + string s; + c = strm.get(); + while (c != separator || !parens.empty()) { + if (c == EOF) { + strm.clear(std::ios::badbit); + return strm; + } + s += c; + // if parens encountered before separator, they must be matched + if (c == open_paren) { + parens.push(1); + } else if (c == close_paren) { + // Fail for mismatched parens + if (parens.empty()) { + strm.clear(std::ios::failbit); + return strm; + } + parens.pop(); + } + c = strm.get(); + } + istringstream sstrm(s); + W r = W::Zero(); + sstrm >> r; + w.SetValue(i, r); + } + + // read n-th element + string s; + c = strm.get(); + while (c != EOF) { + s += c; + c = strm.get(); + } + if (s.empty() || *s.rbegin() != close_paren) { + FSTERROR() << " is fst_weight_parentheses flag set correcty? "; + strm.clear(std::ios::failbit); + return strm; + } + s.erase(s.size() - 1, 1); + istringstream sstrm(s); + W r = W::Zero(); + sstrm >> r; + w.SetValue(n - 1, r); + + return strm; + } + + + private: + W values_[n]; + + friend istream &operator>><W, n>(istream&, TupleWeight<W, n>&); +}; + +template <class W, unsigned int n> +inline bool operator==(const TupleWeight<W, n> &w1, + const TupleWeight<W, n> &w2) { + bool equal = true; + for (size_t i = 0; i < n; ++i) + equal = equal && (w1.Value(i) == w2.Value(i)); + return equal; +} + +template <class W, unsigned int n> +inline bool operator!=(const TupleWeight<W, n> &w1, + const TupleWeight<W, n> &w2) { + bool not_equal = false; + for (size_t i = 0; (i < n) && !not_equal; ++i) + not_equal = not_equal || (w1.Value(i) != w2.Value(i)); + return not_equal; +} + +template <class W, unsigned int n> +inline bool ApproxEqual(const TupleWeight<W, n> &w1, + const TupleWeight<W, n> &w2, + float delta = kDelta) { + bool approx_equal = true; + for (size_t i = 0; i < n; ++i) + approx_equal = approx_equal && + ApproxEqual(w1.Value(i), w2.Value(i), delta); + return approx_equal; +} + +template <class W, unsigned int n> +inline ostream &operator<<(ostream &strm, const TupleWeight<W, n> &w) { + if(FLAGS_fst_weight_separator.size() != 1) { + FSTERROR() << "FLAGS_fst_weight_separator.size() is not equal to 1"; + strm.clear(std::ios::badbit); + return strm; + } + char separator = FLAGS_fst_weight_separator[0]; + bool write_parens = false; + if (!FLAGS_fst_weight_parentheses.empty()) { + if (FLAGS_fst_weight_parentheses.size() != 2) { + FSTERROR() << "FLAGS_fst_weight_parentheses.size() is not equal to 2"; + strm.clear(std::ios::badbit); + return strm; + } + write_parens = true; + } + + if (write_parens) + strm << FLAGS_fst_weight_parentheses[0]; + for (size_t i = 0; i < n; ++i) { + if(i) + strm << separator; + strm << w.Value(i); + } + if (write_parens) + strm << FLAGS_fst_weight_parentheses[1]; + + return strm; +} + +template <class W, unsigned int n> +inline istream &operator>>(istream &strm, TupleWeight<W, n> &w) { + if(FLAGS_fst_weight_separator.size() != 1) { + FSTERROR() << "FLAGS_fst_weight_separator.size() is not equal to 1"; + strm.clear(std::ios::badbit); + return strm; + } + char separator = FLAGS_fst_weight_separator[0]; + + if (!FLAGS_fst_weight_parentheses.empty()) { + if (FLAGS_fst_weight_parentheses.size() != 2) { + FSTERROR() << "FLAGS_fst_weight_parentheses.size() is not equal to 2"; + strm.clear(std::ios::badbit); + return strm; + } + return TupleWeight<W, n>::ReadWithParen( + strm, w, separator, FLAGS_fst_weight_parentheses[0], + FLAGS_fst_weight_parentheses[1]); + } else { + return TupleWeight<W, n>::ReadNoParen(strm, w, separator); + } +} + + + +} // namespace fst + +#endif // FST_LIB_TUPLE_WEIGHT_H__ |