From 96a32415ab43377cf1575bd3f4f2980f58028209 Mon Sep 17 00:00:00 2001 From: Determinant Date: Fri, 14 Aug 2015 11:51:42 +0800 Subject: add implementation for kaldi io (by ymz) --- kaldi_io/src/tools/openfst/include/fst/cache.h | 861 +++++++++++++++++++++++++ 1 file changed, 861 insertions(+) create mode 100644 kaldi_io/src/tools/openfst/include/fst/cache.h (limited to 'kaldi_io/src/tools/openfst/include/fst/cache.h') diff --git a/kaldi_io/src/tools/openfst/include/fst/cache.h b/kaldi_io/src/tools/openfst/include/fst/cache.h new file mode 100644 index 0000000..7c96fe1 --- /dev/null +++ b/kaldi_io/src/tools/openfst/include/fst/cache.h @@ -0,0 +1,861 @@ +// cache.h + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Copyright 2005-2010 Google, Inc. +// Author: riley@google.com (Michael Riley) +// +// \file +// An Fst implementation that caches FST elements of a delayed +// computation. + +#ifndef FST_LIB_CACHE_H__ +#define FST_LIB_CACHE_H__ + +#include +using std::vector; +#include + +#include + + +DECLARE_bool(fst_default_cache_gc); +DECLARE_int64(fst_default_cache_gc_limit); + +namespace fst { + +struct CacheOptions { + bool gc; // enable GC + size_t gc_limit; // # of bytes allowed before GC + + CacheOptions(bool g, size_t l) : gc(g), gc_limit(l) {} + CacheOptions() + : gc(FLAGS_fst_default_cache_gc), + gc_limit(FLAGS_fst_default_cache_gc_limit) {} +}; + +// A CacheStateAllocator allocates and frees CacheStates +// template +// struct CacheStateAllocator { +// S *Allocate(StateId s); +// void Free(S *state, StateId s); +// }; +// + +// A simple allocator class, can be overridden as needed, +// maintains a single entry cache. +template +struct DefaultCacheStateAllocator { + typedef typename S::Arc::StateId StateId; + + DefaultCacheStateAllocator() : mru_(NULL) { } + + ~DefaultCacheStateAllocator() { + delete mru_; + } + + S *Allocate(StateId s) { + if (mru_) { + S *state = mru_; + mru_ = NULL; + state->Reset(); + return state; + } + return new S(); + } + + void Free(S *state, StateId s) { + if (mru_) { + delete mru_; + } + mru_ = state; + } + + private: + S *mru_; +}; + +// VectorState but additionally has a flags data member (see +// CacheState below). This class is used to cache FST elements with +// the flags used to indicate what has been cached. Use HasStart() +// HasFinal(), and HasArcs() to determine if cached and SetStart(), +// SetFinal(), AddArc(), (or PushArc() and SetArcs()) to cache. Note +// you must set the final weight even if the state is non-final to +// mark it as cached. If the 'gc' option is 'false', cached items have +// the extent of the FST - minimizing computation. If the 'gc' option +// is 'true', garbage collection of states (not in use in an arc +// iterator and not 'protected') is performed, in a rough +// approximation of LRU order, when 'gc_limit' bytes is reached - +// controlling memory use. When 'gc_limit' is 0, special optimizations +// apply - minimizing memory use. + +template > +class CacheBaseImpl : public VectorFstBaseImpl { + public: + typedef S State; + typedef C Allocator; + typedef typename State::Arc Arc; + typedef typename Arc::Weight Weight; + typedef typename Arc::StateId StateId; + + using FstImpl::Type; + using FstImpl::Properties; + using FstImpl::SetProperties; + using VectorFstBaseImpl::NumStates; + using VectorFstBaseImpl::Start; + using VectorFstBaseImpl::AddState; + using VectorFstBaseImpl::SetState; + using VectorFstBaseImpl::ReserveStates; + + explicit CacheBaseImpl(C *allocator = 0) + : cache_start_(false), nknown_states_(0), min_unexpanded_state_id_(0), + cache_first_state_id_(kNoStateId), cache_first_state_(0), + cache_gc_(FLAGS_fst_default_cache_gc), cache_size_(0), + cache_limit_(FLAGS_fst_default_cache_gc_limit > kMinCacheLimit || + FLAGS_fst_default_cache_gc_limit == 0 ? + FLAGS_fst_default_cache_gc_limit : kMinCacheLimit), + protect_(false) { + allocator_ = allocator ? allocator : new C(); + } + + explicit CacheBaseImpl(const CacheOptions &opts, C *allocator = 0) + : cache_start_(false), nknown_states_(0), + min_unexpanded_state_id_(0), cache_first_state_id_(kNoStateId), + cache_first_state_(0), cache_gc_(opts.gc), cache_size_(0), + cache_limit_(opts.gc_limit > kMinCacheLimit || opts.gc_limit == 0 ? + opts.gc_limit : kMinCacheLimit), + protect_(false) { + allocator_ = allocator ? allocator : new C(); + } + + // Preserve gc parameters. If preserve_cache true, also preserves + // cache data. + CacheBaseImpl(const CacheBaseImpl &impl, bool preserve_cache = false) + : VectorFstBaseImpl(), cache_start_(false), nknown_states_(0), + min_unexpanded_state_id_(0), cache_first_state_id_(kNoStateId), + cache_first_state_(0), cache_gc_(impl.cache_gc_), cache_size_(0), + cache_limit_(impl.cache_limit_), + protect_(impl.protect_) { + allocator_ = new C(); + if (preserve_cache) { + cache_start_ = impl.cache_start_; + nknown_states_ = impl.nknown_states_; + expanded_states_ = impl.expanded_states_; + min_unexpanded_state_id_ = impl.min_unexpanded_state_id_; + if (impl.cache_first_state_id_ != kNoStateId) { + cache_first_state_id_ = impl.cache_first_state_id_; + cache_first_state_ = allocator_->Allocate(cache_first_state_id_); + *cache_first_state_ = *impl.cache_first_state_; + } + cache_states_ = impl.cache_states_; + cache_size_ = impl.cache_size_; + ReserveStates(impl.NumStates()); + for (StateId s = 0; s < impl.NumStates(); ++s) { + const S *state = + static_cast &>(impl).GetState(s); + if (state) { + S *copied_state = allocator_->Allocate(s); + *copied_state = *state; + AddState(copied_state); + } else { + AddState(0); + } + } + VectorFstBaseImpl::SetStart(impl.Start()); + } + } + + ~CacheBaseImpl() { + allocator_->Free(cache_first_state_, cache_first_state_id_); + delete allocator_; + } + + // Gets a state from its ID; state must exist. + const S *GetState(StateId s) const { + if (s == cache_first_state_id_) + return cache_first_state_; + else + return VectorFstBaseImpl::GetState(s); + } + + // Gets a state from its ID; state must exist. + S *GetState(StateId s) { + if (s == cache_first_state_id_) + return cache_first_state_; + else + return VectorFstBaseImpl::GetState(s); + } + + // Gets a state from its ID; return 0 if it doesn't exist. + const S *CheckState(StateId s) const { + if (s == cache_first_state_id_) + return cache_first_state_; + else if (s < NumStates()) + return VectorFstBaseImpl::GetState(s); + else + return 0; + } + + // Gets a state from its ID; add it if necessary. + S *ExtendState(StateId s); + + void SetStart(StateId s) { + VectorFstBaseImpl::SetStart(s); + cache_start_ = true; + if (s >= nknown_states_) + nknown_states_ = s + 1; + } + + void SetFinal(StateId s, Weight w) { + S *state = ExtendState(s); + state->final = w; + state->flags |= kCacheFinal | kCacheRecent | kCacheModified; + } + + // AddArc adds a single arc to state s and does incremental cache + // book-keeping. For efficiency, prefer PushArc and SetArcs below + // when possible. + void AddArc(StateId s, const Arc &arc) { + S *state = ExtendState(s); + state->arcs.push_back(arc); + if (arc.ilabel == 0) { + ++state->niepsilons; + } + if (arc.olabel == 0) { + ++state->noepsilons; + } + const Arc *parc = state->arcs.empty() ? 0 : &(state->arcs.back()); + SetProperties(AddArcProperties(Properties(), s, arc, parc)); + state->flags |= kCacheModified; + if (cache_gc_ && s != cache_first_state_id_ && + !(state->flags & kCacheProtect)) { + cache_size_ += sizeof(Arc); + if (cache_size_ > cache_limit_) + GC(s, false); + } + } + + // Adds a single arc to state s but delays cache book-keeping. + // SetArcs must be called when all PushArc calls at a state are + // complete. Do not mix with calls to AddArc. + void PushArc(StateId s, const Arc &arc) { + S *state = ExtendState(s); + state->arcs.push_back(arc); + } + + // Marks arcs of state s as cached and does cache book-keeping after all + // calls to PushArc have been completed. Do not mix with calls to AddArc. + void SetArcs(StateId s) { + S *state = ExtendState(s); + vector &arcs = state->arcs; + state->niepsilons = state->noepsilons = 0; + for (size_t a = 0; a < arcs.size(); ++a) { + const Arc &arc = arcs[a]; + if (arc.nextstate >= nknown_states_) + nknown_states_ = arc.nextstate + 1; + if (arc.ilabel == 0) + ++state->niepsilons; + if (arc.olabel == 0) + ++state->noepsilons; + } + ExpandedState(s); + state->flags |= kCacheArcs | kCacheRecent | kCacheModified; + if (cache_gc_ && s != cache_first_state_id_ && + !(state->flags & kCacheProtect)) { + cache_size_ += arcs.capacity() * sizeof(Arc); + if (cache_size_ > cache_limit_) + GC(s, false); + } + }; + + void ReserveArcs(StateId s, size_t n) { + S *state = ExtendState(s); + state->arcs.reserve(n); + } + + void DeleteArcs(StateId s, size_t n) { + S *state = ExtendState(s); + const vector &arcs = state->arcs; + for (size_t i = 0; i < n; ++i) { + size_t j = arcs.size() - i - 1; + if (arcs[j].ilabel == 0) + --state->niepsilons; + if (arcs[j].olabel == 0) + --state->noepsilons; + } + + state->arcs.resize(arcs.size() - n); + SetProperties(DeleteArcsProperties(Properties())); + state->flags |= kCacheModified; + if (cache_gc_ && s != cache_first_state_id_ && + !(state->flags & kCacheProtect)) { + cache_size_ -= n * sizeof(Arc); + } + } + + void DeleteArcs(StateId s) { + S *state = ExtendState(s); + size_t n = state->arcs.size(); + state->niepsilons = 0; + state->noepsilons = 0; + state->arcs.clear(); + SetProperties(DeleteArcsProperties(Properties())); + state->flags |= kCacheModified; + if (cache_gc_ && s != cache_first_state_id_ && + !(state->flags & kCacheProtect)) { + cache_size_ -= n * sizeof(Arc); + } + } + + void DeleteStates(const vector &dstates) { + size_t old_num_states = NumStates(); + vector newid(old_num_states, 0); + for (size_t i = 0; i < dstates.size(); ++i) + newid[dstates[i]] = kNoStateId; + StateId nstates = 0; + for (StateId s = 0; s < old_num_states; ++s) { + if (newid[s] != kNoStateId) { + newid[s] = nstates; + ++nstates; + } + } + // just for states_.resize(), does unnecessary walk. + VectorFstBaseImpl::DeleteStates(dstates); + SetProperties(DeleteStatesProperties(Properties())); + // Update list of cached states. + typename list::iterator siter = cache_states_.begin(); + while (siter != cache_states_.end()) { + if (newid[*siter] != kNoStateId) { + *siter = newid[*siter]; + ++siter; + } else { + cache_states_.erase(siter++); + } + } + } + + void DeleteStates() { + cache_states_.clear(); + allocator_->Free(cache_first_state_, cache_first_state_id_); + for (int s = 0; s < NumStates(); ++s) { + allocator_->Free(VectorFstBaseImpl::GetState(s), s); + SetState(s, 0); + } + nknown_states_ = 0; + min_unexpanded_state_id_ = 0; + cache_first_state_id_ = kNoStateId; + cache_first_state_ = 0; + cache_size_ = 0; + cache_start_ = false; + VectorFstBaseImpl::DeleteStates(); + SetProperties(DeleteAllStatesProperties(Properties(), + kExpanded | kMutable)); + } + + // Is the start state cached? + bool HasStart() const { + if (!cache_start_ && Properties(kError)) + cache_start_ = true; + return cache_start_; + } + + // Is the final weight of state s cached? + bool HasFinal(StateId s) const { + const S *state = CheckState(s); + if (state && state->flags & kCacheFinal) { + state->flags |= kCacheRecent; + return true; + } else { + return false; + } + } + + // Are arcs of state s cached? + bool HasArcs(StateId s) const { + const S *state = CheckState(s); + if (state && state->flags & kCacheArcs) { + state->flags |= kCacheRecent; + return true; + } else { + return false; + } + } + + Weight Final(StateId s) const { + const S *state = GetState(s); + return state->final; + } + + size_t NumArcs(StateId s) const { + const S *state = GetState(s); + return state->arcs.size(); + } + + size_t NumInputEpsilons(StateId s) const { + const S *state = GetState(s); + return state->niepsilons; + } + + size_t NumOutputEpsilons(StateId s) const { + const S *state = GetState(s); + return state->noepsilons; + } + + // Provides information needed for generic arc iterator. + void InitArcIterator(StateId s, ArcIteratorData *data) const { + const S *state = GetState(s); + data->base = 0; + data->narcs = state->arcs.size(); + data->arcs = data->narcs > 0 ? &(state->arcs[0]) : 0; + data->ref_count = &(state->ref_count); + ++(*data->ref_count); + } + + // Number of known states. + StateId NumKnownStates() const { return nknown_states_; } + + // Update number of known states taking in account the existence of state s. + void UpdateNumKnownStates(StateId s) { + if (s >= nknown_states_) + nknown_states_ = s + 1; + } + + // Find the mininum never-expanded state Id + StateId MinUnexpandedState() const { + while (min_unexpanded_state_id_ < expanded_states_.size() && + expanded_states_[min_unexpanded_state_id_]) + ++min_unexpanded_state_id_; + return min_unexpanded_state_id_; + } + + // Removes from cache_states_ and uncaches (not referenced-counted + // or protected) states that have not been accessed since the last + // GC until at most cache_fraction * cache_limit_ bytes are cached. + // If that fails to free enough, recurs uncaching recently visited + // states as well. If still unable to free enough memory, then + // widens cache_limit_ to fulfill condition. + void GC(StateId current, bool free_recent, float cache_fraction = 0.666); + + // Setc/clears GC protection: if true, new states are protected + // from garbage collection. + void GCProtect(bool on) { protect_ = on; } + + void ExpandedState(StateId s) { + if (s < min_unexpanded_state_id_) + return; + while (expanded_states_.size() <= s) + expanded_states_.push_back(false); + expanded_states_[s] = true; + } + + C *GetAllocator() const { + return allocator_; + } + + // Caching on/off switch, limit and size accessors. + bool GetCacheGc() const { return cache_gc_; } + size_t GetCacheLimit() const { return cache_limit_; } + size_t GetCacheSize() const { return cache_size_; } + + private: + static const size_t kMinCacheLimit = 8096; // Minimum (non-zero) cache limit + + static const uint32 kCacheFinal = 0x0001; // Final weight has been cached + static const uint32 kCacheArcs = 0x0002; // Arcs have been cached + static const uint32 kCacheRecent = 0x0004; // Mark as visited since GC + static const uint32 kCacheProtect = 0x0008; // Mark state as GC protected + + public: + static const uint32 kCacheModified = 0x0010; // Mark state as modified + static const uint32 kCacheFlags = kCacheFinal | kCacheArcs | kCacheRecent + | kCacheProtect | kCacheModified; + + private: + C *allocator_; // used to allocate new states + mutable bool cache_start_; // Is the start state cached? + StateId nknown_states_; // # of known states + vector expanded_states_; // states that have been expanded + mutable StateId min_unexpanded_state_id_; // minimum never-expanded state Id + StateId cache_first_state_id_; // First cached state id + S *cache_first_state_; // First cached state + list cache_states_; // list of currently cached states + bool cache_gc_; // enable GC + size_t cache_size_; // # of bytes cached + size_t cache_limit_; // # of bytes allowed before GC + bool protect_; // Protect new states from GC + + void operator=(const CacheBaseImpl &impl); // disallow +}; + +// Gets a state from its ID; add it if necessary. +template +S *CacheBaseImpl::ExtendState(typename S::Arc::StateId s) { + // If 'protect_' true and a new state, protects from garbage collection. + if (s == cache_first_state_id_) { + return cache_first_state_; // Return 1st cached state + } else if (cache_limit_ == 0 && cache_first_state_id_ == kNoStateId) { + cache_first_state_id_ = s; // Remember 1st cached state + cache_first_state_ = allocator_->Allocate(s); + if (protect_) cache_first_state_->flags |= kCacheProtect; + return cache_first_state_; + } else if (cache_first_state_id_ != kNoStateId && + cache_first_state_->ref_count == 0 && + !(cache_first_state_->flags & kCacheProtect)) { + // With Default allocator, the Free and Allocate will reuse the same S*. + allocator_->Free(cache_first_state_, cache_first_state_id_); + cache_first_state_id_ = s; + cache_first_state_ = allocator_->Allocate(s); + if (protect_) cache_first_state_->flags |= kCacheProtect; + return cache_first_state_; // Return 1st cached state + } else { + while (NumStates() <= s) // Add state to main cache + AddState(0); + S *state = VectorFstBaseImpl::GetState(s); + if (!state) { + state = allocator_->Allocate(s); + if (protect_) state->flags |= kCacheProtect; + SetState(s, state); + if (cache_first_state_id_ != kNoStateId) { // Forget 1st cached state + while (NumStates() <= cache_first_state_id_) + AddState(0); + SetState(cache_first_state_id_, cache_first_state_); + if (cache_gc_ && !(cache_first_state_->flags & kCacheProtect)) { + cache_states_.push_back(cache_first_state_id_); + cache_size_ += sizeof(S) + + cache_first_state_->arcs.capacity() * sizeof(Arc); + } + cache_limit_ = kMinCacheLimit; + cache_first_state_id_ = kNoStateId; + cache_first_state_ = 0; + } + if (cache_gc_ && !protect_) { + cache_states_.push_back(s); + cache_size_ += sizeof(S); + if (cache_size_ > cache_limit_) + GC(s, false); + } + } + return state; + } +} + +// Removes from cache_states_ and uncaches (not referenced-counted or +// protected) states that have not been accessed since the last GC +// until at most cache_fraction * cache_limit_ bytes are cached. If +// that fails to free enough, recurs uncaching recently visited states +// as well. If still unable to free enough memory, then widens cache_limit_ +// to fulfill condition. +template +void CacheBaseImpl::GC(typename S::Arc::StateId current, + bool free_recent, float cache_fraction) { + if (!cache_gc_) + return; + VLOG(2) << "CacheImpl: Enter GC: object = " << Type() << "(" << this + << "), free recently cached = " << free_recent + << ", cache size = " << cache_size_ + << ", cache frac = " << cache_fraction + << ", cache limit = " << cache_limit_ << "\n"; + typename list::iterator siter = cache_states_.begin(); + + size_t cache_target = cache_fraction * cache_limit_; + while (siter != cache_states_.end()) { + StateId s = *siter; + S* state = VectorFstBaseImpl::GetState(s); + if (cache_size_ > cache_target && state->ref_count == 0 && + (free_recent || !(state->flags & kCacheRecent)) && s != current) { + cache_size_ -= sizeof(S) + state->arcs.capacity() * sizeof(Arc); + allocator_->Free(state, s); + SetState(s, 0); + cache_states_.erase(siter++); + } else { + state->flags &= ~kCacheRecent; + ++siter; + } + } + if (!free_recent && cache_size_ > cache_target) { // recurses on recent + GC(current, true); + } else if (cache_target > 0) { // widens cache limit + while (cache_size_ > cache_target) { + cache_limit_ *= 2; + cache_target *= 2; + } + } else if (cache_size_ > 0) { + FSTERROR() << "CacheImpl:GC: Unable to free all cached states"; + } + VLOG(2) << "CacheImpl: Exit GC: object = " << Type() << "(" << this + << "), free recently cached = " << free_recent + << ", cache size = " << cache_size_ + << ", cache frac = " << cache_fraction + << ", cache limit = " << cache_limit_ << "\n"; +} + +template const uint32 CacheBaseImpl::kCacheFinal; +template const uint32 CacheBaseImpl::kCacheArcs; +template const uint32 CacheBaseImpl::kCacheRecent; +template const uint32 CacheBaseImpl::kCacheModified; +template const size_t CacheBaseImpl::kMinCacheLimit; + +// Arcs implemented by an STL vector per state. Similar to VectorState +// but adds flags and ref count to keep track of what has been cached. +template +struct CacheState { + typedef A Arc; + typedef typename A::Weight Weight; + typedef typename A::StateId StateId; + + CacheState() : final(Weight::Zero()), flags(0), ref_count(0) {} + + void Reset() { + flags = 0; + ref_count = 0; + arcs.resize(0); + } + + Weight final; // Final weight + vector arcs; // Arcs represenation + size_t niepsilons; // # of input epsilons + size_t noepsilons; // # of output epsilons + mutable uint32 flags; + mutable int ref_count; +}; + +// A CacheBaseImpl with a commonly used CacheState. +template +class CacheImpl : public CacheBaseImpl< CacheState > { + public: + typedef CacheState State; + + CacheImpl() {} + + explicit CacheImpl(const CacheOptions &opts) + : CacheBaseImpl< CacheState >(opts) {} + + CacheImpl(const CacheImpl &impl, bool preserve_cache = false) + : CacheBaseImpl(impl, preserve_cache) {} + + private: + void operator=(const CacheImpl &impl); // disallow +}; + + +// Use this to make a state iterator for a CacheBaseImpl-derived Fst, +// which must have type 'State' defined. Note this iterator only +// returns those states reachable from the initial state, so consider +// implementing a class-specific one. +template +class CacheStateIterator : public StateIteratorBase { + public: + typedef typename F::Arc Arc; + typedef typename Arc::StateId StateId; + typedef typename F::State State; + typedef CacheBaseImpl Impl; + + CacheStateIterator(const F &fst, Impl *impl) + : fst_(fst), impl_(impl), s_(0) { + fst_.Start(); // force start state + } + + bool Done() const { + if (s_ < impl_->NumKnownStates()) + return false; + if (s_ < impl_->NumKnownStates()) + return false; + for (StateId u = impl_->MinUnexpandedState(); + u < impl_->NumKnownStates(); + u = impl_->MinUnexpandedState()) { + // force state expansion + ArcIterator aiter(fst_, u); + aiter.SetFlags(kArcValueFlags, kArcValueFlags | kArcNoCache); + for (; !aiter.Done(); aiter.Next()) + impl_->UpdateNumKnownStates(aiter.Value().nextstate); + impl_->ExpandedState(u); + if (s_ < impl_->NumKnownStates()) + return false; + } + return true; + } + + StateId Value() const { return s_; } + + void Next() { ++s_; } + + void Reset() { s_ = 0; } + + private: + // This allows base class virtual access to non-virtual derived- + // class members of the same name. It makes the derived class more + // efficient to use but unsafe to further derive. + virtual bool Done_() const { return Done(); } + virtual StateId Value_() const { return Value(); } + virtual void Next_() { Next(); } + virtual void Reset_() { Reset(); } + + const F &fst_; + Impl *impl_; + StateId s_; +}; + + +// Use this to make an arc iterator for a CacheBaseImpl-derived Fst, +// which must have types 'Arc' and 'State' defined. +template > > +class CacheArcIterator { + public: + typedef typename F::Arc Arc; + typedef typename F::State State; + typedef typename Arc::StateId StateId; + typedef CacheBaseImpl Impl; + + CacheArcIterator(Impl *impl, StateId s) : i_(0) { + state_ = impl->ExtendState(s); + ++state_->ref_count; + } + + ~CacheArcIterator() { --state_->ref_count; } + + bool Done() const { return i_ >= state_->arcs.size(); } + + const Arc& Value() const { return state_->arcs[i_]; } + + void Next() { ++i_; } + + size_t Position() const { return i_; } + + void Reset() { i_ = 0; } + + void Seek(size_t a) { i_ = a; } + + uint32 Flags() const { + return kArcValueFlags; + } + + void SetFlags(uint32 flags, uint32 mask) {} + + private: + const State *state_; + size_t i_; + + DISALLOW_COPY_AND_ASSIGN(CacheArcIterator); +}; + +// Use this to make a mutable arc iterator for a CacheBaseImpl-derived Fst, +// which must have types 'Arc' and 'State' defined. +template > > +class CacheMutableArcIterator + : public MutableArcIteratorBase { + public: + typedef typename F::State State; + typedef typename F::Arc Arc; + typedef typename Arc::StateId StateId; + typedef typename Arc::Weight Weight; + typedef CacheBaseImpl Impl; + + // You will need to call MutateCheck() in the constructor. + CacheMutableArcIterator(Impl *impl, StateId s) : i_(0), s_(s), impl_(impl) { + state_ = impl_->ExtendState(s_); + ++state_->ref_count; + }; + + ~CacheMutableArcIterator() { + --state_->ref_count; + } + + bool Done() const { return i_ >= state_->arcs.size(); } + + const Arc& Value() const { return state_->arcs[i_]; } + + void Next() { ++i_; } + + size_t Position() const { return i_; } + + void Reset() { i_ = 0; } + + void Seek(size_t a) { i_ = a; } + + void SetValue(const Arc& arc) { + state_->flags |= CacheBaseImpl::kCacheModified; + uint64 properties = impl_->Properties(); + Arc& oarc = state_->arcs[i_]; + if (oarc.ilabel != oarc.olabel) + properties &= ~kNotAcceptor; + if (oarc.ilabel == 0) { + --state_->niepsilons; + properties &= ~kIEpsilons; + if (oarc.olabel == 0) + properties &= ~kEpsilons; + } + if (oarc.olabel == 0) { + --state_->noepsilons; + properties &= ~kOEpsilons; + } + if (oarc.weight != Weight::Zero() && oarc.weight != Weight::One()) + properties &= ~kWeighted; + oarc = arc; + if (arc.ilabel != arc.olabel) { + properties |= kNotAcceptor; + properties &= ~kAcceptor; + } + if (arc.ilabel == 0) { + ++state_->niepsilons; + properties |= kIEpsilons; + properties &= ~kNoIEpsilons; + if (arc.olabel == 0) { + properties |= kEpsilons; + properties &= ~kNoEpsilons; + } + } + if (arc.olabel == 0) { + ++state_->noepsilons; + properties |= kOEpsilons; + properties &= ~kNoOEpsilons; + } + if (arc.weight != Weight::Zero() && arc.weight != Weight::One()) { + properties |= kWeighted; + properties &= ~kUnweighted; + } + properties &= kSetArcProperties | kAcceptor | kNotAcceptor | + kEpsilons | kNoEpsilons | kIEpsilons | kNoIEpsilons | + kOEpsilons | kNoOEpsilons | kWeighted | kUnweighted; + impl_->SetProperties(properties); + } + + uint32 Flags() const { + return kArcValueFlags; + } + + void SetFlags(uint32 f, uint32 m) {} + + private: + virtual bool Done_() const { return Done(); } + virtual const Arc& Value_() const { return Value(); } + virtual void Next_() { Next(); } + virtual size_t Position_() const { return Position(); } + virtual void Reset_() { Reset(); } + virtual void Seek_(size_t a) { Seek(a); } + virtual void SetValue_(const Arc &a) { SetValue(a); } + uint32 Flags_() const { return Flags(); } + void SetFlags_(uint32 f, uint32 m) { SetFlags(f, m); } + + size_t i_; + StateId s_; + Impl *impl_; + State *state_; + + DISALLOW_COPY_AND_ASSIGN(CacheMutableArcIterator); +}; + +} // namespace fst + +#endif // FST_LIB_CACHE_H__ -- cgit v1.2.3