// hmm/transition-model.h // Copyright 2009-2012 Microsoft Corporation // Johns Hopkins University (author: Guoguo Chen) // See ../../COPYING for clarification regarding multiple authors // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY // KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED // WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, // MERCHANTABLITY OR NON-INFRINGEMENT. // See the Apache 2 License for the specific language governing permissions and // limitations under the License. #ifndef KALDI_HMM_TRANSITION_MODEL_H_ #define KALDI_HMM_TRANSITION_MODEL_H_ #include "base/kaldi-common.h" #include "tree/context-dep.h" #include "util/const-integer-set.h" #include "fst/fst-decl.h" // forward declarations. #include "hmm/hmm-topology.h" #include "itf/options-itf.h" namespace kaldi { /// \addtogroup hmm_group /// @{ // The class TransitionModel is a repository for the transition probabilities. // It also handles certain integer mappings. // The basic model is as follows. Each phone has a HMM topology defined in // hmm-topology.h. Each HMM-state of each of these phones has a number of // transitions (and final-probs) out of it. Each HMM-state defined in the // HmmTopology class has an associated "pdf_class". This gets replaced with // an actual pdf-id via the tree. The transition model associates the // transition probs with the (phone, HMM-state, pdf-id). We associate with // each such triple a transition-state. Each // transition-state has a number of associated probabilities to estimate; // this depends on the number of transitions/final-probs in the topology for // that (phone, HMM-state). Each probability has an associated transition-index. // We associate with each (transition-state, transition-index) a unique transition-id. // Each individual probability estimated by the transition-model is asociated with a // transition-id. // // List of the various types of quantity referred to here and what they mean: // phone: a phone index (1, 2, 3 ...) // HMM-state: a number (0, 1, 2...) that indexes TopologyEntry (see hmm-topology.h) // pdf-id: a number output by the Compute function of ContextDependency (it // indexes pdf's). Zero-based. // transition-state: the states for which we estimate transition probabilities for transitions // out of them. In some topologies, will map one-to-one with pdf-ids. // One-based, since it appears on FSTs. // transition-index: identifier of a transition (or final-prob) in the HMM. Indexes the // "transitions" vector in HmmTopology::HmmState. [if it is out of range, // equal to transitions.size(), it refers to the final-prob.] // Zero-based. // transition-id: identifier of a unique parameter of the TransitionModel. // Associated with a (transition-state, transition-index) pair. // One-based, since it appears on FSTs. // // List of the possible mappings TransitionModel can do: // (phone, HMM-state, pdf-id) -> transition-state // (transition-state, transition-index) -> transition-id // Reverse mappings: // transition-id -> transition-state // transition-id -> transition-index // transition-state -> phone // transition-state -> HMM-state // transition-state -> pdf-id // // The main things the TransitionModel object can do are: // Get initialized (need ContextDependency and HmmTopology objects). // Read/write. // Update [given a vector of counts indexed by transition-id]. // Do the various integer mappings mentioned above. // Get the probability (or log-probability) associated with a particular transition-id. // Note: this was previously called TransitionUpdateConfig. struct MleTransitionUpdateConfig { BaseFloat floor; BaseFloat mincount; bool share_for_pdfs; // If true, share all transition parameters that have the same pdf. MleTransitionUpdateConfig(BaseFloat floor = 0.01, BaseFloat mincount = 5.0, bool share_for_pdfs = false): floor(floor), mincount(mincount), share_for_pdfs(share_for_pdfs) {} void Register (OptionsItf *po) { po->Register("transition-floor", &floor, "Floor for transition probabilities"); po->Register("transition-min-count", &mincount, "Minimum count required to update transitions from a state"); po->Register("share-for-pdfs", &share_for_pdfs, "If true, share all transition parameters where the states " "have the same pdf."); } }; struct MapTransitionUpdateConfig { BaseFloat tau; bool share_for_pdfs; // If true, share all transition parameters that have the same pdf. MapTransitionUpdateConfig(): tau(5.0), share_for_pdfs(false) { } void Register (OptionsItf *po) { po->Register("transition-tau", &tau, "Tau value for MAP estimation of transition " "probabilities."); po->Register("share-for-pdfs", &share_for_pdfs, "If true, share all transition parameters where the states " "have the same pdf."); } }; class TransitionModel { public: /// Initialize the object [e.g. at the start of training]. /// The class keeps a copy of the HmmTopology object, but not /// the ContextDependency object. TransitionModel(const ContextDependency &ctx_dep, const HmmTopology &hmm_topo); /// Constructor that takes no arguments: typically used prior to calling Read. TransitionModel() { } void Read(std::istream &is, bool binary); // note, no symbol table: topo object always read/written w/o symbols. void Write(std::ostream &os, bool binary) const; /// return reference to HMM-topology object. const HmmTopology &GetTopo() const { return topo_; } /// \name Integer mapping functions /// @{ int32 TripleToTransitionState(int32 phone, int32 hmm_state, int32 pdf) const; int32 PairToTransitionId(int32 trans_state, int32 trans_index) const; int32 TransitionIdToTransitionState(int32 trans_id) const; int32 TransitionIdToTransitionIndex(int32 trans_id) const; int32 TransitionStateToPhone(int32 trans_state) const; int32 TransitionStateToHmmState(int32 trans_state) const; int32 TransitionStateToPdf(int32 trans_state) const; int32 SelfLoopOf(int32 trans_state) const; // returns the self-loop transition-id, or zero if // this state doesn't have a self-loop. inline int32 TransitionIdToPdf(int32 trans_id) const; int32 TransitionIdToPhone(int32 trans_id) const; int32 TransitionIdToPdfClass(int32 trans_id) const; int32 TransitionIdToHmmState(int32 trans_id) const; /// @} bool IsFinal(int32 trans_id) const; // returns true if this trans_id goes to the final state // (which is bound to be nonemitting). bool IsSelfLoop(int32 trans_id) const; // return true if this trans_id corresponds to a self-loop. /// Returns the total number of transition-ids (note, these are one-based). inline int32 NumTransitionIds() const { return id2state_.size()-1; } /// Returns the number of transition-indices for a particular transition-state. /// Note: "Indices" is the plural of "index". Index is not the same as "id", /// here. A transition-index is a zero-based offset into the transitions /// out of a particular transition state. int32 NumTransitionIndices(int32 trans_state) const; /// Returns the total number of transition-states (note, these are one-based). int32 NumTransitionStates() const { return triples_.size(); } // NumPdfs() actually returns the highest-numbered pdf we ever saw, plus one. // In normal cases this should equal the number of pdfs in the system, but if you // initialized this object with fewer than all the phones, and it happens that // an unseen phone has the highest-numbered pdf, this might be different. int32 NumPdfs() const { return num_pdfs_; } // This loops over the triples and finds the highest phone index present. If // the FST symbol table for the phones is created in the expected way, i.e.: // starting from 1 ( is 0) and numbered contiguously till the last phone, // this will be the total number of phones. int32 NumPhones() const; /// Returns a sorted, unique list of phones. const std::vector &GetPhones() const { return topo_.GetPhones(); } // Transition-parameter-getting functions: BaseFloat GetTransitionProb(int32 trans_id) const; BaseFloat GetTransitionLogProb(int32 trans_id) const; // The following functions are more specialized functions for getting // transition probabilities, that are provided for convenience. /// Returns the log-probability of a particular non-self-loop transition /// after subtracting the probability mass of the self-loop and renormalizing; /// will crash if called on a self-loop. Specifically: /// for non-self-loops it returns the log of that prob divided by (1 minus /// self-loop-prob-for-that-state). BaseFloat GetTransitionLogProbIgnoringSelfLoops(int32 trans_id) const; /// Returns the log-prob of the non-self-loop probability /// mass for this transition state. (you can get the self-loop prob, if a self-loop /// exists, by calling GetTransitionLogProb(SelfLoopOf(trans_state)). BaseFloat GetNonSelfLoopLogProb(int32 trans_state) const; /// Does Maximum Likelihood estimation. The stats are counts/weights, indexed /// by transition-id. This was previously called Update(). void MleUpdate(const Vector &stats, const MleTransitionUpdateConfig &cfg, BaseFloat *objf_impr_out, BaseFloat *count_out); /// Does Maximum A Posteriori (MAP) estimation. The stats are counts/weights, /// indexed by transition-id. void MapUpdate(const Vector &stats, const MapTransitionUpdateConfig &cfg, BaseFloat *objf_impr_out, BaseFloat *count_out); /// Print will print the transition model in a human-readable way, for purposes of human /// inspection. The "occs" are optional (they are indexed by pdf-id). void Print(std::ostream &os, const std::vector &phone_names, const Vector *occs = NULL); void InitStats(Vector *stats) const { stats->Resize(NumTransitionIds()+1); } void Accumulate(BaseFloat prob, int32 trans_id, Vector *stats) const { KALDI_ASSERT(trans_id <= NumTransitionIds()); (*stats)(trans_id) += prob; // This is trivial and doesn't require class members, but leaves us more open // to design changes than doing it manually. } /// returns true if all the integer class members are identical (but does not /// compare the transition probabilities. bool Compatible(const TransitionModel &other) const; private: void MleUpdateShared(const Vector &stats, const MleTransitionUpdateConfig &cfg, BaseFloat *objf_impr_out, BaseFloat *count_out); void MapUpdateShared(const Vector &stats, const MapTransitionUpdateConfig &cfg, BaseFloat *objf_impr_out, BaseFloat *count_out); void ComputeTriples(const ContextDependency &ctx_dep); // called from constructor. initializes triples_. void ComputeDerived(); // called from constructor and Read function: computes state2id_ and id2state_. void ComputeDerivedOfProbs(); // computes quantities derived from log-probs (currently just // non_self_loop_log_probs_; called whenever log-probs change. void InitializeProbs(); // called from constructor. void Check() const; struct Triple { int32 phone; int32 hmm_state; int32 pdf; Triple() { } Triple(int32 phone, int32 hmm_state, int32 pdf): phone(phone), hmm_state(hmm_state), pdf(pdf) { } bool operator < (const Triple &other) const { if (phone < other.phone) return true; else if (phone > other.phone) return false; else if (hmm_state < other.hmm_state) return true; else if (hmm_state > other.hmm_state) return false; else return pdf < other.pdf; } bool operator == (const Triple &other) const { return (phone == other.phone && hmm_state == other.hmm_state && pdf == other.pdf); } }; HmmTopology topo_; /// Triples indexed by transition state minus one; /// the triples are in sorted order which allows us to do the reverse mapping from /// triple to transition state std::vector triples_; /// Gives the first transition_id of each transition-state; indexed by /// the transition-state. Array indexed 1..num-transition-states+1 (the last one /// is needed so we can know the num-transitions of the last transition-state. std::vector state2id_; /// For each transition-id, the corresponding transition /// state (indexed by transition-id). std::vector id2state_; /// For each transition-id, the corresponding log-prob. Indexed by transition-id. Vector log_probs_; /// For each transition-state, the log of (1 - self-loop-prob). Indexed by /// transition-state. Vector non_self_loop_log_probs_; /// This is actually one plus the highest-numbered pdf we ever got back from the /// tree (but the tree numbers pdfs contiguously from zero so this is the number /// of pdfs). int32 num_pdfs_; DISALLOW_COPY_AND_ASSIGN(TransitionModel); }; inline int32 TransitionModel::TransitionIdToPdf(int32 trans_id) const { // If a lot of time is spent here we may create an extra array // to handle this. KALDI_ASSERT(static_cast(trans_id) < id2state_.size() && "Likely graph/model mismatch (graph built from wrong model?)"); int32 trans_state = id2state_[trans_id]; return triples_[trans_state-1].pdf; } /// Works out which pdfs might correspond to the given phones. Will return true /// if these pdfs correspond *just* to these phones, false if these pdfs are also /// used by other phones. /// @param trans_model [in] Transition-model used to work out this information /// @param phones [in] A sorted, uniq vector that represents a set of phones /// @param pdfs [out] Will be set to a sorted, uniq list of pdf-ids that correspond /// to one of this set of phones. /// @return Returns true if all of the pdfs output to "pdfs" correspond to phones from /// just this set (false if they may be shared with phones outside this set). bool GetPdfsForPhones(const TransitionModel &trans_model, const std::vector &phones, std::vector *pdfs); /// Works out which phones might correspond to the given pdfs. Similar to the /// above GetPdfsForPhones(, ,) bool GetPhonesForPdfs(const TransitionModel &trans_model, const std::vector &pdfs, std::vector *phones); /// @} } // end namespace kaldi #endif