// encode.h
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// Copyright 2005-2010 Google, Inc.
// Author: johans@google.com (Johan Schalkwyk)
//
// \file
// Class to encode and decoder an fst.
#ifndef FST_LIB_ENCODE_H__
#define FST_LIB_ENCODE_H__
#include <climits>
#include <tr1/unordered_map>
using std::tr1::unordered_map;
using std::tr1::unordered_multimap;
#include <string>
#include <vector>
using std::vector;
#include <fst/arc-map.h>
#include <fst/rmfinalepsilon.h>
namespace fst {
static const uint32 kEncodeLabels = 0x0001;
static const uint32 kEncodeWeights = 0x0002;
static const uint32 kEncodeFlags = 0x0003; // All non-internal flags
static const uint32 kEncodeHasISymbols = 0x0004; // For internal use
static const uint32 kEncodeHasOSymbols = 0x0008; // For internal use
enum EncodeType { ENCODE = 1, DECODE = 2 };
// Identifies stream data as an encode table (and its endianity)
static const int32 kEncodeMagicNumber = 2129983209;
// The following class encapsulates implementation details for the
// encoding and decoding of label/weight tuples used for encoding
// and decoding of Fsts. The EncodeTable is bidirectional. I.E it
// stores both the Tuple of encode labels and weights to a unique
// label, and the reverse.
template <class A> class EncodeTable {
public:
typedef typename A::Label Label;
typedef typename A::Weight Weight;
// Encoded data consists of arc input/output labels and arc weight
struct Tuple {
Tuple() {}
Tuple(Label ilabel_, Label olabel_, Weight weight_)
: ilabel(ilabel_), olabel(olabel_), weight(weight_) {}
Tuple(const Tuple& tuple)
: ilabel(tuple.ilabel), olabel(tuple.olabel), weight(tuple.weight) {}
Label ilabel;
Label olabel;
Weight weight;
};
// Comparison object for hashing EncodeTable Tuple(s).
class TupleEqual {
public:
bool operator()(const Tuple* x, const Tuple* y) const {
return (x->ilabel == y->ilabel &&
x->olabel == y->olabel &&
x->weight == y->weight);
}
};
// Hash function for EncodeTabe Tuples. Based on the encode flags
// we either hash the labels, weights or combination of them.
class TupleKey {
public:
TupleKey()
: encode_flags_(kEncodeLabels | kEncodeWeights) {}
TupleKey(const TupleKey& key)
: encode_flags_(key.encode_flags_) {}
explicit TupleKey(uint32 encode_flags)
: encode_flags_(encode_flags) {}
size_t operator()(const Tuple* x) const {
size_t hash = x->ilabel;
const int lshift = 5;
const int rshift = CHAR_BIT * sizeof(size_t) - 5;
if (encode_flags_ & kEncodeLabels)
hash = hash << lshift ^ hash >> rshift ^ x->olabel;
if (encode_flags_ & kEncodeWeights)
hash = hash << lshift ^ hash >> rshift ^ x->weight.Hash();
return hash;
}
private:
int32 encode_flags_;
};
typedef unordered_map<const Tuple*,
Label,
TupleKey,
TupleEqual> EncodeHash;
explicit EncodeTable(uint32 encode_flags)
: flags_(encode_flags),
encode_hash_(1024, TupleKey(encode_flags)),
isymbols_(0), osymbols_(0) {}
~EncodeTable() {
for (size_t i = 0; i < encode_tuples_.size(); ++i) {
delete encode_tuples_[i];
}
delete isymbols_;
delete osymbols_;
}
// Given an arc encode either input/ouptut labels or input/costs or both
Label Encode(const A &arc) {
const Tuple tuple(arc.ilabel,
flags_ & kEncodeLabels ? arc.olabel : 0,
flags_ & kEncodeWeights ? arc.weight : Weight::One());
typename EncodeHash::const_iterator it = encode_hash_.find(&tuple);
if (it == encode_hash_.end()) {
encode_tuples_.push_back(new Tuple(tuple));
encode_hash_[encode_tuples_.back()] = encode_tuples_.size();
return encode_tuples_.size();
} else {
return it->second;
}
}
// Given an arc, look up its encoded label. Returns kNoLabel if not found.
Label GetLabel(const A &arc) const {
const Tuple tuple(arc.ilabel,
flags_ & kEncodeLabels ? arc.olabel : 0,
flags_ & kEncodeWeights ? arc.weight : Weight::One());
typename EncodeHash::const_iterator it = encode_hash_.find(&tuple);
if (it == encode_hash_.end()) {
return kNoLabel;
} else {
return it->second;
}
}
// Given an encode arc Label decode back to input/output labels and costs
const Tuple* Decode(Label key) const {
if (key < 1 || key > encode_tuples_.size()) {
LOG(ERROR) << "EncodeTable::Decode: unknown decode key: " << key;
return 0;
}
return encode_tuples_[key - 1];
}
size_t Size() const { return encode_tuples_.size(); }
bool Write(ostream &strm, const string &source) const;
static EncodeTable<A> *Read(istream &strm, const string &source);
const uint32 flags() const { return flags_ & kEncodeFlags; }
int RefCount() const { return ref_count_.count(); }
int IncrRefCount() { return ref_count_.Incr(); }
int DecrRefCount() { return ref_count_.Decr(); }
SymbolTable *InputSymbols() const { return isymbols_; }
SymbolTable *OutputSymbols() const { return osymbols_; }
void SetInputSymbols(const SymbolTable* syms) {
if (isymbols_) delete isymbols_;
if (syms) {
isymbols_ = syms->Copy();
flags_ |= kEncodeHasISymbols;
} else {
isymbols_ = 0;
flags_ &= ~kEncodeHasISymbols;
}
}
void SetOutputSymbols(const SymbolTable* syms) {
if (osymbols_) delete osymbols_;
if (syms)