summaryrefslogtreecommitdiff
path: root/kaldi_io/src/kaldi/hmm/transition-model.h
blob: ccc4f11e6493f58f05e186f84188289fc7558a8d (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
// hmm/transition-model.h

// Copyright 2009-2012  Microsoft Corporation
//                      Johns Hopkins University (author: Guoguo Chen)

// See ../../COPYING for clarification regarding multiple authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//  http://www.apache.org/licenses/LICENSE-2.0
//
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
// MERCHANTABLITY OR NON-INFRINGEMENT.
// See the Apache 2 License for the specific language governing permissions and
// limitations under the License.

#ifndef KALDI_HMM_TRANSITION_MODEL_H_
#define KALDI_HMM_TRANSITION_MODEL_H_

#include "base/kaldi-common.h"
#include "tree/context-dep.h"
#include "util/const-integer-set.h"
#include "fst/fst-decl.h" // forward declarations.
#include "hmm/hmm-topology.h"
#include "itf/options-itf.h"

namespace kaldi {

/// \addtogroup hmm_group
/// @{

// The class TransitionModel is a repository for the transition probabilities.
// It also handles certain integer mappings.
// The basic model is as follows.  Each phone has a HMM topology defined in
// hmm-topology.h.  Each HMM-state of each of these phones has a number of
// transitions (and final-probs) out of it.  Each HMM-state defined in the
// HmmTopology class has an associated "pdf_class".  This gets replaced with
// an actual pdf-id via the tree.  The transition model associates the
// transition probs with the (phone, HMM-state, pdf-id).  We associate with
// each such triple a transition-state.  Each
// transition-state has a number of associated probabilities to estimate;
// this depends on the number of transitions/final-probs in the topology for
// that (phone, HMM-state).  Each probability has an associated transition-index.
// We associate with each (transition-state, transition-index) a unique transition-id.
// Each individual probability estimated by the transition-model is asociated with a
// transition-id.
//
// List of the various types of quantity referred to here and what they mean:
//           phone:  a phone index (1, 2, 3 ...)
//       HMM-state:  a number (0, 1, 2...) that indexes TopologyEntry (see hmm-topology.h)
//          pdf-id:  a number output by the Compute function of ContextDependency (it
//                   indexes pdf's).  Zero-based.
// transition-state:  the states for which we estimate transition probabilities for transitions
//                    out of them.  In some topologies, will map one-to-one with pdf-ids.
//                    One-based, since it appears on FSTs.
// transition-index:  identifier of a transition (or final-prob) in the HMM.  Indexes the
//                    "transitions" vector in HmmTopology::HmmState.  [if it is out of range,
//                    equal to transitions.size(), it refers to the final-prob.]
//                    Zero-based.
//   transition-id:   identifier of a unique parameter of the TransitionModel.
//                    Associated with a (transition-state, transition-index) pair.
//                    One-based, since it appears on FSTs.
//
// List of the possible mappings TransitionModel can do:
//             (phone, HMM-state, pdf-id) -> transition-state
//   (transition-state, transition-index) -> transition-id
//  Reverse mappings:
//                        transition-id -> transition-state
//                        transition-id -> transition-index
//                     transition-state -> phone
//                     transition-state -> HMM-state
//                     transition-state -> pdf-id
//
// The main things the TransitionModel object can do are:
//    Get initialized (need ContextDependency and HmmTopology objects).
//    Read/write.
//    Update [given a vector of counts indexed by transition-id].
//    Do the various integer mappings mentioned above.
//    Get the probability (or log-probability) associated with a particular transition-id.


// Note: this was previously called TransitionUpdateConfig.
struct MleTransitionUpdateConfig {
  BaseFloat floor;
  BaseFloat mincount;
  bool share_for_pdfs; // If true, share all transition parameters that have the same pdf.
  MleTransitionUpdateConfig(BaseFloat floor = 0.01,
                            BaseFloat mincount = 5.0,
                            bool share_for_pdfs = false):
      floor(floor), mincount(mincount), share_for_pdfs(share_for_pdfs) {}
  
  void Register (OptionsItf *po) {
    po->Register("transition-floor", &floor,
                 "Floor for transition probabilities");
    po->Register("transition-min-count", &mincount,
                 "Minimum count required to update transitions from a state");
    po->Register("share-for-pdfs", &share_for_pdfs,
                 "If true, share all transition parameters where the states "
                 "have the same pdf.");
  }
};

struct MapTransitionUpdateConfig {
  BaseFloat tau;
  bool share_for_pdfs; // If true, share all transition parameters that have the same pdf.
  MapTransitionUpdateConfig(): tau(5.0), share_for_pdfs(false) { }

  void Register (OptionsItf *po) {
    po->Register("transition-tau", &tau, "Tau value for MAP estimation of transition "
                 "probabilities.");
    po->Register("share-for-pdfs", &share_for_pdfs,
                 "If true, share all transition parameters where the states "
                 "have the same pdf.");
  }
};

class TransitionModel {

 public:
  /// Initialize the object [e.g. at the start of training].
  /// The class keeps a copy of the HmmTopology object, but not
  /// the ContextDependency object.
  TransitionModel(const ContextDependency &ctx_dep,
                  const HmmTopology &hmm_topo);


  /// Constructor that takes no arguments: typically used prior to calling Read.
  TransitionModel() { }

  void Read(std::istream &is, bool binary);  // note, no symbol table: topo object always read/written w/o symbols.
  void Write(std::ostream &os, bool binary) const;


  /// return reference to HMM-topology object.
  const HmmTopology &GetTopo() const { return topo_; }

  /// \name Integer mapping functions
  /// @{

  int32 TripleToTransitionState(int32 phone, int32 hmm_state, int32 pdf) const;
  int32 PairToTransitionId(int32 trans_state, int32 trans_index) const;
  int32 TransitionIdToTransitionState(int32 trans_id) const;
  int32 TransitionIdToTransitionIndex(int32 trans_id) const;
  int32 TransitionStateToPhone(int32 trans_state) const;
  int32 TransitionStateToHmmState(int32 trans_state) const;
  int32 TransitionStateToPdf(int32 trans_state) const;
  int32 SelfLoopOf(int32 trans_state) const;  // returns the self-loop transition-id, or zero if
  // this state doesn't have a self-loop.

  inline int32 TransitionIdToPdf(int32 trans_id) const;
  int32 TransitionIdToPhone(int32 trans_id) const;
  int32 TransitionIdToPdfClass(int32 trans_id) const;
  int32 TransitionIdToHmmState(int32 trans_id) const;

  /// @}

  bool IsFinal(int32 trans_id) const;  // returns true if this trans_id goes to the final state
  // (which is bound to be nonemitting).
  bool IsSelfLoop(int32 trans_id) const;  // return true if this trans_id corresponds to a self-loop.

  /// Returns the total number of transition-ids (note, these are one-based).
  inline int32 NumTransitionIds() const { return id2state_.size()-1; }

  /// Returns the number of transition-indices for a particular transition-state.
  /// Note: "Indices" is the plural of "index".   Index is not the same as "id",
  /// here.  A transition-index is a zero-based offset into the transitions
  /// out of a particular transition state.
  int32 NumTransitionIndices(int32 trans_state) const;

  /// Returns the total number of transition-states (note, these are one-based).
  int32 NumTransitionStates() const { return triples_.size(); }

  // NumPdfs() actually returns the highest-numbered pdf we ever saw, plus one.
  // In normal cases this should equal the number of pdfs in the system, but if you
  // initialized this object with fewer than all the phones, and it happens that
  // an unseen phone has the highest-numbered pdf, this might be different.
  int32 NumPdfs() const { return num_pdfs_; }

  // This loops over the triples and finds the highest phone index present. If
  // the FST symbol table for the phones is created in the expected way, i.e.:
  // starting from 1 (<eps> is 0) and numbered contiguously till the last phone,
  // this will be the total number of phones.
  int32 NumPhones() const;

  /// Returns a sorted, unique list of phones.
  const std::vector<int32> &GetPhones() const { return topo_.GetPhones(); }

  // Transition-parameter-getting functions:
  BaseFloat GetTransitionProb(int32 trans_id) const;
  BaseFloat GetTransitionLogProb(int32 trans_id) const;

  // The following functions are more specialized functions for getting
  // transition probabilities, that are provided for convenience.

  /// Returns the log-probability of a particular non-self-loop transition
  /// after subtracting the probability mass of the self-loop and renormalizing;
  /// will crash if called on a self-loop.  Specifically:
  /// for non-self-loops it returns the log of that prob divided by (1 minus
  /// self-loop-prob-for-that-state).
  BaseFloat GetTransitionLogProbIgnoringSelfLoops(int32 trans_id) const;

  /// Returns the log-prob of the non-self-loop probability
  /// mass for this transition state. (you can get the self-loop prob, if a self-loop
  /// exists, by calling GetTransitionLogProb(SelfLoopOf(trans_state)).
  BaseFloat GetNonSelfLoopLogProb(int32 trans_state) const;

  /// Does Maximum Likelihood estimation.  The stats are counts/weights, indexed
  /// by transition-id.  This was previously called Update().
  void MleUpdate(const Vector<double> &stats, 
                 const MleTransitionUpdateConfig &cfg,
                 BaseFloat *objf_impr_out,
                 BaseFloat *count_out);

  /// Does Maximum A Posteriori (MAP) estimation.  The stats are counts/weights,
  /// indexed by transition-id.
  void MapUpdate(const Vector<double> &stats, 
                 const MapTransitionUpdateConfig &cfg,
                 BaseFloat *objf_impr_out,
                 BaseFloat *count_out);
  
  /// Print will print the transition model in a human-readable way, for purposes of human
  /// inspection.  The "occs" are optional (they are indexed by pdf-id).
  void Print(std::ostream &os,
             const std::vector<std::string> &phone_names,
             const Vector<double> *occs = NULL);


  void InitStats(Vector<double> *stats) const { stats->Resize(NumTransitionIds()+1); }

  void Accumulate(BaseFloat prob, int32 trans_id, Vector<double> *stats) const {
    KALDI_ASSERT(trans_id <= NumTransitionIds());
    (*stats)(trans_id) += prob;
    // This is trivial and doesn't require class members, but leaves us more open
    // to design changes than doing it manually.
  }

  /// returns true if all the integer class members are identical (but does not
  /// compare the transition probabilities.
  bool Compatible(const TransitionModel &other) const;
  
 private:
  void MleUpdateShared(const Vector<double> &stats,
                       const MleTransitionUpdateConfig &cfg,
                       BaseFloat *objf_impr_out, BaseFloat *count_out);
  void MapUpdateShared(const Vector<double> &stats,
                       const MapTransitionUpdateConfig &cfg,
                       BaseFloat *objf_impr_out, BaseFloat *count_out);
  void ComputeTriples(const ContextDependency &ctx_dep);  // called from constructor.  initializes triples_.
  void ComputeDerived();  // called from constructor and Read function: computes state2id_ and id2state_.
  void ComputeDerivedOfProbs();  // computes quantities derived from log-probs (currently just
  // non_self_loop_log_probs_; called whenever log-probs change.
  void InitializeProbs();  // called from constructor.
  void Check() const;

  struct Triple {
    int32 phone;
    int32 hmm_state;
    int32 pdf;
    Triple() { }
    Triple(int32 phone, int32 hmm_state, int32 pdf):
        phone(phone), hmm_state(hmm_state), pdf(pdf) { }
    bool operator < (const Triple &other) const {
      if (phone < other.phone) return true;
      else if (phone > other.phone) return false;
      else if (hmm_state < other.hmm_state) return true;
      else if (hmm_state > other.hmm_state) return false;
      else return pdf < other.pdf;
    }
    bool operator == (const Triple &other) const {
      return (phone == other.phone && hmm_state == other.hmm_state
              && pdf == other.pdf);
    }
  };

  HmmTopology topo_;

  /// Triples indexed by transition state minus one;
  /// the triples are in sorted order which allows us to do the reverse mapping from
  /// triple to transition state
  std::vector<Triple> triples_;
  
  /// Gives the first transition_id of each transition-state; indexed by
  /// the transition-state.  Array indexed 1..num-transition-states+1 (the last one
  /// is needed so we can know the num-transitions of the last transition-state.
  std::vector<int32> state2id_;

  /// For each transition-id, the corresponding transition
  /// state (indexed by transition-id).
  std::vector<int32> id2state_;

  /// For each transition-id, the corresponding log-prob.  Indexed by transition-id.
  Vector<BaseFloat> log_probs_;

  /// For each transition-state, the log of (1 - self-loop-prob).  Indexed by
  /// transition-state.
  Vector<BaseFloat> non_self_loop_log_probs_;

  /// This is actually one plus the highest-numbered pdf we ever got back from the
  /// tree (but the tree numbers pdfs contiguously from zero so this is the number
  /// of pdfs).
  int32 num_pdfs_;


  DISALLOW_COPY_AND_ASSIGN(TransitionModel);

};

inline int32 TransitionModel::TransitionIdToPdf(int32 trans_id) const {
  // If a lot of time is spent here we may create an extra array
  // to handle this.
  KALDI_ASSERT(static_cast<size_t>(trans_id) < id2state_.size() &&
               "Likely graph/model mismatch (graph built from wrong model?)");
  int32 trans_state = id2state_[trans_id];
  return triples_[trans_state-1].pdf;
}

/// Works out which pdfs might correspond to the given phones.  Will return true
/// if these pdfs correspond *just* to these phones, false if these pdfs are also
/// used by other phones.
/// @param trans_model [in] Transition-model used to work out this information
/// @param phones [in] A sorted, uniq vector that represents a set of phones
/// @param pdfs [out] Will be set to a sorted, uniq list of pdf-ids that correspond
///                   to one of this set of phones.
/// @return  Returns true if all of the pdfs output to "pdfs" correspond to phones from
///          just this set (false if they may be shared with phones outside this set).
bool GetPdfsForPhones(const TransitionModel &trans_model,
                      const std::vector<int32> &phones,
                      std::vector<int32> *pdfs);

/// Works out which phones might correspond to the given pdfs. Similar to the
/// above GetPdfsForPhones(, ,)
bool GetPhonesForPdfs(const TransitionModel &trans_model,
                      const std::vector<int32> &pdfs,
                      std::vector<int32> *phones);
/// @}


} // end namespace kaldi


#endif