diff options
author | Determinant <[email protected]> | 2015-06-25 12:56:45 +0800 |
---|---|---|
committer | Determinant <[email protected]> | 2015-06-25 12:56:45 +0800 |
commit | a74183ddb4ab8383bfe214b3745eb8a0a99ee47a (patch) | |
tree | d5e69cf8c4c2db2e3a4722778352fc3c95953bb2 /htk_io/src | |
parent | b6301089cde20f4c825c7f5deaf179082aad63da (diff) |
let HTK I/O implementation be a single package
Diffstat (limited to 'htk_io/src')
36 files changed, 11574 insertions, 0 deletions
diff --git a/htk_io/src/KaldiLib/Common.cc b/htk_io/src/KaldiLib/Common.cc new file mode 100644 index 0000000..40909ee --- /dev/null +++ b/htk_io/src/KaldiLib/Common.cc @@ -0,0 +1,277 @@ +#include <string> +#include <stdexcept> +#include <cmath> +#include <cfloat> +#include <cstdio> + +#include "Common.h" +#include "MathAux.h" + + +/// Defines the white chars for string trimming +#if !defined(WHITE_CHARS) +# define WHITE_CHARS " \t" +#endif + +namespace TNet { + +#include <ios> + + // Allocating stream variable used by stream modifier MatrixVectorIostreamControl + const int MATRIX_IOS_FORMAT_IWORD = std::ios_base::xalloc(); + + //*************************************************************************** + //*************************************************************************** + int getHTKstr(char *str) + { + char termChar = '\0'; + char *chrptr = str; + + while (std::isspace(*chrptr)) ++chrptr; + + if (*chrptr == '\'' || *chrptr == '"') { + termChar = *chrptr; + chrptr++; + } + + for (; *chrptr; chrptr++) { + if (*chrptr == '\'' || *chrptr == '"') { + if (termChar == *chrptr) { + termChar = '\0'; + chrptr++; + break; + } + } + + if (std::isspace(*chrptr) && !termChar) { + break; + } + + if (*chrptr == '\\') { + ++chrptr; + if (*chrptr == '\0' || (*chrptr >= '0' && *chrptr <= '7' && + (*++chrptr < '0' || *chrptr > '7' || + *++chrptr < '0' || *chrptr > '7'))) { + return -1; + } + + if (*chrptr >= '0' && *chrptr <= '7') { + *chrptr = (char)((*chrptr - '0') + (chrptr[-1] - '0') * 8 + (chrptr[-2] - '0') * 64); + } + } + *str++ = *chrptr; + } + + if (termChar) { + return -2; + } + + *str = '\0'; + + return 0; + } + + + //***************************************************************************** + //***************************************************************************** + void + ParseHTKString(const std::string & rIn, std::string & rOut) + { + int ret_val; + + // the new string will be at most as long as the original, so we allocate + // space + char* new_str = new char[rIn.size() + 1]; + + char* p_htk_str = new_str; + + strcpy(p_htk_str, rIn.c_str()); + ret_val = getHTKstr(p_htk_str); + + // call the function + if (!ret_val) { + rOut = p_htk_str; + } + + delete [] new_str; + + if (ret_val) { + throw std::runtime_error("Error parsing HTK string"); + } + } + + + + //*************************************************************************** + //*************************************************************************** + bool + IsBigEndian() + { + int a = 1; + return (bool) ((char *) &a)[0] != 1; + } + + + //*************************************************************************** + //*************************************************************************** + void + MakeHtkFileName(char* pOutFileName, const char* inFileName, + const char* out_dir, const char* out_ext) + { + const char* base_name; + const char* bname_end = NULL; + const char* chrptr; + + // if (*inFileName == '*' && *++inFileName == '/') ++inFileName; + + // we don't do anything if file is stdin/out + if (!strcmp(inFileName, "-")) + { + pOutFileName[0] = '-'; + pOutFileName[1] = '\0'; + return; + } + + base_name = strrchr(inFileName, '/'); + base_name = base_name != NULL ? base_name + 1 : inFileName; + + if (out_ext) bname_end = strrchr(base_name, '.'); + if (!bname_end) bname_end = base_name + strlen(base_name); + + + if ((chrptr = strstr(inFileName, "/./")) != NULL) + { + // what is in path after /./ serve as base name + base_name = chrptr + 3; + } + /* else if (*inFileName != '/') + { + // if inFileName isn't absolut path, don't forget directory structure + base_name = inFileName; + }*/ + + *pOutFileName = '\0'; + if (out_dir) + { + if (*out_dir) + { + strcat(pOutFileName, out_dir); + strcat(pOutFileName, "/"); + } + strncat(pOutFileName, base_name, bname_end-base_name); + } + else + { + strncat(pOutFileName, inFileName, bname_end-inFileName); + } + + if (out_ext && *out_ext) + { + strcat(pOutFileName, "."); + strcat(pOutFileName, out_ext); + } + } + + + //**************************************************************************** + //**************************************************************************** + bool + CloseEnough(const float f1, const float f2, const float nRounds) + { + bool ret_val = (_ABS((f1 - f2) / (f2 == 0.0f ? 1.0f : f2)) + < (nRounds * FLT_EPSILON)); + + return ret_val; + } + + + //**************************************************************************** + //**************************************************************************** + bool + CloseEnough(const double f1, const double f2, const double nRounds) + { + bool ret_val = (_ABS((f1 - f2) / (f2 == 0.0 ? 1.0 : f2)) + < (nRounds * DBL_EPSILON)); + + return ret_val; + } + + + //**************************************************************************** + //**************************************************************************** + char* + ExpandHtkFilterCmd(const char *command, const char *filename, const char* pFilter) + { + + char *out, *outend; + const char *chrptr = command; + int ndollars = 0; + int fnlen = strlen(filename); + + while (*chrptr++) ndollars += (*chrptr == *pFilter); + + out = (char*) malloc(strlen(command) - ndollars + ndollars * fnlen + 1); + + outend = out; + + for (chrptr = command; *chrptr; chrptr++) { + if (*chrptr == *pFilter) { + strcpy(outend, filename); + outend += fnlen; + } else { + *outend++ = *chrptr; + } + } + *outend = '\0'; + return out; + } + + //*************************************************************************** + //*************************************************************************** + char * + StrToUpper(char *str) + { + char *chptr; + for (chptr = str; *chptr; chptr++) { + *chptr = (char)toupper(*chptr); + } + return str; + } + + + //**************************************************************************** + //**************************************************************************** + std::string& + Trim(std::string& rStr) + { + // WHITE_CHARS is defined in common.h + std::string::size_type pos = rStr.find_last_not_of(WHITE_CHARS); + if(pos != std::string::npos) + { + rStr.erase(pos + 1); + pos = rStr.find_first_not_of(WHITE_CHARS); + if(pos != std::string::npos) rStr.erase(0, pos); + } + else + rStr.erase(rStr.begin(), rStr.end()); + + return rStr; + } + + +} // namespace TNet + +//#ifdef CYGWIN + +void assertf(const char *c, int i, const char *msg){ + printf("Assertion \"%s\" failed: file \"%s\", line %d\n", msg?msg:"(null)", c?c:"(null)", i); + abort(); +} + + +void assertf_throw(const char *c, int i, const char *msg){ + char buf[2000]; + snprintf(buf, 1999, "Assertion \"%s\" failed, throwing exception: file \"%s\", line %d\n", msg?msg:"(null)", c?c:"(null)", i); + throw std::runtime_error((std::string)buf); +} +//#endif diff --git a/htk_io/src/KaldiLib/Common.h b/htk_io/src/KaldiLib/Common.h new file mode 100644 index 0000000..9cd9658 --- /dev/null +++ b/htk_io/src/KaldiLib/Common.h @@ -0,0 +1,233 @@ +#ifndef TNet_Common_h +#define TNet_Common_h + +#include <cstdlib> +#include <string.h> // C string stuff like strcpy +#include <string> +#include <sstream> +#include <stdexcept> + +/* Alignment of critical dynamic data structure + * + * Not all platforms support memalign so we provide a stk_memalign wrapper + * void *stk_memalign( size_t align, size_t size, void **pp_orig ) + * *pp_orig is the pointer that has to be freed afterwards. + */ +#ifdef HAVE_POSIX_MEMALIGN +# define stk_memalign(align,size,pp_orig) \ + ( !posix_memalign( pp_orig, align, size ) ? *(pp_orig) : NULL ) +# ifdef STK_MEMALIGN_MANUAL +# undef STK_MEMALIGN_MANUAL +# endif +#elif defined(HAVE_MEMALIGN) + /* Some systems have memalign() but no declaration for it */ + //void * memalign( size_t align, size_t size ); +# define stk_memalign(align,size,pp_orig) \ + ( *(pp_orig) = memalign( align, size ) ) +# ifdef STK_MEMALIGN_MANUAL +# undef STK_MEMALIGN_MANUAL +# endif +#else /* We don't have any choice but to align manually */ +# define stk_memalign(align,size,pp_orig) \ + (( *(pp_orig) = malloc( size + align - 1 )) ? \ + (void *)( (((unsigned long)*(pp_orig)) + 15) & ~0xFUL ) : NULL ) +# define STK_MEMALIGN_MANUAL +#endif + + +#define swap8(a) { \ + char t=((char*)&a)[0]; ((char*)&a)[0]=((char*)&a)[7]; ((char*)&a)[7]=t;\ + t=((char*)&a)[1]; ((char*)&a)[1]=((char*)&a)[6]; ((char*)&a)[6]=t;\ + t=((char*)&a)[2]; ((char*)&a)[2]=((char*)&a)[5]; ((char*)&a)[5]=t;\ + t=((char*)&a)[3]; ((char*)&a)[3]=((char*)&a)[4]; ((char*)&a)[4]=t;} +#define swap4(a) { \ + char t=((char*)&a)[0]; ((char*)&a)[0]=((char*)&a)[3]; ((char*)&a)[3]=t;\ + t=((char*)&a)[1]; ((char*)&a)[1]=((char*)&a)[2]; ((char*)&a)[2]=t;} +#define swap2(a) { \ + char t=((char*)&a)[0]; ((char*)&a)[0]=((char*)&a)[1]; ((char*)&a)[1]=t;} + + +namespace TNet +{ + /** ************************************************************************** + ** ************************************************************************** + * @brief Aligns a number to a specified base + * @param n Number of type @c _T to align + * @return Aligned value of type @c _T + */ + template<size_t _align, typename _T> + inline _T + align(const _T n) + { + const _T x(_align - 1); + return (n + x) & ~(x); + } + + + /** + * @brief Returns true if architecture is big endian + */ + bool + IsBigEndian(); + + + /** + * @brief Returns true if two numbers are close enough to each other + * + * @param f1 First operand + * @param f2 Second operand + * @param nRounds Expected number of operations prior to this comparison + */ + bool + CloseEnough(const float f1, const float f2, const float nRounds); + + + /** + * @brief Returns true if two numbers are close enough to each other + * + * @param f1 First operand + * @param f2 Second operand + * @param nRounds Expected number of operations prior to this comparison + */ + bool + CloseEnough(const double f1, const double f2, const double nRounds); + + + /** + * @brief Parses a HTK-style string into a C++ std::string readable + * + * @param rIn HTK input string + * @param rOut output parsed string + */ + void + ParseHTKString(const std::string & rIn, std::string & rOut); + + + /** + * @brief Synthesize new file name based on name, path, and extension + * + * @param pOutFileName full ouptut file name + * @param pInFileName file name + * @param pOutDir directory + * @param pOutExt extension + */ + void + MakeHtkFileName(char *pOutFileName, const char* pInFileName, const char *pOutDir, + const char *pOutExt); + + + /** + * @brief Removes the leading and trailing white chars + * + * @param rStr Refference to the string to be processed + * @return Refference to the original string + * + * The white characters are determined by the @c WHITE_CHARS macro defined + * above. + */ + std::string& + Trim(std::string& rStr); + + + char* + StrToUpper(char* pStr); + + char* + ExpandHtkFilterCmd(const char *command, const char *filename, const char* pFilter); + + + template <class T> + std::string to_string(const T& val) + { + std::stringstream ss; + ss << val; + return ss.str(); + } + + inline void + ExpectKeyword(std::istream &i_stream, const char *kwd) + { + std::string token; + i_stream >> token; + if (token != kwd) { + throw std::runtime_error(std::string(kwd) + " expected"); + } + } + + extern const int MATRIX_IOS_FORMAT_IWORD; + + enum MatrixVectorIostreamControlBits { + ACCUMULATE_INPUT = 1, +// BINARY_OUTPUT = 2 + }; + + class MatrixVectorIostreamControl + { + public: + MatrixVectorIostreamControl(enum MatrixVectorIostreamControlBits bitsToBeSet, bool valueToBeSet) + : mBitsToBeSet(bitsToBeSet), mValueToBeSet(valueToBeSet) {} + + static long Flags(std::ios_base &rIos, enum MatrixVectorIostreamControlBits bits) + { return rIos.iword(MATRIX_IOS_FORMAT_IWORD); } + + long mBitsToBeSet; + bool mValueToBeSet; + + friend std::ostream & operator <<(std::ostream &rOs, const MatrixVectorIostreamControl modifier) + { + if(modifier.mValueToBeSet) { + rOs.iword(MATRIX_IOS_FORMAT_IWORD) |= modifier.mBitsToBeSet; + } else { + rOs.iword(MATRIX_IOS_FORMAT_IWORD) &= ~modifier.mBitsToBeSet; + } + return rOs; + } + + friend std::istream & operator >>(std::istream &rIs, const MatrixVectorIostreamControl modifier) + { + if(modifier.mValueToBeSet) { + rIs.iword(MATRIX_IOS_FORMAT_IWORD) |= modifier.mBitsToBeSet; + } else { + rIs.iword(MATRIX_IOS_FORMAT_IWORD) &= ~modifier.mBitsToBeSet; + } + return rIs; + } + }; + + + + +} // namespace TNet + +#ifdef __ICC +#pragma warning (disable: 383) // ICPC remark we don't want. +#pragma warning (disable: 810) // ICPC remark we don't want. +#pragma warning (disable: 981) // ICPC remark we don't want. +#pragma warning (disable: 1418) // ICPC remark we don't want. +#pragma warning (disable: 444) // ICPC remark we don't want. +#pragma warning (disable: 869) // ICPC remark we don't want. +#pragma warning (disable: 1287) // ICPC remark we don't want. +#pragma warning (disable: 279) // ICPC remark we don't want. +#pragma warning (disable: 981) // ICPC remark we don't want. +#endif + +//#ifdef CYGWIN +#if 1 +#undef assert +#ifndef NDEBUG +#define assert(e) ((e) ? (void)0 : assertf(__FILE__, __LINE__, #e)) +#else +#define assert(e) ((void)0) +#endif +void assertf(const char *c, int i, const char *msg); // Just make it possible to break into assert on gdb-- has some kind of bug on cygwin. +#else +#include <cassert> +#endif + +#define assert_throw(e) ((e) ? (void)0 : assertf_throw(__FILE__, __LINE__, #e)) +void assertf_throw(const char *c, int i, const char *msg); + +#define DAN_STYLE_IO + +#endif // ifndef TNet_Common_h + diff --git a/htk_io/src/KaldiLib/Error.h b/htk_io/src/KaldiLib/Error.h new file mode 100644 index 0000000..2228dde --- /dev/null +++ b/htk_io/src/KaldiLib/Error.h @@ -0,0 +1,172 @@ +// +// C++ Interface: %{MODULE} +// +// Description: +// +// +// Author: %{AUTHOR} <%{EMAIL}>, (C) %{YEAR} +// +// Copyright: See COPYING file that comes with this distribution +// +// + +/** @file Error.h + * This header defines several types and functions relating to the + * handling of exceptions in STK. + */ + +#ifndef TNET_Error_h +#define TNET_Error_h + +#include <iostream> +#include <stdexcept> +#include <string> +#include <sstream> + +#include <cstdlib> +#include <execinfo.h> +#include <cstdarg> +#include <cstdio> + +// THESE MACROS TERRIBLY CLASH WITH STK!!!! +// WE MUST USE SAME MACROS! +// +//#define Error(msg) _Error_(__func__, __FILE__, __LINE__, msg) +//#define Warning(msg) _Warning_(__func__, __FILE__, __LINE__, msg) +//#define TraceLog(msg) _TraceLog_(__func__, __FILE__, __LINE__, msg) +// + +#ifndef Error + #define Error(...) _Error_(__func__, __FILE__, __LINE__, __VA_ARGS__) +#endif +#ifndef PError + #define PError(...) _PError_(__func__, __FILE__, __LINE__, __VA_ARGS__) +#endif +#ifndef Warning + #define Warning(...) _Warning_(__func__, __FILE__, __LINE__, __VA_ARGS__) +#endif +#ifndef TraceLog + #define TraceLog(...) _TraceLog_(__func__, __FILE__, __LINE__, __VA_ARGS__) +#endif + +namespace TNet { + + + + /** MyException + * Custom exception class, gets the stacktrace + */ + class MyException + : public std::runtime_error + { + public: + explicit MyException(const std::string& what_arg) throw(); + virtual ~MyException() throw(); + + const char* what() const throw() + { return mWhat.c_str(); } + + private: + std::string mWhat; + }; + + /** + * MyException:: implemenatation + */ + inline + MyException:: + MyException(const std::string& what_arg) throw() + : std::runtime_error(what_arg) + { + mWhat = what_arg; + mWhat += "\nTHE STACKTRACE INSIDE MyException OBJECT IS:\n"; + + void *array[10]; + size_t size; + char **strings; + size_t i; + + size = backtrace (array, 10); + strings = backtrace_symbols (array, size); + + //<< 0th string is the MyException ctor, so ignore and start by 1 + for (i = 1; i < size; i++) { + mWhat += strings[i]; + mWhat += "\n"; + } + + free (strings); + } + + + inline + MyException:: + ~MyException() throw() + { } + + + /** + * @brief Error throwing function (with backtrace) + */ + inline void + _Error_(const char *func, const char *file, int line, const std::string &msg) + { + std::stringstream ss; + ss << "ERROR (" << func << ':' << file << ':' << line << ") " << msg; + throw MyException(ss.str()); + } + + /** + * @brief Throw a formatted error + */ + inline void _PError_(const char *func, const char *file, int line, const char *fmt, ...) { + va_list ap; + char msg[256]; + va_start(ap, fmt); + vsnprintf(msg, sizeof msg, fmt, ap); + va_end(ap); + _Error_(func, file, line, msg); + } + + /** + * @brief Warning handling function + */ + inline void + _Warning_(const char *func, const char *file, int line, const std::string &msg) + { + std::cout << "WARNING (" << func << ':' << file << ':' << line << ") " << msg << std::endl; + } + + inline void + _TraceLog_(const char *func, const char *file, int line, const std::string &msg) + { + std::cout << "INFO (" << func << ':' << file << ':' << line << ") " << msg << std::endl; + std::cout.flush(); + } + + /** + * New kaldi error handling: + * + * class KaldiErrorMessage is invoked from the KALDI_ERROR macro. + * The destructor throws an exception. + */ + class KaldiErrorMessage { + public: + KaldiErrorMessage(const char *func, const char *file, int line) { + this->stream() << "ERROR (" + << func << "():" + << file << ':' << line << ") "; + } + inline std::ostream &stream() { return ss; } + ~KaldiErrorMessage() { throw MyException(ss.str()); } + private: + std::ostringstream ss; + }; + #define KALDI_ERR TNet::KaldiErrorMessage(__func__, __FILE__, __LINE__).stream() + + + +} // namespace TNet + +//#define TNET_Error_h +#endif diff --git a/htk_io/src/KaldiLib/Features.cc b/htk_io/src/KaldiLib/Features.cc new file mode 100644 index 0000000..64b63e8 --- /dev/null +++ b/htk_io/src/KaldiLib/Features.cc @@ -0,0 +1,1798 @@ + +//enable feature repository profiling +#define PROFILING 1 + +#include <sstream> +#include <map> +#include <list> +#include <cstdio> + +#include "Features.h" +#include "Tokenizer.h" +#include "StkMatch.h" +#include "Types.h" + + + +namespace TNet +{ + const char + FeatureRepository:: + mpParmKindNames[13][16] = + { + {"WAVEFORM"}, + {"LPC"}, + {"LPREFC"}, + {"LPCEPSTRA"}, + {"LPDELCEP"}, + {"IREFC"}, + {"MFCC"}, + {"FBANK"}, + {"MELSPEC"}, + {"USER"}, + {"DISCRETE"}, + {"PLP"}, + {"ANON"} + }; + + //*************************************************************************** + //*************************************************************************** + + FileListElem:: + FileListElem(const std::string & rFileName) + { + std::string::size_type pos; + + mLogical = rFileName; + mWeight = 1.0; + + // some slash-backslash replacement hack + for (size_t i = 0; i < mLogical.size(); i++) { + if (mLogical[i] == '\\') { + mLogical[i] = '/'; + } + } + + // read sentence weight definition if any ( physical_file.fea[s,e]{weight} ) + if ((pos = mLogical.find('{')) != std::string::npos) + { + std::string tmp_weight(mLogical.begin() + pos + 1, mLogical.end()); + std::stringstream tmp_ss(tmp_weight); + + tmp_ss >> mWeight; + mLogical.erase(pos); + } + + // look for "=" symbol and if found, split it + if ((pos = mLogical.find('=')) != std::string::npos) + { + // copy all from mLogical[pos+1] till the end to mPhysical + mPhysical.assign(mLogical.begin() + pos + 1, mLogical.end()); + // erase all from pos + 1 till the end from mLogical + mLogical.erase(pos); + // trim the leading and trailing spaces + Trim(mPhysical); + Trim(mLogical); + } + else + { + // trim the leading and trailing spaces + Trim(mLogical); + + mPhysical = mLogical; + } + } + + + //########################################################################### + //########################################################################### + // FeatureRepository section + //########################################################################### + //########################################################################### + + //*************************************************************************** + //*************************************************************************** + void + FeatureRepository:: + ReadCepsNormFile( + const char * pFileName, + char ** pLastFileName, + BaseFloat ** vec_buff, + int sampleKind, + CNFileType type, + int coefs) + { + FILE* fp; + int i; + char s1[64]; + char s2[64]; + const char* typeStr = (type == CNF_Mean ? "MEAN" : + type == CNF_Variance ? "VARIANCE" : "VARSCALE"); + + const char* typeStr2 = (type == CNF_Mean ? "CMN" : + type == CNF_Variance ? "CVN" : "VarScale"); + + if (*pLastFileName != NULL && !strcmp(*pLastFileName, pFileName)) { + return; + } + free(*pLastFileName); + *pLastFileName=strdup(pFileName); + *vec_buff = (BaseFloat*) realloc(*vec_buff, coefs * sizeof(BaseFloat)); + + if (*pLastFileName == NULL || *vec_buff== NULL) + throw std::runtime_error("Insufficient memory"); + + if ((fp = fopen(pFileName, "r")) == NULL) { + throw std::runtime_error(std::string("Cannot open ") + typeStr2 + + " pFileName: '" + pFileName + "'"); + } + + if ((type != CNF_VarScale + && (fscanf(fp, " <%64[^>]> <%64[^>]>", s1, s2) != 2 + || strcmp(StrToUpper(s1), "CEPSNORM") + || ReadParmKind(s2, false) != sampleKind)) + || fscanf(fp, " <%64[^>]> %d", s1, &i) != 2 + || strcmp(StrToUpper(s1), typeStr) + || i != coefs) + { + ParmKind2Str(sampleKind, s2); + + //std::cout << "[[[TADY!!!!]]]" << pFileName << "\n" << std::flush; + + throw std::runtime_error(std::string("") + + (type == CNF_VarScale ? "" : "<CEPSNORM> <") + + (type == CNF_VarScale ? "" : s2) + + (type == CNF_VarScale ? "" : ">") + + " <" + typeStr + " ... expected in " + typeStr2 + + " file " + pFileName); + } + + for (i = 0; i < coefs; i++) { + if (fscanf(fp, " "FLOAT_FMT, *vec_buff+i) != 1) { + if (fscanf(fp, "%64s", s2) == 1) { + throw std::runtime_error(std::string("Decimal number expected but '") + + s2 + "' found in " + typeStr2 + " file " + pFileName); + } + else if (feof(fp)) { + throw std::runtime_error(std::string("Unexpected end of ") + + typeStr2 + " file "+ pFileName); + } + else { + throw std::runtime_error(std::string("Cannot read ") + typeStr2 + + " file " + pFileName); + } + } + + if (type == CNF_Variance) + (*vec_buff)[i] = BaseFloat(1 / sqrt((*vec_buff)[i])); + else if (type == CNF_VarScale) + (*vec_buff)[i] = BaseFloat(sqrt((*vec_buff)[i])); + } + + if (fscanf(fp, "%64s", s2) == 1) + { + throw std::runtime_error(std::string("End of file expected but '") + + s2 + "' found in " + typeStr2 + " file " + pFileName); + } + + fclose(fp); + } // ReadCepsNormFile(...) + + + //*************************************************************************** + //*************************************************************************** + void + FeatureRepository:: + HtkFilter(const char* pFilter, const char* pValue, FeatureRepository& rOut) + { + std::list<FileListElem>::iterator it; + std::string str; + + rOut.mSwapFeatures = mSwapFeatures; + rOut.mStartFrameExt = mStartFrameExt; + rOut.mEndFrameExt = mEndFrameExt; + rOut.mTargetKind = mTargetKind; + rOut.mDerivOrder = mDerivOrder; + rOut.mDerivWinLengths = mDerivWinLengths; + + rOut.mpCvgFile = mpCvgFile; + rOut.mpCmnPath = mpCmnPath; + rOut.mpCmnMask = mpCmnMask; + rOut.mpCvnPath = mpCvnPath; + rOut.mpCvnMask = mpCvnMask; + + rOut.mInputQueue.clear(); + + // go through all records and check the mask + for (it=mInputQueue.begin(); it!= mInputQueue.end(); ++it) { + if (pFilter == NULL + || (ProcessMask(it->Logical(), pFilter, str) && (str == pValue))) { + rOut.mInputQueue.push_back(*it); + } + } + + // set the queue position to the begining + rOut.mInputQueueIterator = mInputQueue.end(); + + rOut.mCurrentIndexFileName = ""; + rOut.mCurrentIndexFileDir = ""; + rOut.mCurrentIndexFileExt = ""; + + mStream.close(); + mStream.clear(); + + rOut.mpLastFileName = NULL; + rOut.mLastFileName = ""; + rOut.mpLastCmnFile = NULL; + rOut.mpLastCvnFile = NULL; + rOut.mpLastCvgFile = NULL; + rOut.mpCmn = NULL; + rOut.mpCvn = NULL; + rOut.mpCvg = NULL; + rOut.mpA = NULL; + rOut.mpB = NULL; + + } + + + //*************************************************************************** + //*************************************************************************** + void + FeatureRepository:: + HtkSelection(const char* pFilter, std::list< std::string >& rOut) + { + std::map< std::string, bool> aux_map; + std::map< std::string, bool>::iterator map_it; + std::list<FileListElem>::iterator it; + std::string str; + + rOut.clear(); + + if(pFilter != NULL) { + // go through all records and check the mask + for (it=mInputQueue.begin(); it!= mInputQueue.end(); ++it) { + if (ProcessMask(it->Logical(), pFilter, str)) { + aux_map[str] = true; + } + } + } else { + aux_map[std::string("default speaker")] = true; + } + + for (map_it = aux_map.begin(); map_it != aux_map.end(); ++map_it) { + rOut.push_back(map_it->first); + } + } + + + //*************************************************************************** + //*************************************************************************** + int + FeatureRepository:: + ParmKind2Str(unsigned parmKind, char *pOutString) + { + // :KLUDGE: Absolutely no idea what this is... + if ((parmKind & 0x003F) >= sizeof(mpParmKindNames)/sizeof(mpParmKindNames[0])) + return 0; + + strcpy(pOutString, mpParmKindNames[parmKind & 0x003F]); + + if (parmKind & PARAMKIND_E) strcat(pOutString, "_E"); + if (parmKind & PARAMKIND_N) strcat(pOutString, "_N"); + if (parmKind & PARAMKIND_D) strcat(pOutString, "_D"); + if (parmKind & PARAMKIND_A) strcat(pOutString, "_A"); + if (parmKind & PARAMKIND_C) strcat(pOutString, "_C"); + if (parmKind & PARAMKIND_Z) strcat(pOutString, "_Z"); + if (parmKind & PARAMKIND_K) strcat(pOutString, "_K"); + if (parmKind & PARAMKIND_0) strcat(pOutString, "_0"); + if (parmKind & PARAMKIND_V) strcat(pOutString, "_V"); + if (parmKind & PARAMKIND_T) strcat(pOutString, "_T"); + + return 1; + } + + + // //*************************************************************************** + // //*************************************************************************** + // void + // AddFileListToFeatureRepositories( + // const char* pFileName, + // const char* pFilter, + // std::queue<FeatureRepository *> &featureRepositoryList) + // { + // IStkStream l_stream; + // std::string file_name; + // Tokenizer file_list(pFileName, ","); + // Tokenizer::iterator p_file_name; + + // //:TODO: error if empty featureRepositoryList + // + // for (p_file_name = file_list.begin(); p_file_name != file_list.end(); ++p_file_name) + // { + // // get rid of initial and trailing blanks + // Trim(*p_file_name); + + // // open file name + // l_stream.open(p_file_name->c_str(), std::ios::in, pFilter); + // + // if (!l_stream.good()) { + // //:TODO: + // // Warning or error ... Why warning? -Lukas + // throw std::runtime_error(std::string("Cannot not open list file ") + + // *p_file_name); + // } + + // // read all lines and parse them + // for(;;) + // { + // l_stream >> file_name; + // //:TODO: if(l_stream.badl()) Error() + // // Reading after last token set the fail bit + // if(l_stream.fail()) + // break; + // // we can push_back a std::string as new FileListElem object + // // is created using FileListElem(const std::string&) constructor + // // and logical and physical names are correctly extracted + // featureRepositoryList.front()->mInputQueue.push_back(file_name); + // + // //cycle in the featureRepositoryList + // featureRepositoryList.push(featureRepositoryList.front()); + // featureRepositoryList.pop(); + // } + // l_stream.close(); + // } + // } // AddFileList(const std::string & rFileName) + + + //*************************************************************************** + //*************************************************************************** + void + FeatureRepository:: + Init( + bool swap, + int extLeft, + int extRight, + int targetKind, + int derivOrder, + int* pDerivWinLen, + const char* pCmnPath, + const char* pCmnMask, + const char* pCvnPath, + const char* pCvnMask, + const char* pCvgFile) + { + mSwapFeatures = swap; + mStartFrameExt = extLeft; + mEndFrameExt = extRight; + mTargetKind = targetKind; + mDerivOrder = derivOrder; + mDerivWinLengths = pDerivWinLen; + mpCmnPath = pCmnPath; + mpCmnMask = pCmnMask; + mpCvnPath = pCvnPath; + mpCvnMask = pCvnMask; + mpCvgFile = pCvgFile; + } // Init() + + + //*************************************************************************** + //*************************************************************************** + void + FeatureRepository:: + AddFile(const std::string & rFileName) + { + mInputQueue.push_back(rFileName); + } // AddFile(const std::string & rFileName) + + + //*************************************************************************** + //*************************************************************************** + void + FeatureRepository:: + AddFileList(const char* pFileName, const char* pFilter) + { + IStkStream l_stream; + std::string file_name; + Tokenizer file_list(pFileName, ","); + Tokenizer::iterator p_file_name; + + for (p_file_name = file_list.begin(); p_file_name != file_list.end(); ++p_file_name) + { + // get rid of spaces + Trim(*p_file_name); + + // open the file + l_stream.open(p_file_name->c_str(), std::ios::in, pFilter); + + if (!l_stream.good()) + { + //:TODO: + // Warning or error ... Why warning? -Lukas + throw std::runtime_error(std::string("Cannot not open list file ") + + *p_file_name); + } + // read all lines and parse them + for(;;) + { + l_stream >> file_name; + //:TODO: if(l_stream.badl()) Error() + // Reading after last token set the fail bit + if(l_stream.fail()) + break; + // we can push_back a std::string as new FileListElem object + // is created using FileListElem(const std::string&) constructor + // and logical and physical names are correctly extracted + mInputQueue.push_back(file_name); + } + l_stream.close(); + } + } // AddFileList(const std::string & rFileName) + + + //*************************************************************************** + //*************************************************************************** + void + FeatureRepository:: + MoveNext() + { + assert (mInputQueueIterator != mInputQueue.end()); + mInputQueueIterator++; + } // ReadFullMatrix(Matrix<BaseFloat>& rMatrix) + + + //*************************************************************************** + //*************************************************************************** + bool + FeatureRepository:: + ReadFullMatrix(Matrix<BaseFloat>& rMatrix) + { + // clear the matrix + rMatrix.Destroy(); + + // extract index file name + if (!mCurrentIndexFileDir.empty()) + { + char tmp_name[mCurrentIndexFileDir.length() + + mCurrentIndexFileExt.length() + + mInputQueueIterator->Physical().length()]; + + MakeHtkFileName(tmp_name, mInputQueueIterator->Physical().c_str(), + mCurrentIndexFileDir.c_str(), mCurrentIndexFileExt.c_str()); + + mCurrentIndexFileName = tmp_name; + } + else + mCurrentIndexFileName = ""; + + //get the 3-letter suffix + int pos_last_three_chars = mInputQueueIterator->Physical().size() - 3; + if (pos_last_three_chars < 0) pos_last_three_chars = 0; + //read the gzipped ascii features + if (mInputQueueIterator->Physical().substr(pos_last_three_chars) == ".gz") { + return ReadGzipAsciiFeatures(*mInputQueueIterator, rMatrix); + } + + // read the matrix and return the result + return ReadHTKFeatures(*mInputQueueIterator, rMatrix); + } // ReadFullMatrix(Matrix<BaseFloat>& rMatrix) + + + + //*************************************************************************** + //*************************************************************************** + bool + FeatureRepository:: + WriteFeatureMatrix(const Matrix<BaseFloat>& rMatrix, const std::string& filename, int targetKind, int samplePeriod) + { + FILE* fp = fopen(filename.c_str(),"w"); + if(NULL == fp) { Error(std::string("Cannot create file:") + filename); return false; } + + WriteHTKFeatures(fp, samplePeriod, targetKind, mSwapFeatures, const_cast<Matrix<BaseFloat>&>(rMatrix)); + + fclose(fp); + + return true; + } + + + //*************************************************************************** + //*************************************************************************** + // private: + int + FeatureRepository:: + ReadHTKHeader() + { + // TODO + // Change this... We should read from StkStream + FILE* fp = mStream.fp(); + + if (!fread(&mHeader.mNSamples, sizeof(INT_32), 1, fp)) return -1; + if (!fread(&mHeader.mSamplePeriod, sizeof(INT_32), 1, fp)) return -1; + if (!fread(&mHeader.mSampleSize, sizeof(INT_16), 1, fp)) return -1; + if (!fread(&mHeader.mSampleKind, sizeof(UINT_16), 1, fp)) return -1; + + if (mSwapFeatures) + { + swap4(mHeader.mNSamples); + swap4(mHeader.mSamplePeriod); + swap2(mHeader.mSampleSize); + swap2(mHeader.mSampleKind); + } + + if (mHeader.mSamplePeriod < 0 + || mHeader.mSamplePeriod > 1000000 + || mHeader.mNSamples < 0 + || mHeader.mSampleSize < 0) + { + return -1; + } + + return 0; + } + + + //*************************************************************************** + //*************************************************************************** + // private: + int + FeatureRepository:: + ReadHTKFeature( + BaseFloat* pIn, + size_t feaLen, + bool decompress, + BaseFloat* pScale, + BaseFloat* pBias) + { + FILE* fp = mStream.fp(); + + size_t i; + + if (decompress) + { + INT_16 s; + // BaseFloat pScale = (xmax - xmin) / (2*32767); + // BaseFloat pBias = (xmax + xmin) / 2; + + for (i = 0; i < feaLen; i++) + { + if (fread(&s, sizeof(INT_16), 1, fp) != 1) + return -1; + + if (mSwapFeatures) swap2(s); + pIn[i] = ((BaseFloat)s + pBias[i]) / pScale[i]; + } + + return 0; + } + +#if !DOUBLEPRECISION + if (fread(pIn, sizeof(FLOAT_32), feaLen, fp) != feaLen) + return -1; + + if (mSwapFeatures) + for (i = 0; i < feaLen; i++) + swap4(pIn[i]); +#else + float f; + + for (i = 0; i < feaLen; i++) + { + if (fread(&f, sizeof(FLOAT_32), 1, fp) != 1) + return -1; + + if (mSwapFeatures) + swap4(f); + + pIn[i] = f; + } +#endif + return 0; + } // int ReadHTKFeature + + + + //*************************************************************************** + //*************************************************************************** +/* bool + FeatureRepository:: + ReadHTKFeatures(const std::string& rFileName, Matrix<BaseFloat>& rFeatureMatrix) + { + std::string file_name(rFileName); + std::string cmn_file_name; + std::string cvn_file_name; + + int ext_left = mStartFrameExt; + int ext_right = mEndFrameExt; + int from_frame; + int to_frame; + int tot_frames; + int trg_vec_size; + int src_vec_size; + int src_deriv_order; + int lo_src_tgz_deriv_order; + int i; + int j; + int k; + int e; + int coefs; + int trg_E; + int trg_0; + int trg_N; + int src_E; + int src_0; + int src_N; + int comp; + int coef_size; + char* chptr; + + + + // read frame range definition if any ( physical_file.fea[s,e] ) + if ((chptr = strrchr(file_name.c_str(), '[')) == NULL || + ((i=0), sscanf(chptr, "[%d,%d]%n", &from_frame, &to_frame, &i), + chptr[i] != '\0')) + { + chptr = NULL; + } + + if (chptr != NULL) + *chptr = '\0'; + + // Experimental changes... + // if ((strcmp(file_name.c_str(), "-")) + // && (mpLastFileName != NULL) + // && (!strcmp(mpLastFileName, file_name.c_str()))) + // { + // mHeader = mLastHeader; + // } + // else + // { + // if (mpLastFileName) + // { + // //if (mpFp != stdin) + // // fclose(mpFp); + // mStream.close(); + // + // free(mpLastFileName); + // mpLastFileName = NULL; + // } + + if ((file_name != "-" ) + && (!mLastFileName.empty()) + && (mLastFileName == file_name)) + { + mHeader = mLastHeader; + } + else + { + if (!mLastFileName.empty()) + { + mStream.close(); + mLastFileName = ""; + } + + + // open the feature file + mStream.open(file_name.c_str(), ios::binary); + if (!mStream.good()) + { + Error("Cannot open feature file: '%s'", file_name.c_str()); + } + + + if (ReadHTKHeader()) + Error("Invalid HTK header in feature file: '%s'", file_name.c_str()); + + if (mHeader.mSampleKind & PARAMKIND_C) + { + // File is in compressed form, scale and pBias vectors + // are appended after HTK header. + + int coefs = mHeader.mSampleSize/sizeof(INT_16); + mpA = (BaseFloat*) realloc(mpA, coefs * sizeof(BaseFloat)); + mpB = (BaseFloat*) realloc(mpB, coefs * sizeof(BaseFloat)); + if (mpA == NULL || mpB == NULL) Error("Insufficient memory"); + + e = ReadHTKFeature(mpA, coefs, 0, 0, 0); + e |= ReadHTKFeature(mpB, coefs, 0, 0, 0); + + if (e) + Error("Cannot read feature file: '%s'", file_name.c_str()); + + mHeader.mNSamples -= 2 * sizeof(FLOAT_32) / sizeof(INT_16); + } + + // remember current settings + mLastFileName = file_name; + mLastHeader = mHeader; + } + + if (chptr != NULL) + *chptr = '['; + + if (chptr == NULL) + { // Range [s,e] was not specified + from_frame = 0; + to_frame = mHeader.mNSamples-1; + } + + src_deriv_order = PARAMKIND_T & mHeader.mSampleKind ? 3 : + PARAMKIND_A & mHeader.mSampleKind ? 2 : + PARAMKIND_D & mHeader.mSampleKind ? 1 : 0; + src_E = (PARAMKIND_E & mHeader.mSampleKind) != 0; + src_0 = (PARAMKIND_0 & mHeader.mSampleKind) != 0; + src_N = ((PARAMKIND_N & mHeader.mSampleKind) != 0) * (src_E + src_0); + comp = PARAMKIND_C & mHeader.mSampleKind; + + mHeader.mSampleKind &= ~PARAMKIND_C; + + if (mTargetKind == PARAMKIND_ANON) + { + mTargetKind = mHeader.mSampleKind; + } + else if ((mTargetKind & 077) == PARAMKIND_ANON) + { + mTargetKind &= ~077; + mTargetKind |= mHeader.mSampleKind & 077; + } + + trg_E = (PARAMKIND_E & mTargetKind) != 0; + trg_0 = (PARAMKIND_0 & mTargetKind) != 0; + trg_N =((PARAMKIND_N & mTargetKind) != 0) * (trg_E + trg_0); + + coef_size = comp ? sizeof(INT_16) : sizeof(FLOAT_32); + coefs = (mHeader.mSampleSize/coef_size + src_N) / + (src_deriv_order+1) - src_E - src_0; + src_vec_size = (coefs + src_E + src_0) * (src_deriv_order+1) - src_N; + + //Is coefs dividable by 1 + number of derivatives specified in header + if (src_vec_size * coef_size != mHeader.mSampleSize) + { + Error("Invalid HTK header in feature file: '%s'. " + "mSampleSize do not match with parmKind", file_name.c_str()); + } + + if (mDerivOrder < 0) + mDerivOrder = src_deriv_order; + + + if ((!src_E && trg_E) || (!src_0 && trg_0) || (src_N && !trg_N) || + (trg_N && !trg_E && !trg_0) || (trg_N && !mDerivOrder) || + (src_N && !src_deriv_order && mDerivOrder) || + ((mHeader.mSampleKind & 077) != (mTargetKind & 077) && + (mHeader.mSampleKind & 077) != PARAMKIND_ANON)) + { + char srcParmKind[64]; + char trgParmKind[64]; + + ParmKind2Str(mHeader.mSampleKind, srcParmKind); + ParmKind2Str(mTargetKind, trgParmKind); + Error("Cannot convert %s to %s", srcParmKind, trgParmKind); + } + + lo_src_tgz_deriv_order = LOWER_OF(src_deriv_order, mDerivOrder); + trg_vec_size = (coefs + trg_E + trg_0) * (mDerivOrder+1) - trg_N; + + i = LOWER_OF(from_frame, mStartFrameExt); + from_frame -= i; + ext_left -= i; + + i = LOWER_OF(mHeader.mNSamples-to_frame-1, mEndFrameExt); + to_frame += i; + ext_right -= i; + + if (from_frame > to_frame || from_frame >= mHeader.mNSamples || to_frame< 0) + Error("Invalid frame range for feature file: '%s'", file_name.c_str()); + + tot_frames = to_frame - from_frame + 1 + ext_left + ext_right; + + // initialize matrix + rFeatureMatrix.Init(tot_frames, trg_vec_size); + + // fill the matrix with features + for (i = 0; i <= to_frame - from_frame; i++) + { + BaseFloat* A = mpA; + BaseFloat* B = mpB; + BaseFloat* mxPtr = rFeatureMatrix[i+ext_left]; + + // seek to the desired position + fseek(mStream.fp(), + sizeof(HtkHeader) + (comp ? src_vec_size * 2 * sizeof(FLOAT_32) : 0) + + (from_frame + i) * src_vec_size * coef_size, + SEEK_SET); + + e = ReadHTKFeature(mxPtr, coefs, comp, A, B); + + mxPtr += coefs; + A += coefs; + B += coefs; + + if (src_0 && !src_N) e |= ReadHTKFeature(mxPtr, 1, comp, A++, B++); + if (trg_0 && !trg_N) mxPtr++; + if (src_E && !src_N) e |= ReadHTKFeature(mxPtr, 1, comp, A++, B++); + if (trg_E && !trg_N) mxPtr++; + + for (j = 0; j < lo_src_tgz_deriv_order; j++) + { + e |= ReadHTKFeature(mxPtr, coefs, comp, A, B); + mxPtr += coefs; + A += coefs; + B += coefs; + + if (src_0) e |= ReadHTKFeature(mxPtr, 1, comp, A++, B++); + if (trg_0) mxPtr++; + if (src_E) e |= ReadHTKFeature(mxPtr, 1, comp, A++, B++); + if (trg_E) mxPtr++; + } + + if (e) + Error("Cannot read feature file: '%s' frame %d/%d", file_name.c_str(), + i, to_frame - from_frame + 1); + } + + // From now, coefs includes also trg_0 + trg_E ! + coefs += trg_0 + trg_E; + + // If extension of the matrix to the left or to the right is required, + // perform it here + for (i = 0; i < ext_left; i++) + { + memcpy(rFeatureMatrix[i], + rFeatureMatrix[ext_left], + (coefs * (1+lo_src_tgz_deriv_order) - trg_N) * sizeof(BaseFloat)); + } + + for (i = tot_frames - ext_right; i < tot_frames; i++) + { + memcpy(rFeatureMatrix[i], + rFeatureMatrix[tot_frames - ext_right - 1], + (coefs * (1+lo_src_tgz_deriv_order) - trg_N) * sizeof(BaseFloat)); + } + + // Sentence cepstral mean normalization + if( (mpCmnPath == NULL) + && !(PARAMKIND_Z & mHeader.mSampleKind) + && (PARAMKIND_Z & mTargetKind)) + { + // for each coefficient + for(j=0; j < coefs; j++) + { + BaseFloat norm = 0.0; + for(i=0; i < tot_frames; i++) // for each frame + { + norm += rFeatureMatrix[i][j - trg_N]; + //norm += fea_mx[i*trg_vec_size - trg_N + j]; + } + + norm /= tot_frames; + + for(i=0; i < tot_frames; i++) // for each frame + rFeatureMatrix[i][j - trg_N] -= norm; + //fea_mx[i*trg_vec_size - trg_N + j] -= norm; + } + } + + // Compute missing derivatives + for (; src_deriv_order < mDerivOrder; src_deriv_order++) + { + int winLen = mDerivWinLengths[src_deriv_order]; + BaseFloat norm = 0.0; + + for (k = 1; k <= winLen; k++) + { + norm += 2 * k * k; + } + + // for each frame + for (i=0; i < tot_frames; i++) + { + // for each coefficient + for (j=0; j < coefs; j++) + { + //BaseFloat* src = fea_mx + i*trg_vec_size + src_deriv_order*coefs - trg_N + j; + BaseFloat* src = &rFeatureMatrix[i][src_deriv_order*coefs - trg_N + j]; + + *(src + coefs) = 0.0; + + if (i < winLen || i >= tot_frames-winLen) + { // boundaries need special treatment + for (k = 1; k <= winLen; k++) + { + *(src+coefs) += k*(src[ LOWER_OF(tot_frames-1-i,k)*rFeatureMatrix.Stride()] + -src[-LOWER_OF(i, k)*rFeatureMatrix.Stride()]); + } + } + else + { // otherwise use more efficient code + for (k = 1; k <= winLen; k++) + { + *(src+coefs) += k*(src[ k * rFeatureMatrix.Stride()] + -src[-k * rFeatureMatrix.Stride()]); + } + } + *(src + coefs) /= norm; + } + } + } + + mHeader.mNSamples = tot_frames; + mHeader.mSampleSize = trg_vec_size * sizeof(FLOAT_32); + mHeader.mSampleKind = mTargetKind & ~(PARAMKIND_D | PARAMKIND_A | PARAMKIND_T); + + + //////////////////////////////////////////////////////////////////////////// + /////////////// Cepstral mean and variance normalization /////////////////// + //////////////////////////////////////////////////////////////////////////// + //......................................................................... + if (mpCmnPath != NULL + && mpCmnMask != NULL) + { + // retrieve file name + ProcessMask(file_name, mpCmnMask, cmn_file_name); + // add the path correctly + cmn_file_name.insert(0, "/"); + cmn_file_name.insert(0, mpCmnPath); + + // read the file + ReadCepsNormFile(cmn_file_name.c_str(), &mpLastCmnFile, &mpCmn, + mHeader.mSampleKind & ~PARAMKIND_Z, CNF_Mean, coefs); + + // recompute feature values + for (i=0; i < tot_frames; i++) + { + for (j=trg_N; j < coefs; j++) + { + rFeatureMatrix[i][j - trg_N] -= mpCmn[j]; + } + } + } + + mHeader.mSampleKind |= mDerivOrder==3 ? PARAMKIND_D | PARAMKIND_A | PARAMKIND_T : + mDerivOrder==2 ? PARAMKIND_D | PARAMKIND_A : + mDerivOrder==1 ? PARAMKIND_D : 0; + + //......................................................................... + if (mpCvnPath != NULL + && mpCvnMask != NULL) + { + // retrieve file name + ProcessMask(file_name, mpCvnMask, cvn_file_name); + // add the path correctly + cvn_file_name.insert(0, "/"); + cvn_file_name.insert(0, mpCvnPath); + + // read the file + ReadCepsNormFile(cvn_file_name.c_str(), &mpLastCvnFile, &mpCvn, + mHeader.mSampleKind, CNF_Variance, trg_vec_size); + + // recompute feature values + for (i=0; i < tot_frames; i++) + { + for (j=trg_N; j < trg_vec_size; j++) + { + rFeatureMatrix[i][j - trg_N] *= mpCvn[j]; + } + } + } + + //......................................................................... + // process the global covariance file + if (mpCvgFile != NULL) + { + ReadCepsNormFile(mpCvgFile, &mpLastCvgFile, &mpCvg, + -1, CNF_VarScale, trg_vec_size); + + // recompute feature values + for (i=0; i < tot_frames; i++) + { + for (j=trg_N; j < trg_vec_size; j++) + { + rFeatureMatrix[i][j - trg_N] *= mpCvg[j]; + } + } + } + + return true; + } +*/ + + //*************************************************************************** + //*************************************************************************** + + + + + + //*************************************************************************** + //*************************************************************************** + bool + FeatureRepository:: + ReadHTKFeatures(const FileListElem& rFileNameRecord, + Matrix<BaseFloat>& rFeatureMatrix) + { + std::string file_name(rFileNameRecord.Physical()); + std::string cmn_file_name; + std::string cvn_file_name; + + int ext_left = mStartFrameExt; + int ext_right = mEndFrameExt; + int from_frame; + int to_frame; + int tot_frames; + int trg_vec_size; + int src_vec_size; + int src_deriv_order; + int lo_src_tgz_deriv_order; + int i; + int j; + int k; + int e; + int coefs; + int trg_E; + int trg_0; + int trg_N; + int src_E; + int src_0; + int src_N; + int comp; + int coef_size; + char* chptr; + + + TIMER_START(mTim); + + // read frame range definition if any ( physical_file.fea[s,e] ) + if ((chptr = strrchr((char*)file_name.c_str(), '[')) == NULL || + ((i=0), sscanf(chptr, "[%d,%d]%n", &from_frame, &to_frame, &i), + chptr[i] != '\0')) + { + chptr = NULL; + } + + if (chptr != NULL) + *chptr = '\0'; + + + if ((file_name != "-" ) + && (!mLastFileName.empty()) + && (mLastFileName == file_name)) + { + mHeader = mLastHeader; + } + else + { + if (!mLastFileName.empty()) + { + mStream.close(); + mLastFileName = ""; + } + + + // open the feature file + mStream.open(file_name.c_str(), std::ios::binary); + if (!mStream.good()) + { + throw std::runtime_error(std::string("Cannot open feature file: '") + + file_name.c_str() + "'"); + } + + + if (ReadHTKHeader()) { + throw std::runtime_error(std::string("Invalid HTK header in feature file: '") + + file_name.c_str() + "'"); + } + + if (mHeader.mSampleKind & PARAMKIND_C) + { + // File is in compressed form, scale and pBias vectors + // are appended after HTK header. + coefs = mHeader.mSampleSize/sizeof(INT_16); + + mpA = (BaseFloat*) realloc(mpA, coefs * sizeof(BaseFloat)); + mpB = (BaseFloat*) realloc(mpB, coefs * sizeof(BaseFloat)); + + if (mpA == NULL || mpB == NULL) { + throw std::runtime_error("Insufficient memory"); + } + + e = ReadHTKFeature(mpA, coefs, 0, 0, 0); + e |= ReadHTKFeature(mpB, coefs, 0, 0, 0); + + if (e) { + throw std::runtime_error(std::string("Cannot read feature file: '") + + file_name.c_str() + "'"); + } + + mHeader.mNSamples -= 2 * sizeof(FLOAT_32) / sizeof(INT_16); + } + + // remember current settings + mLastFileName = file_name; + mLastHeader = mHeader; + } + + if (chptr != NULL) { + *chptr = '['; + } + + if (chptr == NULL) { + // Range [s,e] was not specified + from_frame = 0; + to_frame = mHeader.mNSamples-1; + } + + src_deriv_order = PARAMKIND_T & mHeader.mSampleKind ? 3 : + PARAMKIND_A & mHeader.mSampleKind ? 2 : + PARAMKIND_D & mHeader.mSampleKind ? 1 : 0; + src_E = (PARAMKIND_E & mHeader.mSampleKind) != 0; + src_0 = (PARAMKIND_0 & mHeader.mSampleKind) != 0; + src_N = ((PARAMKIND_N & mHeader.mSampleKind) != 0) * (src_E + src_0); + comp = PARAMKIND_C & mHeader.mSampleKind; + + mHeader.mSampleKind &= ~PARAMKIND_C; + + if (mTargetKind == PARAMKIND_ANON) + { + mTargetKind = mHeader.mSampleKind; + } + else if ((mTargetKind & 077) == PARAMKIND_ANON) + { + mTargetKind &= ~077; + mTargetKind |= mHeader.mSampleKind & 077; + } + + trg_E = (PARAMKIND_E & mTargetKind) != 0; + trg_0 = (PARAMKIND_0 & mTargetKind) != 0; + trg_N =((PARAMKIND_N & mTargetKind) != 0) * (trg_E + trg_0); + + coef_size = comp ? sizeof(INT_16) : sizeof(FLOAT_32); + coefs = (mHeader.mSampleSize/coef_size + src_N) / + (src_deriv_order+1) - src_E - src_0; + src_vec_size = (coefs + src_E + src_0) * (src_deriv_order+1) - src_N; + + //Is coefs dividable by 1 + number of derivatives specified in header + if (src_vec_size * coef_size != mHeader.mSampleSize) + { + throw std::runtime_error(std::string("Invalid HTK header in feature file: '") + + file_name + "' mSampleSize do not match with parmKind"); + } + + if (mDerivOrder < 0) + mDerivOrder = src_deriv_order; + + + if ((!src_E && trg_E) || (!src_0 && trg_0) || (src_N && !trg_N) || + (trg_N && !trg_E && !trg_0) || (trg_N && !mDerivOrder) || + (src_N && !src_deriv_order && mDerivOrder) || + ((mHeader.mSampleKind & 077) != (mTargetKind & 077) && + (mHeader.mSampleKind & 077) != PARAMKIND_ANON)) + { + char srcParmKind[64]; + char trgParmKind[64]; + memset(srcParmKind,0,64); + memset(trgParmKind,0,64); + + ParmKind2Str(mHeader.mSampleKind, srcParmKind); + ParmKind2Str(mTargetKind, trgParmKind); + throw std::runtime_error(std::string("Cannot convert ") + srcParmKind + + " to " + trgParmKind); + } + + lo_src_tgz_deriv_order = std::min(src_deriv_order, mDerivOrder); + trg_vec_size = (coefs + trg_E + trg_0) * (mDerivOrder+1) - trg_N; + + i = std::min(from_frame, mStartFrameExt); + from_frame -= i; + ext_left -= i; + + i = std::min(mHeader.mNSamples-to_frame-1, mEndFrameExt); + to_frame += i; + ext_right -= i; + + if (from_frame > to_frame || from_frame >= mHeader.mNSamples || to_frame< 0) + throw std::runtime_error(std::string("Invalid frame range for feature file: '") + + file_name.c_str() + "'"); + + tot_frames = to_frame - from_frame + 1 + ext_left + ext_right; + + + TIMER_END(mTim,mTimeOpen); + + + // initialize matrix + rFeatureMatrix.Init(tot_frames, trg_vec_size, false); + + // fill the matrix with features + for (i = 0; i <= to_frame - from_frame; i++) + { + BaseFloat* A = mpA; + BaseFloat* B = mpB; + BaseFloat* mxPtr = rFeatureMatrix.pRowData(i+ext_left); + + TIMER_START(mTim); + // seek to the desired position + fseek(mStream.fp(), + sizeof(HtkHeader) + (comp ? src_vec_size * 2 * sizeof(FLOAT_32) : 0) + + (from_frame + i) * src_vec_size * coef_size, + SEEK_SET); + TIMER_END(mTim,mTimeSeek); + + TIMER_START(mTim); + // read + e = ReadHTKFeature(mxPtr, coefs, comp, A, B); + TIMER_END(mTim,mTimeRead); + + mxPtr += coefs; + A += coefs; + B += coefs; + + if (src_0 && !src_N) e |= ReadHTKFeature(mxPtr, 1, comp, A++, B++); + if (trg_0 && !trg_N) mxPtr++; + if (src_E && !src_N) e |= ReadHTKFeature(mxPtr, 1, comp, A++, B++); + if (trg_E && !trg_N) mxPtr++; + + for (j = 0; j < lo_src_tgz_deriv_order; j++) + { + e |= ReadHTKFeature(mxPtr, coefs, comp, A, B); + mxPtr += coefs; + A += coefs; + B += coefs; + + if (src_0) e |= ReadHTKFeature(mxPtr, 1, comp, A++, B++); + if (trg_0) mxPtr++; + if (src_E) e |= ReadHTKFeature(mxPtr, 1, comp, A++, B++); + if (trg_E) mxPtr++; + } + + if (e) { + std::cout << mHeader.mNSamples << "\n"; + std::cout << 2 * sizeof(FLOAT_32) / sizeof(INT_16) << "\n"; + std::cout << "from" << from_frame << "to" << to_frame << "i" << i << "\n"; + + std::ostringstream s; + s << i << "/" << to_frame - from_frame + 1, s.str(); + throw std::runtime_error(std::string("Cannot read feature file: '") + + file_name + "' frame " + s.str()); + } + } + + // From now, coefs includes also trg_0 + trg_E ! + coefs += trg_0 + trg_E; + + // If extension of the matrix to the left or to the right is required, + // perform it here + for (i = 0; i < ext_left; i++) + { + memcpy(rFeatureMatrix.pRowData(i), + rFeatureMatrix.pRowData(ext_left), + (coefs * (1+lo_src_tgz_deriv_order) - trg_N) * sizeof(BaseFloat)); + } + + for (i = tot_frames - ext_right; i < tot_frames; i++) + { + memcpy(rFeatureMatrix.pRowData(i), + rFeatureMatrix.pRowData(tot_frames - ext_right - 1), + (coefs * (1+lo_src_tgz_deriv_order) - trg_N) * sizeof(BaseFloat)); + } + + // Sentence cepstral mean normalization + if( (mpCmnPath == NULL) + && !(PARAMKIND_Z & mHeader.mSampleKind) + && (PARAMKIND_Z & mTargetKind)) + { + // for each coefficient + for(j=0; j < coefs; j++) + { + BaseFloat norm = 0.0; + for(i=0; i < tot_frames; i++) // for each frame + { + norm += rFeatureMatrix[i][j - trg_N]; + //norm += fea_mx[i*trg_vec_size - trg_N + j]; + } + + norm /= tot_frames; + + for(i=0; i < tot_frames; i++) // for each frame + rFeatureMatrix[i][j - trg_N] -= norm; + //fea_mx[i*trg_vec_size - trg_N + j] -= norm; + } + } + + // Compute missing derivatives + for (; src_deriv_order < mDerivOrder; src_deriv_order++) + { + int winLen = mDerivWinLengths[src_deriv_order]; + BaseFloat norm = 0.0; + + for (k = 1; k <= winLen; k++) + { + norm += 2 * k * k; + } + + // for each frame + for (i=0; i < tot_frames; i++) + { + // for each coefficient + for (j=0; j < coefs; j++) + { + //BaseFloat* src = fea_mx + i*trg_vec_size + src_deriv_order*coefs - trg_N + j; + BaseFloat* src = &rFeatureMatrix[i][src_deriv_order*coefs - trg_N + j]; + + *(src + coefs) = 0.0; + + if (i < winLen || i >= tot_frames-winLen) + { // boundaries need special treatment + for (k = 1; k <= winLen; k++) + { + *(src+coefs) += k*(src[ std::min(tot_frames-1-i,k)*rFeatureMatrix.Stride()] + -src[-std::min(i, k)*rFeatureMatrix.Stride()]); + } + } + else + { // otherwise use more efficient code + for (k = 1; k <= winLen; k++) + { + *(src+coefs) += k*(src[ k * rFeatureMatrix.Stride()] + -src[-k * rFeatureMatrix.Stride()]); + } + } + *(src + coefs) /= norm; + } + } + } + + mHeader.mNSamples = tot_frames; + mHeader.mSampleSize = trg_vec_size * sizeof(FLOAT_32); + mHeader.mSampleKind = mTargetKind & ~(PARAMKIND_D | PARAMKIND_A | PARAMKIND_T); + + + TIMER_START(mTim); + //////////////////////////////////////////////////////////////////////////// + /////////////// Cepstral mean and variance normalization /////////////////// + //////////////////////////////////////////////////////////////////////////// + //......................................................................... + if (mpCmnPath != NULL + && mpCmnMask != NULL) + { + // retrieve file name + ProcessMask(rFileNameRecord.Logical(), mpCmnMask, cmn_file_name); + // add the path correctly + + if(cmn_file_name == "") { + throw std::runtime_error("CMN Matching failed"); + } + + cmn_file_name.insert(0, "/"); + cmn_file_name.insert(0, mpCmnPath); + + // read the file + ReadCepsNormFile(cmn_file_name.c_str(), &mpLastCmnFile, &mpCmn, + mHeader.mSampleKind & ~PARAMKIND_Z, CNF_Mean, coefs); + + // recompute feature values + for (i=0; i < tot_frames; i++) + { + for (j=trg_N; j < coefs; j++) + { + rFeatureMatrix[i][j - trg_N] -= mpCmn[j]; + } + } + } + + mHeader.mSampleKind |= mDerivOrder==3 ? PARAMKIND_D | PARAMKIND_A | PARAMKIND_T : + mDerivOrder==2 ? PARAMKIND_D | PARAMKIND_A : + mDerivOrder==1 ? PARAMKIND_D : 0; + + //......................................................................... + if (mpCvnPath != NULL + && mpCvnMask != NULL) + { + // retrieve file name + ProcessMask(rFileNameRecord.Logical(), mpCvnMask, cvn_file_name); + // add the path correctly + cvn_file_name.insert(0, "/"); + cvn_file_name.insert(0, mpCvnPath); + + // read the file + ReadCepsNormFile(cvn_file_name.c_str(), &mpLastCvnFile, &mpCvn, + mHeader.mSampleKind, CNF_Variance, trg_vec_size); + + // recompute feature values + for (i=0; i < tot_frames; i++) + { + for (j=trg_N; j < trg_vec_size; j++) + { + rFeatureMatrix[i][j - trg_N] *= mpCvn[j]; + } + } + } + + //......................................................................... + // process the global covariance file + if (mpCvgFile != NULL) + { + ReadCepsNormFile(mpCvgFile, &mpLastCvgFile, &mpCvg, + -1, CNF_VarScale, trg_vec_size); + + // recompute feature values + for (i=0; i < tot_frames; i++) + { + for (j=trg_N; j < trg_vec_size; j++) + { + rFeatureMatrix[i][j - trg_N] *= mpCvg[j]; + } + } + } + + TIMER_END(mTim,mTimeNormalize); + + return true; + } + + + //*************************************************************************** + //*************************************************************************** + int + FeatureRepository:: + ReadParmKind(const char *str, bool checkBrackets) + { + unsigned int i; + int parmKind =0; + int slen = strlen(str); + + if (checkBrackets) + { + if (str[0] != '<' || str[slen-1] != '>') return -1; + str++; slen -= 2; + } + + for (; slen >= 0 && str[slen-2] == '_'; slen -= 2) + { + parmKind |= str[slen-1] == 'E' ? PARAMKIND_E : + str[slen-1] == 'N' ? PARAMKIND_N : + str[slen-1] == 'D' ? PARAMKIND_D : + str[slen-1] == 'A' ? PARAMKIND_A : + str[slen-1] == 'C' ? PARAMKIND_C : + str[slen-1] == 'Z' ? PARAMKIND_Z : + str[slen-1] == 'K' ? PARAMKIND_K : + str[slen-1] == '0' ? PARAMKIND_0 : + str[slen-1] == 'V' ? PARAMKIND_V : + str[slen-1] == 'T' ? PARAMKIND_T : -1; + + if (parmKind == -1) return -1; + } + + for (i = 0; i < sizeof(mpParmKindNames) / sizeof(char*); i++) + { + if (!strncmp(str, mpParmKindNames[i], slen)) + return parmKind | i; + } + return -1; + } + + + + + //*************************************************************************** + //*************************************************************************** + int + FeatureRepository:: + WriteHTKHeader (FILE * pOutFp, HtkHeader header, bool swap) + { + int cc; + + if (swap) { + swap4(header.mNSamples); + swap4(header.mSamplePeriod); + swap2(header.mSampleSize); + swap2(header.mSampleKind); + } + + fseek (pOutFp, 0L, SEEK_SET); + cc = fwrite(&header, sizeof(HtkHeader), 1, pOutFp); + + if (swap) { + swap4(header.mNSamples); + swap4(header.mSamplePeriod); + swap2(header.mSampleSize); + swap2(header.mSampleKind); + } + + return cc == 1 ? 0 : -1; + } + + + //*************************************************************************** + //*************************************************************************** + int + FeatureRepository:: + WriteHTKFeature( + FILE * pOutFp, + FLOAT * pOut, + size_t feaLen, + bool swap, + bool compress, + FLOAT* pScale, + FLOAT* pBias) + { + size_t i; + size_t cc = 0; + + + if (compress) + { + INT_16 s; + + for (i = 0; i < feaLen; i++) + { + s = pOut[i] * pScale[i] - pBias[i]; + if (swap) + swap2(s); + cc += fwrite(&s, sizeof(INT_16), 1, pOutFp); + } + + } else { + #if !DOUBLEPRECISION + if (swap) + for (i = 0; i < feaLen; i++) + swap4(pOut[i]); + + cc = fwrite(pOut, sizeof(FLOAT_32), feaLen, pOutFp); + + if (swap) + for (i = 0; i < feaLen; i++) + swap4(pOut[i]); + #else + FLOAT_32 f; + + for (i = 0; i < feaLen; i++) + { + f = pOut[i]; + if (swap) + swap4(f); + cc += fwrite(&f, sizeof(FLOAT_32), 1, pOutFp); + } + #endif + } + return cc == feaLen ? 0 : -1; + } + + //*************************************************************************** + //*************************************************************************** + int + FeatureRepository:: + WriteHTKFeatures( + FILE * pOutFp, + FLOAT * pOut, + int nCoeffs, + int nSamples, + int samplePeriod, + int targetKind, + bool swap) + { + HtkHeader header; + int i, j; + FLOAT *pScale = NULL; + FLOAT *pBias = NULL; + + header.mNSamples = nSamples + ((targetKind & PARAMKIND_C) ? 2 * sizeof(FLOAT_32) / sizeof(INT_16) : 0); + header.mSamplePeriod = samplePeriod; + header.mSampleSize = nCoeffs * ((targetKind & PARAMKIND_C) ? sizeof(INT_16) : sizeof(FLOAT_32));; + header.mSampleKind = targetKind; + + WriteHTKHeader (pOutFp, header, swap); + + if(targetKind & PARAMKIND_C) { + pScale = (FLOAT*) malloc(nCoeffs * sizeof(FLOAT)); + pBias = (FLOAT*) malloc(nCoeffs * sizeof(FLOAT)); + if (pScale == NULL || pBias == NULL) Error("Insufficient memory"); + + for(i = 0; i < nCoeffs; i++) { + float xmin, xmax; + xmin = xmax = pOut[i]; + for(j = 1; j < nSamples; j++) { + if(pOut[j*nCoeffs+i] > xmax) xmax = pOut[j*nCoeffs+i]; + if(pOut[j*nCoeffs+i] < xmin) xmin = pOut[j*nCoeffs+i]; + } + pScale[i] = (2*32767) / (xmax - xmin); + pBias[i] = pScale[i] * (xmax + xmin) / 2; + + + } + if (WriteHTKFeature(pOutFp, pScale, nCoeffs, swap, false, 0, 0) + || WriteHTKFeature(pOutFp, pBias, nCoeffs, swap, false, 0, 0)) { + return -1; + } + } + for(j = 0; j < nSamples; j++) { + if (WriteHTKFeature(pOutFp, &pOut[j*nCoeffs], nCoeffs, swap, targetKind & PARAMKIND_C, pScale, pBias)) { + return -1; + } + } + return 0; + } + + + //*************************************************************************** + //*************************************************************************** + int + FeatureRepository:: + WriteHTKFeatures( + FILE * pOutFp, + int samplePeriod, + int targetKind, + bool swap, + Matrix<BaseFloat>& rFeatureMatrix) + { + HtkHeader header; + size_t i, j; + FLOAT *p_scale = NULL; + FLOAT *p_bias = NULL; + size_t n_samples = rFeatureMatrix.Rows(); + size_t n_coeffs = rFeatureMatrix.Cols(); + + header.mNSamples = n_samples + ((targetKind & PARAMKIND_C) ? 2 * sizeof(FLOAT_32) / sizeof(INT_16) : 0); + header.mSamplePeriod = samplePeriod; + header.mSampleSize = n_coeffs * ((targetKind & PARAMKIND_C) ? sizeof(INT_16) : sizeof(FLOAT_32));; + header.mSampleKind = targetKind; + + WriteHTKHeader (pOutFp, header, swap); + + if(targetKind & PARAMKIND_C) { + p_scale = (FLOAT*) malloc(n_coeffs * sizeof(FLOAT)); + p_bias = (FLOAT*) malloc(n_coeffs * sizeof(FLOAT)); + if (p_scale == NULL || p_bias == NULL) Error("Insufficient memory"); + + for(i = 0; i < n_coeffs; i++) { + float xmin, xmax; + xmin = xmax = rFeatureMatrix[0][i]; + + for(j = 1; j < n_samples; j++) { + if(rFeatureMatrix[j][i] > xmax) xmax = rFeatureMatrix[j][i]; + if(rFeatureMatrix[j][i] < xmin) xmin = rFeatureMatrix[j][i]; + } + + p_scale[i] = (2*32767) / (xmax - xmin); + p_bias[i] = p_scale[i] * (xmax + xmin) / 2; + } + + if (WriteHTKFeature(pOutFp, p_scale, n_coeffs, swap, false, 0, 0) + || WriteHTKFeature(pOutFp, p_bias, n_coeffs, swap, false, 0, 0)) { + return -1; + } + } + + for(j = 0; j < n_samples; j++) { + if (WriteHTKFeature(pOutFp, rFeatureMatrix[j].pData(), n_coeffs, swap, targetKind & PARAMKIND_C, p_scale, p_bias)) { + return -1; + } + } + + return 0; + } + + //*************************************************************************** + //*************************************************************************** + + + bool + FeatureRepository:: + ReadGzipAsciiFeatures(const FileListElem& rFileNameRecord, Matrix<BaseFloat>& rFeatureMatrix) + { + //build the command + std::string cmd("gunzip -c "); cmd += rFileNameRecord.Physical(); + + //define buffer + const int buf_size=262144; + char buf[buf_size]; + char vbuf[2*buf_size]; + + TIMER_START(mTim); + //open the pipe + FILE* fp = popen(cmd.c_str(),"r"); + if(fp == NULL) { + //2nd try... + Warning(std::string("2nd try to open pipe: ")+cmd); + sleep(5); + fp = popen(cmd.c_str(),"r"); + if(fp == NULL) { + KALDI_ERR << "Cannot open pipe: " << cmd; + } + } + setvbuf(fp,vbuf,_IOFBF,2*buf_size); + TIMER_END(mTim,mTimeOpen); + + //string will stay allocated across calls + static std::string line; line.resize(0); + + //define matrix storage + static int cols = 131072; + std::list<std::vector<BaseFloat> > matrix(1); + matrix.front().reserve(cols); + + //read all the lines to a vector + int line_ctr=1; + while(1) { + TIMER_START(mTim); + if(NULL == fgets(buf,buf_size,fp)) break; + TIMER_END(mTim,mTimeRead); + + line += buf; + if(*(line.rbegin()) == '\n' || feof(fp)) { + //parse the line of numbers + TIMER_START(mTim); + const char* ptr = line.c_str(); + char* end; + while(1) { + //skip whitespace + while(isspace(*ptr)) ptr++; + if(*ptr == 0) break; + //check that a number follows + switch(*ptr) { + case '0': case '1': case '2': case '3': case '4': + case '5': case '6': case '7': case '8': case '9': + case '.': case '+': case '-': + break; + default : KALDI_ERR << "A number was expected:" << ptr + << " reading from" << cmd; + exit(1); + } + //read a number + BaseFloat val = strtof(ptr,&end); ptr=end; + matrix.back().push_back(val); + } + TIMER_END(mTim,mTimeNormalize); + //we have the line of numbers, insert empty row to matrix + if(matrix.back().size() > 0 && !feof(fp)) { + matrix.push_back(std::vector<BaseFloat>()); + matrix.back().reserve(matrix.front().size()); + } + //dispose the current line + line.resize(0);//but stay allocated... + line_ctr++; + } + } + if(matrix.back().size() == 0) matrix.pop_back(); + + //get matrix dimensions + int rows = matrix.size(); + /*int*/ cols = matrix.front().size(); + + //define interators + std::list<std::vector<BaseFloat> >::iterator it_r; + std::vector<BaseFloat>::iterator it_c; + + //check that all lines have same size + int i; + for(i=0,it_r=matrix.begin(); it_r != matrix.end(); ++i,++it_r) { + if(it_r->size() != cols) { + KALDI_ERR << "All rows must have same dimension, 1st line cols: " << cols + << ", " << i << "th line cols: " << it_r->size(); + } + } + + //copy data to matrix + TIMER_START(mTim); + rFeatureMatrix.Init(rows,cols); + int r,c; + for(r=0,it_r=matrix.begin(); it_r!=matrix.end(); ++r,++it_r) { + for(c=0,it_c=it_r->begin(); it_c!=it_r->end(); ++c,++it_c) { + rFeatureMatrix(r,c) = *it_c; + } + } + TIMER_END(mTim,mTimeSeek); + + //close the pipe + if(pclose(fp) == -1) { + KALDI_ERR << "Cannot close pipe: " << cmd; + } + + return true; + } + + + //*************************************************************************** + //*************************************************************************** + +} // namespace TNet diff --git a/htk_io/src/KaldiLib/Features.h b/htk_io/src/KaldiLib/Features.h new file mode 100644 index 0000000..0980ab6 --- /dev/null +++ b/htk_io/src/KaldiLib/Features.h @@ -0,0 +1,597 @@ +// +// C++ Interface: %{MODULE} +// +// Description: +// +// +// Author: %{AUTHOR} <%{EMAIL}>, (C) %{YEAR} +// +// Copyright: See COPYING file that comes with this distribution +// +// + +#ifndef TNet_Features_h +#define TNet_Features_h + +//***************************************************************************** +//***************************************************************************** +// Standard includes +// +#include <list> +#include <queue> +#include <string> + + +//***************************************************************************** +//***************************************************************************** +// Specific includes +// +#include "Common.h" +#include "Matrix.h" +#include "StkStream.h" +#include "Types.h" +#include "Timer.h" + + + +// we need these for reading and writing +#define UINT_16 unsigned short +#define UINT_32 unsigned +#define INT_16 short +#define INT_32 int +#define FLOAT_32 float +#define DOUBLE_64 double + + +#define PARAMKIND_WAVEFORM 0 +#define PARAMKIND_LPC 1 +#define PARAMKIND_LPREFC 2 +#define PARAMKIND_LPCEPSTRA 3 +#define PARAMKIND_LPDELCEP 4 +#define PARAMKIND_IREFC 5 +#define PARAMKIND_MFCC 6 +#define PARAMKIND_FBANK 7 +#define PARAMKIND_MELSPEC 8 +#define PARAMKIND_USER 9 +#define PARAMKIND_DISCRETE 10 +#define PARAMKIND_PLP 11 +#define PARAMKIND_ANON 12 + +#define PARAMKIND_E 0000100 /// has energy +#define PARAMKIND_N 0000200 /// absolute energy suppressed +#define PARAMKIND_D 0000400 /// has delta coefficients +#define PARAMKIND_A 0001000 /// has acceleration coefficients +#define PARAMKIND_C 0002000 /// is compressed +#define PARAMKIND_Z 0004000 /// has zero mean static coef. +#define PARAMKIND_K 0010000 /// has CRC checksum +#define PARAMKIND_0 0020000 /// has 0'th cepstral coef. +#define PARAMKIND_V 0040000 /// has VQ codebook index +#define PARAMKIND_T 0100000 /// has triple delta coefficients + + +//***************************************************************************** +//***************************************************************************** +// Code ... +// + +namespace TNet +{ + + /** ************************************************************************** + ** ************************************************************************** + */ + class FileListElem + { + private: + std::string mLogical; ///< Logical file name representation + std::string mPhysical; ///< Pysical file name representation + float mWeight; + + public: + FileListElem(const std::string & rFileName); + ~FileListElem() {} + + const std::string & + Logical() const { return mLogical; } + + const std::string & + Physical() const { return mPhysical; } + + const float& + Weight() const { return mWeight; } + }; + + /** ************************************************************************* + * @brief + */ + class FeatureRepository + { + public: + /** + * @brief HTK parameter file header (see HTK manual) + */ + struct HtkHeader + { + int mNSamples; + int mSamplePeriod; + short mSampleSize; + short mSampleKind; + + HtkHeader() + : mNSamples(0),mSamplePeriod(100000),mSampleSize(0),mSampleKind(12) + { } + }; + + + /** + * @brief Extension of the HTK header + */ + struct HtkHeaderExt + { + int mHeaderSize; + int mVersion; + int mSampSize; + }; + + + /** + * @brief Normalization file type + */ + enum CNFileType + { + CNF_Mean, + CNF_Variance, + CNF_VarScale + }; + + + static int + ReadParmKind(const char *pStr, bool checkBrackets); + + static int + ParmKind2Str(unsigned parmKind, char *pOutstr); + + static void + ReadCepsNormFile( + const char* pFileName, + char** lastFile, + BaseFloat** vecBuff, + int sampleKind, + CNFileType type, + int coefs); + + static const char mpParmKindNames[13][16]; + + + + ////////////////////////////////////////////////////////////////////////////// + // PUBLIC SECTION + ////////////////////////////////////////////////////////////////////////////// + public: + /// Iterates through the list of feature file records + typedef std::list<FileListElem>::iterator ListIterator; + + // some params for loading features + bool mSwapFeatures; + int mStartFrameExt; + int mEndFrameExt; + int mTargetKind; + int mDerivOrder; + int* mDerivWinLengths; + const char* mpCvgFile; + //:TODO: get rid of these + const char* mpCmnPath; + const char* mpCmnMask; + const char* mpCvnPath; + const char* mpCvnMask; + + int mTrace; + + + // Constructors and destructors + /** + * @brief Default constructor that creates an empty repository + */ + FeatureRepository() : mDerivWinLengths(NULL), mpCvgFile(NULL), + mpCmnPath(NULL), mpCmnMask(NULL), mpCvnPath(NULL), mpCvnMask(NULL), + mTrace(0), + mpLastFileName(NULL), mLastFileName(""), mpLastCmnFile (NULL), + mpLastCvnFile (NULL), mpLastCvgFile (NULL), mpCmn(NULL), + mpCvn(NULL), mpCvg(NULL), mpA(NULL), mpB(NULL), + mTimeOpen(0), mTimeSeek(0), mTimeRead(0), mTimeNormalize(0) + { + mInputQueueIterator = mInputQueue.end(); + } + + /** + * @brief Copy constructor which copies filled repository + */ + FeatureRepository(const FeatureRepository& ori) + : mDerivWinLengths(NULL), mpCvgFile(NULL), + mpCmnPath(NULL), mpCmnMask(NULL), mpCvnPath(NULL), mpCvnMask(NULL), + mTrace(0), + mpLastFileName(NULL), mLastFileName(""), mpLastCmnFile (NULL), + mpLastCvnFile (NULL), mpLastCvgFile (NULL), mpCmn(NULL), + mpCvn(NULL), mpCvg(NULL), mpA(NULL), mpB(NULL), + mTimeOpen(0), mTimeSeek(0), mTimeRead(0), mTimeNormalize(0) + { + //copy all the data from the input queue + mInputQueue = ori.mInputQueue; + + //initialize like the original + Init( + ori.mSwapFeatures, + ori.mStartFrameExt, + ori.mEndFrameExt, + ori.mTargetKind, + ori.mDerivOrder, + ori.mDerivWinLengths, + ori.mpCmnPath, + ori.mpCmnMask, + ori.mpCvnPath, + ori.mpCvnMask, + ori.mpCvgFile); + + //set on the end + mInputQueueIterator = mInputQueue.end(); + //copy default header values + mHeader = ori.mHeader; + } + + + /** + * @brief Destroys the repository + */ + ~FeatureRepository() + { + if (NULL != mpA) { + free(mpA); + } + + if (NULL != mpB) { + free(mpB); + } + //remove all entries + mInputQueue.clear(); + + if(mTrace&4) { + std::cout << "[FeatureRepository -- open:" << mTimeOpen << "s seek:" << mTimeSeek << "s read:" << mTimeRead << "s normalize:" << mTimeNormalize << "s]\n"; + } + + } + + + /** + * @brief Initializes the object using the given parameters + * + * @param swap Boolean value specifies whether to swap bytes + * when reading file or not. + * @param extLeft Features read from file are extended with extLeft + * initial frames. Normally, these frames are + * repetitions of the first feature frame in file + * (with its derivative, if derivatives are preset in + * the file). However, if segment of feature frames + * is extracted according to range specification, the + * true feature frames from beyond the segment boundary + * are used, wherever it is possible. Note that value + * of extLeft can be also negative. In such case + * corresponding number of initial frames is discarded. + * @param extRight The paramerer is complementary to parameter extLeft + * and has obvious meaning. (Controls extensions over + * the last frame, last frame from file is repeated + * only if necessary). + * @param targetKind The parameters is used to check whether + * pHeader->mSampleKind match to requited targetKind + * and to control suppression of 0'th cepstral or + * energy coefficients accorging to modifiers _E, _0, + * and _N. Modifiers _D, _A and _T are ignored; + * Computation of derivatives is controled by parameters + * derivOrder and derivWinLen. Value PARAMKIND_ANON + * ensures that function do not result in targetKind + * mismatch error and cause no _E or _0 suppression. + * @param derivOrder Final features will be augmented with their + * derivatives up to 'derivOrder' order. If 'derivOrder' + * is negative value, no new derivatives are appended + * and derivatives that already present in feature file + * are preserved. Straight features are considered + * to be of zero order. If some derivatives are already + * present in feature file, these are not computed + * again, only higher order derivatives are appended + * if required. Note, that HTK feature file cannot + * contain higher order derivatives (e.g. double delta) + * without containing lower ones (e.g. delta). + * Derivative present in feature file that are of + * higher order than is required are discarded. + * Derivatives are computed in the final stage from + * (extracted segment of) feature frames possibly + * extended by repeated frames. Derivatives are + * computed using the same formula that is employed + * also by HTK tools. Lengths of windows used for + * computation of derivatives are passed in parameter + * derivWinLen. To compute derivatives for frames close + * to boundaries, frames before the first and after the + * last frame (of the extracted segment) are considered + * to be (yet another) repetitions of the first and the + * last frame, respectively. If the segment of frames + * is extracted according to range specification and + * parameters extLeft and extLeft are set to zero, the + * first and the last frames of the segment are + * considered to be repeated, eventough the true feature + * frames from beyond the segment boundary can be + * available in the file. Therefore, segment extracted + * from features that were before augmented with + * derivatives will differ + * from the same segment augmented with derivatives by + * this function. Difference will be of course only on + * boundaries and only in derivatives. This "incorrect" + * behavior was chosen to fully simulate behavior of + * HTK tools. To obtain more correct computation of + * derivatives, use parameters extLeft and extRight, + * which correctly extend segment with the true frames + * (if possible) and in resulting feature matrix ignore + * first extLeft and last extRight frames. For this + * purpose, both extLeft and extRight should be set to + * sum of all values in the array derivWinLen. + * @param pDerivWinLen Array of size derivOrder specifying lengths of + * windows used for computation of derivatives. + * Individual values represents one side context + * used in the computation. The each window length is + * therefore twice the value from array plus one. + * Value at index zero specify window length for first + * order derivatives (delta), higher indices + * corresponds to higher order derivatives. + * @param pCmnPath Cepstral mean normalization path + * @param pCmnMask Cepstral mean normalization mask + * @param pCvnPath Cepstral variance normalization path + * @param pCvnMask Cepstral variance normalization mask + * @param pCvgFile Global variance file to be parsed + * + * The given parameters are necessary for propper feature extraction + */ + void + Init( + bool swap, + int extLeft, + int extRight, + int targetKind, + int derivOrder, + int* pDerivWinLen, + const char* pCmnPath, + const char* pCmnMask, + const char* pCvnPath, + const char* pCvnMask, + const char* pCvgFile); + + + void Trace(int trace) + { mTrace = trace; } + + /** + * @brief Returns a refference to the current file header + */ + const HtkHeader& + CurrentHeader() const + { return mHeader; } + + /** + * @brief Returns a refference to the current file header + */ + const HtkHeaderExt& + CurrentHeaderExt() const + { return mHeaderExt; } + + /** + * @brief Returns the current file details + * + * @return Refference to a class @c FileListElem + * + * Logical and physical file names are stored in @c FileListElem class + */ + const std::list<FileListElem>::iterator& + pCurrentRecord() const + { return mInputQueueIterator; } + + + /** + * @brief Returns the following file details + * + * @return Refference to a class @c FileListElem + * + * Logical and physical file names are stored in @c FileListElem class + */ + const std::list<FileListElem>::iterator& + pFollowingRecord() const + { return mInputQueueIterator; } + + + void + Rewind() + { mInputQueueIterator = mInputQueue.begin(); } + + + /** + * @brief Adds a single feature file to the repository + * @param rFileName file to read features from + */ + void + AddFile(const std::string & rFileName); + + + /** + * @brief Adds a list of feature files to the repository + * @param rFileName feature list file to read from + */ + void + AddFileList(const char* pFileName, const char* pFilter = ""); + + + const FileListElem& + Current() const + { return *mInputQueueIterator; } + + + /** + * @brief Moves to the next record + */ + void + MoveNext(); + + /** + * @brief Reads full feature matrix from a feature file + * @param rMatrix matrix to be created and filled with read data + * @return number of successfully read feature vectors + */ + bool + ReadFullMatrix(Matrix<BaseFloat>& rMatrix); + + bool + WriteFeatureMatrix(const Matrix<BaseFloat>& rMatrix, const std::string& filename, int targetKind, int samplePeriod); + + size_t + QueueSize() const {return mInputQueue.size(); } + + /** + * @brief Reads feature vectors from a feature file + * @param rMatrix matrix to be (only!) filled with read data. + * @return number of successfully read feature vectors + * + * The function tries to fill @c pMatrix with feature vectors comming from + * the current stream. If there are less vectors left in the stream, + * they are used and true number of successfuly read vectors is returned. + */ + int + ReadPartialMatrix(Matrix<BaseFloat>& rMatrix); + + /** + * @brief Filters the records of this repository based on HTK logical name + * masking. If pFilter equals to NULL, all source repository entries are + * coppied to rOut repository. + * + * @param pFilter HTK mask that defines the filter + * @param pValue Filter value + * @param rOut Reference to the new FeatureRepository which will be filled + * with the matching records + */ + void + HtkFilter(const char* pFilter, const char* pValue, FeatureRepository& rOut); + + + /** + * @brief Filters the records of this repository based on HTK logical name + * masking and returns list of unique names. If pFilter equals to NULL, + * single name "default" is returned. + * + * @param pFilter HTK mask that defines the filter + * @param rOut Reference to the list of results (std::list< std::string >) + */ + void + HtkSelection(const char* pFilter, std::list< std::string >& rOut); + + + /** + * @brief Returns true if there are no feature files left on input + */ + bool + EndOfList() const + { return mInputQueueIterator == mInputQueue.end(); } + + const std::string& + CurrentIndexFileName() const + { return mCurrentIndexFileName; } + + friend + void + AddFileListToFeatureRepositories( + const char* pFileName, + const char* pFilter, + std::queue<FeatureRepository *> &featureRepositoryList); + + +//////////////////////////////////////////////////////////////////////////////// +// PRIVATE SECTION +//////////////////////////////////////////////////////////////////////////////// + private: + /// List (queue) of input feature files + std::list<FileListElem> mInputQueue; + std::list<FileListElem>::iterator mInputQueueIterator; + + std::string mCurrentIndexFileName; + std::string mCurrentIndexFileDir; + std::string mCurrentIndexFileExt; + + /// current stream + IStkStream mStream; + + // stores feature file's HTK header + HtkHeader mHeader; + HtkHeaderExt mHeaderExt; + + + // this group of variables serve for working withthe same physical + // file name more than once + char* mpLastFileName; + std::string mLastFileName; + char* mpLastCmnFile; + char* mpLastCvnFile; + char* mpLastCvgFile; + BaseFloat* mpCmn; + BaseFloat* mpCvn; + BaseFloat* mpCvg; + HtkHeader mLastHeader; + BaseFloat* mpA; + BaseFloat* mpB; + + + + Timer mTim; + double mTimeOpen; + double mTimeSeek; + double mTimeRead; + double mTimeNormalize; + + + // Reads HTK feature file header + int + ReadHTKHeader(); + + int + ReadHTKFeature(BaseFloat* pIn, + size_t feaLen, + bool decompress, + BaseFloat* pScale, + BaseFloat* pBias); + + + bool + ReadHTKFeatures(const std::string& rFileName, Matrix<BaseFloat>& rFeatureMatrix); + + bool + ReadHTKFeatures(const FileListElem& rFileNameRecord, Matrix<BaseFloat>& rFeatureMatrix); + + + int + WriteHTKHeader (FILE* fp_out, HtkHeader header, bool swap); + + int + WriteHTKFeature (FILE* fp_out, FLOAT *out, size_t fea_len, bool swap, bool compress, FLOAT* pScale, FLOAT* pBias); + + int + WriteHTKFeatures(FILE* pOutFp, FLOAT * pOut, int nCoeffs, int nSamples, int samplePeriod, int targetKind, bool swap); + + int + WriteHTKFeatures( + FILE * pOutFp, + int samplePeriod, + int targetKind, + bool swap, + Matrix<BaseFloat>& rFeatureMatrix + ); + + bool + ReadGzipAsciiFeatures(const FileListElem& rFileNameRecord, Matrix<BaseFloat>& rFeatureMatrix); + + }; // class FeatureStream + +} //namespace TNet + +#endif // TNet_Features_h diff --git a/htk_io/src/KaldiLib/Labels.cc b/htk_io/src/KaldiLib/Labels.cc new file mode 100644 index 0000000..8b04cde --- /dev/null +++ b/htk_io/src/KaldiLib/Labels.cc @@ -0,0 +1,612 @@ +#include "Labels.h" +#include "Timer.h" +#include "Error.h" +#include <cstdio> +#include <sstream> + + +namespace TNet { + + + //////////////////////////////////////////////////////////////////////// + // Class LabelRepository:: + void + LabelRepository:: + Init(const char* pLabelMlfFile, const char* pOutputLabelMapFile, const char* pLabelDir, const char* pLabelExt) { + InitMap(pLabelMlfFile, pOutputLabelMapFile, pLabelDir, pLabelExt); + } + void + LabelRepository:: + InitExt(const char* pLabelMlfFile, const char* fmt, const char *arg, const char* pLabelDir, const char* pLabelExt) { + if (strcmp(fmt, "map") == 0) + { + InitMap(pLabelMlfFile, arg, pLabelDir, pLabelExt); + mlf_fmt = MAP; + } + else if (strcmp(fmt, "raw") == 0) + { + InitRaw(pLabelMlfFile, arg, pLabelDir, pLabelExt); + mlf_fmt = RAW; + } + } + + void + LabelRepository:: + InitMap(const char* pLabelMlfFile, const char* pOutputLabelMapFile, const char* pLabelDir, const char* pLabelExt) + { + assert(NULL != pLabelMlfFile); + assert(NULL != pOutputLabelMapFile); + + // initialize the label streams + delete mpLabelStream; //if NULL, does nothing + delete _mpLabelStream; + _mpLabelStream = new std::ifstream(pLabelMlfFile); + mpLabelStream = new IMlfStream(*_mpLabelStream); + + // Label stream is initialized, just test it + if(!mpLabelStream->good()) + Error(std::string("Cannot open Label MLF file: ")+pLabelMlfFile); + + // Index the labels (good for randomized file lists) + Timer tim; tim.Start(); + mpLabelStream->Index(); + tim.End(); mIndexTime += tim.Val(); + + // Read the state-label to state-id map + ReadOutputLabelMap(pOutputLabelMapFile); + + // Store the label dir/ext + mpLabelDir = pLabelDir; + mpLabelExt = pLabelExt; + } + + void + LabelRepository:: + InitRaw(const char* pLabelMlfFile, const char *arg, const char* pLabelDir, const char* pLabelExt) + { + assert(NULL != pLabelMlfFile); + std::istringstream iss(arg); + size_t dim; + iss >> dim; + if (iss.fail() || dim <= 0) + PError("[lab] malformed dimension specification"); + raw_dim = dim; + // initialize the label streams + delete mpLabelStream; //if NULL, does nothing + delete _mpLabelStream; + _mpLabelStream = new std::ifstream(pLabelMlfFile); + mpLabelStream = new IMlfStream(*_mpLabelStream); + + // Label stream is initialized, just test it + if(!mpLabelStream->good()) + Error(std::string("Cannot open Label MLF file: ")+pLabelMlfFile); + + // Index the labels (good for randomized file lists) + Timer tim; tim.Start(); + mpLabelStream->Index(); + tim.End(); mIndexTime += tim.Val(); + + // Read the state-label to state-id map + //ReadOutputLabelMap(pOutputLabelMapFile); + + // Store the label dir/ext + mpLabelDir = pLabelDir; + mpLabelExt = pLabelExt; + } + + + void + LabelRepository:: + GenDesiredMatrix(BfMatrix& rDesired, size_t nFrames, size_t sourceRate, const char* pFeatureLogical, bool has_vad) + { + //timer + Timer tim; tim.Start(); + + //Get the MLF stream reference... + IMlfStream& mLabelStream = *mpLabelStream; + //Build the file name of the label + MakeHtkFileName(mpLabelFile, pFeatureLogical, mpLabelDir, mpLabelExt); + + //Find block in MLF file + mLabelStream.Open(mpLabelFile); + if(!mLabelStream.good()) { + Error(std::string("Cannot open label MLF record: ") + mpLabelFile); + } + + + //resize the matrix + if(nFrames < 1) { + KALDI_ERR << "Number of frames:" << nFrames << " is lower than 1!!!\n" + << pFeatureLogical; + } + int label_map_size = mLabelMap.size(); + rDesired.Init(nFrames, label_map_size + (has_vad ? 2 : 0), true); //true: Zero() + + //aux variables + std::string line, state; + unsigned long long beg, end; + size_t state_index; + size_t trunc_frames = 0; + TagToIdMap::iterator it; + int vad_state; + + //parse the label file + while(!mLabelStream.eof()) { + std::getline(mLabelStream, line); + if(line == "") continue; //skip newlines/comments from MLF + if(line[0] == '#') continue; + + std::istringstream& iss = mGenDesiredMatrixStream; + iss.clear(); + iss.str(line); + + //parse the line + //begin + iss >> std::ws >> beg; + if(iss.fail()) { + KALDI_ERR << "Cannot parse column 1 (begin)\n" + << "line: " << line << "\n" + << "file: " << mpLabelFile << "\n"; + } + //end + iss >> std::ws >> end; + if(iss.fail()) { + KALDI_ERR << "Cannot parse column 2 (end)\n" + << "line: " << line << "\n" + << "file: " << mpLabelFile << "\n"; + } + //state tag + iss >> std::ws >> state; + if(iss.fail()) { + KALDI_ERR << "Cannot parse column 3 (state_tag)\n" + << "line: " << line << "\n" + << "file: " << mpLabelFile << "\n"; + } + + if (has_vad) /* an additional column for vad */ + { + iss >> std::ws >> vad_state; + if(iss.fail()) { + KALDI_ERR << "Cannot parse column 4 (vad_state)\n" + << "line: " << line << "\n" + << "file: " << mpLabelFile << "\n"; + } + } + + //fprintf(stderr, "Parsed: %lld %lld %s\n", beg, end, state.c_str()); + + //divide beg/end by sourceRate and round up to get interval of frames + beg = (beg+sourceRate/2)/sourceRate; + end = (end+sourceRate/2)/sourceRate; + //beg = (int)round(beg / (double)sourceRate); + //end = (int)round(end / (double)sourceRate); + + //find the state id + it = mLabelMap.find(state); + if(mLabelMap.end() == it) { + Error(std::string("Unknown state tag: '") + state + "' file:'" + mpLabelFile); + } + state_index = it->second; + + // Fill the desired matrix + for(unsigned long long frame=beg; frame<end; frame++) { + //don't write after matrix... (possible longer transcript than feature file) + if(frame >= (int)rDesired.Rows()) { trunc_frames++; continue; } + + //check the next frame is empty: + if(0.0 != rDesired[frame].Sum()) { + //ERROR!!! + //find out what was previously filled!!! + BaseFloat max = rDesired[frame].Max(); + int idx = -1; + for(int i=0; i<(int)rDesired[frame].Dim(); i++) { + if(rDesired[frame][i] == max) idx = i; + } + for(it=mLabelMap.begin(); it!=mLabelMap.end(); ++it) { + if((int)it->second == idx) break; + } + std::string state_prev = "error"; + if(it != mLabelMap.end()) { + state_prev = it->first; + } + //print the error message + std::ostringstream os; + os << "Frame already assigned to other state, " + << " file: " << mpLabelFile + << " frame: " << frame + << " nframes: " << nFrames + << " sum: " << rDesired[frame].Sum() + << " previously assigned to: " << state_prev << "(" << idx << ")" + << " now should be assigned to: " << state << "(" << state_index << ")" + << "\n"; + Error(os.str()); + } + + //fill the row + rDesired[(size_t)frame][state_index] = 1.0f; + if (has_vad) + rDesired[(size_t)frame][label_map_size + !!vad_state] = 1.0f; + } + } + + mLabelStream.Close(); + + //check the desired matrix (rows sum up to 1.0) + for(size_t i=0; i<rDesired.Rows(); ++i) { + float desired_row_sum = rDesired[i].Sum(); + if(!desired_row_sum == 1.0) { + std::ostringstream os; + os << "Desired vector sum isn't 1.0, " + << " file: " << mpLabelFile + << " row: " << i + << " nframes: " << nFrames + << " content: " << rDesired[i] + << " sum: " << desired_row_sum << "\n"; + Error(os.str()); + } + } + + //warning when truncating many frames + if(trunc_frames > 10) { + std::ostringstream os; + os << "Truncated frames: " << trunc_frames + << " Check sourcerate in features and validity of labels\n"; + Warning(os.str()); + } + + //timer + tim.End(); mGenDesiredMatrixTime += tim.Val(); + } + + + + void + LabelRepository:: + ReadOutputLabelMap(const char* file) + { + assert(mLabelMap.size() == 0); + int i = 0; + std::string state_tag; + std::ifstream in(file); + if(!in.good()) + Error(std::string("Cannot open OutputLabelMapFile: ")+file); + + in >> std::ws; + while(!in.eof()) { + in >> state_tag; + in >> std::ws; + assert(mLabelMap.find(state_tag) == mLabelMap.end()); + mLabelMap[state_tag] = i++; + } + + in.close(); + assert(mLabelMap.size() > 0); + } + + + void + LabelRepository:: + GenDesiredMatrixExt(std::vector<BfMatrix>& rDesired, size_t nFrames, size_t sourceRate, const char* pFeatureLogical) { + switch (mlf_fmt) + { + case MAP: GenDesiredMatrixExtMap(rDesired, nFrames, sourceRate, pFeatureLogical); + break; + case RAW: GenDesiredMatrixExtRaw(rDesired, nFrames, sourceRate, pFeatureLogical); + break; + default: assert(0); + } + } + + void + LabelRepository:: + GenDesiredMatrixExtMap(std::vector<BfMatrix>& rDesired, size_t nFrames, size_t sourceRate, const char* pFeatureLogical) + { + //timer + Timer tim; tim.Start(); + + //Get the MLF stream reference... + IMlfStream& mLabelStream = *mpLabelStream; + //Build the file name of the label + MakeHtkFileName(mpLabelFile, pFeatureLogical, mpLabelDir, mpLabelExt); + + //Find block in MLF file + mLabelStream.Open(mpLabelFile); + if(!mLabelStream.good()) { + Error(std::string("Cannot open label MLF record: ") + mpLabelFile); + } + + + //resize the matrix + if(nFrames < 1) { + KALDI_ERR << "Number of frames:" << nFrames << " is lower than 1!!!\n" + << pFeatureLogical; + } + + size_t prev = rDesired.size(); + rDesired.resize(prev + 1, BfMatrix()); /* state + vad */ + int label_map_size = mLabelMap.size(); + rDesired[prev].Init(nFrames, 1, true); //true: Zero() + + //aux variables + std::string line, state; + unsigned long long beg, end; + size_t state_index; + size_t trunc_frames = 0; + TagToIdMap::iterator it; + + //parse the label file + while(!mLabelStream.eof()) { + std::getline(mLabelStream, line); + if(line == "") continue; //skip newlines/comments from MLF + if(line[0] == '#') continue; + + std::istringstream& iss = mGenDesiredMatrixStream; + iss.clear(); + iss.str(line); + + //parse the line + //begin + iss >> std::ws >> beg; + if(iss.fail()) { + KALDI_ERR << "Cannot parse column 1 (begin)\n" + << "line: " << line << "\n" + << "file: " << mpLabelFile << "\n"; + } + //end + iss >> std::ws >> end; + if(iss.fail()) { + KALDI_ERR << "Cannot parse column 2 (end)\n" + << "line: " << line << "\n" + << "file: " << mpLabelFile << "\n"; + } + //state tag + iss >> std::ws >> state; + if(iss.fail()) { + KALDI_ERR << "Cannot parse column 3 (state_tag)\n" + << "line: " << line << "\n" + << "file: " << mpLabelFile << "\n"; + } + + + //fprintf(stderr, "Parsed: %lld %lld %s\n", beg, end, state.c_str()); + + //divide beg/end by sourceRate and round up to get interval of frames + beg = (beg+sourceRate/2)/sourceRate; + if (end == (unsigned long long)-1) + end = rDesired[prev].Rows(); + else + end = (end+sourceRate/2)/sourceRate; + //beg = (int)round(beg / (double)sourceRate); + //end = (int)round(end / (double)sourceRate); + + //find the state id + it = mLabelMap.find(state); + if(mLabelMap.end() == it) { + Error(std::string("Unknown state tag: '") + state + "' file:'" + mpLabelFile); + } + state_index = it->second; + + // Fill the desired matrix + for(unsigned long long frame=beg; frame<end; frame++) { + //don't write after matrix... (possible longer transcript than feature file) + if(frame >= (int)rDesired[prev].Rows()) { trunc_frames++; continue; } + + //check the next frame is empty: + if(0.0 != rDesired[prev][frame][0]) { + //ERROR!!! + //find out what was previously filled!!! + /* + BaseFloat max = rDesired[prev][frame].Max(); + int idx = -1; + for(int i=0; i<(int)rDesired[prev][frame].Dim(); i++) { + if(rDesired[prev][frame][i] == max) idx = i; + } + */ + BaseFloat max = rDesired[prev][frame][0]; + int idx = round(max); + for(it=mLabelMap.begin(); it!=mLabelMap.end(); ++it) { + if((int)it->second == idx) break; + } + std::string state_prev = "error"; + if(it != mLabelMap.end()) { + state_prev = it->first; + } + //print the error message + std::ostringstream os; + os << "Frame already assigned to other state, " + << " file: " << mpLabelFile + << " frame: " << frame + << " nframes: " << nFrames + << " sum: " << max + << " previously assigned to: " << state_prev << "(" << idx << ")" + << " now should be assigned to: " << state << "(" << state_index << ")" + << "\n"; + Error(os.str()); + } + + //fill the row + //rDesired[prev][(size_t)frame][state_index] = 1.0f; + rDesired[prev][(size_t)frame][0] = state_index; + } + } + + mLabelStream.Close(); + /* + //check the desired matrix (rows sum up to 1.0) + for(size_t i=0; i<rDesired[prev].Rows(); ++i) { + float desired_row_sum = rDesired[prev][i].Sum(); + if(!desired_row_sum == 1.0) { + std::ostringstream os; + os << "Desired vector sum isn't 1.0, " + << " file: " << mpLabelFile + << " row: " << i + << " nframes: " << nFrames + << " content: " << rDesired[prev][i] + << " sum: " << desired_row_sum << "\n"; + Error(os.str()); + } + } + */ + + //warning when truncating many frames + if(trunc_frames > 10) { + std::ostringstream os; + os << "Truncated frames: " << trunc_frames + << " Check sourcerate in features and validity of labels\n"; + Warning(os.str()); + } + + //timer + tim.End(); mGenDesiredMatrixTime += tim.Val(); + } + + void + LabelRepository:: + GenDesiredMatrixExtRaw(std::vector<BfMatrix>& rDesired, size_t nFrames, size_t sourceRate, const char* pFeatureLogical) + { + //timer + Timer tim; tim.Start(); + + //Get the MLF stream reference... + IMlfStream& mLabelStream = *mpLabelStream; + //Build the file name of the label + MakeHtkFileName(mpLabelFile, pFeatureLogical, mpLabelDir, mpLabelExt); + + //Find block in MLF file + mLabelStream.Open(mpLabelFile); + if(!mLabelStream.good()) { + Error(std::string("Cannot open label MLF record: ") + mpLabelFile); + } + + + //resize the matrix + if(nFrames < 1) { + KALDI_ERR << "Number of frames:" << nFrames << " is lower than 1!!!\n" + << pFeatureLogical; + } + + size_t prev = rDesired.size(); + rDesired.resize(prev + 1, BfMatrix()); /* state + vad */ + rDesired[prev].Init(nFrames, raw_dim, true); //true: Zero() + + //aux variables + std::string line, state; + unsigned long long beg, end; + size_t trunc_frames = 0; + Vector<BaseFloat> raw; + raw.Init(raw_dim); + + //parse the label file + while(!mLabelStream.eof()) { + std::getline(mLabelStream, line); + if(line == "") continue; //skip newlines/comments from MLF + if(line[0] == '#') continue; + + std::istringstream& iss = mGenDesiredMatrixStream; + iss.clear(); + iss.str(line); + + //parse the line + //begin + iss >> std::ws >> beg; + if(iss.fail()) { + KALDI_ERR << "Cannot parse column 1 (begin)\n" + << "line: " << line << "\n" + << "file: " << mpLabelFile << "\n"; + } + //end + iss >> std::ws >> end; + if(iss.fail()) { + KALDI_ERR << "Cannot parse column 2 (end)\n" + << "line: " << line << "\n" + << "file: " << mpLabelFile << "\n"; + } + + for (size_t i = 0; i < raw_dim; i++) + { + if (iss.eof()) + PError("[label] insufficient columns for the label: %s", mpLabelFile); + iss >> raw[i]; + if (iss.fail()) + PError("[label] cannot parse raw value for the label: %s", mpLabelFile); + } + /* + for (size_t i = 0; i < raw_dim; i++) + fprintf(stderr, "%.3f", raw[i]); + fprintf(stderr, "\n"); + */ + //divide beg/end by sourceRate and round up to get interval of frames + beg = (beg+sourceRate/2)/sourceRate; + if (end == (unsigned long long)-1) + end = rDesired[prev].Rows(); + else + end = (end+sourceRate/2)/sourceRate; + //printf("end:%lld\n", end); + //beg = (int)round(beg / (double)sourceRate); + //end = (int)round(end / (double)sourceRate); + + // Fill the desired matrix + for(unsigned long long frame=beg; frame<end; frame++) { + //don't write after matrix... (possible longer transcript than feature file) + if(frame >= (int)rDesired[prev].Rows()) { trunc_frames++; continue; } + + //check the next frame is empty: + if(0.0 != rDesired[prev][frame][0]) { + //ERROR!!! + //find out what was previously filled!!! + /* + BaseFloat max = rDesired[prev][frame].Max(); + int idx = -1; + for(int i=0; i<(int)rDesired[prev][frame].Dim(); i++) { + if(rDesired[prev][frame][i] == max) idx = i; + } + */ + BaseFloat max = rDesired[prev][frame][0]; + //print the error message + std::ostringstream os; + os << "Frame already assigned to other state, " + << " file: " << mpLabelFile + << " frame: " << frame + << " nframes: " << nFrames + << " sum: " << max + << "\n"; + Error(os.str()); + } + + //fill the row + //rDesired[prev][(size_t)frame][state_index] = 1.0f; + rDesired[prev][(size_t)frame].Copy(raw); + } + } + + mLabelStream.Close(); + /* + //check the desired matrix (rows sum up to 1.0) + for(size_t i=0; i<rDesired[prev].Rows(); ++i) { + float desired_row_sum = rDesired[prev][i].Sum(); + if(!desired_row_sum == 1.0) { + std::ostringstream os; + os << "Desired vector sum isn't 1.0, " + << " file: " << mpLabelFile + << " row: " << i + << " nframes: " << nFrames + << " content: " << rDesired[prev][i] + << " sum: " << desired_row_sum << "\n"; + Error(os.str()); + } + } + */ + + //warning when truncating many frames + if(trunc_frames > 10) { + std::ostringstream os; + os << "Truncated frames: " << trunc_frames + << " Check sourcerate in features and validity of labels\n"; + Warning(os.str()); + } + + //timer + tim.End(); mGenDesiredMatrixTime += tim.Val(); + } + +}//namespace diff --git a/htk_io/src/KaldiLib/Labels.h b/htk_io/src/KaldiLib/Labels.h new file mode 100644 index 0000000..409a080 --- /dev/null +++ b/htk_io/src/KaldiLib/Labels.h @@ -0,0 +1,90 @@ +#ifndef _LABELS_H_ +#define _LABELS_H_ + + +#include "Matrix.h" +#include "MlfStream.h" +#include "Features.h" + +#include <map> +#include <iostream> + +namespace TNet { + + + class FeaCatPool; + + /** + * Desired matrix generation object, + * supports background-reading and caching, however can be + * used in foreground as well by GenDesiredMatrix() + */ + class LabelRepository + { + typedef std::map<std::string,size_t> TagToIdMap; + + public: + enum MFormat { + MAP, + RAW + }; + + LabelRepository() + : _mpLabelStream(NULL), mpLabelStream(NULL), mpLabelDir(NULL), mpLabelExt(NULL), mGenDesiredMatrixTime(0), mIndexTime(0), mTrace(0) + { } + + ~LabelRepository() + { + if(mTrace&4) { + std::cout << "[LabelRepository -- indexing:" << mIndexTime << "s" + " genDesiredMatrix:" << mGenDesiredMatrixTime << "s]" << std::endl; + } + delete mpLabelStream; + delete _mpLabelStream; + } + + /// Initialize the LabelRepository + void Init(const char* pLabelMlfFile, const char* pOutputLabelMapFile, const char* pLabelDir, const char* pLabelExt); + void InitExt(const char* pLabelMlfFile, const char* fmt, const char* arg, const char* pLabelDir, const char* pLabelExt); + void InitMap(const char* pLabelMlfFile, const char* pOutputLabelMapFile, const char* pLabelDir, const char* pLabelExt); + void InitRaw(const char* pLabelMlfFile,const char* arg, const char* pLabelDir, const char* pLabelExt); + + /// Set trace level + void Trace(int trace) + { mTrace = trace; } + + /// Get desired matrix from labels + void GenDesiredMatrix(BfMatrix& rDesired, size_t nFrames, size_t sourceRate, const char* pFeatureLogical, bool has_vad = false); + + void GenDesiredMatrixExt(std::vector<BfMatrix>& rDesired, size_t nFrames, size_t sourceRate, const char* pFeatureLogical); + void GenDesiredMatrixExtMap(std::vector<BfMatrix>& rDesired, size_t nFrames, size_t sourceRate, const char* pFeatureLogical); + void GenDesiredMatrixExtRaw(std::vector<BfMatrix>& rDesired, size_t nFrames, size_t sourceRate, const char* pFeatureLogical); + size_t getWidth() { return mLabelMap.size(); } + MFormat getFormat() { return mlf_fmt; } + private: + /// Prepare the state-label to state-id map + void ReadOutputLabelMap(const char* file); + + private: + // Streams and state-map + std::ifstream* _mpLabelStream; ///< Helper stream for Label stream + IMlfStream* mpLabelStream; ///< Label stream + std::istringstream mGenDesiredMatrixStream; ///< Label file parsing stream + + const char* mpLabelDir; ///< Label dir in MLF + const char* mpLabelExt; ///< Label ext in MLF + char mpLabelFile[4096]; ///< Buffer for filenames in MLF + + TagToIdMap mLabelMap; ///< Map of state tags to net output indices + + double mGenDesiredMatrixTime; + float mIndexTime; + + int mTrace; + MFormat mlf_fmt; + size_t raw_dim; + }; + +}//namespace + +#endif diff --git a/htk_io/src/KaldiLib/Makefile b/htk_io/src/KaldiLib/Makefile new file mode 100644 index 0000000..61e0a59 --- /dev/null +++ b/htk_io/src/KaldiLib/Makefile @@ -0,0 +1,28 @@ + +include ../tnet.mk + +INCLUDE = -I. + +all: $(OBJ_DIR)/libKaldiLib.a + +$(OBJ_DIR)/libKaldiLib.a: $(OBJ) + $(AR) ruv $@ $(OBJ) + $(RANLIB) $@ + +$(OBJ_DIR)/%.o : %.cc + $(CXX) -o $@ -c $< $(CFLAGS) $(CXXFLAGS) $(INCLUDE) + + + +.PHONY: clean doc depend +clean: + rm -f $(OBJ_DIR)/*.o $(OBJ_DIR)/*.a + +doc: + doxygen ../../doc/doxyfile_TNetLib + +depend: + $(CXX) -M $(CXXFLAGS) *.cc $(INCLUDE) > .depend.mk + +-include .depend.mk + diff --git a/htk_io/src/KaldiLib/MathAux.h b/htk_io/src/KaldiLib/MathAux.h new file mode 100644 index 0000000..c08e836 --- /dev/null +++ b/htk_io/src/KaldiLib/MathAux.h @@ -0,0 +1,117 @@ +#ifndef TNet_MathAux_h +#define TNet_MathAux_h + +#include <cmath> + + +#if !defined(SQR) +# define SQR(x) ((x) * (x)) +#endif + + +#if !defined(LOG_0) +# define LOG_0 (-1.0e10) +#endif + +#if !defined(LOG_MIN) +# define LOG_MIN (0.5 * LOG_0) +#endif + + +#ifndef DBL_EPSILON +#define DBL_EPSILON 2.2204460492503131e-16 +#endif + + +#ifndef M_PI +# define M_PI 3.1415926535897932384626433832795 +#endif + +#define M_LOG_2PI 1.8378770664093454835606594728112 + + +#if DOUBLEPRECISION +# define FLOAT double +# define EPSILON DBL_EPSILON +# define FLOAT_FMT "%lg" +# define swapFLOAT swap8 +# define _ABS fabs +# define _COS cos +# define _EXP exp +# define _LOG log +# define _SQRT sqrt +#else +# define FLOAT float +# define EPSILON FLT_EPSILON +# define FLOAT_FMT "%g" +# define swapFLOAT swap4 +# define _ABS fabsf +# define _COS cosf +# define _EXP expf +# define _LOG logf +# define _SQRT sqrtf +#endif + +namespace TNet +{ + inline float frand(){ // random between 0 and 1. + return (float(rand()) + 1.0f) / (float(RAND_MAX)+2.0f); + } + inline float gauss_rand(){ + return _SQRT( -2.0f * _LOG(frand()) ) * _COS(2.0f*float(M_PI)*frand()); + } + + static const double gMinLogDiff = log(DBL_EPSILON); + + //*************************************************************************** + //*************************************************************************** + inline double + LogAdd(double x, double y) + { + double diff; + + if (x < y) { + diff = x - y; + x = y; + } else { + diff = y - x; + } + + double res; + if (x >= LOG_MIN) { + if (diff >= gMinLogDiff) { + res = x + log(1.0 + exp(diff)); + } else { + res = x; + } + } else { + res = LOG_0; + } + return res; + } + + + //*************************************************************************** + //*************************************************************************** + inline double + LogSub(double x, double y) // returns exp(x) - exp(y). Throws exception if y>=x. + { + + if(y >= x){ + if(y==x) return LOG_0; + else throw std::runtime_error("LogSub: cannot subtract a larger from a smaller number."); + } + + double diff = y - x; // Will be negative. + + double res = x + log(1.0 - exp(diff)); + + if(res != res) // test for res==NaN.. could happen if diff ~0.0, so 1.0-exp(diff) == 0.0 to machine precision. + res = LOG_0; + return res; + } + +} // namespace TNet + + +#endif diff --git a/htk_io/src/KaldiLib/Matrix.cc b/htk_io/src/KaldiLib/Matrix.cc new file mode 100644 index 0000000..f9d5909 --- /dev/null +++ b/htk_io/src/KaldiLib/Matrix.cc @@ -0,0 +1,295 @@ +/** + * @file Matrix.cc + * + * Implementation of specialized Matrix template methods + */ + + +#include "Matrix.h" + +#if defined(HAVE_CLAPACK) +#include "CLAPACK-3.1.1.1/INCLUDE/f2c.h" +extern "C" { +#include "CLAPACK-3.1.1.1/INCLUDE/clapack.h" +} +// These are some stupid clapack things that we want to get rid of +#ifdef min +#undef min +#endif + +#ifdef max +#undef max +#endif + +#endif + + + + +namespace TNet +{ + //*************************************************************************** + //*************************************************************************** +#ifdef HAVE_ATLAS + //*************************************************************************** + //*************************************************************************** + template<> + Matrix<float> & + Matrix<float>:: + Invert(float *LogDet, float *DetSign, bool inverse_needed) + { + assert(Rows() == Cols()); + +#if defined(HAVE_CLAPACK) + integer* pivot = new integer[mMRows]; + integer M = Rows(); + integer N = Cols(); + integer LDA = mStride; + integer result; + integer l_work = std::max<integer>(1, N); + float* p_work = new float[l_work]; + + sgetrf_(&M, &N, mpData, &LDA, pivot, &result); + const int pivot_offset=1; +#else + int* pivot = new int[mMRows]; + int result = clapack_sgetrf(CblasColMajor, Rows(), Cols(), mpData, mStride, pivot); + const int pivot_offset=0; +#endif + assert(result >= 0 && "Call to CLAPACK sgetrf_ or ATLAS clapack_sgetrf called with wrong arguments"); + if(result != 0) { + Error("Matrix is singular"); + } + if(DetSign!=NULL){ *DetSign=1.0; for(size_t i=0;i<mMRows;i++) if(pivot[i]!=(int)i+pivot_offset) *DetSign *= -1.0; } + if(LogDet!=NULL||DetSign!=NULL){ // Compute log determinant... + assert(mMRows==mMCols); // Can't take determinant of non-square matrix. + *LogDet = 0.0; float prod = 1.0; + for(size_t i=0;i<mMRows;i++){ + prod *= (*this)(i,i); + if(i==mMRows-1 || fabs(prod)<1.0e-10 || fabs(prod)>1.0e+10){ + if(LogDet!=NULL) *LogDet += log(fabs(prod)); + if(DetSign!=NULL) *DetSign *= (prod>0?1.0:-1.0); + prod=1.0; + } + } + } +#if defined(HAVE_CLAPACK) + if(inverse_needed) sgetri_(&M, mpData, &LDA, pivot, p_work, &l_work, &result); + delete [] pivot; +#else + if(inverse_needed) result = clapack_sgetri(CblasColMajor, Rows(), mpData, mStride, pivot); + delete [] pivot; +#endif + assert(result == 0 && "Call to CLAPACK sgetri_ or ATLAS clapack_sgetri called with wrong arguments"); + return *this; + } + + + //*************************************************************************** + //*************************************************************************** + template<> + Matrix<double> & + Matrix<double>:: + Invert(double *LogDet, double *DetSign, bool inverse_needed) + { + assert(Rows() == Cols()); + +#if defined(HAVE_CLAPACK) + integer* pivot = new integer[mMRows]; + integer M = Rows(); + integer N = Cols(); + integer LDA = mStride; + integer result; + integer l_work = std::max<integer>(1, N); + double* p_work = new double[l_work]; + + dgetrf_(&M, &N, mpData, &LDA, pivot, &result); + const int pivot_offset=1; +#else + int* pivot = new int[mMRows]; + int result = clapack_dgetrf(CblasColMajor, Rows(), Cols(), mpData, mStride, pivot); + const int pivot_offset=0; +#endif + assert(result >= 0 && "Call to CLAPACK dgetrf_ or ATLAS clapack_dgetrf called with wrong arguments"); + if(result != 0) { + Error("Matrix is singular"); + } + if(DetSign!=NULL){ *DetSign=1.0; for(size_t i=0;i<mMRows;i++) if(pivot[i]!=(int)i+pivot_offset) *DetSign *= -1.0; } + if(LogDet!=NULL||DetSign!=NULL){ // Compute log determinant... + assert(mMRows==mMCols); // Can't take determinant of non-square matrix. + *LogDet = 0.0; double prod = 1.0; + for(size_t i=0;i<mMRows;i++){ + prod *= (*this)(i,i); + if(i==mMRows-1 || fabs(prod)<1.0e-10 || fabs(prod)>1.0e+10){ + if(LogDet!=NULL) *LogDet += log(fabs(prod)); + if(DetSign!=NULL) *DetSign *= (prod>0?1.0:-1.0); + prod=1.0; + } + } + } +#if defined(HAVE_CLAPACK) + if(inverse_needed) dgetri_(&M, mpData, &LDA, pivot, p_work, &l_work, &result); + delete [] pivot; +#else + if(inverse_needed) result = clapack_dgetri(CblasColMajor, Rows(), mpData, mStride, pivot); + delete [] pivot; +#endif + assert(result == 0 && "Call to CLAPACK dgetri_ or ATLAS clapack_dgetri called with wrong arguments"); + return *this; + } + + template<> + Matrix<float> & + Matrix<float>:: + BlasGer(const float alpha, const Vector<float>& rA, const Vector<float>& rB) + { + assert(rA.Dim() == mMRows && rB.Dim() == mMCols); + cblas_sger(CblasRowMajor, rA.Dim(), rB.Dim(), alpha, rA.pData(), 1, rB.pData(), 1, mpData, mStride); + return *this; + } + + template<> + Matrix<double> & + Matrix<double>:: + BlasGer(const double alpha, const Vector<double>& rA, const Vector<double>& rB) + { + assert(rA.Dim() == mMRows && rB.Dim() == mMCols); + cblas_dger(CblasRowMajor, rA.Dim(), rB.Dim(), alpha, rA.pData(), 1, rB.pData(), 1, mpData, mStride); + return *this; + } + + template<> + Matrix<float>& + Matrix<float>:: + BlasGemm(const float alpha, + const Matrix<float>& rA, MatrixTrasposeType transA, + const Matrix<float>& rB, MatrixTrasposeType transB, + const float beta) + { + assert((transA == NO_TRANS && transB == NO_TRANS && rA.Cols() == rB.Rows() && rA.Rows() == Rows() && rB.Cols() == Cols()) + || (transA == TRANS && transB == NO_TRANS && rA.Rows() == rB.Rows() && rA.Cols() == Rows() && rB.Cols() == Cols()) + || (transA == NO_TRANS && transB == TRANS && rA.Cols() == rB.Cols() && rA.Rows() == Rows() && rB.Rows() == Cols()) + || (transA == TRANS && transB == TRANS && rA.Rows() == rB.Cols() && rA.Cols() == Rows() && rB.Rows() == Cols())); + + cblas_sgemm(CblasRowMajor, static_cast<CBLAS_TRANSPOSE>(transA), static_cast<CBLAS_TRANSPOSE>(transB), + Rows(), Cols(), transA == NO_TRANS ? rA.Cols() : rA.Rows(), + alpha, rA.mpData, rA.mStride, rB.mpData, rB.mStride, + beta, mpData, mStride); + return *this; + } + + template<> + Matrix<double>& + Matrix<double>:: + BlasGemm(const double alpha, + const Matrix<double>& rA, MatrixTrasposeType transA, + const Matrix<double>& rB, MatrixTrasposeType transB, + const double beta) + { + assert((transA == NO_TRANS && transB == NO_TRANS && rA.Cols() == rB.Rows() && rA.Rows() == Rows() && rB.Cols() == Cols()) + || (transA == TRANS && transB == NO_TRANS && rA.Rows() == rB.Rows() && rA.Cols() == Rows() && rB.Cols() == Cols()) + || (transA == NO_TRANS && transB == TRANS && rA.Cols() == rB.Cols() && rA.Rows() == Rows() && rB.Rows() == Cols()) + || (transA == TRANS && transB == TRANS && rA.Rows() == rB.Cols() && rA.Cols() == Rows() && rB.Rows() == Cols())); + + cblas_dgemm(CblasRowMajor, static_cast<CBLAS_TRANSPOSE>(transA), static_cast<CBLAS_TRANSPOSE>(transB), + Rows(), Cols(), transA == NO_TRANS ? rA.Cols() : rA.Rows(), + alpha, rA.mpData, rA.mStride, rB.mpData, rB.mStride, + beta, mpData, mStride); + return *this; + } + + template<> + Matrix<float>& + Matrix<float>:: + Axpy(const float alpha, + const Matrix<float>& rA, MatrixTrasposeType transA){ + int aStride = (int)rA.mStride, stride = mStride; + float *adata=rA.mpData, *data=mpData; + if(transA == NO_TRANS){ + assert(rA.Rows()==Rows() && rA.Cols()==Cols()); + for(size_t row=0;row<mMRows;row++,adata+=aStride,data+=stride) + cblas_saxpy(mMCols, alpha, adata, 1, data, 1); + } else { + assert(rA.Cols()==Rows() && rA.Rows()==Cols()); + for(size_t row=0;row<mMRows;row++,adata++,data+=stride) + cblas_saxpy(mMCols, alpha, adata, aStride, data, 1); + } + return *this; + } + + template<> + Matrix<double>& + Matrix<double>:: + Axpy(const double alpha, + const Matrix<double>& rA, MatrixTrasposeType transA){ + int aStride = (int)rA.mStride, stride = mStride; + double *adata=rA.mpData, *data=mpData; + if(transA == NO_TRANS){ + assert(rA.Rows()==Rows() && rA.Cols()==Cols()); + for(size_t row=0;row<mMRows;row++,adata+=aStride,data+=stride) + cblas_daxpy(mMCols, alpha, adata, 1, data, 1); + } else { + assert(rA.Cols()==Rows() && rA.Rows()==Cols()); + for(size_t row=0;row<mMRows;row++,adata++,data+=stride) + cblas_daxpy(mMCols, alpha, adata, aStride, data, 1); + } + return *this; + } + + template <> //non-member but friend! + double TraceOfProduct(const Matrix<double> &A, const Matrix<double> &B){ // tr(A B), equivalent to sum of each element of A times same element in B' + size_t aStride = A.mStride, bStride = B.mStride; + assert(A.Rows()==B.Cols() && A.Cols()==B.Rows()); + double ans = 0.0; + double *adata=A.mpData, *bdata=B.mpData; + size_t arows=A.Rows(), acols=A.Cols(); + for(size_t row=0;row<arows;row++,adata+=aStride,bdata++) + ans += cblas_ddot(acols, adata, 1, bdata, bStride); + return ans; + } + + template <> //non-member but friend! + double TraceOfProductT(const Matrix<double> &A, const Matrix<double> &B){ // tr(A B), equivalent to sum of each element of A times same element in B' + size_t aStride = A.mStride, bStride = B.mStride; + assert(A.Rows()==B.Rows() && A.Cols()==B.Cols()); + double ans = 0.0; + double *adata=A.mpData, *bdata=B.mpData; + size_t arows=A.Rows(), acols=A.Cols(); + for(size_t row=0;row<arows;row++,adata+=aStride,bdata+=bStride) + ans += cblas_ddot(acols, adata, 1, bdata, 1); + return ans; + } + + + template <> //non-member but friend! + float TraceOfProduct(const Matrix<float> &A, const Matrix<float> &B){ // tr(A B), equivalent to sum of each element of A times same element in B' + size_t aStride = A.mStride, bStride = B.mStride; + assert(A.Rows()==B.Cols() && A.Cols()==B.Rows()); + float ans = 0.0; + float *adata=A.mpData, *bdata=B.mpData; + size_t arows=A.Rows(), acols=A.Cols(); + for(size_t row=0;row<arows;row++,adata+=aStride,bdata++) + ans += cblas_sdot(acols, adata, 1, bdata, bStride); + return ans; + } + + template <> //non-member but friend! + float TraceOfProductT(const Matrix<float> &A, const Matrix<float> &B){ // tr(A B), equivalent to sum of each element of A times same element in B' + size_t aStride = A.mStride, bStride = B.mStride; + assert(A.Rows()==B.Rows() && A.Cols()==B.Cols()); + float ans = 0.0; + float *adata=A.mpData, *bdata=B.mpData; + size_t arows=A.Rows(), acols=A.Cols(); + for(size_t row=0;row<arows;row++,adata+=aStride,bdata+=bStride) + ans += cblas_sdot(acols, adata, 1, bdata, 1); + return ans; + } + + + + +#endif //HAVE_ATLAS + + + +} //namespace STK diff --git a/htk_io/src/KaldiLib/Matrix.h b/htk_io/src/KaldiLib/Matrix.h new file mode 100644 index 0000000..d33cb0c --- /dev/null +++ b/htk_io/src/KaldiLib/Matrix.h @@ -0,0 +1,677 @@ +#ifndef TNet_Matrix_h +#define TNet_Matrix_h + +#include <stddef.h> +#include <stdlib.h> +#include <stdexcept> +#include <iostream> + +#ifdef HAVE_ATLAS +extern "C"{ + #include <cblas.h> + #include <clapack.h> +} +#endif + +#include "Common.h" +#include "MathAux.h" +#include "Types.h" +#include "Error.h" + +//#define TRACE_MATRIX_OPERATIONS +#define CHECKSIZE + +namespace TNet +{ + + + // class matrix_error : public std::logic_error {}; + // class matrix_sizes_error : public matrix_error {}; + + // declare the class so the header knows about it + template<typename _ElemT> class Vector; + template<typename _ElemT> class SubVector; + template<typename _ElemT> class Matrix; + template<typename _ElemT> class SubMatrix; + + // we need to declare the friend << operator here + template<typename _ElemT> + std::ostream & operator << (std::ostream & rOut, const Matrix<_ElemT> & rM); + + // we need to declare the friend << operator here + template<typename _ElemT> + std::istream & operator >> (std::istream & rIn, Matrix<_ElemT> & rM); + + // we need to declare this friend function here + template<typename _ElemT> + _ElemT TraceOfProduct(const Matrix<_ElemT> &A, const Matrix<_ElemT> &B); // tr(A B) + + // we need to declare this friend function here + template<typename _ElemT> + _ElemT TraceOfProductT(const Matrix<_ElemT> &A, const Matrix<_ElemT> &B); // tr(A B^T)==tr(A^T B) + + + /** ************************************************************************** + ** ************************************************************************** + * @brief Provides a matrix class + * + * This class provides a way to work with matrices in TNet. + * It encapsulates basic operations and memory optimizations. + * + */ + template<typename _ElemT> + class Matrix + { + public: + /// defines a transpose type + + struct HtkHeader + { + INT_32 mNSamples; + INT_32 mSamplePeriod; + INT_16 mSampleSize; + UINT_16 mSampleKind; + }; + + + /** + * @brief Extension of the HTK header + */ + struct HtkHeaderExt + { + INT_32 mHeaderSize; + INT_32 mVersion; + INT_32 mSampSize; + }; + + + + + /// defines a type of this + typedef Matrix<_ElemT> ThisType; + + // Constructors + + /// Empty constructor + Matrix<_ElemT> (): + mpData(NULL), mMCols(0), mMRows(0), mStride(0) +#ifdef STK_MEMALIGN_MANUAL + , mpFreeData(NULL) +#endif + {} + + /// Copy constructor + Matrix<_ElemT> (const Matrix<_ElemT> & rM, MatrixTrasposeType trans=NO_TRANS): + mpData(NULL) + { if(trans==NO_TRANS){ Init(rM.mMRows, rM.mMCols); Copy(rM); } else { Init(rM.mMCols,rM.mMRows); Copy(rM,TRANS); } } + + /// Copy constructor from another type. + template<typename _ElemU> + explicit Matrix<_ElemT> (const Matrix<_ElemU> & rM, MatrixTrasposeType trans=NO_TRANS): + mpData(NULL) + { if(trans==NO_TRANS){ Init(rM.Rows(), rM.Cols()); Copy(rM); } else { Init(rM.Cols(),rM.Rows()); Copy(rM,TRANS); } } + + /// Basic constructor + Matrix(const size_t r, const size_t c, bool clear=true) + { mpData=NULL; Init(r, c, clear); } + + + Matrix<_ElemT> &operator = (const Matrix <_ElemT> &other) { Init(other.Rows(), other.Cols()); Copy(other); return *this; } // Needed for inclusion in std::vector + + /// Destructor + ~Matrix() + { Destroy(); } + + + /// Initializes matrix (if not done by constructor) + ThisType & + Init(const size_t r, + const size_t c, bool clear=true); + + /** + * @brief Dealocates the matrix from memory and resets the dimensions to (0, 0) + */ + void + Destroy(); + + + ThisType & + Zero(); + + ThisType & + Unit(); // set to unit. + + /** + * @brief Copies the contents of a matrix + * @param rM Source data matrix + * @return Returns reference to this + */ + template<typename _ElemU> ThisType & + Copy(const Matrix<_ElemU> & rM, MatrixTrasposeType Trans=NO_TRANS); + + + + /** + * @brief Copies the elements of a vector row-by-row into a matrix + * @param rV Source vector + * @param nRows Number of rows of returned matrix + * @param nCols Number of columns of returned matrix + * + * Note that rV.Dim() must equal nRows*nCols + */ + ThisType & + CopyVectorSplicedRows(const Vector<_ElemT> &rV, const size_t nRows, const size_t nCols); + + /** + * @brief Returns @c true if matrix is initialized + */ + bool + IsInitialized() const + { return mpData != NULL; } + + /// Returns number of rows in the matrix + inline size_t + Rows() const + { + return mMRows; + } + + /// Returns number of columns in the matrix + inline size_t + Cols() const + { + return mMCols; + } + + /// Returns number of columns in the matrix memory + inline size_t + Stride() const + { + return mStride; + } + + + /** + * @brief Gives access to a specified matrix row without range check + * @return Pointer to the const array + */ + inline const _ElemT* __attribute__((aligned(16))) + pData () const + { + return mpData; + } + + + /** + * @brief Gives access to a specified matrix row without range check + * @return Pointer to the non-const data array + */ + inline _ElemT* __attribute__((aligned(16))) + pData () + { + return mpData; + } + + + /** + * @brief pData_workaround is a workaround that allows SubMatrix to get a + * @return pointer to non-const data even though the Matrix is const... + */ + protected: + inline _ElemT* __attribute__((aligned(16))) + pData_workaround () const + { + return mpData; + } + public: + + + /// Returns size of matrix in memory + size_t + MSize() const + { + return mMRows * mStride * sizeof(_ElemT); + } + + /// Checks the content of the matrix for nan and inf values + void + CheckData(const std::string file = "") const + { + for(size_t row=0; row<Rows(); row++) { + for(size_t col=0; col<Cols(); col++) { + if(isnan((*this)(row,col)) || isinf((*this)(row,col))) { + std::ostringstream os; + os << "Invalid value: " << (*this)(row,col) + << " in matrix row: " << row + << " col: " << col + << " file: " << file; + Error(os.str()); + } + } + } + } + + /** + * ********************************************************************** + * ********************************************************************** + * @defgroup RESHAPE Matrix reshaping rutines + * ********************************************************************** + * ********************************************************************** + * @{ + */ + + /** + * @brief Removes one row from the matrix. The memory is not reallocated. + */ + ThisType & + RemoveRow(size_t i); + + /** @} */ + + /** + * ********************************************************************** + * ********************************************************************** + * @defgroup ACCESS Access functions and operators + * ********************************************************************** + * ********************************************************************** + * @{ + */ + + /** + * @brief Gives access to a specified matrix row without range check + * @return Subvector object representing the row + */ + inline const SubVector<_ElemT> + operator [] (size_t i) const + { + assert(i < mMRows); + return SubVector<_ElemT>(mpData + (i * mStride), Cols()); + } + + inline SubVector<_ElemT> + operator [] (size_t i) + { + assert(i < mMRows); + return SubVector<_ElemT>(mpData + (i * mStride), Cols()); + } + + /** + * @brief Gives access to a specified matrix row without range check + * @return pointer to the first field of the row + */ + inline _ElemT* + pRowData(size_t i) + { + assert(i < mMRows); + return mpData + i * mStride; + } + + /** + * @brief Gives access to a specified matrix row without range check + * @return pointer to the first field of the row (const version) + */ + inline const _ElemT* + pRowData(size_t i) const + { + assert(i < mMRows); + return mpData + i * mStride; + } + + /** + * @brief Gives access to matrix elements (row, col) + * @return reference to the desired field + */ + inline _ElemT& + operator () (size_t r, size_t c) + { +#ifdef PARANOID + assert(r < mMRows && c < mMCols); +#endif + return *(mpData + r * mStride + c); + } + + /** + * @brief Gives access to matrix elements (row, col) + * @return pointer to the desired field (const version) + */ + inline const _ElemT + operator () (size_t r, size_t c) const + { +#ifdef PARANOID + assert(r < mMRows && c < mMCols); +#endif + return *(mpData + r * mStride + c); + } + + /** + * @brief Returns a matrix sub-range + * @param ro Row offset + * @param r Rows in range + * @param co Column offset + * @param c Coluns in range + * See @c SubMatrix class for details + */ + SubMatrix<_ElemT> + Range(const size_t ro, const size_t r, + const size_t co, const size_t c) + { return SubMatrix<_ElemT>(*this, ro, r, co, c); } + + const SubMatrix<_ElemT> + Range(const size_t ro, const size_t r, + const size_t co, const size_t c) const + { return SubMatrix<_ElemT>(*this, ro, r, co, c); } + /** @} */ + + + /** + * ********************************************************************** + * ********************************************************************** + * @defgroup MATH ROUTINES + * ********************************************************************** + * ********************************************************************** + * @{ + **/ + + /** + * @brief Returns sum of all elements + */ + _ElemT& + Sum() const; + + ThisType & + DotMul(const ThisType& a); + + ThisType & + Scale(_ElemT alpha); + + ThisType & + ScaleCols(const Vector<_ElemT> &scale); // Equivalent to (*this) = (*this) * diag(scale). + + ThisType & + ScaleRows(const Vector<_ElemT> &scale); // Equivalent to (*this) = diag(scale) * (*this); + + /// Sum another matrix rMatrix with this matrix + ThisType& + Add(const Matrix<_ElemT>& rMatrix); + + + /// Sum scaled matrix rMatrix with this matrix + ThisType& + AddScaled(_ElemT alpha, const Matrix<_ElemT>& rMatrix); + + /// Apply log to all items of the matrix + ThisType& + ApplyLog(); + + /** + * @brief Computes the determinant of this matrix + * @return Returns the determinant of a matrix + * @ingroup MATH + * + */ + _ElemT LogAbsDeterminant(_ElemT *DetSign=NULL); + + + /** + * @brief Performs matrix inplace inversion + */ + ThisType & + Invert(_ElemT *LogDet=NULL, _ElemT *DetSign=NULL, bool inverse_needed=true); + + /** + * @brief Performs matrix inplace inversion in double precision, even if this object is not double precision. + */ + ThisType & + InvertDouble(_ElemT *LogDet=NULL, _ElemT *DetSign=NULL, bool inverse_needed=true){ + double LogDet_tmp, DetSign_tmp; + Matrix<double> dmat(*this); dmat.Invert(&LogDet_tmp, &DetSign_tmp, inverse_needed); if(inverse_needed) (*this).Copy(dmat); + if(LogDet) *LogDet = LogDet_tmp; if(DetSign) *DetSign = DetSign_tmp; + return *this; + } + + + /** + * @brief Inplace matrix transposition. Applicable only to square matrices + */ + ThisType & + Transpose() + { + assert(Rows()==Cols()); + size_t M=Rows(); + for(size_t i=0;i<M;i++) + for(size_t j=0;j<i;j++){ + _ElemT &a = (*this)(i,j), &b = (*this)(j,i); + std::swap(a,b); + } + return *this; + } + + + + + + bool IsSymmetric(_ElemT cutoff = 1.0e-05) const; + + bool IsDiagonal(_ElemT cutoff = 1.0e-05) const; + + bool IsUnit(_ElemT cutoff = 1.0e-05) const; + + bool IsZero(_ElemT cutoff = 1.0e-05) const; + + _ElemT FrobeniusNorm() const; // sqrt of sum of square elements. + + _ElemT LargestAbsElem() const; // largest absolute value. + + + friend _ElemT TNet::TraceOfProduct<_ElemT>(const Matrix<_ElemT> &A, const Matrix<_ElemT> &B); // tr(A B) + friend _ElemT TNet::TraceOfProductT<_ElemT>(const Matrix<_ElemT> &A, const Matrix<_ElemT> &B); // tr(A B^T)==tr(A^T B) + friend class SubMatrix<_ElemT>; // so it can get around const restrictions on the pointer to mpData. + + /** ********************************************************************** + * ********************************************************************** + * @defgroup BLAS_ROUTINES BLAS ROUTINES + * @ingroup MATH + * ********************************************************************** + * ********************************************************************** + **/ + + ThisType & + BlasGer(const _ElemT alpha, const Vector<_ElemT>& rA, const Vector<_ElemT>& rB); + + ThisType & + Axpy(const _ElemT alpha, const Matrix<_ElemT> &rM, MatrixTrasposeType transA=NO_TRANS); + + ThisType & + BlasGemm(const _ElemT alpha, + const ThisType& rA, MatrixTrasposeType transA, + const ThisType& rB, MatrixTrasposeType transB, + const _ElemT beta = 0.0); + + + /** @} */ + + + /** ********************************************************************** + * ********************************************************************** + * @defgroup IO Input/Output ROUTINES + * ********************************************************************** + * ********************************************************************** + * @{ + **/ + + friend std::ostream & + operator << <> (std::ostream & out, const ThisType & m); + + void PrintOut(char *file); + void ReadIn(char *file); + + + bool + LoadHTK(const char* pFileName); + + /** @} */ + + + protected: +// inline void swap4b(void *a); +// inline void swap2b(void *a); + + + protected: + /// data memory area + _ElemT* mpData; + + /// these atributes store the real matrix size as it is stored in memory + /// including memalignment + size_t mMCols; ///< Number of columns + size_t mMRows; ///< Number of rows + size_t mStride; ///< true number of columns for the internal matrix. + ///< This number may differ from M_cols as memory + ///< alignment might be used + +#ifdef STK_MEMALIGN_MANUAL + /// data to be freed (in case of manual memalignment use, see Common.h) + _ElemT* mpFreeData; +#endif + }; // class Matrix + + template<> Matrix<float> & Matrix<float>::Invert(float *LogDet, float *DetSign, bool inverse_needed); // state that we will implement separately for float and double. + template<> Matrix<double> & Matrix<double>::Invert(double *LogDet, double *DetSign, bool inverse_needed); + + + + /** ************************************************************************** + ** ************************************************************************** + * @brief Sub-matrix representation + * + * This class provides a way to work with matrix cutouts in STK. + * + * + */ + template<typename _ElemT> + class SubMatrix : public Matrix<_ElemT> + { + typedef SubMatrix<_ElemT> ThisType; + + public: + /// Constructor + SubMatrix(const Matrix<_ElemT>& rT, // Input matrix cannot be const because SubMatrix can change its contents. + const size_t ro, + const size_t r, + const size_t co, + const size_t c); + + + /// The destructor + ~SubMatrix<_ElemT>() + { +#ifndef STK_MEMALIGN_MANUAL + Matrix<_ElemT>::mpData = NULL; +#else + Matrix<_ElemT>::mpFreeData = NULL; +#endif + } + + /// Assign operator + ThisType& operator=(const ThisType& rSrc) + { + //std::cout << "[PERFORMing operator= SubMatrix&^2]" << std::flush; + this->mpData = rSrc.mpData; + this->mMCols = rSrc.mMCols; + this->mMRows = rSrc.mMRows; + this->mStride = rSrc.mStride; + this->mpFreeData = rSrc.mpFreeData; + return *this; + } + + + + /// Initializes matrix (if not done by constructor) + ThisType & + Init(const size_t r, + const size_t c, bool clear=true) + { Error("Submatrix cannot do Init"); return *this; } + + /** + * @brief Dealocates the matrix from memory and resets the dimensions to (0, 0) + */ + void + Destroy() + { Error("Submatrix cannot do Destroy"); } + + + + }; + + + + //Create useful shortcuts + typedef Matrix<BaseFloat> BfMatrix; + typedef SubMatrix<BaseFloat> BfSubMatrix; + + /** + * Function for summing matrices of different types + */ + template<typename _ElemT, typename _ElemU> + void Add(Matrix<_ElemT>& rDst, const Matrix<_ElemU>& rSrc) { + assert(rDst.Cols() == rSrc.Cols()); + assert(rDst.Rows() == rSrc.Rows()); + + for(size_t i=0; i<rDst.Rows(); i++) { + const _ElemU* p_src = rSrc.pRowData(i); + _ElemT* p_dst = rDst.pRowData(i); + for(size_t j=0; j<rDst.Cols(); j++) { + *p_dst++ += (_ElemT)*p_src++; + } + } + } + + /** + * Function for summing matrices of different types + */ + template<typename _ElemT, typename _ElemU> + void AddScaled(Matrix<_ElemT>& rDst, const Matrix<_ElemU>& rSrc, _ElemT scale) { + assert(rDst.Cols() == rSrc.Cols()); + assert(rDst.Rows() == rSrc.Rows()); + + Vector<_ElemT> tmp(rDst[0]); + + for(size_t i=0; i<rDst.Rows(); i++) { + tmp.Copy(rSrc[i]); + rDst[i].BlasAxpy(scale, tmp); + + /* + const _ElemU* p_src = rSrc.pRowData(i); + _ElemT* p_dst = rDst.pRowData(i); + for(size_t j=0; j<rDst.Cols(); j++) { + *p_dst++ += (_ElemT)(*p_src++) * scale; + } + */ + } + } + + + + + +} // namespace STK + + + +//***************************************************************************** +//***************************************************************************** +// we need to include the implementation +#include "Matrix.tcc" +//***************************************************************************** +//***************************************************************************** + + +/****************************************************************************** + ****************************************************************************** + * The following section contains specialized template definitions + * whose implementation is in Matrix.cc + */ + + +//#ifndef TNet_Matrix_h +#endif diff --git a/htk_io/src/KaldiLib/Matrix.tcc b/htk_io/src/KaldiLib/Matrix.tcc new file mode 100644 index 0000000..f6ffb8f --- /dev/null +++ b/htk_io/src/KaldiLib/Matrix.tcc @@ -0,0 +1,796 @@ + +/** @file Matrix.tcc + * This is an internal header file, included by other library headers. + * You should not attempt to use it directly. + */ + + +#ifndef TNet_Matrix_tcc +#define TNet_Matrix_tcc + +//#pragma GCC system_header + +#include <cstdlib> +#include <cmath> +#include <cfloat> +#include <fstream> +#include <iomanip> +#include <typeinfo> +#include <algorithm> +#include <limits> +#include <vector> +#include "Common.h" + +#ifndef _XOPEN_SOURCE + #define _XOPEN_SOURCE 600 +#endif + + +#ifdef HAVE_ATLAS +extern "C"{ + #include <cblas.h> +} +#endif + + +#include "Common.h" +#include "Vector.h" +namespace TNet +{ + +//****************************************************************************** + template<typename _ElemT> + Matrix<_ElemT> & + Matrix<_ElemT>:: + Init(const size_t rows, + const size_t cols, + bool clear) + { + if(mpData != NULL) Destroy(); + if(rows*cols == 0){ + assert(rows==0 && cols==0); + mMRows=rows; + mMCols=cols; +#ifdef STK_MEMALIGN_MANUAL + mpFreeData=NULL; +#endif + mpData=NULL; + return *this; + } + // initialize some helping vars + size_t skip; + size_t real_cols; + size_t size; + void* data; // aligned memory block + void* free_data; // memory block to be really freed + + // compute the size of skip and real cols + skip = ((16 / sizeof(_ElemT)) - cols % (16 / sizeof(_ElemT))) % (16 / sizeof(_ElemT)); + real_cols = cols + skip; + size = rows * real_cols * sizeof(_ElemT); + + // allocate the memory and set the right dimensions and parameters + + if (NULL != (data = stk_memalign(16, size, &free_data))) + { + mpData = static_cast<_ElemT *> (data); +#ifdef STK_MEMALIGN_MANUAL + mpFreeData = static_cast<_ElemT *> (free_data); +#endif + mMRows = rows; + mMCols = cols; + mStride = real_cols; + } + else + { + throw std::bad_alloc(); + } + if(clear) Zero(); + return *this; + } // + + //**************************************************************************** + //**************************************************************************** + template<typename _ElemT> + template<typename _ElemU> + Matrix<_ElemT> & + Matrix<_ElemT>:: + Copy(const Matrix<_ElemU> & rM, MatrixTrasposeType Trans) + { + if(Trans==NO_TRANS){ + assert(mMRows == rM.Rows() && mMCols == rM.Cols()); + for(size_t i = 0; i < mMRows; i++) + (*this)[i].Copy(rM[i]); + return *this; + } else { + assert(mMCols == rM.Rows() && mMRows == rM.Cols()); + for(size_t i = 0; i < mMRows; i++) + for(size_t j = 0; j < mMCols; j++) + (*this)(i,j) = rM(j,i); + return *this; + } + } + + + + //**************************************************************************** + //**************************************************************************** + template<typename _ElemT> + Matrix<_ElemT> & + Matrix<_ElemT>:: + CopyVectorSplicedRows(const Vector<_ElemT> &rV, const size_t nRows, const size_t nCols) { + assert(rV.Dim() == nRows*nCols); + mMRows = nRows; + mMCols = nCols; + Init(nRows,nCols,true); + for(size_t r=0; r<mMRows; r++) + for(size_t c=0; c<mMCols; c++) + (*this)(r,c) = rV(r*mMCols + c); + + return *this; + } + + //**************************************************************************** + //**************************************************************************** + template<typename _ElemT> + Matrix<_ElemT> & + Matrix<_ElemT>:: + RemoveRow(size_t i) + { + assert(i < mMRows && "Access out of matrix"); + for(size_t j = i + 1; j < mMRows; j++) + (*this)[j - 1].Copy((*this)[j]); + mMRows--; + return *this; + } + + + //**************************************************************************** + //**************************************************************************** + // The destructor + template<typename _ElemT> + void + Matrix<_ElemT>:: + Destroy() + { + // we need to free the data block if it was defined +#ifndef STK_MEMALIGN_MANUAL + if (NULL != mpData) free(mpData); +#else + if (NULL != mpData) free(mpFreeData); + mpFreeData = NULL; +#endif + + mpData = NULL; + mMRows = mMCols = 0; + } + + //**************************************************************************** + //**************************************************************************** +// template<typename _ElemT> +// void +// Matrix<_ElemT>:: +// VectorizeRows(Vector<_ElemT> &rV) { +//#ifdef PARANIOD +// assert(rV.Dim() == mMRows*mMCols); +//#endif +// for(size_t r=0; r<mMRows; r++) { +// rV.Range((r-1)*mMCols, mMCols).Copy((*this)[r]); +// } +// } + + + //**************************************************************************** + //**************************************************************************** + template<typename _ElemT> + bool + Matrix<_ElemT>:: + LoadHTK(const char* pFileName) + { + HtkHeader htk_hdr; + + FILE *fp = fopen(pFileName, "rb"); + if(!fp) + { + return false; + } + + read(fileno(fp), &htk_hdr, sizeof(htk_hdr)); + + swap4(htk_hdr.mNSamples); + swap4(htk_hdr.mSamplePeriod); + swap2(htk_hdr.mSampleSize); + swap2(htk_hdr.mSampleKind); + + Init(htk_hdr.mNSamples, htk_hdr.mSampleSize / sizeof(float)); + + size_t i; + size_t j; + if (typeid(_ElemT) == typeid(float)) + { + for (i=0; i< Rows(); ++i) { + read(fileno(fp), (*this).pRowData(i), Cols() * sizeof(float)); + + for(j = 0; j < Cols(); j++) { + swap4(((*this)(i,j))); + } + } + } + else + { + float *pmem = new (std::nothrow) float[Cols()]; + if (!pmem) + { + fclose(fp); + return false; + } + + for(i = 0; i < Rows(); i++) { + read(fileno(fp), pmem, Cols() * sizeof(float)); + + for (j = 0; j < Cols(); ++j) { + swap4(pmem[j]); + (*this)(i,j) = static_cast<_ElemT>(pmem[j]); + } + } + delete [] pmem; + } + + fclose(fp); + + return true; + } + + //**************************************************************************** + //**************************************************************************** + template<typename _ElemT> + Matrix<_ElemT> & + Matrix<_ElemT>:: + DotMul(const ThisType& a) + { + size_t i; + size_t j; + + for (i = 0; i < mMRows; ++i) { + for (j = 0; j < mMCols; ++j) { + (*this)(i,j) *= a(i,j); + } + } + return *this; + } + + //**************************************************************************** + //**************************************************************************** + template<typename _ElemT> + _ElemT & + Matrix<_ElemT>:: + Sum() const + { + double sum = 0.0; + + for (size_t i = 0; i < Rows(); ++i) { + for (size_t j = 0; j < Cols(); ++j) { + sum += (*this)(i,j); + } + } + + return sum; + } + + //**************************************************************************** + //**************************************************************************** + template<typename _ElemT> + Matrix<_ElemT>& + Matrix<_ElemT>:: + Scale(_ElemT alpha) + { +#if 0 + for (size_t i = 0; i < Rows(); ++i) + for (size_t j = 0; j < Cols(); ++j) + (*this)(i,j) *= alpha; +#else + for (size_t i = 0; i < Rows(); ++i) { + _ElemT* p_data = pRowData(i); + for (size_t j = 0; j < Cols(); ++j) { + *p_data++ *= alpha; + } + } +#endif + return *this; + } + + //**************************************************************************** + //**************************************************************************** + template<typename _ElemT> + Matrix<_ElemT>& + Matrix<_ElemT>:: + ScaleRows(const Vector<_ElemT>& scale) // scales each row by scale[i]. + { + assert(scale.Dim() == Rows()); + size_t M = Rows(), N = Cols(); + + for (size_t i = 0; i < M; i++) { + _ElemT this_scale = scale(i); + for (size_t j = 0; j < N; j++) { + (*this)(i,j) *= this_scale; + } + } + return *this; + } + + //**************************************************************************** + //**************************************************************************** + template<typename _ElemT> + Matrix<_ElemT>& + Matrix<_ElemT>:: + ScaleCols(const Vector<_ElemT>& scale) // scales each column by scale[i]. + { + assert(scale.Dim() == Cols()); + for (size_t i = 0; i < Rows(); i++) { + for (size_t j = 0; j < Cols(); j++) { + _ElemT this_scale = scale(j); + (*this)(i,j) *= this_scale; + } + } + return *this; + } + + + //**************************************************************************** + //**************************************************************************** + template<typename _ElemT> + Matrix<_ElemT>& + Matrix<_ElemT>:: + Add(const Matrix<_ElemT>& rMatrix) + { + assert(rMatrix.Cols() == Cols()); + assert(rMatrix.Rows() == Rows()); + +#if 0 + //this can be slow + for (size_t i = 0; i < Rows(); i++) { + for (size_t j = 0; j < Cols(); j++) { + (*this)(i,j) += rMatrix(i,j); + } + } +#else + //this will be faster (but less secure) + for(size_t i=0; i<Rows(); i++) { + const _ElemT* p_src = rMatrix.pRowData(i); + _ElemT* p_dst = pRowData(i); + for(size_t j=0; j<Cols(); j++) { + *p_dst++ += *p_src++; + } + } +#endif + return *this; + } + + + + //**************************************************************************** + //**************************************************************************** + template<typename _ElemT> + Matrix<_ElemT>& + Matrix<_ElemT>:: + AddScaled(_ElemT alpha, const Matrix<_ElemT>& rMatrix) + { + assert(rMatrix.Cols() == Cols()); + assert(rMatrix.Rows() == Rows()); + +#if 0 + //this can be slow + for (size_t i = 0; i < Rows(); i++) { + for (size_t j = 0; j < Cols(); j++) { + (*this)(i,j) += rMatrix(i,j) * alpha; + } + } +#else + /* + //this will be faster (but less secure) + for(size_t i=0; i<Rows(); i++) { + const _ElemT* p_src = rMatrix.pRowData(i); + _ElemT* p_dst = pRowData(i); + for(size_t j=0; j<Cols(); j++) { + *p_dst++ += *p_src++ * alpha; + } + } + */ + + //let's use BLAS + for(size_t i=0; i<Rows(); i++) { + (*this)[i].BlasAxpy(alpha, rMatrix[i]); + } +#endif + return *this; + } + + + + //**************************************************************************** + //**************************************************************************** + template<typename _ElemT> + Matrix<_ElemT>& + Matrix<_ElemT>:: + ApplyLog() + { + +#if 0 + //this can be slow + for (size_t i = 0; i < Rows(); i++) { + for (size_t j = 0; j < Cols(); j++) { + (*this)(i,j) = += _LOG((*this)(i,j)); + } + } +#else + //this will be faster (but less secure) + for(size_t i=0; i<Rows(); i++) { + _ElemT* p_data = pRowData(i); + for(size_t j=0; j<Cols(); j++) { + *p_data = _LOG(*p_data); + p_data++; + } + } +#endif + return *this; + } + + + + //**************************************************************************** + //**************************************************************************** + template<typename _ElemT> + Matrix<_ElemT> & + Matrix<_ElemT>:: + Zero() + { + for(size_t row=0;row<mMRows;row++) + memset(mpData + row*mStride, 0, sizeof(_ElemT)*mMCols); + return *this; + } + + //**************************************************************************** + //**************************************************************************** + template<typename _ElemT> + Matrix<_ElemT> & + Matrix<_ElemT>:: + Unit() + { + for(size_t row=0;row<std::min(mMRows,mMCols);row++){ + memset(mpData + row*mStride, 0, sizeof(_ElemT)*mMCols); + (*this)(row,row) = 1.0; + } + return *this; + } + + + //**************************************************************************** + //**************************************************************************** + template<typename _ElemT> + void + Matrix<_ElemT>:: + PrintOut(char* file) + { + FILE* f = fopen(file, "w"); + unsigned i,j; + fprintf(f, "%dx%d\n", this->mMRows, this->mMCols); + + for(i=0; i<this->mMRows; i++) + { + _ElemT* row = (*this)[i]; + + for(j=0; j<this->mStride; j++){ + fprintf(f, "%20.17f ",row[j]); + } + fprintf(f, "\n"); + } + + fclose(f); + } + + + //**************************************************************************** + //**************************************************************************** + template<typename _ElemT> + void + Matrix<_ElemT>:: + ReadIn(char* file) + { + FILE* f = fopen(file, "r"); + int i = 0; + int j = 0; + fscanf(f, "%dx%d\n", &i,&j); + fprintf(stderr, "%dx%d\n", i,j); + + for(i=0; i<this->mMRows; i++) + { + _ElemT* row = (*this)[i]; + + for(j=0; j<this->mStride; j++){ + fscanf(f, "%f ",&row[j]); + } + //fprintf(f, "\n"); + } + + fclose(f); + } + + + //**************************************************************************** + //**************************************************************************** + template<typename _ElemT> + void Save (std::ostream &rOut, const Matrix<_ElemT> &rM) + { + for (size_t i = 0; i < rM.Rows(); i++) { + for (size_t j = 0; j < rM.Cols(); j++) { + rOut << rM(i,j) << ' '; + } + rOut << '\n'; + } + if(rOut.fail()) + throw std::runtime_error("Failed to write matrix to stream"); + } + + + //**************************************************************************** + //**************************************************************************** + template<typename _ElemT> + std::ostream & + operator << (std::ostream & rOut, const Matrix<_ElemT> & rM) + { + rOut << "m " << rM.Rows() << ' ' << rM.Cols() << '\n'; + Save(rOut, rM); + return rOut; + } + + + + //**************************************************************************** + //**************************************************************************** + template<typename _ElemT> + void Load (std::istream & rIn, Matrix<_ElemT> & rM) + { + if(MatrixVectorIostreamControl::Flags(rIn, ACCUMULATE_INPUT)) { + for (size_t i = 0; i < rM.Rows(); i++) { + std::streamoff pos = rIn.tellg(); + for (size_t j = 0; j < rM.Cols(); j++) { + _ElemT tmp; + rIn >> tmp; + rM(i,j) += tmp; + if(rIn.fail()){ + throw std::runtime_error("Failed to read matrix from stream. File position is "+to_string(pos)); + } + } + } + } else { + for (size_t i = 0; i < rM.Rows(); i++) { + std::streamoff pos = rIn.tellg(); + for (size_t j = 0; j < rM.Cols(); j++) { + rIn >> rM(i,j); + if(rIn.fail()){ + throw std::runtime_error("Failed to read matrix from stream. File position is "+to_string(pos)); + } + + } + } + } + } + + + //**************************************************************************** + //**************************************************************************** + template<typename _ElemT> + std::istream & + operator >> (std::istream & rIn, Matrix<_ElemT> & rM) + { + while(isascii(rIn.peek()) && isspace(rIn.peek())) rIn.get(); // eat up space. + if(rIn.peek() == 'm'){ // "new" format: m <nrows> <ncols> \n 1.0 0.2 4.3 ... + rIn.get();// eat up the 'm'. + long long int nrows=-1; rIn>>nrows; + long long int ncols=-1; rIn>>ncols; + if(rIn.fail()||nrows<0||ncols<0){ throw std::runtime_error("Failed to read matrix from stream: no size\n"); } + + size_t nrows2 = size_t(nrows), ncols2 = size_t(ncols); + assert((long long int)nrows2 == nrows && (long long int)ncols2 == ncols); + + if(rM.Rows()!=nrows2 || rM.Cols()!=ncols2) rM.Init(nrows2,ncols2); + } + Load(rIn,rM); + return rIn; + } + + + + //**************************************************************************** + //**************************************************************************** + // Constructor + template<typename _ElemT> + SubMatrix<_ElemT>:: + SubMatrix(const Matrix<_ElemT>& rT, // Matrix cannot be const because SubMatrix can change its contents. Would have to have a ConstSubMatrix or something... + const size_t ro, + const size_t r, + const size_t co, + const size_t c) + { + assert(ro >= 0 && ro <= rT.Rows()); + assert(co >= 0 && co <= rT.Cols()); + assert(r > 0 && r <= rT.Rows() - ro); + assert(c > 0 && c <= rT.Cols() - co); + // point to the begining of window + Matrix<_ElemT>::mMRows = r; + Matrix<_ElemT>::mMCols = c; + Matrix<_ElemT>::mStride = rT.Stride(); + Matrix<_ElemT>::mpData = rT.pData_workaround() + co + ro * rT.Stride(); + } + + + +#ifdef HAVE_ATLAS + + template<> + Matrix<float> & + Matrix<float>:: + BlasGer(const float alpha, const Vector<float>& rA, const Vector<float>& rB); + + + template<> + Matrix<double> & + Matrix<double>:: + BlasGer(const double alpha, const Vector<double>& rA, const Vector<double>& rB); + + + template<> + Matrix<float>& + Matrix<float>:: + BlasGemm(const float alpha, + const Matrix<float>& rA, MatrixTrasposeType transA, + const Matrix<float>& rB, MatrixTrasposeType transB, + const float beta); + + template<> + Matrix<double>& + Matrix<double>:: + BlasGemm(const double alpha, + const Matrix<double>& rA, MatrixTrasposeType transA, + const Matrix<double>& rB, MatrixTrasposeType transB, + const double beta); + + template<> + Matrix<float>& + Matrix<float>:: + Axpy(const float alpha, + const Matrix<float>& rA, MatrixTrasposeType transA); + + template<> + Matrix<double>& + Matrix<double>:: + Axpy(const double alpha, + const Matrix<double>& rA, MatrixTrasposeType transA); + + template <> // non-member so automatic namespace lookup can occur. + double TraceOfProduct(const Matrix<double> &A, const Matrix<double> &B); + + template <> // non-member so automatic namespace lookup can occur. + double TraceOfProductT(const Matrix<double> &A, const Matrix<double> &B); + + template <> // non-member so automatic namespace lookup can occur. + float TraceOfProduct(const Matrix<float> &A, const Matrix<float> &B); + + template <> // non-member so automatic namespace lookup can occur. + float TraceOfProductT(const Matrix<float> &A, const Matrix<float> &B); + + + +#else // HAVE_ATLAS + #error Routines in this section are not implemented yet without BLAS +#endif // HAVE_ATLAS + + template<class _ElemT> + bool + Matrix<_ElemT>:: + IsSymmetric(_ElemT cutoff) const { + size_t R=Rows(), C=Cols(); + if(R!=C) return false; + _ElemT bad_sum=0.0, good_sum=0.0; + for(size_t i=0;i<R;i++){ + for(size_t j=0;j<i;j++){ + _ElemT a=(*this)(i,j),b=(*this)(j,i), avg=0.5*(a+b), diff=0.5*(a-b); + good_sum += fabs(avg); bad_sum += fabs(diff); + } + good_sum += fabs((*this)(i,i)); + } + if(bad_sum > cutoff*good_sum) return false; + return true; + } + + template<class _ElemT> + bool + Matrix<_ElemT>:: + IsDiagonal(_ElemT cutoff) const{ + size_t R=Rows(), C=Cols(); + _ElemT bad_sum=0.0, good_sum=0.0; + for(size_t i=0;i<R;i++){ + for(size_t j=0;j<C;j++){ + if(i==j) good_sum += (*this)(i,j); + else bad_sum += (*this)(i,j); + } + } + return (!(bad_sum > good_sum * cutoff)); + } + + template<class _ElemT> + bool + Matrix<_ElemT>:: + IsUnit(_ElemT cutoff) const { + size_t R=Rows(), C=Cols(); + if(R!=C) return false; + _ElemT bad_sum=0.0; + for(size_t i=0;i<R;i++) + for(size_t j=0;j<C;j++) + bad_sum += fabs( (*this)(i,j) - (i==j?1.0:0.0)); + return (bad_sum <= cutoff); + } + + template<class _ElemT> + bool + Matrix<_ElemT>:: + IsZero(_ElemT cutoff)const { + size_t R=Rows(), C=Cols(); + _ElemT bad_sum=0.0; + for(size_t i=0;i<R;i++) + for(size_t j=0;j<C;j++) + bad_sum += fabs( (*this)(i,j) ); + return (bad_sum <= cutoff); + } + + template<class _ElemT> + _ElemT + Matrix<_ElemT>:: + FrobeniusNorm() const{ + size_t R=Rows(), C=Cols(); + _ElemT sum=0.0; + for(size_t i=0;i<R;i++) + for(size_t j=0;j<C;j++){ + _ElemT tmp = (*this)(i,j); + sum += tmp*tmp; + } + return sqrt(sum); + } + + template<class _ElemT> + _ElemT + Matrix<_ElemT>:: + LargestAbsElem() const{ + size_t R=Rows(), C=Cols(); + _ElemT largest=0.0; + for(size_t i=0;i<R;i++) + for(size_t j=0;j<C;j++) + largest = std::max(largest, (_ElemT)fabs((*this)(i,j))); + return largest; + } + + + + // Uses SVD to compute the eigenvalue decomposition of a symmetric positive semidefinite + // matrix: + // (*this) = rU * diag(rS) * rU^T, with rU an orthogonal matrix so rU^{-1} = rU^T. + // Does this by computing svd (*this) = U diag(rS) V^T ... answer is just U diag(rS) U^T. + // Throws exception if this failed to within supplied precision (typically because *this was not + // symmetric positive definite). + + + + template<class _ElemT> + _ElemT + Matrix<_ElemT>:: + LogAbsDeterminant(_ElemT *DetSign){ + _ElemT LogDet; + Matrix<_ElemT> tmp(*this); + tmp.Invert(&LogDet, DetSign, false); // false== output not needed (saves some computation). + return LogDet; + } + +}// namespace TNet + +// #define TNet_Matrix_tcc +#endif diff --git a/htk_io/src/KaldiLib/MlfStream.cc b/htk_io/src/KaldiLib/MlfStream.cc new file mode 100644 index 0000000..a2f6478 --- /dev/null +++ b/htk_io/src/KaldiLib/MlfStream.cc @@ -0,0 +1,268 @@ +#include "MlfStream.h" +#include "Common.h" +#include "Error.h" + + +namespace TNet +{ + //****************************************************************************** + LabelContainer:: + ~LabelContainer() + { + while (!this->mLabelList.empty()) + { + delete this->mLabelList.back(); + this->mLabelList.pop_back(); + } + } + + //****************************************************************************** + size_t + LabelContainer:: + DirDepth(const std::string & rPath) + { + size_t depth = 0; + size_t length = rPath.length(); + const char * s = rPath.c_str(); + + for (size_t i = 0; i < length; i++) + { + if (*s == '/' || *s == '\\') + { + depth++; + } + s++; + } + return depth; + } + + + //****************************************************************************** + void + LabelContainer:: + Insert(const std::string & rLabel, + std::streampos Pos) + { + LabelRecord ls; + size_t depth; + LabelRecord tmp_ls; + + // we need to compute the depth of the label path if + // wildcard is used + // do we have a wildcard??? + if (rLabel[0] == '*') + { + depth = this->DirDepth(rLabel); + } + else + { + depth = MAX_LABEL_DEPTH; + } + + // perhaps we want to store the depth of the path in the label for the wildcards + // to work + this->mDepths.insert(depth); + + // store the values + ls.mStreamPos = Pos; + ls.miLabelListLimit = mLabelList.end(); + + + if (mLabelList.begin() != mLabelList.end()) { + ls.miLabelListLimit--; + } + + // if no wildcard chars, then we try to store in hash, otherwise store in + // list + if (rLabel.find_first_of("*?%",1) == rLabel.npos) + { + if (!Find(rLabel, tmp_ls)) + { + // look in the + this->mLabelMap[rLabel] = ls; + } + else { + ; + //Warning("More general definition found when inserting " + rLabel + " ... label: " + MatchedPattern()); + } + } + else + { + this->mLabelList.push_back(new std::pair<std::string,LabelRecord>(rLabel, ls)); + } + } + + + //****************************************************************************** + bool + LabelContainer:: + FindInHash(const std::string & rLabel, LabelRecord & rLS) + { + bool found = false; + + std::string str; + + // current depth within the str + DepthType current_depth = MAX_LABEL_DEPTH; + + // current search position within the str + size_t prev = rLabel.size() + 1; + + // we will walk through the set depts bacwards so we begin at the end and move + // to the front... + std::set<DepthType>::reverse_iterator ri (this->mDepths.end()); + std::set<DepthType>::reverse_iterator rlast (this->mDepths.begin()); + LabelHashType::iterator lab; + + // we perform the search until we run to the end of the set or we find something + while ((!found) && (ri != rlast)) + { + // we don't need to do anything with the string if the depth is set to + // max label depth since it contains no * + if (*ri == MAX_LABEL_DEPTH) + { + found = ((lab=this->mLabelMap.find(rLabel)) != this->mLabelMap.end()); + if (found) str = rLabel; + } + // we will crop the string and put * in the begining and try to search + else + { + // we know that we walk backwards in the depths, so we need to first find + // the last / and + if (current_depth == MAX_LABEL_DEPTH) + { + if (*ri > 0) + { + // we find the ri-th / from back + for (DepthType i=1; (i <= *ri) && (prev != rLabel.npos); i++) + { + prev = rLabel.find_last_of("/\\", prev-1); + } + } + else + { + prev = 0; + } + + // check if finding succeeded (prev == str.npos => failure, see STL) + if (prev != rLabel.npos) + { + // construct the new string beign sought for + str.assign(rLabel, prev, rLabel.size()); + str = '*' + str; + + // now we try to find + found = ((lab=this->mLabelMap.find(str)) != this->mLabelMap.end()); + + // say, that current depth is *ri + current_depth = *ri; + } + else + { + prev = rLabel.size() + 1; + } + } // if (current_depth == MAX_LABEL_DEPTH) + else + { + // now we know at which / we are from the back, so we search forward now + // and we need to reach the ri-th / + while (current_depth > *ri) + { + // we try to find next / + if ((prev = rLabel.find_first_of("/\\", prev+1)) != rLabel.npos) + current_depth--; + else + return false; + } + + // construct the new string beign sought for + str.assign(rLabel, prev, rLabel.size()); + str = '*' + str; + + // now we try to find + found = ((lab=this->mLabelMap.find(str)) != this->mLabelMap.end()); + } + } + + // move one element further (jump to next observed depth) + ri++; + } // while (run) + + // some debug info + if (found) + { + rLS = lab->second; + this->mMatchedPattern = str; + } + + return found; + } + + + //****************************************************************************** + bool + LabelContainer:: + FindInList(const std::string & rLabel, LabelRecord & rLS, bool limitSearch) + { + + bool found = false; + std::string str; + LabelListType::iterator lab = mLabelList.begin(); + LabelListType::iterator limit; + + if (limitSearch && (rLS.miLabelListLimit != mLabelList.end())) + { + limit = rLS.miLabelListLimit; + limit++; + } + else + { + limit = this->mLabelList.end(); + } + + // we perform sequential search until we run to the end of the list or we find + // something + while ((!found) && (lab != limit)) + { + if (ProcessMask(rLabel, (*lab)->first, str)) + { + found = true; + } + else + { + lab++; + } + } // while (run) + + // some debug info + if (found) + { + rLS = (*lab)->second; + this->mMatchedPattern = (*lab)->first; + this->mMatchedPatternMask = str; + } + return found; + } + + + //****************************************************************************** + bool + LabelContainer:: + Find(const std::string & rLabel, LabelRecord & rLS) + { + // try to find the label in the Hash + if (FindInHash(rLabel, rLS)) + { + // we look in the list, but we limit the search. + FindInList(rLabel, rLS, true); + return true; + } //if (this->mLabelContainer.FindInHash(rLabel, label_stream)) + else + { + // we didn't find it in the hash so we look in the list + return FindInList(rLabel, rLS); + } + } + +} // namespace TNet + diff --git a/htk_io/src/KaldiLib/MlfStream.h b/htk_io/src/KaldiLib/MlfStream.h new file mode 100644 index 0000000..d643f5c --- /dev/null +++ b/htk_io/src/KaldiLib/MlfStream.h @@ -0,0 +1,639 @@ +/** @file MlfStream.h + * This is an TNet C++ Library header. + * + * The naming convention in this file coppies the std::* naming as well as STK + */ + + +#ifndef STK_MlfStream_h +#define STK_MlfStream_h + +#include <iostream> +#include <vector> +#include <map> +#include <list> +#include <set> + + +namespace TNet +{ + class LabelRecord; + class LabelContainer; + + + /// this container stores the lables in linear order as they came + /// i.e. they cannot be hashed + typedef std::list< std::pair<std::string,LabelRecord> *> LabelListType; + + /// type of the container used to store the labels + typedef std::map<std::string, LabelRecord> LabelHashType; + + + + /** + * @brief Describes type of MLF definition + * + * See HTK book for MLF structure. Terms used in TNet are + * compatible with those in HTK book. + */ + enum MlfDefType + { + MLF_DEF_UNKNOWN = 0, ///< unknown definition + MLF_DEF_IMMEDIATE_TRANSCRIPTION, ///< immediate transcription + MLF_DEF_SUB_DIR_DEF ///< subdirectory definition + }; + + + + /** ************************************************************************** + * @brief Holds association between label and stream + */ + class LabelRecord + { + + public: + LabelRecord() : miLabelListLimit(NULL) + { } + + ~LabelRecord() + { } + + /// definition type + MlfDefType mDefType; + + /// position of the label in the stream + std::streampos mStreamPos; + + /** + * @brief points to the current end of the LabelList + * + * The reason for storing this value is to know when we inserted + * a label into the hash. It is possible, that the hash label came + * after list label, in which case the list label is prefered + */ + LabelListType::iterator miLabelListLimit; + + }; + + + + + /** + * @brief Provides an interface to label hierarchy and searching + * + * This class stores label files in a map structure. When a wildcard + * convence is used, the class stores the labels in separate maps according + * to level of wildcard abstraction. By level we mean the directory structure + * depth. + */ + class LabelContainer + { + public: + /// The constructor + LabelContainer() : mUseHashedSearch(true) {} + + /// The destructor + ~LabelContainer(); + + /** + * @brief Inserts new label to the hash structure + */ + void + Insert( + const std::string & rLabel, + std::streampos Pos); + + + /** + * @brief Looks for a record in the hash + */ + bool + FindInHash( + const std::string& rLabel, + LabelRecord& rLS); + + /** + * @brief Looks for a record in the list + * @param rLabel Label to look for + * @param rLS Structure to fill with found data + * @param limitSearch If true @p rLS's @c mLabelListLimit gives the limiting position in the list + */ + bool + FindInList( + const std::string& rLabel, + LabelRecord& rLS, + bool limitSearch = false); + + /** + * @brief Looks for a record + */ + bool + Find( + const std::string & rLabel, + LabelRecord & rLS); + + /** + * @brief Returns the matched pattern + */ + const std::string & + MatchedPattern() const + { + return mMatchedPattern; + } + + /** + * @brief Returns the matched pattern mask (%%%) + */ + const std::string & + MatchedPatternMask() const + { + return mMatchedPatternMask; + } + + /** + * @brief Writes contents to stream (text) + * @param rOStream stream to write to + */ + void + Write(std::ostream& rOStream); + + private: + /// type used for directory depth notation + typedef size_t DepthType; + + + /// this set stores depths of * labels observed at insertion + std::set<DepthType> mDepths; + + /// stores the labels + LabelHashType mLabelMap; + LabelListType mLabelList; + + /// true if labels are to be sought by hashing function (fast) or by + /// sequential search (slow) + bool mUseHashedSearch; + + /// if Find matches the label, this var stores the pattern that matched the + /// query + std::string mMatchedPattern; + + /// if Find matches the label, this var stores the the masked characters. + /// The mask is given by '%' symbols + std::string mMatchedPatternMask; + + /** + * @brief Returns the directory depth of path + */ + size_t + DirDepth(const std::string & path); + + + }; + + + /** + * @brief MLF output buffer definition + */ + template< + typename _CharT, + typename _Traits = std::char_traits<_CharT>, + typename _CharTA = std::allocator<_CharT>, + typename ByteT = char, + typename ByteAT = std::allocator<ByteT> + > + class BasicOMlfStreamBuf + : public std::basic_streambuf<_CharT, _Traits> + { + public: + // necessary typedefs .................................................... + typedef BasicOMlfStreamBuf<_CharT,_Traits,_CharTA,ByteT,ByteAT> + this_type; + typedef std::basic_ostream<_CharT, _Traits>& + OStreamReference; + typedef std::basic_streambuf<_CharT, _Traits> + StreamBufType; + typedef _CharTA char_allocator_type; + typedef _CharT char_type; + typedef typename _Traits::int_type int_type; + typedef typename _Traits::pos_type pos_type; + typedef ByteT byte_type; + typedef ByteAT byte_allocator_type; + typedef byte_type* byte_buffer_type; + typedef std::vector<byte_type, byte_allocator_type > byte_vector_type; + typedef std::vector<char_type, char_allocator_type > char_vector_type; + + + BasicOMlfStreamBuf(OStreamReference rOStream, size_t bufferSize); + + ~BasicOMlfStreamBuf(); + + // virtual functions inherited from basic_streambuf....................... + int + sync(); + + /** + * @brief Write character in the case of overflow + * @param c Character to be written. + * @return A value different than EOF (or traits::eof() for other traits) + * signals success. If the function fails, either EOF + * (or traits::eof() for other traits) is returned or an + * exception is thrown. + */ + int_type + overflow(int_type c = _Traits::eof()); + + + // MLF specific functions ................................................ + /** + * @brief Creates a new MLF block + * @param rFileName filename to be opened + */ + this_type* + Open(const std::string& rFileName); + + /** + * @brief Closes MLF block + */ + void + Close(); + + /** + * @brief Returns true if the MLF is now in open state + */ + bool + IsOpen() const + { return mIsOpen; } + + LabelContainer& + rLabels() + { return mLabels; } + + private: + bool mIsOpen; + char_type mLastChar; + OStreamReference mOStream; + LabelContainer mLabels; + }; // class BasicOMlfStreamBuf + + + + /** + * @brief MLF input buffer definition + */ + template< + typename _CharT, + typename _Traits = std::char_traits<_CharT>, + typename _CharTA = std::allocator<_CharT>, + typename ByteT = char, + typename ByteAT = std::allocator<ByteT> + > + class BasicIMlfStreamBuf + : public std::basic_streambuf<_CharT, _Traits> + { + private: + // internal automaton states + static const int IN_HEADER_STATE = 0; + static const int OUT_OF_BODY_STATE = 1; + static const int IN_TITLE_STATE = 2; + static const int IN_BODY_STATE = 3; + + + public: // necessary typedefs .............................................. + typedef BasicIMlfStreamBuf<_CharT,_Traits,_CharTA,ByteT,ByteAT> + this_type; + typedef std::basic_istream<_CharT, _Traits>& IStreamReference; + typedef std::basic_streambuf<_CharT, _Traits> + StreamBufType; + typedef _CharTA char_allocator_type; + typedef _CharT char_type; + typedef typename _Traits::int_type int_type; + typedef typename _Traits::pos_type pos_type; + typedef ByteT byte_type; + typedef ByteAT byte_allocator_type; + typedef byte_type* byte_buffer_type; + typedef std::vector<byte_type, byte_allocator_type > byte_vector_type; + typedef std::vector<char_type, char_allocator_type > char_vector_type; + + + public: + // constructors and destructors .......................................... + BasicIMlfStreamBuf(IStreamReference rIStream, size_t bufferSize = 1024); + + ~BasicIMlfStreamBuf(); + + // virtual functions inherited from basic_streambuf....................... + /** + * @brief Get character in the case of underflow + * + * @return The new character available at the get pointer position, if + * any. Otherwise, traits::eof() is returned. + */ + int_type + underflow(); + + + // MLF specific functions ................................................ + /** + * @brief Creates a new MLF block + * @param rFileName filename to be opened + */ + this_type* + Open(const std::string& rFileName); + + /** + * @brief Closes MLF block + */ + this_type* + Close(); + + /** + * @brief Returns true if the MLF is now in open state + */ + bool + IsOpen() const + { return mIsOpen; } + + /** + * @brief Parses the stream (if possible) and stores positions to the + * label titles + */ + void + Index(); + + bool + IsHashed() const + { return mIsHashed; } + + /** + * @brief Jumps to next label definition + * @param rName std::string to be filled with the label name + * @return true on success + * + * The procedure automatically tries to hash the labels. + */ + bool + JumpToNextDefinition(std::string& rName); + + /** + * @brief Returns reference to the base stream + * @return reference to the stream + * + */ + IStreamReference + GetBaseStream() + { + return mIStream; + } + + private: // auxillary functions ............................................ + /** + * @brief Fills the line buffer with next line and updates the internal + * state of the finite automaton + */ + void + FillLineBuffer(); + + + private: // atributes ...................................................... + // some flags + bool mIsOpen; + bool mIsHashed; + bool mIsEof; + + /// internal state of the finite automaton + int mState; + + IStreamReference mIStream; + LabelContainer mLabels; + + std::vector<char_type> mLineBuffer; + }; // class BasicIMlfStreamBuf + + + + + /** + * @brief Base class with type-independent members for the Mlf Output + * Stram class + * + * This is a derivative of the basic_ios class. We derive it as we need + * to override some member functions + */ + template< + typename Elem, + typename Tr = std::char_traits<Elem>, + typename ElemA = std::allocator<Elem>, + typename ByteT = char, + typename ByteAT = std::allocator<ByteT> + > + class BasicOMlfStreamBase + : virtual public std::basic_ios<Elem,Tr> + { + public: + typedef std::basic_ostream<Elem, Tr>& OStreamReference; + typedef BasicOMlfStreamBuf < + Elem,Tr,ElemA,ByteT,ByteAT> OMlfStreamBufType; + + /** + * @brief constructor + * + * @param rOStream user defined output stream + */ + BasicOMlfStreamBase(OStreamReference rOStream, + size_t bufferSize) + : mBuf(rOStream, bufferSize) + { this->init(&mBuf); }; + + /** + * @brief Returns a pointer to the buffer object for this stream + */ + OMlfStreamBufType* + rdbuf() + { return &mBuf; }; + + private: + OMlfStreamBufType mBuf; + }; + + + template< + typename Elem, + typename Tr = std::char_traits<Elem>, + typename ElemA = std::allocator<Elem>, + typename ByteT = char, + typename ByteAT = std::allocator<ByteT> + > + class BasicIMlfStreamBase + : virtual public std::basic_ios<Elem,Tr> + { + public: + typedef std::basic_istream<Elem, Tr>& IStreamReference; + typedef BasicIMlfStreamBuf < + Elem,Tr,ElemA,ByteT,ByteAT> IMlfStreamBufType; + + BasicIMlfStreamBase( IStreamReference rIStream, + size_t bufferSize) + : mBuf(rIStream, bufferSize) + { this->init(&mBuf ); }; + + IMlfStreamBufType* + rdbuf() + { return &mBuf; }; + + IStreamReference + GetBaseStream() + { return mBuf.GetBaseStream(); } + + private: + IMlfStreamBufType mBuf; + }; + + + template< + typename Elem, + typename Tr = std::char_traits<Elem>, + typename ElemA = std::allocator<Elem>, + typename ByteT = char, + typename ByteAT = std::allocator<ByteT> + > + class BasicOMlfStream + : public BasicOMlfStreamBase<Elem,Tr,ElemA,ByteT,ByteAT>, + public std::basic_ostream<Elem,Tr> + { + public: + typedef BasicOMlfStreamBase< Elem,Tr,ElemA,ByteT,ByteAT> + BasicOMlfStreamBaseType; + typedef std::basic_ostream<Elem,Tr> OStreamType; + typedef OStreamType& OStreamReference; + + BasicOMlfStream(OStreamReference rOStream, size_t bufferSize = 32) + : BasicOMlfStreamBaseType(rOStream, bufferSize), + OStreamType(BasicOMlfStreamBaseType::rdbuf()) + { } + + /** + * @brief Destructor closes the stream + */ + ~BasicOMlfStream() + { } + + + /** + * @brief Creates a new MLF block + * @param rFileName filename to be opened + */ + void + Open(const std::string& rFileName) + { BasicOMlfStreamBaseType::rdbuf()->Open(rFileName); } + + /** + * @brief Closes MLF block + */ + void + Close() + { BasicOMlfStreamBaseType::rdbuf()->Close(); } + + /** + * @brief Returns true if the MLF is now in open state + */ + bool + IsOpen() const + { return BasicOMlfStreamBaseType::rdbuf()->IsOpen(); } + + /** + * @brief Accessor to the label container + * @return Reference to the label container + */ + LabelContainer& + rLabels() + { return BasicOMlfStreamBaseType::rdbuf()->rLabels(); } + }; + + + + template< + typename Elem, + typename Tr = std::char_traits<Elem>, + typename ElemA = std::allocator<Elem>, + typename ByteT = char, + typename ByteAT = std::allocator<ByteT> + > + class BasicIMlfStream + : public BasicIMlfStreamBase<Elem,Tr,ElemA,ByteT,ByteAT>, + public std::basic_istream<Elem,Tr> + { + public: + typedef BasicIMlfStreamBase <Elem,Tr,ElemA,ByteT,ByteAT> + BasicIMlfStreamBaseType; + typedef std::basic_istream<Elem,Tr> IStreamType; + typedef IStreamType& IStreamReference; + typedef unsigned char byte_type; + + BasicIMlfStream(IStreamReference rIStream, size_t bufferSize = 32) + : BasicIMlfStreamBaseType(rIStream, bufferSize), + IStreamType(BasicIMlfStreamBaseType::rdbuf()) + {}; + + + /** + * @brief Creates a new MLF block + * @param rFileName filename to be opened + */ + void + Open(const std::string& rFileName) + { + std::basic_streambuf<Elem, Tr>* p_buf; + + p_buf = BasicIMlfStreamBaseType::rdbuf()->Open(rFileName); + + if (NULL == p_buf) { + IStreamType::clear(IStreamType::rdstate() | std::ios::failbit); + } + else { + IStreamType::clear(); + } + } + + /** + * @brief Closes MLF block. + * In fact, nothing is done + */ + void + Close() + { + if (NULL == BasicIMlfStreamBaseType::rdbuf()->Close()) { + IStreamType::clear(IStreamType::rdstate() | std::ios::failbit); + } + } + + void + Index() + { BasicIMlfStreamBaseType::rdbuf()->Index(); } + + bool + IsHashed() const + { return BasicIMlfStreamBaseType::rdbuf()->IsHashed(); } + + }; + + + + // MAIN TYPEDEFS.............................................................. + typedef BasicOMlfStream<char> OMlfStream; + typedef BasicOMlfStream<wchar_t> WOMlfStream; + typedef BasicIMlfStream<char> IMlfStream; + typedef BasicIMlfStream<wchar_t> WIMlfStream; + + +#ifdef PATH_MAX + const size_t MAX_LABEL_DEPTH = PATH_MAX; +#else + const size_t MAX_LABEL_DEPTH = 1024; +#endif + + +} // namespace TNet + +#include "MlfStream.tcc" + +#endif diff --git a/htk_io/src/KaldiLib/MlfStream.tcc b/htk_io/src/KaldiLib/MlfStream.tcc new file mode 100644 index 0000000..8978545 --- /dev/null +++ b/htk_io/src/KaldiLib/MlfStream.tcc @@ -0,0 +1,517 @@ +#ifndef STK_MlfStream_tcc +#define STK_MlfStream_tcc + +#include <algorithm> + +#include "Common.h" +#include "StkMatch.h" + +namespace TNet +{ + //**************************************************************************** + //**************************************************************************** + template< + typename _CharT, + typename _Traits, + typename _CharTA, + typename ByteT, + typename ByteAT + > + BasicOMlfStreamBuf<_CharT, _Traits, _CharTA, ByteT, ByteAT>:: + BasicOMlfStreamBuf(OStreamReference rOStream, size_t bufferSize) + : mIsOpen(false), mOStream(rOStream) + { } + + + //**************************************************************************** + //**************************************************************************** + template< + typename _CharT, + typename _Traits, + typename _CharTA, + typename ByteT, + typename ByteAT + > + BasicOMlfStreamBuf<_CharT, _Traits, _CharTA, ByteT, ByteAT>:: + ~BasicOMlfStreamBuf() + { + mOStream.flush(); + } + + + //**************************************************************************** + //**************************************************************************** + template< + typename _CharT, + typename _Traits, + typename _CharTA, + typename ByteT, + typename ByteAT + > + int + BasicOMlfStreamBuf<_CharT, _Traits, _CharTA, ByteT, ByteAT>:: + sync() + { + mOStream.flush(); + return 0; + } + + + //**************************************************************************** + //**************************************************************************** + template< + typename _CharT, + typename _Traits, + typename _CharTA, + typename ByteT, + typename ByteAT + > + typename _Traits::int_type + BasicOMlfStreamBuf<_CharT, _Traits, _CharTA, ByteT, ByteAT>:: + overflow(typename _Traits::int_type c) + { + // we don't use buffer here... + if (mIsOpen) { + if (_Traits::eof() == c) { + return _Traits::not_eof(c); + } + // only pass the character to the stream + mOStream.rdbuf()->sputc(c); + + // remember last char (in case we want to close) + mLastChar = c; + + return c; + } + else { + return _Traits::eof(); + } + } + + + //**************************************************************************** + //**************************************************************************** + template< + typename _CharT, + typename _Traits, + typename _CharTA, + typename ByteT, + typename ByteAT + > + void + BasicOMlfStreamBuf<_CharT, _Traits, _CharTA, ByteT, ByteAT>:: + Close() + { + // if last character was not EOL, we need to insert it + if (mLastChar != '\n') { + mOStream.put('\n'); + } + mOStream << ".\n"; + + // flush the stream and declare the stream closed + mOStream.flush(); + mIsOpen = false; + } + + + //**************************************************************************** + //**************************************************************************** + template< + typename _CharT, + typename _Traits, + typename _CharTA, + typename ByteT, + typename ByteAT + > + BasicOMlfStreamBuf<_CharT, _Traits, _CharTA, ByteT, ByteAT> * + BasicOMlfStreamBuf<_CharT, _Traits, _CharTA, ByteT, ByteAT>:: + Open(const std::string& rFileName) + { + // retreive position + std::streampos pos = mOStream.tellp(); + + // write the initial "filename" in parantheses + mOStream << '"' << rFileName << '"' << std::endl; + mLastChar = '\n'; + + // return NULL if we canot open + if (!mOStream.good()) { + return NULL; + } + + // if ok, store the name position + if (-1 != pos) { + pos = mOStream.tellp(); + mLabels.Insert(rFileName, pos); + } + + // set open flag and return this + mIsOpen = true; + return this; + } + + + //**************************************************************************** + //**************************************************************************** + // BasicIMlfStreamBuf section + // + //**************************************************************************** + //**************************************************************************** + + //**************************************************************************** + //**************************************************************************** + template< + typename _CharT, + typename _Traits, + typename _CharTA, + typename ByteT, + typename ByteAT + > + BasicIMlfStreamBuf<_CharT, _Traits, _CharTA, ByteT, ByteAT>:: + BasicIMlfStreamBuf(IStreamReference rIStream, size_t bufferSize) + : mIsOpen(false), mIsHashed(false), mIsEof(true), mState(IN_HEADER_STATE), + mIStream(rIStream), mLineBuffer() + { + // we reserve some place for the buffer... + mLineBuffer.reserve(bufferSize); + + //StreamBufType::setg(mpBuffer, mpBuffer + bufferSize, mpBuffer + bufferSize); + StreamBufType::setg(&(mLineBuffer.front()), &(mLineBuffer.back()), &(mLineBuffer.back())); + } + + //**************************************************************************** + //**************************************************************************** + template< + typename _CharT, + typename _Traits, + typename _CharTA, + typename ByteT, + typename ByteAT + > + BasicIMlfStreamBuf<_CharT, _Traits, _CharTA, ByteT, ByteAT>:: + ~BasicIMlfStreamBuf() + { + } + + //**************************************************************************** + //**************************************************************************** + template< + typename _CharT, + typename _Traits, + typename _CharTA, + typename ByteT, + typename ByteAT + > + void + BasicIMlfStreamBuf<_CharT, _Traits, _CharTA, ByteT, ByteAT>:: + Index() + { + // retreive position + std::streampos orig_pos = mIStream.tellg(); + int orig_state = mState; + + // for streams like stdin, pos will by definition be -1, so we can only + // rely on sequential access and cannot hash it. + if (-1 != orig_pos) { + std::string aux_name; + // we will constantly jump to next definition. the function automatically + // hashes the stream if possible + while (JumpToNextDefinition(aux_name)) + { } + + // move to the original position + mIStream.clear(); + mIStream.seekg(orig_pos); + mState = orig_state; + + // set as hashed + mIsHashed=true; + } + } + + + //**************************************************************************** + //**************************************************************************** + template< + typename _CharT, + typename _Traits, + typename _CharTA, + typename ByteT, + typename ByteAT + > + bool + BasicIMlfStreamBuf<_CharT, _Traits, _CharTA, ByteT, ByteAT>:: + JumpToNextDefinition(std::string& rName) + { + if (!mIStream.good()) { + return false; + } + + // if we can, we will try to index the label + std::streampos pos = mIStream.tellg(); + + // we might be at a definition already, so first move one line further + FillLineBuffer(); + + // read lines till we get to definition again + while (mIStream.good() && mState != IN_TITLE_STATE) { + FillLineBuffer(); + } + + // decide what happened + if (IN_TITLE_STATE == mState) { + // if we can, we will try to index the label + pos = mIStream.tellg(); + + if (pos != static_cast<const std::streampos>(-1)) { + // if (pos !=std::string::npos) { // This line does not work under MinGW + std::string line_buffer(mLineBuffer.begin(), mLineBuffer.end()); + TNet::ParseHTKString(line_buffer, rName); + mLabels.Insert(rName, pos); + } + + return true; + } + else { + // we have been hashing all the way through so we know that if this is + // is the EOF, we are done hashing this stream + if (pos != static_cast<const std::streampos>(-1)) { + mIsHashed = true; + } + + // we are not in body state, so we just return false + return false; + } + } + + + //**************************************************************************** + //**************************************************************************** + template< + typename _CharT, + typename _Traits, + typename _CharTA, + typename ByteT, + typename ByteAT + > + BasicIMlfStreamBuf<_CharT, _Traits, _CharTA, ByteT, ByteAT>* + BasicIMlfStreamBuf<_CharT, _Traits, _CharTA, ByteT, ByteAT>:: + Close() + { + if (!mIsOpen) { + mIsEof = true; + return NULL; + } + else { + // if we try to close while in the body, we need to reach the end + if (mState == IN_BODY_STATE) { + while (mState == IN_BODY_STATE) { + FillLineBuffer(); + } + } + + // disable buffer mechanism + StreamBufType::setg(&(mLineBuffer.front()), &(mLineBuffer.front()), + &(mLineBuffer.front())); + + mIsEof = true; + mIsOpen = false; + + return this; + } + } + + + //**************************************************************************** + //**************************************************************************** + template< + typename _CharT, + typename _Traits, + typename _CharTA, + typename ByteT, + typename ByteAT + > + BasicIMlfStreamBuf<_CharT, _Traits, _CharTA, ByteT, ByteAT>* + BasicIMlfStreamBuf<_CharT, _Traits, _CharTA, ByteT, ByteAT>:: + Open(const std::string& rFileName) + { + BasicIMlfStreamBuf<_CharT, _Traits, _CharTA, ByteT, ByteAT>* ret_val = NULL; + + // this behavior is compatible with ifstream + if (mIsOpen) { + Close(); + return NULL; + } + + // retreive position + std::streampos pos = mIStream.tellg(); + LabelRecord label_record; + + // for streams like stdin, pos will by definition be -1, so we can only + // rely on sequential access. At this place, we decide what to do + if ((-1 != pos) && (mLabels.Find(rFileName, label_record))) { + mIStream.seekg(label_record.mStreamPos); + mState = IN_TITLE_STATE; + + // we don't want the other stream to be bad, so we transfer the + // flagbits to this stream + if (!mIStream.good()) { + mIStream.clear(); + mIsOpen = false; + ret_val = NULL; + } + else { + mIsOpen = true; + mIsEof = false; + ret_val = this; + } + } + + // we don't have sequential stream and we didn't find the label, but + // we are hashed, so we can be sure, that we failed + else if ((-1 != pos) && mIsHashed) { + mIsOpen = false; + ret_val = NULL; + } + + // we either have sequential stream or didn't find anything, but we can + // still try to sequentially go and look for it + else { + bool found = false; + std::string aux_name; + std::string aux_name2; + + while ((!found) && JumpToNextDefinition(aux_name)) { + if (TNet::ProcessMask(rFileName, aux_name, aux_name2)) { + mIsOpen = true; + mIsEof = false; + found = true; + ret_val = this; + } + } + + if (!found) { + mIsOpen = false; + ret_val = NULL; + } + } + + return ret_val; + } + + + //**************************************************************************** + //**************************************************************************** + template< + typename _CharT, + typename _Traits, + typename _CharTA, + typename ByteT, + typename ByteAT + > + typename _Traits::int_type + BasicIMlfStreamBuf<_CharT, _Traits, _CharTA, ByteT, ByteAT>:: + underflow() + { + // we don't do anything if EOF + if (mIsEof) { + StreamBufType::setg(&(mLineBuffer.front()), &(mLineBuffer.front()), + &(mLineBuffer.front())); + return _Traits::eof(); + } + + // read from buffer if we can + if (StreamBufType::gptr() && (StreamBufType::gptr() < StreamBufType::egptr())) { + return _Traits::not_eof(*StreamBufType::gptr()); + } + + // might happen that stream is in !good state + if (!mIStream.good()) { + mIsEof = true; + StreamBufType::setg(&(mLineBuffer.front()), &(mLineBuffer.front()), + &(mLineBuffer.front())); + return _Traits::eof(); + } + + // fill the line buffer and update my state + FillLineBuffer(); + + // if the whole line is just period or it's eof, declare EOF + if (mState == OUT_OF_BODY_STATE) { + mIsEof = true; + StreamBufType::setg(&(mLineBuffer.front()), &(mLineBuffer.front()), + &(mLineBuffer.front())); + return _Traits::eof(); + } + + // restore the buffer mechanism + StreamBufType::setg(&(mLineBuffer.front()), &(mLineBuffer.front()), + &(mLineBuffer.back()) + 1); + + return *StreamBufType::gptr(); + } + + + //**************************************************************************** + //**************************************************************************** + template< + typename _CharT, + typename _Traits, + typename _CharTA, + typename ByteT, + typename ByteAT + > + void + BasicIMlfStreamBuf<_CharT, _Traits, _CharTA, ByteT, ByteAT>:: + FillLineBuffer() + { + // reset line buffer + size_t capacity = mLineBuffer.capacity(); + mLineBuffer.clear(); + mLineBuffer.reserve(capacity); + + // read one line into buffer + int c; + while ((c = mIStream.get()) != '\n' && c != _Traits::eof()) { + mLineBuffer.push_back(char(c)); + } + + // we want to be able to pass last eol symbol + if (c == '\n') { + mLineBuffer.push_back(char(c)); + } + + // we will decide where we are + switch (mState) { + case IN_HEADER_STATE: + + case OUT_OF_BODY_STATE: + if (mLineBuffer[0] != '#') { + mState = IN_TITLE_STATE; + } + break; + + case IN_TITLE_STATE: + if (mLineBuffer[0] == '.' && (mLineBuffer.back() == '\n' || mIStream.eof())) { + mState = OUT_OF_BODY_STATE; + } + else { + mState = IN_BODY_STATE; + } + break; + + case IN_BODY_STATE: + // period or EOF will end the file + if (mLineBuffer[0] == '.' && (mLineBuffer.back() == '\n' || mIStream.eof())) { + mState = OUT_OF_BODY_STATE; + } + if (mLineBuffer.size() == 0) { + mState = OUT_OF_BODY_STATE; + } + break; + } + } +} // namespace TNet + + +#endif // STK_MlfStream_tcc diff --git a/htk_io/src/KaldiLib/StkMatch.cc b/htk_io/src/KaldiLib/StkMatch.cc new file mode 100644 index 0000000..4ff4b18 --- /dev/null +++ b/htk_io/src/KaldiLib/StkMatch.cc @@ -0,0 +1,582 @@ +/* + EPSHeader + + File: filmatch.c + Author: J. Kercheval + Created: Thu, 03/14/1991 22:22:01 +*/ + +/* + EPSRevision History + O. Glembek Thu, 03/11/2005 01:58:00 Added Mask extraction support (char % does this) + J. Kercheval Wed, 02/20/1991 22:29:01 Released to Public Domain + J. Kercheval Fri, 02/22/1991 15:29:01 fix '\' bugs (two :( of them) + J. Kercheval Sun, 03/10/1991 19:31:29 add error return to matche() + J. Kercheval Sun, 03/10/1991 20:11:11 add is_valid_pattern code + J. Kercheval Sun, 03/10/1991 20:37:11 beef up main() + J. Kercheval Tue, 03/12/1991 22:25:10 Released as V1.1 to Public Domain + J. Kercheval Thu, 03/14/1991 22:22:25 remove '\' for DOS file parsing + J. Kercheval Thu, 03/28/1991 20:58:27 include filmatch.h +*/ + +/* + Wildcard Pattern Matching +*/ + + +#include "StkMatch.h" +#include "Common.h" + +namespace TNet +{ + //#define TEST + static int matche_after_star (register const char *pattern, register const char *text, register char *s); + // following function is not defined or used. + // static int fast_match_after_star (register const char *pattern, register const char *text); + + /*---------------------------------------------------------------------------- + * + * Return true if PATTERN has any special wildcard characters + * + ----------------------------------------------------------------------------*/ + + bool is_pattern (const char *p) + { + while ( *p ) { + switch ( *p++ ) { + case '?': + case '*': + case '%': + case '[': + return true; + } + } + return false; + } + + + /*---------------------------------------------------------------------------- + * + * Return true if PATTERN has is a well formed regular expression according + * to the above syntax + * + * error_type is a return code based on the type of pattern error. Zero is + * returned in error_type if the pattern is a valid one. error_type return + * values are as follows: + * + * PATTERN_VALID - pattern is well formed + * PATTERN_RANGE - [..] construct has a no end range in a '-' pair (ie [a-]) + * PATTERN_CLOSE - [..] construct has no end bracket (ie [abc-g ) + * PATTERN_EMPTY - [..] construct is empty (ie []) + * + ----------------------------------------------------------------------------*/ + + bool is_valid_pattern (const char *p, int *error_type) + { + + /* init error_type */ + *error_type = PATTERN_VALID; + + /* loop through pattern to EOS */ + while ( *p ) + { + /* determine pattern type */ + switch ( *p ) + { + /* the [..] construct must be well formed */ + case '[': + { + p++; + + /* if the next character is ']' then bad pattern */ + if ( *p == ']' ) { + *error_type = PATTERN_EMPTY; + return false; + } + + /* if end of pattern here then bad pattern */ + if ( !*p ) + { + *error_type = PATTERN_CLOSE; + return false; + } + + /* loop to end of [..] construct */ + while ( *p != ']' ) + { + /* check for literal escape */ + if ( *p == '\\' ) + { + p++; + + /* if end of pattern here then bad pattern */ + if ( !*p++ ) { + *error_type = PATTERN_ESC; + return false; + } + } + else + p++; + + /* if end of pattern here then bad pattern */ + if ( !*p ) + { + *error_type = PATTERN_CLOSE; + return false; + } + + /* if this a range */ + if ( *p == '-' ) + { + /* we must have an end of range */ + if ( !*++p || *p == ']' ) + { + *error_type = PATTERN_RANGE; + return false; + } + else + { + + /* check for literal escape */ + if ( *p == '\\' ) + p++; + + /* if end of pattern here then bad pattern */ + if ( !*p++ ) + { + *error_type = PATTERN_ESC; + return false; + } + } + } + } + break; + } //case '[': + + + /* all other characters are valid pattern elements */ + case '*': + case '?': + case '%': + default: + p++; /* "normal" character */ + break; + } // switch ( *p ) + } // while ( *p ) + + return true; + } //bool is_valid_pattern (const char *p, int *error_type) + + + /*---------------------------------------------------------------------------- + * + * Match the pattern PATTERN against the string TEXT; + * + * returns MATCH_VALID if pattern matches, or an errorcode as follows + * otherwise: + * + * MATCH_PATTERN - bad pattern + * MATCH_RANGE - match failure on [..] construct + * MATCH_ABORT - premature end of text string + * MATCH_END - premature end of pattern string + * MATCH_VALID - valid match + * + * + * A match means the entire string TEXT is used up in matching. + * + * In the pattern string: + * `*' matches any sequence of characters (zero or more) + * `?' matches any character + * `%' matches any character and stores it in the s string + * [SET] matches any character in the specified set, + * [!SET] or [^SET] matches any character not in the specified set. + * \ is allowed within a set to escape a character like ']' or '-' + * + * A set is composed of characters or ranges; a range looks like + * character hyphen character (as in 0-9 or A-Z). [0-9a-zA-Z_] is the + * minimal set of characters allowed in the [..] pattern construct. + * Other characters are allowed (ie. 8 bit characters) if your system + * will support them. + * + * To suppress the special syntactic significance of any of `[]*?%!^-\', + * within a [..] construct and match the character exactly, precede it + * with a `\'. + * + ----------------------------------------------------------------------------*/ + + int matche ( register const char *p, register const char *t, register char *s ) + { + register char range_start, range_end; /* start and end in range */ + + bool invert; /* is this [..] or [!..] */ + bool member_match; /* have I matched the [..] construct? */ + bool loop; /* should I terminate? */ + + for ( ; *p; p++, t++ ) { + + /* if this is the end of the text then this is the end of the match */ + if (!*t) { + return ( *p == '*' && *++p == '\0' ) ? MATCH_VALID : MATCH_ABORT; + } + + /* determine and react to pattern type */ + switch ( *p ) { + + /* single any character match */ + case '?': + break; + + /* single any character match, with extraction*/ + case '%': { + *s++ = *t; + *s = '\0'; + break; + } + + /* multiple any character match */ + case '*': + return matche_after_star (p, t, s); + + /* [..] construct, single member/exclusion character match */ + case '[': { + /* move to beginning of range */ + p++; + + /* check if this is a member match or exclusion match */ + invert = false; + if ( *p == '!' || *p == '^') { + invert = true; + p++; + } + + /* if closing bracket here or at range start then we have a + malformed pattern */ + if ( *p == ']' ) { + return MATCH_PATTERN; + } + + member_match = false; + loop = true; + + while ( loop ) { + + /* if end of construct then loop is done */ + if (*p == ']') { + loop = false; + continue; + } + + /* matching a '!', '^', '-', '\' or a ']' */ + if ( *p == '\\' ) { + range_start = range_end = *++p; + } + else { + range_start = range_end = *p; + } + + /* if end of pattern then bad pattern (Missing ']') */ + if (!*p) + return MATCH_PATTERN; + + /* check for range bar */ + if (*++p == '-') { + + /* get the range end */ + range_end = *++p; + + /* if end of pattern or construct then bad pattern */ + if (range_end == '\0' || range_end == ']') + return MATCH_PATTERN; + + /* special character range end */ + if (range_end == '\\') { + range_end = *++p; + + /* if end of text then we have a bad pattern */ + if (!range_end) + return MATCH_PATTERN; + } + + /* move just beyond this range */ + p++; + } + + /* if the text character is in range then match found. + make sure the range letters have the proper + relationship to one another before comparison */ + if ( range_start < range_end ) { + if (*t >= range_start && *t <= range_end) { + member_match = true; + loop = false; + } + } + else { + if (*t >= range_end && *t <= range_start) { + member_match = true; + loop = false; + } + } + } + + /* if there was a match in an exclusion set then no match */ + /* if there was no match in a member set then no match */ + if ((invert && member_match) || + !(invert || member_match)) + return MATCH_RANGE; + + /* if this is not an exclusion then skip the rest of the [...] + construct that already matched. */ + if (member_match) { + while (*p != ']') { + + /* bad pattern (Missing ']') */ + if (!*p) + return MATCH_PATTERN; + + /* skip exact match */ + if (*p == '\\') { + p++; + + /* if end of text then we have a bad pattern */ + if (!*p) + return MATCH_PATTERN; + } + + /* move to next pattern char */ + p++; + } + } + + break; + } // case ']' + + /* must match this character exactly */ + default: + if (*p != *t) + return MATCH_LITERAL; + } + } + + //*s = '\0'; + /* if end of text not reached then the pattern fails */ + if ( *t ) + return MATCH_END; + else + return MATCH_VALID; + } + + + /*---------------------------------------------------------------------------- + * + * recursively call matche() with final segment of PATTERN and of TEXT. + * + ----------------------------------------------------------------------------*/ + + static int matche_after_star (register const char *p, register const char *t, register char *s) + { + register int match = 0; + register char nextp; + + /* pass over existing ? and * in pattern */ + while ( *p == '?' || *p == '%' || *p == '*' ) { + + /* take one char for each ? and + */ + if ( *p == '?') { + + /* if end of text then no match */ + if ( !*t++ ) { + return MATCH_ABORT; + } + } + + if ( *p == '%') { + *s++ = *t; + *s = '\0'; + /* if end of text then no match */ + if ( !*t++ ) { + return MATCH_ABORT; + } + } + + /* move to next char in pattern */ + p++; + } + + /* if end of pattern we have matched regardless of text left */ + if ( !*p ) { + return MATCH_VALID; + } + + /* get the next character to match which must be a literal or '[' */ + nextp = *p; + + /* Continue until we run out of text or definite result seen */ + do { + + /* a precondition for matching is that the next character + in the pattern match the next character in the text or that + the next pattern char is the beginning of a range. Increment + text pointer as we go here */ + if ( nextp == *t || nextp == '[' ) { + match = matche(p, t, s); + } + + /* if the end of text is reached then no match */ + if ( !*t++ ) match = MATCH_ABORT; + + } while ( match != MATCH_VALID && + match != MATCH_ABORT && + match != MATCH_PATTERN); + + /* return result */ + return match; + } + + + /*---------------------------------------------------------------------------- + * + * match() is a shell to matche() to return only bool values. + * + ----------------------------------------------------------------------------*/ + + bool match(const char *p, const char *t, char *s) + { + int error_type; + error_type = matche(p,t,s); + return (error_type != MATCH_VALID ) ? false : true; + } + + + //*************************************************************************** + //*************************************************************************** + bool + ProcessMask(const std::string & rString, + const std::string & rWildcard, + std::string & rSubstr) + { + char * substr; + int percent_count = 0; + int ret ; + size_t pos = 0; + + // let's find how many % to allocate enough space for the return substring + while ((pos = rWildcard.find('%', pos)) != rWildcard.npos) + { + percent_count++; + pos++; + } + + // allocate space for the substring + substr = new char[percent_count + 1]; + substr[percent_count] = 0; + substr[0] = '\0'; + + // optionally prepend '*/' to wildcard + std::string wildcard(rWildcard); + if(wildcard[0] != '*') { + wildcard = "*/" + wildcard; + } + + //optionally prepend '/' to string + std::string string1(rString); + if(string1[0] != '/') { + string1 = "/" + string1; + } + + // parse the string + if (0 != (ret = match(wildcard.c_str(), string1.c_str(), substr))) + { + rSubstr = substr; + } + delete[] substr; + return ret; + } // ProcessMask +} + + +#ifdef TEST + +/* +* This test main expects as first arg the pattern and as second arg +* the match string. Output is yaeh or nay on match. If nay on +* match then the error code is parsed and written. +*/ + +#include <stdio.h> + +int main(int argc, char *argv[]) +{ + int error; + int is_valid_error; + + char * tmp = argv[0]; + int i = 0; + for (; *tmp; tmp++) + if (*tmp=='%') i++; + + char s[i+1]; + + + if (argc != 3) { + printf("Usage: MATCH Pattern Text\n"); + } + else { + printf("Pattern: %s\n", argv[1]); + printf("Text : %s\n", argv[2]); + + if (!is_pattern(argv[1])) { + printf(" First Argument Is Not A Pattern\n"); + } + else { + match(argv[1],argv[2], s) ? printf("true") : printf("false"); + error = matche(argv[1],argv[2], s); + is_valid_pattern(argv[1],&is_valid_error); + + switch ( error ) { + case MATCH_VALID: + printf(" Match Successful"); + if (is_valid_error != PATTERN_VALID) + printf(" -- is_valid_pattern() is complaining\n"); + else + printf("\n"); + printf("%s\n", s); + + break; + case MATCH_RANGE: + printf(" Match Failed on [..]\n"); + break; + case MATCH_ABORT: + printf(" Match Failed on Early Text Termination\n"); + break; + case MATCH_END: + printf(" Match Failed on Early Pattern Termination\n"); + break; + case MATCH_PATTERN: + switch ( is_valid_error ) { + case PATTERN_VALID: + printf(" Internal Disagreement On Pattern\n"); + break; + case PATTERN_RANGE: + printf(" No End of Range in [..] Construct\n"); + break; + case PATTERN_CLOSE: + printf(" [..] Construct is Open\n"); + break; + case PATTERN_EMPTY: + printf(" [..] Construct is Empty\n"); + break; + default: + printf(" Internal Error in is_valid_pattern()\n"); + } + break; + default: + printf(" Internal Error in matche()\n"); + break; + } + } + + } + return(0); +} + +#endif diff --git a/htk_io/src/KaldiLib/StkMatch.h b/htk_io/src/KaldiLib/StkMatch.h new file mode 100644 index 0000000..42c6b97 --- /dev/null +++ b/htk_io/src/KaldiLib/StkMatch.h @@ -0,0 +1,123 @@ +#ifndef TNet_StkMatch_h +#define TNet_StkMatch_h + +#include <string> +namespace TNet +{ + /* + EPSHeader + + File: filmatch.h + Author: J. Kercheval + Created: Thu, 03/14/1991 22:24:34 + */ + + /* + EPSRevision History + O. Glembek Thu, 03/11/2005 01:58:00 Added Mask extraction support (char % does this) + J. Kercheval Wed, 02/20/1991 22:28:37 Released to Public Domain + J. Kercheval Sun, 03/10/1991 18:02:56 add is_valid_pattern + J. Kercheval Sun, 03/10/1991 18:25:48 add error_type in is_valid_pattern + J. Kercheval Sun, 03/10/1991 18:47:47 error return from matche() + J. Kercheval Tue, 03/12/1991 22:24:49 Released as V1.1 to Public Domain + J. Kercheval Thu, 03/14/1991 22:25:00 remove '\' for DOS file matching + J. Kercheval Thu, 03/28/1991 21:03:59 add in PATTERN_ESC & MATCH_LITERAL + */ + + /* + Wildcard Pattern Matching + */ + + + /* match defines */ +#define MATCH_PATTERN 6 /* bad pattern */ +#define MATCH_LITERAL 5 /* match failure on literal match */ +#define MATCH_RANGE 4 /* match failure on [..] construct */ +#define MATCH_ABORT 3 /* premature end of text string */ +#define MATCH_END 2 /* premature end of pattern string */ +#define MATCH_VALID 1 /* valid match */ + + /* pattern defines */ +#define PATTERN_VALID 0 /* valid pattern */ +#define PATTERN_ESC -1 /* literal escape at end of pattern */ +#define PATTERN_RANGE -2 /* malformed range in [..] construct */ +#define PATTERN_CLOSE -3 /* no end bracket in [..] construct */ +#define PATTERN_EMPTY -4 /* [..] contstruct is empty */ + + + /*---------------------------------------------------------------------------- + * + * Match the pattern PATTERN against the string TEXT; + * + * match() returns TRUE if pattern matches, FALSE otherwise. + * matche() returns MATCH_VALID if pattern matches, or an errorcode + * as follows otherwise: + * + * MATCH_PATTERN - bad pattern + * MATCH_RANGE - match failure on [..] construct + * MATCH_ABORT - premature end of text string + * MATCH_END - premature end of pattern string + * MATCH_VALID - valid match + * + * + * A match means the entire string TEXT is used up in matching. + * + * In the pattern string: + * `*' matches any sequence of characters (zero or more) + * `?' matches any character + * [SET] matches any character in the specified set, + * [!SET] or [^SET] matches any character not in the specified set. + * + * A set is composed of characters or ranges; a range looks like + * character hyphen character (as in 0-9 or A-Z). [0-9a-zA-Z_] is the + * minimal set of characters allowed in the [..] pattern construct. + * Other characters are allowed (ie. 8 bit characters) if your system + * will support them. + * + * To suppress the special syntactic significance of any of `[]*?!^-\', + * in a [..] construct and match the character exactly, precede it + * with a `\'. + * + ----------------------------------------------------------------------------*/ + bool + match (const char *pattern, const char *text, char *s); + + int + matche(register const char *pattern, register const char *text, register char *s); + + + /*---------------------------------------------------------------------------- + * + * Return TRUE if PATTERN has any special wildcard characters + * + ----------------------------------------------------------------------------*/ + bool + is_pattern (const char *pattern); + + + /** -------------------------------------------------------------------------- + * + * Return TRUE if PATTERN has is a well formed regular expression according + * to the above syntax + * + * error_type is a return code based on the type of pattern error. Zero is + * returned in error_type if the pattern is a valid one. error_type return + * values are as follows: + * + * PATTERN_VALID - pattern is well formed + * PATTERN_RANGE - [..] construct has a no end range in a '-' pair (ie [a-]) + * PATTERN_CLOSE - [..] construct has no end bracket (ie [abc-g ) + * PATTERN_EMPTY - [..] construct is empty (ie []) + * -------------------------------------------------------------------------- + **/ + bool + is_valid_pattern (const char *pattern, int *error_type); + + + //**************************************************************************** + //**************************************************************************** + bool + ProcessMask(const std::string & rString, const std::string & rWildcard, + std::string & rSubstr); +} +#endif diff --git a/htk_io/src/KaldiLib/StkStream.h b/htk_io/src/KaldiLib/StkStream.h new file mode 100644 index 0000000..ca8de30 --- /dev/null +++ b/htk_io/src/KaldiLib/StkStream.h @@ -0,0 +1,526 @@ + + +/** @file stkstream.h + * This is an TNet C++ Library header. + */ + + +#ifndef TNet_StkStream_h +#define TNet_StkStream_h + +#include <fstream> +#include <string> +#include <vector> +#include <list> +#include <stdexcept> + +#pragma GCC system_header + + +//extern const char * gpFilterWldcrd; + +namespace TNet +{ + + /** + * @brief Expands a filter command into a runnable form + * + * This function replaces all occurances of *filter_wldcard in *command by + * *filename + */ + //char * ExpandFilterCommand(const char *command, const char *filename); + + /** + * @brief Provides a layer of compatibility for C/POSIX. + * + * This GNU extension provides extensions for working with standard C + * FILE*'s and POSIX file descriptors. It must be instantiated by the + * user with the type of character used in the file stream, e.g., + * basic_stkbuf<char>. + */ + template< + typename _CharT, + typename _Traits = std::char_traits<_CharT> + > + class basic_stkbuf : public std::basic_filebuf<_CharT, _Traits> + { + public: + + typedef basic_stkbuf<_CharT, _Traits> this_type; + + // Types: + typedef _CharT char_type; + typedef _Traits traits_type; + + typedef typename traits_type::int_type int_type; + typedef typename traits_type::pos_type pos_type; + typedef typename traits_type::off_type off_type; + typedef std::size_t size_t; + + public: + + /// @{ + /// Type of streambuffer + static const unsigned int t_undef = 0; ///< undefined + static const unsigned int t_file = 1; ///< file stream + static const unsigned int t_pipe = 2; ///< pipe + static const unsigned int t_filter = 4; ///< filter + static const unsigned int t_stdio = 8; ///< standard input/output + /// @} + + public: + + /** + * deferred initialization + */ + basic_stkbuf() : std::basic_filebuf<_CharT, _Traits>(), + mFilename(""), mpFilePtr(0), mStreamType(t_undef){} + + /** + * @brief Opens a stream. + * @param fName The name of the file. + * @param m The open mode flags. + * @param pFilter The pFilter command to use + * @return @c this on success, NULL on failure + * + * If a file is already open, this function immediately fails. + * Otherwise it tries to open the file named @a s using the flags + * given in @a mode. + * + * [Table 92 gives the relation between openmode combinations and the + * equivalent fopen() flags, but the table has not been copied yet.] + */ + basic_stkbuf(const char* fName, std::ios_base::openmode m, const char* pFilter=""); + + + /** + * @return The underlying FILE*. + * + * This function can be used to access the underlying "C" file pointer. + * Note that there is no way for the library to track what you do + * with the file, so be careful. + */ + std::__c_file* + file() { return this->_M_file.file(); } + + + /** + * @return The underlying FILE*. + * + * This function can be used to access the underlying "C" file pointer. + * Note that there is no way for the library to track what you do + * with the file, so be careful. + */ + std::__c_file* + fp() { return this->_M_file.file(); } + + + /** + * @brief Opens an external file. + * @param fName The name of the file. + * @param m The open mode flags. + * @param pFilter The pFilter command to use + * @return @c this on success, NULL on failure + * + * If a file is already open, this function immediately fails. + * Otherwise it tries to open the file named @a s using the flags + * given in @a mode. + * + * [Table 92 gives the relation between openmode combinations and the + * equivalent fopen() flags, but the table has not been copied yet.] + */ + this_type* + open(const char* pFName, std::ios_base::openmode m, const char* pFilter=""); + + /** + * @brief Closes the currently associated file. + * @return @c this on success, NULL on failure + * + * If no file is currently open, this function immediately fails. + * + * If a "put buffer area" exists, @c overflow(eof) is called to flush + * all the characters. The file is then closed. + * + * If any operations fail, this function also fails. + */ + this_type* + close(); + + /** + * Closes the external data stream if the file descriptor constructor + * was used. + */ + virtual + ~basic_stkbuf() + {close();}; + + /// Returns the file name + const std::string + name() const + {return mFilename;} + + + private: + /// converts the ios::xxx mode to stdio style + static void open_mode(std::ios_base::openmode __mode, int&, int&, char* __c_mode); + + /** + * @param __f An open @c FILE*. + * @param __mode Same meaning as in a standard filebuf. + * @param __size Optimal or preferred size of internal buffer, in chars. + * Defaults to system's @c BUFSIZ. + * + * This method associates a file stream buffer with an open + * C @c FILE*. The @c FILE* will not be automatically closed when the + * basic_stkbuf is closed/destroyed. It is equivalent to one of the constructors + * of the stdio_filebuf class defined in GNU ISO C++ ext/stdio_filebuf.h + */ + void superopen(std::__c_file* __f, std::ios_base::openmode __mode, + size_t __size = static_cast<size_t>(BUFSIZ)); + + + private: + /// Holds the full file name + std::string mFilename; + + std::ios_base::openmode mMode; + + /// Holds a pointer to the main FILE structure + FILE * mpFilePtr; + + /// tells what kind of stream we use (stdio, file, pipe) + unsigned int mStreamType; + + }; + + + + /** + * @brief This extension wraps stkbuf stream buffer into the standard ios class. + * + * This class is inherited by (i/o)stkstream classes which make explicit use of + * the custom stream buffer + */ + template< + typename _CharT, + typename _Traits = std::char_traits<_CharT> + > + class BasicStkIos + : virtual public std::basic_ios<_CharT, _Traits> + { + public: + typedef basic_stkbuf <_CharT,_Traits> StkBufType; + + BasicStkIos() + : mBuf() + { this->init(&mBuf) ;}; + + BasicStkIos(const char* fName, std::ios::openmode m, const char* pFilter) + : mBuf(fName, m, pFilter) + { this->init(&mBuf) ; } + + StkBufType* + rdbuf() + { return &mBuf; } + + protected: + StkBufType mBuf; + }; + + + /** + * @brief Controlling input for files. + * + * This class supports reading from named files, using the inherited + * functions from std::istream. To control the associated + * sequence, an instance of std::stkbuf is used. + */ + template< + typename _CharT, + typename _Traits = std::char_traits<_CharT> + > + class BasicIStkStream + : public BasicStkIos<_CharT, _Traits>, + public std::basic_istream<_CharT, _Traits> + { + public: + typedef BasicStkIos<_CharT, _Traits> BasicStkIosType; + typedef std::basic_istream<_CharT,_Traits> IStreamType; + + + // Constructors: + /** + * @brief Default constructor. + * + * Initializes @c mBuf using its default constructor, and passes + * @c &sb to the base class initializer. Does not open any files + * (you haven't given it a filename to open). + */ + BasicIStkStream() + : BasicStkIosType(), + IStreamType(BasicStkIosType::rdbuf()) + {}; + + /** + * @brief Create an input file stream. + * @param fName String specifying the filename. + * @param m Open file in specified mode (see std::ios_base). + * @param pFilter String specifying pFilter command to use on fName + * + * @c ios_base::in is automatically included in + * @a m. + * + * Tip: When using std::string to hold the filename, you must use + * .c_str() before passing it to this constructor. + */ + BasicIStkStream(const char* pFName, std::ios::openmode m=std::ios::out, const char* pFilter="") + : BasicStkIosType(), + IStreamType(BasicStkIosType::rdbuf()) + {this->open(pFName, std::ios::in, pFilter);} + + ~BasicIStkStream() + { + this->close(); + } + + /** + * @brief Opens an external file. + * @param s The name of the file. + * @param mode The open mode flags. + * @param pFilter The pFilter command to use + * + * Calls @c std::basic_filebuf::open(s,mode|in). If that function + * fails, @c failbit is set in the stream's error state. + * + * Tip: When using std::string to hold the filename, you must use + * .c_str() before passing it to this constructor. + */ + void open(const char* pFName, std::ios::openmode m=std::ios::in, const char* pFilter = "") + { + if (!BasicStkIosType::mBuf.open(pFName, m | std::ios_base::in, pFilter)) { + this->setstate(std::ios_base::failbit); + } + else { + // Closing an fstream should clear error state + BasicStkIosType::clear(); + } + } + + /** + * @brief Returns true if the external file is open. + */ + bool is_open() const {return BasicStkIosType::mBuf.is_open();} + + + /** + * @brief Closes the stream + */ + void close() {BasicStkIosType::mBuf.close();} + + /** + * @brief Returns the filename + */ + const std::string name() const {return BasicStkIosType::mBuf.name();} + + /// Returns a pointer to the main FILE structure + std::__c_file* + file() {return BasicStkIosType::mBuf.file();} + + /// Returns a pointer to the main FILE structure + std::__c_file* + fp() {return BasicStkIosType::mBuf.fp();} + + // /** + // * @brief Reads a single line + // * + // * This is a specialized function as std::getline does not provide a way to + // * read multiple end-of-line symbols (we need both '\n' and EOF to delimit + // * the line) + // */ + // void + // GetLine(string& rLine); + + }; // class BasicIStkStream + + + /** + * @brief Controlling output for files. + * + * This class supports reading from named files, using the inherited + * functions from std::ostream. To control the associated + * sequence, an instance of TNet::stkbuf is used. + */ + template< + typename _CharT, + typename _Traits = std::char_traits<_CharT> + > + class BasicOStkStream + : public BasicStkIos<_CharT, _Traits>, + public std::basic_ostream<_CharT, _Traits> + { + public: + typedef BasicStkIos<_CharT, _Traits> BasicStkIosType; + typedef std::basic_ostream<_CharT,_Traits> OStreamType; + + // Constructors: + /** + * @brief Default constructor. + * + * Initializes @c sb using its default constructor, and passes + * @c &sb to the base class initializer. Does not open any files + * (you haven't given it a filename to open). + */ + BasicOStkStream() + : BasicStkIosType(), + OStreamType(BasicStkIosType::rdbuf()) + {}; + + /** + * @brief Create an output file stream. + * @param fName String specifying the filename. + * @param m Open file in specified mode (see std::ios_base). + * @param pFilter String specifying pFilter command to use on fName + * + * @c ios_base::out|ios_base::trunc is automatically included in + * @a mode. + * + * Tip: When using std::string to hold the filename, you must use + * .c_str() before passing it to this constructor. + */ + BasicOStkStream(const char* pFName, std::ios::openmode m=std::ios::out, const char* pFilter="") + : BasicStkIosType(), + OStreamType(BasicStkIosType::rdbuf()) + {this->open(pFName, std::ios::out, pFilter);} + + /** + * @brief Opens an external file. + * @param fName The name of the file. + * @param m The open mode flags. + * @param pFilter String specifying pFilter command to use on fName + * + * Calls @c std::basic_filebuf::open(s,mode|out). If that function + * fails, @c failbit is set in the stream's error state. + * + * Tip: When using std::string to hold the filename, you must use + * .c_str() before passing it to this constructor. + */ + void open(const char* pFName, std::ios::openmode m=std::ios::out, const char* pFilter="") + { + if (!BasicStkIosType::mBuf.open(pFName, m | std::ios_base::out, pFilter)) + this->setstate(std::ios_base::failbit); + else + // Closing an fstream should clear error state + this->clear(); + } + + /** + * @brief Returns true if the external file is open. + */ + bool is_open() const + { return BasicStkIosType::mBuf.is_open();} + + /** + * @brief Closes the stream + */ + void close() + { BasicStkIosType::mBuf.close();} + + /** + * @brief Returns the filename + */ + const std::string name() const + { return BasicStkIosType::mBuf.name();} + + /// Returns a pointer to the main FILE structure + std::__c_file* + file() + { return BasicStkIosType::mBuf.file();} + + /// Returns a pointer to the main FILE structure + std::__c_file* + fp() + { return BasicStkIosType::mBuf.fp();} + + }; // class BasicOStkStream + + + /** + * We define some implicit stkbuf class + */ + ///@{ +#ifndef _GLIBCPP_USE_WCHAR_T + typedef BasicOStkStream<char> OStkStream; + typedef BasicOStkStream<wchar_t> WOStkStream; + typedef BasicIStkStream<char> IStkStream; + typedef BasicIStkStream<wchar_t> WIStkStream; +#else + typedef BasicOStkStream<char> WOStkStream; + typedef BasicOStkStream<wchar_t> WOStkStream; + typedef BasicIStkStream<char> WIStkStream; + typedef BasicIStkStream<wchar_t> WIStkStream; +#endif + /// @} + + /* + template<class T,class char_type> inline + BasicOStkStream<char_type>& operator << (BasicOStkStream<char_type> &ostream, const std::vector<T> &vec){ + ostream << vec.size() << std::endl; + for(size_t i=0;i<vec.size();i++) ostream << vec[i]; + return ostream; + } + + template<class T,class char_type> inline BasicIStkStream<char_type> &operator >> (BasicIStkStream<char_type> &istream, std::vector<T> &vec){ + size_t sz; + istream >> sz; if(!istream.good()){ throw std::runtime_error(std::string("Error reading to vector of [something]: stream bad\n")); } + int ch = istream.get(); if(ch!='\n' || !istream.good()){ throw std::runtime_error(std::string("Expecting newline after vector size, got " + (std::string)(char)ch));} // TODO: This code may not be right for wchar. + vec.resize(sz); + for(size_t i=0;i<vec.size();i++) istream >> vec[i]; + return istream; + }*/ + + template<class T> inline + std::ostream & operator << (std::ostream &ostream, const std::vector<T> &vec){ + ostream << vec.size() << std::endl; + for(size_t i=0;i<vec.size();i++) ostream << vec[i] << "\n"; // '\n' is necessary in case item is atomic e.g. a number. + return ostream; + } + + template<class T> inline std::istream& operator >> (std::istream &istream, std::vector<T> &vec){ + size_t sz; + istream >> sz; if(!istream.good()){ throw std::runtime_error(std::string("Error reading to vector of [something]: stream bad\n")); } + // int ch = istream.get(); if(ch!='\n' || !istream.good()){ throw std::runtime_error(std::string("Expecting newline after vector size\n")); // TODO: This code may not be right for wchar. + vec.resize(sz); + for(size_t i=0;i<vec.size();i++) istream >> vec[i]; + return istream; + } + + template<class T> inline + std::ostream & operator << (std::ostream &ostream, const std::list<T> &lst){ + ostream << lst.size() << std::endl; + typename std::list<T>::iterator it; + for(it = lst.begin(); it != lst.end(); it++) + ostream << *it << "\n"; // '\n' is necessary in case item is atomic e.g. a number. + return ostream; + } + + template<class T> inline std::istream& operator >> (std::istream &istream, std::list<T> &lst){ + size_t sz; + istream >> sz; if(!istream.good()){ throw std::runtime_error(std::string("Error reading to list of [something]: stream bad\n")); } + lst.resize(sz); + typename std::list<T>::iterator it; + for(it = lst.begin(); it != lst.end(); it++) + istream >> *it; + return istream; + } + +}; // namespace TNet + + +using TNet::operator >>; +using TNet::operator <<; + + +# include "StkStream.tcc" + +// TNet_StkStream_h +#endif diff --git a/htk_io/src/KaldiLib/StkStream.tcc b/htk_io/src/KaldiLib/StkStream.tcc new file mode 100644 index 0000000..e3de1ae --- /dev/null +++ b/htk_io/src/KaldiLib/StkStream.tcc @@ -0,0 +1,228 @@ +#ifndef TNet_StkStream_tcc +#define TNet_StkStream_tcc + +#include <cstring> +#include <iostream> + +#include "Common.h" + +#pragma GCC system_header + +namespace TNet +{ + + //****************************************************************************** + template< + typename _CharT, + typename _Traits + > + basic_stkbuf<_CharT, _Traits> * + basic_stkbuf<_CharT, _Traits>:: + close(void) + { + // we only want to close an opened file + if (this->is_open()) + { + // we want to call the parent close() procedure + std::basic_filebuf<_CharT, _Traits>::close(); + + // and for different stream type we perform different closing + if (mStreamType == basic_stkbuf::t_file) + { + fclose(mpFilePtr); + } + else if (mStreamType == basic_stkbuf::t_pipe) + { + pclose(mpFilePtr); + } + else if (mStreamType == basic_stkbuf::t_stdio) + { + + } + + mpFilePtr = NULL; + mFilename = ""; + mMode = std::ios_base::openmode(0); + mStreamType = basic_stkbuf::t_undef; + return this; + } + else + return 0; + } + + + template< + typename _CharT, + typename _Traits + > + void + basic_stkbuf<_CharT, _Traits>:: + open_mode(std::ios_base::openmode __mode, int&, int&, char* __c_mode) + { + bool __testb = __mode & std::ios_base::binary; + bool __testi = __mode & std::ios_base::in; + bool __testo = __mode & std::ios_base::out; + bool __testt = __mode & std::ios_base::trunc; + bool __testa = __mode & std::ios_base::app; + + if (!__testi && __testo && !__testt && !__testa) + strcpy(__c_mode, "w"); + if (!__testi && __testo && !__testt && __testa) + strcpy(__c_mode, "a"); + if (!__testi && __testo && __testt && !__testa) + strcpy(__c_mode, "w"); + if (__testi && !__testo && !__testt && !__testa) + strcpy(__c_mode, "r"); + if (__testi && __testo && !__testt && !__testa) + strcpy(__c_mode, "r+"); + if (__testi && __testo && __testt && !__testa) + strcpy(__c_mode, "w+"); + if (__testb) + strcat(__c_mode, "b"); + } + + + //****************************************************************************** + template< + typename _CharT, + typename _Traits + > + basic_stkbuf<_CharT, _Traits> * + basic_stkbuf<_CharT, _Traits>:: + open(const char* pFName, std::ios::openmode m, const char* pFilter) + { + basic_stkbuf<_CharT, _Traits>* p_ret = NULL; + + if (NULL == pFName) + return NULL; + + // we need to assure, that the stream is not open + if (!this->is_open()) + { + char mstr[4] = {'\0', '\0', '\0', '\0'}; + int __p_mode = 0; + int __rw_mode = 0; + + // now we decide, what kind of file we open + if (!strcmp(pFName,"-")) + { + if ((m & std::ios::in) && !(m & std::ios::out)) + { + mpFilePtr = stdin; + mMode = std::ios::in; + mFilename = pFName; + mStreamType = t_stdio; + p_ret = this; + } + else if ((m & std::ios::out) && !(m & std::ios::in)) + { + mpFilePtr = stdout; + mMode = std::ios::out; + mFilename = pFName; + mStreamType = t_stdio; + p_ret = this; + } + else + p_ret = NULL; + } + else if ( pFName[0] == '|' ) + { + const char* command = pFName + 1; + + if ((m & std::ios::in) && !(m & std::ios::out)) m = std::ios::in; + else if ((m & std::ios::out) && !(m & std::ios::in)) m = std::ios::out; + else return NULL; + + // we need to make some conversion + // iostream -> stdio open mode string + this->open_mode(m, __p_mode, __rw_mode, mstr); + + if ((mpFilePtr = popen(command, mstr))) + { + mFilename = command; + mMode = m; + mStreamType = t_pipe; + p_ret = this; + } + else + p_ret = 0; + } + else + { + // maybe we have a filter specified + if ( pFilter + && ('\0' != pFilter[0])) + { + char* command = ExpandHtkFilterCmd(pFilter, pFName, "$"); + + if ((m & std::ios::in) && !(m & std::ios::out)) m = std::ios::in; + else if ((m & std::ios::out) && !(m & std::ios::in)) m = std::ios::out; + else return NULL; + + // we need to make some conversion + // iostream -> stdio open mode string + this->open_mode(m, __p_mode, __rw_mode, mstr); + + if ((mpFilePtr = popen(command, mstr))) + { + mFilename = pFName; + mMode = m; + mStreamType = t_pipe; + p_ret = this; + } + else + p_ret = 0; + } + else // if (!filter.empty()) + { + // we need to make some conversion + // iostream -> stdio open mode string + this->open_mode(m, __p_mode, __rw_mode, mstr); + + if ((mpFilePtr = fopen(pFName, mstr))) + { + mFilename = pFName; + mMode = m; + mStreamType = t_file; + p_ret = this; + } + else { + p_ret = NULL; + } + } + } + + // here we perform what the stdio_filebuf would do + if (p_ret) { + superopen(mpFilePtr, m); + } + } //if (!isopen) + + return p_ret; + } + + //****************************************************************************** + template< + typename _CharT, + typename _Traits + > + void + basic_stkbuf<_CharT, _Traits>:: + superopen(std::__c_file* __f, std::ios_base::openmode __mode, + size_t __size) + { + this->_M_file.sys_open(__f, __mode); + if (this->is_open()) + { + this->_M_mode = __mode; + this->_M_buf_size = __size; + this->_M_allocate_internal_buffer(); + this->_M_reading = false; + this->_M_writing = false; + this->_M_set_buffer(-1); + } + } +} + +// TNet_StkStream_tcc +#endif diff --git a/htk_io/src/KaldiLib/Timer.cc b/htk_io/src/KaldiLib/Timer.cc new file mode 100644 index 0000000..692969b --- /dev/null +++ b/htk_io/src/KaldiLib/Timer.cc @@ -0,0 +1,5 @@ +#include "Timer.h" + +/* +TNet::Timer gTimer; +*/ diff --git a/htk_io/src/KaldiLib/Timer.h b/htk_io/src/KaldiLib/Timer.h new file mode 100644 index 0000000..b220b93 --- /dev/null +++ b/htk_io/src/KaldiLib/Timer.h @@ -0,0 +1,103 @@ +#ifndef Timer_h +#define Timer_h + +#include "Error.h" +#include <sstream> + + + +#if defined(_WIN32) || defined(MINGW) + +# include <windows.h> + +namespace TNet +{ + class Timer { + public: + void + Start(void) + { + static int first = 1; + + if(first) { + QueryPerformanceFrequency(&mFreq); + first = 0; + } + QueryPerformanceCounter(&mTStart); + } + + void + End(void) + { QueryPerformanceCounter(&mTEnd); } + + double + Val() + { + return ((double)mTEnd.QuadPart - (double)mTStart.QuadPart) / + ((double)mFreq.QuadPart); + } + + private: + LARGE_INTEGER mTStart; + LARGE_INTEGER mTEnd; + LARGE_INTEGER mFreq; + }; +} + +#else + +# include <sys/time.h> +# include <unistd.h> + +namespace TNet +{ + class Timer + { + public: + void + Start() + { gettimeofday(&this->mTStart, &mTz); } + + void + End() + { gettimeofday(&mTEnd,&mTz); } + + double + Val() + { + double t1, t2; + + t1 = (double)mTStart.tv_sec + (double)mTStart.tv_usec/(1000*1000); + t2 = (double)mTEnd.tv_sec + (double)mTEnd.tv_usec/(1000*1000); + return t2-t1; + } + + private: + struct timeval mTStart; + struct timeval mTEnd; + struct timezone mTz; + }; +} + +#endif + + + + + + + +/////////////////////////////////////////////////////////////// +// Macros for adding the time intervals to time accumulator +#if PROFILING==1 +# define TIMER_START(timer) timer.Start() +# define TIMER_END(timer,sum) timer.End(); sum += timer.Val() +#else +# define TIMER_START(timer) +# define TIMER_END(timer,sum) +#endif + +#endif + + + diff --git a/htk_io/src/KaldiLib/Tokenizer.cc b/htk_io/src/KaldiLib/Tokenizer.cc new file mode 100644 index 0000000..0c49050 --- /dev/null +++ b/htk_io/src/KaldiLib/Tokenizer.cc @@ -0,0 +1,53 @@ +#include "Tokenizer.h" +#include "string.h" + +namespace TNet +{ + //**************************************************************************** + //**************************************************************************** + void + Tokenizer:: + AddString(const char* pString) + { + // copy into string struct, which is more convenient + std::string aux_string(pString); + std::string aux_record; + std::string::size_type cur_pos = 0; + std::string::size_type old_pos = 0; + std::string::size_type search_start = 0; + + // make sure we have enough space + aux_record.reserve(aux_string.length()); + + // find all of separators and make a list of tokens + while(old_pos < std::string::npos) { + // find the next separator + cur_pos = aux_string.find_first_of(mSeparator, search_start); + + // if backslash is in front of separator, ignore this separator + if (cur_pos != 0 && cur_pos != std::string::npos && + pString[cur_pos - 1] == '\\') { + search_start = cur_pos + 1; + continue; + } + + // we don't want to have empty records + if (!(cur_pos == old_pos && mSkipEmpty)) { + // extract token + aux_record.insert(0, pString+old_pos, cur_pos==std::string::npos ? strlen(pString+old_pos) : cur_pos - old_pos); + // insert to list + this->push_back(aux_record); + + // we don't need the contents of the token + aux_record.erase(); + } + + // update old position so that it points behind the separator + old_pos = cur_pos < std::string::npos ? cur_pos + 1 : cur_pos; + search_start = old_pos; + } + } + + +} // namespace TNet + diff --git a/htk_io/src/KaldiLib/Tokenizer.h b/htk_io/src/KaldiLib/Tokenizer.h new file mode 100644 index 0000000..1be717b --- /dev/null +++ b/htk_io/src/KaldiLib/Tokenizer.h @@ -0,0 +1,45 @@ +#include <list> +#include <string> + +namespace TNet { + /** + * @brief General string tokenizer + */ + class Tokenizer + : public std::list<std::string> + { + public: + // Constructors and Destructors ............................................ + Tokenizer(const char* pSeparator, bool skipEmpty = false) + : std::list<std::string>(), mSeparator(pSeparator), mSkipEmpty(skipEmpty) + {} + + Tokenizer(const char* pString, const char* pSeparator, bool skipEmpty = false) + : std::list<std::string>(), mSeparator(pSeparator), mSkipEmpty(skipEmpty) + { AddString(pString); } + + ~Tokenizer() + {} + + /** + * @brief Parses a string and appends the tokens to the list + * @param pString string to parse + */ + void + AddString(const char* pString); + + /** + * @brief Constant accessor to the separators string + * @return Const refference + */ + const std::string& + Separator() const + {return mSeparator;} + + private: + std::string mSeparator; ///< holds the list of separators + bool mSkipEmpty; ///< if true, multiple separators will be regarded as one + }; // class Tokenizer +} // namespace TNet + + diff --git a/htk_io/src/KaldiLib/Types.h b/htk_io/src/KaldiLib/Types.h new file mode 100644 index 0000000..6a5bfac --- /dev/null +++ b/htk_io/src/KaldiLib/Types.h @@ -0,0 +1,78 @@ +#ifndef TNet_Types_h +#define TNet_Types_h + +#ifdef HAVE_ATLAS +extern "C"{ + #include <cblas.h> + #include <clapack.h> +} +#endif + + +namespace TNet +{ + // TYPEDEFS .................................................................. +#if DOUBLEPRECISION + typedef double BaseFloat; +#else + typedef float BaseFloat; +#endif + +#ifndef UINT_16 + typedef unsigned short UINT_16 ; + typedef unsigned UINT_32 ; + typedef short INT_16 ; + typedef int INT_32 ; + typedef float FLOAT_32 ; + typedef double DOUBLE_64 ; +#endif + + + + // ........................................................................... + // The following declaration assumes that SSE instructions are enabled + // and that we are using GNU C/C++ compiler, which defines the __attribute__ + // notation. + // + // ENABLE_SSE is defined in <config.h>. Its value depends on options given + // in the configure phase of builidng the library +#if defined(__GNUC__ ) + // vector of four single floats + typedef float v4sf __attribute__((vector_size(16))); + // vector of two single doubles + typedef double v2sd __attribute__((vector_size(16))); + + typedef BaseFloat BaseFloat16Aligned __attribute__((aligned(16))) ; + + typedef union + { + v4sf v; + float f[4]; + } f4vector; + + typedef union + { + v2sd v; + double f[2]; + } d2vector; +#endif // ENABLE_SSE && defined(__GNUC__ ) + + + + typedef enum + { +#ifdef HAVE_ATLAS + TRANS = CblasTrans, + NO_TRANS = CblasNoTrans +#else + TRANS = 'T', + NO_TRANS = 'N' +#endif + } MatrixTrasposeType; + + + +} // namespace TNet + +#endif // #ifndef TNet_Types_h + diff --git a/htk_io/src/KaldiLib/UserInterface.cc b/htk_io/src/KaldiLib/UserInterface.cc new file mode 100644 index 0000000..b59a6c5 --- /dev/null +++ b/htk_io/src/KaldiLib/UserInterface.cc @@ -0,0 +1,669 @@ +#include <stdexcept> +#include <sstream> +#include <stdarg.h> + +#include "UserInterface.h" +#include "StkStream.h" +#include "Features.h" + +namespace TNet +{ + //*************************************************************************** + //*************************************************************************** + int + npercents(const char *str) + { + int ret = 0; + while (*str) if (*str++ == '%') ret++; + return ret; + } + + + //*************************************************************************** + //*************************************************************************** + void + UserInterface:: + ReadConfig(const char *file_name) + { + std::string line_buf; + std::string::iterator chptr; + std::string key; + std::string value; + std::ostringstream ss; + int line_no = 0; + IStkStream i_stream; + + + i_stream.open(file_name, std::ios::binary); + if (!i_stream.good()) { + throw std::runtime_error(std::string("Cannot open input config file ") + + file_name); + } + i_stream >> std::ws; + + while (!i_stream.eof()) { + size_t i_pos; + + // read line + std::getline(i_stream, line_buf); + i_stream >> std::ws; + + if (i_stream.fail()) { + throw std::runtime_error(std::string("Error reading (") + + file_name + ":" + (ss << line_no,ss).str() + ")"); + } + + // increase line counter + line_no++; + + // cut comments + if (std::string::npos != (i_pos = line_buf.find('#'))) { + line_buf.erase(i_pos); + } + + // cut leading and trailing spaces + Trim(line_buf); + + // if empty line, then skip it + if (0 == line_buf.length()) { + continue; + } + + // line = line_buf.c_str(); + // chptr = parptr; + + chptr = line_buf.begin(); + + for (;;) { + // Replace speces by '_', which is removed in InsertConfigParam + while (isalnum(*chptr) || *chptr == '_' || *chptr == '-') { + chptr++; + } + + while (std::isspace(*chptr)) { + *chptr = '_'; + chptr++; + } + + if (*chptr != ':') { + break; + } + + chptr++; + + while (std::isspace(*chptr)) { + *chptr = '_'; + chptr++; + } + } + + if (*chptr != '=') { + throw std::runtime_error(std::string("Character '=' expected (") + + file_name + ":" + (ss.str(""),ss<<line_no,ss).str() + ")"); + } + + key.assign(line_buf.begin(), chptr); + + chptr++; + + value.assign(chptr, line_buf.end()); + + ParseHTKString(value, value); + InsertConfigParam(key.c_str(), value.c_str(), 'C'); + } + + i_stream.close(); + } + + + //*************************************************************************** + //*************************************************************************** + void + UserInterface:: + InsertConfigParam(const char *pParamName, const char *value, int optionChar) + { + std::string key(pParamName); + std::string::iterator i_key = key.begin(); + + while (i_key != key.end()) { + if (*i_key == '-' || *i_key == '_') { + i_key = key.erase(i_key); + } + else { + *i_key = toupper(*i_key); + i_key ++; + } + } + + mMap[key].mValue = value; + mMap[key].mRead = false; + mMap[key].mOption = optionChar; + } + + //*************************************************************************** + //*************************************************************************** + int + UserInterface:: + ParseOptions( + int argc, + char* argv[], + const char* pOptionMapping, + const char* pToolName) + { + int i; + int opt = '?'; + int optind; + bool option_must_follow = false; + char param[1024]; + char* value; + const char* optfmt; + const char* optarg; + char* chptr; + char* bptr; + char tstr[4] = " -?"; + unsigned long long option_mask = 0; + std::ostringstream ss; + + #define MARK_OPTION(ch) {if (isalpha(ch)) option_mask |= 1ULL << ((ch) - 'A');} + #define OPTION_MARK(ch) (isalpha(ch) && ((1ULL << ((ch) - 'A')) & option_mask)) + #define IS_OPTION(str) ((str)[0] == '-' && (isalpha((str)[1]) || (str)[1] == '-')) + + //search for the -A param + for (optind = 1; optind < argc; optind++) { + // we found "--", no -A + if (!strcmp(argv[optind], "--")) { + break; + } + + //repeat till we find -A + if (argv[optind][0] != '-' || argv[optind][1] != 'A') { + continue; + } + + // just "-A" form + if (argv[optind][2] != '\0') { + throw std::runtime_error(std::string("Unexpected argument '") + + (argv[optind] + 2) + "' after option '-A'"); + } + + for (i=0; i < argc; i++) { + // display all params + if(strchr(argv[i], ' ') || strchr(argv[i], '*')) + std::cout << '\'' << argv[i] << '\'' << " "; + else std::cout << argv[i] << " "; + } + + std::cout << std::endl; + + break; + } + + for (optind = 1; optind < argc; optind++) { + // find the '-C?' parameter (possible two configs) + if (!strcmp(argv[optind], "--")) break; + if (argv[optind][0] != '-' || argv[optind][1] != 'C') continue; + if (argv[optind][2] != '\0') { + ReadConfig(argv[optind] + 2); + } else if (optind+1 < argc && !IS_OPTION(argv[optind+1])) { + ReadConfig(argv[++optind]); + } else { + throw std::runtime_error("Config file name expected after option '-C'"); + } + } + + for (optind = 1; optind < argc; optind++) { + if (!strcmp(argv[optind], "--")) break; + if (argv[optind][0] != '-' || argv[optind][1] != '-') continue; + + bptr = new char[strlen(pToolName) + strlen(argv[optind]+2) + 2]; + strcat(strcat(strcpy(bptr, pToolName), ":"), argv[optind]+2); + value = strchr(bptr, '='); + if (!value) { + throw std::runtime_error(std::string("Character '=' expected after option '") + + argv[optind] + "'"); + } + + *value++ = '\0'; + + InsertConfigParam(bptr, value /*? value : "TRUE"*/, '-'); + delete [] bptr; + } + + for (optind = 1; optind < argc && IS_OPTION(argv[optind]); optind++) { + option_must_follow = false; + tstr[2] = opt = argv[optind][1]; + optarg = argv[optind][2] != '\0' ? argv[optind] + 2 : NULL; + + if (opt == '-' && !optarg) { // '--' terminates the option list + return optind+1; + } + if (opt == 'C' || opt == '-') { // C, A and long options have been already processed + if (!optarg) optind++; + continue; + } + if (opt == 'A') continue; + + chptr = strstr((char*)pOptionMapping, tstr); + if (chptr == NULL) { + throw std::runtime_error(std::string("Invalid command line option '-") + + static_cast<char>(opt) + "'"); + } + + chptr += 3; + while (std::isspace(*chptr)) { + chptr++; + } + + if (!chptr || chptr[0] == '-') {// Option without format string will be ignored + optfmt = " "; + } else { + optfmt = chptr; + while (*chptr && !std::isspace(*chptr)) { + chptr++; + } + if (!*chptr) { + throw std::runtime_error("Fatal: Unexpected end of optionMap string"); + } + } + for (i = 0; !std::isspace(*optfmt); optfmt++) { + while (std::isspace(*chptr)) chptr++; + value = chptr; + while (*chptr && !std::isspace(*chptr)) chptr++; + assert(static_cast<unsigned int>(chptr-value+1) < sizeof(param)); + strncat(strcat(strcpy(param, pToolName), ":"), value, chptr-value); + param[chptr-value+strlen(pToolName)+1] = '\0'; + switch (*optfmt) { + case 'n': + value = strchr(param, '='); + if (value) *value = '\0'; + InsertConfigParam(param, + value ? value + 1: "TRUE", opt); + break; + + case 'l': + case 'o': + case 'r': + i++; + if (!optarg && (optind+1==argc || IS_OPTION(argv[optind+1]))) { + if (*optfmt == 'r' || *optfmt == 'l') { + throw std::runtime_error(std::string("Argument ") + + (ss<<i,ss).str() + " of option '-" + + static_cast<char>(opt) + "' expected"); + } + optfmt = " "; // Stop reading option arguments + break; + } + if (!optarg) optarg = argv[++optind]; + if (*optfmt == 'o') { + option_must_follow = (bool) 1; + } + bptr = NULL; + + // For repeated use of option with 'l' (list) format, append + // ',' and argument string to existing config parameter value. + if (*optfmt == 'l' && OPTION_MARK(opt)) { + bptr = strdup(GetStr(param, "")); + if (bptr == NULL) throw std::runtime_error("Insufficient memory"); + bptr = (char*) realloc(bptr, strlen(bptr) + strlen(optarg) + 2); + if (bptr == NULL) throw std::runtime_error("Insufficient memory"); + strcat(strcat(bptr, ","), optarg); + optarg = bptr; + } + MARK_OPTION(opt); + InsertConfigParam(param, optarg, opt); + free(bptr); + optarg = NULL; + break; + + default : + throw std::runtime_error(std::string("Fatal: Invalid character '") + + *optfmt + "' in optionMap after " + tstr); + } + } + if (optarg) { + throw std::runtime_error(std::string("Unexpected argument '") + + optarg + "' after option '-" + + static_cast<char>(opt) + "'"); + } + } + + for (i = optind; i < argc && !IS_OPTION(argv[i]); i++) + {} + + if (i < argc) { + throw std::runtime_error(std::string("No option expected after first non-option argument '") + + argv[optind] + "'"); + } + + if (option_must_follow) { + throw std::runtime_error(std::string("Option '-") + + static_cast<char>(opt) + + "' with optional argument must not be the last option"); + } + + return optind; + } + + + //*************************************************************************** + //*************************************************************************** + int + UserInterface:: + GetFeatureParams( + int * derivOrder, + int ** derivWinLens, + int * startFrmExt, + int * endFrmExt, + char ** CMNPath, + char ** CMNFile, + const char ** CMNMask, + char ** CVNPath, + char ** CVNFile, + const char ** CVNMask, + const char ** CVGFile, + const char * pToolName, + int pseudoModeule) + { + const char * str; + int targetKind; + char * chrptr; + char paramName[32]; + const char * CMNDir; + const char * CVNDir; + + strcpy(paramName, pToolName); + strcat(paramName, pseudoModeule == 1 ? "SPARM1:" : + pseudoModeule == 2 ? "SPARM2:" : ""); + + chrptr = paramName + strlen(paramName); + + strcpy(chrptr, "STARTFRMEXT"); + *startFrmExt = GetInt(paramName, 0); + strcpy(chrptr, "ENDFRMEXT"); + *endFrmExt = GetInt(paramName, 0); + + *CMNPath = *CVNPath = NULL; + strcpy(chrptr, "CMEANDIR"); + CMNDir = GetStr(paramName, NULL); + strcpy(chrptr, "CMEANMASK"); + *CMNMask = GetStr(paramName, NULL); + + if (*CMNMask != NULL) { + *CMNPath = (char*) malloc((CMNDir ? strlen(CMNDir) : 0) + npercents(*CMNMask) + 2); + if (*CMNPath == NULL) throw std::runtime_error("Insufficient memory"); + if (CMNDir != NULL) strcat(strcpy(*CMNPath, CMNDir), "/"); + *CMNFile = *CMNPath + strlen(*CMNPath); + } + strcpy(chrptr, "VARSCALEDIR"); + CVNDir = GetStr(paramName, NULL); + strcpy(chrptr, "VARSCALEMASK"); + *CVNMask = GetStr(paramName, NULL); + + + if (*CVNMask != NULL) { + *CVNPath = (char*) malloc((CVNDir ? strlen(CVNDir) : 0) + npercents(*CVNMask) + 2); + if (*CVNPath == NULL) throw std::runtime_error("Insufficient memory"); + if (CVNDir != NULL) strcat(strcpy(*CVNPath, CVNDir), "/"); + *CVNFile = *CVNPath + strlen(*CVNPath); + } + strcpy(chrptr, "VARSCALEFN"); + *CVGFile = GetStr(paramName, NULL); + strcpy(chrptr, "TARGETKIND"); + str = GetStr(paramName, "ANON"); + + targetKind = FeatureRepository::ReadParmKind(str, false); + + if (targetKind == -1) { + throw std::runtime_error(std::string("Invalid TARGETKIND = '") + + str + "'"); + } + + strcpy(chrptr, "DERIVWINDOWS"); + if ((str = GetStr(paramName, NULL)) != NULL) { + long lval; + *derivOrder = 0; + *derivWinLens = NULL; + + if (NULL != str) + { + while ((str = strtok((char *) str, " \t_")) != NULL) + { + lval = strtol(str, &chrptr, 0); + if (!*str || *chrptr) { + throw std::runtime_error("Integers separated by '_' expected for parameter DERIVWINDOWS"); + } + *derivWinLens = (int *)realloc(*derivWinLens, ++*derivOrder*sizeof(int)); + if (*derivWinLens == NULL) throw std::runtime_error("Insufficient memory"); + (*derivWinLens)[*derivOrder-1] = lval; + str = NULL; + } + } + + return targetKind; + } + *derivOrder = targetKind & PARAMKIND_T ? 3 : + targetKind & PARAMKIND_A ? 2 : + targetKind & PARAMKIND_D ? 1 : 0; + + if (*derivOrder || targetKind != PARAMKIND_ANON) { + *derivWinLens = (int *) malloc(3 * sizeof(int)); + if (*derivWinLens == NULL) throw std::runtime_error("Insufficient memory"); + + strcpy(chrptr, "DELTAWINDOW"); + (*derivWinLens)[0] = GetInt(paramName, 2); + strcpy(chrptr, "ACCWINDOW"); + (*derivWinLens)[1] = GetInt(paramName, 2); + strcpy(chrptr, "THIRDWINDOW"); + (*derivWinLens)[2] = GetInt(paramName, 2); + return targetKind; + } + *derivWinLens = NULL; + *derivOrder = -1; + return targetKind; + } + + + //*************************************************************************** + //*************************************************************************** + UserInterface::ValueRecord* + UserInterface:: + GetParam(const char* pParamName) + { + MapType::iterator it; + + // this is done only for convenience. in the loop we will increase the + // pointer again + pParamName--; + + // we iteratively try to find the param name in the map. if an attempt + // fails, we strip off all characters until the first ':' and we search + // again + do { + pParamName++; + it = mMap.find(pParamName); + } while ((it == mMap.end()) && (NULL != (pParamName = strchr(pParamName, ':')))); + + if (it == mMap.end()) { + return NULL; + } + else { + it->second.mRead = true; + return &(it->second); + } + } + + + //*************************************************************************** + //*************************************************************************** + const char * + UserInterface:: + GetStr( + const char * pParamName, + const char * default_value) + { + ValueRecord* p_val = GetParam(pParamName); + + if (NULL == p_val) { + return default_value; + } + else { + return p_val->mValue.c_str(); + } + } + + + //*************************************************************************** + //*************************************************************************** + long + UserInterface:: + GetInt( + const char *pParamName, + long default_value) + { + char *chrptr; + ValueRecord* p_val = GetParam(pParamName); + + if (NULL == p_val) { + return default_value; + } + + const char *val = p_val->mValue.c_str(); + default_value = strtol(val, &chrptr, 0); + if (!*val || *chrptr) { + throw std::runtime_error(std::string("Integer number expected for ") + + pParamName + " but found '" + val + "'"); + } + return default_value; + } + + //*************************************************************************** + //*************************************************************************** + float + UserInterface:: + GetFlt( + const char * pParamName, + float default_value) + { + char *chrptr; + ValueRecord* p_val = GetParam(pParamName); + + if (NULL == p_val) { + return default_value; + } + + const char *val = p_val->mValue.c_str(); + default_value = strtod(val, &chrptr); + if (!*val || *chrptr) { + throw std::runtime_error(std::string("Decimal number expected for ") + + pParamName + " but found '" + val + "'"); + } + return default_value; + } + + //*************************************************************************** + //*************************************************************************** + bool + UserInterface:: + GetBool( + const char * pParamName, + bool default_value) + { + ValueRecord* p_val = GetParam(pParamName); + + if (NULL == p_val) { + return default_value; + } + + const char* val = p_val->mValue.c_str(); + + if (!strcasecmp(val, "TRUE") || !strcmp(val, "T")) return 1; + if (strcasecmp(val, "FALSE") && strcmp(val, "F")) { + throw std::runtime_error(std::string("TRUE or FALSE expected for ") + + pParamName + " but found '" + val + "'"); + } + return false; + } + + //*************************************************************************** + //*************************************************************************** + // '...' are pairs: string and corresponding integer value , terminated by NULL + int + UserInterface:: + GetEnum( + const char * pParamName, + int default_value, + ...) + { + ValueRecord* p_val = GetParam(pParamName); + + if (NULL == p_val) { + return default_value; + } + + const char* val = p_val->mValue.c_str(); + char* s; + int i = 0, cnt = 0, l = 0; + va_list ap; + + va_start(ap, default_value); + while ((s = va_arg(ap, char *)) != NULL) { + l += strlen(s) + 2; + ++cnt; + i = va_arg(ap, int); + if (!strcmp(val, s)) break; + } + va_end(ap); + + if (s) { + return i; + } + + //To report error, create string listing all possible values + s = (char*) malloc(l + 1); + s[0] = '\0'; + va_start(ap, default_value); + for (i = 0; i < cnt; i++) { + strcat(s, va_arg(ap, char *)); + va_arg(ap, int); + if (i < cnt - 2) strcat(s, ", "); + else if (i == cnt - 2) strcat(s, " or "); + } + + va_end(ap); + + throw std::runtime_error(std::string(s) + " expected for " + + pParamName + " but found '" + val + "'"); + + return 0; + } + + + //*************************************************************************** + //*************************************************************************** + void + UserInterface:: + PrintConfig(std::ostream& rStream) + { + rStream << "Configuration Parameters[" << mMap.size() << "]\n"; + for (MapType::iterator it = mMap.begin(); it != mMap.end(); ++it) { + rStream << (it->second.mRead ? " " : "# ") + << std::setw(35) << std::left << it->first << " = " + << std::setw(30) << std::left << it->second.mValue + << " # -" << it->second.mOption << std::endl; + } + } + + //*************************************************************************** + //*************************************************************************** + void + UserInterface:: + CheckCommandLineParamUse() + { + for (MapType::iterator it = mMap.begin(); it != mMap.end(); ++it) { + if (!it->second.mRead && it->second.mOption != 'C') { + Error("Unexpected command line parameter " + it->first); + } + } + } + +} diff --git a/htk_io/src/KaldiLib/UserInterface.h b/htk_io/src/KaldiLib/UserInterface.h new file mode 100644 index 0000000..fa189e7 --- /dev/null +++ b/htk_io/src/KaldiLib/UserInterface.h @@ -0,0 +1,166 @@ +#ifndef TNet_UserInterface_h +#define TNet_UserInterface_h + +#include <iostream> +#include <cstdlib> +#include <string> +#include <map> + +namespace TNet +{ + /** ************************************************************************** + ** ************************************************************************** + */ + class UserInterface + { + public: + struct ValueRecord { + std::string mValue; + char mOption; + bool mRead; + }; + + + void InsertConfigParam( + const char *param_name, + const char *value, + int optionChar); + + + void + ReadConfig(const char *pFileName); + + + void + CheckCommandLineParamUse(); + + + /** + * @brief Retreives the content of a parameter + * @param pParamName Name of the parameter to look for + * @return Returns the pointer to the ValueRecord structure if success, + * otherwise return NULL + * + * We iteratively try to find the param name in the map. If an attempt + * fails, we strip off all characters until the first occurance of ':' + * and we search again + */ + ValueRecord* + GetParam(const char* pParamName); + + + /** + * @brief Returns the parameter's value as string + * + * @param param_name Parameter name + * @param default_value Value, which is returned in case the parameter + * was not found + * + * @return Pointer to the begining of the string if success, default_value + * otherwise + */ + const char* + GetStr( const char *param_name, const char *default_value); + + + /** + * @brief Returns the parameter's value as int + * + * @param param_name Parameter name + * @param default_value Value, which is returned in case the parameter + * was not found + * + * @return Returns the integer value if success, default_value + * otherwise + */ + long + GetInt( const char *param_name, long default_value); + + + /** + * @brief Returns the parameter's value as float + * + * @param param_name Parameter name + * @param default_value Value, which is returned in case the parameter + * was not found + * + * @return Returns the float value if success, default_value + * otherwise + */ + float + GetFlt( const char *param_name, float default_value); + + + /** + * @brief Returns the parameter's value as bool + * + * @param param_name Parameter name + * @param default_value Value, which is returned in case the parameter + * was not found + * + * @return Returns the bool value if success, default_value + * otherwise + * + * Note that true is returned if the value is 'TRUE' or 'T', false is + * returned if the value is 'FALSE' or 'F'. Otherwise exception is thrown + */ + bool + GetBool(const char *param_name, bool default_value); + + + /** + * @brief Returns the parameter's value as enum integer + * + * @param param_name Parameter name + * @param default_value Value, which is returned in case the parameter + * was not found + * + * @return Returns the index value if success, default_value + * otherwise + * + * Variable arguments specify the possible values of this parameter. If the + * value does not match any of these, exception is thrown. + */ + int + GetEnum( const char *param_name, int default_value, ...); + + + int GetFeatureParams( + int *derivOrder, + int **derivWinLens, + int *startFrmExt, + int *endFrmExt, + char **CMNPath, + char **CMNFile, + const char **CMNMask, + char **CVNPath, + char **CVNFile, + const char **CVNMask, + const char **CVGFile, + const char *toolName, + int pseudoModeule); + + + int ParseOptions( + int argc, + char* argv[], + const char* optionMapping, + const char* toolName); + + + /** + * @brief Send the defined paramaters to a stream + * + * @param rStream stream to use + */ + void + PrintConfig(std::ostream& rStream); + + public: + typedef std::map<std::string, ValueRecord> MapType; + MapType mMap; + }; +} + +#endif + diff --git a/htk_io/src/KaldiLib/Vector.cc b/htk_io/src/KaldiLib/Vector.cc new file mode 100644 index 0000000..020bae2 --- /dev/null +++ b/htk_io/src/KaldiLib/Vector.cc @@ -0,0 +1,110 @@ +#ifndef TNet_Vector_cc +#define TNet_Vector_cc + +#include <cstdlib> +#include <cmath> +#include <cstring> +#include <fstream> +#include <iomanip> +#include "Common.h" + +#ifdef HAVE_ATLAS +extern "C"{ + #include <cblas.h> +} +#endif + +#include "Common.h" +#include "Matrix.h" +#include "Vector.h" + +namespace TNet +{ + +#ifdef HAVE_ATLAS + template<> + float + BlasDot<>(const Vector<float>& rA, const Vector<float>& rB) + { + assert(rA.mDim == rB.mDim); + return cblas_sdot(rA.mDim, rA.pData(), 1, rB.pData(), 1); + } + + template<> + double + BlasDot<>(const Vector<double>& rA, const Vector<double>& rB) + { + assert(rA.mDim == rB.mDim); + return cblas_ddot(rA.mDim, rA.pData(), 1, rB.pData(), 1); + } + + template<> + Vector<float>& + Vector<float>:: + BlasAxpy(const float alpha, const Vector<float>& rV) + { + assert(mDim == rV.mDim); + cblas_saxpy(mDim, alpha, rV.pData(), 1, mpData, 1); + return *this; + } + + template<> + Vector<double>& + Vector<double>:: + BlasAxpy(const double alpha, const Vector<double>& rV) + { + assert(mDim == rV.mDim); + cblas_daxpy(mDim, alpha, rV.pData(), 1, mpData, 1); + return *this; + } + + template<> + Vector<int>& + Vector<int>:: + BlasAxpy(const int alpha, const Vector<int>& rV) + { + assert(mDim == rV.mDim); + for(int i=0; i<Dim(); i++) { + (*this)[i] += rV[i]; + } + return *this; + } + + + template<> + Vector<float>& + Vector<float>:: + BlasGemv(const float alpha, const Matrix<float>& rM, MatrixTrasposeType trans, const Vector<float>& rV, const float beta) + { + assert((trans == NO_TRANS && rM.Cols() == rV.mDim && rM.Rows() == mDim) + || (trans == TRANS && rM.Rows() == rV.mDim && rM.Cols() == mDim)); + + cblas_sgemv(CblasRowMajor, static_cast<CBLAS_TRANSPOSE>(trans), rM.Rows(), rM.Cols(), alpha, rM.pData(), rM.Stride(), + rV.pData(), 1, beta, mpData, 1); + return *this; + } + + + + template<> + Vector<double>& + Vector<double>:: + BlasGemv(const double alpha, const Matrix<double>& rM, MatrixTrasposeType trans, const Vector<double>& rV, const double beta) + { + assert((trans == NO_TRANS && rM.Cols() == rV.mDim && rM.Rows() == mDim) + || (trans == TRANS && rM.Rows() == rV.mDim && rM.Cols() == mDim)); + + cblas_dgemv(CblasRowMajor, static_cast<CBLAS_TRANSPOSE>(trans), rM.Rows(), rM.Cols(), alpha, rM.pData(), rM.Stride(), + rV.pData(), 1, beta, mpData, 1); + return *this; + } + + +#else + #error Routines in this section are not implemented yet without BLAS +#endif + +} // namespace TNet + + +#endif // TNet_Vector_tcc diff --git a/htk_io/src/KaldiLib/Vector.h b/htk_io/src/KaldiLib/Vector.h new file mode 100644 index 0000000..384c5d2 --- /dev/null +++ b/htk_io/src/KaldiLib/Vector.h @@ -0,0 +1,496 @@ +// +// C++ Interface: %{MODULE} +// +// Description: +// +// +// Author: %{AUTHOR} <%{EMAIL}>, (C) %{YEAR} +// +// Copyright: See COPYING file that comes with this distribution +// +// + +#ifndef TNet_Vector_h +#define TNet_Vector_h + +#include <cstddef> +#include <cstdlib> +#include <stdexcept> +#include <iostream> + +#ifdef HAVE_ATLAS +extern "C"{ + #include <cblas.h> + #include <clapack.h> +} +#endif + +#include "Common.h" +#include "MathAux.h" +#include "Types.h" +#include "Error.h" + +namespace TNet +{ + template<typename _ElemT> class Vector; + template<typename _ElemT> class SubVector; + template<typename _ElemT> class Matrix; + template<typename _ElemT> class SpMatrix; + + // we need to declare the friend functions here + template<typename _ElemT> + std::ostream & operator << (std::ostream & rOut, const Vector<_ElemT> & rV); + + template<typename _ElemT> + std::istream & operator >> (std::istream & rIn, Vector<_ElemT> & rV); + + template<typename _ElemT> + _ElemT + BlasDot(const Vector<_ElemT>& rA, const Vector<_ElemT>& rB); + + /** ************************************************************************** + ** ************************************************************************** + * @brief Provides a matrix abstraction class + * + * This class provides a way to work with matrices in TNet. + * It encapsulates basic operations and memory optimizations. + * + */ + template<typename _ElemT> + class Vector + { + public: + + /// defines a type of this + typedef Vector<_ElemT> ThisType; + + + Vector(): mpData(NULL) +#ifdef STK_MEMALIGN_MANUAL + ,mpFreeData(NULL) +#endif + , mDim(0) + {} + + /** + * @brief Copy constructor + * @param rV + */ + Vector(const Vector<_ElemT>& rV) + { mpData=NULL; Init(rV.Dim()); Copy(rV); } + + + /* Type conversion constructor. */ + template<typename _ElemU> + explicit Vector(const Vector<_ElemU>& rV) + { mpData=NULL; Init(rV.Dim()); Copy(rV); } + + + Vector(const _ElemT* ppData, const size_t s) + { mpData=NULL; Init(s); Copy(ppData); } + + explicit Vector(const size_t s, bool clear=true) + { mpData=NULL; Init(s,clear); } + + ~Vector() + { Destroy(); } + + Vector<_ElemT> &operator = (const Vector <_ElemT> &other) + { Init(other.Dim()); Copy(other); return *this; } // Needed for inclusion in std::vector + + Vector<_ElemT>& + Init(size_t length, bool clear=true); + + /** + * @brief Dealocates the window from memory and resets the dimensions to (0) + */ + void + Destroy(); + + /** + * @brief Returns @c true if vector is initialized + */ + bool + IsInitialized() const + { return mpData != NULL; } + + /** + * @brief Sets all elements to 0 + */ + void + Zero(); + + void + Set(_ElemT f); + + inline size_t + Dim() const + { return mDim; } + + /** + * @brief Returns size of matrix in memory (in bytes) + */ + inline size_t + MSize() const + { + return (mDim + (((16 / sizeof(_ElemT)) - mDim%(16 / sizeof(_ElemT))) + % (16 / sizeof(_ElemT)))) * sizeof(_ElemT); + } + + /** + * @brief Gives access to the vector memory area + * @return pointer to the first field + */ + inline _ElemT* + pData() + { return mpData; } + + /** + * @brief Gives access to the vector memory area + * @return pointer to the first field (const version) + */ + inline const _ElemT* + pData() const + { return mpData; } + + /** + * @brief Gives access to a specified vector element (const). + */ + inline _ElemT + operator [] (size_t i) const + { +#ifdef PARANOID + assert(i<mDim); +#endif + return *(mpData + i); + } + + /** + * @brief Gives access to a specified vector element (non-const). + */ + inline _ElemT & + operator [] (size_t i) + { +#ifdef PARANOID + assert(i<mDim); +#endif + return *(mpData + i); + } + + /** + * @brief Gives access to a specified vector element (const). + */ + inline _ElemT + operator () (size_t i) const + { +#ifdef PARANOID + assert(i<mDim); +#endif + return *(mpData + i); + } + + /** + * @brief Gives access to a specified vector element (non-const). + */ + inline _ElemT & + operator () (size_t i) + { +#ifdef PARANOID + assert(i<mDim); +#endif + return *(mpData + i); + } + + /** + * @brief Returns a matrix sub-range + * @param o Origin + * @param l Length + * See @c SubVector class for details + */ + SubVector<_ElemT> + Range(const size_t o, const size_t l) + { return SubVector<_ElemT>(*this, o, l); } + + /** + * @brief Returns a matrix sub-range + * @param o Origin + * @param l Length + * See @c SubVector class for details + */ + const SubVector<_ElemT> + Range(const size_t o, const size_t l) const + { return SubVector<_ElemT>(*this, o, l); } + + + + //######################################################################## + //######################################################################## + + /// Copy data from another vector + Vector<_ElemT>& + Copy(const Vector<_ElemT>& rV); + + /// Copy data from another vector of a different type. + template<typename _ElemU> Vector<_ElemT>& + Copy(const Vector<_ElemU>& rV); + + + /// Load data into the vector + Vector<_ElemT>& + Copy(const _ElemT* ppData); + + Vector<_ElemT>& + CopyVectorizedMatrixRows(const Matrix<_ElemT> &rM); + + Vector<_ElemT>& + RemoveElement(size_t i); + + Vector<_ElemT>& + ApplyLog(); + + Vector<_ElemT>& + ApplyLog(const Vector<_ElemT>& rV);//ApplyLog to rV and put the result in (*this) + + Vector<_ElemT>& + ApplyExp(); + + Vector<_ElemT>& + ApplySoftMax(); + + Vector<_ElemT>& + Invert(); + + Vector<_ElemT>& + DotMul(const Vector<_ElemT>& rV); // Multiplies each element (*this)(i) by rV(i). + + Vector<_ElemT>& + BlasAxpy(const _ElemT alpha, const Vector<_ElemT>& rV); + + Vector<_ElemT>& + BlasGemv(const _ElemT alpha, const Matrix<_ElemT>& rM, const MatrixTrasposeType trans, const Vector<_ElemT>& rV, const _ElemT beta = 0.0); + + + //######################################################################## + //######################################################################## + + Vector<_ElemT>& + Add(const Vector<_ElemT>& rV) + { return BlasAxpy(1.0, rV); } + + Vector<_ElemT>& + Subtract(const Vector<_ElemT>& rV) + { return BlasAxpy(-1.0, rV); } + + Vector<_ElemT>& + AddScaled(_ElemT alpha, const Vector<_ElemT>& rV) + { return BlasAxpy(alpha, rV); } + + Vector<_ElemT>& + Add(_ElemT c); + + Vector<_ElemT>& + MultiplyElements(const Vector<_ElemT>& rV); + + // @brief elementwise : rV.*rR+beta*this --> this + Vector<_ElemT>& + MultiplyElements(_ElemT alpha, const Vector<_ElemT>& rV, const Vector<_ElemT>& rR,_ElemT beta); + + Vector<_ElemT>& + DivideElements(const Vector<_ElemT>& rV); + + /// @brief elementwise : rV./rR+beta*this --> this + Vector<_ElemT>& + DivideElements(_ElemT alpha, const Vector<_ElemT>& rV, const Vector<_ElemT>& rR,_ElemT beta); + + Vector<_ElemT>& + Subtract(_ElemT c); + + Vector<_ElemT>& + Scale(_ElemT c); + + + //######################################################################## + //######################################################################## + + /// Performs a row stack of the matrix rMa + Vector<_ElemT>& + MatrixRowStack(const Matrix<_ElemT>& rMa); + + // Extracts a row of the matrix rMa. .. could also do this with vector.Copy(rMa[row]). + Vector<_ElemT>& + Row(const Matrix<_ElemT>& rMa, size_t row); + + // Extracts a column of the matrix rMa. + Vector<_ElemT>& + Col(const Matrix<_ElemT>& rMa, size_t col); + + // Takes all elements to a power. + Vector<_ElemT>& + Power(_ElemT power); + + _ElemT + Max() const; + + _ElemT + Min() const; + + /// Returns sum of the elements + _ElemT + Sum() const; + + /// Returns sum of the elements + Vector<_ElemT>& + AddRowSum(const Matrix<_ElemT>& rM); + + /// Returns sum of the elements + Vector<_ElemT>& + AddColSum(const Matrix<_ElemT>& rM); + + /// Returns log(sum(exp())) without exp overflow + _ElemT + LogSumExp() const; + + //######################################################################## + //######################################################################## + + friend std::ostream & + operator << <> (std::ostream& rOut, const Vector<_ElemT>& rV); + + friend _ElemT + BlasDot<>(const Vector<_ElemT>& rA, const Vector<_ElemT>& rB); + + /** + * Computes v1^T * M * v2. + * Not as efficient as it could be where v1==v2 (but no suitable blas + * routines available). + */ + _ElemT + InnerProduct(const Vector<_ElemT> &v1, const Matrix<_ElemT> &M, const Vector<_ElemT> &v2) const; + + + //########################################################################## + //########################################################################## + //protected: + public: + /// data memory area + _ElemT* mpData; +#ifdef STK_MEMALIGN_MANUAL + /// data to be freed (in case of manual memalignment use, see common.h) + _ElemT* mpFreeData; +#endif + size_t mDim; ///< Number of elements + }; // class Vector + + + + + /** + * @brief Represents a non-allocating general vector which can be defined + * as a sub-vector of higher-level vector + */ + template<typename _ElemT> + class SubVector : public Vector<_ElemT> + { + protected: + /// Constructor + SubVector(const Vector<_ElemT>& rT, + const size_t origin, + const size_t length) + { + assert(origin+length <= rT.mDim); + Vector<_ElemT>::mpData = rT.mpData+origin; + Vector<_ElemT>::mDim = length; + } + //only Vector class can call this protected constructor + friend class Vector<_ElemT>; + + public: + /// Constructor + SubVector(Vector<_ElemT>& rT, + const size_t origin, + const size_t length) + { + assert(origin+length <= rT.mDim); + Vector<_ElemT>::mpData = rT.mpData+origin; + Vector<_ElemT>::mDim = length; + } + + + /** + * @brief Constructs a vector representation out of a standard array + * + * @param pData pointer to data array to associate with this vector + * @param length length of this vector + */ + inline + SubVector(_ElemT *ppData, + size_t length) + { + Vector<_ElemT>::mpData = ppData; + Vector<_ElemT>::mDim = length; + } + + + /** + * @brief Destructor + */ + ~SubVector() + { + Vector<_ElemT>::mpData = NULL; + } + }; + + + // Useful shortcuts + typedef Vector<BaseFloat> BfVector; + typedef SubVector<BaseFloat> BfSubVector; + + //Adding two vectors of different types + template <typename _ElemT, typename _ElemU> + void Add(Vector<_ElemT>& rDst, const Vector<_ElemU>& rSrc) + { + assert(rDst.Dim() == rSrc.Dim()); + const _ElemU* p_src = rSrc.pData(); + _ElemT* p_dst = rDst.pData(); + + for(size_t i=0; i<rSrc.Dim(); i++) { + *p_dst++ += (_ElemT)*p_src++; + } + } + + + //Scales adding two vectors of different types + template <typename _ElemT, typename _ElemU> + void AddScaled(Vector<_ElemT>& rDst, const Vector<_ElemU>& rSrc, _ElemT scale) + { + assert(rDst.Dim() == rSrc.Dim()); + + Vector<_ElemT> tmp(rSrc); + rDst.BlasAxpy(scale, tmp); + +/* + const _ElemU* p_src = rSrc.pData(); + _ElemT* p_dst = rDst.pData(); + + for(size_t i=0; i<rDst.Dim(); i++) { + *p_dst++ += *p_src++ * scale; + } +*/ + } + + +} // namespace TNet + +//***************************************************************************** +//***************************************************************************** +// we need to include the implementation +#include "Vector.tcc" + +/****************************************************************************** + ****************************************************************************** + * The following section contains specialized template definitions + * whose implementation is in Vector.cc + */ + + +#endif // #ifndef TNet_Vector_h diff --git a/htk_io/src/KaldiLib/Vector.tcc b/htk_io/src/KaldiLib/Vector.tcc new file mode 100644 index 0000000..751ffa7 --- /dev/null +++ b/htk_io/src/KaldiLib/Vector.tcc @@ -0,0 +1,638 @@ +/** @file Vector.tcc + * This is an internal header file, included by other library headers. + * You should not attempt to use it directly. + */ + +#ifndef TNet_Vector_tcc +#define TNet_Vector_tcc + +#include <cstdlib> +#include <cmath> +#include <cstring> +#include <fstream> +#include <iomanip> +#include "Common.h" + +#ifdef HAVE_ATLAS +extern "C"{ + #include <cblas.h> +} +#endif + +#include "Common.h" +#include "MathAux.h" +#include "Matrix.h" + +namespace TNet +{ + //****************************************************************************** + //****************************************************************************** + template<typename _ElemT> + inline Vector<_ElemT>& + Vector<_ElemT>:: + Init(const size_t length, bool clear) + { + if(mpData != NULL) Destroy(); + if(length==0){ + mpData=NULL; +#ifdef STK_MEMALIGN_MANUAL + mpFreeData=NULL; +#endif + mDim=0; + return *this; + } + size_t size; + void* data; + void* free_data; + + size = align<16>(length * sizeof(_ElemT)); + + if (NULL != (data = stk_memalign(16, size, &free_data))) { + mpData = static_cast<_ElemT*> (data); +#ifdef STK_MEMALIGN_MANUAL + mpFreeData = static_cast<_ElemT*> (free_data); +#endif + mDim = length; + } else { + throw std::bad_alloc(); + } + if(clear) Zero(); + return *this; + } + + + //****************************************************************************** + //****************************************************************************** + /// Copy data from another vector + template<typename _ElemT> + inline Vector<_ElemT>& + Vector<_ElemT>:: + Copy(const Vector<_ElemT>& rV) { + assert(Dim() == rV.Dim()); + Copy(rV.mpData); + return *this; + } + + /// Load data into the vector + template<typename _ElemT> + Vector<_ElemT>& + Vector<_ElemT>:: + Copy(const _ElemT* ppData) { + std::memcpy(this->mpData, ppData, Dim() * sizeof(_ElemT)); + return *this; + } + + template<typename _ElemT> + template<typename _ElemU> + Vector<_ElemT>& + Vector<_ElemT>:: + Copy(const Vector<_ElemU> &other){ + assert(Dim()==other.Dim()); + size_t D=Dim(); + for(size_t d=0;d<D;d++) (*this)(d) = (_ElemT) other[d]; + return *this; + } + + + //****************************************************************************** + //****************************************************************************** + template<typename _ElemT> + Vector<_ElemT>& + Vector<_ElemT>:: + CopyVectorizedMatrixRows(const Matrix<_ElemT> &rM) { +// TraceLog("Dim = "+to_string(Dim())+", Rows = "+to_string(rM.Rows())+", Cols = "+to_string(rM.Cols())); + assert(Dim() == rM.Cols()*rM.Rows()); + size_t nCols = rM.Cols(); + for(size_t r=0; r<rM.Rows(); r++) + Range(r*nCols, nCols).Copy(rM[r]); + return *this; + } + + + //**************************************************************************** + //**************************************************************************** + // Remove element from the vector. The vector is non reallocated + template<typename _ElemT> + Vector<_ElemT>& + Vector<_ElemT>:: + RemoveElement(size_t i) { + assert(i < mDim && "Access out of vector"); + for(size_t j = i + 1; j < mDim; j++) + this->mpData[j - 1] = this->mpData[j]; + mDim--; + return *this; + } + + //**************************************************************************** + //**************************************************************************** + // The destructor + template<typename _ElemT> + inline void + Vector<_ElemT>:: + Destroy() + { + // we need to free the data block if it was defined +#ifndef STK_MEMALIGN_MANUAL + if (NULL != mpData) free(mpData); +#else + if (NULL != mpData) free(mpFreeData); + mpFreeData = NULL; +#endif + + mpData = NULL; + mDim = 0; + } + + + //**************************************************************************** + //**************************************************************************** + template<typename _ElemT> + inline void + Vector<_ElemT>:: + Zero() + { + std::memset(mpData, 0, mDim * sizeof(_ElemT)); + } + + //**************************************************************************** + //**************************************************************************** + template<typename _ElemT> + inline void + Vector<_ElemT>:: + Set(_ElemT f) + { + for(size_t i=0;i<mDim;i++) mpData[i] = f; + } + + + //**************************************************************************** + //**************************************************************************** + template<typename _ElemT> + Vector<_ElemT>& + Vector<_ElemT>:: + MatrixRowStack(const Matrix<_ElemT>& rMa) + { + assert(mDim == rMa.Cols() * rMa.Rows()); + + _ElemT* inc_data = mpData; + const size_t cols = rMa.Cols(); + + for (size_t i = 0; i < rMa.Rows(); i++) + { + // copy the data to the propper position + memcpy(inc_data, rMa[i], cols * sizeof(_ElemT)); + + // set new copy position + inc_data += cols; + } + } + + + //**************************************************************************** + //**************************************************************************** + template<typename _ElemT> + Vector<_ElemT>& + Vector<_ElemT>:: + Row(const Matrix<_ElemT> &rMa, size_t row) + { + assert(row < rMa.Rows()); + const _ElemT *mRow = rMa.pRowData(row); + // if(mDim != rMa.Cols()) Init(rMa.Cols()); // automatically resize. + memcpy(mpData, mRow, sizeof(_ElemT)*mDim); + return *this; + } + + //**************************************************************************** + //**************************************************************************** + template<typename _ElemT> + Vector<_ElemT>& + Vector<_ElemT>:: + Power(_ElemT power) // takes elements to a power. Throws exception if could not. + { + for(size_t i=0;i<Dim();i++){ + _ElemT tmp = (*this)(i); + (*this)(i) = pow(tmp, power); + if((*this)(i) == HUGE_VAL) + throw std::runtime_error((std::string)"Error in Vector::power, could not take " +to_string(tmp)+ " to power " +to_string((*this)(i))); + } + return (*this); + } + + + //**************************************************************************** + //**************************************************************************** + template<typename _ElemT> + _ElemT + Vector<_ElemT>:: + Max() const + { + if(Dim()==0) throw std::runtime_error("Error in Vector::Max(), empty vector\n"); + _ElemT ans = (*this)(0); + for(size_t i=1;i<Dim();i++) ans = std::max(ans, (*this)(i)); + return ans; + } + + //**************************************************************************** + //**************************************************************************** + template<typename _ElemT> + _ElemT + Vector<_ElemT>:: + Min() const + { + if(Dim()==0) throw std::runtime_error("Error in Vector::Min(), empty vector\n"); + _ElemT ans = (*this)(0); + for(size_t i=1;i<Dim();i++) ans = std::min(ans, (*this)(i)); + return ans; + } + + + + //**************************************************************************** + //**************************************************************************** + template<typename _ElemT> + Vector<_ElemT>& + Vector<_ElemT>:: + Col(const Matrix<_ElemT> &rMa, size_t col) + { + assert(col < rMa.Cols()); + // if(mDim != rMa.Cols()) Init(rMa.Cols()); // automatically resize. + for(size_t i=0;i<mDim;i++) + mpData[i] = rMa(i,col); // can't do this efficiently so don't really bother. + return *this; + } + + //**************************************************************************** + //**************************************************************************** + template<typename _ElemT> + _ElemT + Vector<_ElemT>:: + Sum() const + { + //note the double accumulator + double sum = 0.0; + + for (size_t i = 0; i < mDim; ++i) { + sum += mpData[i]; + } + return (_ElemT)sum; + } + + + //**************************************************************************** + //**************************************************************************** + template<typename _ElemT> + Vector<_ElemT>& + Vector<_ElemT>:: + AddColSum(const Matrix<_ElemT>& rM) + { + // note the double accumulator + double sum; + + assert(mDim == rM.Cols()); + + for (size_t i = 0; i < mDim; ++i) { + sum = 0.0; + for (size_t j = 0; j < rM.Rows(); ++j) { + sum += rM[j][i]; + } + mpData[i] += sum; + } + return *this; + } + + + //**************************************************************************** + //**************************************************************************** + template<typename _ElemT> + Vector<_ElemT>& + Vector<_ElemT>:: + AddRowSum(const Matrix<_ElemT>& rM) + { + // note the double accumulator + double sum; + + assert(mDim == rM.Rows()); + + for (size_t i = 0; i < mDim; ++i) { + sum = 0.0; + for (size_t j = 0; j < rM.Cols(); ++j) { + sum += rM[i][j]; + } + mpData[i] += sum; + } + return *this; + } + + + //**************************************************************************** + //**************************************************************************** + template<typename _ElemT> + _ElemT + Vector<_ElemT>:: + LogSumExp() const + { + double sum = LOG_0; + + for (size_t i = 0; i < mDim; ++i) { + sum = LogAdd(sum, mpData[i]); + } + return sum; + } + + //**************************************************************************** + //**************************************************************************** + template<typename _ElemT> + Vector<_ElemT>& + Vector<_ElemT>:: + Invert() { + for (size_t i = 0; i < mDim; ++i) { + mpData[i] = static_cast<_ElemT>(1 / mpData[i]); + } + return *this; + } + + //**************************************************************************** + //**************************************************************************** + template<typename _ElemT> + Vector<_ElemT>& + Vector<_ElemT>:: + ApplyLog() { + for (size_t i = 0; i < mDim; ++i) { + mpData[i] = _LOG(mpData[i]); + } + return *this; + } + + //**************************************************************************** + //**************************************************************************** + template<typename _ElemT> + Vector<_ElemT>& + Vector<_ElemT>:: + ApplyLog(const Vector<_ElemT>& rV) { + assert(mDim==rV.Dim()); + for (size_t i = 0; i < mDim; ++i) { + mpData[i] = log(rV[i]); + } + return *this; + } + + //**************************************************************************** + //**************************************************************************** + template<typename _ElemT> + Vector<_ElemT>& + Vector<_ElemT>:: + ApplyExp() { + for (size_t i = 0; i < mDim; ++i) { + mpData[i] = _EXP(mpData[i]); + } + return *this; + } + + //**************************************************************************** + //**************************************************************************** + template<typename _ElemT> + Vector<_ElemT>& + Vector<_ElemT>:: + ApplySoftMax() { + _ElemT lse = LogSumExp(); + + for (size_t i = 0; i < mDim; ++i) { + mpData[i] = exp(mpData[i] - lse); + } + return *this; + } + + + //**************************************************************************** + //**************************************************************************** + template<typename _ElemT> + Vector<_ElemT>& + Vector<_ElemT>:: + Add(_ElemT c) + { + for(size_t i = 0; i < mDim; i++) { + mpData[i] += c; + } + return *this; + } + + + //**************************************************************************** + //**************************************************************************** + template<typename _ElemT> + Vector<_ElemT>& + Vector<_ElemT>:: + Subtract(_ElemT c) + { + for(size_t i = 0; i < mDim; i++) { + mpData[i] -= c; + } + return *this; + } + + + //**************************************************************************** + //**************************************************************************** + template<typename _ElemT> + Vector<_ElemT>& + Vector<_ElemT>:: + Scale(_ElemT c) + { + for(size_t i = 0; i < mDim; i++) { + mpData[i] *= c; + } + return *this; + } + //**************************************************************************** + //**************************************************************************** + template<typename _ElemT> + Vector<_ElemT>& + Vector<_ElemT>:: + MultiplyElements(const Vector<_ElemT>& rV) + { + assert(mDim == rV.Dim()); + for(size_t i = 0; i < mDim; i++) { + mpData[i] *= rV[i]; + } + return *this; + } + + //**************************************************************************** + //**************************************************************************** + template<typename _ElemT> + Vector<_ElemT>& + Vector<_ElemT>:: + MultiplyElements(_ElemT alpha, const Vector<_ElemT>& rV, const Vector<_ElemT>& rR, _ElemT beta) + { + assert((mDim == rV.Dim() && mDim == rR.Dim())); + for(size_t i = 0; i < mDim; i++) { + mpData[i] = alpha * rV[i] * rR[i] + beta * mpData[i]; + } + return *this; + } + + //**************************************************************************** + //**************************************************************************** + template<typename _ElemT> + Vector<_ElemT>& + Vector<_ElemT>:: + DivideElements(const Vector<_ElemT>& rV) + { + assert(mDim == rV.Dim()); + for(size_t i = 0; i < mDim; i++) { + mpData[i] /= rV[i]; + } + return *this; + } + + //**************************************************************************** + //**************************************************************************** + + template<typename _ElemT> + Vector<_ElemT>& + Vector<_ElemT>:: + DivideElements(_ElemT alpha, const Vector<_ElemT>& rV, const Vector<_ElemT>& rR, _ElemT beta) + { + assert((mDim == rV.Dim() && mDim == rR.Dim())); + for(size_t i = 0; i < mDim; i++) { + mpData[i] = alpha * rV[i]/rR[i] + beta * mpData[i] ; + } + return *this; + } + + //**************************************************************************** + //**************************************************************************** + template<typename _ElemT> + void Load(std::istream& rIn, Vector<_ElemT>& rV) + { + std::streamoff pos = rIn.tellg(); + if(MatrixVectorIostreamControl::Flags(rIn, ACCUMULATE_INPUT)) { + for (size_t i = 0; i < rV.Dim(); i++) { + _ElemT tmp; + rIn >> tmp; + rV[i] += tmp; + } + } else { + for (size_t i = 0; i < rV.Dim(); i++) { + rIn >> rV[i]; + } + } + if(rIn.fail()) { + throw std::runtime_error("Failed to read vector from stream. File position is "+to_string(pos)); + } + } + + template<typename _ElemT> + std::istream & + operator >> (std::istream& rIn, Vector<_ElemT>& rV) + { + rIn >> std::ws; + if(rIn.peek() == 'v'){ // "new" format: v <dim> 1.0 0.2 4.3 ... + rIn.get(); + long long int tmp=-1; + rIn >> tmp; + if(rIn.fail() || tmp<0) { + throw std::runtime_error("Failed to read vector from stream: no size"); + } + size_t tmp2 = size_t(tmp); + assert((long long int)tmp2 == tmp); + + if(rV.Dim() != tmp2) rV.Init(tmp2); + } + Load(rIn,rV); + return rIn; + } + + + //**************************************************************************** + //**************************************************************************** + template<typename _ElemT> + void Save (std::ostream& rOut, const Vector<_ElemT>& rV) + { + + for (size_t i = 0; i < rV.Dim(); i++) { + rOut << rV[i] << ' '; + } + if(rOut.fail()) { + throw std::runtime_error("Failed to write vector to stream"); + } + } + + + //**************************************************************************** + //**************************************************************************** + template<typename _ElemT> + std::ostream & + operator << (std::ostream& rOut, const Vector<_ElemT>& rV) + { + rOut << "v " << rV.Dim() << " "; + Save(rOut,rV); + return rOut; + } + + + + //**************************************************************************** + //**************************************************************************** + +#ifdef HAVE_ATLAS + template<> + float + BlasDot<>(const Vector<float>& rA, const Vector<float>& rB); + + template<> + double + BlasDot<>(const Vector<double>& rA, const Vector<double>& rB); + + template<typename _ElemT> + inline Vector<_ElemT>& + Vector<_ElemT>:: + DotMul(const Vector<_ElemT> &rV){ + assert(mDim == rV.mDim); + const _ElemT *other_data = rV.pData(); + _ElemT *my_data = mpData, *my_data_end = my_data+mDim; + for(;my_data<my_data_end;) *(my_data++) *= *(other_data++); + return *this; + } + + template<> + Vector<float>& + Vector<float>:: + BlasAxpy(const float alpha, const Vector<float>& rV); + + + template<> + Vector<double>& + Vector<double>:: + BlasAxpy(const double alpha, const Vector<double>& rV); + + + template<> + Vector<float>& + Vector<float>:: + BlasGemv(const float alpha, const Matrix<float>& rM, MatrixTrasposeType trans, const Vector<float>& rV, const float beta); + + template<> + Vector<double>& + Vector<double>:: + BlasGemv(const double alpha, const Matrix<double>& rM, MatrixTrasposeType trans, const Vector<double>& rV, const double beta); + +#else + #error Routines in this section are not implemented yet without BLAS +#endif + + + template<class _ElemT> + _ElemT + InnerProduct(const Vector<_ElemT> &v1, const Matrix<_ElemT> &M, const Vector<_ElemT> &v2){ + assert(v1.size()==M.Rows() && v2.size()==M.Cols()); + Vector<_ElemT> vtmp(M.Rows()); + vtmp.BlasGemv(1.0, M, NO_TRANS, v2, 0.0); + return BlasDot(v1, vtmp); + } + + +} // namespace TNet + + +#endif // TNet_Vector_tcc diff --git a/htk_io/src/KaldiLib/clapack.cc b/htk_io/src/KaldiLib/clapack.cc new file mode 100644 index 0000000..a486bef --- /dev/null +++ b/htk_io/src/KaldiLib/clapack.cc @@ -0,0 +1,61 @@ + +extern "C" { + + /** + * Wrapper to GotoBLAS lapack for STK and TNet (sgetrf sgetri dgetrf dgetri) + */ + typedef float real; + typedef double doublereal; + typedef int integer; + + + /** + * The lapack interface (used in gotoblas) + */ + /* Subroutine */ int sgetrf_(integer *m, integer *n, real *a, integer *lda, + integer *ipiv, integer *info); + /* Subroutine */ int sgetri_(integer *n, real *a, integer *lda, integer *ipiv, + real *work, integer *lwork, integer *info); + /* Subroutine */ int dgetrf_(integer *m, integer *n, doublereal *a, integer * + lda, integer *ipiv, integer *info); + /* Subroutine */ int dgetri_(integer *n, doublereal *a, integer *lda, integer + *ipiv, doublereal *work, integer *lwork, integer *info); + + + + + + /** + * The clapack interface as used by ATLAS (used in STK, + */ + enum CBLAS_ORDER {CblasRowMajor=101, CblasColMajor=102 }; + + int clapack_sgetrf(const enum CBLAS_ORDER Order, const int M, const int N, + float *A, const int lda, int *ipiv) + { + return sgetrf_((int*)&M, (int*)&N, A, (int*)&lda, (int*)ipiv, 0); + } + + + int clapack_sgetri(const enum CBLAS_ORDER Order, const int N, float *A, + const int lda, const int *ipiv) + { + return sgetri_((int*)&N, A, (int*)&lda, (int*)ipiv, 0, 0, 0); + } + + + int clapack_dgetrf(const enum CBLAS_ORDER Order, const int M, const int N, + double *A, const int lda, int *ipiv) + { + return dgetrf_((int*)&M, (int*)&N, A, (int*)&lda, (int*)ipiv, 0); + } + + + int clapack_dgetri(const enum CBLAS_ORDER Order, const int N, double *A, + const int lda, const int *ipiv) + { + return dgetri_((int*)&N, A, (int*)&lda, (int*)ipiv, 0, 0, 0); + } + + +} diff --git a/htk_io/src/KaldiLib/clapack.h b/htk_io/src/KaldiLib/clapack.h new file mode 100644 index 0000000..0c6855d --- /dev/null +++ b/htk_io/src/KaldiLib/clapack.h @@ -0,0 +1,149 @@ +/* + * Automatically Tuned Linear Algebra Software v3.8.2 + * (C) Copyright 1999 R. Clint Whaley + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions + * are met: + * 1. Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * 2. Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions, and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * 3. The name of the ATLAS group or the names of its contributers may + * not be used to endorse or promote products derived from this + * software without specific written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + * ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED + * TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR + * PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE ATLAS GROUP OR ITS CONTRIBUTORS + * BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + * POSSIBILITY OF SUCH DAMAGE. + * + */ + +#ifndef CLAPACK_H + +#define CLAPACK_H +#include "cblas.h" + +#ifndef ATLAS_ORDER + #define ATLAS_ORDER CBLAS_ORDER +#endif +#ifndef ATLAS_UPLO + #define ATLAS_UPLO CBLAS_UPLO +#endif +#ifndef ATLAS_DIAG + #define ATLAS_DIAG CBLAS_DIAG +#endif +int clapack_sgesv(const enum CBLAS_ORDER Order, const int N, const int NRHS, + float *A, const int lda, int *ipiv, + float *B, const int ldb); +int clapack_sgetrf(const enum CBLAS_ORDER Order, const int M, const int N, + float *A, const int lda, int *ipiv); +int clapack_sgetrs + (const enum CBLAS_ORDER Order, const enum CBLAS_TRANSPOSE Trans, + const int N, const int NRHS, const float *A, const int lda, + const int *ipiv, float *B, const int ldb); +int clapack_sgetri(const enum CBLAS_ORDER Order, const int N, float *A, + const int lda, const int *ipiv); +int clapack_sposv(const enum ATLAS_ORDER Order, const enum ATLAS_UPLO Uplo, + const int N, const int NRHS, float *A, const int lda, + float *B, const int ldb); +int clapack_spotrf(const enum ATLAS_ORDER Order, const enum ATLAS_UPLO Uplo, + const int N, float *A, const int lda); +int clapack_spotrs(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, + const int N, const int NRHS, const float *A, const int lda, + float *B, const int ldb); +int clapack_spotri(const enum ATLAS_ORDER Order, const enum ATLAS_UPLO Uplo, + const int N, float *A, const int lda); +int clapack_slauum(const enum ATLAS_ORDER Order, const enum ATLAS_UPLO Uplo, + const int N, float *A, const int lda); +int clapack_strtri(const enum ATLAS_ORDER Order,const enum ATLAS_UPLO Uplo, + const enum ATLAS_DIAG Diag,const int N, float *A, const int lda); + +int clapack_dgesv(const enum CBLAS_ORDER Order, const int N, const int NRHS, + double *A, const int lda, int *ipiv, + double *B, const int ldb); +int clapack_dgetrf(const enum CBLAS_ORDER Order, const int M, const int N, + double *A, const int lda, int *ipiv); +int clapack_dgetrs + (const enum CBLAS_ORDER Order, const enum CBLAS_TRANSPOSE Trans, + const int N, const int NRHS, const double *A, const int lda, + const int *ipiv, double *B, const int ldb); +int clapack_dgetri(const enum CBLAS_ORDER Order, const int N, double *A, + const int lda, const int *ipiv); +int clapack_dposv(const enum ATLAS_ORDER Order, const enum ATLAS_UPLO Uplo, + const int N, const int NRHS, double *A, const int lda, + double *B, const int ldb); +int clapack_dpotrf(const enum ATLAS_ORDER Order, const enum ATLAS_UPLO Uplo, + const int N, double *A, const int lda); +int clapack_dpotrs(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, + const int N, const int NRHS, const double *A, const int lda, + double *B, const int ldb); +int clapack_dpotri(const enum ATLAS_ORDER Order, const enum ATLAS_UPLO Uplo, + const int N, double *A, const int lda); +int clapack_dlauum(const enum ATLAS_ORDER Order, const enum ATLAS_UPLO Uplo, + const int N, double *A, const int lda); +int clapack_dtrtri(const enum ATLAS_ORDER Order,const enum ATLAS_UPLO Uplo, + const enum ATLAS_DIAG Diag,const int N, double *A, const int lda); + +int clapack_cgesv(const enum CBLAS_ORDER Order, const int N, const int NRHS, + void *A, const int lda, int *ipiv, + void *B, const int ldb); +int clapack_cgetrf(const enum CBLAS_ORDER Order, const int M, const int N, + void *A, const int lda, int *ipiv); +int clapack_cgetrs + (const enum CBLAS_ORDER Order, const enum CBLAS_TRANSPOSE Trans, + const int N, const int NRHS, const void *A, const int lda, + const int *ipiv, void *B, const int ldb); +int clapack_cgetri(const enum CBLAS_ORDER Order, const int N, void *A, + const int lda, const int *ipiv); +int clapack_cposv(const enum ATLAS_ORDER Order, const enum ATLAS_UPLO Uplo, + const int N, const int NRHS, void *A, const int lda, + void *B, const int ldb); +int clapack_cpotrf(const enum ATLAS_ORDER Order, const enum ATLAS_UPLO Uplo, + const int N, void *A, const int lda); +int clapack_cpotrs(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, + const int N, const int NRHS, const void *A, const int lda, + void *B, const int ldb); +int clapack_cpotri(const enum ATLAS_ORDER Order, const enum ATLAS_UPLO Uplo, + const int N, void *A, const int lda); +int clapack_clauum(const enum ATLAS_ORDER Order, const enum ATLAS_UPLO Uplo, + const int N, void *A, const int lda); +int clapack_ctrtri(const enum ATLAS_ORDER Order,const enum ATLAS_UPLO Uplo, + const enum ATLAS_DIAG Diag,const int N, void *A, const int lda); + +int clapack_zgesv(const enum CBLAS_ORDER Order, const int N, const int NRHS, + void *A, const int lda, int *ipiv, + void *B, const int ldb); +int clapack_zgetrf(const enum CBLAS_ORDER Order, const int M, const int N, + void *A, const int lda, int *ipiv); +int clapack_zgetrs + (const enum CBLAS_ORDER Order, const enum CBLAS_TRANSPOSE Trans, + const int N, const int NRHS, const void *A, const int lda, + const int *ipiv, void *B, const int ldb); +int clapack_zgetri(const enum CBLAS_ORDER Order, const int N, void *A, + const int lda, const int *ipiv); +int clapack_zposv(const enum ATLAS_ORDER Order, const enum ATLAS_UPLO Uplo, + const int N, const int NRHS, void *A, const int lda, + void *B, const int ldb); +int clapack_zpotrf(const enum ATLAS_ORDER Order, const enum ATLAS_UPLO Uplo, + const int N, void *A, const int lda); +int clapack_zpotrs(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, + const int N, const int NRHS, const void *A, const int lda, + void *B, const int ldb); +int clapack_zpotri(const enum ATLAS_ORDER Order, const enum ATLAS_UPLO Uplo, + const int N, void *A, const int lda); +int clapack_zlauum(const enum ATLAS_ORDER Order, const enum ATLAS_UPLO Uplo, + const int N, void *A, const int lda); +int clapack_ztrtri(const enum ATLAS_ORDER Order,const enum ATLAS_UPLO Uplo, + const enum ATLAS_DIAG Diag,const int N, void *A, const int lda); + +#endif diff --git a/htk_io/src/cwrapper.cpp b/htk_io/src/cwrapper.cpp new file mode 100644 index 0000000..b7ce2d5 --- /dev/null +++ b/htk_io/src/cwrapper.cpp @@ -0,0 +1,148 @@ +#include "KaldiLib/Features.h" +#include "KaldiLib/Labels.h" +#include "KaldiLib/Common.h" +#include "KaldiLib/UserInterface.h" +#include <string> +#define SNAME "TNET" + +extern "C" { +#include "cwrapper.h" +#include "string.h" +#include "nerv/common.h" + + extern Matrix *nerv_matrix_host_float_create(long nrow, long ncol, Status *status); + + struct TNetFeatureRepo { + TNet::FeatureRepository feature_repo; + TNet::UserInterface ui; + bool swap_features; + int target_kind; + int deriv_order; + int* p_deriv_win_lenghts; + int start_frm_ext; + int end_frm_ext; + char* cmn_path; + char* cmn_file; + const char* cmn_mask; + char* cvn_path; + char* cvn_file; + const char* cvn_mask; + const char* cvg_file; + TNet::Matrix<float> feats_host; /* KaldiLib implementation */ + }; + + TNetFeatureRepo *tnet_feature_repo_new(const char *p_script, const char *config, int context) { + TNetFeatureRepo *repo = new TNetFeatureRepo(); + repo->ui.ReadConfig(config); + repo->swap_features = !repo->ui.GetBool(SNAME":NATURALREADORDER", TNet::IsBigEndian()); + /* load defaults */ + repo->target_kind = repo->ui.GetFeatureParams(&repo->deriv_order, + &repo->p_deriv_win_lenghts, + &repo->start_frm_ext, &repo->end_frm_ext, + &repo->cmn_path, &repo->cmn_file, &repo->cmn_mask, + &repo->cvn_path, &repo->cvn_file, &repo->cvn_mask, + &repo->cvg_file, SNAME":", 0); + repo->start_frm_ext = repo->end_frm_ext = context; + repo->feature_repo.Init(repo->swap_features, + repo->start_frm_ext, repo->end_frm_ext, repo->target_kind, + repo->deriv_order, repo->p_deriv_win_lenghts, + repo->cmn_path, repo->cmn_mask, + repo->cvn_path, repo->cvn_mask, repo->cvg_file); + repo->feature_repo.AddFileList(p_script); + repo->feature_repo.Rewind(); + return repo; + } + + Matrix *tnet_feature_repo_read_utterance(TNetFeatureRepo *repo, lua_State *L, int debug) { + Matrix *mat; /* nerv implementation */ + repo->feature_repo.ReadFullMatrix(repo->feats_host); + std::string utter_str = repo->feature_repo.Current().Logical(); + repo->feats_host.CheckData(utter_str); + int n = repo->feats_host.Rows(); + int m = repo->feats_host.Cols(); + Status status; + mat = nerv_matrix_host_float_create(n, m, &status); + NERV_LUA_CHECK_STATUS(L, status); + size_t stride = mat->stride; + if (debug) + fprintf(stderr, "[tnet] feature: %s %d %d\n", utter_str.c_str(), n, m); + for (int i = 0; i < n; i++) + { + float *row = repo->feats_host.pRowData(i); + float *nerv_row = (float *)((char *)mat->data.f + i * stride); + /* use memmove to copy the row, since KaldiLib uses compact storage */ + memmove(nerv_row, row, sizeof(float) * m); + } + return mat; + } + + void tnet_feature_repo_next(TNetFeatureRepo *repo) { + repo->feature_repo.MoveNext(); + } + + int tnet_feature_repo_is_end(TNetFeatureRepo *repo) { + return repo->feature_repo.EndOfList(); + } + + size_t tnet_feature_repo_current_samplerate(TNetFeatureRepo *repo) { + return repo->feature_repo.CurrentHeader().mSamplePeriod; + } + + const char *tnet_feature_repo_current_tag(TNetFeatureRepo *repo) { + return repo->feature_repo.Current().Logical().c_str(); + } + + void tnet_feature_repo_destroy(TNetFeatureRepo *repo) { + if (repo->cmn_mask) + free(repo->cmn_path); + if (repo->cvn_mask) + free(repo->cvn_path); + free(repo->p_deriv_win_lenghts); + delete repo; + } + + struct TNetLabelRepo { + TNet::LabelRepository label_repo; + }; + + TNetLabelRepo *tnet_label_repo_new(const char *mlf, const char *fmt, + const char *fmt_arg, const char *dir, + const char *ext) { + TNetLabelRepo *repo = new TNetLabelRepo(); + repo->label_repo.InitExt(mlf, fmt, fmt_arg, dir, ext); + /* repo->label_repo.Init(mlf, fmt_arg, dir, ext); */ + return repo; + } + + Matrix *tnet_label_repo_read_utterance(TNetLabelRepo *repo, + size_t frames, + size_t sample_rate, + const char *tag, + lua_State *L, + int debug) { + std::vector<TNet::Matrix<float> > labs_hosts; /* KaldiLib implementation */ + Matrix *mat; + repo->label_repo.GenDesiredMatrixExt(labs_hosts, frames, + sample_rate, tag); + int n = labs_hosts[0].Rows(); + int m = labs_hosts[0].Cols(); + Status status; + mat = nerv_matrix_host_float_create(n, m, &status); + NERV_LUA_CHECK_STATUS(L, status); + size_t stride = mat->stride; + if (debug) + fprintf(stderr, "[tnet] label: %s %d %d\n", tag, n, m); + for (int i = 0; i < n; i++) + { + float *row = labs_hosts[0].pRowData(i); + float *nerv_row = (float *)((char *)mat->data.f + i * stride); + /* use memmove to copy the row, since KaldiLib uses compact storage */ + memmove(nerv_row, row, sizeof(float) * m); + } + return mat; + } + + void tnet_label_repo_destroy(TNetLabelRepo *repo) { + delete repo; + } +} diff --git a/htk_io/src/cwrapper.h b/htk_io/src/cwrapper.h new file mode 100644 index 0000000..e1bce6e --- /dev/null +++ b/htk_io/src/cwrapper.h @@ -0,0 +1,37 @@ +#ifndef NERV_TNET_IO_CWRAPPER +#define NERV_TNET_IO_CWRAPPER +#include "nerv/matrix/matrix.h" +#include "nerv/common.h" +#ifdef __cplusplus +extern "C" { +#endif + + typedef struct TNetFeatureRepo TNetFeatureRepo; + + TNetFeatureRepo *tnet_feature_repo_new(const char *scp, + const char *config, int context); + Matrix *tnet_feature_repo_read_utterance(TNetFeatureRepo *repo, lua_State *L, int debug); + size_t tnet_feature_repo_current_samplerate(TNetFeatureRepo *repo); + const char *tnet_feature_repo_current_tag(TNetFeatureRepo *repo); + void tnet_feature_repo_next(TNetFeatureRepo *repo); + int tnet_feature_repo_is_end(TNetFeatureRepo *repo); + void tnet_feature_repo_destroy(TNetFeatureRepo *repo); + + typedef struct TNetLabelRepo TNetLabelRepo; + + TNetLabelRepo *tnet_label_repo_new(const char *mlf, const char *fmt, + const char *fmt_arg, const char *dir, + const char *ext); + + Matrix *tnet_label_repo_read_utterance(TNetLabelRepo *repo, + size_t frames, + size_t sample_rate, + const char *tag, + lua_State *L, + int debug); + + void tnet_label_repo_destroy(TNetLabelRepo *repo); +#ifdef __cplusplus +} +#endif +#endif diff --git a/htk_io/src/init.c b/htk_io/src/init.c new file mode 100644 index 0000000..8a1ec3b --- /dev/null +++ b/htk_io/src/init.c @@ -0,0 +1,118 @@ +#include "nerv/common.h" +#include "cwrapper.h" +#include <stdio.h> + +const char *nerv_tnet_feat_repo_tname = "nerv.TNetFeatureRepo"; +const char *nerv_tnet_label_repo_tname = "nerv.TNetLabelRepo"; +const char *nerv_matrix_host_float_tname = "nerv.MMatrixFloat"; + +static int feat_repo_new(lua_State *L) { + const char *scp_file = luaL_checkstring(L, 1); + const char *conf = luaL_checkstring(L, 2); + int frm_ext = luaL_checkinteger(L, 3); + TNetFeatureRepo *repo = tnet_feature_repo_new(scp_file, conf, frm_ext); + luaT_pushudata(L, repo, nerv_tnet_feat_repo_tname); + return 1; +} + +static int feat_repo_destroy(lua_State *L) { + TNetFeatureRepo *repo = luaT_checkudata(L, 1, nerv_tnet_feat_repo_tname); + tnet_feature_repo_destroy(repo); + return 0; +} + +static int feat_repo_current_tag(lua_State *L) { + TNetFeatureRepo *repo = luaT_checkudata(L, 1, nerv_tnet_feat_repo_tname); + lua_pushstring(L, tnet_feature_repo_current_tag(repo)); + return 1; +} + +static int feat_repo_current_utterance(lua_State *L) { + TNetFeatureRepo *repo = luaT_checkudata(L, 1, nerv_tnet_feat_repo_tname); + int debug; + if (!lua_isboolean(L, 2)) + nerv_error(L, "debug flag should be a boolean"); + debug = lua_toboolean(L, 2); + Matrix *utter = tnet_feature_repo_read_utterance(repo, L, debug); + luaT_pushudata(L, utter, nerv_matrix_host_float_tname); + return 1; +} + +static int feat_repo_next(lua_State *L) { + TNetFeatureRepo *repo = luaT_checkudata(L, 1, nerv_tnet_feat_repo_tname); + tnet_feature_repo_next(repo); + return 0; +} + +static int feat_repo_is_end(lua_State *L) { + TNetFeatureRepo *repo = luaT_checkudata(L, 1, nerv_tnet_feat_repo_tname); + lua_pushboolean(L, tnet_feature_repo_is_end(repo)); + return 1; +} + +static const luaL_Reg feat_repo_methods[] = { + {"cur_utter", feat_repo_current_utterance}, + {"cur_tag", feat_repo_current_tag}, + {"next", feat_repo_next}, + {"is_end", feat_repo_is_end}, + {NULL, NULL} +}; + +static int label_repo_new(lua_State *L) { + const char *mlf_file = luaL_checkstring(L, 1); + const char *fmt = luaL_checkstring(L, 2); + const char *arg = luaL_checkstring(L, 3); + const char *dir = luaL_checkstring(L, 4); + const char *ext = luaL_checkstring(L, 5); + TNetLabelRepo *repo = tnet_label_repo_new( + mlf_file, fmt, arg, + dir, ext); + luaT_pushudata(L, repo, nerv_tnet_label_repo_tname); + return 1; +} + +static int label_repo_read_utterance(lua_State *L) { + TNetLabelRepo *repo = luaT_checkudata(L, 1, nerv_tnet_label_repo_tname); + TNetFeatureRepo *feat_repo = luaT_checkudata(L, 2, nerv_tnet_feat_repo_tname); + size_t frames = luaL_checkinteger(L, 3); + int debug; + if (!lua_isboolean(L, 4)) + nerv_error(L, "debug flag should be a boolean"); + debug = lua_toboolean(L, 4); + Matrix *utter = tnet_label_repo_read_utterance(repo, + frames, + tnet_feature_repo_current_samplerate(feat_repo), + tnet_feature_repo_current_tag(feat_repo), L, debug); + luaT_pushudata(L, utter, nerv_matrix_host_float_tname); + return 1; +} + +static int label_repo_destroy(lua_State *L) { + TNetLabelRepo *repo = luaT_checkudata(L, 1, nerv_tnet_label_repo_tname); + tnet_label_repo_destroy(repo); + return 0; +} + +static const luaL_Reg label_repo_methods[] = { + {"get_utter", label_repo_read_utterance}, + {NULL, NULL} +}; + +static void feat_repo_init(lua_State *L) { + luaT_newmetatable(L, nerv_tnet_feat_repo_tname, NULL, + feat_repo_new, feat_repo_destroy, NULL); + luaL_register(L, NULL, feat_repo_methods); + lua_pop(L, 1); +} + +static void label_repo_init(lua_State *L) { + luaT_newmetatable(L, nerv_tnet_label_repo_tname, NULL, + label_repo_new, label_repo_destroy, NULL); + luaL_register(L, NULL, label_repo_methods); + lua_pop(L, 1); +} + +void tnet_io_init(lua_State *L) { + feat_repo_init(L); + label_repo_init(L); +} diff --git a/htk_io/src/test.c b/htk_io/src/test.c new file mode 100644 index 0000000..6812ef1 --- /dev/null +++ b/htk_io/src/test.c @@ -0,0 +1,40 @@ +#include "cwrapper.h" +#include <stdio.h> + +void print_nerv_matrix(Matrix *mat) { + int n = mat->nrow; + int m = mat->ncol; + int i, j; + size_t stride = mat->stride; + for (i = 0; i < n; i++) + { + + float *nerv_row = (float *)((char *)mat->data.f + i * stride); + for (j = 0; j < m; j++) + fprintf(stderr, "%.8f ", nerv_row[j]); + fprintf(stderr, "\n"); + } +} + +int main() { + fprintf(stderr, "init repo\n"); + TNetFeatureRepo *feat_repo = tnet_feature_repo_new( + "/slfs1/users/mfy43/swb_ivec/train_bp.scp", + "/slfs1/users/mfy43/swb_ivec/plp_0_d_a.conf", 5); + Matrix *feat_utter; + feat_utter = tnet_feature_repo_read_utterance(feat_repo, NULL, 1); + + TNetLabelRepo *lab_repo = tnet_label_repo_new( + "/slfs1/users/mfy43/swb_ivec/ref.mlf", + "map", + "/slfs1/users/mfy43/swb_ivec/dict", + "*/", + "lab"); + Matrix *lab_utter = tnet_label_repo_read_utterance(lab_repo, + feat_utter->nrow - 5 * 2, + tnet_feature_repo_current_samplerate(feat_repo), + tnet_feature_repo_current_tag(feat_repo), NULL, + 1); + print_nerv_matrix(lab_utter); + return 0; +} diff --git a/htk_io/src/tnet.mk b/htk_io/src/tnet.mk new file mode 100644 index 0000000..9f933db --- /dev/null +++ b/htk_io/src/tnet.mk @@ -0,0 +1,83 @@ +# +# This makefile contains some global definitions, +# that are used during the build process. +# It is included by all the subridrectory libraries. +# + + +############################################################## +##### 64-BIT CROSS-COMPILATION ##### +CXXFLAGS= +FWDPARAM= +BITS64=true +ifeq ($(BITS64), true) + ##### CHANGE WHEN DIFFERENT 64BIT g++ PREFIX ##### + CROSS_COMPILE = x86_64-linux- + ##### CHANGE WHEN DIFFERENT 64BIT g++ PREFIX ##### + CXXFLAGS += -m64 + FWDPARAM += BITS64=true +else + CXXFLAGS += -m32 +endif + +# disable cross-compile prefix if CXX not exists +CXX=$(CROSS_COMPILE)g++ +CXX2=$(notdir $(shell which $(CXX) 2>/dev/null)) +ifneq ("$(CXX)", "$(CXX2)") + CROSS_COMPILE= +endif + +# compilation tools +CC = $(CROSS_COMPILE)g++ +CXX = $(CROSS_COMPILE)g++ +AR = $(CROSS_COMPILE)ar +RANLIB = $(CROSS_COMPILE)ranlib +AS = $(CROSS_COMPILE)as + + + + +############################################################## +##### PATH TO CUDA TOOLKIT ##### +#CUDA_TK_BASE=/usr/local/share/cuda-3.2.12 +#CUDA_TK_BASE=/usr/local/cuda +##### PATH TO CUDA TOOLKIT ##### + + + + +# compilation args +CXXFLAGS += -g -Wall -O2 -DHAVE_ATLAS -rdynamic -fPIC +CXXFLAGS += -Wshadow -Wpointer-arith -Wcast-qual -Wcast-align -Wwrite-strings -Wconversion + +# enable double-precision +ifeq ($(DOUBLEPRECISION), true) + CXXFLAGS += -DDOUBLEPRECISION + FWDPARAM += DOUBLEPRECISION=true +endif + + +# compile all the source .cc files +SRC=$(wildcard *.cc) +OBJ=$(addprefix $(OBJ_DIR)/,$(patsubst %.cc, %.o, $(SRC))) + + + + +######################################################### +# CONFIGURATION CHECKS +# + +#check that CUDA_TK_BASE is set correctly +ifeq ("$(wildcard $(CUDA_TK_BASE)/bin/nvcc)", "$(CUDA_TK_BASE)/bin/nvcc") + HAVE_CUDA=true +else + ifeq ($(CUDA), true) + $(error %%% CUDA not found! Incorrect path in CUDA_TK_BASE: $(CUDA_TK_BASE) in 'trunk/src/tnet.mk') + endif +endif + +# +######################################################### + + |