// relabel.h
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// Copyright 2005-2010 Google, Inc.
// Author: johans@google.com (Johan Schalkwyk)
//
// \file
// Functions and classes to relabel an Fst (either on input or output)
//
#ifndef FST_LIB_RELABEL_H__
#define FST_LIB_RELABEL_H__
#include <tr1/unordered_map>
using std::tr1::unordered_map;
using std::tr1::unordered_multimap;
#include <string>
#include <utility>
using std::pair; using std::make_pair;
#include <vector>
using std::vector;
#include <fst/cache.h>
#include <fst/test-properties.h>
#include <tr1/unordered_map>
using std::tr1::unordered_map;
using std::tr1::unordered_multimap;
namespace fst {
//
// Relabels either the input labels or output labels. The old to
// new labels are specified using a vector of pair<Label,Label>.
// Any label associations not specified are assumed to be identity
// mapping.
//
// \param fst input fst, must be mutable
// \param ipairs vector of input label pairs indicating old to new mapping
// \param opairs vector of output label pairs indicating old to new mapping
//
template <class A>
void Relabel(
MutableFst<A> *fst,
const vector<pair<typename A::Label, typename A::Label> >& ipairs,
const vector<pair<typename A::Label, typename A::Label> >& opairs) {
typedef typename A::StateId StateId;
typedef typename A::Label Label;
uint64 props = fst->Properties(kFstProperties, false);
// construct label to label hash.
unordered_map<Label, Label> input_map;
for (size_t i = 0; i < ipairs.size(); ++i) {
input_map[ipairs[i].first] = ipairs[i].second;
}
unordered_map<Label, Label> output_map;
for (size_t i = 0; i < opairs.size(); ++i) {
output_map[opairs[i].first] = opairs[i].second;
}
for (StateIterator<MutableFst<A> > siter(*fst);
!siter.Done(); siter.Next()) {
StateId s = siter.Value();
for (MutableArcIterator<MutableFst<A> > aiter(fst, s);
!aiter.Done(); aiter.Next()) {
A arc = aiter.Value();
// relabel input
// only relabel if relabel pair defined
typename unordered_map<Label, Label>::iterator it =
input_map.find(arc.ilabel);
if (it != input_map.end()) {
if (it->second == kNoLabel) {
FSTERROR() << "Input symbol id " << arc.ilabel
<< " missing from target vocabulary";
fst->SetProperties(kError, kError);
return;
}
arc.ilabel = it->second;
}
// relabel output
it = output_map.find(arc.olabel);
if (it != output_map.end()) {
if (it->second == kNoLabel) {
FSTERROR() << "Output symbol id " << arc.olabel
<< " missing from target vocabulary";
fst->SetProperties(kError, kError);
return;
}
arc.olabel = it->second;
}
aiter.SetValue(arc);
}
}
fst->SetProperties(RelabelProperties(props), kFstProperties);
}
//
// Relabels either the input labels or output labels. The old to
// new labels mappings are specified using an input Symbol set.
// Any label associations not specified are assumed to be identity
// mapping.
//
// \param fst input fst, must be mutable
// \param new_isymbols symbol set indicating new mapping of input symbols
// \param new_osymbols symbol set indicating new mapping of output symbols
//
template<class A>
void Relabel(MutableFst<A> *fst,
const SymbolTable* new_isymbols,
const SymbolTable* new_osymbols) {
Relabel(fst,
fst->InputSymbols(), new_isymbols, true,
fst->OutputSymbols(), new_osymbols, true);
}
template<class A>
void Relabel(MutableFst<A> *fst,
const SymbolTable* old_isymbols,
const SymbolTable* new_isymbols,
bool attach_new_isymbols,
const SymbolTable* old_osymbols,
const SymbolTable* new_osymbols,
bool attach_new_osymbols) {
typedef typename A::StateId StateId;
typedef typename A::Label Label;
vector<pair<Label, Label> > ipairs;
if (old_isymbols && new_isymbols) {
for (SymbolTableIterator syms_iter(*old_isymbols); !syms_iter.Done();
syms_iter.Next()) {
string isymbol = syms_iter.Symbol();
int isymbol_val = syms_iter.Value();
int new_isymbol_val = new_isymbols->Find(isymbol);
ipairs.push_back(make_pair(isymbol_val, new_isymbol_val));
}
if (attach_new_isymbols)
fst->SetInputSymbols(new_isymbols);
}
vector<pair<Label, Label> > opairs;
if (old_osymbols && new_osymbols) {
for (SymbolTableIterator syms_iter(*old_osymbols); !syms_iter.Done();
syms_iter.Next()) {
string osymbol = syms_iter.Symbol();
int osymbol_val = syms_iter.Value();
int new_osymbol_val = new_osymbols->Find(osymbol);
opairs.push_back(make_pair(osymbol_val, new_osymbol_val));
}
if (attach_new_osymbols)
fst->SetOutputSymbols(new_osymbols);
}
// call relabel using vector of relabel pairs.
Relabel(fst, ipairs, opairs);
}
typedef CacheOptions RelabelFstOptions;
template <class A> class RelabelFst;
//
// \class RelabelFstImpl
// \brief Implementation for delayed relabeling
//
// Relabels an FST from one symbol set to another. Relabeling
// can either be on input or output space. RelabelFst implements
// a delayed version of the relabel. Arcs are relabeled on the fly
// and not cached. I.e each request is recomputed.
//
template<class A>
class RelabelFstImpl : public CacheImpl<A> {
friend class StateIterator< RelabelFst<A> >;
public:
using FstImpl<A>::SetType;
using FstImpl<A>::SetProperties;
using FstImpl<A>::WriteHeader;
using FstImpl<A>::SetInputSymbols;
using FstImpl<A>::SetOutputSymbols;
using CacheImpl<A>::PushArc;
using CacheImpl<A>::HasArcs;
using CacheImpl<A>::HasFinal;
using CacheImpl<A>::HasStart;
using CacheImpl<A>::SetArcs;
using CacheImpl<A>::SetFinal;
using CacheImpl<A>::SetStart;
typedef A Arc;
typedef typename A::Label Label;
typedef typename A::Weight Weight;
typedef typename A::StateId StateId;
typedef CacheState<A> State;
RelabelFstImpl(const Fst<A>& fst,
const vector<pair<Label, Label> >& ipairs,
const vector<pair<Label, Label> >& opairs,
const RelabelFstOptions &opts)
: CacheImpl<A>(opts), fst_(fst.Copy()),
relabel_input_(false), relabel_output_(false) {
uint64 props = fst.Properties(kCopyProperties, false);
SetProperties(RelabelProperties(props));
SetType("relabel");
// create input label map
if (ipairs.size()