From 96a32415ab43377cf1575bd3f4f2980f58028209 Mon Sep 17 00:00:00 2001 From: Determinant Date: Fri, 14 Aug 2015 11:51:42 +0800 Subject: add implementation for kaldi io (by ymz) --- kaldi_io/src/kaldi/tree/build-tree-questions.h | 133 +++++++++ kaldi_io/src/kaldi/tree/build-tree-utils.h | 324 ++++++++++++++++++++++ kaldi_io/src/kaldi/tree/build-tree.h | 250 +++++++++++++++++ kaldi_io/src/kaldi/tree/cluster-utils.h | 291 ++++++++++++++++++++ kaldi_io/src/kaldi/tree/clusterable-classes.h | 158 +++++++++++ kaldi_io/src/kaldi/tree/context-dep.h | 166 +++++++++++ kaldi_io/src/kaldi/tree/event-map.h | 365 +++++++++++++++++++++++++ kaldi_io/src/kaldi/tree/tree-renderer.h | 84 ++++++ 8 files changed, 1771 insertions(+) create mode 100644 kaldi_io/src/kaldi/tree/build-tree-questions.h create mode 100644 kaldi_io/src/kaldi/tree/build-tree-utils.h create mode 100644 kaldi_io/src/kaldi/tree/build-tree.h create mode 100644 kaldi_io/src/kaldi/tree/cluster-utils.h create mode 100644 kaldi_io/src/kaldi/tree/clusterable-classes.h create mode 100644 kaldi_io/src/kaldi/tree/context-dep.h create mode 100644 kaldi_io/src/kaldi/tree/event-map.h create mode 100644 kaldi_io/src/kaldi/tree/tree-renderer.h (limited to 'kaldi_io/src/kaldi/tree') diff --git a/kaldi_io/src/kaldi/tree/build-tree-questions.h b/kaldi_io/src/kaldi/tree/build-tree-questions.h new file mode 100644 index 0000000..a6bcfdd --- /dev/null +++ b/kaldi_io/src/kaldi/tree/build-tree-questions.h @@ -0,0 +1,133 @@ +// tree/build-tree-questions.h + +// Copyright 2009-2011 Microsoft Corporation + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#ifndef KALDI_TREE_BUILD_TREE_QUESTIONS_H_ +#define KALDI_TREE_BUILD_TREE_QUESTIONS_H_ + +#include "util/stl-utils.h" +#include "tree/context-dep.h" + +namespace kaldi { + + +/// \addtogroup tree_group +/// @{ +/// Typedef for statistics to build trees. +typedef std::vector > BuildTreeStatsType; + +/// Typedef used when we get "all keys" from a set of stats-- used in specifying +/// which kinds of questions to ask. +typedef enum { kAllKeysInsistIdentical, kAllKeysIntersection, kAllKeysUnion } AllKeysType; + +/// @} + +/// \defgroup tree_group_questions Question sets for decision-tree clustering +/// See \ref tree_internals (and specifically \ref treei_func_questions) for context. +/// \ingroup tree_group +/// @{ + +/// QuestionsForKey is a class used to define the questions for a key, +/// and also options that allow us to refine the question during tree-building +/// (i.e. make a question specific to the location in the tree). +/// The Questions class handles aggregating these options for a set +/// of different keys. +struct QuestionsForKey { // Configuration class associated with a particular key + // (of type EventKeyType). It also contains the questions themselves. + std::vector > initial_questions; + RefineClustersOptions refine_opts; // if refine_opts.max_iter == 0, + // we just pick from the initial questions. + + QuestionsForKey(int32 num_iters = 5): refine_opts(num_iters, 2) { + // refine_cfg with 5 iters and top-n = 2 (this is no restriction because + // RefineClusters called with 2 clusters; would get set to that anyway as + // it's the only possible value for 2 clusters). User has to add questions. + // This config won't work as-is, as it has no questions. + } + + void Check() const { + for (size_t i = 0;i < initial_questions.size();i++) KALDI_ASSERT(IsSorted(initial_questions[i])); + } + + void Write(std::ostream &os, bool binary) const; + void Read(std::istream &is, bool binary); + + // copy and assign allowed. +}; + +/// This class defines, for each EventKeyType, a set of initial questions that +/// it tries and also a number of iterations for which to refine the questions to increase +/// likelihood. It is perhaps a bit more than an options class, as it contains the +/// actual questions. +class Questions { // careful, this is a class. + public: + const QuestionsForKey &GetQuestionsOf(EventKeyType key) const { + std::map::const_iterator iter; + if ( (iter = key_idx_.find(key)) == key_idx_.end()) { + KALDI_ERR << "Questions: no options for key "<< key; + } + size_t idx = iter->second; + KALDI_ASSERT(idx < key_options_.size()); + key_options_[idx]->Check(); + return *(key_options_[idx]); + } + void SetQuestionsOf(EventKeyType key, const QuestionsForKey &options_of_key) { + options_of_key.Check(); + if (key_idx_.count(key) == 0) { + key_idx_[key] = key_options_.size(); + key_options_.push_back(new QuestionsForKey()); + *(key_options_.back()) = options_of_key; + } else { + size_t idx = key_idx_[key]; + KALDI_ASSERT(idx < key_options_.size()); + *(key_options_[idx]) = options_of_key; + } + } + void GetKeysWithQuestions(std::vector *keys_out) const { + KALDI_ASSERT(keys_out != NULL); + CopyMapKeysToVector(key_idx_, keys_out); + } + const bool HasQuestionsForKey(EventKeyType key) const { return (key_idx_.count(key) != 0); } + ~Questions() { kaldi::DeletePointers(&key_options_); } + + + /// Initializer with arguments. After using this you would have to set up the config for each key you + /// are going to use, or use InitRand(). + Questions() { } + + + /// InitRand attempts to generate "reasonable" random questions. Only + /// of use for debugging. This initializer creates a config that is + /// ready to use. + /// e.g. num_iters_refine = 0 means just use stated questions (if >1, will use + /// different questions at each split of the tree). + void InitRand(const BuildTreeStatsType &stats, int32 num_quest, int32 num_iters_refine, AllKeysType all_keys_type); + + void Write(std::ostream &os, bool binary) const; + void Read(std::istream &is, bool binary); + private: + std::vector key_options_; + std::map key_idx_; + KALDI_DISALLOW_COPY_AND_ASSIGN(Questions); +}; + +/// @} + +}// end namespace kaldi + +#endif // KALDI_TREE_BUILD_TREE_QUESTIONS_H_ diff --git a/kaldi_io/src/kaldi/tree/build-tree-utils.h b/kaldi_io/src/kaldi/tree/build-tree-utils.h new file mode 100644 index 0000000..464fc6b --- /dev/null +++ b/kaldi_io/src/kaldi/tree/build-tree-utils.h @@ -0,0 +1,324 @@ +// tree/build-tree-utils.h + +// Copyright 2009-2011 Microsoft Corporation + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#ifndef KALDI_TREE_BUILD_TREE_UTILS_H_ +#define KALDI_TREE_BUILD_TREE_UTILS_H_ + +#include "tree/build-tree-questions.h" + +// build-tree-questions.h needed for this typedef: +// typedef std::vector > BuildTreeStatsType; +// and for other #includes. + +namespace kaldi { + + +/// \defgroup tree_group_lower Low-level functions for manipulating statistics and event-maps +/// See \ref tree_internals and specifically \ref treei_func for context. +/// \ingroup tree_group +/// +/// @{ + + + +/// This frees the Clusterable* pointers in "stats", where non-NULL, and sets them to NULL. +/// Does not delete the pointer "stats" itself. +void DeleteBuildTreeStats(BuildTreeStatsType *stats); + +/// Writes BuildTreeStats object. This works even if pointers are NULL. +void WriteBuildTreeStats(std::ostream &os, bool binary, + const BuildTreeStatsType &stats); + +/// Reads BuildTreeStats object. The "example" argument must be of the same +/// type as the stats on disk, and is needed for access to the correct "Read" +/// function. It was organized this way for easier extensibility (so adding new +/// Clusterable derived classes isn't painful) +void ReadBuildTreeStats(std::istream &is, bool binary, + const Clusterable &example, BuildTreeStatsType *stats); + +/// Convenience function e.g. to work out possible values of the phones from just the stats. +/// Returns true if key was always defined inside the stats. +/// May be used with and == NULL to find out of key was always defined. +bool PossibleValues(EventKeyType key, const BuildTreeStatsType &stats, + std::vector *ans); + + +/// Splits stats according to the EventMap, indexing them at output by the +/// leaf type. A utility function. NOTE-- pointers in stats_out point to +/// the same memory location as those in stats. No copying of Clusterable* +/// objects happens. Will add to stats in stats_out if non-empty at input. +/// This function may increase the size of vector stats_out as necessary +/// to accommodate stats, but will never decrease the size. +void SplitStatsByMap(const BuildTreeStatsType &stats_in, const EventMap &e, + std::vector *stats_out); + +/// SplitStatsByKey splits stats up according to the value of a particular key, +/// which must be always defined and nonnegative. Like MapStats. Pointers to +/// Clusterable* in stats_out are not newly allocated-- they are the same as the +/// ones in stats_in. Generally they will still be owned at stats_in (user can +/// decide where to allocate ownership). +void SplitStatsByKey(const BuildTreeStatsType &stats_in, EventKeyType key, + std::vector *stats_out); + + +/// Converts stats from a given context-window (N) and central-position (P) to a +/// different N and P, by possibly reducing context. This function does a job +/// that's quite specific to the "normal" stats format we use. See \ref +/// tree_window for background. This function may delete some keys and change +/// others, depending on the N and P values. It expects that at input, all keys +/// will either be -1 or lie between 0 and oldN-1. At output, keys will be +/// either -1 or between 0 and newN-1. +/// Returns false if we could not convert the stats (e.g. because newN is larger +/// than oldN). +bool ConvertStats(int32 oldN, int32 oldP, int32 newN, int32 newP, + BuildTreeStatsType *stats); + + +/// FilterStatsByKey filters the stats according the value of a specified key. +/// If include_if_present == true, it only outputs the stats whose key is in +/// "values"; otherwise it only outputs the stats whose key is not in "values". +/// At input, "values" must be sorted and unique, and all stats in "stats_in" +/// must have "key" defined. At output, pointers to Clusterable* in stats_out +/// are not newly allocated-- they are the same as the ones in stats_in. +void FilterStatsByKey(const BuildTreeStatsType &stats_in, + EventKeyType key, + std::vector &values, + bool include_if_present, // true-> retain only if in "values", + // false-> retain only if not in "values". + BuildTreeStatsType *stats_out); + + +/// Sums stats, or returns NULL stats_in has no non-NULL stats. +/// Stats are newly allocated, owned by caller. +Clusterable *SumStats(const BuildTreeStatsType &stats_in); + +/// Sums the normalizer [typically, data-count] over the stats. +BaseFloat SumNormalizer(const BuildTreeStatsType &stats_in); + +/// Sums the objective function over the stats. +BaseFloat SumObjf(const BuildTreeStatsType &stats_in); + + +/// Sum a vector of stats. Leaves NULL as pointer if no stats available. +/// The pointers in stats_out are owned by caller. At output, there may be +/// NULLs in the vector stats_out. +void SumStatsVec(const std::vector &stats_in, std::vector *stats_out); + +/// Cluster the stats given the event map return the total objf given those clusters. +BaseFloat ObjfGivenMap(const BuildTreeStatsType &stats_in, const EventMap &e); + + +/// FindAllKeys puts in *keys the (sorted, unique) list of all key identities in the stats. +/// If type == kAllKeysInsistIdentical, it will insist that this set of keys is the same for all the +/// stats (else exception is thrown). +/// if type == kAllKeysIntersection, it will return the smallest common set of keys present in +/// the set of stats +/// if type== kAllKeysUnion (currently probably not so useful since maps will return "undefined" +/// if key is not present), it will return the union of all the keys present in the stats. +void FindAllKeys(const BuildTreeStatsType &stats, AllKeysType keys_type, + std::vector *keys); + + +/// @} + + +/** + \defgroup tree_group_intermediate Intermediate-level functions used in building the tree + These functions are are used in top-level tree-building code (\ref tree_group_top); see + \ref tree_internals for documentation. + \ingroup tree_group + @{ +*/ + + +/// Returns a tree with just one node. Used @ start of tree-building process. +/// Not really used in current recipes. +inline EventMap *TrivialTree(int32 *num_leaves) { + KALDI_ASSERT(*num_leaves == 0); // in envisaged usage. + return new ConstantEventMap( (*num_leaves)++ ); +} + +/// DoTableSplit does a complete split on this key (e.g. might correspond to central phone +/// (key = P-1), or HMM-state position (key == kPdfClass == -1). Stats used to work out possible +/// values of the event. "num_leaves" is used to allocate new leaves. All stats must have +/// this key defined, or this function will crash. +EventMap *DoTableSplit(const EventMap &orig, EventKeyType key, + const BuildTreeStatsType &stats, int32 *num_leaves); + + +/// DoTableSplitMultiple does a complete split on all the keys, in order from keys[0], +/// keys[1] +/// and so on. The stats are used to work out possible values corresponding to the key. +/// "num_leaves" is used to allocate new leaves. All stats must have +/// the keys defined, or this function will crash. +/// Returns a newly allocated event map. +EventMap *DoTableSplitMultiple(const EventMap &orig, + const std::vector &keys, + const BuildTreeStatsType &stats, + int32 *num_leaves); + + +/// "ClusterEventMapGetMapping" clusters the leaves of the EventMap, with "thresh" a delta-likelihood +/// threshold to control how many leaves we combine (might be the same as the delta-like +/// threshold used in splitting. +// The function returns the #leaves we combined. The same leaf-ids of the leaves being clustered +// will be used for the clustered leaves (but other than that there is no special rule which +// leaf-ids should be used at output). +// It outputs the mapping for leaves, in "mapping", which may be empty at the start +// but may also contain mappings for other parts of the tree, which must contain +// disjoint leaves from this part. This is so that Cluster can +// be called multiple times for sub-parts of the tree (with disjoint sets of leaves), +// e.g. if we want to avoid sharing across phones. Afterwards you can use Copy function +// of EventMap to apply the mapping, i.e. call e_in.Copy(mapping) to get the new map. +// Note that the application of Cluster creates gaps in the leaves. You should then +// call RenumberEventMap(e_in.Copy(mapping), num_leaves). +// *If you only want to cluster a subset of the leaves (e.g. just non-silence, or just +// a particular phone, do this by providing a set of "stats" that correspond to just +// this subset of leaves*. Leaves with no stats will not be clustered. +// See build-tree.cc for an example of usage. +int ClusterEventMapGetMapping(const EventMap &e_in, const BuildTreeStatsType &stats, + BaseFloat thresh, std::vector *mapping); + +/// This is as ClusterEventMapGetMapping but a more convenient interface +/// that exposes less of the internals. It uses a bottom-up clustering to +/// combine the leaves, until the log-likelihood decrease from combinging two +/// leaves exceeds the threshold. +EventMap *ClusterEventMap(const EventMap &e_in, const BuildTreeStatsType &stats, + BaseFloat thresh, int32 *num_removed); + +/// This is as ClusterEventMap, but first splits the stats on the keys specified +/// in "keys" (e.g. typically keys = [ -1, P ]), and only clusters within the +/// classes defined by that splitting. +/// Note-- leaves will be non-consecutive at output, use RenumberEventMap. +EventMap *ClusterEventMapRestrictedByKeys(const EventMap &e_in, + const BuildTreeStatsType &stats, + BaseFloat thresh, + const std::vector &keys, + int32 *num_removed); + + +/// This version of ClusterEventMapRestricted restricts the clustering to only +/// allow things that "e_restrict" maps to the same value to be clustered +/// together. +EventMap *ClusterEventMapRestrictedByMap(const EventMap &e_in, + const BuildTreeStatsType &stats, + BaseFloat thresh, + const EventMap &e_restrict, + int32 *num_removed); + + +/// RenumberEventMap [intended to be used after calling ClusterEventMap] renumbers +/// an EventMap so its leaves are consecutive. +/// It puts the number of leaves in *num_leaves. If later you need the mapping of +/// the leaves, modify the function and add a new argument. +EventMap *RenumberEventMap(const EventMap &e_in, int32 *num_leaves); + +/// This function remaps the event-map leaves using this mapping, +/// indexed by the number at leaf. +EventMap *MapEventMapLeaves(const EventMap &e_in, + const std::vector &mapping); + + + +/// ShareEventMapLeaves performs a quite specific function that allows us to +/// generate trees where, for a certain list of phones, and for all states in +/// the phone, all the pdf's are shared. +/// Each element of "values" contains a list of phones (may be just one phone), +/// all states of which we want shared together). Typically at input, "key" will +/// equal P, the central-phone position, and "values" will contain just one +/// list containing the silence phone. +/// This function renumbers the event map leaves after doing the sharing, to +/// make the event-map leaves contiguous. +EventMap *ShareEventMapLeaves(const EventMap &e_in, EventKeyType key, + std::vector > &values, + int32 *num_leaves); + + + +/// Does a decision-tree split at the leaves of an EventMap. +/// @param orig [in] The EventMap whose leaves we want to split. [may be either a trivial or a +/// non-trivial one]. +/// @param stats [in] The statistics for splitting the tree; if you do not want a particular +/// subset of leaves to be split, make sure the stats corresponding to those leaves +/// are not present in "stats". +/// @param qcfg [in] Configuration class that contains initial questions (e.g. sets of phones) +/// for each key and says whether to refine these questions during tree building. +/// @param thresh [in] A log-likelihood threshold (e.g. 300) that can be used to +/// limit the number of leaves; you can use zero and set max_leaves instead. +/// @param max_leaves [in] Will stop leaves being split after they reach this number. +/// @param num_leaves [in,out] A pointer used to allocate leaves; always corresponds to the +/// current number of leaves (is incremented when this is increased). +/// @param objf_impr_out [out] If non-NULL, will be set to the objective improvement due to splitting +/// (not normalized by the number of frames). +/// @param smallest_split_change_out If non-NULL, will be set to the smallest objective-function +/// improvement that we got from splitting any leaf; useful to provide a threshold +/// for ClusterEventMap. +/// @return The EventMap after splitting is returned; pointer is owned by caller. +EventMap *SplitDecisionTree(const EventMap &orig, + const BuildTreeStatsType &stats, + Questions &qcfg, + BaseFloat thresh, + int32 max_leaves, // max_leaves<=0 -> no maximum. + int32 *num_leaves, + BaseFloat *objf_impr_out, + BaseFloat *smallest_split_change_out); + +/// CreateRandomQuestions will initialize a Questions randomly, in a reasonable +/// way [for testing purposes, or when hand-designed questions are not available]. +/// e.g. num_quest = 5 might be a reasonable value if num_iters > 0, or num_quest = 20 otherwise. +void CreateRandomQuestions(const BuildTreeStatsType &stats, int32 num_quest, Questions *cfg_out); + + +/// FindBestSplitForKey is a function used in DoDecisionTreeSplit. +/// It finds the best split for this key, given these stats. +/// It will return 0 if the key was not always defined for the stats. +BaseFloat FindBestSplitForKey(const BuildTreeStatsType &stats, + const Questions &qcfg, + EventKeyType key, + std::vector *yes_set); + + +/// GetStubMap is used in tree-building functions to get the initial +/// to-states map, before the decision-tree-building process. It creates +/// a simple map that splits on groups of phones. For the set of phones in +/// phone_sets[i] it creates either: if share_roots[i] == true, a single +/// leaf node, or if share_roots[i] == false, separate root nodes for +/// each HMM-position (it goes up to the highest position for any +/// phone in the set, although it will warn if you share roots between +/// phones with different numbers of states, which is a weird thing to +/// do but should still work. If any phone is present +/// in "phone_sets" but "phone2num_pdf_classes" does not map it to a length, +/// it is an error. Note that the behaviour of the resulting map is +/// undefined for phones not present in "phone_sets". +/// At entry, this function should be called with (*num_leaves == 0). +/// It will number the leaves starting from (*num_leaves). + +EventMap *GetStubMap(int32 P, + const std::vector > &phone_sets, + const std::vector &phone2num_pdf_classes, + const std::vector &share_roots, // indexed by index into phone_sets. + int32 *num_leaves); +/// Note: GetStubMap with P = 0 can be used to get a standard monophone system. + +/// @} + + +}// end namespace kaldi + +#endif diff --git a/kaldi_io/src/kaldi/tree/build-tree.h b/kaldi_io/src/kaldi/tree/build-tree.h new file mode 100644 index 0000000..37bb108 --- /dev/null +++ b/kaldi_io/src/kaldi/tree/build-tree.h @@ -0,0 +1,250 @@ +// tree/build-tree.h + +// Copyright 2009-2011 Microsoft Corporation + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#ifndef KALDI_TREE_BUILD_TREE_H_ +#define KALDI_TREE_BUILD_TREE_H_ + +// The file build-tree.h contains outer-level routines used in tree-building +// and related tasks, that are directly called by the command-line tools. + +#include "tree/build-tree-utils.h" +#include "tree/context-dep.h" +namespace kaldi { + +/// \defgroup tree_group_top Top-level tree-building functions +/// See \ref tree_internals for context. +/// \ingroup tree_group +/// @{ + +// Note, in tree_group_top we also include AccumulateTreeStats, in +// ../hmm/tree-accu.h (it has some extra dependencies so we didn't +// want to include it here). + +/** + * BuildTree is the normal way to build a set of decision trees. + * The sets "phone_sets" dictate how we set up the roots of the decision trees. + * each set of phones phone_sets[i] has shared decision-tree roots, and if + * the corresponding variable share_roots[i] is true, the root will be shared + * for the different HMM-positions in the phone. All phones in "phone_sets" + * should be in the stats (use FixUnseenPhones to ensure this). + * if for any i, do_split[i] is false, we will not do any tree splitting for + * phones in that set. + * @param qopts [in] Questions options class, contains questions for each key + * (e.g. each phone position) + * @param phone_sets [in] Each element of phone_sets is a set of phones whose + * roots are shared together (prior to decision-tree splitting). + * @param phone2num_pdf_classes [in] A map from phones to the number of + * \ref pdf_class "pdf-classes" + * in the phone (this info is derived from the HmmTopology object) + * @param share_roots [in] A vector the same size as phone_sets; says for each + * phone set whether the root should be shared among all the + * pdf-classes or not. + * @param do_split [in] A vector the same size as phone_sets; says for each + * phone set whether decision-tree splitting should be done + * (generally true for non-silence phones). + * @param stats [in] The statistics used in tree-building. + * @param thresh [in] Threshold used in decision-tree splitting (e.g. 1000), + * or you may use 0 in which case max_leaves becomes the + * constraint. + * @param max_leaves [in] Maximum number of leaves it will create; set this + * to a large number if you want to just specify "thresh". + * @param cluster_thresh [in] Threshold for clustering leaves after decision-tree + * splitting (only within each phone-set); leaves will be combined + * if log-likelihood change is less than this. A value about equal + * to "thresh" is suitable + * if thresh != 0; otherwise, zero will mean no clustering is done, + * or a negative value (e.g. -1) sets it to the smallest likelihood + * change seen during the splitting algorithm; this typically causes + * about a 20% reduction in the number of leaves. + + * @param P [in] The central position of the phone context window, e.g. 1 for a + * triphone system. + * @return Returns a pointer to an EventMap object that is the tree. + +*/ + +EventMap *BuildTree(Questions &qopts, + const std::vector > &phone_sets, + const std::vector &phone2num_pdf_classes, + const std::vector &share_roots, + const std::vector &do_split, + const BuildTreeStatsType &stats, + BaseFloat thresh, + int32 max_leaves, + BaseFloat cluster_thresh, // typically == thresh. If negative, use smallest split. + int32 P); + + +/** + * + * BuildTreeTwoLevel builds a two-level tree, useful for example in building tied mixture + * systems with multiple codebooks. It first builds a small tree by splitting to + * "max_leaves_first". It then splits at the leaves of "max_leaves_first" (think of this + * as creating multiple little trees at the leaves of the first tree), until the total + * number of leaves reaches "max_leaves_second". It then outputs the second tree, along + * with a mapping from the leaf-ids of the second tree to the leaf-ids of the first tree. + * Note that the interface is similar to BuildTree, and in fact it calls BuildTree + * internally. + * + * The sets "phone_sets" dictate how we set up the roots of the decision trees. + * each set of phones phone_sets[i] has shared decision-tree roots, and if + * the corresponding variable share_roots[i] is true, the root will be shared + * for the different HMM-positions in the phone. All phones in "phone_sets" + * should be in the stats (use FixUnseenPhones to ensure this). + * if for any i, do_split[i] is false, we will not do any tree splitting for + * phones in that set. + * + * @param qopts [in] Questions options class, contains questions for each key + * (e.g. each phone position) + * @param phone_sets [in] Each element of phone_sets is a set of phones whose + * roots are shared together (prior to decision-tree splitting). + * @param phone2num_pdf_classes [in] A map from phones to the number of + * \ref pdf_class "pdf-classes" + * in the phone (this info is derived from the HmmTopology object) + * @param share_roots [in] A vector the same size as phone_sets; says for each + * phone set whether the root should be shared among all the + * pdf-classes or not. + * @param do_split [in] A vector the same size as phone_sets; says for each + * phone set whether decision-tree splitting should be done + * (generally true for non-silence phones). + * @param stats [in] The statistics used in tree-building. + * @param max_leaves_first [in] Maximum number of leaves it will create in first + * level of decision tree. + * @param max_leaves_second [in] Maximum number of leaves it will create in second + * level of decision tree. Must be > max_leaves_first. + * @param cluster_leaves [in] Boolean value; if true, we post-cluster the leaves produced + * in the second level of decision-tree split; if false, we don't. + * The threshold for post-clustering is the log-like change of the last + * decision-tree split; this typically causes about a 20% reduction in + * the number of leaves. + * @param P [in] The central position of the phone context window, e.g. 1 for a + * triphone system. + * @param leaf_map [out] Will be set to be a mapping from the leaves of the + * "big" tree to the leaves of the "little" tree, which you can + * view as cluster centers. + * @return Returns a pointer to an EventMap object that is the (big) tree. + +*/ + +EventMap *BuildTreeTwoLevel(Questions &qopts, + const std::vector > &phone_sets, + const std::vector &phone2num_pdf_classes, + const std::vector &share_roots, + const std::vector &do_split, + const BuildTreeStatsType &stats, + int32 max_leaves_first, + int32 max_leaves_second, + bool cluster_leaves, + int32 P, + std::vector *leaf_map); + + +/// GenRandStats generates random statistics of the form used by BuildTree. +/// It tries to do so in such a way that they mimic "real" stats. The event keys +/// and their corresponding values are: +/// - key == -1 == kPdfClass -> pdf-class, generally corresponds to +/// zero-based position in HMM (0, 1, 2 .. hmm_lengths[phone]-1) +/// - key == 0 -> phone-id of left-most context phone. +/// - key == 1 -> phone-id of one-from-left-most context phone. +/// - key == P-1 -> phone-id of central phone. +/// - key == N-1 -> phone-id of right-most context phone. +/// GenRandStats is useful only for testing but it serves to document the format of +/// stats used by BuildTreeDefault. +/// if is_ctx_dep[phone] is set to false, GenRandStats will not define the keys for +/// other than the P-1'th phone. + +/// @param dim [in] dimension of features. +/// @param num_stats [in] approximate number of separate phones-in-context wanted. +/// @param N [in] context-size (typically 3) +/// @param P [in] central-phone position in zero-based numbering (typically 1) +/// @param phone_ids [in] integer ids of phones +/// @param hmm_lengths [in] lengths of hmm for phone, indexed by phone. +/// @param is_ctx_dep [in] boolean array indexed by phone, saying whether each phone +/// is context dependent. +/// @param ensure_all_phones_covered [in] Boolean argument: if true, GenRandStats +/// ensures that every phone is seen at least once in the central position (P). +/// @param stats_out [out] The statistics that this routine outputs. + +void GenRandStats(int32 dim, int32 num_stats, int32 N, int32 P, + const std::vector &phone_ids, + const std::vector &hmm_lengths, + const std::vector &is_ctx_dep, + bool ensure_all_phones_covered, + BuildTreeStatsType *stats_out); + + +/// included here because it's used in some tree-building +/// calling code. Reads an OpenFst symbl table, +/// discards the symbols and outputs the integers +void ReadSymbolTableAsIntegers(std::string filename, + bool include_eps, + std::vector *syms); + + + +/** + * Outputs sets of phones that are reasonable for questions + * to ask in the tree-building algorithm. These are obtained by tree + * clustering of the phones; for each node in the tree, all the leaves + * accessible from that node form one of the sets of phones. + * @param stats [in] The statistics as used for normal tree-building. + * @param phone_sets_in [in] All the phones, pre-partitioned into sets. + * The output sets will be various unions of these sets. These sets + * will normally correspond to "real phones", in cases where the phones + * have stress and position markings. + * @param all_pdf_classes_in [in] All the \ref pdf_class "pdf-classes" + * that we consider for clustering. In the normal case this is the singleton + * set {1}, which means that we only consider the central hmm-position + * of the standard 3-state HMM, for clustering purposes. + * @param P [in] The central position in the phone context window; normally + * 1 for triphone system.s + * @param questions_out [out] The questions (sets of phones) are output to here. + **/ +void AutomaticallyObtainQuestions(BuildTreeStatsType &stats, + const std::vector > &phone_sets_in, + const std::vector &all_pdf_classes_in, + int32 P, + std::vector > *questions_out); + +/// This function clusters the phones (or some initially specified sets of phones) +/// into sets of phones, using a k-means algorithm. Useful, for example, in building +/// simple models for purposes of adaptation. + +void KMeansClusterPhones(BuildTreeStatsType &stats, + const std::vector > &phone_sets_in, + const std::vector &all_pdf_classes_in, + int32 P, + int32 num_classes, + std::vector > *sets_out); + +/// Reads the roots file (throws on error). Format is lines like: +/// "shared split 1 2 3 4", +/// "not-shared not-split 5", +/// and so on. The numbers are indexes of phones. +void ReadRootsFile(std::istream &is, + std::vector > *phone_sets, + std::vector *is_shared_root, + std::vector *is_split_root); + + +/// @} + +}// end namespace kaldi + +#endif diff --git a/kaldi_io/src/kaldi/tree/cluster-utils.h b/kaldi_io/src/kaldi/tree/cluster-utils.h new file mode 100644 index 0000000..55583a2 --- /dev/null +++ b/kaldi_io/src/kaldi/tree/cluster-utils.h @@ -0,0 +1,291 @@ +// tree/cluster-utils.h + +// Copyright 2012 Arnab Ghoshal +// Copyright 2009-2011 Microsoft Corporation; Saarland University + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#ifndef KALDI_TREE_CLUSTER_UTILS_H_ +#define KALDI_TREE_CLUSTER_UTILS_H_ + +#include +#include "matrix/matrix-lib.h" +#include "itf/clusterable-itf.h" + +namespace kaldi { + +/// \addtogroup clustering_group_simple +/// @{ + +/// Returns the total objective function after adding up all the +/// statistics in the vector (pointers may be NULL). +BaseFloat SumClusterableObjf(const std::vector &vec); + +/// Returns the total normalizer (usually count) of the cluster (pointers may be NULL). +BaseFloat SumClusterableNormalizer(const std::vector &vec); + +/// Sums stats (ptrs may be NULL). Returns NULL if no non-NULL stats present. +Clusterable *SumClusterable(const std::vector &vec); + +/** Fills in any (NULL) holes in "stats" vector, with empty stats, because + * certain algorithms require non-NULL stats. If "stats" nonempty, requires it + * to contain at least one non-NULL pointer that we can call Copy() on. + */ +void EnsureClusterableVectorNotNull(std::vector *stats); + + +/** Given stats and a vector "assignments" of the same size (that maps to + * cluster indices), sums the stats up into "clusters." It will add to any + * stats already present in "clusters" (although typically "clusters" will be + * empty when called), and it will extend with NULL pointers for any unseen + * indices. Call EnsureClusterableStatsNotNull afterwards if you want to ensure + * all non-NULL clusters. Pointer in "clusters" are owned by caller. Pointers in + * "stats" do not have to be non-NULL. + */ +void AddToClusters(const std::vector &stats, + const std::vector &assignments, + std::vector *clusters); + + +/// AddToClustersOptimized does the same as AddToClusters (it sums up the stats +/// within each cluster, except it uses the sum of all the stats ("total") to +/// optimize the computation for speed, if possible. This will generally only be +/// a significant speedup in the case where there are just two clusters, which +/// can happen in algorithms that are doing binary splits; the idea is that we +/// sum up all the stats in one cluster (the one with the fewest points in it), +/// and then subtract from the total. +void AddToClustersOptimized(const std::vector &stats, + const std::vector &assignments, + const Clusterable &total, + std::vector *clusters); + +/// @} end "addtogroup clustering_group_simple" + +/// \addtogroup clustering_group_algo +/// @{ + +// Note, in the algorithms below, it is assumed that the input "points" (which +// is std::vector) is all non-NULL. + +/** A bottom-up clustering algorithm. There are two parameters that control how + * many clusters we get: a "max_merge_thresh" which is a threshold for merging + * clusters, and a min_clust which puts a floor on the number of clusters we want. Set + * max_merge_thresh = large to use the min_clust only, or min_clust to 0 to use + * the max_merge_thresh only. + * + * The algorithm is: + * \code + * while (num-clusters > min_clust && smallest_merge_cost <= max_merge_thresh) + * merge the closest two clusters. + * \endcode + * + * @param points [in] Points to be clustered (may not contain NULL pointers) + * @param thresh [in] Threshold on cost change from merging clusters; clusters + * won't be merged if the cost is more than this + * @param min_clust [in] Minimum number of clusters desired; we'll stop merging + * after reaching this number. + * @param clusters_out [out] If non-NULL, will be set to a vector of size equal + * to the number of output clusters, containing the clustered + * statistics. Must be empty when called. + * @param assignments_out [out] If non-NULL, will be resized to the number of + * points, and each element is the index of the cluster that point + * was assigned to. + * @return Returns the total objf change relative to all clusters being separate, which is + * a negative. Note that this is not the same as what the other clustering algorithms return. + */ +BaseFloat ClusterBottomUp(const std::vector &points, + BaseFloat thresh, + int32 min_clust, + std::vector *clusters_out, + std::vector *assignments_out); + +/** This is a bottom-up clustering where the points are pre-clustered in a set + * of compartments, such that only points in the same compartment are clustered + * together. The compartment and pair of points with the smallest merge cost + * is selected and the points are clustered. The result stays in the same + * compartment. The code does not merge compartments, and hence assumes that + * the number of compartments is smaller than the 'min_clust' option. + * The clusters in "clusters_out" are newly allocated and owned by the caller. + */ +BaseFloat ClusterBottomUpCompartmentalized( + const std::vector< std::vector > &points, BaseFloat thresh, + int32 min_clust, std::vector< std::vector > *clusters_out, + std::vector< std::vector > *assignments_out); + + +struct RefineClustersOptions { + int32 num_iters; // must be >= 0. If zero, does nothing. + int32 top_n; // must be >= 2. + RefineClustersOptions() : num_iters(100), top_n(5) {} + RefineClustersOptions(int32 num_iters_in, int32 top_n_in) + : num_iters(num_iters_in), top_n(top_n_in) {} + // include Write and Read functions because this object gets written/read as + // part of the QuestionsForKeyOptions class. + void Write(std::ostream &os, bool binary) const; + void Read(std::istream &is, bool binary); +}; + +/** RefineClusters is mainly used internally by other clustering algorithms. + * + * It starts with a given assignment of points to clusters and + * keeps trying to improve it by moving points from cluster to cluster, up to + * a maximum number of iterations. + * + * "clusters" and "assignments" are both input and output variables, and so + * both MUST be non-NULL. + * + * "top_n" (>=2) is a pruning value: more is more exact, fewer is faster. The + * algorithm initially finds the "top_n" closest clusters to any given point, + * and from that point only consider move to those "top_n" clusters. Since + * RefineClusters is called multiple times from ClusterKMeans (for instance), + * this is not really a limitation. + */ +BaseFloat RefineClusters(const std::vector &points, + std::vector *clusters /*non-NULL*/, + std::vector *assignments /*non-NULL*/, + RefineClustersOptions cfg = RefineClustersOptions()); + +struct ClusterKMeansOptions { + RefineClustersOptions refine_cfg; + int32 num_iters; + int32 num_tries; // if >1, try whole procedure >once and pick best. + bool verbose; + ClusterKMeansOptions() + : refine_cfg(), num_iters(20), num_tries(2), verbose(true) {} +}; + +/** ClusterKMeans is a K-means-like clustering algorithm. It starts with + * pseudo-random initialization of points to clusters and uses RefineClusters + * to iteratively improve the cluster assignments. It does this for + * multiple iterations and picks the result with the best objective function. + * + * + * ClusterKMeans implicitly uses Rand(). It will not necessarily return + * the same value on different calls. Use sRand() if you want consistent + * results. + * The algorithm used in ClusterKMeans is a "k-means-like" algorithm that tries + * to be as efficient as possible. Firstly, since the algorithm it uses + * includes random initialization, it tries the whole thing cfg.num_tries times + * and picks the one with the best objective function. Each try, it does as + * follows: it randomly initializes points to clusters, and then for + * cfg.num_iters iterations it calls RefineClusters(). The options to + * RefineClusters() are given by cfg.refine_cfg. Calling RefineClusters once + * will always be at least as good as doing one iteration of reassigning points to + * clusters, but will generally be quite a bit better (without taking too + * much extra time). + * + * @param points [in] points to be clustered (must be all non-NULL). + * @param num_clust [in] number of clusters requested (it will always return exactly + * this many, or will fail if num_clust > points.size()). + * @param clusters_out [out] may be NULL; if non-NULL, should be empty when called. + * Will be set to a vector of statistics corresponding to the output clusters. + * @param assignments_out [out] may be NULL; if non-NULL, will be set to a vector of + * same size as "points", which says for each point which cluster + * it is assigned to. + * @param cfg [in] configuration class specifying options to the algorithm. + * @return Returns the objective function improvement versus everything being + * in the same cluster. + * + */ +BaseFloat ClusterKMeans(const std::vector &points, + int32 num_clust, // exact number of clusters + std::vector *clusters_out, // may be NULL + std::vector *assignments_out, // may be NULL + ClusterKMeansOptions cfg = ClusterKMeansOptions()); + +struct TreeClusterOptions { + ClusterKMeansOptions kmeans_cfg; + int32 branch_factor; + BaseFloat thresh; // Objf change: if >0, may be used to control number of leaves. + TreeClusterOptions() + : kmeans_cfg(), branch_factor(2), thresh(0) { + kmeans_cfg.verbose = false; + } +}; + +/** TreeCluster is a top-down clustering algorithm, using a binary tree (not + * necessarily balanced). Returns objf improvement versus having all points + * in one cluster. The algorithm is: + * - Initialize to 1 cluster (tree with 1 node). + * - Maintain, for each cluster, a "best-binary-split" (using ClusterKMeans + * to do so). Always split the highest scoring cluster, until we can do no + * more splits. + * + * @param points [in] Data points to be clustered + * @param max_clust [in] Maximum number of clusters (you will get exactly this number, + * if there are at least this many points, except if you set the + * cfg.thresh value nonzero, in which case that threshold may limit + * the number of clusters. + * @param clusters_out [out] If non-NULL, will be set to the a vector whose first + * (*num_leaves_out) elements are the leaf clusters, and whose + * subsequent elements are the nonleaf nodes in the tree, in + * topological order with the root node last. Must be empty vector + * when this function is called. + * @param assignments_out [out] If non-NULL, will be set to a vector to a vector the + * same size as "points", where assignments[i] is the leaf node index i + * to which the i'th point gets clustered. + * @param clust_assignments_out [out] If non-NULL, will be set to a vector the same size + * as clusters_out which says for each node (leaf or nonleaf), the + * index of its parent. For the root node (which is last), + * assignments_out[i] == i. For each i, assignments_out[i]>=i, i.e. + * any node's parent is higher numbered than itself. If you don't need + * this information, consider using instead the ClusterTopDown function. + * @param num_leaves_out [out] If non-NULL, will be set to the number of leaf nodes + * in the tree. + * @param cfg [in] Configuration object that controls clustering behavior. Most + * important value is "thresh", which provides an alternative mechanism + * [other than max_clust] to limit the number of leaves. + */ +BaseFloat TreeCluster(const std::vector &points, + int32 max_clust, // max number of leaf-level clusters. + std::vector *clusters_out, + std::vector *assignments_out, + std::vector *clust_assignments_out, + int32 *num_leaves_out, + TreeClusterOptions cfg = TreeClusterOptions()); + + +/** + * A clustering algorithm that internally uses TreeCluster, + * but does not give you the information about the structure of the tree. + * The "clusters_out" and "assignments_out" may be NULL if the outputs are not + * needed. + * + * @param points [in] points to be clustered (must be all non-NULL). + * @param max_clust [in] Maximum number of clusters (you will get exactly this number, + * if there are at least this many points, except if you set the + * cfg.thresh value nonzero, in which case that threshold may limit + * the number of clusters. + * @param clusters_out [out] may be NULL; if non-NULL, should be empty when called. + * Will be set to a vector of statistics corresponding to the output clusters. + * @param assignments_out [out] may be NULL; if non-NULL, will be set to a vector of + * same size as "points", which says for each point which cluster + * it is assigned to. + * @param cfg [in] Configuration object that controls clustering behavior. Most + * important value is "thresh", which provides an alternative mechanism + * [other than max_clust] to limit the number of leaves. +*/ +BaseFloat ClusterTopDown(const std::vector &points, + int32 max_clust, // max number of clusters. + std::vector *clusters_out, + std::vector *assignments_out, + TreeClusterOptions cfg = TreeClusterOptions()); + +/// @} end of "addtogroup clustering_group_algo" + +} // end namespace kaldi. + +#endif // KALDI_TREE_CLUSTER_UTILS_H_ diff --git a/kaldi_io/src/kaldi/tree/clusterable-classes.h b/kaldi_io/src/kaldi/tree/clusterable-classes.h new file mode 100644 index 0000000..817d0c6 --- /dev/null +++ b/kaldi_io/src/kaldi/tree/clusterable-classes.h @@ -0,0 +1,158 @@ +// tree/clusterable-classes.h + +// Copyright 2009-2011 Microsoft Corporation; Saarland University +// 2014 Daniel Povey + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#ifndef KALDI_TREE_CLUSTERABLE_CLASSES_H_ +#define KALDI_TREE_CLUSTERABLE_CLASSES_H_ 1 + +#include +#include "itf/clusterable-itf.h" +#include "matrix/matrix-lib.h" + +namespace kaldi { + +// Note: see sgmm/sgmm-clusterable.h for an SGMM-based clusterable +// class. We didn't include it here, to avoid adding an extra +// dependency to this directory. + +/// \addtogroup clustering_group +/// @{ + +/// ScalarClusterable clusters scalars with x^2 loss. +class ScalarClusterable: public Clusterable { + public: + ScalarClusterable(): x_(0), x2_(0), count_(0) {} + explicit ScalarClusterable(BaseFloat x): x_(x), x2_(x*x), count_(1) {} + virtual std::string Type() const { return "scalar"; } + virtual BaseFloat Objf() const; + virtual void SetZero() { count_ = x_ = x2_ = 0.0; } + virtual void Add(const Clusterable &other_in); + virtual void Sub(const Clusterable &other_in); + virtual Clusterable* Copy() const; + virtual BaseFloat Normalizer() const { + return static_cast(count_); + } + + // Function to write data to stream. Will organize input later [more complex] + virtual void Write(std::ostream &os, bool binary) const; + virtual Clusterable* ReadNew(std::istream &is, bool binary) const; + + std::string Info(); // For debugging. + BaseFloat Mean() { return (count_ != 0 ? x_/count_ : 0.0); } + private: + BaseFloat x_; + BaseFloat x2_; + BaseFloat count_; + + void Read(std::istream &is, bool binary); +}; + + +/// GaussClusterable wraps Gaussian statistics in a form accessible +/// to generic clustering algorithms. +class GaussClusterable: public Clusterable { + public: + GaussClusterable(): count_(0.0), var_floor_(0.0) {} + GaussClusterable(int32 dim, BaseFloat var_floor): + count_(0.0), stats_(2, dim), var_floor_(var_floor) {} + + GaussClusterable(const Vector &x_stats, + const Vector &x2_stats, + BaseFloat var_floor, BaseFloat count); + + virtual std::string Type() const { return "gauss"; } + void AddStats(const VectorBase &vec, BaseFloat weight = 1.0); + virtual BaseFloat Objf() const; + virtual void SetZero(); + virtual void Add(const Clusterable &other_in); + virtual void Sub(const Clusterable &other_in); + virtual BaseFloat Normalizer() const { return count_; } + virtual Clusterable *Copy() const; + virtual void Scale(BaseFloat f); + virtual void Write(std::ostream &os, bool binary) const; + virtual Clusterable *ReadNew(std::istream &is, bool binary) const; + virtual ~GaussClusterable() {} + + BaseFloat count() const { return count_; } + // The next two functions are not const-correct, because of SubVector. + SubVector x_stats() const { return stats_.Row(0); } + SubVector x2_stats() const { return stats_.Row(1); } + private: + double count_; + Matrix stats_; // two rows: sum, then sum-squared. + double var_floor_; // should be common for all objects created. + + void Read(std::istream &is, bool binary); +}; + +/// @} end of "addtogroup clustering_group" + +inline void GaussClusterable::SetZero() { + count_ = 0; + stats_.SetZero(); +} + +inline GaussClusterable::GaussClusterable(const Vector &x_stats, + const Vector &x2_stats, + BaseFloat var_floor, BaseFloat count): + count_(count), stats_(2, x_stats.Dim()), var_floor_(var_floor) { + stats_.Row(0).CopyFromVec(x_stats); + stats_.Row(1).CopyFromVec(x2_stats); +} + + +/// VectorClusterable wraps vectors in a form accessible to generic clustering +/// algorithms. Each vector is associated with a weight; these could be 1.0. +/// The objective function (to be maximized) is the negated sum of squared +/// distances from the cluster center to each vector, times that vector's +/// weight. +class VectorClusterable: public Clusterable { + public: + VectorClusterable(): weight_(0.0), sumsq_(0.0) {} + + VectorClusterable(const Vector &vector, + BaseFloat weight); + + virtual std::string Type() const { return "vector"; } + // Objf is negated weighted sum of squared distances. + virtual BaseFloat Objf() const; + virtual void SetZero() { weight_ = 0.0; sumsq_ = 0.0; stats_.Set(0.0); } + virtual void Add(const Clusterable &other_in); + virtual void Sub(const Clusterable &other_in); + virtual BaseFloat Normalizer() const { return weight_; } + virtual Clusterable *Copy() const; + virtual void Scale(BaseFloat f); + virtual void Write(std::ostream &os, bool binary) const; + virtual Clusterable *ReadNew(std::istream &is, bool binary) const; + virtual ~VectorClusterable() {} + + private: + double weight_; // sum of weights of the source vectors. Never negative. + Vector stats_; // Equals the weighted sum of the source vectors. + double sumsq_; // Equals the sum over all sources, of weight_ * vec.vec, + // where vec = stats_ / weight_. Used in computing + // the objective function. + void Read(std::istream &is, bool binary); +}; + + + +} // end namespace kaldi. + +#endif // KALDI_TREE_CLUSTERABLE_CLASSES_H_ diff --git a/kaldi_io/src/kaldi/tree/context-dep.h b/kaldi_io/src/kaldi/tree/context-dep.h new file mode 100644 index 0000000..307fcd4 --- /dev/null +++ b/kaldi_io/src/kaldi/tree/context-dep.h @@ -0,0 +1,166 @@ +// tree/context-dep.h + +// Copyright 2009-2011 Microsoft Corporation + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#ifndef KALDI_TREE_CONTEXT_DEP_H_ +#define KALDI_TREE_CONTEXT_DEP_H_ + +#include "itf/context-dep-itf.h" +#include "tree/event-map.h" +#include "matrix/matrix-lib.h" +#include "tree/cluster-utils.h" + +/* + This header provides the declarations for the class ContextDependency, which inherits + from the interface class "ContextDependencyInterface" in itf/context-dep-itf.h. + This is basically a wrapper around an EventMap. The EventMap + (tree/event-map.h) declares most of the internals of the class, and the building routines are + in build-tree.h which uses build-tree-utils.h, which uses cluster-utils.h . */ + + +namespace kaldi { + +static const EventKeyType kPdfClass = -1; // The "name" to which we assign the +// pdf-class (generally corresponds ot position in the HMM, zero-based); +// must not be used for any other event. I.e. the value corresponding to +// this key is the pdf-class (see hmm-topology.h for explanation of what this is). + + +/* ContextDependency is quite a generic decision tree. + + It does not actually do very much-- all the magic is in the EventMap object. + All this class does is to encode the phone context as a sequence of events, and + pass this to the EventMap object to turn into what it will interpret as a + vector of pdfs. + + Different versions of the ContextDependency class that are written in the future may + have slightly different interfaces and pass more stuff in as events, to the + EventMap object. + + In order to separate the process of training decision trees from the process + of actually using them, we do not put any training code into the ContextDependency class. + */ +class ContextDependency: public ContextDependencyInterface { + public: + virtual int32 ContextWidth() const { return N_; } + virtual int32 CentralPosition() const { return P_; } + + + /// returns success or failure; outputs pdf to pdf_id + virtual bool Compute(const std::vector &phoneseq, + int32 pdf_class, int32 *pdf_id) const; + + virtual int32 NumPdfs() const { + // this routine could be simplified to return to_pdf_->MaxResult()+1. we're a + // bit more paranoid than that. + if (!to_pdf_) return 0; + EventAnswerType max_result = to_pdf_->MaxResult(); + if (max_result < 0 ) return 0; + else return (int32) max_result+1; + } + virtual ContextDependencyInterface *Copy() const { + return new ContextDependency(N_, P_, to_pdf_->Copy()); + } + + /// Read context-dependency object from disk; throws on error + void Read (std::istream &is, bool binary); + + // Constructor with no arguments; will normally be called + // prior to Read() + ContextDependency(): N_(0), P_(0), to_pdf_(NULL) { } + + // Constructor takes ownership of pointers. + ContextDependency(int32 N, int32 P, + EventMap *to_pdf): + N_(N), P_(P), to_pdf_(to_pdf) { } + void Write (std::ostream &os, bool binary) const; + + ~ContextDependency() { if (to_pdf_ != NULL) delete to_pdf_; } + + const EventMap &ToPdfMap() const { return *to_pdf_; } + + /// GetPdfInfo returns a vector indexed by pdf-id, saying for each pdf which + /// pairs of (phone, pdf-class) it can correspond to. (Usually just one). + /// c.f. hmm/hmm-topology.h for meaning of pdf-class. + + void GetPdfInfo(const std::vector &phones, // list of phones + const std::vector &num_pdf_classes, // indexed by phone, + std::vector > > *pdf_info) + const; + + private: + int32 N_; // + int32 P_; + EventMap *to_pdf_; // owned here. + + KALDI_DISALLOW_COPY_AND_ASSIGN(ContextDependency); +}; + +/// GenRandContextDependency is mainly of use for debugging. Phones must be sorted and uniq +/// on input. +/// @param phones [in] A vector of phone id's [must be sorted and uniq]. +/// @param ensure_all_covered [in] boolean argument; if true, GenRandContextDependency +/// generates a context-dependency object that "works" for all phones [no gaps]. +/// @param num_pdf_classes [out] outputs a vector indexed by phone, of the number +/// of pdf classes (e.g. states) for that phone. +/// @return Returns the a context dependency object. +ContextDependency *GenRandContextDependency(const std::vector &phones, + bool ensure_all_covered, + std::vector *num_pdf_classes); + +/// GenRandContextDependencyLarge is like GenRandContextDependency but generates a larger tree +/// with specified N and P for use in "one-time" larger-scale tests. +ContextDependency *GenRandContextDependencyLarge(const std::vector &phones, + int N, int P, + bool ensure_all_covered, + std: