diff options
author | Determinant <[email protected]> | 2015-08-14 11:51:42 +0800 |
---|---|---|
committer | Determinant <[email protected]> | 2015-08-14 11:51:42 +0800 |
commit | 96a32415ab43377cf1575bd3f4f2980f58028209 (patch) | |
tree | 30a2d92d73e8f40ac87b79f6f56e227bfc4eea6e /kaldi_io/src/tools/openfst/include/fst/vector-fst.h | |
parent | c177a7549bd90670af4b29fa813ddea32cfe0f78 (diff) |
add implementation for kaldi io (by ymz)
Diffstat (limited to 'kaldi_io/src/tools/openfst/include/fst/vector-fst.h')
-rw-r--r-- | kaldi_io/src/tools/openfst/include/fst/vector-fst.h | 731 |
1 files changed, 731 insertions, 0 deletions
diff --git a/kaldi_io/src/tools/openfst/include/fst/vector-fst.h b/kaldi_io/src/tools/openfst/include/fst/vector-fst.h new file mode 100644 index 0000000..8b80876 --- /dev/null +++ b/kaldi_io/src/tools/openfst/include/fst/vector-fst.h @@ -0,0 +1,731 @@ +// vector-fst.h + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Copyright 2005-2010 Google, Inc. +// Author: [email protected] (Michael Riley) +// +// \file +// Simple concrete, mutable FST whose states and arcs are stored in STL +// vectors. + +#ifndef FST_LIB_VECTOR_FST_H__ +#define FST_LIB_VECTOR_FST_H__ + +#include <string> +#include <vector> +using std::vector; + +#include <fst/mutable-fst.h> +#include <fst/test-properties.h> + + +namespace fst { + +template <class A> class VectorFst; +template <class F, class G> void Cast(const F &, G *); + + +// States and arcs implemented by STL vectors, templated on the +// State definition. This does not manage the Fst properties. +template <class State> +class VectorFstBaseImpl : public FstImpl<typename State::Arc> { + public: + typedef typename State::Arc Arc; + typedef typename Arc::Weight Weight; + typedef typename Arc::StateId StateId; + + VectorFstBaseImpl() : start_(kNoStateId) {} + + ~VectorFstBaseImpl() { + for (StateId s = 0; s < states_.size(); ++s) + delete states_[s]; + } + + StateId Start() const { return start_; } + + Weight Final(StateId s) const { return states_[s]->final; } + + StateId NumStates() const { return states_.size(); } + + size_t NumArcs(StateId s) const { return states_[s]->arcs.size(); } + + void SetStart(StateId s) { start_ = s; } + + void SetFinal(StateId s, Weight w) { states_[s]->final = w; } + + StateId AddState() { + states_.push_back(new State); + return states_.size() - 1; + } + + StateId AddState(State *state) { + states_.push_back(state); + return states_.size() - 1; + } + + void AddArc(StateId s, const Arc &arc) { + states_[s]->arcs.push_back(arc); + } + + void DeleteStates(const vector<StateId>& dstates) { + vector<StateId> newid(states_.size(), 0); + for (size_t i = 0; i < dstates.size(); ++i) + newid[dstates[i]] = kNoStateId; + StateId nstates = 0; + for (StateId s = 0; s < states_.size(); ++s) { + if (newid[s] != kNoStateId) { + newid[s] = nstates; + if (s != nstates) + states_[nstates] = states_[s]; + ++nstates; + } else { + delete states_[s]; + } + } + states_.resize(nstates); + for (StateId s = 0; s < states_.size(); ++s) { + vector<Arc> &arcs = states_[s]->arcs; + size_t narcs = 0; + for (size_t i = 0; i < arcs.size(); ++i) { + StateId t = newid[arcs[i].nextstate]; + if (t != kNoStateId) { + arcs[i].nextstate = t; + if (i != narcs) + arcs[narcs] = arcs[i]; + ++narcs; + } else { + if (arcs[i].ilabel == 0) + --states_[s]->niepsilons; + if (arcs[i].olabel == 0) + --states_[s]->noepsilons; + } + } + arcs.resize(narcs); + } + if (Start() != kNoStateId) + SetStart(newid[Start()]); + } + + void DeleteStates() { + for (StateId s = 0; s < states_.size(); ++s) + delete states_[s]; + states_.clear(); + SetStart(kNoStateId); + } + + void DeleteArcs(StateId s, size_t n) { + states_[s]->arcs.resize(states_[s]->arcs.size() - n); + } + + void DeleteArcs(StateId s) { states_[s]->arcs.clear(); } + + State *GetState(StateId s) { return states_[s]; } + + const State *GetState(StateId s) const { return states_[s]; } + + void SetState(StateId s, State *state) { states_[s] = state; } + + void ReserveStates(StateId n) { states_.reserve(n); } + + void ReserveArcs(StateId s, size_t n) { states_[s]->arcs.reserve(n); } + + // Provide information needed for generic state iterator + void InitStateIterator(StateIteratorData<Arc> *data) const { + data->base = 0; + data->nstates = states_.size(); + } + + // Provide information needed for generic arc iterator + void InitArcIterator(StateId s, ArcIteratorData<Arc> *data) const { + data->base = 0; + data->narcs = states_[s]->arcs.size(); + data->arcs = data->narcs > 0 ? &states_[s]->arcs[0] : 0; + data->ref_count = 0; + } + + private: + vector<State *> states_; // States represenation. + StateId start_; // initial state + + DISALLOW_COPY_AND_ASSIGN(VectorFstBaseImpl); +}; + +// Arcs implemented by an STL vector per state. +template <class A> +struct VectorState { + typedef A Arc; + typedef typename A::Weight Weight; + typedef typename A::StateId StateId; + + VectorState() : final(Weight::Zero()), niepsilons(0), noepsilons(0) {} + + Weight final; // Final weight + vector<A> arcs; // Arcs represenation + size_t niepsilons; // # of input epsilons + size_t noepsilons; // # of output epsilons +}; + +// This is a VectorFstBaseImpl container that holds VectorState's. It +// manages Fst properties and the # of input and output epsilons. +template <class A> +class VectorFstImpl : public VectorFstBaseImpl< VectorState<A> > { + public: + using FstImpl<A>::SetInputSymbols; + using FstImpl<A>::SetOutputSymbols; + using FstImpl<A>::SetType; + using FstImpl<A>::SetProperties; + using FstImpl<A>::Properties; + + using VectorFstBaseImpl<VectorState<A> >::Start; + using VectorFstBaseImpl<VectorState<A> >::NumStates; + using VectorFstBaseImpl<VectorState<A> >::GetState; + using VectorFstBaseImpl<VectorState<A> >::ReserveArcs; + + friend class MutableArcIterator< VectorFst<A> >; + + typedef VectorFstBaseImpl< VectorState<A> > BaseImpl; + typedef typename A::Weight Weight; + typedef typename A::StateId StateId; + + VectorFstImpl() { + SetType("vector"); + SetProperties(kNullProperties | kStaticProperties); + } + explicit VectorFstImpl(const Fst<A> &fst); + + static VectorFstImpl<A> *Read(istream &strm, const FstReadOptions &opts); + + size_t NumInputEpsilons(StateId s) const { return GetState(s)->niepsilons; } + + size_t NumOutputEpsilons(StateId s) const { return GetState(s)->noepsilons; } + + void SetStart(StateId s) { + BaseImpl::SetStart(s); + SetProperties(SetStartProperties(Properties())); + } + + void SetFinal(StateId s, Weight w) { + Weight ow = BaseImpl::Final(s); + BaseImpl::SetFinal(s, w); + SetProperties(SetFinalProperties(Properties(), ow, w)); + } + + StateId AddState() { + StateId s = BaseImpl::AddState(); + SetProperties(AddStateProperties(Properties())); + return s; + } + + void AddArc(StateId s, const A &arc) { + VectorState<A> *state = GetState(s); + if (arc.ilabel == 0) { + ++state->niepsilons; + } + if (arc.olabel == 0) { + ++state->noepsilons; + } + + const A *parc = state->arcs.empty() ? 0 : &(state->arcs.back()); + SetProperties(AddArcProperties(Properties(), s, arc, parc)); + + BaseImpl::AddArc(s, arc); + } + + void DeleteStates(const vector<StateId> &dstates) { + BaseImpl::DeleteStates(dstates); + SetProperties(DeleteStatesProperties(Properties())); + } + + void DeleteStates() { + BaseImpl::DeleteStates(); + SetProperties(DeleteAllStatesProperties(Properties(), + kStaticProperties)); + } + + void DeleteArcs(StateId s, size_t n) { + const vector<A> &arcs = GetState(s)->arcs; + for (size_t i = 0; i < n; ++i) { + size_t j = arcs.size() - i - 1; + if (arcs[j].ilabel == 0) + --GetState(s)->niepsilons; + if (arcs[j].olabel == 0) + --GetState(s)->noepsilons; + } + BaseImpl::DeleteArcs(s, n); + SetProperties(DeleteArcsProperties(Properties())); + } + + void DeleteArcs(StateId s) { + GetState(s)->niepsilons = 0; + GetState(s)->noepsilons = 0; + BaseImpl::DeleteArcs(s); + SetProperties(DeleteArcsProperties(Properties())); + } + + // Properties always true of this Fst class + static const uint64 kStaticProperties = kExpanded | kMutable; + + private: + // Current file format version + static const int kFileVersion = 2; + // Minimum file format version supported + static const int kMinFileVersion = 1; + + DISALLOW_COPY_AND_ASSIGN(VectorFstImpl); +}; + +template <class A> const uint64 VectorFstImpl<A>::kStaticProperties; +template <class A> const int VectorFstImpl<A>::kFileVersion; +template <class A> const int VectorFstImpl<A>::kMinFileVersion; + + +template <class A> +VectorFstImpl<A>::VectorFstImpl(const Fst<A> &fst) { + SetType("vector"); + SetInputSymbols(fst.InputSymbols()); + SetOutputSymbols(fst.OutputSymbols()); + BaseImpl::SetStart(fst.Start()); + if (fst.Properties(kExpanded, false)) + BaseImpl::ReserveStates(CountStates(fst)); + + for (StateIterator< Fst<A> > siter(fst); + !siter.Done(); + siter.Next()) { + StateId s = siter.Value(); + BaseImpl::AddState(); + BaseImpl::SetFinal(s, fst.Final(s)); + ReserveArcs(s, fst.NumArcs(s)); + for (ArcIterator< Fst<A> > aiter(fst, s); + !aiter.Done(); + aiter.Next()) { + const A &arc = aiter.Value(); + BaseImpl::AddArc(s, arc); + if (arc.ilabel == 0) + ++GetState(s)->niepsilons; + if (arc.olabel == 0) + ++GetState(s)->noepsilons; + } + } + SetProperties(fst.Properties(kCopyProperties, false) | kStaticProperties); +} + +template <class A> +VectorFstImpl<A> *VectorFstImpl<A>::Read(istream &strm, + const FstReadOptions &opts) { + VectorFstImpl<A> *impl = new VectorFstImpl; + FstHeader hdr; + if (!impl->ReadHeader(strm, opts, kMinFileVersion, &hdr)) { + delete impl; + return 0; + } + impl->BaseImpl::SetStart(hdr.Start()); + if (hdr.NumStates() != kNoStateId) { + impl->ReserveStates(hdr.NumStates()); + } + + StateId s = 0; + for (;hdr.NumStates() == kNoStateId || s < hdr.NumStates(); ++s) { + typename A::Weight final; + if (!final.Read(strm)) break; + impl->BaseImpl::AddState(); + VectorState<A> *state = impl->GetState(s); + state->final = final; + int64 narcs; + ReadType(strm, &narcs); + if (!strm) { + LOG(ERROR) << "VectorFst::Read: read failed: " << opts.source; + delete impl; + return 0; + } + impl->ReserveArcs(s, narcs); + for (size_t j = 0; j < narcs; ++j) { + A arc; + ReadType(strm, &arc.ilabel); + ReadType(strm, &arc.olabel); + arc.weight.Read(strm); + ReadType(strm, &arc.nextstate); + if (!strm) { + LOG(ERROR) << "VectorFst::Read: read failed: " << opts.source; + delete impl; + return 0; + } + impl->BaseImpl::AddArc(s, arc); + if (arc.ilabel == 0) + ++state->niepsilons; + if (arc.olabel == 0) + ++state->noepsilons; + } + } + if (hdr.NumStates() != kNoStateId && s != hdr.NumStates()) { + LOG(ERROR) << "VectorFst::Read: unexpected end of file: " << opts.source; + delete impl; + return 0; + } + return impl; +} + +// Converts a string into a weight. +template <class W> class WeightFromString { + public: + W operator()(const string &s); +}; + +// Generic case fails. +template <class W> inline +W WeightFromString<W>::operator()(const string &s) { + FSTERROR() << "VectorFst::Read: Obsolete file format"; + return W::NoWeight(); +} + +// TropicalWeight version. +template <> inline +TropicalWeight WeightFromString<TropicalWeight>::operator()(const string &s) { + float f; + memcpy(&f, s.data(), sizeof(f)); + return TropicalWeight(f); +} + +// LogWeight version. +template <> inline +LogWeight WeightFromString<LogWeight>::operator()(const string &s) { + float f; + memcpy(&f, s.data(), sizeof(f)); + return LogWeight(f); +} + +// Simple concrete, mutable FST. This class attaches interface to +// implementation and handles reference counting, delegating most +// methods to ImplToMutableFst. Supports additional operations: +// ReserveStates and ReserveArcs (cf. STL vectors). +template <class A> +class VectorFst : public ImplToMutableFst< VectorFstImpl<A> > { + public: + friend class StateIterator< VectorFst<A> >; + friend class ArcIterator< VectorFst<A> >; + friend class MutableArcIterator< VectorFst<A> >; + template <class F, class G> friend void Cast(const F &, G *); + + typedef A Arc; + typedef typename A::StateId StateId; + typedef VectorFstImpl<A> Impl; + + VectorFst() : ImplToMutableFst<Impl>(new Impl) {} + + explicit VectorFst(const Fst<A> &fst) + : ImplToMutableFst<Impl>(new Impl(fst)) {} + + VectorFst(const VectorFst<A> &fst) : ImplToMutableFst<Impl>(fst) {} + + // Get a copy of this VectorFst. See Fst<>::Copy() for further doc. + virtual VectorFst<A> *Copy(bool safe = false) const { + return new VectorFst<A>(*this); + } + + VectorFst<A> &operator=(const VectorFst<A> &fst) { + SetImpl(fst.GetImpl(), false); + return *this; + } + + virtual VectorFst<A> &operator=(const Fst<A> &fst) { + if (this != &fst) SetImpl(new Impl(fst)); + return *this; + } + + // Read a VectorFst from an input stream; return NULL on error + static VectorFst<A> *Read(istream &strm, const FstReadOptions &opts) { + Impl* impl = Impl::Read(strm, opts); + return impl ? new VectorFst<A>(impl) : 0; + } + + // Read a VectorFst from a file; return NULL on error + // Empty filename reads from standard input + static VectorFst<A> *Read(const string &filename) { + Impl* impl = ImplToExpandedFst<Impl, MutableFst<A> >::Read(filename); + return impl ? new VectorFst<A>(impl) : 0; + } + + virtual bool Write(ostream &strm, const FstWriteOptions &opts) const { + return WriteFst(*this, strm, opts); + } + + virtual bool Write(const string &filename) const { + return Fst<A>::WriteFile(filename); + } + + template <class F> + static bool WriteFst(const F &fst, ostream &strm, + const FstWriteOptions &opts); + + void ReserveStates(StateId n) { + MutateCheck(); + GetImpl()->ReserveStates(n); + } + + void ReserveArcs(StateId s, size_t n) { + MutateCheck(); + GetImpl()->ReserveArcs(s, n); + } + + virtual void InitStateIterator(StateIteratorData<Arc> *data) const { + GetImpl()->InitStateIterator(data); + } + + virtual void InitArcIterator(StateId s, ArcIteratorData<Arc> *data) const { + GetImpl()->InitArcIterator(s, data); + } + + virtual inline + void InitMutableArcIterator(StateId s, MutableArcIteratorData<A> *); + + private: + explicit VectorFst(Impl *impl) : ImplToMutableFst<Impl>(impl) {} + + // Makes visible to friends. + Impl *GetImpl() const { return ImplToFst< Impl, MutableFst<A> >::GetImpl(); } + + void SetImpl(Impl *impl, bool own_impl = true) { + ImplToFst< Impl, MutableFst<A> >::SetImpl(impl, own_impl); + } + + void MutateCheck() { return ImplToMutableFst<Impl>::MutateCheck(); } +}; + +// Specialization for VectorFst; see generic version in fst.h +// for sample usage (but use the VectorFst type!). This version +// should inline. +template <class A> +class StateIterator< VectorFst<A> > { + public: + typedef typename A::StateId StateId; + + explicit StateIterator(const VectorFst<A> &fst) + : nstates_(fst.GetImpl()->NumStates()), s_(0) {} + + bool Done() const { return s_ >= nstates_; } + + StateId Value() const { return s_; } + + void Next() { ++s_; } + + void Reset() { s_ = 0; } + + private: + StateId nstates_; + StateId s_; + + DISALLOW_COPY_AND_ASSIGN(StateIterator); +}; + +// Writes Fst to file, will call CountStates so may involve two passes if +// called from an Fst that is not derived from Expanded. +template <class A> +template <class F> +bool VectorFst<A>::WriteFst(const F &fst, ostream &strm, + const FstWriteOptions &opts) { + static const int kFileVersion = 2; + bool update_header = true; + FstHeader hdr; + hdr.SetStart(fst.Start()); + hdr.SetNumStates(kNoStateId); + size_t start_offset = 0; + if (fst.Properties(kExpanded, false) || (start_offset = strm.tellp()) != -1) { + hdr.SetNumStates(CountStates(fst)); + update_header = false; + } + uint64 properties = fst.Properties(kCopyProperties, false) | + VectorFstImpl<A>::kStaticProperties; + FstImpl<A>::WriteFstHeader(fst, strm, opts, kFileVersion, "vector", + properties, &hdr); + StateId num_states = 0; + for (StateIterator<F> siter(fst); !siter.Done(); siter.Next()) { + typename A::StateId s = siter.Value(); + fst.Final(s).Write(strm); + int64 narcs = fst.NumArcs(s); + WriteType(strm, narcs); + for (ArcIterator<F> aiter(fst, s); !aiter.Done(); aiter.Next()) { + const A &arc = aiter.Value(); + WriteType(strm, arc.ilabel); + WriteType(strm, arc.olabel); + arc.weight.Write(strm); + WriteType(strm, arc.nextstate); + } + num_states++; + } + strm.flush(); + if (!strm) { + LOG(ERROR) << "VectorFst::Write: write failed: " << opts.source; + return false; + } + if (update_header) { + hdr.SetNumStates(num_states); + return FstImpl<A>::UpdateFstHeader(fst, strm, opts, kFileVersion, "vector", + properties, &hdr, start_offset); + } else { + if (num_states != hdr.NumStates()) { + LOG(ERROR) << "Inconsistent number of states observed during write"; + return false; + } + } + return true; +} + +// Specialization for VectorFst; see generic version in fst.h +// for sample usage (but use the VectorFst type!). This version +// should inline. +template <class A> +class ArcIterator< VectorFst<A> > { + public: + typedef typename A::StateId StateId; + + ArcIterator(const VectorFst<A> &fst, StateId s) + : arcs_(fst.GetImpl()->GetState(s)->arcs), i_(0) {} + + bool Done() const { return i_ >= arcs_.size(); } + + const A& Value() const { return arcs_[i_]; } + + void Next() { ++i_; } + + void Reset() { i_ = 0; } + + void Seek(size_t a) { i_ = a; } + + size_t Position() const { return i_; } + + uint32 Flags() const { + return kArcValueFlags; + } + + void SetFlags(uint32 f, uint32 m) {} + + private: + const vector<A>& arcs_; + size_t i_; + + DISALLOW_COPY_AND_ASSIGN(ArcIterator); +}; + +// Specialization for VectorFst; see generic version in fst.h +// for sample usage (but use the VectorFst type!). This version +// should inline. +template <class A> +class MutableArcIterator< VectorFst<A> > + : public MutableArcIteratorBase<A> { + public: + typedef typename A::StateId StateId; + typedef typename A::Weight Weight; + + MutableArcIterator(VectorFst<A> *fst, StateId s) : i_(0) { + fst->MutateCheck(); + state_ = fst->GetImpl()->GetState(s); + properties_ = &fst->GetImpl()->properties_; + } + + bool Done() const { return i_ >= state_->arcs.size(); } + + const A& Value() const { return state_->arcs[i_]; } + + void Next() { ++i_; } + + size_t Position() const { return i_; } + + void Reset() { i_ = 0; } + + void Seek(size_t a) { i_ = a; } + + void SetValue(const A &arc) { + A& oarc = state_->arcs[i_]; + if (oarc.ilabel != oarc.olabel) + *properties_ &= ~kNotAcceptor; + if (oarc.ilabel == 0) { + --state_->niepsilons; + *properties_ &= ~kIEpsilons; + if (oarc.olabel == 0) + *properties_ &= ~kEpsilons; + } + if (oarc.olabel == 0) { + --state_->noepsilons; + *properties_ &= ~kOEpsilons; + } + if (oarc.weight != Weight::Zero() && oarc.weight != Weight::One()) + *properties_ &= ~kWeighted; + oarc = arc; + if (arc.ilabel != arc.olabel) { + *properties_ |= kNotAcceptor; + *properties_ &= ~kAcceptor; + } + if (arc.ilabel == 0) { + ++state_->niepsilons; + *properties_ |= kIEpsilons; + *properties_ &= ~kNoIEpsilons; + if (arc.olabel == 0) { + *properties_ |= kEpsilons; + *properties_ &= ~kNoEpsilons; + } + } + if (arc.olabel == 0) { + ++state_->noepsilons; + *properties_ |= kOEpsilons; + *properties_ &= ~kNoOEpsilons; + } + if (arc.weight != Weight::Zero() && arc.weight != Weight::One()) { + *properties_ |= kWeighted; + *properties_ &= ~kUnweighted; + } + *properties_ &= kSetArcProperties | kAcceptor | kNotAcceptor | + kEpsilons | kNoEpsilons | kIEpsilons | kNoIEpsilons | + kOEpsilons | kNoOEpsilons | kWeighted | kUnweighted; + } + + uint32 Flags() const { + return kArcValueFlags; + } + + void SetFlags(uint32 f, uint32 m) {} + + + private: + // This allows base-class virtual access to non-virtual derived- + // class members of the same name. It makes the derived class more + // efficient to use but unsafe to further derive. + virtual bool Done_() const { return Done(); } + virtual const A& Value_() const { return Value(); } + virtual void Next_() { Next(); } + virtual size_t Position_() const { return Position(); } + virtual void Reset_() { Reset(); } + virtual void Seek_(size_t a) { Seek(a); } + virtual void SetValue_(const A &a) { SetValue(a); } + uint32 Flags_() const { return Flags(); } + void SetFlags_(uint32 f, uint32 m) { SetFlags(f, m); } + + struct VectorState<A> *state_; + uint64 *properties_; + size_t i_; + + DISALLOW_COPY_AND_ASSIGN(MutableArcIterator); +}; + +// Provide information needed for the generic mutable arc iterator +template <class A> inline +void VectorFst<A>::InitMutableArcIterator( + StateId s, MutableArcIteratorData<A> *data) { + data->base = new MutableArcIterator< VectorFst<A> >(this, s); +} + +// A useful alias when using StdArc. +typedef VectorFst<StdArc> StdVectorFst; + +} // namespace fst + +#endif // FST_LIB_VECTOR_FST_H__ |