summaryrefslogtreecommitdiff
path: root/kaldi_io/src/kaldi/hmm/transition-model.h
diff options
context:
space:
mode:
Diffstat (limited to 'kaldi_io/src/kaldi/hmm/transition-model.h')
-rw-r--r--kaldi_io/src/kaldi/hmm/transition-model.h345
1 files changed, 345 insertions, 0 deletions
diff --git a/kaldi_io/src/kaldi/hmm/transition-model.h b/kaldi_io/src/kaldi/hmm/transition-model.h
new file mode 100644
index 0000000..ccc4f11
--- /dev/null
+++ b/kaldi_io/src/kaldi/hmm/transition-model.h
@@ -0,0 +1,345 @@
+// hmm/transition-model.h
+
+// Copyright 2009-2012 Microsoft Corporation
+// Johns Hopkins University (author: Guoguo Chen)
+
+// See ../../COPYING for clarification regarding multiple authors
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
+// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
+// MERCHANTABLITY OR NON-INFRINGEMENT.
+// See the Apache 2 License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef KALDI_HMM_TRANSITION_MODEL_H_
+#define KALDI_HMM_TRANSITION_MODEL_H_
+
+#include "base/kaldi-common.h"
+#include "tree/context-dep.h"
+#include "util/const-integer-set.h"
+#include "fst/fst-decl.h" // forward declarations.
+#include "hmm/hmm-topology.h"
+#include "itf/options-itf.h"
+
+namespace kaldi {
+
+/// \addtogroup hmm_group
+/// @{
+
+// The class TransitionModel is a repository for the transition probabilities.
+// It also handles certain integer mappings.
+// The basic model is as follows. Each phone has a HMM topology defined in
+// hmm-topology.h. Each HMM-state of each of these phones has a number of
+// transitions (and final-probs) out of it. Each HMM-state defined in the
+// HmmTopology class has an associated "pdf_class". This gets replaced with
+// an actual pdf-id via the tree. The transition model associates the
+// transition probs with the (phone, HMM-state, pdf-id). We associate with
+// each such triple a transition-state. Each
+// transition-state has a number of associated probabilities to estimate;
+// this depends on the number of transitions/final-probs in the topology for
+// that (phone, HMM-state). Each probability has an associated transition-index.
+// We associate with each (transition-state, transition-index) a unique transition-id.
+// Each individual probability estimated by the transition-model is asociated with a
+// transition-id.
+//
+// List of the various types of quantity referred to here and what they mean:
+// phone: a phone index (1, 2, 3 ...)
+// HMM-state: a number (0, 1, 2...) that indexes TopologyEntry (see hmm-topology.h)
+// pdf-id: a number output by the Compute function of ContextDependency (it
+// indexes pdf's). Zero-based.
+// transition-state: the states for which we estimate transition probabilities for transitions
+// out of them. In some topologies, will map one-to-one with pdf-ids.
+// One-based, since it appears on FSTs.
+// transition-index: identifier of a transition (or final-prob) in the HMM. Indexes the
+// "transitions" vector in HmmTopology::HmmState. [if it is out of range,
+// equal to transitions.size(), it refers to the final-prob.]
+// Zero-based.
+// transition-id: identifier of a unique parameter of the TransitionModel.
+// Associated with a (transition-state, transition-index) pair.
+// One-based, since it appears on FSTs.
+//
+// List of the possible mappings TransitionModel can do:
+// (phone, HMM-state, pdf-id) -> transition-state
+// (transition-state, transition-index) -> transition-id
+// Reverse mappings:
+// transition-id -> transition-state
+// transition-id -> transition-index
+// transition-state -> phone
+// transition-state -> HMM-state
+// transition-state -> pdf-id
+//
+// The main things the TransitionModel object can do are:
+// Get initialized (need ContextDependency and HmmTopology objects).
+// Read/write.
+// Update [given a vector of counts indexed by transition-id].
+// Do the various integer mappings mentioned above.
+// Get the probability (or log-probability) associated with a particular transition-id.
+
+
+// Note: this was previously called TransitionUpdateConfig.
+struct MleTransitionUpdateConfig {
+ BaseFloat floor;
+ BaseFloat mincount;
+ bool share_for_pdfs; // If true, share all transition parameters that have the same pdf.
+ MleTransitionUpdateConfig(BaseFloat floor = 0.01,
+ BaseFloat mincount = 5.0,
+ bool share_for_pdfs = false):
+ floor(floor), mincount(mincount), share_for_pdfs(share_for_pdfs) {}
+
+ void Register (OptionsItf *po) {
+ po->Register("transition-floor", &floor,
+ "Floor for transition probabilities");
+ po->Register("transition-min-count", &mincount,
+ "Minimum count required to update transitions from a state");
+ po->Register("share-for-pdfs", &share_for_pdfs,
+ "If true, share all transition parameters where the states "
+ "have the same pdf.");
+ }
+};
+
+struct MapTransitionUpdateConfig {
+ BaseFloat tau;
+ bool share_for_pdfs; // If true, share all transition parameters that have the same pdf.
+ MapTransitionUpdateConfig(): tau(5.0), share_for_pdfs(false) { }
+
+ void Register (OptionsItf *po) {
+ po->Register("transition-tau", &tau, "Tau value for MAP estimation of transition "
+ "probabilities.");
+ po->Register("share-for-pdfs", &share_for_pdfs,
+ "If true, share all transition parameters where the states "
+ "have the same pdf.");
+ }
+};
+
+class TransitionModel {
+
+ public:
+ /// Initialize the object [e.g. at the start of training].
+ /// The class keeps a copy of the HmmTopology object, but not
+ /// the ContextDependency object.
+ TransitionModel(const ContextDependency &ctx_dep,
+ const HmmTopology &hmm_topo);
+
+
+ /// Constructor that takes no arguments: typically used prior to calling Read.
+ TransitionModel() { }
+
+ void Read(std::istream &is, bool binary); // note, no symbol table: topo object always read/written w/o symbols.
+ void Write(std::ostream &os, bool binary) const;
+
+
+ /// return reference to HMM-topology object.
+ const HmmTopology &GetTopo() const { return topo_; }
+
+ /// \name Integer mapping functions
+ /// @{
+
+ int32 TripleToTransitionState(int32 phone, int32 hmm_state, int32 pdf) const;
+ int32 PairToTransitionId(int32 trans_state, int32 trans_index) const;
+ int32 TransitionIdToTransitionState(int32 trans_id) const;
+ int32 TransitionIdToTransitionIndex(int32 trans_id) const;
+ int32 TransitionStateToPhone(int32 trans_state) const;
+ int32 TransitionStateToHmmState(int32 trans_state) const;
+ int32 TransitionStateToPdf(int32 trans_state) const;
+ int32 SelfLoopOf(int32 trans_state) const; // returns the self-loop transition-id, or zero if
+ // this state doesn't have a self-loop.
+
+ inline int32 TransitionIdToPdf(int32 trans_id) const;
+ int32 TransitionIdToPhone(int32 trans_id) const;
+ int32 TransitionIdToPdfClass(int32 trans_id) const;
+ int32 TransitionIdToHmmState(int32 trans_id) const;
+
+ /// @}
+
+ bool IsFinal(int32 trans_id) const; // returns true if this trans_id goes to the final state
+ // (which is bound to be nonemitting).
+ bool IsSelfLoop(int32 trans_id) const; // return true if this trans_id corresponds to a self-loop.
+
+ /// Returns the total number of transition-ids (note, these are one-based).
+ inline int32 NumTransitionIds() const { return id2state_.size()-1; }
+
+ /// Returns the number of transition-indices for a particular transition-state.
+ /// Note: "Indices" is the plural of "index". Index is not the same as "id",
+ /// here. A transition-index is a zero-based offset into the transitions
+ /// out of a particular transition state.
+ int32 NumTransitionIndices(int32 trans_state) const;
+
+ /// Returns the total number of transition-states (note, these are one-based).
+ int32 NumTransitionStates() const { return triples_.size(); }
+
+ // NumPdfs() actually returns the highest-numbered pdf we ever saw, plus one.
+ // In normal cases this should equal the number of pdfs in the system, but if you
+ // initialized this object with fewer than all the phones, and it happens that
+ // an unseen phone has the highest-numbered pdf, this might be different.
+ int32 NumPdfs() const { return num_pdfs_; }
+
+ // This loops over the triples and finds the highest phone index present. If
+ // the FST symbol table for the phones is created in the expected way, i.e.:
+ // starting from 1 (<eps> is 0) and numbered contiguously till the last phone,
+ // this will be the total number of phones.
+ int32 NumPhones() const;
+
+ /// Returns a sorted, unique list of phones.
+ const std::vector<int32> &GetPhones() const { return topo_.GetPhones(); }
+
+ // Transition-parameter-getting functions:
+ BaseFloat GetTransitionProb(int32 trans_id) const;
+ BaseFloat GetTransitionLogProb(int32 trans_id) const;
+
+ // The following functions are more specialized functions for getting
+ // transition probabilities, that are provided for convenience.
+
+ /// Returns the log-probability of a particular non-self-loop transition
+ /// after subtracting the probability mass of the self-loop and renormalizing;
+ /// will crash if called on a self-loop. Specifically:
+ /// for non-self-loops it returns the log of that prob divided by (1 minus
+ /// self-loop-prob-for-that-state).
+ BaseFloat GetTransitionLogProbIgnoringSelfLoops(int32 trans_id) const;
+
+ /// Returns the log-prob of the non-self-loop probability
+ /// mass for this transition state. (you can get the self-loop prob, if a self-loop
+ /// exists, by calling GetTransitionLogProb(SelfLoopOf(trans_state)).
+ BaseFloat GetNonSelfLoopLogProb(int32 trans_state) const;
+
+ /// Does Maximum Likelihood estimation. The stats are counts/weights, indexed
+ /// by transition-id. This was previously called Update().
+ void MleUpdate(const Vector<double> &stats,
+ const MleTransitionUpdateConfig &cfg,
+ BaseFloat *objf_impr_out,
+ BaseFloat *count_out);
+
+ /// Does Maximum A Posteriori (MAP) estimation. The stats are counts/weights,
+ /// indexed by transition-id.
+ void MapUpdate(const Vector<double> &stats,
+ const MapTransitionUpdateConfig &cfg,
+ BaseFloat *objf_impr_out,
+ BaseFloat *count_out);
+
+ /// Print will print the transition model in a human-readable way, for purposes of human
+ /// inspection. The "occs" are optional (they are indexed by pdf-id).
+ void Print(std::ostream &os,
+ const std::vector<std::string> &phone_names,
+ const Vector<double> *occs = NULL);
+
+
+ void InitStats(Vector<double> *stats) const { stats->Resize(NumTransitionIds()+1); }
+
+ void Accumulate(BaseFloat prob, int32 trans_id, Vector<double> *stats) const {
+ KALDI_ASSERT(trans_id <= NumTransitionIds());
+ (*stats)(trans_id) += prob;
+ // This is trivial and doesn't require class members, but leaves us more open
+ // to design changes than doing it manually.
+ }
+
+ /// returns true if all the integer class members are identical (but does not
+ /// compare the transition probabilities.
+ bool Compatible(const TransitionModel &other) const;
+
+ private:
+ void MleUpdateShared(const Vector<double> &stats,
+ const MleTransitionUpdateConfig &cfg,
+ BaseFloat *objf_impr_out, BaseFloat *count_out);
+ void MapUpdateShared(const Vector<double> &stats,
+ const MapTransitionUpdateConfig &cfg,
+ BaseFloat *objf_impr_out, BaseFloat *count_out);
+ void ComputeTriples(const ContextDependency &ctx_dep); // called from constructor. initializes triples_.
+ void ComputeDerived(); // called from constructor and Read function: computes state2id_ and id2state_.
+ void ComputeDerivedOfProbs(); // computes quantities derived from log-probs (currently just
+ // non_self_loop_log_probs_; called whenever log-probs change.
+ void InitializeProbs(); // called from constructor.
+ void Check() const;
+
+ struct Triple {
+ int32 phone;
+ int32 hmm_state;
+ int32 pdf;
+ Triple() { }
+ Triple(int32 phone, int32 hmm_state, int32 pdf):
+ phone(phone), hmm_state(hmm_state), pdf(pdf) { }
+ bool operator < (const Triple &other) const {
+ if (phone < other.phone) return true;
+ else if (phone > other.phone) return false;
+ else if (hmm_state < other.hmm_state) return true;
+ else if (hmm_state > other.hmm_state) return false;
+ else return pdf < other.pdf;
+ }
+ bool operator == (const Triple &other) const {
+ return (phone == other.phone && hmm_state == other.hmm_state
+ && pdf == other.pdf);
+ }
+ };
+
+ HmmTopology topo_;
+
+ /// Triples indexed by transition state minus one;
+ /// the triples are in sorted order which allows us to do the reverse mapping from
+ /// triple to transition state
+ std::vector<Triple> triples_;
+
+ /// Gives the first transition_id of each transition-state; indexed by
+ /// the transition-state. Array indexed 1..num-transition-states+1 (the last one
+ /// is needed so we can know the num-transitions of the last transition-state.
+ std::vector<int32> state2id_;
+
+ /// For each transition-id, the corresponding transition
+ /// state (indexed by transition-id).
+ std::vector<int32> id2state_;
+
+ /// For each transition-id, the corresponding log-prob. Indexed by transition-id.
+ Vector<BaseFloat> log_probs_;
+
+ /// For each transition-state, the log of (1 - self-loop-prob). Indexed by
+ /// transition-state.
+ Vector<BaseFloat> non_self_loop_log_probs_;
+
+ /// This is actually one plus the highest-numbered pdf we ever got back from the
+ /// tree (but the tree numbers pdfs contiguously from zero so this is the number
+ /// of pdfs).
+ int32 num_pdfs_;
+
+
+ DISALLOW_COPY_AND_ASSIGN(TransitionModel);
+
+};
+
+inline int32 TransitionModel::TransitionIdToPdf(int32 trans_id) const {
+ // If a lot of time is spent here we may create an extra array
+ // to handle this.
+ KALDI_ASSERT(static_cast<size_t>(trans_id) < id2state_.size() &&
+ "Likely graph/model mismatch (graph built from wrong model?)");
+ int32 trans_state = id2state_[trans_id];
+ return triples_[trans_state-1].pdf;
+}
+
+/// Works out which pdfs might correspond to the given phones. Will return true
+/// if these pdfs correspond *just* to these phones, false if these pdfs are also
+/// used by other phones.
+/// @param trans_model [in] Transition-model used to work out this information
+/// @param phones [in] A sorted, uniq vector that represents a set of phones
+/// @param pdfs [out] Will be set to a sorted, uniq list of pdf-ids that correspond
+/// to one of this set of phones.
+/// @return Returns true if all of the pdfs output to "pdfs" correspond to phones from
+/// just this set (false if they may be shared with phones outside this set).
+bool GetPdfsForPhones(const TransitionModel &trans_model,
+ const std::vector<int32> &phones,
+ std::vector<int32> *pdfs);
+
+/// Works out which phones might correspond to the given pdfs. Similar to the
+/// above GetPdfsForPhones(, ,)
+bool GetPhonesForPdfs(const TransitionModel &trans_model,
+ const std::vector<int32> &pdfs,
+ std::vector<int32> *phones);
+/// @}
+
+
+} // end namespace kaldi
+
+
+#endif