summaryrefslogtreecommitdiff
path: root/kaldi_io/src/tools/openfst/include/fst/extensions/ngram/ngram-fst.h
diff options
context:
space:
mode:
Diffstat (limited to 'kaldi_io/src/tools/openfst/include/fst/extensions/ngram/ngram-fst.h')
-rw-r--r--kaldi_io/src/tools/openfst/include/fst/extensions/ngram/ngram-fst.h934
1 files changed, 0 insertions, 934 deletions
diff --git a/kaldi_io/src/tools/openfst/include/fst/extensions/ngram/ngram-fst.h b/kaldi_io/src/tools/openfst/include/fst/extensions/ngram/ngram-fst.h
deleted file mode 100644
index d113fb3..0000000
--- a/kaldi_io/src/tools/openfst/include/fst/extensions/ngram/ngram-fst.h
+++ /dev/null
@@ -1,934 +0,0 @@
-
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-//
-// Copyright 2005-2010 Google, Inc.
-// Author: sorenj@google.com (Jeffrey Sorensen)
-//
-#ifndef FST_EXTENSIONS_NGRAM_NGRAM_FST_H_
-#define FST_EXTENSIONS_NGRAM_NGRAM_FST_H_
-
-#include <stddef.h>
-#include <string.h>
-#include <algorithm>
-#include <string>
-#include <vector>
-using std::vector;
-
-#include <fst/compat.h>
-#include <fst/fstlib.h>
-#include <fst/mapped-file.h>
-#include <fst/extensions/ngram/bitmap-index.h>
-
-// NgramFst implements a n-gram language model based upon the LOUDS data
-// structure. Please refer to "Unary Data Strucutres for Language Models"
-// http://research.google.com/pubs/archive/37218.pdf
-
-namespace fst {
-template <class A> class NGramFst;
-template <class A> class NGramFstMatcher;
-
-// Instance data containing mutable state for bookkeeping repeated access to
-// the same state.
-template <class A>
-struct NGramFstInst {
- typedef typename A::Label Label;
- typedef typename A::StateId StateId;
- typedef typename A::Weight Weight;
- StateId state_;
- size_t num_futures_;
- size_t offset_;
- size_t node_;
- StateId node_state_;
- vector<Label> context_;
- StateId context_state_;
- NGramFstInst()
- : state_(kNoStateId), node_state_(kNoStateId),
- context_state_(kNoStateId) { }
-};
-
-// Implementation class for LOUDS based NgramFst interface
-template <class A>
-class NGramFstImpl : public FstImpl<A> {
- using FstImpl<A>::SetInputSymbols;
- using FstImpl<A>::SetOutputSymbols;
- using FstImpl<A>::SetType;
- using FstImpl<A>::WriteHeader;
-
- friend class ArcIterator<NGramFst<A> >;
- friend class NGramFstMatcher<A>;
-
- public:
- using FstImpl<A>::InputSymbols;
- using FstImpl<A>::SetProperties;
- using FstImpl<A>::Properties;
-
- typedef A Arc;
- typedef typename A::Label Label;
- typedef typename A::StateId StateId;
- typedef typename A::Weight Weight;
-
- NGramFstImpl() : data_region_(0), data_(0), owned_(false) {
- SetType("ngram");
- SetInputSymbols(NULL);
- SetOutputSymbols(NULL);
- SetProperties(kStaticProperties);
- }
-
- NGramFstImpl(const Fst<A> &fst, vector<StateId>* order_out);
-
- ~NGramFstImpl() {
- if (owned_) {
- delete [] data_;
- }
- delete data_region_;
- }
-
- static NGramFstImpl<A>* Read(istream &strm, // NOLINT
- const FstReadOptions &opts) {
- NGramFstImpl<A>* impl = new NGramFstImpl();
- FstHeader hdr;
- if (!impl->ReadHeader(strm, opts, kMinFileVersion, &hdr)) return 0;
- uint64 num_states, num_futures, num_final;
- const size_t offset = sizeof(num_states) + sizeof(num_futures) +
- sizeof(num_final);
- // Peek at num_states and num_futures to see how much more needs to be read.
- strm.read(reinterpret_cast<char *>(&num_states), sizeof(num_states));
- strm.read(reinterpret_cast<char *>(&num_futures), sizeof(num_futures));
- strm.read(reinterpret_cast<char *>(&num_final), sizeof(num_final));
- size_t size = Storage(num_states, num_futures, num_final);
- MappedFile *data_region = MappedFile::Allocate(size);
- char *data = reinterpret_cast<char *>(data_region->mutable_data());
- // Copy num_states, num_futures and num_final back into data.
- memcpy(data, reinterpret_cast<char *>(&num_states), sizeof(num_states));
- memcpy(data + sizeof(num_states), reinterpret_cast<char *>(&num_futures),
- sizeof(num_futures));
- memcpy(data + sizeof(num_states) + sizeof(num_futures),
- reinterpret_cast<char *>(&num_final), sizeof(num_final));
- strm.read(data + offset, size - offset);
- if (!strm) {
- delete impl;
- return NULL;
- }
- impl->Init(data, false, data_region);
- return impl;
- }
-
- bool Write(ostream &strm, // NOLINT
- const FstWriteOptions &opts) const {
- FstHeader hdr;
- hdr.SetStart(Start());
- hdr.SetNumStates(num_states_);
- WriteHeader(strm, opts, kFileVersion, &hdr);
- strm.write(data_, Storage(num_states_, num_futures_, num_final_));
- return strm;
- }
-
- StateId Start() const {
- return 1;
- }
-
- Weight Final(StateId state) const {
- if (final_index_.Get(state)) {
- return final_probs_[final_index_.Rank1(state)];
- } else {
- return Weight::Zero();
- }
- }
-
- size_t NumArcs(StateId state, NGramFstInst<A> *inst = NULL) const {
- if (inst == NULL) {
- const size_t next_zero = future_index_.Select0(state + 1);
- const size_t this_zero = future_index_.Select0(state);
- return next_zero - this_zero - 1;
- }
- SetInstFuture(state, inst);
- return inst->num_futures_ + ((state == 0) ? 0 : 1);
- }
-
- size_t NumInputEpsilons(StateId state) const {
- // State 0 has no parent, thus no backoff.
- if (state == 0) return 0;
- return 1;
- }
-
- size_t NumOutputEpsilons(StateId state) const {
- return NumInputEpsilons(state);
- }
-
- StateId NumStates() const {
- return num_states_;
- }
-
- void InitStateIterator(StateIteratorData<A>* data) const {
- data->base = 0;
- data->nstates = num_states_;
- }
-
- static size_t Storage(uint64 num_states, uint64 num_futures,
- uint64 num_final) {
- uint64 b64;
- Weight weight;
- Label label;
- size_t offset = sizeof(num_states) + sizeof(num_futures) +
- sizeof(num_final);
- offset += sizeof(b64) * (
- BitmapIndex::StorageSize(num_states * 2 + 1) +
- BitmapIndex::StorageSize(num_futures + num_states + 1) +
- BitmapIndex::StorageSize(num_states));
- offset += (num_states + 1) * sizeof(label) + num_futures * sizeof(label);
- // Pad for alignemnt, see
- // http://en.wikipedia.org/wiki/Data_structure_alignment#Computing_padding
- offset = (offset + sizeof(weight) - 1) & ~(sizeof(weight) - 1);
- offset += (num_states + 1) * sizeof(weight) + num_final * sizeof(weight) +
- (num_futures + 1) * sizeof(weight);
- return offset;
- }
-
- void SetInstFuture(StateId state, NGramFstInst<A> *inst) const {
- if (inst->state_ != state) {
- inst->state_ = state;
- const size_t next_zero = future_index_.Select0(state + 1);
- const size_t this_zero = future_index_.Select0(state);
- inst->num_futures_ = next_zero - this_zero - 1;
- inst->offset_ = future_index_.Rank1(future_index_.Select0(state) + 1);
- }
- }
-
- void SetInstNode(NGramFstInst<A> *inst) const {
- if (inst->node_state_ != inst->state_) {
- inst->node_state_ = inst->state_;
- inst->node_ = context_index_.Select1(inst->state_);
- }
- }
-
- void SetInstContext(NGramFstInst<A> *inst) const {
- SetInstNode(inst);
- if (inst->context_state_ != inst->state_) {
- inst->context_state_ = inst->state_;
- inst->context_.clear();
- size_t node = inst->node_;
- while (node != 0) {
- inst->context_.push_back(context_words_[context_index_.Rank1(node)]);
- node = context_index_.Select1(context_index_.Rank0(node) - 1);
- }
- }
- }
-
- // Access to the underlying representation
- const char* GetData(size_t* data_size) const {
- *data_size = Storage(num_states_, num_futures_, num_final_);
- return data_;
- }
-
- void Init(const char* data, bool owned, MappedFile *file = 0);
-
- const vector<Label> &GetContext(StateId s, NGramFstInst<A> *inst) const {
- SetInstFuture(s, inst);
- SetInstContext(inst);
- return inst->context_;
- }
-
- private:
- StateId Transition(const vector<Label> &context, Label future) const;
-
- // Properties always true for this Fst class.
- static const uint64 kStaticProperties = kAcceptor | kIDeterministic |
- kODeterministic | kEpsilons | kIEpsilons | kOEpsilons | kILabelSorted |
- kOLabelSorted | kWeighted | kCyclic | kInitialAcyclic | kNotTopSorted |
- kAccessible | kCoAccessible | kNotString | kExpanded;
- // Current file format version.
- static const int kFileVersion = 4;
- // Minimum file format version supported.
- static const int kMinFileVersion = 4;
-
- MappedFile *data_region_;
- const char* data_;
- bool owned_; // True if we own data_
- uint64 num_states_, num_futures_, num_final_;
- size_t root_num_children_;
- const Label *root_children_;
- size_t root_first_child_;
- // borrowed references
- const uint64 *context_, *future_, *final_;
- const Label *context_words_, *future_words_;
- const Weight *backoff_, *final_probs_, *future_probs_;
- BitmapIndex context_index_;
- BitmapIndex future_index_;
- BitmapIndex final_index_;
-
- void operator=(const NGramFstImpl<A> &); // Disallow
-};
-
-template<typename A>
-NGramFstImpl<A>::NGramFstImpl(const Fst<A> &fst, vector<StateId>* order_out)
- : data_region_(0), data_(0), owned_(false) {
- typedef A Arc;
- typedef typename Arc::Label Label;
- typedef typename Arc::Weight Weight;
- typedef typename Arc::StateId StateId;
- SetType("ngram");
- SetInputSymbols(fst.InputSymbols());
- SetOutputSymbols(fst.OutputSymbols());
- SetProperties(kStaticProperties);
-
- // Check basic requirements for an OpenGRM language model Fst.
- int64 props = kAcceptor | kIDeterministic | kIEpsilons | kILabelSorted;
- if (fst.Properties(props, true) != props) {
- FSTERROR() << "NGramFst only accepts OpenGRM langauge models as input";
- SetProperties(kError, kError);
- return;
- }
-
- int64 num_states = CountStates(fst);
- Label* context = new Label[num_states];
-
- // Find the unigram state by starting from the start state, following
- // epsilons.
- StateId unigram = fst.Start();
- while (1) {
- if (unigram == kNoStateId) {
- FSTERROR() << "Could not identify unigram state.";
- SetProperties(kError, kError);
- return;
- }
- ArcIterator<Fst<A> > aiter(fst, unigram);
- if (aiter.Done()) {
- LOG(WARNING) << "Unigram state " << unigram << " has no arcs.";
- break;
- }
- if (aiter.Value().ilabel != 0) break;
- unigram = aiter.Value().nextstate;
- }
-
- // Each state's context is determined by the subtree it is under from the
- // unigram state.
- queue<pair<StateId, Label> > label_queue;
- vector<bool> visited(num_states);
- // Force an epsilon link to the start state.
- label_queue.push(make_pair(fst.Start(), 0));
- for (ArcIterator<Fst<A> > aiter(fst, unigram);
- !aiter.Done(); aiter.Next()) {
- label_queue.push(make_pair(aiter.Value().nextstate, aiter.Value().ilabel));
- }
- // investigate states in breadth first fashion to assign context words.
- while (!label_queue.empty()) {
- pair<StateId, Label> &now = label_queue.front();
- if (!visited[now.first]) {
- context[now.first] = now.second;
- visited[now.first] = true;
- for (ArcIterator<Fst<A> > aiter(fst, now.first);
- !aiter.Done(); aiter.Next()) {
- const Arc &arc = aiter.Value();
- if (arc.ilabel != 0) {
- label_queue.push(make_pair(arc.nextstate, now.second));
- }
- }
- }
- label_queue.pop();
- }
- visited.clear();
-
- // The arc from the start state should be assigned an epsilon to put it
- // in front of the all other labels (which makes Start state 1 after
- // unigram which is state 0).
- context[fst.Start()] = 0;
-
- // Build the tree of contexts fst by reversing the epsilon arcs from fst.
- VectorFst<Arc> context_fst;
- uint64 num_final = 0;
- for (int i = 0; i < num_states; ++i) {
- if (fst.Final(i) != Weight::Zero()) {
- ++num_final;
- }
- context_fst.SetFinal(context_fst.AddState(), fst.Final(i));
- }
- context_fst.SetStart(unigram);
- context_fst.SetInputSymbols(fst.InputSymbols());
- context_fst.SetOutputSymbols(fst.OutputSymbols());
- int64 num_context_arcs = 0;
- int64 num_futures = 0;
- for (StateIterator<Fst<A> > siter(fst); !siter.Done(); siter.Next()) {
- const StateId &state = siter.Value();
- num_futures += fst.NumArcs(state) - fst.NumInputEpsilons(state);
- ArcIterator<Fst<A> > aiter(fst, state);
- if (!aiter.Done()) {
- const Arc &arc = aiter.Value();
- // this arc goes from state to arc.nextstate, so create an arc from
- // arc.nextstate to state to reverse it.
- if (arc.ilabel == 0) {
- context_fst.AddArc(arc.nextstate, Arc(context[state], context[state],
- arc.weight, state));
- num_context_arcs++;
- }
- }
- }
- if (num_context_arcs != context_fst.NumStates() - 1) {
- FSTERROR() << "Number of contexts arcs != number of states - 1";
- SetProperties(kError, kError);
- return;
- }
- if (context_fst.NumStates() != num_states) {
- FSTERROR() << "Number of contexts != number of states";
- SetProperties(kError, kError);
- return;
- }
- int64 context_props = context_fst.Properties(kIDeterministic |
- kILabelSorted, true);
- if (!(context_props & kIDeterministic)) {
- FSTERROR() << "Input fst is not structured properly";
- SetProperties(kError, kError);
- return;
- }
- if (!(context_props & kILabelSorted)) {
- ArcSort(&context_fst, ILabelCompare<Arc>());
- }
-
- delete [] context;
-
- uint64 b64;
- Weight weight;
- Label label = kNoLabel;
- const size_t storage = Storage(num_states, num_futures, num_final);
- MappedFile *data_region = MappedFile::Allocate(storage);
- char *data = reinterpret_cast<char *>(data_region->mutable_data());
- memset(data, 0, storage);
- size_t offset = 0;
- memcpy(data + offset, reinterpret_cast<char *>(&num_states),
- sizeof(num_states));
- offset += sizeof(num_states);
- memcpy(data + offset, reinterpret_cast<char *>(&num_futures),
- sizeof(num_futures));
- offset += sizeof(num_futures);
- memcpy(data + offset, reinterpret_cast<char *>(&num_final),
- sizeof(num_final));
- offset += sizeof(num_final);
- uint64* context_bits = reinterpret_cast<uint64*>(data + offset);
- offset += BitmapIndex::StorageSize(num_states * 2 + 1) * sizeof(b64);
- uint64* future_bits = reinterpret_cast<uint64*>(data + offset);
- offset +=
- BitmapIndex::StorageSize(num_futures + num_states + 1) * sizeof(b64);
- uint64* final_bits = reinterpret_cast<uint64*>(data + offset);
- offset += BitmapIndex::StorageSize(num_states) * sizeof(b64);
- Label* context_words = reinterpret_cast<Label*>(data + offset);
- offset += (num_states + 1) * sizeof(label);
- Label* future_words = reinterpret_cast<Label*>(data + offset);
- offset += num_futures * sizeof(label);
- offset = (offset + sizeof(weight) - 1) & ~(sizeof(weight) - 1);
- Weight* backoff = reinterpret_cast<Weight*>(data + offset);
- offset += (num_states + 1) * sizeof(weight);
- Weight* final_probs = reinterpret_cast<Weight*>(data + offset);
- offset += num_final * sizeof(weight);
- Weight* future_probs = reinterpret_cast<Weight*>(data + offset);
- int64 context_arc = 0, future_arc = 0, context_bit = 0, future_bit = 0,
- final_bit = 0;
-
- // pseudo-root bits
- BitmapIndex::Set(context_bits, context_bit++);
- ++context_bit;
- context_words[context_arc] = label;
- backoff[context_arc] = Weight::Zero();
- context_arc++;
-
- ++future_bit;
- if (order_out) {
- order_out->clear();
- order_out->resize(num_states);
- }
-
- queue<StateId> context_q;
- context_q.push(context_fst.Start());
- StateId state_number = 0;
- while (!context_q.empty()) {
- const StateId &state = context_q.front();
- if (order_out) {
- (*order_out)[state] = state_number;
- }
-
- const Weight &final = context_fst.Final(state);
- if (final != Weight::Zero()) {
- BitmapIndex::Set(final_bits, state_number);
- final_probs[final_bit] = final;
- ++final_bit;
- }
-
- for (ArcIterator<VectorFst<A> > aiter(context_fst, state);
- !aiter.Done(); aiter.Next()) {
- const Arc &arc = aiter.Value();
- context_words[context_arc] = arc.ilabel;
- backoff[context_arc] = arc.weight;
- ++context_arc;
- BitmapIndex::Set(context_bits, context_bit++);
- context_q.push(arc.nextstate);
- }
- ++context_bit;
-
- for (ArcIterator<Fst<A> > aiter(fst, state); !aiter.Done(); aiter.Next()) {
- const Arc &arc = aiter.Value();
- if (arc.ilabel != 0) {
- future_words[future_arc] = arc.ilabel;
- future_probs[future_arc] = arc.weight;
- ++future_arc;
- BitmapIndex::Set(future_bits, future_bit++);
- }
- }
- ++future_bit;
- ++state_number;
- context_q.pop();
- }
-
- if ((state_number != num_states) ||
- (context_bit != num_states * 2 + 1) ||
- (context_arc != num_states) ||
- (future_arc != num_futures) ||
- (future_bit != num_futures + num_states + 1) ||
- (final_bit != num_final)) {
- FSTERROR() << "Structure problems detected during construction";
- SetProperties(kError, kError);
- return;
- }
-
- Init(data, false, data_region);
-}
-
-template<typename A>
-inline void NGramFstImpl<A>::Init(const char* data, bool owned,
- MappedFile *data_region) {
- if (owned_) {
- delete [] data_;
- }
- delete data_region_;
- data_region_ = data_region;
- owned_ = owned;
- data_ = data;
- size_t offset = 0;
- num_states_ = *(reinterpret_cast<const uint64*>(data_ + offset));
- offset += sizeof(num_states_);
- num_futures_ = *(reinterpret_cast<const uint64*>(data_ + offset));
- offset += sizeof(num_futures_);
- num_final_ = *(reinterpret_cast<const uint64*>(data_ + offset));
- offset += sizeof(num_final_);
- uint64 bits;
- size_t context_bits = num_states_ * 2 + 1;
- size_t future_bits = num_futures_ + num_states_ + 1;
- context_ = reinterpret_cast<const uint64*>(data_ + offset);
- offset += BitmapIndex::StorageSize(context_bits) * sizeof(bits);
- future_ = reinterpret_cast<const uint64*>(data_ + offset);
- offset += BitmapIndex::StorageSize(future_bits) * sizeof(bits);
- final_ = reinterpret_cast<const uint64*>(data_ + offset);
- offset += BitmapIndex::StorageSize(num_states_) * sizeof(bits);
- context_words_ = reinterpret_cast<const Label*>(data_ + offset);
- offset += (num_states_ + 1) * sizeof(*context_words_);
- future_words_ = reinterpret_cast<const Label*>(data_ + offset);
- offset += num_futures_ * sizeof(*future_words_);
- offset = (offset + sizeof(*backoff_) - 1) & ~(sizeof(*backoff_) - 1);
- backoff_ = reinterpret_cast<const Weight*>(data_ + offset);
- offset += (num_states_ + 1) * sizeof(*backoff_);
- final_probs_ = reinterpret_cast<const Weight*>(data_ + offset);
- offset += num_final_ * sizeof(*final_probs_);
- future_probs_ = reinterpret_cast<const Weight*>(data_ + offset);
-
- context_index_.BuildIndex(context_, context_bits);
- future_index_.BuildIndex(future_, future_bits);
- final_index_.BuildIndex(final_, num_states_);
-
- const size_t node_rank = context_index_.Rank1(0);
- root_first_child_ = context_index_.Select0(node_rank) + 1;
- if (context_index_.Get(root_first_child_) == false) {
- FSTERROR() << "Missing unigrams";
- SetProperties(kError, kError);
- return;
- }
- const size_t last_child = context_index_.Select0(node_rank + 1) - 1;
- root_num_children_ = last_child - root_first_child_ + 1;
- root_children_ = context_words_ + context_index_.Rank1(root_first_child_);
-}
-
-template<typename A>
-inline typename A::StateId NGramFstImpl<A>::Transition(
- const vector<Label> &context, Label future) const {
- size_t num_children = root_num_children_;
- const Label *children = root_children_;
- const Label *loc = lower_bound(children, children + num_children, future);
- if (loc == children + num_children || *loc != future) {
- return context_index_.Rank1(0);
- }
- size_t node = root_first_child_ + loc - children;
- size_t node_rank = context_index_.Rank1(node);
- size_t first_child = context_index_.Select0(node_rank) + 1;
- if (context_index_.Get(first_child) == false) {
- return context_index_.Rank1(node);
- }
- size_t last_child = context_index_.Select0(node_rank + 1) - 1;
- num_children = last_child - first_child + 1;
- for (int word = context.size() - 1; word >= 0; --word) {
- children = context_words_ + context_index_.Rank1(first_child);
- loc = lower_bound(children, children + last_child - first_child + 1,
- context[word]);
- if (loc == children + last_child - first_child + 1 ||
- *loc != context[word]) {
- break;
- }
- node = first_child + loc - children;
- node_rank = context_index_.Rank1(node);
- first_child = context_index_.Select0(node_rank) + 1;
- if (context_index_.Get(first_child) == false) break;
- last_child = context_index_.Select0(node_rank + 1) - 1;
- }
- return context_index_.Rank1(node);
-}
-
-/*****************************************************************************/
-template<class A>
-class NGramFst : public ImplToExpandedFst<NGramFstImpl<A> > {
- friend class ArcIterator<NGramFst<A> >;
- friend class NGramFstMatcher<A>;
-
- public:
- typedef A Arc;
- typedef typename A::StateId StateId;
- typedef typename A::Label Label;
- typedef typename A::Weight Weight;
- typedef NGramFstImpl<A> Impl;
-
- explicit NGramFst(const Fst<A> &dst)
- : ImplToExpandedFst<Impl>(new Impl(dst, NULL)) {}
-
- NGramFst(const Fst<A> &fst, vector<StateId>* order_out)
- : ImplToExpandedFst<Impl>(new Impl(fst, order_out)) {}
-
- // Because the NGramFstImpl is a const stateless data structure, there
- // is never a need to do anything beside copy the reference.
- NGramFst(const NGramFst<A> &fst, bool safe = false)
- : ImplToExpandedFst<Impl>(fst, false) {}
-
- NGramFst() : ImplToExpandedFst<Impl>(new Impl()) {}
-
- // Non-standard constructor to initialize NGramFst directly from data.
- NGramFst(const char* data, bool owned) : ImplToExpandedFst<Impl>(new Impl()) {
- GetImpl()->Init(data, owned, NULL);
- }
-
- // Get method that gets the data associated with Init().
- const char* GetData(size_t* data_size) const {
- return GetImpl()->GetData(data_size);
- }
-
- const vector<Label> GetContext(StateId s) const {
- return GetImpl()->GetContext(s, &inst_);
- }
-
- virtual size_t NumArcs(StateId s) const {
- return GetImpl()->NumArcs(s, &inst_);
- }
-
- virtual NGramFst<A>* Copy(bool safe = false) const {
- return new NGramFst(*this, safe);
- }
-
- static NGramFst<A>* Read(istream &strm, const FstReadOptions &opts) {
- Impl* impl = Impl::Read(strm, opts);
- return impl ? new NGramFst<A>(impl) : 0;
- }
-
- static NGramFst<A>* Read(const string &filename) {
- if (!filename.empty()) {
- ifstream strm(filename.c_str(), ifstream::in | ifstream::binary);
- if (!strm) {
- LOG(ERROR) << "NGramFst::Read: Can't open file: " << filename;
- return 0;
- }
- return Read(strm, FstReadOptions(filename));
- } else {
- return Read(cin, FstReadOptions("standard input"));
- }
- }
-
- virtual bool Write(ostream &strm, const FstWriteOptions &opts) const {
- return GetImpl()->Write(strm, opts);
- }
-
- virtual bool Write(const string &filename) const {
- return Fst<A>::WriteFile(filename);
- }
-
- virtual inline void InitStateIterator(StateIteratorData<A>* data) const {
- GetImpl()->InitStateIterator(data);
- }
-
- virtual inline void InitArcIterator(
- StateId s, ArcIteratorData<A>* data) const;
-
- virtual MatcherBase<A>* InitMatcher(MatchType match_type) const {
- return new NGramFstMatcher<A>(*this, match_type);
- }
-
- private:
- explicit NGramFst(Impl* impl) : ImplToExpandedFst<Impl>(impl) {}
-
- Impl* GetImpl() const {
- return
- ImplToExpandedFst<Impl, ExpandedFst<A> >::GetImpl();
- }
-
- void SetImpl(Impl* impl, bool own_impl = true) {
- ImplToExpandedFst<Impl, Fst<A> >::SetImpl(impl, own_impl);
- }
-
- mutable NGramFstInst<A> inst_;
-};
-
-template <class A> inline void
-NGramFst<A>::InitArcIterator(StateId s, ArcIteratorData<A>* data) const {
- GetImpl()->SetInstFuture(s, &inst_);
- GetImpl()->SetInstNode(&inst_);
- data->base = new ArcIterator<NGramFst<A> >(*this, s);
-}
-
-/*****************************************************************************/
-template <class A>
-class NGramFstMatcher : public MatcherBase<A> {
- public:
- typedef A Arc;
- typedef typename A::Label Label;
- typedef typename A::StateId StateId;
- typedef typename A::Weight Weight;
-
- NGramFstMatcher(const NGramFst<A> &fst, MatchType match_type)
- : fst_(fst), inst_(fst.inst_), match_type_(match_type),
- current_loop_(false),
- loop_(kNoLabel, 0, A::Weight::One(), kNoStateId) {
- if (match_type_ == MATCH_OUTPUT) {
- swap(loop_.ilabel, loop_.olabel);
- }
- }
-
- NGramFstMatcher(const NGramFstMatcher<A> &matcher, bool safe = false)
- : fst_(matcher.fst_), inst_(matcher.inst_),
- match_type_(matcher.match_type_), current_loop_(false),
- loop_(kNoLabel, 0, A::Weight::One(), kNoStateId) {
- if (match_type_ == MATCH_OUTPUT) {
- swap(loop_.ilabel, loop_.olabel);
- }
- }
-
- virtual NGramFstMatcher<A>* Copy(bool safe = false) const {
- return new NGramFstMatcher<A>(*this, safe);
- }
-
- virtual MatchType Type(bool test) const {
- return match_type_;
- }
-
- virtual const Fst<A> &GetFst() const {
- return fst_;
- }
-
- virtual uint64 Properties(uint64 props) const {
- return props;
- }
-
- private:
- virtual void SetState_(StateId s) {
- fst_.GetImpl()->SetInstFuture(s, &inst_);
- current_loop_ = false;
- }
-
- virtual bool Find_(Label label) {
- const Label nolabel = kNoLabel;
- done_ = true;
- if (label == 0 || label == nolabel) {
- if (label == 0) {
- current_loop_ = true;
- loop_.nextstate = inst_.state_;
- }
- // The unigram state has no epsilon arc.
- if (inst_.state_ != 0) {
- arc_.ilabel = arc_.olabel = 0;
- fst_.GetImpl()->SetInstNode(&inst_);
- arc_.nextstate = fst_.GetImpl()->context_index_.Rank1(
- fst_.GetImpl()->context_index_.Select1(
- fst_.GetImpl()->context_index_.Rank0(inst_.node_) - 1));
- arc_.weight = fst_.GetImpl()->backoff_[inst_.state_];
- done_ = false;
- }
- } else {
- const Label *start = fst_.GetImpl()->future_words_ + inst_.offset_;
- const Label *end = start + inst_.num_futures_;
- const Label* search = lower_bound(start, end, label);
- if (search != end && *search == label) {
- size_t state = search - start;
- arc_.ilabel = arc_.olabel = label;
- arc_.weight = fst_.GetImpl()->future_probs_[inst_.offset_ + state];
- fst_.GetImpl()->SetInstContext(&inst_);
- arc_.nextstate = fst_.GetImpl()->Transition(inst_.context_, label);
- done_ = false;
- }
- }
- return !Done_();
- }
-
- virtual bool Done_() const {
- return !current_loop_ && done_;
- }
-
- virtual const Arc& Value_() const {
- return (current_loop_) ? loop_ : arc_;
- }
-
- virtual void Next_() {
- if (current_loop_) {
- current_loop_ = false;
- } else {
- done_ = true;
- }
- }
-
- const NGramFst<A>& fst_;
- NGramFstInst<A> inst_;
- MatchType match_type_; // Supplied by caller
- bool done_;
- Arc arc_;
- bool current_loop_; // Current arc is the implicit loop
- Arc loop_;
-};
-
-/*****************************************************************************/
-template<class A>
-class ArcIterator<NGramFst<A> > : public ArcIteratorBase<A> {
- public:
- typedef A Arc;
- typedef typename A::Label Label;
- typedef typename A::StateId StateId;
- typedef typename A::Weight Weight;
-
- ArcIterator(const NGramFst<A> &fst, StateId state)
- : lazy_(~0), impl_(fst.GetImpl()), i_(0), flags_(kArcValueFlags) {
- inst_ = fst.inst_;
- impl_->SetInstFuture(state, &inst_);
- impl_->SetInstNode(&inst_);
- }
-
- bool Done() const {
- return i_ >= ((inst_.node_ == 0) ? inst_.num_futures_ :
- inst_.num_futures_ + 1);
- }
-
- const Arc &Value() const {
- bool eps = (inst_.node_ != 0 && i_ == 0);
- StateId state = (inst_.node_ == 0) ? i_ : i_ - 1;
- if (flags_ & lazy_ & (kArcILabelValue | kArcOLabelValue)) {
- arc_.ilabel =
- arc_.olabel = eps ? 0 : impl_->future_words_[inst_.offset_ + state];
- lazy_ &= ~(kArcILabelValue | kArcOLabelValue);
- }
- if (flags_ & lazy_ & kArcNextStateValue) {
- if (eps) {
- arc_.nextstate = impl_->context_index_.Rank1(
- impl_->context_index_.Select1(
- impl_->context_index_.Rank0(inst_.node_) - 1));
- } else {
- if (lazy_ & kArcNextStateValue) {
- impl_->SetInstContext(&inst_); // first time only.
- }
- arc_.nextstate =
- impl_->Transition(inst_.context_,
- impl_->future_words_[inst_.offset_ + state]);
- }
- lazy_ &= ~kArcNextStateValue;
- }
- if (flags_ & lazy_ & kArcWeightValue) {
- arc_.weight = eps ? impl_->backoff_[inst_.state_] :
- impl_->future_probs_[inst_.offset_ + state];
- lazy_ &= ~kArcWeightValue;
- }
- return arc_;
- }
-
- void Next() {
- ++i_;
- lazy_ = ~0;
- }
-
- size_t Position() const { return i_; }
-
- void Reset() {
- i_ = 0;
- lazy_ = ~0;
- }
-
- void Seek(size_t a) {
- if (i_ != a) {
- i_ = a;
- lazy_ = ~0;
- }
- }
-
- uint32 Flags() const {
- return flags_;
- }
-
- void SetFlags(uint32 f, uint32 m) {
- flags_ &= ~m;
- flags_ |= (f & kArcValueFlags);
- }
-
- private:
- virtual bool Done_() const { return Done(); }
- virtual const Arc& Value_() const { return Value(); }
- virtual void Next_() { Next(); }
- virtual size_t Position_() const { return Position(); }
- virtual void Reset_() { Reset(); }
- virtual void Seek_(size_t a) { Seek(a); }
- uint32 Flags_() const { return Flags(); }
- void SetFlags_(uint32 f, uint32 m) { SetFlags(f, m); }
-
- mutable Arc arc_;
- mutable uint32 lazy_;
- const NGramFstImpl<A> *impl_;
- mutable NGramFstInst<A> inst_;
-
- size_t i_;
- uint32 flags_;
-
- DISALLOW_COPY_AND_ASSIGN(ArcIterator);
-};
-
-/*****************************************************************************/
-// Specialization for NGramFst; see generic version in fst.h
-// for sample usage (but use the ProdLmFst type!). This version
-// should inline.
-template <class A>
-class StateIterator<NGramFst<A> > : public StateIteratorBase<A> {
- public:
- typedef typename A::StateId StateId;
-
- explicit StateIterator(const NGramFst<A> &fst)
- : s_(0), num_states_(fst.NumStates()) { }
-
- bool Done() const { return s_ >= num_states_; }
- StateId Value() const { return s_; }
- void Next() { ++s_; }
- void Reset() { s_ = 0; }
-
- private:
- virtual bool Done_() const { return Done(); }
- virtual StateId Value_() const { return Value(); }
- virtual void Next_() { Next(); }
- virtual void Reset_() { Reset(); }
-
- StateId s_, num_states_;
-
- DISALLOW_COPY_AND_ASSIGN(StateIterator);
-};
-} // namespace fst
-#endif // FST_EXTENSIONS_NGRAM_NGRAM_FST_H_