// vector-fst.h // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // // Copyright 2005-2010 Google, Inc. // Author: riley@google.com (Michael Riley) // // \file // Simple concrete, mutable FST whose states and arcs are stored in STL // vectors. #ifndef FST_LIB_VECTOR_FST_H__ #define FST_LIB_VECTOR_FST_H__ #include #include using std::vector; #include #include namespace fst { template class VectorFst; template void Cast(const F &, G *); // States and arcs implemented by STL vectors, templated on the // State definition. This does not manage the Fst properties. template class VectorFstBaseImpl : public FstImpl { public: typedef typename State::Arc Arc; typedef typename Arc::Weight Weight; typedef typename Arc::StateId StateId; VectorFstBaseImpl() : start_(kNoStateId) {} ~VectorFstBaseImpl() { for (StateId s = 0; s < states_.size(); ++s) delete states_[s]; } StateId Start() const { return start_; } Weight Final(StateId s) const { return states_[s]->final; } StateId NumStates() const { return states_.size(); } size_t NumArcs(StateId s) const { return states_[s]->arcs.size(); } void SetStart(StateId s) { start_ = s; } void SetFinal(StateId s, Weight w) { states_[s]->final = w; } StateId AddState() { states_.push_back(new State); return states_.size() - 1; } StateId AddState(State *state) { states_.push_back(state); return states_.size() - 1; } void AddArc(StateId s, const Arc &arc) { states_[s]->arcs.push_back(arc); } void DeleteStates(const vector& dstates) { vector newid(states_.size(), 0); for (size_t i = 0; i < dstates.size(); ++i) newid[dstates[i]] = kNoStateId; StateId nstates = 0; for (StateId s = 0; s < states_.size(); ++s) { if (newid[s] != kNoStateId) { newid[s] = nstates; if (s != nstates) states_[nstates] = states_[s]; ++nstates; } else { delete states_[s]; } } states_.resize(nstates); for (StateId s = 0; s < states_.size(); ++s) { vector &arcs = states_[s]->arcs; size_t narcs = 0; for (size_t i = 0; i < arcs.size(); ++i) { StateId t = newid[arcs[i].nextstate]; if (t != kNoStateId) { arcs[i].nextstate = t; if (i != narcs) arcs[narcs] = arcs[i]; ++narcs; } else { if (arcs[i].ilabel == 0) --states_[s]->niepsilons; if (arcs[i].olabel == 0) --states_[s]->noepsilons; } } arcs.resize(narcs); } if (Start() != kNoStateId) SetStart(newid[Start()]); } void DeleteStates() { for (StateId s = 0; s < states_.size(); ++s) delete states_[s]; states_.clear(); SetStart(kNoStateId); } void DeleteArcs(StateId s, size_t n) { states_[s]->arcs.resize(states_[s]->arcs.size() - n); } void DeleteArcs(StateId s) { states_[s]->arcs.clear(); } State *GetState(StateId s) { return states_[s]; } const State *GetState(StateId s) const { return states_[s]; } void SetState(StateId s, State *state) { states_[s] = state; } void ReserveStates(StateId n) { states_.reserve(n); } void ReserveArcs(StateId s, size_t n) { states_[s]->arcs.reserve(n); } // Provide information needed for generic state iterator void InitStateIterator(StateIteratorData *data) const { data->base = 0; data->nstates = states_.size(); } // Provide information needed for generic arc iterator void InitArcIterator(StateId s, ArcIteratorData *data) const { data->base = 0; data->narcs = states_[s]->arcs.size(); data->arcs = data->narcs > 0 ? &states_[s]->arcs[0] : 0; data->ref_count = 0; } private: vector states_; // States represenation. StateId start_; // initial state DISALLOW_COPY_AND_ASSIGN(VectorFstBaseImpl); }; // Arcs implemented by an STL vector per state. template struct VectorState { typedef A Arc; typedef typename A::Weight Weight; typedef typename A::StateId StateId; VectorState() : final(Weight::Zero()), niepsilons(0), noepsilons(0) {} Weight final; // Final weight vector arcs; // Arcs represenation size_t niepsilons; // # of input epsilons size_t noepsilons; // # of output epsilons }; // This is a VectorFstBaseImpl container that holds VectorState's. It // manages Fst properties and the # of input and output epsilons. template class VectorFstImpl : public VectorFstBaseImpl< VectorState > { public: using FstImpl::SetInputSymbols; using FstImpl::SetOutputSymbols; using FstImpl::SetType; using FstImpl::SetProperties; using FstImpl::Properties; using VectorFstBaseImpl >::Start; using VectorFstBaseImpl >::NumStates; using VectorFstBaseImpl >::GetState; using VectorFstBaseImpl >::ReserveArcs; friend class MutableArcIterator< VectorFst >; typedef VectorFstBaseImpl< VectorState > BaseImpl; typedef typename A::Weight Weight; typedef typename A::StateId StateId; VectorFstImpl() { SetType("vector"); SetProperties(kNullProperties | kStaticProperties); } explicit VectorFstImpl(const Fst &fst); static VectorFstImpl *Read(istream &strm, const FstReadOptions &opts); size_t NumInputEpsilons(StateId s) const { return GetState(s)->niepsilons; } size_t NumOutputEpsilons(StateId s) const { return GetState(s)->noepsilons; } void SetStart(StateId s) { BaseImpl::SetStart(s); SetProperties(SetStartProperties(Properties())); } void SetFinal(StateId s, Weight w) { Weight ow = BaseImpl::Final(s); BaseImpl::SetFinal(s, w); SetProperties(SetFinalProperties(Properties(), ow, w)); } StateId AddState() { StateId s = BaseImpl::AddState(); SetProperties(AddStateProperties(Properties())); return s; } void AddArc(StateId s, const A &arc) { VectorState *state = GetState(s); if (arc.ilabel == 0) { ++state->niepsilons; } if (arc.olabel == 0) { ++state->noepsilons; } const A *parc = state->arcs.empty() ? 0 : &(state->arcs.back()); SetProperties(AddArcProperties(Properties(), s, arc, parc)); BaseImpl::AddArc(s, arc); } void DeleteStates(const vector &dstates) { BaseImpl::DeleteStates(dstates); SetProperties(DeleteStatesProperties(Properties())); } void DeleteStates() { BaseImpl::DeleteStates(); SetProperties(DeleteAllStatesProperties(Properties(), kStaticProperties)); } void DeleteArcs(StateId s, size_t n) { const vector &arcs = GetState(s)->arcs; for (size_t i = 0; i < n; ++i) { size_t j = arcs.size() - i - 1; if (arcs[j].ilabel == 0) --GetState(s)->niepsilons; if (arcs[j].olabel == 0) --GetState(s)->noepsilons; } BaseImpl::DeleteArcs(s, n); SetProperties(DeleteArcsProperties(Properties())); } void DeleteArcs(StateId s) { GetState(s)->niepsilons = 0; GetState(s)->noepsilons = 0; BaseImpl::DeleteArcs(s); SetProperties(DeleteArcsProperties(Properties())); } // Properties always true of this Fst class static const uint64 kStaticProperties = kExpanded | kMutable; private: // Current file format version static const int kFileVersion = 2; // Minimum file format version supported static const int kMinFileVersion = 1; DISALLOW_COPY_AND_ASSIGN(VectorFstImpl); }; template const uint64 VectorFstImpl::kStaticProperties; template const int VectorFstImpl::kFileVersion; template const int VectorFstImpl::kMinFileVersion; template VectorFstImpl::VectorFstImpl(const Fst &fst) { SetType("vector"); SetInputSymbols(fst.InputSymbols()); SetOutputSymbols(fst.OutputSymbols()); BaseImpl::SetStart(fst.Start()); if (fst.Properties(kExpanded, false)) BaseImpl::ReserveStates(CountStates(fst)); for (StateIterator< Fst > siter(fst); !siter.Done(); siter.Next()) { StateId s = siter.Value(); BaseImpl::AddState(); BaseImpl::SetFinal(s, fst.Final(s)); ReserveArcs(s, fst.NumArcs(s)); for (ArcIterator< Fst > aiter(fst, s); !aiter.Done(); aiter.Next()) { const A &arc = aiter.Value(); BaseImpl::AddArc(s, arc); if (arc.ilabel == 0) ++GetState(s)->niepsilons; if (arc.olabel == 0) ++GetState(s)->noepsilons; } } SetProperties(fst.Properties(kCopyProperties, false) | kStaticProperties); } template VectorFstImpl *VectorFstImpl::Read(istream &strm, const FstReadOptions &opts) { VectorFstImpl *impl = new VectorFstImpl; FstHeader hdr; if (!impl->ReadHeader(strm, opts, kMinFileVersion, &hdr)) { delete impl; return 0; } impl->BaseImpl::SetStart(hdr.Start()); if (hdr.NumStates() != kNoStateId) { impl->ReserveStates(hdr.NumStates()); } StateId s = 0; for (;hdr.NumStates() == kNoStateId || s < hdr.NumStates(); ++s) { typename A::Weight final; if (!final.Read(strm)) break; impl->BaseImpl::AddState(); VectorState *state = impl->GetState(s); state->final = final; int64 narcs; ReadType(strm, &narcs); if (!strm) { LOG(ERROR) << "VectorFst::Read: read failed: " << opts.source; delete impl; return 0; } impl->ReserveArcs(s, narcs); for (size_t j = 0; j < narcs; ++j) { A arc; ReadType(strm, &arc.ilabel); ReadType(strm, &arc.olabel); arc.weight.Read(strm); ReadType(strm, &arc.nextstate); if (!strm) { LOG(ERROR) << "VectorFst::Read: read failed: " << opts.source; delete impl; return 0; } impl->BaseImpl::AddArc(s, arc); if (arc.ilabel == 0) ++state->niepsilons; if (arc.olabel == 0) ++state->noepsilons; } } if (hdr.NumStates() != kNoStateId && s != hdr.NumStates()) { LOG(ERROR) << "VectorFst::Read: unexpected end of file: " << opts.source; delete impl; return 0; } return impl; } // Converts a string into a weight. template class WeightFromString { public: W operator()(const string &s); }; // Generic case fails. template inline W WeightFromString::operator()(const string &s) { FSTERROR() << "VectorFst::Read: Obsolete file format"; return W::NoWeight(); } // TropicalWeight version. template <> inline TropicalWeight WeightFromString::operator()(const string &s) { float f; memcpy(&f, s.data(), sizeof(f)); return TropicalWeight(f); } // LogWeight version. template <> inline LogWeight WeightFromString::operator()(const string &s) { float f; memcpy(&f, s.data(), sizeof(f)); return LogWeight(f); } // Simple concrete, mutable FST. This class attaches interface to // implementation and handles reference counting, delegating most // methods to ImplToMutableFst. Supports additional operations: // ReserveStates and ReserveArcs (cf. STL vectors). template class VectorFst : public ImplToMutableFst< VectorFstImpl > { public: friend class StateIterator< VectorFst >; friend class ArcIterator< VectorFst >; friend class MutableArcIterator< VectorFst >; template friend void Cast(const F &, G *); typedef A Arc; typedef typename A::StateId StateId; typedef VectorFstImpl Impl; VectorFst() : ImplToMutableFst(new Impl) {} explicit VectorFst(const Fst &fst) : ImplToMutableFst(new Impl(fst)) {} VectorFst(const VectorFst &fst) : ImplToMutableFst(fst) {} // Get a copy of this VectorFst. See Fst<>::Copy() for further doc. virtual VectorFst *Copy(bool safe = false) const { return new VectorFst(*this); } VectorFst &operator=(const VectorFst &fst) { SetImpl(fst.GetImpl(), false); return *this; } virtual VectorFst &operator=(const Fst &fst) { if (this != &fst) SetImpl(new Impl(fst)); return *this; } // Read a VectorFst from an input stream; return NULL on error static VectorFst *Read(istream &strm, const FstReadOptions &opts) { Impl* impl = Impl::Read(strm, opts); return impl ? new VectorFst(impl) : 0; } // Read a VectorFst from a file; return NULL on error // Empty filename reads from standard input static VectorFst *Read(const string &filename) { Impl* impl = ImplToExpandedFst >::Read(filename); return impl ? new VectorFst(impl) : 0; } virtual bool Write(ostream &strm, const FstWriteOptions &opts) const { return WriteFst(*this, strm, opts); } virtual bool Write(const string &filename) const { return Fst::WriteFile(filename); } template static bool WriteFst(const F &fst, ostream &strm, const FstWriteOptions &opts); void ReserveStates(StateId n) { MutateCheck(); GetImpl()->ReserveStates(n); } void ReserveArcs(StateId s, size_t n) { MutateCheck(); GetImpl()->ReserveArcs(s, n); } virtual void InitStateIterator(StateIteratorData *data) const { GetImpl()->InitStateIterator(data); } virtual void InitArcIterator(StateId s, ArcIteratorData *data) const { GetImpl()->InitArcIterator(s, data); } virtual inline void InitMutableArcIterator(StateId s, MutableArcIteratorData *); private: explicit VectorFst(Impl *impl) : ImplToMutableFst(impl) {} // Makes visible to friends. Impl *GetImpl() const { return ImplToFst< Impl, MutableFst >::GetImpl(); } void SetImpl(Impl *impl, bool own_impl = true) { ImplToFst< Impl, MutableFst >::SetImpl(impl, own_impl); } void MutateCheck() { return ImplToMutableFst::MutateCheck(); } }; // Specialization for VectorFst; see generic version in fst.h // for sample usage (but use the VectorFst type!). This version // should inline. template class StateIterator< VectorFst > { public: typedef typename A::StateId StateId; explicit StateIterator(const VectorFst &fst) : nstates_(fst.GetImpl()->NumStates()), s_(0) {} bool Done() const { return s_ >= nstates_; } StateId Value() const { return s_; } void Next() { ++s_; } void Reset() { s_ = 0; } private: StateId nstates_; StateId s_; DISALLOW_COPY_AND_ASSIGN(StateIterator); }; // Writes Fst to file, will call CountStates so may involve two passes if // called from an Fst that is not derived from Expanded. template template bool VectorFst::WriteFst(const F &fst, ostream &strm, const FstWriteOptions &opts) { static const int kFileVersion = 2; bool update_header = true; FstHeader hdr; hdr.SetStart(fst.Start()); hdr.SetNumStates(kNoStateId); size_t start_offset = 0; if (fst.Properties(kExpanded, false) || (start_offset = strm.tellp()) != -1) { hdr.SetNumStates(CountStates(fst)); update_header = false; } uint64 properties = fst.Properties(kCopyProperties, false) | VectorFstImpl::kStaticProperties; FstImpl::WriteFstHeader(fst, strm, opts, kFileVersion, "vector", properties, &hdr); StateId num_states = 0; for (StateIterator siter(fst); !siter.Done(); siter.Next()) { typename A::StateId s = siter.Value(); fst.Final(s).Write(strm); int64 narcs = fst.NumArcs(s); WriteType(strm, narcs); for (ArcIterator aiter(fst, s); !aiter.Done(); aiter.Next()) { const A &arc = aiter.Value(); WriteType(strm, arc.ilabel); WriteType(strm, arc.olabel); arc.weight.Write(strm); WriteType(strm, arc.nextstate); } num_states++; } strm.flush(); if (!strm) { LOG(ERROR) << "VectorFst::Write: write failed: " << opts.source; return false; } if (update_header) { hdr.SetNumStates(num_states); return FstImpl::UpdateFstHeader(fst, strm, opts, kFileVersion, "vector", properties, &hdr, start_offset); } else { if (num_states != hdr.NumStates()) { LOG(ERROR) << "Inconsistent number of states observed during write"; return false; } } return true; } // Specialization for VectorFst; see generic version in fst.h // for sample usage (but use the VectorFst type!). This version // should inline. template class ArcIterator< VectorFst > { public: typedef typename A::StateId StateId; ArcIterator(const VectorFst &fst, StateId s) : arcs_(fst.GetImpl()->GetState(s)->arcs), i_(0) {} bool Done() const { return i_ >= arcs_.size(); } const A& Value() const { return arcs_[i_]; } void Next() { ++i_; } void Reset() { i_ = 0; } void Seek(size_t a) { i_ = a; } size_t Position() const { return i_; } uint32 Flags() const { return kArcValueFlags; } void SetFlags(uint32 f, uint32 m) {} private: const vector& arcs_; size_t i_; DISALLOW_COPY_AND_ASSIGN(ArcIterator); }; // Specialization for VectorFst; see generic version in fst.h // for sample usage (but use the VectorFst type!). This version // should inline. template class MutableArcIterator< VectorFst > : public MutableArcIteratorBase { public: typedef typename A::StateId StateId; typedef typename A::Weight Weight; MutableArcIterator(VectorFst *fst, StateId s) : i_(0) { fst->MutateCheck(); state_ = fst->GetImpl()->GetState(s); properties_ = &fst->GetImpl()->properties_; } bool Done() const { return i_ >= state_->arcs.size(); } const A& Value() const { return state_->arcs[i_]; } void Next() { ++i_; } size_t Position() const { return i_; } void Reset() { i_ = 0; } void Seek(size_t a) { i_ = a; } void SetValue(const A &arc) { A& oarc = state_->arcs[i_]; if (oarc.ilabel != oarc.olabel) *properties_ &= ~kNotAcceptor; if (oarc.ilabel == 0) { --state_->niepsilons; *properties_ &= ~kIEpsilons; if (oarc.olabel == 0) *properties_ &= ~kEpsilons; } if (oarc.olabel == 0) { --state_->noepsilons; *properties_ &= ~kOEpsilons; } if (oarc.weight != Weight::Zero() && oarc.weight != Weight::One()) *properties_ &= ~kWeighted; oarc = arc; if (arc.ilabel != arc.olabel) { *properties_ |= kNotAcceptor; *properties_ &= ~kAcceptor; } if (arc.ilabel == 0) { ++state_->niepsilons; *properties_ |= kIEpsilons; *properties_ &= ~kNoIEpsilons; if (arc.olabel == 0) { *properties_ |= kEpsilons; *properties_ &= ~kNoEpsilons; } } if (arc.olabel == 0) { ++state_->noepsilons; *properties_ |= kOEpsilons; *properties_ &= ~kNoOEpsilons; } if (arc.weight != Weight::Zero() && arc.weight != Weight::One()) { *properties_ |= kWeighted; *properties_ &= ~kUnweighted; } *properties_ &= kSetArcProperties | kAcceptor | kNotAcceptor | kEpsilons | kNoEpsilons | kIEpsilons | kNoIEpsilons | kOEpsilons | kNoOEpsilons | kWeighted | kUnweighted; } uint32 Flags() const { return kArcValueFlags; } void SetFlags(uint32 f, uint32 m) {} private: // This allows base-class virtual access to non-virtual derived- // class members of the same name. It makes the derived class more // efficient to use but unsafe to further derive. virtual bool Done_() const { return Done(); } virtual const A& Value_() const { return Value(); } virtual void Next_() { Next(); } virtual size_t Position_() const { return Position(); } virtual void Reset_() { Reset(); } virtual void Seek_(size_t a) { Seek(a); } virtual void SetValue_(const A &a) { SetValue(a); } uint32 Flags_() const { return Flags(); } void SetFlags_(uint32 f, uint32 m) { SetFlags(f, m); } struct VectorState *state_; uint64 *properties_; size_t i_; DISALLOW_COPY_AND_ASSIGN(MutableArcIterator); }; // Provide information needed for the generic mutable arc iterator template inline void VectorFst::InitMutableArcIterator( StateId s, MutableArcIteratorData *data) { data->base = new MutableArcIterator< VectorFst >(this, s); } // A useful alias when using StdArc. typedef VectorFst StdVectorFst; } // namespace fst #endif // FST_LIB_VECTOR_FST_H__