diff options
author | Determinant <[email protected]> | 2015-08-14 11:51:42 +0800 |
---|---|---|
committer | Determinant <[email protected]> | 2015-08-14 11:51:42 +0800 |
commit | 96a32415ab43377cf1575bd3f4f2980f58028209 (patch) | |
tree | 30a2d92d73e8f40ac87b79f6f56e227bfc4eea6e /kaldi_io/src/tools/openfst/include/fst/string.h | |
parent | c177a7549bd90670af4b29fa813ddea32cfe0f78 (diff) |
add implementation for kaldi io (by ymz)
Diffstat (limited to 'kaldi_io/src/tools/openfst/include/fst/string.h')
-rw-r--r-- | kaldi_io/src/tools/openfst/include/fst/string.h | 271 |
1 files changed, 271 insertions, 0 deletions
diff --git a/kaldi_io/src/tools/openfst/include/fst/string.h b/kaldi_io/src/tools/openfst/include/fst/string.h new file mode 100644 index 0000000..9eaf7a3 --- /dev/null +++ b/kaldi_io/src/tools/openfst/include/fst/string.h @@ -0,0 +1,271 @@ + +// string.h + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Copyright 2005-2010 Google, Inc. +// Author: [email protected] (Cyril Allauzen) +// +// \file +// Utilities to convert strings into FSTs. +// + +#ifndef FST_LIB_STRING_H_ +#define FST_LIB_STRING_H_ + +#include <fst/compact-fst.h> +#include <fst/icu.h> +#include <fst/mutable-fst.h> + +DECLARE_string(fst_field_separator); + +namespace fst { + +// Functor compiling a string in an FST +template <class A> +class StringCompiler { + public: + typedef A Arc; + typedef typename A::Label Label; + typedef typename A::Weight Weight; + + enum TokenType { SYMBOL = 1, BYTE = 2, UTF8 = 3 }; + + StringCompiler(TokenType type, const SymbolTable *syms = 0, + Label unknown_label = kNoLabel, + bool allow_negative = false) + : token_type_(type), syms_(syms), unknown_label_(unknown_label), + allow_negative_(allow_negative) {} + + // Compile string 's' into FST 'fst'. + template <class F> + bool operator()(const string &s, F *fst) const { + vector<Label> labels; + if (!ConvertStringToLabels(s, &labels)) + return false; + Compile(labels, fst); + return true; + } + + template <class F> + bool operator()(const string &s, F *fst, Weight w) const { + vector<Label> labels; + if (!ConvertStringToLabels(s, &labels)) + return false; + Compile(labels, fst, w); + return true; + } + + private: + bool ConvertStringToLabels(const string &str, vector<Label> *labels) const { + labels->clear(); + if (token_type_ == BYTE) { + for (size_t i = 0; i < str.size(); ++i) + labels->push_back(static_cast<unsigned char>(str[i])); + } else if (token_type_ == UTF8) { + return UTF8StringToLabels(str, labels); + } else { + char *c_str = new char[str.size() + 1]; + str.copy(c_str, str.size()); + c_str[str.size()] = 0; + vector<char *> vec; + string separator = "\n" + FLAGS_fst_field_separator; + SplitToVector(c_str, separator.c_str(), &vec, true); + for (size_t i = 0; i < vec.size(); ++i) { + Label label; + if (!ConvertSymbolToLabel(vec[i], &label)) + return false; + labels->push_back(label); + } + delete[] c_str; + } + return true; + } + + void Compile(const vector<Label> &labels, MutableFst<A> *fst, + const Weight &weight = Weight::One()) const { + fst->DeleteStates(); + while (fst->NumStates() <= labels.size()) + fst->AddState(); + for (size_t i = 0; i < labels.size(); ++i) + fst->AddArc(i, Arc(labels[i], labels[i], Weight::One(), i + 1)); + fst->SetStart(0); + fst->SetFinal(labels.size(), weight); + } + + template <class Unsigned> + void Compile(const vector<Label> &labels, + CompactFst<A, StringCompactor<A>, Unsigned> *fst) const { + fst->SetCompactElements(labels.begin(), labels.end()); + } + + template <class Unsigned> + void Compile(const vector<Label> &labels, + CompactFst<A, WeightedStringCompactor<A>, Unsigned> *fst, + const Weight &weight = Weight::One()) const { + vector<pair<Label, Weight> > compacts; + compacts.reserve(labels.size()); + for (size_t i = 0; i < labels.size(); ++i) + compacts.push_back(make_pair(labels[i], Weight::One())); + compacts.back().second = weight; + fst->SetCompactElements(compacts.begin(), compacts.end()); + } + + bool ConvertSymbolToLabel(const char *s, Label* output) const { + int64 n; + if (syms_) { + n = syms_->Find(s); + if ((n == -1) && (unknown_label_ != kNoLabel)) + n = unknown_label_; + if (n == -1 || (!allow_negative_ && n < 0)) { + VLOG(1) << "StringCompiler::ConvertSymbolToLabel: Symbol \"" << s + << "\" is not mapped to any integer label, symbol table = " + << syms_->Name(); + return false; + } + } else { + char *p; + n = strtoll(s, &p, 10); + if (p < s + strlen(s) || (!allow_negative_ && n < 0)) { + VLOG(1) << "StringCompiler::ConvertSymbolToLabel: Bad label integer " + << "= \"" << s << "\""; + return false; + } + } + *output = n; + return true; + } + + TokenType token_type_; // Token type: symbol, byte or utf8 encoded + const SymbolTable *syms_; // Symbol table used when token type is symbol + Label unknown_label_; // Label for token missing from symbol table + bool allow_negative_; // Negative labels allowed? + + DISALLOW_COPY_AND_ASSIGN(StringCompiler); +}; + +// Functor to print a string FST as a string. +template <class A> +class StringPrinter { + public: + typedef A Arc; + typedef typename A::Label Label; + typedef typename A::StateId StateId; + typedef typename A::Weight Weight; + + enum TokenType { SYMBOL = 1, BYTE = 2, UTF8 = 3 }; + + StringPrinter(TokenType token_type, + const SymbolTable *syms = 0) + : token_type_(token_type), syms_(syms) {} + + // Convert the FST 'fst' into the string 'output' + bool operator()(const Fst<A> &fst, string *output) { + bool is_a_string = FstToLabels(fst); + if (!is_a_string) { + VLOG(1) << "StringPrinter::operator(): Fst is not a string."; + return false; + } + + output->clear(); + + if (token_type_ == SYMBOL) { + stringstream sstrm; + for (size_t i = 0; i < labels_.size(); ++i) { + if (i) + sstrm << *(FLAGS_fst_field_separator.rbegin()); + if (!PrintLabel(labels_[i], sstrm)) + return false; + } + *output = sstrm.str(); + } else if (token_type_ == BYTE) { + output->reserve(labels_.size()); + for (size_t i = 0; i < labels_.size(); ++i) { + output->push_back(labels_[i]); + } + } else if (token_type_ == UTF8) { + return LabelsToUTF8String(labels_, output); + } else { + VLOG(1) << "StringPrinter::operator(): Unknown token type: " + << token_type_; + return false; + } + return true; + } + + private: + bool FstToLabels(const Fst<A> &fst) { + labels_.clear(); + + StateId s = fst.Start(); + if (s == kNoStateId) { + VLOG(2) << "StringPrinter::FstToLabels: Invalid starting state for " + << "string fst."; + return false; + } + + while (fst.Final(s) == Weight::Zero()) { + ArcIterator<Fst<A> > aiter(fst, s); + if (aiter.Done()) { + VLOG(2) << "StringPrinter::FstToLabels: String fst traversal does " + << "not reach final state."; + return false; + } + + const A& arc = aiter.Value(); + labels_.push_back(arc.olabel); + + s = arc.nextstate; + if (s == kNoStateId) { + VLOG(2) << "StringPrinter::FstToLabels: Transition to invalid " + << "state."; + return false; + } + + aiter.Next(); + if (!aiter.Done()) { + VLOG(2) << "StringPrinter::FstToLabels: State with multiple " + << "outgoing arcs found."; + return false; + } + } + + return true; + } + + bool PrintLabel(Label lab, ostream& ostrm) { + if (syms_) { + string symbol = syms_->Find(lab); + if (symbol == "") { + VLOG(2) << "StringPrinter::PrintLabel: Integer " << lab << " is not " + << "mapped to any textual symbol, symbol table = " + << syms_->Name(); + return false; + } + ostrm << symbol; + } else { + ostrm << lab; + } + return true; + } + + TokenType token_type_; // Token type: symbol, byte or utf8 encoded + const SymbolTable *syms_; // Symbol table used when token type is symbol + vector<Label> labels_; // Input FST labels. + + DISALLOW_COPY_AND_ASSIGN(StringPrinter); +}; + +} // namespace fst + +#endif // FST_LIB_STRING_H_ |