diff options
author | Determinant <[email protected]> | 2015-08-14 11:51:42 +0800 |
---|---|---|
committer | Determinant <[email protected]> | 2015-08-14 11:51:42 +0800 |
commit | 96a32415ab43377cf1575bd3f4f2980f58028209 (patch) | |
tree | 30a2d92d73e8f40ac87b79f6f56e227bfc4eea6e /kaldi_io/src/kaldi | |
parent | c177a7549bd90670af4b29fa813ddea32cfe0f78 (diff) |
add implementation for kaldi io (by ymz)
Diffstat (limited to 'kaldi_io/src/kaldi')
69 files changed, 17535 insertions, 0 deletions
diff --git a/kaldi_io/src/kaldi/base/io-funcs-inl.h b/kaldi_io/src/kaldi/base/io-funcs-inl.h new file mode 100644 index 0000000..e55458e --- /dev/null +++ b/kaldi_io/src/kaldi/base/io-funcs-inl.h @@ -0,0 +1,219 @@ +// base/io-funcs-inl.h + +// Copyright 2009-2011 Microsoft Corporation; Saarland University; +// Jan Silovsky; Yanmin Qian; Johns Hopkins University (Author: Daniel Povey) + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at + +// http://www.apache.org/licenses/LICENSE-2.0 + +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#ifndef KALDI_BASE_IO_FUNCS_INL_H_ +#define KALDI_BASE_IO_FUNCS_INL_H_ 1 + +// Do not include this file directly. It is included by base/io-funcs.h + +#include <limits> +#include <vector> + +namespace kaldi { + +// Template that covers integers. +template<class T> void WriteBasicType(std::ostream &os, + bool binary, T t) { + // Compile time assertion that this is not called with a wrong type. + KALDI_ASSERT_IS_INTEGER_TYPE(T); + if (binary) { + char len_c = (std::numeric_limits<T>::is_signed ? 1 : -1) + * static_cast<char>(sizeof(t)); + os.put(len_c); + os.write(reinterpret_cast<const char *>(&t), sizeof(t)); + } else { + if (sizeof(t) == 1) + os << static_cast<int16>(t) << " "; + else + os << t << " "; + } + if (os.fail()) { + throw std::runtime_error("Write failure in WriteBasicType."); + } +} + +// Template that covers integers. +template<class T> inline void ReadBasicType(std::istream &is, + bool binary, T *t) { + KALDI_PARANOID_ASSERT(t != NULL); + // Compile time assertion that this is not called with a wrong type. + KALDI_ASSERT_IS_INTEGER_TYPE(T); + if (binary) { + int len_c_in = is.get(); + if (len_c_in == -1) + KALDI_ERR << "ReadBasicType: encountered end of stream."; + char len_c = static_cast<char>(len_c_in), len_c_expected + = (std::numeric_limits<T>::is_signed ? 1 : -1) + * static_cast<char>(sizeof(*t)); + + if (len_c != len_c_expected) { + KALDI_ERR << "ReadBasicType: did not get expected integer type, " + << static_cast<int>(len_c) + << " vs. " << static_cast<int>(len_c_expected) + << ". You can change this code to successfully" + << " read it later, if needed."; + // insert code here to read "wrong" type. Might have a switch statement. + } + is.read(reinterpret_cast<char *>(t), sizeof(*t)); + } else { + if (sizeof(*t) == 1) { + int16 i; + is >> i; + *t = i; + } else { + is >> *t; + } + } + if (is.fail()) { + KALDI_ERR << "Read failure in ReadBasicType, file position is " + << is.tellg() << ", next char is " << is.peek(); + } +} + + +template<class T> inline void WriteIntegerVector(std::ostream &os, bool binary, + const std::vector<T> &v) { + // Compile time assertion that this is not called with a wrong type. + KALDI_ASSERT_IS_INTEGER_TYPE(T); + if (binary) { + char sz = sizeof(T); // this is currently just a check. + os.write(&sz, 1); + int32 vecsz = static_cast<int32>(v.size()); + KALDI_ASSERT((size_t)vecsz == v.size()); + os.write(reinterpret_cast<const char *>(&vecsz), sizeof(vecsz)); + if (vecsz != 0) { + os.write(reinterpret_cast<const char *>(&(v[0])), sizeof(T)*vecsz); + } + } else { + // focus here is on prettiness of text form rather than + // efficiency of reading-in. + // reading-in is dominated by low-level operations anyway: + // for efficiency use binary. + os << "[ "; + typename std::vector<T>::const_iterator iter = v.begin(), end = v.end(); + for (; iter != end; ++iter) { + if (sizeof(T) == 1) + os << static_cast<int16>(*iter) << " "; + else + os << *iter << " "; + } + os << "]\n"; + } + if (os.fail()) { + throw std::runtime_error("Write failure in WriteIntegerType."); + } +} + + +template<class T> inline void ReadIntegerVector(std::istream &is, + bool binary, + std::vector<T> *v) { + KALDI_ASSERT_IS_INTEGER_TYPE(T); + KALDI_ASSERT(v != NULL); + if (binary) { + int sz = is.peek(); + if (sz == sizeof(T)) { + is.get(); + } else { // this is currently just a check. + KALDI_ERR << "ReadIntegerVector: expected to see type of size " + << sizeof(T) << ", saw instead " << sz << ", at file position " + << is.tellg(); + } + int32 vecsz; + is.read(reinterpret_cast<char *>(&vecsz), sizeof(vecsz)); + if (is.fail() || vecsz < 0) goto bad; + v->resize(vecsz); + if (vecsz > 0) { + is.read(reinterpret_cast<char *>(&((*v)[0])), sizeof(T)*vecsz); + } + } else { + std::vector<T> tmp_v; // use temporary so v doesn't use extra memory + // due to resizing. + is >> std::ws; + if (is.peek() != static_cast<int>('[')) { + KALDI_ERR << "ReadIntegerVector: expected to see [, saw " + << is.peek() << ", at file position " << is.tellg(); + } + is.get(); // consume the '['. + is >> std::ws; // consume whitespace. + while (is.peek() != static_cast<int>(']')) { + if (sizeof(T) == 1) { // read/write chars as numbers. + int16 next_t; + is >> next_t >> std::ws; + if (is.fail()) goto bad; + else + tmp_v.push_back((T)next_t); + } else { + T next_t; + is >> next_t >> std::ws; + if (is.fail()) goto bad; + else + tmp_v.push_back(next_t); + } + } + is.get(); // get the final ']'. + *v = tmp_v; // could use std::swap to use less temporary memory, but this + // uses less permanent memory. + } + if (!is.fail()) return; + bad: + KALDI_ERR << "ReadIntegerVector: read failure at file position " + << is.tellg(); +} + +// Initialize an opened stream for writing by writing an optional binary +// header and modifying the floating-point precision. +inline void InitKaldiOutputStream(std::ostream &os, bool binary) { + // This does not throw exceptions (does not check for errors). + if (binary) { + os.put('\0'); + os.put('B'); + } + // Note, in non-binary mode we may at some point want to mess with + // the precision a bit. + // 7 is a bit more than the precision of float.. + if (os.precision() < 7) + os.precision(7); +} + +/// Initialize an opened stream for reading by detecting the binary header and +// setting the "binary" value appropriately. +inline bool InitKaldiInputStream(std::istream &is, bool *binary) { + // Sets the 'binary' variable. + // Throws exception in the very unusual situation that stream + // starts with '\0' but not then 'B'. + + if (is.peek() == '\0') { // seems to be binary + is.get(); + if (is.peek() != 'B') { + return false; + } + is.get(); + *binary = true; + return true; + } else { + *binary = false; + return true; + } +} + +} // end namespace kaldi. + +#endif // KALDI_BASE_IO_FUNCS_INL_H_ diff --git a/kaldi_io/src/kaldi/base/io-funcs.h b/kaldi_io/src/kaldi/base/io-funcs.h new file mode 100644 index 0000000..2bc9da8 --- /dev/null +++ b/kaldi_io/src/kaldi/base/io-funcs.h @@ -0,0 +1,231 @@ +// base/io-funcs.h + +// Copyright 2009-2011 Microsoft Corporation; Saarland University; +// Jan Silovsky; Yanmin Qian + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at + +// http://www.apache.org/licenses/LICENSE-2.0 + +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#ifndef KALDI_BASE_IO_FUNCS_H_ +#define KALDI_BASE_IO_FUNCS_H_ + +// This header only contains some relatively low-level I/O functions. +// The full Kaldi I/O declarations are in ../util/kaldi-io.h +// and ../util/kaldi-table.h +// They were put in util/ in order to avoid making the Matrix library +// dependent on them. + +#include <cctype> +#include <vector> +#include <string> +#include "base/kaldi-common.h" + +namespace kaldi { + + + +/* + This comment describes the Kaldi approach to I/O. All objects can be written + and read in two modes: binary and text. In addition we want to make the I/O + work if we redefine the typedef "BaseFloat" between floats and doubles. + We also want to have control over whitespace in text mode without affecting + the meaning of the file, for pretty-printing purposes. + + Errors are handled by throwing an exception (std::runtime_error). + + For integer and floating-point types (and boolean values): + + WriteBasicType(std::ostream &, bool binary, const T&); + ReadBasicType(std::istream &, bool binary, T*); + + and we expect these functions to be defined in such a way that they work when + the type T changes between float and double, so you can read float into double + and vice versa]. Note that for efficiency and space-saving reasons, the Vector + and Matrix classes do not use these functions [but they preserve the type + interchangeability in their own way] + + For a class (or struct) C: + class C { + .. + Write(std::ostream &, bool binary, [possibly extra optional args for specific classes]) const; + Read(std::istream &, bool binary, [possibly extra optional args for specific classes]); + .. + } + NOTE: The only actual optional args we used are the "add" arguments in + Vector/Matrix classes, which specify whether we should sum the data already + in the class with the data being read. + + For types which are typedef's involving stl classes, I/O is as follows: + typedef std::vector<std::pair<A, B> > MyTypedefName; + + The user should define something like: + + WriteMyTypedefName(std::ostream &, bool binary, const MyTypedefName &t); + ReadMyTypedefName(std::ostream &, bool binary, MyTypedefName *t); + + The user would have to write these functions. + + For a type std::vector<T>: + + void WriteIntegerVector(std::ostream &os, bool binary, const std::vector<T> &v); + void ReadIntegerVector(std::istream &is, bool binary, std::vector<T> *v); + + For other types, e.g. vectors of pairs, the user should create a routine of the + type WriteMyTypedefName. This is to avoid introducing confusing templated functions; + we could easily create templated functions to handle most of these cases but they + would have to share the same name. + + It also often happens that the user needs to write/read special tokens as part + of a file. These might be class headers, or separators/identifiers in the class. + We provide special functions for manipulating these. These special tokens must + be nonempty and must not contain any whitespace. + + void WriteToken(std::ostream &os, bool binary, const char*); + void WriteToken(std::ostream &os, bool binary, const std::string & token); + int Peek(std::istream &is, bool binary); + void ReadToken(std::istream &is, bool binary, std::string *str); + void PeekToken(std::istream &is, bool binary, std::string *str); + + + WriteToken writes the token and one space (whether in binary or text mode). + + Peek returns the first character of the next token, by consuming whitespace + (in text mode) and then returning the peek() character. It returns -1 at EOF; + it doesn't throw. It's useful if a class can have various forms based on + typedefs and virtual classes, and wants to know which version to read. + + ReadToken allow the caller to obtain the next token. PeekToken works just + like ReadToken, but seeks back to the beginning of the token. A subsequent + call to ReadToken will read the same token again. This is useful when + different object types are written to the same file; using PeekToken one can + decide which of the objects to read. + + There is currently no special functionality for writing/reading strings (where the strings + contain data rather than "special tokens" that are whitespace-free and nonempty). This is + because Kaldi is structured in such a way that strings don't appear, except as OpenFst symbol + table entries (and these have their own format). + + + NOTE: you should not call ReadIntegerType and WriteIntegerType with types, + such as int and size_t, that are machine-independent -- at least not + if you want your file formats to port between machines. Use int32 and + int64 where necessary. There is no way to detect this using compile-time + assertions because C++ only keeps track of the internal representation of + the type. +*/ + +/// \addtogroup io_funcs_basic +/// @{ + + +/// WriteBasicType is the name of the write function for bool, integer types, +/// and floating-point types. They all throw on error. +template<class T> void WriteBasicType(std::ostream &os, bool binary, T t); + +/// ReadBasicType is the name of the read function for bool, integer types, +/// and floating-point types. They all throw on error. +template<class T> void ReadBasicType(std::istream &is, bool binary, T *t); + + +// Declare specialization for bool. +template<> +void WriteBasicType<bool>(std::ostream &os, bool binary, bool b); + +template <> +void ReadBasicType<bool>(std::istream &is, bool binary, bool *b); + +// Declare specializations for float and double. +template<> +void WriteBasicType<float>(std::ostream &os, bool binary, float f); + +template<> +void WriteBasicType<double>(std::ostream &os, bool binary, double f); + +template<> +void ReadBasicType<float>(std::istream &is, bool binary, float *f); + +template<> +void ReadBasicType<double>(std::istream &is, bool binary, double *f); + +// Define ReadBasicType that accepts an "add" parameter to add to +// the destination. Caution: if used in Read functions, be careful +// to initialize the parameters concerned to zero in the default +// constructor. +template<class T> +inline void ReadBasicType(std::istream &is, bool binary, T *t, bool add) { + if (!add) { + ReadBasicType(is, binary, t); + } else { + T tmp = T(0); + ReadBasicType(is, binary, &tmp); + *t += tmp; + } +} + +/// Function for writing STL vectors of integer types. +template<class T> inline void WriteIntegerVector(std::ostream &os, bool binary, + const std::vector<T> &v); + +/// Function for reading STL vector of integer types. +template<class T> inline void ReadIntegerVector(std::istream &is, bool binary, + std::vector<T> *v); + +/// The WriteToken functions are for writing nonempty sequences of non-space +/// characters. They are not for general strings. +void WriteToken(std::ostream &os, bool binary, const char *token); +void WriteToken(std::ostream &os, bool binary, const std::string & token); + +/// Peek consumes whitespace (if binary == false) and then returns the peek() +/// value of the stream. +int Peek(std::istream &is, bool binary); + +/// ReadToken gets the next token and puts it in str (exception on failure). +void ReadToken(std::istream &is, bool binary, std::string *token); + +/// PeekToken will return the first character of the next token, or -1 if end of +/// file. It's the same as Peek(), except if the first character is '<' it will +/// skip over it and will return the next character. It will unget the '<' so +/// the stream is where it was before you did PeekToken(). +int PeekToken(std::istream &is, bool binary); + +/// ExpectToken tries to read in the given token, and throws an exception +/// on failure. +void ExpectToken(std::istream &is, bool binary, const char *token); +void ExpectToken(std::istream &is, bool binary, const std::string & token); + +/// ExpectPretty attempts to read the text in "token", but only in non-binary +/// mode. Throws exception on failure. It expects an exact match except that +/// arbitrary whitespace matches arbitrary whitespace. +void ExpectPretty(std::istream &is, bool binary, const char *token); +void ExpectPretty(std::istream &is, bool binary, const std::string & token); + +/// @} end "addtogroup io_funcs_basic" + + +/// InitKaldiOutputStream initializes an opened stream for writing by writing an +/// optional binary header and modifying the floating-point precision; it will +/// typically not be called by users directly. +inline void InitKaldiOutputStream(std::ostream &os, bool binary); + +/// InitKaldiInputStream initializes an opened stream for reading by detecting +/// the binary header and setting the "binary" value appropriately; +/// It will typically not be called by users directly. +inline bool InitKaldiInputStream(std::istream &is, bool *binary); + +} // end namespace kaldi. + +#include "base/io-funcs-inl.h" + +#endif // KALDI_BASE_IO_FUNCS_H_ diff --git a/kaldi_io/src/kaldi/base/kaldi-common.h b/kaldi_io/src/kaldi/base/kaldi-common.h new file mode 100644 index 0000000..33f6f31 --- /dev/null +++ b/kaldi_io/src/kaldi/base/kaldi-common.h @@ -0,0 +1,41 @@ +// base/kaldi-common.h + +// Copyright 2009-2011 Microsoft Corporation + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#ifndef KALDI_BASE_KALDI_COMMON_H_ +#define KALDI_BASE_KALDI_COMMON_H_ 1 + +#include <cstddef> +#include <cstdlib> +#include <cstring> // C string stuff like strcpy +#include <string> +#include <sstream> +#include <stdexcept> +#include <cassert> +#include <vector> +#include <iostream> +#include <fstream> + +#include "base/kaldi-utils.h" +#include "base/kaldi-error.h" +#include "base/kaldi-types.h" +#include "base/io-funcs.h" +#include "base/kaldi-math.h" + +#endif // KALDI_BASE_KALDI_COMMON_H_ + diff --git a/kaldi_io/src/kaldi/base/kaldi-error.h b/kaldi_io/src/kaldi/base/kaldi-error.h new file mode 100644 index 0000000..8334e42 --- /dev/null +++ b/kaldi_io/src/kaldi/base/kaldi-error.h @@ -0,0 +1,153 @@ +// base/kaldi-error.h + +// Copyright 2009-2011 Microsoft Corporation; Ondrej Glembek; Lukas Burget; +// Saarland University + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#ifndef KALDI_BASE_KALDI_ERROR_H_ +#define KALDI_BASE_KALDI_ERROR_H_ 1 + +#include <stdexcept> +#include <string> +#include <cstring> +#include <sstream> +#include <cstdio> + +#ifdef _MSC_VER +#define NOEXCEPT(Predicate) +#elif __cplusplus > 199711L || defined(__GXX_EXPERIMENTAL_CXX0X__) +#define NOEXCEPT(Predicate) noexcept((Predicate)) +#else +#define NOEXCEPT(Predicate) +#endif + +#include "base/kaldi-types.h" +#include "base/kaldi-utils.h" + +/* Important that this file does not depend on any other kaldi headers. */ + + +namespace kaldi { + +/// \addtogroup error_group +/// @{ + +/// This is set by util/parse-options.{h, cc} if you set --verbose = ? option +extern int32 g_kaldi_verbose_level; + +/// This is set by util/parse-options.{h, cc} (from argv[0]) and used (if set) +/// in error reporting code to display the name of the program (this is because +/// in our scripts, we often mix together the stderr of many programs). it is +/// the base-name of the program (no directory), followed by ':' We don't use +/// std::string, due to the static initialization order fiasco. +extern const char *g_program_name; + +inline int32 GetVerboseLevel() { return g_kaldi_verbose_level; } + +/// This should be rarely used; command-line programs set the verbose level +/// automatically from ParseOptions. +inline void SetVerboseLevel(int32 i) { g_kaldi_verbose_level = i; } + +// Class KaldiLogMessage is invoked from the KALDI_WARN, KALDI_VLOG and +// KALDI_LOG macros. It prints the message to stderr. Note: we avoid +// using cerr, due to problems with thread safety. fprintf is guaranteed +// thread-safe. + +// class KaldiWarnMessage is invoked from the KALDI_WARN macro. +class KaldiWarnMessage { + public: + inline std::ostream &stream() { return ss; } + KaldiWarnMessage(const char *func, const char *file, int32 line); + ~KaldiWarnMessage() { fprintf(stderr, "%s\n", ss.str().c_str()); } + private: + std::ostringstream ss; +}; + +// class KaldiLogMessage is invoked from the KALDI_LOG macro. +class KaldiLogMessage { + public: + inline std::ostream &stream() { return ss; } + KaldiLogMessage(const char *func, const char *file, int32 line); + ~KaldiLogMessage() { fprintf(stderr, "%s\n", ss.str().c_str()); } + private: + std::ostringstream ss; +}; + +// Class KaldiVlogMessage is invoked from the KALDI_VLOG macro. +class KaldiVlogMessage { + public: + KaldiVlogMessage(const char *func, const char *file, int32 line, + int32 verbose_level); + inline std::ostream &stream() { return ss; } + ~KaldiVlogMessage() { fprintf(stderr, "%s\n", ss.str().c_str()); } + private: + std::ostringstream ss; +}; + + +// class KaldiErrorMessage is invoked from the KALDI_ERROR macro. +// The destructor throws an exception. +class KaldiErrorMessage { + public: + KaldiErrorMessage(const char *func, const char *file, int32 line); + inline std::ostream &stream() { return ss; } + ~KaldiErrorMessage() NOEXCEPT(false); // defined in kaldi-error.cc + private: + std::ostringstream ss; +}; + + + +#ifdef _MSC_VER +#define __func__ __FUNCTION__ +#endif + +#ifndef NDEBUG +#define KALDI_ASSERT(cond) \ + if (!(cond)) kaldi::KaldiAssertFailure_(__func__, __FILE__, __LINE__, #cond); +#else +#define KALDI_ASSERT(cond) +#endif +// also see KALDI_COMPILE_TIME_ASSERT, defined in base/kaldi-utils.h, +// and KALDI_ASSERT_IS_INTEGER_TYPE and KALDI_ASSERT_IS_FLOATING_TYPE, +// also defined there. +#ifdef KALDI_PARANOID // some more expensive asserts only checked if this defined +#define KALDI_PARANOID_ASSERT(cond) \ + if (!(cond)) kaldi::KaldiAssertFailure_(__func__, __FILE__, __LINE__, #cond); +#else +#define KALDI_PARANOID_ASSERT(cond) +#endif + +#define KALDI_ERR kaldi::KaldiErrorMessage(__func__, __FILE__, __LINE__).stream() +#define KALDI_WARN kaldi::KaldiWarnMessage(__func__, __FILE__, __LINE__).stream() +#define KALDI_LOG kaldi::KaldiLogMessage(__func__, __FILE__, __LINE__).stream() + +#define KALDI_VLOG(v) if (v <= kaldi::g_kaldi_verbose_level) \ + kaldi::KaldiVlogMessage(__func__, __FILE__, __LINE__, v).stream() + +inline bool IsKaldiError(const std::string &str) { + return(!strncmp(str.c_str(), "ERROR ", 6)); +} + +void KaldiAssertFailure_(const char *func, const char *file, + int32 line, const char *cond_str); + +/// @} end "addtogroup error_group" + +} // namespace kaldi + +#endif // KALDI_BASE_KALDI_ERROR_H_ diff --git a/kaldi_io/src/kaldi/base/kaldi-math.h b/kaldi_io/src/kaldi/base/kaldi-math.h new file mode 100644 index 0000000..4f60d00 --- /dev/null +++ b/kaldi_io/src/kaldi/base/kaldi-math.h @@ -0,0 +1,346 @@ +// base/kaldi-math.h + +// Copyright 2009-2011 Ondrej Glembek; Microsoft Corporation; Yanmin Qian; +// Jan Silovsky; Saarland University +// +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#ifndef KALDI_BASE_KALDI_MATH_H_ +#define KALDI_BASE_KALDI_MATH_H_ 1 + +#ifdef _MSC_VER +#include <float.h> +#endif + +#include <cmath> +#include <limits> +#include <vector> + +#include "base/kaldi-types.h" +#include "base/kaldi-common.h" + + +#ifndef DBL_EPSILON +#define DBL_EPSILON 2.2204460492503131e-16 +#endif +#ifndef FLT_EPSILON +#define FLT_EPSILON 1.19209290e-7f +#endif + +#ifndef M_PI +# define M_PI 3.1415926535897932384626433832795 +#endif + +#ifndef M_SQRT2 +# define M_SQRT2 1.4142135623730950488016887 +#endif + + +#ifndef M_2PI +# define M_2PI 6.283185307179586476925286766559005 +#endif + +#ifndef M_SQRT1_2 +# define M_SQRT1_2 0.7071067811865475244008443621048490 +#endif + +#ifndef M_LOG_2PI +#define M_LOG_2PI 1.8378770664093454835606594728112 +#endif + +#ifndef M_LN2 +#define M_LN2 0.693147180559945309417232121458 +#endif + +#ifdef _MSC_VER +# define KALDI_ISNAN _isnan +# define KALDI_ISINF(x) (!_isnan(x) && _isnan(x-x)) +# define KALDI_ISFINITE _finite +#else +# define KALDI_ISNAN std::isnan +# define KALDI_ISINF std::isinf +# define KALDI_ISFINITE(x) std::isfinite(x) +#endif +#if !defined(KALDI_SQR) +# define KALDI_SQR(x) ((x) * (x)) +#endif + +namespace kaldi { + +// -infinity +const float kLogZeroFloat = -std::numeric_limits<float>::infinity(); +const double kLogZeroDouble = -std::numeric_limits<double>::infinity(); +const BaseFloat kLogZeroBaseFloat = -std::numeric_limits<BaseFloat>::infinity(); + +// Returns a random integer between 0 and RAND_MAX, inclusive +int Rand(struct RandomState* state=NULL); + +// State for thread-safe random number generator +struct RandomState { + RandomState(); + unsigned seed; +}; + +// Returns a random integer between min and max inclusive. +int32 RandInt(int32 min, int32 max, struct RandomState* state=NULL); + +bool WithProb(BaseFloat prob, struct RandomState* state=NULL); // Returns true with probability "prob", +// with 0 <= prob <= 1 [we check this]. +// Internally calls Rand(). This function is carefully implemented so +// that it should work even if prob is very small. + +/// Returns a random number strictly between 0 and 1. +inline float RandUniform(struct RandomState* state = NULL) { + return static_cast<float>((Rand(state) + 1.0) / (RAND_MAX+2.0)); +} + +inline float RandGauss(struct RandomState* state = NULL) { + return static_cast<float>(sqrtf (-2 * logf(RandUniform(state))) + * cosf(2*M_PI*RandUniform(state))); +} + +// Returns poisson-distributed random number. Uses Knuth's algorithm. +// Take care: this takes time proportinal +// to lambda. Faster algorithms exist but are more complex. +int32 RandPoisson(float lambda, struct RandomState* state=NULL); + +// Returns a pair of gaussian random numbers. Uses Box-Muller transform +void RandGauss2(float *a, float *b, RandomState *state = NULL); +void RandGauss2(double *a, double *b, RandomState *state = NULL); + +// Also see Vector<float,double>::RandCategorical(). + +// This is a randomized pruning mechanism that preserves expectations, +// that we typically use to prune posteriors. +template<class Float> +inline Float RandPrune(Float post, BaseFloat prune_thresh, struct RandomState* state=NULL) { + KALDI_ASSERT(prune_thresh >= 0.0); + if (post == 0.0 || std::abs(post) >= prune_thresh) + return post; + return (post >= 0 ? 1.0 : -1.0) * + (RandUniform(state) <= fabs(post)/prune_thresh ? prune_thresh : 0.0); +} + +static const double kMinLogDiffDouble = std::log(DBL_EPSILON); // negative! +static const float kMinLogDiffFloat = std::log(FLT_EPSILON); // negative! + +inline double LogAdd(double x, double y) { + double diff; + if (x < y) { + diff = x - y; + x = y; + } else { + diff = y - x; + } + // diff is negative. x is now the larger one. + + if (diff >= kMinLogDiffDouble) { + double res; +#ifdef _MSC_VER + res = x + log(1.0 + exp(diff)); +#else + res = x + log1p(exp(diff)); +#endif + return res; + } else { + return x; // return the larger one. + } +} + + +inline float LogAdd(float x, float y) { + float diff; + if (x < y) { + diff = x - y; + x = y; + } else { + diff = y - x; + } + // diff is negative. x is now the larger one. + + if (diff >= kMinLogDiffFloat) { + float res; +#ifdef _MSC_VER + res = x + logf(1.0 + expf(diff)); +#else + res = x + log1pf(expf(diff)); +#endif + return res; + } else { + return x; // return the larger one. + } +} + + +// returns exp(x) - exp(y). +inline double LogSub(double x, double y) { + if (y >= x) { // Throws exception if y>=x. + if (y == x) + return kLogZeroDouble; + else + KALDI_ERR << "Cannot subtract a larger from a smaller number."; + } + + double diff = y - x; // Will be negative. + double res = x + log(1.0 - exp(diff)); + + // res might be NAN if diff ~0.0, and 1.0-exp(diff) == 0 to machine precision + if (KALDI_ISNAN(res)) + return kLogZeroDouble; + return res; +} + + +// returns exp(x) - exp(y). +inline float LogSub(float x, float y) { + if (y >= x) { // Throws exception if y>=x. + if (y == x) + return kLogZeroDouble; + else + KALDI_ERR << "Cannot subtract a larger from a smaller number."; + } + + float diff = y - x; // Will be negative. + float res = x + logf(1.0 - expf(diff)); + + // res might be NAN if diff ~0.0, and 1.0-exp(diff) == 0 to machine precision + if (KALDI_ISNAN(res)) + return kLogZeroFloat; + return res; +} + +/// return abs(a - b) <= relative_tolerance * (abs(a)+abs(b)). +static inline bool ApproxEqual(float a, float b, + float relative_tolerance = 0.001) { + // a==b handles infinities. + if (a==b) return true; + float diff = std::abs(a-b); + if (diff == std::numeric_limits<float>::infinity() + || diff != diff) return false; // diff is +inf or nan. + return (diff <= relative_tolerance*(std::abs(a)+std::abs(b))); +} + +/// assert abs(a - b) <= relative_tolerance * (abs(a)+abs(b)) +static inline void AssertEqual(float a, float b, + float relative_tolerance = 0.001) { + // a==b handles infinities. + KALDI_ASSERT(ApproxEqual(a, b, relative_tolerance)); +} + + +// RoundUpToNearestPowerOfTwo does the obvious thing. It crashes if n <= 0. +int32 RoundUpToNearestPowerOfTwo(int32 n); + +template<class I> I Gcd(I m, I n) { + if (m == 0 || n == 0) { + if (m == 0 && n == 0) { // gcd not defined, as all integers are divisors. + KALDI_ERR << "Undefined GCD since m = 0, n = 0."; + } + return (m == 0 ? (n > 0 ? n : -n) : ( m > 0 ? m : -m)); + // return absolute value of whichever is nonzero + } + // could use compile-time assertion + // but involves messing with complex template stuff. + KALDI_ASSERT(std::numeric_limits<I>::is_integer); + while (1) { + m %= n; + if (m == 0) return (n > 0 ? n : -n); + n %= m; + if (n == 0) return (m > 0 ? m : -m); + } +} + +/// Returns the least common multiple of two integers. Will +/// crash unless the inputs are positive. +template<class I> I Lcm(I m, I n) { + KALDI_ASSERT(m > 0 && n > 0); + I gcd = Gcd(m, n); + return gcd * (m/gcd) * (n/gcd); +} + + +template<class I> void Factorize(I m, std::vector<I> *factors) { + // Splits a number into its prime factors, in sorted order from + // least to greatest, with duplication. A very inefficient + // algorithm, which is mainly intended for use in the + // mixed-radix FFT computation (where we assume most factors + // are small). + KALDI_ASSERT(factors != NULL); + KALDI_ASSERT(m >= 1); // Doesn't work for zero or negative numbers. + factors->clear(); + I small_factors[10] = { 2, 3, 5, 7, 11, 13, 17, 19, 23, 29 }; + + // First try small factors. + for (I i = 0; i < 10; i++) { + if (m == 1) return; // We're done. + while (m % small_factors[i] == 0) { + m /= small_factors[i]; + factors->push_back(small_factors[i]); + } + } + // Next try all odd numbers starting from 31. + for (I j = 31;; j += 2) { + if (m == 1) return; + while (m % j == 0) { + m /= j; + factors->push_back(j); + } + } +} + +inline double Hypot(double x, double y) { return hypot(x, y); } + +inline float Hypot(float x, float y) { return hypotf(x, y); } + +#if !defined(_MSC_VER) || (_MSC_VER >= 1800) +inline double Log1p(double x) { return log1p(x); } + +inline float Log1p(float x) { return log1pf(x); } +#else +inline double Log1p(double x) { + const double cutoff = 1.0e-08; + if (x < cutoff) + return x - 2 * x * x; + else + return log(1.0 + x); +} + +inline float Log1p(float x) { + const float cutoff = 1.0e-07; + if (x < cutoff) + return x - 2 * x * x; + else + return log(1.0 + x); +} +#endif + +inline double Exp(double x) { return exp(x); } + +#ifndef KALDI_NO_EXPF +inline float Exp(float x) { return expf(x); } +#else +inline float Exp(float x) { return exp(x); } +#endif + +inline double Log(double x) { return log(x); } + +inline float Log(float x) { return logf(x); } + + +} // namespace kaldi + + +#endif // KALDI_BASE_KALDI_MATH_H_ diff --git a/kaldi_io/src/kaldi/base/kaldi-types.h b/kaldi_io/src/kaldi/base/kaldi-types.h new file mode 100644 index 0000000..04354b2 --- /dev/null +++ b/kaldi_io/src/kaldi/base/kaldi-types.h @@ -0,0 +1,64 @@ +// base/kaldi-types.h + +// Copyright 2009-2011 Microsoft Corporation; Saarland University; +// Jan Silovsky; Yanmin Qian + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#ifndef KALDI_BASE_KALDI_TYPES_H_ +#define KALDI_BASE_KALDI_TYPES_H_ 1 + +namespace kaldi { +// TYPEDEFS .................................................................. +#if (KALDI_DOUBLEPRECISION != 0) +typedef double BaseFloat; +#else +typedef float BaseFloat; +#endif +} + +#ifdef _MSC_VER +namespace kaldi { +typedef unsigned __int16 uint16; +typedef unsigned __int32 uint32; +typedef __int16 int16; +typedef __int32 int32; +typedef __int64 int64; +typedef unsigned __int64 uint64; +typedef float float32; +typedef double double64; +} +#include <basetsd.h> +#define ssize_t SSIZE_T + +#else +// we can do this a different way if some platform +// we find in the future lacks stdint.h +#include <stdint.h> + +namespace kaldi { +typedef uint16_t uint16; +typedef uint32_t uint32; +typedef uint64_t uint64; +typedef int16_t int16; +typedef int32_t int32; +typedef int64_t int64; +typedef float float32; +typedef double double64; +} // end namespace kaldi +#endif + +#endif // KALDI_BASE_KALDI_TYPES_H_ diff --git a/kaldi_io/src/kaldi/base/kaldi-utils.h b/kaldi_io/src/kaldi/base/kaldi-utils.h new file mode 100644 index 0000000..1b2c893 --- /dev/null +++ b/kaldi_io/src/kaldi/base/kaldi-utils.h @@ -0,0 +1,157 @@ +// base/kaldi-utils.h + +// Copyright 2009-2011 Ondrej Glembek; Microsoft Corporation; +// Saarland University; Karel Vesely; Yanmin Qian + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#ifndef KALDI_BASE_KALDI_UTILS_H_ +#define KALDI_BASE_KALDI_UTILS_H_ 1 + +#include <limits> +#include <string> + +#if defined(_MSC_VER) +# define WIN32_LEAN_AND_MEAN +# define NOMINMAX +# include <windows.h> +#endif + +#if defined(_MSC_VER) +#pragma warning(disable: 4244 4056 4305 4800 4267 4996 4756 4661) +#define __restrict__ +#endif + +#ifdef HAVE_POSIX_MEMALIGN +# define KALDI_MEMALIGN(align, size, pp_orig) \ + (!posix_memalign(pp_orig, align, size) ? *(pp_orig) : NULL) +# define KALDI_MEMALIGN_FREE(x) free(x) +#elif defined(HAVE_MEMALIGN) + /* Some systems have memalign() but no declaration for it */ + void * memalign(size_t align, size_t size); +# define KALDI_MEMALIGN(align, size, pp_orig) \ + (*(pp_orig) = memalign(align, size)) +# define KALDI_MEMALIGN_FREE(x) free(x) +#elif defined(_MSC_VER) +# define KALDI_MEMALIGN(align, size, pp_orig) \ + (*(pp_orig) = _aligned_malloc(size, align)) +# define KALDI_MEMALIGN_FREE(x) _aligned_free(x) +#else +#error Manual memory alignment is no longer supported +#endif + +#ifdef __ICC +#pragma warning(disable: 383) // ICPC remark we don't want. +#pragma warning(disable: 810) // ICPC remark we don't want. +#pragma warning(disable: 981) // ICPC remark we don't want. +#pragma warning(disable: 1418) // ICPC remark we don't want. +#pragma warning(disable: 444) // ICPC remark we don't want. +#pragma warning(disable: 869) // ICPC remark we don't want. +#pragma warning(disable: 1287) // ICPC remark we don't want. +#pragma warning(disable: 279) // ICPC remark we don't want. +#pragma warning(disable: 981) // ICPC remark we don't want. +#endif + + +namespace kaldi { + + +// CharToString prints the character in a human-readable form, for debugging. +std::string CharToString(const char &c); + + +inline int MachineIsLittleEndian() { + int check = 1; + return (*reinterpret_cast<char*>(&check) != 0); +} + +// This function kaldi::Sleep() provides a portable way to sleep for a possibly fractional +// number of seconds. On Windows it's only accurate to microseconds. +void Sleep(float seconds); + +} + +#define KALDI_SWAP8(a) { \ + int t = ((char*)&a)[0]; ((char*)&a)[0]=((char*)&a)[7]; ((char*)&a)[7]=t;\ + t = ((char*)&a)[1]; ((char*)&a)[1]=((char*)&a)[6]; ((char*)&a)[6]=t;\ + t = ((char*)&a)[2]; ((char*)&a)[2]=((char*)&a)[5]; ((char*)&a)[5]=t;\ + t = ((char*)&a)[3]; ((char*)&a)[3]=((char*)&a)[4]; ((char*)&a)[4]=t;} +#define KALDI_SWAP4(a) { \ + int t = ((char*)&a)[0]; ((char*)&a)[0]=((char*)&a)[3]; ((char*)&a)[3]=t;\ + t = ((char*)&a)[1]; ((char*)&a)[1]=((char*)&a)[2]; ((char*)&a)[2]=t;} +#define KALDI_SWAP2(a) { \ + int t = ((char*)&a)[0]; ((char*)&a)[0]=((char*)&a)[1]; ((char*)&a)[1]=t;} + + +// Makes copy constructor and operator= private. Same as in compat.h of OpenFst +// toolkit. If using VS, for which this results in compilation errors, we +// do it differently. + +#if defined(_MSC_VER) +#define KALDI_DISALLOW_COPY_AND_ASSIGN(type) \ + void operator = (const type&) +#else +#define KALDI_DISALLOW_COPY_AND_ASSIGN(type) \ + type(const type&); \ + void operator = (const type&) +#endif + +template<bool B> class KaldiCompileTimeAssert { }; +template<> class KaldiCompileTimeAssert<true> { + public: + static inline void Check() { } +}; + +#define KALDI_COMPILE_TIME_ASSERT(b) KaldiCompileTimeAssert<(b)>::Check() + +#define KALDI_ASSERT_IS_INTEGER_TYPE(I) \ + KaldiCompileTimeAssert<std::numeric_limits<I>::is_specialized \ + && std::numeric_limits<I>::is_integer>::Check() + +#define KALDI_ASSERT_IS_FLOATING_TYPE(F) \ + KaldiCompileTimeAssert<std::numeric_limits<F>::is_specialized \ + && !std::numeric_limits<F>::is_integer>::Check() + +#ifdef _MSC_VER +#include <stdio.h> +#define unlink _unlink +#else +#include <unistd.h> +#endif + + +#ifdef _MSC_VER +#define KALDI_STRCASECMP _stricmp +#else +#define KALDI_STRCASECMP strcasecmp +#endif +#ifdef _MSC_VER +# define KALDI_STRTOLL(cur_cstr, end_cstr) _strtoi64(cur_cstr, end_cstr, 10); +#else +# define KALDI_STRTOLL(cur_cstr, end_cstr) strtoll(cur_cstr, end_cstr, 10); +#endif + +#define KALDI_STRTOD(cur_cstr, end_cstr) strtod(cur_cstr, end_cstr) + +#ifdef _MSC_VER +# define KALDI_STRTOF(cur_cstr, end_cstr) \ + static_cast<float>(strtod(cur_cstr, end_cstr)); +#else +# define KALDI_STRTOF(cur_cstr, end_cstr) strtof(cur_cstr, end_cstr); +#endif + +#endif // KALDI_BASE_KALDI_UTILS_H_ + diff --git a/kaldi_io/src/kaldi/base/timer.h b/kaldi_io/src/kaldi/base/timer.h new file mode 100644 index 0000000..d93a461 --- /dev/null +++ b/kaldi_io/src/kaldi/base/timer.h @@ -0,0 +1,83 @@ +// base/timer.h + +// Copyright 2009-2011 Ondrej Glembek; Microsoft Corporation + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at + +// http://www.apache.org/licenses/LICENSE-2.0 + +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. +#ifndef KALDI_BASE_TIMER_H_ +#define KALDI_BASE_TIMER_H_ + +#include "base/kaldi-utils.h" +// Note: Sleep(float secs) is included in base/kaldi-utils.h. + + +#if defined(_MSC_VER) || defined(MINGW) + +namespace kaldi +{ + +class Timer { + public: + Timer() { Reset(); } + void Reset() { + QueryPerformanceCounter(&time_start_); + } + double Elapsed() { + LARGE_INTEGER time_end; + LARGE_INTEGER freq; + QueryPerformanceCounter(&time_end); + if (QueryPerformanceFrequency(&freq) == 0) return 0.0; // Hardware does not support this. + return ((double)time_end.QuadPart - (double)time_start_.QuadPart) / + ((double)freq.QuadPart); + } + private: + LARGE_INTEGER time_start_; +}; +} + +#else + +# include <sys/time.h> +# include <unistd.h> +namespace kaldi +{ +class Timer +{ + public: + Timer() { Reset(); } + + void Reset() { gettimeofday(&this->time_start_, &time_zone_); } + + /// Returns time in seconds. + double Elapsed() { + struct timeval time_end; + gettimeofday(&time_end, &time_zone_); + double t1, t2; + t1 = (double)time_start_.tv_sec + + (double)time_start_.tv_usec/(1000*1000); + t2 = (double)time_end.tv_sec + (double)time_end.tv_usec/(1000*1000); + return t2-t1; + } + + private: + struct timeval time_start_; + struct timezone time_zone_; +}; +} + +#endif + + +#endif diff --git a/kaldi_io/src/kaldi/hmm/hmm-topology.h b/kaldi_io/src/kaldi/hmm/hmm-topology.h new file mode 100644 index 0000000..53ca427 --- /dev/null +++ b/kaldi_io/src/kaldi/hmm/hmm-topology.h @@ -0,0 +1,172 @@ +// hmm/hmm-topology.h + +// Copyright 2009-2011 Microsoft Corporation + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#ifndef KALDI_HMM_HMM_TOPOLOGY_H_ +#define KALDI_HMM_HMM_TOPOLOGY_H_ + +#include "base/kaldi-common.h" +#include "tree/context-dep.h" +#include "util/const-integer-set.h" + + +namespace kaldi { + + +/// \addtogroup hmm_group +/// @{ + +/* + // The following would be the text form for the "normal" HMM topology. + // Note that the first state is the start state, and the final state, + // which must have no output transitions and must be nonemitting, has + // an exit probability of one (no other state can have nonzero exit + // probability; you can treat the transition probability to the final + // state as an exit probability). + // Note also that it's valid to omit the "<PdfClass>" entry of the <State>, which + // will mean we won't have a pdf on that state [non-emitting state]. This is equivalent + // to setting the <PdfClass> to -1. We do this normally just for the final state. + // The Topology object can have multiple <TopologyEntry> blocks. + // This is useful if there are multiple types of topology in the system. + + <Topology> + <TopologyEntry> + <ForPhones> 1 2 3 4 5 6 7 8 </ForPhones> + <State> 0 <PdfClass> 0 + <Transition> 0 0.5 + <Transition> 1 0.5 + </State> + <State> 1 <PdfClass> 1 + <Transition> 1 0.5 + <Transition> 2 0.5 + </State> + <State> 2 <PdfClass> 2 + <Transition> 2 0.5 + <Transition> 3 0.5 + <Final> 0.5 + </State> + <State> 3 + </State> + </TopologyEntry> + </Topology> +*/ + +// kNoPdf is used where pdf_class or pdf would be used, to indicate, +// none is there. Mainly useful in skippable models, but also used +// for end states. +// A caveat with nonemitting states is that their out-transitions +// are not trainable, due to technical issues with the way +// we decided to accumulate the stats. Any transitions arising from (*) +// HMM states with "kNoPdf" as the label are second-class transitions, +// They do not have "transition-states" or "transition-ids" associated +// with them. They are used to create the FST version of the +// HMMs, where they lead to epsilon arcs. +// (*) "arising from" is a bit of a technical term here, due to the way +// (if reorder == true), we put the transition-id associated with the +// outward arcs of the state, on the input transition to the state. + +/// A constant used in the HmmTopology class as the \ref pdf_class "pdf-class" +/// kNoPdf, which is used when a HMM-state is nonemitting (has no associated +/// PDF). + +static const int32 kNoPdf = -1; + +/// A class for storing topology information for phones. See \ref hmm for context. +/// This object is sometimes accessed in a file by itself, but more often +/// as a class member of the Transition class (this is for convenience to reduce +/// the number of files programs have to access). + +class HmmTopology { + public: + /// A structure defined inside HmmTopology to represent a HMM state. + struct HmmState { + /// The \ref pdf_class pdf-class, typically 0, 1 or 2 (the same as the HMM-state index), + /// but may be different to enable us to hardwire sharing of state, and may be + /// equal to \ref kNoPdf == -1 in order to specify nonemitting states (unusual). + int32 pdf_class; + + /// A list of transitions. The first member of each pair is the index of + /// the next HmmState, and the second is the default transition probability + /// (before training). + std::vector<std::pair<int32, BaseFloat> > transitions; + + explicit HmmState(int32 p): pdf_class(p) { } + + bool operator == (const HmmState &other) const { + return (pdf_class == other.pdf_class && transitions == other.transitions); + } + + HmmState(): pdf_class(-1) { } + }; + + /// TopologyEntry is a typedef that represents the topology of + /// a single (prototype) state. + typedef std::vector<HmmState> TopologyEntry; + + void Read(std::istream &is, bool binary); + void Write(std::ostream &os, bool binary) const; + + // Checks that the object is valid, and throw exception otherwise. + void Check(); + + + /// Returns the topology entry (i.e. vector of HmmState) for this phone; + /// will throw exception if phone not covered by the topology. + const TopologyEntry &TopologyForPhone(int32 phone) const; + + /// Returns the number of \ref pdf_class "pdf-classes" for this phone; + /// throws exception if phone not covered by this topology. + int32 NumPdfClasses(int32 phone) const; + + /// Returns a reference to a sorted, unique list of phones covered by + /// the topology (these phones will be positive integers, and usually + /// contiguous and starting from one but the toolkit doesn't assume + /// they are contiguous). + const std::vector<int32> &GetPhones() const { return phones_; }; + + /// Outputs a vector of int32, indexed by phone, that gives the + /// number of \ref pdf_class pdf-classes for the phones; this is + /// used by tree-building code such as BuildTree(). + void GetPhoneToNumPdfClasses(std::vector<int32> *phone2num_pdf_classes) const; + + HmmTopology() {} + + bool operator == (const HmmTopology &other) const { + return phones_ == other.phones_ && phone2idx_ == other.phone2idx_ + && entries_ == other.entries_; + } + // Allow default assignment operator and copy constructor. + private: + std::vector<int32> phones_; // list of all phones we have topology for. Sorted, uniq. no epsilon (zero) phone. + std::vector<int32> phone2idx_; // map from phones to indexes into the entries vector (or -1 for not present). + std::vector<TopologyEntry> entries_; +}; + + +/// This function returns a HmmTopology object giving a normal 3-state topology, +/// covering all phones in the list "phones". This is mainly of use in testing +/// code. +HmmTopology GetDefaultTopology(const std::vector<int32> &phones); + +/// @} end "addtogroup hmm_group" + + +} // end namespace kaldi + + +#endif diff --git a/kaldi_io/src/kaldi/hmm/hmm-utils.h b/kaldi_io/src/kaldi/hmm/hmm-utils.h new file mode 100644 index 0000000..240f706 --- /dev/null +++ b/kaldi_io/src/kaldi/hmm/hmm-utils.h @@ -0,0 +1,295 @@ +// hmm/hmm-utils.h + +// Copyright 2009-2011 Microsoft Corporation + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#ifndef KALDI_HMM_HMM_UTILS_H_ +#define KALDI_HMM_HMM_UTILS_H_ + +#include "hmm/hmm-topology.h" +#include "hmm/transition-model.h" +#include "lat/kaldi-lattice.h" + +namespace kaldi { + + +/// \defgroup hmm_group_graph Classes and functions for creating FSTs from HMMs +/// \ingroup hmm_group +/// @{ + +/// Configuration class for the GetHTransducer() function; see +/// \ref hmm_graph_config for context. +struct HTransducerConfig { + /// Transition log-prob scale, see \ref hmm_scale. + /// Note this doesn't apply to self-loops; GetHTransducer() does + /// not include self-loops. + BaseFloat transition_scale; + + /// if true, we are constructing time-reversed FST: phone-seqs in ilabel_info + /// are backwards, and we want to output a backwards version of the HMM + /// corresponding to each phone. If reverse == true, + bool reverse; + + /// This variable is only looked at if reverse == true. If reverse == true + /// and push_weights == true, then we push the weights in the reversed FSTs we create for each + /// phone HMM. This is only safe if the HMMs are probabilistic (i.e. not discriminatively + bool push_weights; + + /// delta used if we do push_weights [only relevant if reverse == true + /// and push_weights == true]. + BaseFloat push_delta; + + HTransducerConfig(): + transition_scale(1.0), + reverse(false), + push_weights(true), + push_delta(0.001) + { } + + // Note-- this Register registers the easy-to-register options + // but not the "sym_type" which is an enum and should be handled + // separately in main(). + void Register (OptionsItf *po) { + po->Register("transition-scale", &transition_scale, + "Scale of transition probs (relative to LM)"); + po->Register("reverse", &reverse, + "Set true to build time-reversed FST."); + po->Register("push-weights", &push_weights, + "Push weights (only applicable if reverse == true)"); + po->Register("push-delta", &push_delta, + "Delta used in pushing weights (only applicable if " + "reverse && push-weights"); + } +}; + + +struct HmmCacheHash { + int operator () (const std::pair<int32, std::vector<int32> >&p) const { + VectorHasher<int32> v; + int32 prime = 103049; + return prime*p.first + v(p.second); + } +}; + +/// HmmCacheType is a map from (central-phone, sequence of pdf-ids) to FST, used +/// as cache in GetHmmAsFst, as an optimization. +typedef unordered_map<std::pair<int32, std::vector<int32> >, + fst::VectorFst<fst::StdArc>*, + HmmCacheHash> HmmCacheType; + + +/// Called by GetHTransducer() and probably will not need to be called directly; +/// it creates the FST corresponding to the phone. Does not include self-loops; +/// you have to call AddSelfLoops() for that. Result owned by caller. +/// Returns an acceptor (i.e. ilabels, olabels identical) with transition-ids +/// as the symbols. +/// For documentation in context, see \ref hmm_graph_get_hmm_as_fst +/// @param context_window A vector representing the phonetic context; see +/// \ref tree_window "here" for explanation. +/// @param ctx_dep The object that contains the phonetic decision-tree +/// @param trans_model The transition-model object, which provides +/// the mappings to transition-ids and also the transition +/// probabilities. +/// @param config Configuration object, see \ref HTransducerConfig. +/// @param cache Object used as a lookaside buffer to save computation; +/// if it finds that the object it needs is already there, it will +/// just return a pointer value from "cache"-- not that this means +/// you have to be careful not to delete things twice. + +fst::VectorFst<fst::StdArc> *GetHmmAsFst( + std::vector<int32> context_window, + const ContextDependencyInterface &ctx_dep, + const TransitionModel &trans_model, + const HTransducerConfig &config, + HmmCacheType *cache = NULL); + +/// Included mainly as a form of documentation, not used in any other code +/// currently. Creates the FST with self-loops, and with fewer options. +fst::VectorFst<fst::StdArc>* +GetHmmAsFstSimple(std::vector<int32> context_window, + const ContextDependencyInterface &ctx_dep, + const TransitionModel &trans_model, + BaseFloat prob_scale); + + +/** + * Returns the H tranducer; result owned by caller. + * See \ref hmm_graph_get_h_transducer. The H transducer has on the + * input transition-ids, and also possibly some disambiguation symbols, which + * will be put in disambig_syms. The output side contains the identifiers that + * are indexes into "ilabel_info" (these represent phones-in-context or + * disambiguation symbols). The ilabel_info vector allows GetHTransducer to map + * from symbols to phones-in-context (i.e. phonetic context windows). Any + * singleton symbols in the ilabel_info vector which are not phones, will be + * treated as disambiguation symbols. [Not all recipes use these]. The output + * "disambig_syms_left" will be set to a list of the disambiguation symbols on + * the input of the transducer (i.e. same symbol type as whatever is on the + * input of the transducer + */ +fst::VectorFst<fst::StdArc>* +GetHTransducer (const std::vector<std::vector<int32> > &ilabel_info, + const ContextDependencyInterface &ctx_dep, + const TransitionModel &trans_model, + const HTransducerConfig &config, + std::vector<int32> *disambig_syms_left); + +/** + * GetIlabelMapping produces a mapping that's similar to HTK's logical-to-physical + * model mapping (i.e. the xwrd.clustered.mlist files). It groups together + * "logical HMMs" (i.e. in our world, phonetic context windows) that share the + * same sequence of transition-ids. This can be used in an + * optional graph-creation step that produces a remapped form of CLG that can be + * more productively determinized and minimized. This is used in the command-line program + * make-ilabel-transducer.cc. + * @param ilabel_info_old [in] The original \ref tree_ilabel "ilabel_info" vector + * @param ctx_dep [in] The tree + * @param trans_model [in] The transition-model object + * @param old2new_map [out] The output; this vector, which is of size equal to the + * number of new labels, is a mapping to the old labels such that we could + * create a vector ilabel_info_new such that + * ilabel_info_new[i] == ilabel_info_old[old2new_map[i]] + */ +void GetIlabelMapping (const std::vector<std::vector<int32> > &ilabel_info_old, + const ContextDependencyInterface &ctx_dep, + const TransitionModel &trans_model, + std::vector<int32> *old2new_map); + + + +/** + * For context, see \ref hmm_graph_add_self_loops. Expands an FST that has been + * built without self-loops, and adds the self-loops (it also needs to modify + * the probability of the non-self-loop ones, as the graph without self-loops + * was created in such a way that it was stochastic). Note that the + * disambig_syms will be empty in some recipes (e.g. if you already removed + * the disambiguation symbols). + * @param trans_model [in] Transition model + * @param disambig_syms [in] Sorted, uniq list of disambiguation symbols, required + * if the graph contains disambiguation symbols but only needed for sanity checks. + * @param self_loop_scale [in] Transition-probability scale for self-loops; c.f. + * \ref hmm_scale + * @param reorder [in] If true, reorders the transitions (see \ref hmm_reorder). + * @param fst [in, out] The FST to be modified. + */ +void AddSelfLoops(const TransitionModel &trans_model, + const std::vector<int32> &disambig_syms, // used as a check only. + BaseFloat self_loop_scale, + bool reorder, // true->dan-style, false->lukas-style. + fst::VectorFst<fst::StdArc> *fst); + +/** + * Adds transition-probs, with the supplied + * scales (see \ref hmm_scale), to the graph. + * Useful if you want to create a graph without transition probs, then possibly + * train the model (including the transition probs) but keep the graph fixed, + * and add back in the transition probs. It assumes the fst has transition-ids + * on it. It is not an error if the FST has no states (nothing will be done). + * @param trans_model [in] The transition model + * @param disambig_syms [in] A list of disambiguation symbols, required if the + * graph has disambiguation symbols on its input but only + * used for checks. + * @param transition_scale [in] A scale on transition-probabilities apart from + * those involving self-loops; see \ref hmm_scale. + * @param self_loop_scale [in] A scale on self-loop transition probabilities; + * see \ref hmm_scale. + * @param fst [in, out] The FST to be modified. + */ +void AddTransitionProbs(const TransitionModel &trans_model, + const std::vector<int32> &disambig_syms, + BaseFloat transition_scale, + BaseFloat self_loop_scale, + fst::VectorFst<fst::StdArc> *fst); + +/** + This is as AddSelfLoops(), but operates on a Lattice, where + it affects the graph part of the weight (the first element + of the pair). */ +void AddTransitionProbs(const TransitionModel &trans_model, + BaseFloat transition_scale, + BaseFloat self_loop_scale, + Lattice *lat); + + +/// Returns a transducer from pdfs plus one (input) to transition-ids (output). +/// Currenly of use only for testing. +fst::VectorFst<fst::StdArc>* +GetPdfToTransitionIdTransducer(const TransitionModel &trans_model); + +/// Converts all transition-ids in the FST to pdfs plus one. +/// Placeholder: not implemented yet! +void ConvertTransitionIdsToPdfs(const TransitionModel &trans_model, + const std::vector<int32> &disambig_syms, + fst::VectorFst<fst::StdArc> *fst); + +/// @} end "defgroup hmm_group_graph" + +/// \addtogroup hmm_group +/// @{ + +/// SplitToPhones splits up the TransitionIds in "alignment" into their +/// individual phones (one vector per instance of a phone). At output, +/// the sum of the sizes of the vectors in split_alignment will be the same +/// as the corresponding sum for "alignment". The function returns +/// true on success. If the alignment appears to be incomplete, e.g. +/// not ending at the end-state of a phone, it will still break it up into +/// phones but it will return false. For more serious errors it will +/// die or throw an exception. +/// This function works out by itself whether the graph was created +/// with "reordering" (dan-style graph), and just does the right thing. + +bool SplitToPhones(const TransitionModel &trans_model, + const std::vector<int32> &alignment, + std::vector<std::vector<int32> > *split_alignment); + +/// ConvertAlignment converts an alignment that was created using one +/// model, to another model. They must use a compatible topology (so we +/// know the state alignments of the new model). +/// It returns false if it could not be split to phones (probably +/// because the alignment was partial), but for other kinds of +/// error that are more likely a coding error, it will throw +/// an exception. +bool ConvertAlignment(const TransitionModel &old_trans_model, + const TransitionModel &new_trans_model, + const ContextDependencyInterface &new_ctx_dep, + const std::vector<int32> &old_alignment, + const std::vector<int32> *phone_map, // may be NULL + std::vector<int32> *new_alignment); + +// ConvertPhnxToProns is only needed in bin/phones-to-prons.cc and +// isn't closely related with HMMs, but we put it here as there isn't +// any other obvious place for it and it needs to be tested. +// This function takes a phone-sequence with word-start and word-end +// markers in it, and a word-sequence, and outputs the pronunciations +// "prons"... the format of "prons" is, each element is a vector, +// where the first element is the word (or zero meaning no word, e.g. +// for optional silence introduced by the lexicon), and the remaining +// elements are the phones in the word's pronunciation. +// It returns false if it encounters a problem of some kind, e.g. +// if the phone-sequence doesn't seem to have the right number of +// words in it. +bool ConvertPhnxToProns(const std::vector<int32> &phnx, + const std::vector<int32> &words, + int32 word_start_sym, + int32 word_end_sym, + std::vector<std::vector<int32> > *prons); + +/// @} end "addtogroup hmm_group" + +} // end namespace kaldi + + +#endif diff --git a/kaldi_io/src/kaldi/hmm/posterior.h b/kaldi_io/src/kaldi/hmm/posterior.h new file mode 100644 index 0000000..be73be9 --- /dev/null +++ b/kaldi_io/src/kaldi/hmm/posterior.h @@ -0,0 +1,214 @@ +// hmm/posterior.h + +// Copyright 2009-2011 Microsoft Corporation +// 2013-2014 Johns Hopkins University (author: Daniel Povey) +// 2014 Guoguo Chen + + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#ifndef KALDI_HMM_POSTERIOR_H_ +#define KALDI_HMM_POSTERIOR_H_ + +#include "base/kaldi-common.h" +#include "tree/context-dep.h" +#include "util/const-integer-set.h" +#include "util/kaldi-table.h" +#include "hmm/transition-model.h" + + +namespace kaldi { + + +/// \addtogroup posterior_group +/// @{ + +/// Posterior is a typedef for storing acoustic-state (actually, transition-id) +/// posteriors over an utterance. The "int32" is a transition-id, and the BaseFloat +/// is a probability (typically between zero and one). +typedef std::vector<std::vector<std::pair<int32, BaseFloat> > > Posterior; + +/// GaussPost is a typedef for storing Gaussian-level posteriors for an utterance. +/// the "int32" is a transition-id, and the Vector<BaseFloat> is a vector of +/// Gaussian posteriors. +/// WARNING: We changed "int32" from transition-id to pdf-id, and the change is +/// applied for all programs using GaussPost. This is for efficiency purpose. We +/// also changed the name slightly from GauPost to GaussPost to reduce the +/// chance that the change will go un-noticed in downstream code. +typedef std::vector<std::vector<std::pair<int32, Vector<BaseFloat> > > > GaussPost; + + +// PosteriorHolder is a holder for Posterior, which is +// std::vector<std::vector<std::pair<int32, BaseFloat> > > +// This is used for storing posteriors of transition id's for an +// utterance. +class PosteriorHolder { + public: + typedef Posterior T; + + PosteriorHolder() { } + + static bool Write(std::ostream &os, bool binary, const T &t); + + void Clear() { Posterior tmp; std::swap(tmp, t_); } + + // Reads into the holder. + bool Read(std::istream &is); + + // Kaldi objects always have the stream open in binary mode for + // reading. + static bool IsReadInBinary() { return true; } + + const T &Value() const { return t_; } + + private: + KALDI_DISALLOW_COPY_AND_ASSIGN(PosteriorHolder); + T t_; +}; + + +// GaussPostHolder is a holder for GaussPost, which is +// std::vector<std::vector<std::pair<int32, Vector<BaseFloat> > > > +// This is used for storing posteriors of transition id's for an +// utterance. +class GaussPostHolder { + public: + typedef GaussPost T; + + GaussPostHolder() { } + + static bool Write(std::ostream &os, bool binary, const T &t); + + void Clear() { GaussPost tmp; std::swap(tmp, t_); } + + // Reads into the holder. + bool Read(std::istream &is); + + // Kaldi objects always have the stream open in binary mode for + // reading. + static bool IsReadInBinary() { return true; } + + const T &Value() const { return t_; } + + private: + KALDI_DISALLOW_COPY_AND_ASSIGN(GaussPostHolder); + T t_; +}; + + +// Posterior is a typedef: vector<vector<pair<int32, BaseFloat> > >, +// representing posteriors over (typically) transition-ids for an +// utterance. +typedef TableWriter<PosteriorHolder> PosteriorWriter; +typedef SequentialTableReader<PosteriorHolder> SequentialPosteriorReader; +typedef RandomAccessTableReader<PosteriorHolder> RandomAccessPosteriorReader; + + +// typedef std::vector<std::vector<std::pair<int32, Vector<BaseFloat> > > > GaussPost; +typedef TableWriter<GaussPostHolder> GaussPostWriter; +typedef SequentialTableReader<GaussPostHolder> SequentialGaussPostReader; +typedef RandomAccessTableReader<GaussPostHolder> RandomAccessGaussPostReader; + + +/// Scales the BaseFloat (weight) element in the posterior entries. +void ScalePosterior(BaseFloat scale, Posterior *post); + +/// Returns the total of all the weights in "post". +BaseFloat TotalPosterior(const Posterior &post); + +/// Returns true if the two lists of pairs have no common .first element. +bool PosteriorEntriesAreDisjoint( + const std::vector<std::pair<int32, BaseFloat> > &post_elem1, + const std::vector<std::pair<int32, BaseFloat> > &post_elem2); + + +/// Merge two sets of posteriors, which must have the same length. If "merge" +/// is true, it will make a common entry whenever there are duplicated entries, +/// adding up the weights. If "drop_frames" is true, for frames where the +/// two sets of posteriors were originally disjoint, makes no entries for that +/// frame (relates to frame dropping, or drop_frames, see Vesely et al, ICASSP +/// 2013). Returns the number of frames for which the two posteriors were +/// disjoint (i.e. no common transition-ids or whatever index we are using). +int32 MergePosteriors(const Posterior &post1, + const Posterior &post2, + bool merge, + bool drop_frames, + Posterior *post); + +/// Given a vector of log-likelihoods (typically of Gaussians in a GMM +/// but could be of pdf-ids), a number gselect >= 1 and a minimum posterior +/// 0 <= min_post < 1, it gets the posterior for each element of log-likes +/// by applying Softmax(), then prunes the posteriors using "gselect" and +/// "min_post" (keeping at least one), and outputs the result into +/// "post_entry", sorted from greatest to least posterior. +/// Returns the total log-likelihood (the output of calling ApplySoftMax() +/// on a copy of log_likes). +BaseFloat VectorToPosteriorEntry( + const VectorBase<BaseFloat> &log_likes, + int32 num_gselect, + BaseFloat min_post, + std::vector<std::pair<int32, BaseFloat> > *post_entry); + +/// Convert an alignment to a posterior (with a scale of 1.0 on +/// each entry). +void AlignmentToPosterior(const std::vector<int32> &ali, + Posterior *post); + +/// Sorts posterior entries so that transition-ids with same pdf-id are next to +/// each other. +void SortPosteriorByPdfs(const TransitionModel &tmodel, + Posterior *post); + + +/// Converts a posterior over transition-ids to be a posterior +/// over pdf-ids. +void ConvertPosteriorToPdfs(const TransitionModel &tmodel, + const Posterior &post_in, + Posterior *post_out); + +/// Converts a posterior over transition-ids to be a posterior +/// over phones. +void ConvertPosteriorToPhones(const TransitionModel &tmodel, + const Posterior &post_in, + Posterior *post_out); + +/// Weight any silence phones in the posterior (i.e. any phones +/// in the set "silence_set" by scale "silence_scale". +/// The interface was changed in Feb 2014 to do the modification +/// "in-place" rather than having separate input and output. +void WeightSilencePost(const TransitionModel &trans_model, + const ConstIntegerSet<int32> &silence_set, + BaseFloat silence_scale, + Posterior *post); + +/// This is similar to WeightSilencePost, except that on each frame it +/// works out the amount by which the overall posterior would be reduced, +/// and scales down everything on that frame by the same amount. It +/// has the effect that frames that are mostly silence get down-weighted. +/// The interface was changed in Feb 2014 to do the modification +/// "in-place" rather than having separate input and output. +void WeightSilencePostDistributed(const TransitionModel &trans_model, + const ConstIntegerSet<int32> &silence_set, + BaseFloat silence_scale, + Posterior *post); + +/// @} end "addtogroup posterior_group" + + +} // end namespace kaldi + + +#endif diff --git a/kaldi_io/src/kaldi/hmm/transition-model.h b/kaldi_io/src/kaldi/hmm/transition-model.h new file mode 100644 index 0000000..ccc4f11 --- /dev/null +++ b/kaldi_io/src/kaldi/hmm/transition-model.h @@ -0,0 +1,345 @@ +// hmm/transition-model.h + +// Copyright 2009-2012 Microsoft Corporation +// Johns Hopkins University (author: Guoguo Chen) + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#ifndef KALDI_HMM_TRANSITION_MODEL_H_ +#define KALDI_HMM_TRANSITION_MODEL_H_ + +#include "base/kaldi-common.h" +#include "tree/context-dep.h" +#include "util/const-integer-set.h" +#include "fst/fst-decl.h" // forward declarations. +#include "hmm/hmm-topology.h" +#include "itf/options-itf.h" + +namespace kaldi { + +/// \addtogroup hmm_group +/// @{ + +// The class TransitionModel is a repository for the transition probabilities. +// It also handles certain integer mappings. +// The basic model is as follows. Each phone has a HMM topology defined in +// hmm-topology.h. Each HMM-state of each of these phones has a number of +// transitions (and final-probs) out of it. Each HMM-state defined in the +// HmmTopology class has an associated "pdf_class". This gets replaced with +// an actual pdf-id via the tree. The transition model associates the +// transition probs with the (phone, HMM-state, pdf-id). We associate with +// each such triple a transition-state. Each +// transition-state has a number of associated probabilities to estimate; +// this depends on the number of transitions/final-probs in the topology for +// that (phone, HMM-state). Each probability has an associated transition-index. +// We associate with each (transition-state, transition-index) a unique transition-id. +// Each individual probability estimated by the transition-model is asociated with a +// transition-id. +// +// List of the various types of quantity referred to here and what they mean: +// phone: a phone index (1, 2, 3 ...) +// HMM-state: a number (0, 1, 2...) that indexes TopologyEntry (see hmm-topology.h) +// pdf-id: a number output by the Compute function of ContextDependency (it +// indexes pdf's). Zero-based. +// transition-state: the states for which we estimate transition probabilities for transitions +// out of them. In some topologies, will map one-to-one with pdf-ids. +// One-based, since it appears on FSTs. +// transition-index: identifier of a transition (or final-prob) in the HMM. Indexes the +// "transitions" vector in HmmTopology::HmmState. [if it is out of range, +// equal to transitions.size(), it refers to the final-prob.] +// Zero-based. +// transition-id: identifier of a unique parameter of the TransitionModel. +// Associated with a (transition-state, transition-index) pair. +// One-based, since it appears on FSTs. +// +// List of the possible mappings TransitionModel can do: +// (phone, HMM-state, pdf-id) -> transition-state +// (transition-state, transition-index) -> transition-id +// Reverse mappings: +// transition-id -> transition-state +// transition-id -> transition-index +// transition-state -> phone +// transition-state -> HMM-state +// transition-state -> pdf-id +// +// The main things the TransitionModel object can do are: +// Get initialized (need ContextDependency and HmmTopology objects). +// Read/write. +// Update [given a vector of counts indexed by transition-id]. +// Do the various integer mappings mentioned above. +// Get the probability (or log-probability) associated with a particular transition-id. + + +// Note: this was previously called TransitionUpdateConfig. +struct MleTransitionUpdateConfig { + BaseFloat floor; + BaseFloat mincount; + bool share_for_pdfs; // If true, share all transition parameters that have the same pdf. + MleTransitionUpdateConfig(BaseFloat floor = 0.01, + BaseFloat mincount = 5.0, + bool share_for_pdfs = false): + floor(floor), mincount(mincount), share_for_pdfs(share_for_pdfs) {} + + void Register (OptionsItf *po) { + po->Register("transition-floor", &floor, + "Floor for transition probabilities"); + po->Register("transition-min-count", &mincount, + "Minimum count required to update transitions from a state"); + po->Register("share-for-pdfs", &share_for_pdfs, + "If true, share all transition parameters where the states " + "have the same pdf."); + } +}; + +struct MapTransitionUpdateConfig { + BaseFloat tau; + bool share_for_pdfs; // If true, share all transition parameters that have the same pdf. + MapTransitionUpdateConfig(): tau(5.0), share_for_pdfs(false) { } + + void Register (OptionsItf *po) { + po->Register("transition-tau", &tau, "Tau value for MAP estimation of transition " + "probabilities."); + po->Register("share-for-pdfs", &share_for_pdfs, + "If true, share all transition parameters where the states " + "have the same pdf."); + } +}; + +class TransitionModel { + + public: + /// Initialize the object [e.g. at the start of training]. + /// The class keeps a copy of the HmmTopology object, but not + /// the ContextDependency object. + TransitionModel(const ContextDependency &ctx_dep, + const HmmTopology &hmm_topo); + + + /// Constructor that takes no arguments: typically used prior to calling Read. + TransitionModel() { } + + void Read(std::istream &is, bool binary); // note, no symbol table: topo object always read/written w/o symbols. + void Write(std::ostream &os, bool binary) const; + + + /// return reference to HMM-topology object. + const HmmTopology &GetTopo() const { return topo_; } + + /// \name Integer mapping functions + /// @{ + + int32 TripleToTransitionState(int32 phone, int32 hmm_state, int32 pdf) const; + int32 PairToTransitionId(int32 trans_state, int32 trans_index) const; + int32 TransitionIdToTransitionState(int32 trans_id) const; + int32 TransitionIdToTransitionIndex(int32 trans_id) const; + int32 TransitionStateToPhone(int32 trans_state) const; + int32 TransitionStateToHmmState(int32 trans_state) const; + int32 TransitionStateToPdf(int32 trans_state) const; + int32 SelfLoopOf(int32 trans_state) const; // returns the self-loop transition-id, or zero if + // this state doesn't have a self-loop. + + inline int32 TransitionIdToPdf(int32 trans_id) const; + int32 TransitionIdToPhone(int32 trans_id) const; + int32 TransitionIdToPdfClass(int32 trans_id) const; + int32 TransitionIdToHmmState(int32 trans_id) const; + + /// @} + + bool IsFinal(int32 trans_id) const; // returns true if this trans_id goes to the final state + // (which is bound to be nonemitting). + bool IsSelfLoop(int32 trans_id) const; // return true if this trans_id corresponds to a self-loop. + + /// Returns the total number of transition-ids (note, these are one-based). + inline int32 NumTransitionIds() const { return id2state_.size()-1; } + + /// Returns the number of transition-indices for a particular transition-state. + /// Note: "Indices" is the plural of "index". Index is not the same as "id", + /// here. A transition-index is a zero-based offset into the transitions + /// out of a particular transition state. + int32 NumTransitionIndices(int32 trans_state) const; + + /// Returns the total number of transition-states (note, these are one-based). + int32 NumTransitionStates() const { return triples_.size(); } + + // NumPdfs() actually returns the highest-numbered pdf we ever saw, plus one. + // In normal cases this should equal the number of pdfs in the system, but if you + // initialized this object with fewer than all the phones, and it happens that + // an unseen phone has the highest-numbered pdf, this might be different. + int32 NumPdfs() const { return num_pdfs_; } + + // This loops over the triples and finds the highest phone index present. If + // the FST symbol table for the phones is created in the expected way, i.e.: + // starting from 1 (<eps> is 0) and numbered contiguously till the last phone, + // this will be the total number of phones. + int32 NumPhones() const; + + /// Returns a sorted, unique list of phones. + const std::vector<int32> &GetPhones() const { return topo_.GetPhones(); } + + // Transition-parameter-getting functions: + BaseFloat GetTransitionProb(int32 trans_id) const; + BaseFloat GetTransitionLogProb(int32 trans_id) const; + + // The following functions are more specialized functions for getting + // transition probabilities, that are provided for convenience. + + /// Returns the log-probability of a particular non-self-loop transition + /// after subtracting the probability mass of the self-loop and renormalizing; + /// will crash if called on a self-loop. Specifically: + /// for non-self-loops it returns the log of that prob divided by (1 minus + /// self-loop-prob-for-that-state). + BaseFloat GetTransitionLogProbIgnoringSelfLoops(int32 trans_id) const; + + /// Returns the log-prob of the non-self-loop probability + /// mass for this transition state. (you can get the self-loop prob, if a self-loop + /// exists, by calling GetTransitionLogProb(SelfLoopOf(trans_state)). + BaseFloat GetNonSelfLoopLogProb(int32 trans_state) const; + + /// Does Maximum Likelihood estimation. The stats are counts/weights, indexed + /// by transition-id. This was previously called Update(). + void MleUpdate(const Vector<double> &stats, + const MleTransitionUpdateConfig &cfg, + BaseFloat *objf_impr_out, + BaseFloat *count_out); + + /// Does Maximum A Posteriori (MAP) estimation. The stats are counts/weights, + /// indexed by transition-id. + void MapUpdate(const Vector<double> &stats, + const MapTransitionUpdateConfig &cfg, + BaseFloat *objf_impr_out, + BaseFloat *count_out); + + /// Print will print the transition model in a human-readable way, for purposes of human + /// inspection. The "occs" are optional (they are indexed by pdf-id). + void Print(std::ostream &os, + const std::vector<std::string> &phone_names, + const Vector<double> *occs = NULL); + + + void InitStats(Vector<double> *stats) const { stats->Resize(NumTransitionIds()+1); } + + void Accumulate(BaseFloat prob, int32 trans_id, Vector<double> *stats) const { + KALDI_ASSERT(trans_id <= NumTransitionIds()); + (*stats)(trans_id) += prob; + // This is trivial and doesn't require class members, but leaves us more open + // to design changes than doing it manually. + } + + /// returns true if all the integer class members are identical (but does not + /// compare the transition probabilities. + bool Compatible(const TransitionModel &other) const; + + private: + void MleUpdateShared(const Vector<double> &stats, + const MleTransitionUpdateConfig &cfg, + BaseFloat *objf_impr_out, BaseFloat *count_out); + void MapUpdateShared(const Vector<double> &stats, + const MapTransitionUpdateConfig &cfg, + BaseFloat *objf_impr_out, BaseFloat *count_out); + void ComputeTriples(const ContextDependency &ctx_dep); // called from constructor. initializes triples_. + void ComputeDerived(); // called from constructor and Read function: computes state2id_ and id2state_. + void ComputeDerivedOfProbs(); // computes quantities derived from log-probs (currently just + // non_self_loop_log_probs_; called whenever log-probs change. + void InitializeProbs(); // called from constructor. + void Check() const; + + struct Triple { + int32 phone; + int32 hmm_state; + int32 pdf; + Triple() { } + Triple(int32 phone, int32 hmm_state, int32 pdf): + phone(phone), hmm_state(hmm_state), pdf(pdf) { } + bool operator < (const Triple &other) const { + if (phone < other.phone) return true; + else if (phone > other.phone) return false; + else if (hmm_state < other.hmm_state) return true; + else if (hmm_state > other.hmm_state) return false; + else return pdf < other.pdf; + } + bool operator == (const Triple &other) const { + return (phone == other.phone && hmm_state == other.hmm_state + && pdf == other.pdf); + } + }; + + HmmTopology topo_; + + /// Triples indexed by transition state minus one; + /// the triples are in sorted order which allows us to do the reverse mapping from + /// triple to transition state + std::vector<Triple> triples_; + + /// Gives the first transition_id of each transition-state; indexed by + /// the transition-state. Array indexed 1..num-transition-states+1 (the last one + /// is needed so we can know the num-transitions of the last transition-state. + std::vector<int32> state2id_; + + /// For each transition-id, the corresponding transition + /// state (indexed by transition-id). + std::vector<int32> id2state_; + + /// For each transition-id, the corresponding log-prob. Indexed by transition-id. + Vector<BaseFloat> log_probs_; + + /// For each transition-state, the log of (1 - self-loop-prob). Indexed by + /// transition-state. + Vector<BaseFloat> non_self_loop_log_probs_; + + /// This is actually one plus the highest-numbered pdf we ever got back from the + /// tree (but the tree numbers pdfs contiguously from zero so this is the number + /// of pdfs). + int32 num_pdfs_; + + + DISALLOW_COPY_AND_ASSIGN(TransitionModel); + +}; + +inline int32 TransitionModel::TransitionIdToPdf(int32 trans_id) const { + // If a lot of time is spent here we may create an extra array + // to handle this. + KALDI_ASSERT(static_cast<size_t>(trans_id) < id2state_.size() && + "Likely graph/model mismatch (graph built from wrong model?)"); + int32 trans_state = id2state_[trans_id]; + return triples_[trans_state-1].pdf; +} + +/// Works out which pdfs might correspond to the given phones. Will return true +/// if these pdfs correspond *just* to these phones, false if these pdfs are also +/// used by other phones. +/// @param trans_model [in] Transition-model used to work out this information +/// @param phones [in] A sorted, uniq vector that represents a set of phones +/// @param pdfs [out] Will be set to a sorted, uniq list of pdf-ids that correspond +/// to one of this set of phones. +/// @return Returns true if all of the pdfs output to "pdfs" correspond to phones from +/// just this set (false if they may be shared with phones outside this set). +bool GetPdfsForPhones(const TransitionModel &trans_model, + const std::vector<int32> &phones, + std::vector<int32> *pdfs); + +/// Works out which phones might correspond to the given pdfs. Similar to the +/// above GetPdfsForPhones(, ,) +bool GetPhonesForPdfs(const TransitionModel &trans_model, + const std::vector<int32> &pdfs, + std::vector<int32> *phones); +/// @} + + +} // end namespace kaldi + + +#endif diff --git a/kaldi_io/src/kaldi/hmm/tree-accu.h b/kaldi_io/src/kaldi/hmm/tree-accu.h new file mode 100644 index 0000000..d571762 --- /dev/null +++ b/kaldi_io/src/kaldi/hmm/tree-accu.h @@ -0,0 +1,69 @@ +// hmm/tree-accu.h + +// Copyright 2009-2011 Microsoft Corporation +// 2013 Johns Hopkins University (author: Daniel Povey) + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at + +// http://www.apache.org/licenses/LICENSE-2.0 + +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. +#ifndef KALDI_HMM_TREE_ACCU_H_ +#define KALDI_HMM_TREE_ACCU_H_ + +#include <cctype> // For isspace. +#include <limits> +#include "base/kaldi-common.h" +#include "hmm/transition-model.h" +#include "tree/clusterable-classes.h" +#include "tree/build-tree-questions.h" // needed for this typedef: +// typedef std::vector<std::pair<EventVector, Clusterable*> > BuildTreeStatsType; + +namespace kaldi { + +/// \ingroup tree_group_top +/// @{ + + +/// Accumulates the stats needed for training context-dependency trees (in the +/// "normal" way). It adds to 'stats' the stats obtained from this file. Any +/// new GaussClusterable* pointers in "stats" will be allocated with "new". + +void AccumulateTreeStats(const TransitionModel &trans_model, + BaseFloat var_floor, + int N, // context window size. + int P, // central position. + const std::vector<int32> &ci_phones, // sorted + const std::vector<int32> &alignment, + const Matrix<BaseFloat> &features, + const std::vector<int32> *phone_map, // or NULL + std::map<EventType, GaussClusterable*> *stats); + + + +/*** Read a mapping from one phone set to another. The phone map file has lines + of the form <old-phone> <new-phone>, where both entries are integers, usually + nonzero (but this is not enforced). This program will crash if the input is + invalid, e.g. there are multiple inconsistent entries for the same old phone. + The output vector "phone_map" will be indexed by old-phone and will contain + the corresponding new-phone, or -1 for any entry that was not defined. */ + +void ReadPhoneMap(std::string phone_map_rxfilename, + std::vector<int32> *phone_map); + + + +/// @} + +} // end namespace kaldi. + +#endif diff --git a/kaldi_io/src/kaldi/itf/clusterable-itf.h b/kaldi_io/src/kaldi/itf/clusterable-itf.h new file mode 100644 index 0000000..7ef9ae0 --- /dev/null +++ b/kaldi_io/src/kaldi/itf/clusterable-itf.h @@ -0,0 +1,97 @@ +// itf/clusterable-itf.h + +// Copyright 2009-2011 Microsoft Corporation; Go Vivace Inc. + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + + +#ifndef KALDI_ITF_CLUSTERABLE_ITF_H_ +#define KALDI_ITF_CLUSTERABLE_ITF_H_ 1 + +#include <string> +#include "base/kaldi-common.h" + +namespace kaldi { + + +/** \addtogroup clustering_group + @{ + A virtual class for clusterable objects; see \ref clustering for an + explanation if its function. +*/ + + + +class Clusterable { + public: + /// \name Functions that must be overridden + /// @{ + + /// Return a copy of this object. + virtual Clusterable *Copy() const = 0; + /// Return the objective function associated with the stats + /// [assuming ML estimation] + virtual BaseFloat Objf() const = 0; + /// Return the normalizer (typically, count) associated with the stats + virtual BaseFloat Normalizer() const = 0; + /// Set stats to empty. + virtual void SetZero() = 0; + /// Add other stats. + virtual void Add(const Clusterable &other) = 0; + /// Subtract other stats. + virtual void Sub(const Clusterable &other) = 0; + /// Scale the stats by a positive number f [not mandatory to supply this]. + virtual void Scale(BaseFloat f) { + KALDI_ERR << "This Clusterable object does not implement Scale()."; + } + + /// Return a string that describes the inherited type. + virtual std::string Type() const = 0; + + /// Write data to stream. + virtual void Write(std::ostream &os, bool binary) const = 0; + + /// Read data from a stream and return the corresponding object (const + /// function; it's a class member because we need access to the vtable + /// so generic code can read derived types). + virtual Clusterable* ReadNew(std::istream &os, bool binary) const = 0; + + virtual ~Clusterable() {} + + /// @} + + /// \name Functions that have default implementations + /// @{ + + // These functions have default implementations (but may be overridden for + // speed). Implementatons in tree/clusterable-classes.cc + + /// Return the objective function of the combined object this + other. + virtual BaseFloat ObjfPlus(const Clusterable &other) const; + /// Return the objective function of the subtracted object this - other. + virtual BaseFloat ObjfMinus(const Clusterable &other) const; + /// Return the objective function decrease from merging the two + /// clusters, negated to be a positive number (or zero). + virtual BaseFloat Distance(const Clusterable &other) const; + /// @} + +}; +/// @} end of "ingroup clustering_group" + +} // end namespace kaldi + +#endif // KALDI_ITF_CLUSTERABLE_ITF_H_ + diff --git a/kaldi_io/src/kaldi/itf/context-dep-itf.h b/kaldi_io/src/kaldi/itf/context-dep-itf.h new file mode 100644 index 0000000..6a0bd0f --- /dev/null +++ b/kaldi_io/src/kaldi/itf/context-dep-itf.h @@ -0,0 +1,80 @@ +// itf/context-dep-itf.h + +// Copyright 2009-2011 Microsoft Corporation; Go Vivace Inc. + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + + +#ifndef KALDI_ITF_CONTEXT_DEP_ITF_H_ +#define KALDI_ITF_CONTEXT_DEP_ITF_H_ +#include "base/kaldi-common.h" + +namespace kaldi { +/// @ingroup tree_group +/// @{ + +/// context-dep-itf.h provides a link between +/// the tree-building code in ../tree/, and the FST code in ../fstext/ +/// (particularly, ../fstext/context-dep.h). It is an abstract +/// interface that describes an object that can map from a +/// phone-in-context to a sequence of integer leaf-ids. +class ContextDependencyInterface { + public: + /// ContextWidth() returns the value N (e.g. 3 for triphone models) that says how many phones + /// are considered for computing context. + virtual int ContextWidth() const = 0; + + /// Central position P of the phone context, in 0-based numbering, e.g. P = 1 for typical + /// triphone system. We have to see if we can do without this function. + virtual int CentralPosition() const = 0; + + /// The "new" Compute interface. For typical topologies, + /// pdf_class would be 0, 1, 2. + /// Returns success or failure; outputs the pdf-id. + /// + /// "Compute" is the main function of this interface, that takes a + /// sequence of N phones (and it must be N phones), possibly + /// including epsilons (symbol id zero) but only at positions other + /// than P [these represent unknown phone context due to end or + /// begin of sequence]. We do not insist that Compute must always + /// output (into stateseq) a nonempty sequence of states, but we + /// anticipate that stateseq will alyway be nonempty at output in + /// typical use cases. "Compute" returns false if expansion somehow + /// failed. Normally the calling code should raise an exception if + /// this happens. We can define a different interface later in + /// order to handle other kinds of information-- the underlying + /// data-structures from event-map.h are very flexible. + virtual bool Compute(const std::vector<int32> &phoneseq, int32 pdf_class, + int32 *pdf_id) const = 0; + + + + /// NumPdfs() returns the number of acoustic pdfs (they are numbered 0.. NumPdfs()-1). + virtual int32 NumPdfs() const = 0; + + virtual ~ContextDependencyInterface() {}; + ContextDependencyInterface() {} + + /// Returns pointer to new object which is copy of current one. + virtual ContextDependencyInterface *Copy() const = 0; + private: + KALDI_DISALLOW_COPY_AND_ASSIGN(ContextDependencyInterface); +}; +/// @} +} // namespace Kaldi + + +#endif diff --git a/kaldi_io/src/kaldi/itf/decodable-itf.h b/kaldi_io/src/kaldi/itf/decodable-itf.h new file mode 100644 index 0000000..ba4d765 --- /dev/null +++ b/kaldi_io/src/kaldi/itf/decodable-itf.h @@ -0,0 +1,123 @@ +// itf/decodable-itf.h + +// Copyright 2009-2011 Microsoft Corporation; Saarland University; +// Mirko Hannemann; Go Vivace Inc.; +// 2013 Johns Hopkins University (author: Daniel Povey) + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#ifndef KALDI_ITF_DECODABLE_ITF_H_ +#define KALDI_ITF_DECODABLE_ITF_H_ 1 +#include "base/kaldi-common.h" + +namespace kaldi { +/// @ingroup Interfaces +/// @{ + + +/** + DecodableInterface provides a link between the (acoustic-modeling and + feature-processing) code and the decoder. The idea is to make this + interface as small as possible, and to make it as agnostic as possible about + the form of the acoustic model (e.g. don't assume the probabilities are a + function of just a vector of floats), and about the decoder (e.g. don't + assume it accesses frames in strict left-to-right order). For normal + models, without on-line operation, the "decodable" sub-class will just be a + wrapper around a matrix of features and an acoustic model, and it will + answer the question 'what is the acoustic likelihood for this index and this + frame?'. + + For online decoding, where the features are coming in in real time, it is + important to understand the IsLastFrame() and NumFramesReady() functions. + There are two ways these are used: the old online-decoding code, in ../online/, + and the new online-decoding code, in ../online2/. In the old online-decoding + code, the decoder would do: + \code{.cc} + for (int frame = 0; !decodable.IsLastFrame(frame); frame++) { + // Process this frame + } + \endcode + and the the call to IsLastFrame would block if the features had not arrived yet. + The decodable object would have to know when to terminate the decoding. This + online-decoding mode is still supported, it is what happens when you call, for + example, LatticeFasterDecoder::Decode(). + + We realized that this "blocking" mode of decoding is not very convenient + because it forces the program to be multi-threaded and makes it complex to + control endpointing. In the "new" decoding code, you don't call (for example) + LatticeFasterDecoder::Decode(), you call LatticeFasterDecoder::InitDecoding(), + and then each time you get more features, you provide them to the decodable + object, and you call LatticeFasterDecoder::AdvanceDecoding(), which does + something like this: + \code{.cc} + while (num_frames_decoded_ < decodable.NumFramesReady()) { + // Decode one more frame [increments num_frames_decoded_] + } + \endcode + So the decodable object never has IsLastFrame() called. For decoding where + you are starting with a matrix of features, the NumFramesReady() function will + always just return the number of frames in the file, and IsLastFrame() will + return true for the last frame. + + For truly online decoding, the "old" online decodable objects in ../online/ have a + "blocking" IsLastFrame() and will crash if you call NumFramesReady(). + The "new" online decodable objects in ../online2/ return the number of frames + currently accessible if you call NumFramesReady(). You will likely not need + to call IsLastFrame(), but we implement it to only return true for the last + frame of the file once we've decided to terminate decoding. +*/ + +class DecodableInterface { + public: + /// Returns the log likelihood, which will be negated in the decoder. + /// The "frame" starts from zero. You should verify that IsLastFrame(frame-1) + /// returns false before calling this. + virtual BaseFloat LogLikelihood(int32 frame, int32 index) = 0; + + /// Returns true if this is the last frame. Frames are zero-based, so the + /// first frame is zero. IsLastFrame(-1) will return false, unless the file + /// is empty (which is a case that I'm not sure all the code will handle, so + /// be careful). Caution: the behavior of this function in an online setting + /// is being changed somewhat. In future it may return false in cases where + /// we haven't yet decided to terminate decoding, but later true if we decide + /// to terminate decoding. The plan in future is to rely more on + /// NumFramesReady(), and in future, IsLastFrame() would always return false + /// in an online-decoding setting, and would only return true in a + /// decoding-from-matrix setting where we want to allow the last delta or LDA + /// features to be flushed out for compatibility with the baseline setup. + virtual bool IsLastFrame(int32 frame) const = 0; + + /// The call NumFramesReady() will return the number of frames currently available + /// for this decodable object. This is for use in setups where you don't want the + /// decoder to block while waiting for input. This is newly added as of Jan 2014, + /// and I hope, going forward, to rely on this mechanism more than IsLastFrame to + /// know when to stop decoding. + virtual int32 NumFramesReady() const { + KALDI_ERR << "NumFramesReady() not implemented for this decodable type."; + return -1; + } + + /// Returns the number of states in the acoustic model + /// (they will be indexed one-based, i.e. from 1 to NumIndices(); + /// this is for compatibility with OpenFst. + virtual int32 NumIndices() const = 0; + + virtual ~DecodableInterface() {} +}; +/// @} +} // namespace Kaldi + +#endif // KALDI_ITF_DECODABLE_ITF_H_ diff --git a/kaldi_io/src/kaldi/itf/online-feature-itf.h b/kaldi_io/src/kaldi/itf/online-feature-itf.h new file mode 100644 index 0000000..dafcd8a --- /dev/null +++ b/kaldi_io/src/kaldi/itf/online-feature-itf.h @@ -0,0 +1,105 @@ +// itf/online-feature-itf.h + +// Copyright 2013 Johns Hopkins University (author: Daniel Povey) + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#ifndef KALDI_ITF_ONLINE_FEATURE_ITF_H_ +#define KALDI_ITF_ONLINE_FEATURE_ITF_H_ 1 +#include "base/kaldi-common.h" +#include "matrix/matrix-lib.h" + +namespace kaldi { +/// @ingroup Interfaces +/// @{ + +/** + OnlineFeatureInterface is an interface for online feature processing (it is + also usable in the offline setting, but currently we're not using it for + that). This is for use in the online2/ directory, and it supersedes the + interface in ../online/online-feat-input.h. We have a slighty different + model that puts more control in the hands of the calling thread, and won't + involve waiting on semaphores in the decoding thread. + + This interface only specifies how the object *outputs* the features. + How it obtains the features, e.g. from a previous object or objects of type + OnlineFeatureInterface, is not specified in the interface and you will + likely define new constructors or methods in the derived type to do that. + + You should appreciate that this interface is designed to allow random + access to features, as long as they are ready. That is, the user + can call GetFrame for any frame less than NumFramesReady(), and when + implementing a child class you must not make assumptions about the + order in which the user makes these calls. +*/ + +class OnlineFeatureInterface { + public: + virtual int32 Dim() const = 0; /// returns the feature dimension. + + /// Returns the total number of frames, since the start of the utterance, that + /// are now available. In an online-decoding context, this will likely + /// increase with time as more data becomes available. + virtual int32 NumFramesReady() const = 0; + + /// Returns true if this is the last frame. Frame indices are zero-based, so the + /// first frame is zero. IsLastFrame(-1) will return false, unless the file + /// is empty (which is a case that I'm not sure all the code will handle, so + /// be careful). This function may return false for some frame if + /// we haven't yet decided to terminate decoding, but later true if we decide + /// to terminate decoding. This function exists mainly to correctly handle + /// end effects in feature extraction, and is not a mechanism to determine how + /// many frames are in the decodable object (as it used to be, and for backward + /// compatibility, still is, in the Decodable interface). + virtual bool IsLastFrame(int32 frame) const = 0; + + /// Gets the feature vector for this frame. Before calling this for a given + /// frame, it is assumed that you called NumFramesReady() and it returned a + /// number greater than "frame". Otherwise this call will likely crash with + /// an assert failure. This function is not declared const, in case there is + /// some kind of caching going on, but most of the time it shouldn't modify + /// the class. + virtual void GetFrame(int32 frame, VectorBase<BaseFloat> *feat) = 0; + + /// Virtual destructor. Note: constructors that take another member of + /// type OnlineFeatureInterface are not expected to take ownership of + /// that pointer; the caller needs to keep track of that manually. + virtual ~OnlineFeatureInterface() { } +}; + + +/// Add a virtual class for "source" features such as MFCC or PLP or pitch +/// features. +class OnlineBaseFeature: public OnlineFeatureInterface { + public: + /// This would be called from the application, when you get more wave data. + /// Note: the sampling_rate is typically only provided so the code can assert + /// that it matches the sampling rate expected in the options. + virtual void AcceptWaveform(BaseFloat sampling_rate, + const VectorBase<BaseFloat> &waveform) = 0; + + /// InputFinished() tells the class you won't be providing any + /// more waveform. This will help flush out the last few frames + /// of delta or LDA features (it will typically affect the return value + /// of IsLastFrame. + virtual void InputFinished() = 0; +}; + + +/// @} +} // namespace Kaldi + +#endif // KALDI_ITF_ONLINE_FEATURE_ITF_H_ diff --git a/kaldi_io/src/kaldi/itf/optimizable-itf.h b/kaldi_io/src/kaldi/itf/optimizable-itf.h new file mode 100644 index 0000000..1b8f54b --- /dev/null +++ b/kaldi_io/src/kaldi/itf/optimizable-itf.h @@ -0,0 +1,51 @@ +// itf/optimizable-itf.h + +// Copyright 2009-2011 Go Vivace Inc.; Microsoft Corporation; Georg Stemmer + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. +#ifndef KALDI_ITF_OPTIMIZABLE_ITF_H_ +#define KALDI_ITF_OPTIMIZABLE_ITF_H_ + +#include "base/kaldi-common.h" +#include "matrix/matrix-lib.h" + +namespace kaldi { +/// @ingroup Interfaces +/// @{ + +/// OptimizableInterface provides +/// a virtual class for optimizable objects. +/// E.g. a class that computed a likelihood function and +/// its gradient using some parameter +/// that has to be optimized on data +/// could inherit from it. +template<class Real> +class OptimizableInterface { + public: + /// computes gradient for a parameter params and returns it + /// in gradient_out + virtual void ComputeGradient(const Vector<Real> ¶ms, + Vector<Real> *gradient_out) = 0; + /// computes the function value for a parameter params + /// and returns it + virtual Real ComputeValue(const Vector<Real> ¶ms) = 0; + + virtual ~OptimizableInterface() {} +}; +/// @} end of "Interfaces" +} // end namespace kaldi + +#endif diff --git a/kaldi_io/src/kaldi/itf/options-itf.h b/kaldi_io/src/kaldi/itf/options-itf.h new file mode 100644 index 0000000..204f46d --- /dev/null +++ b/kaldi_io/src/kaldi/itf/options-itf.h @@ -0,0 +1,49 @@ +// itf/options-itf.h + +// Copyright 2013 Tanel Alumae, Tallinn University of Technology + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#ifndef KALDI_ITF_OPTIONS_ITF_H_ +#define KALDI_ITF_OPTIONS_ITF_H_ 1 +#include "base/kaldi-common.h" + +namespace kaldi { + +class OptionsItf { + public: + + virtual void Register(const std::string &name, + bool *ptr, const std::string &doc) = 0; + virtual void Register(const std::string &name, + int32 *ptr, const std::string &doc) = 0; + virtual void Register(const std::string &name, + uint32 *ptr, const std::string &doc) = 0; + virtual void Register(const std::string &name, + float *ptr, const std::string &doc) = 0; + virtual void Register(const std::string &name, + double *ptr, const std::string &doc) = 0; + virtual void Register(const std::string &name, + std::string *ptr, const std::string &doc) = 0; + + virtual ~OptionsItf() {} +}; + +} // namespace Kaldi + +#endif // KALDI_ITF_OPTIONS_ITF_H_ + + diff --git a/kaldi_io/src/kaldi/matrix/cblas-wrappers.h b/kaldi_io/src/kaldi/matrix/cblas-wrappers.h new file mode 100644 index 0000000..ebec0a3 --- /dev/null +++ b/kaldi_io/src/kaldi/matrix/cblas-wrappers.h @@ -0,0 +1,491 @@ +// matrix/cblas-wrappers.h + +// Copyright 2012 Johns Hopkins University (author: Daniel Povey); +// Haihua Xu; Wei Shi + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at + +// http://www.apache.org/licenses/LICENSE-2.0 + +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. +#ifndef KALDI_MATRIX_CBLAS_WRAPPERS_H_ +#define KALDI_MATRIX_CBLAS_WRAPPERS_H_ 1 + + +#include <limits> +#include "matrix/sp-matrix.h" +#include "matrix/kaldi-vector.h" +#include "matrix/kaldi-matrix.h" +#include "matrix/matrix-functions.h" + +// Do not include this file directly. It is to be included +// by .cc files in this directory. + +namespace kaldi { + + +inline void cblas_Xcopy(const int N, const float *X, const int incX, float *Y, + const int incY) { + cblas_scopy(N, X, incX, Y, incY); +} + +inline void cblas_Xcopy(const int N, const double *X, const int incX, double *Y, + const int incY) { + cblas_dcopy(N, X, incX, Y, incY); +} + + +inline float cblas_Xasum(const int N, const float *X, const int incX) { + return cblas_sasum(N, X, incX); +} + +inline double cblas_Xasum(const int N, const double *X, const int incX) { + return cblas_dasum(N, X, incX); +} + +inline void cblas_Xrot(const int N, float *X, const int incX, float *Y, + const int incY, const float c, const float s) { + cblas_srot(N, X, incX, Y, incY, c, s); +} +inline void cblas_Xrot(const int N, double *X, const int incX, double *Y, + const int incY, const double c, const double s) { + cblas_drot(N, X, incX, Y, incY, c, s); +} +inline float cblas_Xdot(const int N, const float *const X, + const int incX, const float *const Y, + const int incY) { + return cblas_sdot(N, X, incX, Y, incY); +} +inline double cblas_Xdot(const int N, const double *const X, + const int incX, const double *const Y, + const int incY) { + return cblas_ddot(N, X, incX, Y, incY); +} +inline void cblas_Xaxpy(const int N, const float alpha, const float *X, + const int incX, float *Y, const int incY) { + cblas_saxpy(N, alpha, X, incX, Y, incY); +} +inline void cblas_Xaxpy(const int N, const double alpha, const double *X, + const int incX, double *Y, const int incY) { + cblas_daxpy(N, alpha, X, incX, Y, incY); +} +inline void cblas_Xscal(const int N, const float alpha, float *data, + const int inc) { + cblas_sscal(N, alpha, data, inc); +} +inline void cblas_Xscal(const int N, const double alpha, double *data, + const int inc) { + cblas_dscal(N, alpha, data, inc); +} +inline void cblas_Xspmv(const float alpha, const int num_rows, const float *Mdata, + const float *v, const int v_inc, + const float beta, float *y, const int y_inc) { + cblas_sspmv(CblasRowMajor, CblasLower, num_rows, alpha, Mdata, v, v_inc, beta, y, y_inc); +} +inline void cblas_Xspmv(const double alpha, const int num_rows, const double *Mdata, + const double *v, const int v_inc, + const double beta, double *y, const int y_inc) { + cblas_dspmv(CblasRowMajor, CblasLower, num_rows, alpha, Mdata, v, v_inc, beta, y, y_inc); +} +inline void cblas_Xtpmv(MatrixTransposeType trans, const float *Mdata, + const int num_rows, float *y, const int y_inc) { + cblas_stpmv(CblasRowMajor, CblasLower, static_cast<CBLAS_TRANSPOSE>(trans), + CblasNonUnit, num_rows, Mdata, y, y_inc); +} +inline void cblas_Xtpmv(MatrixTransposeType trans, const double *Mdata, + const int num_rows, double *y, const int y_inc) { + cblas_dtpmv(CblasRowMajor, CblasLower, static_cast<CBLAS_TRANSPOSE>(trans), + CblasNonUnit, num_rows, Mdata, y, y_inc); +} + + +inline void cblas_Xtpsv(MatrixTransposeType trans, const float *Mdata, + const int num_rows, float *y, const int y_inc) { + cblas_stpsv(CblasRowMajor, CblasLower, static_cast<CBLAS_TRANSPOSE>(trans), + CblasNonUnit, num_rows, Mdata, y, y_inc); +} +inline void cblas_Xtpsv(MatrixTransposeType trans, const double *Mdata, + const int num_rows, double *y, const int y_inc) { + cblas_dtpsv(CblasRowMajor, CblasLower, static_cast<CBLAS_TRANSPOSE>(trans), + CblasNonUnit, num_rows, Mdata, y, y_inc); +} + +// x = alpha * M * y + beta * x +inline void cblas_Xspmv(MatrixIndexT dim, float alpha, const float *Mdata, + const float *ydata, MatrixIndexT ystride, + float beta, float *xdata, MatrixIndexT xstride) { + cblas_sspmv(CblasRowMajor, CblasLower, dim, alpha, Mdata, + ydata, ystride, beta, xdata, xstride); +} +inline void cblas_Xspmv(MatrixIndexT dim, double alpha, const double *Mdata, + const double *ydata, MatrixIndexT ystride, + double beta, double *xdata, MatrixIndexT xstride) { + cblas_dspmv(CblasRowMajor, CblasLower, dim, alpha, Mdata, + ydata, ystride, beta, xdata, xstride); +} + +// Implements A += alpha * (x y' + y x'); A is symmetric matrix. +inline void cblas_Xspr2(MatrixIndexT dim, float alpha, const float *Xdata, + MatrixIndexT incX, const float *Ydata, MatrixIndexT incY, + float *Adata) { + cblas_sspr2(CblasRowMajor, CblasLower, dim, alpha, Xdata, + incX, Ydata, incY, Adata); +} +inline void cblas_Xspr2(MatrixIndexT dim, double alpha, const double *Xdata, + MatrixIndexT incX, const double *Ydata, MatrixIndexT incY, + double *Adata) { + cblas_dspr2(CblasRowMajor, CblasLower, dim, alpha, Xdata, + incX, Ydata, incY, Adata); +} + +// Implements A += alpha * (x x'); A is symmetric matrix. +inline void cblas_Xspr(MatrixIndexT dim, float alpha, const float *Xdata, + MatrixIndexT incX, float *Adata) { + cblas_sspr(CblasRowMajor, CblasLower, dim, alpha, Xdata, incX, Adata); +} +inline void cblas_Xspr(MatrixIndexT dim, double alpha, const double *Xdata, + MatrixIndexT incX, double *Adata) { + cblas_dspr(CblasRowMajor, CblasLower, dim, alpha, Xdata, incX, Adata); +} + +// sgemv,dgemv: y = alpha M x + beta y. +inline void cblas_Xgemv(MatrixTransposeType trans, MatrixIndexT num_rows, + MatrixIndexT num_cols, float alpha, const float *Mdata, + MatrixIndexT stride, const float *xdata, + MatrixIndexT incX, float beta, float *ydata, MatrixIndexT incY) { + cblas_sgemv(CblasRowMajor, static_cast<CBLAS_TRANSPOSE>(trans), num_rows, + num_cols, alpha, Mdata, stride, xdata, incX, beta, ydata, incY); +} +inline void cblas_Xgemv(MatrixTransposeType trans, MatrixIndexT num_rows, + MatrixIndexT num_cols, double alpha, const double *Mdata, + MatrixIndexT stride, const double *xdata, + MatrixIndexT incX, double beta, double *ydata, MatrixIndexT incY) { + cblas_dgemv(CblasRowMajor, static_cast<CBLAS_TRANSPOSE>(trans), num_rows, + num_cols, alpha, Mdata, stride, xdata, incX, beta, ydata, incY); +} + +// sgbmv, dgmmv: y = alpha M x + + beta * y. +inline void cblas_Xgbmv(MatrixTransposeType trans, MatrixIndexT num_rows, + MatrixIndexT num_cols, MatrixIndexT num_below, + MatrixIndexT num_above, float alpha, const float *Mdata, + MatrixIndexT stride, const float *xdata, + MatrixIndexT incX, float beta, float *ydata, MatrixIndexT incY) { + cblas_sgbmv(CblasRowMajor, static_cast<CBLAS_TRANSPOSE>(trans), num_rows, + num_cols, num_below, num_above, alpha, Mdata, stride, xdata, + incX, beta, ydata, incY); +} +inline void cblas_Xgbmv(MatrixTransposeType trans, MatrixIndexT num_rows, + MatrixIndexT num_cols, MatrixIndexT num_below, + MatrixIndexT num_above, double alpha, const double *Mdata, + MatrixIndexT stride, const double *xdata, + MatrixIndexT incX, double beta, double *ydata, MatrixIndexT incY) { + cblas_dgbmv(CblasRowMajor, static_cast<CBLAS_TRANSPOSE>(trans), num_rows, + num_cols, num_below, num_above, alpha, Mdata, stride, xdata, + incX, beta, ydata, incY); +} + + +template<typename Real> +inline void Xgemv_sparsevec(MatrixTransposeType trans, MatrixIndexT num_rows, + MatrixIndexT num_cols, Real alpha, const Real *Mdata, + MatrixIndexT stride, const Real *xdata, + MatrixIndexT incX, Real beta, Real *ydata, + MatrixIndexT incY) { + if (trans == kNoTrans) { + if (beta != 1.0) cblas_Xscal(num_rows, beta, ydata, incY); + for (MatrixIndexT i = 0; i < num_cols; i++) { + Real x_i = xdata[i * incX]; + if (x_i == 0.0) continue; + // Add to ydata, the i'th column of M, times alpha * x_i + cblas_Xaxpy(num_rows, x_i * alpha, Mdata + i, stride, ydata, incY); + } + } else { + if (beta != 1.0) cblas_Xscal(num_cols, beta, ydata, incY); + for (MatrixIndexT i = 0; i < num_rows; i++) { + Real x_i = xdata[i * incX]; + if (x_i == 0.0) continue; + // Add to ydata, the i'th row of M, times alpha * x_i + cblas_Xaxpy(num_cols, x_i * alpha, + Mdata + (i * stride), 1, ydata, incY); + } + } +} + +inline void cblas_Xgemm(const float alpha, + MatrixTransposeType transA, + const float *Adata, + MatrixIndexT a_num_rows, MatrixIndexT a_num_cols, MatrixIndexT a_stride, + MatrixTransposeType transB, + const float *Bdata, MatrixIndexT b_stride, + const float beta, + float *Mdata, + MatrixIndexT num_rows, MatrixIndexT num_cols,MatrixIndexT stride) { + cblas_sgemm(CblasRowMajor, static_cast<CBLAS_TRANSPOSE>(transA), + static_cast<CBLAS_TRANSPOSE>(transB), + num_rows, num_cols, transA == kNoTrans ? a_num_cols : a_num_rows, + alpha, Adata, a_stride, Bdata, b_stride, + beta, Mdata, stride); +} +inline void cblas_Xgemm(const double alpha, + MatrixTransposeType transA, + const double *Adata, + MatrixIndexT a_num_rows, MatrixIndexT a_num_cols, MatrixIndexT a_stride, + MatrixTransposeType transB, + const double *Bdata, MatrixIndexT b_stride, + const double beta, + double *Mdata, + MatrixIndexT num_rows, MatrixIndexT num_cols,MatrixIndexT stride) { + cblas_dgemm(CblasRowMajor, static_cast<CBLAS_TRANSPOSE>(transA), + static_cast<CBLAS_TRANSPOSE>(transB), + num_rows, num_cols, transA == kNoTrans ? a_num_cols : a_num_rows, + alpha, Adata, a_stride, Bdata, b_stride, + beta, Mdata, stride); +} + + +inline void cblas_Xsymm(const float alpha, + MatrixIndexT sz, + const float *Adata,MatrixIndexT a_stride, + const float *Bdata,MatrixIndexT b_stride, + const float beta, + float *Mdata, MatrixIndexT stride) { + cblas_ssymm(CblasRowMajor, CblasLeft, CblasLower, sz, sz, alpha, Adata, + a_stride, Bdata, b_stride, beta, Mdata, stride); +} +inline void cblas_Xsymm(const double alpha, + MatrixIndexT sz, + const double *Adata,MatrixIndexT a_stride, + const double *Bdata,MatrixIndexT b_stride, + const double beta, + double *Mdata, MatrixIndexT stride) { + cblas_dsymm(CblasRowMajor, CblasLeft, CblasLower, sz, sz, alpha, Adata, + a_stride, Bdata, b_stride, beta, Mdata, stride); +} +// ger: M += alpha x y^T. +inline void cblas_Xger(MatrixIndexT num_rows, MatrixIndexT num_cols, float alpha, + const float *xdata, MatrixIndexT incX, const float *ydata, + MatrixIndexT incY, float *Mdata, MatrixIndexT stride) { + cblas_sger(CblasRowMajor, num_rows, num_cols, alpha, xdata, 1, ydata, 1, + Mdata, stride); +} +inline void cblas_Xger(MatrixIndexT num_rows, MatrixIndexT num_cols, double alpha, + const double *xdata, MatrixIndexT incX, const double *ydata, + MatrixIndexT incY, double *Mdata, MatrixIndexT stride) { + cblas_dger(CblasRowMajor, num_rows, num_cols, alpha, xdata, 1, ydata, 1, + Mdata, stride); +} + +// syrk: symmetric rank-k update. +// if trans==kNoTrans, then C = alpha A A^T + beta C +// else C = alpha A^T A + beta C. +// note: dim_c is dim(C), other_dim_a is the "other" dimension of A, i.e. +// num-cols(A) if kNoTrans, or num-rows(A) if kTrans. +// We only need the row-major and lower-triangular option of this, and this +// is hard-coded. +inline void cblas_Xsyrk ( + const MatrixTransposeType trans, const MatrixIndexT dim_c, + const MatrixIndexT other_dim_a, const float alpha, const float *A, + const MatrixIndexT a_stride, const float beta, float *C, + const MatrixIndexT c_stride) { + cblas_ssyrk(CblasRowMajor, CblasLower, static_cast<CBLAS_TRANSPOSE>(trans), + dim_c, other_dim_a, alpha, A, a_stride, beta, C, c_stride); +} + +inline void cblas_Xsyrk( + const MatrixTransposeType trans, const MatrixIndexT dim_c, + const MatrixIndexT other_dim_a, const double alpha, const double *A, + const MatrixIndexT a_stride, const double beta, double *C, + const MatrixIndexT c_stride) { + cblas_dsyrk(CblasRowMajor, CblasLower, static_cast<CBLAS_TRANSPOSE>(trans), + dim_c, other_dim_a, alpha, A, a_stride, beta, C, c_stride); +} + +/// matrix-vector multiply using a banded matrix; we always call this +/// with b = 1 meaning we're multiplying by a diagonal matrix. This is used for +/// elementwise multiplication. We miss some of the arguments out of this +/// wrapper. +inline void cblas_Xsbmv1( + const MatrixIndexT dim, + const double *A, + const double alpha, + const double *x, + const double beta, + double *y) { + cblas_dsbmv(CblasRowMajor, CblasLower, dim, 0, alpha, A, + 1, x, 1, beta, y, 1); +} + +inline void cblas_Xsbmv1( + const MatrixIndexT dim, + const float *A, + const float alpha, + const float *x, + const float beta, + float *y) { + cblas_ssbmv(CblasRowMajor, CblasLower, dim, 0, alpha, A, + 1, x, 1, beta, y, 1); +} + + +/// This is not really a wrapper for CBLAS as CBLAS does not have this; in future we could +/// extend this somehow. +inline void mul_elements( + const MatrixIndexT dim, + const double *a, + double *b) { // does b *= a, elementwise. + double c1, c2, c3, c4; + MatrixIndexT i; + for (i = 0; i + 4 <= dim; i += 4) { + c1 = a[i] * b[i]; + c2 = a[i+1] * b[i+1]; + c3 = a[i+2] * b[i+2]; + c4 = a[i+3] * b[i+3]; + b[i] = c1; + b[i+1] = c2; + b[i+2] = c3; + b[i+3] = c4; + } + for (; i < dim; i++) + b[i] *= a[i]; +} + +inline void mul_elements( + const MatrixIndexT dim, + const float *a, + float *b) { // does b *= a, elementwise. + float c1, c2, c3, c4; + MatrixIndexT i; + for (i = 0; i + 4 <= dim; i += 4) { + c1 = a[i] * b[i]; + c2 = a[i+1] * b[i+1]; + c3 = a[i+2] * b[i+2]; + c4 = a[i+3] * b[i+3]; + b[i] = c1; + b[i+1] = c2; + b[i+2] = c3; + b[i+3] = c4; + } + for (; i < dim; i++) + b[i] *= a[i]; +} + + + +// add clapack here +#if !defined(HAVE_ATLAS) +inline void clapack_Xtptri(KaldiBlasInt *num_rows, float *Mdata, KaldiBlasInt *result) { + stptri_(const_cast<char *>("U"), const_cast<char *>("N"), num_rows, Mdata, result); +} +inline void clapack_Xtptri(KaldiBlasInt *num_rows, double *Mdata, KaldiBlasInt *result) { + dtptri_(const_cast<char *>("U"), const_cast<char *>("N"), num_rows, Mdata, result); +} +// +inline void clapack_Xgetrf2(KaldiBlasInt *num_rows, KaldiBlasInt *num_cols, + float *Mdata, KaldiBlasInt *stride, KaldiBlasInt *pivot, + KaldiBlasInt *result) { + sgetrf_(num_rows, num_cols, Mdata, stride, pivot, result); +} +inline void clapack_Xgetrf2(KaldiBlasInt *num_rows, KaldiBlasInt *num_cols, + double *Mdata, KaldiBlasInt *stride, KaldiBlasInt *pivot, + KaldiBlasInt *result) { + dgetrf_(num_rows, num_cols, Mdata, stride, pivot, result); +} + +// +inline void clapack_Xgetri2(KaldiBlasInt *num_rows, float *Mdata, KaldiBlasInt *stride, + KaldiBlasInt *pivot, float *p_work, + KaldiBlasInt *l_work, KaldiBlasInt *result) { + sgetri_(num_rows, Mdata, stride, pivot, p_work, l_work, result); +} +inline void clapack_Xgetri2(KaldiBlasInt *num_rows, double *Mdata, KaldiBlasInt *stride, + KaldiBlasInt *pivot, double *p_work, + KaldiBlasInt *l_work, KaldiBlasInt *result) { + dgetri_(num_rows, Mdata, stride, pivot, p_work, l_work, result); +} +// +inline void clapack_Xgesvd(char *v, char *u, KaldiBlasInt *num_cols, + KaldiBlasInt *num_rows, float *Mdata, KaldiBlasInt *stride, + float *sv, float *Vdata, KaldiBlasInt *vstride, + float *Udata, KaldiBlasInt *ustride, float *p_work, + KaldiBlasInt *l_work, KaldiBlasInt *result) { + sgesvd_(v, u, + num_cols, num_rows, Mdata, stride, + sv, Vdata, vstride, Udata, ustride, + p_work, l_work, result); +} +inline void clapack_Xgesvd(char *v, char *u, KaldiBlasInt *num_cols, + KaldiBlasInt *num_rows, double *Mdata, KaldiBlasInt *stride, + double *sv, double *Vdata, KaldiBlasInt *vstride, + double *Udata, KaldiBlasInt *ustride, double *p_work, + KaldiBlasInt *l_work, KaldiBlasInt *result) { + dgesvd_(v, u, + num_cols, num_rows, Mdata, stride, + sv, Vdata, vstride, Udata, ustride, + p_work, l_work, result); +} +// +void inline clapack_Xsptri(KaldiBlasInt *num_rows, float *Mdata, + KaldiBlasInt *ipiv, float *work, KaldiBlasInt *result) { + ssptri_(const_cast<char *>("U"), num_rows, Mdata, ipiv, work, result); +} +void inline clapack_Xsptri(KaldiBlasInt *num_rows, double *Mdata, + KaldiBlasInt *ipiv, double *work, KaldiBlasInt *result) { + dsptri_(const_cast<char *>("U"), num_rows, Mdata, ipiv, work, result); +} +// +void inline clapack_Xsptrf(KaldiBlasInt *num_rows, float *Mdata, + KaldiBlasInt *ipiv, KaldiBlasInt *result) { + ssptrf_(const_cast<char *>("U"), num_rows, Mdata, ipiv, result); +} +void inline clapack_Xsptrf(KaldiBlasInt *num_rows, double *Mdata, + KaldiBlasInt *ipiv, KaldiBlasInt *result) { + dsptrf_(const_cast<char *>("U"), num_rows, Mdata, ipiv, result); +} +#else +inline void clapack_Xgetrf(MatrixIndexT num_rows, MatrixIndexT num_cols, + float *Mdata, MatrixIndexT stride, + int *pivot, int *result) { + *result = clapack_sgetrf(CblasColMajor, num_rows, num_cols, + Mdata, stride, pivot); +} + +inline void clapack_Xgetrf(MatrixIndexT num_rows, MatrixIndexT num_cols, + double *Mdata, MatrixIndexT stride, + int *pivot, int *result) { + *result = clapack_dgetrf(CblasColMajor, num_rows, num_cols, + Mdata, stride, pivot); +} +// +inline int clapack_Xtrtri(int num_rows, float *Mdata, MatrixIndexT stride) { + return clapack_strtri(CblasColMajor, CblasUpper, CblasNonUnit, num_rows, + Mdata, stride); +} + +inline int clapack_Xtrtri(int num_rows, double *Mdata, MatrixIndexT stride) { + return clapack_dtrtri(CblasColMajor, CblasUpper, CblasNonUnit, num_rows, + Mdata, stride); +} +// +inline void clapack_Xgetri(MatrixIndexT num_rows, float *Mdata, MatrixIndexT stride, + int *pivot, int *result) { + *result = clapack_sgetri(CblasColMajor, num_rows, Mdata, stride, pivot); +} +inline void clapack_Xgetri(MatrixIndexT num_rows, double *Mdata, MatrixIndexT stride, + int *pivot, int *result) { + *result = clapack_dgetri(CblasColMajor, num_rows, Mdata, stride, pivot); +} +#endif + +} +// namespace kaldi + +#endif diff --git a/kaldi_io/src/kaldi/matrix/compressed-matrix.h b/kaldi_io/src/kaldi/matrix/compressed-matrix.h new file mode 100644 index 0000000..746cab3 --- /dev/null +++ b/kaldi_io/src/kaldi/matrix/compressed-matrix.h @@ -0,0 +1,179 @@ +// matrix/compressed-matrix.h + +// Copyright 2012 Johns Hopkins University (author: Daniel Povey) +// Frantisek Skala, Wei Shi + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#ifndef KALDI_MATRIX_COMPRESSED_MATRIX_H_ +#define KALDI_MATRIX_COMPRESSED_MATRIX_H_ 1 + +#include "kaldi-matrix.h" + +namespace kaldi { + +/// \addtogroup matrix_group +/// @{ + +/// This class does lossy compression of a matrix. It only +/// supports copying to-from a KaldiMatrix. For large matrices, +/// each element is compressed into about one byte, but there +/// is a little overhead on top of that (globally, and also per +/// column). + +/// The basic idea is for each column (in the normal configuration) +/// we work out the values at the 0th, 25th, 50th and 100th percentiles +/// and store them as 16-bit integers; we then encode each value in +/// the column as a single byte, in 3 separate ranges with different +/// linear encodings (0-25th, 25-50th, 50th-100th). +/// If the matrix has 8 rows or fewer, we simply store all values as +/// uint16. + +class CompressedMatrix { + public: + CompressedMatrix(): data_(NULL) { } + + ~CompressedMatrix() { Destroy(); } + + template<typename Real> + CompressedMatrix(const MatrixBase<Real> &mat): data_(NULL) { CopyFromMat(mat); } + + /// Initializer that can be used to select part of an existing + /// CompressedMatrix without un-compressing and re-compressing (note: unlike + /// similar initializers for class Matrix, it doesn't point to the same memory + /// location). + CompressedMatrix(const CompressedMatrix &mat, + const MatrixIndexT row_offset, + const MatrixIndexT num_rows, + const MatrixIndexT col_offset, + const MatrixIndexT num_cols); + + void *Data() const { return this->data_; } + + /// This will resize *this and copy the contents of mat to *this. + template<typename Real> + void CopyFromMat(const MatrixBase<Real> &mat); + + CompressedMatrix(const CompressedMatrix &mat); + + CompressedMatrix &operator = (const CompressedMatrix &mat); // assignment operator. + + template<typename Real> + CompressedMatrix &operator = (const MatrixBase<Real> &mat); // assignment operator. + + /// Copies contents to matrix. Note: mat must have the correct size, + /// CopyToMat no longer attempts to resize it. + template<typename Real> + void CopyToMat(MatrixBase<Real> *mat) const; + + void Write(std::ostream &os, bool binary) const; + + void Read(std::istream &is, bool binary); + + /// Returns number of rows (or zero for emtpy matrix). + inline MatrixIndexT NumRows() const { return (data_ == NULL) ? 0 : + (*reinterpret_cast<GlobalHeader*>(data_)).num_rows; } + + /// Returns number of columns (or zero for emtpy matrix). + inline MatrixIndexT NumCols() const { return (data_ == NULL) ? 0 : + (*reinterpret_cast<GlobalHeader*>(data_)).num_cols; } + + /// Copies row #row of the matrix into vector v. + /// Note: v must have same size as #cols. + template<typename Real> + void CopyRowToVec(MatrixIndexT row, VectorBase<Real> *v) const; + + /// Copies column #col of the matrix into vector v. + /// Note: v must have same size as #rows. + template<typename Real> + void CopyColToVec(MatrixIndexT col, VectorBase<Real> *v) const; + + /// Copies submatrix of compressed matrix into matrix dest. + /// Submatrix starts at row row_offset and column column_offset and its size + /// is defined by size of provided matrix dest + template<typename Real> + void CopyToMat(int32 row_offset, + int32 column_offset, + MatrixBase<Real> *dest) const; + + void Swap(CompressedMatrix *other) { std::swap(data_, other->data_); } + + friend class Matrix<float>; + friend class Matrix<double>; + private: + + // allocates data using new [], ensures byte alignment + // sufficient for float. + static void *AllocateData(int32 num_bytes); + + // the "format" will be 1 for the original format where each column has a + // PerColHeader, and 2 for the format now used for matrices with 8 or fewer + // rows, where everything is represented as 16-bit integers. + struct GlobalHeader { + int32 format; + float min_value; + float range; + int32 num_rows; + int32 num_cols; + }; + + static MatrixIndexT DataSize(const GlobalHeader &header); + + struct PerColHeader { + uint16 percentile_0; + uint16 percentile_25; + uint16 percentile_75; + uint16 percentile_100; + }; + + template<typename Real> + static void CompressColumn(const GlobalHeader &global_header, + const Real *data, MatrixIndexT stride, + int32 num_rows, PerColHeader *header, + unsigned char *byte_data); + template<typename Real> + static void ComputeColHeader(const GlobalHeader &global_header, + const Real *data, MatrixIndexT stride, + int32 num_rows, PerColHeader *header); + + static inline uint16 FloatToUint16(const GlobalHeader &global_header, + float value); + + static inline float Uint16ToFloat(const GlobalHeader &global_header, + uint16 value); + static inline unsigned char FloatToChar(float p0, float p25, + float p75, float p100, + float value); + static inline float CharToFloat(float p0, float p25, + float p75, float p100, + unsigned char value); + + void Destroy(); + + void *data_; // first GlobalHeader, then PerColHeader (repeated), then + // the byte data for each column (repeated). Note: don't intersperse + // the byte data with the PerColHeaders, because of alignment issues. + +}; + + +/// @} end of \addtogroup matrix_group + + +} // namespace kaldi + + +#endif // KALDI_MATRIX_COMPRESSED_MATRIX_H_ diff --git a/kaldi_io/src/kaldi/matrix/jama-eig.h b/kaldi_io/src/kaldi/matrix/jama-eig.h new file mode 100644 index 0000000..c7278bc --- /dev/null +++ b/kaldi_io/src/kaldi/matrix/jama-eig.h @@ -0,0 +1,924 @@ +// matrix/jama-eig.h + +// Copyright 2009-2011 Microsoft Corporation + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at + +// http://www.apache.org/licenses/LICENSE-2.0 + +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +// This file consists of a port and modification of materials from +// JAMA: A Java Matrix Package +// under the following notice: This software is a cooperative product of +// The MathWorks and the National Institute of Standards and Technology (NIST) +// which has been released to the public. This notice and the original code are +// available at http://math.nist.gov/javanumerics/jama/domain.notice + + + +#ifndef KALDI_MATRIX_JAMA_EIG_H_ +#define KALDI_MATRIX_JAMA_EIG_H_ 1 + +#include "matrix/kaldi-matrix.h" + +namespace kaldi { + +// This class is not to be used externally. See the Eig function in the Matrix +// class in kaldi-matrix.h. This is the external interface. + +template<typename Real> class EigenvalueDecomposition { + // This class is based on the EigenvalueDecomposition class from the JAMA + // library (version 1.0.2). + public: + EigenvalueDecomposition(const MatrixBase<Real> &A); + + ~EigenvalueDecomposition(); // free memory. + + void GetV(MatrixBase<Real> *V_out) { // V is what we call P externally; it's the matrix of + // eigenvectors. + KALDI_ASSERT(V_out->NumRows() == static_cast<MatrixIndexT>(n_) + && V_out->NumCols() == static_cast<MatrixIndexT>(n_)); + for (int i = 0; i < n_; i++) + for (int j = 0; j < n_; j++) + (*V_out)(i, j) = V(i, j); // V(i, j) is member function. + } + void GetRealEigenvalues(VectorBase<Real> *r_out) { + // returns real part of eigenvalues. + KALDI_ASSERT(r_out->Dim() == static_cast<MatrixIndexT>(n_)); + for (int i = 0; i < n_; i++) + (*r_out)(i) = d_[i]; + } + void GetImagEigenvalues(VectorBase<Real> *i_out) { + // returns imaginary part of eigenvalues. + KALDI_ASSERT(i_out->Dim() == static_cast<MatrixIndexT>(n_)); + for (int i = 0; i < n_; i++) + (*i_out)(i) = e_[i]; + } + private: + + inline Real &H(int r, int c) { return H_[r*n_ + c]; } + inline Real &V(int r, int c) { return V_[r*n_ + c]; } + + // complex division + inline static void cdiv(Real xr, Real xi, Real yr, Real yi, Real *cdivr, Real *cdivi) { + Real r, d; + if (std::abs(yr) > std::abs(yi)) { + r = yi/yr; + d = yr + r*yi; + *cdivr = (xr + r*xi)/d; + *cdivi = (xi - r*xr)/d; + } else { + r = yr/yi; + d = yi + r*yr; + *cdivr = (r*xr + xi)/d; + *cdivi = (r*xi - xr)/d; + } + } + + // Nonsymmetric reduction from Hessenberg to real Schur form. + void Hqr2 (); + + + int n_; // matrix dimension. + + Real *d_, *e_; // real and imaginary parts of eigenvalues. + Real *V_; // the eigenvectors (P in our external notation) + Real *H_; // the nonsymmetric Hessenberg form. + Real *ort_; // working storage for nonsymmetric algorithm. + + // Symmetric Householder reduction to tridiagonal form. + void Tred2 (); + + // Symmetric tridiagonal QL algorithm. + void Tql2 (); + + // Nonsymmetric reduction to Hessenberg form. + void Orthes (); + +}; + +template class EigenvalueDecomposition<float>; // force instantiation. +template class EigenvalueDecomposition<double>; // force instantiation. + +template<typename Real> void EigenvalueDecomposition<Real>::Tred2() { + // This is derived from the Algol procedures tred2 by + // Bowdler, Martin, Reinsch, and Wilkinson, Handbook for + // Auto. Comp., Vol.ii-Linear Algebra, and the corresponding + // Fortran subroutine in EISPACK. + + for (int j = 0; j < n_; j++) { + d_[j] = V(n_-1, j); + } + + // Householder reduction to tridiagonal form. + + for (int i = n_-1; i > 0; i--) { + + // Scale to avoid under/overflow. + + Real scale = 0.0; + Real h = 0.0; + for (int k = 0; k < i; k++) { + scale = scale + std::abs(d_[k]); + } + if (scale == 0.0) { + e_[i] = d_[i-1]; + for (int j = 0; j < i; j++) { + d_[j] = V(i-1, j); + V(i, j) = 0.0; + V(j, i) = 0.0; + } + } else { + + // Generate Householder vector. + + for (int k = 0; k < i; k++) { + d_[k] /= scale; + h += d_[k] * d_[k]; + } + Real f = d_[i-1]; + Real g = std::sqrt(h); + if (f > 0) { + g = -g; + } + e_[i] = scale * g; + h = h - f * g; + d_[i-1] = f - g; + for (int j = 0; j < i; j++) { + e_[j] = 0.0; + } + + // Apply similarity transformation to remaining columns. + + for (int j = 0; j < i; j++) { + f = d_[j]; + V(j, i) = f; + g =e_[j] + V(j, j) * f; + for (int k = j+1; k <= i-1; k++) { + g += V(k, j) * d_[k]; + e_[k] += V(k, j) * f; + } + e_[j] = g; + } + f = 0.0; + for (int j = 0; j < i; j++) { + e_[j] /= h; + f += e_[j] * d_[j]; + } + Real hh = f / (h + h); + for (int j = 0; j < i; j++) { + e_[j] -= hh * d_[j]; + } + for (int j = 0; j < i; j++) { + f = d_[j]; + g = e_[j]; + for (int k = j; k <= i-1; k++) { + V(k, j) -= (f * e_[k] + g * d_[k]); + } + d_[j] = V(i-1, j); + V(i, j) = 0.0; + } + } + d_[i] = h; + } + + // Accumulate transformations. + + for (int i = 0; i < n_-1; i++) { + V(n_-1, i) = V(i, i); + V(i, i) = 1.0; + Real h = d_[i+1]; + if (h != 0.0) { + for (int k = 0; k <= i; k++) { + d_[k] = V(k, i+1) / h; + } + for (int j = 0; j <= i; j++) { + Real g = 0.0; + for (int k = 0; k <= i; k++) { + g += V(k, i+1) * V(k, j); + } + for (int k = 0; k <= i; k++) { + V(k, j) -= g * d_[k]; + } + } + } + for (int k = 0; k <= i; k++) { + V(k, i+1) = 0.0; + } + } + for (int j = 0; j < n_; j++) { + d_[j] = V(n_-1, j); + V(n_-1, j) = 0.0; + } + V(n_-1, n_-1) = 1.0; + e_[0] = 0.0; +} + +template<typename Real> void EigenvalueDecomposition<Real>::Tql2() { + // This is derived from the Algol procedures tql2, by + // Bowdler, Martin, Reinsch, and Wilkinson, Handbook for + // Auto. Comp., Vol.ii-Linear Algebra, and the corresponding + // Fortran subroutine in EISPACK. + + for (int i = 1; i < n_; i++) { + e_[i-1] = e_[i]; + } + e_[n_-1] = 0.0; + + Real f = 0.0; + Real tst1 = 0.0; + Real eps = std::numeric_limits<Real>::epsilon(); + for (int l = 0; l < n_; l++) { + + // Find small subdiagonal element + + tst1 = std::max(tst1, std::abs(d_[l]) + std::abs(e_[l])); + int m = l; + while (m < n_) { + if (std::abs(e_[m]) <= eps*tst1) { + break; + } + m++; + } + + // If m == l, d_[l] is an eigenvalue, + // otherwise, iterate. + + if (m > l) { + int iter = 0; + do { + iter = iter + 1; // (Could check iteration count here.) + + // Compute implicit shift + + Real g = d_[l]; + Real p = (d_[l+1] - g) / (2.0 *e_[l]); + Real r = Hypot(p, static_cast<Real>(1.0)); // This is a Kaldi version of hypot that works with templates. + if (p < 0) { + r = -r; + } + d_[l] =e_[l] / (p + r); + d_[l+1] =e_[l] * (p + r); + Real dl1 = d_[l+1]; + Real h = g - d_[l]; + for (int i = l+2; i < n_; i++) { + d_[i] -= h; + } + f = f + h; + + // Implicit QL transformation. + + p = d_[m]; + Real c = 1.0; + Real c2 = c; + Real c3 = c; + Real el1 =e_[l+1]; + Real s = 0.0; + Real s2 = 0.0; + for (int i = m-1; i >= l; i--) { + c3 = c2; + c2 = c; + s2 = s; + g = c *e_[i]; + h = c * p; + r = Hypot(p, e_[i]); // This is a Kaldi version of Hypot that works with templates. + e_[i+1] = s * r; + s =e_[i] / r; + c = p / r; + p = c * d_[i] - s * g; + d_[i+1] = h + s * (c * g + s * d_[i]); + + // Accumulate transformation. + + for (int k = 0; k < n_; k++) { + h = V(k, i+1); + V(k, i+1) = s * V(k, i) + c * h; + V(k, i) = c * V(k, i) - s * h; + } + } + p = -s * s2 * c3 * el1 *e_[l] / dl1; + e_[l] = s * p; + d_[l] = c * p; + + // Check for convergence. + + } while (std::abs(e_[l]) > eps*tst1); + } + d_[l] = d_[l] + f; + e_[l] = 0.0; + } + + // Sort eigenvalues and corresponding vectors. + + for (int i = 0; i < n_-1; i++) { + int k = i; + Real p = d_[i]; + for (int j = i+1; j < n_; j++) { + if (d_[j] < p) { + k = j; + p = d_[j]; + } + } + if (k != i) { + d_[k] = d_[i]; + d_[i] = p; + for (int j = 0; j < n_; j++) { + p = V(j, i); + V(j, i) = V(j, k); + V(j, k) = p; + } + } + } +} + +template<typename Real> +void EigenvalueDecomposition<Real>::Orthes() { + + // This is derived from the Algol procedures orthes and ortran, + // by Martin and Wilkinson, Handbook for Auto. Comp., + // Vol.ii-Linear Algebra, and the corresponding + // Fortran subroutines in EISPACK. + + int low = 0; + int high = n_-1; + + for (int m = low+1; m <= high-1; m++) { + + // Scale column. + + Real scale = 0.0; + for (int i = m; i <= high; i++) { + scale = scale + std::abs(H(i, m-1)); + } + if (scale != 0.0) { + + // Compute Householder transformation. + + Real h = 0.0; + for (int i = high; i >= m; i--) { + ort_[i] = H(i, m-1)/scale; + h += ort_[i] * ort_[i]; + } + Real g = std::sqrt(h); + if (ort_[m] > 0) { + g = -g; + } + h = h - ort_[m] * g; + ort_[m] = ort_[m] - g; + + // Apply Householder similarity transformation + // H = (I-u*u'/h)*H*(I-u*u')/h) + + for (int j = m; j < n_; j++) { + Real f = 0.0; + for (int i = high; i >= m; i--) { + f += ort_[i]*H(i, j); + } + f = f/h; + for (int i = m; i <= high; i++) { + H(i, j) -= f*ort_[i]; + } + } + + for (int i = 0; i <= high; i++) { + Real f = 0.0; + for (int j = high; j >= m; j--) { + f += ort_[j]*H(i, j); + } + f = f/h; + for (int j = m; j <= high; j++) { + H(i, j) -= f*ort_[j]; + } + } + ort_[m] = scale*ort_[m]; + H(m, m-1) = scale*g; + } + } + + // Accumulate transformations (Algol's ortran). + + for (int i = 0; i < n_; i++) { + for (int j = 0; j < n_; j++) { + V(i, j) = (i == j ? 1.0 : 0.0); + } + } + + for (int m = high-1; m >= low+1; m--) { + if (H(m, m-1) != 0.0) { + for (int i = m+1; i <= high; i++) { + ort_[i] = H(i, m-1); + } + for (int j = m; j <= high; j++) { + Real g = 0.0; + for (int i = m; i <= high; i++) { + g += ort_[i] * V(i, j); + } + // Double division avoids possible underflow + g = (g / ort_[m]) / H(m, m-1); + for (int i = m; i <= high; i++) { + V(i, j) += g * ort_[i]; + } + } + } + } +} + +template<typename Real> void EigenvalueDecomposition<Real>::Hqr2() { + // This is derived from the Algol procedure hqr2, + // by Martin and Wilkinson, Handbook for Auto. Comp., + // Vol.ii-Linear Algebra, and the corresponding + // Fortran subroutine in EISPACK. + + int nn = n_; + int n = nn-1; + int low = 0; + int high = nn-1; + Real eps = std::numeric_limits<Real>::epsilon(); + Real exshift = 0.0; + Real p = 0, q = 0, r = 0, s = 0, z=0, t, w, x, y; + + // Store roots isolated by balanc and compute matrix norm + + Real norm = 0.0; + for (int i = 0; i < nn; i++) { + if (i < low || i > high) { + d_[i] = H(i, i); + e_[i] = 0.0; + } + for (int j = std::max(i-1, 0); j < nn; j++) { + norm = norm + std::abs(H(i, j)); + } + } + + // Outer loop over eigenvalue index + + int iter = 0; + while (n >= low) { + + // Look for single small sub-diagonal element + + int l = n; + while (l > low) { + s = std::abs(H(l-1, l-1)) + std::abs(H(l, l)); + if (s == 0.0) { + s = norm; + } + if (std::abs(H(l, l-1)) < eps * s) { + break; + } + l--; + } + + // Check for convergence + // One root found + + if (l == n) { + H(n, n) = H(n, n) + exshift; + d_[n] = H(n, n); + e_[n] = 0.0; + n--; + iter = 0; + + // Two roots found + + } else if (l == n-1) { + w = H(n, n-1) * H(n-1, n); + p = (H(n-1, n-1) - H(n, n)) / 2.0; + q = p * p + w; + z = std::sqrt(std::abs(q)); + H(n, n) = H(n, n) + exshift; + H(n-1, n-1) = H(n-1, n-1) + exshift; + x = H(n, n); + + // Real pair + + if (q >= 0) { + if (p >= 0) { + z = p + z; + } else { + z = p - z; + } + d_[n-1] = x + z; + d_[n] = d_[n-1]; + if (z != 0.0) { + d_[n] = x - w / z; + } + e_[n-1] = 0.0; + e_[n] = 0.0; + x = H(n, n-1); + s = std::abs(x) + std::abs(z); + p = x / s; + q = z / s; + r = std::sqrt(p * p+q * q); + p = p / r; + q = q / r; + + // Row modification + + for (int j = n-1; j < nn; j++) { + z = H(n-1, j); + H(n-1, j) = q * z + p * H(n, j); + H(n, j) = q * H(n, j) - p * z; + } + + // Column modification + + for (int i = 0; i <= n; i++) { + z = H(i, n-1); + H(i, n-1) = q * z + p * H(i, n); + H(i, n) = q * H(i, n) - p * z; + } + + // Accumulate transformations + + for (int i = low; i <= high; i++) { + z = V(i, n-1); + V(i, n-1) = q * z + p * V(i, n); + V(i, n) = q * V(i, n) - p * z; + } + + // Complex pair + + } else { + d_[n-1] = x + p; + d_[n] = x + p; + e_[n-1] = z; + e_[n] = -z; + } + n = n - 2; + iter = 0; + + // No convergence yet + + } else { + + // Form shift + + x = H(n, n); + y = 0.0; + w = 0.0; + if (l < n) { + y = H(n-1, n-1); + w = H(n, n-1) * H(n-1, n); + } + + // Wilkinson's original ad hoc shift + + if (iter == 10) { + exshift += x; + for (int i = low; i <= n; i++) { + H(i, i) -= x; + } + s = std::abs(H(n, n-1)) + std::abs(H(n-1, n-2)); + x = y = 0.75 * s; + w = -0.4375 * s * s; + } + + // MATLAB's new ad hoc shift + + if (iter == 30) { + s = (y - x) / 2.0; + s = s * s + w; + if (s > 0) { + s = std::sqrt(s); + if (y < x) { + s = -s; + } + s = x - w / ((y - x) / 2.0 + s); + for (int i = low; i <= n; i++) { + H(i, i) -= s; + } + exshift += s; + x = y = w = 0.964; + } + } + + iter = iter + 1; // (Could check iteration count here.) + + // Look for two consecutive small sub-diagonal elements + + int m = n-2; + while (m >= l) { + z = H(m, m); + r = x - z; + s = y - z; + p = (r * s - w) / H(m+1, m) + H(m, m+1); + q = H(m+1, m+1) - z - r - s; + r = H(m+2, m+1); + s = std::abs(p) + std::abs(q) + std::abs(r); + p = p / s; + q = q / s; + r = r / s; + if (m == l) { + break; + } + if (std::abs(H(m, m-1)) * (std::abs(q) + std::abs(r)) < + eps * (std::abs(p) * (std::abs(H(m-1, m-1)) + std::abs(z) + + std::abs(H(m+1, m+1))))) { + break; + } + m--; + } + + for (int i = m+2; i <= n; i++) { + H(i, i-2) = 0.0; + if (i > m+2) { + H(i, i-3) = 0.0; + } + } + + // Double QR step involving rows l:n and columns m:n + + for (int k = m; k <= n-1; k++) { + bool notlast = (k != n-1); + if (k != m) { + p = H(k, k-1); + q = H(k+1, k-1); + r = (notlast ? H(k+2, k-1) : 0.0); + x = std::abs(p) + std::abs(q) + std::abs(r); + if (x != 0.0) { + p = p / x; + q = q / x; + r = r / x; + } + } + if (x == 0.0) { + break; + } + s = std::sqrt(p * p + q * q + r * r); + if (p < 0) { + s = -s; + } + if (s != 0) { + if (k != m) { + H(k, k-1) = -s * x; + } else if (l != m) { + H(k, k-1) = -H(k, k-1); + } + p = p + s; + x = p / s; + y = q / s; + z = r / s; + q = q / p; + r = r / p; + + // Row modification + + for (int j = k; j < nn; j++) { + p = H(k, j) + q * H(k+1, j); + if (notlast) { + p = p + r * H(k+2, j); + H(k+2, j) = H(k+2, j) - p * z; + } + H(k, j) = H(k, j) - p * x; + H(k+1, j) = H(k+1, j) - p * y; + } + + // Column modification + + for (int i = 0; i <= std::min(n, k+3); i++) { + p = x * H(i, k) + y * H(i, k+1); + if (notlast) { + p = p + z * H(i, k+2); + H(i, k+2) = H(i, k+2) - p * r; + } + H(i, k) = H(i, k) - p; + H(i, k+1) = H(i, k+1) - p * q; + } + + // Accumulate transformations + + for (int i = low; i <= high; i++) { + p = x * V(i, k) + y * V(i, k+1); + if (notlast) { + p = p + z * V(i, k+2); + V(i, k+2) = V(i, k+2) - p * r; + } + V(i, k) = V(i, k) - p; + V(i, k+1) = V(i, k+1) - p * q; + } + } // (s != 0) + } // k loop + } // check convergence + } // while (n >= low) + + // Backsubstitute to find vectors of upper triangular form + + if (norm == 0.0) { + return; + } + + for (n = nn-1; n >= 0; n--) { + p = d_[n]; + q = e_[n]; + + // Real vector + + if (q == 0) { + int l = n; + H(n, n) = 1.0; + for (int i = n-1; i >= 0; i--) { + w = H(i, i) - p; + r = 0.0; + for (int j = l; j <= n; j++) { + r = r + H(i, j) * H(j, n); + } + if (e_[i] < 0.0) { + z = w; + s = r; + } else { + l = i; + if (e_[i] == 0.0) { + if (w != 0.0) { + H(i, n) = -r / w; + } else { + H(i, n) = -r / (eps * norm); + } + + // Solve real equations + + } else { + x = H(i, i+1); + y = H(i+1, i); + q = (d_[i] - p) * (d_[i] - p) +e_[i] *e_[i]; + t = (x * s - z * r) / q; + H(i, n) = t; + if (std::abs(x) > std::abs(z)) { + H(i+1, n) = (-r - w * t) / x; + } else { + H(i+1, n) = (-s - y * t) / z; + } + } + + // Overflow control + + t = std::abs(H(i, n)); + if ((eps * t) * t > 1) { + for (int j = i; j <= n; j++) { + H(j, n) = H(j, n) / t; + } + } + } + } + + // Complex vector + + } else if (q < 0) { + int l = n-1; + + // Last vector component imaginary so matrix is triangular + + if (std::abs(H(n, n-1)) > std::abs(H(n-1, n))) { + H(n-1, n-1) = q / H(n, n-1); + H(n-1, n) = -(H(n, n) - p) / H(n, n-1); + } else { + Real cdivr, cdivi; + cdiv(0.0, -H(n-1, n), H(n-1, n-1)-p, q, &cdivr, &cdivi); + H(n-1, n-1) = cdivr; + H(n-1, n) = cdivi; + } + H(n, n-1) = 0.0; + H(n, n) = 1.0; + for (int i = n-2; i >= 0; i--) { + Real ra, sa, vr, vi; + ra = 0.0; + sa = 0.0; + for (int j = l; j <= n; j++) { + ra = ra + H(i, j) * H(j, n-1); + sa = sa + H(i, j) * H(j, n); + } + w = H(i, i) - p; + + if (e_[i] < 0.0) { + z = w; + r = ra; + s = sa; + } else { + l = i; + if (e_[i] == 0) { + Real cdivr, cdivi; + cdiv(-ra, -sa, w, q, &cdivr, &cdivi); + H(i, n-1) = cdivr; + H(i, n) = cdivi; + } else { + Real cdivr, cdivi; + // Solve complex equations + + x = H(i, i+1); + y = H(i+1, i); + vr = (d_[i] - p) * (d_[i] - p) +e_[i] *e_[i] - q * q; + vi = (d_[i] - p) * 2.0 * q; + if (vr == 0.0 && vi == 0.0) { + vr = eps * norm * (std::abs(w) + std::abs(q) + + std::abs(x) + std::abs(y) + std::abs(z)); + } + cdiv(x*r-z*ra+q*sa, x*s-z*sa-q*ra, vr, vi, &cdivr, &cdivi); + H(i, n-1) = cdivr; + H(i, n) = cdivi; + if (std::abs(x) > (std::abs(z) + std::abs(q))) { + H(i+1, n-1) = (-ra - w * H(i, n-1) + q * H(i, n)) / x; + H(i+1, n) = (-sa - w * H(i, n) - q * H(i, n-1)) / x; + } else { + cdiv(-r-y*H(i, n-1), -s-y*H(i, n), z, q, &cdivr, &cdivi); + H(i+1, n-1) = cdivr; + H(i+1, n) = cdivi; + } + } + + // Overflow control + + t = std::max(std::abs(H(i, n-1)), std::abs(H(i, n))); + if ((eps * t) * t > 1) { + for (int j = i; j <= n; j++) { + H(j, n-1) = H(j, n-1) / t; + H(j, n) = H(j, n) / t; + } + } + } + } + } + } + + // Vectors of isolated roots + + for (int i = 0; i < nn; i++) { + if (i < low || i > high) { + for (int j = i; j < nn; j++) { + V(i, j) = H(i, j); + } + } + } + + // Back transformation to get eigenvectors of original matrix + + for (int j = nn-1; j >= low; j--) { + for (int i = low; i <= high; i++) { + z = 0.0; + for (int k = low; k <= std::min(j, high); k++) { + z = z + V(i, k) * H(k, j); + } + V(i, j) = z; + } + } +} + +template<typename Real> +EigenvalueDecomposition<Real>::EigenvalueDecomposition(const MatrixBase<Real> &A) { + KALDI_ASSERT(A.NumCols() == A.NumRows() && A.NumCols() >= 1); + n_ = A.NumRows(); + V_ = new Real[n_*n_]; + d_ = new Real[n_]; + e_ = new Real[n_]; + H_ = NULL; + ort_ = NULL; + if (A.IsSymmetric(0.0)) { + + for (int i = 0; i < n_; i++) + for (int j = 0; j < n_; j++) + V(i, j) = A(i, j); // Note that V(i, j) is a member function; A(i, j) is an operator + // of the matrix A. + // Tridiagonalize. + Tred2(); + + // Diagonalize. + Tql2(); + } else { + H_ = new Real[n_*n_]; + ort_ = new Real[n_]; + for (int i = 0; i < n_; i++) + for (int j = 0; j < n_; j++) + H(i, j) = A(i, j); // as before: H is member function, A(i, j) is operator of matrix. + + // Reduce to Hessenberg form. + Orthes(); + + // Reduce Hessenberg to real Schur form. + Hqr2(); + } +} + +template<typename Real> +EigenvalueDecomposition<Real>::~EigenvalueDecomposition() { + delete [] d_; + delete [] e_; + delete [] V_; + if (H_) delete [] H_; + if (ort_) delete [] ort_; +} + +// see function MatrixBase<Real>::Eig in kaldi-matrix.cc + + +} // namespace kaldi + +#endif // KALDI_MATRIX_JAMA_EIG_H_ diff --git a/kaldi_io/src/kaldi/matrix/jama-svd.h b/kaldi_io/src/kaldi/matrix/jama-svd.h new file mode 100644 index 0000000..8304dac --- /dev/null +++ b/kaldi_io/src/kaldi/matrix/jama-svd.h @@ -0,0 +1,531 @@ +// matrix/jama-svd.h + +// Copyright 2009-2011 Microsoft Corporation + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +// This file consists of a port and modification of materials from +// JAMA: A Java Matrix Package +// under the following notice: This software is a cooperative product of +// The MathWorks and the National Institute of Standards and Technology (NIST) +// which has been released to the public. This notice and the original code are +// available at http://math.nist.gov/javanumerics/jama/domain.notice + + +#ifndef KALDI_MATRIX_JAMA_SVD_H_ +#define KALDI_MATRIX_JAMA_SVD_H_ 1 + + +#include "matrix/kaldi-matrix.h" +#include "matrix/sp-matrix.h" +#include "matrix/cblas-wrappers.h" + +namespace kaldi { + +#if defined(HAVE_ATLAS) || defined(USE_KALDI_SVD) +// using ATLAS as our math library, which doesn't have SVD -> need +// to implement it. + +// This routine is a modified form of jama_svd.h which is part of the TNT distribution. +// (originally comes from JAMA). + +/** Singular Value Decomposition. + * <P> + * For an m-by-n matrix A with m >= n, the singular value decomposition is + * an m-by-n orthogonal matrix U, an n-by-n diagonal matrix S, and + * an n-by-n orthogonal matrix V so that A = U*S*V'. + * <P> + * The singular values, sigma[k] = S(k, k), are ordered so that + * sigma[0] >= sigma[1] >= ... >= sigma[n-1]. + * <P> + * The singular value decompostion always exists, so the constructor will + * never fail. The matrix condition number and the effective numerical + * rank can be computed from this decomposition. + + * <p> + * (Adapted from JAMA, a Java Matrix Library, developed by jointly + * by the Mathworks and NIST; see http://math.nist.gov/javanumerics/jama). + */ + + +template<typename Real> +bool MatrixBase<Real>::JamaSvd(VectorBase<Real> *s_in, + MatrixBase<Real> *U_in, + MatrixBase<Real> *V_in) { // Destructive! + KALDI_ASSERT(s_in != NULL && U_in != this && V_in != this); + int wantu = (U_in != NULL), wantv = (V_in != NULL); + Matrix<Real> Utmp, Vtmp; + MatrixBase<Real> &U = (U_in ? *U_in : Utmp), &V = (V_in ? *V_in : Vtmp); + VectorBase<Real> &s = *s_in; + + int m = num_rows_, n = num_cols_; + KALDI_ASSERT(m>=n && m != 0 && n != 0); + if (wantu) KALDI_ASSERT((int)U.num_rows_ == m && (int)U.num_cols_ == n); + if (wantv) KALDI_ASSERT((int)V.num_rows_ == n && (int)V.num_cols_ == n); + KALDI_ASSERT((int)s.Dim() == n); // n<=m so n is min. + + int nu = n; + U.SetZero(); // make sure all zero. + Vector<Real> e(n); + Vector<Real> work(m); + MatrixBase<Real> &A(*this); + Real *adata = A.Data(), *workdata = work.Data(), *edata = e.Data(), + *udata = U.Data(), *vdata = V.Data(); + int astride = static_cast<int>(A.Stride()), + ustride = static_cast<int>(U.Stride()), + vstride = static_cast<int>(V.Stride()); + int i = 0, j = 0, k = 0; + + // Reduce A to bidiagonal form, storing the diagonal elements + // in s and the super-diagonal elements in e. + + int nct = std::min(m-1, n); + int nrt = std::max(0, std::min(n-2, m)); + for (k = 0; k < std::max(nct, nrt); k++) { + if (k < nct) { + + // Compute the transformation for the k-th column and + // place the k-th diagonal in s(k). + // Compute 2-norm of k-th column without under/overflow. + s(k) = 0; + for (i = k; i < m; i++) { + s(k) = hypot(s(k), A(i, k)); + } + if (s(k) != 0.0) { + if (A(k, k) < 0.0) { + s(k) = -s(k); + } + for (i = k; i < m; i++) { + A(i, k) /= s(k); + } + A(k, k) += 1.0; + } + s(k) = -s(k); + } + for (j = k+1; j < n; j++) { + if ((k < nct) && (s(k) != 0.0)) { + + // Apply the transformation. + + Real t = cblas_Xdot(m - k, adata + astride*k + k, astride, + adata + astride*k + j, astride); + /*for (i = k; i < m; i++) { + t += adata[i*astride + k]*adata[i*astride + j]; // A(i, k)*A(i, j); // 3 + }*/ + t = -t/A(k, k); + cblas_Xaxpy(m - k, t, adata + k*astride + k, astride, + adata + k*astride + j, astride); + /*for (i = k; i < m; i++) { + adata[i*astride + j] += t*adata[i*astride + k]; // A(i, j) += t*A(i, k); // 5 + }*/ + } + + // Place the k-th row of A into e for the + // subsequent calculation of the row transformation. + + e(j) = A(k, j); + } + if (wantu & (k < nct)) { + + // Place the transformation in U for subsequent back + // multiplication. + + for (i = k; i < m; i++) { + U(i, k) = A(i, k); + } + } + if (k < nrt) { + + // Compute the k-th row transformation and place the + // k-th super-diagonal in e(k). + // Compute 2-norm without under/overflow. + e(k) = 0; + for (i = k+1; i < n; i++) { + e(k) = hypot(e(k), e(i)); + } + if (e(k) != 0.0) { + if (e(k+1) < 0.0) { + e(k) = -e(k); + } + for (i = k+1; i < n; i++) { + e(i) /= e(k); + } + e(k+1) += 1.0; + } + e(k) = -e(k); + if ((k+1 < m) & (e(k) != 0.0)) { + + // Apply the transformation. + + for (i = k+1; i < m; i++) { + work(i) = 0.0; + } + for (j = k+1; j < n; j++) { + for (i = k+1; i < m; i++) { + workdata[i] += edata[j] * adata[i*astride + j]; // work(i) += e(j)*A(i, j); // 5 + } + } + for (j = k+1; j < n; j++) { + Real t(-e(j)/e(k+1)); + cblas_Xaxpy(m - (k+1), t, workdata + (k+1), 1, + adata + (k+1)*astride + j, astride); + /* + for (i = k+1; i < m; i++) { + adata[i*astride + j] += t*workdata[i]; // A(i, j) += t*work(i); // 5 + }*/ + } + } + if (wantv) { + + // Place the transformation in V for subsequent + // back multiplication. + + for (i = k+1; i < n; i++) { + V(i, k) = e(i); + } + } + } + } + + // Set up the final bidiagonal matrix or order p. + + int p = std::min(n, m+1); + if (nct < n) { + s(nct) = A(nct, nct); + } + if (m < p) { + s(p-1) = 0.0; + } + if (nrt+1 < p) { + e(nrt) = A(nrt, p-1); + } + e(p-1) = 0.0; + + // If required, generate U. + + if (wantu) { + for (j = nct; j < nu; j++) { + for (i = 0; i < m; i++) { + U(i, j) = 0.0; + } + U(j, j) = 1.0; + } + for (k = nct-1; k >= 0; k--) { + if (s(k) != 0.0) { + for (j = k+1; j < nu; j++) { + Real t = cblas_Xdot(m - k, udata + k*ustride + k, ustride, udata + k*ustride + j, ustride); + //for (i = k; i < m; i++) { + // t += udata[i*ustride + k]*udata[i*ustride + j]; // t += U(i, k)*U(i, j); // 8 + // } + t = -t/U(k, k); + cblas_Xaxpy(m - k, t, udata + ustride*k + k, ustride, + udata + k*ustride + j, ustride); + /*for (i = k; i < m; i++) { + udata[i*ustride + j] += t*udata[i*ustride + k]; // U(i, j) += t*U(i, k); // 4 + }*/ + } + for (i = k; i < m; i++ ) { + U(i, k) = -U(i, k); + } + U(k, k) = 1.0 + U(k, k); + for (i = 0; i < k-1; i++) { + U(i, k) = 0.0; + } + } else { + for (i = 0; i < m; i++) { + U(i, k) = 0.0; + } + U(k, k) = 1.0; + } + } + } + + // If required, generate V. + + if (wantv) { + for (k = n-1; k >= 0; k--) { + if ((k < nrt) & (e(k) != 0.0)) { + for (j = k+1; j < nu; j++) { + Real t = cblas_Xdot(n - (k+1), vdata + (k+1)*vstride + k, vstride, + vdata + (k+1)*vstride + j, vstride); + /*Real t (0.0); + for (i = k+1; i < n; i++) { + t += vdata[i*vstride + k]*vdata[i*vstride + j]; // t += V(i, k)*V(i, j); // 7 + }*/ + t = -t/V(k+1, k); + cblas_Xaxpy(n - (k+1), t, vdata + (k+1)*vstride + k, vstride, + vdata + (k+1)*vstride + j, vstride); + /*for (i = k+1; i < n; i++) { + vdata[i*vstride + j] += t*vdata[i*vstride + k]; // V(i, j) += t*V(i, k); // 7 + }*/ + } + } + for (i = 0; i < n; i++) { + V(i, k) = 0.0; + } + V(k, k) = 1.0; + } + } + + // Main iteration loop for the singular values. + + int pp = p-1; + int iter = 0; + // note: -52.0 is from Jama code; the -23 is the extension + // to float, because mantissa length in (double, float) + // is (52, 23) bits respectively. + Real eps(pow(2.0, sizeof(Real) == 4 ? -23.0 : -52.0)); + // Note: the -966 was taken from Jama code, but the -120 is a guess + // of how to extend this to float... the exponent in double goes + // from -1022 .. 1023, and in float from -126..127. I'm not sure + // what the significance of 966 is, so -120 just represents a number + // that's a bit less negative than -126. If we get convergence + // failure in float only, this may mean that we have to make the + // -120 value less negative. + Real tiny(pow(2.0, sizeof(Real) == 4 ? -120.0: -966.0 )); + + while (p > 0) { + int k = 0; + int kase = 0; + + if (iter == 500 || iter == 750) { + KALDI_WARN << "Svd taking a long time: making convergence criterion less exact."; + eps = pow(static_cast<Real>(0.8), eps); + tiny = pow(static_cast<Real>(0.8), tiny); + } + if (iter > 1000) { + KALDI_WARN << "Svd not converging on matrix of size " << m << " by " <<n; + return false; + } + + // This section of the program inspects for + // negligible elements in the s and e arrays. On + // completion the variables kase and k are set as follows. + + // kase = 1 if s(p) and e(k-1) are negligible and k < p + // kase = 2 if s(k) is negligible and k < p + // kase = 3 if e(k-1) is negligible, k < p, and + // s(k), ..., s(p) are not negligible (qr step). + // kase = 4 if e(p-1) is negligible (convergence). + + for (k = p-2; k >= -1; k--) { + if (k == -1) { + break; + } + if (std::abs(e(k)) <= + tiny + eps*(std::abs(s(k)) + std::abs(s(k+1)))) { + e(k) = 0.0; + break; + } + } + if (k == p-2) { + kase = 4; + } else { + int ks; + for (ks = p-1; ks >= k; ks--) { + if (ks == k) { + break; + } + Real t( (ks != p ? std::abs(e(ks)) : 0.) + + (ks != k+1 ? std::abs(e(ks-1)) : 0.)); + if (std::abs(s(ks)) <= tiny + eps*t) { + s(ks) = 0.0; + break; + } + } + if (ks == k) { + kase = 3; + } else if (ks == p-1) { + kase = 1; + } else { + kase = 2; + k = ks; + } + } + k++; + + // Perform the task indicated by kase. + + switch (kase) { + + // Deflate negligible s(p). + + case 1: { + Real f(e(p-2)); + e(p-2) = 0.0; + for (j = p-2; j >= k; j--) { + Real t( hypot(s(j), f)); + Real cs(s(j)/t); + Real sn(f/t); + s(j) = t; + if (j != k) { + f = -sn*e(j-1); + e(j-1) = cs*e(j-1); + } + if (wantv) { + for (i = 0; i < n; i++) { + t = cs*V(i, j) + sn*V(i, p-1); + V(i, p-1) = -sn*V(i, j) + cs*V(i, p-1); + V(i, j) = t; + } + } + } + } + break; + + // Split at negligible s(k). + + case 2: { + Real f(e(k-1)); + e(k-1) = 0.0; + for (j = k; j < p; j++) { + Real t(hypot(s(j), f)); + Real cs( s(j)/t); + Real sn(f/t); + s(j) = t; + f = -sn*e(j); + e(j) = cs*e(j); + if (wantu) { + for (i = 0; i < m; i++) { + t = cs*U(i, j) + sn*U(i, k-1); + U(i, k-1) = -sn*U(i, j) + cs*U(i, k-1); + U(i, j) = t; + } + } + } + } + break; + + // Perform one qr step. + + case 3: { + + // Calculate the shift. + + Real scale = std::max(std::max(std::max(std::max( + std::abs(s(p-1)), std::abs(s(p-2))), std::abs(e(p-2))), + std::abs(s(k))), std::abs(e(k))); + Real sp = s(p-1)/scale; + Real spm1 = s(p-2)/scale; + Real epm1 = e(p-2)/scale; + Real sk = s(k)/scale; + Real ek = e(k)/scale; + Real b = ((spm1 + sp)*(spm1 - sp) + epm1*epm1)/2.0; + Real c = (sp*epm1)*(sp*epm1); + Real shift = 0.0; + if ((b != 0.0) || (c != 0.0)) { + shift = std::sqrt(b*b + c); + if (b < 0.0) { + shift = -shift; + } + shift = c/(b + shift); + } + Real f = (sk + sp)*(sk - sp) + shift; + Real g = sk*ek; + + // Chase zeros. + + for (j = k; j < p-1; j++) { + Real t = hypot(f, g); + Real cs = f/t; + Real sn = g/t; + if (j != k) { + e(j-1) = t; + } + f = cs*s(j) + sn*e(j); + e(j) = cs*e(j) - sn*s(j); + g = sn*s(j+1); + s(j+1) = cs*s(j+1); + if (wantv) { + cblas_Xrot(n, vdata + j, vstride, vdata + j+1, vstride, cs, sn); + /*for (i = 0; i < n; i++) { + t = cs*vdata[i*vstride + j] + sn*vdata[i*vstride + j+1]; // t = cs*V(i, j) + sn*V(i, j+1); // 13 + vdata[i*vstride + j+1] = -sn*vdata[i*vstride + j] + cs*vdata[i*vstride + j+1]; // V(i, j+1) = -sn*V(i, j) + cs*V(i, j+1); // 5 + vdata[i*vstride + j] = t; // V(i, j) = t; // 4 + }*/ + } + t = hypot(f, g); + cs = f/t; + sn = g/t; + s(j) = t; + f = cs*e(j) + sn*s(j+1); + s(j+1) = -sn*e(j) + cs*s(j+1); + g = sn*e(j+1); + e(j+1) = cs*e(j+1); + if (wantu && (j < m-1)) { + cblas_Xrot(m, udata + j, ustride, udata + j+1, ustride, cs, sn); + /*for (i = 0; i < m; i++) { + t = cs*udata[i*ustride + j] + sn*udata[i*ustride + j+1]; // t = cs*U(i, j) + sn*U(i, j+1); // 7 + udata[i*ustride + j+1] = -sn*udata[i*ustride + j] +cs*udata[i*ustride + j+1]; // U(i, j+1) = -sn*U(i, j) + cs*U(i, j+1); // 8 + udata[i*ustride + j] = t; // U(i, j) = t; // 1 + }*/ + } + } + e(p-2) = f; + iter = iter + 1; + } + break; + + // Convergence. + + case 4: { + + // Make the singular values positive. + + if (s(k) <= 0.0) { + s(k) = (s(k) < 0.0 ? -s(k) : 0.0); + if (wantv) { + for (i = 0; i <= pp; i++) { + V(i, k) = -V(i, k); + } + } + } + + // Order the singular values. + + while (k < pp) { + if (s(k) >= s(k+1)) { + break; + } + Real t = s(k); + s(k) = s(k+1); + s(k+1) = t; + if (wantv && (k < n-1)) { + for (i = 0; i < n; i++) { + t = V(i, k+1); V(i, k+1) = V(i, k); V(i, k) = t; + } + } + if (wantu && (k < m-1)) { + for (i = 0; i < m; i++) { + t = U(i, k+1); U(i, k+1) = U(i, k); U(i, k) = t; + } + } + k++; + } + iter = 0; + p--; + } + break; + } + } + return true; +} + +#endif // defined(HAVE_ATLAS) || defined(USE_KALDI_SVD) + +} // namespace kaldi + +#endif // KALDI_MATRIX_JAMA_SVD_H_ diff --git a/kaldi_io/src/kaldi/matrix/kaldi-blas.h b/kaldi_io/src/kaldi/matrix/kaldi-blas.h new file mode 100644 index 0000000..5d25ab8 --- /dev/null +++ b/kaldi_io/src/kaldi/matrix/kaldi-blas.h @@ -0,0 +1,132 @@ +// matrix/kaldi-blas.h + +// Copyright 2009-2011 Ondrej Glembek; Microsoft Corporation + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at + +// http://www.apache.org/licenses/LICENSE-2.0 + +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. +#ifndef KALDI_MATRIX_KALDI_BLAS_H_ +#define KALDI_MATRIX_KALDI_BLAS_H_ + +// This file handles the #includes for BLAS, LAPACK and so on. +// It manipulates the declarations into a common format that kaldi can handle. +// However, the kaldi code will check whether HAVE_ATLAS is defined as that +// code is called a bit differently from CLAPACK that comes from other sources. + +// There are three alternatives: +// (i) you have ATLAS, which includes the ATLAS implementation of CBLAS +// plus a subset of CLAPACK (but with clapack_ in the function declarations). +// In this case, define HAVE_ATLAS and make sure the relevant directories are +// in the include path. + +// (ii) you have CBLAS (some implementation thereof) plus CLAPACK. +// In this case, define HAVE_CLAPACK. +// [Since CLAPACK depends on BLAS, the presence of BLAS is implicit]. + +// (iii) you have the MKL library, which includes CLAPACK and CBLAS. + +// Note that if we are using ATLAS, no Svd implementation is supplied, +// so we define HAVE_Svd to be zero and this directs our implementation to +// supply its own "by hand" implementation which is based on TNT code. + + + + +#if (defined(HAVE_CLAPACK) && (defined(HAVE_ATLAS) || defined(HAVE_MKL))) \ + || (defined(HAVE_ATLAS) && defined(HAVE_MKL)) +#error "Do not define more than one of HAVE_CLAPACK, HAVE_ATLAS and HAVE_MKL" +#endif + +#ifdef HAVE_ATLAS + extern "C" { + #include <cblas.h> + #include <clapack.h> + } +#elif defined(HAVE_CLAPACK) + #ifdef __APPLE__ + #ifndef __has_extension + #define __has_extension(x) 0 + #endif + #define vImage_Utilities_h + #define vImage_CVUtilities_h + #include <Accelerate/Accelerate.h> + typedef __CLPK_integer integer; + typedef __CLPK_logical logical; + typedef __CLPK_real real; + typedef __CLPK_doublereal doublereal; + typedef __CLPK_complex complex; + typedef __CLPK_doublecomplex doublecomplex; + typedef __CLPK_ftnlen ftnlen; + #else + extern "C" { + // May be in /usr/[local]/include if installed; else this uses the one + // from the tools/CLAPACK_include directory. + #include <cblas.h> + #include <f2c.h> + #include <clapack.h> + + // get rid of macros from f2c.h -- these are dangerous. + #undef abs + #undef dabs + #undef min + #undef max + #undef dmin + #undef dmax + #undef bit_test + #undef bit_clear + #undef bit_set + } + #endif +#elif defined(HAVE_MKL) + extern "C" { + #include <mkl.h> + } +#elif defined(HAVE_OPENBLAS) + // getting cblas.h and lapacke.h from <openblas-install-dir>/. + // putting in "" not <> to search -I before system libraries. + #include "cblas.h" + #include "lapacke.h" + #undef I + #undef complex + // get rid of macros from f2c.h -- these are dangerous. + #undef abs + #undef dabs + #undef min + #undef max + #undef dmin + #undef dmax + #undef bit_test + #undef bit_clear + #undef bit_set +#else + #error "You need to define (using the preprocessor) either HAVE_CLAPACK or HAVE_ATLAS or HAVE_MKL (but not more than one)" +#endif + +#ifdef HAVE_OPENBLAS +typedef int KaldiBlasInt; // try int. +#endif +#ifdef HAVE_CLAPACK +typedef integer KaldiBlasInt; +#endif +#ifdef HAVE_MKL +typedef MKL_INT KaldiBlasInt; +#endif + +#ifdef HAVE_ATLAS +// in this case there is no need for KaldiBlasInt-- this typedef is only needed +// for Svd code which is not included in ATLAS (we re-implement it). +#endif + + +#endif // KALDI_MATRIX_KALDI_BLAS_H_ diff --git a/kaldi_io/src/kaldi/matrix/kaldi-gpsr.h b/kaldi_io/src/kaldi/matrix/kaldi-gpsr.h new file mode 100644 index 0000000..c294bdd --- /dev/null +++ b/kaldi_io/src/kaldi/matrix/kaldi-gpsr.h @@ -0,0 +1,166 @@ +// matrix/kaldi-gpsr.h + +// Copyright 2012 Arnab Ghoshal + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#ifndef KALDI_MATRIX_KALDI_GPSR_H_ +#define KALDI_MATRIX_KALDI_GPSR_H_ + +#include <string> +#include <vector> + +#include "base/kaldi-common.h" +#include "matrix/matrix-lib.h" +#include "itf/options-itf.h" + +namespace kaldi { + +/// This is an implementation of the GPSR algorithm. See, Figueiredo, Nowak and +/// Wright, "Gradient Projection for Sparse Reconstruction: Application to +/// Compressed Sensing and Other Inverse Problems," IEEE Journal of Selected +/// Topics in Signal Processing, vol. 1, no. 4, pp. 586-597, 2007. +/// http://dx.doi.org/10.1109/JSTSP.2007.910281 + +/// The GPSR algorithm, described in Figueiredo, et al., 2007, solves: +/// \f[ \min_x 0.5 * ||y - Ax||_2^2 + \tau ||x||_1, \f] +/// where \f$ x \in R^n, y \in R^k \f$, and \f$ A \in R^{n \times k} \f$. +/// In this implementation, we solve: +/// \f[ \min_x 0.5 * x^T H x - g^T x + \tau ||x||_1, \f] +/// which is the more natural form in which such problems arise in our case. +/// Here, \f$ H = A^T A \in R^{n \times n} \f$ and \f$ g = A^T y \in R^n \f$. + + +/** \struct GpsrConfig + * Configuration variables needed in the GPSR algorithm. + */ +struct GpsrConfig { + bool use_gpsr_bb; ///< Use the Barzilai-Borwein gradient projection method + + /// The following options are common to both the basic & Barzilai-Borwein + /// versions of GPSR + double stop_thresh; ///< Stopping threshold + int32 max_iters; ///< Maximum number of iterations + double gpsr_tau; ///< Regularization scale + double alpha_min; ///< Minimum step size in the feasible direction + double alpha_max; ///< Maximum step size in the feasible direction + double max_sparsity; ///< Maximum percentage of dimensions set to 0 + double tau_reduction; ///< Multiply tau by this if max_sparsity reached + + /// The following options are for the backtracking line search in basic GPSR. + /// Step size reduction factor in backtracking line search. 0 < beta < 1 + double gpsr_beta; + /// Improvement factor in backtracking line search, i.e. the new objective + /// function must be less than the old one by mu times the gradient in the + /// direction of the change in x. 0 < mu < 1 + double gpsr_mu; + int32 max_iters_backtrak; ///< Max iterations for backtracking line search + + bool debias; ///< Do debiasing, i.e. unconstrained optimization at the end + double stop_thresh_debias; ///< Stopping threshold for debiasing stage + int32 max_iters_debias; ///< Maximum number of iterations for debiasing stage + + GpsrConfig() { + use_gpsr_bb = true; + + stop_thresh = 0.005; + max_iters = 100; + gpsr_tau = 10; + alpha_min = 1.0e-10; + alpha_max = 1.0e+20; + max_sparsity = 0.9; + tau_reduction = 0.8; + + gpsr_beta = 0.5; + gpsr_mu = 0.1; + max_iters_backtrak = 50; + + debias = false; + stop_thresh_debias = 0.001; + max_iters_debias = 50; + } + + void Register(OptionsItf *po); +}; + +inline void GpsrConfig::Register(OptionsItf *po) { + std::string module = "GpsrConfig: "; + po->Register("use-gpsr-bb", &use_gpsr_bb, module+ + "Use the Barzilai-Borwein gradient projection method."); + + po->Register("stop-thresh", &stop_thresh, module+ + "Stopping threshold for GPSR."); + po->Register("max-iters", &max_iters, module+ + "Maximum number of iterations of GPSR."); + po->Register("gpsr-tau", &gpsr_tau, module+ + "Regularization scale for GPSR."); + po->Register("alpha-min", &alpha_min, module+ + "Minimum step size in feasible direction."); + po->Register("alpha-max", &alpha_max, module+ + "Maximum step size in feasible direction."); + po->Register("max-sparsity", &max_sparsity, module+ + "Maximum percentage of dimensions set to 0."); + po->Register("tau-reduction", &tau_reduction, module+ + "Multiply tau by this if maximum sparsity is reached."); + + po->Register("gpsr-beta", &gpsr_beta, module+ + "Step size reduction factor in backtracking line search (0<beta<1)."); + po->Register("gpsr-mu", &gpsr_mu, module+ + "Improvement factor in backtracking line search (0<mu<1)."); + po->Register("max-iters-backtrack", &max_iters_backtrak, module+ + "Maximum number of iterations of backtracking line search."); + + po->Register("debias", &debias, module+ + "Do final debiasing step."); + po->Register("stop-thresh-debias", &stop_thresh_debias, module+ + "Stopping threshold for debiaisng step."); + po->Register("max-iters-debias", &max_iters_debias, module+ + "Maximum number of iterations of debiasing."); +} + +/// Solves a quadratic program in \f$ x \f$, with L_1 regularization: +/// \f[ \min_x 0.5 * x^T H x - g^T x + \tau ||x||_1. \f] +/// This is similar to SolveQuadraticProblem() in sp-matrix.h with an added +/// L_1 term. +template<typename Real> +Real Gpsr(const GpsrConfig &opts, const SpMatrix<Real> &H, + const Vector<Real> &g, Vector<Real> *x, + const char *debug_str = "[unknown]") { + if (opts.use_gpsr_bb) + return GpsrBB(opts, H, g, x, debug_str); + else + return GpsrBasic(opts, H, g, x, debug_str); +} + +/// This is the basic GPSR algorithm, where the step size is determined by a +/// backtracking line search. The line search is called "Armijo rule along the +/// projection arc" in Bertsekas, Nonlinear Programming, 2nd ed. page 230. +template<typename Real> +Real GpsrBasic(const GpsrConfig &opts, const SpMatrix<Real> &H, + const Vector<Real> &g, Vector<Real> *x, + const char *debug_str = "[unknown]"); + +/// This is the paper calls the Barzilai-Borwein variant. This is a constrained +/// Netwon's method where the Hessian is approximated by scaled identity matrix +template<typename Real> +Real GpsrBB(const GpsrConfig &opts, const SpMatrix<Real> &H, + const Vector<Real> &g, Vector<Real> *x, + const char *debug_str = "[unknown]"); + + +} // namespace kaldi + +#endif // KALDI_MATRIX_KALDI_GPSR_H_ diff --git a/kaldi_io/src/kaldi/matrix/kaldi-matrix-inl.h b/kaldi_io/src/kaldi/matrix/kaldi-matrix-inl.h new file mode 100644 index 0000000..8bc4749 --- /dev/null +++ b/kaldi_io/src/kaldi/matrix/kaldi-matrix-inl.h @@ -0,0 +1,62 @@ +// matrix/kaldi-matrix-inl.h + +// Copyright 2009-2011 Microsoft Corporation; Haihua Xu + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#ifndef KALDI_MATRIX_KALDI_MATRIX_INL_H_ +#define KALDI_MATRIX_KALDI_MATRIX_INL_H_ 1 + +#include "matrix/kaldi-vector.h" + +namespace kaldi { + +/// Empty constructor +template<typename Real> +Matrix<Real>::Matrix(): MatrixBase<Real>(NULL, 0, 0, 0) { } + + +template<> +template<> +void MatrixBase<float>::AddVecVec(const float alpha, const VectorBase<float> &ra, const VectorBase<float> &rb); + +template<> +template<> +void MatrixBase<double>::AddVecVec(const double alpha, const VectorBase<double> &ra, const VectorBase<double> &rb); + +template<typename Real> +inline std::ostream & operator << (std::ostream & os, const MatrixBase<Real> & M) { + M.Write(os, false); + return os; +} + +template<typename Real> +inline std::istream & operator >> (std::istream & is, Matrix<Real> & M) { + M.Read(is, false); + return is; +} + + +template<typename Real> +inline std::istream & operator >> (std::istream & is, MatrixBase<Real> & M) { + M.Read(is, false); + return is; +} + +}// namespace kaldi + + +#endif // KALDI_MATRIX_KALDI_MATRIX_INL_H_ diff --git a/kaldi_io/src/kaldi/matrix/kaldi-matrix.h b/kaldi_io/src/kaldi/matrix/kaldi-matrix.h new file mode 100644 index 0000000..e6829e0 --- /dev/null +++ b/kaldi_io/src/kaldi/matrix/kaldi-matrix.h @@ -0,0 +1,983 @@ +// matrix/kaldi-matrix.h + +// Copyright 2009-2011 Ondrej Glembek; Microsoft Corporation; Lukas Burget; +// Saarland University; Petr Schwarz; Yanmin Qian; +// Karel Vesely; Go Vivace Inc.; Haihua Xu + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#ifndef KALDI_MATRIX_KALDI_MATRIX_H_ +#define KALDI_MATRIX_KALDI_MATRIX_H_ 1 + +#include "matrix/matrix-common.h" + +namespace kaldi { + +/// @{ \addtogroup matrix_funcs_scalar + +/// We need to declare this here as it will be a friend function. +/// tr(A B), or tr(A B^T). +template<typename Real> +Real TraceMatMat(const MatrixBase<Real> &A, const MatrixBase<Real> &B, + MatrixTransposeType trans = kNoTrans); +/// @} + +/// \addtogroup matrix_group +/// @{ + +/// Base class which provides matrix operations not involving resizing +/// or allocation. Classes Matrix and SubMatrix inherit from it and take care +/// of allocation and resizing. +template<typename Real> +class MatrixBase { + public: + // so this child can access protected members of other instances. + friend class Matrix<Real>; + // friend declarations for CUDA matrices (see ../cudamatrix/) + friend class CuMatrixBase<Real>; + friend class CuMatrix<Real>; + friend class CuSubMatrix<Real>; + friend class CuPackedMatrix<Real>; + + friend class PackedMatrix<Real>; + + /// Returns number of rows (or zero for emtpy matrix). + inline MatrixIndexT NumRows() const { return num_rows_; } + + /// Returns number of columns (or zero for emtpy matrix). + inline MatrixIndexT NumCols() const { return num_cols_; } + + /// Stride (distance in memory between each row). Will be >= NumCols. + inline MatrixIndexT Stride() const { return stride_; } + + /// Returns size in bytes of the data held by the matrix. + size_t SizeInBytes() const { + return static_cast<size_t>(num_rows_) * static_cast<size_t>(stride_) * + sizeof(Real); + } + + /// Gives pointer to raw data (const). + inline const Real* Data() const { + return data_; + } + + /// Gives pointer to raw data (non-const). + inline Real* Data() { return data_; } + + /// Returns pointer to data for one row (non-const) + inline Real* RowData(MatrixIndexT i) { + KALDI_ASSERT(static_cast<UnsignedMatrixIndexT>(i) < + static_cast<UnsignedMatrixIndexT>(num_rows_)); + return data_ + i * stride_; + } + + /// Returns pointer to data for one row (const) + inline const Real* RowData(MatrixIndexT i) const { + KALDI_ASSERT(static_cast<UnsignedMatrixIndexT>(i) < + static_cast<UnsignedMatrixIndexT>(num_rows_)); + return data_ + i * stride_; + } + + /// Indexing operator, non-const + /// (only checks sizes if compiled with -DKALDI_PARANOID) + inline Real& operator() (MatrixIndexT r, MatrixIndexT c) { + KALDI_PARANOID_ASSERT(static_cast<UnsignedMatrixIndexT>(r) < + static_cast<UnsignedMatrixIndexT>(num_rows_) && + static_cast<UnsignedMatrixIndexT>(c) < + static_cast<UnsignedMatrixIndexT>(num_cols_)); + return *(data_ + r * stride_ + c); + } + /// Indexing operator, provided for ease of debugging (gdb doesn't work + /// with parenthesis operator). + Real &Index (MatrixIndexT r, MatrixIndexT c) { return (*this)(r, c); } + + /// Indexing operator, const + /// (only checks sizes if compiled with -DKALDI_PARANOID) + inline const Real operator() (MatrixIndexT r, MatrixIndexT c) const { + KALDI_PARANOID_ASSERT(static_cast<UnsignedMatrixIndexT>(r) < + static_cast<UnsignedMatrixIndexT>(num_rows_) && + static_cast<UnsignedMatrixIndexT>(c) < + static_cast<UnsignedMatrixIndexT>(num_cols_)); + return *(data_ + r * stride_ + c); + } + + /* Basic setting-to-special values functions. */ + + /// Sets matrix to zero. + void SetZero(); + /// Sets all elements to a specific value. + void Set(Real); + /// Sets to zero, except ones along diagonal [for non-square matrices too] + void SetUnit(); + /// Sets to random values of a normal distribution + void SetRandn(); + /// Sets to numbers uniformly distributed on (0, 1) + void SetRandUniform(); + + /* Copying functions. These do not resize the matrix! */ + + + /// Copy given matrix. (no resize is done). + template<typename OtherReal> + void CopyFromMat(const MatrixBase<OtherReal> & M, + MatrixTransposeType trans = kNoTrans); + + /// Copy from compressed matrix. + void CopyFromMat(const CompressedMatrix &M); + + /// Copy given spmatrix. (no resize is done). + template<typename OtherReal> + void CopyFromSp(const SpMatrix<OtherReal> &M); + + /// Copy given tpmatrix. (no resize is done). + template<typename OtherReal> + void CopyFromTp(const TpMatrix<OtherReal> &M, + MatrixTransposeType trans = kNoTrans); + + /// Copy from CUDA matrix. Implemented in ../cudamatrix/cu-matrix.h + template<typename OtherReal> + void CopyFromMat(const CuMatrixBase<OtherReal> &M, + MatrixTransposeType trans = kNoTrans); + + /// Inverse of vec() operator. Copies vector into matrix, row-by-row. + /// Note that rv.Dim() must either equal NumRows()*NumCols() or + /// NumCols()-- this has two modes of operation. + void CopyRowsFromVec(const VectorBase<Real> &v); + + /// This version of CopyRowsFromVec is implemented in ../cudamatrix/cu-vector.cc + void CopyRowsFromVec(const CuVectorBase<Real> &v); + + template<typename OtherReal> + void CopyRowsFromVec(const VectorBase<OtherReal> &v); + + /// Copies vector into matrix, column-by-column. + /// Note that rv.Dim() must either equal NumRows()*NumCols() or NumRows(); + /// this has two modes of operation. + void CopyColsFromVec(const VectorBase<Real> &v); + + /// Copy vector into specific column of matrix. + void CopyColFromVec(const VectorBase<Real> &v, const MatrixIndexT col); + /// Copy vector into specific row of matrix. + void CopyRowFromVec(const VectorBase<Real> &v, const MatrixIndexT row); + /// Copy vector into diagonal of matrix. + void CopyDiagFromVec(const VectorBase<Real> &v); + + /* Accessing of sub-parts of the matrix. */ + + /// Return specific row of matrix [const]. + inline const SubVector<Real> Row(MatrixIndexT i) const { + KALDI_ASSERT(static_cast<UnsignedMatrixIndexT>(i) < + static_cast<UnsignedMatrixIndexT>(num_rows_)); + return SubVector<Real>(data_ + (i * stride_), NumCols()); + } + + /// Return specific row of matrix. + inline SubVector<Real> Row(MatrixIndexT i) { + KALDI_ASSERT(static_cast<UnsignedMatrixIndexT>(i) < + static_cast<UnsignedMatrixIndexT>(num_rows_)); + return SubVector<Real>(data_ + (i * stride_), NumCols()); + } + + /// Return a sub-part of matrix. + inline SubMatrix<Real> Range(const MatrixIndexT row_offset, + const MatrixIndexT num_rows, + const MatrixIndexT col_offset, + const MatrixIndexT num_cols) const { + return SubMatrix<Real>(*this, row_offset, num_rows, + col_offset, num_cols); + } + inline SubMatrix<Real> RowRange(const MatrixIndexT row_offset, + const MatrixIndexT num_rows) const { + return SubMatrix<Real>(*this, row_offset, num_rows, 0, num_cols_); + } + inline SubMatrix<Real> ColRange(const MatrixIndexT col_offset, + const MatrixIndexT num_cols) const { + return SubMatrix<Real>(*this, 0, num_rows_, col_offset, num_cols); + } + + /* Various special functions. */ + /// Returns sum of all elements in matrix. + Real Sum() const; + /// Returns trace of matrix. + Real Trace(bool check_square = true) const; + // If check_square = true, will crash if matrix is not square. + + /// Returns maximum element of matrix. + Real Max() const; + /// Returns minimum element of matrix. + Real Min() const; + + /// Element by element multiplication with a given matrix. + void MulElements(const MatrixBase<Real> &A); + + /// Divide each element by the corresponding element of a given matrix. + void DivElements(const MatrixBase<Real> &A); + + /// Multiply each element with a scalar value. + void Scale(Real alpha); + + /// Set, element-by-element, *this = max(*this, A) + void Max(const MatrixBase<Real> &A); + + /// Equivalent to (*this) = (*this) * diag(scale). Scaling + /// each column by a scalar taken from that dimension of the vector. + void MulColsVec(const VectorBase<Real> &scale); + + /// Equivalent to (*this) = diag(scale) * (*this). Scaling + /// each row by a scalar taken from that dimension of the vector. + void MulRowsVec(const VectorBase<Real> &scale); + + /// Divide each row into src.NumCols() equal groups, and then scale i'th row's + /// j'th group of elements by src(i, j). Requires src.NumRows() == + /// this->NumRows() and this->NumCols() % src.NumCols() == 0. + void MulRowsGroupMat(const MatrixBase<Real> &src); + + /// Returns logdet of matrix. + Real LogDet(Real *det_sign = NULL) const; + + /// matrix inverse. + /// if inverse_needed = false, will fill matrix with garbage. + /// (only useful if logdet wanted). + void Invert(Real *log_det = NULL, Real *det_sign = NULL, + bool inverse_needed = true); + /// matrix inverse [double]. + /// if inverse_needed = false, will fill matrix with garbage + /// (only useful if logdet wanted). + /// Does inversion in double precision even if matrix was not double. + void InvertDouble(Real *LogDet = NULL, Real *det_sign = NULL, + bool inverse_needed = true); + + /// Inverts all the elements of the matrix + void InvertElements(); + + /// Transpose the matrix. This one is only + /// applicable to square matrices (the one in the + /// Matrix child class works also for non-square. + void Transpose(); + + /// Copies column r from column indices[r] of src. + /// As a special case, if indexes[i] == -1, sets column i to zero + /// indices.size() must equal this->NumCols(), + /// all elements of "reorder" must be in [-1, src.NumCols()-1], + /// and src.NumRows() must equal this.NumRows() + void CopyCols(const MatrixBase<Real> &src, + const std::vector<MatrixIndexT> &indices); + + /// Copies row r from row indices[r] of src. + /// As a special case, if indexes[i] == -1, sets row i to zero + /// "reorder".size() must equal this->NumRows(), + /// all elements of "reorder" must be in [-1, src.NumRows()-1], + /// and src.NumCols() must equal this.NumCols() + void CopyRows(const MatrixBase<Real> &src, + const std::vector<MatrixIndexT> &indices); + + /// Applies floor to all matrix elements + void ApplyFloor(Real floor_val); + + /// Applies floor to all matrix elements + void ApplyCeiling(Real ceiling_val); + + /// Calculates log of all the matrix elemnts + void ApplyLog(); + + /// Exponentiate each of the elements. + void ApplyExp(); + + /// Applies power to all matrix elements + void ApplyPow(Real power); + + /// Apply power to the absolute value of each element. + /// Include the sign of the input element if include_sign == true. + /// If the power is negative and the input to the power is zero, + /// The output will be set zero. + void ApplyPowAbs(Real power, bool include_sign=false); + + /// Applies the Heaviside step function (x > 0 ? 1 : 0) to all matrix elements + /// Note: in general you can make different choices for x = 0, but for now + /// please leave it as it (i.e. returning zero) because it affects the + /// RectifiedLinearComponent in the neural net code. + void ApplyHeaviside(); + + /// Eigenvalue Decomposition of a square NxN matrix into the form (*this) = P D + /// P^{-1}. Be careful: the relationship of D to the eigenvalues we output is + /// slightly complicated, due to the need for P to be real. In the symmetric + /// case D is diagonal and real, but in + /// the non-symmetric case there may be complex-conjugate pairs of eigenvalues. + /// In this case, for the equation (*this) = P D P^{-1} to hold, D must actually + /// be block diagonal, with 2x2 blocks corresponding to any such pairs. If a + /// pair is lambda +- i*mu, D will have a corresponding 2x2 block + /// [lambda, mu; -mu, lambda]. + /// Note that if the input matrix (*this) is non-invertible, P may not be invertible + /// so in this case instead of the equation (*this) = P D P^{-1} holding, we have + /// instead (*this) P = P D. + /// + /// The non-member function CreateEigenvalueMatrix creates D from eigs_real and eigs_imag. + void Eig(MatrixBase<Real> *P, + VectorBase<Real> *eigs_real, + VectorBase<Real> *eigs_imag) const; + + /// The Power method attempts to take the matrix to a power using a method that + /// works in general for fractional and negative powers. The input matrix must + /// be invertible and have reasonable condition (or we don't guarantee the + /// results. The method is based on the eigenvalue decomposition. It will + /// return false and leave the matrix unchanged, if at entry the matrix had + /// real negative eigenvalues (or if it had zero eigenvalues and the power was + /// negative). + bool Power(Real pow); + + /** Singular value decomposition + Major limitations: + For nonsquare matrices, we assume m>=n (NumRows >= NumCols), and we return + the "skinny" Svd, i.e. the matrix in the middle is diagonal, and the + one on the left is rectangular. + + In Svd, *this = U*diag(S)*Vt. + Null pointers for U and/or Vt at input mean we do not want that output. We + expect that S.Dim() == m, U is either NULL or m by n, + and v is either NULL or n by n. + The singular values are not sorted (use SortSvd for that). */ + void DestructiveSvd(VectorBase<Real> *s, MatrixBase<Real> *U, + MatrixBase<Real> *Vt); // Destroys calling matrix. + + /// Compute SVD (*this) = U diag(s) Vt. Note that the V in the call is already + /// transposed; the normal formulation is U diag(s) V^T. + /// Null pointers for U or V mean we don't want that output (this saves + /// compute). The singular values are not sorted (use SortSvd for that). + void Svd(VectorBase<Real> *s, MatrixBase<Real> *U, + MatrixBase<Real> *Vt) const; + /// Compute SVD but only retain the singular values. + void Svd(VectorBase<Real> *s) const { Svd(s, NULL, NULL); } + + + /// Returns smallest singular value. + Real MinSingularValue() const { + Vector<Real> tmp(std::min(NumRows(), NumCols())); + Svd(&tmp); + return tmp.Min(); + } + + void TestUninitialized() const; // This function is designed so that if any element + // if the matrix is uninitialized memory, valgrind will complain. + + /// Returns condition number by computing Svd. Works even if cols > rows. + /// Returns infinity if all singular values are zero. + Real Cond() const; + + /// Returns true if matrix is Symmetric. + bool IsSymmetric(Real cutoff = 1.0e-05) const; // replace magic number + + /// Returns true if matrix is Diagonal. + bool IsDiagonal(Real cutoff = 1.0e-05) const; // replace magic number + + /// Returns true if the matrix is all zeros, except for ones on diagonal. (it + /// does not have to be square). More specifically, this function returns + /// false if for any i, j, (*this)(i, j) differs by more than cutoff from the + /// expression (i == j ? 1 : 0). + bool IsUnit(Real cutoff = 1.0e-05) const; // replace magic number + + /// Returns true if matrix is all zeros. + bool IsZero(Real cutoff = 1.0e-05) const; // replace magic number + + /// Frobenius norm, which is the sqrt of sum of square elements. Same as Schatten 2-norm, + /// or just "2-norm". + Real FrobeniusNorm() const; + + /// Returns true if ((*this)-other).FrobeniusNorm() + /// <= tol * (*this).FrobeniusNorm(). + bool ApproxEqual(const MatrixBase<Real> &other, float tol = 0.01) const; + + /// Tests for exact equality. It's usually preferable to use ApproxEqual. + bool Equal(const MatrixBase<Real> &other) const; + + /// largest absolute value. + Real LargestAbsElem() const; // largest absolute value. + + /// Returns log(sum(exp())) without exp overflow + /// If prune > 0.0, it uses a pruning beam, discarding + /// terms less than (max - prune). Note: in future + /// we may change this so that if prune = 0.0, it takes + /// the max, so use -1 if you don't want to prune. + Real LogSumExp(Real prune = -1.0) const; + + /// Apply soft-max to the collection of all elements of the + /// matrix and return normalizer (log sum of exponentials). + Real ApplySoftMax(); + + /// Set each element to the sigmoid of the corresponding element of "src". + void Sigmoid(const MatrixBase<Real> &src); + + /// Set each element to y = log(1 + exp(x)) + void SoftHinge(const MatrixBase<Real> &src); + + /// Apply the function y(i) = (sum_{j = i*G}^{(i+1)*G-1} x_j^(power))^(1 / p). + /// Requires src.NumRows() == this->NumRows() and src.NumCols() % this->NumCols() == 0. + void GroupPnorm(const MatrixBase<Real> &src, Real power); + + + /// Calculate derivatives for the GroupPnorm function above... + /// if "input" is the input to the GroupPnorm function above (i.e. the "src" variable), + /// and "output" is the result of the computation (i.e. the "this" of that function + /// call), and *this has the same dimension as "input", then it sets each element + /// of *this to the derivative d(output-elem)/d(input-elem) for each element of "input", where + /// "output-elem" is whichever element of output depends on that input element. + void GroupPnormDeriv(const MatrixBase<Real> &input, const MatrixBase<Real> &output, + Real power); + + + /// Set each element to the tanh of the corresponding element of "src". + void Tanh(const MatrixBase<Real> &src); + + // Function used in backpropagating derivatives of the sigmoid function: + // element-by-element, set *this = diff * value * (1.0 - value). + void DiffSigmoid(const MatrixBase<Real> &value, + const MatrixBase<Real> &diff); + + // Function used in backpropagating derivatives of the tanh function: + // element-by-element, set *this = diff * (1.0 - value^2). + void DiffTanh(const MatrixBase<Real> &value, + const MatrixBase<Real> &diff); + + /** Uses Svd to compute the eigenvalue decomposition of a symmetric positive + * semi-definite matrix: (*this) = rP * diag(rS) * rP^T, with rP an + * orthogonal matrix so rP^{-1} = rP^T. Throws exception if input was not + * positive semi-definite (check_thresh controls how stringent the check is; + * set it to 2 to ensure it won't ever complain, but it will zero out negative + * dimensions in your matrix. + */ + void SymPosSemiDefEig(VectorBase<Real> *s, MatrixBase<Real> *P, + Real check_thresh = 0.001); + + friend Real kaldi::TraceMatMat<Real>(const MatrixBase<Real> &A, + const MatrixBase<Real> &B, MatrixTransposeType trans); // tr (A B) + + // so it can get around const restrictions on the pointer to data_. + friend class SubMatrix<Real>; + + /// Add a scalar to each element + void Add(const Real alpha); + + /// Add a scalar to each diagonal element. + void AddToDiag(const Real alpha); + + /// *this += alpha * a * b^T + template<typename OtherReal> + void AddVecVec(const Real alpha, const VectorBase<OtherReal> &a, + const VectorBase<OtherReal> &b); + + /// [each row of *this] += alpha * v + template<typename OtherReal> + void AddVecToRows(const Real alpha, const VectorBase<OtherReal> &v); + + /// [each col of *this] += alpha * v + template<typename OtherReal> + void AddVecToCols(const Real alpha, const VectorBase<OtherReal> &v); + + /// *this += alpha * M [or M^T] + void AddMat(const Real alpha, const MatrixBase<Real> &M, + MatrixTransposeType transA = kNoTrans); + + /// *this = beta * *this + alpha * M M^T, for symmetric matrices. It only + /// updates the lower triangle of *this. It will leave the matrix asymmetric; + /// if you need it symmetric as a regular matrix, do CopyLowerToUpper(). + void SymAddMat2(const Real alpha, const MatrixBase<Real> &M, + MatrixTransposeType transA, Real beta); + + /// *this = beta * *this + alpha * diag(v) * M [or M^T]. + /// The same as adding M but scaling each row M_i by v(i). + void AddDiagVecMat(const Real alpha, VectorBase<Real> &v, + const MatrixBase<Real> &M, MatrixTransposeType transM, + Real beta = 1.0); + + /// *this = beta * *this + alpha * M [or M^T] * diag(v) + /// The same as adding M but scaling each column M_j by v(j). + void AddMatDiagVec(const Real alpha, + const MatrixBase<Real> &M, MatrixTransposeType transM, + VectorBase<Real> &v, + Real beta = 1.0); + + /// *this = beta * *this + alpha * A .* B (.* element by element multiplication) + void AddMatMatElements(const Real alpha, + const MatrixBase<Real>& A, + const MatrixBase<Real>& B, + const Real beta); + + /// *this += alpha * S + template<typename OtherReal> + void AddSp(const Real alpha, const SpMatrix<OtherReal> &S); + + void AddMatMat(const Real alpha, + const MatrixBase<Real>& A, MatrixTransposeType transA, + const MatrixBase<Real>& B, MatrixTransposeType transB, + const Real beta); + + /// *this = a * b / c (by element; when c = 0, *this = a) + void AddMatMatDivMat(const MatrixBase<Real>& A, + const MatrixBase<Real>& B, + const MatrixBase<Real>& C); + + /// A version of AddMatMat specialized for when the second argument + /// contains a lot of zeroes. + void AddMatSmat(const Real alpha, + const MatrixBase<Real>& A, MatrixTransposeType transA, + const MatrixBase<Real>& B, MatrixTransposeType transB, + const Real beta); + + /// A version of AddMatMat specialized for when the first argument + /// contains a lot of zeroes. + void AddSmatMat(const Real alpha, + const MatrixBase<Real>& A, MatrixTransposeType transA, + const MatrixBase<Real>& B, MatrixTransposeType transB, + const Real beta); + + /// this <-- beta*this + alpha*A*B*C. + void AddMatMatMat(const Real alpha, + const MatrixBase<Real>& A, MatrixTransposeType transA, + const MatrixBase<Real>& B, MatrixTransposeType transB, + const MatrixBase<Real>& C, MatrixTransposeType transC, + const Real beta); + + /// this <-- beta*this + alpha*SpA*B. + // This and the routines below are really + // stubs that need to be made more efficient. + void AddSpMat(const Real alpha, + const SpMatrix<Real>& A, + const MatrixBase<Real>& B, MatrixTransposeType transB, + const Real beta) { + Matrix<Real> M(A); + return AddMatMat(alpha, M, kNoTrans, B, transB, beta); + } + /// this <-- beta*this + alpha*A*B. + void AddTpMat(const Real alpha, + const TpMatrix<Real>& A, MatrixTransposeType transA, + const MatrixBase<Real>& B, MatrixTransposeType transB, + const Real beta) { + Matrix<Real> M(A); + return AddMatMat(alpha, M, transA, B, transB, beta); + } + /// this <-- beta*this + alpha*A*B. + void AddMatSp(const Real alpha, + const MatrixBase<Real>& A, MatrixTransposeType transA, + const SpMatrix<Real>& B, + const Real beta) { + Matrix<Real> M(B); + return AddMatMat(alpha, A, transA, M, kNoTrans, beta); + } + /// this <-- beta*this + alpha*A*B*C. + void AddSpMatSp(const Real alpha, + const SpMatrix<Real> &A, + const MatrixBase<Real>& B, MatrixTransposeType transB, + const SpMatrix<Real>& C, + const Real beta) { + Matrix<Real> M(A), N(C); + return AddMatMatMat(alpha, M, kNoTrans, B, transB, N, kNoTrans, beta); + } + /// this <-- beta*this + alpha*A*B. + void AddMatTp(const Real alpha, + const MatrixBase<Real>& A, MatrixTransposeType transA, + const TpMatrix<Real>& B, MatrixTransposeType transB, + const Real beta) { + Matrix<Real> M(B); + return AddMatMat(alpha, A, transA, M, transB, beta); + } + + /// this <-- beta*this + alpha*A*B. + void AddTpTp(const Real alpha, + const TpMatrix<Real>& A, MatrixTransposeType transA, + const TpMatrix<Real>& B, MatrixTransposeType transB, + const Real beta) { + Matrix<Real> M(A), N(B); + return AddMatMat(alpha, M, transA, N, transB, beta); + } + + /// this <-- beta*this + alpha*A*B. + // This one is more efficient, not like the others above. + void AddSpSp(const Real alpha, + const SpMatrix<Real>& A, const SpMatrix<Real>& B, + const Real beta); + + /// Copy lower triangle to upper triangle (symmetrize) + void CopyLowerToUpper(); + + /// Copy upper triangle to lower triangle (symmetrize) + void CopyUpperToLower(); + + /// This function orthogonalizes the rows of a matrix using the Gram-Schmidt + /// process. It is only applicable if NumRows() <= NumCols(). It will use + /// random number generation to fill in rows with something nonzero, in cases + /// where the original matrix was of deficient row rank. + void OrthogonalizeRows(); + + /// stream read. + /// Use instead of stream<<*this, if you want to add to existing contents. + // Will throw exception on failure. + void Read(std::istream & in, bool binary, bool add = false); + /// write to stream. + void Write(std::ostream & out, bool binary) const; + + // Below is internal methods for Svd, user does not have to know about this. +#if !defined(HAVE_ATLAS) && !defined(USE_KALDI_SVD) + // protected: + // Should be protected but used directly in testing routine. + // destroys *this! + void LapackGesvd(VectorBase<Real> *s, MatrixBase<Real> *U, + MatrixBase<Real> *Vt); +#else + protected: + // destroys *this! + bool JamaSvd(VectorBase<Real> *s, MatrixBase<Real> *U, + MatrixBase<Real> *V); + +#endif + protected: + + /// Initializer, callable only from child. + explicit MatrixBase(Real *data, MatrixIndexT cols, MatrixIndexT rows, MatrixIndexT stride) : + data_(data), num_cols_(cols), num_rows_(rows), stride_(stride) { + KALDI_ASSERT_IS_FLOATING_TYPE(Real); + } + + /// Initializer, callable only from child. + /// Empty initializer, for un-initialized matrix. + explicit MatrixBase(): data_(NULL) { + KALDI_ASSERT_IS_FLOATING_TYPE(Real); + } + + // Make sure pointers to MatrixBase cannot be deleted. + ~MatrixBase() { } + + /// A workaround that allows SubMatrix to get a pointer to non-const data + /// for const Matrix. Unfortunately C++ does not allow us to declare a + /// "public const" inheritance or anything like that, so it would require + /// a lot of work to make the SubMatrix class totally const-correct-- + /// we would have to override many of the Matrix functions. + inline Real* Data_workaround() const { + return data_; + } + + /// data memory area + Real* data_; + + /// these atributes store the real matrix size as it is stored in memory + /// including memalignment + MatrixIndexT num_cols_; /// < Number of columns + MatrixIndexT num_rows_; /// < Number of rows + /** True number of columns for the internal matrix. This number may differ + * from num_cols_ as memory alignment might be used. */ + MatrixIndexT stride_; + private: + KALDI_DISALLOW_COPY_AND_ASSIGN(MatrixBase); +}; + +/// A class for storing matrices. +template<typename Real> +class Matrix : public MatrixBase<Real> { + public: + + /// Empty constructor. + Matrix(); + + /// Basic constructor. Sets to zero by default. + /// if set_zero == false, memory contents are undefined. + Matrix(const MatrixIndexT r, const MatrixIndexT c, + MatrixResizeType resize_type = kSetZero): + MatrixBase<Real>() { Resize(r, c, resize_type); } + + /// Copy constructor from CUDA matrix + /// This is defined in ../cudamatrix/cu-matrix.h + template<typename OtherReal> + explicit Matrix(const CuMatrixBase<OtherReal> &cu, + MatrixTransposeType trans = kNoTrans); + + + /// Swaps the contents of *this and *other. Shallow swap. + void Swap(Matrix<Real> *other); + + /// Defined in ../cudamatrix/cu-matrix.cc + void Swap(CuMatrix<Real> *mat); + + /// Constructor from any MatrixBase. Can also copy with transpose. + /// Allocates new memory. + explicit Matrix(const MatrixBase<Real> & M, + MatrixTransposeType trans = kNoTrans); + + /// Same as above, but need to avoid default copy constructor. + Matrix(const Matrix<Real> & M); // (cannot make explicit) + + /// Copy constructor: as above, but from another type. + template<typename OtherReal> + explicit Matrix(const MatrixBase<OtherReal> & M, + MatrixTransposeType trans = kNoTrans); + + /// Copy constructor taking SpMatrix... + /// It is symmetric, so no option for transpose, and NumRows == Cols + template<typename OtherReal> + explicit Matrix(const SpMatrix<OtherReal> & M) : MatrixBase<Real>() { + Resize(M.NumRows(), M.NumRows(), kUndefined); + this->CopyFromSp(M); + } + + /// Constructor from CompressedMatrix + explicit Matrix(const CompressedMatrix &C); + + /// Copy constructor taking TpMatrix... + template <typename OtherReal> + explicit Matrix(const TpMatrix<OtherReal> & M, + MatrixTransposeType trans = kNoTrans) : MatrixBase<Real>() { + if (trans == kNoTrans) { + Resize(M.NumRows(), M.NumCols(), kUndefined); + this->CopyFromTp(M); + } else { + Resize(M.NumCols(), M.NumRows(), kUndefined); + this->CopyFromTp(M, kTrans); + } + } + + /// read from stream. + // Unlike one in base, allows resizing. + void Read(std::istream & in, bool binary, bool add = false); + + /// Remove a specified row. + void RemoveRow(MatrixIndexT i); + + /// Transpose the matrix. Works for non-square + /// matrices as well as square ones. + void Transpose(); + + /// Distructor to free matrices. + ~Matrix() { Destroy(); } + + /// Sets matrix to a specified size (zero is OK as long as both r and c are + /// zero). The value of the new data depends on resize_type: + /// -if kSetZero, the new data will be zero + /// -if kUndefined, the new data will be undefined + /// -if kCopyData, the new data will be the same as the old data in any + /// shared positions, and zero elsewhere. + /// This function takes time proportional to the number of data elements. + void Resize(const MatrixIndexT r, + const MatrixIndexT c, + MatrixResizeType resize_type = kSetZero); + + /// Assignment operator that takes MatrixBase. + Matrix<Real> &operator = (const MatrixBase<Real> &other) { + if (MatrixBase<Real>::NumRows() != other.NumRows() || + MatrixBase<Real>::NumCols() != other.NumCols()) + Resize(other.NumRows(), other.NumCols(), kUndefined); + MatrixBase<Real>::CopyFromMat(other); + return *this; + } + + /// Assignment operator. Needed for inclusion in std::vector. + Matrix<Real> &operator = (const Matrix<Real> &other) { + if (MatrixBase<Real>::NumRows() != other.NumRows() || + MatrixBase<Real>::NumCols() != other.NumCols()) + Resize(other.NumRows(), other.NumCols(), kUndefined); + MatrixBase<Real>::CopyFromMat(other); + return *this; + } + + + private: + /// Deallocates memory and sets to empty matrix (dimension 0, 0). + void Destroy(); + + /// Init assumes the current class contents are invalid (i.e. junk or have + /// already been freed), and it sets the matrix to newly allocated memory with + /// the specified number of rows and columns. r == c == 0 is acceptable. The data + /// memory contents will be undefined. + void Init(const MatrixIndexT r, + const MatrixIndexT c); + +}; +/// @} end "addtogroup matrix_group" + +/// \addtogroup matrix_funcs_io +/// @{ + +/// A structure containing the HTK header. +/// [TODO: change the style of the variables to Kaldi-compliant] +struct HtkHeader { + /// Number of samples. + int32 mNSamples; + /// Sample period. + int32 mSamplePeriod; + /// Sample size + int16 mSampleSize; + /// Sample kind. + uint16 mSampleKind; +}; + +// Read HTK formatted features from file into matrix. +template<typename Real> +bool ReadHtk(std::istream &is, Matrix<Real> *M, HtkHeader *header_ptr); + +// Write (HTK format) features to file from matrix. +template<typename Real> +bool WriteHtk(std::ostream &os, const MatrixBase<Real> &M, HtkHeader htk_hdr); + +// Write (CMUSphinx format) features to file from matrix. +template<typename Real> +bool WriteSphinx(std::ostream &os, const MatrixBase<Real> &M); + +/// @} end of "addtogroup matrix_funcs_io" + +/** + Sub-matrix representation. + Can work with sub-parts of a matrix using this class. + Note that SubMatrix is not very const-correct-- it allows you to + change the contents of a const Matrix. Be careful! +*/ + +template<typename Real> +class SubMatrix : public MatrixBase<Real> { + public: + // Initialize a SubMatrix from part of a matrix; this is + // a bit like A(b:c, d:e) in Matlab. + // This initializer is against the proper semantics of "const", since + // SubMatrix can change its contents. It would be hard to implement + // a "const-safe" version of this class. + SubMatrix(const MatrixBase<Real>& T, + const MatrixIndexT ro, // row offset, 0 < ro < NumRows() + const MatrixIndexT r, // number of rows, r > 0 + const MatrixIndexT co, // column offset, 0 < co < NumCols() + const MatrixIndexT c); // number of columns, c > 0 + + // This initializer is mostly intended for use in CuMatrix and related + // classes. Be careful! + SubMatrix(Real *data, + MatrixIndexT num_rows, + MatrixIndexT num_cols, + MatrixIndexT stride); + + ~SubMatrix<Real>() {} + + /// This type of constructor is needed for Range() to work [in Matrix base + /// class]. Cannot make it explicit. + SubMatrix<Real> (const SubMatrix &other): + MatrixBase<Real> (other.data_, other.num_cols_, other.num_rows_, + other.stride_) {} + + private: + /// Disallow assignment. + SubMatrix<Real> &operator = (const SubMatrix<Real> &other); +}; +/// @} End of "addtogroup matrix_funcs_io". + +/// \addtogroup matrix_funcs_scalar +/// @{ + +// Some declarations. These are traces of products. + + +template<typename Real> +bool ApproxEqual(const MatrixBase<Real> &A, + const MatrixBase<Real> &B, Real tol = 0.01) { + return A.ApproxEqual(B, tol); +} + +template<typename Real> +inline void AssertEqual(const MatrixBase<Real> &A, const MatrixBase<Real> &B, + float tol = 0.01) { + KALDI_ASSERT(A.ApproxEqual(B, tol)); +} + +/// Returns trace of matrix. +template <typename Real> +double TraceMat(const MatrixBase<Real> &A) { return A.Trace(); } + + +/// Returns tr(A B C) +template <typename Real> +Real TraceMatMatMat(const MatrixBase<Real> &A, MatrixTransposeType transA, + const MatrixBase<Real> &B, MatrixTransposeType transB, + const MatrixBase<Real> &C, MatrixTransposeType transC); + +/// Returns tr(A B C D) +template <typename Real> +Real TraceMatMatMatMat(const MatrixBase<Real> &A, MatrixTransposeType transA, + const MatrixBase<Real> &B, MatrixTransposeType transB, + const MatrixBase<Real> &C, MatrixTransposeType transC, + const MatrixBase<Real> &D, MatrixTransposeType transD); + +/// @} end "addtogroup matrix_funcs_scalar" + + +/// \addtogroup matrix_funcs_misc +/// @{ + + +/// Function to ensure that SVD is sorted. This function is made as generic as +/// possible, to be applicable to other types of problems. s->Dim() should be +/// the same as U->NumCols(), and we sort s from greatest to least absolute +/// value (if sort_on_absolute_value == true) or greatest to least value +/// otherwise, moving the columns of U, if it exists, and the rows of Vt, if it +/// exists, around in the same way. Note: the "absolute value" part won't matter +/// if this is an actual SVD, since singular values are non-negative. +template<typename Real> void SortSvd(VectorBase<Real> *s, MatrixBase<Real> *U, + MatrixBase<Real>* Vt = NULL, + bool sort_on_absolute_value = true); + +/// Creates the eigenvalue matrix D that is part of the decomposition used Matrix::Eig. +/// D will be block-diagonal with blocks of size 1 (for real eigenvalues) or 2x2 +/// for complex pairs. If a complex pair is lambda +- i*mu, D will have a corresponding +/// 2x2 block [lambda, mu; -mu, lambda]. +/// This function will throw if any complex eigenvalues are not in complex conjugate +/// pairs (or the members of such pairs are not consecutively numbered). +template<typename Real> +void CreateEigenvalueMatrix(const VectorBase<Real> &real, const VectorBase<Real> &imag, + MatrixBase<Real> *D); + +/// The following function is used in Matrix::Power, and separately tested, so we +/// declare it here mainly for the testing code to see. It takes a complex value to +/// a power using a method that will work for noninteger powers (but will fail if the +/// complex value is real and negative). +template<typename Real> +bool AttemptComplexPower(Real *x_re, Real *x_im, Real power); + + + +/// @} end of addtogroup matrix_funcs_misc + +/// \addtogroup matrix_funcs_io +/// @{ +template<typename Real> +std::ostream & operator << (std::ostream & Out, const MatrixBase<Real> & M); + +template<typename Real> +std::istream & operator >> (std::istream & In, MatrixBase<Real> & M); + +// The Matrix read allows resizing, so we override the MatrixBase one. +template<typename Real> +std::istream & operator >> (std::istream & In, Matrix<Real> & M); + + +template<typename Real> +bool SameDim(const MatrixBase<Real> &M, const MatrixBase<Real> &N) { + return (M.NumRows() == N.NumRows() && M.NumCols() == N.NumCols()); +} + +/// @} end of \addtogroup matrix_funcs_io + + +} // namespace kaldi + + + +// we need to include the implementation and some +// template specializations. +#include "matrix/kaldi-matrix-inl.h" + + +#endif // KALDI_MATRIX_KALDI_MATRIX_H_ diff --git a/kaldi_io/src/kaldi/matrix/kaldi-vector-inl.h b/kaldi_io/src/kaldi/matrix/kaldi-vector-inl.h new file mode 100644 index 0000000..c3a4f52 --- /dev/null +++ b/kaldi_io/src/kaldi/matrix/kaldi-vector-inl.h @@ -0,0 +1,58 @@ +// matrix/kaldi-vector-inl.h + +// Copyright 2009-2011 Ondrej Glembek; Microsoft Corporation; +// Haihua Xu + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +// This is an internal header file, included by other library headers. +// You should not attempt to use it directly. + +#ifndef KALDI_MATRIX_KALDI_VECTOR_INL_H_ +#define KALDI_MATRIX_KALDI_VECTOR_INL_H_ 1 + +namespace kaldi { + +template<typename Real> +std::ostream & operator << (std::ostream &os, const VectorBase<Real> &rv) { + rv.Write(os, false); + return os; +} + +template<typename Real> +std::istream &operator >> (std::istream &is, VectorBase<Real> &rv) { + rv.Read(is, false); + return is; +} + +template<typename Real> +std::istream &operator >> (std::istream &is, Vector<Real> &rv) { + rv.Read(is, false); + return is; +} + +template<> +template<> +void VectorBase<float>::AddVec(const float alpha, const VectorBase<float> &rv); + +template<> +template<> +void VectorBase<double>::AddVec<double>(const double alpha, + const VectorBase<double> &rv); + +} // namespace kaldi + +#endif // KALDI_MATRIX_KALDI_VECTOR_INL_H_ diff --git a/kaldi_io/src/kaldi/matrix/kaldi-vector.h b/kaldi_io/src/kaldi/matrix/kaldi-vector.h new file mode 100644 index 0000000..2b3395b --- /dev/null +++ b/kaldi_io/src/kaldi/matrix/kaldi-vector.h @@ -0,0 +1,585 @@ +// matrix/kaldi-vector.h + +// Copyright 2009-2012 Ondrej Glembek; Microsoft Corporation; Lukas Burget; +// Saarland University (Author: Arnab Ghoshal); +// Ariya Rastrow; Petr Schwarz; Yanmin Qian; +// Karel Vesely; Go Vivace Inc.; Arnab Ghoshal +// Wei Shi; + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#ifndef KALDI_MATRIX_KALDI_VECTOR_H_ +#define KALDI_MATRIX_KALDI_VECTOR_H_ 1 + +#include "matrix/matrix-common.h" + +namespace kaldi { + +/// \addtogroup matrix_group +/// @{ + +/// Provides a vector abstraction class. +/// This class provides a way to work with vectors in kaldi. +/// It encapsulates basic operations and memory optimizations. +template<typename Real> +class VectorBase { + public: + /// Set vector to all zeros. + void SetZero(); + + /// Returns true if matrix is all zeros. + bool IsZero(Real cutoff = 1.0e-06) const; // replace magic number + + /// Set all members of a vector to a specified value. + void Set(Real f); + + /// Set vector to random normally-distributed noise. + void SetRandn(); + + /// This function returns a random index into this vector, + /// chosen with probability proportional to the corresponding + /// element. Requires that this->Min() >= 0 and this->Sum() > 0. + MatrixIndexT RandCategorical() const; + + /// Returns the dimension of the vector. + inline MatrixIndexT Dim() const { return dim_; } + + /// Returns the size in memory of the vector, in bytes. + inline MatrixIndexT SizeInBytes() const { return (dim_*sizeof(Real)); } + + /// Returns a pointer to the start of the vector's data. + inline Real* Data() { return data_; } + + /// Returns a pointer to the start of the vector's data (const). + inline const Real* Data() const { return data_; } + + /// Indexing operator (const). + inline Real operator() (MatrixIndexT i) const { + KALDI_PARANOID_ASSERT(static_cast<UnsignedMatrixIndexT>(i) < + static_cast<UnsignedMatrixIndexT>(dim_)); + return *(data_ + i); + } + + /// Indexing operator (non-const). + inline Real & operator() (MatrixIndexT i) { + KALDI_PARANOID_ASSERT(static_cast<UnsignedMatrixIndexT>(i) < + static_cast<UnsignedMatrixIndexT>(dim_)); + return *(data_ + i); + } + + /** @brief Returns a sub-vector of a vector (a range of elements). + * @param o [in] Origin, 0 < o < Dim() + * @param l [in] Length 0 < l < Dim()-o + * @return A SubVector object that aliases the data of the Vector object. + * See @c SubVector class for details */ + SubVector<Real> Range(const MatrixIndexT o, const MatrixIndexT l) { + return SubVector<Real>(*this, o, l); + } + + /** @brief Returns a const sub-vector of a vector (a range of elements). + * @param o [in] Origin, 0 < o < Dim() + * @param l [in] Length 0 < l < Dim()-o + * @return A SubVector object that aliases the data of the Vector object. + * See @c SubVector class for details */ + const SubVector<Real> Range(const MatrixIndexT o, + const MatrixIndexT l) const { + return SubVector<Real>(*this, o, l); + } + + /// Copy data from another vector (must match own size). + void CopyFromVec(const VectorBase<Real> &v); + + /// Copy data from a SpMatrix or TpMatrix (must match own size). + template<typename OtherReal> + void CopyFromPacked(const PackedMatrix<OtherReal> &M); + + /// Copy data from another vector of different type (double vs. float) + template<typename OtherReal> + void CopyFromVec(const VectorBase<OtherReal> &v); + + /// Copy from CuVector. This is defined in ../cudamatrix/cu-vector.h + template<typename OtherReal> + void CopyFromVec(const CuVectorBase<OtherReal> &v); + + + /// Apply natural log to all elements. Throw if any element of + /// the vector is negative (but doesn't complain about zero; the + /// log will be -infinity + void ApplyLog(); + + /// Apply natural log to another vector and put result in *this. + void ApplyLogAndCopy(const VectorBase<Real> &v); + + /// Apply exponential to each value in vector. + void ApplyExp(); + + /// Take absolute value of each of the elements + void ApplyAbs(); + + /// Applies floor to all elements. Returns number of elements floored. + MatrixIndexT ApplyFloor(Real floor_val); + + /// Applies ceiling to all elements. Returns number of elements changed. + MatrixIndexT ApplyCeiling(Real ceil_val); + + /// Applies floor to all elements. Returns number of elements floored. + MatrixIndexT ApplyFloor(const VectorBase<Real> &floor_vec); + + /// Apply soft-max to vector and return normalizer (log sum of exponentials). + /// This is the same as: \f$ x(i) = exp(x(i)) / \sum_i exp(x(i)) \f$ + Real ApplySoftMax(); + + /// Sets each element of *this to the tanh of the corresponding element of "src". + void Tanh(const VectorBase<Real> &src); + + /// Sets each element of *this to the sigmoid function of the corresponding + /// element of "src". + void Sigmoid(const VectorBase<Real> &src); + + /// Take all elements of vector to a power. + void ApplyPow(Real power); + + /// Take the absolute value of all elements of a vector to a power. + /// Include the sign of the input element if include_sign == true. + /// If power is negative and the input value is zero, the output is set zero. + void ApplyPowAbs(Real power, bool include_sign=false); + + /// Compute the p-th norm of the vector. + Real Norm(Real p) const; + + /// Returns true if ((*this)-other).Norm(2.0) <= tol * (*this).Norm(2.0). + bool ApproxEqual(const VectorBase<Real> &other, float tol = 0.01) const; + + /// Invert all elements. + void InvertElements(); + + /// Add vector : *this = *this + alpha * rv (with casting between floats and + /// doubles) + template<typename OtherReal> + void AddVec(const Real alpha, const VectorBase<OtherReal> &v); + + /// Add vector : *this = *this + alpha * rv^2 [element-wise squaring]. + void AddVec2(const Real alpha, const VectorBase<Real> &v); + + /// Add vector : *this = *this + alpha * rv^2 [element-wise squaring], + /// with casting between floats and doubles. + template<typename OtherReal> + void AddVec2(const Real alpha, const VectorBase<OtherReal> &v); + + /// Add matrix times vector : this <-- beta*this + alpha*M*v. + /// Calls BLAS GEMV. + void AddMatVec(const Real alpha, const MatrixBase<Real> &M, + const MatrixTransposeType trans, const VectorBase<Real> &v, + const Real beta); // **beta previously defaulted to 0.0** + + /// This is as AddMatVec, except optimized for where v contains a lot + /// of zeros. + void AddMatSvec(const Real alpha, const MatrixBase<Real> &M, + const MatrixTransposeType trans, const VectorBase<Real> &v, + const Real beta); // **beta previously defaulted to 0.0** + + + /// Add symmetric positive definite matrix times vector: + /// this <-- beta*this + alpha*M*v. Calls BLAS SPMV. + void AddSpVec(const Real alpha, const SpMatrix<Real> &M, + const VectorBase<Real> &v, const Real beta); // **beta previously defaulted to 0.0** + + /// Add triangular matrix times vector: this <-- beta*this + alpha*M*v. + /// Works even if rv == *this. + void AddTpVec(const Real alpha, const TpMatrix<Real> &M, + const MatrixTransposeType trans, const VectorBase<Real> &v, + const Real beta); // **beta previously defaulted to 0.0** + + /// Set each element to y = (x == orig ? changed : x). + void ReplaceValue(Real orig, Real changed); + + /// Multipy element-by-element by another vector. + void MulElements(const VectorBase<Real> &v); + /// Multipy element-by-element by another vector of different type. + template<typename OtherReal> + void MulElements(const VectorBase<OtherReal> &v); + + /// Divide element-by-element by a vector. + void DivElements(const VectorBase<Real> &v); + /// Divide element-by-element by a vector of different type. + template<typename OtherReal> + void DivElements(const VectorBase<OtherReal> &v); + + /// Add a constant to each element of a vector. + void Add(Real c); + + /// Add element-by-element product of vectlrs: + // this <-- alpha * v .* r + beta*this . + void AddVecVec(Real alpha, const VectorBase<Real> &v, + const VectorBase<Real> &r, Real beta); + + /// Add element-by-element quotient of two vectors. + /// this <---- alpha*v/r + beta*this + void AddVecDivVec(Real alpha, const VectorBase<Real> &v, + const VectorBase<Real> &r, Real beta); + + /// Multiplies all elements by this constant. + void Scale(Real alpha); + + /// Multiplies this vector by lower-triangular marix: *this <-- *this *M + void MulTp(const TpMatrix<Real> &M, const MatrixTransposeType trans); + + /// If trans == kNoTrans, solves M x = b, where b is the value of *this at input + /// and x is the value of *this at output. + /// If trans == kTrans, solves M' x = b. + /// Does not test for M being singular or near-singular, so test it before + /// calling this routine. + void Solve(const TpMatrix<Real> &M, const MatrixTransposeType trans); + + /// Performs a row stack of the matrix M + void CopyRowsFromMat(const MatrixBase<Real> &M); + template<typename OtherReal> + void CopyRowsFromMat(const MatrixBase<OtherReal> &M); + + /// The following is implemented in ../cudamatrix/cu-matrix.cc + void CopyRowsFromMat(const CuMatrixBase<Real> &M); + + /// Performs a column stack of the matrix M + void CopyColsFromMat(const MatrixBase<Real> &M); + + /// Extracts a row of the matrix M. Could also do this with + /// this->Copy(M[row]). + void CopyRowFromMat(const MatrixBase<Real> &M, MatrixIndexT row); + /// Extracts a row of the matrix M with type conversion. + template<typename OtherReal> + void CopyRowFromMat(const MatrixBase<OtherReal> &M, MatrixIndexT row); + + /// Extracts a row of the symmetric matrix S. + template<typename OtherReal> + void CopyRowFromSp(const SpMatrix<OtherReal> &S, MatrixIndexT row); + + /// Extracts a column of the matrix M. + template<typename OtherReal> + void CopyColFromMat(const MatrixBase<OtherReal> &M , MatrixIndexT col); + + /// Extracts the diagonal of the matrix M. + void CopyDiagFromMat(const MatrixBase<Real> &M); + + /// Extracts the diagonal of a packed matrix M; works for Sp or Tp. + void CopyDiagFromPacked(const PackedMatrix<Real> &M); + + + /// Extracts the diagonal of a symmetric matrix. + inline void CopyDiagFromSp(const SpMatrix<Real> &M) { CopyDiagFromPacked(M); } + + /// Extracts the diagonal of a triangular matrix. + inline void CopyDiagFromTp(const TpMatrix<Real> &M) { CopyDiagFromPacked(M); } + + /// Returns the maximum value of any element, or -infinity for the empty vector. + Real Max() const; + + /// Returns the maximum value of any element, and the associated index. + /// Error if vector is empty. + Real Max(MatrixIndexT *index) const; + + /// Returns the minimum value of any element, or +infinity for the empty vector. + Real Min() const; + + /// Returns the minimum value of any element, and the associated index. + /// Error if vector is empty. + Real Min(MatrixIndexT *index) const; + + /// Returns sum of the elements + Real Sum() const; + + /// Returns sum of the logs of the elements. More efficient than + /// just taking log of each. Will return NaN if any elements are + /// negative. + Real SumLog() const; + + /// Does *this = alpha * (sum of rows of M) + beta * *this. + void AddRowSumMat(Real alpha, const MatrixBase<Real> &M, Real beta = 1.0); + + /// Does *this = alpha * (sum of columns of M) + beta * *this. + void AddColSumMat(Real alpha, const MatrixBase<Real> &M, Real beta = 1.0); + + /// Add the diagonal of a matrix times itself: + /// *this = diag(M M^T) + beta * *this (if trans == kNoTrans), or + /// *this = diag(M^T M) + beta * *this (if trans == kTrans). + void AddDiagMat2(Real alpha, const MatrixBase<Real> &M, + MatrixTransposeType trans = kNoTrans, Real beta = 1.0); + + /// Add the diagonal of a matrix product: *this = diag(M N), assuming the + /// "trans" arguments are both kNoTrans; for transpose arguments, it behaves + /// as you would expect. + void AddDiagMatMat(Real alpha, const MatrixBase<Real> &M, MatrixTransposeType transM, + const MatrixBase<Real> &N, MatrixTransposeType transN, + Real beta = 1.0); + + /// Returns log(sum(exp())) without exp overflow + /// If prune > 0.0, ignores terms less than the max - prune. + /// [Note: in future, if prune = 0.0, it will take the max. + /// For now, use -1 if you don't want it to prune.] + Real LogSumExp(Real prune = -1.0) const; + + /// Reads from C++ stream (option to add to existing contents). + /// Throws exception on failure + void Read(std::istream & in, bool binary, bool add = false); + + /// Writes to C++ stream (option to write in binary). + void Write(std::ostream &Out, bool binary) const; + + friend class VectorBase<double>; + friend class VectorBase<float>; + friend class CuVectorBase<Real>; + friend class CuVector<Real>; + protected: + /// Destructor; does not deallocate memory, this is handled by child classes. + /// This destructor is protected so this object so this object can only be + /// deleted via a child. + ~VectorBase() {} + + /// Empty initializer, corresponds to vector of zero size. + explicit VectorBase(): data_(NULL), dim_(0) { + KALDI_ASSERT_IS_FLOATING_TYPE(Real); + } + +// Took this out since it is not currently used, and it is possible to create +// objects where the allocated memory is not the same size as dim_ : Arnab +// /// Initializer from a pointer and a size; keeps the pointer internally +// /// (ownership or non-ownership depends on the child class). +// explicit VectorBase(Real* data, MatrixIndexT dim) +// : data_(data), dim_(dim) {} + + // Arnab : made this protected since it is unsafe too. + /// Load data into the vector: sz must match own size. + void CopyFromPtr(const Real* Data, MatrixIndexT sz); + + /// data memory area + Real* data_; + /// dimension of vector + MatrixIndexT dim_; + KALDI_DISALLOW_COPY_AND_ASSIGN(VectorBase); +}; // class VectorBase + +/** @brief A class representing a vector. + * + * This class provides a way to work with vectors in kaldi. + * It encapsulates basic operations and memory optimizations. */ +template<typename Real> +class Vector: public VectorBase<Real> { + public: + /// Constructor that takes no arguments. Initializes to empty. + Vector(): VectorBase<Real>() {} + + /// Constructor with specific size. Sets to all-zero by default + /// if set_zero == false, memory contents are undefined. + explicit Vector(const MatrixIndexT s, + MatrixResizeType resize_type = kSetZero) + : VectorBase<Real>() { Resize(s, resize_type); } + + /// Copy constructor from CUDA vector + /// This is defined in ../cudamatrix/cu-vector.h + template<typename OtherReal> + explicit Vector(const CuVectorBase<OtherReal> &cu); + + /// Copy constructor. The need for this is controversial. + Vector(const Vector<Real> &v) : VectorBase<Real>() { // (cannot be explicit) + Resize(v.Dim(), kUndefined); + this->CopyFromVec(v); + } + + /// Copy-constructor from base-class, needed to copy from SubVector. + explicit Vector(const VectorBase<Real> &v) : VectorBase<Real>() { + Resize(v.Dim(), kUndefined); + this->CopyFromVec(v); + } + + /// Type conversion constructor. + template<typename OtherReal> + explicit Vector(const VectorBase<OtherReal> &v): VectorBase<Real>() { + Resize(v.Dim(), kUndefined); + this->CopyFromVec(v); + } + +// Took this out since it is unsafe : Arnab +// /// Constructor from a pointer and a size; copies the data to a location +// /// it owns. +// Vector(const Real* Data, const MatrixIndexT s): VectorBase<Real>() { +// Resize(s); + // CopyFromPtr(Data, s); +// } + + + /// Swaps the contents of *this and *other. Shallow swap. + void Swap(Vector<Real> *other); + + /// Destructor. Deallocates memory. + ~Vector() { Destroy(); } + + /// Read function using C++ streams. Can also add to existing contents + /// of matrix. + void Read(std::istream & in, bool binary, bool add = false); + + /// Set vector to a specified size (can be zero). + /// The value of the new data depends on resize_type: + /// -if kSetZero, the new data will be zero + /// -if kUndefined, the new data will be undefined + /// -if kCopyData, the new data will be the same as the old data in any + /// shared positions, and zero elsewhere. + /// This function takes time proportional to the number of data elements. + void Resize(MatrixIndexT length, MatrixResizeType resize_type = kSetZero); + + /// Remove one element and shifts later elements down. + void RemoveElement(MatrixIndexT i); + + /// Assignment operator, protected so it can only be used by std::vector + Vector<Real> &operator = (const Vector<Real> &other) { + Resize(other.Dim(), kUndefined); + this->CopyFromVec(other); + return *this; + } + + /// Assignment operator that takes VectorBase. + Vector<Real> &operator = (const VectorBase<Real> &other) { + Resize(other.Dim(), kUndefined); + this->CopyFromVec(other); + return *this; + } + private: + /// Init assumes the current contents of the class are invalid (i.e. junk or + /// has already been freed), and it sets the vector to newly allocated memory + /// with the specified dimension. dim == 0 is acceptable. The memory contents + /// pointed to by data_ will be undefined. + void Init(const MatrixIndexT dim); + + /// Destroy function, called internally. + void Destroy(); + +}; + + +/// Represents a non-allocating general vector which can be defined +/// as a sub-vector of higher-level vector [or as the row of a matrix]. +template<typename Real> +class SubVector : public VectorBase<Real> { + public: + /// Constructor from a Vector or SubVector. + /// SubVectors are not const-safe and it's very hard to make them + /// so for now we just give up. This function contains const_cast. + SubVector(const VectorBase<Real> &t, const MatrixIndexT origin, + const MatrixIndexT length) : VectorBase<Real>() { + // following assert equiv to origin>=0 && length>=0 && + // origin+length <= rt.dim_ + KALDI_ASSERT(static_cast<UnsignedMatrixIndexT>(origin)+ + static_cast<UnsignedMatrixIndexT>(length) <= + static_cast<UnsignedMatrixIndexT>(t.Dim())); + VectorBase<Real>::data_ = const_cast<Real*> (t.Data()+origin); + VectorBase<Real>::dim_ = length; + } + + /// This constructor initializes the vector to point at the contents + /// of this packed matrix (SpMatrix or TpMatrix). + SubVector(const PackedMatrix<Real> &M) { + VectorBase<Real>::data_ = const_cast<Real*> (M.Data()); + VectorBase<Real>::dim_ = (M.NumRows()*(M.NumRows()+1))/2; + } + + /// Copy constructor + SubVector(const SubVector &other) : VectorBase<Real> () { + // this copy constructor needed for Range() to work in base class. + VectorBase<Real>::data_ = other.data_; + VectorBase<Real>::dim_ = other.dim_; + } + + /// Constructor from a pointer to memory and a length. Keeps a pointer + /// to the data but does not take ownership (will never delete). + SubVector(Real *data, MatrixIndexT length) : VectorBase<Real> () { + VectorBase<Real>::data_ = data; + VectorBase<Real>::dim_ = length; + } + + + /// This operation does not preserve const-ness, so be careful. + SubVector(const MatrixBase<Real> &matrix, MatrixIndexT row) { + VectorBase<Real>::data_ = const_cast<Real*>(matrix.RowData(row)); + VectorBase<Real>::dim_ = matrix.NumCols(); + } + + ~SubVector() {} ///< Destructor (does nothing; no pointers are owned here). + + private: + /// Disallow assignment operator. + SubVector & operator = (const SubVector &other) {} +}; + +/// @} end of "addtogroup matrix_group" +/// \addtogroup matrix_funcs_io +/// @{ +/// Output to a C++ stream. Non-binary by default (use Write for +/// binary output). +template<typename Real> +std::ostream & operator << (std::ostream & out, const VectorBase<Real> & v); + +/// Input from a C++ stream. Will automatically read text or +/// binary data from the stream. +template<typename Real> +std::istream & operator >> (std::istream & in, VectorBase<Real> & v); + +/// Input from a C++ stream. Will automatically read text or +/// binary data from the stream. +template<typename Real> +std::istream & operator >> (std::istream & in, Vector<Real> & v); +/// @} end of \addtogroup matrix_funcs_io + +/// \addtogroup matrix_funcs_scalar +/// @{ + + +template<typename Real> +bool ApproxEqual(const VectorBase<Real> &a, + const VectorBase<Real> &b, Real tol = 0.01) { + return a.ApproxEqual(b, tol); +} + +template<typename Real> +inline void AssertEqual(VectorBase<Real> &a, VectorBase<Real> &b, + float tol = 0.01) { + KALDI_ASSERT(a.ApproxEqual(b, tol)); +} + + +/// Returns dot product between v1 and v2. +template<typename Real> +Real VecVec(const VectorBase<Real> &v1, const VectorBase<Real> &v2); + +template<typename Real, typename OtherReal> +Real VecVec(const VectorBase<Real> &v1, const VectorBase<OtherReal> &v2); + + +/// Returns \f$ v_1^T M v_2 \f$ . +/// Not as efficient as it could be where v1 == v2. +template<typename Real> +Real VecMatVec(const VectorBase<Real> &v1, const MatrixBase<Real> &M, + const VectorBase<Real> &v2); + +/// @} End of "addtogroup matrix_funcs_scalar" + + +} // namespace kaldi + +// we need to include the implementation +#include "matrix/kaldi-vector-inl.h" + + + +#endif // KALDI_MATRIX_KALDI_VECTOR_H_ + diff --git a/kaldi_io/src/kaldi/matrix/matrix-common.h b/kaldi_io/src/kaldi/matrix/matrix-common.h new file mode 100644 index 0000000..d202b2e --- /dev/null +++ b/kaldi_io/src/kaldi/matrix/matrix-common.h @@ -0,0 +1,100 @@ +// matrix/matrix-common.h + +// Copyright 2009-2011 Microsoft Corporation + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at + +// http://www.apache.org/licenses/LICENSE-2.0 + +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. +#ifndef KALDI_MATRIX_MATRIX_COMMON_H_ +#define KALDI_MATRIX_MATRIX_COMMON_H_ + +// This file contains some #includes, forward declarations +// and typedefs that are needed by all the main header +// files in this directory. + +#include "base/kaldi-common.h" +#include "matrix/kaldi-blas.h" + +namespace kaldi { +typedef enum { + kTrans = CblasTrans, + kNoTrans = CblasNoTrans +} MatrixTransposeType; + +typedef enum { + kSetZero, + kUndefined, + kCopyData +} MatrixResizeType; + +typedef enum { + kTakeLower, + kTakeUpper, + kTakeMean, + kTakeMeanAndCheck +} SpCopyType; + +template<typename Real> class VectorBase; +template<typename Real> class Vector; +template<typename Real> class SubVector; +template<typename Real> class MatrixBase; +template<typename Real> class SubMatrix; +template<typename Real> class Matrix; +template<typename Real> class SpMatrix; +template<typename Real> class TpMatrix; +template<typename Real> class PackedMatrix; + +// these are classes that won't be defined in this +// directory; they're mostly needed for friend declarations. +template<typename Real> class CuMatrixBase; +template<typename Real> class CuSubMatrix; +template<typename Real> class CuMatrix; +template<typename Real> class CuVectorBase; +template<typename Real> class CuSubVector; +template<typename Real> class CuVector; +template<typename Real> class CuPackedMatrix; +template<typename Real> class CuSpMatrix; +template<typename Real> class CuTpMatrix; + +class CompressedMatrix; + +/// This class provides a way for switching between double and float types. +template<typename T> class OtherReal { }; // useful in reading+writing routines + // to switch double and float. +/// A specialized class for switching from float to double. +template<> class OtherReal<float> { + public: + typedef double Real; +}; +/// A specialized class for switching from double to float. +template<> class OtherReal<double> { + public: + typedef float Real; +}; + + +typedef int32 MatrixIndexT; +typedef int32 SignedMatrixIndexT; +typedef uint32 UnsignedMatrixIndexT; + +// If you want to use size_t for the index type, do as follows instead: +//typedef size_t MatrixIndexT; +//typedef ssize_t SignedMatrixIndexT; +//typedef size_t UnsignedMatrixIndexT; + +} + + + +#endif // KALDI_MATRIX_MATRIX_COMMON_H_ diff --git a/kaldi_io/src/kaldi/matrix/matrix-functions-inl.h b/kaldi_io/src/kaldi/matrix/matrix-functions-inl.h new file mode 100644 index 0000000..9fac851 --- /dev/null +++ b/kaldi_io/src/kaldi/matrix/matrix-functions-inl.h @@ -0,0 +1,56 @@ +// matrix/matrix-functions-inl.h + +// Copyright 2009-2011 Microsoft Corporation +// +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. +// +// (*) incorporates, with permission, FFT code from his book +// "Signal Processing with Lapped Transforms", Artech, 1992. + + + +#ifndef KALDI_MATRIX_MATRIX_FUNCTIONS_INL_H_ +#define KALDI_MATRIX_MATRIX_FUNCTIONS_INL_H_ + +namespace kaldi { + +//! ComplexMul implements, inline, the complex multiplication b *= a. +template<typename Real> inline void ComplexMul(const Real &a_re, const Real &a_im, + Real *b_re, Real *b_im) { + Real tmp_re = (*b_re * a_re) - (*b_im * a_im); + *b_im = *b_re * a_im + *b_im * a_re; + *b_re = tmp_re; +} + +template<typename Real> inline void ComplexAddProduct(const Real &a_re, const Real &a_im, + const Real &b_re, const Real &b_im, + Real *c_re, Real *c_im) { + *c_re += b_re*a_re - b_im*a_im; + *c_im += b_re*a_im + b_im*a_re; +} + + +template<typename Real> inline void ComplexImExp(Real x, Real *a_re, Real *a_im) { + *a_re = std::cos(x); + *a_im = std::sin(x); +} + + +} // end namespace kaldi + + +#endif // KALDI_MATRIX_MATRIX_FUNCTIONS_INL_H_ + diff --git a/kaldi_io/src/kaldi/matrix/matrix-functions.h b/kaldi_io/src/kaldi/matrix/matrix-functions.h new file mode 100644 index 0000000..b70ca56 --- /dev/null +++ b/kaldi_io/src/kaldi/matrix/matrix-functions.h @@ -0,0 +1,235 @@ +// matrix/matrix-functions.h + +// Copyright 2009-2011 Microsoft Corporation; Go Vivace Inc.; Jan Silovsky; +// Yanmin Qian; 1991 Henrique (Rico) Malvar (*) +// +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. +// +// (*) incorporates, with permission, FFT code from his book +// "Signal Processing with Lapped Transforms", Artech, 1992. + + + +#ifndef KALDI_MATRIX_MATRIX_FUNCTIONS_H_ +#define KALDI_MATRIX_MATRIX_FUNCTIONS_H_ + +#include "matrix/kaldi-vector.h" +#include "matrix/kaldi-matrix.h" + +namespace kaldi { + +/// @addtogroup matrix_funcs_misc +/// @{ + +/** The function ComplexFft does an Fft on the vector argument v. + v is a vector of even dimension, interpreted for both input + and output as a vector of complex numbers i.e. + \f[ v = ( re_0, im_0, re_1, im_1, ... ) \f] + The dimension of v must be a power of 2. + + If "forward == true" this routine does the Discrete Fourier Transform + (DFT), i.e.: + \f[ vout[m] \leftarrow \sum_{n = 0}^{N-1} vin[i] exp( -2pi m n / N ) \f] + + If "backward" it does the Inverse Discrete Fourier Transform (IDFT) + *WITHOUT THE FACTOR 1/N*, + i.e.: + \f[ vout[m] <-- \sum_{n = 0}^{N-1} vin[i] exp( 2pi m n / N ) \f] + [note the sign difference on the 2 pi for the backward one.] + + Note that this is the definition of the FT given in most texts, but + it differs from the Numerical Recipes version in which the forward + and backward algorithms are flipped. + + Note that you would have to multiply by 1/N after the IDFT to get + back to where you started from. We don't do this because + in some contexts, the transform is made symmetric by multiplying + by sqrt(N) in both passes. The user can do this by themselves. + + See also SplitRadixComplexFft, declared in srfft.h, which is more efficient + but only works if the length of the input is a power of 2. + */ +template<typename Real> void ComplexFft (VectorBase<Real> *v, bool forward, Vector<Real> *tmp_work = NULL); + +/// ComplexFt is the same as ComplexFft but it implements the Fourier +/// transform in an inefficient way. It is mainly included for testing purposes. +/// See comment for ComplexFft to describe the input and outputs and what it does. +template<typename Real> void ComplexFt (const VectorBase<Real> &in, + VectorBase<Real> *out, bool forward); + +/// RealFft is a fourier transform of real inputs. Internally it uses +/// ComplexFft. The input dimension N must be even. If forward == true, +/// it transforms from a sequence of N real points to its complex fourier +/// transform; otherwise it goes in the reverse direction. If you call it +/// in the forward and then reverse direction and multiply by 1.0/N, you +/// will get back the original data. +/// The interpretation of the complex-FFT data is as follows: the array +/// is a sequence of complex numbers C_n of length N/2 with (real, im) format, +/// i.e. [real0, real_{N/2}, real1, im1, real2, im2, real3, im3, ...]. +/// See also SplitRadixRealFft, declared in srfft.h, which is more efficient +/// but only works if the length of the input is a power of 2. + +template<typename Real> void RealFft (VectorBase<Real> *v, bool forward); + + +/// RealFt has the same input and output format as RealFft above, but it is +/// an inefficient implementation included for testing purposes. +template<typename Real> void RealFftInefficient (VectorBase<Real> *v, bool forward); + +/// ComputeDctMatrix computes a matrix corresponding to the DCT, such that +/// M * v equals the DCT of vector v. M must be square at input. +/// This is the type = III DCT with normalization, corresponding to the +/// following equations, where x is the signal and X is the DCT: +/// X_0 = 1/sqrt(2*N) \sum_{n = 0}^{N-1} x_n +/// X_k = 1/sqrt(N) \sum_{n = 0}^{N-1} x_n cos( \pi/N (n + 1/2) k ) +/// This matrix's transpose is its own inverse, so transposing this +/// matrix will give the inverse DCT. +/// Caution: the type III DCT is generally known as the "inverse DCT" (with the +/// type II being the actual DCT), so this function is somewhatd mis-named. It +/// was probably done this way for HTK compatibility. We don't change it +/// because it was this way from the start and changing it would affect the +/// feature generation. + +template<typename Real> void ComputeDctMatrix(Matrix<Real> *M); + + +/// ComplexMul implements, inline, the complex multiplication b *= a. +template<typename Real> inline void ComplexMul(const Real &a_re, const Real &a_im, + Real *b_re, Real *b_im); + +/// ComplexMul implements, inline, the complex operation c += (a * b). +template<typename Real> inline void ComplexAddProduct(const Real &a_re, const Real &a_im, + const Real &b_re, const Real &b_im, + Real *c_re, Real *c_im); + + +/// ComplexImExp implements a <-- exp(i x), inline. +template<typename Real> inline void ComplexImExp(Real x, Real *a_re, Real *a_im); + + +// This class allows you to compute the matrix exponential function +// B = I + A + 1/2! A^2 + 1/3! A^3 + ... +// This method is most accurate where the result is of the same order of +// magnitude as the unit matrix (it will typically not work well when +// the answer has almost-zero eigenvalues or is close to zero). +// It also provides a function that allows you do back-propagate the +// derivative of a scalar function through this calculation. +// The +template<typename Real> +class MatrixExponential { + public: + MatrixExponential() { } + + void Compute(const MatrixBase<Real> &M, MatrixBase<Real> *X); // does *X = exp(M) + + // Version for symmetric matrices (it just copies to full matrix). + void Compute(const SpMatrix<Real> &M, SpMatrix<Real> *X); // does *X = exp(M) + + void Backprop(const MatrixBase<Real> &hX, MatrixBase<Real> *hM) const; // Propagates + // the gradient of a scalar function f backwards through this operation, i.e.: + // if the parameter dX represents df/dX (with no transpose, so element i, j of dX + // is the derivative of f w.r.t. E(i, j)), it sets dM to df/dM, again with no + // transpose (of course, only the part thereof that comes through the effect of + // A on B). This applies to the values of A and E that were called most recently + // with Compute(). + + // Version for symmetric matrices (it just copies to full matrix). + void Backprop(const SpMatrix<Real> &hX, SpMatrix<Real> *hM) const; + + private: + void Clear(); + + static MatrixIndexT ComputeN(const MatrixBase<Real> &M); + + // This is intended for matrices P with small norms: compute B_0 = exp(P) - I. + // Keeps adding terms in the Taylor series till there is no further + // change in the result. Stores some of the powers of A in powers_, + // and the number of terms K as K_. + void ComputeTaylor(const MatrixBase<Real> &P, MatrixBase<Real> *B0); + + // Backprop through the Taylor-series computation above. + // note: hX is \hat{X} in the math; hM is \hat{M} in the math. + void BackpropTaylor(const MatrixBase<Real> &hX, + MatrixBase<Real> *hM) const; + + Matrix<Real> P_; // Equals M * 2^(-N_) + std::vector<Matrix<Real> > B_; // B_[0] = exp(P_) - I, + // B_[k] = 2 B_[k-1] + B_[k-1]^2 [k > 0], + // ( = exp(P_)^k - I ) + // goes from 0..N_ [size N_+1]. + + std::vector<Matrix<Real> > powers_; // powers (>1) of P_ stored here, + // up to all but the last one used in the Taylor expansion (this is the + // last one we need in the backprop). The index is the power minus 2. + + MatrixIndexT N_; // Power N_ >=0 such that P_ = A * 2^(-N_), + // we choose it so that P_ has a sufficiently small norm + // that the Taylor series will converge fast. +}; + + +/** + ComputePCA does a PCA computation, using either outer products + or inner products, whichever is more efficient. Let D be + the dimension of the data points, N be the number of data + points, and G be the PCA dimension we want to retain. We assume + G <= N and G <= D. + + @param X [in] An N x D matrix. Each row of X is a point x_i. + @param U [out] A G x D matrix. Each row of U is a basis element u_i. + @param A [out] An N x D matrix, or NULL. Each row of A is a set of coefficients + in the basis for a point x_i, so A(i, g) is the coefficient of u_i + in x_i. + @param print_eigs [in] If true, prints out diagnostic information about the + eigenvalues. + @param exact [in] If true, does the exact computation; if false, does + a much faster (but almost exact) computation based on the Lanczos + method. +*/ + +template<typename Real> +void ComputePca(const MatrixBase<Real> &X, + MatrixBase<Real> *U, + MatrixBase<Real> *A, + bool print_eigs = false, + bool exact = true); + + + +// This function does: *plus += max(0, a b^T), +// *minus += max(0, -(a b^T)). +template<typename Real> +void AddOuterProductPlusMinus(Real alpha, + const VectorBase<Real> &a, + const VectorBase<Real> &b, + MatrixBase<Real> *plus, + MatrixBase<Real> *minus); + +template<typename Real1, typename Real2> +inline void AssertSameDim(const MatrixBase<Real1> &mat1, const MatrixBase<Real2> &mat2) { + KALDI_ASSERT(mat1.NumRows() == mat2.NumRows() + && mat1.NumCols() == mat2.NumCols()); +} + + +/// @} end of "addtogroup matrix_funcs_misc" + +} // end namespace kaldi + +#include "matrix/matrix-functions-inl.h" + + +#endif diff --git a/kaldi_io/src/kaldi/matrix/matrix-lib.h b/kaldi_io/src/kaldi/matrix/matrix-lib.h new file mode 100644 index 0000000..39acec5 --- /dev/null +++ b/kaldi_io/src/kaldi/matrix/matrix-lib.h @@ -0,0 +1,37 @@ +// matrix/matrix-lib.h + +// Copyright 2009-2011 Ondrej Glembek; Microsoft Corporation; Haihua Xu + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at + +// http://www.apache.org/licenses/LICENSE-2.0 + +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +// Include everything from this directory. +// These files include other stuff that we need. +#ifndef KALDI_MATRIX_MATRIX_LIB_H_ +#define KALDI_MATRIX_MATRIX_LIB_H_ + +#include "matrix/cblas-wrappers.h" +#include "base/kaldi-common.h" +#include "matrix/kaldi-vector.h" +#include "matrix/kaldi-matrix.h" +#include "matrix/sp-matrix.h" +#include "matrix/tp-matrix.h" +#include "matrix/matrix-functions.h" +#include "matrix/srfft.h" +#include "matrix/compressed-matrix.h" +#include "matrix/optimization.h" + +#endif + diff --git a/kaldi_io/src/kaldi/matrix/optimization.h b/kaldi_io/src/kaldi/matrix/optimization.h new file mode 100644 index 0000000..66309ac --- /dev/null +++ b/kaldi_io/src/kaldi/matrix/optimization.h @@ -0,0 +1,248 @@ +// matrix/optimization.h + +// Copyright 2012 Johns Hopkins University (author: Daniel Povey) +// +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. +// +// (*) incorporates, with permission, FFT code from his book +// "Signal Processing with Lapped Transforms", Artech, 1992. + + + +#ifndef KALDI_MATRIX_OPTIMIZATION_H_ +#define KALDI_MATRIX_OPTIMIZATION_H_ + +#include "matrix/kaldi-vector.h" +#include "matrix/kaldi-matrix.h" + +namespace kaldi { + + +/// @addtogroup matrix_optimization +/// @{ + +struct LinearCgdOptions { + int32 max_iters; // Maximum number of iters (if >= 0). + BaseFloat max_error; // Maximum 2-norm of the residual A x - b (convergence + // test) + // Every time the residual 2-norm decreases by this recompute_residual_factor + // since the last time it was computed from scratch, recompute it from + // scratch. This helps to keep the computed residual accurate even in the + // presence of roundoff. + BaseFloat recompute_residual_factor; + + LinearCgdOptions(): max_iters(-1), + max_error(0.0), + recompute_residual_factor(0.01) { } +}; + +/* + This function uses linear conjugate gradient descent to approximately solve + the system A x = b. The value of x at entry corresponds to the initial guess + of x. The algorithm continues until the number of iterations equals b.Dim(), + or until the 2-norm of (A x - b) is <= max_error, or until the number of + iterations equals max_iter, whichever happens sooner. It is a requirement + that A be positive definite. + It returns the number of iterations that were actually executed (this is + useful for testing purposes). +*/ +template<typename Real> +int32 LinearCgd(const LinearCgdOptions &opts, + const SpMatrix<Real> &A, const VectorBase<Real> &b, + VectorBase<Real> *x); + + + + + + +/** + This is an implementation of L-BFGS. It pushes responsibility for + determining when to stop, onto the user. There is no call-back here: + everything is done via calls to the class itself (see the example in + matrix-lib-test.cc). This does not implement constrained L-BFGS, but it will + handle constrained problems correctly as long as the function approaches + +infinity (or -infinity for maximization problems) when it gets close to the + bound of the constraint. In these types of problems, you just let the + function value be +infinity for minimization problems, or -infinity for + maximization problems, outside these bounds). +*/ + +struct LbfgsOptions { + bool minimize; // if true, we're minimizing, else maximizing. + int m; // m is the number of stored vectors L-BFGS keeps. + float first_step_learning_rate; // The very first step of L-BFGS is + // like gradient descent. If you want to configure the size of that step, + // you can do it using this variable. + float first_step_length; // If this variable is >0.0, it overrides + // first_step_learning_rate; on the first step we choose an approximate + // Hessian that is the multiple of the identity that would generate this + // step-length, or 1.0 if the gradient is zero. + float first_step_impr; // If this variable is >0.0, it overrides + // first_step_learning_rate; on the first step we choose an approximate + // Hessian that is the multiple of the identity that would generate this + // amount of objective function improvement (assuming the "real" objf + // was linear). + float c1; // A constant in Armijo rule = Wolfe condition i) + float c2; // A constant in Wolfe condition ii) + float d; // An amount > 1.0 (default 2.0) that we initially multiply or + // divide the step length by, in the line search. + int max_line_search_iters; // after this many iters we restart L-BFGS. + int avg_step_length; // number of iters to avg step length over, in + // RecentStepLength(). + + LbfgsOptions (bool minimize = true): + minimize(minimize), + m(10), + first_step_learning_rate(1.0), + first_step_length(0.0), + first_step_impr(0.0), + c1(1.0e-04), + c2(0.9), + d(2.0), + max_line_search_iters(50), + avg_step_length(4) { } +}; + +template<typename Real> +class OptimizeLbfgs { + public: + /// Initializer takes the starting value of x. + OptimizeLbfgs(const VectorBase<Real> &x, + const LbfgsOptions &opts); + + /// This returns the value of the variable x that has the best objective + /// function so far, and the corresponding objective function value if + /// requested. This would typically be called only at the end. + const VectorBase<Real>& GetValue(Real *objf_value = NULL) const; + + /// This returns the value at which the function wants us + /// to compute the objective function and gradient. + const VectorBase<Real>& GetProposedValue() const { return new_x_; } + + /// Returns the average magnitude of the last n steps (but not + /// more than the number we have stored). Before we have taken + /// any steps, returns +infinity. Note: if the most recent + /// step length was 0, it returns 0, regardless of the other + /// step lengths. This makes it suitable as a convergence test + /// (else we'd generate NaN's). + Real RecentStepLength() const; + + /// The user calls this function to provide the class with the + /// function and gradient info at the point GetProposedValue(). + /// If this point is outside the constraints you can set function_value + /// to {+infinity,-infinity} for {minimization,maximization} problems. + /// In this case the gradient, and also the second derivative (if you call + /// the second overloaded version of this function) will be ignored. + void DoStep(Real function_value, + const VectorBase<Real> &gradient); + + /// The user can call this version of DoStep() if it is desired to set some + /// kind of approximate Hessian on this iteration. Note: it is a prerequisite + /// that diag_approx_2nd_deriv must be strictly positive (minimizing), or + /// negative (maximizing). + void DoStep(Real function_value, + const VectorBase<Real> &gradient, + const VectorBase<Real> &diag_approx_2nd_deriv); + + private: + KALDI_DISALLOW_COPY_AND_ASSIGN(OptimizeLbfgs); + + + // The following variable says what stage of the computation we're at. + // Refer to Algorithm 7.5 (L-BFGS) of Nodecdal & Wright, "Numerical + // Optimization", 2nd edition. + // kBeforeStep means we're about to do + /// "compute p_k <-- - H_k \delta f_k" (i.e. Algorithm 7.4). + // kWithinStep means we're at some point within line search; note + // that line search is iterative so we can stay in this state more + // than one time on each iteration. + enum ComputationState { + kBeforeStep, + kWithinStep, // This means we're within the step-size computation, and + // have not yet done the 1st function evaluation. + }; + + inline MatrixIndexT Dim() { return x_.Dim(); } + inline MatrixIndexT M() { return opts_.m; } + SubVector<Real> Y(MatrixIndexT i) { + return SubVector<Real>(data_, (i % M()) * 2); // vector y_i + } + SubVector<Real> S(MatrixIndexT i) { + return SubVector<Real>(data_, (i % M()) * 2 + 1); // vector s_i + } + // The following are subroutines within DoStep(): + bool AcceptStep(Real function_value, + const VectorBase<Real> &gradient); + void Restart(const VectorBase<Real> &x, + Real function_value, + const VectorBase<Real> &gradient); + void ComputeNewDirection(Real function_value, + const VectorBase<Real> &gradient); + void ComputeHifNeeded(const VectorBase<Real> &gradient); + void StepSizeIteration(Real function_value, + const VectorBase<Real> &gradient); + void RecordStepLength(Real s); + + + LbfgsOptions opts_; + SignedMatrixIndexT k_; // Iteration number, starts from zero. Gets set back to zero + // when we restart. + + ComputationState computation_state_; + bool H_was_set_; // True if the user specified H_; if false, + // we'll use a heuristic to estimate it. + + + Vector<Real> x_; // current x. + Vector<Real> new_x_; // the x proposed in the line search. + Vector<Real> best_x_; // the x with the best objective function so far + // (either the same as x_ or something in the current line search.) + Vector<Real> deriv_; // The most recently evaluated derivative-- at x_k. + Vector<Real> temp_; + Real f_; // The function evaluated at x_k. + Real best_f_; // the best objective function so far. + Real d_; // a number d > 1.0, but during an iteration we may decrease this, when + // we switch between armijo and wolfe failures. + + int num_wolfe_i_failures_; // the num times we decreased step size. + int num_wolfe_ii_failures_; // the num times we increased step size. + enum { kWolfeI, kWolfeII, kNone } last_failure_type_; // last type of step-search + // failure on this iter. + + Vector<Real> H_; // Current inverse-Hessian estimate. May be computed by this class itself, + // or provided by user using 2nd form of SetGradientInfo(). + Matrix<Real> data_; // dimension (m*2) x dim. Even rows store + // gradients y_i, odd rows store steps s_i. + Vector<Real> rho_; // dimension m; rho_(m) = 1/(y_m^T s_m), Eq. 7.17. + + std::vector<Real> step_lengths_; // The step sizes we took on the last + // (up to m) iterations; these are not stored in a rotating buffer but + // are shifted by one each time (this is more convenient when we + // restart, as we keep this info past restarting). + + +}; + +/// @} + + +} // end namespace kaldi + + + +#endif + diff --git a/kaldi_io/src/kaldi/matrix/packed-matrix.h b/kaldi_io/src/kaldi/matrix/packed-matrix.h new file mode 100644 index 0000000..722d932 --- /dev/null +++ b/kaldi_io/src/kaldi/matrix/packed-matrix.h @@ -0,0 +1,197 @@ +// matrix/packed-matrix.h + +// Copyright 2009-2013 Ondrej Glembek; Lukas Burget; Microsoft Corporation; +// Saarland University; Yanmin Qian; +// Johns Hopkins University (Author: Daniel Povey) + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at + +// http://www.apache.org/licenses/LICENSE-2.0 + +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#ifndef KALDI_MATRIX_PACKED_MATRIX_H_ +#define KALDI_MATRIX_PACKED_MATRIX_H_ + +#include "matrix/matrix-common.h" +#include <algorithm> + +namespace kaldi { + +/// \addtogroup matrix_funcs_io +// we need to declare the friend << operator here +template<typename Real> +std::ostream & operator <<(std::ostream & out, const PackedMatrix<Real>& M); + + +/// \addtogroup matrix_group +/// @{ + +/// @brief Packed matrix: base class for triangular and symmetric matrices. +template<typename Real> class PackedMatrix { + friend class CuPackedMatrix<Real>; + public: + //friend class CuPackedMatrix<Real>; + + PackedMatrix() : data_(NULL), num_rows_(0) {} + + explicit PackedMatrix(MatrixIndexT r, MatrixResizeType resize_type = kSetZero): + data_(NULL) { Resize(r, resize_type); } + + explicit PackedMatrix(const PackedMatrix<Real> &orig) : data_(NULL) { + Resize(orig.num_rows_, kUndefined); + CopyFromPacked(orig); + } + + template<typename OtherReal> + explicit PackedMatrix(const PackedMatrix<OtherReal> &orig) : data_(NULL) { + Resize(orig.NumRows(), kUndefined); + CopyFromPacked(orig); + } + + void SetZero(); /// < Set to zero + void SetUnit(); /// < Set to unit matrix. + void SetRandn(); /// < Set to random values of a normal distribution + + Real Trace() const; + + // Needed for inclusion in std::vector + PackedMatrix<Real> & operator =(const PackedMatrix<Real> &other) { + Resize(other.NumRows()); + CopyFromPacked(other); + return *this; + } + + ~PackedMatrix() { + Destroy(); + } + + /// Set packed matrix to a specified size (can be zero). + /// The value of the new data depends on resize_type: + /// -if kSetZero, the new data will be zero + /// -if kUndefined, the new data will be undefined + /// -if kCopyData, the new data will be the same as the old data in any + /// shared positions, and zero elsewhere. + /// This function takes time proportional to the number of data elements. + void Resize(MatrixIndexT nRows, MatrixResizeType resize_type = kSetZero); + + void AddToDiag(const Real r); // Adds r to diaginal + + void ScaleDiag(const Real alpha); // Scales diagonal by alpha. + + void SetDiag(const Real alpha); // Sets diagonal to this value. + + template<typename OtherReal> + void CopyFromPacked(const PackedMatrix<OtherReal> &orig); + + /// CopyFromVec just interprets the vector as having the same layout + /// as the packed matrix. Must have the same dimension, i.e. + /// orig.Dim() == (NumRows()*(NumRows()+1)) / 2; + template<typename OtherReal> + void CopyFromVec(const SubVector<OtherReal> &orig); + + Real* Data() { return data_; } + const Real* Data() const { return data_; } + inline MatrixIndexT NumRows() const { return num_rows_; } + inline MatrixIndexT NumCols() const { return num_rows_; } + size_t SizeInBytes() const { + size_t nr = static_cast<size_t>(num_rows_); + return ((nr * (nr+1)) / 2) * sizeof(Real); + } + + //MatrixIndexT Stride() const { return stride_; } + + // This code is duplicated in child classes to avoid extra levels of calls. + Real operator() (MatrixIndexT r, MatrixIndexT c) const { + KALDI_ASSERT(static_cast<UnsignedMatrixIndexT>(r) < + static_cast<UnsignedMatrixIndexT>(num_rows_) && + static_cast<UnsignedMatrixIndexT>(c) < + static_cast<UnsignedMatrixIndexT>(num_rows_) + && c <= r); + return *(data_ + (r * (r + 1)) / 2 + c); + } + + // This code is duplicated in child classes to avoid extra levels of calls. + Real &operator() (MatrixIndexT r, MatrixIndexT c) { + KALDI_ASSERT(static_cast<UnsignedMatrixIndexT>(r) < + static_cast<UnsignedMatrixIndexT>(num_rows_) && + static_cast<UnsignedMatrixIndexT>(c) < + static_cast<UnsignedMatrixIndexT>(num_rows_) + && c <= r); + return *(data_ + (r * (r + 1)) / 2 + c); + } + + Real Max() const { + KALDI_ASSERT(num_rows_ > 0); + return * (std::max_element(data_, data_ + ((num_rows_*(num_rows_+1))/2) )); + } + + Real Min() const { + KALDI_ASSERT(num_rows_ > 0); + return * (std::min_element(data_, data_ + ((num_rows_*(num_rows_+1))/2) )); + } + + void Scale(Real c); + + friend std::ostream & operator << <> (std::ostream & out, + const PackedMatrix<Real> &m); + // Use instead of stream<<*this, if you want to add to existing contents. + // Will throw exception on failure. + void Read(std::istream &in, bool binary, bool add = false); + + void Write(std::ostream &out, bool binary) const; + + void Destroy(); + + /// Swaps the contents of *this and *other. Shallow swap. + void Swap(PackedMatrix<Real> *other); + void Swap(Matrix<Real> *other); + + + protected: + // Will only be called from this class or derived classes. + void AddPacked(const Real alpha, const PackedMatrix<Real>& M); + Real *data_; + MatrixIndexT num_rows_; + //MatrixIndexT stride_; + private: + /// Init assumes the current contents of the class are is invalid (i.e. junk or + /// has already been freed), and it sets the matrixd to newly allocated memory + /// with the specified dimension. dim == 0 is acceptable. The memory contents + /// pointed to by data_ will be undefined. + void Init(MatrixIndexT dim); + +}; +/// @} end "addtogroup matrix_group" + + +/// \addtogroup matrix_funcs_io +/// @{ + +template<typename Real> +std::ostream & operator << (std::ostream & os, const PackedMatrix<Real>& M) { + M.Write(os, false); + return os; +} + +template<typename Real> +std::istream & operator >> (std::istream &is, PackedMatrix<Real> &M) { + M.Read(is, false); + return is; +} + +/// @} + +} // namespace kaldi + +#endif + diff --git a/kaldi_io/src/kaldi/matrix/sp-matrix-inl.h b/kaldi_io/src/kaldi/matrix/sp-matrix-inl.h new file mode 100644 index 0000000..1579592 --- /dev/null +++ b/kaldi_io/src/kaldi/matrix/sp-matrix-inl.h @@ -0,0 +1,42 @@ +// matrix/sp-matrix-inl.h + +// Copyright 2009-2011 Ondrej Glembek; Microsoft Corporation; Haihua Xu + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at + +// http://www.apache.org/licenses/LICENSE-2.0 + +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#ifndef KALDI_MATRIX_SP_MATRIX_INL_H_ +#define KALDI_MATRIX_SP_MATRIX_INL_H_ + +#include "matrix/tp-matrix.h" + +namespace kaldi { + +// All the lines in this file seem to be declaring template specializations. +// These tell the compiler that we'll implement the templated function +// separately for the different template arguments (float, double). + +template<> +double SolveQuadraticProblem(const SpMatrix<double> &H, const VectorBase<double> &g, + const SolverOptions &opts, VectorBase<double> *x); + +template<> +float SolveQuadraticProblem(const SpMatrix<float> &H, const VectorBase<float> &g, + const SolverOptions &opts, VectorBase<float> *x); + +} // namespace kaldi + + +#endif // KALDI_MATRIX_SP_MATRIX_INL_H_ diff --git a/kaldi_io/src/kaldi/matrix/sp-matrix.h b/kaldi_io/src/kaldi/matrix/sp-matrix.h new file mode 100644 index 0000000..209d24a --- /dev/null +++ b/kaldi_io/src/kaldi/matrix/sp-matrix.h @@ -0,0 +1,524 @@ +// matrix/sp-matrix.h + +// Copyright 2009-2011 Ondrej Glembek; Microsoft Corporation; Lukas Burget; +// Saarland University; Ariya Rastrow; Yanmin Qian; +// Jan Silovsky + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at + +// http://www.apache.org/licenses/LICENSE-2.0 + +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. +#ifndef KALDI_MATRIX_SP_MATRIX_H_ +#define KALDI_MATRIX_SP_MATRIX_H_ + +#include <algorithm> +#include <vector> + +#include "matrix/packed-matrix.h" + +namespace kaldi { + + +/// \addtogroup matrix_group +/// @{ +template<typename Real> class SpMatrix; + + +/** + * @brief Packed symetric matrix class +*/ +template<typename Real> +class SpMatrix : public PackedMatrix<Real> { + friend class CuSpMatrix<Real>; + public: + // so it can use our assignment operator. + friend class std::vector<Matrix<Real> >; + + SpMatrix(): PackedMatrix<Real>() {} + + /// Copy constructor from CUDA version of SpMatrix + /// This is defined in ../cudamatrix/cu-sp-matrix.h + + explicit SpMatrix(const CuSpMatrix<Real> &cu); + + explicit SpMatrix(MatrixIndexT r, MatrixResizeType resize_type = kSetZero) + : PackedMatrix<Real>(r, resize_type) {} + + SpMatrix(const SpMatrix<Real> &orig) + : PackedMatrix<Real>(orig) {} + + template<typename OtherReal> + explicit SpMatrix(const SpMatrix<OtherReal> &orig) + : PackedMatrix<Real>(orig) {} + +#ifdef KALDI_PARANOID + explicit SpMatrix(const MatrixBase<Real> & orig, + SpCopyType copy_type = kTakeMeanAndCheck) + : PackedMatrix<Real>(orig.NumRows(), kUndefined) { + CopyFromMat(orig, copy_type); + } +#else + explicit SpMatrix(const MatrixBase<Real> & orig, + SpCopyType copy_type = kTakeMean) + : PackedMatrix<Real>(orig.NumRows(), kUndefined) { + CopyFromMat(orig, copy_type); + } +#endif + + /// Shallow swap. + void Swap(SpMatrix *other); + + inline void Resize(MatrixIndexT nRows, MatrixResizeType resize_type = kSetZero) { + PackedMatrix<Real>::Resize(nRows, resize_type); + } + + void CopyFromSp(const SpMatrix<Real> &other) { + PackedMatrix<Real>::CopyFromPacked(other); + } + + template<typename OtherReal> + void CopyFromSp(const SpMatrix<OtherReal> &other) { + PackedMatrix<Real>::CopyFromPacked(other); + } + +#ifdef KALDI_PARANOID + void CopyFromMat(const MatrixBase<Real> &orig, + SpCopyType copy_type = kTakeMeanAndCheck); +#else // different default arg if non-paranoid mode. + void CopyFromMat(const MatrixBase<Real> &orig, + SpCopyType copy_type = kTakeMean); +#endif + + inline Real operator() (MatrixIndexT r, MatrixIndexT c) const { + // if column is less than row, then swap these as matrix is stored + // as upper-triangular... only allowed for const matrix object. + if (static_cast<UnsignedMatrixIndexT>(c) > + static_cast<UnsignedMatrixIndexT>(r)) + std::swap(c, r); + // c<=r now so don't have to check c. + KALDI_ASSERT(static_cast<UnsignedMatrixIndexT>(r) < + static_cast<UnsignedMatrixIndexT>(this->num_rows_)); + return *(this->data_ + (r*(r+1)) / 2 + c); + // Duplicating code from PackedMatrix.h + } + + inline Real &operator() (MatrixIndexT r, MatrixIndexT c) { + if (static_cast<UnsignedMatrixIndexT>(c) > + static_cast<UnsignedMatrixIndexT>(r)) + std::swap(c, r); + // c<=r now so don't have to check c. + KALDI_ASSERT(static_cast<UnsignedMatrixIndexT>(r) < + static_cast<UnsignedMatrixIndexT>(this->num_rows_)); + return *(this->data_ + (r * (r + 1)) / 2 + c); + // Duplicating code from PackedMatrix.h + } + + using PackedMatrix<Real>::operator =; + using PackedMatrix<Real>::Scale; + + /// matrix inverse. + /// if inverse_needed = false, will fill matrix with garbage. + /// (only useful if logdet wanted). + void Invert(Real *logdet = NULL, Real *det_sign= NULL, + bool inverse_needed = true); + + // Below routine does inversion in double precision, + // even for single-precision object. + void InvertDouble(Real *logdet = NULL, Real *det_sign = NULL, + bool inverse_needed = true); + + /// Returns maximum ratio of singular values. + inline Real Cond() const { + Matrix<Real> tmp(*this); + return tmp.Cond(); + } + + /// Takes matrix to a fraction power via Svd. + /// Will throw exception if matrix is not positive semidefinite + /// (to within a tolerance) + void ApplyPow(Real exponent); + + /// This is the version of SVD that we implement for symmetric positive + /// definite matrices. This exists for historical reasons; right now its + /// internal implementation is the same as Eig(). It computes the eigenvalue + /// decomposition (*this) = P * diag(s) * P^T with P orthogonal. Will throw + /// exception if input is not positive semidefinite to within a tolerance. + void SymPosSemiDefEig(VectorBase<Real> *s, MatrixBase<Real> *P, + Real tolerance = 0.001) const; + + /// Solves the symmetric eigenvalue problem: at end we should have (*this) = P + /// * diag(s) * P^T. We solve the problem using the symmetric QR method. + /// P may be NULL. + /// Implemented in qr.cc. + /// If you need the eigenvalues sorted, the function SortSvd declared in + /// kaldi-matrix is suitable. + void Eig(VectorBase<Real> *s, MatrixBase<Real> *P = NULL) const; + + /// This function gives you, approximately, the largest eigenvalues of the + /// symmetric matrix and the corresponding eigenvectors. (largest meaning, + /// further from zero). It does this by doing a SVD within the Krylov + /// subspace generated by this matrix and a random vector. This is + /// a form of the Lanczos method with complete reorthogonalization, followed + /// by SVD within a smaller dimension ("lanczos_dim"). + /// + /// If *this is m by m, s should be of dimension n and P should be of + /// dimension m by n, with n <= m. The *columns* of P are the approximate + /// eigenvectors; P * diag(s) * P^T would be a low-rank reconstruction of + /// *this. The columns of P will be orthogonal, and the elements of s will be + /// the eigenvalues of *this projected into that subspace, but beyond that + /// there are no exact guarantees. (This is because the convergence of this + /// method is statistical). Note: it only makes sense to use this + /// method if you are in very high dimension and n is substantially smaller + /// than m: for example, if you want the 100 top eigenvalues of a 10k by 10k + /// matrix. This function calls Rand() to initialize the lanczos + /// iterations and also for restarting. + /// If lanczos_dim is zero, it will default to the greater of: + /// s->Dim() + 50 or s->Dim() + s->Dim()/2, but not more than this->Dim(). + /// If lanczos_dim == this->Dim(), you might as well just call the function + /// Eig() since the result will be the same, and Eig() would be faster; the + /// whole point of this function is to reduce the dimension of the SVD + /// computation. + void TopEigs(VectorBase<Real> *s, MatrixBase<Real> *P, + MatrixIndexT lanczos_dim = 0) const; + + + + /// Takes log of the matrix (does eigenvalue decomposition then takes + /// log of eigenvalues and reconstructs). Will throw of not +ve definite. + void Log(); + + + // Takes exponential of the matrix (equivalent to doing eigenvalue + // decomposition then taking exp of eigenvalues and reconstructing). + void Exp(); + + /// Returns the maximum of the absolute values of any of the + /// eigenvalues. + Real MaxAbsEig() const; + + void PrintEigs(const char *name) { + Vector<Real> s((*this).NumRows()); + Matrix<Real> P((*this).NumRows(), (*this).NumCols()); + SymPosSemiDefEig(&s, &P); + KALDI_LOG << "PrintEigs: " << name << ": " << s; + } + + bool IsPosDef() const; // returns true if Cholesky succeeds. + void AddSp(const Real alpha, const SpMatrix<Real> &Ma) { + this->AddPacked(alpha, Ma); + } + + /// Computes log determinant but only for +ve-def matrices + /// (it uses Cholesky). + /// If matrix is not +ve-def, it will throw an exception + /// was LogPDDeterminant() + Real LogPosDefDet() const; + + Real LogDet(Real *det_sign = NULL) const; + + /// rank-one update, this <-- this + alpha v v' + template<typename OtherReal> + void AddVec2(const Real alpha, const VectorBase<OtherReal> &v); + + /// rank-two update, this <-- this + alpha (v w' + w v'). + void AddVecVec(const Real alpha, const VectorBase<Real> &v, + const VectorBase<Real> &w); + + /// Does *this = beta * *thi + alpha * diag(v) * S * diag(v) + void AddVec2Sp(const Real alpha, const VectorBase<Real> &v, + const SpMatrix<Real> &S, const Real beta); + + /// diagonal update, this <-- this + diag(v) + template<typename OtherReal> + void AddDiagVec(const Real alpha, const VectorBase<OtherReal> &v); + + /// rank-N update: + /// if (transM == kNoTrans) + /// (*this) = beta*(*this) + alpha * M * M^T, + /// or (if transM == kTrans) + /// (*this) = beta*(*this) + alpha * M^T * M + /// Note: beta used to default to 0.0. + void AddMat2(const Real alpha, const MatrixBase<Real> &M, + MatrixTransposeType transM, const Real beta); + + /// Extension of rank-N update: + /// this <-- beta*this + alpha * M * A * M^T. + /// (*this) and A are allowed to be the same. + /// If transM == kTrans, then we do it as M^T * A * M. + void AddMat2Sp(const Real alpha, const MatrixBase<Real> &M, + MatrixTransposeType transM, const SpMatrix<Real> &A, + const Real beta = 0.0); + + /// This is a version of AddMat2Sp specialized for when M is fairly sparse. + /// This was required for making the raw-fMLLR code efficient. + void AddSmat2Sp(const Real alpha, const MatrixBase<Real> &M, + MatrixTransposeType transM, const SpMatrix<Real> &A, + const Real beta = 0.0); + + /// The following function does: + /// this <-- beta*this + alpha * T * A * T^T. + /// (*this) and A are allowed to be the same. + /// If transM == kTrans, then we do it as alpha * T^T * A * T. + /// Currently it just calls AddMat2Sp, but if needed we + /// can implement it more efficiently. + void AddTp2Sp(const Real alpha, const TpMatrix<Real> &T, + MatrixTransposeType transM, const SpMatrix<Real> &A, + const Real beta = 0.0); + + /// The following function does: + /// this <-- beta*this + alpha * T * T^T. + /// (*this) and A are allowed to be the same. + /// If transM == kTrans, then we do it as alpha * T^T * T + /// Currently it just calls AddMat2, but if needed we + /// can implement it more efficiently. + void AddTp2(const Real alpha, const TpMatrix<Real> &T, + MatrixTransposeType transM, const Real beta = 0.0); + + /// Extension of rank-N update: + /// this <-- beta*this + alpha * M * diag(v) * M^T. + /// if transM == kTrans, then + /// this <-- beta*this + alpha * M^T * diag(v) * M. + void AddMat2Vec(const Real alpha, const MatrixBase<Real> &M, + MatrixTransposeType transM, const VectorBase<Real> &v, + const Real beta = 0.0); + + + /// Floors this symmetric matrix to the matrix + /// alpha * Floor, where the matrix Floor is positive + /// definite. + /// It is floored in the sense that after flooring, + /// x^T (*this) x >= x^T (alpha*Floor) x. + /// This is accomplished using an Svd. It will crash + /// if Floor is not positive definite. Returns the number of + /// elements that were floored. + int ApplyFloor(const SpMatrix<Real> &Floor, Real alpha = 1.0, + bool verbose = false); + + /// Floor: Given a positive semidefinite matrix, floors the eigenvalues + /// to the specified quantity. A previous version of this function had + /// a tolerance which is now no longer needed since we have code to + /// do the symmetric eigenvalue decomposition and no longer use the SVD + /// code for that purose. + int ApplyFloor(Real floor); + + bool IsDiagonal(Real cutoff = 1.0e-05) const; + bool IsUnit(Real cutoff = 1.0e-05) const; + bool IsZero(Real cutoff = 1.0e-05) const; + bool IsTridiagonal(Real cutoff = 1.0e-05) const; + + /// sqrt of sum of square elements. + Real FrobeniusNorm() const; + + /// Returns true if ((*this)-other).FrobeniusNorm() <= + /// tol*(*this).FrobeniusNorma() + bool ApproxEqual(const SpMatrix<Real> &other, float tol = 0.01) const; + + // LimitCond: + // Limits the condition of symmetric positive semidefinite matrix to + // a specified value + // by flooring all eigenvalues to a positive number which is some multiple + // of the largest one (or zero if there are no positive eigenvalues). + // Takes the condition number we are willing to accept, and floors + // eigenvalues to the largest eigenvalue divided by this. + // Returns #eigs floored or already equal to the floor. + // Throws exception if input is not positive definite. + // returns #floored. + MatrixIndexT LimitCond(Real maxCond = 1.0e+5, bool invert = false); + + // as LimitCond but all done in double precision. // returns #floored. + MatrixIndexT LimitCondDouble(Real maxCond = 1.0e+5, bool invert = false) { + SpMatrix<double> dmat(*this); + MatrixIndexT ans = dmat.LimitCond(maxCond, invert); + (*this).CopyFromSp(dmat); + return ans; + } + Real Trace() const; + + /// Tridiagonalize the matrix with an orthogonal transformation. If + /// *this starts as S, produce T (and Q, if non-NULL) such that + /// T = Q A Q^T, i.e. S = Q^T T Q. Caution: this is the other way + /// round from most authors (it's more efficient in row-major indexing). + void Tridiagonalize(MatrixBase<Real> *Q); + + /// The symmetric QR algorithm. This will mostly be useful in internal code. + /// Typically, you will call this after Tridiagonalize(), on the same object. + /// When called, *this (call it A at this point) must be tridiagonal; at exit, + /// *this will be a diagonal matrix D that is similar to A via orthogonal + /// transformations. This algorithm right-multiplies Q by orthogonal + /// transformations. It turns *this from a tridiagonal into a diagonal matrix + /// while maintaining that (Q *this Q^T) has the same value at entry and exit. + /// At entry Q should probably be either NULL or orthogonal, but we don't check + /// this. + void Qr(MatrixBase<Real> *Q); + + private: + void EigInternal(VectorBase<Real> *s, MatrixBase<Real> *P, + Real tolerance, int recurse) const; +}; + +/// @} end of "addtogroup matrix_group" + +/// \addtogroup matrix_funcs_scalar +/// @{ + + +/// Returns tr(A B). +float TraceSpSp(const SpMatrix<float> &A, const SpMatrix<float> &B); +double TraceSpSp(const SpMatrix<double> &A, const SpMatrix<double> &B); + + +template<typename Real> +inline bool ApproxEqual(const SpMatrix<Real> &A, + const SpMatrix<Real> &B, Real tol = 0.01) { + return A.ApproxEqual(B, tol); +} + +template<typename Real> +inline void AssertEqual(const SpMatrix<Real> &A, + const SpMatrix<Real> &B, Real tol = 0.01) { + KALDI_ASSERT(ApproxEqual(A, B, tol)); +} + + + +/// Returns tr(A B). +template<typename Real, typename OtherReal> +Real TraceSpSp(const SpMatrix<Real> &A, const SpMatrix<OtherReal> &B); + + + +// TraceSpSpLower is the same as Trace(A B) except the lower-diagonal elements +// are counted only once not twice as they should be. It is useful in certain +// optimizations. +template<typename Real> +Real TraceSpSpLower(const SpMatrix<Real> &A, const SpMatrix<Real> &B); + + +/// Returns tr(A B). +/// No option to transpose B because would make no difference. +template<typename Real> +Real TraceSpMat(const SpMatrix<Real> &A, const MatrixBase<Real> &B); + +/// Returns tr(A B C) +/// (A and C may be transposed as specified by transA and transC). +template<typename Real> +Real TraceMatSpMat(const MatrixBase<Real> &A, MatrixTransposeType transA, + const SpMatrix<Real> &B, const MatrixBase<Real> &C, + MatrixTransposeType transC); + +/// Returns tr (A B C D) +/// (A and C may be transposed as specified by transA and transB). +template<typename Real> +Real TraceMatSpMatSp(const MatrixBase<Real> &A, MatrixTransposeType transA, + const SpMatrix<Real> &B, const MatrixBase<Real> &C, + MatrixTransposeType transC, const SpMatrix<Real> &D); + +/** Computes v1^T * M * v2. Not as efficient as it could be where v1 == v2 + * (but no suitable blas routines available). + */ + +/// Returns \f$ v_1^T M v_2 \f$ +/// Not as efficient as it could be where v1 == v2. +template<typename Real> +Real VecSpVec(const VectorBase<Real> &v1, const SpMatrix<Real> &M, + const VectorBase<Real> &v2); + + +/// @} \addtogroup matrix_funcs_scalar + +/// \addtogroup matrix_funcs_misc +/// @{ + + +/// This class describes the options for maximizing various quadratic objective +/// functions. It's mostly as described in the SGMM paper "the subspace +/// Gaussian mixture model -- a structured model for speech recognition", but +/// the diagonal_precondition option is newly added, to handle problems where +/// different dimensions have very different scaling (we recommend to use the +/// option but it's set false for back compatibility). +struct SolverOptions { + BaseFloat K; // maximum condition number + BaseFloat eps; + std::string name; + bool optimize_delta; + bool diagonal_precondition; + bool print_debug_output; + explicit SolverOptions(const std::string &name): + K(1.0e+4), eps(1.0e-40), name(name), + optimize_delta(true), diagonal_precondition(false), + print_debug_output(true) { } + SolverOptions(): K(1.0e+4), eps(1.0e-40), name("[unknown]"), + optimize_delta(true), diagonal_precondition(false), + print_debug_output(true) { } + void Check() const; +}; + + +/// Maximizes the auxiliary function +/// \f[ Q(x) = x.g - 0.5 x^T H x \f] +/// using a numerically stable method. Like a numerically stable version of +/// \f$ x := Q^{-1} g. \f$ +/// Assumes H positive semidefinite. +/// Returns the objective-function change. + +template<typename Real> +Real SolveQuadraticProblem(const SpMatrix<Real> &H, + const VectorBase<Real> &g, + const SolverOptions &opts, + VectorBase<Real> *x); + + + +/// Maximizes the auxiliary function : +/// \f[ Q(x) = tr(M^T P Y) - 0.5 tr(P M Q M^T) \f] +/// Like a numerically stable version of \f$ M := Y Q^{-1} \f$. +/// Assumes Q and P positive semidefinite, and matrix dimensions match +/// enough to make expressions meaningful. +/// This is mostly as described in the SGMM paper "the subspace Gaussian mixture +/// model -- a structured model for speech recognition", but the +/// diagonal_precondition option is newly added, to handle problems +/// where different dimensions have very different scaling (we recommend to use +/// the option but it's set false for back compatibility). +template<typename Real> +Real SolveQuadraticMatrixProblem(const SpMatrix<Real> &Q, + const MatrixBase<Real> &Y, + const SpMatrix<Real> &P, + const SolverOptions &opts, + MatrixBase<Real> *M); + +/// Maximizes the auxiliary function : +/// \f[ Q(M) = tr(M^T G) -0.5 tr(P_1 M Q_1 M^T) -0.5 tr(P_2 M Q_2 M^T). \f] +/// Encountered in matrix update with a prior. We also apply a limit on the +/// condition but it should be less frequently necessary, and can be set larger. +template<typename Real> +Real SolveDoubleQuadraticMatrixProblem(const MatrixBase<Real> &G, + const SpMatrix<Real> &P1, + const SpMatrix<Real> &P2, + const SpMatrix<Real> &Q1, + const SpMatrix<Real> &Q2, + const SolverOptions &opts, + MatrixBase<Real> *M); + + +/// @} End of "addtogroup matrix_funcs_misc" + +} // namespace kaldi + + +// Including the implementation (now actually just includes some +// template specializations). +#include "matrix/sp-matrix-inl.h" + + +#endif // KALDI_MATRIX_SP_MATRIX_H_ + diff --git a/kaldi_io/src/kaldi/matrix/srfft.h b/kaldi_io/src/kaldi/matrix/srfft.h new file mode 100644 index 0000000..c0d36af --- /dev/null +++ b/kaldi_io/src/kaldi/matrix/srfft.h @@ -0,0 +1,132 @@ +// matrix/srfft.h + +// Copyright 2009-2011 Microsoft Corporation; Go Vivace Inc. +// 2014 Daniel Povey +// +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. +// +// This file includes a modified version of code originally published in Malvar, +// H., "Signal processing with lapped transforms, " Artech House, Inc., 1992. The +// current copyright holder of the original code, Henrique S. Malvar, has given +// his permission for the release of this modified version under the Apache +// License v2.0. + +#ifndef KALDI_MATRIX_SRFFT_H_ +#define KALDI_MATRIX_SRFFT_H_ + +#include "matrix/kaldi-vector.h" +#include "matrix/kaldi-matrix.h" + +namespace kaldi { + +/// @addtogroup matrix_funcs_misc +/// @{ + + +// This class is based on code by Henrique (Rico) Malvar, from his book +// "Signal Processing with Lapped Transforms" (1992). Copied with +// permission, optimized by Go Vivace Inc., and converted into C++ by +// Microsoft Corporation +// This is a more efficient way of doing the complex FFT than ComplexFft +// (declared in matrix-functios.h), but it only works for powers of 2. +// Note: in multi-threaded code, you would need to have one of these objects per +// thread, because multiple calls to Compute in parallel would not work. +template<typename Real> +class SplitRadixComplexFft { + public: + typedef MatrixIndexT Integer; + + // N is the number of complex points (must be a power of two, or this + // will crash). Note that the constructor does some work so it's best to + // initialize the object once and do the computation many times. + SplitRadixComplexFft(Integer N); + + // Does the FFT computation, given pointers to the real and + // imaginary parts. If "forward", do the forward FFT; else + // do the inverse FFT (without the 1/N factor). + // xr and xi are pointers to zero-based arrays of size N, + // containing the real and imaginary parts + // respectively. + void Compute(Real *xr, Real *xi, bool forward) const; + + // This version of Compute takes a single array of size N*2, + // containing [ r0 im0 r1 im1 ... ]. Otherwise its behavior is the + // same as the version above. + void Compute(Real *x, bool forward); + + + // This version of Compute is const; it operates on an array of size N*2 + // containing [ r0 im0 r1 im1 ... ], but it uses the argument "temp_buffer" as + // temporary storage instead of a class-member variable. It will allocate it if + // needed. + void Compute(Real *x, bool forward, std::vector<Real> *temp_buffer) const; + + ~SplitRadixComplexFft(); + + protected: + // temp_buffer_ is allocated only if someone calls Compute with only one Real* + // argument and we need a temporary buffer while creating interleaved data. + std::vector<Real> temp_buffer_; + private: + void ComputeTables(); + void ComputeRecursive(Real *xr, Real *xi, Integer logn) const; + void BitReversePermute(Real *x, Integer logn) const; + + Integer N_; + Integer logn_; // log(N) + + Integer *brseed_; + // brseed is Evans' seed table, ref: (Ref: D. M. W. + // Evans, "An improved digit-reversal permutation algorithm ...", + // IEEE Trans. ASSP, Aug. 1987, pp. 1120-1125). + Real **tab_; // Tables of butterfly coefficients. + + KALDI_DISALLOW_COPY_AND_ASSIGN(SplitRadixComplexFft); +}; + +template<typename Real> +class SplitRadixRealFft: private SplitRadixComplexFft<Real> { + public: + SplitRadixRealFft(MatrixIndexT N): // will fail unless N>=4 and N is a power of 2. + SplitRadixComplexFft<Real> (N/2), N_(N) { } + + /// If forward == true, this function transforms from a sequence of N real points to its complex fourier + /// transform; otherwise it goes in the reverse direction. If you call it + /// in the forward and then reverse direction and multiply by 1.0/N, you + /// will get back the original data. + /// The interpretation of the complex-FFT data is as follows: the array + /// is a sequence of complex numbers C_n of length N/2 with (real, im) format, + /// i.e. [real0, real_{N/2}, real1, im1, real2, im2, real3, im3, ...]. + void Compute(Real *x, bool forward); + + + /// This is as the other Compute() function, but it is a const version that + /// uses a user-supplied buffer. + void Compute(Real *x, bool forward, std::vector<Real> *temp_buffer) const; + + private: + KALDI_DISALLOW_COPY_AND_ASSIGN(SplitRadixRealFft); + int N_; +}; + + +/// @} end of "addtogroup matrix_funcs_misc" + +} // end namespace kaldi + + +#endif + diff --git a/kaldi_io/src/kaldi/matrix/tp-matrix.h b/kaldi_io/src/kaldi/matrix/tp-matrix.h new file mode 100644 index 0000000..f43e86c --- /dev/null +++ b/kaldi_io/src/kaldi/matrix/tp-matrix.h @@ -0,0 +1,131 @@ +// matrix/tp-matrix.h + +// Copyright 2009-2011 Ondrej Glembek; Lukas Burget; Microsoft Corporation; +// Saarland University; Yanmin Qian; Haihua Xu +// 2013 Johns Hopkins Universith (author: Daniel Povey) + + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at + +// http://www.apache.org/licenses/LICENSE-2.0 + +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. +#ifndef KALDI_MATRIX_TP_MATRIX_H_ +#define KALDI_MATRIX_TP_MATRIX_H_ + + +#include "matrix/packed-matrix.h" + +namespace kaldi { +/// \addtogroup matrix_group +/// @{ + +template<typename Real> class TpMatrix; + +/// @brief Packed symetric matrix class +template<typename Real> +class TpMatrix : public PackedMatrix<Real> { + friend class CuTpMatrix<float>; + friend class CuTpMatrix<double>; + public: + TpMatrix() : PackedMatrix<Real>() {} + explicit TpMatrix(MatrixIndexT r, MatrixResizeType resize_type = kSetZero) + : PackedMatrix<Real>(r, resize_type) {} + TpMatrix(const TpMatrix<Real>& orig) : PackedMatrix<Real>(orig) {} + + /// Copy constructor from CUDA TpMatrix + /// This is defined in ../cudamatrix/cu-tp-matrix.cc + explicit TpMatrix(const CuTpMatrix<Real> &cu); + + + template<typename OtherReal> explicit TpMatrix(const TpMatrix<OtherReal>& orig) + : PackedMatrix<Real>(orig) {} + + Real operator() (MatrixIndexT r, MatrixIndexT c) const { + if (static_cast<UnsignedMatrixIndexT>(c) > + static_cast<UnsignedMatrixIndexT>(r)) { + KALDI_ASSERT(static_cast<UnsignedMatrixIndexT>(c) < + static_cast<UnsignedMatrixIndexT>(this->num_rows_)); + return 0; + } + KALDI_ASSERT(static_cast<UnsignedMatrixIndexT>(r) < + static_cast<UnsignedMatrixIndexT>(this->num_rows_)); + // c<=r now so don't have to check c. + return *(this->data_ + (r*(r+1)) / 2 + c); + // Duplicating code from PackedMatrix.h + } + + Real &operator() (MatrixIndexT r, MatrixIndexT c) { + KALDI_ASSERT(static_cast<UnsignedMatrixIndexT>(r) < + static_cast<UnsignedMatrixIndexT>(this->num_rows_)); + KALDI_ASSERT(static_cast<UnsignedMatrixIndexT>(c) <= + static_cast<UnsignedMatrixIndexT>(r) && + "you cannot access the upper triangle of TpMatrix using " + "a non-const matrix object."); + return *(this->data_ + (r*(r+1)) / 2 + c); + // Duplicating code from PackedMatrix.h + } + // Note: Cholesky may throw std::runtime_error + void Cholesky(const SpMatrix<Real>& orig); + + void Invert(); + + // Inverts in double precision. + void InvertDouble() { + TpMatrix<double> dmat(*this); + dmat.Invert(); + (*this).CopyFromTp(dmat); + } + + /// Shallow swap + void Swap(TpMatrix<Real> *other); + + /// Returns the determinant of the matrix (product of diagonals) + Real Determinant(); + + /// CopyFromMat copies the lower triangle of M into *this + /// (or the upper triangle, if Trans == kTrans). + void CopyFromMat(const MatrixBase<Real> &M, + MatrixTransposeType Trans = kNoTrans); + + /// This is implemented in ../cudamatrix/cu-tp-matrix.cc + void CopyFromMat(const CuTpMatrix<Real> &other); + + /// CopyFromTp copies another triangular matrix into this one. + void CopyFromTp(const TpMatrix<Real> &other) { + PackedMatrix<Real>::CopyFromPacked(other); + } + + template<typename OtherReal> void CopyFromTp(const TpMatrix<OtherReal> &other) { + PackedMatrix<Real>::CopyFromPacked(other); + } + + /// AddTp does *this += alpha * M. + void AddTp(const Real alpha, const TpMatrix<Real> &M) { + this->AddPacked(alpha, M); + } + + using PackedMatrix<Real>::operator =; + using PackedMatrix<Real>::Scale; + + void Resize(MatrixIndexT nRows, MatrixResizeType resize_type = kSetZero) { + PackedMatrix<Real>::Resize(nRows, resize_type); + } +}; + +/// @} end of "addtogroup matrix_group". + +} // namespace kaldi + + +#endif + diff --git a/kaldi_io/src/kaldi/tree/build-tree-questions.h b/kaldi_io/src/kaldi/tree/build-tree-questions.h new file mode 100644 index 0000000..a6bcfdd --- /dev/null +++ b/kaldi_io/src/kaldi/tree/build-tree-questions.h @@ -0,0 +1,133 @@ +// tree/build-tree-questions.h + +// Copyright 2009-2011 Microsoft Corporation + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#ifndef KALDI_TREE_BUILD_TREE_QUESTIONS_H_ +#define KALDI_TREE_BUILD_TREE_QUESTIONS_H_ + +#include "util/stl-utils.h" +#include "tree/context-dep.h" + +namespace kaldi { + + +/// \addtogroup tree_group +/// @{ +/// Typedef for statistics to build trees. +typedef std::vector<std::pair<EventType, Clusterable*> > BuildTreeStatsType; + +/// Typedef used when we get "all keys" from a set of stats-- used in specifying +/// which kinds of questions to ask. +typedef enum { kAllKeysInsistIdentical, kAllKeysIntersection, kAllKeysUnion } AllKeysType; + +/// @} + +/// \defgroup tree_group_questions Question sets for decision-tree clustering +/// See \ref tree_internals (and specifically \ref treei_func_questions) for context. +/// \ingroup tree_group +/// @{ + +/// QuestionsForKey is a class used to define the questions for a key, +/// and also options that allow us to refine the question during tree-building +/// (i.e. make a question specific to the location in the tree). +/// The Questions class handles aggregating these options for a set +/// of different keys. +struct QuestionsForKey { // Configuration class associated with a particular key + // (of type EventKeyType). It also contains the questions themselves. + std::vector<std::vector<EventValueType> > initial_questions; + RefineClustersOptions refine_opts; // if refine_opts.max_iter == 0, + // we just pick from the initial questions. + + QuestionsForKey(int32 num_iters = 5): refine_opts(num_iters, 2) { + // refine_cfg with 5 iters and top-n = 2 (this is no restriction because + // RefineClusters called with 2 clusters; would get set to that anyway as + // it's the only possible value for 2 clusters). User has to add questions. + // This config won't work as-is, as it has no questions. + } + + void Check() const { + for (size_t i = 0;i < initial_questions.size();i++) KALDI_ASSERT(IsSorted(initial_questions[i])); + } + + void Write(std::ostream &os, bool binary) const; + void Read(std::istream &is, bool binary); + + // copy and assign allowed. +}; + +/// This class defines, for each EventKeyType, a set of initial questions that +/// it tries and also a number of iterations for which to refine the questions to increase +/// likelihood. It is perhaps a bit more than an options class, as it contains the +/// actual questions. +class Questions { // careful, this is a class. + public: + const QuestionsForKey &GetQuestionsOf(EventKeyType key) const { + std::map<EventKeyType, size_t>::const_iterator iter; + if ( (iter = key_idx_.find(key)) == key_idx_.end()) { + KALDI_ERR << "Questions: no options for key "<< key; + } + size_t idx = iter->second; + KALDI_ASSERT(idx < key_options_.size()); + key_options_[idx]->Check(); + return *(key_options_[idx]); + } + void SetQuestionsOf(EventKeyType key, const QuestionsForKey &options_of_key) { + options_of_key.Check(); + if (key_idx_.count(key) == 0) { + key_idx_[key] = key_options_.size(); + key_options_.push_back(new QuestionsForKey()); + *(key_options_.back()) = options_of_key; + } else { + size_t idx = key_idx_[key]; + KALDI_ASSERT(idx < key_options_.size()); + *(key_options_[idx]) = options_of_key; + } + } + void GetKeysWithQuestions(std::vector<EventKeyType> *keys_out) const { + KALDI_ASSERT(keys_out != NULL); + CopyMapKeysToVector(key_idx_, keys_out); + } + const bool HasQuestionsForKey(EventKeyType key) const { return (key_idx_.count(key) != 0); } + ~Questions() { kaldi::DeletePointers(&key_options_); } + + + /// Initializer with arguments. After using this you would have to set up the config for each key you + /// are going to use, or use InitRand(). + Questions() { } + + + /// InitRand attempts to generate "reasonable" random questions. Only + /// of use for debugging. This initializer creates a config that is + /// ready to use. + /// e.g. num_iters_refine = 0 means just use stated questions (if >1, will use + /// different questions at each split of the tree). + void InitRand(const BuildTreeStatsType &stats, int32 num_quest, int32 num_iters_refine, AllKeysType all_keys_type); + + void Write(std::ostream &os, bool binary) const; + void Read(std::istream &is, bool binary); + private: + std::vector<QuestionsForKey*> key_options_; + std::map<EventKeyType, size_t> key_idx_; + KALDI_DISALLOW_COPY_AND_ASSIGN(Questions); +}; + +/// @} + +}// end namespace kaldi + +#endif // KALDI_TREE_BUILD_TREE_QUESTIONS_H_ diff --git a/kaldi_io/src/kaldi/tree/build-tree-utils.h b/kaldi_io/src/kaldi/tree/build-tree-utils.h new file mode 100644 index 0000000..464fc6b --- /dev/null +++ b/kaldi_io/src/kaldi/tree/build-tree-utils.h @@ -0,0 +1,324 @@ +// tree/build-tree-utils.h + +// Copyright 2009-2011 Microsoft Corporation + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#ifndef KALDI_TREE_BUILD_TREE_UTILS_H_ +#define KALDI_TREE_BUILD_TREE_UTILS_H_ + +#include "tree/build-tree-questions.h" + +// build-tree-questions.h needed for this typedef: +// typedef std::vector<std::pair<EventType, Clusterable*> > BuildTreeStatsType; +// and for other #includes. + +namespace kaldi { + + +/// \defgroup tree_group_lower Low-level functions for manipulating statistics and event-maps +/// See \ref tree_internals and specifically \ref treei_func for context. +/// \ingroup tree_group +/// +/// @{ + + + +/// This frees the Clusterable* pointers in "stats", where non-NULL, and sets them to NULL. +/// Does not delete the pointer "stats" itself. +void DeleteBuildTreeStats(BuildTreeStatsType *stats); + +/// Writes BuildTreeStats object. This works even if pointers are NULL. +void WriteBuildTreeStats(std::ostream &os, bool binary, + const BuildTreeStatsType &stats); + +/// Reads BuildTreeStats object. The "example" argument must be of the same +/// type as the stats on disk, and is needed for access to the correct "Read" +/// function. It was organized this way for easier extensibility (so adding new +/// Clusterable derived classes isn't painful) +void ReadBuildTreeStats(std::istream &is, bool binary, + const Clusterable &example, BuildTreeStatsType *stats); + +/// Convenience function e.g. to work out possible values of the phones from just the stats. +/// Returns true if key was always defined inside the stats. +/// May be used with and == NULL to find out of key was always defined. +bool PossibleValues(EventKeyType key, const BuildTreeStatsType &stats, + std::vector<EventValueType> *ans); + + +/// Splits stats according to the EventMap, indexing them at output by the +/// leaf type. A utility function. NOTE-- pointers in stats_out point to +/// the same memory location as those in stats. No copying of Clusterable* +/// objects happens. Will add to stats in stats_out if non-empty at input. +/// This function may increase the size of vector stats_out as necessary +/// to accommodate stats, but will never decrease the size. +void SplitStatsByMap(const BuildTreeStatsType &stats_in, const EventMap &e, + std::vector<BuildTreeStatsType> *stats_out); + +/// SplitStatsByKey splits stats up according to the value of a particular key, +/// which must be always defined and nonnegative. Like MapStats. Pointers to +/// Clusterable* in stats_out are not newly allocated-- they are the same as the +/// ones in stats_in. Generally they will still be owned at stats_in (user can +/// decide where to allocate ownership). +void SplitStatsByKey(const BuildTreeStatsType &stats_in, EventKeyType key, + std::vector<BuildTreeStatsType> *stats_out); + + +/// Converts stats from a given context-window (N) and central-position (P) to a +/// different N and P, by possibly reducing context. This function does a job +/// that's quite specific to the "normal" stats format we use. See \ref +/// tree_window for background. This function may delete some keys and change +/// others, depending on the N and P values. It expects that at input, all keys +/// will either be -1 or lie between 0 and oldN-1. At output, keys will be +/// either -1 or between 0 and newN-1. +/// Returns false if we could not convert the stats (e.g. because newN is larger +/// than oldN). +bool ConvertStats(int32 oldN, int32 oldP, int32 newN, int32 newP, + BuildTreeStatsType *stats); + + +/// FilterStatsByKey filters the stats according the value of a specified key. +/// If include_if_present == true, it only outputs the stats whose key is in +/// "values"; otherwise it only outputs the stats whose key is not in "values". +/// At input, "values" must be sorted and unique, and all stats in "stats_in" +/// must have "key" defined. At output, pointers to Clusterable* in stats_out +/// are not newly allocated-- they are the same as the ones in stats_in. +void FilterStatsByKey(const BuildTreeStatsType &stats_in, + EventKeyType key, + std::vector<EventValueType> &values, + bool include_if_present, // true-> retain only if in "values", + // false-> retain only if not in "values". + BuildTreeStatsType *stats_out); + + +/// Sums stats, or returns NULL stats_in has no non-NULL stats. +/// Stats are newly allocated, owned by caller. +Clusterable *SumStats(const BuildTreeStatsType &stats_in); + +/// Sums the normalizer [typically, data-count] over the stats. +BaseFloat SumNormalizer(const BuildTreeStatsType &stats_in); + +/// Sums the objective function over the stats. +BaseFloat SumObjf(const BuildTreeStatsType &stats_in); + + +/// Sum a vector of stats. Leaves NULL as pointer if no stats available. +/// The pointers in stats_out are owned by caller. At output, there may be +/// NULLs in the vector stats_out. +void SumStatsVec(const std::vector<BuildTreeStatsType> &stats_in, std::vector<Clusterable*> *stats_out); + +/// Cluster the stats given the event map return the total objf given those clusters. +BaseFloat ObjfGivenMap(const BuildTreeStatsType &stats_in, const EventMap &e); + + +/// FindAllKeys puts in *keys the (sorted, unique) list of all key identities in the stats. +/// If type == kAllKeysInsistIdentical, it will insist that this set of keys is the same for all the +/// stats (else exception is thrown). +/// if type == kAllKeysIntersection, it will return the smallest common set of keys present in +/// the set of stats +/// if type== kAllKeysUnion (currently probably not so useful since maps will return "undefined" +/// if key is not present), it will return the union of all the keys present in the stats. +void FindAllKeys(const BuildTreeStatsType &stats, AllKeysType keys_type, + std::vector<EventKeyType> *keys); + + +/// @} + + +/** + \defgroup tree_group_intermediate Intermediate-level functions used in building the tree + These functions are are used in top-level tree-building code (\ref tree_group_top); see + \ref tree_internals for documentation. + \ingroup tree_group + @{ +*/ + + +/// Returns a tree with just one node. Used @ start of tree-building process. +/// Not really used in current recipes. +inline EventMap *TrivialTree(int32 *num_leaves) { + KALDI_ASSERT(*num_leaves == 0); // in envisaged usage. + return new ConstantEventMap( (*num_leaves)++ ); +} + +/// DoTableSplit does a complete split on this key (e.g. might correspond to central phone +/// (key = P-1), or HMM-state position (key == kPdfClass == -1). Stats used to work out possible +/// values of the event. "num_leaves" is used to allocate new leaves. All stats must have +/// this key defined, or this function will crash. +EventMap *DoTableSplit(const EventMap &orig, EventKeyType key, + const BuildTreeStatsType &stats, int32 *num_leaves); + + +/// DoTableSplitMultiple does a complete split on all the keys, in order from keys[0], +/// keys[1] +/// and so on. The stats are used to work out possible values corresponding to the key. +/// "num_leaves" is used to allocate new leaves. All stats must have +/// the keys defined, or this function will crash. +/// Returns a newly allocated event map. +EventMap *DoTableSplitMultiple(const EventMap &orig, + const std::vector<EventKeyType> &keys, + const BuildTreeStatsType &stats, + int32 *num_leaves); + + +/// "ClusterEventMapGetMapping" clusters the leaves of the EventMap, with "thresh" a delta-likelihood +/// threshold to control how many leaves we combine (might be the same as the delta-like +/// threshold used in splitting. +// The function returns the #leaves we combined. The same leaf-ids of the leaves being clustered +// will be used for the clustered leaves (but other than that there is no special rule which +// leaf-ids should be used at output). +// It outputs the mapping for leaves, in "mapping", which may be empty at the start +// but may also contain mappings for other parts of the tree, which must contain +// disjoint leaves from this part. This is so that Cluster can +// be called multiple times for sub-parts of the tree (with disjoint sets of leaves), +// e.g. if we want to avoid sharing across phones. Afterwards you can use Copy function +// of EventMap to apply the mapping, i.e. call e_in.Copy(mapping) to get the new map. +// Note that the application of Cluster creates gaps in the leaves. You should then +// call RenumberEventMap(e_in.Copy(mapping), num_leaves). +// *If you only want to cluster a subset of the leaves (e.g. just non-silence, or just +// a particular phone, do this by providing a set of "stats" that correspond to just +// this subset of leaves*. Leaves with no stats will not be clustered. +// See build-tree.cc for an example of usage. +int ClusterEventMapGetMapping(const EventMap &e_in, const BuildTreeStatsType &stats, + BaseFloat thresh, std::vector<EventMap*> *mapping); + +/// This is as ClusterEventMapGetMapping but a more convenient interface +/// that exposes less of the internals. It uses a bottom-up clustering to +/// combine the leaves, until the log-likelihood decrease from combinging two +/// leaves exceeds the threshold. +EventMap *ClusterEventMap(const EventMap &e_in, const BuildTreeStatsType &stats, + BaseFloat thresh, int32 *num_removed); + +/// This is as ClusterEventMap, but first splits the stats on the keys specified +/// in "keys" (e.g. typically keys = [ -1, P ]), and only clusters within the +/// classes defined by that splitting. +/// Note-- leaves will be non-consecutive at output, use RenumberEventMap. +EventMap *ClusterEventMapRestrictedByKeys(const EventMap &e_in, + const BuildTreeStatsType &stats, + BaseFloat thresh, + const std::vector<EventKeyType> &keys, + int32 *num_removed); + + +/// This version of ClusterEventMapRestricted restricts the clustering to only +/// allow things that "e_restrict" maps to the same value to be clustered +/// together. +EventMap *ClusterEventMapRestrictedByMap(const EventMap &e_in, + const BuildTreeStatsType &stats, + BaseFloat thresh, + const EventMap &e_restrict, + int32 *num_removed); + + +/// RenumberEventMap [intended to be used after calling ClusterEventMap] renumbers +/// an EventMap so its leaves are consecutive. +/// It puts the number of leaves in *num_leaves. If later you need the mapping of +/// the leaves, modify the function and add a new argument. +EventMap *RenumberEventMap(const EventMap &e_in, int32 *num_leaves); + +/// This function remaps the event-map leaves using this mapping, +/// indexed by the number at leaf. +EventMap *MapEventMapLeaves(const EventMap &e_in, + const std::vector<int32> &mapping); + + + +/// ShareEventMapLeaves performs a quite specific function that allows us to +/// generate trees where, for a certain list of phones, and for all states in +/// the phone, all the pdf's are shared. +/// Each element of "values" contains a list of phones (may be just one phone), +/// all states of which we want shared together). Typically at input, "key" will +/// equal P, the central-phone position, and "values" will contain just one +/// list containing the silence phone. +/// This function renumbers the event map leaves after doing the sharing, to +/// make the event-map leaves contiguous. +EventMap *ShareEventMapLeaves(const EventMap &e_in, EventKeyType key, + std::vector<std::vector<EventValueType> > &values, + int32 *num_leaves); + + + +/// Does a decision-tree split at the leaves of an EventMap. +/// @param orig [in] The EventMap whose leaves we want to split. [may be either a trivial or a +/// non-trivial one]. +/// @param stats [in] The statistics for splitting the tree; if you do not want a particular +/// subset of leaves to be split, make sure the stats corresponding to those leaves +/// are not present in "stats". +/// @param qcfg [in] Configuration class that contains initial questions (e.g. sets of phones) +/// for each key and says whether to refine these questions during tree building. +/// @param thresh [in] A log-likelihood threshold (e.g. 300) that can be used to +/// limit the number of leaves; you can use zero and set max_leaves instead. +/// @param max_leaves [in] Will stop leaves being split after they reach this number. +/// @param num_leaves [in,out] A pointer used to allocate leaves; always corresponds to the +/// current number of leaves (is incremented when this is increased). +/// @param objf_impr_out [out] If non-NULL, will be set to the objective improvement due to splitting +/// (not normalized by the number of frames). +/// @param smallest_split_change_out If non-NULL, will be set to the smallest objective-function +/// improvement that we got from splitting any leaf; useful to provide a threshold +/// for ClusterEventMap. +/// @return The EventMap after splitting is returned; pointer is owned by caller. +EventMap *SplitDecisionTree(const EventMap &orig, + const BuildTreeStatsType &stats, + Questions &qcfg, + BaseFloat thresh, + int32 max_leaves, // max_leaves<=0 -> no maximum. + int32 *num_leaves, + BaseFloat *objf_impr_out, + BaseFloat *smallest_split_change_out); + +/// CreateRandomQuestions will initialize a Questions randomly, in a reasonable +/// way [for testing purposes, or when hand-designed questions are not available]. +/// e.g. num_quest = 5 might be a reasonable value if num_iters > 0, or num_quest = 20 otherwise. +void CreateRandomQuestions(const BuildTreeStatsType &stats, int32 num_quest, Questions *cfg_out); + + +/// FindBestSplitForKey is a function used in DoDecisionTreeSplit. +/// It finds the best split for this key, given these stats. +/// It will return 0 if the key was not always defined for the stats. +BaseFloat FindBestSplitForKey(const BuildTreeStatsType &stats, + const Questions &qcfg, + EventKeyType key, + std::vector<EventValueType> *yes_set); + + +/// GetStubMap is used in tree-building functions to get the initial +/// to-states map, before the decision-tree-building process. It creates +/// a simple map that splits on groups of phones. For the set of phones in +/// phone_sets[i] it creates either: if share_roots[i] == true, a single +/// leaf node, or if share_roots[i] == false, separate root nodes for +/// each HMM-position (it goes up to the highest position for any +/// phone in the set, although it will warn if you share roots between +/// phones with different numbers of states, which is a weird thing to +/// do but should still work. If any phone is present +/// in "phone_sets" but "phone2num_pdf_classes" does not map it to a length, +/// it is an error. Note that the behaviour of the resulting map is +/// undefined for phones not present in "phone_sets". +/// At entry, this function should be called with (*num_leaves == 0). +/// It will number the leaves starting from (*num_leaves). + +EventMap *GetStubMap(int32 P, + const std::vector<std::vector<int32> > &phone_sets, + const std::vector<int32> &phone2num_pdf_classes, + const std::vector<bool> &share_roots, // indexed by index into phone_sets. + int32 *num_leaves); +/// Note: GetStubMap with P = 0 can be used to get a standard monophone system. + +/// @} + + +}// end namespace kaldi + +#endif diff --git a/kaldi_io/src/kaldi/tree/build-tree.h b/kaldi_io/src/kaldi/tree/build-tree.h new file mode 100644 index 0000000..37bb108 --- /dev/null +++ b/kaldi_io/src/kaldi/tree/build-tree.h @@ -0,0 +1,250 @@ +// tree/build-tree.h + +// Copyright 2009-2011 Microsoft Corporation + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#ifndef KALDI_TREE_BUILD_TREE_H_ +#define KALDI_TREE_BUILD_TREE_H_ + +// The file build-tree.h contains outer-level routines used in tree-building +// and related tasks, that are directly called by the command-line tools. + +#include "tree/build-tree-utils.h" +#include "tree/context-dep.h" +namespace kaldi { + +/// \defgroup tree_group_top Top-level tree-building functions +/// See \ref tree_internals for context. +/// \ingroup tree_group +/// @{ + +// Note, in tree_group_top we also include AccumulateTreeStats, in +// ../hmm/tree-accu.h (it has some extra dependencies so we didn't +// want to include it here). + +/** + * BuildTree is the normal way to build a set of decision trees. + * The sets "phone_sets" dictate how we set up the roots of the decision trees. + * each set of phones phone_sets[i] has shared decision-tree roots, and if + * the corresponding variable share_roots[i] is true, the root will be shared + * for the different HMM-positions in the phone. All phones in "phone_sets" + * should be in the stats (use FixUnseenPhones to ensure this). + * if for any i, do_split[i] is false, we will not do any tree splitting for + * phones in that set. + * @param qopts [in] Questions options class, contains questions for each key + * (e.g. each phone position) + * @param phone_sets [in] Each element of phone_sets is a set of phones whose + * roots are shared together (prior to decision-tree splitting). + * @param phone2num_pdf_classes [in] A map from phones to the number of + * \ref pdf_class "pdf-classes" + * in the phone (this info is derived from the HmmTopology object) + * @param share_roots [in] A vector the same size as phone_sets; says for each + * phone set whether the root should be shared among all the + * pdf-classes or not. + * @param do_split [in] A vector the same size as phone_sets; says for each + * phone set whether decision-tree splitting should be done + * (generally true for non-silence phones). + * @param stats [in] The statistics used in tree-building. + * @param thresh [in] Threshold used in decision-tree splitting (e.g. 1000), + * or you may use 0 in which case max_leaves becomes the + * constraint. + * @param max_leaves [in] Maximum number of leaves it will create; set this + * to a large number if you want to just specify "thresh". + * @param cluster_thresh [in] Threshold for clustering leaves after decision-tree + * splitting (only within each phone-set); leaves will be combined + * if log-likelihood change is less than this. A value about equal + * to "thresh" is suitable + * if thresh != 0; otherwise, zero will mean no clustering is done, + * or a negative value (e.g. -1) sets it to the smallest likelihood + * change seen during the splitting algorithm; this typically causes + * about a 20% reduction in the number of leaves. + + * @param P [in] The central position of the phone context window, e.g. 1 for a + * triphone system. + * @return Returns a pointer to an EventMap object that is the tree. + +*/ + +EventMap *BuildTree(Questions &qopts, + const std::vector<std::vector<int32> > &phone_sets, + const std::vector<int32> &phone2num_pdf_classes, + const std::vector<bool> &share_roots, + const std::vector<bool> &do_split, + const BuildTreeStatsType &stats, + BaseFloat thresh, + int32 max_leaves, + BaseFloat cluster_thresh, // typically == thresh. If negative, use smallest split. + int32 P); + + +/** + * + * BuildTreeTwoLevel builds a two-level tree, useful for example in building tied mixture + * systems with multiple codebooks. It first builds a small tree by splitting to + * "max_leaves_first". It then splits at the leaves of "max_leaves_first" (think of this + * as creating multiple little trees at the leaves of the first tree), until the total + * number of leaves reaches "max_leaves_second". It then outputs the second tree, along + * with a mapping from the leaf-ids of the second tree to the leaf-ids of the first tree. + * Note that the interface is similar to BuildTree, and in fact it calls BuildTree + * internally. + * + * The sets "phone_sets" dictate how we set up the roots of the decision trees. + * each set of phones phone_sets[i] has shared decision-tree roots, and if + * the corresponding variable share_roots[i] is true, the root will be shared + * for the different HMM-positions in the phone. All phones in "phone_sets" + * should be in the stats (use FixUnseenPhones to ensure this). + * if for any i, do_split[i] is false, we will not do any tree splitting for + * phones in that set. + * + * @param qopts [in] Questions options class, contains questions for each key + * (e.g. each phone position) + * @param phone_sets [in] Each element of phone_sets is a set of phones whose + * roots are shared together (prior to decision-tree splitting). + * @param phone2num_pdf_classes [in] A map from phones to the number of + * \ref pdf_class "pdf-classes" + * in the phone (this info is derived from the HmmTopology object) + * @param share_roots [in] A vector the same size as phone_sets; says for each + * phone set whether the root should be shared among all the + * pdf-classes or not. + * @param do_split [in] A vector the same size as phone_sets; says for each + * phone set whether decision-tree splitting should be done + * (generally true for non-silence phones). + * @param stats [in] The statistics used in tree-building. + * @param max_leaves_first [in] Maximum number of leaves it will create in first + * level of decision tree. + * @param max_leaves_second [in] Maximum number of leaves it will create in second + * level of decision tree. Must be > max_leaves_first. + * @param cluster_leaves [in] Boolean value; if true, we post-cluster the leaves produced + * in the second level of decision-tree split; if false, we don't. + * The threshold for post-clustering is the log-like change of the last + * decision-tree split; this typically causes about a 20% reduction in + * the number of leaves. + * @param P [in] The central position of the phone context window, e.g. 1 for a + * triphone system. + * @param leaf_map [out] Will be set to be a mapping from the leaves of the + * "big" tree to the leaves of the "little" tree, which you can + * view as cluster centers. + * @return Returns a pointer to an EventMap object that is the (big) tree. + +*/ + +EventMap *BuildTreeTwoLevel(Questions &qopts, + const std::vector<std::vector<int32> > &phone_sets, + const std::vector<int32> &phone2num_pdf_classes, + const std::vector<bool> &share_roots, + const std::vector<bool> &do_split, + const BuildTreeStatsType &stats, + int32 max_leaves_first, + int32 max_leaves_second, + bool cluster_leaves, + int32 P, + std::vector<int32> *leaf_map); + + +/// GenRandStats generates random statistics of the form used by BuildTree. +/// It tries to do so in such a way that they mimic "real" stats. The event keys +/// and their corresponding values are: +/// - key == -1 == kPdfClass -> pdf-class, generally corresponds to +/// zero-based position in HMM (0, 1, 2 .. hmm_lengths[phone]-1) +/// - key == 0 -> phone-id of left-most context phone. +/// - key == 1 -> phone-id of one-from-left-most context phone. +/// - key == P-1 -> phone-id of central phone. +/// - key == N-1 -> phone-id of right-most context phone. +/// GenRandStats is useful only for testing but it serves to document the format of +/// stats used by BuildTreeDefault. +/// if is_ctx_dep[phone] is set to false, GenRandStats will not define the keys for +/// other than the P-1'th phone. + +/// @param dim [in] dimension of features. +/// @param num_stats [in] approximate number of separate phones-in-context wanted. +/// @param N [in] context-size (typically 3) +/// @param P [in] central-phone position in zero-based numbering (typically 1) +/// @param phone_ids [in] integer ids of phones +/// @param hmm_lengths [in] lengths of hmm for phone, indexed by phone. +/// @param is_ctx_dep [in] boolean array indexed by phone, saying whether each phone +/// is context dependent. +/// @param ensure_all_phones_covered [in] Boolean argument: if true, GenRandStats +/// ensures that every phone is seen at least once in the central position (P). +/// @param stats_out [out] The statistics that this routine outputs. + +void GenRandStats(int32 dim, int32 num_stats, int32 N, int32 P, + const std::vector<int32> &phone_ids, + const std::vector<int32> &hmm_lengths, + const std::vector<bool> &is_ctx_dep, + bool ensure_all_phones_covered, + BuildTreeStatsType *stats_out); + + +/// included here because it's used in some tree-building +/// calling code. Reads an OpenFst symbl table, +/// discards the symbols and outputs the integers +void ReadSymbolTableAsIntegers(std::string filename, + bool include_eps, + std::vector<int32> *syms); + + + +/** + * Outputs sets of phones that are reasonable for questions + * to ask in the tree-building algorithm. These are obtained by tree + * clustering of the phones; for each node in the tree, all the leaves + * accessible from that node form one of the sets of phones. + * @param stats [in] The statistics as used for normal tree-building. + * @param phone_sets_in [in] All the phones, pre-partitioned into sets. + * The output sets will be various unions of these sets. These sets + * will normally correspond to "real phones", in cases where the phones + * have stress and position markings. + * @param all_pdf_classes_in [in] All the \ref pdf_class "pdf-classes" + * that we consider for clustering. In the normal case this is the singleton + * set {1}, which means that we only consider the central hmm-position + * of the standard 3-state HMM, for clustering purposes. + * @param P [in] The central position in the phone context window; normally + * 1 for triphone system.s + * @param questions_out [out] The questions (sets of phones) are output to here. + **/ +void AutomaticallyObtainQuestions(BuildTreeStatsType &stats, + const std::vector<std::vector<int32> > &phone_sets_in, + const std::vector<int32> &all_pdf_classes_in, + int32 P, + std::vector<std::vector<int32> > *questions_out); + +/// This function clusters the phones (or some initially specified sets of phones) +/// into sets of phones, using a k-means algorithm. Useful, for example, in building +/// simple models for purposes of adaptation. + +void KMeansClusterPhones(BuildTreeStatsType &stats, + const std::vector<std::vector<int32> > &phone_sets_in, + const std::vector<int32> &all_pdf_classes_in, + int32 P, + int32 num_classes, + std::vector<std::vector<int32> > *sets_out); + +/// Reads the roots file (throws on error). Format is lines like: +/// "shared split 1 2 3 4", +/// "not-shared not-split 5", +/// and so on. The numbers are indexes of phones. +void ReadRootsFile(std::istream &is, + std::vector<std::vector<int32> > *phone_sets, + std::vector<bool> *is_shared_root, + std::vector<bool> *is_split_root); + + +/// @} + +}// end namespace kaldi + +#endif diff --git a/kaldi_io/src/kaldi/tree/cluster-utils.h b/kaldi_io/src/kaldi/tree/cluster-utils.h new file mode 100644 index 0000000..55583a2 --- /dev/null +++ b/kaldi_io/src/kaldi/tree/cluster-utils.h @@ -0,0 +1,291 @@ +// tree/cluster-utils.h + +// Copyright 2012 Arnab Ghoshal +// Copyright 2009-2011 Microsoft Corporation; Saarland University + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#ifndef KALDI_TREE_CLUSTER_UTILS_H_ +#define KALDI_TREE_CLUSTER_UTILS_H_ + +#include <vector> +#include "matrix/matrix-lib.h" +#include "itf/clusterable-itf.h" + +namespace kaldi { + +/// \addtogroup clustering_group_simple +/// @{ + +/// Returns the total objective function after adding up all the +/// statistics in the vector (pointers may be NULL). +BaseFloat SumClusterableObjf(const std::vector<Clusterable*> &vec); + +/// Returns the total normalizer (usually count) of the cluster (pointers may be NULL). +BaseFloat SumClusterableNormalizer(const std::vector<Clusterable*> &vec); + +/// Sums stats (ptrs may be NULL). Returns NULL if no non-NULL stats present. +Clusterable *SumClusterable(const std::vector<Clusterable*> &vec); + +/** Fills in any (NULL) holes in "stats" vector, with empty stats, because + * certain algorithms require non-NULL stats. If "stats" nonempty, requires it + * to contain at least one non-NULL pointer that we can call Copy() on. + */ +void EnsureClusterableVectorNotNull(std::vector<Clusterable*> *stats); + + +/** Given stats and a vector "assignments" of the same size (that maps to + * cluster indices), sums the stats up into "clusters." It will add to any + * stats already present in "clusters" (although typically "clusters" will be + * empty when called), and it will extend with NULL pointers for any unseen + * indices. Call EnsureClusterableStatsNotNull afterwards if you want to ensure + * all non-NULL clusters. Pointer in "clusters" are owned by caller. Pointers in + * "stats" do not have to be non-NULL. + */ +void AddToClusters(const std::vector<Clusterable*> &stats, + const std::vector<int32> &assignments, + std::vector<Clusterable*> *clusters); + + +/// AddToClustersOptimized does the same as AddToClusters (it sums up the stats +/// within each cluster, except it uses the sum of all the stats ("total") to +/// optimize the computation for speed, if possible. This will generally only be +/// a significant speedup in the case where there are just two clusters, which +/// can happen in algorithms that are doing binary splits; the idea is that we +/// sum up all the stats in one cluster (the one with the fewest points in it), +/// and then subtract from the total. +void AddToClustersOptimized(const std::vector<Clusterable*> &stats, + const std::vector<int32> &assignments, + const Clusterable &total, + std::vector<Clusterable*> *clusters); + +/// @} end "addtogroup clustering_group_simple" + +/// \addtogroup clustering_group_algo +/// @{ + +// Note, in the algorithms below, it is assumed that the input "points" (which +// is std::vector<Clusterable*>) is all non-NULL. + +/** A bottom-up clustering algorithm. There are two parameters that control how + * many clusters we get: a "max_merge_thresh" which is a threshold for merging + * clusters, and a min_clust which puts a floor on the number of clusters we want. Set + * max_merge_thresh = large to use the min_clust only, or min_clust to 0 to use + * the max_merge_thresh only. + * + * The algorithm is: + * \code + * while (num-clusters > min_clust && smallest_merge_cost <= max_merge_thresh) + * merge the closest two clusters. + * \endcode + * + * @param points [in] Points to be clustered (may not contain NULL pointers) + * @param thresh [in] Threshold on cost change from merging clusters; clusters + * won't be merged if the cost is more than this + * @param min_clust [in] Minimum number of clusters desired; we'll stop merging + * after reaching this number. + * @param clusters_out [out] If non-NULL, will be set to a vector of size equal + * to the number of output clusters, containing the clustered + * statistics. Must be empty when called. + * @param assignments_out [out] If non-NULL, will be resized to the number of + * points, and each element is the index of the cluster that point + * was assigned to. + * @return Returns the total objf change relative to all clusters being separate, which is + * a negative. Note that this is not the same as what the other clustering algorithms return. + */ +BaseFloat ClusterBottomUp(const std::vector<Clusterable*> &points, + BaseFloat thresh, + int32 min_clust, + std::vector<Clusterable*> *clusters_out, + std::vector<int32> *assignments_out); + +/** This is a bottom-up clustering where the points are pre-clustered in a set + * of compartments, such that only points in the same compartment are clustered + * together. The compartment and pair of points with the smallest merge cost + * is selected and the points are clustered. The result stays in the same + * compartment. The code does not merge compartments, and hence assumes that + * the number of compartments is smaller than the 'min_clust' option. + * The clusters in "clusters_out" are newly allocated and owned by the caller. + */ +BaseFloat ClusterBottomUpCompartmentalized( + const std::vector< std::vector<Clusterable*> > &points, BaseFloat thresh, + int32 min_clust, std::vector< std::vector<Clusterable*> > *clusters_out, + std::vector< std::vector<int32> > *assignments_out); + + +struct RefineClustersOptions { + int32 num_iters; // must be >= 0. If zero, does nothing. + int32 top_n; // must be >= 2. + RefineClustersOptions() : num_iters(100), top_n(5) {} + RefineClustersOptions(int32 num_iters_in, int32 top_n_in) + : num_iters(num_iters_in), top_n(top_n_in) {} + // include Write and Read functions because this object gets written/read as + // part of the QuestionsForKeyOptions class. + void Write(std::ostream &os, bool binary) const; + void Read(std::istream &is, bool binary); +}; + +/** RefineClusters is mainly used internally by other clustering algorithms. + * + * It starts with a given assignment of points to clusters and + * keeps trying to improve it by moving points from cluster to cluster, up to + * a maximum number of iterations. + * + * "clusters" and "assignments" are both input and output variables, and so + * both MUST be non-NULL. + * + * "top_n" (>=2) is a pruning value: more is more exact, fewer is faster. The + * algorithm initially finds the "top_n" closest clusters to any given point, + * and from that point only consider move to those "top_n" clusters. Since + * RefineClusters is called multiple times from ClusterKMeans (for instance), + * this is not really a limitation. + */ +BaseFloat RefineClusters(const std::vector<Clusterable*> &points, + std::vector<Clusterable*> *clusters /*non-NULL*/, + std::vector<int32> *assignments /*non-NULL*/, + RefineClustersOptions cfg = RefineClustersOptions()); + +struct ClusterKMeansOptions { + RefineClustersOptions refine_cfg; + int32 num_iters; + int32 num_tries; // if >1, try whole procedure >once and pick best. + bool verbose; + ClusterKMeansOptions() + : refine_cfg(), num_iters(20), num_tries(2), verbose(true) {} +}; + +/** ClusterKMeans is a K-means-like clustering algorithm. It starts with + * pseudo-random initialization of points to clusters and uses RefineClusters + * to iteratively improve the cluster assignments. It does this for + * multiple iterations and picks the result with the best objective function. + * + * + * ClusterKMeans implicitly uses Rand(). It will not necessarily return + * the same value on different calls. Use sRand() if you want consistent + * results. + * The algorithm used in ClusterKMeans is a "k-means-like" algorithm that tries + * to be as efficient as possible. Firstly, since the algorithm it uses + * includes random initialization, it tries the whole thing cfg.num_tries times + * and picks the one with the best objective function. Each try, it does as + * follows: it randomly initializes points to clusters, and then for + * cfg.num_iters iterations it calls RefineClusters(). The options to + * RefineClusters() are given by cfg.refine_cfg. Calling RefineClusters once + * will always be at least as good as doing one iteration of reassigning points to + * clusters, but will generally be quite a bit better (without taking too + * much extra time). + * + * @param points [in] points to be clustered (must be all non-NULL). + * @param num_clust [in] number of clusters requested (it will always return exactly + * this many, or will fail if num_clust > points.size()). + * @param clusters_out [out] may be NULL; if non-NULL, should be empty when called. + * Will be set to a vector of statistics corresponding to the output clusters. + * @param assignments_out [out] may be NULL; if non-NULL, will be set to a vector of + * same size as "points", which says for each point which cluster + * it is assigned to. + * @param cfg [in] configuration class specifying options to the algorithm. + * @return Returns the objective function improvement versus everything being + * in the same cluster. + * + */ +BaseFloat ClusterKMeans(const std::vector<Clusterable*> &points, + int32 num_clust, // exact number of clusters + std::vector<Clusterable*> *clusters_out, // may be NULL + std::vector<int32> *assignments_out, // may be NULL + ClusterKMeansOptions cfg = ClusterKMeansOptions()); + +struct TreeClusterOptions { + ClusterKMeansOptions kmeans_cfg; + int32 branch_factor; + BaseFloat thresh; // Objf change: if >0, may be used to control number of leaves. + TreeClusterOptions() + : kmeans_cfg(), branch_factor(2), thresh(0) { + kmeans_cfg.verbose = false; + } +}; + +/** TreeCluster is a top-down clustering algorithm, using a binary tree (not + * necessarily balanced). Returns objf improvement versus having all points + * in one cluster. The algorithm is: + * - Initialize to 1 cluster (tree with 1 node). + * - Maintain, for each cluster, a "best-binary-split" (using ClusterKMeans + * to do so). Always split the highest scoring cluster, until we can do no + * more splits. + * + * @param points [in] Data points to be clustered + * @param max_clust [in] Maximum number of clusters (you will get exactly this number, + * if there are at least this many points, except if you set the + * cfg.thresh value nonzero, in which case that threshold may limit + * the number of clusters. + * @param clusters_out [out] If non-NULL, will be set to the a vector whose first + * (*num_leaves_out) elements are the leaf clusters, and whose + * subsequent elements are the nonleaf nodes in the tree, in + * topological order with the root node last. Must be empty vector + * when this function is called. + * @param assignments_out [out] If non-NULL, will be set to a vector to a vector the + * same size as "points", where assignments[i] is the leaf node index i + * to which the i'th point gets clustered. + * @param clust_assignments_out [out] If non-NULL, will be set to a vector the same size + * as clusters_out which says for each node (leaf or nonleaf), the + * index of its parent. For the root node (which is last), + * assignments_out[i] == i. For each i, assignments_out[i]>=i, i.e. + * any node's parent is higher numbered than itself. If you don't need + * this information, consider using instead the ClusterTopDown function. + * @param num_leaves_out [out] If non-NULL, will be set to the number of leaf nodes + * in the tree. + * @param cfg [in] Configuration object that controls clustering behavior. Most + * important value is "thresh", which provides an alternative mechanism + * [other than max_clust] to limit the number of leaves. + */ +BaseFloat TreeCluster(const std::vector<Clusterable*> &points, + int32 max_clust, // max number of leaf-level clusters. + std::vector<Clusterable*> *clusters_out, + std::vector<int32> *assignments_out, + std::vector<int32> *clust_assignments_out, + int32 *num_leaves_out, + TreeClusterOptions cfg = TreeClusterOptions()); + + +/** + * A clustering algorithm that internally uses TreeCluster, + * but does not give you the information about the structure of the tree. + * The "clusters_out" and "assignments_out" may be NULL if the outputs are not + * needed. + * + * @param points [in] points to be clustered (must be all non-NULL). + * @param max_clust [in] Maximum number of clusters (you will get exactly this number, + * if there are at least this many points, except if you set the + * cfg.thresh value nonzero, in which case that threshold may limit + * the number of clusters. + * @param clusters_out [out] may be NULL; if non-NULL, should be empty when called. + * Will be set to a vector of statistics corresponding to the output clusters. + * @param assignments_out [out] may be NULL; if non-NULL, will be set to a vector of + * same size as "points", which says for each point which cluster + * it is assigned to. + * @param cfg [in] Configuration object that controls clustering behavior. Most + * important value is "thresh", which provides an alternative mechanism + * [other than max_clust] to limit the number of leaves. +*/ +BaseFloat ClusterTopDown(const std::vector<Clusterable*> &points, + int32 max_clust, // max number of clusters. + std::vector<Clusterable*> *clusters_out, + std::vector<int32> *assignments_out, + TreeClusterOptions cfg = TreeClusterOptions()); + +/// @} end of "addtogroup clustering_group_algo" + +} // end namespace kaldi. + +#endif // KALDI_TREE_CLUSTER_UTILS_H_ diff --git a/kaldi_io/src/kaldi/tree/clusterable-classes.h b/kaldi_io/src/kaldi/tree/clusterable-classes.h new file mode 100644 index 0000000..817d0c6 --- /dev/null +++ b/kaldi_io/src/kaldi/tree/clusterable-classes.h @@ -0,0 +1,158 @@ +// tree/clusterable-classes.h + +// Copyright 2009-2011 Microsoft Corporation; Saarland University +// 2014 Daniel Povey + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#ifndef KALDI_TREE_CLUSTERABLE_CLASSES_H_ +#define KALDI_TREE_CLUSTERABLE_CLASSES_H_ 1 + +#include <string> +#include "itf/clusterable-itf.h" +#include "matrix/matrix-lib.h" + +namespace kaldi { + +// Note: see sgmm/sgmm-clusterable.h for an SGMM-based clusterable +// class. We didn't include it here, to avoid adding an extra +// dependency to this directory. + +/// \addtogroup clustering_group +/// @{ + +/// ScalarClusterable clusters scalars with x^2 loss. +class ScalarClusterable: public Clusterable { + public: + ScalarClusterable(): x_(0), x2_(0), count_(0) {} + explicit ScalarClusterable(BaseFloat x): x_(x), x2_(x*x), count_(1) {} + virtual std::string Type() const { return "scalar"; } + virtual BaseFloat Objf() const; + virtual void SetZero() { count_ = x_ = x2_ = 0.0; } + virtual void Add(const Clusterable &other_in); + virtual void Sub(const Clusterable &other_in); + virtual Clusterable* Copy() const; + virtual BaseFloat Normalizer() const { + return static_cast<BaseFloat>(count_); + } + + // Function to write data to stream. Will organize input later [more complex] + virtual void Write(std::ostream &os, bool binary) const; + virtual Clusterable* ReadNew(std::istream &is, bool binary) const; + + std::string Info(); // For debugging. + BaseFloat Mean() { return (count_ != 0 ? x_/count_ : 0.0); } + private: + BaseFloat x_; + BaseFloat x2_; + BaseFloat count_; + + void Read(std::istream &is, bool binary); +}; + + +/// GaussClusterable wraps Gaussian statistics in a form accessible +/// to generic clustering algorithms. +class GaussClusterable: public Clusterable { + public: + GaussClusterable(): count_(0.0), var_floor_(0.0) {} + GaussClusterable(int32 dim, BaseFloat var_floor): + count_(0.0), stats_(2, dim), var_floor_(var_floor) {} + + GaussClusterable(const Vector<BaseFloat> &x_stats, + const Vector<BaseFloat> &x2_stats, + BaseFloat var_floor, BaseFloat count); + + virtual std::string Type() const { return "gauss"; } + void AddStats(const VectorBase<BaseFloat> &vec, BaseFloat weight = 1.0); + virtual BaseFloat Objf() const; + virtual void SetZero(); + virtual void Add(const Clusterable &other_in); + virtual void Sub(const Clusterable &other_in); + virtual BaseFloat Normalizer() const { return count_; } + virtual Clusterable *Copy() const; + virtual void Scale(BaseFloat f); + virtual void Write(std::ostream &os, bool binary) const; + virtual Clusterable *ReadNew(std::istream &is, bool binary) const; + virtual ~GaussClusterable() {} + + BaseFloat count() const { return count_; } + // The next two functions are not const-correct, because of SubVector. + SubVector<double> x_stats() const { return stats_.Row(0); } + SubVector<double> x2_stats() const { return stats_.Row(1); } + private: + double count_; + Matrix<double> stats_; // two rows: sum, then sum-squared. + double var_floor_; // should be common for all objects created. + + void Read(std::istream &is, bool binary); +}; + +/// @} end of "addtogroup clustering_group" + +inline void GaussClusterable::SetZero() { + count_ = 0; + stats_.SetZero(); +} + +inline GaussClusterable::GaussClusterable(const Vector<BaseFloat> &x_stats, + const Vector<BaseFloat> &x2_stats, + BaseFloat var_floor, BaseFloat count): + count_(count), stats_(2, x_stats.Dim()), var_floor_(var_floor) { + stats_.Row(0).CopyFromVec(x_stats); + stats_.Row(1).CopyFromVec(x2_stats); +} + + +/// VectorClusterable wraps vectors in a form accessible to generic clustering +/// algorithms. Each vector is associated with a weight; these could be 1.0. +/// The objective function (to be maximized) is the negated sum of squared +/// distances from the cluster center to each vector, times that vector's +/// weight. +class VectorClusterable: public Clusterable { + public: + VectorClusterable(): weight_(0.0), sumsq_(0.0) {} + + VectorClusterable(const Vector<BaseFloat> &vector, + BaseFloat weight); + + virtual std::string Type() const { return "vector"; } + // Objf is negated weighted sum of squared distances. + virtual BaseFloat Objf() const; + virtual void SetZero() { weight_ = 0.0; sumsq_ = 0.0; stats_.Set(0.0); } + virtual void Add(const Clusterable &other_in); + virtual void Sub(const Clusterable &other_in); + virtual BaseFloat Normalizer() const { return weight_; } + virtual Clusterable *Copy() const; + virtual void Scale(BaseFloat f); + virtual void Write(std::ostream &os, bool binary) const; + virtual Clusterable *ReadNew(std::istream &is, bool binary) const; + virtual ~VectorClusterable() {} + + private: + double weight_; // sum of weights of the source vectors. Never negative. + Vector<double> stats_; // Equals the weighted sum of the source vectors. + double sumsq_; // Equals the sum over all sources, of weight_ * vec.vec, + // where vec = stats_ / weight_. Used in computing + // the objective function. + void Read(std::istream &is, bool binary); +}; + + + +} // end namespace kaldi. + +#endif // KALDI_TREE_CLUSTERABLE_CLASSES_H_ diff --git a/kaldi_io/src/kaldi/tree/context-dep.h b/kaldi_io/src/kaldi/tree/context-dep.h new file mode 100644 index 0000000..307fcd4 --- /dev/null +++ b/kaldi_io/src/kaldi/tree/context-dep.h @@ -0,0 +1,166 @@ +// tree/context-dep.h + +// Copyright 2009-2011 Microsoft Corporation + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#ifndef KALDI_TREE_CONTEXT_DEP_H_ +#define KALDI_TREE_CONTEXT_DEP_H_ + +#include "itf/context-dep-itf.h" +#include "tree/event-map.h" +#include "matrix/matrix-lib.h" +#include "tree/cluster-utils.h" + +/* + This header provides the declarations for the class ContextDependency, which inherits + from the interface class "ContextDependencyInterface" in itf/context-dep-itf.h. + This is basically a wrapper around an EventMap. The EventMap + (tree/event-map.h) declares most of the internals of the class, and the building routines are + in build-tree.h which uses build-tree-utils.h, which uses cluster-utils.h . */ + + +namespace kaldi { + +static const EventKeyType kPdfClass = -1; // The "name" to which we assign the +// pdf-class (generally corresponds ot position in the HMM, zero-based); +// must not be used for any other event. I.e. the value corresponding to +// this key is the pdf-class (see hmm-topology.h for explanation of what this is). + + +/* ContextDependency is quite a generic decision tree. + + It does not actually do very much-- all the magic is in the EventMap object. + All this class does is to encode the phone context as a sequence of events, and + pass this to the EventMap object to turn into what it will interpret as a + vector of pdfs. + + Different versions of the ContextDependency class that are written in the future may + have slightly different interfaces and pass more stuff in as events, to the + EventMap object. + + In order to separate the process of training decision trees from the process + of actually using them, we do not put any training code into the ContextDependency class. + */ +class ContextDependency: public ContextDependencyInterface { + public: + virtual int32 ContextWidth() const { return N_; } + virtual int32 CentralPosition() const { return P_; } + + + /// returns success or failure; outputs pdf to pdf_id + virtual bool Compute(const std::vector<int32> &phoneseq, + int32 pdf_class, int32 *pdf_id) const; + + virtual int32 NumPdfs() const { + // this routine could be simplified to return to_pdf_->MaxResult()+1. we're a + // bit more paranoid than that. + if (!to_pdf_) return 0; + EventAnswerType max_result = to_pdf_->MaxResult(); + if (max_result < 0 ) return 0; + else return (int32) max_result+1; + } + virtual ContextDependencyInterface *Copy() const { + return new ContextDependency(N_, P_, to_pdf_->Copy()); + } + + /// Read context-dependency object from disk; throws on error + void Read (std::istream &is, bool binary); + + // Constructor with no arguments; will normally be called + // prior to Read() + ContextDependency(): N_(0), P_(0), to_pdf_(NULL) { } + + // Constructor takes ownership of pointers. + ContextDependency(int32 N, int32 P, + EventMap *to_pdf): + N_(N), P_(P), to_pdf_(to_pdf) { } + void Write (std::ostream &os, bool binary) const; + + ~ContextDependency() { if (to_pdf_ != NULL) delete to_pdf_; } + + const EventMap &ToPdfMap() const { return *to_pdf_; } + + /// GetPdfInfo returns a vector indexed by pdf-id, saying for each pdf which + /// pairs of (phone, pdf-class) it can correspond to. (Usually just one). + /// c.f. hmm/hmm-topology.h for meaning of pdf-class. + + void GetPdfInfo(const std::vector<int32> &phones, // list of phones + const std::vector<int32> &num_pdf_classes, // indexed by phone, + std::vector<std::vector<std::pair<int32, int32> > > *pdf_info) + const; + + private: + int32 N_; // + int32 P_; + EventMap *to_pdf_; // owned here. + + KALDI_DISALLOW_COPY_AND_ASSIGN(ContextDependency); +}; + +/// GenRandContextDependency is mainly of use for debugging. Phones must be sorted and uniq +/// on input. +/// @param phones [in] A vector of phone id's [must be sorted and uniq]. +/// @param ensure_all_covered [in] boolean argument; if true, GenRandContextDependency +/// generates a context-dependency object that "works" for all phones [no gaps]. +/// @param num_pdf_classes [out] outputs a vector indexed by phone, of the number +/// of pdf classes (e.g. states) for that phone. +/// @return Returns the a context dependency object. +ContextDependency *GenRandContextDependency(const std::vector<int32> &phones, + bool ensure_all_covered, + std::vector<int32> *num_pdf_classes); + +/// GenRandContextDependencyLarge is like GenRandContextDependency but generates a larger tree +/// with specified N and P for use in "one-time" larger-scale tests. +ContextDependency *GenRandContextDependencyLarge(const std::vector<int32> &phones, + int N, int P, + bool ensure_all_covered, + std::vector<int32> *num_pdf_classes); + +// MonophoneContextDependency() returns a new ContextDependency object that +// corresponds to a monophone system. +// The map phone2num_pdf_classes maps from the phone id to the number of +// pdf-classes we have for that phone (e.g. 3, so the pdf-classes would be +// 0, 1, 2). + +ContextDependency* +MonophoneContextDependency(const std::vector<int32> phones, + const std::vector<int32> phone2num_pdf_classes); + +// MonophoneContextDependencyShared is as MonophoneContextDependency but lets +// you define classes of phones which share pdfs (e.g. different stress-markers of a single +// phone.) Each element of phone_classes is a set of phones that are in that class. +ContextDependency* +MonophoneContextDependencyShared(const std::vector<std::vector<int32> > phone_classes, + const std::vector<int32> phone2num_pdf_classes); + + +// Important note: +// Statistics for training decision trees will be of type: +// std::vector<std::pair<EventType, Clusterable*> > +// We don't make this a typedef as it doesn't add clarity. +// they will be sorted and unique on the EventType member, which +// itself is sorted and unique on the name (see event-map.h). + +// See build-tree.h for functions relating to actually building the decision trees. + + + + +} // namespace Kaldi + + +#endif diff --git a/kaldi_io/src/kaldi/tree/event-map.h b/kaldi_io/src/kaldi/tree/event-map.h new file mode 100644 index 0000000..07fcc2b --- /dev/null +++ b/kaldi_io/src/kaldi/tree/event-map.h @@ -0,0 +1,365 @@ +// tree/event-map.h + +// Copyright 2009-2011 Microsoft Corporation; Haihua Xu + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#ifndef KALDI_TREE_EVENT_MAP_H_ +#define KALDI_TREE_EVENT_MAP_H_ + +#include <vector> +#include <map> +#include <algorithm> +#include "base/kaldi-common.h" +#include "util/stl-utils.h" +#include "util/const-integer-set.h" + +namespace kaldi { + +/// \defgroup event_map_group Event maps +/// \ingroup tree_group +/// See \ref tree_internals for overview, and specifically \ref treei_event_map. + + +// Note RE negative values: some of this code will not work if things of type +// EventValueType are negative. In particular, TableEventMap can't be used if +// things of EventValueType are negative, and additionally TableEventMap won't +// be efficient if things of EventValueType take on extremely large values. The +// EventKeyType can be negative though. + +/// Things of type EventKeyType can take any value. The code does not assume they are contiguous. +/// So values like -1, 1000000 and the like are acceptable. +typedef int32 EventKeyType; + +/// Given current code, things of type EventValueType should generally be nonnegative and in a +/// reasonably small range (e.g. not one million), as we sometimes construct vectors of the size: +/// [largest value we saw for this key]. This deficiency may be fixed in future [would require +/// modifying TableEventMap] +typedef int32 EventValueType; + +/// As far as the event-map code itself is concerned, things of type EventAnswerType may take +/// any value except kNoAnswer (== -1). However, some specific uses of EventMap (e.g. in +/// build-tree-utils.h) assume these quantities are nonnegative. +typedef int32 EventAnswerType; + +typedef std::vector<std::pair<EventKeyType, EventValueType> > EventType; +// It is required to be sorted and have unique keys-- i.e. functions assume this when called +// with this type. + +inline std::pair<EventKeyType, EventValueType> MakeEventPair (EventKeyType k, EventValueType v) { + return std::pair<EventKeyType, EventValueType>(k, v); +} + +void WriteEventType(std::ostream &os, bool binary, const EventType &vec); +void ReadEventType(std::istream &is, bool binary, EventType *vec); + +std::string EventTypeToString(const EventType &evec); // so we can print events out in error messages. + +struct EventMapVectorHash { // Hashing object for EventMapVector. Works for both pointers and references. + // Not used in event-map.{h, cc} + size_t operator () (const EventType &vec); + size_t operator () (const EventType *ptr) { return (*this)(*ptr); } +}; +struct EventMapVectorEqual { // Equality object for EventType pointers-- test equality of underlying vector. + // Not used in event-map.{h, cc} + size_t operator () (const EventType *p1, const EventType *p2) { return (*p1 == *p2); } +}; + + +/// A class that is capable of representing a generic mapping from +/// EventType (which is a vector of (key, value) pairs) to +/// EventAnswerType which is just an integer. See \ref tree_internals +/// for overview. +class EventMap { + public: + static void Check(const EventType &event); // will crash if not sorted and unique on key. + static bool Lookup(const EventType &event, EventKeyType key, EventValueType *ans); + + // Maps events to the answer type. input must be sorted. + virtual bool Map(const EventType &event, EventAnswerType *ans) const = 0; + + // MultiMap maps a partially specified set of events to the set of answers it might + // map to. It appends these to "ans". "ans" is + // **not guaranteed unique at output** if the + // tree contains duplicate answers at leaves -- you should sort & uniq afterwards. + // e.g.: SortAndUniq(ans). + virtual void MultiMap(const EventType &event, std::vector<EventAnswerType> *ans) const = 0; + + // GetChildren() returns the EventMaps that are immediate children of this + // EventMap (if they exist), by putting them in *out. Useful for + // determining the structure of the event map. + virtual void GetChildren(std::vector<EventMap*> *out) const = 0; + + // This Copy() does a deep copy of the event map. + // If new_leaves is nonempty when it reaches a leaf with value l s.t. new_leaves[l] != NULL, + // it replaces it with a copy of that EventMap. This makes it possible to extend and modify + // It's the way we do splits of trees, and clustering of trees. Think about this carefully, because + // the EventMap structure does not support modification of an existing tree. Do not be tempted + // to do this differently, because other kinds of mechanisms would get very messy and unextensible. + // Copy() is the only mechanism to modify a tree. It's similar to a kind of function composition. + // Copy() does not take ownership of the pointers in new_leaves (it uses the Copy() function of those + // EventMaps). + virtual EventMap *Copy(const std::vector<EventMap*> &new_leaves) const = 0; + + EventMap *Copy() const { std::vector<EventMap*> new_leaves; return Copy(new_leaves); } + + // The function MapValues() is intended to be used to map phone-sets between + // different integer representations. For all the keys in the set + // "keys_to_map", it will map the corresponding values using the map + // "value_map". Note: these values are the values in the key->value pairs of + // the EventMap, which really correspond to phones in the usual case; they are + // not the "answers" of the EventMap which correspond to clustered states. In + // case multiple values are mapped to the same value, it will try to deal with + // it gracefully where it can, but will crash if, for example, this would + // cause problems with the TableEventMap. It will also crash if any values + // used for keys in "keys_to_map" are not mapped by "value_map". This + // function is not currently used. + virtual EventMap *MapValues( + const unordered_set<EventKeyType> &keys_to_map, + const unordered_map<EventValueType,EventValueType> &value_map) const = 0; + + // The function Prune() is like Copy(), except it removes parts of the tree + // that return only -1 (it will return NULL if this EventMap returns only -1). + // This is a mechanism to remove parts of the tree-- you would first use the + // Copy() function with a vector of EventMap*, and for the parts you don't + // want, you'd put a ConstantEventMap with -1; you'd then call + // Prune() on the result. This function is not currently used. + virtual EventMap *Prune() const = 0; + + virtual EventAnswerType MaxResult() const { // child classes may override this for efficiency; here is basic version. + // returns -1 if nothing found. + std::vector<EventAnswerType> tmp; EventType empty_event; + MultiMap(empty_event, &tmp); + if (tmp.empty()) { + KALDI_WARN << "EventMap::MaxResult(), empty result"; + return std::numeric_limits<EventAnswerType>::min(); + } + else { return * std::max_element(tmp.begin(), tmp.end()); } + } + + /// Write to stream. + virtual void Write(std::ostream &os, bool binary) = 0; + + virtual ~EventMap() {} + + /// a Write function that takes care of NULL pointers. + static void Write(std::ostream &os, bool binary, EventMap *emap); + /// a Read function that reads an arbitrary EventMap; also + /// works for NULL pointers. + static EventMap *Read(std::istream &is, bool binary); +}; + + +class ConstantEventMap: public EventMap { + public: + virtual bool Map(const EventType &event, EventAnswerType *ans) const { + *ans = answer_; + return true; + } + + virtual void MultiMap(const EventType &, + std::vector<EventAnswerType> *ans) const { + ans->push_back(answer_); + } + + virtual void GetChildren(std::vector<EventMap*> *out) const { out->clear(); } + + virtual EventMap *Copy(const std::vector<EventMap*> &new_leaves) const { + if (answer_ < 0 || answer_ >= (EventAnswerType)new_leaves.size() || + new_leaves[answer_] == NULL) + return new ConstantEventMap(answer_); + else return new_leaves[answer_]->Copy(); + } + + virtual EventMap *MapValues( + const unordered_set<EventKeyType> &keys_to_map, + const unordered_map<EventValueType,EventValueType> &value_map) const { + return new ConstantEventMap(answer_); + } + + virtual EventMap *Prune() const { + return (answer_ == -1 ? NULL : new ConstantEventMap(answer_)); + } + + explicit ConstantEventMap(EventAnswerType answer): answer_(answer) { } + + virtual void Write(std::ostream &os, bool binary); + static ConstantEventMap *Read(std::istream &is, bool binary); + private: + EventAnswerType answer_; + KALDI_DISALLOW_COPY_AND_ASSIGN(ConstantEventMap); +}; + +class TableEventMap: public EventMap { + public: + + virtual bool Map(const EventType &event, EventAnswerType *ans) const { + EventValueType tmp; *ans = -1; // means no answer + if (Lookup(event, key_, &tmp) && tmp >= 0 + && tmp < (EventValueType)table_.size() && table_[tmp] != NULL) { + return table_[tmp]->Map(event, ans); + } + return false; + } + + virtual void GetChildren(std::vector<EventMap*> *out) const { + out->clear(); + for (size_t i = 0; i<table_.size(); i++) + if (table_[i] != NULL) out->push_back(table_[i]); + } + + virtual void MultiMap(const EventType &event, std::vector<EventAnswerType> *ans) const { + EventValueType tmp; + if (Lookup(event, key_, &tmp)) { + if (tmp >= 0 && tmp < (EventValueType)table_.size() && table_[tmp] != NULL) + return table_[tmp]->MultiMap(event, ans); + // else no answers. + } else { // all answers are possible if no such key. + for (size_t i = 0;i < table_.size();i++) + if (table_[i] != NULL) table_[i]->MultiMap(event, ans); // append. + } + } + + virtual EventMap *Prune() const; + + virtual EventMap *MapValues( + const unordered_set<EventKeyType> &keys_to_map, + const unordered_map<EventValueType,EventValueType> &value_map) const; + + /// Takes ownership of pointers. + explicit TableEventMap(EventKeyType key, const std::vector<EventMap*> &table): key_(key), table_(table) {} + /// Takes ownership of pointers. + explicit TableEventMap(EventKeyType key, const std::map<EventValueType, EventMap*> &map_in); + /// This initializer creates a ConstantEventMap for each value in the map. + explicit TableEventMap(EventKeyType key, const std::map<EventValueType, EventAnswerType> &map_in); + + virtual void Write(std::ostream &os, bool binary); + static TableEventMap *Read(std::istream &is, bool binary); + + virtual EventMap *Copy(const std::vector<EventMap*> &new_leaves) const { + std::vector<EventMap*> new_table_(table_.size(), NULL); + for (size_t i = 0;i<table_.size();i++) if (table_[i]) new_table_[i]=table_[i]->Copy(new_leaves); + return new TableEventMap(key_, new_table_); + } + virtual ~TableEventMap() { + DeletePointers(&table_); + } + private: + EventKeyType key_; + std::vector<EventMap*> table_; + KALDI_DISALLOW_COPY_AND_ASSIGN(TableEventMap); +}; + + + + +class SplitEventMap: public EventMap { // A decision tree [non-leaf] node. + public: + + virtual bool Map(const EventType &event, EventAnswerType *ans) const { + EventValueType value; + if (Lookup(event, key_, &value)) { + // if (std::binary_search(yes_set_.begin(), yes_set_.end(), value)) { + if (yes_set_.count(value)) { + return yes_->Map(event, ans); + } + return no_->Map(event, ans); + } + return false; + } + + virtual void MultiMap(const EventType &event, std::vector<EventAnswerType> *ans) const { + EventValueType tmp; + if (Lookup(event, key_, &tmp)) { + if (std::binary_search(yes_set_.begin(), yes_set_.end(), tmp)) + yes_->MultiMap(event, ans); + else + no_->MultiMap(event, ans); + } else { // both yes and no contribute. + yes_->MultiMap(event, ans); + no_->MultiMap(event, ans); + } + } + + virtual void GetChildren(std::vector<EventMap*> *out) const { + out->clear(); + out->push_back(yes_); + out->push_back(no_); + } + + virtual EventMap *Copy(const std::vector<EventMap*> &new_leaves) const { + return new SplitEventMap(key_, yes_set_, yes_->Copy(new_leaves), no_->Copy(new_leaves)); + } + + virtual void Write(std::ostream &os, bool binary); + static SplitEventMap *Read(std::istream &is, bool binary); + + virtual EventMap *Prune() const; + + virtual EventMap *MapValues( + const unordered_set<EventKeyType> &keys_to_map, + const unordered_map<EventValueType,EventValueType> &value_map) const; + + virtual ~SplitEventMap() { Destroy(); } + + /// This constructor takes ownership of the "yes" and "no" arguments. + SplitEventMap(EventKeyType key, const std::vector<EventValueType> &yes_set, + EventMap *yes, EventMap *no): key_(key), yes_set_(yes_set), yes_(yes), no_(no) { + KALDI_PARANOID_ASSERT(IsSorted(yes_set)); + KALDI_ASSERT(yes_ != NULL && no_ != NULL); + } + + + private: + /// This constructor used in the Copy() function. + SplitEventMap(EventKeyType key, const ConstIntegerSet<EventValueType> &yes_set, + EventMap *yes, EventMap *no): key_(key), yes_set_(yes_set), yes_(yes), no_(no) { + KALDI_ASSERT(yes_ != NULL && no_ != NULL); + } + void Destroy() { + delete yes_; delete no_; + } + EventKeyType key_; + // std::vector<EventValueType> yes_set_; + ConstIntegerSet<EventValueType> yes_set_; // more efficient Map function. + EventMap *yes_; // owned here. + EventMap *no_; // owned here. + SplitEventMap &operator = (const SplitEventMap &other); // Disallow. +}; + +/** + This function gets the tree structure of the EventMap "map" in a convenient form. + If "map" corresponds to a tree structure (not necessarily binary) with leaves + uniquely numbered from 0 to num_leaves-1, then the function will return true, + output "num_leaves", and set "parent" to a vector of size equal to the number of + nodes in the tree (nonleaf and leaf), where each index corresponds to a node + and the leaf indices correspond to the values returned by the EventMap from + that leaf; for an index i, parent[i] equals the parent of that node in the tree + structure, where parent[i] > i, except for the last (root) node where parent[i] == i. + If the EventMap does not have this structure (e.g. if multiple different leaf nodes share + the same number), then it will return false. +*/ + +bool GetTreeStructure(const EventMap &map, + int32 *num_leaves, + std::vector<int32> *parents); + + +/// @} end "addtogroup event_map_group" + +} + +#endif diff --git a/kaldi_io/src/kaldi/tree/tree-renderer.h b/kaldi_io/src/kaldi/tree/tree-renderer.h new file mode 100644 index 0000000..5e0b0d8 --- /dev/null +++ b/kaldi_io/src/kaldi/tree/tree-renderer.h @@ -0,0 +1,84 @@ +// tree/tree-renderer.h + +// Copyright 2012 Vassil Panayotov + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#ifndef KALDI_TREE_TREE_RENDERER_H_ +#define KALDI_TREE_TREE_RENDERER_H_ + +#include "base/kaldi-common.h" +#include "tree/event-map.h" +#include "util/common-utils.h" +#include "hmm/transition-model.h" +#include "fst/fstlib.h" + +namespace kaldi { + +// Parses a decision tree file and outputs its description in GraphViz format +class TreeRenderer { + public: + const static int32 kEdgeWidth; // normal width of the edges and state contours + const static int32 kEdgeWidthQuery; // edge and state width when in query + const static std::string kEdgeColor; // normal color for states and edges + const static std::string kEdgeColorQuery; // edge and state color when in query + + TreeRenderer(std::istream &is, bool binary, std::ostream &os, + fst::SymbolTable &phone_syms, bool use_tooltips) + : phone_syms_(phone_syms), is_(is), out_(os), binary_(binary), + N_(-1), use_tooltips_(use_tooltips), next_id_(0) {} + + // Renders the tree and if the "query" parameter is not NULL + // a distinctly colored trace corresponding to the event. + void Render(const EventType *query); + + private: + // Looks-up the next token from the stream and invokes + // the appropriate render method to visualize it + void RenderSubTree(const EventType *query, int32 id); + + // Renders a leaf node (constant event map) + void RenderConstant(const EventType *query, int32 id); + + // Renders a split event map node and the edges to the nodes + // representing YES and NO sets + void RenderSplit(const EventType *query, int32 id); + + // Renders a table event map node and the edges to its (non-null) children + void RenderTable(const EventType *query, int32 id); + + // Makes a comma-separated string from the elements of a set of identifiers + // If the identifiers represent phones, their symbolic representations are used + std::string MakeEdgeLabel(const EventKeyType &key, + const ConstIntegerSet<EventValueType> &intset); + + // Writes the GraphViz representation of a non-leaf node to the out stream + // A question about a phone from the context window or about pdf-class + // is used as a label. + void RenderNonLeaf(int32 id, const EventKeyType &key, bool in_query); + + fst::SymbolTable &phone_syms_; // phone symbols to be used as edge labels + std::istream &is_; // the stream from which the tree is read + std::ostream &out_; // the GraphViz representation is written to this stream + bool binary_; // is the input stream binary? + int32 N_, P_; // context-width and central position + bool use_tooltips_; // use tooltips(useful in e.g. SVG) instead of labels + int32 next_id_; // the first unused GraphViz node ID +}; + +} // namespace kaldi + +#endif // KALDI_TREE_TREE_RENDERER_H_ diff --git a/kaldi_io/src/kaldi/util/basic-filebuf.h b/kaldi_io/src/kaldi/util/basic-filebuf.h new file mode 100644 index 0000000..cf2e079 --- /dev/null +++ b/kaldi_io/src/kaldi/util/basic-filebuf.h @@ -0,0 +1,1065 @@ +/////////////////////////////////////////////////////////////////////////////// +// This is a modified version of the std::basic_filebuf from libc++ +// (http://libcxx.llvm.org/). +// It allows one to create basic_filebuf from an existing FILE* handle or file +// descriptor. +// +// This file is dual licensed under the MIT and the University of Illinois Open +// Source License licenses. See LICENSE.TXT for details (included at the +// bottom). +/////////////////////////////////////////////////////////////////////////////// +#ifndef KALDI_UTIL_BASIC_FILEBUF_H_ +#define KALDI_UTIL_BASIC_FILEBUF_H_ + +/////////////////////////////////////////////////////////////////////////////// +#include <fstream> +#include <cstdio> +#include <cstring> + +/////////////////////////////////////////////////////////////////////////////// +namespace kaldi +{ + +/////////////////////////////////////////////////////////////////////////////// +template <typename CharT, typename Traits = std::char_traits<CharT> > +class basic_filebuf : public std::basic_streambuf<CharT, Traits> +{ +public: + typedef CharT char_type; + typedef Traits traits_type; + typedef typename traits_type::int_type int_type; + typedef typename traits_type::pos_type pos_type; + typedef typename traits_type::off_type off_type; + typedef typename traits_type::state_type state_type; + + basic_filebuf(); + basic_filebuf(basic_filebuf&& rhs); + virtual ~basic_filebuf(); + + basic_filebuf& operator=(basic_filebuf&& rhs); + void swap(basic_filebuf& rhs); + + bool is_open() const; + basic_filebuf* open(const char* s, std::ios_base::openmode mode); + basic_filebuf* open(const std::string& s, std::ios_base::openmode mode); + basic_filebuf* open(int fd, std::ios_base::openmode mode); + basic_filebuf* open(FILE* f, std::ios_base::openmode mode); + basic_filebuf* close(); + + FILE* file() { return this->_M_file; } + int fd() { return fileno(this->_M_file); } + +protected: + int_type underflow() override; + int_type pbackfail(int_type c = traits_type::eof()) override; + int_type overflow (int_type c = traits_type::eof()) override; + std::basic_streambuf<char_type, traits_type>* setbuf(char_type* s, std::streamsize n) override; + pos_type seekoff(off_type off, std::ios_base::seekdir way, + std::ios_base::openmode wch = std::ios_base::in | std::ios_base::out) override; + pos_type seekpos(pos_type sp, + std::ios_base::openmode wch = std::ios_base::in | std::ios_base::out) override; + int sync() override; + void imbue(const std::locale& loc) override; + +protected: + char* _M_extbuf; + const char* _M_extbufnext; + const char* _M_extbufend; + char _M_extbuf_min[8]; + size_t _M_ebs; + char_type* _M_intbuf; + size_t _M_ibs; + FILE* _M_file; + const std::codecvt<char_type, char, state_type>* _M_cv; + state_type _M_st; + state_type _M_st_last; + std::ios_base::openmode _M_om; + std::ios_base::openmode _M_cm; + bool _M_owns_eb; + bool _M_owns_ib; + bool _M_always_noconv; + + const char* _M_get_mode(std::ios_base::openmode mode); + bool _M_read_mode(); + void _M_write_mode(); +}; + +/////////////////////////////////////////////////////////////////////////////// +template <class CharT, class Traits> +basic_filebuf<CharT, Traits>::basic_filebuf() + : _M_extbuf(nullptr), + _M_extbufnext(nullptr), + _M_extbufend(nullptr), + _M_ebs(0), + _M_intbuf(nullptr), + _M_ibs(0), + _M_file(nullptr), + _M_cv(nullptr), + _M_st(), + _M_st_last(), + _M_om(std::ios_base::openmode(0)), + _M_cm(std::ios_base::openmode(0)), + _M_owns_eb(false), + _M_owns_ib(false), + _M_always_noconv(false) +{ + if (std::has_facet<std::codecvt<char_type, char, state_type> >(this->getloc())) + { + _M_cv = &std::use_facet<std::codecvt<char_type, char, state_type> >(this->getloc()); + _M_always_noconv = _M_cv->always_noconv(); + } + setbuf(0, 4096); +} + +/////////////////////////////////////////////////////////////////////////////// +template <class CharT, class Traits> +basic_filebuf<CharT, Traits>::basic_filebuf(basic_filebuf&& rhs) + : std::basic_streambuf<CharT, Traits>(rhs) +{ + if (rhs._M_extbuf == rhs._M_extbuf_min) + { + _M_extbuf = _M_extbuf_min; + _M_extbufnext = _M_extbuf + (rhs._M_extbufnext - rhs._M_extbuf); + _M_extbufend = _M_extbuf + (rhs._M_extbufend - rhs._M_extbuf); + } + else + { + _M_extbuf = rhs._M_extbuf; + _M_extbufnext = rhs._M_extbufnext; + _M_extbufend = rhs._M_extbufend; + } + _M_ebs = rhs._M_ebs; + _M_intbuf = rhs._M_intbuf; + _M_ibs = rhs._M_ibs; + _M_file = rhs._M_file; + _M_cv = rhs._M_cv; + _M_st = rhs._M_st; + _M_st_last = rhs._M_st_last; + _M_om = rhs._M_om; + _M_cm = rhs._M_cm; + _M_owns_eb = rhs._M_owns_eb; + _M_owns_ib = rhs._M_owns_ib; + _M_always_noconv = rhs._M_always_noconv; + if (rhs.pbase()) + { + if (rhs.pbase() == rhs._M_intbuf) + this->setp(_M_intbuf, _M_intbuf + (rhs. epptr() - rhs.pbase())); + else + this->setp((char_type*)_M_extbuf, + (char_type*)_M_extbuf + (rhs. epptr() - rhs.pbase())); + this->pbump(rhs. pptr() - rhs.pbase()); + } + else if (rhs.eback()) + { + if (rhs.eback() == rhs._M_intbuf) + this->setg(_M_intbuf, _M_intbuf + (rhs.gptr() - rhs.eback()), + _M_intbuf + (rhs.egptr() - rhs.eback())); + else + this->setg((char_type*)_M_extbuf, + (char_type*)_M_extbuf + (rhs.gptr() - rhs.eback()), + (char_type*)_M_extbuf + (rhs.egptr() - rhs.eback())); + } + rhs._M_extbuf = nullptr; + rhs._M_extbufnext = nullptr; + rhs._M_extbufend = nullptr; + rhs._M_ebs = 0; + rhs._M_intbuf = nullptr; + rhs._M_ibs = 0; + rhs._M_file = nullptr; + rhs._M_st = state_type(); + rhs._M_st_last = state_type(); + rhs._M_om = std::ios_base::openmode(0); + rhs._M_cm = std::ios_base::openmode(0); + rhs._M_owns_eb = false; + rhs._M_owns_ib = false; + rhs.setg(0, 0, 0); + rhs.setp(0, 0); +} + +/////////////////////////////////////////////////////////////////////////////// +template <class CharT, class Traits> +inline +basic_filebuf<CharT, Traits>& +basic_filebuf<CharT, Traits>::operator=(basic_filebuf&& rhs) +{ + close(); + swap(rhs); + return *this; +} + +/////////////////////////////////////////////////////////////////////////////// +template <class CharT, class Traits> +basic_filebuf<CharT, Traits>::~basic_filebuf() +{ + // try + // { + // close(); + // } + // catch (...) + // { + // } + if (_M_owns_eb) + delete [] _M_extbuf; + if (_M_owns_ib) + delete [] _M_intbuf; +} + +/////////////////////////////////////////////////////////////////////////////// +template <class CharT, class Traits> +void +basic_filebuf<CharT, Traits>::swap(basic_filebuf& rhs) +{ + std::basic_streambuf<char_type, traits_type>::swap(rhs); + if (_M_extbuf != _M_extbuf_min && rhs._M_extbuf != rhs._M_extbuf_min) + { + std::swap(_M_extbuf, rhs._M_extbuf); + std::swap(_M_extbufnext, rhs._M_extbufnext); + std::swap(_M_extbufend, rhs._M_extbufend); + } + else + { + ptrdiff_t ln = _M_extbufnext - _M_extbuf; + ptrdiff_t le = _M_extbufend - _M_extbuf; + ptrdiff_t rn = rhs._M_extbufnext - rhs._M_extbuf; + ptrdiff_t re = rhs._M_extbufend - rhs._M_extbuf; + if (_M_extbuf == _M_extbuf_min && rhs._M_extbuf != rhs._M_extbuf_min) + { + _M_extbuf = rhs._M_extbuf; + rhs._M_extbuf = rhs._M_extbuf_min; + } + else if (_M_extbuf != _M_extbuf_min && rhs._M_extbuf == rhs._M_extbuf_min) + { + rhs._M_extbuf = _M_extbuf; + _M_extbuf = _M_extbuf_min; + } + _M_extbufnext = _M_extbuf + rn; + _M_extbufend = _M_extbuf + re; + rhs._M_extbufnext = rhs._M_extbuf + ln; + rhs._M_extbufend = rhs._M_extbuf + le; + } + std::swap(_M_ebs, rhs._M_ebs); + std::swap(_M_intbuf, rhs._M_intbuf); + std::swap(_M_ibs, rhs._M_ibs); + std::swap(_M_file, rhs._M_file); + std::swap(_M_cv, rhs._M_cv); + std::swap(_M_st, rhs._M_st); + std::swap(_M_st_last, rhs._M_st_last); + std::swap(_M_om, rhs._M_om); + std::swap(_M_cm, rhs._M_cm); + std::swap(_M_owns_eb, rhs._M_owns_eb); + std::swap(_M_owns_ib, rhs._M_owns_ib); + std::swap(_M_always_noconv, rhs._M_always_noconv); + if (this->eback() == (char_type*)rhs._M_extbuf_min) + { + ptrdiff_t n = this->gptr() - this->eback(); + ptrdiff_t e = this->egptr() - this->eback(); + this->setg((char_type*)_M_extbuf_min, + (char_type*)_M_extbuf_min + n, + (char_type*)_M_extbuf_min + e); + } + else if (this->pbase() == (char_type*)rhs._M_extbuf_min) + { + ptrdiff_t n = this->pptr() - this->pbase(); + ptrdiff_t e = this->epptr() - this->pbase(); + this->setp((char_type*)_M_extbuf_min, + (char_type*)_M_extbuf_min + e); + this->pbump(n); + } + if (rhs.eback() == (char_type*)_M_extbuf_min) + { + ptrdiff_t n = rhs.gptr() - rhs.eback(); + ptrdiff_t e = rhs.egptr() - rhs.eback(); + rhs.setg((char_type*)rhs._M_extbuf_min, + (char_type*)rhs._M_extbuf_min + n, + (char_type*)rhs._M_extbuf_min + e); + } + else if (rhs.pbase() == (char_type*)_M_extbuf_min) + { + ptrdiff_t n = rhs.pptr() - rhs.pbase(); + ptrdiff_t e = rhs.epptr() - rhs.pbase(); + rhs.setp((char_type*)rhs._M_extbuf_min, + (char_type*)rhs._M_extbuf_min + e); + rhs.pbump(n); + } +} + +/////////////////////////////////////////////////////////////////////////////// +template <class CharT, class Traits> +inline +void +swap(basic_filebuf<CharT, Traits>& x, basic_filebuf<CharT, Traits>& y) +{ + x.swap(y); +} + +/////////////////////////////////////////////////////////////////////////////// +template <class CharT, class Traits> +inline +bool +basic_filebuf<CharT, Traits>::is_open() const +{ + return _M_file != nullptr; +} + +/////////////////////////////////////////////////////////////////////////////// +template <class CharT, class Traits> +const char* basic_filebuf<CharT, Traits>::_M_get_mode(std::ios_base::openmode mode) +{ + switch ((mode & ~std::ios_base::ate) | 0) + { + case std::ios_base::out: + case std::ios_base::out | std::ios_base::trunc: + return "w"; + case std::ios_base::out | std::ios_base::app: + case std::ios_base::app: + return "a"; + break; + case std::ios_base::in: + return "r"; + case std::ios_base::in | std::ios_base::out: + return "r+"; + case std::ios_base::in | std::ios_base::out | std::ios_base::trunc: + return "w+"; + case std::ios_base::in | std::ios_base::out | std::ios_base::app: + case std::ios_base::in | std::ios_base::app: + return "a+"; + case std::ios_base::out | std::ios_base::binary: + case std::ios_base::out | std::ios_base::trunc | std::ios_base::binary: + return "wb"; + case std::ios_base::out | std::ios_base::app | std::ios_base::binary: + case std::ios_base::app | std::ios_base::binary: + return "ab"; + case std::ios_base::in | std::ios_base::binary: + return "rb"; + case std::ios_base::in | std::ios_base::out | std::ios_base::binary: + return "r+b"; + case std::ios_base::in | std::ios_base::out | std::ios_base::trunc | std::ios_base::binary: + return "w+b"; + case std::ios_base::in | std::ios_base::out | std::ios_base::app | std::ios_base::binary: + case std::ios_base::in | std::ios_base::app | std::ios_base::binary: + return "a+b"; + default: + return nullptr; + } +} + +/////////////////////////////////////////////////////////////////////////////// +template <class CharT, class Traits> +basic_filebuf<CharT, Traits>* +basic_filebuf<CharT, Traits>::open(const char* s, std::ios_base::openmode mode) +{ + basic_filebuf<CharT, Traits>* rt = nullptr; + if (_M_file == nullptr) + { + const char* md= _M_get_mode(mode); + if (md) + { + _M_file = fopen(s, md); + if (_M_file) + { + rt = this; + _M_om = mode; + if (mode & std::ios_base::ate) + { + if (fseek(_M_file, 0, SEEK_END)) + { + fclose(_M_file); + _M_file = nullptr; + rt = nullptr; + } + } + } + } + } + return rt; +} + +/////////////////////////////////////////////////////////////////////////////// +template <class CharT, class Traits> +inline +basic_filebuf<CharT, Traits>* +basic_filebuf<CharT, Traits>::open(const std::string& s, std::ios_base::openmode mode) +{ + return open(s.c_str(), mode); +} + +/////////////////////////////////////////////////////////////////////////////// +template <class CharT, class Traits> +basic_filebuf<CharT, Traits>* +basic_filebuf<CharT, Traits>::open(int fd, std::ios_base::openmode mode) +{ + const char* md= this->_M_get_mode(mode); + if (md) + { + this->_M_file= fdopen(fd, md); + this->_M_om = mode; + return this; + } + else return nullptr; +} + +/////////////////////////////////////////////////////////////////////////////// +template <class CharT, class Traits> +basic_filebuf<CharT, Traits>* +basic_filebuf<CharT, Traits>::open(FILE* f, std::ios_base::openmode mode) +{ + this->_M_file = f; + this->_M_om = mode; + return this; +} + +/////////////////////////////////////////////////////////////////////////////// +template <class CharT, class Traits> +basic_filebuf<CharT, Traits>* +basic_filebuf<CharT, Traits>::close() +{ + basic_filebuf<CharT, Traits>* rt = nullptr; + if (_M_file) + { + rt = this; + std::unique_ptr<FILE, int(*)(FILE*)> h(_M_file, fclose); + if (sync()) + rt = nullptr; + if (fclose(h.release()) == 0) + _M_file = nullptr; + else + rt = nullptr; + } + return rt; +} + +/////////////////////////////////////////////////////////////////////////////// +template <class CharT, class Traits> +typename basic_filebuf<CharT, Traits>::int_type +basic_filebuf<CharT, Traits>::underflow() +{ + if (_M_file == nullptr) + return traits_type::eof(); + bool initial = _M_read_mode(); + char_type buf; + if (this->gptr() == nullptr) + this->setg(&buf, &buf+1, &buf+1); + const size_t unget_sz = initial ? 0 : std::min<size_t>((this->egptr() - this->eback()) / 2, 4); + int_type c = traits_type::eof(); + if (this->gptr() == this->egptr()) + { + memmove(this->eback(), this->egptr() - unget_sz, unget_sz * sizeof(char_type)); + if (_M_always_noconv) + { + size_t nmemb = static_cast<size_t>(this->egptr() - this->eback() - unget_sz); + nmemb = fread(this->eback() + unget_sz, 1, nmemb, _M_file); + if (nmemb != 0) + { + this->setg(this->eback(), + this->eback() + unget_sz, + this->eback() + unget_sz + nmemb); + c = traits_type::to_int_type(*this->gptr()); + } + } + else + { + memmove(_M_extbuf, _M_extbufnext, _M_extbufend - _M_extbufnext); + _M_extbufnext = _M_extbuf + (_M_extbufend - _M_extbufnext); + _M_extbufend = _M_extbuf + (_M_extbuf == _M_extbuf_min ? sizeof(_M_extbuf_min) : _M_ebs); + size_t nmemb = std::min(static_cast<size_t>(_M_ibs - unget_sz), + static_cast<size_t>(_M_extbufend - _M_extbufnext)); + std::codecvt_base::result r; + _M_st_last = _M_st; + size_t nr = fread((void*)_M_extbufnext, 1, nmemb, _M_file); + if (nr != 0) + { + if (!_M_cv) + throw std::bad_cast(); + _M_extbufend = _M_extbufnext + nr; + char_type* inext; + r = _M_cv->in(_M_st, _M_extbuf, _M_extbufend, _M_extbufnext, + this->eback() + unget_sz, + this->eback() + _M_ibs, inext); + if (r == std::codecvt_base::noconv) + { + this->setg((char_type*)_M_extbuf, (char_type*)_M_extbuf, (char_type*)_M_extbufend); + c = traits_type::to_int_type(*this->gptr()); + } + else if (inext != this->eback() + unget_sz) + { + this->setg(this->eback(), this->eback() + unget_sz, inext); + c = traits_type::to_int_type(*this->gptr()); + } + } + } + } + else + c = traits_type::to_int_type(*this->gptr()); + if (this->eback() == &buf) + this->setg(0, 0, 0); + return c; +} + +/////////////////////////////////////////////////////////////////////////////// +template <class CharT, class Traits> +typename basic_filebuf<CharT, Traits>::int_type +basic_filebuf<CharT, Traits>::pbackfail(int_type c) +{ + if (_M_file && this->eback() < this->gptr()) + { + if (traits_type::eq_int_type(c, traits_type::eof())) + { + this->gbump(-1); + return traits_type::not_eof(c); + } + if ((_M_om & std::ios_base::out) || + traits_type::eq(traits_type::to_char_type(c), this->gptr()[-1])) + { + this->gbump(-1); + *this->gptr() = traits_type::to_char_type(c); + return c; + } + } + return traits_type::eof(); +} + +/////////////////////////////////////////////////////////////////////////////// +template <class CharT, class Traits> +typename basic_filebuf<CharT, Traits>::int_type +basic_filebuf<CharT, Traits>::overflow(int_type c) +{ + if (_M_file == nullptr) + return traits_type::eof(); + _M_write_mode(); + char_type buf; + char_type* pb_save = this->pbase(); + char_type* epb_save = this->epptr(); + if (!traits_type::eq_int_type(c, traits_type::eof())) + { + if (this->pptr() == nullptr) + this->setp(&buf, &buf+1); + *this->pptr() = traits_type::to_char_type(c); + this->pbump(1); + } + if (this->pptr() != this->pbase()) + { + if (_M_always_noconv) + { + size_t nmemb = static_cast<size_t>(this->pptr() - this->pbase()); + if (fwrite(this->pbase(), sizeof(char_type), nmemb, _M_file) != nmemb) + return traits_type::eof(); + } + else + { + char* extbe = _M_extbuf; + std::codecvt_base::result r; + do + { + if (!_M_cv) + throw std::bad_cast(); + const char_type* e; + r = _M_cv->out(_M_st, this->pbase(), this->pptr(), e, + _M_extbuf, _M_extbuf + _M_ebs, extbe); + if (e == this->pbase()) + return traits_type::eof(); + if (r == std::codecvt_base::noconv) + { + size_t nmemb = static_cast<size_t>(this->pptr() - this->pbase()); + if (fwrite(this->pbase(), 1, nmemb, _M_file) != nmemb) + return traits_type::eof(); + } + else if (r == std::codecvt_base::ok || r == std::codecvt_base::partial) + { + size_t nmemb = static_cast<size_t>(extbe - _M_extbuf); + if (fwrite(_M_extbuf, 1, nmemb, _M_file) != nmemb) + return traits_type::eof(); + if (r == std::codecvt_base::partial) + { + this->setp((char_type*)e, this->pptr()); + this->pbump(this->epptr() - this->pbase()); + } + } + else + return traits_type::eof(); + } while (r == std::codecvt_base::partial); + } + this->setp(pb_save, epb_save); + } + return traits_type::not_eof(c); +} + +/////////////////////////////////////////////////////////////////////////////// +template <class CharT, class Traits> +std::basic_streambuf<CharT, Traits>* +basic_filebuf<CharT, Traits>::setbuf(char_type* s, std::streamsize n) +{ + this->setg(0, 0, 0); + this->setp(0, 0); + if (_M_owns_eb) + delete [] _M_extbuf; + if (_M_owns_ib) + delete [] _M_intbuf; + _M_ebs = n; + if (_M_ebs > sizeof(_M_extbuf_min)) + { + if (_M_always_noconv && s) + { + _M_extbuf = (char*)s; + _M_owns_eb = false; + } + else + { + _M_extbuf = new char[_M_ebs]; + _M_owns_eb = true; + } + } + else + { + _M_extbuf = _M_extbuf_min; + _M_ebs = sizeof(_M_extbuf_min); + _M_owns_eb = false; + } + if (!_M_always_noconv) + { + _M_ibs = std::max<std::streamsize>(n, sizeof(_M_extbuf_min)); + if (s && _M_ibs >= sizeof(_M_extbuf_min)) + { + _M_intbuf = s; + _M_owns_ib = false; + } + else + { + _M_intbuf = new char_type[_M_ibs]; + _M_owns_ib = true; + } + } + else + { + _M_ibs = 0; + _M_intbuf = 0; + _M_owns_ib = false; + } + return this; +} + +/////////////////////////////////////////////////////////////////////////////// +template <class CharT, class Traits> +typename basic_filebuf<CharT, Traits>::pos_type +basic_filebuf<CharT, Traits>::seekoff(off_type off, std::ios_base::seekdir way, + std::ios_base::openmode) +{ + if (!_M_cv) + throw std::bad_cast(); + int width = _M_cv->encoding(); + if (_M_file == nullptr || (width <= 0 && off != 0) || sync()) + return pos_type(off_type(-1)); + // width > 0 || off == 0 + int whence; + switch (way) + { + case std::ios_base::beg: + whence = SEEK_SET; + break; + case std::ios_base::cur: + whence = SEEK_CUR; + break; + case std::ios_base::end: + whence = SEEK_END; + break; + default: + return pos_type(off_type(-1)); + } +#if _WIN32 + if (fseek(_M_file, width > 0 ? width * off : 0, whence)) + return pos_type(off_type(-1)); + pos_type r = ftell(_M_file); +#else + if (fseeko(_M_file, width > 0 ? width * off : 0, whence)) + return pos_type(off_type(-1)); + pos_type r = ftello(_M_file); +#endif + r.state(_M_st); + return r; +} + +/////////////////////////////////////////////////////////////////////////////// +template <class CharT, class Traits> +typename basic_filebuf<CharT, Traits>::pos_type +basic_filebuf<CharT, Traits>::seekpos(pos_type sp, std::ios_base::openmode) +{ + if (_M_file == nullptr || sync()) + return pos_type(off_type(-1)); +#if _WIN32 + if (fseek(_M_file, sp, SEEK_SET)) + return pos_type(off_type(-1)); +#else + if (fseeko(_M_file, sp, SEEK_SET)) + return pos_type(off_type(-1)); +#endif + _M_st = sp.state(); + return sp; +} + +/////////////////////////////////////////////////////////////////////////////// +template <class CharT, class Traits> +int +basic_filebuf<CharT, Traits>::sync() +{ + if (_M_file == nullptr) + return 0; + if (!_M_cv) + throw std::bad_cast(); + if (_M_cm & std::ios_base::out) + { + if (this->pptr() != this->pbase()) + if (overflow() == traits_type::eof()) + return -1; + std::codecvt_base::result r; + do + { + char* extbe; + r = _M_cv->unshift(_M_st, _M_extbuf, _M_extbuf + _M_ebs, extbe); + size_t nmemb = static_cast<size_t>(extbe - _M_extbuf); + if (fwrite(_M_extbuf, 1, nmemb, _M_file) != nmemb) + return -1; + } while (r == std::codecvt_base::partial); + if (r == std::codecvt_base::error) + return -1; + if (fflush(_M_file)) + return -1; + } + else if (_M_cm & std::ios_base::in) + { + off_type c; + state_type state = _M_st_last; + bool update_st = false; + if (_M_always_noconv) + c = this->egptr() - this->gptr(); + else + { + int width = _M_cv->encoding(); + c = _M_extbufend - _M_extbufnext; + if (width > 0) + c += width * (this->egptr() - this->gptr()); + else + { + if (this->gptr() != this->egptr()) + { + const int off = _M_cv->length(state, _M_extbuf, + _M_extbufnext, + this->gptr() - this->eback()); + c += _M_extbufnext - _M_extbuf - off; + update_st = true; + } + } + } +#if _WIN32 + if (fseek(_M_file_, -c, SEEK_CUR)) + return -1; +#else + if (fseeko(_M_file, -c, SEEK_CUR)) + return -1; +#endif + if (update_st) + _M_st = state; + _M_extbufnext = _M_extbufend = _M_extbuf; + this->setg(0, 0, 0); + _M_cm = std::ios_base::openmode(0); + } + return 0; +} + +/////////////////////////////////////////////////////////////////////////////// +template <class CharT, class Traits> +void +basic_filebuf<CharT, Traits>::imbue(const std::locale& loc) +{ + sync(); + _M_cv = &std::use_facet<std::codecvt<char_type, char, state_type> >(loc); + bool old_anc = _M_always_noconv; + _M_always_noconv = _M_cv->always_noconv(); + if (old_anc != _M_always_noconv) + { + this->setg(0, 0, 0); + this->setp(0, 0); + // invariant, char_type is char, else we couldn't get here + if (_M_always_noconv) // need to dump _M_intbuf + { + if (_M_owns_eb) + delete [] _M_extbuf; + _M_owns_eb = _M_owns_ib; + _M_ebs = _M_ibs; + _M_extbuf = (char*)_M_intbuf; + _M_ibs = 0; + _M_intbuf = nullptr; + _M_owns_ib = false; + } + else // need to obtain an _M_intbuf. + { // If _M_extbuf is user-supplied, use it, else new _M_intbuf + if (!_M_owns_eb && _M_extbuf != _M_extbuf_min) + { + _M_ibs = _M_ebs; + _M_intbuf = (char_type*)_M_extbuf; + _M_owns_ib = false; + _M_extbuf = new char[_M_ebs]; + _M_owns_eb = true; + } + else + { + _M_ibs = _M_ebs; + _M_intbuf = new char_type[_M_ibs]; + _M_owns_ib = true; + } + } + } +} + +/////////////////////////////////////////////////////////////////////////////// +template <class CharT, class Traits> +bool +basic_filebuf<CharT, Traits>::_M_read_mode() +{ + if (!(_M_cm & std::ios_base::in)) + { + this->setp(0, 0); + if (_M_always_noconv) + this->setg((char_type*)_M_extbuf, + (char_type*)_M_extbuf + _M_ebs, + (char_type*)_M_extbuf + _M_ebs); + else + this->setg(_M_intbuf, _M_intbuf + _M_ibs, _M_intbuf + _M_ibs); + _M_cm = std::ios_base::in; + return true; + } + return false; +} + +/////////////////////////////////////////////////////////////////////////////// +template <class CharT, class Traits> +void +basic_filebuf<CharT, Traits>::_M_write_mode() +{ + if (!(_M_cm & std::ios_base::out)) + { + this->setg(0, 0, 0); + if (_M_ebs > sizeof(_M_extbuf_min)) + { + if (_M_always_noconv) + this->setp((char_type*)_M_extbuf, + (char_type*)_M_extbuf + (_M_ebs - 1)); + else + this->setp(_M_intbuf, _M_intbuf + (_M_ibs - 1)); + } + else + this->setp(0, 0); + _M_cm = std::ios_base::out; + } +} + +/////////////////////////////////////////////////////////////////////////////// +} + +/////////////////////////////////////////////////////////////////////////////// +#endif // KALDI_UTIL_BASIC_FILEBUF_H_ + +/////////////////////////////////////////////////////////////////////////////// + +/* + * ============================================================================ + * libc++ License + * ============================================================================ + * + * The libc++ library is dual licensed under both the University of Illinois + * "BSD-Like" license and the MIT license. As a user of this code you may + * choose to use it under either license. As a contributor, you agree to allow + * your code to be used under both. + * + * Full text of the relevant licenses is included below. + * + * ============================================================================ + * + * University of Illinois/NCSA + * Open Source License + * + * Copyright (c) 2009-2014 by the contributors listed in CREDITS.TXT (included below) + * + * All rights reserved. + * + * Developed by: + * + * LLVM Team + * + * University of Illinois at Urbana-Champaign + * + * http://llvm.org + * + * Permission is hereby granted, free of charge, to any person obtaining a copy of + * this software and associated documentation files (the "Software"), to deal with + * the Software without restriction, including without limitation the rights to + * use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies + * of the Software, and to permit persons to whom the Software is furnished to do + * so, subject to the following conditions: + * + * * Redistributions of source code must retain the above copyright notice, + * this list of conditions and the following disclaimers. + * + * * Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimers in the + * documentation and/or other materials provided with the distribution. + * + * * Neither the names of the LLVM Team, University of Illinois at + * Urbana-Champaign, nor the names of its contributors may be used to + * endorse or promote products derived from this Software without specific + * prior written permission. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS + * FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * CONTRIBUTORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS WITH THE + * SOFTWARE. + * + * ============================================================================== + * + * Copyright (c) 2009-2014 by the contributors listed in CREDITS.TXT (included below) + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + * + * ============================================================================== + * + * This file is a partial list of people who have contributed to the LLVM/libc++ + * project. If you have contributed a patch or made some other contribution to + * LLVM/libc++, please submit a patch to this file to add yourself, and it will be + * done! + * + * The list is sorted by surname and formatted to allow easy grepping and + * beautification by scripts. The fields are: name (N), email (E), web-address + * (W), PGP key ID and fingerprint (P), description (D), and snail-mail address + * (S). + * + * N: Saleem Abdulrasool + * E: [email protected] + * D: Minor patches and Linux fixes. + * + * N: Dimitry Andric + * E: [email protected] + * D: Visibility fixes, minor FreeBSD portability patches. + * + * N: Holger Arnold + * E: [email protected] + * D: Minor fix. + * + * N: Ruben Van Boxem + * E: vanboxem dot ruben at gmail dot com + * D: Initial Windows patches. + * + * N: David Chisnall + * E: theraven at theravensnest dot org + * D: FreeBSD and Solaris ports, libcxxrt support, some atomics work. + * + * N: Marshall Clow + * E: [email protected] + * E: [email protected] + * D: C++14 support, patches and bug fixes. + * + * N: Bill Fisher + * E: [email protected] + * D: Regex bug fixes. + * + * N: Matthew Dempsky + * E: [email protected] + * D: Minor patches and bug fixes. + * + * N: Google Inc. + * D: Copyright owner and contributor of the CityHash algorithm + * + * N: Howard Hinnant + * E: [email protected] + * D: Architect and primary author of libc++ + * + * N: Hyeon-bin Jeong + * E: [email protected] + * D: Minor patches and bug fixes. + * + * N: Argyrios Kyrtzidis + * E: [email protected] + * D: Bug fixes. + * + * N: Bruce Mitchener, Jr. + * E: [email protected] + * D: Emscripten-related changes. + * + * N: Michel Morin + * E: [email protected] + * D: Minor patches to is_convertible. + * + * N: Andrew Morrow + * E: [email protected] + * D: Minor patches and Linux fixes. + * + * N: Arvid Picciani + * E: aep at exys dot org + * D: Minor patches and musl port. + * + * N: Bjorn Reese + * E: [email protected] + * D: Initial regex prototype + * + * N: Nico Rieck + * E: [email protected] + * D: Windows fixes + * + * N: Jonathan Sauer + * D: Minor patches, mostly related to constexpr + * + * N: Craig Silverstein + * E: [email protected] + * D: Implemented Cityhash as the string hash function on 64-bit machines + * + * N: Richard Smith + * D: Minor patches. + * + * N: Joerg Sonnenberger + * E: [email protected] + * D: NetBSD port. + * + * N: Stephan Tolksdorf + * E: [email protected] + * D: Minor <atomic> fix + * + * N: Michael van der Westhuizen + * E: r1mikey at gmail dot com + * + * N: Klaas de Vries + * E: klaas at klaasgaaf dot nl + * D: Minor bug fix. + * + * N: Zhang Xiongpang + * E: [email protected] + * D: Minor patches and bug fixes. + * + * N: Xing Xue + * E: [email protected] + * D: AIX port + * + * N: Zhihao Yuan + * E: [email protected] + * D: Standard compatibility fixes. + * + * N: Jeffrey Yasskin + * E: [email protected] + * E: [email protected] + * D: Linux fixes. + */ diff --git a/kaldi_io/src/kaldi/util/common-utils.h b/kaldi_io/src/kaldi/util/common-utils.h new file mode 100644 index 0000000..9d39f9d --- /dev/null +++ b/kaldi_io/src/kaldi/util/common-utils.h @@ -0,0 +1,31 @@ +// util/common-utils.h + +// Copyright 2009-2011 Microsoft Corporation + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at + +// http://www.apache.org/licenses/LICENSE-2.0 + +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. +#ifndef KALDI_UTIL_COMMON_UTILS_H_ +#define KALDI_UTIL_COMMON_UTILS_H_ + +#include "base/kaldi-common.h" +#include "util/parse-options.h" +#include "util/kaldi-io.h" +#include "util/simple-io-funcs.h" +#include "util/kaldi-holder.h" +#include "util/kaldi-table.h" +#include "util/table-types.h" +#include "util/text-utils.h" + +#endif diff --git a/kaldi_io/src/kaldi/util/const-integer-set-inl.h b/kaldi_io/src/kaldi/util/const-integer-set-inl.h new file mode 100644 index 0000000..8f92ab2 --- /dev/null +++ b/kaldi_io/src/kaldi/util/const-integer-set-inl.h @@ -0,0 +1,88 @@ +// util/const-integer-set-inl.h + +// Copyright 2009-2011 Microsoft Corporation + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + + +#ifndef KALDI_UTIL_CONST_INTEGER_SET_INL_H_ +#define KALDI_UTIL_CONST_INTEGER_SET_INL_H_ + +// Do not include this file directly. It is included by const-integer-set.h + + +namespace kaldi { + +template<class I> +void ConstIntegerSet<I>::InitInternal() { + KALDI_ASSERT_IS_INTEGER_TYPE(I); + quick_set_.clear(); // just in case we previously had data. + if (slow_set_.size() == 0) { + lowest_member_=(I) 1; + highest_member_=(I) 0; + contiguous_ = false; + quick_ = false; + } else { + lowest_member_ = slow_set_.front(); + highest_member_ = slow_set_.back(); + size_t range = highest_member_ + 1 - lowest_member_; + if (range == slow_set_.size()) { + contiguous_ = true; + quick_=false; + } else { + contiguous_ = false; + if (range < slow_set_.size() * 8 * sizeof(I)) { // If it would be more compact to store as bool + // (assuming 1 bit per element)... + quick_set_.resize(range, false); + for (size_t i = 0;i < slow_set_.size();i++) + quick_set_[slow_set_[i] - lowest_member_] = true; + quick_ = true; + } else { + quick_ = false; + } + } + } +} + +template<class I> +int ConstIntegerSet<I>::count(I i) const { + if (i < lowest_member_ || i > highest_member_) return 0; + else { + if (contiguous_) return true; + if (quick_) return (quick_set_[i-lowest_member_] ? 1 : 0); + else { + bool ans = std::binary_search(slow_set_.begin(), slow_set_.end(), i); + return (ans ? 1 : 0); + } + } +} + +template<class I> +void ConstIntegerSet<I>::Write(std::ostream &os, bool binary) const { + WriteIntegerVector(os, binary, slow_set_); +} + +template<class I> +void ConstIntegerSet<I>::Read(std::istream &is, bool binary) { + ReadIntegerVector(is, binary, &slow_set_); + InitInternal(); +} + + + +} // end namespace kaldi + +#endif diff --git a/kaldi_io/src/kaldi/util/const-integer-set.h b/kaldi_io/src/kaldi/util/const-integer-set.h new file mode 100644 index 0000000..ffdce4d --- /dev/null +++ b/kaldi_io/src/kaldi/util/const-integer-set.h @@ -0,0 +1,95 @@ +// util/const-integer-set.h + +// Copyright 2009-2011 Microsoft Corporation + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + + +#ifndef KALDI_UTIL_CONST_INTEGER_SET_H_ +#define KALDI_UTIL_CONST_INTEGER_SET_H_ +#include <vector> +#include <set> +#include <algorithm> +#include <limits> +#include <cassert> +#include "util/stl-utils.h" + + /* ConstIntegerSet is a way to efficiently test whether something is in a + supplied set of integers. It can be initialized from a vector or set, but + never changed after that. It either uses a sorted vector or an array of + bool, depending on the input. It behaves like a const version of an STL set, with + only a subset of the functionality, except all the member functions are + upper-case. + + Note that we could get rid of the member slow_set_, but we'd have to + do more work to implement an iterator type. This would save memory. + */ + +namespace kaldi { + +template<class I> class ConstIntegerSet { + public: + ConstIntegerSet(): lowest_member_(1), highest_member_(0) { } + + void Init(const std::vector<I> &input) { + slow_set_ = input; + SortAndUniq(&slow_set_); + InitInternal(); + } + + void Init(const std::set<I> &input) { + CopySetToVector(input, &slow_set_); + InitInternal(); + } + + explicit ConstIntegerSet(const std::vector<I> &input): slow_set_(input) { + SortAndUniq(&slow_set_); + InitInternal(); + } + explicit ConstIntegerSet(const std::set<I> &input) { + CopySetToVector(input, &slow_set_); + InitInternal(); + } + explicit ConstIntegerSet(const ConstIntegerSet<I> &other): slow_set_(other.slow_set_) { + InitInternal(); + } + + int count(I i) const; // returns 1 or 0. + + typedef typename std::vector<I>::const_iterator iterator; + iterator begin() const { return slow_set_.begin(); } + iterator end() const { return slow_set_.end(); } + size_t size() const { return slow_set_.size(); } + bool empty() const { return slow_set_.empty(); } + + void Write(std::ostream &os, bool binary) const; + void Read(std::istream &is, bool binary); + + private: + I lowest_member_; + I highest_member_; + bool contiguous_; + bool quick_; + std::vector<bool> quick_set_; + std::vector<I> slow_set_; + void InitInternal(); +}; + +} // end namespace kaldi + +#include "const-integer-set-inl.h" + +#endif diff --git a/kaldi_io/src/kaldi/util/edit-distance-inl.h b/kaldi_io/src/kaldi/util/edit-distance-inl.h new file mode 100644 index 0000000..ebbfb71 --- /dev/null +++ b/kaldi_io/src/kaldi/util/edit-distance-inl.h @@ -0,0 +1,189 @@ +// util/edit-distance-inl.h + +// Copyright 2009-2011 Microsoft Corporation; Haihua Xu; Yanmin Qian + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at + +// http://www.apache.org/licenses/LICENSE-2.0 + +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#ifndef KALDI_UTIL_EDIT_DISTANCE_INL_H_ +#define KALDI_UTIL_EDIT_DISTANCE_INL_H_ +#include "util/stl-utils.h" + + +namespace kaldi { + +template<class T> +int32 LevenshteinEditDistance(const std::vector<T> &a, + const std::vector<T> &b) { + // Algorithm: + // write A and B for the sequences, with elements a_0 .. + // let |A| = M and |B| = N be the lengths, and have + // elements a_0 ... a_{M-1} and b_0 ... b_{N-1}. + // We are computing the recursion + // E(m, n) = min( E(m-1, n-1) + (1-delta(a_{m-1}, b_{n-1})), + // E(m-1, n), + // E(m, n-1) ). + // where E(m, n) is defined for m = 0..M and n = 0..N and out-of- + // bounds quantities are considered to be infinity (i.e. the + // recursion does not visit them). + + // We do this computation using a vector e of size N+1. + // The outer iterations range over m = 0..M. + + int M = a.size(), N = b.size(); + std::vector<int32> e(N+1); + std::vector<int32> e_tmp(N+1); + // initialize e. + for (size_t i = 0; i < e.size(); i++) + e[i] = i; + for (int32 m = 1; m <= M; m++) { + // computing E(m, .) from E(m-1, .) + // handle special case n = 0: + e_tmp[0] = e[0] + 1; + + for (int32 n = 1; n <= N; n++) { + int32 term1 = e[n-1] + (a[m-1] == b[n-1] ? 0 : 1); + int32 term2 = e[n] + 1; + int32 term3 = e_tmp[n-1] + 1; + e_tmp[n] = std::min(term1, std::min(term2, term3)); + } + e = e_tmp; + } + return e.back(); +} +// +struct error_stats{ + int32 ins_num; + int32 del_num; + int32 sub_num; + int32 total_cost; // minimum total cost to the current alignment. +}; +// Note that both hyp and ref should not contain noise word in +// the following implementation. + +template<class T> +int32 LevenshteinEditDistance(const std::vector<T> &ref, + const std::vector<T> &hyp, + int32 *ins, int32 *del, int32 *sub) { + // temp sequence to remember error type and stats. + std::vector<error_stats> e(ref.size()+1); + std::vector<error_stats> cur_e(ref.size()+1); + // initialize the first hypothesis aligned to the reference at each + // position:[hyp_index =0][ref_index] + for (size_t i =0; i < e.size(); i ++) { + e[i].ins_num = 0; + e[i].sub_num = 0; + e[i].del_num = i; + e[i].total_cost = i; + } + + // for other alignments + for (size_t hyp_index = 1; hyp_index <= hyp.size(); hyp_index ++) { + cur_e[0] = e[0]; + cur_e[0].ins_num ++; + cur_e[0].total_cost ++; + for (size_t ref_index = 1; ref_index <= ref.size(); ref_index ++) { + + int32 ins_err = e[ref_index].total_cost + 1; + int32 del_err = cur_e[ref_index-1].total_cost + 1; + int32 sub_err = e[ref_index-1].total_cost; + if (hyp[hyp_index-1] != ref[ref_index-1]) + sub_err ++; + + if (sub_err < ins_err && sub_err < del_err) { + cur_e[ref_index] =e[ref_index-1]; + if (hyp[hyp_index-1] != ref[ref_index-1]) + cur_e[ref_index].sub_num ++; // substitution error should be increased + cur_e[ref_index].total_cost = sub_err; + }else if (del_err < ins_err ) { + cur_e[ref_index] = cur_e[ref_index-1]; + cur_e[ref_index].total_cost = del_err; + cur_e[ref_index].del_num ++; // deletion number is increased. + }else{ + cur_e[ref_index] = e[ref_index]; + cur_e[ref_index].total_cost = ins_err; + cur_e[ref_index].ins_num ++; // insertion number is increased. + } + } + e = cur_e; // alternate for the next recursion. + } + size_t ref_index = e.size()-1; + *ins = e[ref_index].ins_num, *del = e[ref_index].del_num, *sub = e[ref_index].sub_num; + return e[ref_index].total_cost; +} + +template<class T> +int32 LevenshteinAlignment(const std::vector<T> &a, + const std::vector<T> &b, + T eps_symbol, + std::vector<std::pair<T, T> > *output) { + // Check inputs: + { + KALDI_ASSERT(output != NULL); + for (size_t i = 0; i < a.size(); i++) KALDI_ASSERT(a[i] != eps_symbol); + for (size_t i = 0; i < b.size(); i++) KALDI_ASSERT(b[i] != eps_symbol); + } + output->clear(); + // This is very memory-inefficiently implemented using a vector of vectors. + size_t M = a.size(), N = b.size(); + size_t m, n; + std::vector<std::vector<int32> > e(M+1); + for (m = 0; m <=M; m++) e[m].resize(N+1); + for (n = 0; n <= N; n++) + e[0][n] = n; + for (m = 1; m <= M; m++) { + e[m][0] = e[m-1][0] + 1; + for (n = 1; n <= N; n++) { + int32 sub_or_ok = e[m-1][n-1] + (a[m-1] == b[n-1] ? 0 : 1); + int32 del = e[m-1][n] + 1; // assumes a == ref, b == hyp. + int32 ins = e[m][n-1] + 1; + e[m][n] = std::min(sub_or_ok, std::min(del, ins)); + } + } + // get time-reversed output first: trace back. + m = M; n = N; + while (m != 0 || n != 0) { + size_t last_m, last_n; + if (m == 0) { last_m = m; last_n = n-1; } + else if (n == 0) { last_m = m-1; last_n = n; } + else { + int32 sub_or_ok = e[m-1][n-1] + (a[m-1] == b[n-1] ? 0 : 1); + int32 del = e[m-1][n] + 1; // assumes a == ref, b == hyp. + int32 ins = e[m][n-1] + 1; + if (sub_or_ok <= std::min(del, ins)) { // choose sub_or_ok if all else equal. + last_m = m-1; last_n = n-1; + } else { + if (del <= ins) { // choose del over ins if equal. + last_m = m-1; last_n = n; + } else { + last_m = m; last_n = n-1; + } + } + } + T a_sym, b_sym; + a_sym = (last_m == m ? eps_symbol : a[last_m]); + b_sym = (last_n == n ? eps_symbol : b[last_n]); + output->push_back(std::make_pair(a_sym, b_sym)); + m = last_m; + n = last_n; + } + ReverseVector(output); + return e[M][N]; +} + + +} // end namespace kaldi + +#endif // KALDI_UTIL_EDIT_DISTANCE_INL_H_ diff --git a/kaldi_io/src/kaldi/util/edit-distance.h b/kaldi_io/src/kaldi/util/edit-distance.h new file mode 100644 index 0000000..6000622 --- /dev/null +++ b/kaldi_io/src/kaldi/util/edit-distance.h @@ -0,0 +1,63 @@ +// util/edit-distance.h + +// Copyright 2009-2011 Microsoft Corporation; Haihua Xu + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + + +#ifndef KALDI_UTIL_EDIT_DISTANCE_H_ +#define KALDI_UTIL_EDIT_DISTANCE_H_ +#include <vector> +#include <set> +#include <algorithm> +#include <limits> +#include <cassert> +#include "base/kaldi-types.h" + +namespace kaldi { + +// Compute the edit-distance between two strings. +template<class T> +int32 LevenshteinEditDistance(const std::vector<T> &a, + const std::vector<T> &b); + + +// edit distance calculation with conventional method. +// note: noise word must be filtered out from the hypothesis and reference sequence +// before the following procedure conducted. +template<class T> +int32 LevenshteinEditDistance(const std::vector<T> &ref, + const std::vector<T> &hyp, + int32 *ins, int32 *del, int32 *sub); + +// This version of the edit-distance computation outputs the alignment +// between the two. This is a vector of pairs of (symbol a, symbol b). +// The epsilon symbol (eps_symbol) must not occur in sequences a or b. +// Where one aligned to no symbol in the other (insertion or deletion), +// epsilon will be the corresponding member of the pair. +// It returns the edit-distance between the two strings. + +template<class T> +int32 LevenshteinAlignment(const std::vector<T> &a, + const std::vector<T> &b, + T eps_symbol, + std::vector<std::pair<T, T> > *output); + +} // end namespace kaldi + +#include "edit-distance-inl.h" + +#endif diff --git a/kaldi_io/src/kaldi/util/hash-list-inl.h b/kaldi_io/src/kaldi/util/hash-list-inl.h new file mode 100644 index 0000000..19c2bb6 --- /dev/null +++ b/kaldi_io/src/kaldi/util/hash-list-inl.h @@ -0,0 +1,183 @@ +// util/hash-list-inl.h + +// Copyright 2009-2011 Microsoft Corporation +// 2013 Johns Hopkins University (author: Daniel Povey) + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + + +#ifndef KALDI_UTIL_HASH_LIST_INL_H_ +#define KALDI_UTIL_HASH_LIST_INL_H_ + +// Do not include this file directly. It is included by fast-hash.h + + +namespace kaldi { + +template<class I, class T> HashList<I, T>::HashList() { + list_head_ = NULL; + bucket_list_tail_ = static_cast<size_t>(-1); // invalid. + hash_size_ = 0; + freed_head_ = NULL; +} + +template<class I, class T> void HashList<I, T>::SetSize(size_t size) { + hash_size_ = size; + KALDI_ASSERT(list_head_ == NULL && bucket_list_tail_ == static_cast<size_t>(-1)); // make sure empty. + if (size > buckets_.size()) + buckets_.resize(size, HashBucket(0, NULL)); +} + +template<class I, class T> +typename HashList<I, T>::Elem* HashList<I, T>::Clear() { + // Clears the hashtable and gives ownership of the currently contained list to the + // user. + for (size_t cur_bucket = bucket_list_tail_; + cur_bucket != static_cast<size_t>(-1); + cur_bucket = buckets_[cur_bucket].prev_bucket) { + buckets_[cur_bucket].last_elem = NULL; // this is how we indicate "empty". + } + bucket_list_tail_ = static_cast<size_t>(-1); + Elem *ans = list_head_; + list_head_ = NULL; + return ans; +} + +template<class I, class T> +const typename HashList<I, T>::Elem* HashList<I, T>::GetList() const { + return list_head_; +} + +template<class I, class T> +inline void HashList<I, T>::Delete(Elem *e) { + e->tail = freed_head_; + freed_head_ = e; +} + +template<class I, class T> +inline typename HashList<I, T>::Elem* HashList<I, T>::Find(I key) { + size_t index = (static_cast<size_t>(key) % hash_size_); + HashBucket &bucket = buckets_[index]; + if (bucket.last_elem == NULL) { + return NULL; // empty bucket. + } else { + Elem *head = (bucket.prev_bucket == static_cast<size_t>(-1) ? + list_head_ : + buckets_[bucket.prev_bucket].last_elem->tail), + *tail = bucket.last_elem->tail; + for (Elem *e = head; e != tail; e = e->tail) + if (e->key == key) return e; + return NULL; // Not found. + } +} + +template<class I, class T> +inline typename HashList<I, T>::Elem* HashList<I, T>::New() { + if (freed_head_) { + Elem *ans = freed_head_; + freed_head_ = freed_head_->tail; + return ans; + } else { + Elem *tmp = new Elem[allocate_block_size_]; + for (size_t i = 0; i+1 < allocate_block_size_; i++) + tmp[i].tail = tmp+i+1; + tmp[allocate_block_size_-1].tail = NULL; + freed_head_ = tmp; + allocated_.push_back(tmp); + return this->New(); + } +} + +template<class I, class T> +HashList<I, T>::~HashList() { + // First test whether we had any memory leak within the + // HashList, i.e. things for which the user did not call Delete(). + size_t num_in_list = 0, num_allocated = 0; + for (Elem *e = freed_head_; e != NULL; e = e->tail) + num_in_list++; + for (size_t i = 0; i < allocated_.size(); i++) { + num_allocated += allocate_block_size_; + delete[] allocated_[i]; + } + if (num_in_list != num_allocated) { + KALDI_WARN << "Possible memory leak: " << num_in_list + << " != " << num_allocated + << ": you might have forgotten to call Delete on " + << "some Elems"; + } +} + + +template<class I, class T> +void HashList<I, T>::Insert(I key, T val) { + size_t index = (static_cast<size_t>(key) % hash_size_); + HashBucket &bucket = buckets_[index]; + Elem *elem = New(); + elem->key = key; + elem->val = val; + + if (bucket.last_elem == NULL) { // Unoccupied bucket. Insert at + // head of bucket list (which is tail of regular list, they go in + // opposite directions). + if (bucket_list_tail_ == static_cast<size_t>(-1)) { + // list was empty so this is the first elem. + KALDI_ASSERT(list_head_ == NULL); + list_head_ = elem; + } else { + // link in to the chain of Elems + buckets_[bucket_list_tail_].last_elem->tail = elem; + } + elem->tail = NULL; + bucket.last_elem = elem; + bucket.prev_bucket = bucket_list_tail_; + bucket_list_tail_ = index; + } else { + // Already-occupied bucket. Insert at tail of list of elements within + // the bucket. + elem->tail = bucket.last_elem->tail; + bucket.last_elem->tail = elem; + bucket.last_elem = elem; + } +} + +template<class I, class T> +void HashList<I, T>::InsertMore(I key, T val) { + size_t index = (static_cast<size_t>(key) % hash_size_); + HashBucket &bucket = buckets_[index]; + Elem *elem = New(); + elem->key = key; + elem->val = val; + + KALDI_ASSERT(bucket.last_elem != NULL); // we assume there is already one element + if (bucket.last_elem->key == key) { // standard behavior: add as last element + elem->tail = bucket.last_elem->tail; + bucket.last_elem->tail = elem; + bucket.last_elem = elem; + return; + } + Elem *e = (bucket.prev_bucket == static_cast<size_t>(-1) ? + list_head_ : buckets_[bucket.prev_bucket].last_elem->tail); + // find place to insert in linked list + while (e != bucket.last_elem->tail && e->key != key) e = e->tail; + KALDI_ASSERT(e->key == key); // not found? - should not happen + elem->tail = e->tail; + e->tail = elem; +} + + +} // end namespace kaldi + +#endif diff --git a/kaldi_io/src/kaldi/util/hash-list.h b/kaldi_io/src/kaldi/util/hash-list.h new file mode 100644 index 0000000..4524759 --- /dev/null +++ b/kaldi_io/src/kaldi/util/hash-list.h @@ -0,0 +1,140 @@ +// util/hash-list.h + +// Copyright 2009-2011 Microsoft Corporation +// 2013 Johns Hopkins University (author: Daniel Povey) + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + + +#ifndef KALDI_UTIL_HASH_LIST_H_ +#define KALDI_UTIL_HASH_LIST_H_ +#include <vector> +#include <set> +#include <algorithm> +#include <limits> +#include <cassert> +#include "util/stl-utils.h" + + +/* This header provides utilities for a structure that's used in a decoder (but + is quite generic in nature so we implement and test it separately). + Basically it's a singly-linked list, but implemented in such a way that we + can quickly search for elements in the list. We give it a slightly richer + interface than just a hash and a list. The idea is that we want to separate + the hash part and the list part: basically, in the decoder, we want to have a + single hash for the current frame and the next frame, because by the time we + need to access the hash for the next frame we no longer need the hash for the + previous frame. So we have an operation that clears the hash but leaves the + list structure intact. We also control memory management inside this object, + to avoid repeated new's/deletes. + + See hash-list-test.cc for an example of how to use this object. +*/ + + +namespace kaldi { + +template<class I, class T> class HashList { + + public: + struct Elem { + I key; + T val; + Elem *tail; + }; + + /// Constructor takes no arguments. Call SetSize to inform it of the likely size. + HashList(); + + /// Clears the hash and gives the head of the current list to the user; + /// ownership is transferred to the user (the user must call Delete() + /// for each element in the list, at his/her leisure). + Elem *Clear(); + + /// Gives the head of the current list to the user. Ownership retained in the + /// class. Caution: in December 2013 the return type was changed to const Elem* + /// and this function was made const. You may need to change some types of + /// local Elem* variables to const if this produces compilation errors. + const Elem *GetList() const; + + /// Think of this like delete(). It is to be called for each Elem in turn + /// after you "obtained ownership" by doing Clear(). This is not the opposite of + /// Insert, it is the opposite of New. It's really a memory operation. + inline void Delete(Elem *e); + + /// This should probably not be needed to be called directly by the user. Think of it as opposite + /// to Delete(); + inline Elem *New(); + + /// Find tries to find this element in the current list using the hashtable. + /// It returns NULL if not present. The Elem it returns is not owned by the user, + /// it is part of the internal list owned by this object, but the user is + /// free to modify the "val" element. + inline Elem *Find(I key); + + /// Insert inserts a new element into the hashtable/stored list. By calling this, + /// the user asserts that it is not already present (e.g. Find was called and + /// returned NULL). With current code, calling this if an element already exists will + /// result in duplicate elements in the structure, and Find() will find the + /// first one that was added. [but we don't guarantee this behavior]. + inline void Insert(I key, T val); + + /// Insert inserts another element with same key into the hashtable/stored list. + /// By calling this, the user asserts that one element with that key is already present. + /// We insert it that way, that all elements with the same key follow each other. + /// Find() will return the first one of the elements with the same key. + inline void InsertMore(I key, T val); + + /// SetSize tells the object how many hash buckets to allocate (should typically be + /// at least twice the number of objects we expect to go in the structure, for fastest + /// performance). It must be called while the hash is empty (e.g. after Clear() or + /// after initializing the object, but before adding anything to the hash. + void SetSize(size_t sz); + + /// Returns current number of hash buckets. + inline size_t Size() { return hash_size_; } + + ~HashList(); + private: + + struct HashBucket { + size_t prev_bucket; // index to next bucket (-1 if list tail). Note: list of buckets + // goes in opposite direction to list of Elems. + Elem *last_elem; // pointer to last element in this bucket (NULL if empty) + inline HashBucket(size_t i, Elem *e): prev_bucket(i), last_elem(e) {} + }; + + Elem *list_head_; // head of currently stored list. + size_t bucket_list_tail_; // tail of list of active hash buckets. + + size_t hash_size_; // number of hash buckets. + + std::vector<HashBucket> buckets_; + + Elem *freed_head_; // head of list of currently freed elements. [ready for allocation] + + std::vector<Elem*> allocated_; // list of allocated blocks. + + static const size_t allocate_block_size_ = 1024; // Number of Elements to allocate in one block. Must be + // largish so storing allocated_ doesn't become a problem. +}; + + +} // end namespace kaldi + +#include "hash-list-inl.h" + +#endif diff --git a/kaldi_io/src/kaldi/util/kaldi-holder-inl.h b/kaldi_io/src/kaldi/util/kaldi-holder-inl.h new file mode 100644 index 0000000..6a66e61 --- /dev/null +++ b/kaldi_io/src/kaldi/util/kaldi-holder-inl.h @@ -0,0 +1,800 @@ +// util/kaldi-holder-inl.h + +// Copyright 2009-2011 Microsoft Corporation + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + + +#ifndef KALDI_UTIL_KALDI_HOLDER_INL_H_ +#define KALDI_UTIL_KALDI_HOLDER_INL_H_ + +#include <algorithm> +#include "util/kaldi-io.h" +#include "util/text-utils.h" +#include "matrix/kaldi-matrix.h" + +namespace kaldi { + +/// \addtogroup holders +/// @{ + + +// KaldiObjectHolder is valid only for Kaldi objects with +// copy constructors, default constructors, and "normal" +// Kaldi Write and Read functions. E.g. it works for +// Matrix and Vector. +template<class KaldiType> class KaldiObjectHolder { + public: + typedef KaldiType T; + + KaldiObjectHolder(): t_(NULL) { } + + static bool Write(std::ostream &os, bool binary, const T &t) { + InitKaldiOutputStream(os, binary); // Puts binary header if binary mode. + try { + t.Write(os, binary); + return os.good(); + } catch (const std::exception &e) { + KALDI_WARN << "Exception caught writing Table object: " << e.what(); + if (!IsKaldiError(e.what())) { std::cerr << e.what(); } + return false; // Write failure. + } + } + + void Clear() { + if (t_) { + delete t_; + t_ = NULL; + } + } + + // Reads into the holder. + bool Read(std::istream &is) { + if (t_) delete t_; + t_ = new T; + // Don't want any existing state to complicate the read functioN: get new object. + bool is_binary; + if (!InitKaldiInputStream(is, &is_binary)) { + KALDI_WARN << "Reading Table object, failed reading binary header\n"; + return false; + } + try { + t_->Read(is, is_binary); + return true; + } catch (std::exception &e) { + KALDI_WARN << "Exception caught reading Table object "; + if (!IsKaldiError(e.what())) { std::cerr << e.what(); } + delete t_; + t_ = NULL; + return false; + } + } + + // Kaldi objects always have the stream open in binary mode for + // reading. + static bool IsReadInBinary() { return true; } + + const T &Value() const { + // code error if !t_. + if (!t_) KALDI_ERR << "KaldiObjectHolder::Value() called wrongly."; + return *t_; + } + + ~KaldiObjectHolder() { if (t_) delete t_; } + private: + KALDI_DISALLOW_COPY_AND_ASSIGN(KaldiObjectHolder); + T *t_; +}; + + +// BasicHolder is valid for float, double, bool, and integer +// types. There will be a compile time error otherwise, because +// we make sure that the {Write, Read}BasicType functions do not +// get instantiated for other types. + +template<class BasicType> class BasicHolder { + public: + typedef BasicType T; + + BasicHolder(): t_(static_cast<T>(-1)) { } + + static bool Write(std::ostream &os, bool binary, const T &t) { + InitKaldiOutputStream(os, binary); // Puts binary header if binary mode. + try { + WriteBasicType(os, binary, t); + if (!binary) os << '\n'; // Makes output format more readable and + // easier to manipulate. + return os.good(); + } catch (const std::exception &e) { + KALDI_WARN << "Exception caught writing Table object: " << e.what(); + if (!IsKaldiError(e.what())) { std::cerr << e.what(); } + return false; // Write failure. + } + } + + void Clear() { } + + // Reads into the holder. + bool Read(std::istream &is) { + bool is_binary; + if (!InitKaldiInputStream(is, &is_binary)) { + KALDI_WARN << "Reading Table object [integer type], failed reading binary header\n"; + return false; + } + try { + int c; + if (!is_binary) { // This is to catch errors, the class would work without it.. + // Eat up any whitespace and make sure it's not newline. + while (isspace((c = is.peek())) && c != static_cast<int>('\n')) is.get(); + if (is.peek() == '\n') { + KALDI_WARN << "Found newline but expected basic type."; + return false; // This is just to catch a more- + // likely-than average type of error (empty line before the token), since + // ReadBasicType will eat it up. + } + } + + ReadBasicType(is, is_binary, &t_); + + if (!is_binary) { // This is to catch errors, the class would work without it.. + // make sure there is a newline. + while (isspace((c = is.peek())) && c != static_cast<int>('\n')) is.get(); + if (is.peek() != '\n') { + KALDI_WARN << "BasicHolder::Read, expected newline, got " + << CharToString(is.peek()) << ", position " << is.tellg(); + return false; + } + is.get(); // Consume the newline. + } + return true; + } catch (std::exception &e) { + KALDI_WARN << "Exception caught reading Table object"; + if (!IsKaldiError(e.what())) { std::cerr << e.what(); } + return false; + } + } + + // Objects read/written with the Kaldi I/O functions always have the stream + // open in binary mode for reading. + static bool IsReadInBinary() { return true; } + + const T &Value() const { + return t_; + } + + ~BasicHolder() { } + private: + KALDI_DISALLOW_COPY_AND_ASSIGN(BasicHolder); + + T t_; +}; + + +/// A Holder for a vector of basic types, e.g. +/// std::vector<int32>, std::vector<float>, and so on. +/// Note: a basic type is defined as a type for which ReadBasicType +/// and WriteBasicType are implemented, i.e. integer and floating +/// types, and bool. +template<class BasicType> class BasicVectorHolder { + public: + typedef std::vector<BasicType> T; + + BasicVectorHolder() { } + + static bool Write(std::ostream &os, bool binary, const T &t) { + InitKaldiOutputStream(os, binary); // Puts binary header if binary mode. + try { + if (binary) { // need to write the size, in binary mode. + KALDI_ASSERT(static_cast<size_t>(static_cast<int32>(t.size())) == t.size()); + // Or this Write routine cannot handle such a large vector. + // use int32 because it's fixed size regardless of compilation. + // change to int64 (plus in Read function) if this becomes a problem. + WriteBasicType(os, binary, static_cast<int32>(t.size())); + for (typename std::vector<BasicType>::const_iterator iter = t.begin(); + iter != t.end(); ++iter) + WriteBasicType(os, binary, *iter); + + } else { + for (typename std::vector<BasicType>::const_iterator iter = t.begin(); + iter != t.end(); ++iter) + WriteBasicType(os, binary, *iter); + os << '\n'; // Makes output format more readable and + // easier to manipulate. In text mode, this function writes something like + // "1 2 3\n". + } + return os.good(); + } catch (const std::exception &e) { + KALDI_WARN << "Exception caught writing Table object (BasicVector). "; + if (!IsKaldiError(e.what())) { std::cerr << e.what(); } + return false; // Write failure. + } + } + + void Clear() { t_.clear(); } + + // Reads into the holder. + bool Read(std::istream &is) { + t_.clear(); + bool is_binary; + if (!InitKaldiInputStream(is, &is_binary)) { + KALDI_WARN << "Reading Table object [integer type], failed reading binary header\n"; + return false; + } + if (!is_binary) { + // In text mode, we terminate with newline. + std::string line; + getline(is, line); // this will discard the \n, if present. + if (is.fail()) { + KALDI_WARN << "BasicVectorHolder::Read, error reading line " << (is.eof() ? "[eof]" : ""); + return false; // probably eof. fail in any case. + } + std::istringstream line_is(line); + try { + while (1) { + line_is >> std::ws; // eat up whitespace. + if (line_is.eof()) break; + BasicType bt; + ReadBasicType(line_is, false, &bt); + t_.push_back(bt); + } + return true; + } catch(std::exception &e) { + KALDI_WARN << "BasicVectorHolder::Read, could not interpret line: " << line; + if (!IsKaldiError(e.what())) { std::cerr << e.what(); } + return false; + } + } else { // binary mode. + size_t filepos = is.tellg(); + try { + int32 size; + ReadBasicType(is, true, &size); + t_.resize(size); + for (typename std::vector<BasicType>::iterator iter = t_.begin(); + iter != t_.end(); + ++iter) { + ReadBasicType(is, true, &(*iter)); + } + return true; + } catch (...) { + KALDI_WARN << "BasicVectorHolder::Read, read error or unexpected data at archive entry beginning at file position " << filepos; + return false; + } + } + } + + // Objects read/written with the Kaldi I/O functions always have the stream + // open in binary mode for reading. + static bool IsReadInBinary() { return true; } + + const T &Value() const { return t_; } + + ~BasicVectorHolder() { } + private: + KALDI_DISALLOW_COPY_AND_ASSIGN(BasicVectorHolder); + T t_; +}; + + +/// BasicVectorVectorHolder is a Holder for a vector of vector of +/// a basic type, e.g. std::vector<std::vector<int32> >. +/// Note: a basic type is defined as a type for which ReadBasicType +/// and WriteBasicType are implemented, i.e. integer and floating +/// types, and bool. +template<class BasicType> class BasicVectorVectorHolder { + public: + typedef std::vector<std::vector<BasicType> > T; + + BasicVectorVectorHolder() { } + + static bool Write(std::ostream &os, bool binary, const T &t) { + InitKaldiOutputStream(os, binary); // Puts binary header if binary mode. + try { + if (binary) { // need to write the size, in binary mode. + KALDI_ASSERT(static_cast<size_t>(static_cast<int32>(t.size())) == t.size()); + // Or this Write routine cannot handle such a large vector. + // use int32 because it's fixed size regardless of compilation. + // change to int64 (plus in Read function) if this becomes a problem. + WriteBasicType(os, binary, static_cast<int32>(t.size())); + for (typename std::vector<std::vector<BasicType> >::const_iterator iter = t.begin(); + iter != t.end(); ++iter) { + KALDI_ASSERT(static_cast<size_t>(static_cast<int32>(iter->size())) == iter->size()); + WriteBasicType(os, binary, static_cast<int32>(iter->size())); + for (typename std::vector<BasicType>::const_iterator iter2=iter->begin(); + iter2 != iter->end(); ++iter2) { + WriteBasicType(os, binary, *iter2); + } + } + } else { // text mode... + // In text mode, we write out something like (for integers): + // "1 2 3 ; 4 5 ; 6 ; ; 7 8 9 ;\n" + // where the semicolon is a terminator, not a separator + // (a separator would cause ambiguity between an + // empty list, and a list containing a single empty list). + for (typename std::vector<std::vector<BasicType> >::const_iterator iter = t.begin(); + iter != t.end(); + ++iter) { + for (typename std::vector<BasicType>::const_iterator iter2=iter->begin(); + iter2 != iter->end(); ++iter2) + WriteBasicType(os, binary, *iter2); + os << "; "; + } + os << '\n'; + } + return os.good(); + } catch (const std::exception &e) { + KALDI_WARN << "Exception caught writing Table object. "; + if (!IsKaldiError(e.what())) { std::cerr << e.what(); } + return false; // Write failure. + } + } + + void Clear() { t_.clear(); } + + // Reads into the holder. + bool Read(std::istream &is) { + t_.clear(); + bool is_binary; + if (!InitKaldiInputStream(is, &is_binary)) { + KALDI_WARN << "Failed reading binary header\n"; + return false; + } + if (!is_binary) { + // In text mode, we terminate with newline. + try { // catching errors from ReadBasicType.. + std::vector<BasicType> v; // temporary vector + while (1) { + int i = is.peek(); + if (i == -1) { + KALDI_WARN << "Unexpected EOF"; + return false; + } else if (static_cast<char>(i) == '\n') { + if (!v.empty()) { + KALDI_WARN << "No semicolon before newline (wrong format)"; + return false; + } else { is.get(); return true; } + } else if (std::isspace(i)) { + is.get(); + } else if (static_cast<char>(i) == ';') { + t_.push_back(v); + v.clear(); + is.get(); + } else { // some object we want to read... + BasicType b; + ReadBasicType(is, false, &b); // throws on error. + v.push_back(b); + } + } + } catch(std::exception &e) { + KALDI_WARN << "BasicVectorVectorHolder::Read, read error"; + if (!IsKaldiError(e.what())) { std::cerr << e.what(); } + return false; + } + } else { // binary mode. + size_t filepos = is.tellg(); + try { + int32 size; + ReadBasicType(is, true, &size); + t_.resize(size); + for (typename std::vector<std::vector<BasicType> >::iterator iter = t_.begin(); + iter != t_.end(); + ++iter) { + int32 size2; + ReadBasicType(is, true, &size2); + iter->resize(size2); + for (typename std::vector<BasicType>::iterator iter2 = iter->begin(); + iter2 != iter->end(); + ++iter2) + ReadBasicType(is, true, &(*iter2)); + } + return true; + } catch (...) { + KALDI_WARN << "Read error or unexpected data at archive entry beginning at file position " << filepos; + return false; + } + } + } + + // Objects read/written with the Kaldi I/O functions always have the stream + // open in binary mode for reading. + static bool IsReadInBinary() { return true; } + + const T &Value() const { return t_; } + + ~BasicVectorVectorHolder() { } + private: + KALDI_DISALLOW_COPY_AND_ASSIGN(BasicVectorVectorHolder); + T t_; +}; + + +/// BasicPairVectorHolder is a Holder for a vector of pairs of +/// a basic type, e.g. std::vector<std::pair<int32> >. +/// Note: a basic type is defined as a type for which ReadBasicType +/// and WriteBasicType are implemented, i.e. integer and floating +/// types, and bool. +template<class BasicType> class BasicPairVectorHolder { + public: + typedef std::vector<std::pair<BasicType, BasicType> > T; + + BasicPairVectorHolder() { } + + static bool Write(std::ostream &os, bool binary, const T &t) { + InitKaldiOutputStream(os, binary); // Puts binary header if binary mode. + try { + if (binary) { // need to write the size, in binary mode. + KALDI_ASSERT(static_cast<size_t>(static_cast<int32>(t.size())) == t.size()); + // Or this Write routine cannot handle such a large vector. + // use int32 because it's fixed size regardless of compilation. + // change to int64 (plus in Read function) if this becomes a problem. + WriteBasicType(os, binary, static_cast<int32>(t.size())); + for (typename T::const_iterator iter = t.begin(); + iter != t.end(); ++iter) { + WriteBasicType(os, binary, iter->first); + WriteBasicType(os, binary, iter->second); + } + } else { // text mode... + // In text mode, we write out something like (for integers): + // "1 2 ; 4 5 ; 6 7 ; 8 9 \n" + // where the semicolon is a separator, not a terminator. + for (typename T::const_iterator iter = t.begin(); + iter != t.end();) { + WriteBasicType(os, binary, iter->first); + WriteBasicType(os, binary, iter->second); + ++iter; + if (iter != t.end()) + os << "; "; + } + os << '\n'; + } + return os.good(); + } catch (const std::exception &e) { + KALDI_WARN << "Exception caught writing Table object. "; + if (!IsKaldiError(e.what())) { std::cerr << e.what(); } + return false; // Write failure. + } + } + + void Clear() { t_.clear(); } + + // Reads into the holder. + bool Read(std::istream &is) { + t_.clear(); + bool is_binary; + if (!InitKaldiInputStream(is, &is_binary)) { + KALDI_WARN << "Reading Table object [integer type], failed reading binary header\n"; + return false; + } + if (!is_binary) { + // In text mode, we terminate with newline. + try { // catching errors from ReadBasicType.. + std::vector<BasicType> v; // temporary vector + while (1) { + int i = is.peek(); + if (i == -1) { + KALDI_WARN << "Unexpected EOF"; + return false; + } else if (static_cast<char>(i) == '\n') { + if (t_.empty() && v.empty()) { + is.get(); + return true; + } else if (v.size() == 2) { + t_.push_back(std::make_pair(v[0], v[1])); + is.get(); + return true; + } else { + KALDI_WARN << "Unexpected newline, reading vector<pair<?> >; got " + << v.size() << " elements, expected 2."; + return false; + } + } else if (std::isspace(i)) { + is.get(); + } else if (static_cast<char>(i) == ';') { + if (v.size() != 2) { + KALDI_WARN << "Wrong input format, reading vector<pair<?> >; got " + << v.size() << " elements, expected 2."; + return false; + } + t_.push_back(std::make_pair(v[0], v[1])); + v.clear(); + is.get(); + } else { // some object we want to read... + BasicType b; + ReadBasicType(is, false, &b); // throws on error. + v.push_back(b); + } + } + } catch(std::exception &e) { + KALDI_WARN << "BasicPairVectorHolder::Read, read error"; + if (!IsKaldiError(e.what())) { std::cerr << e.what(); } + return false; + } + } else { // binary mode. + size_t filepos = is.tellg(); + try { + int32 size; + ReadBasicType(is, true, &size); + t_.resize(size); + for (typename T::iterator iter = t_.begin(); + iter != t_.end(); + ++iter) { + ReadBasicType(is, true, &(iter->first)); + ReadBasicType(is, true, &(iter->second)); + } + return true; + } catch (...) { + KALDI_WARN << "BasicVectorHolder::Read, read error or unexpected data at archive entry beginning at file position " << filepos; + return false; + } + } + } + + // Objects read/written with the Kaldi I/O functions always have the stream + // open in binary mode for reading. + static bool IsReadInBinary() { return true; } + + const T &Value() const { return t_; } + + ~BasicPairVectorHolder() { } + private: + KALDI_DISALLOW_COPY_AND_ASSIGN(BasicPairVectorHolder); + T t_; +}; + + + + +// We define a Token as a nonempty, printable, whitespace-free std::string. +// The binary and text formats here are the same (newline-terminated) +// and as such we don't bother with the binary-mode headers. +class TokenHolder { + public: + typedef std::string T; + + TokenHolder() {} + + static bool Write(std::ostream &os, bool, const T &t) { // ignore binary-mode. + KALDI_ASSERT(IsToken(t)); + os << t << '\n'; + return os.good(); + } + + void Clear() { t_.clear(); } + + // Reads into the holder. + bool Read(std::istream &is) { + is >> t_; + if (is.fail()) return false; + char c; + while (isspace(c = is.peek()) && c!= '\n') is.get(); + if (is.peek() != '\n') { + KALDI_ERR << "TokenHolder::Read, expected newline, got char " << CharToString(is.peek()) + << ", at stream pos " << is.tellg(); + return false; + } + is.get(); // get '\n' + return true; + } + + + // Since this is fundamentally a text format, read in text mode (would work + // fine either way, but doing it this way will exercise more of the code). + static bool IsReadInBinary() { return false; } + + const T &Value() const { return t_; } + + ~TokenHolder() { } + private: + KALDI_DISALLOW_COPY_AND_ASSIGN(TokenHolder); + T t_; +}; + +// A Token is a nonempty, whitespace-free std::string. +// Class TokenVectorHolder is a Holder class for vectors of these. +class TokenVectorHolder { + public: + typedef std::vector<std::string> T; + + TokenVectorHolder() { } + + static bool Write(std::ostream &os, bool, const T &t) { // ignore binary-mode. + for (std::vector<std::string>::const_iterator iter = t.begin(); + iter != t.end(); + ++iter) { + KALDI_ASSERT(IsToken(*iter)); // make sure it's whitespace-free, printable and nonempty. + os << *iter << ' '; + } + os << '\n'; + return os.good(); + } + + void Clear() { t_.clear(); } + + + // Reads into the holder. + bool Read(std::istream &is) { + t_.clear(); + + // there is no binary/non-binary mode. + + std::string line; + getline(is, line); // this will discard the \n, if present. + if (is.fail()) { + KALDI_WARN << "BasicVectorHolder::Read, error reading line " << (is.eof() ? "[eof]" : ""); + return false; // probably eof. fail in any case. + } + const char *white_chars = " \t\n\r\f\v"; + SplitStringToVector(line, white_chars, true, &t_); // true== omit empty strings e.g. + // between spaces. + return true; + } + + // Read in text format since it's basically a text-mode thing.. doesn't really matter, + // it would work either way since we ignore the extra '\r'. + static bool IsReadInBinary() { return false; } + + const T &Value() const { return t_; } + + private: + KALDI_DISALLOW_COPY_AND_ASSIGN(TokenVectorHolder); + T t_; +}; + + +class HtkMatrixHolder { + public: + typedef std::pair<Matrix<BaseFloat>, HtkHeader> T; + + HtkMatrixHolder() {} + + static bool Write(std::ostream &os, bool binary, const T &t) { + if (!binary) + KALDI_ERR << "Non-binary HTK-format write not supported."; + bool ans = WriteHtk(os, t.first, t.second); + if (!ans) + KALDI_WARN << "Error detected writing HTK-format matrix."; + return ans; + } + + void Clear() { t_.first.Resize(0, 0); } + + // Reads into the holder. + bool Read(std::istream &is) { + bool ans = ReadHtk(is, &t_.first, &t_.second); + if (!ans) { + KALDI_WARN << "Error detected reading HTK-format matrix."; + return false; + } + return ans; + } + + // HTK-format matrices only read in binary. + static bool IsReadInBinary() { return true; } + + const T &Value() const { return t_; } + + + // No destructor. + private: + KALDI_DISALLOW_COPY_AND_ASSIGN(HtkMatrixHolder); + T t_; +}; + +// SphinxMatrixHolder can be used to read and write feature files in +// CMU Sphinx format. 13-dimensional big-endian features are assumed. +// The ultimate reference is SphinxBase's source code (for example see +// feat_s2mfc_read() in src/libsphinxbase/feat/feat.c). +// We can't fully automate the detection of machine/feature file endianess +// mismatch here, because for this Sphinx relies on comparing the feature +// file's size with the number recorded in its header. We are working with +// streams, however(what happens if this is a Kaldi archive?). This should +// be no problem, because the usage help of Sphinx' "wave2feat" for example +// says that Sphinx features are always big endian. +// Note: the kFeatDim defaults to 13, see forward declaration in kaldi-holder.h +template<int kFeatDim> class SphinxMatrixHolder { + public: + typedef Matrix<BaseFloat> T; + + SphinxMatrixHolder() {} + + void Clear() { feats_.Resize(0, 0); } + + // Writes Sphinx-format features + static bool Write(std::ostream &os, bool binary, const T &m) { + if (!binary) { + KALDI_WARN << "SphinxMatrixHolder can't write Sphinx features in text "; + return false; + } + + int32 size = m.NumRows() * m.NumCols(); + if (MachineIsLittleEndian()) + KALDI_SWAP4(size); + os.write((char*) &size, sizeof(size)); // write the header + + for (MatrixIndexT i = 0; i < m.NumRows(); i++) { + float32 tmp[m.NumCols()]; + for (MatrixIndexT j = 0; j < m.NumCols(); j++) { + tmp[j] = static_cast<float32>(m(i, j)); + if (MachineIsLittleEndian()) + KALDI_SWAP4(tmp[j]); + } + os.write((char*) tmp, sizeof(tmp)); + } + + return true; + } + + // Reads the features into a Kaldi Matrix + bool Read(std::istream &is) { + int32 nmfcc; + + is.read((char*) &nmfcc, sizeof(nmfcc)); + if (MachineIsLittleEndian()) + KALDI_SWAP4(nmfcc); + KALDI_VLOG(2) << "#feats: " << nmfcc; + int32 nfvec = nmfcc / kFeatDim; + if ((nmfcc % kFeatDim) != 0) { + KALDI_WARN << "Sphinx feature count is inconsistent with vector length "; + return false; + } + + feats_.Resize(nfvec, kFeatDim); + for (MatrixIndexT i = 0; i < feats_.NumRows(); i++) { + if (sizeof(BaseFloat) == sizeof(float32)) { + is.read((char*) feats_.RowData(i), kFeatDim * sizeof(float32)); + if (!is.good()) { + KALDI_WARN << "Unexpected error/EOF while reading Sphinx features "; + return false; + } + if (MachineIsLittleEndian()) { + for (MatrixIndexT j=0; j < kFeatDim; j++) + KALDI_SWAP4(feats_(i, j)); + } + } else { // KALDI_DOUBLEPRECISION=1 + float32 tmp[kFeatDim]; + is.read((char*) tmp, sizeof(tmp)); + if (!is.good()) { + KALDI_WARN << "Unexpected error/EOF while reading Sphinx features "; + return false; + } + for (MatrixIndexT j=0; j < kFeatDim; j++) { + if (MachineIsLittleEndian()) + KALDI_SWAP4(tmp[j]); + feats_(i, j) = static_cast<BaseFloat>(tmp[j]); + } + } + } + + return true; + } + + // Only read in binary + static bool IsReadInBinary() { return true; } + + const T &Value() const { return feats_; } + + private: + KALDI_DISALLOW_COPY_AND_ASSIGN(SphinxMatrixHolder); + T feats_; +}; + + +/// @} end "addtogroup holders" + +} // end namespace kaldi + + + +#endif diff --git a/kaldi_io/src/kaldi/util/kaldi-holder.h b/kaldi_io/src/kaldi/util/kaldi-holder.h new file mode 100644 index 0000000..95f1183 --- /dev/null +++ b/kaldi_io/src/kaldi/util/kaldi-holder.h @@ -0,0 +1,207 @@ +// util/kaldi-holder.h + +// Copyright 2009-2011 Microsoft Corporation + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + + +#ifndef KALDI_UTIL_KALDI_HOLDER_H_ +#define KALDI_UTIL_KALDI_HOLDER_H_ + +#include <algorithm> +#include "util/kaldi-io.h" +#include "util/text-utils.h" +#include "matrix/kaldi-vector.h" + +namespace kaldi { + + +// The Table class uses a Holder class to wrap objects, and make them behave +// in a "normalized" way w.r.t. reading and writing, so the Table class can +// be template-ized without too much trouble. Look below this +// comment (search for GenericHolder) to see what it looks like. +// +// Requirements of the holder class: +// +// They can only contain objects that can be read/written without external +// information; other objects cannot be stored in this type of archive. +// +// In terms of what functions it should have, see GenericHolder below. +// It is just for documentation. +// +// (1) Requirements of the Read and Write functions +// +// The Read and Write functions should have the property that in a longer +// file, if the Read function is started from where the Write function started +// writing, it should go to where the Write function stopped writing, in either +// text or binary mode (but it's OK if it doesn't eat up trailing space). +// +// [Desirable property: when writing in text mode the output should contain +// exactly one newline, at the end of the output; this makes it easier to manipulate] +// +// [Desirable property for classes: the output should just be a binary-mode +// header (if in binary mode and it's a Kaldi object, or no header +// othewise), and then the output of Object.Write(). This means that when +// written to individual files with the scp: type of wspecifier, we can read +// the individual files in the "normal" Kaldi way by reading the binary +// header and then the object.] +// +// +// The Write function takes a 'binary' argument. In general, each object will +// have two formats: text and binary. However, it's permitted to throw() if +// asked to read in the text format if there is none. The file will be open, if +// the file system has binary/text modes, in the corresponding mode. However, +// the object should have a file-mode in which it can read either text or binary +// output. It announces this via the static IsReadInBinary() function. This +// will generally be the binary mode and it means that where necessary, in text +// formats, we must ignore \r characters. +// +// Memory requirements: if it allocates memory, the destructor should +// free that memory. Copying and assignment of Holder objects may be +// disallowed as the Table code never does this. + + +/// GenericHolder serves to document the requirements of the Holder interface; +/// it's not intended to be used. +template<class SomeType> class GenericHolder { + public: + typedef SomeType T; + + /// Must have a constructor that takes no arguments. + GenericHolder() { } + + /// Write writes this object of type T. Possibly also writes a binary-mode + /// header so that the Read function knows which mode to read in (since the + /// Read function does not get this information). It's a static member so we + /// can write those not inside this class (can use this function with Value() + /// to write from this class). The Write method may throw if it cannot write + /// the object in the given (binary/non-binary) mode. The holder object can + /// assume the stream has been opened in the given mode (where relevant). The + /// object can write the data how it likes. + static bool Write(std::ostream &os, bool binary, const T &t); + + /// Reads into the holder. Must work out from the stream (which will be opened + /// on Windows in binary mode if the IsReadInBinary() function of this class + /// returns true, and text mode otherwise) whether the actual data is binary or + /// not (usually via reading the Kaldi binary-mode header). We put the + /// responsibility for reading the Kaldi binary-mode header in the Read + /// function (rather than making the binary mode an argument to this function), + /// so that for non-Kaldi binary files we don't have to write the header, which + /// would prevent the file being read by non-Kaldi programs (e.g. if we write + /// to individual files using an scp). + /// + /// Read must deallocate any existing data we have here, if applicable (must + /// not assume the object was newly constructed). + /// + /// Returns true on success. + bool Read(std::istream &is); + + /// IsReadInBinary() will return true if the object wants the file to be + /// opened in binary for reading (if the file system has binary/text modes), + /// and false otherwise. Static function. Kaldi objects always return true + /// as they always read in binary mode. Note that we must be able to read, in + /// this mode, objects written in both text and binary mode by Write (which + /// may mean ignoring "\r" characters). I doubt we will ever want this + /// function to return false. + static bool IsReadInBinary() { return true; } + + /// Returns the value of the object held here. Will only + /// ever be called if Read() has been previously called and it returned + /// true (so OK to throw exception if no object was read). + const T &Value() const { return t_; } // if t is a pointer, would return *t_; + + /// The Clear() function doesn't have to do anything. Its purpose is to + /// allow the object to free resources if they're no longer needed. + void Clear() { } + + /// If the object held pointers, the destructor would free them. + ~GenericHolder() { } + + private: + KALDI_DISALLOW_COPY_AND_ASSIGN(GenericHolder); + T t_; // t_ may alternatively be of type T*. +}; + + +// See kaldi-holder-inl.h for examples of some actual Holder +// classes and templates. + + +// The following two typedefs should probably be in their own file, but they're +// here until there are enough of them to warrant their own header. + + +/// \addtogroup holders +/// @{ + +/// KaldiObjectHolder works for Kaldi objects that have the "standard" Read and Write +/// functions, and a copy constructor. +template<class KaldiType> class KaldiObjectHolder; + +/// BasicHolder is valid for float, double, bool, and integer +/// types. There will be a compile time error otherwise, because +/// we make sure that the {Write, Read}BasicType functions do not +/// get instantiated for other types. +template<class BasicType> class BasicHolder; + + +// A Holder for a vector of basic types, e.g. +// std::vector<int32>, std::vector<float>, and so on. +// Note: a basic type is defined as a type for which ReadBasicType +// and WriteBasicType are implemented, i.e. integer and floating +// types, and bool. +template<class BasicType> class BasicVectorHolder; + + +// A holder for vectors of vectors of basic types, e.g. +// std::vector<std::vector<int32> >, and so on. +// Note: a basic type is defined as a type for which ReadBasicType +// and WriteBasicType are implemented, i.e. integer and floating +// types, and bool. +template<class BasicType> class BasicVectorVectorHolder; + +// A holder for vectors of pairsof basic types, e.g. +// std::vector<std::vector<int32> >, and so on. +// Note: a basic type is defined as a type for which ReadBasicType +// and WriteBasicType are implemented, i.e. integer and floating +// types, and bool. Text format is (e.g. for integers), +// "1 12 ; 43 61 ; 17 8 \n" +template<class BasicType> class BasicPairVectorHolder; + +/// We define a Token (not a typedef, just a word) as a nonempty, printable, +/// whitespace-free std::string. The binary and text formats here are the same +/// (newline-terminated) and as such we don't bother with the binary-mode headers. +class TokenHolder; + +/// Class TokenVectorHolder is a Holder class for vectors of Tokens (T == std::string). +class TokenVectorHolder; + +/// A class for reading/writing HTK-format matrices. +/// T == std::pair<Matrix<BaseFloat>, HtkHeader> +class HtkMatrixHolder; + +/// A class for reading/writing Sphinx format matrices. +template<int kFeatDim=13> class SphinxMatrixHolder; + + +/// @} end "addtogroup holders" + + +} // end namespace kaldi + +#include "kaldi-holder-inl.h" + +#endif diff --git a/kaldi_io/src/kaldi/util/kaldi-io-inl.h b/kaldi_io/src/kaldi/util/kaldi-io-inl.h new file mode 100644 index 0000000..7df7505 --- /dev/null +++ b/kaldi_io/src/kaldi/util/kaldi-io-inl.h @@ -0,0 +1,45 @@ +// util/kaldi-io-inl.h + +// Copyright 2009-2011 Microsoft Corporation + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at + +// http://www.apache.org/licenses/LICENSE-2.0 + +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. +#ifndef KALDI_UTIL_KALDI_IO_INL_H_ +#define KALDI_UTIL_KALDI_IO_INL_H_ + + +namespace kaldi { + +bool Input::Open(const std::string &rxfilename, bool *binary) { + return OpenInternal(rxfilename, true, binary); +} + +bool Input::OpenTextMode(const std::string &rxfilename) { + return OpenInternal(rxfilename, false, NULL); +} + +bool Input::IsOpen() { + return impl_ != NULL; +} + +bool Output::IsOpen() { + return impl_ != NULL; +} + + +} // end namespace kaldi. + + +#endif diff --git a/kaldi_io/src/kaldi/util/kaldi-io.h b/kaldi_io/src/kaldi/util/kaldi-io.h new file mode 100644 index 0000000..f2c7563 --- /dev/null +++ b/kaldi_io/src/kaldi/util/kaldi-io.h @@ -0,0 +1,264 @@ +// util/kaldi-io.h + +// Copyright 2009-2011 Microsoft Corporation; Jan Silovsky + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at + +// http://www.apache.org/licenses/LICENSE-2.0 + +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. +#ifndef KALDI_UTIL_KALDI_IO_H_ +#define KALDI_UTIL_KALDI_IO_H_ + +#include <cctype> // For isspace. +#include <limits> +#include <string> +#include "base/kaldi-common.h" +#ifdef _MSC_VER +# include <fcntl.h> +# include <io.h> +#endif + + + +namespace kaldi { + +class OutputImplBase; // Forward decl; defined in a .cc file +class InputImplBase; // Forward decl; defined in a .cc file + +/// \addtogroup io_group +/// @{ + +// The Output and Input classes handle stream-opening for "extended" filenames +// that include actual files, standard-input/standard-output, pipes, and +// offsets into actual files. They also handle reading and writing the +// binary-mode headers for Kaldi files, where applicable. The classes have +// versions of the Open routines that throw and do not throw, depending whether +// the calling code wants to catch the errors or not; there are also versions +// that write (or do not write) the Kaldi binary-mode header that says if it's +// binary mode. Generally files that contain Kaldi objects will have the header +// on, so we know upon reading them whether they have the header. So you would +// use the OpenWithHeader routines for these (or the constructor); but other +// types of objects (e.g. FSTs) would have files without a header so you would +// use OpenNoHeader. + +// We now document the types of extended filenames that we use. +// +// A "wxfilename" is an extended filename for writing. It can take three forms: +// (1) Filename: e.g. "/some/filename", "./a/b/c", "c:\Users\dpovey\My Documents\\boo" +// (whatever the actual file-system interprets) +// (2) Standard output: "" or "-" +// (3) A pipe: e.g. "gunzip -c /tmp/abc.gz |" +// +// +// A "rxfilename" is an extended filename for reading. It can take four forms: +// (1) An actual filename, whatever the file-system can read, e.g. "/my/file". +// (2) Standard input: "" or "-" +// (3) A pipe: e.g. "| gzip -c > /tmp/abc.gz" +// (4) An offset into a file, e.g.: "/mnt/blah/data/1.ark:24871" +// [these are created by the Table and TableWriter classes; I may also write +// a program that creates them for arbitrary files] +// + + +// Typical usage: +// ... +// bool binary; +// MyObject.Write(Output(some_filename, binary).Stream(), binary); +// +// ... more extensive example: +// { +// Output ko(some_filename, binary); +// MyObject1.Write(ko.Stream(), binary); +// MyObject2.Write(ko.Stream(), binary); +// } + + + +enum OutputType { + kNoOutput, + kFileOutput, + kStandardOutput, + kPipeOutput +}; + +/// ClassifyWxfilename interprets filenames as follows: +/// - kNoOutput: invalid filenames (leading or trailing space, things that look +/// like wspecifiers and rspecifiers or like pipes to read from with leading |. +/// - kFileOutput: Normal filenames +/// - kStandardOutput: The empty string or "-", interpreted as standard output +/// - kPipeOutput: pipes, e.g. "gunzip -c some_file.gz |" +OutputType ClassifyWxfilename(const std::string &wxfilename); + +enum InputType { + kNoInput, + kFileInput, + kStandardInput, + kOffsetFileInput, + kPipeInput +}; + +/// ClassifyRxfilenames interprets filenames for reading as follows: +/// - kNoInput: invalid filenames (leading or trailing space, things that +/// look like wspecifiers and rspecifiers or pipes to write to +/// with trailing |. +/// - kFileInput: normal filenames +/// - kStandardInput: the empty string or "-" +/// - kPipeInput: e.g. "| gzip -c > blah.gz" +/// - kOffsetFileInput: offsets into files, e.g. /some/filename:12970 +InputType ClassifyRxfilename(const std::string &rxfilename); + + +class Output { + public: + // The normal constructor, provided for convenience. + // Equivalent to calling with default constructor then Open() + // with these arguments. + Output(const std::string &filename, bool binary, bool write_header = true); + + Output(): impl_(NULL) {}; + + /// This opens the stream, with the given mode (binary or text). It returns + /// true on success and false on failure. However, it will throw if something + /// was already open and could not be closed (to avoid this, call Close() + /// first. if write_header == true and binary == true, it writes the Kaldi + /// binary-mode header ('\0' then 'B'). You may call Open even if it is + /// already open; it will close the existing stream and reopen (however if + /// closing the old stream failed it will throw). + bool Open(const std::string &wxfilename, bool binary, bool write_header); + + inline bool IsOpen(); // return true if we have an open stream. Does not imply + // stream is good for writing. + + std::ostream &Stream(); // will throw if not open; else returns stream. + + // Close closes the stream. Calling Close is never necessary unless you + // want to avoid exceptions being thrown. There are times when calling + // Close will hurt efficiency (basically, when using offsets into files, + // and using the same Input object), + // but most of the time the user won't be doing this directly, it will + // be done in kaldi-table.{h, cc}, so you don't have to worry about it. + bool Close(); + + // This will throw if stream could not be closed (to check error status, + // call Close()). + ~Output(); + + private: + OutputImplBase *impl_; // non-NULL if open. + std::string filename_; + KALDI_DISALLOW_COPY_AND_ASSIGN(Output); +}; + + +// bool binary_in; +// Input ki(some_filename, &binary_in); +// MyObject.Read(ki, binary_in); +// +// ... more extensive example: +// +// { +// bool binary_in; +// Input ki(some_filename, &binary_in); +// MyObject1.Read(ki.Stream(), &binary_in); +// MyObject2.Write(ki.Stream(), &binary_in); +// } +// Note that to catch errors you need to use try.. catch. +// Input communicates errors by throwing exceptions. + + +// Input interprets four kinds of filenames: +// (1) Normal filenames +// (2) The empty string or "-", interpreted as standard output +// (3) Pipes, e.g. "| gzip -c > some_file.gz" +// (4) Offsets into [real] files, e.g. "/my/filename:12049" +// The last one has no correspondence in Output. + + +class Input { + public: + /// The normal constructor. Opens the stream in binary mode. + /// Equivalent to calling the default constructor followed by Open(); then, if + /// binary != NULL, it calls ReadHeader(), putting the output in "binary"; it + /// throws on error. + Input(const std::string &rxfilename, bool *contents_binary = NULL); + + Input(): impl_(NULL) {} + + // Open opens the stream for reading (the mode, where relevant, is binary; use + // OpenTextMode for text-mode, we made this a separate function rather than a + // boolean argument, to avoid confusion with Kaldi's text/binary distinction, + // since reading in the file system's text mode is unusual.) If + // contents_binary != NULL, it reads the binary-mode header and puts it in the + // "binary" variable. Returns true on success. If it returns false it will + // not be open. You may call Open even if it is already open; it will close + // the existing stream and reopen (however if closing the old stream failed it + // will throw). + inline bool Open(const std::string &rxfilename, bool *contents_binary = NULL); + + // As Open but (if the file system has text/binary modes) opens in text mode; + // you shouldn't ever have to use this as in Kaldi we read even text files in + // binary mode (and ignore the \r). + inline bool OpenTextMode(const std::string &rxfilename); + + // Return true if currently open for reading and Stream() will + // succeed. Does not guarantee that the stream is good. + inline bool IsOpen(); + + // It is never necessary or helpful to call Close, except if + // you are concerned about to many filehandles being open. + // Close does not throw. + void Close(); + + // Returns the underlying stream. Throws if !IsOpen() + std::istream &Stream(); + + // Destructor does not throw: input streams may legitimately fail so we + // don't worry about the status when we close them. + ~Input(); + private: + bool OpenInternal(const std::string &rxfilename, bool file_binary, bool *contents_binary); + InputImplBase *impl_; + KALDI_DISALLOW_COPY_AND_ASSIGN(Input); +}; + +template <class C> inline void ReadKaldiObject(const std::string &filename, + C *c) { + bool binary_in; + Input ki(filename, &binary_in); + c->Read(ki.Stream(), binary_in); +} + +template <class C> inline void WriteKaldiObject(const C &c, + const std::string &filename, + bool binary) { + Output ko(filename, binary); + c.Write(ko.Stream(), binary); +} + +/// PrintableRxfilename turns the rxfilename into a more human-readable +/// form for error reporting, i.e. it does quoting and escaping and +/// replaces "" or "-" with "standard input". +std::string PrintableRxfilename(std::string rxfilename); + +/// PrintableWxfilename turns the filename into a more human-readable +/// form for error reporting, i.e. it does quoting and escaping and +/// replaces "" or "-" with "standard output". +std::string PrintableWxfilename(std::string wxfilename); + +/// @} + +} // end namespace kaldi. + +#include "kaldi-io-inl.h" + +#endif diff --git a/kaldi_io/src/kaldi/util/kaldi-pipebuf.h b/kaldi_io/src/kaldi/util/kaldi-pipebuf.h new file mode 100644 index 0000000..43e5a2e --- /dev/null +++ b/kaldi_io/src/kaldi/util/kaldi-pipebuf.h @@ -0,0 +1,90 @@ +// util/kaldi-pipebuf.h + +// Copyright 2009-2011 Ondrej Glembek + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at + +// http://www.apache.org/licenses/LICENSE-2.0 + +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + + +/** @file kaldi-pipebuf.h + * This is an Kaldi C++ Library header. + */ + +#ifndef KALDI_UTIL_KALDI_PIPEBUF_H_ +#define KALDI_UTIL_KALDI_PIPEBUF_H_ + +#if defined(_LIBCPP_VERSION) // libc++ +#include "basic-filebuf.h" +#else +#include <fstream> +#endif + +namespace kaldi +{ +// This class provides a way to initialize a filebuf with a FILE* pointer +// directly; it will not close the file pointer when it is deleted. +// The C++ standard does not allow implementations of C++ to provide +// this constructor within basic_filebuf, which makes it hard to deal +// with pipes using completely native C++. This is a workaround + +#ifdef _MSC_VER +#elif defined(_LIBCPP_VERSION) // libc++ +template<class CharType, class Traits = std::char_traits<CharType> > +class basic_pipebuf : public basic_filebuf<CharType, Traits> +{ + public: + typedef basic_pipebuf<CharType, Traits> ThisType; + + public: + basic_pipebuf(FILE *fptr, std::ios_base::openmode mode) + : basic_filebuf<CharType, Traits>() { + this->open(fptr, mode); + if (!this->is_open()) { + KALDI_WARN << "Error initializing pipebuf"; // probably indicates + // code error, if the fptr was good. + return; + } + } +}; // class basic_pipebuf +#else +template<class CharType, class Traits = std::char_traits<CharType> > +class basic_pipebuf : public std::basic_filebuf<CharType, Traits> +{ + public: + typedef basic_pipebuf<CharType, Traits> ThisType; + + public: + basic_pipebuf(FILE *fptr, std::ios_base::openmode mode) + : std::basic_filebuf<CharType, Traits>() { + this->_M_file.sys_open(fptr, mode); + if (!this->is_open()) { + KALDI_WARN << "Error initializing pipebuf"; // probably indicates + // code error, if the fptr was good. + return; + } + this->_M_mode = mode; + this->_M_buf_size = BUFSIZ; + this->_M_allocate_internal_buffer(); + this->_M_reading = false; + this->_M_writing = false; + this->_M_set_buffer(-1); + } +}; // class basic_pipebuf +#endif // _MSC_VER + +}; // namespace kaldi + +#endif // KALDI_UTIL_KALDI_PIPEBUF_H_ + diff --git a/kaldi_io/src/kaldi/util/kaldi-table-inl.h b/kaldi_io/src/kaldi/util/kaldi-table-inl.h new file mode 100644 index 0000000..6b73c88 --- /dev/null +++ b/kaldi_io/src/kaldi/util/kaldi-table-inl.h @@ -0,0 +1,2246 @@ +// util/kaldi-table-inl.h + +// Copyright 2009-2011 Microsoft Corporation +// 2013 Johns Hopkins University (author: Daniel Povey) + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + + +#ifndef KALDI_UTIL_KALDI_TABLE_INL_H_ +#define KALDI_UTIL_KALDI_TABLE_INL_H_ + +#include <algorithm> +#include "util/kaldi-io.h" +#include "util/text-utils.h" +#include "util/stl-utils.h" // for StringHasher. + + +namespace kaldi { + +/// \addtogroup table_impl_types +/// @{ + +template<class Holder> class SequentialTableReaderImplBase { + public: + typedef typename Holder::T T; + // note that Open takes rxfilename not rspecifier. + virtual bool Open(const std::string &rxfilename) = 0; + virtual bool Done() const = 0; + virtual bool IsOpen() const = 0; + virtual std::string Key() = 0; + virtual const T &Value() = 0; + virtual void FreeCurrent() = 0; + virtual void Next() = 0; + virtual bool Close() = 0; + SequentialTableReaderImplBase() { } + virtual ~SequentialTableReaderImplBase() { } + private: + KALDI_DISALLOW_COPY_AND_ASSIGN(SequentialTableReaderImplBase); +}; + + +// This is the implementation for SequentialTableReader +// when it's actually a script file. +template<class Holder> class SequentialTableReaderScriptImpl: + public SequentialTableReaderImplBase<Holder> { + public: + typedef typename Holder::T T; + + SequentialTableReaderScriptImpl(): state_(kUninitialized) { } + + virtual bool Open(const std::string &rspecifier) { + if (state_ != kUninitialized) + if (! Close()) // call Close() yourself to suppress this exception. + KALDI_ERR << "TableReader::Open, error closing previous input: " + << "rspecifier was " << rspecifier_; + bool binary; + rspecifier_ = rspecifier; + RspecifierType rs = ClassifyRspecifier(rspecifier, &script_rxfilename_, + &opts_); + KALDI_ASSERT(rs == kScriptRspecifier); + if (!script_input_.Open(script_rxfilename_, &binary)) { // Failure on Open + KALDI_WARN << "Failed to open script file " + << PrintableRxfilename(script_rxfilename_); + state_ = kUninitialized; + return false; + } else { // Open succeeded. + if (binary) { // script file should not be binary file.. + state_ = kError; // bad script file. + script_input_.Close(); + return false; + } else { + state_ = kFileStart; + Next(); + if (state_ == kError) { + script_input_.Close(); + return false; + } + if (opts_.permissive) { // Next() will have preloaded. + KALDI_ASSERT(state_ == kLoadSucceeded || state_ == kEof); + } else { + KALDI_ASSERT(state_ == kHaveScpLine || state_ == kEof); + } + return true; // Success. + } + } + } + + virtual bool IsOpen() const { + switch (state_) { + case kEof: case kError: case kHaveScpLine: case kLoadSucceeded: case kLoadFailed: return true; + case kUninitialized: return false; + default: KALDI_ERR << "IsOpen() called on invalid object."; // kFileStart is not valid + // state for user to call something on. + return false; + } + } + + virtual bool Done() const { + switch (state_) { + case kHaveScpLine: return false; + case kLoadSucceeded: case kLoadFailed: return false; + // These cases are because we want LoadCurrent() + // to be callable after Next() and to not change the Done() status [only Next() should change + // the Done() status]. + case kEof: case kError: return true; // Error condition, like Eof, counts as Done(); the destructor + // or Close() will inform the user of the error. + default: KALDI_ERR << "Done() called on TableReader object at the wrong time."; + return false; + } + } + + virtual std::string Key() { + // Valid to call this whenever Done() returns false. + switch (state_) { + case kHaveScpLine: case kLoadSucceeded: case kLoadFailed: break; + default: + // coding error. + KALDI_ERR << "Key() called on TableReader object at the wrong time."; + } + return key_; + } + const T &Value() { + StateType orig_state = state_; + if (state_ == kHaveScpLine) LoadCurrent(); // Takes + // state_ to kLoadSucceeded or kLoadFailed. + if (state_ == kLoadFailed) { // this can happen due to + // a file listed in an scp file not existing, or + // read failure, failure of a command, etc. + if (orig_state == kHaveScpLine) + KALDI_ERR << "TableReader: failed to load object from " + << PrintableRxfilename(data_rxfilename_) + << " (to suppress this error, add the permissive " + << "(p, ) option to the rspecifier."; + + else // orig_state_ was kLoadFailed, which only could have happened + // if the user called FreeCurrent(). + KALDI_ERR << "TableReader: you called Value() after FreeCurrent()."; + } else if (state_ != kLoadSucceeded) { + // This would be a coding error. + KALDI_ERR << "TableReader: Value() called at the wrong time."; + } + return holder_.Value(); + } + void FreeCurrent() { + if (state_ == kLoadSucceeded) { + holder_.Clear(); + state_ = kLoadFailed; + } else { + KALDI_WARN << "TableReader: FreeCurrent called at the wrong time."; + } + } + void Next() { + while (1) { + NextScpLine(); + if (Done()) return; + if (opts_.permissive) { + // Permissive mode means, when reading scp files, we treat keys whose scp entry + // cannot be read as nonexistent. This means trying to read. + if (LoadCurrent()) return; // Success. + // else try the next scp line. + } else { + return; // We go the next key; Value() will crash if we can't + // read the scp line. + } + } + } + + virtual bool Close() { + // Close() will succeed if the stream was not in an error + // state. To clean up, it also closes the Input objects if + // they're open. + if (script_input_.IsOpen()) + script_input_.Close(); + if (data_input_.IsOpen()) + data_input_.Close(); + if (state_ == kLoadSucceeded) + holder_.Clear(); + if (!this->IsOpen()) + KALDI_ERR << "Close() called on input that was not open."; + StateType old_state = state_; + state_ = kUninitialized; + if (old_state == kError) { + if (opts_.permissive) { + KALDI_WARN << "Close() called on scp file with read error, ignoring the " + "error because permissive mode specified."; + return true; + } else return false; // User will do something with the error status. + } else return true; + } + + virtual ~SequentialTableReaderScriptImpl() { + if (state_ == kError) + KALDI_ERR << "TableReader: reading script file failed: from scp " + << PrintableRxfilename(script_rxfilename_); + // If you don't want this exception to be thrown you can + // call Close() and check the status. + if (state_ == kLoadSucceeded) + holder_.Clear(); + } + private: + bool LoadCurrent() { + // Attempts to load object whose rxfilename is on the current scp line. + if (state_ != kHaveScpLine) + KALDI_ERR << "TableReader: LoadCurrent() called at the wrong time."; + bool ans; + // note, NULL means it doesn't read the binary-mode header + if (Holder::IsReadInBinary()) ans = data_input_.Open(data_rxfilename_, NULL); + else ans = data_input_.OpenTextMode(data_rxfilename_); + if (!ans) { + // May want to make this warning a VLOG at some point + KALDI_WARN << "TableReader: failed to open file " + << PrintableRxfilename(data_rxfilename_); + state_ = kLoadFailed; + return false; + } else { + if (holder_.Read(data_input_.Stream())) { + state_ = kLoadSucceeded; + return true; + } else { // holder_ will not contain data. + KALDI_WARN << "TableReader: failed to load object from " + << PrintableRxfilename(data_rxfilename_); + state_ = kLoadFailed; + return false; + } + } + } + + // Reads the next line in the script file. + void NextScpLine() { + switch (state_) { + case kLoadSucceeded: holder_.Clear(); break; + case kHaveScpLine: case kLoadFailed: case kFileStart: break; + default: + // No other states are valid to call Next() from. + KALDI_ERR << "Reading script file: Next called wrongly."; + } + std::string line; + if (getline(script_input_.Stream(), line)) { + SplitStringOnFirstSpace(line, &key_, &data_rxfilename_); + if (!key_.empty() && !data_rxfilename_.empty()) { + // Got a valid line. + state_ = kHaveScpLine; + } else { + // Got an invalid line. + state_ = kError; // we can't make sense of this + // scp file and will now die. + } + } else { + state_ = kEof; // nothing more in the scp file. + // Might as well close the input streams as don't need them. + script_input_.Close(); + if (data_input_.IsOpen()) + data_input_.Close(); + } + } + + + Input script_input_; // Input object for the .scp file + Input data_input_; // Input object for the entries in + // the script file. + Holder holder_; // Holds the object. + bool binary_; // Binary-mode archive. + std::string key_; + std::string rspecifier_; + std::string script_rxfilename_; // of the script file. + RspecifierOptions opts_; // options. + std::string data_rxfilename_; // of the file we're reading. + enum StateType { + // [The state of the reading process] [does holder_ [is script_inp_ + // have object] open] + kUninitialized, // Uninitialized or closed. no no + kEof, // We did Next() and found eof in script file. no no + kError, // Some other error no yes + kHaveScpLine, // Just called Open() or Next() and have a no yes + // line of the script file but no data. + kLoadSucceeded, // Called LoadCurrent() and it succeeded. yes yes + kLoadFailed, // Called LoadCurrent() and it failed, no yes + // or the user called FreeCurrent().. note, + // if when called by user we are in this state, + // it means the user called FreeCurrent(). + kFileStart, // [state we only use internally] no yes + } state_; + private: +}; + + +// This is the implementation for SequentialTableReader +// when it's an archive. Note that the archive format is: +// key1 [space] object1 key2 [space] +// object2 ... eof. +// "object1" is the output of the Holder::Write function and will +// typically contain a binary header (in binary mode) and then +// the output of object.Write(os, binary). +// The archive itself does not care whether it is in binary +// or text mode, for reading purposes. + +template<class Holder> class SequentialTableReaderArchiveImpl: + public SequentialTableReaderImplBase<Holder> { + public: + typedef typename Holder::T T; + + SequentialTableReaderArchiveImpl(): state_(kUninitialized) { } + + virtual bool Open(const std::string &rspecifier) { + if (state_ != kUninitialized) { + if (! Close()) { // call Close() yourself to suppress this exception. + if (opts_.permissive) + KALDI_WARN << "TableReader::Open, error closing previous input " + "(only warning, since permissive mode)."; + else + KALDI_ERR << "TableReader::Open, error closing previous input."; + } + } + rspecifier_ = rspecifier; + RspecifierType rs = ClassifyRspecifier(rspecifier, + &archive_rxfilename_, + &opts_); + KALDI_ASSERT(rs == kArchiveRspecifier); + + bool ans; + // NULL means don't expect binary-mode header + if (Holder::IsReadInBinary()) + ans = input_.Open(archive_rxfilename_, NULL); + else + ans = input_.OpenTextMode(archive_rxfilename_); + if (!ans) { // header. + KALDI_WARN << "TableReader: failed to open stream " + << PrintableRxfilename(archive_rxfilename_); + state_ = kUninitialized; // Failure on Open + return false; // User should print the error message. + } + state_ = kFileStart; + Next(); + if (state_ == kError) { + KALDI_WARN << "Error beginning to read archive file (wrong filename?): " + << PrintableRxfilename(archive_rxfilename_); + input_.Close(); + state_ = kUninitialized; + return false; + } + KALDI_ASSERT(state_ == kHaveObject || state_ == kEof); + return true; + } + + virtual void Next() { + switch (state_) { + case kHaveObject: + holder_.Clear(); break; + case kFileStart: case kFreedObject: + break; + default: + KALDI_ERR << "TableReader: Next() called wrongly."; + } + std::istream &is = input_.Stream(); + is.clear(); // Clear any fail bits that may have been set... just in case + // this happened in the Read function. + is >> key_; // This eats up any leading whitespace and gets the string. + if (is.eof()) { + state_ = kEof; + return; + } + if (is.fail()) { // This shouldn't really happen, barring file-system errors. + KALDI_WARN << "Error reading archive " + << PrintableRxfilename(archive_rxfilename_); + state_ = kError; + return; + } + int c; + if ((c = is.peek()) != ' ' && c != '\t' && c != '\n') { // We expect a space ' ' after the key. + // We also allow tab [which is consumed] and newline [which is not], just + // so we can read archives generated by scripts that may not be fully + // aware of how this format works. + KALDI_WARN << "Invalid archive file format: expected space after key " + << key_ << ", got character " + << CharToString(static_cast<char>(is.peek())) << ", reading " + << PrintableRxfilename(archive_rxfilename_); + state_ = kError; + return; + } + if (c != '\n') is.get(); // Consume the space or tab. + if (holder_.Read(is)) { + state_ = kHaveObject; + return; + } else { + KALDI_WARN << "Object read failed, reading archive " + << PrintableRxfilename(archive_rxfilename_); + state_ = kError; + return; + } + } + + virtual bool IsOpen() const { + switch (state_) { + case kEof: case kError: case kHaveObject: case kFreedObject: return true; + case kUninitialized: return false; + default: KALDI_ERR << "IsOpen() called on invalid object."; // kFileStart is not valid + // state for user to call something on. + return false; + } + } + + virtual bool Done() const { + switch (state_) { + case kHaveObject: + return false; + case kEof: case kError: + return true; // Error-state counts as Done(), but destructor + // will fail (unless you check the status with Close()). + default: + KALDI_ERR << "Done() called on TableReader object at the wrong time."; + return false; + } + } + + virtual std::string Key() { + // Valid to call this whenever Done() returns false + switch (state_) { + case kHaveObject: break; // only valid case. + default: + // coding error. + KALDI_ERR << "Key() called on TableReader object at the wrong time."; + } + return key_; + } + const T &Value() { + switch (state_) { + case kHaveObject: + break; // only valid case. + default: + // coding error. + KALDI_ERR << "Value() called on TableReader object at the wrong time."; + } + return holder_.Value(); + } + virtual void FreeCurrent() { + if (state_ == kHaveObject) { + holder_.Clear(); + state_ = kFreedObject; + } else + KALDI_WARN << "TableReader: FreeCurernt called at the wrong time."; + } + + virtual bool Close() { + if (! this->IsOpen()) + KALDI_ERR << "Close() called on TableReader twice or otherwise wrongly."; + if (input_.IsOpen()) + input_.Close(); + if (state_ == kHaveObject) + holder_.Clear(); + bool ans; + if (opts_.permissive) { + ans = true; // always return success. + if (state_ == kError) + KALDI_WARN << "Error detected closing TableReader for archive " + << PrintableRxfilename(archive_rxfilename_) << " but ignoring " + << "it as permissive mode specified."; + } else + ans = (state_ != kError); // If error state, user should detect it. + state_ = kUninitialized; + return ans; + } + + virtual ~SequentialTableReaderArchiveImpl() { + if (state_ == kError) { + if (opts_.permissive) + KALDI_WARN << "Error detected closing TableReader for archive " + << PrintableRxfilename(archive_rxfilename_) << " but ignoring " + << "it as permissive mode specified."; + else + KALDI_ERR << "TableReader: error detected closing archive " + << PrintableRxfilename(archive_rxfilename_); + } + // If you don't want this exception to be thrown you can + // call Close() and check the status. + if (state_ == kHaveObject) + holder_.Clear(); + } + private: + Input input_; // Input object for the archive + Holder holder_; // Holds the object. + std::string key_; + std::string rspecifier_; + std::string archive_rxfilename_; + RspecifierOptions opts_; + enum { // [The state of the reading process] [does holder_ [is input_ + // have object] open] + kUninitialized, // Uninitialized or closed. no no + kFileStart, // [state we use internally: just opened.] no yes + kEof, // We did Next() and found eof in archive no no + kError, // Some other error no no + kHaveObject, // We read the key and the object after it. yes yes + kFreedObject, // The user called FreeCurrent(). no yes + } state_; +}; + + +template<class Holder> +SequentialTableReader<Holder>::SequentialTableReader(const std::string &rspecifier): impl_(NULL) { + if (rspecifier != "" && !Open(rspecifier)) + KALDI_ERR << "Error constructing TableReader: rspecifier is " << rspecifier; +} + +template<class Holder> +bool SequentialTableReader<Holder>::Open(const std::string &rspecifier) { + if (IsOpen()) + if (!Close()) + KALDI_ERR << "Could not close previously open object."; + // now impl_ will be NULL. + + RspecifierType wt = ClassifyRspecifier(rspecifier, NULL, NULL); + switch (wt) { + case kArchiveRspecifier: + impl_ = new SequentialTableReaderArchiveImpl<Holder>(); + break; + case kScriptRspecifier: + impl_ = new SequentialTableReaderScriptImpl<Holder>(); + break; + case kNoRspecifier: default: + KALDI_WARN << "Invalid rspecifier " << rspecifier; + return false; + } + if (!impl_->Open(rspecifier)) { + delete impl_; + impl_ = NULL; + return false; // sub-object will have printed warnings. + } + else return true; +} + +template<class Holder> +bool SequentialTableReader<Holder>::Close() { + CheckImpl(); + bool ans = impl_->Close(); + delete impl_; // We don't keep around empty impl_ objects. + impl_ = NULL; + return ans; +} + + +template<class Holder> +bool SequentialTableReader<Holder>::IsOpen() const { + return (impl_ != NULL); // Because we delete the object whenever + // that object is not open. Thus, the IsOpen functions of the + // Impl objects are not really needed. +} + +template<class Holder> +std::string SequentialTableReader<Holder>::Key() { + CheckImpl(); + return impl_->Key(); // this call may throw if called wrongly in other ways, + // e.g. eof. +} + + +template<class Holder> +void SequentialTableReader<Holder>::FreeCurrent() { + CheckImpl(); + impl_->FreeCurrent(); +} + + +template<class Holder> +const typename SequentialTableReader<Holder>::T & +SequentialTableReader<Holder>::Value() { + CheckImpl(); + return impl_->Value(); // This may throw (if LoadCurrent() returned false you are safe.). +} + + +template<class Holder> +void SequentialTableReader<Holder>::Next() { + CheckImpl(); + impl_->Next(); +} + +template<class Holder> +bool SequentialTableReader<Holder>::Done() { + CheckImpl(); + return impl_->Done(); +} + + +template<class Holder> +SequentialTableReader<Holder>::~SequentialTableReader() { + if (impl_) delete impl_; + // Destructor of impl_ may throw. +} + + + +template<class Holder> class TableWriterImplBase { + public: + typedef typename Holder::T T; + + virtual bool Open(const std::string &wspecifier) = 0; + + // Write returns true on success, false on failure, but + // some errors may not be detected until we call Close(). + // It throws (via KALDI_ERR) if called wrongly. We could + // have just thrown on all errors, since this is what + // TableWriter does; it was designed this way because originally + // TableWriter::Write returned an exit status. + virtual bool Write(const std::string &key, const T &value) = 0; + + // Flush will flush any archive; it does not return error status, + // any errors will be reported on the next Write or Close. + virtual void Flush() = 0; + + virtual bool Close() = 0; + + virtual bool IsOpen() const = 0; + + // May throw on write error if Close was not called. + virtual ~TableWriterImplBase() { } + + TableWriterImplBase() { } + private: + KALDI_DISALLOW_COPY_AND_ASSIGN(TableWriterImplBase); +}; + + +// The implementation of TableWriter we use when writing directly +// to an archive with no associated scp. +template<class Holder> +class TableWriterArchiveImpl: public TableWriterImplBase<Holder> { + public: + typedef typename Holder::T T; + + virtual bool Open(const std::string &wspecifier) { + switch (state_) { + case kUninitialized: + break; + case kWriteError: + KALDI_ERR << "TableWriter: opening stream, already open with write error."; + case kOpen: default: + if (!Close()) // throw because this error may not have been previously + // detected by the user. + KALDI_ERR << "TableWriter: opening stream, error closing previously open stream."; + } + wspecifier_ = wspecifier; + WspecifierType ws = ClassifyWspecifier(wspecifier, + &archive_wxfilename_, + NULL, + &opts_); + KALDI_ASSERT(ws == kArchiveWspecifier); // or wrongly called. + + if (output_.Open(archive_wxfilename_, opts_.binary, false)) { // false means no binary header. + state_ = kOpen; + return true; + } else { + // stream will not be open. User will report this error + // (we return bool), so don't bother printing anything. + state_ = kUninitialized; + return false; + } + } + + virtual bool IsOpen() const { + switch (state_) { + case kUninitialized: return false; + case kOpen: case kWriteError: return true; + default: KALDI_ERR << "IsOpen() called on TableWriter in invalid state."; + } + return false; + } + + // Write returns true on success, false on failure, but + // some errors may not be detected till we call Close(). + virtual bool Write(const std::string &key, const T &value) { + switch (state_) { + case kOpen: break; + case kWriteError: + // user should have known from the last + // call to Write that there was a problem. + KALDI_WARN << "TableWriter: attempting to write to invalid stream."; + return false; + case kUninitialized: default: + KALDI_ERR << "TableWriter: Write called on invalid stream"; + + } + // state is now kOpen or kWriteError. + if (!IsToken(key)) // e.g. empty string or has spaces... + KALDI_ERR << "TableWriter: using invalid key " << key; + output_.Stream() << key << ' '; + if (!Holder::Write(output_.Stream(), opts_.binary, value)) { + KALDI_WARN << "TableWriter: write failure to " + << PrintableWxfilename(archive_wxfilename_); + state_ = kWriteError; + return false; + } + if (state_ == kWriteError) return false; // Even if this Write seems to have + // succeeded, we fail because a previous Write failed and the archive may be + // corrupted and unreadable. + + if (opts_.flush) + Flush(); + return true; + } + + // Flush will flush any archive; it does not return error status, + // any errors will be reported on the next Write or Close. + virtual void Flush() { + switch (state_) { + case kWriteError: case kOpen: + output_.Stream().flush(); // Don't check error status. + return; + default: + KALDI_WARN << "TableWriter: Flush called on not-open writer."; + } + } + + virtual bool Close() { + if (!this->IsOpen() || !output_.IsOpen()) + KALDI_ERR << "TableWriter: Close called on a stream that was not open." << this->IsOpen() << ", " << output_.IsOpen(); + bool close_success = output_.Close(); + if (!close_success) { + KALDI_WARN << "TableWriter: error closing stream: wspecifier is " + << wspecifier_; + state_ = kUninitialized; + return false; + } + if (state_ == kWriteError) { + KALDI_WARN << "TableWriter: closing writer in error state: wspecifier is " + << wspecifier_; + state_ = kUninitialized; + return false; + } + state_ = kUninitialized; + return true; + } + + TableWriterArchiveImpl(): state_(kUninitialized) {} + + // May throw on write error if Close was not called. + virtual ~TableWriterArchiveImpl() { + if (!IsOpen()) return; + else if (!Close()) + KALDI_ERR << "At TableWriter destructor: Write failed or stream close " + << "failed: wspecifier is "<< wspecifier_; + } + + private: + Output output_; + WspecifierOptions opts_; + std::string wspecifier_; + std::string archive_wxfilename_; + enum { // is stream open? + kUninitialized, // no + kOpen, // yes + kWriteError, // yes + } state_; +}; + + + + +// The implementation of TableWriter we use when writing to +// individual files (more generally, wxfilenames) specified +// in an scp file that we read. + +// Note: the code for this class is similar to RandomAccessTableReaderScriptImpl; +// try to keep them in sync. + +template<class Holder> +class TableWriterScriptImpl: public TableWriterImplBase<Holder> { + public: + typedef typename Holder::T T; + + TableWriterScriptImpl(): last_found_(0), state_(kUninitialized) {} + + virtual bool Open(const std::string &wspecifier) { + switch (state_) { + case kReadScript: + KALDI_ERR << " Opening already open TableWriter: call Close first."; + case kUninitialized: case kNotReadScript: + break; + } + wspecifier_ = wspecifier; + WspecifierType ws = ClassifyWspecifier(wspecifier, + NULL, + &script_rxfilename_, + &opts_); + KALDI_ASSERT(ws == kScriptWspecifier); // or wrongly called. + KALDI_ASSERT(script_.empty()); // no way it could be nonempty at this point. + + if (! ReadScriptFile(script_rxfilename_, + true, // print any warnings + &script_)) { // error reading script file or invalid format + state_ = kNotReadScript; + return false; // no need to print further warnings. user gets the error. + } + std::sort(script_.begin(), script_.end()); + for (size_t i = 0; i+1 < script_.size(); i++) { + if (script_[i].first.compare(script_[i+1].first) >= 0) { + // script[i] not < script[i+1] in lexical order... + KALDI_WARN << "Script file " << PrintableRxfilename(script_rxfilename_) + << " contains duplicate key " << script_[i].first; + state_ = kNotReadScript; + return false; + } + } + state_ = kReadScript; + return true; + } + + virtual bool IsOpen() const { return (state_ == kReadScript); } + + virtual bool Close() { + if (!IsOpen()) + KALDI_ERR << "Close() called on TableWriter that was not open."; + state_ = kUninitialized; + last_found_ = 0; + script_.clear(); + return true; + } + + // Write returns true on success, false on failure, but + // some errors may not be detected till we call Close(). + virtual bool Write(const std::string &key, const T &value) { + if (!IsOpen()) + KALDI_ERR << "TableWriter: Write called on invalid stream"; + + if (!IsToken(key)) // e.g. empty string or has spaces... + KALDI_ERR << "TableWriter: using invalid key " << key; + + std::string wxfilename; + if (!LookupFilename(key, &wxfilename)) { + if (opts_.permissive) { + return true; // In permissive mode, it's as if we're writing to /dev/null + // for missing keys. + } else { + KALDI_WARN << "TableWriter: script file " + << PrintableRxfilename(script_rxfilename_) + << " has no entry for key "<<key; + return false; + } + } + Output output; + if (!output.Open(wxfilename, opts_.binary, false)) { + // Open in the text/binary mode (on Windows) given by member var. "binary" + // (obtained from wspecifier), but do not put the binary-mode header (it + // will be written, if needed, by the Holder::Write function.) + KALDI_WARN << "TableWriter: failed to open stream: " + << PrintableWxfilename(wxfilename); + return false; + } + if (!Holder::Write(output.Stream(), opts_.binary, value) + || !output.Close()) { + KALDI_WARN << "TableWriter: failed to write data to " + << PrintableWxfilename(wxfilename); + return false; + } + return true; + } + + // Flush does nothing in this implementation, there is nothing to flush. + virtual void Flush() { } + + + virtual ~TableWriterScriptImpl() { + // Nothing to do in destructor. + } + + private: + // Note: this function is almost the same as in RandomAccessTableReaderScriptImpl. + bool LookupFilename(const std::string &key, std::string *wxfilename) { + // First, an optimization: if we're going consecutively, this will + // make the lookup very fast. + last_found_++; + if (last_found_ < script_.size() && script_[last_found_].first == key) { + *wxfilename = script_[last_found_].second; + return true; + } + std::pair<std::string, std::string> pr(key, ""); // Important that "" + // compares less than or equal to any string, so lower_bound points to the + // element that has the same key. + typedef typename std::vector<std::pair<std::string, std::string> >::const_iterator + IterType; + IterType iter = std::lower_bound(script_.begin(), script_.end(), pr); + if (iter != script_.end() && iter->first == key) { + last_found_ = iter - script_.begin(); + *wxfilename = iter->second; + return true; + } else { + return false; + } + } + + + WspecifierOptions opts_; + std::string wspecifier_; + std::string script_rxfilename_; + + // the script_ variable contains pairs of (key, filename), sorted using + // std::sort. This can be used with binary_search to look up filenames for + // writing. If this becomes inefficient we can use std::unordered_map (but I + // suspect this wouldn't be significantly faster & would use more memory). + // If memory becomes a problem here, the user should probably be passing + // only the relevant part of the scp file rather than expecting us to get too + // clever in the code. + std::vector<std::pair<std::string, std::string> > script_; + size_t last_found_; // This is for an optimization used in LookupFilename. + + enum { + kUninitialized, + kReadScript, + kNotReadScript, // read of script failed. + } state_; +}; + + +// The implementation of TableWriter we use when writing directly +// to an archive plus an associated scp. +template<class Holder> +class TableWriterBothImpl: public TableWriterImplBase<Holder> { + public: + typedef typename Holder::T T; + + virtual bool Open(const std::string &wspecifier) { + switch (state_) { + case kUninitialized: + break; + case kWriteError: + KALDI_ERR << "TableWriter: opening stream, already open with write error."; + case kOpen: default: + if (!Close()) // throw because this error may not have been previously detected by user. + KALDI_ERR << "TableWriter: opening stream, error closing previously open stream."; + } + wspecifier_ = wspecifier; + WspecifierType ws = ClassifyWspecifier(wspecifier, + &archive_wxfilename_, + &script_wxfilename_, + &opts_); + KALDI_ASSERT(ws == kBothWspecifier); // or wrongly called. + if (ClassifyWxfilename(archive_wxfilename_) != kFileOutput) + KALDI_WARN << "When writing to both archive and script, the script file " + "will generally not be interpreted correctly unless the archive is " + "an actual file: wspecifier = " << wspecifier; + + if (!archive_output_.Open(archive_wxfilename_, opts_.binary, false)) { // false means no binary header. + state_ = kUninitialized; + return false; + } + if (!script_output_.Open(script_wxfilename_, false, false)) { // first false means text mode: + // script files always text-mode. second false means don't write header (doesn't matter + // for text mode). + archive_output_.Close(); // Don't care about status: error anyway. + state_ = kUninitialized; + return false; + } + state_ = kOpen; + return true; + } + + virtual bool IsOpen() const { + switch (state_) { + case kUninitialized: return false; + case kOpen: case kWriteError: return true; + default: KALDI_ERR << "IsOpen() called on TableWriter in invalid state."; + } + return false; + } + + void MakeFilename(typename std::ostream::pos_type streampos, std::string *output) const { + std::ostringstream ss; + ss << ':' << streampos; + KALDI_ASSERT(ss.str() != ":-1"); + *output = archive_wxfilename_ + ss.str(); + + // e.g. /some/file:12302. + // Note that we warned if archive_wxfilename_ is not an actual filename; + // the philosophy is we give the user rope and if they want to hang + // themselves, with it, fine. + } + + // Write returns true on success, false on failure, but + // some errors may not be detected till we call Close(). + virtual bool Write(const std::string &key, const T &value) { + switch (state_) { + case kOpen: break; + case kWriteError: + // user should have known from the last + // call to Write that there was a problem. Warn about it. + KALDI_WARN << "TableWriter: writing to non-open TableWriter object."; + return false; + case kUninitialized: default: + KALDI_ERR << "TableWriter: Write called on invalid stream"; + } + // state is now kOpen or kWriteError. + if (!IsToken(key)) // e.g. empty string or has spaces... + KALDI_ERR << "TableWriter: using invalid key " << key; + std::ostream &archive_os = archive_output_.Stream(); + archive_os << key << ' '; + typename std::ostream::pos_type archive_os_pos = archive_os.tellp(); + // position at start of Write() to archive. We will record this in the script file. + std::string offset_rxfilename; // rxfilename with offset into the archive, + // e.g. some_archive_name.ark:431541423 + MakeFilename(archive_os_pos, &offset_rxfilename); + + // Write to the script file first. + // The idea is that we want to get all the information possible into the + // script file, to make it easier to unwind errors later. + std::ostream &script_os = script_output_.Stream(); + script_output_.Stream() << key << ' ' << offset_rxfilename << '\n'; + + if (!Holder::Write(archive_output_.Stream(), opts_.binary, value)) { + KALDI_WARN << "TableWriter: write failure to" + << PrintableWxfilename(archive_wxfilename_); + state_ = kWriteError; + return false; + } + + if (script_os.fail()) { + KALDI_WARN << "TableWriter: write failure to script file detected: " + << PrintableWxfilename(script_wxfilename_); + state_ = kWriteError; + return false; + } + + if (archive_os.fail()) { + KALDI_WARN << "TableWriter: write failure to archive file detected: " + << PrintableWxfilename(archive_wxfilename_); + state_ = kWriteError; + return false; + } + + if (state_ == kWriteError) return false; // Even if this Write seems to have + // succeeded, we fail because a previous Write failed and the archive may be + // corrupted and unreadable. + + if (opts_.flush) + Flush(); + return true; + } + + // Flush will flush any archive; it does not return error status, + // any errors will be reported on the next Write or Close. + virtual void Flush() { + switch (state_) { + case kWriteError: case kOpen: + archive_output_.Stream().flush(); // Don't check error status. + script_output_.Stream().flush(); // Don't check error status. + return; + default: + KALDI_WARN << "TableWriter: Flush called on not-open writer."; + } + } + + virtual bool Close() { + if (!this->IsOpen()) + KALDI_ERR << "TableWriter: Close called on a stream that was not open."; + bool close_success = true; + if (archive_output_.IsOpen()) + if (!archive_output_.Close()) close_success = false; + if (script_output_.IsOpen()) + if (!script_output_.Close()) close_success = false; + bool ans = close_success && (state_ != kWriteError); + state_ = kUninitialized; + return ans; + } + + TableWriterBothImpl(): state_(kUninitialized) {} + + // May throw on write error if Close() was not called. + // User can get the error status by calling Close(). + virtual ~TableWriterBothImpl() { + if (!IsOpen()) return; + else if (!Close()) + KALDI_ERR << "At TableWriter destructor: Write failed or stream close failed: " + << wspecifier_; + } + + private: + Output archive_output_; + Output script_output_; + WspecifierOptions opts_; + std::string archive_wxfilename_; + std::string script_wxfilename_; + std::string wspecifier_; + enum { // is stream open? + kUninitialized, // no + kOpen, // yes + kWriteError, // yes + } state_; +}; + + +template<class Holder> +TableWriter<Holder>::TableWriter(const std::string &wspecifier): impl_(NULL) { + if (wspecifier != "" && !Open(wspecifier)) { + KALDI_ERR << "TableWriter: failed to write to " + << wspecifier; + } +} + +template<class Holder> +bool TableWriter<Holder>::IsOpen() const { + return (impl_ != NULL); +} + + +template<class Holder> +bool TableWriter<Holder>::Open(const std::string &wspecifier) { + + if (IsOpen()) { + if (!Close()) // call Close() yourself to suppress this exception. + KALDI_ERR << "TableWriter::Open, failed to close previously open writer."; + } + KALDI_ASSERT(impl_ == NULL); + WspecifierType wtype = ClassifyWspecifier(wspecifier, NULL, NULL, NULL); + switch (wtype) { + case kBothWspecifier: + impl_ = new TableWriterBothImpl<Holder>(); + break; + case kArchiveWspecifier: + impl_ = new TableWriterArchiveImpl<Holder>(); + break; + case kScriptWspecifier: + impl_ = new TableWriterScriptImpl<Holder>(); + break; + case kNoWspecifier: default: + KALDI_WARN << "ClassifyWspecifier: invalid wspecifier " << wspecifier; + return false; + } + if (impl_->Open(wspecifier)) return true; + else { // The class will have printed a more specific warning. + delete impl_; + impl_ = NULL; + return false; + } +} + +template<class Holder> +void TableWriter<Holder>::Write(const std::string &key, + const T &value) const { + CheckImpl(); + if (!impl_->Write(key, value)) + KALDI_ERR << "Error in TableWriter::Write"; + // More specific warning will have + // been printed in the Write function. +} + +template<class Holder> +void TableWriter<Holder>::Flush() { + CheckImpl(); + impl_->Flush(); +} + +template<class Holder> +bool TableWriter<Holder>::Close() { + CheckImpl(); + bool ans = impl_->Close(); + delete impl_; // We don't keep around non-open impl_ objects [c.f. definition of IsOpen()] + impl_ = NULL; + return ans; +} + +template<class Holder> +TableWriter<Holder>::~TableWriter() { + if (IsOpen() && !Close()) { + KALDI_ERR << "Error closing TableWriter [in destructor]."; + } +} + + +// Types of RandomAccessTableReader: +// In principle, we would like to have four types of RandomAccessTableReader: +// the 4 combinations [scp, archive], [seekable, not-seekable], +// where if something is seekable we only store a file offset. However, +// it seems sufficient for now to only implement two of these, in both +// cases assuming it's not seekable so we never store file offsets and always +// store either the scp line or the data in the archive. The reasons are: +// (1) +// For scp files, storing the actual entry is not that much more expensive +// than storing the file offsets (since the entries are just filenames), and +// avoids a lot of fseek operations that might be expensive. +// (2) +// For archive files, there is no real reason, if you have the archive file +// on disk somewhere, why you wouldn't access it via its associated scp. +// [i.e. write it as ark, scp]. The main reason to read archives directly +// is if they are part of a pipe, and in this case it's not seekable, so +// we implement only this case. +// +// Note that we will rarely in practice have to keep in memory everything in +// the archive, as long as things are only read once from the archive (the +// "o, " or "once" option) and as long as we keep our keys in sorted order; to take +// advantage of this we need the "s, " (sorted) option, so we would read archives +// as e.g. "s, o, ark:-" (this is the rspecifier we would use if it was the +// standard input and these conditions held). + +template<class Holder> class RandomAccessTableReaderImplBase { + public: + typedef typename Holder::T T; + + virtual bool Open(const std::string &rspecifier) = 0; + + virtual bool HasKey(const std::string &key) = 0; + + virtual const T &Value(const std::string &key) = 0; + + virtual bool Close() = 0; + + virtual ~RandomAccessTableReaderImplBase() {} +}; + + +// Implementation of RandomAccessTableReader for a script file; for simplicity we +// just read it in all in one go, as it's unlikely someone would generate this +// from a pipe. In principle we could read it on-demand as for the archives, but +// this would probably be overkill. + +// Note: the code for this this class is similar to TableWriterScriptImpl: +// try to keep them in sync. +template<class Holder> +class RandomAccessTableReaderScriptImpl: + public RandomAccessTableReaderImplBase<Holder> { + + public: + typedef typename Holder::T T; + + RandomAccessTableReaderScriptImpl(): last_found_(0), state_(kUninitialized) {} + + virtual bool Open(const std::string &rspecifier) { + switch (state_) { + case kNotHaveObject: case kHaveObject: case kGaveObject: + KALDI_ERR << " Opening already open RandomAccessTableReader: call Close first."; + case kUninitialized: case kNotReadScript: + break; + } + rspecifier_ = rspecifier; + RspecifierType rs = ClassifyRspecifier(rspecifier, + &script_rxfilename_, + &opts_); + KALDI_ASSERT(rs == kScriptRspecifier); // or wrongly called. + KALDI_ASSERT(script_.empty()); // no way it could be nonempty at this point. + + if (! ReadScriptFile(script_rxfilename_, + true, // print any warnings + &script_)) { // error reading script file or invalid format + state_ = kNotReadScript; + return false; // no need to print further warnings. user gets the error. + } + + rspecifier_ = rspecifier; + // If opts_.sorted, the user has asserted that the keys are already sorted. + // Although we could easily sort them, we want to let the user know of this + // mistake. This same mistake could have serious effects if used with an + // archive rather than a script. + if (!opts_.sorted) + std::sort(script_.begin(), script_.end()); + for (size_t i = 0; i+1 < script_.size(); i++) { + if (script_[i].first.compare(script_[i+1].first) >= 0) { + // script[i] not < script[i+1] in lexical order... + bool same = (script_[i].first == script_[i+1].first); + KALDI_WARN << "Script file " << PrintableRxfilename(script_rxfilename_) + << (same ? " contains duplicate key: " : + " is not sorted (remove s, option or add ns, option): key is ") + << script_[i].first; + state_ = kNotReadScript; + return false; + } + } + state_ = kNotHaveObject; + return true; + } + + virtual bool IsOpen() const { + return (state_ == kNotHaveObject || state_ == kHaveObject || + state_ == kGaveObject); + } + + virtual bool Close() { + if (!IsOpen()) + KALDI_ERR << "Close() called on RandomAccessTableReader that was not open."; + holder_.Clear(); + state_ = kUninitialized; + last_found_ = 0; + script_.clear(); + current_key_ = ""; + // This one cannot fail because any errors of a "global" + // nature would have been detected when we did Open(). + // With archives it's different. + return true; + } + + virtual bool HasKey(const std::string &key) { + bool preload = opts_.permissive; + // In permissive mode, we have to check that we can read + // the scp entry before we assert that the key is there. + return HasKeyInternal(key, preload); + } + + + // Write returns true on success, false on failure, but + // some errors may not be detected till we call Close(). + virtual const T& Value(const std::string &key) { + + if (!IsOpen()) + KALDI_ERR << "Value() called on non-open object."; + + if (!((state_ == kHaveObject || state_ == kGaveObject) + && key == current_key_)) { // Not already stored... + bool has_key = HasKeyInternal(key, true); // preload. + if (!has_key) + KALDI_ERR << "Could not get item for key " << key + << ", rspecifier is " << rspecifier_ << "[to ignore this, " + << "add the p, (permissive) option to the rspecifier."; + KALDI_ASSERT(state_ == kHaveObject && key == current_key_); + } + + if (state_ == kHaveObject) { + state_ = kGaveObject; + if (opts_.once) MakeTombstone(key); // make sure that future lookups fail. + return holder_.Value(); + } else { // state_ == kGaveObject + if (opts_.once) + KALDI_ERR << "Value called twice for the same key and ,o (once) option " + << "is used: rspecifier is " << rspecifier_; + return holder_.Value(); + } + } + + virtual ~RandomAccessTableReaderScriptImpl() { + if (state_ == kHaveObject || state_ == kGaveObject) + holder_.Clear(); + } + + private: + // HasKeyInternal when called with preload == false just tells us whether the + // key is in the scp. With preload == true, which happens when the ,p + // (permissive) option is given in the rspecifier, it will also check that we + // can preload the object from disk (loading from the rxfilename in the scp), + // and only return true if we can. This function is called both from HasKey + // and from Value(). + virtual bool HasKeyInternal(const std::string &key, bool preload) { + switch (state_) { + case kUninitialized: case kNotReadScript: + KALDI_ERR << "HasKey called on RandomAccessTableReader object that is not open."; + case kHaveObject: case kGaveObject: + if (key == current_key_) + return true; + break; + default: break; + } + KALDI_ASSERT(IsToken(key)); + size_t key_pos = 0; // set to zero to suppress warning + bool ans = LookupKey(key, &key_pos); + if (!ans) return false; + else { + // First do a check regarding the "once" option. + if (opts_.once && script_[key_pos].second == "") { // A "tombstone"; user is asking about + // already-read key. + KALDI_ERR << "HasKey called on key whose value was already read, and " + " you specified the \"once\" option (o, ): try removing o, or adding no, :" + " rspecifier is " << rspecifier_; + } + if (!preload) + return true; // we have the key. + else { // preload specified, so we have to pre-load the object before returning true. + if (!input_.Open(script_[key_pos].second)) { + KALDI_WARN << "Error opening stream " + << PrintableRxfilename(script_[key_pos].second); + return false; + } else { + // Make sure holder empty. + if (state_ == kHaveObject || state_ == kGaveObject) + holder_.Clear(); + if (holder_.Read(input_.Stream())) { + state_ = kHaveObject; + current_key_ = key; + return true; + } else { + KALDI_WARN << "Error reading object from " + "stream " << PrintableRxfilename(script_[key_pos].second); + state_ = kNotHaveObject; + return false; + } + } + } + } + } + void MakeTombstone(const std::string &key) { + size_t offset; + if (!LookupKey(key, &offset)) + KALDI_ERR << "RandomAccessTableReader object in inconsistent state."; + else + script_[offset].second = ""; + } + bool LookupKey(const std::string &key, size_t *script_offset) { + // First, an optimization: if we're going consecutively, this will + // make the lookup very fast. Since we may call HasKey and then + // Value(), which both may look up the key, we test if either the + // current or next position are correct. + if (last_found_ < script_.size() && script_[last_found_].first == key) { + *script_offset = last_found_; + return true; + } + last_found_++; + if (last_found_ < script_.size() && script_[last_found_].first == key) { + *script_offset = last_found_; + return true; + } + std::pair<std::string, std::string> pr(key, ""); // Important that "" + // compares less than or equal to any string, so lower_bound points to the + // element that has the same key. + typedef typename std::vector<std::pair<std::string, std::string> >::const_iterator + IterType; + IterType iter = std::lower_bound(script_.begin(), script_.end(), pr); + if (iter != script_.end() && iter->first == key) { + last_found_ = *script_offset = iter - script_.begin(); + return true; + } else { + return false; + } + } + + + Input input_; // Use the same input_ object for reading each file, in case + // the scp specifies offsets in an archive (so we can keep the same file open). + RspecifierOptions opts_; + std::string rspecifier_; // rspecifier used to open it; used in debug messages + std::string script_rxfilename_; // filename of script. + + std::string current_key_; // Key of object in holder_ + Holder holder_; + + // the script_ variable contains pairs of (key, filename), sorted using + // std::sort. This can be used with binary_search to look up filenames for + // writing. If this becomes inefficient we can use std::unordered_map (but I + // suspect this wouldn't be significantly faster & would use more memory). + // If memory becomes a problem here, the user should probably be passing + // only the relevant part of the scp file rather than expecting us to get too + // clever in the code. + std::vector<std::pair<std::string, std::string> > script_; + size_t last_found_; // This is for an optimization used in FindFilename. + + enum { // [Do we have [Does holder_ + // script_ set up?] contain object?] + kUninitialized, // no no + kNotReadScript, // no no + kNotHaveObject, // yes no + kHaveObject, // yes yes + kGaveObject, // yes yes + // [kGaveObject is as kHaveObject but we note that the + // user has already read it; this is for checking that + // if "once" is specified, the user actually only reads + // it once. + } state_; + +}; + + + + +// This is the base-class (with some implemented functions) for the +// implementations of RandomAccessTableReader when it's an archive. This +// base-class handles opening the files, storing the state of the reading +// process, and loading objects. This is the only case in which we have +// an intermediate class in the hierarchy between the virtual ImplBase +// class and the actual Impl classes. +// The child classes vary in the assumptions regarding sorting, etc. + +template<class Holder> class RandomAccessTableReaderArchiveImplBase: + public RandomAccessTableReaderImplBase<Holder> { + public: + typedef typename Holder::T T; + + RandomAccessTableReaderArchiveImplBase(): holder_(NULL), state_(kUninitialized) { } + + virtual bool Open(const std::string &rspecifier) { + if (state_ != kUninitialized) { + if (! this->Close()) // call Close() yourself to suppress this exception. + KALDI_ERR << "TableReader::Open, error closing previous input."; + } + rspecifier_ = rspecifier; + RspecifierType rs = ClassifyRspecifier(rspecifier, &archive_rxfilename_, + &opts_); + KALDI_ASSERT(rs == kArchiveRspecifier); + + // NULL means don't expect binary-mode header + bool ans; + if (Holder::IsReadInBinary()) + ans = input_.Open(archive_rxfilename_, NULL); + else + ans = input_.OpenTextMode(archive_rxfilename_); + if (!ans) { // header. + KALDI_WARN << "TableReader: failed to open stream " + << PrintableRxfilename(archive_rxfilename_); + state_ = kUninitialized; // Failure on Open + return false; // User should print the error message. + } else { + state_ = kNoObject; + } + return true; + } + + // ReadNextObject() requires that the state be kNoObject, + // and it will try read the next object. If it succeeds, + // it sets the state to kHaveObject, and + // cur_key_ and holder_ have the key and value. If it fails, + // it sets the state to kError or kEof. + void ReadNextObject() { + if (state_ != kNoObject) + KALDI_ERR << "TableReader: ReadNextObject() called from wrong state."; // Code error + // somewhere in this class or a child class. + std::istream &is = input_.Stream(); + is.clear(); // Clear any fail bits that may have been set... just in case + // this happened in the Read function. + is >> cur_key_; // This eats up any leading whitespace and gets the string. + if (is.eof()) { + state_ = kEof; + return; + } + if (is.fail()) { // This shouldn't really happen, barring file-system errors. + KALDI_WARN << "Error reading archive: rspecifier is " << rspecifier_; + state_ = kError; + return; + } + int c; + if ((c = is.peek()) != ' ' && c != '\t' && c != '\n') { // We expect a space ' ' after the key. + // We also allow tab, just so we can read archives generated by scripts that may + // not be fully aware of how this format works. + KALDI_WARN << "Invalid archive file format: expected space after key " <<cur_key_ + <<", got character " + << CharToString(static_cast<char>(is.peek())) << ", reading archive " + << PrintableRxfilename(archive_rxfilename_); + state_ = kError; + return; + } + if (c != '\n') is.get(); // Consume the space or tab. + holder_ = new Holder; + if (holder_->Read(is)) { + state_ = kHaveObject; + return; + } else { + KALDI_WARN << "Object read failed, reading archive " + << PrintableRxfilename(archive_rxfilename_); + state_ = kError; + delete holder_; + holder_ = NULL; + return; + } + } + + virtual bool IsOpen() const { + switch (state_) { + case kEof: case kError: case kHaveObject: case kNoObject: return true; + case kUninitialized: return false; + default: KALDI_ERR << "IsOpen() called on invalid object."; + return false; + } + } + + // Called by the child-class virutal Close() functions; does the + // shared parts of the cleanup. + bool CloseInternal() { + if (! this->IsOpen()) + KALDI_ERR << "Close() called on TableReader twice or otherwise wrongly."; + if (input_.IsOpen()) + input_.Close(); + if (state_ == kHaveObject) { + KALDI_ASSERT(holder_ != NULL); + delete holder_; + holder_ = NULL; + } else KALDI_ASSERT(holder_ == NULL); + bool ans = (state_ != kError); + state_ = kUninitialized; + if (!ans && opts_.permissive) { + KALDI_WARN << "Error state detected closing reader. " + << "Ignoring it because you specified permissive mode."; + return true; + } + return ans; + } + + ~RandomAccessTableReaderArchiveImplBase() { + // The child class has the responsibility to call CloseInternal(). + KALDI_ASSERT(state_ == kUninitialized && holder_ == NULL); + } + private: + Input input_; // Input object for the archive + protected: + // The variables below are accessed by child classes. + + std::string cur_key_; // current key (if state == kHaveObject). + Holder *holder_; // Holds the object we just read (if state == kHaveObject) + + std::string rspecifier_; + std::string archive_rxfilename_; + RspecifierOptions opts_; + + enum { // [The state of the reading process] [does holder_ [is input_ + // have object] open] + kUninitialized, // Uninitialized or closed no no + kNoObject, // Do not have object in holder_ no yes + kHaveObject, // Have object in holder_ yes yes + kEof, // End of file no yes + kError, // Some kind of error-state in the reading. no yes + } state_; + +}; + + +// RandomAccessTableReaderDSortedArchiveImpl (DSorted for "doubly sorted") is the +// implementation for random-access reading of archives when both the archive, +// and the calling code, are in sorted order (i.e. we ask for the keys in sorted +// order). This is when the s and cs options are both given. It only ever has +// to keep one object in memory. It inherits from +// RandomAccessTableReaderArchiveImplBase which implements the common parts of +// RandomAccessTableReader that are used when it's an archive we're reading from. + +template<class Holder> class RandomAccessTableReaderDSortedArchiveImpl: + public RandomAccessTableReaderArchiveImplBase<Holder> { + using RandomAccessTableReaderArchiveImplBase<Holder>::kUninitialized; + using RandomAccessTableReaderArchiveImplBase<Holder>::kHaveObject; + using RandomAccessTableReaderArchiveImplBase<Holder>::kNoObject; + using RandomAccessTableReaderArchiveImplBase<Holder>::kEof; + using RandomAccessTableReaderArchiveImplBase<Holder>::kError; + using RandomAccessTableReaderArchiveImplBase<Holder>::state_; + using RandomAccessTableReaderArchiveImplBase<Holder>::opts_; + using RandomAccessTableReaderArchiveImplBase<Holder>::cur_key_; + using RandomAccessTableReaderArchiveImplBase<Holder>::holder_; + using RandomAccessTableReaderArchiveImplBase<Holder>::rspecifier_; + using RandomAccessTableReaderArchiveImplBase<Holder>::archive_rxfilename_; + using RandomAccessTableReaderArchiveImplBase<Holder>::ReadNextObject; + public: + typedef typename Holder::T T; + + RandomAccessTableReaderDSortedArchiveImpl() { } + + virtual bool Close() { + // We don't have anything additional to clean up, so just + // call generic base-class one. + return this->CloseInternal(); + } + + virtual bool HasKey(const std::string &key) { + return FindKeyInternal(key); + } + virtual const T & Value(const std::string &key) { + if (FindKeyInternal(key)) { + KALDI_ASSERT(this->state_ == kHaveObject && key == this->cur_key_ + && holder_ != NULL); + return this->holder_->Value(); + } else { + KALDI_ERR << "Value() called but no such key " << key + << " in archive " << PrintableRxfilename(archive_rxfilename_); + return *(const T*)NULL; // keep compiler happy. + } + } + + virtual ~RandomAccessTableReaderDSortedArchiveImpl() { + if (this->IsOpen()) + if (!Close()) // more specific warning will already have been printed. + // we are in some kind of error state & user did not find out by + // calling Close(). + KALDI_ERR << "Error closing RandomAccessTableReader: rspecifier is " + << rspecifier_; + } + private: + // FindKeyInternal tries to find the key by calling "ReadNextObject()" + // as many times as necessary till we get to it. It is called from + // both FindKey and Value(). + bool FindKeyInternal(const std::string &key) { + // First check that the user is calling us right: should be + // in sorted order. If not, error. + if (!last_requested_key_.empty()) { + if (key.compare(last_requested_key_) < 0) { // key < last_requested_key_ + KALDI_ERR << "You provided the \"cs\" option " + << "but are not calling with keys in sorted order: " + << key << " < " << last_requested_key_ << ": rspecifier is " + << rspecifier_; + } + } + // last_requested_key_ is just for debugging of order of calling. + last_requested_key_ = key; + + if (state_ == kNoObject) + ReadNextObject(); // This can only happen + // once, the first time someone calls HasKey() or Value(). We don't + // do it in the initializer to stop the program hanging too soon, + // if reading from a pipe. + + if (state_ == kEof || state_ == kError) return false; + + if (state_ == kUninitialized) + KALDI_ERR << "Trying to access a RandomAccessTableReader object that is not open."; + + std::string last_key_; // To check that + // the archive we're reading is in sorted order. + while (1) { + KALDI_ASSERT(state_ == kHaveObject); + int compare = key.compare(cur_key_); + if (compare == 0) { // key == key_ + return true; // we got it.. + } else if (compare < 0) { // key < cur_key_, so we already read past the + // place where we want to be. This implies that we will never find it + // [due to the sorting etc., this means it just isn't in the archive]. + return false; + } else { // compare > 0, key > cur_key_. We need to read further ahead. + last_key_ = cur_key_; + // read next object.. we have to set state to kNoObject first. + KALDI_ASSERT(holder_ != NULL); + delete holder_; + holder_ = NULL; + state_ = kNoObject; + ReadNextObject(); + if (state_ != kHaveObject) + return false; // eof or read error. + if (cur_key_.compare(last_key_) <= 0) { + KALDI_ERR << "You provided the \"s\" option " + << " (sorted order), but keys are out of order or duplicated: " + << last_key_ << " is followed by " << cur_key_ + << ": rspecifier is " << rspecifier_; + } + } + } + } + + /// Last string provided to HasKey() or Value(); + std::string last_requested_key_; + + +}; + +// RandomAccessTableReaderSortedArchiveImpl is for random-access reading of +// archives when the user specified the sorted (s) option but not the +// called-sorted (cs) options. +template<class Holder> class RandomAccessTableReaderSortedArchiveImpl: + public RandomAccessTableReaderArchiveImplBase<Holder> { + using RandomAccessTableReaderArchiveImplBase<Holder>::kUninitialized; + using RandomAccessTableReaderArchiveImplBase<Holder>::kHaveObject; + using RandomAccessTableReaderArchiveImplBase<Holder>::kNoObject; + using RandomAccessTableReaderArchiveImplBase<Holder>::kEof; + using RandomAccessTableReaderArchiveImplBase<Holder>::kError; + using RandomAccessTableReaderArchiveImplBase<Holder>::state_; + using RandomAccessTableReaderArchiveImplBase<Holder>::opts_; + using RandomAccessTableReaderArchiveImplBase<Holder>::cur_key_; + using RandomAccessTableReaderArchiveImplBase<Holder>::holder_; + using RandomAccessTableReaderArchiveImplBase<Holder>::rspecifier_; + using RandomAccessTableReaderArchiveImplBase<Holder>::archive_rxfilename_; + using RandomAccessTableReaderArchiveImplBase<Holder>::ReadNextObject; + + public: + typedef typename Holder::T T; + + RandomAccessTableReaderSortedArchiveImpl(): + last_found_index_(static_cast<size_t>(-1)), + pending_delete_(static_cast<size_t>(-1)) { } + + virtual bool Close() { + for (size_t i = 0; i < seen_pairs_.size(); i++) + if (seen_pairs_[i].second) + delete seen_pairs_[i].second; + seen_pairs_.clear(); + + pending_delete_ = static_cast<size_t>(-1); + last_found_index_ = static_cast<size_t>(-1); + + return this->CloseInternal(); + } + virtual bool HasKey(const std::string &key) { + HandlePendingDelete(); + size_t index; + bool ans = FindKeyInternal(key, &index); + if (ans && opts_.once && seen_pairs_[index].second == NULL) { + // Just do a check RE the once option. "&&opts_.once" is for + // efficiency since this can only happen in that case. + KALDI_ERR << "Error: HasKey called after Value() already called for " + << " that key, and once (o) option specified: rspecifier is " + << rspecifier_; + } + return ans; + } + virtual const T & Value(const std::string &key) { + HandlePendingDelete(); + size_t index; + if (FindKeyInternal(key, &index)) { + if (seen_pairs_[index].second == NULL) { // can happen if opts.once_ + KALDI_ERR << "Error: Value() called more than once for key " + << key << " and once (o) option specified: rspecifier is " + << rspecifier_; + } + if (opts_.once) + pending_delete_ = index; // mark this index to be deleted on next call. + return seen_pairs_[index].second->Value(); + } else { + KALDI_ERR << "Value() called but no such key " << key + << " in archive " << PrintableRxfilename(archive_rxfilename_); + return *(const T*)NULL; // keep compiler happy. + } + } + virtual ~RandomAccessTableReaderSortedArchiveImpl() { + if (this->IsOpen()) + if (!Close()) // more specific warning will already have been printed. + // we are in some kind of error state & user did not find out by + // calling Close(). + KALDI_ERR << "Error closing RandomAccessTableReader: rspecifier is " + << rspecifier_; + } + private: + void HandlePendingDelete() { + const size_t npos = static_cast<size_t>(-1); + if (pending_delete_ != npos) { + KALDI_ASSERT(pending_delete_ < seen_pairs_.size()); + KALDI_ASSERT(seen_pairs_[pending_delete_].second != NULL); + delete seen_pairs_[pending_delete_].second; + seen_pairs_[pending_delete_].second = NULL; + pending_delete_ = npos; + } + } + + // FindKeyInternal tries to find the key in the array "seen_pairs_". + // If it is not already there, it reads ahead as far as necessary + // to determine whether we have the key or not. On success it returns + // true and puts the index into the array seen_pairs_, into "index"; + // on failure it returns false. + // It will leave the state as either kNoObject, kEof or kError. + // FindKeyInternal does not do any checking about whether you are asking + // about a key that has been already given (with the "once" option). + // That is the user's responsibility. + + bool FindKeyInternal(const std::string &key, size_t *index) { + // First, an optimization in case the previous call was for the + // same key, and we found it. + if (last_found_index_ < seen_pairs_.size() + && seen_pairs_[last_found_index_].first == key) { + *index = last_found_index_; + return true; + } + + if (state_ == kUninitialized) + KALDI_ERR << "Trying to access a RandomAccessTableReader object that is not open."; + + // Step one is to see whether we have to read ahead for the object.. + // Note, the possible states right now are kNoObject, kEof or kError. + // We are never in the state kHaveObject except just after calling + // ReadNextObject(). + bool looped = false; + while (state_ == kNoObject && + (seen_pairs_.empty() || key.compare(seen_pairs_.back().first) > 0)) { + looped = true; + // Read this as: + // while ( the stream is potentially good for reading && + // ([got no keys] || key > most_recent_key) ) { ... + // Try to read a new object. + // Note that the keys in seen_pairs_ are ordered from least to greatest. + ReadNextObject(); + if (state_ == kHaveObject) { // Successfully read object. + if (!seen_pairs_.empty() && // This is just a check. + cur_key_.compare(seen_pairs_.back().first) <= 0) { + // read the expression above as: !( cur_key_ > previous_key). + // it means we are not in sorted order [the user specified that we + // are, or we would not be using this implementation]. + KALDI_ERR << "You provided the sorted (s) option but keys in archive " + << PrintableRxfilename(archive_rxfilename_) << " are not " + << "in sorted order: " << seen_pairs_.back().first + << " is followed by " << cur_key_; + } + KALDI_ASSERT(holder_ != NULL); + seen_pairs_.push_back(std::make_pair(cur_key_, holder_)); + holder_ = NULL; + state_ = kNoObject; + } + } + if (looped) { // We only need to check the last element of the seen_pairs_ array, + // since we would not have read more after getting "key". + if (!seen_pairs_.empty() && seen_pairs_.back().first == key) { + last_found_index_ = *index = seen_pairs_.size() - 1; + return true; + } else return false; + } + // Now we have do an actual binary search in the seen_pairs_ array. + std::pair<std::string, Holder*> pr(key, static_cast<Holder*>(NULL)); + typename std::vector<std::pair<std::string, Holder*> >::iterator + iter = std::lower_bound(seen_pairs_.begin(), seen_pairs_.end(), + pr, PairCompare()); + if (iter != seen_pairs_.end() && + key == iter->first) { + last_found_index_ = *index = (iter - seen_pairs_.begin()); + return true; + } else return false; + } + + // These are the pairs of (key, object) we have read. We keep all the keys we + // have read but the actual objects (if they are stored with pointers inside + // the Holder object) may be deallocated if once == true, and the Holder + // pointer set to NULL. + std::vector<std::pair<std::string, Holder*> > seen_pairs_; + size_t last_found_index_; // An optimization s.t. if FindKeyInternal called twice with + // same key (as it often will), it doesn't have to do the key search twice. + size_t pending_delete_; // If opts_.once == true, this is the index of + // element of seen_pairs_ that is pending deletion. + struct PairCompare { + // PairCompare is the Less-than operator for the pairs of(key, Holder). + // compares the keys. + inline bool operator() (const std::pair<std::string, Holder*> &pr1, + const std::pair<std::string, Holder*> &pr2) { + return (pr1.first.compare(pr2.first) < 0); + } + }; +}; + + + +// RandomAccessTableReaderUnsortedArchiveImpl is for random-access reading of +// archives when the user does not specify the sorted (s) option (in this case +// the called-sorted, or "cs" option, is ignored). This is the least efficient +// of the random access archive readers, in general, but it can be as efficient +// as the others, in speed, memory and latency, if the "once" option is specified +// and it happens that the keys of the archive are the same as the keys the code +// is called with (to HasKey() and Value()), and in the same order. However, if +// you ask it for a key that's not present it will have to read the archive till +// the end and store it all in memory. + +template<class Holder> class RandomAccessTableReaderUnsortedArchiveImpl: + public RandomAccessTableReaderArchiveImplBase<Holder> { + using RandomAccessTableReaderArchiveImplBase<Holder>::kUninitialized; + using RandomAccessTableReaderArchiveImplBase<Holder>::kHaveObject; + using RandomAccessTableReaderArchiveImplBase<Holder>::kNoObject; + using RandomAccessTableReaderArchiveImplBase<Holder>::kEof; + using RandomAccessTableReaderArchiveImplBase<Holder>::kError; + using RandomAccessTableReaderArchiveImplBase<Holder>::state_; + using RandomAccessTableReaderArchiveImplBase<Holder>::opts_; + using RandomAccessTableReaderArchiveImplBase<Holder>::cur_key_; + using RandomAccessTableReaderArchiveImplBase<Holder>::holder_; + using RandomAccessTableReaderArchiveImplBase<Holder>::rspecifier_; + using RandomAccessTableReaderArchiveImplBase<Holder>::archive_rxfilename_; + using RandomAccessTableReaderArchiveImplBase<Holder>::ReadNextObject; + + typedef typename Holder::T T; + + public: + RandomAccessTableReaderUnsortedArchiveImpl(): to_delete_iter_(map_.end()), + to_delete_iter_valid_(false) + { + map_.max_load_factor(0.5); // make it quite empty -> quite efficient. + // default seems to be 1. + } + + virtual bool Close() { + for (typename MapType::iterator iter = map_.begin(); + iter != map_.end(); + ++iter) { + if (iter->second) + delete iter->second; + } + map_.clear(); + first_deleted_string_ = ""; + to_delete_iter_valid_ = false; + return this->CloseInternal(); + } + + virtual bool HasKey(const std::string &key) { + HandlePendingDelete(); + return FindKeyInternal(key, NULL); + } + virtual const T & Value(const std::string &key) { + HandlePendingDelete(); + const T *ans_ptr = NULL; + if (FindKeyInternal(key, &ans_ptr)) + return *ans_ptr; + else + KALDI_ERR << "Value() called but no such key " << key + << " in archive " << PrintableRxfilename(archive_rxfilename_); + return *(const T*)NULL; // keep compiler happy. + } + virtual ~RandomAccessTableReaderUnsortedArchiveImpl() { + if (this->IsOpen()) + if (!Close()) // more specific warning will already have been printed. + // we are in some kind of error state & user did not find out by + // calling Close(). + KALDI_ERR << "Error closing RandomAccessTableReader: rspecifier is " + << rspecifier_; + } + private: + void HandlePendingDelete() { + if (to_delete_iter_valid_) { + to_delete_iter_valid_ = false; + delete to_delete_iter_->second; // Delete Holder object. + if (first_deleted_string_.length() == 0) + first_deleted_string_ = to_delete_iter_->first; + map_.erase(to_delete_iter_); // delete that element. + } + } + + // FindKeyInternal tries to find the key in the map "map_" + // If it is not already there, it reads ahead either until it finds the + // key, or until end of file. If called with value_ptr == NULL, + // it assumes it's called from HasKey() and just returns true or false + // and doesn't otherwise have side effects. If called with value_ptr != + // NULL, it assumes it's called from Value(). Thus, it will crash + // if it cannot find the key. If it can find it it puts its address in + // *value_ptr, and if opts_once == true it will mark that element of the + // map to be deleted. + + bool FindKeyInternal(const std::string &key, const T **value_ptr = NULL) { + typename MapType::iterator iter = map_.find(key); + if (iter != map_.end()) { // Found in the map... + if (value_ptr == NULL) { // called from HasKey + return true; // this is all we have to do. + } else { + *value_ptr = &(iter->second->Value()); + if (opts_.once) { // value won't be needed again, so mark + // for deletion. + to_delete_iter_ = iter; // pending delete. + KALDI_ASSERT(!to_delete_iter_valid_); + to_delete_iter_valid_ = true; + } + return true; + } + } + while (state_ == kNoObject) { + ReadNextObject(); + if (state_ == kHaveObject) { // Successfully read object. + state_ = kNoObject; // we are about to transfer ownership + // of the object in holder_ to map_. + // Insert it into map_. + std::pair<typename MapType::iterator, bool> pr = + map_.insert(typename MapType::value_type(cur_key_, holder_)); + + if (!pr.second) { // Was not inserted-- previous element w/ same key + delete holder_; // map was not changed, no ownership transferred. + holder_ = NULL; + KALDI_ERR << "Error in RandomAccessTableReader: duplicate key " + << cur_key_ << " in archive " << archive_rxfilename_; + } + holder_ = NULL; // ownership transferred to map_. + if (cur_key_ == key) { // the one we wanted.. + if (value_ptr == NULL) { // called from HasKey + return true; + } else { // called from Value() + *value_ptr = &(pr.first->second->Value()); // this gives us the + // Value() from the Holder in the map. + if (opts_.once) { // mark for deletion, as won't be needed again. + to_delete_iter_ = pr.first; + KALDI_ASSERT(!to_delete_iter_valid_); + to_delete_iter_valid_ = true; + } + return true; + } + } + } + } + if (opts_.once && key == first_deleted_string_) { + KALDI_ERR << "You specified the once (o) option but " + << "you are calling using key " << key + << " more than once: rspecifier is " << rspecifier_; + } + return false; // We read the entire archive (or got to error state) and didn't + // find it. + } + + typedef unordered_map<std::string, Holder*, StringHasher> MapType; + MapType map_; + + typename MapType::iterator to_delete_iter_; + bool to_delete_iter_valid_; + + std::string first_deleted_string_; // keep the first string we deleted + // from map_ (if opts_.once == true). It's for an inexact spot-check that the + // "once" option isn't being used incorrectly. + +}; + + + + + +template<class Holder> +RandomAccessTableReader<Holder>::RandomAccessTableReader(const std::string &rspecifier): + impl_(NULL) { + if (rspecifier != "" && !Open(rspecifier)) + KALDI_ERR << "Error opening RandomAccessTableReader object " + " (rspecifier is: " << rspecifier << ")"; +} + +template<class Holder> +bool RandomAccessTableReader<Holder>::Open(const std::string &rspecifier) { + if (IsOpen()) + KALDI_ERR << "Already open."; + RspecifierOptions opts; + RspecifierType rs = ClassifyRspecifier(rspecifier, NULL, &opts); + switch (rs) { + case kScriptRspecifier: + impl_ = new RandomAccessTableReaderScriptImpl<Holder>(); + break; + case kArchiveRspecifier: + if (opts.sorted) { + if (opts.called_sorted) // "doubly" sorted case. + impl_ = new RandomAccessTableReaderDSortedArchiveImpl<Holder>(); + else + impl_ = new RandomAccessTableReaderSortedArchiveImpl<Holder>(); + } else impl_ = new RandomAccessTableReaderUnsortedArchiveImpl<Holder>(); + break; + case kNoRspecifier: default: + KALDI_WARN << "Invalid rspecifier: " + << rspecifier; + return false; + } + if (impl_->Open(rspecifier)) + return true; + else { + // Warning will already have been printed. + delete impl_; + impl_ = NULL; + return false; + } +} + +template<class Holder> +bool RandomAccessTableReader<Holder>::HasKey(const std::string &key) { + CheckImpl(); + if (!IsToken(key)) + KALDI_ERR << "Invalid key \"" << key << '"'; + return impl_->HasKey(key); +} + + +template<class Holder> +const typename RandomAccessTableReader<Holder>::T& +RandomAccessTableReader<Holder>::Value(const std::string &key) { + CheckImpl(); + return impl_->Value(key); +} + +template<class Holder> +bool RandomAccessTableReader<Holder>::Close() { + CheckImpl(); + bool ans =impl_->Close(); + delete impl_; + impl_ = NULL; + return ans; +} + +template<class Holder> +RandomAccessTableReader<Holder>::~RandomAccessTableReader() { + if (IsOpen() && !Close()) // call Close() yourself to stop this being thrown. + KALDI_ERR << "failure detected in destructor."; +} + +template<class Holder> +void SequentialTableReader<Holder>::CheckImpl() const { + if (!impl_) { + KALDI_ERR << "Trying to use empty SequentialTableReader (perhaps you " + << "passed the empty string as an argument to a program?)"; + } +} + +template<class Holder> +void RandomAccessTableReader<Holder>::CheckImpl() const { + if (!impl_) { + KALDI_ERR << "Trying to use empty RandomAccessTableReader (perhaps you " + << "passed the empty string as an argument to a program?)"; + } +} + +template<class Holder> +void TableWriter<Holder>::CheckImpl() const { + if (!impl_) { + KALDI_ERR << "Trying to use empty TableWriter (perhaps you " + << "passed the empty string as an argument to a program?)"; + } +} + +template<class Holder> +RandomAccessTableReaderMapped<Holder>::RandomAccessTableReaderMapped( + const std::string &table_rxfilename, + const std::string &utt2spk_rxfilename): + reader_(table_rxfilename), token_reader_(table_rxfilename.empty() ? "" : + utt2spk_rxfilename), + utt2spk_rxfilename_(utt2spk_rxfilename) { } + +template<class Holder> +bool RandomAccessTableReaderMapped<Holder>::Open( + const std::string &table_rxfilename, + const std::string &utt2spk_rxfilename) { + if (reader_.IsOpen()) reader_.Close(); + if (token_reader_.IsOpen()) token_reader_.Close(); + KALDI_ASSERT(!table_rxfilename.empty()); + if (!reader_.Open(table_rxfilename)) return false; // will have printed + // warning internally, probably. + if (!utt2spk_rxfilename.empty()) { + if (!token_reader_.Open(utt2spk_rxfilename)) { + reader_.Close(); + return false; + } + } + return true; +} + + +template<class Holder> +bool RandomAccessTableReaderMapped<Holder>::HasKey(const std::string &utt) { + // We don't check IsOpen, we let the call go through to the member variable + // (reader_), which will crash with a more informative error message than + // we can give here, as we don't any longer know the rxfilename. + if (token_reader_.IsOpen()) { // We need to map the key from utt to spk. + if (!token_reader_.HasKey(utt)) + KALDI_ERR << "Attempting to read key " << utt << ", which is not present " + << "in utt2spk map or similar map being read from " + << PrintableRxfilename(utt2spk_rxfilename_); + const std::string &spk = token_reader_.Value(utt); + return reader_.HasKey(spk); + } else { + return reader_.HasKey(utt); + } +} + +template<class Holder> +const typename Holder::T& RandomAccessTableReaderMapped<Holder>::Value( + const std::string &utt) { + if (token_reader_.IsOpen()) { // We need to map the key from utt to spk. + if (!token_reader_.HasKey(utt)) + KALDI_ERR << "Attempting to read key " << utt << ", which is not present " + << "in utt2spk map or similar map being read from " + << PrintableRxfilename(utt2spk_rxfilename_); + const std::string &spk = token_reader_.Value(utt); + return reader_.Value(spk); + } else { + return reader_.Value(utt); + } +} + + + +/// @} + +} // end namespace kaldi + + + +#endif diff --git a/kaldi_io/src/kaldi/util/kaldi-table.h b/kaldi_io/src/kaldi/util/kaldi-table.h new file mode 100644 index 0000000..6f6cb98 --- /dev/null +++ b/kaldi_io/src/kaldi/util/kaldi-table.h @@ -0,0 +1,459 @@ +// util/kaldi-table.h + +// Copyright 2009-2011 Microsoft Corporation +// 2013 Johns Hopkins University (author: Daniel Povey) + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#ifndef KALDI_UTIL_KALDI_TABLE_H_ +#define KALDI_UTIL_KALDI_TABLE_H_ + +#include <string> +#include <vector> +#include <utility> + +#include "base/kaldi-common.h" +#include "util/kaldi-holder.h" + +namespace kaldi { + +// Forward declarations +template<class Holder> class RandomAccessTableReaderImplBase; +template<class Holder> class SequentialTableReaderImplBase; +template<class Holder> class TableWriterImplBase; + +/// \addtogroup table_group +/// @{ + +// This header defines the Table classes (RandomAccessTableReader, +// SequentialTableReader and TableWriter) and explains what the Holder classes, +// which the Table class requires as a template argument, are like. It also +// explains the "rspecifier" and "wspecifier" concepts (these are strings that +// explain how to read/write objects via archives or scp files. A table is +// conceptually a collection of objects of a particular type T indexed by keys +// of type std::string (these Keys additionally have an order within each table). +// The Table classes are templated on a type (call it Holder) such that Holder::T +// is a typedef equal to T. + +// see kaldi-holder.h for detail on the Holder classes. + +typedef std::vector<std::string> KeyList; + +// Documentation for "wspecifier" +// "wspecifier" describes how we write a set of objects indexed by keys. +// The basic, unadorned wspecifiers are as follows: +// +// ark:wxfilename +// scp:rxfilename +// ark,scp:filename,wxfilename +// ark,scp:filename,wxfilename +// +// +// We also allow the following modifiers: +// t means text mode. +// b means binary mode. +// f means flush the stream after writing each entry. +// (nf means don't flush, and isn't very useful as the default is to flush). +// p means permissive mode, when writing to an "scp" file only: will ignore +// missing scp entries, i.e. won't write anything for those files but will +// return success status). +// +// So the following are valid wspecifiers: +// ark,b,f:foo +// "ark,b,b:| gzip -c > foo" +// "ark,scp,t,nf:foo.ark,|gzip -c > foo.scp.gz" +// ark,b:- +// +// The meanings of rxfilename and wxfilename are as described in +// kaldi-stream.h (they are filenames but include pipes, stdin/stdout +// and so on; filename is a regular filename. +// + +// The ark:wxfilename type of wspecifier instructs the class to +// write directly to an archive. For small objects (e.g. lists of ints), +// the text archive format will generally be human readable with one line +// per entry in the archive. +// +// The type "scp:xfilename" refers to an scp file which should +// already exist on disk, and tells us where to write the data for +// each key (usually an actual file); each line of the scp file +// would be: +// key xfilename +// +// The type ark,scp:filename,wxfilename means +// we write both an archive and an scp file that specifies offsets into the +// archive, with lines like: +// key filename:12407 +// where the number is the byte offset into the file. +// In this case we restrict the archive-filename to be an actual filename, +// as we can't see a situtation where an extended filename would make sense +// for this (we can't fseek() in pipes). + +enum WspecifierType { + kNoWspecifier, + kArchiveWspecifier, + kScriptWspecifier, + kBothWspecifier +}; + +struct WspecifierOptions { + bool binary; + bool flush; + bool permissive; // will ignore absent scp entries. + WspecifierOptions(): binary(true), flush(false), permissive(false) { } +}; + +// ClassifyWspecifier returns the type of the wspecifier string, +// and (if pointers are non-NULL) outputs the extra information +// about the options, and the script and archive +// filenames. +WspecifierType ClassifyWspecifier(const std::string &wspecifier, + std::string *archive_wxfilename, + std::string *script_wxfilename, + WspecifierOptions *opts); + +// ReadScriptFile reads an .scp file in its entirety, and appends it +// (in order as it was in the scp file) in script_out_, which contains +// pairs of (key, xfilename). The .scp +// file format is: on each line, key xfilename +// where xfilename means rxfilename or wxfilename, and may contain internal spaces +// (we trim away any leading or trailing space). The key is space-free. +// ReadScriptFile returns true if the format was valid (empty files +// are valid). +// If 'print_warnings', it will print out warning messages that explain what kind +// of error there was. +bool ReadScriptFile(const std::string &rxfilename, + bool print_warnings, + std::vector<std::pair<std::string, std::string> > *script_out); + +// This version of ReadScriptFile works from an istream. +bool ReadScriptFile(std::istream &is, + bool print_warnings, + std::vector<std::pair<std::string, std::string> > *script_out); + +// Writes, for each entry in script, the first element, then ' ', then the second +// element then '\n'. Checks that the keys (first elements of pairs) are valid +// tokens (nonempty, no whitespace), and the values (second elements of pairs) +// are newline-free and contain no leading or trailing space. Returns true on +// success. +bool WriteScriptFile(const std::string &wxfilename, + const std::vector<std::pair<std::string, std::string> > &script); + +// This version writes to an ostream. +bool WriteScriptFile(std::ostream &os, + const std::vector<std::pair<std::string, std::string> > &script); + +// Documentation for "rspecifier" +// "rspecifier" describes how we read a set of objects indexed by keys. +// The possibilities are: +// +// ark:rxfilename +// scp:rxfilename +// +// We also allow various modifiers: +// o means the program will only ask for each key once, which enables +// the reader to discard already-asked-for values. +// s means the keys are sorted on input (means we don't have to read till +// eof if someone asked for a key that wasn't there). +// cs means that it is called in sorted order (we are generally asserting this +// based on knowledge of how the program works). +// p means "permissive", and causes it to skip over keys whose corresponding +// scp-file entries cannot be read. [and to ignore errors in archives and +// script files, and just consider the "good" entries]. +// We allow the negation of the options above, as in no, ns, np, +// but these aren't currently very useful (just equivalent to omitting the +// corresponding option). +// [any of the above options can be prefixed by n to negate them, e.g. no, ns, +// ncs, np; but these aren't currently useful as you could just omit the option]. +// +// b is ignored [for scripting convenience] +// t is ignored [for scripting convenience] +// +// +// So for instance the following would be a valid rspecifier: +// +// "o, s, p, ark:gunzip -c foo.gz|" + +struct RspecifierOptions { + // These options only make a difference for the RandomAccessTableReader class. + bool once; // we assert that the program will only ask for each key once. + bool sorted; // we assert that the keys are sorted. + bool called_sorted; // we assert that the (HasKey(), Value() functions will + // also be called in sorted order. [this implies "once" but not vice versa]. + bool permissive; // If "permissive", when reading from scp files it treats + // scp files that can't be read as if the corresponding key were not there. + // For archive files it will suppress errors getting thrown if the archive + + // is corrupted and can't be read to the end. + + RspecifierOptions(): once(false), sorted(false), + called_sorted(false), permissive(false) { } +}; + +enum RspecifierType { + kNoRspecifier, + kArchiveRspecifier, + kScriptRspecifier +}; + +RspecifierType ClassifyRspecifier(const std::string &rspecifier, std::string *rxfilename, + RspecifierOptions *opts); + +// Class Table<Holder> is useful when you want the entire set of +// objects in memory. NOT IMPLEMENTED YET. +// It is the least scalable way of accessing data in Tables. +// The *TableReader and TableWriter classes are more scalable. + + +/// Allows random access to a collection +/// of objects in an archive or script file; see \ref io_sec_tables. +template<class Holder> +class RandomAccessTableReader { + public: + typedef typename Holder::T T; + + RandomAccessTableReader(): impl_(NULL) { } + + // This constructor equivalent to default constructor + "open", but + // throws on error. + RandomAccessTableReader(const std::string &rspecifier); + + // Opens the table. + bool Open(const std::string &rspecifier); + + // Returns true if table is open. + bool IsOpen() const { return (impl_ != NULL); } + + // Close() will close the table [throws if it was not open], + // and returns true on success (false if we were reading an + // archive and we discovered an error in the archive). + bool Close(); + + // Says if it has this key. + // If you are using the "permissive" (p) read option, + // it will return false for keys whose corresponding entry + // in the scp file cannot be read. + + bool HasKey(const std::string &key); + + // Value() may throw if you are reading an scp file, you + // do not have the "permissive" (p) option, and an entry + // in the scp file cannot be read. Typically you won't + // want to catch this error. + const T &Value(const std::string &key); + + ~RandomAccessTableReader(); + + // Allow copy-constructor only for non-opened readers (needed for inclusion in + // stl vector) + RandomAccessTableReader(const RandomAccessTableReader<Holder> &other): + impl_(NULL) { KALDI_ASSERT(other.impl_ == NULL); } + private: + // Disallow assignment. + RandomAccessTableReader &operator=(const RandomAccessTableReader<Holder>&); + void CheckImpl() const; // Checks that impl_ is non-NULL; prints an error + // message and dies (with KALDI_ERR) if NULL. + RandomAccessTableReaderImplBase<Holder> *impl_; +}; + + + +/// A templated class for reading objects sequentially from an archive or script +/// file; see \ref io_sec_tables. +template<class Holder> +class SequentialTableReader { + public: + typedef typename Holder::T T; + + SequentialTableReader(): impl_(NULL) { } + + // This constructor equivalent to default constructor + "open", but + // throws on error. + SequentialTableReader(const std::string &rspecifier); + + // Opens the table. Returns exit status; but does throw if previously + // open stream was in error state. Call Close to stop this [anyway, + // calling Open more than once is not recommended.] + bool Open(const std::string &rspecifier); + + // Returns true if we're done. It will also return true if there's some kind + // of error and we can't read any more; in this case, you can detect the + // error by calling Close and checking the return status; otherwise + // the destructor will throw. + inline bool Done(); + + // Only valid to call Key() if Done() returned false. + inline std::string Key(); + + // FreeCurrent() is provided as an optimization to save memory, for large + // objects. It instructs the class to deallocate the current value. The + // reference Value() will/ be invalidated by this. + + void FreeCurrent(); + + // Return reference to the current value. + // The reference is valid till next call to this object. + // If will throw if you are reading an scp file, did not + // specify the "permissive" (p) option and the file cannot + // be read. [The permissive option makes it behave as if that + // key does not even exist, if the corresponding file cannot be + // read.] You probably wouldn't want to catch this exception; + // the user can just specify the p option in the rspecifier. + const T &Value(); + + // Next goes to the next key. It will not throw; any error will + // result in Done() returning true, and then the destructor will + // throw unless you call Close(). + void Next(); + + // Returns true if table is open for reading (does not imply + // stream is in good state). + bool IsOpen() const; + + // Close() will return false (failure) if Done() became true + // because of an error/ condition rather than because we are + // really done [e.g. because of an error or early termination + // in the archive]. + // If there is an error and you don't call Close(), the destructor + // will fail. + // Close() + bool Close(); + + // The destructor may throw. This is the desired behaviour, as it's the way we + // signal the error to the user (to detect it, call Close(). The issue is that + // otherwise the user has no way to tell whether Done() returned true because + // we reached the end of the archive or script, or because there was an error + // that prevented further reading. + ~SequentialTableReader(); + + // Allow copy-constructor only for non-opened readers (needed for inclusion in + // stl vector) + SequentialTableReader(const SequentialTableReader<Holder> &other): + impl_(NULL) { KALDI_ASSERT(other.impl_ == NULL); } + private: + // Disallow assignment. + SequentialTableReader &operator = (const SequentialTableReader<Holder>&); + void CheckImpl() const; // Checks that impl_ is non-NULL; prints an error + // message and dies (with KALDI_ERR) if NULL. + SequentialTableReaderImplBase<Holder> *impl_; +}; + + +/// A templated class for writing objects to an +/// archive or script file; see \ref io_sec_tables. +template<class Holder> +class TableWriter { + public: + typedef typename Holder::T T; + + TableWriter(): impl_(NULL) { } + + // This constructor equivalent to default constructor + // + "open", but throws on error. See docs for + // wspecifier above. + TableWriter(const std::string &wspecifier); + + // Opens the table. See docs for wspecifier above. + // If it returns true, it is open. + bool Open(const std::string &wspecifier); + + // Returns true if open for writing. + bool IsOpen() const; + + // Write the object. Throws std::runtime_error on error (via the + // KALDI_ERR macro) + inline void Write(const std::string &key, const T &value) const; + + + // Flush will flush any archive; it does not return error status + // or throw, any errors will be reported on the next Write or Close. + // Useful if we may be writing to a command in a pipe and want + // to ensure good CPU utilization. + void Flush(); + + // Close() is not necessary to call, as the destructor + // closes it; it's mainly useful if you want to handle + // error states because the destructor will throw on + // error if you do not call Close(). + bool Close(); + + ~TableWriter(); + + // Allow copy-constructor only for non-opened writers (needed for inclusion in + // stl vector) + TableWriter(const TableWriter &other): impl_(NULL) { + KALDI_ASSERT(other.impl_ == NULL); + } + private: + TableWriter &operator = (const TableWriter&); // Disallow assignment. + void CheckImpl() const; // Checks that impl_ is non-NULL; prints an error + // message and dies (with KALDI_ERR) if NULL. + TableWriterImplBase<Holder> *impl_; +}; + + +/// This class is for when you are reading something in random access, but +/// it may actually be stored per-speaker (or something similar) but the +/// keys you're using are per utterance. So you also provide an "rxfilename" +/// for a file containing lines like +/// utt1 spk1 +/// utt2 spk1 +/// utt3 spk1 +/// and so on. Note: this is optional; if it is an empty string, we just won't +/// do the mapping. Also, "table_rxfilename" may be the empty string (as for +/// a regular table), in which case the table just won't be opened. +/// We provide only the most frequently used of the functions of RandomAccessTableReader. + +template<class Holder> +class RandomAccessTableReaderMapped { + public: + typedef typename Holder::T T; + /// Note: "utt2spk_rxfilename" will in the normal case be an rxfilename + /// for an utterance to speaker map, but this code is general; it accepts + /// a generic map. + RandomAccessTableReaderMapped(const std::string &table_rxfilename, + const std::string &utt2spk_rxfilename); + + RandomAccessTableReaderMapped() {}; + + /// Note: when calling Open, utt2spk_rxfilename may be empty. + bool Open(const std::string &table_rxfilename, + const std::string &utt2spk_rxfilename); + + bool HasKey(const std::string &key); + const T &Value(const std::string &key); + inline bool IsOpen() const { return reader_.IsOpen(); } + inline bool Close() { return reader_.Close(); } + + + + // The default copy-constructor will do what we want: it will crash + // for already-opened readers, by calling the member-variable copy-constructors. + private: + // Disallow assignment. + RandomAccessTableReaderMapped &operator=(const RandomAccessTableReaderMapped<Holder>&); + RandomAccessTableReader<Holder> reader_; + RandomAccessTableReader<TokenHolder> token_reader_; + std::string utt2spk_rxfilename_; // Used only in diagnostic messages. +}; + + +/// @} end "addtogroup table_group" +} // end namespace kaldi + +#include "kaldi-table-inl.h" + +#endif // KALDI_UTIL_KALDI_TABLE_H_ diff --git a/kaldi_io/src/kaldi/util/parse-options.h b/kaldi_io/src/kaldi/util/parse-options.h new file mode 100644 index 0000000..f563b54 --- /dev/null +++ b/kaldi_io/src/kaldi/util/parse-options.h @@ -0,0 +1,264 @@ +// util/parse-options.h + +// Copyright 2009-2011 Karel Vesely; Microsoft Corporation; +// Saarland University (Author: Arnab Ghoshal); +// Copyright 2012-2013 Frantisek Skala; Arnab Ghoshal + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#ifndef KALDI_UTIL_PARSE_OPTIONS_H_ +#define KALDI_UTIL_PARSE_OPTIONS_H_ + +#include <map> +#include <string> +#include <vector> + +#include "base/kaldi-common.h" +#include "itf/options-itf.h" + +namespace kaldi { + +/// The class ParseOptions is for parsing command-line options; see +/// \ref parse_options for more documentation. +class ParseOptions : public OptionsItf { + public: + explicit ParseOptions(const char *usage) : + print_args_(true), help_(false), usage_(usage), argc_(0), argv_(NULL), + prefix_(""), other_parser_(NULL) { +#ifndef _MSC_VER // This is just a convenient place to set the stderr to line + setlinebuf(stderr); // buffering mode, since it's called at program start. +#endif // This helps ensure different programs' output is not mixed up. + RegisterStandard("config", &config_, "Configuration file to read (this " + "option may be repeated)"); + RegisterStandard("print-args", &print_args_, + "Print the command line arguments (to stderr)"); + RegisterStandard("help", &help_, "Print out usage message"); + RegisterStandard("verbose", &g_kaldi_verbose_level, + "Verbose level (higher->more logging)"); + } + + /** + This is a constructor for the special case where some options are + registered with a prefix to avoid conflicts. The object thus created will + only be used temporarily to register an options class with the original + options parser (which is passed as the *other pointer) using the given + prefix. It should not be used for any other purpose, and the prefix must + not be the empty string. It seems to be the least bad way of implementing + options with prefixes at this point. + Example of usage is: + ParseOptions po; // original ParseOptions object + ParseOptions po_mfcc("mfcc", &po); // object with prefix. + MfccOptions mfcc_opts; + mfcc_opts.Register(&po_mfcc); + The options will now get registered as, e.g., --mfcc.frame-shift=10.0 + instead of just --frame-shift=10.0 + */ + ParseOptions(const std::string &prefix, OptionsItf *other); + + ~ParseOptions() {} + + // Methods from the interface + void Register(const std::string &name, + bool *ptr, const std::string &doc); + void Register(const std::string &name, + int32 *ptr, const std::string &doc); + void Register(const std::string &name, + uint32 *ptr, const std::string &doc); + void Register(const std::string &name, + float *ptr, const std::string &doc); + void Register(const std::string &name, + double *ptr, const std::string &doc); + void Register(const std::string &name, + std::string *ptr, const std::string &doc); + + /// If called after registering an option and before calling + /// Read(), disables that option from being used. Will crash + /// at runtime if that option had not been registered. + void DisableOption(const std::string &name); + + /// This one is used for registering standard parameters of all the programs + template<typename T> + void RegisterStandard(const std::string &name, + T *ptr, const std::string &doc); + + /** + Parses the command line options and fills the ParseOptions-registered + variables. This must be called after all the variables were registered!!! + + Initially the variables have implicit values, + then the config file values are set-up, + finally the command line vaues given. + Returns the first position in argv that was not used. + [typically not useful: use NumParams() and GetParam(). ] + */ + int Read(int argc, const char *const *argv); + + /// Prints the usage documentation [provided in the constructor]. + void PrintUsage(bool print_command_line = false); + /// Prints the actual configuration of all the registered variables + void PrintConfig(std::ostream &os); + + /// Reads the options values from a config file. Must be called after + /// registering all options. This is usually used internally after the + /// standard --config option is used, but it may also be called from a + /// program. + void ReadConfigFile(const std::string &filename); + + /// Number of positional parameters (c.f. argc-1). + int NumArgs() const; + + /// Returns one of the positional parameters; 1-based indexing for argc/argv + /// compatibility. Will crash if param is not >=1 and <=NumArgs(). + std::string GetArg(int param) const; + + std::string GetOptArg(int param) const { + return (param <= NumArgs() ? GetArg(param) : ""); + } + + /// The following function will return a possibly quoted and escaped + /// version of "str", according to the current shell. Currently + /// this is just hardwired to bash. It's useful for debug output. + static std::string Escape(const std::string &str); + + private: + /// Template to register various variable types, + /// used for program-specific parameters + template<typename T> + void RegisterTmpl(const std::string &name, T *ptr, const std::string &doc); + + // Following functions do just the datatype-specific part of the job + /// Register boolean variable + void RegisterSpecific(const std::string &name, const std::string &idx, + bool *b, const std::string &doc, bool is_standard); + /// Register int32 variable + void RegisterSpecific(const std::string &name, const std::string &idx, + int32 *i, const std::string &doc, bool is_standard); + /// Register unsinged int32 variable + void RegisterSpecific(const std::string &name, const std::string &idx, + uint32 *u, + const std::string &doc, bool is_standard); + /// Register float variable + void RegisterSpecific(const std::string &name, const std::string &idx, + float *f, const std::string &doc, bool is_standard); + /// Register double variable [useful as we change BaseFloat type]. + void RegisterSpecific(const std::string &name, const std::string &idx, + double *f, const std::string &doc, bool is_standard); + /// Register string variable + void RegisterSpecific(const std::string &name, const std::string &idx, + std::string *s, const std::string &doc, + bool is_standard); + + /// Does the actual job for both kinds of parameters + /// Does the common part of the job for all datatypes, + /// then calls RegisterSpecific + template<typename T> + void RegisterCommon(const std::string &name, + T *ptr, const std::string &doc, bool is_standard); + + /// SplitLongArg parses an argument of the form --a=b, --a=, or --a, + /// and sets "has_equal_sign" to true if an equals-sign was parsed.. + /// this is needed in order to correctly allow --x for a boolean option + /// x, and --y= for a string option y, and to disallow --x= and --y. + void SplitLongArg(std::string in, std::string *key, std::string *value, + bool *has_equal_sign); + + void NormalizeArgName(std::string *str); + + /// Set option with name "key" to "value"; will crash if can't do it. + /// "has_equal_sign" is used to allow --x for a boolean option x, + /// and --y=, for a string option y. + bool SetOption(const std::string &key, const std::string &value, + bool has_equal_sign); + + bool ToBool(std::string str); + int32 ToInt(std::string str); + uint32 ToUInt(std::string str); + float ToFloat(std::string str); + double ToDouble(std::string str); + + // maps for option variables + std::map<std::string, bool*> bool_map_; + std::map<std::string, int32*> int_map_; + std::map<std::string, uint32*> uint_map_; + std::map<std::string, float*> float_map_; + std::map<std::string, double*> double_map_; + std::map<std::string, std::string*> string_map_; + + /** + Structure for options' documentation + */ + struct DocInfo { + DocInfo() {} + DocInfo(const std::string &name, const std::string &usemsg) + : name_(name), use_msg_(usemsg), is_standard_(false) {} + DocInfo(const std::string &name, const std::string &usemsg, + bool is_standard) + : name_(name), use_msg_(usemsg), is_standard_(is_standard) {} + + std::string name_; + std::string use_msg_; + bool is_standard_; + }; + typedef std::map<std::string, DocInfo> DocMapType; + DocMapType doc_map_; ///< map for the documentation + + bool print_args_; ///< variable for the implicit --print-args parameter + bool help_; ///< variable for the implicit --help parameter + std::string config_; ///< variable for the implicit --config parameter + std::vector<std::string> positional_args_; + const char *usage_; + int argc_; + const char *const *argv_; + + /// These members are not normally used. They are only used when the object + /// is constructed with a prefix + std::string prefix_; + OptionsItf *other_parser_; +}; + +/// This template is provided for convenience in reading config classes from +/// files; this is not the standard way to read configuration options, but may +/// occasionally be needed. This function assumes the config has a function +/// "void Register(OptionsItf *po)" which it can call to register the +/// ParseOptions object. +template<class C> void ReadConfigFromFile(const std::string config_filename, + C *c) { + std::ostringstream usage_str; + usage_str << "Parsing config from " + << "from '" << config_filename << "'"; + ParseOptions po(usage_str.str().c_str()); + c->Register(&po); + po.ReadConfigFile(config_filename); +} + +/// This variant of the template ReadConfigFromFile is for if you need to read +/// two config classes from the same file. +template<class C1, class C2> void ReadConfigsFromFile(const std::string config_filename, + C1 *c1, C2 *c2) { + std::ostringstream usage_str; + usage_str << "Parsing config from " + << "from '" << config_filename << "'"; + ParseOptions po(usage_str.str().c_str()); + c1->Register(&po); + c2->Register(&po); + po.ReadConfigFile(config_filename); +} + + + +} // namespace kaldi + +#endif // KALDI_UTIL_PARSE_OPTIONS_H_ diff --git a/kaldi_io/src/kaldi/util/simple-io-funcs.h b/kaldi_io/src/kaldi/util/simple-io-funcs.h new file mode 100644 index 0000000..56573e4 --- /dev/null +++ b/kaldi_io/src/kaldi/util/simple-io-funcs.h @@ -0,0 +1,56 @@ +// util/simple-io-funcs.h + +// Copyright 2009-2011 Microsoft Corporation; Jan Silovsky + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at + +// http://www.apache.org/licenses/LICENSE-2.0 + +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. +#ifndef KALDI_UTIL_SIMPLE_IO_FUNCS_H_ +#define KALDI_UTIL_SIMPLE_IO_FUNCS_H_ + +#include "kaldi-io.h" + +// This header contains some utilities for reading some common, simple text formats: +// integers in files, one per line, and integers in files, possibly multiple per line. +// these are not really fully native Kaldi formats; they are mostly for small files that +// might be generated by scripts, and can be read all at one time. +// for longer files of this type, we would probably use the Table code. + +namespace kaldi { + +/// WriteToList attempts to write this list of integers, one per line, +/// to the given file, in text format. +/// returns true if succeeded. +bool WriteIntegerVectorSimple(std::string wxfilename, const std::vector<int32> &v); + +/// ReadFromList attempts to read this list of integers, one per line, +/// from the given file, in text format. +/// returns true if succeeded. +bool ReadIntegerVectorSimple(std::string rxfilename, std::vector<int32> *v); + +// This is a file format like: +// 1 2 +// 3 +// +// 4 5 6 +// etc. +bool WriteIntegerVectorVectorSimple(std::string wxfilename, const std::vector<std::vector<int32> > &v); + +bool ReadIntegerVectorVectorSimple(std::string rxfilename, std::vector<std::vector<int32> > *v); + + +} // end namespace kaldi. + + +#endif diff --git a/kaldi_io/src/kaldi/util/simple-options.h b/kaldi_io/src/kaldi/util/simple-options.h new file mode 100644 index 0000000..58816af --- /dev/null +++ b/kaldi_io/src/kaldi/util/simple-options.h @@ -0,0 +1,112 @@ +// util/simple-options.hh + +// Copyright 2013 Tanel Alumae, Tallinn University of Technology + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#ifndef KALDI_UTIL_SIMPLE_OPTIONS_H_ +#define KALDI_UTIL_SIMPLE_OPTIONS_H_ + +#include <map> +#include <string> +#include <vector> + +#include "base/kaldi-common.h" +#include "itf/options-itf.h" + +namespace kaldi { + + +/// The class SimpleOptions is an implementation of OptionsItf that allows +/// setting and getting option values programmatically, i.e., via getter +/// and setter methods. It doesn't provide any command line parsing functionality. +/// The class ParseOptions should be used for command-line options. +class SimpleOptions : public OptionsItf { + public: + SimpleOptions() { + } + + virtual ~SimpleOptions() { + } + + // Methods from the interface + void Register(const std::string &name, bool *ptr, const std::string &doc); + void Register(const std::string &name, int32 *ptr, const std::string &doc); + void Register(const std::string &name, uint32 *ptr, const std::string &doc); + void Register(const std::string &name, float *ptr, const std::string &doc); + void Register(const std::string &name, double *ptr, const std::string &doc); + void Register(const std::string &name, std::string *ptr, + const std::string &doc); + + // set option with the specified key, return true if successful + bool SetOption(const std::string &key, const bool &value); + bool SetOption(const std::string &key, const int32 &value); + bool SetOption(const std::string &key, const uint32 &value); + bool SetOption(const std::string &key, const float &value); + bool SetOption(const std::string &key, const double &value); + bool SetOption(const std::string &key, const std::string &value); + bool SetOption(const std::string &key, const char* value); + + // get option with the specified key and put to 'value', + // return true if successful + bool GetOption(const std::string &key, bool *value); + bool GetOption(const std::string &key, int32 *value); + bool GetOption(const std::string &key, uint32 *value); + bool GetOption(const std::string &key, float *value); + bool GetOption(const std::string &key, double *value); + bool GetOption(const std::string &key, std::string *value); + + enum OptionType { + kBool, + kInt32, + kUint32, + kFloat, + kDouble, + kString + }; + + struct OptionInfo { + OptionInfo(const std::string &doc, OptionType type) : + doc(doc), type(type) { + } + std::string doc; + OptionType type; + }; + + std::vector<std::pair<std::string, OptionInfo> > GetOptionInfoList(); + + /* + * Puts the type of the option with name 'key' in the argument 'type'. + * Return true if such option is found, false otherwise. + */ + bool GetOptionType(const std::string &key, OptionType *type); + + private: + + std::vector<std::pair<std::string, OptionInfo> > option_info_list_; + + // maps for option variables + std::map<std::string, bool*> bool_map_; + std::map<std::string, int32*> int_map_; + std::map<std::string, uint32*> uint_map_; + std::map<std::string, float*> float_map_; + std::map<std::string, double*> double_map_; + std::map<std::string, std::string*> string_map_; +}; + +} // namespace kaldi + +#endif // KALDI_UTIL_SIMPLE_OPTIONS_H_ diff --git a/kaldi_io/src/kaldi/util/stl-utils.h b/kaldi_io/src/kaldi/util/stl-utils.h new file mode 100644 index 0000000..12526ff --- /dev/null +++ b/kaldi_io/src/kaldi/util/stl-utils.h @@ -0,0 +1,327 @@ +// util/stl-utils.h + +// Copyright 2009-2011 Microsoft Corporation; Saarland University + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#ifndef KALDI_UTIL_STL_UTILS_H_ +#define KALDI_UTIL_STL_UTILS_H_ + +#include <algorithm> +#include <map> +#include <set> +#include <string> +#include <vector> +#include "base/kaldi-common.h" + +#ifdef _MSC_VER +#include <unordered_map> +#include <unordered_set> +using std::unordered_map; +using std::unordered_set; +#elif __cplusplus > 199711L || defined(__GXX_EXPERIMENTAL_CXX0X__) +#include <unordered_map> +#include <unordered_set> +using std::unordered_map; +using std::unordered_set; +#else +#include <tr1/unordered_map> +#include <tr1/unordered_set> +using std::tr1::unordered_map; +using std::tr1::unordered_set; +#endif + + +namespace kaldi { + +/// Sorts and uniq's (removes duplicates) from a vector. +template<typename T> +inline void SortAndUniq(std::vector<T> *vec) { + std::sort(vec->begin(), vec->end()); + vec->erase(std::unique(vec->begin(), vec->end()), vec->end()); +} + + +/// Returns true if the vector is sorted. +template<typename T> +inline bool IsSorted(const std::vector<T> &vec) { + typename std::vector<T>::const_iterator iter = vec.begin(), end = vec.end(); + if (iter == end) return true; + while (1) { + typename std::vector<T>::const_iterator next_iter = iter; + ++next_iter; + if (next_iter == end) return true; // end of loop and nothing out of order + if (*next_iter < *iter) return false; + iter = next_iter; + } +} + + +/// Returns true if the vector is sorted and contains each element +/// only once. +template<typename T> +inline bool IsSortedAndUniq(const std::vector<T> &vec) { + typename std::vector<T>::const_iterator iter = vec.begin(), end = vec.end(); + if (iter == end) return true; + while (1) { + typename std::vector<T>::const_iterator next_iter = iter; + ++next_iter; + if (next_iter == end) return true; // end of loop and nothing out of order + if (*next_iter <= *iter) return false; + iter = next_iter; + } +} + + +/// Removes duplicate elements from a sorted list. +template<typename T> +inline void Uniq(std::vector<T> *vec) { // must be already sorted. + KALDI_PARANOID_ASSERT(IsSorted(*vec)); + KALDI_ASSERT(vec); + vec->erase(std::unique(vec->begin(), vec->end()), vec->end()); +} + +/// Copies the elements of a set to a vector. +template<class T> +void CopySetToVector(const std::set<T> &s, std::vector<T> *v) { + // adds members of s to v, in sorted order from lowest to highest + // (because the set was in sorted order). + KALDI_ASSERT(v != NULL); + v->resize(s.size()); + typename std::set<T>::const_iterator siter = s.begin(), send = s.end(); + typename std::vector<T>::iterator viter = v->begin(); + for (; siter != send; ++siter, ++viter) { + *viter = *siter; + } +} + +template<class T> +void CopySetToVector(const unordered_set<T> &s, std::vector<T> *v) { + // adds members of s to v, in sorted order from lowest to highest + // (because the set was in sorted order). + KALDI_ASSERT(v != NULL); + v->resize(s.size()); + typename unordered_set<T>::const_iterator siter = s.begin(), send = s.end(); + typename std::vector<T>::iterator viter = v->begin(); + for (; siter != send; ++siter, ++viter) { + *viter = *siter; + } +} + + +/// Copies the (key, value) pairs in a map to a vector of pairs. +template<class A, class B> +void CopyMapToVector(const std::map<A, B> &m, + std::vector<std::pair<A, B> > *v) { + KALDI_ASSERT(v != NULL); + v->resize(m.size()); + typename std::map<A, B>::const_iterator miter = m.begin(), mend = m.end(); + typename std::vector<std::pair<A, B> >::iterator viter = v->begin(); + for (; miter != mend; ++miter, ++viter) { + *viter = std::make_pair(miter->first, miter->second); + // do it like this because of const casting. + } +} + +/// Copies the keys in a map to a vector. +template<class A, class B> +void CopyMapKeysToVector(const std::map<A, B> &m, std::vector<A> *v) { + KALDI_ASSERT(v != NULL); + v->resize(m.size()); + typename std::map<A, B>::const_iterator miter = m.begin(), mend = m.end(); + typename std::vector<A>::iterator viter = v->begin(); + for (; miter != mend; ++miter, ++viter) { + *viter = miter->first; + } +} + +/// Copies the values in a map to a vector. +template<class A, class B> +void CopyMapValuesToVector(const std::map<A, B> &m, std::vector<B> *v) { + KALDI_ASSERT(v != NULL); + v->resize(m.size()); + typename std::map<A, B>::const_iterator miter = m.begin(), mend = m.end(); + typename std::vector<B>::iterator viter = v->begin(); + for (; miter != mend; ++miter, ++viter) { + *viter = miter->second; + } +} + +/// Copies the keys in a map to a set. +template<class A, class B> +void CopyMapKeysToSet(const std::map<A, B> &m, std::set<A> *s) { + KALDI_ASSERT(s != NULL); + s->clear(); + typename std::map<A, B>::const_iterator miter = m.begin(), mend = m.end(); + for (; miter != mend; ++miter) { + s->insert(s->end(), miter->first); + } +} + +/// Copies the values in a map to a set. +template<class A, class B> +void CopyMapValuesToSet(const std::map<A, B> &m, std::set<B> *s) { + KALDI_ASSERT(s != NULL); + s->clear(); + typename std::map<A, B>::const_iterator miter = m.begin(), mend = m.end(); + for (; miter != mend; ++miter) + s->insert(s->end(), miter->second); +} + + +/// Copies the contents of a vector to a set. +template<class A> +void CopyVectorToSet(const std::vector<A> &v, std::set<A> *s) { + KALDI_ASSERT(s != NULL); + s->clear(); + typename std::vector<A>::const_iterator iter = v.begin(), end = v.end(); + for (; iter != end; ++iter) + s->insert(s->end(), *iter); + // s->end() is a hint in case v was sorted. will work regardless. +} + +/// Deletes any non-NULL pointers in the vector v, and sets +/// the corresponding entries of v to NULL +template<class A> +void DeletePointers(std::vector<A*> *v) { + KALDI_ASSERT(v != NULL); + typename std::vector<A*>::iterator iter = v->begin(), end = v->end(); + for (; iter != end; ++iter) { + if (*iter != NULL) { + delete *iter; + *iter = NULL; // set to NULL for extra safety. + } + } +} + +/// Returns true if the vector of pointers contains NULL pointers. +template<class A> +bool ContainsNullPointers(const std::vector<A*> &v) { + typename std::vector<A*>::const_iterator iter = v.begin(), end = v.end(); + for (; iter != end; ++iter) + if (*iter == static_cast<A*> (NULL)) return true; + return false; +} + +/// Copies the contents a vector of one type to a vector +/// of another type. +template<typename A, typename B> +void CopyVectorToVector(const std::vector<A> &vec_in, std::vector<B> *vec_out) { + KALDI_ASSERT(vec_out != NULL); + vec_out->resize(vec_in.size()); + for (size_t i = 0; i < vec_in.size(); i++) + (*vec_out)[i] = static_cast<B> (vec_in[i]); +} + +/// A hashing function-object for vectors. +template<typename Int> +struct VectorHasher { // hashing function for vector<Int>. + size_t operator()(const std::vector<Int> &x) const { + size_t ans = 0; + typename std::vector<Int>::const_iterator iter = x.begin(), end = x.end(); + for (; iter != end; ++iter) { + ans *= kPrime; + ans += *iter; + } + return ans; + } + VectorHasher() { // Check we're instantiated with an integer type. + KALDI_ASSERT_IS_INTEGER_TYPE(Int); + } + private: + static const int kPrime = 7853; +}; + +/// A hashing function-object for pairs of ints +template<typename Int> +struct PairHasher { // hashing function for pair<int> + size_t operator()(const std::pair<Int,Int> &x) const { + return x.first + x.second * kPrime; + } + PairHasher() { // Check we're instantiated with an integer type. + KALDI_ASSERT_IS_INTEGER_TYPE(Int); + } + private: + static const int kPrime = 7853; +}; + + +/// A hashing function object for strings. +struct StringHasher { // hashing function for std::string + size_t operator()(const std::string &str) const { + size_t ans = 0, len = str.length(); + const char *c = str.c_str(), *end = c + len; + for (; c != end; c++) { + ans *= kPrime; + ans += *c; + } + return ans; + } + private: + static const int kPrime = 7853; +}; + +/// Reverses the contents of a vector. +template<typename T> +inline void ReverseVector(std::vector<T> *vec) { + KALDI_ASSERT(vec != NULL); + size_t sz = vec->size(); + for (size_t i = 0; i < sz/2; i++) + std::swap( (*vec)[i], (*vec)[sz-1-i]); +} + + +/// Comparator object for pairs that compares only the first pair. +template<class A, class B> +struct CompareFirstMemberOfPair { + inline bool operator() (const std::pair<A, B> &p1, + const std::pair<A, B> &p2) { + return p1.first < p2.first; + } +}; + +/// For a vector of pair<I, F> where I is an integer and F a floating-point or +/// integer type, this function sorts a vector of type vector<pair<I, F> > on +/// the I value and then merges elements with equal I values, summing these over +/// the F component and then removing any F component with zero value. This +/// is for where the vector of pairs represents a map from the integer to float +/// component, with an "adding" type of semantics for combining the elements. +template<typename I, typename F> +inline void MergePairVectorSumming(std::vector<std::pair<I, F> > *vec) { + KALDI_ASSERT_IS_INTEGER_TYPE(I); + CompareFirstMemberOfPair<I, F> c; + std::sort(vec->begin(), vec->end(), c); // sort on 1st element. + typename std::vector<std::pair<I, F> >::iterator out = vec->begin(), + in = vec->begin(), end = vec->end(); + while (in < end) { + // We reach this point only at the first element of + // each stretch of identical .first elements. + *out = *in; + ++in; + while (in < end && in->first == out->first) { + out->second += in->second; // this is the merge operation. + ++in; + } + if (out->second != static_cast<F>(0)) // Don't keep zero elements. + out++; + } + vec->erase(out, end); +} + +} // namespace kaldi + +#endif // KALDI_UTIL_STL_UTILS_H_ + diff --git a/kaldi_io/src/kaldi/util/table-types.h b/kaldi_io/src/kaldi/util/table-types.h new file mode 100644 index 0000000..313d1aa --- /dev/null +++ b/kaldi_io/src/kaldi/util/table-types.h @@ -0,0 +1,137 @@ +// util/table-types.h + +// Copyright 2009-2011 Microsoft Corporation + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + + +#ifndef KALDI_UTIL_TABLE_TYPES_H_ +#define KALDI_UTIL_TABLE_TYPES_H_ +#include "base/kaldi-common.h" +#include "util/kaldi-table.h" +#include "util/kaldi-holder.h" +#include "matrix/matrix-lib.h" + +namespace kaldi { + +// This header defines typedefs that are specific instantiations of +// the Table types. + +/// \addtogroup table_types +/// @{ + +typedef TableWriter<KaldiObjectHolder<Matrix<BaseFloat> > > BaseFloatMatrixWriter; +typedef SequentialTableReader<KaldiObjectHolder<Matrix<BaseFloat> > > SequentialBaseFloatMatrixReader; +typedef RandomAccessTableReader<KaldiObjectHolder<Matrix<BaseFloat> > > RandomAccessBaseFloatMatrixReader; +typedef RandomAccessTableReaderMapped<KaldiObjectHolder<Matrix<BaseFloat> > > RandomAccessBaseFloatMatrixReaderMapped; + +typedef TableWriter<KaldiObjectHolder<Matrix<double> > > DoubleMatrixWriter; +typedef SequentialTableReader<KaldiObjectHolder<Matrix<double> > > SequentialDoubleMatrixReader; +typedef RandomAccessTableReader<KaldiObjectHolder<Matrix<double> > > RandomAccessDoubleMatrixReader; +typedef RandomAccessTableReaderMapped<KaldiObjectHolder<Matrix<double> > > RandomAccessDoubleMatrixReaderMapped; + +typedef TableWriter<KaldiObjectHolder<CompressedMatrix> > CompressedMatrixWriter; + +typedef TableWriter<KaldiObjectHolder<Vector<BaseFloat> > > BaseFloatVectorWriter; +typedef SequentialTableReader<KaldiObjectHolder<Vector<BaseFloat> > > SequentialBaseFloatVectorReader; +typedef RandomAccessTableReader<KaldiObjectHolder<Vector<BaseFloat> > > RandomAccessBaseFloatVectorReader; +typedef RandomAccessTableReaderMapped<KaldiObjectHolder<Vector<BaseFloat> > > RandomAccessBaseFloatVectorReaderMapped; + +typedef TableWriter<KaldiObjectHolder<Vector<double> > > DoubleVectorWriter; +typedef SequentialTableReader<KaldiObjectHolder<Vector<double> > > SequentialDoubleVectorReader; +typedef RandomAccessTableReader<KaldiObjectHolder<Vector<double> > > RandomAccessDoubleVectorReader; + +typedef TableWriter<KaldiObjectHolder<CuMatrix<BaseFloat> > > BaseFloatCuMatrixWriter; +typedef SequentialTableReader<KaldiObjectHolder<CuMatrix<BaseFloat> > > SequentialBaseFloatCuMatrixReader; +typedef RandomAccessTableReader<KaldiObjectHolder<CuMatrix<BaseFloat> > > RandomAccessBaseFloatCuMatrixReader; +typedef RandomAccessTableReaderMapped<KaldiObjectHolder<CuMatrix<BaseFloat> > > RandomAccessBaseFloatCuMatrixReaderMapped; + +typedef TableWriter<KaldiObjectHolder<CuMatrix<double> > > DoubleCuMatrixWriter; +typedef SequentialTableReader<KaldiObjectHolder<CuMatrix<double> > > SequentialDoubleCuMatrixReader; +typedef RandomAccessTableReader<KaldiObjectHolder<CuMatrix<double> > > RandomAccessDoubleCuMatrixReader; +typedef RandomAccessTableReaderMapped<KaldiObjectHolder<CuMatrix<double> > > RandomAccessDoubleCuMatrixReaderMapped; + +typedef TableWriter<KaldiObjectHolder<CuVector<BaseFloat> > > BaseFloatCuVectorWriter; +typedef SequentialTableReader<KaldiObjectHolder<CuVector<BaseFloat> > > SequentialBaseFloatCuVectorReader; +typedef RandomAccessTableReader<KaldiObjectHolder<CuVector<BaseFloat> > > RandomAccessBaseFloatCuVectorReader; +typedef RandomAccessTableReaderMapped<KaldiObjectHolder<CuVector<BaseFloat> > > RandomAccessBaseFloatCuVectorReaderMapped; + +typedef TableWriter<KaldiObjectHolder<CuVector<double> > > DoubleCuVectorWriter; +typedef SequentialTableReader<KaldiObjectHolder<CuVector<double> > > SequentialDoubleCuVectorReader; +typedef RandomAccessTableReader<KaldiObjectHolder<CuVector<double> > > RandomAccessDoubleCuVectorReader; + + +typedef TableWriter<BasicHolder<int32> > Int32Writer; +typedef SequentialTableReader<BasicHolder<int32> > SequentialInt32Reader; +typedef RandomAccessTableReader<BasicHolder<int32> > RandomAccessInt32Reader; + +typedef TableWriter<BasicVectorHolder<int32> > Int32VectorWriter; +typedef SequentialTableReader<BasicVectorHolder<int32> > SequentialInt32VectorReader; +typedef RandomAccessTableReader<BasicVectorHolder<int32> > RandomAccessInt32VectorReader; + +typedef TableWriter<BasicVectorVectorHolder<int32> > Int32VectorVectorWriter; +typedef SequentialTableReader<BasicVectorVectorHolder<int32> > SequentialInt32VectorVectorReader; +typedef RandomAccessTableReader<BasicVectorVectorHolder<int32> > RandomAccessInt32VectorVectorReader; + +typedef TableWriter<BasicPairVectorHolder<int32> > Int32PairVectorWriter; +typedef SequentialTableReader<BasicPairVectorHolder<int32> > SequentialInt32PairVectorReader; +typedef RandomAccessTableReader<BasicPairVectorHolder<int32> > RandomAccessInt32PairVectorReader; + +typedef TableWriter<BasicPairVectorHolder<BaseFloat> > BaseFloatPairVectorWriter; +typedef SequentialTableReader<BasicPairVectorHolder<BaseFloat> > SequentialBaseFloatPairVectorReader; +typedef RandomAccessTableReader<BasicPairVectorHolder<BaseFloat> > RandomAccessBaseFloatPairVectorReader; + +typedef TableWriter<BasicHolder<BaseFloat> > BaseFloatWriter; +typedef SequentialTableReader<BasicHolder<BaseFloat> > SequentialBaseFloatReader; +typedef RandomAccessTableReader<BasicHolder<BaseFloat> > RandomAccessBaseFloatReader; +typedef RandomAccessTableReaderMapped<BasicHolder<BaseFloat> > RandomAccessBaseFloatReaderMapped; + +typedef TableWriter<BasicHolder<double> > DoubleWriter; +typedef SequentialTableReader<BasicHolder<double> > SequentialDoubleReader; +typedef RandomAccessTableReader<BasicHolder<double> > RandomAccessDoubleReader; + +typedef TableWriter<BasicHolder<bool> > BoolWriter; +typedef SequentialTableReader<BasicHolder<bool> > SequentialBoolReader; +typedef RandomAccessTableReader<BasicHolder<bool> > RandomAccessBoolReader; + + + +/// TokenWriter is a writer specialized for std::string where the strings +/// are nonempty and whitespace-free. T == std::string +typedef TableWriter<TokenHolder> TokenWriter; +typedef SequentialTableReader<TokenHolder> SequentialTokenReader; +typedef RandomAccessTableReader<TokenHolder> RandomAccessTokenReader; + + +/// TokenVectorWriter is a writer specialized for sequences of +/// std::string where the strings are nonempty and whitespace-free. +/// T == std::vector<std::string> +typedef TableWriter<TokenVectorHolder> TokenVectorWriter; +// Ditto for SequentialTokenVectorReader. +typedef SequentialTableReader<TokenVectorHolder> SequentialTokenVectorReader; +typedef RandomAccessTableReader<TokenVectorHolder> RandomAccessTokenVectorReader; + + +/// @} + +// Note: for FST reader/writer, see ../fstext/fstext-utils.h +// [not done yet]. + +} // end namespace kaldi + + + +#endif diff --git a/kaldi_io/src/kaldi/util/text-utils.h b/kaldi_io/src/kaldi/util/text-utils.h new file mode 100644 index 0000000..1d85c47 --- /dev/null +++ b/kaldi_io/src/kaldi/util/text-utils.h @@ -0,0 +1,169 @@ +// util/text-utils.h + +// Copyright 2009-2011 Saarland University; Microsoft Corporation + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#ifndef KALDI_UTIL_TEXT_UTILS_H_ +#define KALDI_UTIL_TEXT_UTILS_H_ + +#include <algorithm> +#include <map> +#include <set> +#include <string> +#include <vector> +#include <errno.h> + +#include "base/kaldi-common.h" + +namespace kaldi { + +/// Split a string using any of the single character delimiters. +/// If omit_empty_strings == true, the output will contain any +/// nonempty strings after splitting on any of the +/// characters in the delimiter. If omit_empty_strings == false, +/// the output will contain n+1 strings if there are n characters +/// in the set "delim" within the input string. In this case +/// the empty string is split to a single empty string. +void SplitStringToVector(const std::string &full, const char *delim, + bool omit_empty_strings, + std::vector<std::string> *out); + +/// Joins the elements of a vector of strings into a single string using +/// "delim" as the delimiter. If omit_empty_strings == true, any empty strings +/// in the vector are skipped. A vector of empty strings results in an empty +/// string on the output. +void JoinVectorToString(const std::vector<std::string> &vec_in, + const char *delim, bool omit_empty_strings, + std::string *str_out); + + +/// Split a string (e.g. 1:2:3) into a vector of integers. +/// The delimiting char may be any character in "delim". +/// returns true on success, false on failure. +/// If omit_empty_strings == true, 1::2:3: will become +/// { 1, 2, 3 }. Otherwise it would be rejected. +/// Regardless of the value of omit_empty_strings, +/// the empty string is successfully parsed as an empty +/// vector of integers +template<class I> +bool SplitStringToIntegers(const std::string &full, + const char *delim, + bool omit_empty_strings, // typically false [but + // should probably be true + // if "delim" is spaces]. + std::vector<I> *out) { + KALDI_ASSERT(out != NULL); + KALDI_ASSERT_IS_INTEGER_TYPE(I); + if ( *(full.c_str()) == '\0') { + out->clear(); + return true; + } + std::vector<std::string> split; + SplitStringToVector(full, delim, omit_empty_strings, &split); + out->resize(split.size()); + for (size_t i = 0; i < split.size(); i++) { + const char *this_str = split[i].c_str(); + char *end = NULL; + long long int j = 0; + j = KALDI_STRTOLL(this_str, &end); + if (end == this_str || *end != '\0') { + out->clear(); + return false; + } else { + I jI = static_cast<I>(j); + if (static_cast<long long int>(jI) != j) { + // output type cannot fit this integer. + out->clear(); + return false; + } + (*out)[i] = jI; + } + } + return true; +} + +// This is defined for F = float and double. +template<class F> +bool SplitStringToFloats(const std::string &full, + const char *delim, + bool omit_empty_strings, // typically false + std::vector<F> *out); + + +/// Converts a string into an integer via strtoll and returns false if there was +/// any kind of problem (i.e. the string was not an integer or contained extra +/// non-whitespace junk, or the integer was too large to fit into the type it is +/// being converted into). Only sets *out if everything was OK and it returns +/// true. +template<class Int> +bool ConvertStringToInteger(const std::string &str, + Int *out) { + KALDI_ASSERT_IS_INTEGER_TYPE(Int); + const char *this_str = str.c_str(); + char *end = NULL; + errno = 0; + long long int i = KALDI_STRTOLL(this_str, &end); + if (end != this_str) + while (isspace(*end)) end++; + if (end == this_str || *end != '\0' || errno != 0) + return false; + Int iInt = static_cast<Int>(i); + if (static_cast<long long int>(iInt) != i || (i<0 && !std::numeric_limits<Int>::is_signed)) { + return false; + } + *out = iInt; + return true; +} + + +/// ConvertStringToReal converts a string into either float or double via strtod, +/// and returns false if there was any kind of problem (i.e. the string was not a +/// floating point number or contained extra non-whitespace junk. +/// Be careful- this function will successfully read inf's or nan's. +bool ConvertStringToReal(const std::string &str, + double *out); +bool ConvertStringToReal(const std::string &str, + float *out); + + +/// Removes the beginning and trailing whitespaces from a string +void Trim(std::string *str); + + +/// Removes leading and trailing white space from the string, then splits on the +/// first section of whitespace found (if present), putting the part before the +/// whitespace in "first" and the rest in "rest". If there is no such space, +/// everything that remains after removing leading and trailing whitespace goes +/// in "first". +void SplitStringOnFirstSpace(const std::string &line, + std::string *first, + std::string *rest); + + +/// Returns true if "token" is nonempty, and all characters are +/// printable and whitespace-free. +bool IsToken(const std::string &token); + + +/// Returns true if "line" is free of \n characters and unprintable +/// characters, and does not contain leading or trailing whitespace. +bool IsLine(const std::string &line); + + +} // namespace kaldi + +#endif // KALDI_UTIL_TEXT_UTILS_H_ diff --git a/kaldi_io/src/kaldi/util/timer.h b/kaldi_io/src/kaldi/util/timer.h new file mode 100644 index 0000000..e3ee8d5 --- /dev/null +++ b/kaldi_io/src/kaldi/util/timer.h @@ -0,0 +1,27 @@ +// util/timer.h + +// Copyright 2014 Johns Hopkins University (author: Daniel Povey) + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at + +// http://www.apache.org/licenses/LICENSE-2.0 + +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +// We are temporarily leaving this file to forward #includes to +// base-timer.h. Its use is deprecated; you should directrly +// #include base/timer.h +#ifndef KALDI_UTIL_TIMER_H_ +#define KALDI_UTIL_TIMER_H_ +#pragma message warning: please do not include util/timer.h, include base/timer.h (it has been moved) +#include "base/timer.h" +#endif |