diff options
Diffstat (limited to 'kaldi_io/src/tools/openfst/include/fst/sparse-tuple-weight.h')
-rw-r--r-- | kaldi_io/src/tools/openfst/include/fst/sparse-tuple-weight.h | 640 |
1 files changed, 640 insertions, 0 deletions
diff --git a/kaldi_io/src/tools/openfst/include/fst/sparse-tuple-weight.h b/kaldi_io/src/tools/openfst/include/fst/sparse-tuple-weight.h new file mode 100644 index 0000000..c12ef4f --- /dev/null +++ b/kaldi_io/src/tools/openfst/include/fst/sparse-tuple-weight.h @@ -0,0 +1,640 @@ +// sparse-tuple-weight.h + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Copyright 2005-2010 Google, Inc. +// Author: [email protected] (Kasturi Rangan Raghavan) +// Inspiration: [email protected] (Cyril Allauzen) +// \file +// Sparse version of tuple-weight, based on tuple-weight.h +// Internally stores sparse key, value pairs in linked list +// Default value elemnt is the assumed value of unset keys +// Internal singleton implementation that stores first key, +// value pair as a initialized member variable to avoide +// unnecessary allocation on heap. +// Use SparseTupleWeightIterator to iterate through the key,value pairs +// Note: this does NOT iterate through the default value. +// +// Sparse tuple weight set operation definitions. + +#ifndef FST_LIB_SPARSE_TUPLE_WEIGHT_H__ +#define FST_LIB_SPARSE_TUPLE_WEIGHT_H__ + +#include<string> +#include<list> +#include<stack> +#include<tr1/unordered_map> +using std::tr1::unordered_map; +using std::tr1::unordered_multimap; + +#include <fst/weight.h> + + +DECLARE_string(fst_weight_parentheses); +DECLARE_string(fst_weight_separator); + +namespace fst { + +template <class W, class K> class SparseTupleWeight; + +template<class W, class K> +class SparseTupleWeightIterator; + +template <class W, class K> +istream &operator>>(istream &strm, SparseTupleWeight<W, K> &w); + +// Arbitrary dimension tuple weight, stored as a sorted linked-list +// W is any weight class, +// K is the key value type. kNoKey(-1) is reserved for internal use +template <class W, class K = int> +class SparseTupleWeight { + public: + typedef pair<K, W> Pair; + typedef SparseTupleWeight<typename W::ReverseWeight, K> ReverseWeight; + + const static K kNoKey = -1; + SparseTupleWeight() { + Init(); + } + + template <class Iterator> + SparseTupleWeight(Iterator begin, Iterator end) { + Init(); + // Assumes input iterator is sorted + for (Iterator it = begin; it != end; ++it) + Push(*it); + } + + + SparseTupleWeight(const K& key, const W &w) { + Init(); + Push(key, w); + } + + SparseTupleWeight(const W &w) { + Init(w); + } + + SparseTupleWeight(const SparseTupleWeight<W, K> &w) { + Init(w.DefaultValue()); + SetDefaultValue(w.DefaultValue()); + for (SparseTupleWeightIterator<W, K> it(w); !it.Done(); it.Next()) { + Push(it.Value()); + } + } + + static const SparseTupleWeight<W, K> &Zero() { + static SparseTupleWeight<W, K> zero; + return zero; + } + + static const SparseTupleWeight<W, K> &One() { + static SparseTupleWeight<W, K> one(W::One()); + return one; + } + + static const SparseTupleWeight<W, K> &NoWeight() { + static SparseTupleWeight<W, K> no_weight(W::NoWeight()); + return no_weight; + } + + istream &Read(istream &strm) { + ReadType(strm, &default_); + ReadType(strm, &first_); + return ReadType(strm, &rest_); + } + + ostream &Write(ostream &strm) const { + WriteType(strm, default_); + WriteType(strm, first_); + return WriteType(strm, rest_); + } + + SparseTupleWeight<W, K> &operator=(const SparseTupleWeight<W, K> &w) { + if (this == &w) return *this; // check for w = w + Init(w.DefaultValue()); + for (SparseTupleWeightIterator<W, K> it(w); !it.Done(); it.Next()) { + Push(it.Value()); + } + return *this; + } + + bool Member() const { + if (!DefaultValue().Member()) return false; + for (SparseTupleWeightIterator<W, K> it(*this); !it.Done(); it.Next()) { + if (!it.Value().second.Member()) return false; + } + return true; + } + + // Assumes H() function exists for the hash of the key value + size_t Hash() const { + uint64 h = 0; + std::tr1::hash<K> H; + for (SparseTupleWeightIterator<W, K> it(*this); !it.Done(); it.Next()) { + h = 5 * h + H(it.Value().first); + h = 13 * h + it.Value().second.Hash(); + } + return size_t(h); + } + + SparseTupleWeight<W, K> Quantize(float delta = kDelta) const { + SparseTupleWeight<W, K> w; + for (SparseTupleWeightIterator<W, K> it(*this); !it.Done(); it.Next()) { + w.Push(it.Value().first, it.Value().second.Quantize(delta)); + } + return w; + } + + ReverseWeight Reverse() const { + SparseTupleWeight<W, K> w; + for (SparseTupleWeightIterator<W, K> it(*this); !it.Done(); it.Next()) { + w.Push(it.Value().first, it.Value().second.Reverse()); + } + return w; + } + + // Common initializer among constructors. + void Init() { + Init(W::Zero()); + } + + void Init(const W& default_value) { + first_.first = kNoKey; + /* initialized to the reserved key value */ + default_ = default_value; + rest_.clear(); + } + + size_t Size() const { + if (first_.first == kNoKey) + return 0; + else + return rest_.size() + 1; + } + + inline void Push(const K &k, const W &w, bool default_value_check = true) { + Push(make_pair(k, w), default_value_check); + } + + inline void Push(const Pair &p, bool default_value_check = true) { + if (default_value_check && p.second == default_) return; + if (first_.first == kNoKey) { + first_ = p; + } else { + rest_.push_back(p); + } + } + + void SetDefaultValue(const W& val) { default_ = val; } + + const W& DefaultValue() const { return default_; } + + protected: + static istream& ReadNoParen( + istream&, SparseTupleWeight<W, K>&, char separator); + + static istream& ReadWithParen( + istream&, SparseTupleWeight<W, K>&, + char separator, char open_paren, char close_paren); + + private: + // Assumed default value of uninitialized keys, by default W::Zero() + W default_; + + // Key values pairs are first stored in first_, then fill rest_ + // this way we can avoid dynamic allocation in the common case + // where the weight is a single key,val pair. + Pair first_; + list<Pair> rest_; + + friend istream &operator>><W, K>(istream&, SparseTupleWeight<W, K>&); + friend class SparseTupleWeightIterator<W, K>; +}; + +template<class W, class K> +class SparseTupleWeightIterator { + public: + typedef typename SparseTupleWeight<W, K>::Pair Pair; + typedef typename list<Pair>::const_iterator const_iterator; + typedef typename list<Pair>::iterator iterator; + + explicit SparseTupleWeightIterator(const SparseTupleWeight<W, K>& w) + : first_(w.first_), rest_(w.rest_), init_(true), + iter_(rest_.begin()) {} + + bool Done() const { + if (init_) + return first_.first == SparseTupleWeight<W, K>::kNoKey; + else + return iter_ == rest_.end(); + } + + const Pair& Value() const { return init_ ? first_ : *iter_; } + + void Next() { + if (init_) + init_ = false; + else + ++iter_; + } + + void Reset() { + init_ = true; + iter_ = rest_.begin(); + } + + private: + const Pair &first_; + const list<Pair> & rest_; + bool init_; // in the initialized state? + typename list<Pair>::const_iterator iter_; + + DISALLOW_COPY_AND_ASSIGN(SparseTupleWeightIterator); +}; + +template<class W, class K, class M> +inline void SparseTupleWeightMap( + SparseTupleWeight<W, K>* ret, + const SparseTupleWeight<W, K>& w1, + const SparseTupleWeight<W, K>& w2, + const M& operator_mapper) { + SparseTupleWeightIterator<W, K> w1_it(w1); + SparseTupleWeightIterator<W, K> w2_it(w2); + const W& v1_def = w1.DefaultValue(); + const W& v2_def = w2.DefaultValue(); + ret->SetDefaultValue(operator_mapper.Map(0, v1_def, v2_def)); + while (!w1_it.Done() || !w2_it.Done()) { + const K& k1 = (w1_it.Done()) ? w2_it.Value().first : w1_it.Value().first; + const K& k2 = (w2_it.Done()) ? w1_it.Value().first : w2_it.Value().first; + const W& v1 = (w1_it.Done()) ? v1_def : w1_it.Value().second; + const W& v2 = (w2_it.Done()) ? v2_def : w2_it.Value().second; + if (k1 == k2) { + ret->Push(k1, operator_mapper.Map(k1, v1, v2)); + if (!w1_it.Done()) w1_it.Next(); + if (!w2_it.Done()) w2_it.Next(); + } else if (k1 < k2) { + ret->Push(k1, operator_mapper.Map(k1, v1, v2_def)); + w1_it.Next(); + } else { + ret->Push(k2, operator_mapper.Map(k2, v1_def, v2)); + w2_it.Next(); + } + } +} + +template <class W, class K> +inline bool operator==(const SparseTupleWeight<W, K> &w1, + const SparseTupleWeight<W, K> &w2) { + const W& v1_def = w1.DefaultValue(); + const W& v2_def = w2.DefaultValue(); + if (v1_def != v2_def) return false; + + SparseTupleWeightIterator<W, K> w1_it(w1); + SparseTupleWeightIterator<W, K> w2_it(w2); + while (!w1_it.Done() || !w2_it.Done()) { + const K& k1 = (w1_it.Done()) ? w2_it.Value().first : w1_it.Value().first; + const K& k2 = (w2_it.Done()) ? w1_it.Value().first : w2_it.Value().first; + const W& v1 = (w1_it.Done()) ? v1_def : w1_it.Value().second; + const W& v2 = (w2_it.Done()) ? v2_def : w2_it.Value().second; + if (k1 == k2) { + if (v1 != v2) return false; + if (!w1_it.Done()) w1_it.Next(); + if (!w2_it.Done()) w2_it.Next(); + } else if (k1 < k2) { + if (v1 != v2_def) return false; + w1_it.Next(); + } else { + if (v1_def != v2) return false; + w2_it.Next(); + } + } + return true; +} + +template <class W, class K> +inline bool operator!=(const SparseTupleWeight<W, K> &w1, + const SparseTupleWeight<W, K> &w2) { + return !(w1 == w2); +} + +template <class W, class K> +inline ostream &operator<<(ostream &strm, const SparseTupleWeight<W, K> &w) { + if(FLAGS_fst_weight_separator.size() != 1) { + FSTERROR() << "FLAGS_fst_weight_separator.size() is not equal to 1"; + strm.clear(std::ios::badbit); + return strm; + } + char separator = FLAGS_fst_weight_separator[0]; + bool write_parens = false; + if (!FLAGS_fst_weight_parentheses.empty()) { + if (FLAGS_fst_weight_parentheses.size() != 2) { + FSTERROR() << "FLAGS_fst_weight_parentheses.size() is not equal to 2"; + strm.clear(std::ios::badbit); + return strm; + } + write_parens = true; + } + + if (write_parens) + strm << FLAGS_fst_weight_parentheses[0]; + + strm << w.DefaultValue(); + strm << separator; + + size_t n = w.Size(); + strm << n; + strm << separator; + + for (SparseTupleWeightIterator<W, K> it(w); !it.Done(); it.Next()) { + strm << it.Value().first; + strm << separator; + strm << it.Value().second; + strm << separator; + } + + if (write_parens) + strm << FLAGS_fst_weight_parentheses[1]; + + return strm; +} + +template <class W, class K> +inline istream &operator>>(istream &strm, SparseTupleWeight<W, K> &w) { + if(FLAGS_fst_weight_separator.size() != 1) { + FSTERROR() << "FLAGS_fst_weight_separator.size() is not equal to 1"; + strm.clear(std::ios::badbit); + return strm; + } + char separator = FLAGS_fst_weight_separator[0]; + + if (!FLAGS_fst_weight_parentheses.empty()) { + if (FLAGS_fst_weight_parentheses.size() != 2) { + FSTERROR() << "FLAGS_fst_weight_parentheses.size() is not equal to 2"; + strm.clear(std::ios::badbit); + return strm; + } + return SparseTupleWeight<W, K>::ReadWithParen( + strm, w, separator, FLAGS_fst_weight_parentheses[0], + FLAGS_fst_weight_parentheses[1]); + } else { + return SparseTupleWeight<W, K>::ReadNoParen(strm, w, separator); + } +} + +// Reads SparseTupleWeight when there are no parentheses around tuple terms +template <class W, class K> +inline istream& SparseTupleWeight<W, K>::ReadNoParen( + istream &strm, + SparseTupleWeight<W, K> &w, + char separator) { + int c; + size_t n; + + do { + c = strm.get(); + } while (isspace(c)); + + + { // Read default weight + W default_value; + string s; + while (c != separator) { + if (c == EOF) { + strm.clear(std::ios::badbit); + return strm; + } + s += c; + c = strm.get(); + } + istringstream sstrm(s); + sstrm >> default_value; + w.SetDefaultValue(default_value); + } + + c = strm.get(); + + { // Read n + string s; + while (c != separator) { + if (c == EOF) { + strm.clear(std::ios::badbit); + return strm; + } + s += c; + c = strm.get(); + } + istringstream sstrm(s); + sstrm >> n; + } + + // Read n elements + for (size_t i = 0; i < n; ++i) { + // discard separator + c = strm.get(); + K p; + W r; + + { // read key + string s; + while (c != separator) { + if (c == EOF) { + strm.clear(std::ios::badbit); + return strm; + } + s += c; + c = strm.get(); + } + istringstream sstrm(s); + sstrm >> p; + } + + c = strm.get(); + + { // read weight + string s; + while (c != separator) { + if (c == EOF) { + strm.clear(std::ios::badbit); + return strm; + } + s += c; + c = strm.get(); + } + istringstream sstrm(s); + sstrm >> r; + } + + w.Push(p, r); + } + + c = strm.get(); + if (c != separator) { + strm.clear(std::ios::badbit); + } + + return strm; +} + +// Reads SparseTupleWeight when there are parentheses around tuple terms +template <class W, class K> +inline istream& SparseTupleWeight<W, K>::ReadWithParen( + istream &strm, + SparseTupleWeight<W, K> &w, + char separator, + char open_paren, + char close_paren) { + int c; + size_t n; + + do { + c = strm.get(); + } while (isspace(c)); + + if (c != open_paren) { + FSTERROR() << "is fst_weight_parentheses flag set correcty? "; + strm.clear(std::ios::badbit); + return strm; + } + + c = strm.get(); + + { // Read weight + W default_value; + stack<int> parens; + string s; + while (c != separator || !parens.empty()) { + if (c == EOF) { + strm.clear(std::ios::badbit); + return strm; + } + s += c; + // If parens encountered before separator, they must be matched + if (c == open_paren) { + parens.push(1); + } else if (c == close_paren) { + // Fail for mismatched parens + if (parens.empty()) { + strm.clear(std::ios::failbit); + return strm; + } + parens.pop(); + } + c = strm.get(); + } + istringstream sstrm(s); + sstrm >> default_value; + w.SetDefaultValue(default_value); + } + + c = strm.get(); + + { // Read n + string s; + while (c != separator) { + if (c == EOF) { + strm.clear(std::ios::badbit); + return strm; + } + s += c; + c = strm.get(); + } + istringstream sstrm(s); + sstrm >> n; + } + + // Read n elements + for (size_t i = 0; i < n; ++i) { + // discard separator + c = strm.get(); + K p; + W r; + + { // Read key + stack<int> parens; + string s; + while (c != separator || !parens.empty()) { + if (c == EOF) { + strm.clear(std::ios::badbit); + return strm; + } + s += c; + // If parens encountered before separator, they must be matched + if (c == open_paren) { + parens.push(1); + } else if (c == close_paren) { + // Fail for mismatched parens + if (parens.empty()) { + strm.clear(std::ios::failbit); + return strm; + } + parens.pop(); + } + c = strm.get(); + } + istringstream sstrm(s); + sstrm >> p; + } + + c = strm.get(); + + { // Read weight + stack<int> parens; + string s; + while (c != separator || !parens.empty()) { + if (c == EOF) { + strm.clear(std::ios::badbit); + return strm; + } + s += c; + // If parens encountered before separator, they must be matched + if (c == open_paren) { + parens.push(1); + } else if (c == close_paren) { + // Fail for mismatched parens + if (parens.empty()) { + strm.clear(std::ios::failbit); + return strm; + } + parens.pop(); + } + c = strm.get(); + } + istringstream sstrm(s); + sstrm >> r; + } + + w.Push(p, r); + } + + if (c != separator) { + FSTERROR() << " separator expected, not found! "; + strm.clear(std::ios::badbit); + return strm; + } + + c = strm.get(); + if (c != close_paren) { + FSTERROR() << " is fst_weight_parentheses flag set correcty? "; + strm.clear(std::ios::badbit); + return strm; + } + + return strm; +} + + + +} // namespace fst + +#endif // FST_LIB_SPARSE_TUPLE_WEIGHT_H__ |