summaryrefslogtreecommitdiff
path: root/kaldi_io/src/kaldi/tree
diff options
context:
space:
mode:
authorDeterminant <[email protected]>2015-08-14 11:51:42 +0800
committerDeterminant <[email protected]>2015-08-14 11:51:42 +0800
commit96a32415ab43377cf1575bd3f4f2980f58028209 (patch)
tree30a2d92d73e8f40ac87b79f6f56e227bfc4eea6e /kaldi_io/src/kaldi/tree
parentc177a7549bd90670af4b29fa813ddea32cfe0f78 (diff)
add implementation for kaldi io (by ymz)
Diffstat (limited to 'kaldi_io/src/kaldi/tree')
-rw-r--r--kaldi_io/src/kaldi/tree/build-tree-questions.h133
-rw-r--r--kaldi_io/src/kaldi/tree/build-tree-utils.h324
-rw-r--r--kaldi_io/src/kaldi/tree/build-tree.h250
-rw-r--r--kaldi_io/src/kaldi/tree/cluster-utils.h291
-rw-r--r--kaldi_io/src/kaldi/tree/clusterable-classes.h158
-rw-r--r--kaldi_io/src/kaldi/tree/context-dep.h166
-rw-r--r--kaldi_io/src/kaldi/tree/event-map.h365
-rw-r--r--kaldi_io/src/kaldi/tree/tree-renderer.h84
8 files changed, 1771 insertions, 0 deletions
diff --git a/kaldi_io/src/kaldi/tree/build-tree-questions.h b/kaldi_io/src/kaldi/tree/build-tree-questions.h
new file mode 100644
index 0000000..a6bcfdd
--- /dev/null
+++ b/kaldi_io/src/kaldi/tree/build-tree-questions.h
@@ -0,0 +1,133 @@
+// tree/build-tree-questions.h
+
+// Copyright 2009-2011 Microsoft Corporation
+
+// See ../../COPYING for clarification regarding multiple authors
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
+// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
+// MERCHANTABLITY OR NON-INFRINGEMENT.
+// See the Apache 2 License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef KALDI_TREE_BUILD_TREE_QUESTIONS_H_
+#define KALDI_TREE_BUILD_TREE_QUESTIONS_H_
+
+#include "util/stl-utils.h"
+#include "tree/context-dep.h"
+
+namespace kaldi {
+
+
+/// \addtogroup tree_group
+/// @{
+/// Typedef for statistics to build trees.
+typedef std::vector<std::pair<EventType, Clusterable*> > BuildTreeStatsType;
+
+/// Typedef used when we get "all keys" from a set of stats-- used in specifying
+/// which kinds of questions to ask.
+typedef enum { kAllKeysInsistIdentical, kAllKeysIntersection, kAllKeysUnion } AllKeysType;
+
+/// @}
+
+/// \defgroup tree_group_questions Question sets for decision-tree clustering
+/// See \ref tree_internals (and specifically \ref treei_func_questions) for context.
+/// \ingroup tree_group
+/// @{
+
+/// QuestionsForKey is a class used to define the questions for a key,
+/// and also options that allow us to refine the question during tree-building
+/// (i.e. make a question specific to the location in the tree).
+/// The Questions class handles aggregating these options for a set
+/// of different keys.
+struct QuestionsForKey { // Configuration class associated with a particular key
+ // (of type EventKeyType). It also contains the questions themselves.
+ std::vector<std::vector<EventValueType> > initial_questions;
+ RefineClustersOptions refine_opts; // if refine_opts.max_iter == 0,
+ // we just pick from the initial questions.
+
+ QuestionsForKey(int32 num_iters = 5): refine_opts(num_iters, 2) {
+ // refine_cfg with 5 iters and top-n = 2 (this is no restriction because
+ // RefineClusters called with 2 clusters; would get set to that anyway as
+ // it's the only possible value for 2 clusters). User has to add questions.
+ // This config won't work as-is, as it has no questions.
+ }
+
+ void Check() const {
+ for (size_t i = 0;i < initial_questions.size();i++) KALDI_ASSERT(IsSorted(initial_questions[i]));
+ }
+
+ void Write(std::ostream &os, bool binary) const;
+ void Read(std::istream &is, bool binary);
+
+ // copy and assign allowed.
+};
+
+/// This class defines, for each EventKeyType, a set of initial questions that
+/// it tries and also a number of iterations for which to refine the questions to increase
+/// likelihood. It is perhaps a bit more than an options class, as it contains the
+/// actual questions.
+class Questions { // careful, this is a class.
+ public:
+ const QuestionsForKey &GetQuestionsOf(EventKeyType key) const {
+ std::map<EventKeyType, size_t>::const_iterator iter;
+ if ( (iter = key_idx_.find(key)) == key_idx_.end()) {
+ KALDI_ERR << "Questions: no options for key "<< key;
+ }
+ size_t idx = iter->second;
+ KALDI_ASSERT(idx < key_options_.size());
+ key_options_[idx]->Check();
+ return *(key_options_[idx]);
+ }
+ void SetQuestionsOf(EventKeyType key, const QuestionsForKey &options_of_key) {
+ options_of_key.Check();
+ if (key_idx_.count(key) == 0) {
+ key_idx_[key] = key_options_.size();
+ key_options_.push_back(new QuestionsForKey());
+ *(key_options_.back()) = options_of_key;
+ } else {
+ size_t idx = key_idx_[key];
+ KALDI_ASSERT(idx < key_options_.size());
+ *(key_options_[idx]) = options_of_key;
+ }
+ }
+ void GetKeysWithQuestions(std::vector<EventKeyType> *keys_out) const {
+ KALDI_ASSERT(keys_out != NULL);
+ CopyMapKeysToVector(key_idx_, keys_out);
+ }
+ const bool HasQuestionsForKey(EventKeyType key) const { return (key_idx_.count(key) != 0); }
+ ~Questions() { kaldi::DeletePointers(&key_options_); }
+
+
+ /// Initializer with arguments. After using this you would have to set up the config for each key you
+ /// are going to use, or use InitRand().
+ Questions() { }
+
+
+ /// InitRand attempts to generate "reasonable" random questions. Only
+ /// of use for debugging. This initializer creates a config that is
+ /// ready to use.
+ /// e.g. num_iters_refine = 0 means just use stated questions (if >1, will use
+ /// different questions at each split of the tree).
+ void InitRand(const BuildTreeStatsType &stats, int32 num_quest, int32 num_iters_refine, AllKeysType all_keys_type);
+
+ void Write(std::ostream &os, bool binary) const;
+ void Read(std::istream &is, bool binary);
+ private:
+ std::vector<QuestionsForKey*> key_options_;
+ std::map<EventKeyType, size_t> key_idx_;
+ KALDI_DISALLOW_COPY_AND_ASSIGN(Questions);
+};
+
+/// @}
+
+}// end namespace kaldi
+
+#endif // KALDI_TREE_BUILD_TREE_QUESTIONS_H_
diff --git a/kaldi_io/src/kaldi/tree/build-tree-utils.h b/kaldi_io/src/kaldi/tree/build-tree-utils.h
new file mode 100644
index 0000000..464fc6b
--- /dev/null
+++ b/kaldi_io/src/kaldi/tree/build-tree-utils.h
@@ -0,0 +1,324 @@
+// tree/build-tree-utils.h
+
+// Copyright 2009-2011 Microsoft Corporation
+
+// See ../../COPYING for clarification regarding multiple authors
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
+// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
+// MERCHANTABLITY OR NON-INFRINGEMENT.
+// See the Apache 2 License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef KALDI_TREE_BUILD_TREE_UTILS_H_
+#define KALDI_TREE_BUILD_TREE_UTILS_H_
+
+#include "tree/build-tree-questions.h"
+
+// build-tree-questions.h needed for this typedef:
+// typedef std::vector<std::pair<EventType, Clusterable*> > BuildTreeStatsType;
+// and for other #includes.
+
+namespace kaldi {
+
+
+/// \defgroup tree_group_lower Low-level functions for manipulating statistics and event-maps
+/// See \ref tree_internals and specifically \ref treei_func for context.
+/// \ingroup tree_group
+///
+/// @{
+
+
+
+/// This frees the Clusterable* pointers in "stats", where non-NULL, and sets them to NULL.
+/// Does not delete the pointer "stats" itself.
+void DeleteBuildTreeStats(BuildTreeStatsType *stats);
+
+/// Writes BuildTreeStats object. This works even if pointers are NULL.
+void WriteBuildTreeStats(std::ostream &os, bool binary,
+ const BuildTreeStatsType &stats);
+
+/// Reads BuildTreeStats object. The "example" argument must be of the same
+/// type as the stats on disk, and is needed for access to the correct "Read"
+/// function. It was organized this way for easier extensibility (so adding new
+/// Clusterable derived classes isn't painful)
+void ReadBuildTreeStats(std::istream &is, bool binary,
+ const Clusterable &example, BuildTreeStatsType *stats);
+
+/// Convenience function e.g. to work out possible values of the phones from just the stats.
+/// Returns true if key was always defined inside the stats.
+/// May be used with and == NULL to find out of key was always defined.
+bool PossibleValues(EventKeyType key, const BuildTreeStatsType &stats,
+ std::vector<EventValueType> *ans);
+
+
+/// Splits stats according to the EventMap, indexing them at output by the
+/// leaf type. A utility function. NOTE-- pointers in stats_out point to
+/// the same memory location as those in stats. No copying of Clusterable*
+/// objects happens. Will add to stats in stats_out if non-empty at input.
+/// This function may increase the size of vector stats_out as necessary
+/// to accommodate stats, but will never decrease the size.
+void SplitStatsByMap(const BuildTreeStatsType &stats_in, const EventMap &e,
+ std::vector<BuildTreeStatsType> *stats_out);
+
+/// SplitStatsByKey splits stats up according to the value of a particular key,
+/// which must be always defined and nonnegative. Like MapStats. Pointers to
+/// Clusterable* in stats_out are not newly allocated-- they are the same as the
+/// ones in stats_in. Generally they will still be owned at stats_in (user can
+/// decide where to allocate ownership).
+void SplitStatsByKey(const BuildTreeStatsType &stats_in, EventKeyType key,
+ std::vector<BuildTreeStatsType> *stats_out);
+
+
+/// Converts stats from a given context-window (N) and central-position (P) to a
+/// different N and P, by possibly reducing context. This function does a job
+/// that's quite specific to the "normal" stats format we use. See \ref
+/// tree_window for background. This function may delete some keys and change
+/// others, depending on the N and P values. It expects that at input, all keys
+/// will either be -1 or lie between 0 and oldN-1. At output, keys will be
+/// either -1 or between 0 and newN-1.
+/// Returns false if we could not convert the stats (e.g. because newN is larger
+/// than oldN).
+bool ConvertStats(int32 oldN, int32 oldP, int32 newN, int32 newP,
+ BuildTreeStatsType *stats);
+
+
+/// FilterStatsByKey filters the stats according the value of a specified key.
+/// If include_if_present == true, it only outputs the stats whose key is in
+/// "values"; otherwise it only outputs the stats whose key is not in "values".
+/// At input, "values" must be sorted and unique, and all stats in "stats_in"
+/// must have "key" defined. At output, pointers to Clusterable* in stats_out
+/// are not newly allocated-- they are the same as the ones in stats_in.
+void FilterStatsByKey(const BuildTreeStatsType &stats_in,
+ EventKeyType key,
+ std::vector<EventValueType> &values,
+ bool include_if_present, // true-> retain only if in "values",
+ // false-> retain only if not in "values".
+ BuildTreeStatsType *stats_out);
+
+
+/// Sums stats, or returns NULL stats_in has no non-NULL stats.
+/// Stats are newly allocated, owned by caller.
+Clusterable *SumStats(const BuildTreeStatsType &stats_in);
+
+/// Sums the normalizer [typically, data-count] over the stats.
+BaseFloat SumNormalizer(const BuildTreeStatsType &stats_in);
+
+/// Sums the objective function over the stats.
+BaseFloat SumObjf(const BuildTreeStatsType &stats_in);
+
+
+/// Sum a vector of stats. Leaves NULL as pointer if no stats available.
+/// The pointers in stats_out are owned by caller. At output, there may be
+/// NULLs in the vector stats_out.
+void SumStatsVec(const std::vector<BuildTreeStatsType> &stats_in, std::vector<Clusterable*> *stats_out);
+
+/// Cluster the stats given the event map return the total objf given those clusters.
+BaseFloat ObjfGivenMap(const BuildTreeStatsType &stats_in, const EventMap &e);
+
+
+/// FindAllKeys puts in *keys the (sorted, unique) list of all key identities in the stats.
+/// If type == kAllKeysInsistIdentical, it will insist that this set of keys is the same for all the
+/// stats (else exception is thrown).
+/// if type == kAllKeysIntersection, it will return the smallest common set of keys present in
+/// the set of stats
+/// if type== kAllKeysUnion (currently probably not so useful since maps will return "undefined"
+/// if key is not present), it will return the union of all the keys present in the stats.
+void FindAllKeys(const BuildTreeStatsType &stats, AllKeysType keys_type,
+ std::vector<EventKeyType> *keys);
+
+
+/// @}
+
+
+/**
+ \defgroup tree_group_intermediate Intermediate-level functions used in building the tree
+ These functions are are used in top-level tree-building code (\ref tree_group_top); see
+ \ref tree_internals for documentation.
+ \ingroup tree_group
+ @{
+*/
+
+
+/// Returns a tree with just one node. Used @ start of tree-building process.
+/// Not really used in current recipes.
+inline EventMap *TrivialTree(int32 *num_leaves) {
+ KALDI_ASSERT(*num_leaves == 0); // in envisaged usage.
+ return new ConstantEventMap( (*num_leaves)++ );
+}
+
+/// DoTableSplit does a complete split on this key (e.g. might correspond to central phone
+/// (key = P-1), or HMM-state position (key == kPdfClass == -1). Stats used to work out possible
+/// values of the event. "num_leaves" is used to allocate new leaves. All stats must have
+/// this key defined, or this function will crash.
+EventMap *DoTableSplit(const EventMap &orig, EventKeyType key,
+ const BuildTreeStatsType &stats, int32 *num_leaves);
+
+
+/// DoTableSplitMultiple does a complete split on all the keys, in order from keys[0],
+/// keys[1]
+/// and so on. The stats are used to work out possible values corresponding to the key.
+/// "num_leaves" is used to allocate new leaves. All stats must have
+/// the keys defined, or this function will crash.
+/// Returns a newly allocated event map.
+EventMap *DoTableSplitMultiple(const EventMap &orig,
+ const std::vector<EventKeyType> &keys,
+ const BuildTreeStatsType &stats,
+ int32 *num_leaves);
+
+
+/// "ClusterEventMapGetMapping" clusters the leaves of the EventMap, with "thresh" a delta-likelihood
+/// threshold to control how many leaves we combine (might be the same as the delta-like
+/// threshold used in splitting.
+// The function returns the #leaves we combined. The same leaf-ids of the leaves being clustered
+// will be used for the clustered leaves (but other than that there is no special rule which
+// leaf-ids should be used at output).
+// It outputs the mapping for leaves, in "mapping", which may be empty at the start
+// but may also contain mappings for other parts of the tree, which must contain
+// disjoint leaves from this part. This is so that Cluster can
+// be called multiple times for sub-parts of the tree (with disjoint sets of leaves),
+// e.g. if we want to avoid sharing across phones. Afterwards you can use Copy function
+// of EventMap to apply the mapping, i.e. call e_in.Copy(mapping) to get the new map.
+// Note that the application of Cluster creates gaps in the leaves. You should then
+// call RenumberEventMap(e_in.Copy(mapping), num_leaves).
+// *If you only want to cluster a subset of the leaves (e.g. just non-silence, or just
+// a particular phone, do this by providing a set of "stats" that correspond to just
+// this subset of leaves*. Leaves with no stats will not be clustered.
+// See build-tree.cc for an example of usage.
+int ClusterEventMapGetMapping(const EventMap &e_in, const BuildTreeStatsType &stats,
+ BaseFloat thresh, std::vector<EventMap*> *mapping);
+
+/// This is as ClusterEventMapGetMapping but a more convenient interface
+/// that exposes less of the internals. It uses a bottom-up clustering to
+/// combine the leaves, until the log-likelihood decrease from combinging two
+/// leaves exceeds the threshold.
+EventMap *ClusterEventMap(const EventMap &e_in, const BuildTreeStatsType &stats,
+ BaseFloat thresh, int32 *num_removed);
+
+/// This is as ClusterEventMap, but first splits the stats on the keys specified
+/// in "keys" (e.g. typically keys = [ -1, P ]), and only clusters within the
+/// classes defined by that splitting.
+/// Note-- leaves will be non-consecutive at output, use RenumberEventMap.
+EventMap *ClusterEventMapRestrictedByKeys(const EventMap &e_in,
+ const BuildTreeStatsType &stats,
+ BaseFloat thresh,
+ const std::vector<EventKeyType> &keys,
+ int32 *num_removed);
+
+
+/// This version of ClusterEventMapRestricted restricts the clustering to only
+/// allow things that "e_restrict" maps to the same value to be clustered
+/// together.
+EventMap *ClusterEventMapRestrictedByMap(const EventMap &e_in,
+ const BuildTreeStatsType &stats,
+ BaseFloat thresh,
+ const EventMap &e_restrict,
+ int32 *num_removed);
+
+
+/// RenumberEventMap [intended to be used after calling ClusterEventMap] renumbers
+/// an EventMap so its leaves are consecutive.
+/// It puts the number of leaves in *num_leaves. If later you need the mapping of
+/// the leaves, modify the function and add a new argument.
+EventMap *RenumberEventMap(const EventMap &e_in, int32 *num_leaves);
+
+/// This function remaps the event-map leaves using this mapping,
+/// indexed by the number at leaf.
+EventMap *MapEventMapLeaves(const EventMap &e_in,
+ const std::vector<int32> &mapping);
+
+
+
+/// ShareEventMapLeaves performs a quite specific function that allows us to
+/// generate trees where, for a certain list of phones, and for all states in
+/// the phone, all the pdf's are shared.
+/// Each element of "values" contains a list of phones (may be just one phone),
+/// all states of which we want shared together). Typically at input, "key" will
+/// equal P, the central-phone position, and "values" will contain just one
+/// list containing the silence phone.
+/// This function renumbers the event map leaves after doing the sharing, to
+/// make the event-map leaves contiguous.
+EventMap *ShareEventMapLeaves(const EventMap &e_in, EventKeyType key,
+ std::vector<std::vector<EventValueType> > &values,
+ int32 *num_leaves);
+
+
+
+/// Does a decision-tree split at the leaves of an EventMap.
+/// @param orig [in] The EventMap whose leaves we want to split. [may be either a trivial or a
+/// non-trivial one].
+/// @param stats [in] The statistics for splitting the tree; if you do not want a particular
+/// subset of leaves to be split, make sure the stats corresponding to those leaves
+/// are not present in "stats".
+/// @param qcfg [in] Configuration class that contains initial questions (e.g. sets of phones)
+/// for each key and says whether to refine these questions during tree building.
+/// @param thresh [in] A log-likelihood threshold (e.g. 300) that can be used to
+/// limit the number of leaves; you can use zero and set max_leaves instead.
+/// @param max_leaves [in] Will stop leaves being split after they reach this number.
+/// @param num_leaves [in,out] A pointer used to allocate leaves; always corresponds to the
+/// current number of leaves (is incremented when this is increased).
+/// @param objf_impr_out [out] If non-NULL, will be set to the objective improvement due to splitting
+/// (not normalized by the number of frames).
+/// @param smallest_split_change_out If non-NULL, will be set to the smallest objective-function
+/// improvement that we got from splitting any leaf; useful to provide a threshold
+/// for ClusterEventMap.
+/// @return The EventMap after splitting is returned; pointer is owned by caller.
+EventMap *SplitDecisionTree(const EventMap &orig,
+ const BuildTreeStatsType &stats,
+ Questions &qcfg,
+ BaseFloat thresh,
+ int32 max_leaves, // max_leaves<=0 -> no maximum.
+ int32 *num_leaves,
+ BaseFloat *objf_impr_out,
+ BaseFloat *smallest_split_change_out);
+
+/// CreateRandomQuestions will initialize a Questions randomly, in a reasonable
+/// way [for testing purposes, or when hand-designed questions are not available].
+/// e.g. num_quest = 5 might be a reasonable value if num_iters > 0, or num_quest = 20 otherwise.
+void CreateRandomQuestions(const BuildTreeStatsType &stats, int32 num_quest, Questions *cfg_out);
+
+
+/// FindBestSplitForKey is a function used in DoDecisionTreeSplit.
+/// It finds the best split for this key, given these stats.
+/// It will return 0 if the key was not always defined for the stats.
+BaseFloat FindBestSplitForKey(const BuildTreeStatsType &stats,
+ const Questions &qcfg,
+ EventKeyType key,
+ std::vector<EventValueType> *yes_set);
+
+
+/// GetStubMap is used in tree-building functions to get the initial
+/// to-states map, before the decision-tree-building process. It creates
+/// a simple map that splits on groups of phones. For the set of phones in
+/// phone_sets[i] it creates either: if share_roots[i] == true, a single
+/// leaf node, or if share_roots[i] == false, separate root nodes for
+/// each HMM-position (it goes up to the highest position for any
+/// phone in the set, although it will warn if you share roots between
+/// phones with different numbers of states, which is a weird thing to
+/// do but should still work. If any phone is present
+/// in "phone_sets" but "phone2num_pdf_classes" does not map it to a length,
+/// it is an error. Note that the behaviour of the resulting map is
+/// undefined for phones not present in "phone_sets".
+/// At entry, this function should be called with (*num_leaves == 0).
+/// It will number the leaves starting from (*num_leaves).
+
+EventMap *GetStubMap(int32 P,
+ const std::vector<std::vector<int32> > &phone_sets,
+ const std::vector<int32> &phone2num_pdf_classes,
+ const std::vector<bool> &share_roots, // indexed by index into phone_sets.
+ int32 *num_leaves);
+/// Note: GetStubMap with P = 0 can be used to get a standard monophone system.
+
+/// @}
+
+
+}// end namespace kaldi
+
+#endif
diff --git a/kaldi_io/src/kaldi/tree/build-tree.h b/kaldi_io/src/kaldi/tree/build-tree.h
new file mode 100644
index 0000000..37bb108
--- /dev/null
+++ b/kaldi_io/src/kaldi/tree/build-tree.h
@@ -0,0 +1,250 @@
+// tree/build-tree.h
+
+// Copyright 2009-2011 Microsoft Corporation
+
+// See ../../COPYING for clarification regarding multiple authors
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
+// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
+// MERCHANTABLITY OR NON-INFRINGEMENT.
+// See the Apache 2 License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef KALDI_TREE_BUILD_TREE_H_
+#define KALDI_TREE_BUILD_TREE_H_
+
+// The file build-tree.h contains outer-level routines used in tree-building
+// and related tasks, that are directly called by the command-line tools.
+
+#include "tree/build-tree-utils.h"
+#include "tree/context-dep.h"
+namespace kaldi {
+
+/// \defgroup tree_group_top Top-level tree-building functions
+/// See \ref tree_internals for context.
+/// \ingroup tree_group
+/// @{
+
+// Note, in tree_group_top we also include AccumulateTreeStats, in
+// ../hmm/tree-accu.h (it has some extra dependencies so we didn't
+// want to include it here).
+
+/**
+ * BuildTree is the normal way to build a set of decision trees.
+ * The sets "phone_sets" dictate how we set up the roots of the decision trees.
+ * each set of phones phone_sets[i] has shared decision-tree roots, and if
+ * the corresponding variable share_roots[i] is true, the root will be shared
+ * for the different HMM-positions in the phone. All phones in "phone_sets"
+ * should be in the stats (use FixUnseenPhones to ensure this).
+ * if for any i, do_split[i] is false, we will not do any tree splitting for
+ * phones in that set.
+ * @param qopts [in] Questions options class, contains questions for each key
+ * (e.g. each phone position)
+ * @param phone_sets [in] Each element of phone_sets is a set of phones whose
+ * roots are shared together (prior to decision-tree splitting).
+ * @param phone2num_pdf_classes [in] A map from phones to the number of
+ * \ref pdf_class "pdf-classes"
+ * in the phone (this info is derived from the HmmTopology object)
+ * @param share_roots [in] A vector the same size as phone_sets; says for each
+ * phone set whether the root should be shared among all the
+ * pdf-classes or not.
+ * @param do_split [in] A vector the same size as phone_sets; says for each
+ * phone set whether decision-tree splitting should be done
+ * (generally true for non-silence phones).
+ * @param stats [in] The statistics used in tree-building.
+ * @param thresh [in] Threshold used in decision-tree splitting (e.g. 1000),
+ * or you may use 0 in which case max_leaves becomes the
+ * constraint.
+ * @param max_leaves [in] Maximum number of leaves it will create; set this
+ * to a large number if you want to just specify "thresh".
+ * @param cluster_thresh [in] Threshold for clustering leaves after decision-tree
+ * splitting (only within each phone-set); leaves will be combined
+ * if log-likelihood change is less than this. A value about equal
+ * to "thresh" is suitable
+ * if thresh != 0; otherwise, zero will mean no clustering is done,
+ * or a negative value (e.g. -1) sets it to the smallest likelihood
+ * change seen during the splitting algorithm; this typically causes
+ * about a 20% reduction in the number of leaves.
+
+ * @param P [in] The central position of the phone context window, e.g. 1 for a
+ * triphone system.
+ * @return Returns a pointer to an EventMap object that is the tree.
+
+*/
+
+EventMap *BuildTree(Questions &qopts,
+ const std::vector<std::vector<int32> > &phone_sets,
+ const std::vector<int32> &phone2num_pdf_classes,
+ const std::vector<bool> &share_roots,
+ const std::vector<bool> &do_split,
+ const BuildTreeStatsType &stats,
+ BaseFloat thresh,
+ int32 max_leaves,
+ BaseFloat cluster_thresh, // typically == thresh. If negative, use smallest split.
+ int32 P);
+
+
+/**
+ *
+ * BuildTreeTwoLevel builds a two-level tree, useful for example in building tied mixture
+ * systems with multiple codebooks. It first builds a small tree by splitting to
+ * "max_leaves_first". It then splits at the leaves of "max_leaves_first" (think of this
+ * as creating multiple little trees at the leaves of the first tree), until the total
+ * number of leaves reaches "max_leaves_second". It then outputs the second tree, along
+ * with a mapping from the leaf-ids of the second tree to the leaf-ids of the first tree.
+ * Note that the interface is similar to BuildTree, and in fact it calls BuildTree
+ * internally.
+ *
+ * The sets "phone_sets" dictate how we set up the roots of the decision trees.
+ * each set of phones phone_sets[i] has shared decision-tree roots, and if
+ * the corresponding variable share_roots[i] is true, the root will be shared
+ * for the different HMM-positions in the phone. All phones in "phone_sets"
+ * should be in the stats (use FixUnseenPhones to ensure this).
+ * if for any i, do_split[i] is false, we will not do any tree splitting for
+ * phones in that set.
+ *
+ * @param qopts [in] Questions options class, contains questions for each key
+ * (e.g. each phone position)
+ * @param phone_sets [in] Each element of phone_sets is a set of phones whose
+ * roots are shared together (prior to decision-tree splitting).
+ * @param phone2num_pdf_classes [in] A map from phones to the number of
+ * \ref pdf_class "pdf-classes"
+ * in the phone (this info is derived from the HmmTopology object)
+ * @param share_roots [in] A vector the same size as phone_sets; says for each
+ * phone set whether the root should be shared among all the
+ * pdf-classes or not.
+ * @param do_split [in] A vector the same size as phone_sets; says for each
+ * phone set whether decision-tree splitting should be done
+ * (generally true for non-silence phones).
+ * @param stats [in] The statistics used in tree-building.
+ * @param max_leaves_first [in] Maximum number of leaves it will create in first
+ * level of decision tree.
+ * @param max_leaves_second [in] Maximum number of leaves it will create in second
+ * level of decision tree. Must be > max_leaves_first.
+ * @param cluster_leaves [in] Boolean value; if true, we post-cluster the leaves produced
+ * in the second level of decision-tree split; if false, we don't.
+ * The threshold for post-clustering is the log-like change of the last
+ * decision-tree split; this typically causes about a 20% reduction in
+ * the number of leaves.
+ * @param P [in] The central position of the phone context window, e.g. 1 for a
+ * triphone system.
+ * @param leaf_map [out] Will be set to be a mapping from the leaves of the
+ * "big" tree to the leaves of the "little" tree, which you can
+ * view as cluster centers.
+ * @return Returns a pointer to an EventMap object that is the (big) tree.
+
+*/
+
+EventMap *BuildTreeTwoLevel(Questions &qopts,
+ const std::vector<std::vector<int32> > &phone_sets,
+ const std::vector<int32> &phone2num_pdf_classes,
+ const std::vector<bool> &share_roots,
+ const std::vector<bool> &do_split,
+ const BuildTreeStatsType &stats,
+ int32 max_leaves_first,
+ int32 max_leaves_second,
+ bool cluster_leaves,
+ int32 P,
+ std::vector<int32> *leaf_map);
+
+
+/// GenRandStats generates random statistics of the form used by BuildTree.
+/// It tries to do so in such a way that they mimic "real" stats. The event keys
+/// and their corresponding values are:
+/// - key == -1 == kPdfClass -> pdf-class, generally corresponds to
+/// zero-based position in HMM (0, 1, 2 .. hmm_lengths[phone]-1)
+/// - key == 0 -> phone-id of left-most context phone.
+/// - key == 1 -> phone-id of one-from-left-most context phone.
+/// - key == P-1 -> phone-id of central phone.
+/// - key == N-1 -> phone-id of right-most context phone.
+/// GenRandStats is useful only for testing but it serves to document the format of
+/// stats used by BuildTreeDefault.
+/// if is_ctx_dep[phone] is set to false, GenRandStats will not define the keys for
+/// other than the P-1'th phone.
+
+/// @param dim [in] dimension of features.
+/// @param num_stats [in] approximate number of separate phones-in-context wanted.
+/// @param N [in] context-size (typically 3)
+/// @param P [in] central-phone position in zero-based numbering (typically 1)
+/// @param phone_ids [in] integer ids of phones
+/// @param hmm_lengths [in] lengths of hmm for phone, indexed by phone.
+/// @param is_ctx_dep [in] boolean array indexed by phone, saying whether each phone
+/// is context dependent.
+/// @param ensure_all_phones_covered [in] Boolean argument: if true, GenRandStats
+/// ensures that every phone is seen at least once in the central position (P).
+/// @param stats_out [out] The statistics that this routine outputs.
+
+void GenRandStats(int32 dim, int32 num_stats, int32 N, int32 P,
+ const std::vector<int32> &phone_ids,
+ const std::vector<int32> &hmm_lengths,
+ const std::vector<bool> &is_ctx_dep,
+ bool ensure_all_phones_covered,
+ BuildTreeStatsType *stats_out);
+
+
+/// included here because it's used in some tree-building
+/// calling code. Reads an OpenFst symbl table,
+/// discards the symbols and outputs the integers
+void ReadSymbolTableAsIntegers(std::string filename,
+ bool include_eps,
+ std::vector<int32> *syms);
+
+
+
+/**
+ * Outputs sets of phones that are reasonable for questions
+ * to ask in the tree-building algorithm. These are obtained by tree
+ * clustering of the phones; for each node in the tree, all the leaves
+ * accessible from that node form one of the sets of phones.
+ * @param stats [in] The statistics as used for normal tree-building.
+ * @param phone_sets_in [in] All the phones, pre-partitioned into sets.
+ * The output sets will be various unions of these sets. These sets
+ * will normally correspond to "real phones", in cases where the phones
+ * have stress and position markings.
+ * @param all_pdf_classes_in [in] All the \ref pdf_class "pdf-classes"
+ * that we consider for clustering. In the normal case this is the singleton
+ * set {1}, which means that we only consider the central hmm-position
+ * of the standard 3-state HMM, for clustering purposes.
+ * @param P [in] The central position in the phone context window; normally
+ * 1 for triphone system.s
+ * @param questions_out [out] The questions (sets of phones) are output to here.
+ **/
+void AutomaticallyObtainQuestions(BuildTreeStatsType &stats,
+ const std::vector<std::vector<int32> > &phone_sets_in,
+ const std::vector<int32> &all_pdf_classes_in,
+ int32 P,
+ std::vector<std::vector<int32> > *questions_out);
+
+/// This function clusters the phones (or some initially specified sets of phones)
+/// into sets of phones, using a k-means algorithm. Useful, for example, in building
+/// simple models for purposes of adaptation.
+
+void KMeansClusterPhones(BuildTreeStatsType &stats,
+ const std::vector<std::vector<int32> > &phone_sets_in,
+ const std::vector<int32> &all_pdf_classes_in,
+ int32 P,
+ int32 num_classes,
+ std::vector<std::vector<int32> > *sets_out);
+
+/// Reads the roots file (throws on error). Format is lines like:
+/// "shared split 1 2 3 4",
+/// "not-shared not-split 5",
+/// and so on. The numbers are indexes of phones.
+void ReadRootsFile(std::istream &is,
+ std::vector<std::vector<int32> > *phone_sets,
+ std::vector<bool> *is_shared_root,
+ std::vector<bool> *is_split_root);
+
+
+/// @}
+
+}// end namespace kaldi
+
+#endif
diff --git a/kaldi_io/src/kaldi/tree/cluster-utils.h b/kaldi_io/src/kaldi/tree/cluster-utils.h
new file mode 100644
index 0000000..55583a2
--- /dev/null
+++ b/kaldi_io/src/kaldi/tree/cluster-utils.h
@@ -0,0 +1,291 @@
+// tree/cluster-utils.h
+
+// Copyright 2012 Arnab Ghoshal
+// Copyright 2009-2011 Microsoft Corporation; Saarland University
+
+// See ../../COPYING for clarification regarding multiple authors
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
+// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
+// MERCHANTABLITY OR NON-INFRINGEMENT.
+// See the Apache 2 License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef KALDI_TREE_CLUSTER_UTILS_H_
+#define KALDI_TREE_CLUSTER_UTILS_H_
+
+#include <vector>
+#include "matrix/matrix-lib.h"
+#include "itf/clusterable-itf.h"
+
+namespace kaldi {
+
+/// \addtogroup clustering_group_simple
+/// @{
+
+/// Returns the total objective function after adding up all the
+/// statistics in the vector (pointers may be NULL).
+BaseFloat SumClusterableObjf(const std::vector<Clusterable*> &vec);
+
+/// Returns the total normalizer (usually count) of the cluster (pointers may be NULL).
+BaseFloat SumClusterableNormalizer(const std::vector<Clusterable*> &vec);
+
+/// Sums stats (ptrs may be NULL). Returns NULL if no non-NULL stats present.
+Clusterable *SumClusterable(const std::vector<Clusterable*> &vec);
+
+/** Fills in any (NULL) holes in "stats" vector, with empty stats, because
+ * certain algorithms require non-NULL stats. If "stats" nonempty, requires it
+ * to contain at least one non-NULL pointer that we can call Copy() on.
+ */
+void EnsureClusterableVectorNotNull(std::vector<Clusterable*> *stats);
+
+
+/** Given stats and a vector "assignments" of the same size (that maps to
+ * cluster indices), sums the stats up into "clusters." It will add to any
+ * stats already present in "clusters" (although typically "clusters" will be
+ * empty when called), and it will extend with NULL pointers for any unseen
+ * indices. Call EnsureClusterableStatsNotNull afterwards if you want to ensure
+ * all non-NULL clusters. Pointer in "clusters" are owned by caller. Pointers in
+ * "stats" do not have to be non-NULL.
+ */
+void AddToClusters(const std::vector<Clusterable*> &stats,
+ const std::vector<int32> &assignments,
+ std::vector<Clusterable*> *clusters);
+
+
+/// AddToClustersOptimized does the same as AddToClusters (it sums up the stats
+/// within each cluster, except it uses the sum of all the stats ("total") to
+/// optimize the computation for speed, if possible. This will generally only be
+/// a significant speedup in the case where there are just two clusters, which
+/// can happen in algorithms that are doing binary splits; the idea is that we
+/// sum up all the stats in one cluster (the one with the fewest points in it),
+/// and then subtract from the total.
+void AddToClustersOptimized(const std::vector<Clusterable*> &stats,
+ const std::vector<int32> &assignments,
+ const Clusterable &total,
+ std::vector<Clusterable*> *clusters);
+
+/// @} end "addtogroup clustering_group_simple"
+
+/// \addtogroup clustering_group_algo
+/// @{
+
+// Note, in the algorithms below, it is assumed that the input "points" (which
+// is std::vector<Clusterable*>) is all non-NULL.
+
+/** A bottom-up clustering algorithm. There are two parameters that control how
+ * many clusters we get: a "max_merge_thresh" which is a threshold for merging
+ * clusters, and a min_clust which puts a floor on the number of clusters we want. Set
+ * max_merge_thresh = large to use the min_clust only, or min_clust to 0 to use
+ * the max_merge_thresh only.
+ *
+ * The algorithm is:
+ * \code
+ * while (num-clusters > min_clust && smallest_merge_cost <= max_merge_thresh)
+ * merge the closest two clusters.
+ * \endcode
+ *
+ * @param points [in] Points to be clustered (may not contain NULL pointers)
+ * @param thresh [in] Threshold on cost change from merging clusters; clusters
+ * won't be merged if the cost is more than this
+ * @param min_clust [in] Minimum number of clusters desired; we'll stop merging
+ * after reaching this number.
+ * @param clusters_out [out] If non-NULL, will be set to a vector of size equal
+ * to the number of output clusters, containing the clustered
+ * statistics. Must be empty when called.
+ * @param assignments_out [out] If non-NULL, will be resized to the number of
+ * points, and each element is the index of the cluster that point
+ * was assigned to.
+ * @return Returns the total objf change relative to all clusters being separate, which is
+ * a negative. Note that this is not the same as what the other clustering algorithms return.
+ */
+BaseFloat ClusterBottomUp(const std::vector<Clusterable*> &points,
+ BaseFloat thresh,
+ int32 min_clust,
+ std::vector<Clusterable*> *clusters_out,
+ std::vector<int32> *assignments_out);
+
+/** This is a bottom-up clustering where the points are pre-clustered in a set
+ * of compartments, such that only points in the same compartment are clustered
+ * together. The compartment and pair of points with the smallest merge cost
+ * is selected and the points are clustered. The result stays in the same
+ * compartment. The code does not merge compartments, and hence assumes that
+ * the number of compartments is smaller than the 'min_clust' option.
+ * The clusters in "clusters_out" are newly allocated and owned by the caller.
+ */
+BaseFloat ClusterBottomUpCompartmentalized(
+ const std::vector< std::vector<Clusterable*> > &points, BaseFloat thresh,
+ int32 min_clust, std::vector< std::vector<Clusterable*> > *clusters_out,
+ std::vector< std::vector<int32> > *assignments_out);
+
+
+struct RefineClustersOptions {
+ int32 num_iters; // must be >= 0. If zero, does nothing.
+ int32 top_n; // must be >= 2.
+ RefineClustersOptions() : num_iters(100), top_n(5) {}
+ RefineClustersOptions(int32 num_iters_in, int32 top_n_in)
+ : num_iters(num_iters_in), top_n(top_n_in) {}
+ // include Write and Read functions because this object gets written/read as
+ // part of the QuestionsForKeyOptions class.
+ void Write(std::ostream &os, bool binary) const;
+ void Read(std::istream &is, bool binary);
+};
+
+/** RefineClusters is mainly used internally by other clustering algorithms.
+ *
+ * It starts with a given assignment of points to clusters and
+ * keeps trying to improve it by moving points from cluster to cluster, up to
+ * a maximum number of iterations.
+ *
+ * "clusters" and "assignments" are both input and output variables, and so
+ * both MUST be non-NULL.
+ *
+ * "top_n" (>=2) is a pruning value: more is more exact, fewer is faster. The
+ * algorithm initially finds the "top_n" closest clusters to any given point,
+ * and from that point only consider move to those "top_n" clusters. Since
+ * RefineClusters is called multiple times from ClusterKMeans (for instance),
+ * this is not really a limitation.
+ */
+BaseFloat RefineClusters(const std::vector<Clusterable*> &points,
+ std::vector<Clusterable*> *clusters /*non-NULL*/,
+ std::vector<int32> *assignments /*non-NULL*/,
+ RefineClustersOptions cfg = RefineClustersOptions());
+
+struct ClusterKMeansOptions {
+ RefineClustersOptions refine_cfg;
+ int32 num_iters;
+ int32 num_tries; // if >1, try whole procedure >once and pick best.
+ bool verbose;
+ ClusterKMeansOptions()
+ : refine_cfg(), num_iters(20), num_tries(2), verbose(true) {}
+};
+
+/** ClusterKMeans is a K-means-like clustering algorithm. It starts with
+ * pseudo-random initialization of points to clusters and uses RefineClusters
+ * to iteratively improve the cluster assignments. It does this for
+ * multiple iterations and picks the result with the best objective function.
+ *
+ *
+ * ClusterKMeans implicitly uses Rand(). It will not necessarily return
+ * the same value on different calls. Use sRand() if you want consistent
+ * results.
+ * The algorithm used in ClusterKMeans is a "k-means-like" algorithm that tries
+ * to be as efficient as possible. Firstly, since the algorithm it uses
+ * includes random initialization, it tries the whole thing cfg.num_tries times
+ * and picks the one with the best objective function. Each try, it does as
+ * follows: it randomly initializes points to clusters, and then for
+ * cfg.num_iters iterations it calls RefineClusters(). The options to
+ * RefineClusters() are given by cfg.refine_cfg. Calling RefineClusters once
+ * will always be at least as good as doing one iteration of reassigning points to
+ * clusters, but will generally be quite a bit better (without taking too
+ * much extra time).
+ *
+ * @param points [in] points to be clustered (must be all non-NULL).
+ * @param num_clust [in] number of clusters requested (it will always return exactly
+ * this many, or will fail if num_clust > points.size()).
+ * @param clusters_out [out] may be NULL; if non-NULL, should be empty when called.
+ * Will be set to a vector of statistics corresponding to the output clusters.
+ * @param assignments_out [out] may be NULL; if non-NULL, will be set to a vector of
+ * same size as "points", which says for each point which cluster
+ * it is assigned to.
+ * @param cfg [in] configuration class specifying options to the algorithm.
+ * @return Returns the objective function improvement versus everything being
+ * in the same cluster.
+ *
+ */
+BaseFloat ClusterKMeans(const std::vector<Clusterable*> &points,
+ int32 num_clust, // exact number of clusters
+ std::vector<Clusterable*> *clusters_out, // may be NULL
+ std::vector<int32> *assignments_out, // may be NULL
+ ClusterKMeansOptions cfg = ClusterKMeansOptions());
+
+struct TreeClusterOptions {
+ ClusterKMeansOptions kmeans_cfg;
+ int32 branch_factor;
+ BaseFloat thresh; // Objf change: if >0, may be used to control number of leaves.
+ TreeClusterOptions()
+ : kmeans_cfg(), branch_factor(2), thresh(0) {
+ kmeans_cfg.verbose = false;
+ }
+};
+
+/** TreeCluster is a top-down clustering algorithm, using a binary tree (not
+ * necessarily balanced). Returns objf improvement versus having all points
+ * in one cluster. The algorithm is:
+ * - Initialize to 1 cluster (tree with 1 node).
+ * - Maintain, for each cluster, a "best-binary-split" (using ClusterKMeans
+ * to do so). Always split the highest scoring cluster, until we can do no
+ * more splits.
+ *
+ * @param points [in] Data points to be clustered
+ * @param max_clust [in] Maximum number of clusters (you will get exactly this number,
+ * if there are at least this many points, except if you set the
+ * cfg.thresh value nonzero, in which case that threshold may limit
+ * the number of clusters.
+ * @param clusters_out [out] If non-NULL, will be set to the a vector whose first
+ * (*num_leaves_out) elements are the leaf clusters, and whose
+ * subsequent elements are the nonleaf nodes in the tree, in
+ * topological order with the root node last. Must be empty vector
+ * when this function is called.
+ * @param assignments_out [out] If non-NULL, will be set to a vector to a vector the
+ * same size as "points", where assignments[i] is the leaf node index i
+ * to which the i'th point gets clustered.
+ * @param clust_assignments_out [out] If non-NULL, will be set to a vector the same size
+ * as clusters_out which says for each node (leaf or nonleaf), the
+ * index of its parent. For the root node (which is last),
+ * assignments_out[i] == i. For each i, assignments_out[i]>=i, i.e.
+ * any node's parent is higher numbered than itself. If you don't need
+ * this information, consider using instead the ClusterTopDown function.
+ * @param num_leaves_out [out] If non-NULL, will be set to the number of leaf nodes
+ * in the tree.
+ * @param cfg [in] Configuration object that controls clustering behavior. Most
+ * important value is "thresh", which provides an alternative mechanism
+ * [other than max_clust] to limit the number of leaves.
+ */
+BaseFloat TreeCluster(const std::vector<Clusterable*> &points,
+ int32 max_clust, // max number of leaf-level clusters.
+ std::vector<Clusterable*> *clusters_out,
+ std::vector<int32> *assignments_out,
+ std::vector<int32> *clust_assignments_out,
+ int32 *num_leaves_out,
+ TreeClusterOptions cfg = TreeClusterOptions());
+
+
+/**
+ * A clustering algorithm that internally uses TreeCluster,
+ * but does not give you the information about the structure of the tree.
+ * The "clusters_out" and "assignments_out" may be NULL if the outputs are not
+ * needed.
+ *
+ * @param points [in] points to be clustered (must be all non-NULL).
+ * @param max_clust [in] Maximum number of clusters (you will get exactly this number,
+ * if there are at least this many points, except if you set the
+ * cfg.thresh value nonzero, in which case that threshold may limit
+ * the number of clusters.
+ * @param clusters_out [out] may be NULL; if non-NULL, should be empty when called.
+ * Will be set to a vector of statistics corresponding to the output clusters.
+ * @param assignments_out [out] may be NULL; if non-NULL, will be set to a vector of
+ * same size as "points", which says for each point which cluster
+ * it is assigned to.
+ * @param cfg [in] Configuration object that controls clustering behavior. Most
+ * important value is "thresh", which provides an alternative mechanism
+ * [other than max_clust] to limit the number of leaves.
+*/
+BaseFloat ClusterTopDown(const std::vector<Clusterable*> &points,
+ int32 max_clust, // max number of clusters.
+ std::vector<Clusterable*> *clusters_out,
+ std::vector<int32> *assignments_out,
+ TreeClusterOptions cfg = TreeClusterOptions());
+
+/// @} end of "addtogroup clustering_group_algo"
+
+} // end namespace kaldi.
+
+#endif // KALDI_TREE_CLUSTER_UTILS_H_
diff --git a/kaldi_io/src/kaldi/tree/clusterable-classes.h b/kaldi_io/src/kaldi/tree/clusterable-classes.h
new file mode 100644
index 0000000..817d0c6
--- /dev/null
+++ b/kaldi_io/src/kaldi/tree/clusterable-classes.h
@@ -0,0 +1,158 @@
+// tree/clusterable-classes.h
+
+// Copyright 2009-2011 Microsoft Corporation; Saarland University
+// 2014 Daniel Povey
+
+// See ../../COPYING for clarification regarding multiple authors
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
+// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
+// MERCHANTABLITY OR NON-INFRINGEMENT.
+// See the Apache 2 License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef KALDI_TREE_CLUSTERABLE_CLASSES_H_
+#define KALDI_TREE_CLUSTERABLE_CLASSES_H_ 1
+
+#include <string>
+#include "itf/clusterable-itf.h"
+#include "matrix/matrix-lib.h"
+
+namespace kaldi {
+
+// Note: see sgmm/sgmm-clusterable.h for an SGMM-based clusterable
+// class. We didn't include it here, to avoid adding an extra
+// dependency to this directory.
+
+/// \addtogroup clustering_group
+/// @{
+
+/// ScalarClusterable clusters scalars with x^2 loss.
+class ScalarClusterable: public Clusterable {
+ public:
+ ScalarClusterable(): x_(0), x2_(0), count_(0) {}
+ explicit ScalarClusterable(BaseFloat x): x_(x), x2_(x*x), count_(1) {}
+ virtual std::string Type() const { return "scalar"; }
+ virtual BaseFloat Objf() const;
+ virtual void SetZero() { count_ = x_ = x2_ = 0.0; }
+ virtual void Add(const Clusterable &other_in);
+ virtual void Sub(const Clusterable &other_in);
+ virtual Clusterable* Copy() const;
+ virtual BaseFloat Normalizer() const {
+ return static_cast<BaseFloat>(count_);
+ }
+
+ // Function to write data to stream. Will organize input later [more complex]
+ virtual void Write(std::ostream &os, bool binary) const;
+ virtual Clusterable* ReadNew(std::istream &is, bool binary) const;
+
+ std::string Info(); // For debugging.
+ BaseFloat Mean() { return (count_ != 0 ? x_/count_ : 0.0); }
+ private:
+ BaseFloat x_;
+ BaseFloat x2_;
+ BaseFloat count_;
+
+ void Read(std::istream &is, bool binary);
+};
+
+
+/// GaussClusterable wraps Gaussian statistics in a form accessible
+/// to generic clustering algorithms.
+class GaussClusterable: public Clusterable {
+ public:
+ GaussClusterable(): count_(0.0), var_floor_(0.0) {}
+ GaussClusterable(int32 dim, BaseFloat var_floor):
+ count_(0.0), stats_(2, dim), var_floor_(var_floor) {}
+
+ GaussClusterable(const Vector<BaseFloat> &x_stats,
+ const Vector<BaseFloat> &x2_stats,
+ BaseFloat var_floor, BaseFloat count);
+
+ virtual std::string Type() const { return "gauss"; }
+ void AddStats(const VectorBase<BaseFloat> &vec, BaseFloat weight = 1.0);
+ virtual BaseFloat Objf() const;
+ virtual void SetZero();
+ virtual void Add(const Clusterable &other_in);
+ virtual void Sub(const Clusterable &other_in);
+ virtual BaseFloat Normalizer() const { return count_; }
+ virtual Clusterable *Copy() const;
+ virtual void Scale(BaseFloat f);
+ virtual void Write(std::ostream &os, bool binary) const;
+ virtual Clusterable *ReadNew(std::istream &is, bool binary) const;
+ virtual ~GaussClusterable() {}
+
+ BaseFloat count() const { return count_; }
+ // The next two functions are not const-correct, because of SubVector.
+ SubVector<double> x_stats() const { return stats_.Row(0); }
+ SubVector<double> x2_stats() const { return stats_.Row(1); }
+ private:
+ double count_;
+ Matrix<double> stats_; // two rows: sum, then sum-squared.
+ double var_floor_; // should be common for all objects created.
+
+ void Read(std::istream &is, bool binary);
+};
+
+/// @} end of "addtogroup clustering_group"
+
+inline void GaussClusterable::SetZero() {
+ count_ = 0;
+ stats_.SetZero();
+}
+
+inline GaussClusterable::GaussClusterable(const Vector<BaseFloat> &x_stats,
+ const Vector<BaseFloat> &x2_stats,
+ BaseFloat var_floor, BaseFloat count):
+ count_(count), stats_(2, x_stats.Dim()), var_floor_(var_floor) {
+ stats_.Row(0).CopyFromVec(x_stats);
+ stats_.Row(1).CopyFromVec(x2_stats);
+}
+
+
+/// VectorClusterable wraps vectors in a form accessible to generic clustering
+/// algorithms. Each vector is associated with a weight; these could be 1.0.
+/// The objective function (to be maximized) is the negated sum of squared
+/// distances from the cluster center to each vector, times that vector's
+/// weight.
+class VectorClusterable: public Clusterable {
+ public:
+ VectorClusterable(): weight_(0.0), sumsq_(0.0) {}
+
+ VectorClusterable(const Vector<BaseFloat> &vector,
+ BaseFloat weight);
+
+ virtual std::string Type() const { return "vector"; }
+ // Objf is negated weighted sum of squared distances.
+ virtual BaseFloat Objf() const;
+ virtual void SetZero() { weight_ = 0.0; sumsq_ = 0.0; stats_.Set(0.0); }
+ virtual void Add(const Clusterable &other_in);
+ virtual void Sub(const Clusterable &other_in);
+ virtual BaseFloat Normalizer() const { return weight_; }
+ virtual Clusterable *Copy() const;
+ virtual void Scale(BaseFloat f);
+ virtual void Write(std::ostream &os, bool binary) const;
+ virtual Clusterable *ReadNew(std::istream &is, bool binary) const;
+ virtual ~VectorClusterable() {}
+
+ private:
+ double weight_; // sum of weights of the source vectors. Never negative.
+ Vector<double> stats_; // Equals the weighted sum of the source vectors.
+ double sumsq_; // Equals the sum over all sources, of weight_ * vec.vec,
+ // where vec = stats_ / weight_. Used in computing
+ // the objective function.
+ void Read(std::istream &is, bool binary);
+};
+
+
+
+} // end namespace kaldi.
+
+#endif // KALDI_TREE_CLUSTERABLE_CLASSES_H_
diff --git a/kaldi_io/src/kaldi/tree/context-dep.h b/kaldi_io/src/kaldi/tree/context-dep.h
new file mode 100644
index 0000000..307fcd4
--- /dev/null
+++ b/kaldi_io/src/kaldi/tree/context-dep.h
@@ -0,0 +1,166 @@
+// tree/context-dep.h
+
+// Copyright 2009-2011 Microsoft Corporation
+
+// See ../../COPYING for clarification regarding multiple authors
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
+// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
+// MERCHANTABLITY OR NON-INFRINGEMENT.
+// See the Apache 2 License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef KALDI_TREE_CONTEXT_DEP_H_
+#define KALDI_TREE_CONTEXT_DEP_H_
+
+#include "itf/context-dep-itf.h"
+#include "tree/event-map.h"
+#include "matrix/matrix-lib.h"
+#include "tree/cluster-utils.h"
+
+/*
+ This header provides the declarations for the class ContextDependency, which inherits
+ from the interface class "ContextDependencyInterface" in itf/context-dep-itf.h.
+ This is basically a wrapper around an EventMap. The EventMap
+ (tree/event-map.h) declares most of the internals of the class, and the building routines are
+ in build-tree.h which uses build-tree-utils.h, which uses cluster-utils.h . */
+
+
+namespace kaldi {
+
+static const EventKeyType kPdfClass = -1; // The "name" to which we assign the
+// pdf-class (generally corresponds ot position in the HMM, zero-based);
+// must not be used for any other event. I.e. the value corresponding to
+// this key is the pdf-class (see hmm-topology.h for explanation of what this is).
+
+
+/* ContextDependency is quite a generic decision tree.
+
+ It does not actually do very much-- all the magic is in the EventMap object.
+ All this class does is to encode the phone context as a sequence of events, and
+ pass this to the EventMap object to turn into what it will interpret as a
+ vector of pdfs.
+
+ Different versions of the ContextDependency class that are written in the future may
+ have slightly different interfaces and pass more stuff in as events, to the
+ EventMap object.
+
+ In order to separate the process of training decision trees from the process
+ of actually using them, we do not put any training code into the ContextDependency class.
+ */
+class ContextDependency: public ContextDependencyInterface {
+ public:
+ virtual int32 ContextWidth() const { return N_; }
+ virtual int32 CentralPosition() const { return P_; }
+
+
+ /// returns success or failure; outputs pdf to pdf_id
+ virtual bool Compute(const std::vector<int32> &phoneseq,
+ int32 pdf_class, int32 *pdf_id) const;
+
+ virtual int32 NumPdfs() const {
+ // this routine could be simplified to return to_pdf_->MaxResult()+1. we're a
+ // bit more paranoid than that.
+ if (!to_pdf_) return 0;
+ EventAnswerType max_result = to_pdf_->MaxResult();
+ if (max_result < 0 ) return 0;
+ else return (int32) max_result+1;
+ }
+ virtual ContextDependencyInterface *Copy() const {
+ return new ContextDependency(N_, P_, to_pdf_->Copy());
+ }
+
+ /// Read context-dependency object from disk; throws on error
+ void Read (std::istream &is, bool binary);
+
+ // Constructor with no arguments; will normally be called
+ // prior to Read()
+ ContextDependency(): N_(0), P_(0), to_pdf_(NULL) { }
+
+ // Constructor takes ownership of pointers.
+ ContextDependency(int32 N, int32 P,
+ EventMap *to_pdf):
+ N_(N), P_(P), to_pdf_(to_pdf) { }
+ void Write (std::ostream &os, bool binary) const;
+
+ ~ContextDependency() { if (to_pdf_ != NULL) delete to_pdf_; }
+
+ const EventMap &ToPdfMap() const { return *to_pdf_; }
+
+ /// GetPdfInfo returns a vector indexed by pdf-id, saying for each pdf which
+ /// pairs of (phone, pdf-class) it can correspond to. (Usually just one).
+ /// c.f. hmm/hmm-topology.h for meaning of pdf-class.
+
+ void GetPdfInfo(const std::vector<int32> &phones, // list of phones
+ const std::vector<int32> &num_pdf_classes, // indexed by phone,
+ std::vector<std::vector<std::pair<int32, int32> > > *pdf_info)
+ const;
+
+ private:
+ int32 N_; //
+ int32 P_;
+ EventMap *to_pdf_; // owned here.
+
+ KALDI_DISALLOW_COPY_AND_ASSIGN(ContextDependency);
+};
+
+/// GenRandContextDependency is mainly of use for debugging. Phones must be sorted and uniq
+/// on input.
+/// @param phones [in] A vector of phone id's [must be sorted and uniq].
+/// @param ensure_all_covered [in] boolean argument; if true, GenRandContextDependency
+/// generates a context-dependency object that "works" for all phones [no gaps].
+/// @param num_pdf_classes [out] outputs a vector indexed by phone, of the number
+/// of pdf classes (e.g. states) for that phone.
+/// @return Returns the a context dependency object.
+ContextDependency *GenRandContextDependency(const std::vector<int32> &phones,
+ bool ensure_all_covered,
+ std::vector<int32> *num_pdf_classes);
+
+/// GenRandContextDependencyLarge is like GenRandContextDependency but generates a larger tree
+/// with specified N and P for use in "one-time" larger-scale tests.
+ContextDependency *GenRandContextDependencyLarge(const std::vector<int32> &phones,
+ int N, int P,
+ bool ensure_all_covered,
+ std::vector<int32> *num_pdf_classes);
+
+// MonophoneContextDependency() returns a new ContextDependency object that
+// corresponds to a monophone system.
+// The map phone2num_pdf_classes maps from the phone id to the number of
+// pdf-classes we have for that phone (e.g. 3, so the pdf-classes would be
+// 0, 1, 2).
+
+ContextDependency*
+MonophoneContextDependency(const std::vector<int32> phones,
+ const std::vector<int32> phone2num_pdf_classes);
+
+// MonophoneContextDependencyShared is as MonophoneContextDependency but lets
+// you define classes of phones which share pdfs (e.g. different stress-markers of a single
+// phone.) Each element of phone_classes is a set of phones that are in that class.
+ContextDependency*
+MonophoneContextDependencyShared(const std::vector<std::vector<int32> > phone_classes,
+ const std::vector<int32> phone2num_pdf_classes);
+
+
+// Important note:
+// Statistics for training decision trees will be of type:
+// std::vector<std::pair<EventType, Clusterable*> >
+// We don't make this a typedef as it doesn't add clarity.
+// they will be sorted and unique on the EventType member, which
+// itself is sorted and unique on the name (see event-map.h).
+
+// See build-tree.h for functions relating to actually building the decision trees.
+
+
+
+
+} // namespace Kaldi
+
+
+#endif
diff --git a/kaldi_io/src/kaldi/tree/event-map.h b/kaldi_io/src/kaldi/tree/event-map.h
new file mode 100644
index 0000000..07fcc2b
--- /dev/null
+++ b/kaldi_io/src/kaldi/tree/event-map.h
@@ -0,0 +1,365 @@
+// tree/event-map.h
+
+// Copyright 2009-2011 Microsoft Corporation; Haihua Xu
+
+// See ../../COPYING for clarification regarding multiple authors
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
+// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
+// MERCHANTABLITY OR NON-INFRINGEMENT.
+// See the Apache 2 License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef KALDI_TREE_EVENT_MAP_H_
+#define KALDI_TREE_EVENT_MAP_H_
+
+#include <vector>
+#include <map>
+#include <algorithm>
+#include "base/kaldi-common.h"
+#include "util/stl-utils.h"
+#include "util/const-integer-set.h"
+
+namespace kaldi {
+
+/// \defgroup event_map_group Event maps
+/// \ingroup tree_group
+/// See \ref tree_internals for overview, and specifically \ref treei_event_map.
+
+
+// Note RE negative values: some of this code will not work if things of type
+// EventValueType are negative. In particular, TableEventMap can't be used if
+// things of EventValueType are negative, and additionally TableEventMap won't
+// be efficient if things of EventValueType take on extremely large values. The
+// EventKeyType can be negative though.
+
+/// Things of type EventKeyType can take any value. The code does not assume they are contiguous.
+/// So values like -1, 1000000 and the like are acceptable.
+typedef int32 EventKeyType;
+
+/// Given current code, things of type EventValueType should generally be nonnegative and in a
+/// reasonably small range (e.g. not one million), as we sometimes construct vectors of the size:
+/// [largest value we saw for this key]. This deficiency may be fixed in future [would require
+/// modifying TableEventMap]
+typedef int32 EventValueType;
+
+/// As far as the event-map code itself is concerned, things of type EventAnswerType may take
+/// any value except kNoAnswer (== -1). However, some specific uses of EventMap (e.g. in
+/// build-tree-utils.h) assume these quantities are nonnegative.
+typedef int32 EventAnswerType;
+
+typedef std::vector<std::pair<EventKeyType, EventValueType> > EventType;
+// It is required to be sorted and have unique keys-- i.e. functions assume this when called
+// with this type.
+
+inline std::pair<EventKeyType, EventValueType> MakeEventPair (EventKeyType k, EventValueType v) {
+ return std::pair<EventKeyType, EventValueType>(k, v);
+}
+
+void WriteEventType(std::ostream &os, bool binary, const EventType &vec);
+void ReadEventType(std::istream &is, bool binary, EventType *vec);
+
+std::string EventTypeToString(const EventType &evec); // so we can print events out in error messages.
+
+struct EventMapVectorHash { // Hashing object for EventMapVector. Works for both pointers and references.
+ // Not used in event-map.{h, cc}
+ size_t operator () (const EventType &vec);
+ size_t operator () (const EventType *ptr) { return (*this)(*ptr); }
+};
+struct EventMapVectorEqual { // Equality object for EventType pointers-- test equality of underlying vector.
+ // Not used in event-map.{h, cc}
+ size_t operator () (const EventType *p1, const EventType *p2) { return (*p1 == *p2); }
+};
+
+
+/// A class that is capable of representing a generic mapping from
+/// EventType (which is a vector of (key, value) pairs) to
+/// EventAnswerType which is just an integer. See \ref tree_internals
+/// for overview.
+class EventMap {
+ public:
+ static void Check(const EventType &event); // will crash if not sorted and unique on key.
+ static bool Lookup(const EventType &event, EventKeyType key, EventValueType *ans);
+
+ // Maps events to the answer type. input must be sorted.
+ virtual bool Map(const EventType &event, EventAnswerType *ans) const = 0;
+
+ // MultiMap maps a partially specified set of events to the set of answers it might
+ // map to. It appends these to "ans". "ans" is
+ // **not guaranteed unique at output** if the
+ // tree contains duplicate answers at leaves -- you should sort & uniq afterwards.
+ // e.g.: SortAndUniq(ans).
+ virtual void MultiMap(const EventType &event, std::vector<EventAnswerType> *ans) const = 0;
+
+ // GetChildren() returns the EventMaps that are immediate children of this
+ // EventMap (if they exist), by putting them in *out. Useful for
+ // determining the structure of the event map.
+ virtual void GetChildren(std::vector<EventMap*> *out) const = 0;
+
+ // This Copy() does a deep copy of the event map.
+ // If new_leaves is nonempty when it reaches a leaf with value l s.t. new_leaves[l] != NULL,
+ // it replaces it with a copy of that EventMap. This makes it possible to extend and modify
+ // It's the way we do splits of trees, and clustering of trees. Think about this carefully, because
+ // the EventMap structure does not support modification of an existing tree. Do not be tempted
+ // to do this differently, because other kinds of mechanisms would get very messy and unextensible.
+ // Copy() is the only mechanism to modify a tree. It's similar to a kind of function composition.
+ // Copy() does not take ownership of the pointers in new_leaves (it uses the Copy() function of those
+ // EventMaps).
+ virtual EventMap *Copy(const std::vector<EventMap*> &new_leaves) const = 0;
+
+ EventMap *Copy() const { std::vector<EventMap*> new_leaves; return Copy(new_leaves); }
+
+ // The function MapValues() is intended to be used to map phone-sets between
+ // different integer representations. For all the keys in the set
+ // "keys_to_map", it will map the corresponding values using the map
+ // "value_map". Note: these values are the values in the key->value pairs of
+ // the EventMap, which really correspond to phones in the usual case; they are
+ // not the "answers" of the EventMap which correspond to clustered states. In
+ // case multiple values are mapped to the same value, it will try to deal with
+ // it gracefully where it can, but will crash if, for example, this would
+ // cause problems with the TableEventMap. It will also crash if any values
+ // used for keys in "keys_to_map" are not mapped by "value_map". This
+ // function is not currently used.
+ virtual EventMap *MapValues(
+ const unordered_set<EventKeyType> &keys_to_map,
+ const unordered_map<EventValueType,EventValueType> &value_map) const = 0;
+
+ // The function Prune() is like Copy(), except it removes parts of the tree
+ // that return only -1 (it will return NULL if this EventMap returns only -1).
+ // This is a mechanism to remove parts of the tree-- you would first use the
+ // Copy() function with a vector of EventMap*, and for the parts you don't
+ // want, you'd put a ConstantEventMap with -1; you'd then call
+ // Prune() on the result. This function is not currently used.
+ virtual EventMap *Prune() const = 0;
+
+ virtual EventAnswerType MaxResult() const { // child classes may override this for efficiency; here is basic version.
+ // returns -1 if nothing found.
+ std::vector<EventAnswerType> tmp; EventType empty_event;
+ MultiMap(empty_event, &tmp);
+ if (tmp.empty()) {
+ KALDI_WARN << "EventMap::MaxResult(), empty result";
+ return std::numeric_limits<EventAnswerType>::min();
+ }
+ else { return * std::max_element(tmp.begin(), tmp.end()); }
+ }
+
+ /// Write to stream.
+ virtual void Write(std::ostream &os, bool binary) = 0;
+
+ virtual ~EventMap() {}
+
+ /// a Write function that takes care of NULL pointers.
+ static void Write(std::ostream &os, bool binary, EventMap *emap);
+ /// a Read function that reads an arbitrary EventMap; also
+ /// works for NULL pointers.
+ static EventMap *Read(std::istream &is, bool binary);
+};
+
+
+class ConstantEventMap: public EventMap {
+ public:
+ virtual bool Map(const EventType &event, EventAnswerType *ans) const {
+ *ans = answer_;
+ return true;
+ }
+
+ virtual void MultiMap(const EventType &,
+ std::vector<EventAnswerType> *ans) const {
+ ans->push_back(answer_);
+ }
+
+ virtual void GetChildren(std::vector<EventMap*> *out) const { out->clear(); }
+
+ virtual EventMap *Copy(const std::vector<EventMap*> &new_leaves) const {
+ if (answer_ < 0 || answer_ >= (EventAnswerType)new_leaves.size() ||
+ new_leaves[answer_] == NULL)
+ return new ConstantEventMap(answer_);
+ else return new_leaves[answer_]->Copy();
+ }
+
+ virtual EventMap *MapValues(
+ const unordered_set<EventKeyType> &keys_to_map,
+ const unordered_map<EventValueType,EventValueType> &value_map) const {
+ return new ConstantEventMap(answer_);
+ }
+
+ virtual EventMap *Prune() const {
+ return (answer_ == -1 ? NULL : new ConstantEventMap(answer_));
+ }
+
+ explicit ConstantEventMap(EventAnswerType answer): answer_(answer) { }
+
+ virtual void Write(std::ostream &os, bool binary);
+ static ConstantEventMap *Read(std::istream &is, bool binary);
+ private:
+ EventAnswerType answer_;
+ KALDI_DISALLOW_COPY_AND_ASSIGN(ConstantEventMap);
+};
+
+class TableEventMap: public EventMap {
+ public:
+
+ virtual bool Map(const EventType &event, EventAnswerType *ans) const {
+ EventValueType tmp; *ans = -1; // means no answer
+ if (Lookup(event, key_, &tmp) && tmp >= 0
+ && tmp < (EventValueType)table_.size() && table_[tmp] != NULL) {
+ return table_[tmp]->Map(event, ans);
+ }
+ return false;
+ }
+
+ virtual void GetChildren(std::vector<EventMap*> *out) const {
+ out->clear();
+ for (size_t i = 0; i<table_.size(); i++)
+ if (table_[i] != NULL) out->push_back(table_[i]);
+ }
+
+ virtual void MultiMap(const EventType &event, std::vector<EventAnswerType> *ans) const {
+ EventValueType tmp;
+ if (Lookup(event, key_, &tmp)) {
+ if (tmp >= 0 && tmp < (EventValueType)table_.size() && table_[tmp] != NULL)
+ return table_[tmp]->MultiMap(event, ans);
+ // else no answers.
+ } else { // all answers are possible if no such key.
+ for (size_t i = 0;i < table_.size();i++)
+ if (table_[i] != NULL) table_[i]->MultiMap(event, ans); // append.
+ }
+ }
+
+ virtual EventMap *Prune() const;
+
+ virtual EventMap *MapValues(
+ const unordered_set<EventKeyType> &keys_to_map,
+ const unordered_map<EventValueType,EventValueType> &value_map) const;
+
+ /// Takes ownership of pointers.
+ explicit TableEventMap(EventKeyType key, const std::vector<EventMap*> &table): key_(key), table_(table) {}
+ /// Takes ownership of pointers.
+ explicit TableEventMap(EventKeyType key, const std::map<EventValueType, EventMap*> &map_in);
+ /// This initializer creates a ConstantEventMap for each value in the map.
+ explicit TableEventMap(EventKeyType key, const std::map<EventValueType, EventAnswerType> &map_in);
+
+ virtual void Write(std::ostream &os, bool binary);
+ static TableEventMap *Read(std::istream &is, bool binary);
+
+ virtual EventMap *Copy(const std::vector<EventMap*> &new_leaves) const {
+ std::vector<EventMap*> new_table_(table_.size(), NULL);
+ for (size_t i = 0;i<table_.size();i++) if (table_[i]) new_table_[i]=table_[i]->Copy(new_leaves);
+ return new TableEventMap(key_, new_table_);
+ }
+ virtual ~TableEventMap() {
+ DeletePointers(&table_);
+ }
+ private:
+ EventKeyType key_;
+ std::vector<EventMap*> table_;
+ KALDI_DISALLOW_COPY_AND_ASSIGN(TableEventMap);
+};
+
+
+
+
+class SplitEventMap: public EventMap { // A decision tree [non-leaf] node.
+ public:
+
+ virtual bool Map(const EventType &event, EventAnswerType *ans) const {
+ EventValueType value;
+ if (Lookup(event, key_, &value)) {
+ // if (std::binary_search(yes_set_.begin(), yes_set_.end(), value)) {
+ if (yes_set_.count(value)) {
+ return yes_->Map(event, ans);
+ }
+ return no_->Map(event, ans);
+ }
+ return false;
+ }
+
+ virtual void MultiMap(const EventType &event, std::vector<EventAnswerType> *ans) const {
+ EventValueType tmp;
+ if (Lookup(event, key_, &tmp)) {
+ if (std::binary_search(yes_set_.begin(), yes_set_.end(), tmp))
+ yes_->MultiMap(event, ans);
+ else
+ no_->MultiMap(event, ans);
+ } else { // both yes and no contribute.
+ yes_->MultiMap(event, ans);
+ no_->MultiMap(event, ans);
+ }
+ }
+
+ virtual void GetChildren(std::vector<EventMap*> *out) const {
+ out->clear();
+ out->push_back(yes_);
+ out->push_back(no_);
+ }
+
+ virtual EventMap *Copy(const std::vector<EventMap*> &new_leaves) const {
+ return new SplitEventMap(key_, yes_set_, yes_->Copy(new_leaves), no_->Copy(new_leaves));
+ }
+
+ virtual void Write(std::ostream &os, bool binary);
+ static SplitEventMap *Read(std::istream &is, bool binary);
+
+ virtual EventMap *Prune() const;
+
+ virtual EventMap *MapValues(
+ const unordered_set<EventKeyType> &keys_to_map,
+ const unordered_map<EventValueType,EventValueType> &value_map) const;
+
+ virtual ~SplitEventMap() { Destroy(); }
+
+ /// This constructor takes ownership of the "yes" and "no" arguments.
+ SplitEventMap(EventKeyType key, const std::vector<EventValueType> &yes_set,
+ EventMap *yes, EventMap *no): key_(key), yes_set_(yes_set), yes_(yes), no_(no) {
+ KALDI_PARANOID_ASSERT(IsSorted(yes_set));
+ KALDI_ASSERT(yes_ != NULL && no_ != NULL);
+ }
+
+
+ private:
+ /// This constructor used in the Copy() function.
+ SplitEventMap(EventKeyType key, const ConstIntegerSet<EventValueType> &yes_set,
+ EventMap *yes, EventMap *no): key_(key), yes_set_(yes_set), yes_(yes), no_(no) {
+ KALDI_ASSERT(yes_ != NULL && no_ != NULL);
+ }
+ void Destroy() {
+ delete yes_; delete no_;
+ }
+ EventKeyType key_;
+ // std::vector<EventValueType> yes_set_;
+ ConstIntegerSet<EventValueType> yes_set_; // more efficient Map function.
+ EventMap *yes_; // owned here.
+ EventMap *no_; // owned here.
+ SplitEventMap &operator = (const SplitEventMap &other); // Disallow.
+};
+
+/**
+ This function gets the tree structure of the EventMap "map" in a convenient form.
+ If "map" corresponds to a tree structure (not necessarily binary) with leaves
+ uniquely numbered from 0 to num_leaves-1, then the function will return true,
+ output "num_leaves", and set "parent" to a vector of size equal to the number of
+ nodes in the tree (nonleaf and leaf), where each index corresponds to a node
+ and the leaf indices correspond to the values returned by the EventMap from
+ that leaf; for an index i, parent[i] equals the parent of that node in the tree
+ structure, where parent[i] > i, except for the last (root) node where parent[i] == i.
+ If the EventMap does not have this structure (e.g. if multiple different leaf nodes share
+ the same number), then it will return false.
+*/
+
+bool GetTreeStructure(const EventMap &map,
+ int32 *num_leaves,
+ std::vector<int32> *parents);
+
+
+/// @} end "addtogroup event_map_group"
+
+}
+
+#endif
diff --git a/kaldi_io/src/kaldi/tree/tree-renderer.h b/kaldi_io/src/kaldi/tree/tree-renderer.h
new file mode 100644
index 0000000..5e0b0d8
--- /dev/null
+++ b/kaldi_io/src/kaldi/tree/tree-renderer.h
@@ -0,0 +1,84 @@
+// tree/tree-renderer.h
+
+// Copyright 2012 Vassil Panayotov
+
+// See ../../COPYING for clarification regarding multiple authors
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
+// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
+// MERCHANTABLITY OR NON-INFRINGEMENT.
+// See the Apache 2 License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef KALDI_TREE_TREE_RENDERER_H_
+#define KALDI_TREE_TREE_RENDERER_H_
+
+#include "base/kaldi-common.h"
+#include "tree/event-map.h"
+#include "util/common-utils.h"
+#include "hmm/transition-model.h"
+#include "fst/fstlib.h"
+
+namespace kaldi {
+
+// Parses a decision tree file and outputs its description in GraphViz format
+class TreeRenderer {
+ public:
+ const static int32 kEdgeWidth; // normal width of the edges and state contours
+ const static int32 kEdgeWidthQuery; // edge and state width when in query
+ const static std::string kEdgeColor; // normal color for states and edges
+ const static std::string kEdgeColorQuery; // edge and state color when in query
+
+ TreeRenderer(std::istream &is, bool binary, std::ostream &os,
+ fst::SymbolTable &phone_syms, bool use_tooltips)
+ : phone_syms_(phone_syms), is_(is), out_(os), binary_(binary),
+ N_(-1), use_tooltips_(use_tooltips), next_id_(0) {}
+
+ // Renders the tree and if the "query" parameter is not NULL
+ // a distinctly colored trace corresponding to the event.
+ void Render(const EventType *query);
+
+ private:
+ // Looks-up the next token from the stream and invokes
+ // the appropriate render method to visualize it
+ void RenderSubTree(const EventType *query, int32 id);
+
+ // Renders a leaf node (constant event map)
+ void RenderConstant(const EventType *query, int32 id);
+
+ // Renders a split event map node and the edges to the nodes
+ // representing YES and NO sets
+ void RenderSplit(const EventType *query, int32 id);
+
+ // Renders a table event map node and the edges to its (non-null) children
+ void RenderTable(const EventType *query, int32 id);
+
+ // Makes a comma-separated string from the elements of a set of identifiers
+ // If the identifiers represent phones, their symbolic representations are used
+ std::string MakeEdgeLabel(const EventKeyType &key,
+ const ConstIntegerSet<EventValueType> &intset);
+
+ // Writes the GraphViz representation of a non-leaf node to the out stream
+ // A question about a phone from the context window or about pdf-class
+ // is used as a label.
+ void RenderNonLeaf(int32 id, const EventKeyType &key, bool in_query);
+
+ fst::SymbolTable &phone_syms_; // phone symbols to be used as edge labels
+ std::istream &is_; // the stream from which the tree is read
+ std::ostream &out_; // the GraphViz representation is written to this stream
+ bool binary_; // is the input stream binary?
+ int32 N_, P_; // context-width and central position
+ bool use_tooltips_; // use tooltips(useful in e.g. SVG) instead of labels
+ int32 next_id_; // the first unused GraphViz node ID
+};
+
+} // namespace kaldi
+
+#endif // KALDI_TREE_TREE_RENDERER_H_