diff options
Diffstat (limited to 'kaldi_io')
304 files changed, 96497 insertions, 0 deletions
diff --git a/kaldi_io/Makefile b/kaldi_io/Makefile new file mode 100644 index 0000000..59393cf --- /dev/null +++ b/kaldi_io/Makefile @@ -0,0 +1,45 @@ +.PHONY: kaldi +SHELL := /bin/bash +BUILD_DIR := $(CURDIR)/build +INC_PATH := $(LUA_BINDIR)/../include/ +OBJS := init.o src/cwrapper_kaldi.o src/init.o +LIBS := libkaldiio.so +LUA_LIBS := init.lua +INCLUDE := -I $(LUA_INCDIR) -I $(INC_PATH) -DLUA_USE_APICHECK + +SUBDIR := src +OBJ_DIR := $(BUILD_DIR)/objs +LUA_DIR := $(INST_LUADIR)/kaldi_io +KALDIINCLUDE := -I src/tools/ATLAS/include/ -I src/kaldi/ -I src/tools/openfst/include/ + +OBJS := $(addprefix $(OBJ_DIR)/,$(OBJS)) +LIBS := $(addprefix $(INST_LIBDIR)/,$(LIBS)) +OBJ_SUBDIR := $(addprefix $(OBJ_DIR)/,$(SUBDIR)) +LUA_SUBDIR := $(addprefix $(LUA_DIR)/,$(SUBDIR)) +LUA_LIBS := $(addprefix $(LUA_DIR)/,$(LUA_LIBS)) +LIB_PATH := $(LUA_BINDIR)/../lib + +build: $(OBJ_DIR) $(OBJ_SUBDIR) $(OBJS) $(OBJ_DIR)/src/test +install: $(LUA_DIR) $(LUA_SUBDIR) $(LUA_LIBS) $(LIBS) + +include kaldi.mk + +KL := /home/stuymf/kaldi-trunk/src/feat/kaldi-feat.a /home/stuymf/kaldi-trunk/src/matrix/kaldi-matrix.a /home/stuymf/kaldi-trunk/src/base/kaldi-base.a /home/stuymf/kaldi-trunk/src/util/kaldi-util.a /home/stuymf/kaldi-trunk/src/hmm/kaldi-hmm.a /home/stuymf/kaldi-trunk/src/tree/kaldi-tree.a /usr/lib/libatlas.so.3 /usr/lib/libf77blas.so.3 /usr/lib/libcblas.so.3 /usr/lib/liblapack_atlas.so.3 + + +$(OBJ_DIR) $(LUA_DIR) $(OBJ_SUBDIR) $(LUA_SUBDIR): + -mkdir -p $@ +$(LUA_DIR)/%.lua: %.lua + cp $< $@ +$(LIBS): $(OBJ_DIR)/src/cwrapper_kaldi.o $(OBJ_DIR)/init.o $(OBJ_DIR)/src/init.o + gcc -shared -o $@ $(OBJ_DIR)/src/cwrapper_kaldi.o $(OBJ_DIR)/init.o $(OBJ_DIR)/src/init.o -lstdc++ -Wl,-rpath=$(LIB_PATH) -L$(LIB_PATH) -lnervcore -lluaT $(KL) + g++ -o $@ -c $< -DHAVE_ATLAS $(KALDIINCLUDE) -g -fPIC $(INCLUDE) -DKALDI_DOUBLEPRECISION=0 -msse2 -DHAVE_POSIX_MEMALIGN +$(OBJ_DIR)/src/cwrapper_kaldi.o: src/cwrapper_kaldi.cpp + g++ -o $@ -c $< -DHAVE_ATLAS $(KALDIINCLUDE) -g -fPIC $(INCLUDE) -DKALDI_DOUBLEPRECISION=0 -msse2 -DHAVE_POSIX_MEMALIGN +$(OBJ_DIR)/src/test: $(OBJ_DIR)/src/cwrapper_kaldi.o $(OBJ_DIR)/src/test.o + gcc -o $@ $^ -Wl,-rpath=$(LIB_PATH) -L$(LIB_PATH) $(INCLUDE) $(KALDIINCLUDE) -lnervcore -Wl,-rpath=$(LUA_LIBDIR) -L$(LUA_LIBDIR) -lluajit-5.1 -lstdc++ -lm $(KL) +$(OBJ_DIR)/%.o: %.c + gcc -o $@ -c $< -g $(INCLUDE) -fPIC +clean: + -rm $(OBJ_DIR)/src/*.o + diff --git a/kaldi_io/example/kaldi_io_example.lua b/kaldi_io/example/kaldi_io_example.lua new file mode 100644 index 0000000..8fd068a --- /dev/null +++ b/kaldi_io/example/kaldi_io_example.lua @@ -0,0 +1,8 @@ +require 'kaldi_io' + +frm_ext = 5 +feat_repo = nerv.KaldiFeatureRepo("ark:/slfs6/users/ymz09/kaldi/src/featbin/copy-feats scp:/slfs6/users/ymz09/swb_ivec/train_bp.scp ark:- |") + +feat_utter = feat_repo:cur_utter(true) +print(feat_utter) + diff --git a/kaldi_io/example/swb_baseline.lua b/kaldi_io/example/swb_baseline.lua new file mode 100644 index 0000000..8b1e122 --- /dev/null +++ b/kaldi_io/example/swb_baseline.lua @@ -0,0 +1,193 @@ +require 'kaldi_io' +gconf = {lrate = 0.8, wcost = 1e-6, momentum = 0.9, + cumat_type = nerv.CuMatrixFloat, + mmat_type = nerv.MMatrixFloat, + frm_ext = 5, + tr_scp = "ark:/slfs6/users/ymz09/kaldi/src/featbin/copy-feats scp:/slfs6/users/ymz09/swb_ivec/train_bp.scp ark:- |", + cv_scp = "ark:/slfs6/users/ymz09/kaldi/src/featbin/copy-feats scp:/slfs6/users/ymz09/swb_ivec/train_cv.scp ark:- |", + initialized_param = {"/slfs6/users/ymz09/swb_ivec/swb_init.nerv", + "/slfs6/users/ymz09/swb_ivec/swb_global_transf.nerv"}, + debug = false} + +function make_layer_repo(param_repo) + local layer_repo = nerv.LayerRepo( + { + -- global transf + ["nerv.BiasLayer"] = + { + blayer1 = {{bias = "bias1"}, {dim_in = {429}, dim_out = {429}}}, + blayer2 = {{bias = "bias2"}, {dim_in = {429}, dim_out = {429}}} + }, + ["nerv.WindowLayer"] = + { + wlayer1 = {{window = "window1"}, {dim_in = {429}, dim_out = {429}}}, + wlayer2 = {{window = "window2"}, {dim_in = {429}, dim_out = {429}}} + }, + -- biased linearity + ["nerv.AffineLayer"] = + { + affine0 = {{ltp = "affine0_ltp", bp = "affine0_bp"}, + {dim_in = {429}, dim_out = {2048}}}, + affine1 = {{ltp = "affine1_ltp", bp = "affine1_bp"}, + {dim_in = {2048}, dim_out = {2048}}}, + affine2 = {{ltp = "affine2_ltp", bp = "affine2_bp"}, + {dim_in = {2048}, dim_out = {2048}}}, + affine3 = {{ltp = "affine3_ltp", bp = "affine3_bp"}, + {dim_in = {2048}, dim_out = {2048}}}, + affine4 = {{ltp = "affine4_ltp", bp = "affine4_bp"}, + {dim_in = {2048}, dim_out = {2048}}}, + affine5 = {{ltp = "affine5_ltp", bp = "affine5_bp"}, + {dim_in = {2048}, dim_out = {2048}}}, + affine6 = {{ltp = "affine6_ltp", bp = "affine6_bp"}, + {dim_in = {2048}, dim_out = {2048}}}, + affine7 = {{ltp = "affine7_ltp", bp = "affine7_bp"}, + {dim_in = {2048}, dim_out = {3001}}} + }, + ["nerv.SigmoidLayer"] = + { + sigmoid0 = {{}, {dim_in = {2048}, dim_out = {2048}}}, + sigmoid1 = {{}, {dim_in = {2048}, dim_out = {2048}}}, + sigmoid2 = {{}, {dim_in = {2048}, dim_out = {2048}}}, + sigmoid3 = {{}, {dim_in = {2048}, dim_out = {2048}}}, + sigmoid4 = {{}, {dim_in = {2048}, dim_out = {2048}}}, + sigmoid5 = {{}, {dim_in = {2048}, dim_out = {2048}}}, + sigmoid6 = {{}, {dim_in = {2048}, dim_out = {2048}}} + }, + ["nerv.SoftmaxCELayer"] = -- softmax + ce criterion layer for finetune output + { + ce_crit = {{}, {dim_in = {3001, 1}, dim_out = {1}, compressed = true}} + }, + ["nerv.SoftmaxLayer"] = -- softmax for decode output + { + softmax = {{}, {dim_in = {3001}, dim_out = {3001}}} + } + }, param_repo, gconf) + + layer_repo:add_layers( + { + ["nerv.DAGLayer"] = + { + global_transf = {{}, { + dim_in = {429}, dim_out = {429}, + sub_layers = layer_repo, + connections = { + ["<input>[1]"] = "blayer1[1]", + ["blayer1[1]"] = "wlayer1[1]", + ["wlayer1[1]"] = "blayer2[1]", + ["blayer2[1]"] = "wlayer2[1]", + ["wlayer2[1]"] = "<output>[1]" + } + }}, + main = {{}, { + dim_in = {429}, dim_out = {3001}, + sub_layers = layer_repo, + connections = { + ["<input>[1]"] = "affine0[1]", + ["affine0[1]"] = "sigmoid0[1]", + ["sigmoid0[1]"] = "affine1[1]", + ["affine1[1]"] = "sigmoid1[1]", + ["sigmoid1[1]"] = "affine2[1]", + ["affine2[1]"] = "sigmoid2[1]", + ["sigmoid2[1]"] = "affine3[1]", + ["affine3[1]"] = "sigmoid3[1]", + ["sigmoid3[1]"] = "affine4[1]", + ["affine4[1]"] = "sigmoid4[1]", + ["sigmoid4[1]"] = "affine5[1]", + ["affine5[1]"] = "sigmoid5[1]", + ["sigmoid5[1]"] = "affine6[1]", + ["affine6[1]"] = "sigmoid6[1]", + ["sigmoid6[1]"] = "affine7[1]", + ["affine7[1]"] = "<output>[1]" + } + }} + } + }, param_repo, gconf) + + layer_repo:add_layers( + { + ["nerv.DAGLayer"] = + { + ce_output = {{}, { + dim_in = {429, 1}, dim_out = {1}, + sub_layers = layer_repo, + connections = { + ["<input>[1]"] = "main[1]", + ["main[1]"] = "ce_crit[1]", + ["<input>[2]"] = "ce_crit[2]", + ["ce_crit[1]"] = "<output>[1]" + } + }}, + softmax_output = {{}, { + dim_in = {429}, dim_out = {3001}, + sub_layers = layer_repo, + connections = { + ["<input>[1]"] = "main[1]", + ["main[1]"] = "softmax[1]", + ["softmax[1]"] = "<output>[1]" + } + }} + } + }, param_repo, gconf) + + return layer_repo +end + +function get_network(layer_repo) + return layer_repo:get_layer("ce_output") +end + +function get_decode_network(layer_repo) + return layer_repo:get_layer("softmax_output") +end + +function get_global_transf(layer_repo) + return layer_repo:get_layer("global_transf") +end + +function make_readers(feature_rspecifier, layer_repo) + return { + {reader = nerv.KaldiReader(gconf, + { + id = "main_scp", + feature_rspecifier = feature_rspecifier, + frm_ext = gconf.frm_ext, + mlfs = { + phone_state = { + targets_rspecifier = "ark:/slfs6/users/ymz09/kaldi/src/bin/ali-to-pdf /slfs6/users/ymz09/swb_ivec/final.mdl \"ark:gunzip -c /slfs6/users/ymz09/swb_ivec/ali.*.gz |\" ark:- | /slfs6/users/ymz09/kaldi/src/bin/ali-to-post ark:- ark:- |", + format = "map" + } + }, + global_transf = layer_repo:get_layer("global_transf") + }), + data = {main_scp = 429, phone_state = 1}} + } +end + +function make_buffer(readers) + return nerv.SGDBuffer(gconf, + { + buffer_size = gconf.buffer_size, + randomize = gconf.randomize, + readers = readers + }) +end + +function get_input_order() + return {"main_scp", "phone_state"} +end + +function get_accuracy(layer_repo) + local ce_crit = layer_repo:get_layer("ce_crit") + return ce_crit.total_correct / ce_crit.total_frames * 100 +end + +function print_stat(layer_repo) + local ce_crit = layer_repo:get_layer("ce_crit") + nerv.info("*** training stat begin ***") + nerv.printf("cross entropy:\t\t%.8f\n", ce_crit.total_ce) + nerv.printf("correct:\t\t%d\n", ce_crit.total_correct) + nerv.printf("frames:\t\t\t%d\n", ce_crit.total_frames) + nerv.printf("err/frm:\t\t%.8f\n", ce_crit.total_ce / ce_crit.total_frames) + nerv.printf("accuracy:\t\t%.3f%%\n", get_accuracy(layer_repo)) + nerv.info("*** training stat end ***") +end diff --git a/kaldi_io/example/swb_baseline_basic.lua b/kaldi_io/example/swb_baseline_basic.lua new file mode 100644 index 0000000..e6c8145 --- /dev/null +++ b/kaldi_io/example/swb_baseline_basic.lua @@ -0,0 +1,157 @@ +require 'kaldi_io' +gconf = {lrate = 0.8, wcost = 1e-6, momentum = 0.9, + cumat_type = nerv.CuMatrixFloat, + mmat_type = nerv.MMatrixFloat, + frm_ext = 5, + tr_rspecifier = "ark:/slfs6/users/ymz09/kaldi/src/featbin/copy-feats scp:/slfs6/users/ymz09/swb_ivec/train_bp.scp ark:- |", + cv_rspecifier = "ark:/slfs6/users/ymz09/kaldi/src/featbin/copy-feats scp:/slfs6/users/ymz09/swb_ivec/train_cv.scp ark:- |", + initialized_param = {"/slfs6/users/ymz09/swb_ivec/swb_init.nerv", + "/slfs6/users/ymz09/swb_ivec/swb_global_transf.nerv"}, + debug = false} + +function make_sublayer_repo(param_repo) + return nerv.LayerRepo( + { + -- global transf + ["nerv.BiasLayer"] = + { + blayer1 = {{bias = "bias1"}, {dim_in = {429}, dim_out = {429}}}, + blayer2 = {{bias = "bias2"}, {dim_in = {429}, dim_out = {429}}} + }, + ["nerv.WindowLayer"] = + { + wlayer1 = {{window = "window1"}, {dim_in = {429}, dim_out = {429}}}, + wlayer2 = {{window = "window2"}, {dim_in = {429}, dim_out = {429}}} + }, + -- biased linearity + ["nerv.AffineLayer"] = + { + affine0 = {{ltp = "affine0_ltp", bp = "affine0_bp"}, + {dim_in = {429}, dim_out = {2048}}}, + affine1 = {{ltp = "affine1_ltp", bp = "affine1_bp"}, + {dim_in = {2048}, dim_out = {2048}}}, + affine2 = {{ltp = "affine2_ltp", bp = "affine2_bp"}, + {dim_in = {2048}, dim_out = {2048}}}, + affine3 = {{ltp = "affine3_ltp", bp = "affine3_bp"}, + {dim_in = {2048}, dim_out = {2048}}}, + affine4 = {{ltp = "affine4_ltp", bp = "affine4_bp"}, + {dim_in = {2048}, dim_out = {2048}}}, + affine5 = {{ltp = "affine5_ltp", bp = "affine5_bp"}, + {dim_in = {2048}, dim_out = {2048}}}, + affine6 = {{ltp = "affine6_ltp", bp = "affine6_bp"}, + {dim_in = {2048}, dim_out = {2048}}}, + affine7 = {{ltp = "affine7_ltp", bp = "affine7_bp"}, + {dim_in = {2048}, dim_out = {3001}}} + }, + ["nerv.SigmoidLayer"] = + { + sigmoid0 = {{}, {dim_in = {2048}, dim_out = {2048}}}, + sigmoid1 = {{}, {dim_in = {2048}, dim_out = {2048}}}, + sigmoid2 = {{}, {dim_in = {2048}, dim_out = {2048}}}, + sigmoid3 = {{}, {dim_in = {2048}, dim_out = {2048}}}, + sigmoid4 = {{}, {dim_in = {2048}, dim_out = {2048}}}, + sigmoid5 = {{}, {dim_in = {2048}, dim_out = {2048}}}, + sigmoid6 = {{}, {dim_in = {2048}, dim_out = {2048}}} + }, + ["nerv.SoftmaxCELayer"] = + { + ce_crit = {{}, {dim_in = {3001, 1}, dim_out = {1}, compressed = true}} + } + }, param_repo, gconf) +end + +function make_layer_repo(sublayer_repo, param_repo) + return nerv.LayerRepo( + { + ["nerv.DAGLayer"] = + { + global_transf = {{}, { + dim_in = {429}, dim_out = {429}, + sub_layers = sublayer_repo, + connections = { + ["<input>[1]"] = "blayer1[1]", + ["blayer1[1]"] = "wlayer1[1]", + ["wlayer1[1]"] = "blayer2[1]", + ["blayer2[1]"] = "wlayer2[1]", + ["wlayer2[1]"] = "<output>[1]" + } + }}, + main = {{}, { + dim_in = {429, 1}, dim_out = {1}, + sub_layers = sublayer_repo, + connections = { + ["<input>[1]"] = "affine0[1]", + ["affine0[1]"] = "sigmoid0[1]", + ["sigmoid0[1]"] = "affine1[1]", + ["affine1[1]"] = "sigmoid1[1]", + ["sigmoid1[1]"] = "affine2[1]", + ["affine2[1]"] = "sigmoid2[1]", + ["sigmoid2[1]"] = "affine3[1]", + ["affine3[1]"] = "sigmoid3[1]", + ["sigmoid3[1]"] = "affine4[1]", + ["affine4[1]"] = "sigmoid4[1]", + ["sigmoid4[1]"] = "affine5[1]", + ["affine5[1]"] = "sigmoid5[1]", + ["sigmoid5[1]"] = "affine6[1]", + ["affine6[1]"] = "sigmoid6[1]", + ["sigmoid6[1]"] = "affine7[1]", + ["affine7[1]"] = "ce_crit[1]", + ["<input>[2]"] = "ce_crit[2]", + ["ce_crit[1]"] = "<output>[1]" + } + }} + } + }, param_repo, gconf) +end + +function get_network(layer_repo) + return layer_repo:get_layer("main") +end + +function make_readers(feature_rspecifier, layer_repo) + return { + {reader = nerv.KaldiReader(gconf, + { + id = "main_scp", + feature_rspecifier = feature_rspecifier, + frm_ext = gconf.frm_ext, + mlfs = { + phone_state = { + targets_rspecifier = "ark:/slfs6/users/ymz09/kaldi/src/bin/ali-to-pdf /slfs6/users/ymz09/swb_ivec/final.mdl \"ark:gunzip -c /slfs6/users/ymz09/swb_ivec/ali.*.gz |\" ark:- | /slfs6/users/ymz09/kaldi/src/bin/ali-to-post ark:- ark:- |", + format = "map" + } + }, + global_transf = layer_repo:get_layer("global_transf") + }), + data = {main_scp = 429, phone_state = 1}} + } +end + +function make_buffer(readers) + return nerv.SGDBuffer(gconf, + { + buffer_size = gconf.buffer_size, + randomize = gconf.randomize, + readers = readers + }) +end + +function get_input_order() + return {"main_scp", "phone_state"} +end + +function get_accuracy(sublayer_repo) + local ce_crit = sublayer_repo:get_layer("ce_crit") + return ce_crit.total_correct / ce_crit.total_frames * 100 +end + +function print_stat(sublayer_repo) + local ce_crit = sublayer_repo:get_layer("ce_crit") + nerv.info("*** training stat begin ***") + nerv.printf("cross entropy:\t\t%.8f\n", ce_crit.total_ce) + nerv.printf("correct:\t\t%d\n", ce_crit.total_correct) + nerv.printf("frames:\t\t\t%d\n", ce_crit.total_frames) + nerv.printf("err/frm:\t\t%.8f\n", ce_crit.total_ce / ce_crit.total_frames) + nerv.printf("accuracy:\t\t%.3f%%\n", get_accuracy(sublayer_repo)) + nerv.info("*** training stat end ***") +end diff --git a/kaldi_io/init.c b/kaldi_io/init.c new file mode 100644 index 0000000..fe2f967 --- /dev/null +++ b/kaldi_io/init.c @@ -0,0 +1,8 @@ +#include "../nerv/common.h" +#include <stdio.h> + +extern void kaldi_io_init(lua_State *L); +int luaopen_libkaldiio(lua_State *L) { + kaldi_io_init(L); + return 1; +} diff --git a/kaldi_io/init.lua b/kaldi_io/init.lua new file mode 100644 index 0000000..0ad3a60 --- /dev/null +++ b/kaldi_io/init.lua @@ -0,0 +1,47 @@ +require 'libkaldiio' +require 'speech_utils' +local KaldiReader = nerv.class("nerv.KaldiReader", "nerv.DataReader") + +function KaldiReader:__init(global_conf, reader_conf) + self.feat_id = reader_conf.id + self.frm_ext = reader_conf.frm_ext + self.gconf = global_conf + self.global_transf = reader_conf.global_transf + self.debug = global_conf.debug + if self.debug == nil then + self.debug = false + end + self.feat_repo = nerv.KaldiFeatureRepo(reader_conf.feature_rspecifier) + + self.lab_repo = {} + for id, mlf_spec in pairs(reader_conf.mlfs) do + self.lab_repo[id] = nerv.KaldiLabelRepo(mlf_spec.targets_rspecifier, + mlf_spec.format) + end +end + +function KaldiReader:get_data() + if self.feat_repo:is_end() then + return nil + end + local res = {} + -- read Kaldi feature + local feat_utter = self.feat_repo:cur_utter(self.debug) + -- global transf + local transformed = nerv.speech_utils.global_transf(feat_utter, + self.global_transf, self.frm_ext, 0, self.gconf) + res[self.feat_id] = transformed + -- add corresponding labels + for id, repo in pairs(self.lab_repo) do + local lab_utter = repo:get_utter(self.feat_repo, + self.frm_ext, + transformed:nrow(), + self.debug) + res[id] = lab_utter + --print(lab_utter) + end + -- move the pointer to next + self.feat_repo:next() + collectgarbage("collect") + return res +end diff --git a/kaldi_io/kaldi.mk b/kaldi_io/kaldi.mk new file mode 100644 index 0000000..4a397f0 --- /dev/null +++ b/kaldi_io/kaldi.mk @@ -0,0 +1,70 @@ +# This file was generated using the following command: +# ./configure + +# Rules that enable valgrind debugging ("make valgrind") + +valgrind: .valgrind + +.valgrind: + echo -n > valgrind.out + for x in $(TESTFILES); do echo $$x>>valgrind.out; valgrind ./$$x >/dev/null 2>> valgrind.out; done + ! ( grep 'ERROR SUMMARY' valgrind.out | grep -v '0 errors' ) + ! ( grep 'definitely lost' valgrind.out | grep -v -w 0 ) + rm valgrind.out + touch .valgrind + + +CONFIGURE_VERSION := 2 +OPENFSTLIBS = -L/slwork/users/wd007/src/kaldi/tools/openfst/lib -lfst +OPENFSTLDFLAGS = -Wl,-rpath=/slwork/users/wd007/src/kaldi/tools/openfst/lib +FSTROOT = /slwork/users/wd007/src/kaldi/tools/openfst +ATLASINC = /slwork/users/wd007/src/kaldi/tools/ATLAS/include +ATLASLIBS = -L/usr/lib -llapack -lcblas -latlas -lf77blas +# You have to make sure ATLASLIBS is set... + +ifndef FSTROOT +$(error FSTROOT not defined.) +endif + +ifndef ATLASINC +$(error ATLASINC not defined.) +endif + +ifndef ATLASLIBS +$(error ATLASLIBS not defined.) +endif + + +CXXFLAGS = -msse -msse2 -Wall -I.. \ + -fPIC \ + -DKALDI_DOUBLEPRECISION=0 -DHAVE_POSIX_MEMALIGN \ + -Wno-sign-compare -Wno-unused-local-typedefs -Winit-self \ + -DHAVE_EXECINFO_H=1 -rdynamic -DHAVE_CXXABI_H \ + -DHAVE_ATLAS -I$(ATLASINC) \ + -I$(FSTROOT)/include \ + $(EXTRA_CXXFLAGS) \ + -g # -O0 -DKALDI_PARANOID + +ifeq ($(KALDI_FLAVOR), dynamic) +CXXFLAGS += -fPIC +endif + +LDFLAGS = -rdynamic $(OPENFSTLDFLAGS) +LDLIBS = $(EXTRA_LDLIBS) $(OPENFSTLIBS) $(ATLASLIBS) -lm -lpthread -ldl +CC = g++ +CXX = g++ +AR = ar +AS = as +RANLIB = ranlib + +#Next section enables CUDA for compilation +CUDA = true +CUDATKDIR = /usr/local/cuda + +CUDA_INCLUDE= -I$(CUDATKDIR)/include +CUDA_FLAGS = -g -Xcompiler -fPIC --verbose --machine 64 -DHAVE_CUDA + +CXXFLAGS += -DHAVE_CUDA -I$(CUDATKDIR)/include +CUDA_LDFLAGS += -L$(CUDATKDIR)/lib64 -Wl,-rpath,$(CUDATKDIR)/lib64 +CUDA_LDLIBS += -lcublas -lcudart #LDLIBS : The libs are loaded later than static libs in implicit rule + diff --git a/kaldi_io/kaldi_io-scm-1.rockspec b/kaldi_io/kaldi_io-scm-1.rockspec new file mode 100644 index 0000000..7c9f8d8 --- /dev/null +++ b/kaldi_io/kaldi_io-scm-1.rockspec @@ -0,0 +1,36 @@ +package = "kaldi_io" +version = "scm-1" +source = { + url = "https://github.com/Nerv-SJTU/nerv-speech.git" +} +description = { + summary = "Kaldi I/O support (Kaldi I/O wrapper) for Nerv", + detailed = [[ + ]], + homepage = "https://github.com/Nerv-SJTU/nerv-speech", + license = "BSD" +} +dependencies = { + "nerv >= scm-1", + "lua >= 5.1" +} +build = { + type = "make", + build_variables = { + CFLAGS="$(CFLAGS)", + LIBFLAG="$(LIBFLAG)", + LUA_LIBDIR="$(LUA_LIBDIR)", + LUA_BINDIR="$(LUA_BINDIR)", + LUA_INCDIR="$(LUA_INCDIR)", + INST_PREFIX="$(PREFIX)", + LUA="$(LUA)", + }, + install_variables = { + LUA_BINDIR="$(LUA_BINDIR)", + INST_PREFIX="$(PREFIX)", + INST_BINDIR="$(BINDIR)", + INST_LIBDIR="$(LIBDIR)", + INST_LUADIR="$(LUADIR)", + INST_CONFDIR="$(CONFDIR)", + }, +} diff --git a/kaldi_io/src/cwrapper_kaldi.cpp b/kaldi_io/src/cwrapper_kaldi.cpp new file mode 100644 index 0000000..3dd055f --- /dev/null +++ b/kaldi_io/src/cwrapper_kaldi.cpp @@ -0,0 +1,111 @@ +#include <string> +#include "kaldi/base/kaldi-common.h" +#include "kaldi/hmm/posterior.h" +#include "kaldi/util/table-types.h" +typedef kaldi::BaseFloat BaseFloat; + +extern "C" { +#include "cwrapper_kaldi.h" +#include "string.h" +#include "assert.h" +#include "nerv/common.h" + + extern Matrix *nerv_matrix_host_float_create(long nrow, long ncol, Status *status); + extern Matrix *nerv_matrix_host_double_create(long nrow, long ncol, Status *status); + + struct KaldiFeatureRepo { + kaldi::SequentialBaseFloatMatrixReader* feature_reader; + string utt; + }; + + KaldiFeatureRepo *kaldi_feature_repo_new(const char *feature_rspecifier) { + KaldiFeatureRepo *repo = new KaldiFeatureRepo(); + repo->feature_reader = new kaldi::SequentialBaseFloatMatrixReader(string(feature_rspecifier)); + return repo; + } + + Matrix *kaldi_feature_repo_read_utterance(KaldiFeatureRepo *repo, lua_State *L, int debug) { + Matrix *mat; /* nerv implementation */ + + repo->utt = repo->feature_reader->Key(); + kaldi::Matrix<BaseFloat> kmat = repo->feature_reader->Value(); + + int n = kmat.NumRows(); + int m = kmat.NumCols(); + Status status; + assert(sizeof(BaseFloat) == sizeof(float)); + if(sizeof(BaseFloat) == sizeof(float)) + mat = nerv_matrix_host_float_create(n, m, &status); + else if(sizeof(BaseFloat) == sizeof(double)) + mat = nerv_matrix_host_double_create(n, m, &status); + NERV_LUA_CHECK_STATUS(L, status); + size_t stride = mat->stride; + if (debug) + fprintf(stderr, "[kaldi] feature: %s %d %d\n", repo->utt.c_str(), n, m); + + for (int i = 0; i < n; i++) + { + const BaseFloat *row = kmat.RowData(i); + BaseFloat *nerv_row = (BaseFloat *)((char *)mat->data.f + i * stride); + /* use memmove to copy the row, since KaldiLib uses compact storage */ + memmove(nerv_row, row, sizeof(BaseFloat) * m); + } + return mat; + } + + void kaldi_feature_repo_next(KaldiFeatureRepo *repo) { + repo->feature_reader->Next(); + } + + int kaldi_feature_repo_is_end(KaldiFeatureRepo *repo) { + return repo->feature_reader->Done(); + } + + void kaldi_feature_repo_destroy(KaldiFeatureRepo *repo) { + if (repo->feature_reader) + delete repo->feature_reader; + delete repo; + } + + struct KaldiLabelRepo { + kaldi::RandomAccessPosteriorReader *targets_reader; + }; + + KaldiLabelRepo *kaldi_label_repo_new(const char *targets_rspecifier, const char *fmt) { + KaldiLabelRepo *repo = new KaldiLabelRepo(); + repo->targets_reader = new kaldi::RandomAccessPosteriorReader(string(targets_rspecifier)); + return repo; + } + + Matrix *kaldi_label_repo_read_utterance(KaldiLabelRepo *repo, KaldiFeatureRepo *frepo, int frm_ext, int nframes, + lua_State *L, + int debug) { + Matrix *mat; + kaldi::Posterior targets = repo->targets_reader->Value(frepo->utt); + + int n = targets.size() < nframes ? targets.size() : nframes; + int m = (int)targets[0].size(); + + Status status; + assert(sizeof(BaseFloat) == sizeof(float)); + if(sizeof(BaseFloat) == sizeof(float)) + mat = nerv_matrix_host_float_create(n, m, &status); + else if(sizeof(BaseFloat) == sizeof(double)) + mat = nerv_matrix_host_double_create(n, m, &status); + NERV_LUA_CHECK_STATUS(L, status); + size_t stride = mat->stride; + + if (debug) + fprintf(stderr, "[kaldi] label: %s %d %d\n", frepo->utt.c_str(), n, m); + for (int i = 0; i < n; i++) + for(int j = 0; j < m; j++) + *((BaseFloat *)((char *)mat->data.f + (i * stride + j))) = (BaseFloat)targets[i][j].first; + return mat; + } + + void kaldi_label_repo_destroy(KaldiLabelRepo *repo) { + if(repo->targets_reader) + delete repo->targets_reader; + delete repo; + } +} diff --git a/kaldi_io/src/cwrapper_kaldi.h b/kaldi_io/src/cwrapper_kaldi.h new file mode 100644 index 0000000..e34cb5a --- /dev/null +++ b/kaldi_io/src/cwrapper_kaldi.h @@ -0,0 +1,29 @@ +#ifndef NERV_kaldi_KALDI_IO_CWRAPPER +#define NERV_kaldi_KALDI_IO_CWRAPPER +#include "nerv/matrix/matrix.h" +#include "nerv/common.h" +#ifdef __cplusplus +extern "C" { +#endif + + typedef struct KaldiFeatureRepo KaldiFeatureRepo; + + KaldiFeatureRepo *kaldi_feature_repo_new(const char *); + Matrix *kaldi_feature_repo_read_utterance(KaldiFeatureRepo *repo, lua_State *L, int debug); + void kaldi_feature_repo_next(KaldiFeatureRepo *repo); + int kaldi_feature_repo_is_end(KaldiFeatureRepo *repo); + void kaldi_feature_repo_destroy(KaldiFeatureRepo *repo); + + typedef struct KaldiLabelRepo KaldiLabelRepo; + + KaldiLabelRepo *kaldi_label_repo_new(const char *, const char *fmt); + + Matrix *kaldi_label_repo_read_utterance(KaldiLabelRepo *repo, KaldiFeatureRepo *, int, int, + lua_State *L, + int debug); + + void kaldi_label_repo_destroy(KaldiLabelRepo *repo); +#ifdef __cplusplus +} +#endif +#endif diff --git a/kaldi_io/src/init.c b/kaldi_io/src/init.c new file mode 100644 index 0000000..413452c --- /dev/null +++ b/kaldi_io/src/init.c @@ -0,0 +1,106 @@ +#include "nerv/common.h" +#include "cwrapper_kaldi.h" +#include <stdio.h> + +const char *nerv_kaldi_feat_repo_tname = "nerv.KaldiFeatureRepo"; +const char *nerv_kaldi_label_repo_tname = "nerv.KaldiLabelRepo"; +const char *nerv_matrix_host_float_tname = "nerv.MMatrixFloat"; + +static int feat_repo_new(lua_State *L) { + const char *feature_rsepcifier = luaL_checkstring(L, 1); + KaldiFeatureRepo *repo = kaldi_feature_repo_new(feature_rsepcifier); + luaT_pushudata(L, repo, nerv_kaldi_feat_repo_tname); + return 1; +} + +static int feat_repo_destroy(lua_State *L) { + KaldiFeatureRepo *repo = luaT_checkudata(L, 1, nerv_kaldi_feat_repo_tname); + kaldi_feature_repo_destroy(repo); + return 0; +} + +static int feat_repo_current_utterance(lua_State *L) { + KaldiFeatureRepo *repo = luaT_checkudata(L, 1, nerv_kaldi_feat_repo_tname); + int debug; + if (!lua_isboolean(L, 2)) + nerv_error(L, "debug flag should be a boolean"); + debug = lua_toboolean(L, 2); + Matrix *utter = kaldi_feature_repo_read_utterance(repo, L, debug); + luaT_pushudata(L, utter, nerv_matrix_host_float_tname); + return 1; +} + +static int feat_repo_next(lua_State *L) { + KaldiFeatureRepo *repo = luaT_checkudata(L, 1, nerv_kaldi_feat_repo_tname); + kaldi_feature_repo_next(repo); + return 0; +} + +static int feat_repo_is_end(lua_State *L) { + KaldiFeatureRepo *repo = luaT_checkudata(L, 1, nerv_kaldi_feat_repo_tname); + lua_pushboolean(L, kaldi_feature_repo_is_end(repo)); + return 1; +} + +static const luaL_Reg feat_repo_methods[] = { + {"cur_utter", feat_repo_current_utterance}, + {"next", feat_repo_next}, + {"is_end", feat_repo_is_end}, + {NULL, NULL} +}; + +static int label_repo_new(lua_State *L) { + const char *targets_rspecifier = luaL_checkstring(L, 1); + const char *fmt = luaL_checkstring(L, 2); + KaldiLabelRepo *repo = kaldi_label_repo_new(targets_rspecifier, fmt); + luaT_pushudata(L, repo, nerv_kaldi_label_repo_tname); + return 1; +} + +static int label_repo_read_utterance(lua_State *L) { + KaldiLabelRepo *repo = luaT_checkudata(L, 1, nerv_kaldi_label_repo_tname); + KaldiFeatureRepo *feat_repo = luaT_checkudata(L, 2, nerv_kaldi_feat_repo_tname); + int frm_ext, nframes, debug; + if (!lua_isnumber(L, 3)) + nerv_error(L, "frm_ext should be a number"); + frm_ext = lua_tonumber(L, 3); + if (!lua_isnumber(L, 4)) + nerv_error(L, "nframes should be a number"); + nframes = lua_tonumber(L, 4); + if (!lua_isboolean(L, 5)) + nerv_error(L, "debug flag should be a boolean"); + debug = lua_toboolean(L, 5); + Matrix *utter = kaldi_label_repo_read_utterance(repo, feat_repo, frm_ext, nframes, L, debug); + luaT_pushudata(L, utter, nerv_matrix_host_float_tname); + return 1; +} + +static int label_repo_destroy(lua_State *L) { + KaldiLabelRepo *repo = luaT_checkudata(L, 1, nerv_kaldi_label_repo_tname); + kaldi_label_repo_destroy(repo); + return 0; +} + +static const luaL_Reg label_repo_methods[] = { + {"get_utter", label_repo_read_utterance}, + {NULL, NULL} +}; + +static void feat_repo_init(lua_State *L) { + luaT_newmetatable(L, nerv_kaldi_feat_repo_tname, NULL, + feat_repo_new, feat_repo_destroy, NULL); + luaL_register(L, NULL, feat_repo_methods); + lua_pop(L, 1); +} + +static void label_repo_init(lua_State *L) { + luaT_newmetatable(L, nerv_kaldi_label_repo_tname, NULL, + label_repo_new, label_repo_destroy, NULL); + luaL_register(L, NULL, label_repo_methods); + lua_pop(L, 1); +} + +void kaldi_io_init(lua_State *L) { + feat_repo_init(L); + label_repo_init(L); +} diff --git a/kaldi_io/src/kaldi/base/io-funcs-inl.h b/kaldi_io/src/kaldi/base/io-funcs-inl.h new file mode 100644 index 0000000..e55458e --- /dev/null +++ b/kaldi_io/src/kaldi/base/io-funcs-inl.h @@ -0,0 +1,219 @@ +// base/io-funcs-inl.h + +// Copyright 2009-2011 Microsoft Corporation; Saarland University; +// Jan Silovsky; Yanmin Qian; Johns Hopkins University (Author: Daniel Povey) + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at + +// http://www.apache.org/licenses/LICENSE-2.0 + +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#ifndef KALDI_BASE_IO_FUNCS_INL_H_ +#define KALDI_BASE_IO_FUNCS_INL_H_ 1 + +// Do not include this file directly. It is included by base/io-funcs.h + +#include <limits> +#include <vector> + +namespace kaldi { + +// Template that covers integers. +template<class T> void WriteBasicType(std::ostream &os, + bool binary, T t) { + // Compile time assertion that this is not called with a wrong type. + KALDI_ASSERT_IS_INTEGER_TYPE(T); + if (binary) { + char len_c = (std::numeric_limits<T>::is_signed ? 1 : -1) + * static_cast<char>(sizeof(t)); + os.put(len_c); + os.write(reinterpret_cast<const char *>(&t), sizeof(t)); + } else { + if (sizeof(t) == 1) + os << static_cast<int16>(t) << " "; + else + os << t << " "; + } + if (os.fail()) { + throw std::runtime_error("Write failure in WriteBasicType."); + } +} + +// Template that covers integers. +template<class T> inline void ReadBasicType(std::istream &is, + bool binary, T *t) { + KALDI_PARANOID_ASSERT(t != NULL); + // Compile time assertion that this is not called with a wrong type. + KALDI_ASSERT_IS_INTEGER_TYPE(T); + if (binary) { + int len_c_in = is.get(); + if (len_c_in == -1) + KALDI_ERR << "ReadBasicType: encountered end of stream."; + char len_c = static_cast<char>(len_c_in), len_c_expected + = (std::numeric_limits<T>::is_signed ? 1 : -1) + * static_cast<char>(sizeof(*t)); + + if (len_c != len_c_expected) { + KALDI_ERR << "ReadBasicType: did not get expected integer type, " + << static_cast<int>(len_c) + << " vs. " << static_cast<int>(len_c_expected) + << ". You can change this code to successfully" + << " read it later, if needed."; + // insert code here to read "wrong" type. Might have a switch statement. + } + is.read(reinterpret_cast<char *>(t), sizeof(*t)); + } else { + if (sizeof(*t) == 1) { + int16 i; + is >> i; + *t = i; + } else { + is >> *t; + } + } + if (is.fail()) { + KALDI_ERR << "Read failure in ReadBasicType, file position is " + << is.tellg() << ", next char is " << is.peek(); + } +} + + +template<class T> inline void WriteIntegerVector(std::ostream &os, bool binary, + const std::vector<T> &v) { + // Compile time assertion that this is not called with a wrong type. + KALDI_ASSERT_IS_INTEGER_TYPE(T); + if (binary) { + char sz = sizeof(T); // this is currently just a check. + os.write(&sz, 1); + int32 vecsz = static_cast<int32>(v.size()); + KALDI_ASSERT((size_t)vecsz == v.size()); + os.write(reinterpret_cast<const char *>(&vecsz), sizeof(vecsz)); + if (vecsz != 0) { + os.write(reinterpret_cast<const char *>(&(v[0])), sizeof(T)*vecsz); + } + } else { + // focus here is on prettiness of text form rather than + // efficiency of reading-in. + // reading-in is dominated by low-level operations anyway: + // for efficiency use binary. + os << "[ "; + typename std::vector<T>::const_iterator iter = v.begin(), end = v.end(); + for (; iter != end; ++iter) { + if (sizeof(T) == 1) + os << static_cast<int16>(*iter) << " "; + else + os << *iter << " "; + } + os << "]\n"; + } + if (os.fail()) { + throw std::runtime_error("Write failure in WriteIntegerType."); + } +} + + +template<class T> inline void ReadIntegerVector(std::istream &is, + bool binary, + std::vector<T> *v) { + KALDI_ASSERT_IS_INTEGER_TYPE(T); + KALDI_ASSERT(v != NULL); + if (binary) { + int sz = is.peek(); + if (sz == sizeof(T)) { + is.get(); + } else { // this is currently just a check. + KALDI_ERR << "ReadIntegerVector: expected to see type of size " + << sizeof(T) << ", saw instead " << sz << ", at file position " + << is.tellg(); + } + int32 vecsz; + is.read(reinterpret_cast<char *>(&vecsz), sizeof(vecsz)); + if (is.fail() || vecsz < 0) goto bad; + v->resize(vecsz); + if (vecsz > 0) { + is.read(reinterpret_cast<char *>(&((*v)[0])), sizeof(T)*vecsz); + } + } else { + std::vector<T> tmp_v; // use temporary so v doesn't use extra memory + // due to resizing. + is >> std::ws; + if (is.peek() != static_cast<int>('[')) { + KALDI_ERR << "ReadIntegerVector: expected to see [, saw " + << is.peek() << ", at file position " << is.tellg(); + } + is.get(); // consume the '['. + is >> std::ws; // consume whitespace. + while (is.peek() != static_cast<int>(']')) { + if (sizeof(T) == 1) { // read/write chars as numbers. + int16 next_t; + is >> next_t >> std::ws; + if (is.fail()) goto bad; + else + tmp_v.push_back((T)next_t); + } else { + T next_t; + is >> next_t >> std::ws; + if (is.fail()) goto bad; + else + tmp_v.push_back(next_t); + } + } + is.get(); // get the final ']'. + *v = tmp_v; // could use std::swap to use less temporary memory, but this + // uses less permanent memory. + } + if (!is.fail()) return; + bad: + KALDI_ERR << "ReadIntegerVector: read failure at file position " + << is.tellg(); +} + +// Initialize an opened stream for writing by writing an optional binary +// header and modifying the floating-point precision. +inline void InitKaldiOutputStream(std::ostream &os, bool binary) { + // This does not throw exceptions (does not check for errors). + if (binary) { + os.put('\0'); + os.put('B'); + } + // Note, in non-binary mode we may at some point want to mess with + // the precision a bit. + // 7 is a bit more than the precision of float.. + if (os.precision() < 7) + os.precision(7); +} + +/// Initialize an opened stream for reading by detecting the binary header and +// setting the "binary" value appropriately. +inline bool InitKaldiInputStream(std::istream &is, bool *binary) { + // Sets the 'binary' variable. + // Throws exception in the very unusual situation that stream + // starts with '\0' but not then 'B'. + + if (is.peek() == '\0') { // seems to be binary + is.get(); + if (is.peek() != 'B') { + return false; + } + is.get(); + *binary = true; + return true; + } else { + *binary = false; + return true; + } +} + +} // end namespace kaldi. + +#endif // KALDI_BASE_IO_FUNCS_INL_H_ diff --git a/kaldi_io/src/kaldi/base/io-funcs.h b/kaldi_io/src/kaldi/base/io-funcs.h new file mode 100644 index 0000000..2bc9da8 --- /dev/null +++ b/kaldi_io/src/kaldi/base/io-funcs.h @@ -0,0 +1,231 @@ +// base/io-funcs.h + +// Copyright 2009-2011 Microsoft Corporation; Saarland University; +// Jan Silovsky; Yanmin Qian + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at + +// http://www.apache.org/licenses/LICENSE-2.0 + +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#ifndef KALDI_BASE_IO_FUNCS_H_ +#define KALDI_BASE_IO_FUNCS_H_ + +// This header only contains some relatively low-level I/O functions. +// The full Kaldi I/O declarations are in ../util/kaldi-io.h +// and ../util/kaldi-table.h +// They were put in util/ in order to avoid making the Matrix library +// dependent on them. + +#include <cctype> +#include <vector> +#include <string> +#include "base/kaldi-common.h" + +namespace kaldi { + + + +/* + This comment describes the Kaldi approach to I/O. All objects can be written + and read in two modes: binary and text. In addition we want to make the I/O + work if we redefine the typedef "BaseFloat" between floats and doubles. + We also want to have control over whitespace in text mode without affecting + the meaning of the file, for pretty-printing purposes. + + Errors are handled by throwing an exception (std::runtime_error). + + For integer and floating-point types (and boolean values): + + WriteBasicType(std::ostream &, bool binary, const T&); + ReadBasicType(std::istream &, bool binary, T*); + + and we expect these functions to be defined in such a way that they work when + the type T changes between float and double, so you can read float into double + and vice versa]. Note that for efficiency and space-saving reasons, the Vector + and Matrix classes do not use these functions [but they preserve the type + interchangeability in their own way] + + For a class (or struct) C: + class C { + .. + Write(std::ostream &, bool binary, [possibly extra optional args for specific classes]) const; + Read(std::istream &, bool binary, [possibly extra optional args for specific classes]); + .. + } + NOTE: The only actual optional args we used are the "add" arguments in + Vector/Matrix classes, which specify whether we should sum the data already + in the class with the data being read. + + For types which are typedef's involving stl classes, I/O is as follows: + typedef std::vector<std::pair<A, B> > MyTypedefName; + + The user should define something like: + + WriteMyTypedefName(std::ostream &, bool binary, const MyTypedefName &t); + ReadMyTypedefName(std::ostream &, bool binary, MyTypedefName *t); + + The user would have to write these functions. + + For a type std::vector<T>: + + void WriteIntegerVector(std::ostream &os, bool binary, const std::vector<T> &v); + void ReadIntegerVector(std::istream &is, bool binary, std::vector<T> *v); + + For other types, e.g. vectors of pairs, the user should create a routine of the + type WriteMyTypedefName. This is to avoid introducing confusing templated functions; + we could easily create templated functions to handle most of these cases but they + would have to share the same name. + + It also often happens that the user needs to write/read special tokens as part + of a file. These might be class headers, or separators/identifiers in the class. + We provide special functions for manipulating these. These special tokens must + be nonempty and must not contain any whitespace. + + void WriteToken(std::ostream &os, bool binary, const char*); + void WriteToken(std::ostream &os, bool binary, const std::string & token); + int Peek(std::istream &is, bool binary); + void ReadToken(std::istream &is, bool binary, std::string *str); + void PeekToken(std::istream &is, bool binary, std::string *str); + + + WriteToken writes the token and one space (whether in binary or text mode). + + Peek returns the first character of the next token, by consuming whitespace + (in text mode) and then returning the peek() character. It returns -1 at EOF; + it doesn't throw. It's useful if a class can have various forms based on + typedefs and virtual classes, and wants to know which version to read. + + ReadToken allow the caller to obtain the next token. PeekToken works just + like ReadToken, but seeks back to the beginning of the token. A subsequent + call to ReadToken will read the same token again. This is useful when + different object types are written to the same file; using PeekToken one can + decide which of the objects to read. + + There is currently no special functionality for writing/reading strings (where the strings + contain data rather than "special tokens" that are whitespace-free and nonempty). This is + because Kaldi is structured in such a way that strings don't appear, except as OpenFst symbol + table entries (and these have their own format). + + + NOTE: you should not call ReadIntegerType and WriteIntegerType with types, + such as int and size_t, that are machine-independent -- at least not + if you want your file formats to port between machines. Use int32 and + int64 where necessary. There is no way to detect this using compile-time + assertions because C++ only keeps track of the internal representation of + the type. +*/ + +/// \addtogroup io_funcs_basic +/// @{ + + +/// WriteBasicType is the name of the write function for bool, integer types, +/// and floating-point types. They all throw on error. +template<class T> void WriteBasicType(std::ostream &os, bool binary, T t); + +/// ReadBasicType is the name of the read function for bool, integer types, +/// and floating-point types. They all throw on error. +template<class T> void ReadBasicType(std::istream &is, bool binary, T *t); + + +// Declare specialization for bool. +template<> +void WriteBasicType<bool>(std::ostream &os, bool binary, bool b); + +template <> +void ReadBasicType<bool>(std::istream &is, bool binary, bool *b); + +// Declare specializations for float and double. +template<> +void WriteBasicType<float>(std::ostream &os, bool binary, float f); + +template<> +void WriteBasicType<double>(std::ostream &os, bool binary, double f); + +template<> +void ReadBasicType<float>(std::istream &is, bool binary, float *f); + +template<> +void ReadBasicType<double>(std::istream &is, bool binary, double *f); + +// Define ReadBasicType that accepts an "add" parameter to add to +// the destination. Caution: if used in Read functions, be careful +// to initialize the parameters concerned to zero in the default +// constructor. +template<class T> +inline void ReadBasicType(std::istream &is, bool binary, T *t, bool add) { + if (!add) { + ReadBasicType(is, binary, t); + } else { + T tmp = T(0); + ReadBasicType(is, binary, &tmp); + *t += tmp; + } +} + +/// Function for writing STL vectors of integer types. +template<class T> inline void WriteIntegerVector(std::ostream &os, bool binary, + const std::vector<T> &v); + +/// Function for reading STL vector of integer types. +template<class T> inline void ReadIntegerVector(std::istream &is, bool binary, + std::vector<T> *v); + +/// The WriteToken functions are for writing nonempty sequences of non-space +/// characters. They are not for general strings. +void WriteToken(std::ostream &os, bool binary, const char *token); +void WriteToken(std::ostream &os, bool binary, const std::string & token); + +/// Peek consumes whitespace (if binary == false) and then returns the peek() +/// value of the stream. +int Peek(std::istream &is, bool binary); + +/// ReadToken gets the next token and puts it in str (exception on failure). +void ReadToken(std::istream &is, bool binary, std::string *token); + +/// PeekToken will return the first character of the next token, or -1 if end of +/// file. It's the same as Peek(), except if the first character is '<' it will +/// skip over it and will return the next character. It will unget the '<' so +/// the stream is where it was before you did PeekToken(). +int PeekToken(std::istream &is, bool binary); + +/// ExpectToken tries to read in the given token, and throws an exception +/// on failure. +void ExpectToken(std::istream &is, bool binary, const char *token); +void ExpectToken(std::istream &is, bool binary, const std::string & token); + +/// ExpectPretty attempts to read the text in "token", but only in non-binary +/// mode. Throws exception on failure. It expects an exact match except that +/// arbitrary whitespace matches arbitrary whitespace. +void ExpectPretty(std::istream &is, bool binary, const char *token); +void ExpectPretty(std::istream &is, bool binary, const std::string & token); + +/// @} end "addtogroup io_funcs_basic" + + +/// InitKaldiOutputStream initializes an opened stream for writing by writing an +/// optional binary header and modifying the floating-point precision; it will +/// typically not be called by users directly. +inline void InitKaldiOutputStream(std::ostream &os, bool binary); + +/// InitKaldiInputStream initializes an opened stream for reading by detecting +/// the binary header and setting the "binary" value appropriately; +/// It will typically not be called by users directly. +inline bool InitKaldiInputStream(std::istream &is, bool *binary); + +} // end namespace kaldi. + +#include "base/io-funcs-inl.h" + +#endif // KALDI_BASE_IO_FUNCS_H_ diff --git a/kaldi_io/src/kaldi/base/kaldi-common.h b/kaldi_io/src/kaldi/base/kaldi-common.h new file mode 100644 index 0000000..33f6f31 --- /dev/null +++ b/kaldi_io/src/kaldi/base/kaldi-common.h @@ -0,0 +1,41 @@ +// base/kaldi-common.h + +// Copyright 2009-2011 Microsoft Corporation + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#ifndef KALDI_BASE_KALDI_COMMON_H_ +#define KALDI_BASE_KALDI_COMMON_H_ 1 + +#include <cstddef> +#include <cstdlib> +#include <cstring> // C string stuff like strcpy +#include <string> +#include <sstream> +#include <stdexcept> +#include <cassert> +#include <vector> +#include <iostream> +#include <fstream> + +#include "base/kaldi-utils.h" +#include "base/kaldi-error.h" +#include "base/kaldi-types.h" +#include "base/io-funcs.h" +#include "base/kaldi-math.h" + +#endif // KALDI_BASE_KALDI_COMMON_H_ + diff --git a/kaldi_io/src/kaldi/base/kaldi-error.h b/kaldi_io/src/kaldi/base/kaldi-error.h new file mode 100644 index 0000000..8334e42 --- /dev/null +++ b/kaldi_io/src/kaldi/base/kaldi-error.h @@ -0,0 +1,153 @@ +// base/kaldi-error.h + +// Copyright 2009-2011 Microsoft Corporation; Ondrej Glembek; Lukas Burget; +// Saarland University + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#ifndef KALDI_BASE_KALDI_ERROR_H_ +#define KALDI_BASE_KALDI_ERROR_H_ 1 + +#include <stdexcept> +#include <string> +#include <cstring> +#include <sstream> +#include <cstdio> + +#ifdef _MSC_VER +#define NOEXCEPT(Predicate) +#elif __cplusplus > 199711L || defined(__GXX_EXPERIMENTAL_CXX0X__) +#define NOEXCEPT(Predicate) noexcept((Predicate)) +#else +#define NOEXCEPT(Predicate) +#endif + +#include "base/kaldi-types.h" +#include "base/kaldi-utils.h" + +/* Important that this file does not depend on any other kaldi headers. */ + + +namespace kaldi { + +/// \addtogroup error_group +/// @{ + +/// This is set by util/parse-options.{h, cc} if you set --verbose = ? option +extern int32 g_kaldi_verbose_level; + +/// This is set by util/parse-options.{h, cc} (from argv[0]) and used (if set) +/// in error reporting code to display the name of the program (this is because +/// in our scripts, we often mix together the stderr of many programs). it is +/// the base-name of the program (no directory), followed by ':' We don't use +/// std::string, due to the static initialization order fiasco. +extern const char *g_program_name; + +inline int32 GetVerboseLevel() { return g_kaldi_verbose_level; } + +/// This should be rarely used; command-line programs set the verbose level +/// automatically from ParseOptions. +inline void SetVerboseLevel(int32 i) { g_kaldi_verbose_level = i; } + +// Class KaldiLogMessage is invoked from the KALDI_WARN, KALDI_VLOG and +// KALDI_LOG macros. It prints the message to stderr. Note: we avoid +// using cerr, due to problems with thread safety. fprintf is guaranteed +// thread-safe. + +// class KaldiWarnMessage is invoked from the KALDI_WARN macro. +class KaldiWarnMessage { + public: + inline std::ostream &stream() { return ss; } + KaldiWarnMessage(const char *func, const char *file, int32 line); + ~KaldiWarnMessage() { fprintf(stderr, "%s\n", ss.str().c_str()); } + private: + std::ostringstream ss; +}; + +// class KaldiLogMessage is invoked from the KALDI_LOG macro. +class KaldiLogMessage { + public: + inline std::ostream &stream() { return ss; } + KaldiLogMessage(const char *func, const char *file, int32 line); + ~KaldiLogMessage() { fprintf(stderr, "%s\n", ss.str().c_str()); } + private: + std::ostringstream ss; +}; + +// Class KaldiVlogMessage is invoked from the KALDI_VLOG macro. +class KaldiVlogMessage { + public: + KaldiVlogMessage(const char *func, const char *file, int32 line, + int32 verbose_level); + inline std::ostream &stream() { return ss; } + ~KaldiVlogMessage() { fprintf(stderr, "%s\n", ss.str().c_str()); } + private: + std::ostringstream ss; +}; + + +// class KaldiErrorMessage is invoked from the KALDI_ERROR macro. +// The destructor throws an exception. +class KaldiErrorMessage { + public: + KaldiErrorMessage(const char *func, const char *file, int32 line); + inline std::ostream &stream() { return ss; } + ~KaldiErrorMessage() NOEXCEPT(false); // defined in kaldi-error.cc + private: + std::ostringstream ss; +}; + + + +#ifdef _MSC_VER +#define __func__ __FUNCTION__ +#endif + +#ifndef NDEBUG +#define KALDI_ASSERT(cond) \ + if (!(cond)) kaldi::KaldiAssertFailure_(__func__, __FILE__, __LINE__, #cond); +#else +#define KALDI_ASSERT(cond) +#endif +// also see KALDI_COMPILE_TIME_ASSERT, defined in base/kaldi-utils.h, +// and KALDI_ASSERT_IS_INTEGER_TYPE and KALDI_ASSERT_IS_FLOATING_TYPE, +// also defined there. +#ifdef KALDI_PARANOID // some more expensive asserts only checked if this defined +#define KALDI_PARANOID_ASSERT(cond) \ + if (!(cond)) kaldi::KaldiAssertFailure_(__func__, __FILE__, __LINE__, #cond); +#else +#define KALDI_PARANOID_ASSERT(cond) +#endif + +#define KALDI_ERR kaldi::KaldiErrorMessage(__func__, __FILE__, __LINE__).stream() +#define KALDI_WARN kaldi::KaldiWarnMessage(__func__, __FILE__, __LINE__).stream() +#define KALDI_LOG kaldi::KaldiLogMessage(__func__, __FILE__, __LINE__).stream() + +#define KALDI_VLOG(v) if (v <= kaldi::g_kaldi_verbose_level) \ + kaldi::KaldiVlogMessage(__func__, __FILE__, __LINE__, v).stream() + +inline bool IsKaldiError(const std::string &str) { + return(!strncmp(str.c_str(), "ERROR ", 6)); +} + +void KaldiAssertFailure_(const char *func, const char *file, + int32 line, const char *cond_str); + +/// @} end "addtogroup error_group" + +} // namespace kaldi + +#endif // KALDI_BASE_KALDI_ERROR_H_ diff --git a/kaldi_io/src/kaldi/base/kaldi-math.h b/kaldi_io/src/kaldi/base/kaldi-math.h new file mode 100644 index 0000000..4f60d00 --- /dev/null +++ b/kaldi_io/src/kaldi/base/kaldi-math.h @@ -0,0 +1,346 @@ +// base/kaldi-math.h + +// Copyright 2009-2011 Ondrej Glembek; Microsoft Corporation; Yanmin Qian; +// Jan Silovsky; Saarland University +// +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#ifndef KALDI_BASE_KALDI_MATH_H_ +#define KALDI_BASE_KALDI_MATH_H_ 1 + +#ifdef _MSC_VER +#include <float.h> +#endif + +#include <cmath> +#include <limits> +#include <vector> + +#include "base/kaldi-types.h" +#include "base/kaldi-common.h" + + +#ifndef DBL_EPSILON +#define DBL_EPSILON 2.2204460492503131e-16 +#endif +#ifndef FLT_EPSILON +#define FLT_EPSILON 1.19209290e-7f +#endif + +#ifndef M_PI +# define M_PI 3.1415926535897932384626433832795 +#endif + +#ifndef M_SQRT2 +# define M_SQRT2 1.4142135623730950488016887 +#endif + + +#ifndef M_2PI +# define M_2PI 6.283185307179586476925286766559005 +#endif + +#ifndef M_SQRT1_2 +# define M_SQRT1_2 0.7071067811865475244008443621048490 +#endif + +#ifndef M_LOG_2PI +#define M_LOG_2PI 1.8378770664093454835606594728112 +#endif + +#ifndef M_LN2 +#define M_LN2 0.693147180559945309417232121458 +#endif + +#ifdef _MSC_VER +# define KALDI_ISNAN _isnan +# define KALDI_ISINF(x) (!_isnan(x) && _isnan(x-x)) +# define KALDI_ISFINITE _finite +#else +# define KALDI_ISNAN std::isnan +# define KALDI_ISINF std::isinf +# define KALDI_ISFINITE(x) std::isfinite(x) +#endif +#if !defined(KALDI_SQR) +# define KALDI_SQR(x) ((x) * (x)) +#endif + +namespace kaldi { + +// -infinity +const float kLogZeroFloat = -std::numeric_limits<float>::infinity(); +const double kLogZeroDouble = -std::numeric_limits<double>::infinity(); +const BaseFloat kLogZeroBaseFloat = -std::numeric_limits<BaseFloat>::infinity(); + +// Returns a random integer between 0 and RAND_MAX, inclusive +int Rand(struct RandomState* state=NULL); + +// State for thread-safe random number generator +struct RandomState { + RandomState(); + unsigned seed; +}; + +// Returns a random integer between min and max inclusive. +int32 RandInt(int32 min, int32 max, struct RandomState* state=NULL); + +bool WithProb(BaseFloat prob, struct RandomState* state=NULL); // Returns true with probability "prob", +// with 0 <= prob <= 1 [we check this]. +// Internally calls Rand(). This function is carefully implemented so +// that it should work even if prob is very small. + +/// Returns a random number strictly between 0 and 1. +inline float RandUniform(struct RandomState* state = NULL) { + return static_cast<float>((Rand(state) + 1.0) / (RAND_MAX+2.0)); +} + +inline float RandGauss(struct RandomState* state = NULL) { + return static_cast<float>(sqrtf (-2 * logf(RandUniform(state))) + * cosf(2*M_PI*RandUniform(state))); +} + +// Returns poisson-distributed random number. Uses Knuth's algorithm. +// Take care: this takes time proportinal +// to lambda. Faster algorithms exist but are more complex. +int32 RandPoisson(float lambda, struct RandomState* state=NULL); + +// Returns a pair of gaussian random numbers. Uses Box-Muller transform +void RandGauss2(float *a, float *b, RandomState *state = NULL); +void RandGauss2(double *a, double *b, RandomState *state = NULL); + +// Also see Vector<float,double>::RandCategorical(). + +// This is a randomized pruning mechanism that preserves expectations, +// that we typically use to prune posteriors. +template<class Float> +inline Float RandPrune(Float post, BaseFloat prune_thresh, struct RandomState* state=NULL) { + KALDI_ASSERT(prune_thresh >= 0.0); + if (post == 0.0 || std::abs(post) >= prune_thresh) + return post; + return (post >= 0 ? 1.0 : -1.0) * + (RandUniform(state) <= fabs(post)/prune_thresh ? prune_thresh : 0.0); +} + +static const double kMinLogDiffDouble = std::log(DBL_EPSILON); // negative! +static const float kMinLogDiffFloat = std::log(FLT_EPSILON); // negative! + +inline double LogAdd(double x, double y) { + double diff; + if (x < y) { + diff = x - y; + x = y; + } else { + diff = y - x; + } + // diff is negative. x is now the larger one. + + if (diff >= kMinLogDiffDouble) { + double res; +#ifdef _MSC_VER + res = x + log(1.0 + exp(diff)); +#else + res = x + log1p(exp(diff)); +#endif + return res; + } else { + return x; // return the larger one. + } +} + + +inline float LogAdd(float x, float y) { + float diff; + if (x < y) { + diff = x - y; + x = y; + } else { + diff = y - x; + } + // diff is negative. x is now the larger one. + + if (diff >= kMinLogDiffFloat) { + float res; +#ifdef _MSC_VER + res = x + logf(1.0 + expf(diff)); +#else + res = x + log1pf(expf(diff)); +#endif + return res; + } else { + return x; // return the larger one. + } +} + + +// returns exp(x) - exp(y). +inline double LogSub(double x, double y) { + if (y >= x) { // Throws exception if y>=x. + if (y == x) + return kLogZeroDouble; + else + KALDI_ERR << "Cannot subtract a larger from a smaller number."; + } + + double diff = y - x; // Will be negative. + double res = x + log(1.0 - exp(diff)); + + // res might be NAN if diff ~0.0, and 1.0-exp(diff) == 0 to machine precision + if (KALDI_ISNAN(res)) + return kLogZeroDouble; + return res; +} + + +// returns exp(x) - exp(y). +inline float LogSub(float x, float y) { + if (y >= x) { // Throws exception if y>=x. + if (y == x) + return kLogZeroDouble; + else + KALDI_ERR << "Cannot subtract a larger from a smaller number."; + } + + float diff = y - x; // Will be negative. + float res = x + logf(1.0 - expf(diff)); + + // res might be NAN if diff ~0.0, and 1.0-exp(diff) == 0 to machine precision + if (KALDI_ISNAN(res)) + return kLogZeroFloat; + return res; +} + +/// return abs(a - b) <= relative_tolerance * (abs(a)+abs(b)). +static inline bool ApproxEqual(float a, float b, + float relative_tolerance = 0.001) { + // a==b handles infinities. + if (a==b) return true; + float diff = std::abs(a-b); + if (diff == std::numeric_limits<float>::infinity() + || diff != diff) return false; // diff is +inf or nan. + return (diff <= relative_tolerance*(std::abs(a)+std::abs(b))); +} + +/// assert abs(a - b) <= relative_tolerance * (abs(a)+abs(b)) +static inline void AssertEqual(float a, float b, + float relative_tolerance = 0.001) { + // a==b handles infinities. + KALDI_ASSERT(ApproxEqual(a, b, relative_tolerance)); +} + + +// RoundUpToNearestPowerOfTwo does the obvious thing. It crashes if n <= 0. +int32 RoundUpToNearestPowerOfTwo(int32 n); + +template<class I> I Gcd(I m, I n) { + if (m == 0 || n == 0) { + if (m == 0 && n == 0) { // gcd not defined, as all integers are divisors. + KALDI_ERR << "Undefined GCD since m = 0, n = 0."; + } + return (m == 0 ? (n > 0 ? n : -n) : ( m > 0 ? m : -m)); + // return absolute value of whichever is nonzero + } + // could use compile-time assertion + // but involves messing with complex template stuff. + KALDI_ASSERT(std::numeric_limits<I>::is_integer); + while (1) { + m %= n; + if (m == 0) return (n > 0 ? n : -n); + n %= m; + if (n == 0) return (m > 0 ? m : -m); + } +} + +/// Returns the least common multiple of two integers. Will +/// crash unless the inputs are positive. +template<class I> I Lcm(I m, I n) { + KALDI_ASSERT(m > 0 && n > 0); + I gcd = Gcd(m, n); + return gcd * (m/gcd) * (n/gcd); +} + + +template<class I> void Factorize(I m, std::vector<I> *factors) { + // Splits a number into its prime factors, in sorted order from + // least to greatest, with duplication. A very inefficient + // algorithm, which is mainly intended for use in the + // mixed-radix FFT computation (where we assume most factors + // are small). + KALDI_ASSERT(factors != NULL); + KALDI_ASSERT(m >= 1); // Doesn't work for zero or negative numbers. + factors->clear(); + I small_factors[10] = { 2, 3, 5, 7, 11, 13, 17, 19, 23, 29 }; + + // First try small factors. + for (I i = 0; i < 10; i++) { + if (m == 1) return; // We're done. + while (m % small_factors[i] == 0) { + m /= small_factors[i]; + factors->push_back(small_factors[i]); + } + } + // Next try all odd numbers starting from 31. + for (I j = 31;; j += 2) { + if (m == 1) return; + while (m % j == 0) { + m /= j; + factors->push_back(j); + } + } +} + +inline double Hypot(double x, double y) { return hypot(x, y); } + +inline float Hypot(float x, float y) { return hypotf(x, y); } + +#if !defined(_MSC_VER) || (_MSC_VER >= 1800) +inline double Log1p(double x) { return log1p(x); } + +inline float Log1p(float x) { return log1pf(x); } +#else +inline double Log1p(double x) { + const double cutoff = 1.0e-08; + if (x < cutoff) + return x - 2 * x * x; + else + return log(1.0 + x); +} + +inline float Log1p(float x) { + const float cutoff = 1.0e-07; + if (x < cutoff) + return x - 2 * x * x; + else + return log(1.0 + x); +} +#endif + +inline double Exp(double x) { return exp(x); } + +#ifndef KALDI_NO_EXPF +inline float Exp(float x) { return expf(x); } +#else +inline float Exp(float x) { return exp(x); } +#endif + +inline double Log(double x) { return log(x); } + +inline float Log(float x) { return logf(x); } + + +} // namespace kaldi + + +#endif // KALDI_BASE_KALDI_MATH_H_ diff --git a/kaldi_io/src/kaldi/base/kaldi-types.h b/kaldi_io/src/kaldi/base/kaldi-types.h new file mode 100644 index 0000000..04354b2 --- /dev/null +++ b/kaldi_io/src/kaldi/base/kaldi-types.h @@ -0,0 +1,64 @@ +// base/kaldi-types.h + +// Copyright 2009-2011 Microsoft Corporation; Saarland University; +// Jan Silovsky; Yanmin Qian + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#ifndef KALDI_BASE_KALDI_TYPES_H_ +#define KALDI_BASE_KALDI_TYPES_H_ 1 + +namespace kaldi { +// TYPEDEFS .................................................................. +#if (KALDI_DOUBLEPRECISION != 0) +typedef double BaseFloat; +#else +typedef float BaseFloat; +#endif +} + +#ifdef _MSC_VER +namespace kaldi { +typedef unsigned __int16 uint16; +typedef unsigned __int32 uint32; +typedef __int16 int16; +typedef __int32 int32; +typedef __int64 int64; +typedef unsigned __int64 uint64; +typedef float float32; +typedef double double64; +} +#include <basetsd.h> +#define ssize_t SSIZE_T + +#else +// we can do this a different way if some platform +// we find in the future lacks stdint.h +#include <stdint.h> + +namespace kaldi { +typedef uint16_t uint16; +typedef uint32_t uint32; +typedef uint64_t uint64; +typedef int16_t int16; +typedef int32_t int32; +typedef int64_t int64; +typedef float float32; +typedef double double64; +} // end namespace kaldi +#endif + +#endif // KALDI_BASE_KALDI_TYPES_H_ diff --git a/kaldi_io/src/kaldi/base/kaldi-utils.h b/kaldi_io/src/kaldi/base/kaldi-utils.h new file mode 100644 index 0000000..1b2c893 --- /dev/null +++ b/kaldi_io/src/kaldi/base/kaldi-utils.h @@ -0,0 +1,157 @@ +// base/kaldi-utils.h + +// Copyright 2009-2011 Ondrej Glembek; Microsoft Corporation; +// Saarland University; Karel Vesely; Yanmin Qian + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#ifndef KALDI_BASE_KALDI_UTILS_H_ +#define KALDI_BASE_KALDI_UTILS_H_ 1 + +#include <limits> +#include <string> + +#if defined(_MSC_VER) +# define WIN32_LEAN_AND_MEAN +# define NOMINMAX +# include <windows.h> +#endif + +#if defined(_MSC_VER) +#pragma warning(disable: 4244 4056 4305 4800 4267 4996 4756 4661) +#define __restrict__ +#endif + +#ifdef HAVE_POSIX_MEMALIGN +# define KALDI_MEMALIGN(align, size, pp_orig) \ + (!posix_memalign(pp_orig, align, size) ? *(pp_orig) : NULL) +# define KALDI_MEMALIGN_FREE(x) free(x) +#elif defined(HAVE_MEMALIGN) + /* Some systems have memalign() but no declaration for it */ + void * memalign(size_t align, size_t size); +# define KALDI_MEMALIGN(align, size, pp_orig) \ + (*(pp_orig) = memalign(align, size)) +# define KALDI_MEMALIGN_FREE(x) free(x) +#elif defined(_MSC_VER) +# define KALDI_MEMALIGN(align, size, pp_orig) \ + (*(pp_orig) = _aligned_malloc(size, align)) +# define KALDI_MEMALIGN_FREE(x) _aligned_free(x) +#else +#error Manual memory alignment is no longer supported +#endif + +#ifdef __ICC +#pragma warning(disable: 383) // ICPC remark we don't want. +#pragma warning(disable: 810) // ICPC remark we don't want. +#pragma warning(disable: 981) // ICPC remark we don't want. +#pragma warning(disable: 1418) // ICPC remark we don't want. +#pragma warning(disable: 444) // ICPC remark we don't want. +#pragma warning(disable: 869) // ICPC remark we don't want. +#pragma warning(disable: 1287) // ICPC remark we don't want. +#pragma warning(disable: 279) // ICPC remark we don't want. +#pragma warning(disable: 981) // ICPC remark we don't want. +#endif + + +namespace kaldi { + + +// CharToString prints the character in a human-readable form, for debugging. +std::string CharToString(const char &c); + + +inline int MachineIsLittleEndian() { + int check = 1; + return (*reinterpret_cast<char*>(&check) != 0); +} + +// This function kaldi::Sleep() provides a portable way to sleep for a possibly fractional +// number of seconds. On Windows it's only accurate to microseconds. +void Sleep(float seconds); + +} + +#define KALDI_SWAP8(a) { \ + int t = ((char*)&a)[0]; ((char*)&a)[0]=((char*)&a)[7]; ((char*)&a)[7]=t;\ + t = ((char*)&a)[1]; ((char*)&a)[1]=((char*)&a)[6]; ((char*)&a)[6]=t;\ + t = ((char*)&a)[2]; ((char*)&a)[2]=((char*)&a)[5]; ((char*)&a)[5]=t;\ + t = ((char*)&a)[3]; ((char*)&a)[3]=((char*)&a)[4]; ((char*)&a)[4]=t;} +#define KALDI_SWAP4(a) { \ + int t = ((char*)&a)[0]; ((char*)&a)[0]=((char*)&a)[3]; ((char*)&a)[3]=t;\ + t = ((char*)&a)[1]; ((char*)&a)[1]=((char*)&a)[2]; ((char*)&a)[2]=t;} +#define KALDI_SWAP2(a) { \ + int t = ((char*)&a)[0]; ((char*)&a)[0]=((char*)&a)[1]; ((char*)&a)[1]=t;} + + +// Makes copy constructor and operator= private. Same as in compat.h of OpenFst +// toolkit. If using VS, for which this results in compilation errors, we +// do it differently. + +#if defined(_MSC_VER) +#define KALDI_DISALLOW_COPY_AND_ASSIGN(type) \ + void operator = (const type&) +#else +#define KALDI_DISALLOW_COPY_AND_ASSIGN(type) \ + type(const type&); \ + void operator = (const type&) +#endif + +template<bool B> class KaldiCompileTimeAssert { }; +template<> class KaldiCompileTimeAssert<true> { + public: + static inline void Check() { } +}; + +#define KALDI_COMPILE_TIME_ASSERT(b) KaldiCompileTimeAssert<(b)>::Check() + +#define KALDI_ASSERT_IS_INTEGER_TYPE(I) \ + KaldiCompileTimeAssert<std::numeric_limits<I>::is_specialized \ + && std::numeric_limits<I>::is_integer>::Check() + +#define KALDI_ASSERT_IS_FLOATING_TYPE(F) \ + KaldiCompileTimeAssert<std::numeric_limits<F>::is_specialized \ + && !std::numeric_limits<F>::is_integer>::Check() + +#ifdef _MSC_VER +#include <stdio.h> +#define unlink _unlink +#else +#include <unistd.h> +#endif + + +#ifdef _MSC_VER +#define KALDI_STRCASECMP _stricmp +#else +#define KALDI_STRCASECMP strcasecmp +#endif +#ifdef _MSC_VER +# define KALDI_STRTOLL(cur_cstr, end_cstr) _strtoi64(cur_cstr, end_cstr, 10); +#else +# define KALDI_STRTOLL(cur_cstr, end_cstr) strtoll(cur_cstr, end_cstr, 10); +#endif + +#define KALDI_STRTOD(cur_cstr, end_cstr) strtod(cur_cstr, end_cstr) + +#ifdef _MSC_VER +# define KALDI_STRTOF(cur_cstr, end_cstr) \ + static_cast<float>(strtod(cur_cstr, end_cstr)); +#else +# define KALDI_STRTOF(cur_cstr, end_cstr) strtof(cur_cstr, end_cstr); +#endif + +#endif // KALDI_BASE_KALDI_UTILS_H_ + diff --git a/kaldi_io/src/kaldi/base/timer.h b/kaldi_io/src/kaldi/base/timer.h new file mode 100644 index 0000000..d93a461 --- /dev/null +++ b/kaldi_io/src/kaldi/base/timer.h @@ -0,0 +1,83 @@ +// base/timer.h + +// Copyright 2009-2011 Ondrej Glembek; Microsoft Corporation + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at + +// http://www.apache.org/licenses/LICENSE-2.0 + +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. +#ifndef KALDI_BASE_TIMER_H_ +#define KALDI_BASE_TIMER_H_ + +#include "base/kaldi-utils.h" +// Note: Sleep(float secs) is included in base/kaldi-utils.h. + + +#if defined(_MSC_VER) || defined(MINGW) + +namespace kaldi +{ + +class Timer { + public: + Timer() { Reset(); } + void Reset() { + QueryPerformanceCounter(&time_start_); + } + double Elapsed() { + LARGE_INTEGER time_end; + LARGE_INTEGER freq; + QueryPerformanceCounter(&time_end); + if (QueryPerformanceFrequency(&freq) == 0) return 0.0; // Hardware does not support this. + return ((double)time_end.QuadPart - (double)time_start_.QuadPart) / + ((double)freq.QuadPart); + } + private: + LARGE_INTEGER time_start_; +}; +} + +#else + +# include <sys/time.h> +# include <unistd.h> +namespace kaldi +{ +class Timer +{ + public: + Timer() { Reset(); } + + void Reset() { gettimeofday(&this->time_start_, &time_zone_); } + + /// Returns time in seconds. + double Elapsed() { + struct timeval time_end; + gettimeofday(&time_end, &time_zone_); + double t1, t2; + t1 = (double)time_start_.tv_sec + + (double)time_start_.tv_usec/(1000*1000); + t2 = (double)time_end.tv_sec + (double)time_end.tv_usec/(1000*1000); + return t2-t1; + } + + private: + struct timeval time_start_; + struct timezone time_zone_; +}; +} + +#endif + + +#endif diff --git a/kaldi_io/src/kaldi/hmm/hmm-topology.h b/kaldi_io/src/kaldi/hmm/hmm-topology.h new file mode 100644 index 0000000..53ca427 --- /dev/null +++ b/kaldi_io/src/kaldi/hmm/hmm-topology.h @@ -0,0 +1,172 @@ +// hmm/hmm-topology.h + +// Copyright 2009-2011 Microsoft Corporation + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#ifndef KALDI_HMM_HMM_TOPOLOGY_H_ +#define KALDI_HMM_HMM_TOPOLOGY_H_ + +#include "base/kaldi-common.h" +#include "tree/context-dep.h" +#include "util/const-integer-set.h" + + +namespace kaldi { + + +/// \addtogroup hmm_group +/// @{ + +/* + // The following would be the text form for the "normal" HMM topology. + // Note that the first state is the start state, and the final state, + // which must have no output transitions and must be nonemitting, has + // an exit probability of one (no other state can have nonzero exit + // probability; you can treat the transition probability to the final + // state as an exit probability). + // Note also that it's valid to omit the "<PdfClass>" entry of the <State>, which + // will mean we won't have a pdf on that state [non-emitting state]. This is equivalent + // to setting the <PdfClass> to -1. We do this normally just for the final state. + // The Topology object can have multiple <TopologyEntry> blocks. + // This is useful if there are multiple types of topology in the system. + + <Topology> + <TopologyEntry> + <ForPhones> 1 2 3 4 5 6 7 8 </ForPhones> + <State> 0 <PdfClass> 0 + <Transition> 0 0.5 + <Transition> 1 0.5 + </State> + <State> 1 <PdfClass> 1 + <Transition> 1 0.5 + <Transition> 2 0.5 + </State> + <State> 2 <PdfClass> 2 + <Transition> 2 0.5 + <Transition> 3 0.5 + <Final> 0.5 + </State> + <State> 3 + </State> + </TopologyEntry> + </Topology> +*/ + +// kNoPdf is used where pdf_class or pdf would be used, to indicate, +// none is there. Mainly useful in skippable models, but also used +// for end states. +// A caveat with nonemitting states is that their out-transitions +// are not trainable, due to technical issues with the way +// we decided to accumulate the stats. Any transitions arising from (*) +// HMM states with "kNoPdf" as the label are second-class transitions, +// They do not have "transition-states" or "transition-ids" associated +// with them. They are used to create the FST version of the +// HMMs, where they lead to epsilon arcs. +// (*) "arising from" is a bit of a technical term here, due to the way +// (if reorder == true), we put the transition-id associated with the +// outward arcs of the state, on the input transition to the state. + +/// A constant used in the HmmTopology class as the \ref pdf_class "pdf-class" +/// kNoPdf, which is used when a HMM-state is nonemitting (has no associated +/// PDF). + +static const int32 kNoPdf = -1; + +/// A class for storing topology information for phones. See \ref hmm for context. +/// This object is sometimes accessed in a file by itself, but more often +/// as a class member of the Transition class (this is for convenience to reduce +/// the number of files programs have to access). + +class HmmTopology { + public: + /// A structure defined inside HmmTopology to represent a HMM state. + struct HmmState { + /// The \ref pdf_class pdf-class, typically 0, 1 or 2 (the same as the HMM-state index), + /// but may be different to enable us to hardwire sharing of state, and may be + /// equal to \ref kNoPdf == -1 in order to specify nonemitting states (unusual). + int32 pdf_class; + + /// A list of transitions. The first member of each pair is the index of + /// the next HmmState, and the second is the default transition probability + /// (before training). + std::vector<std::pair<int32, BaseFloat> > transitions; + + explicit HmmState(int32 p): pdf_class(p) { } + + bool operator == (const HmmState &other) const { + return (pdf_class == other.pdf_class && transitions == other.transitions); + } + + HmmState(): pdf_class(-1) { } + }; + + /// TopologyEntry is a typedef that represents the topology of + /// a single (prototype) state. + typedef std::vector<HmmState> TopologyEntry; + + void Read(std::istream &is, bool binary); + void Write(std::ostream &os, bool binary) const; + + // Checks that the object is valid, and throw exception otherwise. + void Check(); + + + /// Returns the topology entry (i.e. vector of HmmState) for this phone; + /// will throw exception if phone not covered by the topology. + const TopologyEntry &TopologyForPhone(int32 phone) const; + + /// Returns the number of \ref pdf_class "pdf-classes" for this phone; + /// throws exception if phone not covered by this topology. + int32 NumPdfClasses(int32 phone) const; + + /// Returns a reference to a sorted, unique list of phones covered by + /// the topology (these phones will be positive integers, and usually + /// contiguous and starting from one but the toolkit doesn't assume + /// they are contiguous). + const std::vector<int32> &GetPhones() const { return phones_; }; + + /// Outputs a vector of int32, indexed by phone, that gives the + /// number of \ref pdf_class pdf-classes for the phones; this is + /// used by tree-building code such as BuildTree(). + void GetPhoneToNumPdfClasses(std::vector<int32> *phone2num_pdf_classes) const; + + HmmTopology() {} + + bool operator == (const HmmTopology &other) const { + return phones_ == other.phones_ && phone2idx_ == other.phone2idx_ + && entries_ == other.entries_; + } + // Allow default assignment operator and copy constructor. + private: + std::vector<int32> phones_; // list of all phones we have topology for. Sorted, uniq. no epsilon (zero) phone. + std::vector<int32> phone2idx_; // map from phones to indexes into the entries vector (or -1 for not present). + std::vector<TopologyEntry> entries_; +}; + + +/// This function returns a HmmTopology object giving a normal 3-state topology, +/// covering all phones in the list "phones". This is mainly of use in testing +/// code. +HmmTopology GetDefaultTopology(const std::vector<int32> &phones); + +/// @} end "addtogroup hmm_group" + + +} // end namespace kaldi + + +#endif diff --git a/kaldi_io/src/kaldi/hmm/hmm-utils.h b/kaldi_io/src/kaldi/hmm/hmm-utils.h new file mode 100644 index 0000000..240f706 --- /dev/null +++ b/kaldi_io/src/kaldi/hmm/hmm-utils.h @@ -0,0 +1,295 @@ +// hmm/hmm-utils.h + +// Copyright 2009-2011 Microsoft Corporation + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#ifndef KALDI_HMM_HMM_UTILS_H_ +#define KALDI_HMM_HMM_UTILS_H_ + +#include "hmm/hmm-topology.h" +#include "hmm/transition-model.h" +#include "lat/kaldi-lattice.h" + +namespace kaldi { + + +/// \defgroup hmm_group_graph Classes and functions for creating FSTs from HMMs +/// \ingroup hmm_group +/// @{ + +/// Configuration class for the GetHTransducer() function; see +/// \ref hmm_graph_config for context. +struct HTransducerConfig { + /// Transition log-prob scale, see \ref hmm_scale. + /// Note this doesn't apply to self-loops; GetHTransducer() does + /// not include self-loops. + BaseFloat transition_scale; + + /// if true, we are constructing time-reversed FST: phone-seqs in ilabel_info + /// are backwards, and we want to output a backwards version of the HMM + /// corresponding to each phone. If reverse == true, + bool reverse; + + /// This variable is only looked at if reverse == true. If reverse == true + /// and push_weights == true, then we push the weights in the reversed FSTs we create for each + /// phone HMM. This is only safe if the HMMs are probabilistic (i.e. not discriminatively + bool push_weights; + + /// delta used if we do push_weights [only relevant if reverse == true + /// and push_weights == true]. + BaseFloat push_delta; + + HTransducerConfig(): + transition_scale(1.0), + reverse(false), + push_weights(true), + push_delta(0.001) + { } + + // Note-- this Register registers the easy-to-register options + // but not the "sym_type" which is an enum and should be handled + // separately in main(). + void Register (OptionsItf *po) { + po->Register("transition-scale", &transition_scale, + "Scale of transition probs (relative to LM)"); + po->Register("reverse", &reverse, + "Set true to build time-reversed FST."); + po->Register("push-weights", &push_weights, + "Push weights (only applicable if reverse == true)"); + po->Register("push-delta", &push_delta, + "Delta used in pushing weights (only applicable if " + "reverse && push-weights"); + } +}; + + +struct HmmCacheHash { + int operator () (const std::pair<int32, std::vector<int32> >&p) const { + VectorHasher<int32> v; + int32 prime = 103049; + return prime*p.first + v(p.second); + } +}; + +/// HmmCacheType is a map from (central-phone, sequence of pdf-ids) to FST, used +/// as cache in GetHmmAsFst, as an optimization. +typedef unordered_map<std::pair<int32, std::vector<int32> >, + fst::VectorFst<fst::StdArc>*, + HmmCacheHash> HmmCacheType; + + +/// Called by GetHTransducer() and probably will not need to be called directly; +/// it creates the FST corresponding to the phone. Does not include self-loops; +/// you have to call AddSelfLoops() for that. Result owned by caller. +/// Returns an acceptor (i.e. ilabels, olabels identical) with transition-ids +/// as the symbols. +/// For documentation in context, see \ref hmm_graph_get_hmm_as_fst +/// @param context_window A vector representing the phonetic context; see +/// \ref tree_window "here" for explanation. +/// @param ctx_dep The object that contains the phonetic decision-tree +/// @param trans_model The transition-model object, which provides +/// the mappings to transition-ids and also the transition +/// probabilities. +/// @param config Configuration object, see \ref HTransducerConfig. +/// @param cache Object used as a lookaside buffer to save computation; +/// if it finds that the object it needs is already there, it will +/// just return a pointer value from "cache"-- not that this means +/// you have to be careful not to delete things twice. + +fst::VectorFst<fst::StdArc> *GetHmmAsFst( + std::vector<int32> context_window, + const ContextDependencyInterface &ctx_dep, + const TransitionModel &trans_model, + const HTransducerConfig &config, + HmmCacheType *cache = NULL); + +/// Included mainly as a form of documentation, not used in any other code +/// currently. Creates the FST with self-loops, and with fewer options. +fst::VectorFst<fst::StdArc>* +GetHmmAsFstSimple(std::vector<int32> context_window, + const ContextDependencyInterface &ctx_dep, + const TransitionModel &trans_model, + BaseFloat prob_scale); + + +/** + * Returns the H tranducer; result owned by caller. + * See \ref hmm_graph_get_h_transducer. The H transducer has on the + * input transition-ids, and also possibly some disambiguation symbols, which + * will be put in disambig_syms. The output side contains the identifiers that + * are indexes into "ilabel_info" (these represent phones-in-context or + * disambiguation symbols). The ilabel_info vector allows GetHTransducer to map + * from symbols to phones-in-context (i.e. phonetic context windows). Any + * singleton symbols in the ilabel_info vector which are not phones, will be + * treated as disambiguation symbols. [Not all recipes use these]. The output + * "disambig_syms_left" will be set to a list of the disambiguation symbols on + * the input of the transducer (i.e. same symbol type as whatever is on the + * input of the transducer + */ +fst::VectorFst<fst::StdArc>* +GetHTransducer (const std::vector<std::vector<int32> > &ilabel_info, + const ContextDependencyInterface &ctx_dep, + const TransitionModel &trans_model, + const HTransducerConfig &config, + std::vector<int32> *disambig_syms_left); + +/** + * GetIlabelMapping produces a mapping that's similar to HTK's logical-to-physical + * model mapping (i.e. the xwrd.clustered.mlist files). It groups together + * "logical HMMs" (i.e. in our world, phonetic context windows) that share the + * same sequence of transition-ids. This can be used in an + * optional graph-creation step that produces a remapped form of CLG that can be + * more productively determinized and minimized. This is used in the command-line program + * make-ilabel-transducer.cc. + * @param ilabel_info_old [in] The original \ref tree_ilabel "ilabel_info" vector + * @param ctx_dep [in] The tree + * @param trans_model [in] The transition-model object + * @param old2new_map [out] The output; this vector, which is of size equal to the + * number of new labels, is a mapping to the old labels such that we could + * create a vector ilabel_info_new such that + * ilabel_info_new[i] == ilabel_info_old[old2new_map[i]] + */ +void GetIlabelMapping (const std::vector<std::vector<int32> > &ilabel_info_old, + const ContextDependencyInterface &ctx_dep, + const TransitionModel &trans_model, + std::vector<int32> *old2new_map); + + + +/** + * For context, see \ref hmm_graph_add_self_loops. Expands an FST that has been + * built without self-loops, and adds the self-loops (it also needs to modify + * the probability of the non-self-loop ones, as the graph without self-loops + * was created in such a way that it was stochastic). Note that the + * disambig_syms will be empty in some recipes (e.g. if you already removed + * the disambiguation symbols). + * @param trans_model [in] Transition model + * @param disambig_syms [in] Sorted, uniq list of disambiguation symbols, required + * if the graph contains disambiguation symbols but only needed for sanity checks. + * @param self_loop_scale [in] Transition-probability scale for self-loops; c.f. + * \ref hmm_scale + * @param reorder [in] If true, reorders the transitions (see \ref hmm_reorder). + * @param fst [in, out] The FST to be modified. + */ +void AddSelfLoops(const TransitionModel &trans_model, + const std::vector<int32> &disambig_syms, // used as a check only. + BaseFloat self_loop_scale, + bool reorder, // true->dan-style, false->lukas-style. + fst::VectorFst<fst::StdArc> *fst); + +/** + * Adds transition-probs, with the supplied + * scales (see \ref hmm_scale), to the graph. + * Useful if you want to create a graph without transition probs, then possibly + * train the model (including the transition probs) but keep the graph fixed, + * and add back in the transition probs. It assumes the fst has transition-ids + * on it. It is not an error if the FST has no states (nothing will be done). + * @param trans_model [in] The transition model + * @param disambig_syms [in] A list of disambiguation symbols, required if the + * graph has disambiguation symbols on its input but only + * used for checks. + * @param transition_scale [in] A scale on transition-probabilities apart from + * those involving self-loops; see \ref hmm_scale. + * @param self_loop_scale [in] A scale on self-loop transition probabilities; + * see \ref hmm_scale. + * @param fst [in, out] The FST to be modified. + */ +void AddTransitionProbs(const TransitionModel &trans_model, + const std::vector<int32> &disambig_syms, + BaseFloat transition_scale, + BaseFloat self_loop_scale, + fst::VectorFst<fst::StdArc> *fst); + +/** + This is as AddSelfLoops(), but operates on a Lattice, where + it affects the graph part of the weight (the first element + of the pair). */ +void AddTransitionProbs(const TransitionModel &trans_model, + BaseFloat transition_scale, + BaseFloat self_loop_scale, + Lattice *lat); + + +/// Returns a transducer from pdfs plus one (input) to transition-ids (output). +/// Currenly of use only for testing. +fst::VectorFst<fst::StdArc>* +GetPdfToTransitionIdTransducer(const TransitionModel &trans_model); + +/// Converts all transition-ids in the FST to pdfs plus one. +/// Placeholder: not implemented yet! +void ConvertTransitionIdsToPdfs(const TransitionModel &trans_model, + const std::vector<int32> &disambig_syms, + fst::VectorFst<fst::StdArc> *fst); + +/// @} end "defgroup hmm_group_graph" + +/// \addtogroup hmm_group +/// @{ + +/// SplitToPhones splits up the TransitionIds in "alignment" into their +/// individual phones (one vector per instance of a phone). At output, +/// the sum of the sizes of the vectors in split_alignment will be the same +/// as the corresponding sum for "alignment". The function returns +/// true on success. If the alignment appears to be incomplete, e.g. +/// not ending at the end-state of a phone, it will still break it up into +/// phones but it will return false. For more serious errors it will +/// die or throw an exception. +/// This function works out by itself whether the graph was created +/// with "reordering" (dan-style graph), and just does the right thing. + +bool SplitToPhones(const TransitionModel &trans_model, + const std::vector<int32> &alignment, + std::vector<std::vector<int32> > *split_alignment); + +/// ConvertAlignment converts an alignment that was created using one +/// model, to another model. They must use a compatible topology (so we +/// know the state alignments of the new model). +/// It returns false if it could not be split to phones (probably +/// because the alignment was partial), but for other kinds of +/// error that are more likely a coding error, it will throw +/// an exception. +bool ConvertAlignment(const TransitionModel &old_trans_model, + const TransitionModel &new_trans_model, + const ContextDependencyInterface &new_ctx_dep, + const std::vector<int32> &old_alignment, + const std::vector<int32> *phone_map, // may be NULL + std::vector<int32> *new_alignment); + +// ConvertPhnxToProns is only needed in bin/phones-to-prons.cc and +// isn't closely related with HMMs, but we put it here as there isn't +// any other obvious place for it and it needs to be tested. +// This function takes a phone-sequence with word-start and word-end +// markers in it, and a word-sequence, and outputs the pronunciations +// "prons"... the format of "prons" is, each element is a vector, +// where the first element is the word (or zero meaning no word, e.g. +// for optional silence introduced by the lexicon), and the remaining +// elements are the phones in the word's pronunciation. +// It returns false if it encounters a problem of some kind, e.g. +// if the phone-sequence doesn't seem to have the right number of +// words in it. +bool ConvertPhnxToProns(const std::vector<int32> &phnx, + const std::vector<int32> &words, + int32 word_start_sym, + int32 word_end_sym, + std::vector<std::vector<int32> > *prons); + +/// @} end "addtogroup hmm_group" + +} // end namespace kaldi + + +#endif diff --git a/kaldi_io/src/kaldi/hmm/posterior.h b/kaldi_io/src/kaldi/hmm/posterior.h new file mode 100644 index 0000000..be73be9 --- /dev/null +++ b/kaldi_io/src/kaldi/hmm/posterior.h @@ -0,0 +1,214 @@ +// hmm/posterior.h + +// Copyright 2009-2011 Microsoft Corporation +// 2013-2014 Johns Hopkins University (author: Daniel Povey) +// 2014 Guoguo Chen + + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#ifndef KALDI_HMM_POSTERIOR_H_ +#define KALDI_HMM_POSTERIOR_H_ + +#include "base/kaldi-common.h" +#include "tree/context-dep.h" +#include "util/const-integer-set.h" +#include "util/kaldi-table.h" +#include "hmm/transition-model.h" + + +namespace kaldi { + + +/// \addtogroup posterior_group +/// @{ + +/// Posterior is a typedef for storing acoustic-state (actually, transition-id) +/// posteriors over an utterance. The "int32" is a transition-id, and the BaseFloat +/// is a probability (typically between zero and one). +typedef std::vector<std::vector<std::pair<int32, BaseFloat> > > Posterior; + +/// GaussPost is a typedef for storing Gaussian-level posteriors for an utterance. +/// the "int32" is a transition-id, and the Vector<BaseFloat> is a vector of +/// Gaussian posteriors. +/// WARNING: We changed "int32" from transition-id to pdf-id, and the change is +/// applied for all programs using GaussPost. This is for efficiency purpose. We +/// also changed the name slightly from GauPost to GaussPost to reduce the +/// chance that the change will go un-noticed in downstream code. +typedef std::vector<std::vector<std::pair<int32, Vector<BaseFloat> > > > GaussPost; + + +// PosteriorHolder is a holder for Posterior, which is +// std::vector<std::vector<std::pair<int32, BaseFloat> > > +// This is used for storing posteriors of transition id's for an +// utterance. +class PosteriorHolder { + public: + typedef Posterior T; + + PosteriorHolder() { } + + static bool Write(std::ostream &os, bool binary, const T &t); + + void Clear() { Posterior tmp; std::swap(tmp, t_); } + + // Reads into the holder. + bool Read(std::istream &is); + + // Kaldi objects always have the stream open in binary mode for + // reading. + static bool IsReadInBinary() { return true; } + + const T &Value() const { return t_; } + + private: + KALDI_DISALLOW_COPY_AND_ASSIGN(PosteriorHolder); + T t_; +}; + + +// GaussPostHolder is a holder for GaussPost, which is +// std::vector<std::vector<std::pair<int32, Vector<BaseFloat> > > > +// This is used for storing posteriors of transition id's for an +// utterance. +class GaussPostHolder { + public: + typedef GaussPost T; + + GaussPostHolder() { } + + static bool Write(std::ostream &os, bool binary, const T &t); + + void Clear() { GaussPost tmp; std::swap(tmp, t_); } + + // Reads into the holder. + bool Read(std::istream &is); + + // Kaldi objects always have the stream open in binary mode for + // reading. + static bool IsReadInBinary() { return true; } + + const T &Value() const { return t_; } + + private: + KALDI_DISALLOW_COPY_AND_ASSIGN(GaussPostHolder); + T t_; +}; + + +// Posterior is a typedef: vector<vector<pair<int32, BaseFloat> > >, +// representing posteriors over (typically) transition-ids for an +// utterance. +typedef TableWriter<PosteriorHolder> PosteriorWriter; +typedef SequentialTableReader<PosteriorHolder> SequentialPosteriorReader; +typedef RandomAccessTableReader<PosteriorHolder> RandomAccessPosteriorReader; + + +// typedef std::vector<std::vector<std::pair<int32, Vector<BaseFloat> > > > GaussPost; +typedef TableWriter<GaussPostHolder> GaussPostWriter; +typedef SequentialTableReader<GaussPostHolder> SequentialGaussPostReader; +typedef RandomAccessTableReader<GaussPostHolder> RandomAccessGaussPostReader; + + +/// Scales the BaseFloat (weight) element in the posterior entries. +void ScalePosterior(BaseFloat scale, Posterior *post); + +/// Returns the total of all the weights in "post". +BaseFloat TotalPosterior(const Posterior &post); + +/// Returns true if the two lists of pairs have no common .first element. +bool PosteriorEntriesAreDisjoint( + const std::vector<std::pair<int32, BaseFloat> > &post_elem1, + const std::vector<std::pair<int32, BaseFloat> > &post_elem2); + + +/// Merge two sets of posteriors, which must have the same length. If "merge" +/// is true, it will make a common entry whenever there are duplicated entries, +/// adding up the weights. If "drop_frames" is true, for frames where the +/// two sets of posteriors were originally disjoint, makes no entries for that +/// frame (relates to frame dropping, or drop_frames, see Vesely et al, ICASSP +/// 2013). Returns the number of frames for which the two posteriors were +/// disjoint (i.e. no common transition-ids or whatever index we are using). +int32 MergePosteriors(const Posterior &post1, + const Posterior &post2, + bool merge, + bool drop_frames, + Posterior *post); + +/// Given a vector of log-likelihoods (typically of Gaussians in a GMM +/// but could be of pdf-ids), a number gselect >= 1 and a minimum posterior +/// 0 <= min_post < 1, it gets the posterior for each element of log-likes +/// by applying Softmax(), then prunes the posteriors using "gselect" and +/// "min_post" (keeping at least one), and outputs the result into +/// "post_entry", sorted from greatest to least posterior. +/// Returns the total log-likelihood (the output of calling ApplySoftMax() +/// on a copy of log_likes). +BaseFloat VectorToPosteriorEntry( + const VectorBase<BaseFloat> &log_likes, + int32 num_gselect, + BaseFloat min_post, + std::vector<std::pair<int32, BaseFloat> > *post_entry); + +/// Convert an alignment to a posterior (with a scale of 1.0 on +/// each entry). +void AlignmentToPosterior(const std::vector<int32> &ali, + Posterior *post); + +/// Sorts posterior entries so that transition-ids with same pdf-id are next to +/// each other. +void SortPosteriorByPdfs(const TransitionModel &tmodel, + Posterior *post); + + +/// Converts a posterior over transition-ids to be a posterior +/// over pdf-ids. +void ConvertPosteriorToPdfs(const TransitionModel &tmodel, + const Posterior &post_in, + Posterior *post_out); + +/// Converts a posterior over transition-ids to be a posterior +/// over phones. +void ConvertPosteriorToPhones(const TransitionModel &tmodel, + const Posterior &post_in, + Posterior *post_out); + +/// Weight any silence phones in the posterior (i.e. any phones +/// in the set "silence_set" by scale "silence_scale". +/// The interface was changed in Feb 2014 to do the modification +/// "in-place" rather than having separate input and output. +void WeightSilencePost(const TransitionModel &trans_model, + const ConstIntegerSet<int32> &silence_set, + BaseFloat silence_scale, + Posterior *post); + +/// This is similar to WeightSilencePost, except that on each frame it +/// works out the amount by which the overall posterior would be reduced, +/// and scales down everything on that frame by the same amount. It +/// has the effect that frames that are mostly silence get down-weighted. +/// The interface was changed in Feb 2014 to do the modification +/// "in-place" rather than having separate input and output. +void WeightSilencePostDistributed(const TransitionModel &trans_model, + const ConstIntegerSet<int32> &silence_set, + BaseFloat silence_scale, + Posterior *post); + +/// @} end "addtogroup posterior_group" + + +} // end namespace kaldi + + +#endif diff --git a/kaldi_io/src/kaldi/hmm/transition-model.h b/kaldi_io/src/kaldi/hmm/transition-model.h new file mode 100644 index 0000000..ccc4f11 --- /dev/null +++ b/kaldi_io/src/kaldi/hmm/transition-model.h @@ -0,0 +1,345 @@ +// hmm/transition-model.h + +// Copyright 2009-2012 Microsoft Corporation +// Johns Hopkins University (author: Guoguo Chen) + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#ifndef KALDI_HMM_TRANSITION_MODEL_H_ +#define KALDI_HMM_TRANSITION_MODEL_H_ + +#include "base/kaldi-common.h" +#include "tree/context-dep.h" +#include "util/const-integer-set.h" +#include "fst/fst-decl.h" // forward declarations. +#include "hmm/hmm-topology.h" +#include "itf/options-itf.h" + +namespace kaldi { + +/// \addtogroup hmm_group +/// @{ + +// The class TransitionModel is a repository for the transition probabilities. +// It also handles certain integer mappings. +// The basic model is as follows. Each phone has a HMM topology defined in +// hmm-topology.h. Each HMM-state of each of these phones has a number of +// transitions (and final-probs) out of it. Each HMM-state defined in the +// HmmTopology class has an associated "pdf_class". This gets replaced with +// an actual pdf-id via the tree. The transition model associates the +// transition probs with the (phone, HMM-state, pdf-id). We associate with +// each such triple a transition-state. Each +// transition-state has a number of associated probabilities to estimate; +// this depends on the number of transitions/final-probs in the topology for +// that (phone, HMM-state). Each probability has an associated transition-index. +// We associate with each (transition-state, transition-index) a unique transition-id. +// Each individual probability estimated by the transition-model is asociated with a +// transition-id. +// +// List of the various types of quantity referred to here and what they mean: +// phone: a phone index (1, 2, 3 ...) +// HMM-state: a number (0, 1, 2...) that indexes TopologyEntry (see hmm-topology.h) +// pdf-id: a number output by the Compute function of ContextDependency (it +// indexes pdf's). Zero-based. +// transition-state: the states for which we estimate transition probabilities for transitions +// out of them. In some topologies, will map one-to-one with pdf-ids. +// One-based, since it appears on FSTs. +// transition-index: identifier of a transition (or final-prob) in the HMM. Indexes the +// "transitions" vector in HmmTopology::HmmState. [if it is out of range, +// equal to transitions.size(), it refers to the final-prob.] +// Zero-based. +// transition-id: identifier of a unique parameter of the TransitionModel. +// Associated with a (transition-state, transition-index) pair. +// One-based, since it appears on FSTs. +// +// List of the possible mappings TransitionModel can do: +// (phone, HMM-state, pdf-id) -> transition-state +// (transition-state, transition-index) -> transition-id +// Reverse mappings: +// transition-id -> transition-state +// transition-id -> transition-index +// transition-state -> phone +// transition-state -> HMM-state +// transition-state -> pdf-id +// +// The main things the TransitionModel object can do are: +// Get initialized (need ContextDependency and HmmTopology objects). +// Read/write. +// Update [given a vector of counts indexed by transition-id]. +// Do the various integer mappings mentioned above. +// Get the probability (or log-probability) associated with a particular transition-id. + + +// Note: this was previously called TransitionUpdateConfig. +struct MleTransitionUpdateConfig { + BaseFloat floor; + BaseFloat mincount; + bool share_for_pdfs; // If true, share all transition parameters that have the same pdf. + MleTransitionUpdateConfig(BaseFloat floor = 0.01, + BaseFloat mincount = 5.0, + bool share_for_pdfs = false): + floor(floor), mincount(mincount), share_for_pdfs(share_for_pdfs) {} + + void Register (OptionsItf *po) { + po->Register("transition-floor", &floor, + "Floor for transition probabilities"); + po->Register("transition-min-count", &mincount, + "Minimum count required to update transitions from a state"); + po->Register("share-for-pdfs", &share_for_pdfs, + "If true, share all transition parameters where the states " + "have the same pdf."); + } +}; + +struct MapTransitionUpdateConfig { + BaseFloat tau; + bool share_for_pdfs; // If true, share all transition parameters that have the same pdf. + MapTransitionUpdateConfig(): tau(5.0), share_for_pdfs(false) { } + + void Register (OptionsItf *po) { + po->Register("transition-tau", &tau, "Tau value for MAP estimation of transition " + "probabilities."); + po->Register("share-for-pdfs", &share_for_pdfs, + "If true, share all transition parameters where the states " + "have the same pdf."); + } +}; + +class TransitionModel { + + public: + /// Initialize the object [e.g. at the start of training]. + /// The class keeps a copy of the HmmTopology object, but not + /// the ContextDependency object. + TransitionModel(const ContextDependency &ctx_dep, + const HmmTopology &hmm_topo); + + + /// Constructor that takes no arguments: typically used prior to calling Read. + TransitionModel() { } + + void Read(std::istream &is, bool binary); // note, no symbol table: topo object always read/written w/o symbols. + void Write(std::ostream &os, bool binary) const; + + + /// return reference to HMM-topology object. + const HmmTopology &GetTopo() const { return topo_; } + + /// \name Integer mapping functions + /// @{ + + int32 TripleToTransitionState(int32 phone, int32 hmm_state, int32 pdf) const; + int32 PairToTransitionId(int32 trans_state, int32 trans_index) const; + int32 TransitionIdToTransitionState(int32 trans_id) const; + int32 TransitionIdToTransitionIndex(int32 trans_id) const; + int32 TransitionStateToPhone(int32 trans_state) const; + int32 TransitionStateToHmmState(int32 trans_state) const; + int32 TransitionStateToPdf(int32 trans_state) const; + int32 SelfLoopOf(int32 trans_state) const; // returns the self-loop transition-id, or zero if + // this state doesn't have a self-loop. + + inline int32 TransitionIdToPdf(int32 trans_id) const; + int32 TransitionIdToPhone(int32 trans_id) const; + int32 TransitionIdToPdfClass(int32 trans_id) const; + int32 TransitionIdToHmmState(int32 trans_id) const; + + /// @} + + bool IsFinal(int32 trans_id) const; // returns true if this trans_id goes to the final state + // (which is bound to be nonemitting). + bool IsSelfLoop(int32 trans_id) const; // return true if this trans_id corresponds to a self-loop. + + /// Returns the total number of transition-ids (note, these are one-based). + inline int32 NumTransitionIds() const { return id2state_.size()-1; } + + /// Returns the number of transition-indices for a particular transition-state. + /// Note: "Indices" is the plural of "index". Index is not the same as "id", + /// here. A transition-index is a zero-based offset into the transitions + /// out of a particular transition state. + int32 NumTransitionIndices(int32 trans_state) const; + + /// Returns the total number of transition-states (note, these are one-based). + int32 NumTransitionStates() const { return triples_.size(); } + + // NumPdfs() actually returns the highest-numbered pdf we ever saw, plus one. + // In normal cases this should equal the number of pdfs in the system, but if you + // initialized this object with fewer than all the phones, and it happens that + // an unseen phone has the highest-numbered pdf, this might be different. + int32 NumPdfs() const { return num_pdfs_; } + + // This loops over the triples and finds the highest phone index present. If + // the FST symbol table for the phones is created in the expected way, i.e.: + // starting from 1 (<eps> is 0) and numbered contiguously till the last phone, + // this will be the total number of phones. + int32 NumPhones() const; + + /// Returns a sorted, unique list of phones. + const std::vector<int32> &GetPhones() const { return topo_.GetPhones(); } + + // Transition-parameter-getting functions: + BaseFloat GetTransitionProb(int32 trans_id) const; + BaseFloat GetTransitionLogProb(int32 trans_id) const; + + // The following functions are more specialized functions for getting + // transition probabilities, that are provided for convenience. + + /// Returns the log-probability of a particular non-self-loop transition + /// after subtracting the probability mass of the self-loop and renormalizing; + /// will crash if called on a self-loop. Specifically: + /// for non-self-loops it returns the log of that prob divided by (1 minus + /// self-loop-prob-for-that-state). + BaseFloat GetTransitionLogProbIgnoringSelfLoops(int32 trans_id) const; + + /// Returns the log-prob of the non-self-loop probability + /// mass for this transition state. (you can get the self-loop prob, if a self-loop + /// exists, by calling GetTransitionLogProb(SelfLoopOf(trans_state)). + BaseFloat GetNonSelfLoopLogProb(int32 trans_state) const; + + /// Does Maximum Likelihood estimation. The stats are counts/weights, indexed + /// by transition-id. This was previously called Update(). + void MleUpdate(const Vector<double> &stats, + const MleTransitionUpdateConfig &cfg, + BaseFloat *objf_impr_out, + BaseFloat *count_out); + + /// Does Maximum A Posteriori (MAP) estimation. The stats are counts/weights, + /// indexed by transition-id. + void MapUpdate(const Vector<double> &stats, + const MapTransitionUpdateConfig &cfg, + BaseFloat *objf_impr_out, + BaseFloat *count_out); + + /// Print will print the transition model in a human-readable way, for purposes of human + /// inspection. The "occs" are optional (they are indexed by pdf-id). + void Print(std::ostream &os, + const std::vector<std::string> &phone_names, + const Vector<double> *occs = NULL); + + + void InitStats(Vector<double> *stats) const { stats->Resize(NumTransitionIds()+1); } + + void Accumulate(BaseFloat prob, int32 trans_id, Vector<double> *stats) const { + KALDI_ASSERT(trans_id <= NumTransitionIds()); + (*stats)(trans_id) += prob; + // This is trivial and doesn't require class members, but leaves us more open + // to design changes than doing it manually. + } + + /// returns true if all the integer class members are identical (but does not + /// compare the transition probabilities. + bool Compatible(const TransitionModel &other) const; + + private: + void MleUpdateShared(const Vector<double> &stats, + const MleTransitionUpdateConfig &cfg, + BaseFloat *objf_impr_out, BaseFloat *count_out); + void MapUpdateShared(const Vector<double> &stats, + const MapTransitionUpdateConfig &cfg, + BaseFloat *objf_impr_out, BaseFloat *count_out); + void ComputeTriples(const ContextDependency &ctx_dep); // called from constructor. initializes triples_. + void ComputeDerived(); // called from constructor and Read function: computes state2id_ and id2state_. + void ComputeDerivedOfProbs(); // computes quantities derived from log-probs (currently just + // non_self_loop_log_probs_; called whenever log-probs change. + void InitializeProbs(); // called from constructor. + void Check() const; + + struct Triple { + int32 phone; + int32 hmm_state; + int32 pdf; + Triple() { } + Triple(int32 phone, int32 hmm_state, int32 pdf): + phone(phone), hmm_state(hmm_state), pdf(pdf) { } + bool operator < (const Triple &other) const { + if (phone < other.phone) return true; + else if (phone > other.phone) return false; + else if (hmm_state < other.hmm_state) return true; + else if (hmm_state > other.hmm_state) return false; + else return pdf < other.pdf; + } + bool operator == (const Triple &other) const { + return (phone == other.phone && hmm_state == other.hmm_state + && pdf == other.pdf); + } + }; + + HmmTopology topo_; + + /// Triples indexed by transition state minus one; + /// the triples are in sorted order which allows us to do the reverse mapping from + /// triple to transition state + std::vector<Triple> triples_; + + /// Gives the first transition_id of each transition-state; indexed by + /// the transition-state. Array indexed 1..num-transition-states+1 (the last one + /// is needed so we can know the num-transitions of the last transition-state. + std::vector<int32> state2id_; + + /// For each transition-id, the corresponding transition + /// state (indexed by transition-id). + std::vector<int32> id2state_; + + /// For each transition-id, the corresponding log-prob. Indexed by transition-id. + Vector<BaseFloat> log_probs_; + + /// For each transition-state, the log of (1 - self-loop-prob). Indexed by + /// transition-state. + Vector<BaseFloat> non_self_loop_log_probs_; + + /// This is actually one plus the highest-numbered pdf we ever got back from the + /// tree (but the tree numbers pdfs contiguously from zero so this is the number + /// of pdfs). + int32 num_pdfs_; + + + DISALLOW_COPY_AND_ASSIGN(TransitionModel); + +}; + +inline int32 TransitionModel::TransitionIdToPdf(int32 trans_id) const { + // If a lot of time is spent here we may create an extra array + // to handle this. + KALDI_ASSERT(static_cast<size_t>(trans_id) < id2state_.size() && + "Likely graph/model mismatch (graph built from wrong model?)"); + int32 trans_state = id2state_[trans_id]; + return triples_[trans_state-1].pdf; +} + +/// Works out which pdfs might correspond to the given phones. Will return true +/// if these pdfs correspond *just* to these phones, false if these pdfs are also +/// used by other phones. +/// @param trans_model [in] Transition-model used to work out this information +/// @param phones [in] A sorted, uniq vector that represents a set of phones +/// @param pdfs [out] Will be set to a sorted, uniq list of pdf-ids that correspond +/// to one of this set of phones. +/// @return Returns true if all of the pdfs output to "pdfs" correspond to phones from +/// just this set (false if they may be shared with phones outside this set). +bool GetPdfsForPhones(const TransitionModel &trans_model, + const std::vector<int32> &phones, + std::vector<int32> *pdfs); + +/// Works out which phones might correspond to the given pdfs. Similar to the +/// above GetPdfsForPhones(, ,) +bool GetPhonesForPdfs(const TransitionModel &trans_model, + const std::vector<int32> &pdfs, + std::vector<int32> *phones); +/// @} + + +} // end namespace kaldi + + +#endif diff --git a/kaldi_io/src/kaldi/hmm/tree-accu.h b/kaldi_io/src/kaldi/hmm/tree-accu.h new file mode 100644 index 0000000..d571762 --- /dev/null +++ b/kaldi_io/src/kaldi/hmm/tree-accu.h @@ -0,0 +1,69 @@ +// hmm/tree-accu.h + +// Copyright 2009-2011 Microsoft Corporation +// 2013 Johns Hopkins University (author: Daniel Povey) + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at + +// http://www.apache.org/licenses/LICENSE-2.0 + +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. +#ifndef KALDI_HMM_TREE_ACCU_H_ +#define KALDI_HMM_TREE_ACCU_H_ + +#include <cctype> // For isspace. +#include <limits> +#include "base/kaldi-common.h" +#include "hmm/transition-model.h" +#include "tree/clusterable-classes.h" +#include "tree/build-tree-questions.h" // needed for this typedef: +// typedef std::vector<std::pair<EventVector, Clusterable*> > BuildTreeStatsType; + +namespace kaldi { + +/// \ingroup tree_group_top +/// @{ + + +/// Accumulates the stats needed for training context-dependency trees (in the +/// "normal" way). It adds to 'stats' the stats obtained from this file. Any +/// new GaussClusterable* pointers in "stats" will be allocated with "new". + +void AccumulateTreeStats(const TransitionModel &trans_model, + BaseFloat var_floor, + int N, // context window size. + int P, // central position. + const std::vector<int32> &ci_phones, // sorted + const std::vector<int32> &alignment, + const Matrix<BaseFloat> &features, + const std::vector<int32> *phone_map, // or NULL + std::map<EventType, GaussClusterable*> *stats); + + + +/*** Read a mapping from one phone set to another. The phone map file has lines + of the form <old-phone> <new-phone>, where both entries are integers, usually + nonzero (but this is not enforced). This program will crash if the input is + invalid, e.g. there are multiple inconsistent entries for the same old phone. + The output vector "phone_map" will be indexed by old-phone and will contain + the corresponding new-phone, or -1 for any entry that was not defined. */ + +void ReadPhoneMap(std::string phone_map_rxfilename, + std::vector<int32> *phone_map); + + + +/// @} + +} // end namespace kaldi. + +#endif diff --git a/kaldi_io/src/kaldi/itf/clusterable-itf.h b/kaldi_io/src/kaldi/itf/clusterable-itf.h new file mode 100644 index 0000000..7ef9ae0 --- /dev/null +++ b/kaldi_io/src/kaldi/itf/clusterable-itf.h @@ -0,0 +1,97 @@ +// itf/clusterable-itf.h + +// Copyright 2009-2011 Microsoft Corporation; Go Vivace Inc. + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + + +#ifndef KALDI_ITF_CLUSTERABLE_ITF_H_ +#define KALDI_ITF_CLUSTERABLE_ITF_H_ 1 + +#include <string> +#include "base/kaldi-common.h" + +namespace kaldi { + + +/** \addtogroup clustering_group + @{ + A virtual class for clusterable objects; see \ref clustering for an + explanation if its function. +*/ + + + +class Clusterable { + public: + /// \name Functions that must be overridden + /// @{ + + /// Return a copy of this object. + virtual Clusterable *Copy() const = 0; + /// Return the objective function associated with the stats + /// [assuming ML estimation] + virtual BaseFloat Objf() const = 0; + /// Return the normalizer (typically, count) associated with the stats + virtual BaseFloat Normalizer() const = 0; + /// Set stats to empty. + virtual void SetZero() = 0; + /// Add other stats. + virtual void Add(const Clusterable &other) = 0; + /// Subtract other stats. + virtual void Sub(const Clusterable &other) = 0; + /// Scale the stats by a positive number f [not mandatory to supply this]. + virtual void Scale(BaseFloat f) { + KALDI_ERR << "This Clusterable object does not implement Scale()."; + } + + /// Return a string that describes the inherited type. + virtual std::string Type() const = 0; + + /// Write data to stream. + virtual void Write(std::ostream &os, bool binary) const = 0; + + /// Read data from a stream and return the corresponding object (const + /// function; it's a class member because we need access to the vtable + /// so generic code can read derived types). + virtual Clusterable* ReadNew(std::istream &os, bool binary) const = 0; + + virtual ~Clusterable() {} + + /// @} + + /// \name Functions that have default implementations + /// @{ + + // These functions have default implementations (but may be overridden for + // speed). Implementatons in tree/clusterable-classes.cc + + /// Return the objective function of the combined object this + other. + virtual BaseFloat ObjfPlus(const Clusterable &other) const; + /// Return the objective function of the subtracted object this - other. + virtual BaseFloat ObjfMinus(const Clusterable &other) const; + /// Return the objective function decrease from merging the two + /// clusters, negated to be a positive number (or zero). + virtual BaseFloat Distance(const Clusterable &other) const; + /// @} + +}; +/// @} end of "ingroup clustering_group" + +} // end namespace kaldi + +#endif // KALDI_ITF_CLUSTERABLE_ITF_H_ + diff --git a/kaldi_io/src/kaldi/itf/context-dep-itf.h b/kaldi_io/src/kaldi/itf/context-dep-itf.h new file mode 100644 index 0000000..6a0bd0f --- /dev/null +++ b/kaldi_io/src/kaldi/itf/context-dep-itf.h @@ -0,0 +1,80 @@ +// itf/context-dep-itf.h + +// Copyright 2009-2011 Microsoft Corporation; Go Vivace Inc. + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + + +#ifndef KALDI_ITF_CONTEXT_DEP_ITF_H_ +#define KALDI_ITF_CONTEXT_DEP_ITF_H_ +#include "base/kaldi-common.h" + +namespace kaldi { +/// @ingroup tree_group +/// @{ + +/// context-dep-itf.h provides a link between +/// the tree-building code in ../tree/, and the FST code in ../fstext/ +/// (particularly, ../fstext/context-dep.h). It is an abstract +/// interface that describes an object that can map from a +/// phone-in-context to a sequence of integer leaf-ids. +class ContextDependencyInterface { + public: + /// ContextWidth() returns the value N (e.g. 3 for triphone models) that says how many phones + /// are considered for computing context. + virtual int ContextWidth() const = 0; + + /// Central position P of the phone context, in 0-based numbering, e.g. P = 1 for typical + /// triphone system. We have to see if we can do without this function. + virtual int CentralPosition() const = 0; + + /// The "new" Compute interface. For typical topologies, + /// pdf_class would be 0, 1, 2. + /// Returns success or failure; outputs the pdf-id. + /// + /// "Compute" is the main function of this interface, that takes a + /// sequence of N phones (and it must be N phones), possibly + /// including epsilons (symbol id zero) but only at positions other + /// than P [these represent unknown phone context due to end or + /// begin of sequence]. We do not insist that Compute must always + /// output (into stateseq) a nonempty sequence of states, but we + /// anticipate that stateseq will alyway be nonempty at output in + /// typical use cases. "Compute" returns false if expansion somehow + /// failed. Normally the calling code should raise an exception if + /// this happens. We can define a different interface later in + /// order to handle other kinds of information-- the underlying + /// data-structures from event-map.h are very flexible. + virtual bool Compute(const std::vector<int32> &phoneseq, int32 pdf_class, + int32 *pdf_id) const = 0; + + + + /// NumPdfs() returns the number of acoustic pdfs (they are numbered 0.. NumPdfs()-1). + virtual int32 NumPdfs() const = 0; + + virtual ~ContextDependencyInterface() {}; + ContextDependencyInterface() {} + + /// Returns pointer to new object which is copy of current one. + virtual ContextDependencyInterface *Copy() const = 0; + private: + KALDI_DISALLOW_COPY_AND_ASSIGN(ContextDependencyInterface); +}; +/// @} +} // namespace Kaldi + + +#endif diff --git a/kaldi_io/src/kaldi/itf/decodable-itf.h b/kaldi_io/src/kaldi/itf/decodable-itf.h new file mode 100644 index 0000000..ba4d765 --- /dev/null +++ b/kaldi_io/src/kaldi/itf/decodable-itf.h @@ -0,0 +1,123 @@ +// itf/decodable-itf.h + +// Copyright 2009-2011 Microsoft Corporation; Saarland University; +// Mirko Hannemann; Go Vivace Inc.; +// 2013 Johns Hopkins University (author: Daniel Povey) + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#ifndef KALDI_ITF_DECODABLE_ITF_H_ +#define KALDI_ITF_DECODABLE_ITF_H_ 1 +#include "base/kaldi-common.h" + +namespace kaldi { +/// @ingroup Interfaces +/// @{ + + +/** + DecodableInterface provides a link between the (acoustic-modeling and + feature-processing) code and the decoder. The idea is to make this + interface as small as possible, and to make it as agnostic as possible about + the form of the acoustic model (e.g. don't assume the probabilities are a + function of just a vector of floats), and about the decoder (e.g. don't + assume it accesses frames in strict left-to-right order). For normal + models, without on-line operation, the "decodable" sub-class will just be a + wrapper around a matrix of features and an acoustic model, and it will + answer the question 'what is the acoustic likelihood for this index and this + frame?'. + + For online decoding, where the features are coming in in real time, it is + important to understand the IsLastFrame() and NumFramesReady() functions. + There are two ways these are used: the old online-decoding code, in ../online/, + and the new online-decoding code, in ../online2/. In the old online-decoding + code, the decoder would do: + \code{.cc} + for (int frame = 0; !decodable.IsLastFrame(frame); frame++) { + // Process this frame + } + \endcode + and the the call to IsLastFrame would block if the features had not arrived yet. + The decodable object would have to know when to terminate the decoding. This + online-decoding mode is still supported, it is what happens when you call, for + example, LatticeFasterDecoder::Decode(). + + We realized that this "blocking" mode of decoding is not very convenient + because it forces the program to be multi-threaded and makes it complex to + control endpointing. In the "new" decoding code, you don't call (for example) + LatticeFasterDecoder::Decode(), you call LatticeFasterDecoder::InitDecoding(), + and then each time you get more features, you provide them to the decodable + object, and you call LatticeFasterDecoder::AdvanceDecoding(), which does + something like this: + \code{.cc} + while (num_frames_decoded_ < decodable.NumFramesReady()) { + // Decode one more frame [increments num_frames_decoded_] + } + \endcode + So the decodable object never has IsLastFrame() called. For decoding where + you are starting with a matrix of features, the NumFramesReady() function will + always just return the number of frames in the file, and IsLastFrame() will + return true for the last frame. + + For truly online decoding, the "old" online decodable objects in ../online/ have a + "blocking" IsLastFrame() and will crash if you call NumFramesReady(). + The "new" online decodable objects in ../online2/ return the number of frames + currently accessible if you call NumFramesReady(). You will likely not need + to call IsLastFrame(), but we implement it to only return true for the last + frame of the file once we've decided to terminate decoding. +*/ + +class DecodableInterface { + public: + /// Returns the log likelihood, which will be negated in the decoder. + /// The "frame" starts from zero. You should verify that IsLastFrame(frame-1) + /// returns false before calling this. + virtual BaseFloat LogLikelihood(int32 frame, int32 index) = 0; + + /// Returns true if this is the last frame. Frames are zero-based, so the + /// first frame is zero. IsLastFrame(-1) will return false, unless the file + /// is empty (which is a case that I'm not sure all the code will handle, so + /// be careful). Caution: the behavior of this function in an online setting + /// is being changed somewhat. In future it may return false in cases where + /// we haven't yet decided to terminate decoding, but later true if we decide + /// to terminate decoding. The plan in future is to rely more on + /// NumFramesReady(), and in future, IsLastFrame() would always return false + /// in an online-decoding setting, and would only return true in a + /// decoding-from-matrix setting where we want to allow the last delta or LDA + /// features to be flushed out for compatibility with the baseline setup. + virtual bool IsLastFrame(int32 frame) const = 0; + + /// The call NumFramesReady() will return the number of frames currently available + /// for this decodable object. This is for use in setups where you don't want the + /// decoder to block while waiting for input. This is newly added as of Jan 2014, + /// and I hope, going forward, to rely on this mechanism more than IsLastFrame to + /// know when to stop decoding. + virtual int32 NumFramesReady() const { + KALDI_ERR << "NumFramesReady() not implemented for this decodable type."; + return -1; + } + + /// Returns the number of states in the acoustic model + /// (they will be indexed one-based, i.e. from 1 to NumIndices(); + /// this is for compatibility with OpenFst. + virtual int32 NumIndices() const = 0; + + virtual ~DecodableInterface() {} +}; +/// @} +} // namespace Kaldi + +#endif // KALDI_ITF_DECODABLE_ITF_H_ diff --git a/kaldi_io/src/kaldi/itf/online-feature-itf.h b/kaldi_io/src/kaldi/itf/online-feature-itf.h new file mode 100644 index 0000000..dafcd8a --- /dev/null +++ b/kaldi_io/src/kaldi/itf/online-feature-itf.h @@ -0,0 +1,105 @@ +// itf/online-feature-itf.h + +// Copyright 2013 Johns Hopkins University (author: Daniel Povey) + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#ifndef KALDI_ITF_ONLINE_FEATURE_ITF_H_ +#define KALDI_ITF_ONLINE_FEATURE_ITF_H_ 1 +#include "base/kaldi-common.h" +#include "matrix/matrix-lib.h" + +namespace kaldi { +/// @ingroup Interfaces +/// @{ + +/** + OnlineFeatureInterface is an interface for online feature processing (it is + also usable in the offline setting, but currently we're not using it for + that). This is for use in the online2/ directory, and it supersedes the + interface in ../online/online-feat-input.h. We have a slighty different + model that puts more control in the hands of the calling thread, and won't + involve waiting on semaphores in the decoding thread. + + This interface only specifies how the object *outputs* the features. + How it obtains the features, e.g. from a previous object or objects of type + OnlineFeatureInterface, is not specified in the interface and you will + likely define new constructors or methods in the derived type to do that. + + You should appreciate that this interface is designed to allow random + access to features, as long as they are ready. That is, the user + can call GetFrame for any frame less than NumFramesReady(), and when + implementing a child class you must not make assumptions about the + order in which the user makes these calls. +*/ + +class OnlineFeatureInterface { + public: + virtual int32 Dim() const = 0; /// returns the feature dimension. + + /// Returns the total number of frames, since the start of the utterance, that + /// are now available. In an online-decoding context, this will likely + /// increase with time as more data becomes available. + virtual int32 NumFramesReady() const = 0; + + /// Returns true if this is the last frame. Frame indices are zero-based, so the + /// first frame is zero. IsLastFrame(-1) will return false, unless the file + /// is empty (which is a case that I'm not sure all the code will handle, so + /// be careful). This function may return false for some frame if + /// we haven't yet decided to terminate decoding, but later true if we decide + /// to terminate decoding. This function exists mainly to correctly handle + /// end effects in feature extraction, and is not a mechanism to determine how + /// many frames are in the decodable object (as it used to be, and for backward + /// compatibility, still is, in the Decodable interface). + virtual bool IsLastFrame(int32 frame) const = 0; + + /// Gets the feature vector for this frame. Before calling this for a given + /// frame, it is assumed that you called NumFramesReady() and it returned a + /// number greater than "frame". Otherwise this call will likely crash with + /// an assert failure. This function is not declared const, in case there is + /// some kind of caching going on, but most of the time it shouldn't modify + /// the class. + virtual void GetFrame(int32 frame, VectorBase<BaseFloat> *feat) = 0; + + /// Virtual destructor. Note: constructors that take another member of + /// type OnlineFeatureInterface are not expected to take ownership of + /// that pointer; the caller needs to keep track of that manually. + virtual ~OnlineFeatureInterface() { } +}; + + +/// Add a virtual class for "source" features such as MFCC or PLP or pitch +/// features. +class OnlineBaseFeature: public OnlineFeatureInterface { + public: + /// This would be called from the application, when you get more wave data. + /// Note: the sampling_rate is typically only provided so the code can assert + /// that it matches the sampling rate expected in the options. + virtual void AcceptWaveform(BaseFloat sampling_rate, + const VectorBase<BaseFloat> &waveform) = 0; + + /// InputFinished() tells the class you won't be providing any + /// more waveform. This will help flush out the last few frames + /// of delta or LDA features (it will typically affect the return value + /// of IsLastFrame. + virtual void InputFinished() = 0; +}; + + +/// @} +} // namespace Kaldi + +#endif // KALDI_ITF_ONLINE_FEATURE_ITF_H_ diff --git a/kaldi_io/src/kaldi/itf/optimizable-itf.h b/kaldi_io/src/kaldi/itf/optimizable-itf.h new file mode 100644 index 0000000..1b8f54b --- /dev/null +++ b/kaldi_io/src/kaldi/itf/optimizable-itf.h @@ -0,0 +1,51 @@ +// itf/optimizable-itf.h + +// Copyright 2009-2011 Go Vivace Inc.; Microsoft Corporation; Georg Stemmer + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. +#ifndef KALDI_ITF_OPTIMIZABLE_ITF_H_ +#define KALDI_ITF_OPTIMIZABLE_ITF_H_ + +#include "base/kaldi-common.h" +#include "matrix/matrix-lib.h" + +namespace kaldi { +/// @ingroup Interfaces +/// @{ + +/// OptimizableInterface provides +/// a virtual class for optimizable objects. +/// E.g. a class that computed a likelihood function and +/// its gradient using some parameter +/// that has to be optimized on data +/// could inherit from it. +template<class Real> +class OptimizableInterface { + public: + /// computes gradient for a parameter params and returns it + /// in gradient_out + virtual void ComputeGradient(const Vector<Real> ¶ms, + Vector<Real> *gradient_out) = 0; + /// computes the function value for a parameter params + /// and returns it + virtual Real ComputeValue(const Vector<Real> ¶ms) = 0; + + virtual ~OptimizableInterface() {} +}; +/// @} end of "Interfaces" +} // end namespace kaldi + +#endif diff --git a/kaldi_io/src/kaldi/itf/options-itf.h b/kaldi_io/src/kaldi/itf/options-itf.h new file mode 100644 index 0000000..204f46d --- /dev/null +++ b/kaldi_io/src/kaldi/itf/options-itf.h @@ -0,0 +1,49 @@ +// itf/options-itf.h + +// Copyright 2013 Tanel Alumae, Tallinn University of Technology + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#ifndef KALDI_ITF_OPTIONS_ITF_H_ +#define KALDI_ITF_OPTIONS_ITF_H_ 1 +#include "base/kaldi-common.h" + +namespace kaldi { + +class OptionsItf { + public: + + virtual void Register(const std::string &name, + bool *ptr, const std::string &doc) = 0; + virtual void Register(const std::string &name, + int32 *ptr, const std::string &doc) = 0; + virtual void Register(const std::string &name, + uint32 *ptr, const std::string &doc) = 0; + virtual void Register(const std::string &name, + float *ptr, const std::string &doc) = 0; + virtual void Register(const std::string &name, + double *ptr, const std::string &doc) = 0; + virtual void Register(const std::string &name, + std::string *ptr, const std::string &doc) = 0; + + virtual ~OptionsItf() {} +}; + +} // namespace Kaldi + +#endif // KALDI_ITF_OPTIONS_ITF_H_ + + diff --git a/kaldi_io/src/kaldi/matrix/cblas-wrappers.h b/kaldi_io/src/kaldi/matrix/cblas-wrappers.h new file mode 100644 index 0000000..ebec0a3 --- /dev/null +++ b/kaldi_io/src/kaldi/matrix/cblas-wrappers.h @@ -0,0 +1,491 @@ +// matrix/cblas-wrappers.h + +// Copyright 2012 Johns Hopkins University (author: Daniel Povey); +// Haihua Xu; Wei Shi + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at + +// http://www.apache.org/licenses/LICENSE-2.0 + +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. +#ifndef KALDI_MATRIX_CBLAS_WRAPPERS_H_ +#define KALDI_MATRIX_CBLAS_WRAPPERS_H_ 1 + + +#include <limits> +#include "matrix/sp-matrix.h" +#include "matrix/kaldi-vector.h" +#include "matrix/kaldi-matrix.h" +#include "matrix/matrix-functions.h" + +// Do not include this file directly. It is to be included +// by .cc files in this directory. + +namespace kaldi { + + +inline void cblas_Xcopy(const int N, const float *X, const int incX, float *Y, + const int incY) { + cblas_scopy(N, X, incX, Y, incY); +} + +inline void cblas_Xcopy(const int N, const double *X, const int incX, double *Y, + const int incY) { + cblas_dcopy(N, X, incX, Y, incY); +} + + +inline float cblas_Xasum(const int N, const float *X, const int incX) { + return cblas_sasum(N, X, incX); +} + +inline double cblas_Xasum(const int N, const double *X, const int incX) { + return cblas_dasum(N, X, incX); +} + +inline void cblas_Xrot(const int N, float *X, const int incX, float *Y, + const int incY, const float c, const float s) { + cblas_srot(N, X, incX, Y, incY, c, s); +} +inline void cblas_Xrot(const int N, double *X, const int incX, double *Y, + const int incY, const double c, const double s) { + cblas_drot(N, X, incX, Y, incY, c, s); +} +inline float cblas_Xdot(const int N, const float *const X, + const int incX, const float *const Y, + const int incY) { + return cblas_sdot(N, X, incX, Y, incY); +} +inline double cblas_Xdot(const int N, const double *const X, + const int incX, const double *const Y, + const int incY) { + return cblas_ddot(N, X, incX, Y, incY); +} +inline void cblas_Xaxpy(const int N, const float alpha, const float *X, + const int incX, float *Y, const int incY) { + cblas_saxpy(N, alpha, X, incX, Y, incY); +} +inline void cblas_Xaxpy(const int N, const double alpha, const double *X, + const int incX, double *Y, const int incY) { + cblas_daxpy(N, alpha, X, incX, Y, incY); +} +inline void cblas_Xscal(const int N, const float alpha, float *data, + const int inc) { + cblas_sscal(N, alpha, data, inc); +} +inline void cblas_Xscal(const int N, const double alpha, double *data, + const int inc) { + cblas_dscal(N, alpha, data, inc); +} +inline void cblas_Xspmv(const float alpha, const int num_rows, const float *Mdata, + const float *v, const int v_inc, + const float beta, float *y, const int y_inc) { + cblas_sspmv(CblasRowMajor, CblasLower, num_rows, alpha, Mdata, v, v_inc, beta, y, y_inc); +} +inline void cblas_Xspmv(const double alpha, const int num_rows, const double *Mdata, + const double *v, const int v_inc, + const double beta, double *y, const int y_inc) { + cblas_dspmv(CblasRowMajor, CblasLower, num_rows, alpha, Mdata, v, v_inc, beta, y, y_inc); +} +inline void cblas_Xtpmv(MatrixTransposeType trans, const float *Mdata, + const int num_rows, float *y, const int y_inc) { + cblas_stpmv(CblasRowMajor, CblasLower, static_cast<CBLAS_TRANSPOSE>(trans), + CblasNonUnit, num_rows, Mdata, y, y_inc); +} +inline void cblas_Xtpmv(MatrixTransposeType trans, const double *Mdata, + const int num_rows, double *y, const int y_inc) { + cblas_dtpmv(CblasRowMajor, CblasLower, static_cast<CBLAS_TRANSPOSE>(trans), + CblasNonUnit, num_rows, Mdata, y, y_inc); +} + + +inline void cblas_Xtpsv(MatrixTransposeType trans, const float *Mdata, + const int num_rows, float *y, const int y_inc) { + cblas_stpsv(CblasRowMajor, CblasLower, static_cast<CBLAS_TRANSPOSE>(trans), + CblasNonUnit, num_rows, Mdata, y, y_inc); +} +inline void cblas_Xtpsv(MatrixTransposeType trans, const double *Mdata, + const int num_rows, double *y, const int y_inc) { + cblas_dtpsv(CblasRowMajor, CblasLower, static_cast<CBLAS_TRANSPOSE>(trans), + CblasNonUnit, num_rows, Mdata, y, y_inc); +} + +// x = alpha * M * y + beta * x +inline void cblas_Xspmv(MatrixIndexT dim, float alpha, const float *Mdata, + const float *ydata, MatrixIndexT ystride, + float beta, float *xdata, MatrixIndexT xstride) { + cblas_sspmv(CblasRowMajor, CblasLower, dim, alpha, Mdata, + ydata, ystride, beta, xdata, xstride); +} +inline void cblas_Xspmv(MatrixIndexT dim, double alpha, const double *Mdata, + const double *ydata, MatrixIndexT ystride, + double beta, double *xdata, MatrixIndexT xstride) { + cblas_dspmv(CblasRowMajor, CblasLower, dim, alpha, Mdata, + ydata, ystride, beta, xdata, xstride); +} + +// Implements A += alpha * (x y' + y x'); A is symmetric matrix. +inline void cblas_Xspr2(MatrixIndexT dim, float alpha, const float *Xdata, + MatrixIndexT incX, const float *Ydata, MatrixIndexT incY, + float *Adata) { + cblas_sspr2(CblasRowMajor, CblasLower, dim, alpha, Xdata, + incX, Ydata, incY, Adata); +} +inline void cblas_Xspr2(MatrixIndexT dim, double alpha, const double *Xdata, + MatrixIndexT incX, const double *Ydata, MatrixIndexT incY, + double *Adata) { + cblas_dspr2(CblasRowMajor, CblasLower, dim, alpha, Xdata, + incX, Ydata, incY, Adata); +} + +// Implements A += alpha * (x x'); A is symmetric matrix. +inline void cblas_Xspr(MatrixIndexT dim, float alpha, const float *Xdata, + MatrixIndexT incX, float *Adata) { + cblas_sspr(CblasRowMajor, CblasLower, dim, alpha, Xdata, incX, Adata); +} +inline void cblas_Xspr(MatrixIndexT dim, double alpha, const double *Xdata, + MatrixIndexT incX, double *Adata) { + cblas_dspr(CblasRowMajor, CblasLower, dim, alpha, Xdata, incX, Adata); +} + +// sgemv,dgemv: y = alpha M x + beta y. +inline void cblas_Xgemv(MatrixTransposeType trans, MatrixIndexT num_rows, + MatrixIndexT num_cols, float alpha, const float *Mdata, + MatrixIndexT stride, const float *xdata, + MatrixIndexT incX, float beta, float *ydata, MatrixIndexT incY) { + cblas_sgemv(CblasRowMajor, static_cast<CBLAS_TRANSPOSE>(trans), num_rows, + num_cols, alpha, Mdata, stride, xdata, incX, beta, ydata, incY); +} +inline void cblas_Xgemv(MatrixTransposeType trans, MatrixIndexT num_rows, + MatrixIndexT num_cols, double alpha, const double *Mdata, + MatrixIndexT stride, const double *xdata, + MatrixIndexT incX, double beta, double *ydata, MatrixIndexT incY) { + cblas_dgemv(CblasRowMajor, static_cast<CBLAS_TRANSPOSE>(trans), num_rows, + num_cols, alpha, Mdata, stride, xdata, incX, beta, ydata, incY); +} + +// sgbmv, dgmmv: y = alpha M x + + beta * y. +inline void cblas_Xgbmv(MatrixTransposeType trans, MatrixIndexT num_rows, + MatrixIndexT num_cols, MatrixIndexT num_below, + MatrixIndexT num_above, float alpha, const float *Mdata, + MatrixIndexT stride, const float *xdata, + MatrixIndexT incX, float beta, float *ydata, MatrixIndexT incY) { + cblas_sgbmv(CblasRowMajor, static_cast<CBLAS_TRANSPOSE>(trans), num_rows, + num_cols, num_below, num_above, alpha, Mdata, stride, xdata, + incX, beta, ydata, incY); +} +inline void cblas_Xgbmv(MatrixTransposeType trans, MatrixIndexT num_rows, + MatrixIndexT num_cols, MatrixIndexT num_below, + MatrixIndexT num_above, double alpha, const double *Mdata, + MatrixIndexT stride, const double *xdata, + MatrixIndexT incX, double beta, double *ydata, MatrixIndexT incY) { + cblas_dgbmv(CblasRowMajor, static_cast<CBLAS_TRANSPOSE>(trans), num_rows, + num_cols, num_below, num_above, alpha, Mdata, stride, xdata, + incX, beta, ydata, incY); +} + + +template<typename Real> +inline void Xgemv_sparsevec(MatrixTransposeType trans, MatrixIndexT num_rows, + MatrixIndexT num_cols, Real alpha, const Real *Mdata, + MatrixIndexT stride, const Real *xdata, + MatrixIndexT incX, Real beta, Real *ydata, + MatrixIndexT incY) { + if (trans == kNoTrans) { + if (beta != 1.0) cblas_Xscal(num_rows, beta, ydata, incY); + for (MatrixIndexT i = 0; i < num_cols; i++) { + Real x_i = xdata[i * incX]; + if (x_i == 0.0) continue; + // Add to ydata, the i'th column of M, times alpha * x_i + cblas_Xaxpy(num_rows, x_i * alpha, Mdata + i, stride, ydata, incY); + } + } else { + if (beta != 1.0) cblas_Xscal(num_cols, beta, ydata, incY); + for (MatrixIndexT i = 0; i < num_rows; i++) { + Real x_i = xdata[i * incX]; + if (x_i == 0.0) continue; + // Add to ydata, the i'th row of M, times alpha * x_i + cblas_Xaxpy(num_cols, x_i * alpha, + Mdata + (i * stride), 1, ydata, incY); + } + } +} + +inline void cblas_Xgemm(const float alpha, + MatrixTransposeType transA, + const float *Adata, + MatrixIndexT a_num_rows, MatrixIndexT a_num_cols, MatrixIndexT a_stride, + MatrixTransposeType transB, + const float *Bdata, MatrixIndexT b_stride, + const float beta, + float *Mdata, + MatrixIndexT num_rows, MatrixIndexT num_cols,MatrixIndexT stride) { + cblas_sgemm(CblasRowMajor, static_cast<CBLAS_TRANSPOSE>(transA), + static_cast<CBLAS_TRANSPOSE>(transB), + num_rows, num_cols, transA == kNoTrans ? a_num_cols : a_num_rows, + alpha, Adata, a_stride, Bdata, b_stride, + beta, Mdata, stride); +} +inline void cblas_Xgemm(const double alpha, + MatrixTransposeType transA, + const double *Adata, + MatrixIndexT a_num_rows, MatrixIndexT a_num_cols, MatrixIndexT a_stride, + MatrixTransposeType transB, + const double *Bdata, MatrixIndexT b_stride, + const double beta, + double *Mdata, + MatrixIndexT num_rows, MatrixIndexT num_cols,MatrixIndexT stride) { + cblas_dgemm(CblasRowMajor, static_cast<CBLAS_TRANSPOSE>(transA), + static_cast<CBLAS_TRANSPOSE>(transB), + num_rows, num_cols, transA == kNoTrans ? a_num_cols : a_num_rows, + alpha, Adata, a_stride, Bdata, b_stride, + beta, Mdata, stride); +} + + +inline void cblas_Xsymm(const float alpha, + MatrixIndexT sz, + const float *Adata,MatrixIndexT a_stride, + const float *Bdata,MatrixIndexT b_stride, + const float beta, + float *Mdata, MatrixIndexT stride) { + cblas_ssymm(CblasRowMajor, CblasLeft, CblasLower, sz, sz, alpha, Adata, + a_stride, Bdata, b_stride, beta, Mdata, stride); +} +inline void cblas_Xsymm(const double alpha, + MatrixIndexT sz, + const double *Adata,MatrixIndexT a_stride, + const double *Bdata,MatrixIndexT b_stride, + const double beta, + double *Mdata, MatrixIndexT stride) { + cblas_dsymm(CblasRowMajor, CblasLeft, CblasLower, sz, sz, alpha, Adata, + a_stride, Bdata, b_stride, beta, Mdata, stride); +} +// ger: M += alpha x y^T. +inline void cblas_Xger(MatrixIndexT num_rows, MatrixIndexT num_cols, float alpha, + const float *xdata, MatrixIndexT incX, const float *ydata, + MatrixIndexT incY, float *Mdata, MatrixIndexT stride) { + cblas_sger(CblasRowMajor, num_rows, num_cols, alpha, xdata, 1, ydata, 1, + Mdata, stride); +} +inline void cblas_Xger(MatrixIndexT num_rows, MatrixIndexT num_cols, double alpha, + const double *xdata, MatrixIndexT incX, const double *ydata, + MatrixIndexT incY, double *Mdata, MatrixIndexT stride) { + cblas_dger(CblasRowMajor, num_rows, num_cols, alpha, xdata, 1, ydata, 1, + Mdata, stride); +} + +// syrk: symmetric rank-k update. +// if trans==kNoTrans, then C = alpha A A^T + beta C +// else C = alpha A^T A + beta C. +// note: dim_c is dim(C), other_dim_a is the "other" dimension of A, i.e. +// num-cols(A) if kNoTrans, or num-rows(A) if kTrans. +// We only need the row-major and lower-triangular option of this, and this +// is hard-coded. +inline void cblas_Xsyrk ( + const MatrixTransposeType trans, const MatrixIndexT dim_c, + const MatrixIndexT other_dim_a, const float alpha, const float *A, + const MatrixIndexT a_stride, const float beta, float *C, + const MatrixIndexT c_stride) { + cblas_ssyrk(CblasRowMajor, CblasLower, static_cast<CBLAS_TRANSPOSE>(trans), + dim_c, other_dim_a, alpha, A, a_stride, beta, C, c_stride); +} + +inline void cblas_Xsyrk( + const MatrixTransposeType trans, const MatrixIndexT dim_c, + const MatrixIndexT other_dim_a, const double alpha, const double *A, + const MatrixIndexT a_stride, const double beta, double *C, + const MatrixIndexT c_stride) { + cblas_dsyrk(CblasRowMajor, CblasLower, static_cast<CBLAS_TRANSPOSE>(trans), + dim_c, other_dim_a, alpha, A, a_stride, beta, C, c_stride); +} + +/// matrix-vector multiply using a banded matrix; we always call this +/// with b = 1 meaning we're multiplying by a diagonal matrix. This is used for +/// elementwise multiplication. We miss some of the arguments out of this +/// wrapper. +inline void cblas_Xsbmv1( + const MatrixIndexT dim, + const double *A, + const double alpha, + const double *x, + const double beta, + double *y) { + cblas_dsbmv(CblasRowMajor, CblasLower, dim, 0, alpha, A, + 1, x, 1, beta, y, 1); +} + +inline void cblas_Xsbmv1( + const MatrixIndexT dim, + const float *A, + const float alpha, + const float *x, + const float beta, + float *y) { + cblas_ssbmv(CblasRowMajor, CblasLower, dim, 0, alpha, A, + 1, x, 1, beta, y, 1); +} + + +/// This is not really a wrapper for CBLAS as CBLAS does not have this; in future we could +/// extend this somehow. +inline void mul_elements( + const MatrixIndexT dim, + const double *a, + double *b) { // does b *= a, elementwise. + double c1, c2, c3, c4; + MatrixIndexT i; + for (i = 0; i + 4 <= dim; i += 4) { + c1 = a[i] * b[i]; + c2 = a[i+1] * b[i+1]; + c3 = a[i+2] * b[i+2]; + c4 = a[i+3] * b[i+3]; + b[i] = c1; + b[i+1] = c2; + b[i+2] = c3; + b[i+3] = c4; + } + for (; i < dim; i++) + b[i] *= a[i]; +} + +inline void mul_elements( + const MatrixIndexT dim, + const float *a, + float *b) { // does b *= a, elementwise. + float c1, c2, c3, c4; + MatrixIndexT i; + for (i = 0; i + 4 <= dim; i += 4) { + c1 = a[i] * b[i]; + c2 = a[i+1] * b[i+1]; + c3 = a[i+2] * b[i+2]; + c4 = a[i+3] * b[i+3]; + b[i] = c1; + b[i+1] = c2; + b[i+2] = c3; + b[i+3] = c4; + } + for (; i < dim; i++) + b[i] *= a[i]; +} + + + +// add clapack here +#if !defined(HAVE_ATLAS) +inline void clapack_Xtptri(KaldiBlasInt *num_rows, float *Mdata, KaldiBlasInt *result) { + stptri_(const_cast<char *>("U"), const_cast<char *>("N"), num_rows, Mdata, result); +} +inline void clapack_Xtptri(KaldiBlasInt *num_rows, double *Mdata, KaldiBlasInt *result) { + dtptri_(const_cast<char *>("U"), const_cast<char *>("N"), num_rows, Mdata, result); +} +// +inline void clapack_Xgetrf2(KaldiBlasInt *num_rows, KaldiBlasInt *num_cols, + float *Mdata, KaldiBlasInt *stride, KaldiBlasInt *pivot, + KaldiBlasInt *result) { + sgetrf_(num_rows, num_cols, Mdata, stride, pivot, result); +} +inline void clapack_Xgetrf2(KaldiBlasInt *num_rows, KaldiBlasInt *num_cols, + double *Mdata, KaldiBlasInt *stride, KaldiBlasInt *pivot, + KaldiBlasInt *result) { + dgetrf_(num_rows, num_cols, Mdata, stride, pivot, result); +} + +// +inline void clapack_Xgetri2(KaldiBlasInt *num_rows, float *Mdata, KaldiBlasInt *stride, + KaldiBlasInt *pivot, float *p_work, + KaldiBlasInt *l_work, KaldiBlasInt *result) { + sgetri_(num_rows, Mdata, stride, pivot, p_work, l_work, result); +} +inline void clapack_Xgetri2(KaldiBlasInt *num_rows, double *Mdata, KaldiBlasInt *stride, + KaldiBlasInt *pivot, double *p_work, + KaldiBlasInt *l_work, KaldiBlasInt *result) { + dgetri_(num_rows, Mdata, stride, pivot, p_work, l_work, result); +} +// +inline void clapack_Xgesvd(char *v, char *u, KaldiBlasInt *num_cols, + KaldiBlasInt *num_rows, float *Mdata, KaldiBlasInt *stride, + float *sv, float *Vdata, KaldiBlasInt *vstride, + float *Udata, KaldiBlasInt *ustride, float *p_work, + KaldiBlasInt *l_work, KaldiBlasInt *result) { + sgesvd_(v, u, + num_cols, num_rows, Mdata, stride, + sv, Vdata, vstride, Udata, ustride, + p_work, l_work, result); +} +inline void clapack_Xgesvd(char *v, char *u, KaldiBlasInt *num_cols, + KaldiBlasInt *num_rows, double *Mdata, KaldiBlasInt *stride, + double *sv, double *Vdata, KaldiBlasInt *vstride, + double *Udata, KaldiBlasInt *ustride, double *p_work, + KaldiBlasInt *l_work, KaldiBlasInt *result) { + dgesvd_(v, u, + num_cols, num_rows, Mdata, stride, + sv, Vdata, vstride, Udata, ustride, + p_work, l_work, result); +} +// +void inline clapack_Xsptri(KaldiBlasInt *num_rows, float *Mdata, + KaldiBlasInt *ipiv, float *work, KaldiBlasInt *result) { + ssptri_(const_cast<char *>("U"), num_rows, Mdata, ipiv, work, result); +} +void inline clapack_Xsptri(KaldiBlasInt *num_rows, double *Mdata, + KaldiBlasInt *ipiv, double *work, KaldiBlasInt *result) { + dsptri_(const_cast<char *>("U"), num_rows, Mdata, ipiv, work, result); +} +// +void inline clapack_Xsptrf(KaldiBlasInt *num_rows, float *Mdata, + KaldiBlasInt *ipiv, KaldiBlasInt *result) { + ssptrf_(const_cast<char *>("U"), num_rows, Mdata, ipiv, result); +} +void inline clapack_Xsptrf(KaldiBlasInt *num_rows, double *Mdata, + KaldiBlasInt *ipiv, KaldiBlasInt *result) { + dsptrf_(const_cast<char *>("U"), num_rows, Mdata, ipiv, result); +} +#else +inline void clapack_Xgetrf(MatrixIndexT num_rows, MatrixIndexT num_cols, + float *Mdata, MatrixIndexT stride, + int *pivot, int *result) { + *result = clapack_sgetrf(CblasColMajor, num_rows, num_cols, + Mdata, stride, pivot); +} + +inline void clapack_Xgetrf(MatrixIndexT num_rows, MatrixIndexT num_cols, + double *Mdata, MatrixIndexT stride, + int *pivot, int *result) { + *result = clapack_dgetrf(CblasColMajor, num_rows, num_cols, + Mdata, stride, pivot); +} +// +inline int clapack_Xtrtri(int num_rows, float *Mdata, MatrixIndexT stride) { + return clapack_strtri(CblasColMajor, CblasUpper, CblasNonUnit, num_rows, + Mdata, stride); +} + +inline int clapack_Xtrtri(int num_rows, double *Mdata, MatrixIndexT stride) { + return clapack_dtrtri(CblasColMajor, CblasUpper, CblasNonUnit, num_rows, + Mdata, stride); +} +// +inline void clapack_Xgetri(MatrixIndexT num_rows, float *Mdata, MatrixIndexT stride, + int *pivot, int *result) { + *result = clapack_sgetri(CblasColMajor, num_rows, Mdata, stride, pivot); +} +inline void clapack_Xgetri(MatrixIndexT num_rows, double *Mdata, MatrixIndexT stride, + int *pivot, int *result) { + *result = clapack_dgetri(CblasColMajor, num_rows, Mdata, stride, pivot); +} +#endif + +} +// namespace kaldi + +#endif diff --git a/kaldi_io/src/kaldi/matrix/compressed-matrix.h b/kaldi_io/src/kaldi/matrix/compressed-matrix.h new file mode 100644 index 0000000..746cab3 --- /dev/null +++ b/kaldi_io/src/kaldi/matrix/compressed-matrix.h @@ -0,0 +1,179 @@ +// matrix/compressed-matrix.h + +// Copyright 2012 Johns Hopkins University (author: Daniel Povey) +// Frantisek Skala, Wei Shi + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#ifndef KALDI_MATRIX_COMPRESSED_MATRIX_H_ +#define KALDI_MATRIX_COMPRESSED_MATRIX_H_ 1 + +#include "kaldi-matrix.h" + +namespace kaldi { + +/// \addtogroup matrix_group +/// @{ + +/// This class does lossy compression of a matrix. It only +/// supports copying to-from a KaldiMatrix. For large matrices, +/// each element is compressed into about one byte, but there +/// is a little overhead on top of that (globally, and also per +/// column). + +/// The basic idea is for each column (in the normal configuration) +/// we work out the values at the 0th, 25th, 50th and 100th percentiles +/// and store them as 16-bit integers; we then encode each value in +/// the column as a single byte, in 3 separate ranges with different +/// linear encodings (0-25th, 25-50th, 50th-100th). +/// If the matrix has 8 rows or fewer, we simply store all values as +/// uint16. + +class CompressedMatrix { + public: + CompressedMatrix(): data_(NULL) { } + + ~CompressedMatrix() { Destroy(); } + + template<typename Real> + CompressedMatrix(const MatrixBase<Real> &mat): data_(NULL) { CopyFromMat(mat); } + + /// Initializer that can be used to select part of an existing + /// CompressedMatrix without un-compressing and re-compressing (note: unlike + /// similar initializers for class Matrix, it doesn't point to the same memory + /// location). + CompressedMatrix(const CompressedMatrix &mat, + const MatrixIndexT row_offset, + const MatrixIndexT num_rows, + const MatrixIndexT col_offset, + const MatrixIndexT num_cols); + + void *Data() const { return this->data_; } + + /// This will resize *this and copy the contents of mat to *this. + template<typename Real> + void CopyFromMat(const MatrixBase<Real> &mat); + + CompressedMatrix(const CompressedMatrix &mat); + + CompressedMatrix &operator = (const CompressedMatrix &mat); // assignment operator. + + template<typename Real> + CompressedMatrix &operator = (const MatrixBase<Real> &mat); // assignment operator. + + /// Copies contents to matrix. Note: mat must have the correct size, + /// CopyToMat no longer attempts to resize it. + template<typename Real> + void CopyToMat(MatrixBase<Real> *mat) const; + + void Write(std::ostream &os, bool binary) const; + + void Read(std::istream &is, bool binary); + + /// Returns number of rows (or zero for emtpy matrix). + inline MatrixIndexT NumRows() const { return (data_ == NULL) ? 0 : + (*reinterpret_cast<GlobalHeader*>(data_)).num_rows; } + + /// Returns number of columns (or zero for emtpy matrix). + inline MatrixIndexT NumCols() const { return (data_ == NULL) ? 0 : + (*reinterpret_cast<GlobalHeader*>(data_)).num_cols; } + + /// Copies row #row of the matrix into vector v. + /// Note: v must have same size as #cols. + template<typename Real> + void CopyRowToVec(MatrixIndexT row, VectorBase<Real> *v) const; + + /// Copies column #col of the matrix into vector v. + /// Note: v must have same size as #rows. + template<typename Real> + void CopyColToVec(MatrixIndexT col, VectorBase<Real> *v) const; + + /// Copies submatrix of compressed matrix into matrix dest. + /// Submatrix starts at row row_offset and column column_offset and its size + /// is defined by size of provided matrix dest + template<typename Real> + void CopyToMat(int32 row_offset, + int32 column_offset, + MatrixBase<Real> *dest) const; + + void Swap(CompressedMatrix *other) { std::swap(data_, other->data_); } + + friend class Matrix<float>; + friend class Matrix<double>; + private: + + // allocates data using new [], ensures byte alignment + // sufficient for float. + static void *AllocateData(int32 num_bytes); + + // the "format" will be 1 for the original format where each column has a + // PerColHeader, and 2 for the format now used for matrices with 8 or fewer + // rows, where everything is represented as 16-bit integers. + struct GlobalHeader { + int32 format; + float min_value; + float range; + int32 num_rows; + int32 num_cols; + }; + + static MatrixIndexT DataSize(const GlobalHeader &header); + + struct PerColHeader { + uint16 percentile_0; + uint16 percentile_25; + uint16 percentile_75; + uint16 percentile_100; + }; + + template<typename Real> + static void CompressColumn(const GlobalHeader &global_header, + const Real *data, MatrixIndexT stride, + int32 num_rows, PerColHeader *header, + unsigned char *byte_data); + template<typename Real> + static void ComputeColHeader(const GlobalHeader &global_header, + const Real *data, MatrixIndexT stride, + int32 num_rows, PerColHeader *header); + + static inline uint16 FloatToUint16(const GlobalHeader &global_header, + float value); + + static inline float Uint16ToFloat(const GlobalHeader &global_header, + uint16 value); + static inline unsigned char FloatToChar(float p0, float p25, + float p75, float p100, + float value); + static inline float CharToFloat(float p0, float p25, + float p75, float p100, + unsigned char value); + + void Destroy(); + + void *data_; // first GlobalHeader, then PerColHeader (repeated), then + // the byte data for each column (repeated). Note: don't intersperse + // the byte data with the PerColHeaders, because of alignment issues. + +}; + + +/// @} end of \addtogroup matrix_group + + +} // namespace kaldi + + +#endif // KALDI_MATRIX_COMPRESSED_MATRIX_H_ diff --git a/kaldi_io/src/kaldi/matrix/jama-eig.h b/kaldi_io/src/kaldi/matrix/jama-eig.h new file mode 100644 index 0000000..c7278bc --- /dev/null +++ b/kaldi_io/src/kaldi/matrix/jama-eig.h @@ -0,0 +1,924 @@ +// matrix/jama-eig.h + +// Copyright 2009-2011 Microsoft Corporation + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at + +// http://www.apache.org/licenses/LICENSE-2.0 + +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +// This file consists of a port and modification of materials from +// JAMA: A Java Matrix Package +// under the following notice: This software is a cooperative product of +// The MathWorks and the National Institute of Standards and Technology (NIST) +// which has been released to the public. This notice and the original code are +// available at http://math.nist.gov/javanumerics/jama/domain.notice + + + +#ifndef KALDI_MATRIX_JAMA_EIG_H_ +#define KALDI_MATRIX_JAMA_EIG_H_ 1 + +#include "matrix/kaldi-matrix.h" + +namespace kaldi { + +// This class is not to be used externally. See the Eig function in the Matrix +// class in kaldi-matrix.h. This is the external interface. + +template<typename Real> class EigenvalueDecomposition { + // This class is based on the EigenvalueDecomposition class from the JAMA + // library (version 1.0.2). + public: + EigenvalueDecomposition(const MatrixBase<Real> &A); + + ~EigenvalueDecomposition(); // free memory. + + void GetV(MatrixBase<Real> *V_out) { // V is what we call P externally; it's the matrix of + // eigenvectors. + KALDI_ASSERT(V_out->NumRows() == static_cast<MatrixIndexT>(n_) + && V_out->NumCols() == static_cast<MatrixIndexT>(n_)); + for (int i = 0; i < n_; i++) + for (int j = 0; j < n_; j++) + (*V_out)(i, j) = V(i, j); // V(i, j) is member function. + } + void GetRealEigenvalues(VectorBase<Real> *r_out) { + // returns real part of eigenvalues. + KALDI_ASSERT(r_out->Dim() == static_cast<MatrixIndexT>(n_)); + for (int i = 0; i < n_; i++) + (*r_out)(i) = d_[i]; + } + void GetImagEigenvalues(VectorBase<Real> *i_out) { + // returns imaginary part of eigenvalues. + KALDI_ASSERT(i_out->Dim() == static_cast<MatrixIndexT>(n_)); + for (int i = 0; i < n_; i++) + (*i_out)(i) = e_[i]; + } + private: + + inline Real &H(int r, int c) { return H_[r*n_ + c]; } + inline Real &V(int r, int c) { return V_[r*n_ + c]; } + + // complex division + inline static void cdiv(Real xr, Real xi, Real yr, Real yi, Real *cdivr, Real *cdivi) { + Real r, d; + if (std::abs(yr) > std::abs(yi)) { + r = yi/yr; + d = yr + r*yi; + *cdivr = (xr + r*xi)/d; + *cdivi = (xi - r*xr)/d; + } else { + r = yr/yi; + d = yi + r*yr; + *cdivr = (r*xr + xi)/d; + *cdivi = (r*xi - xr)/d; + } + } + + // Nonsymmetric reduction from Hessenberg to real Schur form. + void Hqr2 (); + + + int n_; // matrix dimension. + + Real *d_, *e_; // real and imaginary parts of eigenvalues. + Real *V_; // the eigenvectors (P in our external notation) + Real *H_; // the nonsymmetric Hessenberg form. + Real *ort_; // working storage for nonsymmetric algorithm. + + // Symmetric Householder reduction to tridiagonal form. + void Tred2 (); + + // Symmetric tridiagonal QL algorithm. + void Tql2 (); + + // Nonsymmetric reduction to Hessenberg form. + void Orthes (); + +}; + +template class EigenvalueDecomposition<float>; // force instantiation. +template class EigenvalueDecomposition<double>; // force instantiation. + +template<typename Real> void EigenvalueDecomposition<Real>::Tred2() { + // This is derived from the Algol procedures tred2 by + // Bowdler, Martin, Reinsch, and Wilkinson, Handbook for + // Auto. Comp., Vol.ii-Linear Algebra, and the corresponding + // Fortran subroutine in EISPACK. + + for (int j = 0; j < n_; j++) { + d_[j] = V(n_-1, j); + } + + // Householder reduction to tridiagonal form. + + for (int i = n_-1; i > 0; i--) { + + // Scale to avoid under/overflow. + + Real scale = 0.0; + Real h = 0.0; + for (int k = 0; k < i; k++) { + scale = scale + std::abs(d_[k]); + } + if (scale == 0.0) { + e_[i] = d_[i-1]; + for (int j = 0; j < i; j++) { + d_[j] = V(i-1, j); + V(i, j) = 0.0; + V(j, i) = 0.0; + } + } else { + + // Generate Householder vector. + + for (int k = 0; k < i; k++) { + d_[k] /= scale; + h += d_[k] * d_[k]; + } + Real f = d_[i-1]; + Real g = std::sqrt(h); + if (f > 0) { + g = -g; + } + e_[i] = scale * g; + h = h - f * g; + d_[i-1] = f - g; + for (int j = 0; j < i; j++) { + e_[j] = 0.0; + } + + // Apply similarity transformation to remaining columns. + + for (int j = 0; j < i; j++) { + f = d_[j]; + V(j, i) = f; + g =e_[j] + V(j, j) * f; + for (int k = j+1; k <= i-1; k++) { + g += V(k, j) * d_[k]; + e_[k] += V(k, j) * f; + } + e_[j] = g; + } + f = 0.0; + for (int j = 0; j < i; j++) { + e_[j] /= h; + f += e_[j] * d_[j]; + } + Real hh = f / (h + h); + for (int j = 0; j < i; j++) { + e_[j] -= hh * d_[j]; + } + for (int j = 0; j < i; j++) { + f = d_[j]; + g = e_[j]; + for (int k = j; k <= i-1; k++) { + V(k, j) -= (f * e_[k] + g * d_[k]); + } + d_[j] = V(i-1, j); + V(i, j) = 0.0; + } + } + d_[i] = h; + } + + // Accumulate transformations. + + for (int i = 0; i < n_-1; i++) { + V(n_-1, i) = V(i, i); + V(i, i) = 1.0; + Real h = d_[i+1]; + if (h != 0.0) { + for (int k = 0; k <= i; k++) { + d_[k] = V(k, i+1) / h; + } + for (int j = 0; j <= i; j++) { + Real g = 0.0; + for (int k = 0; k <= i; k++) { + g += V(k, i+1) * V(k, j); + } + for (int k = 0; k <= i; k++) { + V(k, j) -= g * d_[k]; + } + } + } + for (int k = 0; k <= i; k++) { + V(k, i+1) = 0.0; + } + } + for (int j = 0; j < n_; j++) { + d_[j] = V(n_-1, j); + V(n_-1, j) = 0.0; + } + V(n_-1, n_-1) = 1.0; + e_[0] = 0.0; +} + +template<typename Real> void EigenvalueDecomposition<Real>::Tql2() { + // This is derived from the Algol procedures tql2, by + // Bowdler, Martin, Reinsch, and Wilkinson, Handbook for + // Auto. Comp., Vol.ii-Linear Algebra, and the corresponding + // Fortran subroutine in EISPACK. + + for (int i = 1; i < n_; i++) { + e_[i-1] = e_[i]; + } + e_[n_-1] = 0.0; + + Real f = 0.0; + Real tst1 = 0.0; + Real eps = std::numeric_limits<Real>::epsilon(); + for (int l = 0; l < n_; l++) { + + // Find small subdiagonal element + + tst1 = std::max(tst1, std::abs(d_[l]) + std::abs(e_[l])); + int m = l; + while (m < n_) { + if (std::abs(e_[m]) <= eps*tst1) { + break; + } + m++; + } + + // If m == l, d_[l] is an eigenvalue, + // otherwise, iterate. + + if (m > l) { + int iter = 0; + do { + iter = iter + 1; // (Could check iteration count here.) + + // Compute implicit shift + + Real g = d_[l]; + Real p = (d_[l+1] - g) / (2.0 *e_[l]); + Real r = Hypot(p, static_cast<Real>(1.0)); // This is a Kaldi version of hypot that works with templates. + if (p < 0) { + r = -r; + } + d_[l] =e_[l] / (p + r); + d_[l+1] =e_[l] * (p + r); + Real dl1 = d_[l+1]; + Real h = g - d_[l]; + for (int i = l+2; i < n_; i++) { + d_[i] -= h; + } + f = f + h; + + // Implicit QL transformation. + + p = d_[m]; + Real c = 1.0; + Real c2 = c; + Real c3 = c; + Real el1 =e_[l+1]; + Real s = 0.0; + Real s2 = 0.0; + for (int i = m-1; i >= l; i--) { + c3 = c2; + c2 = c; + s2 = s; + g = c *e_[i]; + h = c * p; + r = Hypot(p, e_[i]); // This is a Kaldi version of Hypot that works with templates. + e_[i+1] = s * r; + s =e_[i] / r; + c = p / r; + p = c * d_[i] - s * g; + d_[i+1] = h + s * (c * g + s * d_[i]); + + // Accumulate transformation. + + for (int k = 0; k < n_; k++) { + h = V(k, i+1); + V(k, i+1) = s * V(k, i) + c * h; + V(k, i) = c * V(k, i) - s * h; + } + } + p = -s * s2 * c3 * el1 *e_[l] / dl1; + e_[l] = s * p; + d_[l] = c * p; + + // Check for convergence. + + } while (std::abs(e_[l]) > eps*tst1); + } + d_[l] = d_[l] + f; + e_[l] = 0.0; + } + + // Sort eigenvalues and corresponding vectors. + + for (int i = 0; i < n_-1; i++) { + int k = i; + Real p = d_[i]; + for (int j = i+1; j < n_; j++) { + if (d_[j] < p) { + k = j; + p = d_[j]; + } + } + if (k != i) { + d_[k] = d_[i]; + d_[i] = p; + for (int j = 0; j < n_; j++) { + p = V(j, i); + V(j, i) = V(j, k); + V(j, k) = p; + } + } + } +} + +template<typename Real> +void EigenvalueDecomposition<Real>::Orthes() { + + // This is derived from the Algol procedures orthes and ortran, + // by Martin and Wilkinson, Handbook for Auto. Comp., + // Vol.ii-Linear Algebra, and the corresponding + // Fortran subroutines in EISPACK. + + int low = 0; + int high = n_-1; + + for (int m = low+1; m <= high-1; m++) { + + // Scale column. + + Real scale = 0.0; + for (int i = m; i <= high; i++) { + scale = scale + std::abs(H(i, m-1)); + } + if (scale != 0.0) { + + // Compute Householder transformation. + + Real h = 0.0; + for (int i = high; i >= m; i--) { + ort_[i] = H(i, m-1)/scale; + h += ort_[i] * ort_[i]; + } + Real g = std::sqrt(h); + if (ort_[m] > 0) { + g = -g; + } + h = h - ort_[m] * g; + ort_[m] = ort_[m] - g; + + // Apply Householder similarity transformation + // H = (I-u*u'/h)*H*(I-u*u')/h) + + for (int j = m; j < n_; j++) { + Real f = 0.0; + for (int i = high; i >= m; i--) { + f += ort_[i]*H(i, j); + } + f = f/h; + for (int i = m; i <= high; i++) { + H(i, j) -= f*ort_[i]; + } + } + + for (int i = 0; i <= high; i++) { + Real f = 0.0; + for (int j = high; j >= m; j--) { + f += ort_[j]*H(i, j); + } + f = f/h; + for (int j = m; j <= high; j++) { + H(i, j) -= f*ort_[j]; + } + } + ort_[m] = scale*ort_[m]; + H(m, m-1) = scale*g; + } + } + + // Accumulate transformations (Algol's ortran). + + for (int i = 0; i < n_; i++) { + for (int j = 0; j < n_; j++) { + V(i, j) = (i == j ? 1.0 : 0.0); + } + } + + for (int m = high-1; m >= low+1; m--) { + if (H(m, m-1) != 0.0) { + for (int i = m+1; i <= high; i++) { + ort_[i] = H(i, m-1); + } + for (int j = m; j <= high; j++) { + Real g = 0.0; + for (int i = m; i <= high; i++) { + g += ort_[i] * V(i, j); + } + // Double division avoids possible underflow + g = (g / ort_[m]) / H(m, m-1); + for (int i = m; i <= high; i++) { + V(i, j) += g * ort_[i]; + } + } + } + } +} + +template<typename Real> void EigenvalueDecomposition<Real>::Hqr2() { + // This is derived from the Algol procedure hqr2, + // by Martin and Wilkinson, Handbook for Auto. Comp., + // Vol.ii-Linear Algebra, and the corresponding + // Fortran subroutine in EISPACK. + + int nn = n_; + int n = nn-1; + int low = 0; + int high = nn-1; + Real eps = std::numeric_limits<Real>::epsilon(); + Real exshift = 0.0; + Real p = 0, q = 0, r = 0, s = 0, z=0, t, w, x, y; + + // Store roots isolated by balanc and compute matrix norm + + Real norm = 0.0; + for (int i = 0; i < nn; i++) { + if (i < low || i > high) { + d_[i] = H(i, i); + e_[i] = 0.0; + } + for (int j = std::max(i-1, 0); j < nn; j++) { + norm = norm + std::abs(H(i, j)); + } + } + + // Outer loop over eigenvalue index + + int iter = 0; + while (n >= low) { + + // Look for single small sub-diagonal element + + int l = n; + while (l > low) { + s = std::abs(H(l-1, l-1)) + std::abs(H(l, l)); + if (s == 0.0) { + s = norm; + } + if (std::abs(H(l, l-1)) < eps * s) { + break; + } + l--; + } + + // Check for convergence + // One root found + + if (l == n) { + H(n, n) = H(n, n) + exshift; + d_[n] = H(n, n); + e_[n] = 0.0; + n--; + iter = 0; + + // Two roots found + + } else if (l == n-1) { + w = H(n, n-1) * H(n-1, n); + p = (H(n-1, n-1) - H(n, n)) / 2.0; + q = p * p + w; + z = std::sqrt(std::abs(q)); + H(n, n) = H(n, n) + exshift; + H(n-1, n-1) = H(n-1, n-1) + exshift; + x = H(n, n); + + // Real pair + + if (q >= 0) { + if (p >= 0) { + z = p + z; + } else { + z = p - z; + } + d_[n-1] = x + z; + d_[n] = d_[n-1]; + if (z != 0.0) { + d_[n] = x - w / z; + } + e_[n-1] = 0.0; + e_[n] = 0.0; + x = H(n, n-1); + s = std::abs(x) + std::abs(z); + p = x / s; + q = z / s; + r = std::sqrt(p * p+q * q); + p = p / r; + q = q / r; + + // Row modification + + for (int j = n-1; j < nn; j++) { + z = H(n-1, j); + H(n-1, j) = q * z + p * H(n, j); + H(n, j) = q * H(n, j) - p * z; + } + + // Column modification + + for (int i = 0; i <= n; i++) { + z = H(i, n-1); + H(i, n-1) = q * z + p * H(i, n); + H(i, n) = q * H(i, n) - p * z; + } + + // Accumulate transformations + + for (int i = low; i <= high; i++) { + z = V(i, n-1); + V(i, n-1) = q * z + p * V(i, n); + V(i, n) = q * V(i, n) - p * z; + } + + // Complex pair + + } else { + d_[n-1] = x + p; + d_[n] = x + p; + e_[n-1] = z; + e_[n] = -z; + } + n = n - 2; + iter = 0; + + // No convergence yet + + } else { + + // Form shift + + x = H(n, n); + y = 0.0; + w = 0.0; + if (l < n) { + y = H(n-1, n-1); + w = H(n, n-1) * H(n-1, n); + } + + // Wilkinson's original ad hoc shift + + if (iter == 10) { + exshift += x; + for (int i = low; i <= n; i++) { + H(i, i) -= x; + } + s = std::abs(H(n, n-1)) + std::abs(H(n-1, n-2)); + x = y = 0.75 * s; + w = -0.4375 * s * s; + } + + // MATLAB's new ad hoc shift + + if (iter == 30) { + s = (y - x) / 2.0; + s = s * s + w; + if (s > 0) { + s = std::sqrt(s); + if (y < x) { + s = -s; + } + s = x - w / ((y - x) / 2.0 + s); + for (int i = low; i <= n; i++) { + H(i, i) -= s; + } + exshift += s; + x = y = w = 0.964; + } + } + + iter = iter + 1; // (Could check iteration count here.) + + // Look for two consecutive small sub-diagonal elements + + int m = n-2; + while (m >= l) { + z = H(m, m); + r = x - z; + s = y - z; + p = (r * s - w) / H(m+1, m) + H(m, m+1); + q = H(m+1, m+1) - z - r - s; + r = H(m+2, m+1); + s = std::abs(p) + std::abs(q) + std::abs(r); + p = p / s; + q = q / s; + r = r / s; + if (m == l) { + break; + } + if (std::abs(H(m, m-1)) * (std::abs(q) + std::abs(r)) < + eps * (std::abs(p) * (std::abs(H(m-1, m-1)) + std::abs(z) + + std::abs(H(m+1, m+1))))) { + break; + } + m--; + } + + for (int i = m+2; i <= n; i++) { + H(i, i-2) = 0.0; + if (i > m+2) { + H(i, i-3) = 0.0; + } + } + + // Double QR step involving rows l:n and columns m:n + + for (int k = m; k <= n-1; k++) { + bool notlast = (k != n-1); + if (k != m) { + p = H(k, k-1); + q = H(k+1, k-1); + r = (notlast ? H(k+2, k-1) : 0.0); + x = std::abs(p) + std::abs(q) + std::abs(r); + if (x != 0.0) { + p = p / x; + q = q / x; + r = r / x; + } + } + if (x == 0.0) { + break; + } + s = std::sqrt(p * p + q * q + r * r); + if (p < 0) { + s = -s; + } + if (s != 0) { + if (k != m) { + H(k, k-1) = -s * x; + } else if (l != m) { + H(k, k-1) = -H(k, k-1); + } + p = p + s; + x = p / s; + y = q / s; + z = r / s; + q = q / p; + r = r / p; + + // Row modification + + for (int j = k; j < nn; j++) { + p = H(k, j) + q * H(k+1, j); + if (notlast) { + p = p + r * H(k+2, j); + H(k+2, j) = H(k+2, j) - p * z; + } + H(k, j) = H(k, j) - p * x; + H(k+1, j) = H(k+1, j) - p * y; + } + + // Column modification + + for (int i = 0; i <= std::min(n, k+3); i++) { + p = x * H(i, k) + y * H(i, k+1); + if (notlast) { + p = p + z * H(i, k+2); + H(i, k+2) = H(i, k+2) - p * r; + } + H(i, k) = H(i, k) - p; + H(i, k+1) = H(i, k+1) - p * q; + } + + // Accumulate transformations + + for (int i = low; i <= high; i++) { + p = x * V(i, k) + y * V(i, k+1); + if (notlast) { + p = p + z * V(i, k+2); + V(i, k+2) = V(i, k+2) - p * r; + } + V(i, k) = V(i, k) - p; + V(i, k+1) = V(i, k+1) - p * q; + } + } // (s != 0) + } // k loop + } // check convergence + } // while (n >= low) + + // Backsubstitute to find vectors of upper triangular form + + if (norm == 0.0) { + return; + } + + for (n = nn-1; n >= 0; n--) { + p = d_[n]; + q = e_[n]; + + // Real vector + + if (q == 0) { + int l = n; + H(n, n) = 1.0; + for (int i = n-1; i >= 0; i--) { + w = H(i, i) - p; + r = 0.0; + for (int j = l; j <= n; j++) { + r = r + H(i, j) * H(j, n); + } + if (e_[i] < 0.0) { + z = w; + s = r; + } else { + l = i; + if (e_[i] == 0.0) { + if (w != 0.0) { + H(i, n) = -r / w; + } else { + H(i, n) = -r / (eps * norm); + } + + // Solve real equations + + } else { + x = H(i, i+1); + y = H(i+1, i); + q = (d_[i] - p) * (d_[i] - p) +e_[i] *e_[i]; + t = (x * s - z * r) / q; + H(i, n) = t; + if (std::abs(x) > std::abs(z)) { + H(i+1, n) = (-r - w * t) / x; + } else { + H(i+1, n) = (-s - y * t) / z; + } + } + + // Overflow control + + t = std::abs(H(i, n)); + if ((eps * t) * t > 1) { + for (int j = i; j <= n; j++) { + H(j, n) = H(j, n) / t; + } + } + } + } + + // Complex vector + + } else if (q < 0) { + int l = n-1; + + // Last vector component imaginary so matrix is triangular + + if (std::abs(H(n, n-1)) > std::abs(H(n-1, n))) { + H(n-1, n-1) = q / H(n, n-1); + H(n-1, n) = -(H(n, n) - p) / H(n, n-1); + } else { + Real cdivr, cdivi; + cdiv(0.0, -H(n-1, n), H(n-1, n-1)-p, q, &cdivr, &cdivi); + H(n-1, n-1) = cdivr; + H(n-1, n) = cdivi; + } + H(n, n-1) = 0.0; + H(n, n) = 1.0; + for (int i = n-2; i >= 0; i--) { + Real ra, sa, vr, vi; + ra = 0.0; + sa = 0.0; + for (int j = l; j <= n; j++) { + ra = ra + H(i, j) * H(j, n-1); + sa = sa + H(i, j) * H(j, n); + } + w = H(i, i) - p; + + if (e_[i] < 0.0) { + z = w; + r = ra; + s = sa; + } else { + l = i; + if (e_[i] == 0) { + Real cdivr, cdivi; + cdiv(-ra, -sa, w, q, &cdivr, &cdivi); + H(i, n-1) = cdivr; + H(i, n) = cdivi; + } else { + Real cdivr, cdivi; + // Solve complex equations + + x = H(i, i+1); + y = H(i+1, i); + vr = (d_[i] - p) * (d_[i] - p) +e_[i] *e_[i] - q * q; + vi = (d_[i] - p) * 2.0 * q; + if (vr == 0.0 && vi == 0.0) { + vr = eps * norm * (std::abs(w) + std::abs(q) + + std::abs(x) + std::abs(y) + std::abs(z)); + } + cdiv(x*r-z*ra+q*sa, x*s-z*sa-q*ra, vr, vi, &cdivr, &cdivi); + H(i, n-1) = cdivr; + H(i, n) = cdivi; + if (std::abs(x) > (std::abs(z) + std::abs(q))) { + H(i+1, n-1) = (-ra - w * H(i, n-1) + q * H(i, n)) / x; + H(i+1, n) = (-sa - w * H(i, n) - q * H(i, n-1)) / x; + } else { + cdiv(-r-y*H(i, n-1), -s-y*H(i, n), z, q, &cdivr, &cdivi); + H(i+1, n-1) = cdivr; + H(i+1, n) = cdivi; + } + } + + // Overflow control + + t = std::max(std::abs(H(i, n-1)), std::abs(H(i, n))); + if ((eps * t) * t > 1) { + for (int j = i; j <= n; j++) { + H(j, n-1) = H(j, n-1) / t; + H(j, n) = H(j, n) / t; + } + } + } + } + } + } + + // Vectors of isolated roots + + for (int i = 0; i < nn; i++) { + if (i < low || i > high) { + for (int j = i; j < nn; j++) { + V(i, j) = H(i, j); + } + } + } + + // Back transformation to get eigenvectors of original matrix + + for (int j = nn-1; j >= low; j--) { + for (int i = low; i <= high; i++) { + z = 0.0; + for (int k = low; k <= std::min(j, high); k++) { + z = z + V(i, k) * H(k, j); + } + V(i, j) = z; + } + } +} + +template<typename Real> +EigenvalueDecomposition<Real>::EigenvalueDecomposition(const MatrixBase<Real> &A) { + KALDI_ASSERT(A.NumCols() == A.NumRows() && A.NumCols() >= 1); + n_ = A.NumRows(); + V_ = new Real[n_*n_]; + d_ = new Real[n_]; + e_ = new Real[n_]; + H_ = NULL; + ort_ = NULL; + if (A.IsSymmetric(0.0)) { + + for (int i = 0; i < n_; i++) + for (int j = 0; j < n_; j++) + V(i, j) = A(i, j); // Note that V(i, j) is a member function; A(i, j) is an operator + // of the matrix A. + // Tridiagonalize. + Tred2(); + + // Diagonalize. + Tql2(); + } else { + H_ = new Real[n_*n_]; + ort_ = new Real[n_]; + for (int i = 0; i < n_; i++) + for (int j = 0; j < n_; j++) + H(i, j) = A(i, j); // as before: H is member function, A(i, j) is operator of matrix. + + // Reduce to Hessenberg form. + Orthes(); + + // Reduce Hessenberg to real Schur form. + Hqr2(); + } +} + +template<typename Real> +EigenvalueDecomposition<Real>::~EigenvalueDecomposition() { + delete [] d_; + delete [] e_; + delete [] V_; + if (H_) delete [] H_; + if (ort_) delete [] ort_; +} + +// see function MatrixBase<Real>::Eig in kaldi-matrix.cc + + +} // namespace kaldi + +#endif // KALDI_MATRIX_JAMA_EIG_H_ diff --git a/kaldi_io/src/kaldi/matrix/jama-svd.h b/kaldi_io/src/kaldi/matrix/jama-svd.h new file mode 100644 index 0000000..8304dac --- /dev/null +++ b/kaldi_io/src/kaldi/matrix/jama-svd.h @@ -0,0 +1,531 @@ +// matrix/jama-svd.h + +// Copyright 2009-2011 Microsoft Corporation + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +// This file consists of a port and modification of materials from +// JAMA: A Java Matrix Package +// under the following notice: This software is a cooperative product of +// The MathWorks and the National Institute of Standards and Technology (NIST) +// which has been released to the public. This notice and the original code are +// available at http://math.nist.gov/javanumerics/jama/domain.notice + + +#ifndef KALDI_MATRIX_JAMA_SVD_H_ +#define KALDI_MATRIX_JAMA_SVD_H_ 1 + + +#include "matrix/kaldi-matrix.h" +#include "matrix/sp-matrix.h" +#include "matrix/cblas-wrappers.h" + +namespace kaldi { + +#if defined(HAVE_ATLAS) || defined(USE_KALDI_SVD) +// using ATLAS as our math library, which doesn't have SVD -> need +// to implement it. + +// This routine is a modified form of jama_svd.h which is part of the TNT distribution. +// (originally comes from JAMA). + +/** Singular Value Decomposition. + * <P> + * For an m-by-n matrix A with m >= n, the singular value decomposition is + * an m-by-n orthogonal matrix U, an n-by-n diagonal matrix S, and + * an n-by-n orthogonal matrix V so that A = U*S*V'. + * <P> + * The singular values, sigma[k] = S(k, k), are ordered so that + * sigma[0] >= sigma[1] >= ... >= sigma[n-1]. + * <P> + * The singular value decompostion always exists, so the constructor will + * never fail. The matrix condition number and the effective numerical + * rank can be computed from this decomposition. + + * <p> + * (Adapted from JAMA, a Java Matrix Library, developed by jointly + * by the Mathworks and NIST; see http://math.nist.gov/javanumerics/jama). + */ + + +template<typename Real> +bool MatrixBase<Real>::JamaSvd(VectorBase<Real> *s_in, + MatrixBase<Real> *U_in, + MatrixBase<Real> *V_in) { // Destructive! + KALDI_ASSERT(s_in != NULL && U_in != this && V_in != this); + int wantu = (U_in != NULL), wantv = (V_in != NULL); + Matrix<Real> Utmp, Vtmp; + MatrixBase<Real> &U = (U_in ? *U_in : Utmp), &V = (V_in ? *V_in : Vtmp); + VectorBase<Real> &s = *s_in; + + int m = num_rows_, n = num_cols_; + KALDI_ASSERT(m>=n && m != 0 && n != 0); + if (wantu) KALDI_ASSERT((int)U.num_rows_ == m && (int)U.num_cols_ == n); + if (wantv) KALDI_ASSERT((int)V.num_rows_ == n && (int)V.num_cols_ == n); + KALDI_ASSERT((int)s.Dim() == n); // n<=m so n is min. + + int nu = n; + U.SetZero(); // make sure all zero. + Vector<Real> e(n); + Vector<Real> work(m); + MatrixBase<Real> &A(*this); + Real *adata = A.Data(), *workdata = work.Data(), *edata = e.Data(), + *udata = U.Data(), *vdata = V.Data(); + int astride = static_cast<int>(A.Stride()), + ustride = static_cast<int>(U.Stride()), + vstride = static_cast<int>(V.Stride()); + int i = 0, j = 0, k = 0; + + // Reduce A to bidiagonal form, storing the diagonal elements + // in s and the super-diagonal elements in e. + + int nct = std::min(m-1, n); + int nrt = std::max(0, std::min(n-2, m)); + for (k = 0; k < std::max(nct, nrt); k++) { + if (k < nct) { + + // Compute the transformation for the k-th column and + // place the k-th diagonal in s(k). + // Compute 2-norm of k-th column without under/overflow. + s(k) = 0; + for (i = k; i < m; i++) { + s(k) = hypot(s(k), A(i, k)); + } + if (s(k) != 0.0) { + if (A(k, k) < 0.0) { + s(k) = -s(k); + } + for (i = k; i < m; i++) { + A(i, k) /= s(k); + } + A(k, k) += 1.0; + } + s(k) = -s(k); + } + for (j = k+1; j < n; j++) { + if ((k < nct) && (s(k) != 0.0)) { + + // Apply the transformation. + + Real t = cblas_Xdot(m - k, adata + astride*k + k, astride, + adata + astride*k + j, astride); + /*for (i = k; i < m; i++) { + t += adata[i*astride + k]*adata[i*astride + j]; // A(i, k)*A(i, j); // 3 + }*/ + t = -t/A(k, k); + cblas_Xaxpy(m - k, t, adata + k*astride + k, astride, + adata + k*astride + j, astride); + /*for (i = k; i < m; i++) { + adata[i*astride + j] += t*adata[i*astride + k]; // A(i, j) += t*A(i, k); // 5 + }*/ + } + + // Place the k-th row of A into e for the + // subsequent calculation of the row transformation. + + e(j) = A(k, j); + } + if (wantu & (k < nct)) { + + // Place the transformation in U for subsequent back + // multiplication. + + for (i = k; i < m; i++) { + U(i, k) = A(i, k); + } + } + if (k < nrt) { + + // Compute the k-th row transformation and place the + // k-th super-diagonal in e(k). + // Compute 2-norm without under/overflow. + e(k) = 0; + for (i = k+1; i < n; i++) { + e(k) = hypot(e(k), e(i)); + } + if (e(k) != 0.0) { + if (e(k+1) < 0.0) { + e(k) = -e(k); + } + for (i = k+1; i < n; i++) { + e(i) /= e(k); + } + e(k+1) += 1.0; + } + e(k) = -e(k); + if ((k+1 < m) & (e(k) != 0.0)) { + + // Apply the transformation. + + for (i = k+1; i < m; i++) { + work(i) = 0.0; + } + for (j = k+1; j < n; j++) { + for (i = k+1; i < m; i++) { + workdata[i] += edata[j] * adata[i*astride + j]; // work(i) += e(j)*A(i, j); // 5 + } + } + for (j = k+1; j < n; j++) { + Real t(-e(j)/e(k+1)); + cblas_Xaxpy(m - (k+1), t, workdata + (k+1), 1, + adata + (k+1)*astride + j, astride); + /* + for (i = k+1; i < m; i++) { + adata[i*astride + j] += t*workdata[i]; // A(i, j) += t*work(i); // 5 + }*/ + } + } + if (wantv) { + + // Place the transformation in V for subsequent + // back multiplication. + + for (i = k+1; i < n; i++) { + V(i, k) = e(i); + } + } + } + } + + // Set up the final bidiagonal matrix or order p. + + int p = std::min(n, m+1); + if (nct < n) { + s(nct) = A(nct, nct); + } + if (m < p) { + s(p-1) = 0.0; + } + if (nrt+1 < p) { + e(nrt) = A(nrt, p-1); + } + e(p-1) = 0.0; + + // If required, generate U. + + if (wantu) { + for (j = nct; j < nu; j++) { + for (i = 0; i < m; i++) { + U(i, j) = 0.0; + } + U(j, j) = 1.0; + } + for (k = nct-1; k >= 0; k--) { + if (s(k) != 0.0) { + for (j = k+1; j < nu; j++) { + Real t = cblas_Xdot(m - k, udata + k*ustride + k, ustride, udata + k*ustride + j, ustride); + //for (i = k; i < m; i++) { + // t += udata[i*ustride + k]*udata[i*ustride + j]; // t += U(i, k)*U(i, j); // 8 + // } + t = -t/U(k, k); + cblas_Xaxpy(m - k, t, udata + ustride*k + k, ustride, + udata + k*ustride + j, ustride); + /*for (i = k; i < m; i++) { + udata[i*ustride + j] += t*udata[i*ustride + k]; // U(i, j) += t*U(i, k); // 4 + }*/ + } + for (i = k; i < m; i++ ) { + U(i, k) = -U(i, k); + } + U(k, k) = 1.0 + U(k, k); + for (i = 0; i < k-1; i++) { + U(i, k) = 0.0; + } + } else { + for (i = 0; i < m; i++) { + U(i, k) = 0.0; + } + U(k, k) = 1.0; + } + } + } + + // If required, generate V. + + if (wantv) { + for (k = n-1; k >= 0; k--) { + if ((k < nrt) & (e(k) != 0.0)) { + for (j = k+1; j < nu; j++) { + Real t = cblas_Xdot(n - (k+1), vdata + (k+1)*vstride + k, vstride, + vdata + (k+1)*vstride + j, vstride); + /*Real t (0.0); + for (i = k+1; i < n; i++) { + t += vdata[i*vstride + k]*vdata[i*vstride + j]; // t += V(i, k)*V(i, j); // 7 + }*/ + t = -t/V(k+1, k); + cblas_Xaxpy(n - (k+1), t, vdata + (k+1)*vstride + k, vstride, + vdata + (k+1)*vstride + j, vstride); + /*for (i = k+1; i < n; i++) { + vdata[i*vstride + j] += t*vdata[i*vstride + k]; // V(i, j) += t*V(i, k); // 7 + }*/ + } + } + for (i = 0; i < n; i++) { + V(i, k) = 0.0; + } + V(k, k) = 1.0; + } + } + + // Main iteration loop for the singular values. + + int pp = p-1; + int iter = 0; + // note: -52.0 is from Jama code; the -23 is the extension + // to float, because mantissa length in (double, float) + // is (52, 23) bits respectively. + Real eps(pow(2.0, sizeof(Real) == 4 ? -23.0 : -52.0)); + // Note: the -966 was taken from Jama code, but the -120 is a guess + // of how to extend this to float... the exponent in double goes + // from -1022 .. 1023, and in float from -126..127. I'm not sure + // what the significance of 966 is, so -120 just represents a number + // that's a bit less negative than -126. If we get convergence + // failure in float only, this may mean that we have to make the + // -120 value less negative. + Real tiny(pow(2.0, sizeof(Real) == 4 ? -120.0: -966.0 )); + + while (p > 0) { + int k = 0; + int kase = 0; + + if (iter == 500 || iter == 750) { + KALDI_WARN << "Svd taking a long time: making convergence criterion less exact."; + eps = pow(static_cast<Real>(0.8), eps); + tiny = pow(static_cast<Real>(0.8), tiny); + } + if (iter > 1000) { + KALDI_WARN << "Svd not converging on matrix of size " << m << " by " <<n; + return false; + } + + // This section of the program inspects for + // negligible elements in the s and e arrays. On + // completion the variables kase and k are set as follows. + + // kase = 1 if s(p) and e(k-1) are negligible and k < p + // kase = 2 if s(k) is negligible and k < p + // kase = 3 if e(k-1) is negligible, k < p, and + // s(k), ..., s(p) are not negligible (qr step). + // kase = 4 if e(p-1) is negligible (convergence). + + for (k = p-2; k >= -1; k--) { + if (k == -1) { + break; + } + if (std::abs(e(k)) <= + tiny + eps*(std::abs(s(k)) + std::abs(s(k+1)))) { + e(k) = 0.0; + break; + } + } + if (k == p-2) { + kase = 4; + } else { + int ks; + for (ks = p-1; ks >= k; ks--) { + if (ks == k) { + break; + } + Real t( (ks != p ? std::abs(e(ks)) : 0.) + + (ks != k+1 ? std::abs(e(ks-1)) : 0.)); + if (std::abs(s(ks)) <= tiny + eps*t) { + s(ks) = 0.0; + break; + } + } + if (ks == k) { + kase = 3; + } else if (ks == p-1) { + kase = 1; + } else { + kase = 2; + k = ks; + } + } + k++; + + // Perform the task indicated by kase. + + switch (kase) { + + // Deflate negligible s(p). + + case 1: { + Real f(e(p-2)); + e(p-2) = 0.0; + for (j = p-2; j >= k; j--) { + Real t( hypot(s(j), f)); + Real cs(s(j)/t); + Real sn(f/t); + s(j) = t; + if (j != k) { + f = -sn*e(j-1); + e(j-1) = cs*e(j-1); + } + if (wantv) { + for (i = 0; i < n; i++) { + t = cs*V(i, j) + sn*V(i, p-1); + V(i, p-1) = -sn*V(i, j) + cs*V(i, p-1); + V(i, j) = t; + } + } + } + } + break; + + // Split at negligible s(k). + + case 2: { + Real f(e(k-1)); + e(k-1) = 0.0; + for (j = k; j < p; j++) { + Real t(hypot(s(j), f)); + Real cs( s(j)/t); + Real sn(f/t); + s(j) = t; + f = -sn*e(j); + e(j) = cs*e(j); + if (wantu) { + for (i = 0; i < m; i++) { + t = cs*U(i, j) + sn*U(i, k-1); + U(i, k-1) = -sn*U(i, j) + cs*U(i, k-1); + U(i, j) = t; + } + } + } + } + break; + + // Perform one qr step. + + case 3: { + + // Calculate the shift. + + Real scale = std::max(std::max(std::max(std::max( + std::abs(s(p-1)), std::abs(s(p-2))), std::abs(e(p-2))), + std::abs(s(k))), std::abs(e(k))); + Real sp = s(p-1)/scale; + Real spm1 = s(p-2)/scale; + Real epm1 = e(p-2)/scale; + Real sk = s(k)/scale; + Real ek = e(k)/scale; + Real b = ((spm1 + sp)*(spm1 - sp) + epm1*epm1)/2.0; + Real c = (sp*epm1)*(sp*epm1); + Real shift = 0.0; + if ((b != 0.0) || (c != 0.0)) { + shift = std::sqrt(b*b + c); + if (b < 0.0) { + shift = -shift; + } + shift = c/(b + shift); + } + Real f = (sk + sp)*(sk - sp) + shift; + Real g = sk*ek; + + // Chase zeros. + + for (j = k; j < p-1; j++) { + Real t = hypot(f, g); + Real cs = f/t; + Real sn = g/t; + if (j != k) { + e(j-1) = t; + } + f = cs*s(j) + sn*e(j); + e(j) = cs*e(j) - sn*s(j); + g = sn*s(j+1); + s(j+1) = cs*s(j+1); + if (wantv) { + cblas_Xrot(n, vdata + j, vstride, vdata + j+1, vstride, cs, sn); + /*for (i = 0; i < n; i++) { + t = cs*vdata[i*vstride + j] + sn*vdata[i*vstride + j+1]; // t = cs*V(i, j) + sn*V(i, j+1); // 13 + vdata[i*vstride + j+1] = -sn*vdata[i*vstride + j] + cs*vdata[i*vstride + j+1]; // V(i, j+1) = -sn*V(i, j) + cs*V(i, j+1); // 5 + vdata[i*vstride + j] = t; // V(i, j) = t; // 4 + }*/ + } + t = hypot(f, g); + cs = f/t; + sn = g/t; + s(j) = t; + f = cs*e(j) + sn*s(j+1); + s(j+1) = -sn*e(j) + cs*s(j+1); + g = sn*e(j+1); + e(j+1) = cs*e(j+1); + if (wantu && (j < m-1)) { + cblas_Xrot(m, udata + j, ustride, udata + j+1, ustride, cs, sn); + /*for (i = 0; i < m; i++) { + t = cs*udata[i*ustride + j] + sn*udata[i*ustride + j+1]; // t = cs*U(i, j) + sn*U(i, j+1); // 7 + udata[i*ustride + j+1] = -sn*udata[i*ustride + j] +cs*udata[i*ustride + j+1]; // U(i, j+1) = -sn*U(i, j) + cs*U(i, j+1); // 8 + udata[i*ustride + j] = t; // U(i, j) = t; // 1 + }*/ + } + } + e(p-2) = f; + iter = iter + 1; + } + break; + + // Convergence. + + case 4: { + + // Make the singular values positive. + + if (s(k) <= 0.0) { + s(k) = (s(k) < 0.0 ? -s(k) : 0.0); + if (wantv) { + for (i = 0; i <= pp; i++) { + V(i, k) = -V(i, k); + } + } + } + + // Order the singular values. + + while (k < pp) { + if (s(k) >= s(k+1)) { + break; + } + Real t = s(k); + s(k) = s(k+1); + s(k+1) = t; + if (wantv && (k < n-1)) { + for (i = 0; i < n; i++) { + t = V(i, k+1); V(i, k+1) = V(i, k); V(i, k) = t; + } + } + if (wantu && (k < m-1)) { + for (i = 0; i < m; i++) { + t = U(i, k+1); U(i, k+1) = U(i, k); U(i, k) = t; + } + } + k++; + } + iter = 0; + p--; + } + break; + } + } + return true; +} + +#endif // defined(HAVE_ATLAS) || defined(USE_KALDI_SVD) + +} // namespace kaldi + +#endif // KALDI_MATRIX_JAMA_SVD_H_ diff --git a/kaldi_io/src/kaldi/matrix/kaldi-blas.h b/kaldi_io/src/kaldi/matrix/kaldi-blas.h new file mode 100644 index 0000000..5d25ab8 --- /dev/null +++ b/kaldi_io/src/kaldi/matrix/kaldi-blas.h @@ -0,0 +1,132 @@ +// matrix/kaldi-blas.h + +// Copyright 2009-2011 Ondrej Glembek; Microsoft Corporation + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at + +// http://www.apache.org/licenses/LICENSE-2.0 + +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. +#ifndef KALDI_MATRIX_KALDI_BLAS_H_ +#define KALDI_MATRIX_KALDI_BLAS_H_ + +// This file handles the #includes for BLAS, LAPACK and so on. +// It manipulates the declarations into a common format that kaldi can handle. +// However, the kaldi code will check whether HAVE_ATLAS is defined as that +// code is called a bit differently from CLAPACK that comes from other sources. + +// There are three alternatives: +// (i) you have ATLAS, which includes the ATLAS implementation of CBLAS +// plus a subset of CLAPACK (but with clapack_ in the function declarations). +// In this case, define HAVE_ATLAS and make sure the relevant directories are +// in the include path. + +// (ii) you have CBLAS (some implementation thereof) plus CLAPACK. +// In this case, define HAVE_CLAPACK. +// [Since CLAPACK depends on BLAS, the presence of BLAS is implicit]. + +// (iii) you have the MKL library, which includes CLAPACK and CBLAS. + +// Note that if we are using ATLAS, no Svd implementation is supplied, +// so we define HAVE_Svd to be zero and this directs our implementation to +// supply its own "by hand" implementation which is based on TNT code. + + + + +#if (defined(HAVE_CLAPACK) && (defined(HAVE_ATLAS) || defined(HAVE_MKL))) \ + || (defined(HAVE_ATLAS) && defined(HAVE_MKL)) +#error "Do not define more than one of HAVE_CLAPACK, HAVE_ATLAS and HAVE_MKL" +#endif + +#ifdef HAVE_ATLAS + extern "C" { + #include <cblas.h> + #include <clapack.h> + } +#elif defined(HAVE_CLAPACK) + #ifdef __APPLE__ + #ifndef __has_extension + #define __has_extension(x) 0 + #endif + #define vImage_Utilities_h + #define vImage_CVUtilities_h + #include <Accelerate/Accelerate.h> + typedef __CLPK_integer integer; + typedef __CLPK_logical logical; + typedef __CLPK_real real; + typedef __CLPK_doublereal doublereal; + typedef __CLPK_complex complex; + typedef __CLPK_doublecomplex doublecomplex; + typedef __CLPK_ftnlen ftnlen; + #else + extern "C" { + // May be in /usr/[local]/include if installed; else this uses the one + // from the tools/CLAPACK_include directory. + #include <cblas.h> + #include <f2c.h> + #include <clapack.h> + + // get rid of macros from f2c.h -- these are dangerous. + #undef abs + #undef dabs + #undef min + #undef max + #undef dmin + #undef dmax + #undef bit_test + #undef bit_clear + #undef bit_set + } + #endif +#elif defined(HAVE_MKL) + extern "C" { + #include <mkl.h> + } +#elif defined(HAVE_OPENBLAS) + // getting cblas.h and lapacke.h from <openblas-install-dir>/. + // putting in "" not <> to search -I before system libraries. + #include "cblas.h" + #include "lapacke.h" + #undef I + #undef complex + // get rid of macros from f2c.h -- these are dangerous. + #undef abs + #undef dabs + #undef min + #undef max + #undef dmin + #undef dmax + #undef bit_test + #undef bit_clear + #undef bit_set +#else + #error "You need to define (using the preprocessor) either HAVE_CLAPACK or HAVE_ATLAS or HAVE_MKL (but not more than one)" +#endif + +#ifdef HAVE_OPENBLAS +typedef int KaldiBlasInt; // try int. +#endif +#ifdef HAVE_CLAPACK +typedef integer KaldiBlasInt; +#endif +#ifdef HAVE_MKL +typedef MKL_INT KaldiBlasInt; +#endif + +#ifdef HAVE_ATLAS +// in this case there is no need for KaldiBlasInt-- this typedef is only needed +// for Svd code which is not included in ATLAS (we re-implement it). +#endif + + +#endif // KALDI_MATRIX_KALDI_BLAS_H_ diff --git a/kaldi_io/src/kaldi/matrix/kaldi-gpsr.h b/kaldi_io/src/kaldi/matrix/kaldi-gpsr.h new file mode 100644 index 0000000..c294bdd --- /dev/null +++ b/kaldi_io/src/kaldi/matrix/kaldi-gpsr.h @@ -0,0 +1,166 @@ +// matrix/kaldi-gpsr.h + +// Copyright 2012 Arnab Ghoshal + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#ifndef KALDI_MATRIX_KALDI_GPSR_H_ +#define KALDI_MATRIX_KALDI_GPSR_H_ + +#include <string> +#include <vector> + +#include "base/kaldi-common.h" +#include "matrix/matrix-lib.h" +#include "itf/options-itf.h" + +namespace kaldi { + +/// This is an implementation of the GPSR algorithm. See, Figueiredo, Nowak and +/// Wright, "Gradient Projection for Sparse Reconstruction: Application to +/// Compressed Sensing and Other Inverse Problems," IEEE Journal of Selected +/// Topics in Signal Processing, vol. 1, no. 4, pp. 586-597, 2007. +/// http://dx.doi.org/10.1109/JSTSP.2007.910281 + +/// The GPSR algorithm, described in Figueiredo, et al., 2007, solves: +/// \f[ \min_x 0.5 * ||y - Ax||_2^2 + \tau ||x||_1, \f] +/// where \f$ x \in R^n, y \in R^k \f$, and \f$ A \in R^{n \times k} \f$. +/// In this implementation, we solve: +/// \f[ \min_x 0.5 * x^T H x - g^T x + \tau ||x||_1, \f] +/// which is the more natural form in which such problems arise in our case. +/// Here, \f$ H = A^T A \in R^{n \times n} \f$ and \f$ g = A^T y \in R^n \f$. + + +/** \struct GpsrConfig + * Configuration variables needed in the GPSR algorithm. + */ +struct GpsrConfig { + bool use_gpsr_bb; ///< Use the Barzilai-Borwein gradient projection method + + /// The following options are common to both the basic & Barzilai-Borwein + /// versions of GPSR + double stop_thresh; ///< Stopping threshold + int32 max_iters; ///< Maximum number of iterations + double gpsr_tau; ///< Regularization scale + double alpha_min; ///< Minimum step size in the feasible direction + double alpha_max; ///< Maximum step size in the feasible direction + double max_sparsity; ///< Maximum percentage of dimensions set to 0 + double tau_reduction; ///< Multiply tau by this if max_sparsity reached + + /// The following options are for the backtracking line search in basic GPSR. + /// Step size reduction factor in backtracking line search. 0 < beta < 1 + double gpsr_beta; + /// Improvement factor in backtracking line search, i.e. the new objective + /// function must be less than the old one by mu times the gradient in the + /// direction of the change in x. 0 < mu < 1 + double gpsr_mu; + int32 max_iters_backtrak; ///< Max iterations for backtracking line search + + bool debias; ///< Do debiasing, i.e. unconstrained optimization at the end + double stop_thresh_debias; ///< Stopping threshold for debiasing stage + int32 max_iters_debias; ///< Maximum number of iterations for debiasing stage + + GpsrConfig() { + use_gpsr_bb = true; + + stop_thresh = 0.005; + max_iters = 100; + gpsr_tau = 10; + alpha_min = 1.0e-10; + alpha_max = 1.0e+20; + max_sparsity = 0.9; + tau_reduction = 0.8; + + gpsr_beta = 0.5; + gpsr_mu = 0.1; + max_iters_backtrak = 50; + + debias = false; + stop_thresh_debias = 0.001; + max_iters_debias = 50; + } + + void Register(OptionsItf *po); +}; + +inline void GpsrConfig::Register(OptionsItf *po) { + std::string module = "GpsrConfig: "; + po->Register("use-gpsr-bb", &use_gpsr_bb, module+ + "Use the Barzilai-Borwein gradient projection method."); + + po->Register("stop-thresh", &stop_thresh, module+ + "Stopping threshold for GPSR."); + po->Register("max-iters", &max_iters, module+ + "Maximum number of iterations of GPSR."); + po->Register("gpsr-tau", &gpsr_tau, module+ + "Regularization scale for GPSR."); + po->Register("alpha-min", &alpha_min, module+ + "Minimum step size in feasible direction."); + po->Register("alpha-max", &alpha_max, module+ + "Maximum step size in feasible direction."); + po->Register("max-sparsity", &max_sparsity, module+ + "Maximum percentage of dimensions set to 0."); + po->Register("tau-reduction", &tau_reduction, module+ + "Multiply tau by this if maximum sparsity is reached."); + + po->Register("gpsr-beta", &gpsr_beta, module+ + "Step size reduction factor in backtracking line search (0<beta<1)."); + po->Register("gpsr-mu", &gpsr_mu, module+ + "Improvement factor in backtracking line search (0<mu<1)."); + po->Register("max-iters-backtrack", &max_iters_backtrak, module+ + "Maximum number of iterations of backtracking line search."); + + po->Register("debias", &debias, module+ + "Do final debiasing step."); + po->Register("stop-thresh-debias", &stop_thresh_debias, module+ + "Stopping threshold for debiaisng step."); + po->Register("max-iters-debias", &max_iters_debias, module+ + "Maximum number of iterations of debiasing."); +} + +/// Solves a quadratic program in \f$ x \f$, with L_1 regularization: +/// \f[ \min_x 0.5 * x^T H x - g^T x + \tau ||x||_1. \f] +/// This is similar to SolveQuadraticProblem() in sp-matrix.h with an added +/// L_1 term. +template<typename Real> +Real Gpsr(const GpsrConfig &opts, const SpMatrix<Real> &H, + const Vector<Real> &g, Vector<Real> *x, + const char *debug_str = "[unknown]") { + if (opts.use_gpsr_bb) + return GpsrBB(opts, H, g, x, debug_str); + else + return GpsrBasic(opts, H, g, x, debug_str); +} + +/// This is the basic GPSR algorithm, where the step size is determined by a +/// backtracking line search. The line search is called "Armijo rule along the +/// projection arc" in Bertsekas, Nonlinear Programming, 2nd ed. page 230. +template<typename Real> +Real GpsrBasic(const GpsrConfig &opts, const SpMatrix<Real> &H, + const Vector<Real> &g, Vector<Real> *x, + const char *debug_str = "[unknown]"); + +/// This is the paper calls the Barzilai-Borwein variant. This is a constrained +/// Netwon's method where the Hessian is approximated by scaled identity matrix +template<typename Real> +Real GpsrBB(const GpsrConfig &opts, const SpMatrix<Real> &H, + const Vector<Real> &g, Vector<Real> *x, + const char *debug_str = "[unknown]"); + + +} // namespace kaldi + +#endif // KALDI_MATRIX_KALDI_GPSR_H_ diff --git a/kaldi_io/src/kaldi/matrix/kaldi-matrix-inl.h b/kaldi_io/src/kaldi/matrix/kaldi-matrix-inl.h new file mode 100644 index 0000000..8bc4749 --- /dev/null +++ b/kaldi_io/src/kaldi/matrix/kaldi-matrix-inl.h @@ -0,0 +1,62 @@ +// matrix/kaldi-matrix-inl.h + +// Copyright 2009-2011 Microsoft Corporation; Haihua Xu + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#ifndef KALDI_MATRIX_KALDI_MATRIX_INL_H_ +#define KALDI_MATRIX_KALDI_MATRIX_INL_H_ 1 + +#include "matrix/kaldi-vector.h" + +namespace kaldi { + +/// Empty constructor +template<typename Real> +Matrix<Real>::Matrix(): MatrixBase<Real>(NULL, 0, 0, 0) { } + + +template<> +template<> +void MatrixBase<float>::AddVecVec(const float alpha, const VectorBase<float> &ra, const VectorBase<float> &rb); + +template<> +template<> +void MatrixBase<double>::AddVecVec(const double alpha, const VectorBase<double> &ra, const VectorBase<double> &rb); + +template<typename Real> +inline std::ostream & operator << (std::ostream & os, const MatrixBase<Real> & M) { + M.Write(os, false); + return os; +} + +template<typename Real> +inline std::istream & operator >> (std::istream & is, Matrix<Real> & M) { + M.Read(is, false); + return is; +} + + +template<typename Real> +inline std::istream & operator >> (std::istream & is, MatrixBase<Real> & M) { + M.Read(is, false); + return is; +} + +}// namespace kaldi + + +#endif // KALDI_MATRIX_KALDI_MATRIX_INL_H_ diff --git a/kaldi_io/src/kaldi/matrix/kaldi-matrix.h b/kaldi_io/src/kaldi/matrix/kaldi-matrix.h new file mode 100644 index 0000000..e6829e0 --- /dev/null +++ b/kaldi_io/src/kaldi/matrix/kaldi-matrix.h @@ -0,0 +1,983 @@ +// matrix/kaldi-matrix.h + +// Copyright 2009-2011 Ondrej Glembek; Microsoft Corporation; Lukas Burget; +// Saarland University; Petr Schwarz; Yanmin Qian; +// Karel Vesely; Go Vivace Inc.; Haihua Xu + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#ifndef KALDI_MATRIX_KALDI_MATRIX_H_ +#define KALDI_MATRIX_KALDI_MATRIX_H_ 1 + +#include "matrix/matrix-common.h" + +namespace kaldi { + +/// @{ \addtogroup matrix_funcs_scalar + +/// We need to declare this here as it will be a friend function. +/// tr(A B), or tr(A B^T). +template<typename Real> +Real TraceMatMat(const MatrixBase<Real> &A, const MatrixBase<Real> &B, + MatrixTransposeType trans = kNoTrans); +/// @} + +/// \addtogroup matrix_group +/// @{ + +/// Base class which provides matrix operations not involving resizing +/// or allocation. Classes Matrix and SubMatrix inherit from it and take care +/// of allocation and resizing. +template<typename Real> +class MatrixBase { + public: + // so this child can access protected members of other instances. + friend class Matrix<Real>; + // friend declarations for CUDA matrices (see ../cudamatrix/) + friend class CuMatrixBase<Real>; + friend class CuMatrix<Real>; + friend class CuSubMatrix<Real>; + friend class CuPackedMatrix<Real>; + + friend class PackedMatrix<Real>; + + /// Returns number of rows (or zero for emtpy matrix). + inline MatrixIndexT NumRows() const { return num_rows_; } + + /// Returns number of columns (or zero for emtpy matrix). + inline MatrixIndexT NumCols() const { return num_cols_; } + + /// Stride (distance in memory between each row). Will be >= NumCols. + inline MatrixIndexT Stride() const { return stride_; } + + /// Returns size in bytes of the data held by the matrix. + size_t SizeInBytes() const { + return static_cast<size_t>(num_rows_) * static_cast<size_t>(stride_) * + sizeof(Real); + } + + /// Gives pointer to raw data (const). + inline const Real* Data() const { + return data_; + } + + /// Gives pointer to raw data (non-const). + inline Real* Data() { return data_; } + + /// Returns pointer to data for one row (non-const) + inline Real* RowData(MatrixIndexT i) { + KALDI_ASSERT(static_cast<UnsignedMatrixIndexT>(i) < + static_cast<UnsignedMatrixIndexT>(num_rows_)); + return data_ + i * stride_; + } + + /// Returns pointer to data for one row (const) + inline const Real* RowData(MatrixIndexT i) const { + KALDI_ASSERT(static_cast<UnsignedMatrixIndexT>(i) < + static_cast<UnsignedMatrixIndexT>(num_rows_)); + return data_ + i * stride_; + } + + /// Indexing operator, non-const + /// (only checks sizes if compiled with -DKALDI_PARANOID) + inline Real& operator() (MatrixIndexT r, MatrixIndexT c) { + KALDI_PARANOID_ASSERT(static_cast<UnsignedMatrixIndexT>(r) < + static_cast<UnsignedMatrixIndexT>(num_rows_) && + static_cast<UnsignedMatrixIndexT>(c) < + static_cast<UnsignedMatrixIndexT>(num_cols_)); + return *(data_ + r * stride_ + c); + } + /// Indexing operator, provided for ease of debugging (gdb doesn't work + /// with parenthesis operator). + Real &Index (MatrixIndexT r, MatrixIndexT c) { return (*this)(r, c); } + + /// Indexing operator, const + /// (only checks sizes if compiled with -DKALDI_PARANOID) + inline const Real operator() (MatrixIndexT r, MatrixIndexT c) const { + KALDI_PARANOID_ASSERT(static_cast<UnsignedMatrixIndexT>(r) < + static_cast<UnsignedMatrixIndexT>(num_rows_) && + static_cast<UnsignedMatrixIndexT>(c) < + static_cast<UnsignedMatrixIndexT>(num_cols_)); + return *(data_ + r * stride_ + c); + } + + /* Basic setting-to-special values functions. */ + + /// Sets matrix to zero. + void SetZero(); + /// Sets all elements to a specific value. + void Set(Real); + /// Sets to zero, except ones along diagonal [for non-square matrices too] + void SetUnit(); + /// Sets to random values of a normal distribution + void SetRandn(); + /// Sets to numbers uniformly distributed on (0, 1) + void SetRandUniform(); + + /* Copying functions. These do not resize the matrix! */ + + + /// Copy given matrix. (no resize is done). + template<typename OtherReal> + void CopyFromMat(const MatrixBase<OtherReal> & M, + MatrixTransposeType trans = kNoTrans); + + /// Copy from compressed matrix. + void CopyFromMat(const CompressedMatrix &M); + + /// Copy given spmatrix. (no resize is done). + template<typename OtherReal> + void CopyFromSp(const SpMatrix<OtherReal> &M); + + /// Copy given tpmatrix. (no resize is done). + template<typename OtherReal> + void CopyFromTp(const TpMatrix<OtherReal> &M, + MatrixTransposeType trans = kNoTrans); + + /// Copy from CUDA matrix. Implemented in ../cudamatrix/cu-matrix.h + template<typename OtherReal> + void CopyFromMat(const CuMatrixBase<OtherReal> &M, + MatrixTransposeType trans = kNoTrans); + + /// Inverse of vec() operator. Copies vector into matrix, row-by-row. + /// Note that rv.Dim() must either equal NumRows()*NumCols() or + /// NumCols()-- this has two modes of operation. + void CopyRowsFromVec(const VectorBase<Real> &v); + + /// This version of CopyRowsFromVec is implemented in ../cudamatrix/cu-vector.cc + void CopyRowsFromVec(const CuVectorBase<Real> &v); + + template<typename OtherReal> + void CopyRowsFromVec(const VectorBase<OtherReal> &v); + + /// Copies vector into matrix, column-by-column. + /// Note that rv.Dim() must either equal NumRows()*NumCols() or NumRows(); + /// this has two modes of operation. + void CopyColsFromVec(const VectorBase<Real> &v); + + /// Copy vector into specific column of matrix. + void CopyColFromVec(const VectorBase<Real> &v, const MatrixIndexT col); + /// Copy vector into specific row of matrix. + void CopyRowFromVec(const VectorBase<Real> &v, const MatrixIndexT row); + /// Copy vector into diagonal of matrix. + void CopyDiagFromVec(const VectorBase<Real> &v); + + /* Accessing of sub-parts of the matrix. */ + + /// Return specific row of matrix [const]. + inline const SubVector<Real> Row(MatrixIndexT i) const { + KALDI_ASSERT(static_cast<UnsignedMatrixIndexT>(i) < + static_cast<UnsignedMatrixIndexT>(num_rows_)); + return SubVector<Real>(data_ + (i * stride_), NumCols()); + } + + /// Return specific row of matrix. + inline SubVector<Real> Row(MatrixIndexT i) { + KALDI_ASSERT(static_cast<UnsignedMatrixIndexT>(i) < + static_cast<UnsignedMatrixIndexT>(num_rows_)); + return SubVector<Real>(data_ + (i * stride_), NumCols()); + } + + /// Return a sub-part of matrix. + inline SubMatrix<Real> Range(const MatrixIndexT row_offset, + const MatrixIndexT num_rows, + const MatrixIndexT col_offset, + const MatrixIndexT num_cols) const { + return SubMatrix<Real>(*this, row_offset, num_rows, + col_offset, num_cols); + } + inline SubMatrix<Real> RowRange(const MatrixIndexT row_offset, + const MatrixIndexT num_rows) const { + return SubMatrix<Real>(*this, row_offset, num_rows, 0, num_cols_); + } + inline SubMatrix<Real> ColRange(const MatrixIndexT col_offset, + const MatrixIndexT num_cols) const { + return SubMatrix<Real>(*this, 0, num_rows_, col_offset, num_cols); + } + + /* Various special functions. */ + /// Returns sum of all elements in matrix. + Real Sum() const; + /// Returns trace of matrix. + Real Trace(bool check_square = true) const; + // If check_square = true, will crash if matrix is not square. + + /// Returns maximum element of matrix. + Real Max() const; + /// Returns minimum element of matrix. + Real Min() const; + + /// Element by element multiplication with a given matrix. + void MulElements(const MatrixBase<Real> &A); + + /// Divide each element by the corresponding element of a given matrix. + void DivElements(const MatrixBase<Real> &A); + + /// Multiply each element with a scalar value. + void Scale(Real alpha); + + /// Set, element-by-element, *this = max(*this, A) + void Max(const MatrixBase<Real> &A); + + /// Equivalent to (*this) = (*this) * diag(scale). Scaling + /// each column by a scalar taken from that dimension of the vector. + void MulColsVec(const VectorBase<Real> &scale); + + /// Equivalent to (*this) = diag(scale) * (*this). Scaling + /// each row by a scalar taken from that dimension of the vector. + void MulRowsVec(const VectorBase<Real> &scale); + + /// Divide each row into src.NumCols() equal groups, and then scale i'th row's + /// j'th group of elements by src(i, j). Requires src.NumRows() == + /// this->NumRows() and this->NumCols() % src.NumCols() == 0. + void MulRowsGroupMat(const MatrixBase<Real> &src); + + /// Returns logdet of matrix. + Real LogDet(Real *det_sign = NULL) const; + + /// matrix inverse. + /// if inverse_needed = false, will fill matrix with garbage. + /// (only useful if logdet wanted). + void Invert(Real *log_det = NULL, Real *det_sign = NULL, + bool inverse_needed = true); + /// matrix inverse [double]. + /// if inverse_needed = false, will fill matrix with garbage + /// (only useful if logdet wanted). + /// Does inversion in double precision even if matrix was not double. + void InvertDouble(Real *LogDet = NULL, Real *det_sign = NULL, + bool inverse_needed = true); + + /// Inverts all the elements of the matrix + void InvertElements(); + + /// Transpose the matrix. This one is only + /// applicable to square matrices (the one in the + /// Matrix child class works also for non-square. + void Transpose(); + + /// Copies column r from column indices[r] of src. + /// As a special case, if indexes[i] == -1, sets column i to zero + /// indices.size() must equal this->NumCols(), + /// all elements of "reorder" must be in [-1, src.NumCols()-1], + /// and src.NumRows() must equal this.NumRows() + void CopyCols(const MatrixBase<Real> &src, + const std::vector<MatrixIndexT> &indices); + + /// Copies row r from row indices[r] of src. + /// As a special case, if indexes[i] == -1, sets row i to zero + /// "reorder".size() must equal this->NumRows(), + /// all elements of "reorder" must be in [-1, src.NumRows()-1], + /// and src.NumCols() must equal this.NumCols() + void CopyRows(const MatrixBase<Real> &src, + const std::vector<MatrixIndexT> &indices); + + /// Applies floor to all matrix elements + void ApplyFloor(Real floor_val); + + /// Applies floor to all matrix elements + void ApplyCeiling(Real ceiling_val); + + /// Calculates log of all the matrix elemnts + void ApplyLog(); + + /// Exponentiate each of the elements. + void ApplyExp(); + + /// Applies power to all matrix elements + void ApplyPow(Real power); + + /// Apply power to the absolute value of each element. + /// Include the sign of the input element if include_sign == true. + /// If the power is negative and the input to the power is zero, + /// The output will be set zero. + void ApplyPowAbs(Real power, bool include_sign=false); + + /// Applies the Heaviside step function (x > 0 ? 1 : 0) to all matrix elements + /// Note: in general you can make different choices for x = 0, but for now + /// please leave it as it (i.e. returning zero) because it affects the + /// RectifiedLinearComponent in the neural net code. + void ApplyHeaviside(); + + /// Eigenvalue Decomposition of a square NxN matrix into the form (*this) = P D + /// P^{-1}. Be careful: the relationship of D to the eigenvalues we output is + /// slightly complicated, due to the need for P to be real. In the symmetric + /// case D is diagonal and real, but in + /// the non-symmetric case there may be complex-conjugate pairs of eigenvalues. + /// In this case, for the equation (*this) = P D P^{-1} to hold, D must actually + /// be block diagonal, with 2x2 blocks corresponding to any such pairs. If a + /// pair is lambda +- i*mu, D will have a corresponding 2x2 block + /// [lambda, mu; -mu, lambda]. + /// Note that if the input matrix (*this) is non-invertible, P may not be invertible + /// so in this case instead of the equation (*this) = P D P^{-1} holding, we have + /// instead (*this) P = P D. + /// + /// The non-member function CreateEigenvalueMatrix creates D from eigs_real and eigs_imag. + void Eig(MatrixBase<Real> *P, + VectorBase<Real> *eigs_real, + VectorBase<Real> *eigs_imag) const; + + /// The Power method attempts to take the matrix to a power using a method that + /// works in general for fractional and negative powers. The input matrix must + /// be invertible and have reasonable condition (or we don't guarantee the + /// results. The method is based on the eigenvalue decomposition. It will + /// return false and leave the matrix unchanged, if at entry the matrix had + /// real negative eigenvalues (or if it had zero eigenvalues and the power was + /// negative). + bool Power(Real pow); + + /** Singular value decomposition + Major limitations: + For nonsquare matrices, we assume m>=n (NumRows >= NumCols), and we return + the "skinny" Svd, i.e. the matrix in the middle is diagonal, and the + one on the left is rectangular. + + In Svd, *this = U*diag(S)*Vt. + Null pointers for U and/or Vt at input mean we do not want that output. We + expect that S.Dim() == m, U is either NULL or m by n, + and v is either NULL or n by n. + The singular values are not sorted (use SortSvd for that). */ + void DestructiveSvd(VectorBase<Real> *s, MatrixBase<Real> *U, + MatrixBase<Real> *Vt); // Destroys calling matrix. + + /// Compute SVD (*this) = U diag(s) Vt. Note that the V in the call is already + /// transposed; the normal formulation is U diag(s) V^T. + /// Null pointers for U or V mean we don't want that output (this saves + /// compute). The singular values are not sorted (use SortSvd for that). + void Svd(VectorBase<Real> *s, MatrixBase<Real> *U, + MatrixBase<Real> *Vt) const; + /// Compute SVD but only retain the singular values. + void Svd(VectorBase<Real> *s) const { Svd(s, NULL, NULL); } + + + /// Returns smallest singular value. + Real MinSingularValue() const { + Vector<Real> tmp(std::min(NumRows(), NumCols())); + Svd(&tmp); + return tmp.Min(); + } + + void TestUninitialized() const; // This function is designed so that if any element + // if the matrix is uninitialized memory, valgrind will complain. + + /// Returns condition number by computing Svd. Works even if cols > rows. + /// Returns infinity if all singular values are zero. + Real Cond() const; + + /// Returns true if matrix is Symmetric. + bool IsSymmetric(Real cutoff = 1.0e-05) const; // replace magic number + + /// Returns true if matrix is Diagonal. + bool IsDiagonal(Real cutoff = 1.0e-05) const; // replace magic number + + /// Returns true if the matrix is all zeros, except for ones on diagonal. (it + /// does not have to be square). More specifically, this function returns + /// false if for any i, j, (*this)(i, j) differs by more than cutoff from the + /// expression (i == j ? 1 : 0). + bool IsUnit(Real cutoff = 1.0e-05) const; // replace magic number + + /// Returns true if matrix is all zeros. + bool IsZero(Real cutoff = 1.0e-05) const; // replace magic number + + /// Frobenius norm, which is the sqrt of sum of square elements. Same as Schatten 2-norm, + /// or just "2-norm". + Real FrobeniusNorm() const; + + /// Returns true if ((*this)-other).FrobeniusNorm() + /// <= tol * (*this).FrobeniusNorm(). + bool ApproxEqual(const MatrixBase<Real> &other, float tol = 0.01) const; + + /// Tests for exact equality. It's usually preferable to use ApproxEqual. + bool Equal(const MatrixBase<Real> &other) const; + + /// largest absolute value. + Real LargestAbsElem() const; // largest absolute value. + + /// Returns log(sum(exp())) without exp overflow + /// If prune > 0.0, it uses a pruning beam, discarding + /// terms less than (max - prune). Note: in future + /// we may change this so that if prune = 0.0, it takes + /// the max, so use -1 if you don't want to prune. + Real LogSumExp(Real prune = -1.0) const; + + /// Apply soft-max to the collection of all elements of the + /// matrix and return normalizer (log sum of exponentials). + Real ApplySoftMax(); + + /// Set each element to the sigmoid of the corresponding element of "src". + void Sigmoid(const MatrixBase<Real> &src); + + /// Set each element to y = log(1 + exp(x)) + void SoftHinge(const MatrixBase<Real> &src); + + /// Apply the function y(i) = (sum_{j = i*G}^{(i+1)*G-1} x_j^(power))^(1 / p). + /// Requires src.NumRows() == this->NumRows() and src.NumCols() % this->NumCols() == 0. + void GroupPnorm(const MatrixBase<Real> &src, Real power); + + + /// Calculate derivatives for the GroupPnorm function above... + /// if "input" is the input to the GroupPnorm function above (i.e. the "src" variable), + /// and "output" is the result of the computation (i.e. the "this" of that function + /// call), and *this has the same dimension as "input", then it sets each element + /// of *this to the derivative d(output-elem)/d(input-elem) for each element of "input", where + /// "output-elem" is whichever element of output depends on that input element. + void GroupPnormDeriv(const MatrixBase<Real> &input, const MatrixBase<Real> &output, + Real power); + + + /// Set each element to the tanh of the corresponding element of "src". + void Tanh(const MatrixBase<Real> &src); + + // Function used in backpropagating derivatives of the sigmoid function: + // element-by-element, set *this = diff * value * (1.0 - value). + void DiffSigmoid(const MatrixBase<Real> &value, + const MatrixBase<Real> &diff); + + // Function used in backpropagating derivatives of the tanh function: + // element-by-element, set *this = diff * (1.0 - value^2). + void DiffTanh(const MatrixBase<Real> &value, + const MatrixBase<Real> &diff); + + /** Uses Svd to compute the eigenvalue decomposition of a symmetric positive + * semi-definite matrix: (*this) = rP * diag(rS) * rP^T, with rP an + * orthogonal matrix so rP^{-1} = rP^T. Throws exception if input was not + * positive semi-definite (check_thresh controls how stringent the check is; + * set it to 2 to ensure it won't ever complain, but it will zero out negative + * dimensions in your matrix. + */ + void SymPosSemiDefEig(VectorBase<Real> *s, MatrixBase<Real> *P, + Real check_thresh = 0.001); + + friend Real kaldi::TraceMatMat<Real>(const MatrixBase<Real> &A, + const MatrixBase<Real> &B, MatrixTransposeType trans); // tr (A B) + + // so it can get around const restrictions on the pointer to data_. + friend class SubMatrix<Real>; + + /// Add a scalar to each element + void Add(const Real alpha); + + /// Add a scalar to each diagonal element. + void AddToDiag(const Real alpha); + + /// *this += alpha * a * b^T + template<typename OtherReal> + void AddVecVec(const Real alpha, const VectorBase<OtherReal> &a, + const VectorBase<OtherReal> &b); + + /// [each row of *this] += alpha * v + template<typename OtherReal> + void AddVecToRows(const Real alpha, const VectorBase<OtherReal> &v); + + /// [each col of *this] += alpha * v + template<typename OtherReal> + void AddVecToCols(const Real alpha, const VectorBase<OtherReal> &v); + + /// *this += alpha * M [or M^T] + void AddMat(const Real alpha, const MatrixBase<Real> &M, + MatrixTransposeType transA = kNoTrans); + + /// *this = beta * *this + alpha * M M^T, for symmetric matrices. It only + /// updates the lower triangle of *this. It will leave the matrix asymmetric; + /// if you need it symmetric as a regular matrix, do CopyLowerToUpper(). + void SymAddMat2(const Real alpha, const MatrixBase<Real> &M, + MatrixTransposeType transA, Real beta); + + /// *this = beta * *this + alpha * diag(v) * M [or M^T]. + /// The same as adding M but scaling each row M_i by v(i). + void AddDiagVecMat(const Real alpha, VectorBase<Real> &v, + const MatrixBase<Real> &M, MatrixTransposeType transM, + Real beta = 1.0); + + /// *this = beta * *this + alpha * M [or M^T] * diag(v) + /// The same as adding M but scaling each column M_j by v(j). + void AddMatDiagVec(const Real alpha, + const MatrixBase<Real> &M, MatrixTransposeType transM, + VectorBase<Real> &v, + Real beta = 1.0); + + /// *this = beta * *this + alpha * A .* B (.* element by element multiplication) + void AddMatMatElements(const Real alpha, + const MatrixBase<Real>& A, + const MatrixBase<Real>& B, + const Real beta); + + /// *this += alpha * S + template<typename OtherReal> + void AddSp(const Real alpha, const SpMatrix<OtherReal> &S); + + void AddMatMat(const Real alpha, + const MatrixBase<Real>& A, MatrixTransposeType transA, + const MatrixBase<Real>& B, MatrixTransposeType transB, + const Real beta); + + /// *this = a * b / c (by element; when c = 0, *this = a) + void AddMatMatDivMat(const MatrixBase<Real>& A, + const MatrixBase<Real>& B, + const MatrixBase<Real>& C); + + /// A version of AddMatMat specialized for when the second argument + /// contains a lot of zeroes. + void AddMatSmat(const Real alpha, + const MatrixBase<Real>& A, MatrixTransposeType transA, + const MatrixBase<Real>& B, MatrixTransposeType transB, + const Real beta); + + /// A version of AddMatMat specialized for when the first argument + /// contains a lot of zeroes. + void AddSmatMat(const Real alpha, + const MatrixBase<Real>& A, MatrixTransposeType transA, + const MatrixBase<Real>& B, MatrixTransposeType transB, + const Real beta); + + /// this <-- beta*this + alpha*A*B*C. + void AddMatMatMat(const Real alpha, + const MatrixBase<Real>& A, MatrixTransposeType transA, + const MatrixBase<Real>& B, MatrixTransposeType transB, + const MatrixBase<Real>& C, MatrixTransposeType transC, + const Real beta); + + /// this <-- beta*this + alpha*SpA*B. + // This and the routines below are really + // stubs that need to be made more efficient. + void AddSpMat(const Real alpha, + const SpMatrix<Real>& A, + const MatrixBase<Real>& B, MatrixTransposeType transB, + const Real beta) { + Matrix<Real> M(A); + return AddMatMat(alpha, M, kNoTrans, B, transB, beta); + } + /// this <-- beta*this + alpha*A*B. + void AddTpMat(const Real alpha, + const TpMatrix<Real>& A, MatrixTransposeType transA, + const MatrixBase<Real>& B, MatrixTransposeType transB, + const Real beta) { + Matrix<Real> M(A); + return AddMatMat(alpha, M, transA, B, transB, beta); + } + /// this <-- beta*this + alpha*A*B. + void AddMatSp(const Real alpha, + const MatrixBase<Real>& A, MatrixTransposeType transA, + const SpMatrix<Real>& B, + const Real beta) { + Matrix<Real> M(B); + return AddMatMat(alpha, A, transA, M, kNoTrans, beta); + } + /// this <-- beta*this + alpha*A*B*C. + void AddSpMatSp(const Real alpha, + const SpMatrix<Real> &A, + const MatrixBase<Real>& B, MatrixTransposeType transB, + const SpMatrix<Real>& C, + const Real beta) { + Matrix<Real> M(A), N(C); + return AddMatMatMat(alpha, M, kNoTrans, B, transB, N, kNoTrans, beta); + } + /// this <-- beta*this + alpha*A*B. + void AddMatTp(const Real alpha, + const MatrixBase<Real>& A, MatrixTransposeType transA, + const TpMatrix<Real>& B, MatrixTransposeType transB, + const Real beta) { + Matrix<Real> M(B); + return AddMatMat(alpha, A, transA, M, transB, beta); + } + + /// this <-- beta*this + alpha*A*B. + void AddTpTp(const Real alpha, + const TpMatrix<Real>& A, MatrixTransposeType transA, + const TpMatrix<Real>& B, MatrixTransposeType transB, + const Real beta) { + Matrix<Real> M(A), N(B); + return AddMatMat(alpha, M, transA, N, transB, beta); + } + + /// this <-- beta*this + alpha*A*B. + // This one is more efficient, not like the others above. + void AddSpSp(const Real alpha, + const SpMatrix<Real>& A, const SpMatrix<Real>& B, + const Real beta); + + /// Copy lower triangle to upper triangle (symmetrize) + void CopyLowerToUpper(); + + /// Copy upper triangle to lower triangle (symmetrize) + void CopyUpperToLower(); + + /// This function orthogonalizes the rows of a matrix using the Gram-Schmidt + /// process. It is only applicable if NumRows() <= NumCols(). It will use + /// random number generation to fill in rows with something nonzero, in cases + /// where the original matrix was of deficient row rank. + void OrthogonalizeRows(); + + /// stream read. + /// Use instead of stream<<*this, if you want to add to existing contents. + // Will throw exception on failure. + void Read(std::istream & in, bool binary, bool add = false); + /// write to stream. + void Write(std::ostream & out, bool binary) const; + + // Below is internal methods for Svd, user does not have to know about this. +#if !defined(HAVE_ATLAS) && !defined(USE_KALDI_SVD) + // protected: + // Should be protected but used directly in testing routine. + // destroys *this! + void LapackGesvd(VectorBase<Real> *s, MatrixBase<Real> *U, + MatrixBase<Real> *Vt); +#else + protected: + // destroys *this! + bool JamaSvd(VectorBase<Real> *s, MatrixBase<Real> *U, + MatrixBase<Real> *V); + +#endif + protected: + + /// Initializer, callable only from child. + explicit MatrixBase(Real *data, MatrixIndexT cols, MatrixIndexT rows, MatrixIndexT stride) : + data_(data), num_cols_(cols), num_rows_(rows), stride_(stride) { + KALDI_ASSERT_IS_FLOATING_TYPE(Real); + } + + /// Initializer, callable only from child. + /// Empty initializer, for un-initialized matrix. + explicit MatrixBase(): data_(NULL) { + KALDI_ASSERT_IS_FLOATING_TYPE(Real); + } + + // Make sure pointers to MatrixBase cannot be deleted. + ~MatrixBase() { } + + /// A workaround that allows SubMatrix to get a pointer to non-const data + /// for const Matrix. Unfortunately C++ does not allow us to declare a + /// "public const" inheritance or anything like that, so it would require + /// a lot of work to make the SubMatrix class totally const-correct-- + /// we would have to override many of the Matrix functions. + inline Real* Data_workaround() const { + return data_; + } + + /// data memory area + Real* data_; + + /// these atributes store the real matrix size as it is stored in memory + /// including memalignment + MatrixIndexT num_cols_; /// < Number of columns + MatrixIndexT num_rows_; /// < Number of rows + /** True number of columns for the internal matrix. This number may differ + * from num_cols_ as memory alignment might be used. */ + MatrixIndexT stride_; + private: + KALDI_DISALLOW_COPY_AND_ASSIGN(MatrixBase); +}; + +/// A class for storing matrices. +template<typename Real> +class Matrix : public MatrixBase<Real> { + public: + + /// Empty constructor. + Matrix(); + + /// Basic constructor. Sets to zero by default. + /// if set_zero == false, memory contents are undefined. + Matrix(const MatrixIndexT r, const MatrixIndexT c, + MatrixResizeType resize_type = kSetZero): + MatrixBase<Real>() { Resize(r, c, resize_type); } + + /// Copy constructor from CUDA matrix + /// This is defined in ../cudamatrix/cu-matrix.h + template<typename OtherReal> + explicit Matrix(const CuMatrixBase<OtherReal> &cu, + MatrixTransposeType trans = kNoTrans); + + + /// Swaps the contents of *this and *other. Shallow swap. + void Swap(Matrix<Real> *other); + + /// Defined in ../cudamatrix/cu-matrix.cc + void Swap(CuMatrix<Real> *mat); + + /// Constructor from any MatrixBase. Can also copy with transpose. + /// Allocates new memory. + explicit Matrix(const MatrixBase<Real> & M, + MatrixTransposeType trans = kNoTrans); + + /// Same as above, but need to avoid default copy constructor. + Matrix(const Matrix<Real> & M); // (cannot make explicit) + + /// Copy constructor: as above, but from another type. + template<typename OtherReal> + explicit Matrix(const MatrixBase<OtherReal> & M, + MatrixTransposeType trans = kNoTrans); + + /// Copy constructor taking SpMatrix... + /// It is symmetric, so no option for transpose, and NumRows == Cols + template<typename OtherReal> + explicit Matrix(const SpMatrix<OtherReal> & M) : MatrixBase<Real>() { + Resize(M.NumRows(), M.NumRows(), kUndefined); + this->CopyFromSp(M); + } + + /// Constructor from CompressedMatrix + explicit Matrix(const CompressedMatrix &C); + + /// Copy constructor taking TpMatrix... + template <typename OtherReal> + explicit Matrix(const TpMatrix<OtherReal> & M, + MatrixTransposeType trans = kNoTrans) : MatrixBase<Real>() { + if (trans == kNoTrans) { + Resize(M.NumRows(), M.NumCols(), kUndefined); + this->CopyFromTp(M); + } else { + Resize(M.NumCols(), M.NumRows(), kUndefined); + this->CopyFromTp(M, kTrans); + } + } + + /// read from stream. + // Unlike one in base, allows resizing. + void Read(std::istream & in, bool binary, bool add = false); + + /// Remove a specified row. + void RemoveRow(MatrixIndexT i); + + /// Transpose the matrix. Works for non-square + /// matrices as well as square ones. + void Transpose(); + + /// Distructor to free matrices. + ~Matrix() { Destroy(); } + + /// Sets matrix to a specified size (zero is OK as long as both r and c are + /// zero). The value of the new data depends on resize_type: + /// -if kSetZero, the new data will be zero + /// -if kUndefined, the new data will be undefined + /// -if kCopyData, the new data will be the same as the old data in any + /// shared positions, and zero elsewhere. + /// This function takes time proportional to the number of data elements. + void Resize(const MatrixIndexT r, + const MatrixIndexT c, + MatrixResizeType resize_type = kSetZero); + + /// Assignment operator that takes MatrixBase. + Matrix<Real> &operator = (const MatrixBase<Real> &other) { + if (MatrixBase<Real>::NumRows() != other.NumRows() || + MatrixBase<Real>::NumCols() != other.NumCols()) + Resize(other.NumRows(), other.NumCols(), kUndefined); + MatrixBase<Real>::CopyFromMat(other); + return *this; + } + + /// Assignment operator. Needed for inclusion in std::vector. + Matrix<Real> &operator = (const Matrix<Real> &other) { + if (MatrixBase<Real>::NumRows() != other.NumRows() || + MatrixBase<Real>::NumCols() != other.NumCols()) + Resize(other.NumRows(), other.NumCols(), kUndefined); + MatrixBase<Real>::CopyFromMat(other); + return *this; + } + + + private: + /// Deallocates memory and sets to empty matrix (dimension 0, 0). + void Destroy(); + + /// Init assumes the current class contents are invalid (i.e. junk or have + /// already been freed), and it sets the matrix to newly allocated memory with + /// the specified number of rows and columns. r == c == 0 is acceptable. The data + /// memory contents will be undefined. + void Init(const MatrixIndexT r, + const MatrixIndexT c); + +}; +/// @} end "addtogroup matrix_group" + +/// \addtogroup matrix_funcs_io +/// @{ + +/// A structure containing the HTK header. +/// [TODO: change the style of the variables to Kaldi-compliant] +struct HtkHeader { + /// Number of samples. + int32 mNSamples; + /// Sample period. + int32 mSamplePeriod; + /// Sample size + int16 mSampleSize; + /// Sample kind. + uint16 mSampleKind; +}; + +// Read HTK formatted features from file into matrix. +template<typename Real> +bool ReadHtk(std::istream &is, Matrix<Real> *M, HtkHeader *header_ptr); + +// Write (HTK format) features to file from matrix. +template<typename Real> +bool WriteHtk(std::ostream &os, const MatrixBase<Real> &M, HtkHeader htk_hdr); + +// Write (CMUSphinx format) features to file from matrix. +template<typename Real> +bool WriteSphinx(std::ostream &os, const MatrixBase<Real> &M); + +/// @} end of "addtogroup matrix_funcs_io" + +/** + Sub-matrix representation. + Can work with sub-parts of a matrix using this class. + Note that SubMatrix is not very const-correct-- it allows you to + change the contents of a const Matrix. Be careful! +*/ + +template<typename Real> +class SubMatrix : public MatrixBase<Real> { + public: + // Initialize a SubMatrix from part of a matrix; this is + // a bit like A(b:c, d:e) in Matlab. + // This initializer is against the proper semantics of "const", since + // SubMatrix can change its contents. It would be hard to implement + // a "const-safe" version of this class. + SubMatrix(const MatrixBase<Real>& T, + const MatrixIndexT ro, // row offset, 0 < ro < NumRows() + const MatrixIndexT r, // number of rows, r > 0 + const MatrixIndexT co, // column offset, 0 < co < NumCols() + const MatrixIndexT c); // number of columns, c > 0 + + // This initializer is mostly intended for use in CuMatrix and related + // classes. Be careful! + SubMatrix(Real *data, + MatrixIndexT num_rows, + MatrixIndexT num_cols, + MatrixIndexT stride); + + ~SubMatrix<Real>() {} + + /// This type of constructor is needed for Range() to work [in Matrix base + /// class]. Cannot make it explicit. + SubMatrix<Real> (const SubMatrix &other): + MatrixBase<Real> (other.data_, other.num_cols_, other.num_rows_, + other.stride_) {} + + private: + /// Disallow assignment. + SubMatrix<Real> &operator = (const SubMatrix<Real> &other); +}; +/// @} End of "addtogroup matrix_funcs_io". + +/// \addtogroup matrix_funcs_scalar +/// @{ + +// Some declarations. These are traces of products. + + +template<typename Real> +bool ApproxEqual(const MatrixBase<Real> &A, + const MatrixBase<Real> &B, Real tol = 0.01) { + return A.ApproxEqual(B, tol); +} + +template<typename Real> +inline void AssertEqual(const MatrixBase<Real> &A, const MatrixBase<Real> &B, + float tol = 0.01) { + KALDI_ASSERT(A.ApproxEqual(B, tol)); +} + +/// Returns trace of matrix. +template <typename Real> +double TraceMat(const MatrixBase<Real> &A) { return A.Trace(); } + + +/// Returns tr(A B C) +template <typename Real> +Real TraceMatMatMat(const MatrixBase<Real> &A, MatrixTransposeType transA, + const MatrixBase<Real> &B, MatrixTransposeType transB, + const MatrixBase<Real> &C, MatrixTransposeType transC); + +/// Returns tr(A B C D) +template <typename Real> +Real TraceMatMatMatMat(const MatrixBase<Real> &A, MatrixTransposeType transA, + const MatrixBase<Real> &B, MatrixTransposeType transB, + const MatrixBase<Real> &C, MatrixTransposeType transC, + const MatrixBase<Real> &D, MatrixTransposeType transD); + +/// @} end "addtogroup matrix_funcs_scalar" + + +/// \addtogroup matrix_funcs_misc +/// @{ + + +/// Function to ensure that SVD is sorted. This function is made as generic as +/// possible, to be applicable to other types of problems. s->Dim() should be +/// the same as U->NumCols(), and we sort s from greatest to least absolute +/// value (if sort_on_absolute_value == true) or greatest to least value +/// otherwise, moving the columns of U, if it exists, and the rows of Vt, if it +/// exists, around in the same way. Note: the "absolute value" part won't matter +/// if this is an actual SVD, since singular values are non-negative. +template<typename Real> void SortSvd(VectorBase<Real> *s, MatrixBase<Real> *U, + MatrixBase<Real>* Vt = NULL, + bool sort_on_absolute_value = true); + +/// Creates the eigenvalue matrix D that is part of the decomposition used Matrix::Eig. +/// D will be block-diagonal with blocks of size 1 (for real eigenvalues) or 2x2 +/// for complex pairs. If a complex pair is lambda +- i*mu, D will have a corresponding +/// 2x2 block [lambda, mu; -mu, lambda]. +/// This function will throw if any complex eigenvalues are not in complex conjugate +/// pairs (or the members of such pairs are not consecutively numbered). +template<typename Real> +void CreateEigenvalueMatrix(const VectorBase<Real> &real, const VectorBase<Real> &imag, + MatrixBase<Real> *D); + +/// The following function is used in Matrix::Power, and separately tested, so we +/// declare it here mainly for the testing code to see. It takes a complex value to +/// a power using a method that will work for noninteger powers (but will fail if the +/// complex value is real and negative). +template<typename Real> +bool AttemptComplexPower(Real *x_re, Real *x_im, Real power); + + + +/// @} end of addtogroup matrix_funcs_misc + +/// \addtogroup matrix_funcs_io +/// @{ +template<typename Real> +std::ostream & operator << (std::ostream & Out, const MatrixBase<Real> & M); + +template<typename Real> +std::istream & operator >> (std::istream & In, MatrixBase<Real> & M); + +// The Matrix read allows resizing, so we override the MatrixBase one. +template<typename Real> +std::istream & operator >> (std::istream & In, Matrix<Real> & M); + + +template<typename Real> +bool SameDim(const MatrixBase<Real> &M, const MatrixBase<Real> &N) { + return (M.NumRows() == N.NumRows() && M.NumCols() == N.NumCols()); +} + +/// @} end of \addtogroup matrix_funcs_io + + +} // namespace kaldi + + + +// we need to include the implementation and some +// template specializations. +#include "matrix/kaldi-matrix-inl.h" + + +#endif // KALDI_MATRIX_KALDI_MATRIX_H_ diff --git a/kaldi_io/src/kaldi/matrix/kaldi-vector-inl.h b/kaldi_io/src/kaldi/matrix/kaldi-vector-inl.h new file mode 100644 index 0000000..c3a4f52 --- /dev/null +++ b/kaldi_io/src/kaldi/matrix/kaldi-vector-inl.h @@ -0,0 +1,58 @@ +// matrix/kaldi-vector-inl.h + +// Copyright 2009-2011 Ondrej Glembek; Microsoft Corporation; +// Haihua Xu + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +// This is an internal header file, included by other library headers. +// You should not attempt to use it directly. + +#ifndef KALDI_MATRIX_KALDI_VECTOR_INL_H_ +#define KALDI_MATRIX_KALDI_VECTOR_INL_H_ 1 + +namespace kaldi { + +template<typename Real> +std::ostream & operator << (std::ostream &os, const VectorBase<Real> &rv) { + rv.Write(os, false); + return os; +} + +template<typename Real> +std::istream &operator >> (std::istream &is, VectorBase<Real> &rv) { + rv.Read(is, false); + return is; +} + +template<typename Real> +std::istream &operator >> (std::istream &is, Vector<Real> &rv) { + rv.Read(is, false); + return is; +} + +template<> +template<> +void VectorBase<float>::AddVec(const float alpha, const VectorBase<float> &rv); + +template<> +template<> +void VectorBase<double>::AddVec<double>(const double alpha, + const VectorBase<double> &rv); + +} // namespace kaldi + +#endif // KALDI_MATRIX_KALDI_VECTOR_INL_H_ diff --git a/kaldi_io/src/kaldi/matrix/kaldi-vector.h b/kaldi_io/src/kaldi/matrix/kaldi-vector.h new file mode 100644 index 0000000..2b3395b --- /dev/null +++ b/kaldi_io/src/kaldi/matrix/kaldi-vector.h @@ -0,0 +1,585 @@ +// matrix/kaldi-vector.h + +// Copyright 2009-2012 Ondrej Glembek; Microsoft Corporation; Lukas Burget; +// Saarland University (Author: Arnab Ghoshal); +// Ariya Rastrow; Petr Schwarz; Yanmin Qian; +// Karel Vesely; Go Vivace Inc.; Arnab Ghoshal +// Wei Shi; + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#ifndef KALDI_MATRIX_KALDI_VECTOR_H_ +#define KALDI_MATRIX_KALDI_VECTOR_H_ 1 + +#include "matrix/matrix-common.h" + +namespace kaldi { + +/// \addtogroup matrix_group +/// @{ + +/// Provides a vector abstraction class. +/// This class provides a way to work with vectors in kaldi. +/// It encapsulates basic operations and memory optimizations. +template<typename Real> +class VectorBase { + public: + /// Set vector to all zeros. + void SetZero(); + + /// Returns true if matrix is all zeros. + bool IsZero(Real cutoff = 1.0e-06) const; // replace magic number + + /// Set all members of a vector to a specified value. + void Set(Real f); + + /// Set vector to random normally-distributed noise. + void SetRandn(); + + /// This function returns a random index into this vector, + /// chosen with probability proportional to the corresponding + /// element. Requires that this->Min() >= 0 and this->Sum() > 0. + MatrixIndexT RandCategorical() const; + + /// Returns the dimension of the vector. + inline MatrixIndexT Dim() const { return dim_; } + + /// Returns the size in memory of the vector, in bytes. + inline MatrixIndexT SizeInBytes() const { return (dim_*sizeof(Real)); } + + /// Returns a pointer to the start of the vector's data. + inline Real* Data() { return data_; } + + /// Returns a pointer to the start of the vector's data (const). + inline const Real* Data() const { return data_; } + + /// Indexing operator (const). + inline Real operator() (MatrixIndexT i) const { + KALDI_PARANOID_ASSERT(static_cast<UnsignedMatrixIndexT>(i) < + static_cast<UnsignedMatrixIndexT>(dim_)); + return *(data_ + i); + } + + /// Indexing operator (non-const). + inline Real & operator() (MatrixIndexT i) { + KALDI_PARANOID_ASSERT(static_cast<UnsignedMatrixIndexT>(i) < + static_cast<UnsignedMatrixIndexT>(dim_)); + return *(data_ + i); + } + + /** @brief Returns a sub-vector of a vector (a range of elements). + * @param o [in] Origin, 0 < o < Dim() + * @param l [in] Length 0 < l < Dim()-o + * @return A SubVector object that aliases the data of the Vector object. + * See @c SubVector class for details */ + SubVector<Real> Range(const MatrixIndexT o, const MatrixIndexT l) { + return SubVector<Real>(*this, o, l); + } + + /** @brief Returns a const sub-vector of a vector (a range of elements). + * @param o [in] Origin, 0 < o < Dim() + * @param l [in] Length 0 < l < Dim()-o + * @return A SubVector object that aliases the data of the Vector object. + * See @c SubVector class for details */ + const SubVector<Real> Range(const MatrixIndexT o, + const MatrixIndexT l) const { + return SubVector<Real>(*this, o, l); + } + + /// Copy data from another vector (must match own size). + void CopyFromVec(const VectorBase<Real> &v); + + /// Copy data from a SpMatrix or TpMatrix (must match own size). + template<typename OtherReal> + void CopyFromPacked(const PackedMatrix<OtherReal> &M); + + /// Copy data from another vector of different type (double vs. float) + template<typename OtherReal> + void CopyFromVec(const VectorBase<OtherReal> &v); + + /// Copy from CuVector. This is defined in ../cudamatrix/cu-vector.h + template<typename OtherReal> + void CopyFromVec(const CuVectorBase<OtherReal> &v); + + + /// Apply natural log to all elements. Throw if any element of + /// the vector is negative (but doesn't complain about zero; the + /// log will be -infinity + void ApplyLog(); + + /// Apply natural log to another vector and put result in *this. + void ApplyLogAndCopy(const VectorBase<Real> &v); + + /// Apply exponential to each value in vector. + void ApplyExp(); + + /// Take absolute value of each of the elements + void ApplyAbs(); + + /// Applies floor to all elements. Returns number of elements floored. + MatrixIndexT ApplyFloor(Real floor_val); + + /// Applies ceiling to all elements. Returns number of elements changed. + MatrixIndexT ApplyCeiling(Real ceil_val); + + /// Applies floor to all elements. Returns number of elements floored. + MatrixIndexT ApplyFloor(const VectorBase<Real> &floor_vec); + + /// Apply soft-max to vector and return normalizer (log sum of exponentials). + /// This is the same as: \f$ x(i) = exp(x(i)) / \sum_i exp(x(i)) \f$ + Real ApplySoftMax(); + + /// Sets each element of *this to the tanh of the corresponding element of "src". + void Tanh(const VectorBase<Real> &src); + + /// Sets each element of *this to the sigmoid function of the corresponding + /// element of "src". + void Sigmoid(const VectorBase<Real> &src); + + /// Take all elements of vector to a power. + void ApplyPow(Real power); + + /// Take the absolute value of all elements of a vector to a power. + /// Include the sign of the input element if include_sign == true. + /// If power is negative and the input value is zero, the output is set zero. + void ApplyPowAbs(Real power, bool include_sign=false); + + /// Compute the p-th norm of the vector. + Real Norm(Real p) const; + + /// Returns true if ((*this)-other).Norm(2.0) <= tol * (*this).Norm(2.0). + bool ApproxEqual(const VectorBase<Real> &other, float tol = 0.01) const; + + /// Invert all elements. + void InvertElements(); + + /// Add vector : *this = *this + alpha * rv (with casting between floats and + /// doubles) + template<typename OtherReal> + void AddVec(const Real alpha, const VectorBase<OtherReal> &v); + + /// Add vector : *this = *this + alpha * rv^2 [element-wise squaring]. + void AddVec2(const Real alpha, const VectorBase<Real> &v); + + /// Add vector : *this = *this + alpha * rv^2 [element-wise squaring], + /// with casting between floats and doubles. + template<typename OtherReal> + void AddVec2(const Real alpha, const VectorBase<OtherReal> &v); + + /// Add matrix times vector : this <-- beta*this + alpha*M*v. + /// Calls BLAS GEMV. + void AddMatVec(const Real alpha, const MatrixBase<Real> &M, + const MatrixTransposeType trans, const VectorBase<Real> &v, + const Real beta); // **beta previously defaulted to 0.0** + + /// This is as AddMatVec, except optimized for where v contains a lot + /// of zeros. + void AddMatSvec(const Real alpha, const MatrixBase<Real> &M, + const MatrixTransposeType trans, const VectorBase<Real> &v, + const Real beta); // **beta previously defaulted to 0.0** + + + /// Add symmetric positive definite matrix times vector: + /// this <-- beta*this + alpha*M*v. Calls BLAS SPMV. + void AddSpVec(const Real alpha, const SpMatrix<Real> &M, + const VectorBase<Real> &v, const Real beta); // **beta previously defaulted to 0.0** + + /// Add triangular matrix times vector: this <-- beta*this + alpha*M*v. + /// Works even if rv == *this. + void AddTpVec(const Real alpha, const TpMatrix<Real> &M, + const MatrixTransposeType trans, const VectorBase<Real> &v, + const Real beta); // **beta previously defaulted to 0.0** + + /// Set each element to y = (x == orig ? changed : x). + void ReplaceValue(Real orig, Real changed); + + /// Multipy element-by-element by another vector. + void MulElements(const VectorBase<Real> &v); + /// Multipy element-by-element by another vector of different type. + template<typename OtherReal> + void MulElements(const VectorBase<OtherReal> &v); + + /// Divide element-by-element by a vector. + void DivElements(const VectorBase<Real> &v); + /// Divide element-by-element by a vector of different type. + template<typename OtherReal> + void DivElements(const VectorBase<OtherReal> &v); + + /// Add a constant to each element of a vector. + void Add(Real c); + + /// Add element-by-element product of vectlrs: + // this <-- alpha * v .* r + beta*this . + void AddVecVec(Real alpha, const VectorBase<Real> &v, + const VectorBase<Real> &r, Real beta); + + /// Add element-by-element quotient of two vectors. + /// this <---- alpha*v/r + beta*this + void AddVecDivVec(Real alpha, const VectorBase<Real> &v, + const VectorBase<Real> &r, Real beta); + + /// Multiplies all elements by this constant. + void Scale(Real alpha); + + /// Multiplies this vector by lower-triangular marix: *this <-- *this *M + void MulTp(const TpMatrix<Real> &M, const MatrixTransposeType trans); + + /// If trans == kNoTrans, solves M x = b, where b is the value of *this at input + /// and x is the value of *this at output. + /// If trans == kTrans, solves M' x = b. + /// Does not test for M being singular or near-singular, so test it before + /// calling this routine. + void Solve(const TpMatrix<Real> &M, const MatrixTransposeType trans); + + /// Performs a row stack of the matrix M + void CopyRowsFromMat(const MatrixBase<Real> &M); + template<typename OtherReal> + void CopyRowsFromMat(const MatrixBase<OtherReal> &M); + + /// The following is implemented in ../cudamatrix/cu-matrix.cc + void CopyRowsFromMat(const CuMatrixBase<Real> &M); + + /// Performs a column stack of the matrix M + void CopyColsFromMat(const MatrixBase<Real> &M); + + /// Extracts a row of the matrix M. Could also do this with + /// this->Copy(M[row]). + void CopyRowFromMat(const MatrixBase<Real> &M, MatrixIndexT row); + /// Extracts a row of the matrix M with type conversion. + template<typename OtherReal> + void CopyRowFromMat(const MatrixBase<OtherReal> &M, MatrixIndexT row); + + /// Extracts a row of the symmetric matrix S. + template<typename OtherReal> + void CopyRowFromSp(const SpMatrix<OtherReal> &S, MatrixIndexT row); + + /// Extracts a column of the matrix M. + template<typename OtherReal> + void CopyColFromMat(const MatrixBase<OtherReal> &M , MatrixIndexT col); + + /// Extracts the diagonal of the matrix M. + void CopyDiagFromMat(const MatrixBase<Real> &M); + + /// Extracts the diagonal of a packed matrix M; works for Sp or Tp. + void CopyDiagFromPacked(const PackedMatrix<Real> &M); + + + /// Extracts the diagonal of a symmetric matrix. + inline void CopyDiagFromSp(const SpMatrix<Real> &M) { CopyDiagFromPacked(M); } + + /// Extracts the diagonal of a triangular matrix. + inline void CopyDiagFromTp(const TpMatrix<Real> &M) { CopyDiagFromPacked(M); } + + /// Returns the maximum value of any element, or -infinity for the empty vector. + Real Max() const; + + /// Returns the maximum value of any element, and the associated index. + /// Error if vector is empty. + Real Max(MatrixIndexT *index) const; + + /// Returns the minimum value of any element, or +infinity for the empty vector. + Real Min() const; + + /// Returns the minimum value of any element, and the associated index. + /// Error if vector is empty. + Real Min(MatrixIndexT *index) const; + + /// Returns sum of the elements + Real Sum() const; + + /// Returns sum of the logs of the elements. More efficient than + /// just taking log of each. Will return NaN if any elements are + /// negative. + Real SumLog() const; + + /// Does *this = alpha * (sum of rows of M) + beta * *this. + void AddRowSumMat(Real alpha, const MatrixBase<Real> &M, Real beta = 1.0); + + /// Does *this = alpha * (sum of columns of M) + beta * *this. + void AddColSumMat(Real alpha, const MatrixBase<Real> &M, Real beta = 1.0); + + /// Add the diagonal of a matrix times itself: + /// *this = diag(M M^T) + beta * *this (if trans == kNoTrans), or + /// *this = diag(M^T M) + beta * *this (if trans == kTrans). + void AddDiagMat2(Real alpha, const MatrixBase<Real> &M, + MatrixTransposeType trans = kNoTrans, Real beta = 1.0); + + /// Add the diagonal of a matrix product: *this = diag(M N), assuming the + /// "trans" arguments are both kNoTrans; for transpose arguments, it behaves + /// as you would expect. + void AddDiagMatMat(Real alpha, const MatrixBase<Real> &M, MatrixTransposeType transM, + const MatrixBase<Real> &N, MatrixTransposeType transN, + Real beta = 1.0); + + /// Returns log(sum(exp())) without exp overflow + /// If prune > 0.0, ignores terms less than the max - prune. + /// [Note: in future, if prune = 0.0, it will take the max. + /// For now, use -1 if you don't want it to prune.] + Real LogSumExp(Real prune = -1.0) const; + + /// Reads from C++ stream (option to add to existing contents). + /// Throws exception on failure + void Read(std::istream & in, bool binary, bool add = false); + + /// Writes to C++ stream (option to write in binary). + void Write(std::ostream &Out, bool binary) const; + + friend class VectorBase<double>; + friend class VectorBase<float>; + friend class CuVectorBase<Real>; + friend class CuVector<Real>; + protected: + /// Destructor; does not deallocate memory, this is handled by child classes. + /// This destructor is protected so this object so this object can only be + /// deleted via a child. + ~VectorBase() {} + + /// Empty initializer, corresponds to vector of zero size. + explicit VectorBase(): data_(NULL), dim_(0) { + KALDI_ASSERT_IS_FLOATING_TYPE(Real); + } + +// Took this out since it is not currently used, and it is possible to create +// objects where the allocated memory is not the same size as dim_ : Arnab +// /// Initializer from a pointer and a size; keeps the pointer internally +// /// (ownership or non-ownership depends on the child class). +// explicit VectorBase(Real* data, MatrixIndexT dim) +// : data_(data), dim_(dim) {} + + // Arnab : made this protected since it is unsafe too. + /// Load data into the vector: sz must match own size. + void CopyFromPtr(const Real* Data, MatrixIndexT sz); + + /// data memory area + Real* data_; + /// dimension of vector + MatrixIndexT dim_; + KALDI_DISALLOW_COPY_AND_ASSIGN(VectorBase); +}; // class VectorBase + +/** @brief A class representing a vector. + * + * This class provides a way to work with vectors in kaldi. + * It encapsulates basic operations and memory optimizations. */ +template<typename Real> +class Vector: public VectorBase<Real> { + public: + /// Constructor that takes no arguments. Initializes to empty. + Vector(): VectorBase<Real>() {} + + /// Constructor with specific size. Sets to all-zero by default + /// if set_zero == false, memory contents are undefined. + explicit Vector(const MatrixIndexT s, + MatrixResizeType resize_type = kSetZero) + : VectorBase<Real>() { Resize(s, resize_type); } + + /// Copy constructor from CUDA vector + /// This is defined in ../cudamatrix/cu-vector.h + template<typename OtherReal> + explicit Vector(const CuVectorBase<OtherReal> &cu); + + /// Copy constructor. The need for this is controversial. + Vector(const Vector<Real> &v) : VectorBase<Real>() { // (cannot be explicit) + Resize(v.Dim(), kUndefined); + this->CopyFromVec(v); + } + + /// Copy-constructor from base-class, needed to copy from SubVector. + explicit Vector(const VectorBase<Real> &v) : VectorBase<Real>() { + Resize(v.Dim(), kUndefined); + this->CopyFromVec(v); + } + + /// Type conversion constructor. + template<typename OtherReal> + explicit Vector(const VectorBase<OtherReal> &v): VectorBase<Real>() { + Resize(v.Dim(), kUndefined); + this->CopyFromVec(v); + } + +// Took this out since it is unsafe : Arnab +// /// Constructor from a pointer and a size; copies the data to a location +// /// it owns. +// Vector(const Real* Data, const MatrixIndexT s): VectorBase<Real>() { +// Resize(s); + // CopyFromPtr(Data, s); +// } + + + /// Swaps the contents of *this and *other. Shallow swap. + void Swap(Vector<Real> *other); + + /// Destructor. Deallocates memory. + ~Vector() { Destroy(); } + + /// Read function using C++ streams. Can also add to existing contents + /// of matrix. + void Read(std::istream & in, bool binary, bool add = false); + + /// Set vector to a specified size (can be zero). + /// The value of the new data depends on resize_type: + /// -if kSetZero, the new data will be zero + /// -if kUndefined, the new data will be undefined + /// -if kCopyData, the new data will be the same as the old data in any + /// shared positions, and zero elsewhere. + /// This function takes time proportional to the number of data elements. + void Resize(MatrixIndexT length, MatrixResizeType resize_type = kSetZero); + + /// Remove one element and shifts later elements down. + void RemoveElement(MatrixIndexT i); + + /// Assignment operator, protected so it can only be used by std::vector + Vector<Real> &operator = (const Vector<Real> &other) { + Resize(other.Dim(), kUndefined); + this->CopyFromVec(other); + return *this; + } + + /// Assignment operator that takes VectorBase. + Vector<Real> &operator = (const VectorBase<Real> &other) { + Resize(other.Dim(), kUndefined); + this->CopyFromVec(other); + return *this; + } + private: + /// Init assumes the current contents of the class are invalid (i.e. junk or + /// has already been freed), and it sets the vector to newly allocated memory + /// with the specified dimension. dim == 0 is acceptable. The memory contents + /// pointed to by data_ will be undefined. + void Init(const MatrixIndexT dim); + + /// Destroy function, called internally. + void Destroy(); + +}; + + +/// Represents a non-allocating general vector which can be defined +/// as a sub-vector of higher-level vector [or as the row of a matrix]. +template<typename Real> +class SubVector : public VectorBase<Real> { + public: + /// Constructor from a Vector or SubVector. + /// SubVectors are not const-safe and it's very hard to make them + /// so for now we just give up. This function contains const_cast. + SubVector(const VectorBase<Real> &t, const MatrixIndexT origin, + const MatrixIndexT length) : VectorBase<Real>() { + // following assert equiv to origin>=0 && length>=0 && + // origin+length <= rt.dim_ + KALDI_ASSERT(static_cast<UnsignedMatrixIndexT>(origin)+ + static_cast<UnsignedMatrixIndexT>(length) <= + static_cast<UnsignedMatrixIndexT>(t.Dim())); + VectorBase<Real>::data_ = const_cast<Real*> (t.Data()+origin); + VectorBase<Real>::dim_ = length; + } + + /// This constructor initializes the vector to point at the contents + /// of this packed matrix (SpMatrix or TpMatrix). + SubVector(const PackedMatrix<Real> &M) { + VectorBase<Real>::data_ = const_cast<Real*> (M.Data()); + VectorBase<Real>::dim_ = (M.NumRows()*(M.NumRows()+1))/2; + } + + /// Copy constructor + SubVector(const SubVector &other) : VectorBase<Real> () { + // this copy constructor needed for Range() to work in base class. + VectorBase<Real>::data_ = other.data_; + VectorBase<Real>::dim_ = other.dim_; + } + + /// Constructor from a pointer to memory and a length. Keeps a pointer + /// to the data but does not take ownership (will never delete). + SubVector(Real *data, MatrixIndexT length) : VectorBase<Real> () { + VectorBase<Real>::data_ = data; + VectorBase<Real>::dim_ = length; + } + + + /// This operation does not preserve const-ness, so be careful. + SubVector(const MatrixBase<Real> &matrix, MatrixIndexT row) { + VectorBase<Real>::data_ = const_cast<Real*>(matrix.RowData(row)); + VectorBase<Real>::dim_ = matrix.NumCols(); + } + + ~SubVector() {} ///< Destructor (does nothing; no pointers are owned here). + + private: + /// Disallow assignment operator. + SubVector & operator = (const SubVector &other) {} +}; + +/// @} end of "addtogroup matrix_group" +/// \addtogroup matrix_funcs_io +/// @{ +/// Output to a C++ stream. Non-binary by default (use Write for +/// binary output). +template<typename Real> +std::ostream & operator << (std::ostream & out, const VectorBase<Real> & v); + +/// Input from a C++ stream. Will automatically read text or +/// binary data from the stream. +template<typename Real> +std::istream & operator >> (std::istream & in, VectorBase<Real> & v); + +/// Input from a C++ stream. Will automatically read text or +/// binary data from the stream. +template<typename Real> +std::istream & operator >> (std::istream & in, Vector<Real> & v); +/// @} end of \addtogroup matrix_funcs_io + +/// \addtogroup matrix_funcs_scalar +/// @{ + + +template<typename Real> +bool ApproxEqual(const VectorBase<Real> &a, + const VectorBase<Real> &b, Real tol = 0.01) { + return a.ApproxEqual(b, tol); +} + +template<typename Real> +inline void AssertEqual(VectorBase<Real> &a, VectorBase<Real> &b, + float tol = 0.01) { + KALDI_ASSERT(a.ApproxEqual(b, tol)); +} + + +/// Returns dot product between v1 and v2. +template<typename Real> +Real VecVec(const VectorBase<Real> &v1, const VectorBase<Real> &v2); + +template<typename Real, typename OtherReal> +Real VecVec(const VectorBase<Real> &v1, const VectorBase<OtherReal> &v2); + + +/// Returns \f$ v_1^T M v_2 \f$ . +/// Not as efficient as it could be where v1 == v2. +template<typename Real> +Real VecMatVec(const VectorBase<Real> &v1, const MatrixBase<Real> &M, + const VectorBase<Real> &v2); + +/// @} End of "addtogroup matrix_funcs_scalar" + + +} // namespace kaldi + +// we need to include the implementation +#include "matrix/kaldi-vector-inl.h" + + + +#endif // KALDI_MATRIX_KALDI_VECTOR_H_ + diff --git a/kaldi_io/src/kaldi/matrix/matrix-common.h b/kaldi_io/src/kaldi/matrix/matrix-common.h new file mode 100644 index 0000000..d202b2e --- /dev/null +++ b/kaldi_io/src/kaldi/matrix/matrix-common.h @@ -0,0 +1,100 @@ +// matrix/matrix-common.h + +// Copyright 2009-2011 Microsoft Corporation + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at + +// http://www.apache.org/licenses/LICENSE-2.0 + +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. +#ifndef KALDI_MATRIX_MATRIX_COMMON_H_ +#define KALDI_MATRIX_MATRIX_COMMON_H_ + +// This file contains some #includes, forward declarations +// and typedefs that are needed by all the main header +// files in this directory. + +#include "base/kaldi-common.h" +#include "matrix/kaldi-blas.h" + +namespace kaldi { +typedef enum { + kTrans = CblasTrans, + kNoTrans = CblasNoTrans +} MatrixTransposeType; + +typedef enum { + kSetZero, + kUndefined, + kCopyData +} MatrixResizeType; + +typedef enum { + kTakeLower, + kTakeUpper, + kTakeMean, + kTakeMeanAndCheck +} SpCopyType; + +template<typename Real> class VectorBase; +template<typename Real> class Vector; +template<typename Real> class SubVector; +template<typename Real> class MatrixBase; +template<typename Real> class SubMatrix; +template<typename Real> class Matrix; +template<typename Real> class SpMatrix; +template<typename Real> class TpMatrix; +template<typename Real> class PackedMatrix; + +// these are classes that won't be defined in this +// directory; they're mostly needed for friend declarations. +template<typename Real> class CuMatrixBase; +template<typename Real> class CuSubMatrix; +template<typename Real> class CuMatrix; +template<typename Real> class CuVectorBase; +template<typename Real> class CuSubVector; +template<typename Real> class CuVector; +template<typename Real> class CuPackedMatrix; +template<typename Real> class CuSpMatrix; +template<typename Real> class CuTpMatrix; + +class CompressedMatrix; + +/// This class provides a way for switching between double and float types. +template<typename T> class OtherReal { }; // useful in reading+writing routines + // to switch double and float. +/// A specialized class for switching from float to double. +template<> class OtherReal<float> { + public: + typedef double Real; +}; +/// A specialized class for switching from double to float. +template<> class OtherReal<double> { + public: + typedef float Real; +}; + + +typedef int32 MatrixIndexT; +typedef int32 SignedMatrixIndexT; +typedef uint32 UnsignedMatrixIndexT; + +// If you want to use size_t for the index type, do as follows instead: +//typedef size_t MatrixIndexT; +//typedef ssize_t SignedMatrixIndexT; +//typedef size_t UnsignedMatrixIndexT; + +} + + + +#endif // KALDI_MATRIX_MATRIX_COMMON_H_ diff --git a/kaldi_io/src/kaldi/matrix/matrix-functions-inl.h b/kaldi_io/src/kaldi/matrix/matrix-functions-inl.h new file mode 100644 index 0000000..9fac851 --- /dev/null +++ b/kaldi_io/src/kaldi/matrix/matrix-functions-inl.h @@ -0,0 +1,56 @@ +// matrix/matrix-functions-inl.h + +// Copyright 2009-2011 Microsoft Corporation +// +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. +// +// (*) incorporates, with permission, FFT code from his book +// "Signal Processing with Lapped Transforms", Artech, 1992. + + + +#ifndef KALDI_MATRIX_MATRIX_FUNCTIONS_INL_H_ +#define KALDI_MATRIX_MATRIX_FUNCTIONS_INL_H_ + +namespace kaldi { + +//! ComplexMul implements, inline, the complex multiplication b *= a. +template<typename Real> inline void ComplexMul(const Real &a_re, const Real &a_im, + Real *b_re, Real *b_im) { + Real tmp_re = (*b_re * a_re) - (*b_im * a_im); + *b_im = *b_re * a_im + *b_im * a_re; + *b_re = tmp_re; +} + +template<typename Real> inline void ComplexAddProduct(const Real &a_re, const Real &a_im, + const Real &b_re, const Real &b_im, + Real *c_re, Real *c_im) { + *c_re += b_re*a_re - b_im*a_im; + *c_im += b_re*a_im + b_im*a_re; +} + + +template<typename Real> inline void ComplexImExp(Real x, Real *a_re, Real *a_im) { + *a_re = std::cos(x); + *a_im = std::sin(x); +} + + +} // end namespace kaldi + + +#endif // KALDI_MATRIX_MATRIX_FUNCTIONS_INL_H_ + diff --git a/kaldi_io/src/kaldi/matrix/matrix-functions.h b/kaldi_io/src/kaldi/matrix/matrix-functions.h new file mode 100644 index 0000000..b70ca56 --- /dev/null +++ b/kaldi_io/src/kaldi/matrix/matrix-functions.h @@ -0,0 +1,235 @@ +// matrix/matrix-functions.h + +// Copyright 2009-2011 Microsoft Corporation; Go Vivace Inc.; Jan Silovsky; +// Yanmin Qian; 1991 Henrique (Rico) Malvar (*) +// +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. +// +// (*) incorporates, with permission, FFT code from his book +// "Signal Processing with Lapped Transforms", Artech, 1992. + + + +#ifndef KALDI_MATRIX_MATRIX_FUNCTIONS_H_ +#define KALDI_MATRIX_MATRIX_FUNCTIONS_H_ + +#include "matrix/kaldi-vector.h" +#include "matrix/kaldi-matrix.h" + +namespace kaldi { + +/// @addtogroup matrix_funcs_misc +/// @{ + +/** The function ComplexFft does an Fft on the vector argument v. + v is a vector of even dimension, interpreted for both input + and output as a vector of complex numbers i.e. + \f[ v = ( re_0, im_0, re_1, im_1, ... ) \f] + The dimension of v must be a power of 2. + + If "forward == true" this routine does the Discrete Fourier Transform + (DFT), i.e.: + \f[ vout[m] \leftarrow \sum_{n = 0}^{N-1} vin[i] exp( -2pi m n / N ) \f] + + If "backward" it does the Inverse Discrete Fourier Transform (IDFT) + *WITHOUT THE FACTOR 1/N*, + i.e.: + \f[ vout[m] <-- \sum_{n = 0}^{N-1} vin[i] exp( 2pi m n / N ) \f] + [note the sign difference on the 2 pi for the backward one.] + + Note that this is the definition of the FT given in most texts, but + it differs from the Numerical Recipes version in which the forward + and backward algorithms are flipped. + + Note that you would have to multiply by 1/N after the IDFT to get + back to where you started from. We don't do this because + in some contexts, the transform is made symmetric by multiplying + by sqrt(N) in both passes. The user can do this by themselves. + + See also SplitRadixComplexFft, declared in srfft.h, which is more efficient + but only works if the length of the input is a power of 2. + */ +template<typename Real> void ComplexFft (VectorBase<Real> *v, bool forward, Vector<Real> *tmp_work = NULL); + +/// ComplexFt is the same as ComplexFft but it implements the Fourier +/// transform in an inefficient way. It is mainly included for testing purposes. +/// See comment for ComplexFft to describe the input and outputs and what it does. +template<typename Real> void ComplexFt (const VectorBase<Real> &in, + VectorBase<Real> *out, bool forward); + +/// RealFft is a fourier transform of real inputs. Internally it uses +/// ComplexFft. The input dimension N must be even. If forward == true, +/// it transforms from a sequence of N real points to its complex fourier +/// transform; otherwise it goes in the reverse direction. If you call it +/// in the forward and then reverse direction and multiply by 1.0/N, you +/// will get back the original data. +/// The interpretation of the complex-FFT data is as follows: the array +/// is a sequence of complex numbers C_n of length N/2 with (real, im) format, +/// i.e. [real0, real_{N/2}, real1, im1, real2, im2, real3, im3, ...]. +/// See also SplitRadixRealFft, declared in srfft.h, which is more efficient +/// but only works if the length of the input is a power of 2. + +template<typename Real> void RealFft (VectorBase<Real> *v, bool forward); + + +/// RealFt has the same input and output format as RealFft above, but it is +/// an inefficient implementation included for testing purposes. +template<typename Real> void RealFftInefficient (VectorBase<Real> *v, bool forward); + +/// ComputeDctMatrix computes a matrix corresponding to the DCT, such that +/// M * v equals the DCT of vector v. M must be square at input. +/// This is the type = III DCT with normalization, corresponding to the +/// following equations, where x is the signal and X is the DCT: +/// X_0 = 1/sqrt(2*N) \sum_{n = 0}^{N-1} x_n +/// X_k = 1/sqrt(N) \sum_{n = 0}^{N-1} x_n cos( \pi/N (n + 1/2) k ) +/// This matrix's transpose is its own inverse, so transposing this +/// matrix will give the inverse DCT. +/// Caution: the type III DCT is generally known as the "inverse DCT" (with the +/// type II being the actual DCT), so this function is somewhatd mis-named. It +/// was probably done this way for HTK compatibility. We don't change it +/// because it was this way from the start and changing it would affect the +/// feature generation. + +template<typename Real> void ComputeDctMatrix(Matrix<Real> *M); + + +/// ComplexMul implements, inline, the complex multiplication b *= a. +template<typename Real> inline void ComplexMul(const Real &a_re, const Real &a_im, + Real *b_re, Real *b_im); + +/// ComplexMul implements, inline, the complex operation c += (a * b). +template<typename Real> inline void ComplexAddProduct(const Real &a_re, const Real &a_im, + const Real &b_re, const Real &b_im, + Real *c_re, Real *c_im); + + +/// ComplexImExp implements a <-- exp(i x), inline. +template<typename Real> inline void ComplexImExp(Real x, Real *a_re, Real *a_im); + + +// This class allows you to compute the matrix exponential function +// B = I + A + 1/2! A^2 + 1/3! A^3 + ... +// This method is most accurate where the result is of the same order of +// magnitude as the unit matrix (it will typically not work well when +// the answer has almost-zero eigenvalues or is close to zero). +// It also provides a function that allows you do back-propagate the +// derivative of a scalar function through this calculation. +// The +template<typename Real> +class MatrixExponential { + public: + MatrixExponential() { } + + void Compute(const MatrixBase<Real> &M, MatrixBase<Real> *X); // does *X = exp(M) + + // Version for symmetric matrices (it just copies to full matrix). + void Compute(const SpMatrix<Real> &M, SpMatrix<Real> *X); // does *X = exp(M) + + void Backprop(const MatrixBase<Real> &hX, MatrixBase<Real> *hM) const; // Propagates + // the gradient of a scalar function f backwards through this operation, i.e.: + // if the parameter dX represents df/dX (with no transpose, so element i, j of dX + // is the derivative of f w.r.t. E(i, j)), it sets dM to df/dM, again with no + // transpose (of course, only the part thereof that comes through the effect of + // A on B). This applies to the values of A and E that were called most recently + // with Compute(). + + // Version for symmetric matrices (it just copies to full matrix). + void Backprop(const SpMatrix<Real> &hX, SpMatrix<Real> *hM) const; + + private: + void Clear(); + + static MatrixIndexT ComputeN(const MatrixBase<Real> &M); + + // This is intended for matrices P with small norms: compute B_0 = exp(P) - I. + // Keeps adding terms in the Taylor series till there is no further + // change in the result. Stores some of the powers of A in powers_, + // and the number of terms K as K_. + void ComputeTaylor(const MatrixBase<Real> &P, MatrixBase<Real> *B0); + + // Backprop through the Taylor-series computation above. + // note: hX is \hat{X} in the math; hM is \hat{M} in the math. + void BackpropTaylor(const MatrixBase<Real> &hX, + MatrixBase<Real> *hM) const; + + Matrix<Real> P_; // Equals M * 2^(-N_) + std::vector<Matrix<Real> > B_; // B_[0] = exp(P_) - I, + // B_[k] = 2 B_[k-1] + B_[k-1]^2 [k > 0], + // ( = exp(P_)^k - I ) + // goes from 0..N_ [size N_+1]. + + std::vector<Matrix<Real> > powers_; // powers (>1) of P_ stored here, + // up to all but the last one used in the Taylor expansion (this is the + // last one we need in the backprop). The index is the power minus 2. + + MatrixIndexT N_; // Power N_ >=0 such that P_ = A * 2^(-N_), + // we choose it so that P_ has a sufficiently small norm + // that the Taylor series will converge fast. +}; + + +/** + ComputePCA does a PCA computation, using either outer products + or inner products, whichever is more efficient. Let D be + the dimension of the data points, N be the number of data + points, and G be the PCA dimension we want to retain. We assume + G <= N and G <= D. + + @param X [in] An N x D matrix. Each row of X is a point x_i. + @param U [out] A G x D matrix. Each row of U is a basis element u_i. + @param A [out] An N x D matrix, or NULL. Each row of A is a set of coefficients + in the basis for a point x_i, so A(i, g) is the coefficient of u_i + in x_i. + @param print_eigs [in] If true, prints out diagnostic information about the + eigenvalues. + @param exact [in] If true, does the exact computation; if false, does + a much faster (but almost exact) computation based on the Lanczos + method. +*/ + +template<typename Real> +void ComputePca(const MatrixBase<Real> &X, + MatrixBase<Real> *U, + MatrixBase<Real> *A, + bool print_eigs = false, + bool exact = true); + + + +// This function does: *plus += max(0, a b^T), +// *minus += max(0, -(a b^T)). +template<typename Real> +void AddOuterProductPlusMinus(Real alpha, + const VectorBase<Real> &a, + const VectorBase<Real> &b, + MatrixBase<Real> *plus, + MatrixBase<Real> *minus); + +template<typename Real1, typename Real2> +inline void AssertSameDim(const MatrixBase<Real1> &mat1, const MatrixBase<Real2> &mat2) { + KALDI_ASSERT(mat1.NumRows() == mat2.NumRows() + && mat1.NumCols() == mat2.NumCols()); +} + + +/// @} end of "addtogroup matrix_funcs_misc" + +} // end namespace kaldi + +#include "matrix/matrix-functions-inl.h" + + +#endif diff --git a/kaldi_io/src/kaldi/matrix/matrix-lib.h b/kaldi_io/src/kaldi/matrix/matrix-lib.h new file mode 100644 index 0000000..39acec5 --- /dev/null +++ b/kaldi_io/src/kaldi/matrix/matrix-lib.h @@ -0,0 +1,37 @@ +// matrix/matrix-lib.h + +// Copyright 2009-2011 Ondrej Glembek; Microsoft Corporation; Haihua Xu + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at + +// http://www.apache.org/licenses/LICENSE-2.0 + +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +// Include everything from this directory. +// These files include other stuff that we need. +#ifndef KALDI_MATRIX_MATRIX_LIB_H_ +#define KALDI_MATRIX_MATRIX_LIB_H_ + +#include "matrix/cblas-wrappers.h" +#include "base/kaldi-common.h" +#include "matrix/kaldi-vector.h" +#include "matrix/kaldi-matrix.h" +#include "matrix/sp-matrix.h" +#include "matrix/tp-matrix.h" +#include "matrix/matrix-functions.h" +#include "matrix/srfft.h" +#include "matrix/compressed-matrix.h" +#include "matrix/optimization.h" + +#endif + diff --git a/kaldi_io/src/kaldi/matrix/optimization.h b/kaldi_io/src/kaldi/matrix/optimization.h new file mode 100644 index 0000000..66309ac --- /dev/null +++ b/kaldi_io/src/kaldi/matrix/optimization.h @@ -0,0 +1,248 @@ +// matrix/optimization.h + +// Copyright 2012 Johns Hopkins University (author: Daniel Povey) +// +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. +// +// (*) incorporates, with permission, FFT code from his book +// "Signal Processing with Lapped Transforms", Artech, 1992. + + + +#ifndef KALDI_MATRIX_OPTIMIZATION_H_ +#define KALDI_MATRIX_OPTIMIZATION_H_ + +#include "matrix/kaldi-vector.h" +#include "matrix/kaldi-matrix.h" + +namespace kaldi { + + +/// @addtogroup matrix_optimization +/// @{ + +struct LinearCgdOptions { + int32 max_iters; // Maximum number of iters (if >= 0). + BaseFloat max_error; // Maximum 2-norm of the residual A x - b (convergence + // test) + // Every time the residual 2-norm decreases by this recompute_residual_factor + // since the last time it was computed from scratch, recompute it from + // scratch. This helps to keep the computed residual accurate even in the + // presence of roundoff. + BaseFloat recompute_residual_factor; + + LinearCgdOptions(): max_iters(-1), + max_error(0.0), + recompute_residual_factor(0.01) { } +}; + +/* + This function uses linear conjugate gradient descent to approximately solve + the system A x = b. The value of x at entry corresponds to the initial guess + of x. The algorithm continues until the number of iterations equals b.Dim(), + or until the 2-norm of (A x - b) is <= max_error, or until the number of + iterations equals max_iter, whichever happens sooner. It is a requirement + that A be positive definite. + It returns the number of iterations that were actually executed (this is + useful for testing purposes). +*/ +template<typename Real> +int32 LinearCgd(const LinearCgdOptions &opts, + const SpMatrix<Real> &A, const VectorBase<Real> &b, + VectorBase<Real> *x); + + + + + + +/** + This is an implementation of L-BFGS. It pushes responsibility for + determining when to stop, onto the user. There is no call-back here: + everything is done via calls to the class itself (see the example in + matrix-lib-test.cc). This does not implement constrained L-BFGS, but it will + handle constrained problems correctly as long as the function approaches + +infinity (or -infinity for maximization problems) when it gets close to the + bound of the constraint. In these types of problems, you just let the + function value be +infinity for minimization problems, or -infinity for + maximization problems, outside these bounds). +*/ + +struct LbfgsOptions { + bool minimize; // if true, we're minimizing, else maximizing. + int m; // m is the number of stored vectors L-BFGS keeps. + float first_step_learning_rate; // The very first step of L-BFGS is + // like gradient descent. If you want to configure the size of that step, + // you can do it using this variable. + float first_step_length; // If this variable is >0.0, it overrides + // first_step_learning_rate; on the first step we choose an approximate + // Hessian that is the multiple of the identity that would generate this + // step-length, or 1.0 if the gradient is zero. + float first_step_impr; // If this variable is >0.0, it overrides + // first_step_learning_rate; on the first step we choose an approximate + // Hessian that is the multiple of the identity that would generate this + // amount of objective function improvement (assuming the "real" objf + // was linear). + float c1; // A constant in Armijo rule = Wolfe condition i) + float c2; // A constant in Wolfe condition ii) + float d; // An amount > 1.0 (default 2.0) that we initially multiply or + // divide the step length by, in the line search. + int max_line_search_iters; // after this many iters we restart L-BFGS. + int avg_step_length; // number of iters to avg step length over, in + // RecentStepLength(). + + LbfgsOptions (bool minimize = true): + minimize(minimize), + m(10), + first_step_learning_rate(1.0), + first_step_length(0.0), + first_step_impr(0.0), + c1(1.0e-04), + c2(0.9), + d(2.0), + max_line_search_iters(50), + avg_step_length(4) { } +}; + +template<typename Real> +class OptimizeLbfgs { + public: + /// Initializer takes the starting value of x. + OptimizeLbfgs(const VectorBase<Real> &x, + const LbfgsOptions &opts); + + /// This returns the value of the variable x that has the best objective + /// function so far, and the corresponding objective function value if + /// requested. This would typically be called only at the end. + const VectorBase<Real>& GetValue(Real *objf_value = NULL) const; + + /// This returns the value at which the function wants us + /// to compute the objective function and gradient. + const VectorBase<Real>& GetProposedValue() const { return new_x_; } + + /// Returns the average magnitude of the last n steps (but not + /// more than the number we have stored). Before we have taken + /// any steps, returns +infinity. Note: if the most recent + /// step length was 0, it returns 0, regardless of the other + /// step lengths. This makes it suitable as a convergence test + /// (else we'd generate NaN's). + Real RecentStepLength() const; + + /// The user calls this function to provide the class with the + /// function and gradient info at the point GetProposedValue(). + /// If this point is outside the constraints you can set function_value + /// to {+infinity,-infinity} for {minimization,maximization} problems. + /// In this case the gradient, and also the second derivative (if you call + /// the second overloaded version of this function) will be ignored. + void DoStep(Real function_value, + const VectorBase<Real> &gradient); + + /// The user can call this version of DoStep() if it is desired to set some + /// kind of approximate Hessian on this iteration. Note: it is a prerequisite + /// that diag_approx_2nd_deriv must be strictly positive (minimizing), or + /// negative (maximizing). + void DoStep(Real function_value, + const VectorBase<Real> &gradient, + const VectorBase<Real> &diag_approx_2nd_deriv); + + private: + KALDI_DISALLOW_COPY_AND_ASSIGN(OptimizeLbfgs); + + + // The following variable says what stage of the computation we're at. + // Refer to Algorithm 7.5 (L-BFGS) of Nodecdal & Wright, "Numerical + // Optimization", 2nd edition. + // kBeforeStep means we're about to do + /// "compute p_k <-- - H_k \delta f_k" (i.e. Algorithm 7.4). + // kWithinStep means we're at some point within line search; note + // that line search is iterative so we can stay in this state more + // than one time on each iteration. + enum ComputationState { + kBeforeStep, + kWithinStep, // This means we're within the step-size computation, and + // have not yet done the 1st function evaluation. + }; + + inline MatrixIndexT Dim() { return x_.Dim(); } + inline MatrixIndexT M() { return opts_.m; } + SubVector<Real> Y(MatrixIndexT i) { + return SubVector<Real>(data_, (i % M()) * 2); // vector y_i + } + SubVector<Real> S(MatrixIndexT i) { + return SubVector<Real>(data_, (i % M()) * 2 + 1); // vector s_i + } + // The following are subroutines within DoStep(): + bool AcceptStep(Real function_value, + const VectorBase<Real> &gradient); + void Restart(const VectorBase<Real> &x, + Real function_value, + const VectorBase<Real> &gradient); + void ComputeNewDirection(Real function_value, + const VectorBase<Real> &gradient); + void ComputeHifNeeded(const VectorBase<Real> &gradient); + void StepSizeIteration(Real function_value, + const VectorBase<Real> &gradient); + void RecordStepLength(Real s); + + + LbfgsOptions opts_; + SignedMatrixIndexT k_; // Iteration number, starts from zero. Gets set back to zero + // when we restart. + + ComputationState computation_state_; + bool H_was_set_; // True if the user specified H_; if false, + // we'll use a heuristic to estimate it. + + + Vector<Real> x_; // current x. + Vector<Real> new_x_; // the x proposed in the line search. + Vector<Real> best_x_; // the x with the best objective function so far + // (either the same as x_ or something in the current line search.) + Vector<Real> deriv_; // The most recently evaluated derivative-- at x_k. + Vector<Real> temp_; + Real f_; // The function evaluated at x_k. + Real best_f_; // the best objective function so far. + Real d_; // a number d > 1.0, but during an iteration we may decrease this, when + // we switch between armijo and wolfe failures. + + int num_wolfe_i_failures_; // the num times we decreased step size. + int num_wolfe_ii_failures_; // the num times we increased step size. + enum { kWolfeI, kWolfeII, kNone } last_failure_type_; // last type of step-search + // failure on this iter. + + Vector<Real> H_; // Current inverse-Hessian estimate. May be computed by this class itself, + // or provided by user using 2nd form of SetGradientInfo(). + Matrix<Real> data_; // dimension (m*2) x dim. Even rows store + // gradients y_i, odd rows store steps s_i. + Vector<Real> rho_; // dimension m; rho_(m) = 1/(y_m^T s_m), Eq. 7.17. + + std::vector<Real> step_lengths_; // The step sizes we took on the last + // (up to m) iterations; these are not stored in a rotating buffer but + // are shifted by one each time (this is more convenient when we + // restart, as we keep this info past restarting). + + +}; + +/// @} + + +} // end namespace kaldi + + + +#endif + diff --git a/kaldi_io/src/kaldi/matrix/packed-matrix.h b/kaldi_io/src/kaldi/matrix/packed-matrix.h new file mode 100644 index 0000000..722d932 --- /dev/null +++ b/kaldi_io/src/kaldi/matrix/packed-matrix.h @@ -0,0 +1,197 @@ +// matrix/packed-matrix.h + +// Copyright 2009-2013 Ondrej Glembek; Lukas Burget; Microsoft Corporation; +// Saarland University; Yanmin Qian; +// Johns Hopkins University (Author: Daniel Povey) + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at + +// http://www.apache.org/licenses/LICENSE-2.0 + +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#ifndef KALDI_MATRIX_PACKED_MATRIX_H_ +#define KALDI_MATRIX_PACKED_MATRIX_H_ + +#include "matrix/matrix-common.h" +#include <algorithm> + +namespace kaldi { + +/// \addtogroup matrix_funcs_io +// we need to declare the friend << operator here +template<typename Real> +std::ostream & operator <<(std::ostream & out, const PackedMatrix<Real>& M); + + +/// \addtogroup matrix_group +/// @{ + +/// @brief Packed matrix: base class for triangular and symmetric matrices. +template<typename Real> class PackedMatrix { + friend class CuPackedMatrix<Real>; + public: + //friend class CuPackedMatrix<Real>; + + PackedMatrix() : data_(NULL), num_rows_(0) {} + + explicit PackedMatrix(MatrixIndexT r, MatrixResizeType resize_type = kSetZero): + data_(NULL) { Resize(r, resize_type); } + + explicit PackedMatrix(const PackedMatrix<Real> &orig) : data_(NULL) { + Resize(orig.num_rows_, kUndefined); + CopyFromPacked(orig); + } + + template<typename OtherReal> + explicit PackedMatrix(const PackedMatrix<OtherReal> &orig) : data_(NULL) { + Resize(orig.NumRows(), kUndefined); + CopyFromPacked(orig); + } + + void SetZero(); /// < Set to zero + void SetUnit(); /// < Set to unit matrix. + void SetRandn(); /// < Set to random values of a normal distribution + + Real Trace() const; + + // Needed for inclusion in std::vector + PackedMatrix<Real> & operator =(const PackedMatrix<Real> &other) { + Resize(other.NumRows()); + CopyFromPacked(other); + return *this; + } + + ~PackedMatrix() { + Destroy(); + } + + /// Set packed matrix to a specified size (can be zero). + /// The value of the new data depends on resize_type: + /// -if kSetZero, the new data will be zero + /// -if kUndefined, the new data will be undefined + /// -if kCopyData, the new data will be the same as the old data in any + /// shared positions, and zero elsewhere. + /// This function takes time proportional to the number of data elements. + void Resize(MatrixIndexT nRows, MatrixResizeType resize_type = kSetZero); + + void AddToDiag(const Real r); // Adds r to diaginal + + void ScaleDiag(const Real alpha); // Scales diagonal by alpha. + + void SetDiag(const Real alpha); // Sets diagonal to this value. + + template<typename OtherReal> + void CopyFromPacked(const PackedMatrix<OtherReal> &orig); + + /// CopyFromVec just interprets the vector as having the same layout + /// as the packed matrix. Must have the same dimension, i.e. + /// orig.Dim() == (NumRows()*(NumRows()+1)) / 2; + template<typename OtherReal> + void CopyFromVec(const SubVector<OtherReal> &orig); + + Real* Data() { return data_; } + const Real* Data() const { return data_; } + inline MatrixIndexT NumRows() const { return num_rows_; } + inline MatrixIndexT NumCols() const { return num_rows_; } + size_t SizeInBytes() const { + size_t nr = static_cast<size_t>(num_rows_); + return ((nr * (nr+1)) / 2) * sizeof(Real); + } + + //MatrixIndexT Stride() const { return stride_; } + + // This code is duplicated in child classes to avoid extra levels of calls. + Real operator() (MatrixIndexT r, MatrixIndexT c) const { + KALDI_ASSERT(static_cast<UnsignedMatrixIndexT>(r) < + static_cast<UnsignedMatrixIndexT>(num_rows_) && + static_cast<UnsignedMatrixIndexT>(c) < + static_cast<UnsignedMatrixIndexT>(num_rows_) + && c <= r); + return *(data_ + (r * (r + 1)) / 2 + c); + } + + // This code is duplicated in child classes to avoid extra levels of calls. + Real &operator() (MatrixIndexT r, MatrixIndexT c) { + KALDI_ASSERT(static_cast<UnsignedMatrixIndexT>(r) < + static_cast<UnsignedMatrixIndexT>(num_rows_) && + static_cast<UnsignedMatrixIndexT>(c) < + static_cast<UnsignedMatrixIndexT>(num_rows_) + && c <= r); + return *(data_ + (r * (r + 1)) / 2 + c); + } + + Real Max() const { + KALDI_ASSERT(num_rows_ > 0); + return * (std::max_element(data_, data_ + ((num_rows_*(num_rows_+1))/2) )); + } + + Real Min() const { + KALDI_ASSERT(num_rows_ > 0); + return * (std::min_element(data_, data_ + ((num_rows_*(num_rows_+1))/2) )); + } + + void Scale(Real c); + + friend std::ostream & operator << <> (std::ostream & out, + const PackedMatrix<Real> &m); + // Use instead of stream<<*this, if you want to add to existing contents. + // Will throw exception on failure. + void Read(std::istream &in, bool binary, bool add = false); + + void Write(std::ostream &out, bool binary) const; + + void Destroy(); + + /// Swaps the contents of *this and *other. Shallow swap. + void Swap(PackedMatrix<Real> *other); + void Swap(Matrix<Real> *other); + + + protected: + // Will only be called from this class or derived classes. + void AddPacked(const Real alpha, const PackedMatrix<Real>& M); + Real *data_; + MatrixIndexT num_rows_; + //MatrixIndexT stride_; + private: + /// Init assumes the current contents of the class are is invalid (i.e. junk or + /// has already been freed), and it sets the matrixd to newly allocated memory + /// with the specified dimension. dim == 0 is acceptable. The memory contents + /// pointed to by data_ will be undefined. + void Init(MatrixIndexT dim); + +}; +/// @} end "addtogroup matrix_group" + + +/// \addtogroup matrix_funcs_io +/// @{ + +template<typename Real> +std::ostream & operator << (std::ostream & os, const PackedMatrix<Real>& M) { + M.Write(os, false); + return os; +} + +template<typename Real> +std::istream & operator >> (std::istream &is, PackedMatrix<Real> &M) { + M.Read(is, false); + return is; +} + +/// @} + +} // namespace kaldi + +#endif + diff --git a/kaldi_io/src/kaldi/matrix/sp-matrix-inl.h b/kaldi_io/src/kaldi/matrix/sp-matrix-inl.h new file mode 100644 index 0000000..1579592 --- /dev/null +++ b/kaldi_io/src/kaldi/matrix/sp-matrix-inl.h @@ -0,0 +1,42 @@ +// matrix/sp-matrix-inl.h + +// Copyright 2009-2011 Ondrej Glembek; Microsoft Corporation; Haihua Xu + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at + +// http://www.apache.org/licenses/LICENSE-2.0 + +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#ifndef KALDI_MATRIX_SP_MATRIX_INL_H_ +#define KALDI_MATRIX_SP_MATRIX_INL_H_ + +#include "matrix/tp-matrix.h" + +namespace kaldi { + +// All the lines in this file seem to be declaring template specializations. +// These tell the compiler that we'll implement the templated function +// separately for the different template arguments (float, double). + +template<> +double SolveQuadraticProblem(const SpMatrix<double> &H, const VectorBase<double> &g, + const SolverOptions &opts, VectorBase<double> *x); + +template<> +float SolveQuadraticProblem(const SpMatrix<float> &H, const VectorBase<float> &g, + const SolverOptions &opts, VectorBase<float> *x); + +} // namespace kaldi + + +#endif // KALDI_MATRIX_SP_MATRIX_INL_H_ diff --git a/kaldi_io/src/kaldi/matrix/sp-matrix.h b/kaldi_io/src/kaldi/matrix/sp-matrix.h new file mode 100644 index 0000000..209d24a --- /dev/null +++ b/kaldi_io/src/kaldi/matrix/sp-matrix.h @@ -0,0 +1,524 @@ +// matrix/sp-matrix.h + +// Copyright 2009-2011 Ondrej Glembek; Microsoft Corporation; Lukas Burget; +// Saarland University; Ariya Rastrow; Yanmin Qian; +// Jan Silovsky + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at + +// http://www.apache.org/licenses/LICENSE-2.0 + +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. +#ifndef KALDI_MATRIX_SP_MATRIX_H_ +#define KALDI_MATRIX_SP_MATRIX_H_ + +#include <algorithm> +#include <vector> + +#include "matrix/packed-matrix.h" + +namespace kaldi { + + +/// \addtogroup matrix_group +/// @{ +template<typename Real> class SpMatrix; + + +/** + * @brief Packed symetric matrix class +*/ +template<typename Real> +class SpMatrix : public PackedMatrix<Real> { + friend class CuSpMatrix<Real>; + public: + // so it can use our assignment operator. + friend class std::vector<Matrix<Real> >; + + SpMatrix(): PackedMatrix<Real>() {} + + /// Copy constructor from CUDA version of SpMatrix + /// This is defined in ../cudamatrix/cu-sp-matrix.h + + explicit SpMatrix(const CuSpMatrix<Real> &cu); + + explicit SpMatrix(MatrixIndexT r, MatrixResizeType resize_type = kSetZero) + : PackedMatrix<Real>(r, resize_type) {} + + SpMatrix(const SpMatrix<Real> &orig) + : PackedMatrix<Real>(orig) {} + + template<typename OtherReal> + explicit SpMatrix(const SpMatrix<OtherReal> &orig) + : PackedMatrix<Real>(orig) {} + +#ifdef KALDI_PARANOID + explicit SpMatrix(const MatrixBase<Real> & orig, + SpCopyType copy_type = kTakeMeanAndCheck) + : PackedMatrix<Real>(orig.NumRows(), kUndefined) { + CopyFromMat(orig, copy_type); + } +#else + explicit SpMatrix(const MatrixBase<Real> & orig, + SpCopyType copy_type = kTakeMean) + : PackedMatrix<Real>(orig.NumRows(), kUndefined) { + CopyFromMat(orig, copy_type); + } +#endif + + /// Shallow swap. + void Swap(SpMatrix *other); + + inline void Resize(MatrixIndexT nRows, MatrixResizeType resize_type = kSetZero) { + PackedMatrix<Real>::Resize(nRows, resize_type); + } + + void CopyFromSp(const SpMatrix<Real> &other) { + PackedMatrix<Real>::CopyFromPacked(other); + } + + template<typename OtherReal> + void CopyFromSp(const SpMatrix<OtherReal> &other) { + PackedMatrix<Real>::CopyFromPacked(other); + } + +#ifdef KALDI_PARANOID + void CopyFromMat(const MatrixBase<Real> &orig, + SpCopyType copy_type = kTakeMeanAndCheck); +#else // different default arg if non-paranoid mode. + void CopyFromMat(const MatrixBase<Real> &orig, + SpCopyType copy_type = kTakeMean); +#endif + + inline Real operator() (MatrixIndexT r, MatrixIndexT c) const { + // if column is less than row, then swap these as matrix is stored + // as upper-triangular... only allowed for const matrix object. + if (static_cast<UnsignedMatrixIndexT>(c) > + static_cast<UnsignedMatrixIndexT>(r)) + std::swap(c, r); + // c<=r now so don't have to check c. + KALDI_ASSERT(static_cast<UnsignedMatrixIndexT>(r) < + static_cast<UnsignedMatrixIndexT>(this->num_rows_)); + return *(this->data_ + (r*(r+1)) / 2 + c); + // Duplicating code from PackedMatrix.h + } + + inline Real &operator() (MatrixIndexT r, MatrixIndexT c) { + if (static_cast<UnsignedMatrixIndexT>(c) > + static_cast<UnsignedMatrixIndexT>(r)) + std::swap(c, r); + // c<=r now so don't have to check c. + KALDI_ASSERT(static_cast<UnsignedMatrixIndexT>(r) < + static_cast<UnsignedMatrixIndexT>(this->num_rows_)); + return *(this->data_ + (r * (r + 1)) / 2 + c); + // Duplicating code from PackedMatrix.h + } + + using PackedMatrix<Real>::operator =; + using PackedMatrix<Real>::Scale; + + /// matrix inverse. + /// if inverse_needed = false, will fill matrix with garbage. + /// (only useful if logdet wanted). + void Invert(Real *logdet = NULL, Real *det_sign= NULL, + bool inverse_needed = true); + + // Below routine does inversion in double precision, + // even for single-precision object. + void InvertDouble(Real *logdet = NULL, Real *det_sign = NULL, + bool inverse_needed = true); + + /// Returns maximum ratio of singular values. + inline Real Cond() const { + Matrix<Real> tmp(*this); + return tmp.Cond(); + } + + /// Takes matrix to a fraction power via Svd. + /// Will throw exception if matrix is not positive semidefinite + /// (to within a tolerance) + void ApplyPow(Real exponent); + + /// This is the version of SVD that we implement for symmetric positive + /// definite matrices. This exists for historical reasons; right now its + /// internal implementation is the same as Eig(). It computes the eigenvalue + /// decomposition (*this) = P * diag(s) * P^T with P orthogonal. Will throw + /// exception if input is not positive semidefinite to within a tolerance. + void SymPosSemiDefEig(VectorBase<Real> *s, MatrixBase<Real> *P, + Real tolerance = 0.001) const; + + /// Solves the symmetric eigenvalue problem: at end we should have (*this) = P + /// * diag(s) * P^T. We solve the problem using the symmetric QR method. + /// P may be NULL. + /// Implemented in qr.cc. + /// If you need the eigenvalues sorted, the function SortSvd declared in + /// kaldi-matrix is suitable. + void Eig(VectorBase<Real> *s, MatrixBase<Real> *P = NULL) const; + + /// This function gives you, approximately, the largest eigenvalues of the + /// symmetric matrix and the corresponding eigenvectors. (largest meaning, + /// further from zero). It does this by doing a SVD within the Krylov + /// subspace generated by this matrix and a random vector. This is + /// a form of the Lanczos method with complete reorthogonalization, followed + /// by SVD within a smaller dimension ("lanczos_dim"). + /// + /// If *this is m by m, s should be of dimension n and P should be of + /// dimension m by n, with n <= m. The *columns* of P are the approximate + /// eigenvectors; P * diag(s) * P^T would be a low-rank reconstruction of + /// *this. The columns of P will be orthogonal, and the elements of s will be + /// the eigenvalues of *this projected into that subspace, but beyond that + /// there are no exact guarantees. (This is because the convergence of this + /// method is statistical). Note: it only makes sense to use this + /// method if you are in very high dimension and n is substantially smaller + /// than m: for example, if you want the 100 top eigenvalues of a 10k by 10k + /// matrix. This function calls Rand() to initialize the lanczos + /// iterations and also for restarting. + /// If lanczos_dim is zero, it will default to the greater of: + /// s->Dim() + 50 or s->Dim() + s->Dim()/2, but not more than this->Dim(). + /// If lanczos_dim == this->Dim(), you might as well just call the function + /// Eig() since the result will be the same, and Eig() would be faster; the + /// whole point of this function is to reduce the dimension of the SVD + /// computation. + void TopEigs(VectorBase<Real> *s, MatrixBase<Real> *P, + MatrixIndexT lanczos_dim = 0) const; + + + + /// Takes log of the matrix (does eigenvalue decomposition then takes + /// log of eigenvalues and reconstructs). Will throw of not +ve definite. + void Log(); + + + // Takes exponential of the matrix (equivalent to doing eigenvalue + // decomposition then taking exp of eigenvalues and reconstructing). + void Exp(); + + /// Returns the maximum of the absolute values of any of the + /// eigenvalues. + Real MaxAbsEig() const; + + void PrintEigs(const char *name) { + Vector<Real> s((*this).NumRows()); + Matrix<Real> P((*this).NumRows(), (*this).NumCols()); + SymPosSemiDefEig(&s, &P); + KALDI_LOG << "PrintEigs: " << name << ": " << s; + } + + bool IsPosDef() const; // returns true if Cholesky succeeds. + void AddSp(const Real alpha, const SpMatrix<Real> &Ma) { + this->AddPacked(alpha, Ma); + } + + /// Computes log determinant but only for +ve-def matrices + /// (it uses Cholesky). + /// If matrix is not +ve-def, it will throw an exception + /// was LogPDDeterminant() + Real LogPosDefDet() const; + + Real LogDet(Real *det_sign = NULL) const; + + /// rank-one update, this <-- this + alpha v v' + template<typename OtherReal> + void AddVec2(const Real alpha, const VectorBase<OtherReal> &v); + + /// rank-two update, this <-- this + alpha (v w' + w v'). + void AddVecVec(const Real alpha, const VectorBase<Real> &v, + const VectorBase<Real> &w); + + /// Does *this = beta * *thi + alpha * diag(v) * S * diag(v) + void AddVec2Sp(const Real alpha, const VectorBase<Real> &v, + const SpMatrix<Real> &S, const Real beta); + + /// diagonal update, this <-- this + diag(v) + template<typename OtherReal> + void AddDiagVec(const Real alpha, const VectorBase<OtherReal> &v); + + /// rank-N update: + /// if (transM == kNoTrans) + /// (*this) = beta*(*this) + alpha * M * M^T, + /// or (if transM == kTrans) + /// (*this) = beta*(*this) + alpha * M^T * M + /// Note: beta used to default to 0.0. + void AddMat2(const Real alpha, const MatrixBase<Real> &M, + MatrixTransposeType transM, const Real beta); + + /// Extension of rank-N update: + /// this <-- beta*this + alpha * M * A * M^T. + /// (*this) and A are allowed to be the same. + /// If transM == kTrans, then we do it as M^T * A * M. + void AddMat2Sp(const Real alpha, const MatrixBase<Real> &M, + MatrixTransposeType transM, const SpMatrix<Real> &A, + const Real beta = 0.0); + + /// This is a version of AddMat2Sp specialized for when M is fairly sparse. + /// This was required for making the raw-fMLLR code efficient. + void AddSmat2Sp(const Real alpha, const MatrixBase<Real> &M, + MatrixTransposeType transM, const SpMatrix<Real> &A, + const Real beta = 0.0); + + /// The following function does: + /// this <-- beta*this + alpha * T * A * T^T. + /// (*this) and A are allowed to be the same. + /// If transM == kTrans, then we do it as alpha * T^T * A * T. + /// Currently it just calls AddMat2Sp, but if needed we + /// can implement it more efficiently. + void AddTp2Sp(const Real alpha, const TpMatrix<Real> &T, + MatrixTransposeType transM, const SpMatrix<Real> &A, + const Real beta = 0.0); + + /// The following function does: + /// this <-- beta*this + alpha * T * T^T. + /// (*this) and A are allowed to be the same. + /// If transM == kTrans, then we do it as alpha * T^T * T + /// Currently it just calls AddMat2, but if needed we + /// can implement it more efficiently. + void AddTp2(const Real alpha, const TpMatrix<Real> &T, + MatrixTransposeType transM, const Real beta = 0.0); + + /// Extension of rank-N update: + /// this <-- beta*this + alpha * M * diag(v) * M^T. + /// if transM == kTrans, then + /// this <-- beta*this + alpha * M^T * diag(v) * M. + void AddMat2Vec(const Real alpha, const MatrixBase<Real> &M, + MatrixTransposeType transM, const VectorBase<Real> &v, + const Real beta = 0.0); + + + /// Floors this symmetric matrix to the matrix + /// alpha * Floor, where the matrix Floor is positive + /// definite. + /// It is floored in the sense that after flooring, + /// x^T (*this) x >= x^T (alpha*Floor) x. + /// This is accomplished using an Svd. It will crash + /// if Floor is not positive definite. Returns the number of + /// elements that were floored. + int ApplyFloor(const SpMatrix<Real> &Floor, Real alpha = 1.0, + bool verbose = false); + + /// Floor: Given a positive semidefinite matrix, floors the eigenvalues + /// to the specified quantity. A previous version of this function had + /// a tolerance which is now no longer needed since we have code to + /// do the symmetric eigenvalue decomposition and no longer use the SVD + /// code for that purose. + int ApplyFloor(Real floor); + + bool IsDiagonal(Real cutoff = 1.0e-05) const; + bool IsUnit(Real cutoff = 1.0e-05) const; + bool IsZero(Real cutoff = 1.0e-05) const; + bool IsTridiagonal(Real cutoff = 1.0e-05) const; + + /// sqrt of sum of square elements. + Real FrobeniusNorm() const; + + /// Returns true if ((*this)-other).FrobeniusNorm() <= + /// tol*(*this).FrobeniusNorma() + bool ApproxEqual(const SpMatrix<Real> &other, float tol = 0.01) const; + + // LimitCond: + // Limits the condition of symmetric positive semidefinite matrix to + // a specified value + // by flooring all eigenvalues to a positive number which is some multiple + // of the largest one (or zero if there are no positive eigenvalues). + // Takes the condition number we are willing to accept, and floors + // eigenvalues to the largest eigenvalue divided by this. + // Returns #eigs floored or already equal to the floor. + // Throws exception if input is not positive definite. + // returns #floored. + MatrixIndexT LimitCond(Real maxCond = 1.0e+5, bool invert = false); + + // as LimitCond but all done in double precision. // returns #floored. + MatrixIndexT LimitCondDouble(Real maxCond = 1.0e+5, bool invert = false) { + SpMatrix<double> dmat(*this); + MatrixIndexT ans = dmat.LimitCond(maxCond, invert); + (*this).CopyFromSp(dmat); + return ans; + } + Real Trace() const; + + /// Tridiagonalize the matrix with an orthogonal transformation. If + /// *this starts as S, produce T (and Q, if non-NULL) such that + /// T = Q A Q^T, i.e. S = Q^T T Q. Caution: this is the other way + /// round from most authors (it's more efficient in row-major indexing). + void Tridiagonalize(MatrixBase<Real> *Q); + + /// The symmetric QR algorithm. This will mostly be useful in internal code. + /// Typically, you will call this after Tridiagonalize(), on the same object. + /// When called, *this (call it A at this point) must be tridiagonal; at exit, + /// *this will be a diagonal matrix D that is similar to A via orthogonal + /// transformations. This algorithm right-multiplies Q by orthogonal + /// transformations. It turns *this from a tridiagonal into a diagonal matrix + /// while maintaining that (Q *this Q^T) has the same value at entry and exit. + /// At entry Q should probably be either NULL or orthogonal, but we don't check + /// this. + void Qr(MatrixBase<Real> *Q); + + private: + void EigInternal(VectorBase<Real> *s, MatrixBase<Real> *P, + Real tolerance, int recurse) const; +}; + +/// @} end of "addtogroup matrix_group" + +/// \addtogroup matrix_funcs_scalar +/// @{ + + +/// Returns tr(A B). +float TraceSpSp(const SpMatrix<float> &A, const SpMatrix<float> &B); +double TraceSpSp(const SpMatrix<double> &A, const SpMatrix<double> &B); + + +template<typename Real> +inline bool ApproxEqual(const SpMatrix<Real> &A, + const SpMatrix<Real> &B, Real tol = 0.01) { + return A.ApproxEqual(B, tol); +} + +template<typename Real> +inline void AssertEqual(const SpMatrix<Real> &A, + const SpMatrix<Real> &B, Real tol = 0.01) { + KALDI_ASSERT(ApproxEqual(A, B, tol)); +} + + + +/// Returns tr(A B). +template<typename Real, typename OtherReal> +Real TraceSpSp(const SpMatrix<Real> &A, const SpMatrix<OtherReal> &B); + + + +// TraceSpSpLower is the same as Trace(A B) except the lower-diagonal elements +// are counted only once not twice as they should be. It is useful in certain +// optimizations. +template<typename Real> +Real TraceSpSpLower(const SpMatrix<Real> &A, const SpMatrix<Real> &B); + + +/// Returns tr(A B). +/// No option to transpose B because would make no difference. +template<typename Real> +Real TraceSpMat(const SpMatrix<Real> &A, const MatrixBase<Real> &B); + +/// Returns tr(A B C) +/// (A and C may be transposed as specified by transA and transC). +template<typename Real> +Real TraceMatSpMat(const MatrixBase<Real> &A, MatrixTransposeType transA, + const SpMatrix<Real> &B, const MatrixBase<Real> &C, + MatrixTransposeType transC); + +/// Returns tr (A B C D) +/// (A and C may be transposed as specified by transA and transB). +template<typename Real> +Real TraceMatSpMatSp(const MatrixBase<Real> &A, MatrixTransposeType transA, + const SpMatrix<Real> &B, const MatrixBase<Real> &C, + MatrixTransposeType transC, const SpMatrix<Real> &D); + +/** Computes v1^T * M * v2. Not as efficient as it could be where v1 == v2 + * (but no suitable blas routines available). + */ + +/// Returns \f$ v_1^T M v_2 \f$ +/// Not as efficient as it could be where v1 == v2. +template<typename Real> +Real VecSpVec(const VectorBase<Real> &v1, const SpMatrix<Real> &M, + const VectorBase<Real> &v2); + + +/// @} \addtogroup matrix_funcs_scalar + +/// \addtogroup matrix_funcs_misc +/// @{ + + +/// This class describes the options for maximizing various quadratic objective +/// functions. It's mostly as described in the SGMM paper "the subspace +/// Gaussian mixture model -- a structured model for speech recognition", but +/// the diagonal_precondition option is newly added, to handle problems where +/// different dimensions have very different scaling (we recommend to use the +/// option but it's set false for back compatibility). +struct SolverOptions { + BaseFloat K; // maximum condition number + BaseFloat eps; + std::string name; + bool optimize_delta; + bool diagonal_precondition; + bool print_debug_output; + explicit SolverOptions(const std::string &name): + K(1.0e+4), eps(1.0e-40), name(name), + optimize_delta(true), diagonal_precondition(false), + print_debug_output(true) { } + SolverOptions(): K(1.0e+4), eps(1.0e-40), name("[unknown]"), + optimize_delta(true), diagonal_precondition(false), + print_debug_output(true) { } + void Check() const; +}; + + +/// Maximizes the auxiliary function +/// \f[ Q(x) = x.g - 0.5 x^T H x \f] +/// using a numerically stable method. Like a numerically stable version of +/// \f$ x := Q^{-1} g. \f$ +/// Assumes H positive semidefinite. +/// Returns the objective-function change. + +template<typename Real> +Real SolveQuadraticProblem(const SpMatrix<Real> &H, + const VectorBase<Real> &g, + const SolverOptions &opts, + VectorBase<Real> *x); + + + +/// Maximizes the auxiliary function : +/// \f[ Q(x) = tr(M^T P Y) - 0.5 tr(P M Q M^T) \f] +/// Like a numerically stable version of \f$ M := Y Q^{-1} \f$. +/// Assumes Q and P positive semidefinite, and matrix dimensions match +/// enough to make expressions meaningful. +/// This is mostly as described in the SGMM paper "the subspace Gaussian mixture +/// model -- a structured model for speech recognition", but the +/// diagonal_precondition option is newly added, to handle problems +/// where different dimensions have very different scaling (we recommend to use +/// the option but it's set false for back compatibility). +template<typename Real> +Real SolveQuadraticMatrixProblem(const SpMatrix<Real> &Q, + const MatrixBase<Real> &Y, + const SpMatrix<Real> &P, + const SolverOptions &opts, + MatrixBase<Real> *M); + +/// Maximizes the auxiliary function : +/// \f[ Q(M) = tr(M^T G) -0.5 tr(P_1 M Q_1 M^T) -0.5 tr(P_2 M Q_2 M^T). \f] +/// Encountered in matrix update with a prior. We also apply a limit on the +/// condition but it should be less frequently necessary, and can be set larger. +template<typename Real> +Real SolveDoubleQuadraticMatrixProblem(const MatrixBase<Real> &G, + const SpMatrix<Real> &P1, + const SpMatrix<Real> &P2, + const SpMatrix<Real> &Q1, + const SpMatrix<Real> &Q2, + const SolverOptions &opts, + MatrixBase<Real> *M); + + +/// @} End of "addtogroup matrix_funcs_misc" + +} // namespace kaldi + + +// Including the implementation (now actually just includes some +// template specializations). +#include "matrix/sp-matrix-inl.h" + + +#endif // KALDI_MATRIX_SP_MATRIX_H_ + diff --git a/kaldi_io/src/kaldi/matrix/srfft.h b/kaldi_io/src/kaldi/matrix/srfft.h new file mode 100644 index 0000000..c0d36af --- /dev/null +++ b/kaldi_io/src/kaldi/matrix/srfft.h @@ -0,0 +1,132 @@ +// matrix/srfft.h + +// Copyright 2009-2011 Microsoft Corporation; Go Vivace Inc. +// 2014 Daniel Povey +// +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. +// +// This file includes a modified version of code originally published in Malvar, +// H., "Signal processing with lapped transforms, " Artech House, Inc., 1992. The +// current copyright holder of the original code, Henrique S. Malvar, has given +// his permission for the release of this modified version under the Apache +// License v2.0. + +#ifndef KALDI_MATRIX_SRFFT_H_ +#define KALDI_MATRIX_SRFFT_H_ + +#include "matrix/kaldi-vector.h" +#include "matrix/kaldi-matrix.h" + +namespace kaldi { + +/// @addtogroup matrix_funcs_misc +/// @{ + + +// This class is based on code by Henrique (Rico) Malvar, from his book +// "Signal Processing with Lapped Transforms" (1992). Copied with +// permission, optimized by Go Vivace Inc., and converted into C++ by +// Microsoft Corporation +// This is a more efficient way of doing the complex FFT than ComplexFft +// (declared in matrix-functios.h), but it only works for powers of 2. +// Note: in multi-threaded code, you would need to have one of these objects per +// thread, because multiple calls to Compute in parallel would not work. +template<typename Real> +class SplitRadixComplexFft { + public: + typedef MatrixIndexT Integer; + + // N is the number of complex points (must be a power of two, or this + // will crash). Note that the constructor does some work so it's best to + // initialize the object once and do the computation many times. + SplitRadixComplexFft(Integer N); + + // Does the FFT computation, given pointers to the real and + // imaginary parts. If "forward", do the forward FFT; else + // do the inverse FFT (without the 1/N factor). + // xr and xi are pointers to zero-based arrays of size N, + // containing the real and imaginary parts + // respectively. + void Compute(Real *xr, Real *xi, bool forward) const; + + // This version of Compute takes a single array of size N*2, + // containing [ r0 im0 r1 im1 ... ]. Otherwise its behavior is the + // same as the version above. + void Compute(Real *x, bool forward); + + + // This version of Compute is const; it operates on an array of size N*2 + // containing [ r0 im0 r1 im1 ... ], but it uses the argument "temp_buffer" as + // temporary storage instead of a class-member variable. It will allocate it if + // needed. + void Compute(Real *x, bool forward, std::vector<Real> *temp_buffer) const; + + ~SplitRadixComplexFft(); + + protected: + // temp_buffer_ is allocated only if someone calls Compute with only one Real* + // argument and we need a temporary buffer while creating interleaved data. + std::vector<Real> temp_buffer_; + private: + void ComputeTables(); + void ComputeRecursive(Real *xr, Real *xi, Integer logn) const; + void BitReversePermute(Real *x, Integer logn) const; + + Integer N_; + Integer logn_; // log(N) + + Integer *brseed_; + // brseed is Evans' seed table, ref: (Ref: D. M. W. + // Evans, "An improved digit-reversal permutation algorithm ...", + // IEEE Trans. ASSP, Aug. 1987, pp. 1120-1125). + Real **tab_; // Tables of butterfly coefficients. + + KALDI_DISALLOW_COPY_AND_ASSIGN(SplitRadixComplexFft); +}; + +template<typename Real> +class SplitRadixRealFft: private SplitRadixComplexFft<Real> { + public: + SplitRadixRealFft(MatrixIndexT N): // will fail unless N>=4 and N is a power of 2. + SplitRadixComplexFft<Real> (N/2), N_(N) { } + + /// If forward == true, this function transforms from a sequence of N real points to its complex fourier + /// transform; otherwise it goes in the reverse direction. If you call it + /// in the forward and then reverse direction and multiply by 1.0/N, you + /// will get back the original data. + /// The interpretation of the complex-FFT data is as follows: the array + /// is a sequence of complex numbers C_n of length N/2 with (real, im) format, + /// i.e. [real0, real_{N/2}, real1, im1, real2, im2, real3, im3, ...]. + void Compute(Real *x, bool forward); + + + /// This is as the other Compute() function, but it is a const version that + /// uses a user-supplied buffer. + void Compute(Real *x, bool forward, std::vector<Real> *temp_buffer) const; + + private: + KALDI_DISALLOW_COPY_AND_ASSIGN(SplitRadixRealFft); + int N_; +}; + + +/// @} end of "addtogroup matrix_funcs_misc" + +} // end namespace kaldi + + +#endif + diff --git a/kaldi_io/src/kaldi/matrix/tp-matrix.h b/kaldi_io/src/kaldi/matrix/tp-matrix.h new file mode 100644 index 0000000..f43e86c --- /dev/null +++ b/kaldi_io/src/kaldi/matrix/tp-matrix.h @@ -0,0 +1,131 @@ +// matrix/tp-matrix.h + +// Copyright 2009-2011 Ondrej Glembek; Lukas Burget; Microsoft Corporation; +// Saarland University; Yanmin Qian; Haihua Xu +// 2013 Johns Hopkins Universith (author: Daniel Povey) + + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at + +// http://www.apache.org/licenses/LICENSE-2.0 + +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. +#ifndef KALDI_MATRIX_TP_MATRIX_H_ +#define KALDI_MATRIX_TP_MATRIX_H_ + + +#include "matrix/packed-matrix.h" + +namespace kaldi { +/// \addtogroup matrix_group +/// @{ + +template<typename Real> class TpMatrix; + +/// @brief Packed symetric matrix class +template<typename Real> +class TpMatrix : public PackedMatrix<Real> { + friend class CuTpMatrix<float>; + friend class CuTpMatrix<double>; + public: + TpMatrix() : PackedMatrix<Real>() {} + explicit TpMatrix(MatrixIndexT r, MatrixResizeType resize_type = kSetZero) + : PackedMatrix<Real>(r, resize_type) {} + TpMatrix(const TpMatrix<Real>& orig) : PackedMatrix<Real>(orig) {} + + /// Copy constructor from CUDA TpMatrix + /// This is defined in ../cudamatrix/cu-tp-matrix.cc + explicit TpMatrix(const CuTpMatrix<Real> &cu); + + + template<typename OtherReal> explicit TpMatrix(const TpMatrix<OtherReal>& orig) + : PackedMatrix<Real>(orig) {} + + Real operator() (MatrixIndexT r, MatrixIndexT c) const { + if (static_cast<UnsignedMatrixIndexT>(c) > + static_cast<UnsignedMatrixIndexT>(r)) { + KALDI_ASSERT(static_cast<UnsignedMatrixIndexT>(c) < + static_cast<UnsignedMatrixIndexT>(this->num_rows_)); + return 0; + } + KALDI_ASSERT(static_cast<UnsignedMatrixIndexT>(r) < + static_cast<UnsignedMatrixIndexT>(this->num_rows_)); + // c<=r now so don't have to check c. + return *(this->data_ + (r*(r+1)) / 2 + c); + // Duplicating code from PackedMatrix.h + } + + Real &operator() (MatrixIndexT r, MatrixIndexT c) { + KALDI_ASSERT(static_cast<UnsignedMatrixIndexT>(r) < + static_cast<UnsignedMatrixIndexT>(this->num_rows_)); + KALDI_ASSERT(static_cast<UnsignedMatrixIndexT>(c) <= + static_cast<UnsignedMatrixIndexT>(r) && + "you cannot access the upper triangle of TpMatrix using " + "a non-const matrix object."); + return *(this->data_ + (r*(r+1)) / 2 + c); + // Duplicating code from PackedMatrix.h + } + // Note: Cholesky may throw std::runtime_error + void Cholesky(const SpMatrix<Real>& orig); + + void Invert(); + + // Inverts in double precision. + void InvertDouble() { + TpMatrix<double> dmat(*this); + dmat.Invert(); + (*this).CopyFromTp(dmat); + } + + /// Shallow swap + void Swap(TpMatrix<Real> *other); + + /// Returns the determinant of the matrix (product of diagonals) + Real Determinant(); + + /// CopyFromMat copies the lower triangle of M into *this + /// (or the upper triangle, if Trans == kTrans). + void CopyFromMat(const MatrixBase<Real> &M, + MatrixTransposeType Trans = kNoTrans); + + /// This is implemented in ../cudamatrix/cu-tp-matrix.cc + void CopyFromMat(const CuTpMatrix<Real> &other); + + /// CopyFromTp copies another triangular matrix into this one. + void CopyFromTp(const TpMatrix<Real> &other) { + PackedMatrix<Real>::CopyFromPacked(other); + } + + template<typename OtherReal> void CopyFromTp(const TpMatrix<OtherReal> &other) { + PackedMatrix<Real>::CopyFromPacked(other); + } + + /// AddTp does *this += alpha * M. + void AddTp(const Real alpha, const TpMatrix<Real> &M) { + this->AddPacked(alpha, M); + } + + using PackedMatrix<Real>::operator =; + using PackedMatrix<Real>::Scale; + + void Resize(MatrixIndexT nRows, MatrixResizeType resize_type = kSetZero) { + PackedMatrix<Real>::Resize(nRows, resize_type); + } +}; + +/// @} end of "addtogroup matrix_group". + +} // namespace kaldi + + +#endif + diff --git a/kaldi_io/src/kaldi/tree/build-tree-questions.h b/kaldi_io/src/kaldi/tree/build-tree-questions.h new file mode 100644 index 0000000..a6bcfdd --- /dev/null +++ b/kaldi_io/src/kaldi/tree/build-tree-questions.h @@ -0,0 +1,133 @@ +// tree/build-tree-questions.h + +// Copyright 2009-2011 Microsoft Corporation + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#ifndef KALDI_TREE_BUILD_TREE_QUESTIONS_H_ +#define KALDI_TREE_BUILD_TREE_QUESTIONS_H_ + +#include "util/stl-utils.h" +#include "tree/context-dep.h" + +namespace kaldi { + + +/// \addtogroup tree_group +/// @{ +/// Typedef for statistics to build trees. +typedef std::vector<std::pair<EventType, Clusterable*> > BuildTreeStatsType; + +/// Typedef used when we get "all keys" from a set of stats-- used in specifying +/// which kinds of questions to ask. +typedef enum { kAllKeysInsistIdentical, kAllKeysIntersection, kAllKeysUnion } AllKeysType; + +/// @} + +/// \defgroup tree_group_questions Question sets for decision-tree clustering +/// See \ref tree_internals (and specifically \ref treei_func_questions) for context. +/// \ingroup tree_group +/// @{ + +/// QuestionsForKey is a class used to define the questions for a key, +/// and also options that allow us to refine the question during tree-building +/// (i.e. make a question specific to the location in the tree). +/// The Questions class handles aggregating these options for a set +/// of different keys. +struct QuestionsForKey { // Configuration class associated with a particular key + // (of type EventKeyType). It also contains the questions themselves. + std::vector<std::vector<EventValueType> > initial_questions; + RefineClustersOptions refine_opts; // if refine_opts.max_iter == 0, + // we just pick from the initial questions. + + QuestionsForKey(int32 num_iters = 5): refine_opts(num_iters, 2) { + // refine_cfg with 5 iters and top-n = 2 (this is no restriction because + // RefineClusters called with 2 clusters; would get set to that anyway as + // it's the only possible value for 2 clusters). User has to add questions. + // This config won't work as-is, as it has no questions. + } + + void Check() const { + for (size_t i = 0;i < initial_questions.size();i++) KALDI_ASSERT(IsSorted(initial_questions[i])); + } + + void Write(std::ostream &os, bool binary) const; + void Read(std::istream &is, bool binary); + + // copy and assign allowed. +}; + +/// This class defines, for each EventKeyType, a set of initial questions that +/// it tries and also a number of iterations for which to refine the questions to increase +/// likelihood. It is perhaps a bit more than an options class, as it contains the +/// actual questions. +class Questions { // careful, this is a class. + public: + const QuestionsForKey &GetQuestionsOf(EventKeyType key) const { + std::map<EventKeyType, size_t>::const_iterator iter; + if ( (iter = key_idx_.find(key)) == key_idx_.end()) { + KALDI_ERR << "Questions: no options for key "<< key; + } + size_t idx = iter->second; + KALDI_ASSERT(idx < key_options_.size()); + key_options_[idx]->Check(); + return *(key_options_[idx]); + } + void SetQuestionsOf(EventKeyType key, const QuestionsForKey &options_of_key) { + options_of_key.Check(); + if (key_idx_.count(key) == 0) { + key_idx_[key] = key_options_.size(); + key_options_.push_back(new QuestionsForKey()); + *(key_options_.back()) = options_of_key; + } else { + size_t idx = key_idx_[key]; + KALDI_ASSERT(idx < key_options_.size()); + *(key_options_[idx]) = options_of_key; + } + } + void GetKeysWithQuestions(std::vector<EventKeyType> *keys_out) const { + KALDI_ASSERT(keys_out != NULL); + CopyMapKeysToVector(key_idx_, keys_out); + } + const bool HasQuestionsForKey(EventKeyType key) const { return (key_idx_.count(key) != 0); } + ~Questions() { kaldi::DeletePointers(&key_options_); } + + + /// Initializer with arguments. After using this you would have to set up the config for each key you + /// are going to use, or use InitRand(). + Questions() { } + + + /// InitRand attempts to generate "reasonable" random questions. Only + /// of use for debugging. This initializer creates a config that is + /// ready to use. + /// e.g. num_iters_refine = 0 means just use stated questions (if >1, will use + /// different questions at each split of the tree). + void InitRand(const BuildTreeStatsType &stats, int32 num_quest, int32 num_iters_refine, AllKeysType all_keys_type); + + void Write(std::ostream &os, bool binary) const; + void Read(std::istream &is, bool binary); + private: + std::vector<QuestionsForKey*> key_options_; + std::map<EventKeyType, size_t> key_idx_; + KALDI_DISALLOW_COPY_AND_ASSIGN(Questions); +}; + +/// @} + +}// end namespace kaldi + +#endif // KALDI_TREE_BUILD_TREE_QUESTIONS_H_ diff --git a/kaldi_io/src/kaldi/tree/build-tree-utils.h b/kaldi_io/src/kaldi/tree/build-tree-utils.h new file mode 100644 index 0000000..464fc6b --- /dev/null +++ b/kaldi_io/src/kaldi/tree/build-tree-utils.h @@ -0,0 +1,324 @@ +// tree/build-tree-utils.h + +// Copyright 2009-2011 Microsoft Corporation + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#ifndef KALDI_TREE_BUILD_TREE_UTILS_H_ +#define KALDI_TREE_BUILD_TREE_UTILS_H_ + +#include "tree/build-tree-questions.h" + +// build-tree-questions.h needed for this typedef: +// typedef std::vector<std::pair<EventType, Clusterable*> > BuildTreeStatsType; +// and for other #includes. + +namespace kaldi { + + +/// \defgroup tree_group_lower Low-level functions for manipulating statistics and event-maps +/// See \ref tree_internals and specifically \ref treei_func for context. +/// \ingroup tree_group +/// +/// @{ + + + +/// This frees the Clusterable* pointers in "stats", where non-NULL, and sets them to NULL. +/// Does not delete the pointer "stats" itself. +void DeleteBuildTreeStats(BuildTreeStatsType *stats); + +/// Writes BuildTreeStats object. This works even if pointers are NULL. +void WriteBuildTreeStats(std::ostream &os, bool binary, + const BuildTreeStatsType &stats); + +/// Reads BuildTreeStats object. The "example" argument must be of the same +/// type as the stats on disk, and is needed for access to the correct "Read" +/// function. It was organized this way for easier extensibility (so adding new +/// Clusterable derived classes isn't painful) +void ReadBuildTreeStats(std::istream &is, bool binary, + const Clusterable &example, BuildTreeStatsType *stats); + +/// Convenience function e.g. to work out possible values of the phones from just the stats. +/// Returns true if key was always defined inside the stats. +/// May be used with and == NULL to find out of key was always defined. +bool PossibleValues(EventKeyType key, const BuildTreeStatsType &stats, + std::vector<EventValueType> *ans); + + +/// Splits stats according to the EventMap, indexing them at output by the +/// leaf type. A utility function. NOTE-- pointers in stats_out point to +/// the same memory location as those in stats. No copying of Clusterable* +/// objects happens. Will add to stats in stats_out if non-empty at input. +/// This function may increase the size of vector stats_out as necessary +/// to accommodate stats, but will never decrease the size. +void SplitStatsByMap(const BuildTreeStatsType &stats_in, const EventMap &e, + std::vector<BuildTreeStatsType> *stats_out); + +/// SplitStatsByKey splits stats up according to the value of a particular key, +/// which must be always defined and nonnegative. Like MapStats. Pointers to +/// Clusterable* in stats_out are not newly allocated-- they are the same as the +/// ones in stats_in. Generally they will still be owned at stats_in (user can +/// decide where to allocate ownership). +void SplitStatsByKey(const BuildTreeStatsType &stats_in, EventKeyType key, + std::vector<BuildTreeStatsType> *stats_out); + + +/// Converts stats from a given context-window (N) and central-position (P) to a +/// different N and P, by possibly reducing context. This function does a job +/// that's quite specific to the "normal" stats format we use. See \ref +/// tree_window for background. This function may delete some keys and change +/// others, depending on the N and P values. It expects that at input, all keys +/// will either be -1 or lie between 0 and oldN-1. At output, keys will be +/// either -1 or between 0 and newN-1. +/// Returns false if we could not convert the stats (e.g. because newN is larger +/// than oldN). +bool ConvertStats(int32 oldN, int32 oldP, int32 newN, int32 newP, + BuildTreeStatsType *stats); + + +/// FilterStatsByKey filters the stats according the value of a specified key. +/// If include_if_present == true, it only outputs the stats whose key is in +/// "values"; otherwise it only outputs the stats whose key is not in "values". +/// At input, "values" must be sorted and unique, and all stats in "stats_in" +/// must have "key" defined. At output, pointers to Clusterable* in stats_out +/// are not newly allocated-- they are the same as the ones in stats_in. +void FilterStatsByKey(const BuildTreeStatsType &stats_in, + EventKeyType key, + std::vector<EventValueType> &values, + bool include_if_present, // true-> retain only if in "values", + // false-> retain only if not in "values". + BuildTreeStatsType *stats_out); + + +/// Sums stats, or returns NULL stats_in has no non-NULL stats. +/// Stats are newly allocated, owned by caller. +Clusterable *SumStats(const BuildTreeStatsType &stats_in); + +/// Sums the normalizer [typically, data-count] over the stats. +BaseFloat SumNormalizer(const BuildTreeStatsType &stats_in); + +/// Sums the objective function over the stats. +BaseFloat SumObjf(const BuildTreeStatsType &stats_in); + + +/// Sum a vector of stats. Leaves NULL as pointer if no stats available. +/// The pointers in stats_out are owned by caller. At output, there may be +/// NULLs in the vector stats_out. +void SumStatsVec(const std::vector<BuildTreeStatsType> &stats_in, std::vector<Clusterable*> *stats_out); + +/// Cluster the stats given the event map return the total objf given those clusters. +BaseFloat ObjfGivenMap(const BuildTreeStatsType &stats_in, const EventMap &e); + + +/// FindAllKeys puts in *keys the (sorted, unique) list of all key identities in the stats. +/// If type == kAllKeysInsistIdentical, it will insist that this set of keys is the same for all the +/// stats (else exception is thrown). +/// if type == kAllKeysIntersection, it will return the smallest common set of keys present in +/// the set of stats +/// if type== kAllKeysUnion (currently probably not so useful since maps will return "undefined" +/// if key is not present), it will return the union of all the keys present in the stats. +void FindAllKeys(const BuildTreeStatsType &stats, AllKeysType keys_type, + std::vector<EventKeyType> *keys); + + +/// @} + + +/** + \defgroup tree_group_intermediate Intermediate-level functions used in building the tree + These functions are are used in top-level tree-building code (\ref tree_group_top); see + \ref tree_internals for documentation. + \ingroup tree_group + @{ +*/ + + +/// Returns a tree with just one node. Used @ start of tree-building process. +/// Not really used in current recipes. +inline EventMap *TrivialTree(int32 *num_leaves) { + KALDI_ASSERT(*num_leaves == 0); // in envisaged usage. + return new ConstantEventMap( (*num_leaves)++ ); +} + +/// DoTableSplit does a complete split on this key (e.g. might correspond to central phone +/// (key = P-1), or HMM-state position (key == kPdfClass == -1). Stats used to work out possible +/// values of the event. "num_leaves" is used to allocate new leaves. All stats must have +/// this key defined, or this function will crash. +EventMap *DoTableSplit(const EventMap &orig, EventKeyType key, + const BuildTreeStatsType &stats, int32 *num_leaves); + + +/// DoTableSplitMultiple does a complete split on all the keys, in order from keys[0], +/// keys[1] +/// and so on. The stats are used to work out possible values corresponding to the key. +/// "num_leaves" is used to allocate new leaves. All stats must have +/// the keys defined, or this function will crash. +/// Returns a newly allocated event map. +EventMap *DoTableSplitMultiple(const EventMap &orig, + const std::vector<EventKeyType> &keys, + const BuildTreeStatsType &stats, + int32 *num_leaves); + + +/// "ClusterEventMapGetMapping" clusters the leaves of the EventMap, with "thresh" a delta-likelihood +/// threshold to control how many leaves we combine (might be the same as the delta-like +/// threshold used in splitting. +// The function returns the #leaves we combined. The same leaf-ids of the leaves being clustered +// will be used for the clustered leaves (but other than that there is no special rule which +// leaf-ids should be used at output). +// It outputs the mapping for leaves, in "mapping", which may be empty at the start +// but may also contain mappings for other parts of the tree, which must contain +// disjoint leaves from this part. This is so that Cluster can +// be called multiple times for sub-parts of the tree (with disjoint sets of leaves), +// e.g. if we want to avoid sharing across phones. Afterwards you can use Copy function +// of EventMap to apply the mapping, i.e. call e_in.Copy(mapping) to get the new map. +// Note that the application of Cluster creates gaps in the leaves. You should then +// call RenumberEventMap(e_in.Copy(mapping), num_leaves). +// *If you only want to cluster a subset of the leaves (e.g. just non-silence, or just +// a particular phone, do this by providing a set of "stats" that correspond to just +// this subset of leaves*. Leaves with no stats will not be clustered. +// See build-tree.cc for an example of usage. +int ClusterEventMapGetMapping(const EventMap &e_in, const BuildTreeStatsType &stats, + BaseFloat thresh, std::vector<EventMap*> *mapping); + +/// This is as ClusterEventMapGetMapping but a more convenient interface +/// that exposes less of the internals. It uses a bottom-up clustering to +/// combine the leaves, until the log-likelihood decrease from combinging two +/// leaves exceeds the threshold. +EventMap *ClusterEventMap(const EventMap &e_in, const BuildTreeStatsType &stats, + BaseFloat thresh, int32 *num_removed); + +/// This is as ClusterEventMap, but first splits the stats on the keys specified +/// in "keys" (e.g. typically keys = [ -1, P ]), and only clusters within the +/// classes defined by that splitting. +/// Note-- leaves will be non-consecutive at output, use RenumberEventMap. +EventMap *ClusterEventMapRestrictedByKeys(const EventMap &e_in, + const BuildTreeStatsType &stats, + BaseFloat thresh, + const std::vector<EventKeyType> &keys, + int32 *num_removed); + + +/// This version of ClusterEventMapRestricted restricts the clustering to only +/// allow things that "e_restrict" maps to the same value to be clustered +/// together. +EventMap *ClusterEventMapRestrictedByMap(const EventMap &e_in, + const BuildTreeStatsType &stats, + BaseFloat thresh, + const EventMap &e_restrict, + int32 *num_removed); + + +/// RenumberEventMap [intended to be used after calling ClusterEventMap] renumbers +/// an EventMap so its leaves are consecutive. +/// It puts the number of leaves in *num_leaves. If later you need the mapping of +/// the leaves, modify the function and add a new argument. +EventMap *RenumberEventMap(const EventMap &e_in, int32 *num_leaves); + +/// This function remaps the event-map leaves using this mapping, +/// indexed by the number at leaf. +EventMap *MapEventMapLeaves(const EventMap &e_in, + const std::vector<int32> &mapping); + + + +/// ShareEventMapLeaves performs a quite specific function that allows us to +/// generate trees where, for a certain list of phones, and for all states in +/// the phone, all the pdf's are shared. +/// Each element of "values" contains a list of phones (may be just one phone), +/// all states of which we want shared together). Typically at input, "key" will +/// equal P, the central-phone position, and "values" will contain just one +/// list containing the silence phone. +/// This function renumbers the event map leaves after doing the sharing, to +/// make the event-map leaves contiguous. +EventMap *ShareEventMapLeaves(const EventMap &e_in, EventKeyType key, + std::vector<std::vector<EventValueType> > &values, + int32 *num_leaves); + + + +/// Does a decision-tree split at the leaves of an EventMap. +/// @param orig [in] The EventMap whose leaves we want to split. [may be either a trivial or a +/// non-trivial one]. +/// @param stats [in] The statistics for splitting the tree; if you do not want a particular +/// subset of leaves to be split, make sure the stats corresponding to those leaves +/// are not present in "stats". +/// @param qcfg [in] Configuration class that contains initial questions (e.g. sets of phones) +/// for each key and says whether to refine these questions during tree building. +/// @param thresh [in] A log-likelihood threshold (e.g. 300) that can be used to +/// limit the number of leaves; you can use zero and set max_leaves instead. +/// @param max_leaves [in] Will stop leaves being split after they reach this number. +/// @param num_leaves [in,out] A pointer used to allocate leaves; always corresponds to the +/// current number of leaves (is incremented when this is increased). +/// @param objf_impr_out [out] If non-NULL, will be set to the objective improvement due to splitting +/// (not normalized by the number of frames). +/// @param smallest_split_change_out If non-NULL, will be set to the smallest objective-function +/// improvement that we got from splitting any leaf; useful to provide a threshold +/// for ClusterEventMap. +/// @return The EventMap after splitting is returned; pointer is owned by caller. +EventMap *SplitDecisionTree(const EventMap &orig, + const BuildTreeStatsType &stats, + Questions &qcfg, + BaseFloat thresh, + int32 max_leaves, // max_leaves<=0 -> no maximum. + int32 *num_leaves, + BaseFloat *objf_impr_out, + BaseFloat *smallest_split_change_out); + +/// CreateRandomQuestions will initialize a Questions randomly, in a reasonable +/// way [for testing purposes, or when hand-designed questions are not available]. +/// e.g. num_quest = 5 might be a reasonable value if num_iters > 0, or num_quest = 20 otherwise. +void CreateRandomQuestions(const BuildTreeStatsType &stats, int32 num_quest, Questions *cfg_out); + + +/// FindBestSplitForKey is a function used in DoDecisionTreeSplit. +/// It finds the best split for this key, given these stats. +/// It will return 0 if the key was not always defined for the stats. +BaseFloat FindBestSplitForKey(const BuildTreeStatsType &stats, + const Questions &qcfg, + EventKeyType key, + std::vector<EventValueType> *yes_set); + + +/// GetStubMap is used in tree-building functions to get the initial +/// to-states map, before the decision-tree-building process. It creates +/// a simple map that splits on groups of phones. For the set of phones in +/// phone_sets[i] it creates either: if share_roots[i] == true, a single +/// leaf node, or if share_roots[i] == false, separate root nodes for +/// each HMM-position (it goes up to the highest position for any +/// phone in the set, although it will warn if you share roots between +/// phones with different numbers of states, which is a weird thing to +/// do but should still work. If any phone is present +/// in "phone_sets" but "phone2num_pdf_classes" does not map it to a length, +/// it is an error. Note that the behaviour of the resulting map is +/// undefined for phones not present in "phone_sets". +/// At entry, this function should be called with (*num_leaves == 0). +/// It will number the leaves starting from (*num_leaves). + +EventMap *GetStubMap(int32 P, + const std::vector<std::vector<int32> > &phone_sets, + const std::vector<int32> &phone2num_pdf_classes, + const std::vector<bool> &share_roots, // indexed by index into phone_sets. + int32 *num_leaves); +/// Note: GetStubMap with P = 0 can be used to get a standard monophone system. + +/// @} + + +}// end namespace kaldi + +#endif diff --git a/kaldi_io/src/kaldi/tree/build-tree.h b/kaldi_io/src/kaldi/tree/build-tree.h new file mode 100644 index 0000000..37bb108 --- /dev/null +++ b/kaldi_io/src/kaldi/tree/build-tree.h @@ -0,0 +1,250 @@ +// tree/build-tree.h + +// Copyright 2009-2011 Microsoft Corporation + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#ifndef KALDI_TREE_BUILD_TREE_H_ +#define KALDI_TREE_BUILD_TREE_H_ + +// The file build-tree.h contains outer-level routines used in tree-building +// and related tasks, that are directly called by the command-line tools. + +#include "tree/build-tree-utils.h" +#include "tree/context-dep.h" +namespace kaldi { + +/// \defgroup tree_group_top Top-level tree-building functions +/// See \ref tree_internals for context. +/// \ingroup tree_group +/// @{ + +// Note, in tree_group_top we also include AccumulateTreeStats, in +// ../hmm/tree-accu.h (it has some extra dependencies so we didn't +// want to include it here). + +/** + * BuildTree is the normal way to build a set of decision trees. + * The sets "phone_sets" dictate how we set up the roots of the decision trees. + * each set of phones phone_sets[i] has shared decision-tree roots, and if + * the corresponding variable share_roots[i] is true, the root will be shared + * for the different HMM-positions in the phone. All phones in "phone_sets" + * should be in the stats (use FixUnseenPhones to ensure this). + * if for any i, do_split[i] is false, we will not do any tree splitting for + * phones in that set. + * @param qopts [in] Questions options class, contains questions for each key + * (e.g. each phone position) + * @param phone_sets [in] Each element of phone_sets is a set of phones whose + * roots are shared together (prior to decision-tree splitting). + * @param phone2num_pdf_classes [in] A map from phones to the number of + * \ref pdf_class "pdf-classes" + * in the phone (this info is derived from the HmmTopology object) + * @param share_roots [in] A vector the same size as phone_sets; says for each + * phone set whether the root should be shared among all the + * pdf-classes or not. + * @param do_split [in] A vector the same size as phone_sets; says for each + * phone set whether decision-tree splitting should be done + * (generally true for non-silence phones). + * @param stats [in] The statistics used in tree-building. + * @param thresh [in] Threshold used in decision-tree splitting (e.g. 1000), + * or you may use 0 in which case max_leaves becomes the + * constraint. + * @param max_leaves [in] Maximum number of leaves it will create; set this + * to a large number if you want to just specify "thresh". + * @param cluster_thresh [in] Threshold for clustering leaves after decision-tree + * splitting (only within each phone-set); leaves will be combined + * if log-likelihood change is less than this. A value about equal + * to "thresh" is suitable + * if thresh != 0; otherwise, zero will mean no clustering is done, + * or a negative value (e.g. -1) sets it to the smallest likelihood + * change seen during the splitting algorithm; this typically causes + * about a 20% reduction in the number of leaves. + + * @param P [in] The central position of the phone context window, e.g. 1 for a + * triphone system. + * @return Returns a pointer to an EventMap object that is the tree. + +*/ + +EventMap *BuildTree(Questions &qopts, + const std::vector<std::vector<int32> > &phone_sets, + const std::vector<int32> &phone2num_pdf_classes, + const std::vector<bool> &share_roots, + const std::vector<bool> &do_split, + const BuildTreeStatsType &stats, + BaseFloat thresh, + int32 max_leaves, + BaseFloat cluster_thresh, // typically == thresh. If negative, use smallest split. + int32 P); + + +/** + * + * BuildTreeTwoLevel builds a two-level tree, useful for example in building tied mixture + * systems with multiple codebooks. It first builds a small tree by splitting to + * "max_leaves_first". It then splits at the leaves of "max_leaves_first" (think of this + * as creating multiple little trees at the leaves of the first tree), until the total + * number of leaves reaches "max_leaves_second". It then outputs the second tree, along + * with a mapping from the leaf-ids of the second tree to the leaf-ids of the first tree. + * Note that the interface is similar to BuildTree, and in fact it calls BuildTree + * internally. + * + * The sets "phone_sets" dictate how we set up the roots of the decision trees. + * each set of phones phone_sets[i] has shared decision-tree roots, and if + * the corresponding variable share_roots[i] is true, the root will be shared + * for the different HMM-positions in the phone. All phones in "phone_sets" + * should be in the stats (use FixUnseenPhones to ensure this). + * if for any i, do_split[i] is false, we will not do any tree splitting for + * phones in that set. + * + * @param qopts [in] Questions options class, contains questions for each key + * (e.g. each phone position) + * @param phone_sets [in] Each element of phone_sets is a set of phones whose + * roots are shared together (prior to decision-tree splitting). + * @param phone2num_pdf_classes [in] A map from phones to the number of + * \ref pdf_class "pdf-classes" + * in the phone (this info is derived from the HmmTopology object) + * @param share_roots [in] A vector the same size as phone_sets; says for each + * phone set whether the root should be shared among all the + * pdf-classes or not. + * @param do_split [in] A vector the same size as phone_sets; says for each + * phone set whether decision-tree splitting should be done + * (generally true for non-silence phones). + * @param stats [in] The statistics used in tree-building. + * @param max_leaves_first [in] Maximum number of leaves it will create in first + * level of decision tree. + * @param max_leaves_second [in] Maximum number of leaves it will create in second + * level of decision tree. Must be > max_leaves_first. + * @param cluster_leaves [in] Boolean value; if true, we post-cluster the leaves produced + * in the second level of decision-tree split; if false, we don't. + * The threshold for post-clustering is the log-like change of the last + * decision-tree split; this typically causes about a 20% reduction in + * the number of leaves. + * @param P [in] The central position of the phone context window, e.g. 1 for a + * triphone system. + * @param leaf_map [out] Will be set to be a mapping from the leaves of the + * "big" tree to the leaves of the "little" tree, which you can + * view as cluster centers. + * @return Returns a pointer to an EventMap object that is the (big) tree. + +*/ + +EventMap *BuildTreeTwoLevel(Questions &qopts, + const std::vector<std::vector<int32> > &phone_sets, + const std::vector<int32> &phone2num_pdf_classes, + const std::vector<bool> &share_roots, + const std::vector<bool> &do_split, + const BuildTreeStatsType &stats, + int32 max_leaves_first, + int32 max_leaves_second, + bool cluster_leaves, + int32 P, + std::vector<int32> *leaf_map); + + +/// GenRandStats generates random statistics of the form used by BuildTree. +/// It tries to do so in such a way that they mimic "real" stats. The event keys +/// and their corresponding values are: +/// - key == -1 == kPdfClass -> pdf-class, generally corresponds to +/// zero-based position in HMM (0, 1, 2 .. hmm_lengths[phone]-1) +/// - key == 0 -> phone-id of left-most context phone. +/// - key == 1 -> phone-id of one-from-left-most context phone. +/// - key == P-1 -> phone-id of central phone. +/// - key == N-1 -> phone-id of right-most context phone. +/// GenRandStats is useful only for testing but it serves to document the format of +/// stats used by BuildTreeDefault. +/// if is_ctx_dep[phone] is set to false, GenRandStats will not define the keys for +/// other than the P-1'th phone. + +/// @param dim [in] dimension of features. +/// @param num_stats [in] approximate number of separate phones-in-context wanted. +/// @param N [in] context-size (typically 3) +/// @param P [in] central-phone position in zero-based numbering (typically 1) +/// @param phone_ids [in] integer ids of phones +/// @param hmm_lengths [in] lengths of hmm for phone, indexed by phone. +/// @param is_ctx_dep [in] boolean array indexed by phone, saying whether each phone +/// is context dependent. +/// @param ensure_all_phones_covered [in] Boolean argument: if true, GenRandStats +/// ensures that every phone is seen at least once in the central position (P). +/// @param stats_out [out] The statistics that this routine outputs. + +void GenRandStats(int32 dim, int32 num_stats, int32 N, int32 P, + const std::vector<int32> &phone_ids, + const std::vector<int32> &hmm_lengths, + const std::vector<bool> &is_ctx_dep, + bool ensure_all_phones_covered, + BuildTreeStatsType *stats_out); + + +/// included here because it's used in some tree-building +/// calling code. Reads an OpenFst symbl table, +/// discards the symbols and outputs the integers +void ReadSymbolTableAsIntegers(std::string filename, + bool include_eps, + std::vector<int32> *syms); + + + +/** + * Outputs sets of phones that are reasonable for questions + * to ask in the tree-building algorithm. These are obtained by tree + * clustering of the phones; for each node in the tree, all the leaves + * accessible from that node form one of the sets of phones. + * @param stats [in] The statistics as used for normal tree-building. + * @param phone_sets_in [in] All the phones, pre-partitioned into sets. + * The output sets will be various unions of these sets. These sets + * will normally correspond to "real phones", in cases where the phones + * have stress and position markings. + * @param all_pdf_classes_in [in] All the \ref pdf_class "pdf-classes" + * that we consider for clustering. In the normal case this is the singleton + * set {1}, which means that we only consider the central hmm-position + * of the standard 3-state HMM, for clustering purposes. + * @param P [in] The central position in the phone context window; normally + * 1 for triphone system.s + * @param questions_out [out] The questions (sets of phones) are output to here. + **/ +void AutomaticallyObtainQuestions(BuildTreeStatsType &stats, + const std::vector<std::vector<int32> > &phone_sets_in, + const std::vector<int32> &all_pdf_classes_in, + int32 P, + std::vector<std::vector<int32> > *questions_out); + +/// This function clusters the phones (or some initially specified sets of phones) +/// into sets of phones, using a k-means algorithm. Useful, for example, in building +/// simple models for purposes of adaptation. + +void KMeansClusterPhones(BuildTreeStatsType &stats, + const std::vector<std::vector<int32> > &phone_sets_in, + const std::vector<int32> &all_pdf_classes_in, + int32 P, + int32 num_classes, + std::vector<std::vector<int32> > *sets_out); + +/// Reads the roots file (throws on error). Format is lines like: +/// "shared split 1 2 3 4", +/// "not-shared not-split 5", +/// and so on. The numbers are indexes of phones. +void ReadRootsFile(std::istream &is, + std::vector<std::vector<int32> > *phone_sets, + std::vector<bool> *is_shared_root, + std::vector<bool> *is_split_root); + + +/// @} + +}// end namespace kaldi + +#endif diff --git a/kaldi_io/src/kaldi/tree/cluster-utils.h b/kaldi_io/src/kaldi/tree/cluster-utils.h new file mode 100644 index 0000000..55583a2 --- /dev/null +++ b/kaldi_io/src/kaldi/tree/cluster-utils.h @@ -0,0 +1,291 @@ +// tree/cluster-utils.h + +// Copyright 2012 Arnab Ghoshal +// Copyright 2009-2011 Microsoft Corporation; Saarland University + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#ifndef KALDI_TREE_CLUSTER_UTILS_H_ +#define KALDI_TREE_CLUSTER_UTILS_H_ + +#include <vector> +#include "matrix/matrix-lib.h" +#include "itf/clusterable-itf.h" + +namespace kaldi { + +/// \addtogroup clustering_group_simple +/// @{ + +/// Returns the total objective function after adding up all the +/// statistics in the vector (pointers may be NULL). +BaseFloat SumClusterableObjf(const std::vector<Clusterable*> &vec); + +/// Returns the total normalizer (usually count) of the cluster (pointers may be NULL). +BaseFloat SumClusterableNormalizer(const std::vector<Clusterable*> &vec); + +/// Sums stats (ptrs may be NULL). Returns NULL if no non-NULL stats present. +Clusterable *SumClusterable(const std::vector<Clusterable*> &vec); + +/** Fills in any (NULL) holes in "stats" vector, with empty stats, because + * certain algorithms require non-NULL stats. If "stats" nonempty, requires it + * to contain at least one non-NULL pointer that we can call Copy() on. + */ +void EnsureClusterableVectorNotNull(std::vector<Clusterable*> *stats); + + +/** Given stats and a vector "assignments" of the same size (that maps to + * cluster indices), sums the stats up into "clusters." It will add to any + * stats already present in "clusters" (although typically "clusters" will be + * empty when called), and it will extend with NULL pointers for any unseen + * indices. Call EnsureClusterableStatsNotNull afterwards if you want to ensure + * all non-NULL clusters. Pointer in "clusters" are owned by caller. Pointers in + * "stats" do not have to be non-NULL. + */ +void AddToClusters(const std::vector<Clusterable*> &stats, + const std::vector<int32> &assignments, + std::vector<Clusterable*> *clusters); + + +/// AddToClustersOptimized does the same as AddToClusters (it sums up the stats +/// within each cluster, except it uses the sum of all the stats ("total") to +/// optimize the computation for speed, if possible. This will generally only be +/// a significant speedup in the case where there are just two clusters, which +/// can happen in algorithms that are doing binary splits; the idea is that we +/// sum up all the stats in one cluster (the one with the fewest points in it), +/// and then subtract from the total. +void AddToClustersOptimized(const std::vector<Clusterable*> &stats, + const std::vector<int32> &assignments, + const Clusterable &total, + std::vector<Clusterable*> *clusters); + +/// @} end "addtogroup clustering_group_simple" + +/// \addtogroup clustering_group_algo +/// @{ + +// Note, in the algorithms below, it is assumed that the input "points" (which +// is std::vector<Clusterable*>) is all non-NULL. + +/** A bottom-up clustering algorithm. There are two parameters that control how + * many clusters we get: a "max_merge_thresh" which is a threshold for merging + * clusters, and a min_clust which puts a floor on the number of clusters we want. Set + * max_merge_thresh = large to use the min_clust only, or min_clust to 0 to use + * the max_merge_thresh only. + * + * The algorithm is: + * \code + * while (num-clusters > min_clust && smallest_merge_cost <= max_merge_thresh) + * merge the closest two clusters. + * \endcode + * + * @param points [in] Points to be clustered (may not contain NULL pointers) + * @param thresh [in] Threshold on cost change from merging clusters; clusters + * won't be merged if the cost is more than this + * @param min_clust [in] Minimum number of clusters desired; we'll stop merging + * after reaching this number. + * @param clusters_out [out] If non-NULL, will be set to a vector of size equal + * to the number of output clusters, containing the clustered + * statistics. Must be empty when called. + * @param assignments_out [out] If non-NULL, will be resized to the number of + * points, and each element is the index of the cluster that point + * was assigned to. + * @return Returns the total objf change relative to all clusters being separate, which is + * a negative. Note that this is not the same as what the other clustering algorithms return. + */ +BaseFloat ClusterBottomUp(const std::vector<Clusterable*> &points, + BaseFloat thresh, + int32 min_clust, + std::vector<Clusterable*> *clusters_out, + std::vector<int32> *assignments_out); + +/** This is a bottom-up clustering where the points are pre-clustered in a set + * of compartments, such that only points in the same compartment are clustered + * together. The compartment and pair of points with the smallest merge cost + * is selected and the points are clustered. The result stays in the same + * compartment. The code does not merge compartments, and hence assumes that + * the number of compartments is smaller than the 'min_clust' option. + * The clusters in "clusters_out" are newly allocated and owned by the caller. + */ +BaseFloat ClusterBottomUpCompartmentalized( + const std::vector< std::vector<Clusterable*> > &points, BaseFloat thresh, + int32 min_clust, std::vector< std::vector<Clusterable*> > *clusters_out, + std::vector< std::vector<int32> > *assignments_out); + + +struct RefineClustersOptions { + int32 num_iters; // must be >= 0. If zero, does nothing. + int32 top_n; // must be >= 2. + RefineClustersOptions() : num_iters(100), top_n(5) {} + RefineClustersOptions(int32 num_iters_in, int32 top_n_in) + : num_iters(num_iters_in), top_n(top_n_in) {} + // include Write and Read functions because this object gets written/read as + // part of the QuestionsForKeyOptions class. + void Write(std::ostream &os, bool binary) const; + void Read(std::istream &is, bool binary); +}; + +/** RefineClusters is mainly used internally by other clustering algorithms. + * + * It starts with a given assignment of points to clusters and + * keeps trying to improve it by moving points from cluster to cluster, up to + * a maximum number of iterations. + * + * "clusters" and "assignments" are both input and output variables, and so + * both MUST be non-NULL. + * + * "top_n" (>=2) is a pruning value: more is more exact, fewer is faster. The + * algorithm initially finds the "top_n" closest clusters to any given point, + * and from that point only consider move to those "top_n" clusters. Since + * RefineClusters is called multiple times from ClusterKMeans (for instance), + * this is not really a limitation. + */ +BaseFloat RefineClusters(const std::vector<Clusterable*> &points, + std::vector<Clusterable*> *clusters /*non-NULL*/, + std::vector<int32> *assignments /*non-NULL*/, + RefineClustersOptions cfg = RefineClustersOptions()); + +struct ClusterKMeansOptions { + RefineClustersOptions refine_cfg; + int32 num_iters; + int32 num_tries; // if >1, try whole procedure >once and pick best. + bool verbose; + ClusterKMeansOptions() + : refine_cfg(), num_iters(20), num_tries(2), verbose(true) {} +}; + +/** ClusterKMeans is a K-means-like clustering algorithm. It starts with + * pseudo-random initialization of points to clusters and uses RefineClusters + * to iteratively improve the cluster assignments. It does this for + * multiple iterations and picks the result with the best objective function. + * + * + * ClusterKMeans implicitly uses Rand(). It will not necessarily return + * the same value on different calls. Use sRand() if you want consistent + * results. + * The algorithm used in ClusterKMeans is a "k-means-like" algorithm that tries + * to be as efficient as possible. Firstly, since the algorithm it uses + * includes random initialization, it tries the whole thing cfg.num_tries times + * and picks the one with the best objective function. Each try, it does as + * follows: it randomly initializes points to clusters, and then for + * cfg.num_iters iterations it calls RefineClusters(). The options to + * RefineClusters() are given by cfg.refine_cfg. Calling RefineClusters once + * will always be at least as good as doing one iteration of reassigning points to + * clusters, but will generally be quite a bit better (without taking too + * much extra time). + * + * @param points [in] points to be clustered (must be all non-NULL). + * @param num_clust [in] number of clusters requested (it will always return exactly + * this many, or will fail if num_clust > points.size()). + * @param clusters_out [out] may be NULL; if non-NULL, should be empty when called. + * Will be set to a vector of statistics corresponding to the output clusters. + * @param assignments_out [out] may be NULL; if non-NULL, will be set to a vector of + * same size as "points", which says for each point which cluster + * it is assigned to. + * @param cfg [in] configuration class specifying options to the algorithm. + * @return Returns the objective function improvement versus everything being + * in the same cluster. + * + */ +BaseFloat ClusterKMeans(const std::vector<Clusterable*> &points, + int32 num_clust, // exact number of clusters + std::vector<Clusterable*> *clusters_out, // may be NULL + std::vector<int32> *assignments_out, // may be NULL + ClusterKMeansOptions cfg = ClusterKMeansOptions()); + +struct TreeClusterOptions { + ClusterKMeansOptions kmeans_cfg; + int32 branch_factor; + BaseFloat thresh; // Objf change: if >0, may be used to control number of leaves. + TreeClusterOptions() + : kmeans_cfg(), branch_factor(2), thresh(0) { + kmeans_cfg.verbose = false; + } +}; + +/** TreeCluster is a top-down clustering algorithm, using a binary tree (not + * necessarily balanced). Returns objf improvement versus having all points + * in one cluster. The algorithm is: + * - Initialize to 1 cluster (tree with 1 node). + * - Maintain, for each cluster, a "best-binary-split" (using ClusterKMeans + * to do so). Always split the highest scoring cluster, until we can do no + * more splits. + * + * @param points [in] Data points to be clustered + * @param max_clust [in] Maximum number of clusters (you will get exactly this number, + * if there are at least this many points, except if you set the + * cfg.thresh value nonzero, in which case that threshold may limit + * the number of clusters. + * @param clusters_out [out] If non-NULL, will be set to the a vector whose first + * (*num_leaves_out) elements are the leaf clusters, and whose + * subsequent elements are the nonleaf nodes in the tree, in + * topological order with the root node last. Must be empty vector + * when this function is called. + * @param assignments_out [out] If non-NULL, will be set to a vector to a vector the + * same size as "points", where assignments[i] is the leaf node index i + * to which the i'th point gets clustered. + * @param clust_assignments_out [out] If non-NULL, will be set to a vector the same size + * as clusters_out which says for each node (leaf or nonleaf), the + * index of its parent. For the root node (which is last), + * assignments_out[i] == i. For each i, assignments_out[i]>=i, i.e. + * any node's parent is higher numbered than itself. If you don't need + * this information, consider using instead the ClusterTopDown function. + * @param num_leaves_out [out] If non-NULL, will be set to the number of leaf nodes + * in the tree. + * @param cfg [in] Configuration object that controls clustering behavior. Most + * important value is "thresh", which provides an alternative mechanism + * [other than max_clust] to limit the number of leaves. + */ +BaseFloat TreeCluster(const std::vector<Clusterable*> &points, + int32 max_clust, // max number of leaf-level clusters. + std::vector<Clusterable*> *clusters_out, + std::vector<int32> *assignments_out, + std::vector<int32> *clust_assignments_out, + int32 *num_leaves_out, + TreeClusterOptions cfg = TreeClusterOptions()); + + +/** + * A clustering algorithm that internally uses TreeCluster, + * but does not give you the information about the structure of the tree. + * The "clusters_out" and "assignments_out" may be NULL if the outputs are not + * needed. + * + * @param points [in] points to be clustered (must be all non-NULL). + * @param max_clust [in] Maximum number of clusters (you will get exactly this number, + * if there are at least this many points, except if you set the + * cfg.thresh value nonzero, in which case that threshold may limit + * the number of clusters. + * @param clusters_out [out] may be NULL; if non-NULL, should be empty when called. + * Will be set to a vector of statistics corresponding to the output clusters. + * @param assignments_out [out] may be NULL; if non-NULL, will be set to a vector of + * same size as "points", which says for each point which cluster + * it is assigned to. + * @param cfg [in] Configuration object that controls clustering behavior. Most + * important value is "thresh", which provides an alternative mechanism + * [other than max_clust] to limit the number of leaves. +*/ +BaseFloat ClusterTopDown(const std::vector<Clusterable*> &points, + int32 max_clust, // max number of clusters. + std::vector<Clusterable*> *clusters_out, + std::vector<int32> *assignments_out, + TreeClusterOptions cfg = TreeClusterOptions()); + +/// @} end of "addtogroup clustering_group_algo" + +} // end namespace kaldi. + +#endif // KALDI_TREE_CLUSTER_UTILS_H_ diff --git a/kaldi_io/src/kaldi/tree/clusterable-classes.h b/kaldi_io/src/kaldi/tree/clusterable-classes.h new file mode 100644 index 0000000..817d0c6 --- /dev/null +++ b/kaldi_io/src/kaldi/tree/clusterable-classes.h @@ -0,0 +1,158 @@ +// tree/clusterable-classes.h + +// Copyright 2009-2011 Microsoft Corporation; Saarland University +// 2014 Daniel Povey + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#ifndef KALDI_TREE_CLUSTERABLE_CLASSES_H_ +#define KALDI_TREE_CLUSTERABLE_CLASSES_H_ 1 + +#include <string> +#include "itf/clusterable-itf.h" +#include "matrix/matrix-lib.h" + +namespace kaldi { + +// Note: see sgmm/sgmm-clusterable.h for an SGMM-based clusterable +// class. We didn't include it here, to avoid adding an extra +// dependency to this directory. + +/// \addtogroup clustering_group +/// @{ + +/// ScalarClusterable clusters scalars with x^2 loss. +class ScalarClusterable: public Clusterable { + public: + ScalarClusterable(): x_(0), x2_(0), count_(0) {} + explicit ScalarClusterable(BaseFloat x): x_(x), x2_(x*x), count_(1) {} + virtual std::string Type() const { return "scalar"; } + virtual BaseFloat Objf() const; + virtual void SetZero() { count_ = x_ = x2_ = 0.0; } + virtual void Add(const Clusterable &other_in); + virtual void Sub(const Clusterable &other_in); + virtual Clusterable* Copy() const; + virtual BaseFloat Normalizer() const { + return static_cast<BaseFloat>(count_); + } + + // Function to write data to stream. Will organize input later [more complex] + virtual void Write(std::ostream &os, bool binary) const; + virtual Clusterable* ReadNew(std::istream &is, bool binary) const; + + std::string Info(); // For debugging. + BaseFloat Mean() { return (count_ != 0 ? x_/count_ : 0.0); } + private: + BaseFloat x_; + BaseFloat x2_; + BaseFloat count_; + + void Read(std::istream &is, bool binary); +}; + + +/// GaussClusterable wraps Gaussian statistics in a form accessible +/// to generic clustering algorithms. +class GaussClusterable: public Clusterable { + public: + GaussClusterable(): count_(0.0), var_floor_(0.0) {} + GaussClusterable(int32 dim, BaseFloat var_floor): + count_(0.0), stats_(2, dim), var_floor_(var_floor) {} + + GaussClusterable(const Vector<BaseFloat> &x_stats, + const Vector<BaseFloat> &x2_stats, + BaseFloat var_floor, BaseFloat count); + + virtual std::string Type() const { return "gauss"; } + void AddStats(const VectorBase<BaseFloat> &vec, BaseFloat weight = 1.0); + virtual BaseFloat Objf() const; + virtual void SetZero(); + virtual void Add(const Clusterable &other_in); + virtual void Sub(const Clusterable &other_in); + virtual BaseFloat Normalizer() const { return count_; } + virtual Clusterable *Copy() const; + virtual void Scale(BaseFloat f); + virtual void Write(std::ostream &os, bool binary) const; + virtual Clusterable *ReadNew(std::istream &is, bool binary) const; + virtual ~GaussClusterable() {} + + BaseFloat count() const { return count_; } + // The next two functions are not const-correct, because of SubVector. + SubVector<double> x_stats() const { return stats_.Row(0); } + SubVector<double> x2_stats() const { return stats_.Row(1); } + private: + double count_; + Matrix<double> stats_; // two rows: sum, then sum-squared. + double var_floor_; // should be common for all objects created. + + void Read(std::istream &is, bool binary); +}; + +/// @} end of "addtogroup clustering_group" + +inline void GaussClusterable::SetZero() { + count_ = 0; + stats_.SetZero(); +} + +inline GaussClusterable::GaussClusterable(const Vector<BaseFloat> &x_stats, + const Vector<BaseFloat> &x2_stats, + BaseFloat var_floor, BaseFloat count): + count_(count), stats_(2, x_stats.Dim()), var_floor_(var_floor) { + stats_.Row(0).CopyFromVec(x_stats); + stats_.Row(1).CopyFromVec(x2_stats); +} + + +/// VectorClusterable wraps vectors in a form accessible to generic clustering +/// algorithms. Each vector is associated with a weight; these could be 1.0. +/// The objective function (to be maximized) is the negated sum of squared +/// distances from the cluster center to each vector, times that vector's +/// weight. +class VectorClusterable: public Clusterable { + public: + VectorClusterable(): weight_(0.0), sumsq_(0.0) {} + + VectorClusterable(const Vector<BaseFloat> &vector, + BaseFloat weight); + + virtual std::string Type() const { return "vector"; } + // Objf is negated weighted sum of squared distances. + virtual BaseFloat Objf() const; + virtual void SetZero() { weight_ = 0.0; sumsq_ = 0.0; stats_.Set(0.0); } + virtual void Add(const Clusterable &other_in); + virtual void Sub(const Clusterable &other_in); + virtual BaseFloat Normalizer() const { return weight_; } + virtual Clusterable *Copy() const; + virtual void Scale(BaseFloat f); + virtual void Write(std::ostream &os, bool binary) const; + virtual Clusterable *ReadNew(std::istream &is, bool binary) const; + virtual ~VectorClusterable() {} + + private: + double weight_; // sum of weights of the source vectors. Never negative. + Vector<double> stats_; // Equals the weighted sum of the source vectors. + double sumsq_; // Equals the sum over all sources, of weight_ * vec.vec, + // where vec = stats_ / weight_. Used in computing + // the objective function. + void Read(std::istream &is, bool binary); +}; + + + +} // end namespace kaldi. + +#endif // KALDI_TREE_CLUSTERABLE_CLASSES_H_ diff --git a/kaldi_io/src/kaldi/tree/context-dep.h b/kaldi_io/src/kaldi/tree/context-dep.h new file mode 100644 index 0000000..307fcd4 --- /dev/null +++ b/kaldi_io/src/kaldi/tree/context-dep.h @@ -0,0 +1,166 @@ +// tree/context-dep.h + +// Copyright 2009-2011 Microsoft Corporation + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#ifndef KALDI_TREE_CONTEXT_DEP_H_ +#define KALDI_TREE_CONTEXT_DEP_H_ + +#include "itf/context-dep-itf.h" +#include "tree/event-map.h" +#include "matrix/matrix-lib.h" +#include "tree/cluster-utils.h" + +/* + This header provides the declarations for the class ContextDependency, which inherits + from the interface class "ContextDependencyInterface" in itf/context-dep-itf.h. + This is basically a wrapper around an EventMap. The EventMap + (tree/event-map.h) declares most of the internals of the class, and the building routines are + in build-tree.h which uses build-tree-utils.h, which uses cluster-utils.h . */ + + +namespace kaldi { + +static const EventKeyType kPdfClass = -1; // The "name" to which we assign the +// pdf-class (generally corresponds ot position in the HMM, zero-based); +// must not be used for any other event. I.e. the value corresponding to +// this key is the pdf-class (see hmm-topology.h for explanation of what this is). + + +/* ContextDependency is quite a generic decision tree. + + It does not actually do very much-- all the magic is in the EventMap object. + All this class does is to encode the phone context as a sequence of events, and + pass this to the EventMap object to turn into what it will interpret as a + vector of pdfs. + + Different versions of the ContextDependency class that are written in the future may + have slightly different interfaces and pass more stuff in as events, to the + EventMap object. + + In order to separate the process of training decision trees from the process + of actually using them, we do not put any training code into the ContextDependency class. + */ +class ContextDependency: public ContextDependencyInterface { + public: + virtual int32 ContextWidth() const { return N_; } + virtual int32 CentralPosition() const { return P_; } + + + /// returns success or failure; outputs pdf to pdf_id + virtual bool Compute(const std::vector<int32> &phoneseq, + int32 pdf_class, int32 *pdf_id) const; + + virtual int32 NumPdfs() const { + // this routine could be simplified to return to_pdf_->MaxResult()+1. we're a + // bit more paranoid than that. + if (!to_pdf_) return 0; + EventAnswerType max_result = to_pdf_->MaxResult(); + if (max_result < 0 ) return 0; + else return (int32) max_result+1; + } + virtual ContextDependencyInterface *Copy() const { + return new ContextDependency(N_, P_, to_pdf_->Copy()); + } + + /// Read context-dependency object from disk; throws on error + void Read (std::istream &is, bool binary); + + // Constructor with no arguments; will normally be called + // prior to Read() + ContextDependency(): N_(0), P_(0), to_pdf_(NULL) { } + + // Constructor takes ownership of pointers. + ContextDependency(int32 N, int32 P, + EventMap *to_pdf): + N_(N), P_(P), to_pdf_(to_pdf) { } + void Write (std::ostream &os, bool binary) const; + + ~ContextDependency() { if (to_pdf_ != NULL) delete to_pdf_; } + + const EventMap &ToPdfMap() const { return *to_pdf_; } + + /// GetPdfInfo returns a vector indexed by pdf-id, saying for each pdf which + /// pairs of (phone, pdf-class) it can correspond to. (Usually just one). + /// c.f. hmm/hmm-topology.h for meaning of pdf-class. + + void GetPdfInfo(const std::vector<int32> &phones, // list of phones + const std::vector<int32> &num_pdf_classes, // indexed by phone, + std::vector<std::vector<std::pair<int32, int32> > > *pdf_info) + const; + + private: + int32 N_; // + int32 P_; + EventMap *to_pdf_; // owned here. + + KALDI_DISALLOW_COPY_AND_ASSIGN(ContextDependency); +}; + +/// GenRandContextDependency is mainly of use for debugging. Phones must be sorted and uniq +/// on input. +/// @param phones [in] A vector of phone id's [must be sorted and uniq]. +/// @param ensure_all_covered [in] boolean argument; if true, GenRandContextDependency +/// generates a context-dependency object that "works" for all phones [no gaps]. +/// @param num_pdf_classes [out] outputs a vector indexed by phone, of the number +/// of pdf classes (e.g. states) for that phone. +/// @return Returns the a context dependency object. +ContextDependency *GenRandContextDependency(const std::vector<int32> &phones, + bool ensure_all_covered, + std::vector<int32> *num_pdf_classes); + +/// GenRandContextDependencyLarge is like GenRandContextDependency but generates a larger tree +/// with specified N and P for use in "one-time" larger-scale tests. +ContextDependency *GenRandContextDependencyLarge(const std::vector<int32> &phones, + int N, int P, + bool ensure_all_covered, + std::vector<int32> *num_pdf_classes); + +// MonophoneContextDependency() returns a new ContextDependency object that +// corresponds to a monophone system. +// The map phone2num_pdf_classes maps from the phone id to the number of +// pdf-classes we have for that phone (e.g. 3, so the pdf-classes would be +// 0, 1, 2). + +ContextDependency* +MonophoneContextDependency(const std::vector<int32> phones, + const std::vector<int32> phone2num_pdf_classes); + +// MonophoneContextDependencyShared is as MonophoneContextDependency but lets +// you define classes of phones which share pdfs (e.g. different stress-markers of a single +// phone.) Each element of phone_classes is a set of phones that are in that class. +ContextDependency* +MonophoneContextDependencyShared(const std::vector<std::vector<int32> > phone_classes, + const std::vector<int32> phone2num_pdf_classes); + + +// Important note: +// Statistics for training decision trees will be of type: +// std::vector<std::pair<EventType, Clusterable*> > +// We don't make this a typedef as it doesn't add clarity. +// they will be sorted and unique on the EventType member, which +// itself is sorted and unique on the name (see event-map.h). + +// See build-tree.h for functions relating to actually building the decision trees. + + + + +} // namespace Kaldi + + +#endif diff --git a/kaldi_io/src/kaldi/tree/event-map.h b/kaldi_io/src/kaldi/tree/event-map.h new file mode 100644 index 0000000..07fcc2b --- /dev/null +++ b/kaldi_io/src/kaldi/tree/event-map.h @@ -0,0 +1,365 @@ +// tree/event-map.h + +// Copyright 2009-2011 Microsoft Corporation; Haihua Xu + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#ifndef KALDI_TREE_EVENT_MAP_H_ +#define KALDI_TREE_EVENT_MAP_H_ + +#include <vector> +#include <map> +#include <algorithm> +#include "base/kaldi-common.h" +#include "util/stl-utils.h" +#include "util/const-integer-set.h" + +namespace kaldi { + +/// \defgroup event_map_group Event maps +/// \ingroup tree_group +/// See \ref tree_internals for overview, and specifically \ref treei_event_map. + + +// Note RE negative values: some of this code will not work if things of type +// EventValueType are negative. In particular, TableEventMap can't be used if +// things of EventValueType are negative, and additionally TableEventMap won't +// be efficient if things of EventValueType take on extremely large values. The +// EventKeyType can be negative though. + +/// Things of type EventKeyType can take any value. The code does not assume they are contiguous. +/// So values like -1, 1000000 and the like are acceptable. +typedef int32 EventKeyType; + +/// Given current code, things of type EventValueType should generally be nonnegative and in a +/// reasonably small range (e.g. not one million), as we sometimes construct vectors of the size: +/// [largest value we saw for this key]. This deficiency may be fixed in future [would require +/// modifying TableEventMap] +typedef int32 EventValueType; + +/// As far as the event-map code itself is concerned, things of type EventAnswerType may take +/// any value except kNoAnswer (== -1). However, some specific uses of EventMap (e.g. in +/// build-tree-utils.h) assume these quantities are nonnegative. +typedef int32 EventAnswerType; + +typedef std::vector<std::pair<EventKeyType, EventValueType> > EventType; +// It is required to be sorted and have unique keys-- i.e. functions assume this when called +// with this type. + +inline std::pair<EventKeyType, EventValueType> MakeEventPair (EventKeyType k, EventValueType v) { + return std::pair<EventKeyType, EventValueType>(k, v); +} + +void WriteEventType(std::ostream &os, bool binary, const EventType &vec); +void ReadEventType(std::istream &is, bool binary, EventType *vec); + +std::string EventTypeToString(const EventType &evec); // so we can print events out in error messages. + +struct EventMapVectorHash { // Hashing object for EventMapVector. Works for both pointers and references. + // Not used in event-map.{h, cc} + size_t operator () (const EventType &vec); + size_t operator () (const EventType *ptr) { return (*this)(*ptr); } +}; +struct EventMapVectorEqual { // Equality object for EventType pointers-- test equality of underlying vector. + // Not used in event-map.{h, cc} + size_t operator () (const EventType *p1, const EventType *p2) { return (*p1 == *p2); } +}; + + +/// A class that is capable of representing a generic mapping from +/// EventType (which is a vector of (key, value) pairs) to +/// EventAnswerType which is just an integer. See \ref tree_internals +/// for overview. +class EventMap { + public: + static void Check(const EventType &event); // will crash if not sorted and unique on key. + static bool Lookup(const EventType &event, EventKeyType key, EventValueType *ans); + + // Maps events to the answer type. input must be sorted. + virtual bool Map(const EventType &event, EventAnswerType *ans) const = 0; + + // MultiMap maps a partially specified set of events to the set of answers it might + // map to. It appends these to "ans". "ans" is + // **not guaranteed unique at output** if the + // tree contains duplicate answers at leaves -- you should sort & uniq afterwards. + // e.g.: SortAndUniq(ans). + virtual void MultiMap(const EventType &event, std::vector<EventAnswerType> *ans) const = 0; + + // GetChildren() returns the EventMaps that are immediate children of this + // EventMap (if they exist), by putting them in *out. Useful for + // determining the structure of the event map. + virtual void GetChildren(std::vector<EventMap*> *out) const = 0; + + // This Copy() does a deep copy of the event map. + // If new_leaves is nonempty when it reaches a leaf with value l s.t. new_leaves[l] != NULL, + // it replaces it with a copy of that EventMap. This makes it possible to extend and modify + // It's the way we do splits of trees, and clustering of trees. Think about this carefully, because + // the EventMap structure does not support modification of an existing tree. Do not be tempted + // to do this differently, because other kinds of mechanisms would get very messy and unextensible. + // Copy() is the only mechanism to modify a tree. It's similar to a kind of function composition. + // Copy() does not take ownership of the pointers in new_leaves (it uses the Copy() function of those + // EventMaps). + virtual EventMap *Copy(const std::vector<EventMap*> &new_leaves) const = 0; + + EventMap *Copy() const { std::vector<EventMap*> new_leaves; return Copy(new_leaves); } + + // The function MapValues() is intended to be used to map phone-sets between + // different integer representations. For all the keys in the set + // "keys_to_map", it will map the corresponding values using the map + // "value_map". Note: these values are the values in the key->value pairs of + // the EventMap, which really correspond to phones in the usual case; they are + // not the "answers" of the EventMap which correspond to clustered states. In + // case multiple values are mapped to the same value, it will try to deal with + // it gracefully where it can, but will crash if, for example, this would + // cause problems with the TableEventMap. It will also crash if any values + // used for keys in "keys_to_map" are not mapped by "value_map". This + // function is not currently used. + virtual EventMap *MapValues( + const unordered_set<EventKeyType> &keys_to_map, + const unordered_map<EventValueType,EventValueType> &value_map) const = 0; + + // The function Prune() is like Copy(), except it removes parts of the tree + // that return only -1 (it will return NULL if this EventMap returns only -1). + // This is a mechanism to remove parts of the tree-- you would first use the + // Copy() function with a vector of EventMap*, and for the parts you don't + // want, you'd put a ConstantEventMap with -1; you'd then call + // Prune() on the result. This function is not currently used. + virtual EventMap *Prune() const = 0; + + virtual EventAnswerType MaxResult() const { // child classes may override this for efficiency; here is basic version. + // returns -1 if nothing found. + std::vector<EventAnswerType> tmp; EventType empty_event; + MultiMap(empty_event, &tmp); + if (tmp.empty()) { + KALDI_WARN << "EventMap::MaxResult(), empty result"; + return std::numeric_limits<EventAnswerType>::min(); + } + else { return * std::max_element(tmp.begin(), tmp.end()); } + } + + /// Write to stream. + virtual void Write(std::ostream &os, bool binary) = 0; + + virtual ~EventMap() {} + + /// a Write function that takes care of NULL pointers. + static void Write(std::ostream &os, bool binary, EventMap *emap); + /// a Read function that reads an arbitrary EventMap; also + /// works for NULL pointers. + static EventMap *Read(std::istream &is, bool binary); +}; + + +class ConstantEventMap: public EventMap { + public: + virtual bool Map(const EventType &event, EventAnswerType *ans) const { + *ans = answer_; + return true; + } + + virtual void MultiMap(const EventType &, + std::vector<EventAnswerType> *ans) const { + ans->push_back(answer_); + } + + virtual void GetChildren(std::vector<EventMap*> *out) const { out->clear(); } + + virtual EventMap *Copy(const std::vector<EventMap*> &new_leaves) const { + if (answer_ < 0 || answer_ >= (EventAnswerType)new_leaves.size() || + new_leaves[answer_] == NULL) + return new ConstantEventMap(answer_); + else return new_leaves[answer_]->Copy(); + } + + virtual EventMap *MapValues( + const unordered_set<EventKeyType> &keys_to_map, + const unordered_map<EventValueType,EventValueType> &value_map) const { + return new ConstantEventMap(answer_); + } + + virtual EventMap *Prune() const { + return (answer_ == -1 ? NULL : new ConstantEventMap(answer_)); + } + + explicit ConstantEventMap(EventAnswerType answer): answer_(answer) { } + + virtual void Write(std::ostream &os, bool binary); + static ConstantEventMap *Read(std::istream &is, bool binary); + private: + EventAnswerType answer_; + KALDI_DISALLOW_COPY_AND_ASSIGN(ConstantEventMap); +}; + +class TableEventMap: public EventMap { + public: + + virtual bool Map(const EventType &event, EventAnswerType *ans) const { + EventValueType tmp; *ans = -1; // means no answer + if (Lookup(event, key_, &tmp) && tmp >= 0 + && tmp < (EventValueType)table_.size() && table_[tmp] != NULL) { + return table_[tmp]->Map(event, ans); + } + return false; + } + + virtual void GetChildren(std::vector<EventMap*> *out) const { + out->clear(); + for (size_t i = 0; i<table_.size(); i++) + if (table_[i] != NULL) out->push_back(table_[i]); + } + + virtual void MultiMap(const EventType &event, std::vector<EventAnswerType> *ans) const { + EventValueType tmp; + if (Lookup(event, key_, &tmp)) { + if (tmp >= 0 && tmp < (EventValueType)table_.size() && table_[tmp] != NULL) + return table_[tmp]->MultiMap(event, ans); + // else no answers. + } else { // all answers are possible if no such key. + for (size_t i = 0;i < table_.size();i++) + if (table_[i] != NULL) table_[i]->MultiMap(event, ans); // append. + } + } + + virtual EventMap *Prune() const; + + virtual EventMap *MapValues( + const unordered_set<EventKeyType> &keys_to_map, + const unordered_map<EventValueType,EventValueType> &value_map) const; + + /// Takes ownership of pointers. + explicit TableEventMap(EventKeyType key, const std::vector<EventMap*> &table): key_(key), table_(table) {} + /// Takes ownership of pointers. + explicit TableEventMap(EventKeyType key, const std::map<EventValueType, EventMap*> &map_in); + /// This initializer creates a ConstantEventMap for each value in the map. + explicit TableEventMap(EventKeyType key, const std::map<EventValueType, EventAnswerType> &map_in); + + virtual void Write(std::ostream &os, bool binary); + static TableEventMap *Read(std::istream &is, bool binary); + + virtual EventMap *Copy(const std::vector<EventMap*> &new_leaves) const { + std::vector<EventMap*> new_table_(table_.size(), NULL); + for (size_t i = 0;i<table_.size();i++) if (table_[i]) new_table_[i]=table_[i]->Copy(new_leaves); + return new TableEventMap(key_, new_table_); + } + virtual ~TableEventMap() { + DeletePointers(&table_); + } + private: + EventKeyType key_; + std::vector<EventMap*> table_; + KALDI_DISALLOW_COPY_AND_ASSIGN(TableEventMap); +}; + + + + +class SplitEventMap: public EventMap { // A decision tree [non-leaf] node. + public: + + virtual bool Map(const EventType &event, EventAnswerType *ans) const { + EventValueType value; + if (Lookup(event, key_, &value)) { + // if (std::binary_search(yes_set_.begin(), yes_set_.end(), value)) { + if (yes_set_.count(value)) { + return yes_->Map(event, ans); + } + return no_->Map(event, ans); + } + return false; + } + + virtual void MultiMap(const EventType &event, std::vector<EventAnswerType> *ans) const { + EventValueType tmp; + if (Lookup(event, key_, &tmp)) { + if (std::binary_search(yes_set_.begin(), yes_set_.end(), tmp)) + yes_->MultiMap(event, ans); + else + no_->MultiMap(event, ans); + } else { // both yes and no contribute. + yes_->MultiMap(event, ans); + no_->MultiMap(event, ans); + } + } + + virtual void GetChildren(std::vector<EventMap*> *out) const { + out->clear(); + out->push_back(yes_); + out->push_back(no_); + } + + virtual EventMap *Copy(const std::vector<EventMap*> &new_leaves) const { + return new SplitEventMap(key_, yes_set_, yes_->Copy(new_leaves), no_->Copy(new_leaves)); + } + + virtual void Write(std::ostream &os, bool binary); + static SplitEventMap *Read(std::istream &is, bool binary); + + virtual EventMap *Prune() const; + + virtual EventMap *MapValues( + const unordered_set<EventKeyType> &keys_to_map, + const unordered_map<EventValueType,EventValueType> &value_map) const; + + virtual ~SplitEventMap() { Destroy(); } + + /// This constructor takes ownership of the "yes" and "no" arguments. + SplitEventMap(EventKeyType key, const std::vector<EventValueType> &yes_set, + EventMap *yes, EventMap *no): key_(key), yes_set_(yes_set), yes_(yes), no_(no) { + KALDI_PARANOID_ASSERT(IsSorted(yes_set)); + KALDI_ASSERT(yes_ != NULL && no_ != NULL); + } + + + private: + /// This constructor used in the Copy() function. + SplitEventMap(EventKeyType key, const ConstIntegerSet<EventValueType> &yes_set, + EventMap *yes, EventMap *no): key_(key), yes_set_(yes_set), yes_(yes), no_(no) { + KALDI_ASSERT(yes_ != NULL && no_ != NULL); + } + void Destroy() { + delete yes_; delete no_; + } + EventKeyType key_; + // std::vector<EventValueType> yes_set_; + ConstIntegerSet<EventValueType> yes_set_; // more efficient Map function. + EventMap *yes_; // owned here. + EventMap *no_; // owned here. + SplitEventMap &operator = (const SplitEventMap &other); // Disallow. +}; + +/** + This function gets the tree structure of the EventMap "map" in a convenient form. + If "map" corresponds to a tree structure (not necessarily binary) with leaves + uniquely numbered from 0 to num_leaves-1, then the function will return true, + output "num_leaves", and set "parent" to a vector of size equal to the number of + nodes in the tree (nonleaf and leaf), where each index corresponds to a node + and the leaf indices correspond to the values returned by the EventMap from + that leaf; for an index i, parent[i] equals the parent of that node in the tree + structure, where parent[i] > i, except for the last (root) node where parent[i] == i. + If the EventMap does not have this structure (e.g. if multiple different leaf nodes share + the same number), then it will return false. +*/ + +bool GetTreeStructure(const EventMap &map, + int32 *num_leaves, + std::vector<int32> *parents); + + +/// @} end "addtogroup event_map_group" + +} + +#endif diff --git a/kaldi_io/src/kaldi/tree/tree-renderer.h b/kaldi_io/src/kaldi/tree/tree-renderer.h new file mode 100644 index 0000000..5e0b0d8 --- /dev/null +++ b/kaldi_io/src/kaldi/tree/tree-renderer.h @@ -0,0 +1,84 @@ +// tree/tree-renderer.h + +// Copyright 2012 Vassil Panayotov + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#ifndef KALDI_TREE_TREE_RENDERER_H_ +#define KALDI_TREE_TREE_RENDERER_H_ + +#include "base/kaldi-common.h" +#include "tree/event-map.h" +#include "util/common-utils.h" +#include "hmm/transition-model.h" +#include "fst/fstlib.h" + +namespace kaldi { + +// Parses a decision tree file and outputs its description in GraphViz format +class TreeRenderer { + public: + const static int32 kEdgeWidth; // normal width of the edges and state contours + const static int32 kEdgeWidthQuery; // edge and state width when in query + const static std::string kEdgeColor; // normal color for states and edges + const static std::string kEdgeColorQuery; // edge and state color when in query + + TreeRenderer(std::istream &is, bool binary, std::ostream &os, + fst::SymbolTable &phone_syms, bool use_tooltips) + : phone_syms_(phone_syms), is_(is), out_(os), binary_(binary), + N_(-1), use_tooltips_(use_tooltips), next_id_(0) {} + + // Renders the tree and if the "query" parameter is not NULL + // a distinctly colored trace corresponding to the event. + void Render(const EventType *query); + + private: + // Looks-up the next token from the stream and invokes + // the appropriate render method to visualize it + void RenderSubTree(const EventType *query, int32 id); + + // Renders a leaf node (constant event map) + void RenderConstant(const EventType *query, int32 id); + + // Renders a split event map node and the edges to the nodes + // representing YES and NO sets + void RenderSplit(const EventType *query, int32 id); + + // Renders a table event map node and the edges to its (non-null) children + void RenderTable(const EventType *query, int32 id); + + // Makes a comma-separated string from the elements of a set of identifiers + // If the identifiers represent phones, their symbolic representations are used + std::string MakeEdgeLabel(const EventKeyType &key, + const ConstIntegerSet<EventValueType> &intset); + + // Writes the GraphViz representation of a non-leaf node to the out stream + // A question about a phone from the context window or about pdf-class + // is used as a label. + void RenderNonLeaf(int32 id, const EventKeyType &key, bool in_query); + + fst::SymbolTable &phone_syms_; // phone symbols to be used as edge labels + std::istream &is_; // the stream from which the tree is read + std::ostream &out_; // the GraphViz representation is written to this stream + bool binary_; // is the input stream binary? + int32 N_, P_; // context-width and central position + bool use_tooltips_; // use tooltips(useful in e.g. SVG) instead of labels + int32 next_id_; // the first unused GraphViz node ID +}; + +} // namespace kaldi + +#endif // KALDI_TREE_TREE_RENDERER_H_ diff --git a/kaldi_io/src/kaldi/util/basic-filebuf.h b/kaldi_io/src/kaldi/util/basic-filebuf.h new file mode 100644 index 0000000..cf2e079 --- /dev/null +++ b/kaldi_io/src/kaldi/util/basic-filebuf.h @@ -0,0 +1,1065 @@ +/////////////////////////////////////////////////////////////////////////////// +// This is a modified version of the std::basic_filebuf from libc++ +// (http://libcxx.llvm.org/). +// It allows one to create basic_filebuf from an existing FILE* handle or file +// descriptor. +// +// This file is dual licensed under the MIT and the University of Illinois Open +// Source License licenses. See LICENSE.TXT for details (included at the +// bottom). +/////////////////////////////////////////////////////////////////////////////// +#ifndef KALDI_UTIL_BASIC_FILEBUF_H_ +#define KALDI_UTIL_BASIC_FILEBUF_H_ + +/////////////////////////////////////////////////////////////////////////////// +#include <fstream> +#include <cstdio> +#include <cstring> + +/////////////////////////////////////////////////////////////////////////////// +namespace kaldi +{ + +/////////////////////////////////////////////////////////////////////////////// +template <typename CharT, typename Traits = std::char_traits<CharT> > +class basic_filebuf : public std::basic_streambuf<CharT, Traits> +{ +public: + typedef CharT char_type; + typedef Traits traits_type; + typedef typename traits_type::int_type int_type; + typedef typename traits_type::pos_type pos_type; + typedef typename traits_type::off_type off_type; + typedef typename traits_type::state_type state_type; + + basic_filebuf(); + basic_filebuf(basic_filebuf&& rhs); + virtual ~basic_filebuf(); + + basic_filebuf& operator=(basic_filebuf&& rhs); + void swap(basic_filebuf& rhs); + + bool is_open() const; + basic_filebuf* open(const char* s, std::ios_base::openmode mode); + basic_filebuf* open(const std::string& s, std::ios_base::openmode mode); + basic_filebuf* open(int fd, std::ios_base::openmode mode); + basic_filebuf* open(FILE* f, std::ios_base::openmode mode); + basic_filebuf* close(); + + FILE* file() { return this->_M_file; } + int fd() { return fileno(this->_M_file); } + +protected: + int_type underflow() override; + int_type pbackfail(int_type c = traits_type::eof()) override; + int_type overflow (int_type c = traits_type::eof()) override; + std::basic_streambuf<char_type, traits_type>* setbuf(char_type* s, std::streamsize n) override; + pos_type seekoff(off_type off, std::ios_base::seekdir way, + std::ios_base::openmode wch = std::ios_base::in | std::ios_base::out) override; + pos_type seekpos(pos_type sp, + std::ios_base::openmode wch = std::ios_base::in | std::ios_base::out) override; + int sync() override; + void imbue(const std::locale& loc) override; + +protected: + char* _M_extbuf; + const char* _M_extbufnext; + const char* _M_extbufend; + char _M_extbuf_min[8]; + size_t _M_ebs; + char_type* _M_intbuf; + size_t _M_ibs; + FILE* _M_file; + const std::codecvt<char_type, char, state_type>* _M_cv; + state_type _M_st; + state_type _M_st_last; + std::ios_base::openmode _M_om; + std::ios_base::openmode _M_cm; + bool _M_owns_eb; + bool _M_owns_ib; + bool _M_always_noconv; + + const char* _M_get_mode(std::ios_base::openmode mode); + bool _M_read_mode(); + void _M_write_mode(); +}; + +/////////////////////////////////////////////////////////////////////////////// +template <class CharT, class Traits> +basic_filebuf<CharT, Traits>::basic_filebuf() + : _M_extbuf(nullptr), + _M_extbufnext(nullptr), + _M_extbufend(nullptr), + _M_ebs(0), + _M_intbuf(nullptr), + _M_ibs(0), + _M_file(nullptr), + _M_cv(nullptr), + _M_st(), + _M_st_last(), + _M_om(std::ios_base::openmode(0)), + _M_cm(std::ios_base::openmode(0)), + _M_owns_eb(false), + _M_owns_ib(false), + _M_always_noconv(false) +{ + if (std::has_facet<std::codecvt<char_type, char, state_type> >(this->getloc())) + { + _M_cv = &std::use_facet<std::codecvt<char_type, char, state_type> >(this->getloc()); + _M_always_noconv = _M_cv->always_noconv(); + } + setbuf(0, 4096); +} + +/////////////////////////////////////////////////////////////////////////////// +template <class CharT, class Traits> +basic_filebuf<CharT, Traits>::basic_filebuf(basic_filebuf&& rhs) + : std::basic_streambuf<CharT, Traits>(rhs) +{ + if (rhs._M_extbuf == rhs._M_extbuf_min) + { + _M_extbuf = _M_extbuf_min; + _M_extbufnext = _M_extbuf + (rhs._M_extbufnext - rhs._M_extbuf); + _M_extbufend = _M_extbuf + (rhs._M_extbufend - rhs._M_extbuf); + } + else + { + _M_extbuf = rhs._M_extbuf; + _M_extbufnext = rhs._M_extbufnext; + _M_extbufend = rhs._M_extbufend; + } + _M_ebs = rhs._M_ebs; + _M_intbuf = rhs._M_intbuf; + _M_ibs = rhs._M_ibs; + _M_file = rhs._M_file; + _M_cv = rhs._M_cv; + _M_st = rhs._M_st; + _M_st_last = rhs._M_st_last; + _M_om = rhs._M_om; + _M_cm = rhs._M_cm; + _M_owns_eb = rhs._M_owns_eb; + _M_owns_ib = rhs._M_owns_ib; + _M_always_noconv = rhs._M_always_noconv; + if (rhs.pbase()) + { + if (rhs.pbase() == rhs._M_intbuf) + this->setp(_M_intbuf, _M_intbuf + (rhs. epptr() - rhs.pbase())); + else + this->setp((char_type*)_M_extbuf, + (char_type*)_M_extbuf + (rhs. epptr() - rhs.pbase())); + this->pbump(rhs. pptr() - rhs.pbase()); + } + else if (rhs.eback()) + { + if (rhs.eback() == rhs._M_intbuf) + this->setg(_M_intbuf, _M_intbuf + (rhs.gptr() - rhs.eback()), + _M_intbuf + (rhs.egptr() - rhs.eback())); + else + this->setg((char_type*)_M_extbuf, + (char_type*)_M_extbuf + (rhs.gptr() - rhs.eback()), + (char_type*)_M_extbuf + (rhs.egptr() - rhs.eback())); + } + rhs._M_extbuf = nullptr; + rhs._M_extbufnext = nullptr; + rhs._M_extbufend = nullptr; + rhs._M_ebs = 0; + rhs._M_intbuf = nullptr; + rhs._M_ibs = 0; + rhs._M_file = nullptr; + rhs._M_st = state_type(); + rhs._M_st_last = state_type(); + rhs._M_om = std::ios_base::openmode(0); + rhs._M_cm = std::ios_base::openmode(0); + rhs._M_owns_eb = false; + rhs._M_owns_ib = false; + rhs.setg(0, 0, 0); + rhs.setp(0, 0); +} + +/////////////////////////////////////////////////////////////////////////////// +template <class CharT, class Traits> +inline +basic_filebuf<CharT, Traits>& +basic_filebuf<CharT, Traits>::operator=(basic_filebuf&& rhs) +{ + close(); + swap(rhs); + return *this; +} + +/////////////////////////////////////////////////////////////////////////////// +template <class CharT, class Traits> +basic_filebuf<CharT, Traits>::~basic_filebuf() +{ + // try + // { + // close(); + // } + // catch (...) + // { + // } + if (_M_owns_eb) + delete [] _M_extbuf; + if (_M_owns_ib) + delete [] _M_intbuf; +} + +/////////////////////////////////////////////////////////////////////////////// +template <class CharT, class Traits> +void +basic_filebuf<CharT, Traits>::swap(basic_filebuf& rhs) +{ + std::basic_streambuf<char_type, traits_type>::swap(rhs); + if (_M_extbuf != _M_extbuf_min && rhs._M_extbuf != rhs._M_extbuf_min) + { + std::swap(_M_extbuf, rhs._M_extbuf); + std::swap(_M_extbufnext, rhs._M_extbufnext); + std::swap(_M_extbufend, rhs._M_extbufend); + } + else + { + ptrdiff_t ln = _M_extbufnext - _M_extbuf; + ptrdiff_t le = _M_extbufend - _M_extbuf; + ptrdiff_t rn = rhs._M_extbufnext - rhs._M_extbuf; + ptrdiff_t re = rhs._M_extbufend - rhs._M_extbuf; + if (_M_extbuf == _M_extbuf_min && rhs._M_extbuf != rhs._M_extbuf_min) + { + _M_extbuf = rhs._M_extbuf; + rhs._M_extbuf = rhs._M_extbuf_min; + } + else if (_M_extbuf != _M_extbuf_min && rhs._M_extbuf == rhs._M_extbuf_min) + { + rhs._M_extbuf = _M_extbuf; + _M_extbuf = _M_extbuf_min; + } + _M_extbufnext = _M_extbuf + rn; + _M_extbufend = _M_extbuf + re; + rhs._M_extbufnext = rhs._M_extbuf + ln; + rhs._M_extbufend = rhs._M_extbuf + le; + } + std::swap(_M_ebs, rhs._M_ebs); + std::swap(_M_intbuf, rhs._M_intbuf); + std::swap(_M_ibs, rhs._M_ibs); + std::swap(_M_file, rhs._M_file); + std::swap(_M_cv, rhs._M_cv); + std::swap(_M_st, rhs._M_st); + std::swap(_M_st_last, rhs._M_st_last); + std::swap(_M_om, rhs._M_om); + std::swap(_M_cm, rhs._M_cm); + std::swap(_M_owns_eb, rhs._M_owns_eb); + std::swap(_M_owns_ib, rhs._M_owns_ib); + std::swap(_M_always_noconv, rhs._M_always_noconv); + if (this->eback() == (char_type*)rhs._M_extbuf_min) + { + ptrdiff_t n = this->gptr() - this->eback(); + ptrdiff_t e = this->egptr() - this->eback(); + this->setg((char_type*)_M_extbuf_min, + (char_type*)_M_extbuf_min + n, + (char_type*)_M_extbuf_min + e); + } + else if (this->pbase() == (char_type*)rhs._M_extbuf_min) + { + ptrdiff_t n = this->pptr() - this->pbase(); + ptrdiff_t e = this->epptr() - this->pbase(); + this->setp((char_type*)_M_extbuf_min, + (char_type*)_M_extbuf_min + e); + this->pbump(n); + } + if (rhs.eback() == (char_type*)_M_extbuf_min) + { + ptrdiff_t n = rhs.gptr() - rhs.eback(); + ptrdiff_t e = rhs.egptr() - rhs.eback(); + rhs.setg((char_type*)rhs._M_extbuf_min, + (char_type*)rhs._M_extbuf_min + n, + (char_type*)rhs._M_extbuf_min + e); + } + else if (rhs.pbase() == (char_type*)_M_extbuf_min) + { + ptrdiff_t n = rhs.pptr() - rhs.pbase(); + ptrdiff_t e = rhs.epptr() - rhs.pbase(); + rhs.setp((char_type*)rhs._M_extbuf_min, + (char_type*)rhs._M_extbuf_min + e); + rhs.pbump(n); + } +} + +/////////////////////////////////////////////////////////////////////////////// +template <class CharT, class Traits> +inline +void +swap(basic_filebuf<CharT, Traits>& x, basic_filebuf<CharT, Traits>& y) +{ + x.swap(y); +} + +/////////////////////////////////////////////////////////////////////////////// +template <class CharT, class Traits> +inline +bool +basic_filebuf<CharT, Traits>::is_open() const +{ + return _M_file != nullptr; +} + +/////////////////////////////////////////////////////////////////////////////// +template <class CharT, class Traits> +const char* basic_filebuf<CharT, Traits>::_M_get_mode(std::ios_base::openmode mode) +{ + switch ((mode & ~std::ios_base::ate) | 0) + { + case std::ios_base::out: + case std::ios_base::out | std::ios_base::trunc: + return "w"; + case std::ios_base::out | std::ios_base::app: + case std::ios_base::app: + return "a"; + break; + case std::ios_base::in: + return "r"; + case std::ios_base::in | std::ios_base::out: + return "r+"; + case std::ios_base::in | std::ios_base::out | std::ios_base::trunc: + return "w+"; + case std::ios_base::in | std::ios_base::out | std::ios_base::app: + case std::ios_base::in | std::ios_base::app: + return "a+"; + case std::ios_base::out | std::ios_base::binary: + case std::ios_base::out | std::ios_base::trunc | std::ios_base::binary: + return "wb"; + case std::ios_base::out | std::ios_base::app | std::ios_base::binary: + case std::ios_base::app | std::ios_base::binary: + return "ab"; + case std::ios_base::in | std::ios_base::binary: + return "rb"; + case std::ios_base::in | std::ios_base::out | std::ios_base::binary: + return "r+b"; + case std::ios_base::in | std::ios_base::out | std::ios_base::trunc | std::ios_base::binary: + return "w+b"; + case std::ios_base::in | std::ios_base::out | std::ios_base::app | std::ios_base::binary: + case std::ios_base::in | std::ios_base::app | std::ios_base::binary: + return "a+b"; + default: + return nullptr; + } +} + +/////////////////////////////////////////////////////////////////////////////// +template <class CharT, class Traits> +basic_filebuf<CharT, Traits>* +basic_filebuf<CharT, Traits>::open(const char* s, std::ios_base::openmode mode) +{ + basic_filebuf<CharT, Traits>* rt = nullptr; + if (_M_file == nullptr) + { + const char* md= _M_get_mode(mode); + if (md) + { + _M_file = fopen(s, md); + if (_M_file) + { + rt = this; + _M_om = mode; + if (mode & std::ios_base::ate) + { + if (fseek(_M_file, 0, SEEK_END)) + { + fclose(_M_file); + _M_file = nullptr; + rt = nullptr; + } + } + } + } + } + return rt; +} + +/////////////////////////////////////////////////////////////////////////////// +template <class CharT, class Traits> +inline +basic_filebuf<CharT, Traits>* +basic_filebuf<CharT, Traits>::open(const std::string& s, std::ios_base::openmode mode) +{ + return open(s.c_str(), mode); +} + +/////////////////////////////////////////////////////////////////////////////// +template <class CharT, class Traits> +basic_filebuf<CharT, Traits>* +basic_filebuf<CharT, Traits>::open(int fd, std::ios_base::openmode mode) +{ + const char* md= this->_M_get_mode(mode); + if (md) + { + this->_M_file= fdopen(fd, md); + this->_M_om = mode; + return this; + } + else return nullptr; +} + +/////////////////////////////////////////////////////////////////////////////// +template <class CharT, class Traits> +basic_filebuf<CharT, Traits>* +basic_filebuf<CharT, Traits>::open(FILE* f, std::ios_base::openmode mode) +{ + this->_M_file = f; + this->_M_om = mode; + return this; +} + +/////////////////////////////////////////////////////////////////////////////// +template <class CharT, class Traits> +basic_filebuf<CharT, Traits>* +basic_filebuf<CharT, Traits>::close() +{ + basic_filebuf<CharT, Traits>* rt = nullptr; + if (_M_file) + { + rt = this; + std::unique_ptr<FILE, int(*)(FILE*)> h(_M_file, fclose); + if (sync()) + rt = nullptr; + if (fclose(h.release()) == 0) + _M_file = nullptr; + else + rt = nullptr; + } + return rt; +} + +/////////////////////////////////////////////////////////////////////////////// +template <class CharT, class Traits> +typename basic_filebuf<CharT, Traits>::int_type +basic_filebuf<CharT, Traits>::underflow() +{ + if (_M_file == nullptr) + return traits_type::eof(); + bool initial = _M_read_mode(); + char_type buf; + if (this->gptr() == nullptr) + this->setg(&buf, &buf+1, &buf+1); + const size_t unget_sz = initial ? 0 : std::min<size_t>((this->egptr() - this->eback()) / 2, 4); + int_type c = traits_type::eof(); + if (this->gptr() == this->egptr()) + { + memmove(this->eback(), this->egptr() - unget_sz, unget_sz * sizeof(char_type)); + if (_M_always_noconv) + { + size_t nmemb = static_cast<size_t>(this->egptr() - this->eback() - unget_sz); + nmemb = fread(this->eback() + unget_sz, 1, nmemb, _M_file); + if (nmemb != 0) + { + this->setg(this->eback(), + this->eback() + unget_sz, + this->eback() + unget_sz + nmemb); + c = traits_type::to_int_type(*this->gptr()); + } + } + else + { + memmove(_M_extbuf, _M_extbufnext, _M_extbufend - _M_extbufnext); + _M_extbufnext = _M_extbuf + (_M_extbufend - _M_extbufnext); + _M_extbufend = _M_extbuf + (_M_extbuf == _M_extbuf_min ? sizeof(_M_extbuf_min) : _M_ebs); + size_t nmemb = std::min(static_cast<size_t>(_M_ibs - unget_sz), + static_cast<size_t>(_M_extbufend - _M_extbufnext)); + std::codecvt_base::result r; + _M_st_last = _M_st; + size_t nr = fread((void*)_M_extbufnext, 1, nmemb, _M_file); + if (nr != 0) + { + if (!_M_cv) + throw std::bad_cast(); + _M_extbufend = _M_extbufnext + nr; + char_type* inext; + r = _M_cv->in(_M_st, _M_extbuf, _M_extbufend, _M_extbufnext, + this->eback() + unget_sz, + this->eback() + _M_ibs, inext); + if (r == std::codecvt_base::noconv) + { + this->setg((char_type*)_M_extbuf, (char_type*)_M_extbuf, (char_type*)_M_extbufend); + c = traits_type::to_int_type(*this->gptr()); + } + else if (inext != this->eback() + unget_sz) + { + this->setg(this->eback(), this->eback() + unget_sz, inext); + c = traits_type::to_int_type(*this->gptr()); + } + } + } + } + else + c = traits_type::to_int_type(*this->gptr()); + if (this->eback() == &buf) + this->setg(0, 0, 0); + return c; +} + +/////////////////////////////////////////////////////////////////////////////// +template <class CharT, class Traits> +typename basic_filebuf<CharT, Traits>::int_type +basic_filebuf<CharT, Traits>::pbackfail(int_type c) +{ + if (_M_file && this->eback() < this->gptr()) + { + if (traits_type::eq_int_type(c, traits_type::eof())) + { + this->gbump(-1); + return traits_type::not_eof(c); + } + if ((_M_om & std::ios_base::out) || + traits_type::eq(traits_type::to_char_type(c), this->gptr()[-1])) + { + this->gbump(-1); + *this->gptr() = traits_type::to_char_type(c); + return c; + } + } + return traits_type::eof(); +} + +/////////////////////////////////////////////////////////////////////////////// +template <class CharT, class Traits> +typename basic_filebuf<CharT, Traits>::int_type +basic_filebuf<CharT, Traits>::overflow(int_type c) +{ + if (_M_file == nullptr) + return traits_type::eof(); + _M_write_mode(); + char_type buf; + char_type* pb_save = this->pbase(); + char_type* epb_save = this->epptr(); + if (!traits_type::eq_int_type(c, traits_type::eof())) + { + if (this->pptr() == nullptr) + this->setp(&buf, &buf+1); + *this->pptr() = traits_type::to_char_type(c); + this->pbump(1); + } + if (this->pptr() != this->pbase()) + { + if (_M_always_noconv) + { + size_t nmemb = static_cast<size_t>(this->pptr() - this->pbase()); + if (fwrite(this->pbase(), sizeof(char_type), nmemb, _M_file) != nmemb) + return traits_type::eof(); + } + else + { + char* extbe = _M_extbuf; + std::codecvt_base::result r; + do + { + if (!_M_cv) + throw std::bad_cast(); + const char_type* e; + r = _M_cv->out(_M_st, this->pbase(), this->pptr(), e, + _M_extbuf, _M_extbuf + _M_ebs, extbe); + if (e == this->pbase()) + return traits_type::eof(); + if (r == std::codecvt_base::noconv) + { + size_t nmemb = static_cast<size_t>(this->pptr() - this->pbase()); + if (fwrite(this->pbase(), 1, nmemb, _M_file) != nmemb) + return traits_type::eof(); + } + else if (r == std::codecvt_base::ok || r == std::codecvt_base::partial) + { + size_t nmemb = static_cast<size_t>(extbe - _M_extbuf); + if (fwrite(_M_extbuf, 1, nmemb, _M_file) != nmemb) + return traits_type::eof(); + if (r == std::codecvt_base::partial) + { + this->setp((char_type*)e, this->pptr()); + this->pbump(this->epptr() - this->pbase()); + } + } + else + return traits_type::eof(); + } while (r == std::codecvt_base::partial); + } + this->setp(pb_save, epb_save); + } + return traits_type::not_eof(c); +} + +/////////////////////////////////////////////////////////////////////////////// +template <class CharT, class Traits> +std::basic_streambuf<CharT, Traits>* +basic_filebuf<CharT, Traits>::setbuf(char_type* s, std::streamsize n) +{ + this->setg(0, 0, 0); + this->setp(0, 0); + if (_M_owns_eb) + delete [] _M_extbuf; + if (_M_owns_ib) + delete [] _M_intbuf; + _M_ebs = n; + if (_M_ebs > sizeof(_M_extbuf_min)) + { + if (_M_always_noconv && s) + { + _M_extbuf = (char*)s; + _M_owns_eb = false; + } + else + { + _M_extbuf = new char[_M_ebs]; + _M_owns_eb = true; + } + } + else + { + _M_extbuf = _M_extbuf_min; + _M_ebs = sizeof(_M_extbuf_min); + _M_owns_eb = false; + } + if (!_M_always_noconv) + { + _M_ibs = std::max<std::streamsize>(n, sizeof(_M_extbuf_min)); + if (s && _M_ibs >= sizeof(_M_extbuf_min)) + { + _M_intbuf = s; + _M_owns_ib = false; + } + else + { + _M_intbuf = new char_type[_M_ibs]; + _M_owns_ib = true; + } + } + else + { + _M_ibs = 0; + _M_intbuf = 0; + _M_owns_ib = false; + } + return this; +} + +/////////////////////////////////////////////////////////////////////////////// +template <class CharT, class Traits> +typename basic_filebuf<CharT, Traits>::pos_type +basic_filebuf<CharT, Traits>::seekoff(off_type off, std::ios_base::seekdir way, + std::ios_base::openmode) +{ + if (!_M_cv) + throw std::bad_cast(); + int width = _M_cv->encoding(); + if (_M_file == nullptr || (width <= 0 && off != 0) || sync()) + return pos_type(off_type(-1)); + // width > 0 || off == 0 + int whence; + switch (way) + { + case std::ios_base::beg: + whence = SEEK_SET; + break; + case std::ios_base::cur: + whence = SEEK_CUR; + break; + case std::ios_base::end: + whence = SEEK_END; + break; + default: + return pos_type(off_type(-1)); + } +#if _WIN32 + if (fseek(_M_file, width > 0 ? width * off : 0, whence)) + return pos_type(off_type(-1)); + pos_type r = ftell(_M_file); +#else + if (fseeko(_M_file, width > 0 ? width * off : 0, whence)) + return pos_type(off_type(-1)); + pos_type r = ftello(_M_file); +#endif + r.state(_M_st); + return r; +} + +/////////////////////////////////////////////////////////////////////////////// +template <class CharT, class Traits> +typename basic_filebuf<CharT, Traits>::pos_type +basic_filebuf<CharT, Traits>::seekpos(pos_type sp, std::ios_base::openmode) +{ + if (_M_file == nullptr || sync()) + return pos_type(off_type(-1)); +#if _WIN32 + if (fseek(_M_file, sp, SEEK_SET)) + return pos_type(off_type(-1)); +#else + if (fseeko(_M_file, sp, SEEK_SET)) + return pos_type(off_type(-1)); +#endif + _M_st = sp.state(); + return sp; +} + +/////////////////////////////////////////////////////////////////////////////// +template <class CharT, class Traits> +int +basic_filebuf<CharT, Traits>::sync() +{ + if (_M_file == nullptr) + return 0; + if (!_M_cv) + throw std::bad_cast(); + if (_M_cm & std::ios_base::out) + { + if (this->pptr() != this->pbase()) + if (overflow() == traits_type::eof()) + return -1; + std::codecvt_base::result r; + do + { + char* extbe; + r = _M_cv->unshift(_M_st, _M_extbuf, _M_extbuf + _M_ebs, extbe); + size_t nmemb = static_cast<size_t>(extbe - _M_extbuf); + if (fwrite(_M_extbuf, 1, nmemb, _M_file) != nmemb) + return -1; + } while (r == std::codecvt_base::partial); + if (r == std::codecvt_base::error) + return -1; + if (fflush(_M_file)) + return -1; + } + else if (_M_cm & std::ios_base::in) + { + off_type c; + state_type state = _M_st_last; + bool update_st = false; + if (_M_always_noconv) + c = this->egptr() - this->gptr(); + else + { + int width = _M_cv->encoding(); + c = _M_extbufend - _M_extbufnext; + if (width > 0) + c += width * (this->egptr() - this->gptr()); + else + { + if (this->gptr() != this->egptr()) + { + const int off = _M_cv->length(state, _M_extbuf, + _M_extbufnext, + this->gptr() - this->eback()); + c += _M_extbufnext - _M_extbuf - off; + update_st = true; + } + } + } +#if _WIN32 + if (fseek(_M_file_, -c, SEEK_CUR)) + return -1; +#else + if (fseeko(_M_file, -c, SEEK_CUR)) + return -1; +#endif + if (update_st) + _M_st = state; + _M_extbufnext = _M_extbufend = _M_extbuf; + this->setg(0, 0, 0); + _M_cm = std::ios_base::openmode(0); + } + return 0; +} + +/////////////////////////////////////////////////////////////////////////////// +template <class CharT, class Traits> +void +basic_filebuf<CharT, Traits>::imbue(const std::locale& loc) +{ + sync(); + _M_cv = &std::use_facet<std::codecvt<char_type, char, state_type> >(loc); + bool old_anc = _M_always_noconv; + _M_always_noconv = _M_cv->always_noconv(); + if (old_anc != _M_always_noconv) + { + this->setg(0, 0, 0); + this->setp(0, 0); + // invariant, char_type is char, else we couldn't get here + if (_M_always_noconv) // need to dump _M_intbuf + { + if (_M_owns_eb) + delete [] _M_extbuf; + _M_owns_eb = _M_owns_ib; + _M_ebs = _M_ibs; + _M_extbuf = (char*)_M_intbuf; + _M_ibs = 0; + _M_intbuf = nullptr; + _M_owns_ib = false; + } + else // need to obtain an _M_intbuf. + { // If _M_extbuf is user-supplied, use it, else new _M_intbuf + if (!_M_owns_eb && _M_extbuf != _M_extbuf_min) + { + _M_ibs = _M_ebs; + _M_intbuf = (char_type*)_M_extbuf; + _M_owns_ib = false; + _M_extbuf = new char[_M_ebs]; + _M_owns_eb = true; + } + else + { + _M_ibs = _M_ebs; + _M_intbuf = new char_type[_M_ibs]; + _M_owns_ib = true; + } + } + } +} + +/////////////////////////////////////////////////////////////////////////////// +template <class CharT, class Traits> +bool +basic_filebuf<CharT, Traits>::_M_read_mode() +{ + if (!(_M_cm & std::ios_base::in)) + { + this->setp(0, 0); + if (_M_always_noconv) + this->setg((char_type*)_M_extbuf, + (char_type*)_M_extbuf + _M_ebs, + (char_type*)_M_extbuf + _M_ebs); + else + this->setg(_M_intbuf, _M_intbuf + _M_ibs, _M_intbuf + _M_ibs); + _M_cm = std::ios_base::in; + return true; + } + return false; +} + +/////////////////////////////////////////////////////////////////////////////// +template <class CharT, class Traits> +void +basic_filebuf<CharT, Traits>::_M_write_mode() +{ + if (!(_M_cm & std::ios_base::out)) + { + this->setg(0, 0, 0); + if (_M_ebs > sizeof(_M_extbuf_min)) + { + if (_M_always_noconv) + this->setp((char_type*)_M_extbuf, + (char_type*)_M_extbuf + (_M_ebs - 1)); + else + this->setp(_M_intbuf, _M_intbuf + (_M_ibs - 1)); + } + else + this->setp(0, 0); + _M_cm = std::ios_base::out; + } +} + +/////////////////////////////////////////////////////////////////////////////// +} + +/////////////////////////////////////////////////////////////////////////////// +#endif // KALDI_UTIL_BASIC_FILEBUF_H_ + +/////////////////////////////////////////////////////////////////////////////// + +/* + * ============================================================================ + * libc++ License + * ============================================================================ + * + * The libc++ library is dual licensed under both the University of Illinois + * "BSD-Like" license and the MIT license. As a user of this code you may + * choose to use it under either license. As a contributor, you agree to allow + * your code to be used under both. + * + * Full text of the relevant licenses is included below. + * + * ============================================================================ + * + * University of Illinois/NCSA + * Open Source License + * + * Copyright (c) 2009-2014 by the contributors listed in CREDITS.TXT (included below) + * + * All rights reserved. + * + * Developed by: + * + * LLVM Team + * + * University of Illinois at Urbana-Champaign + * + * http://llvm.org + * + * Permission is hereby granted, free of charge, to any person obtaining a copy of + * this software and associated documentation files (the "Software"), to deal with + * the Software without restriction, including without limitation the rights to + * use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies + * of the Software, and to permit persons to whom the Software is furnished to do + * so, subject to the following conditions: + * + * * Redistributions of source code must retain the above copyright notice, + * this list of conditions and the following disclaimers. + * + * * Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimers in the + * documentation and/or other materials provided with the distribution. + * + * * Neither the names of the LLVM Team, University of Illinois at + * Urbana-Champaign, nor the names of its contributors may be used to + * endorse or promote products derived from this Software without specific + * prior written permission. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS + * FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * CONTRIBUTORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS WITH THE + * SOFTWARE. + * + * ============================================================================== + * + * Copyright (c) 2009-2014 by the contributors listed in CREDITS.TXT (included below) + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + * + * ============================================================================== + * + * This file is a partial list of people who have contributed to the LLVM/libc++ + * project. If you have contributed a patch or made some other contribution to + * LLVM/libc++, please submit a patch to this file to add yourself, and it will be + * done! + * + * The list is sorted by surname and formatted to allow easy grepping and + * beautification by scripts. The fields are: name (N), email (E), web-address + * (W), PGP key ID and fingerprint (P), description (D), and snail-mail address + * (S). + * + * N: Saleem Abdulrasool + * E: [email protected] + * D: Minor patches and Linux fixes. + * + * N: Dimitry Andric + * E: [email protected] + * D: Visibility fixes, minor FreeBSD portability patches. + * + * N: Holger Arnold + * E: [email protected] + * D: Minor fix. + * + * N: Ruben Van Boxem + * E: vanboxem dot ruben at gmail dot com + * D: Initial Windows patches. + * + * N: David Chisnall + * E: theraven at theravensnest dot org + * D: FreeBSD and Solaris ports, libcxxrt support, some atomics work. + * + * N: Marshall Clow + * E: [email protected] + * E: [email protected] + * D: C++14 support, patches and bug fixes. + * + * N: Bill Fisher + * E: [email protected] + * D: Regex bug fixes. + * + * N: Matthew Dempsky + * E: [email protected] + * D: Minor patches and bug fixes. + * + * N: Google Inc. + * D: Copyright owner and contributor of the CityHash algorithm + * + * N: Howard Hinnant + * E: [email protected] + * D: Architect and primary author of libc++ + * + * N: Hyeon-bin Jeong + * E: [email protected] + * D: Minor patches and bug fixes. + * + * N: Argyrios Kyrtzidis + * E: [email protected] + * D: Bug fixes. + * + * N: Bruce Mitchener, Jr. + * E: [email protected] + * D: Emscripten-related changes. + * + * N: Michel Morin + * E: [email protected] + * D: Minor patches to is_convertible. + * + * N: Andrew Morrow + * E: [email protected] + * D: Minor patches and Linux fixes. + * + * N: Arvid Picciani + * E: aep at exys dot org + * D: Minor patches and musl port. + * + * N: Bjorn Reese + * E: [email protected] + * D: Initial regex prototype + * + * N: Nico Rieck + * E: [email protected] + * D: Windows fixes + * + * N: Jonathan Sauer + * D: Minor patches, mostly related to constexpr + * + * N: Craig Silverstein + * E: [email protected] + * D: Implemented Cityhash as the string hash function on 64-bit machines + * + * N: Richard Smith + * D: Minor patches. + * + * N: Joerg Sonnenberger + * E: [email protected] + * D: NetBSD port. + * + * N: Stephan Tolksdorf + * E: [email protected] + * D: Minor <atomic> fix + * + * N: Michael van der Westhuizen + * E: r1mikey at gmail dot com + * + * N: Klaas de Vries + * E: klaas at klaasgaaf dot nl + * D: Minor bug fix. + * + * N: Zhang Xiongpang + * E: [email protected] + * D: Minor patches and bug fixes. + * + * N: Xing Xue + * E: [email protected] + * D: AIX port + * + * N: Zhihao Yuan + * E: [email protected] + * D: Standard compatibility fixes. + * + * N: Jeffrey Yasskin + * E: [email protected] + * E: [email protected] + * D: Linux fixes. + */ diff --git a/kaldi_io/src/kaldi/util/common-utils.h b/kaldi_io/src/kaldi/util/common-utils.h new file mode 100644 index 0000000..9d39f9d --- /dev/null +++ b/kaldi_io/src/kaldi/util/common-utils.h @@ -0,0 +1,31 @@ +// util/common-utils.h + +// Copyright 2009-2011 Microsoft Corporation + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at + +// http://www.apache.org/licenses/LICENSE-2.0 + +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. +#ifndef KALDI_UTIL_COMMON_UTILS_H_ +#define KALDI_UTIL_COMMON_UTILS_H_ + +#include "base/kaldi-common.h" +#include "util/parse-options.h" +#include "util/kaldi-io.h" +#include "util/simple-io-funcs.h" +#include "util/kaldi-holder.h" +#include "util/kaldi-table.h" +#include "util/table-types.h" +#include "util/text-utils.h" + +#endif diff --git a/kaldi_io/src/kaldi/util/const-integer-set-inl.h b/kaldi_io/src/kaldi/util/const-integer-set-inl.h new file mode 100644 index 0000000..8f92ab2 --- /dev/null +++ b/kaldi_io/src/kaldi/util/const-integer-set-inl.h @@ -0,0 +1,88 @@ +// util/const-integer-set-inl.h + +// Copyright 2009-2011 Microsoft Corporation + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + + +#ifndef KALDI_UTIL_CONST_INTEGER_SET_INL_H_ +#define KALDI_UTIL_CONST_INTEGER_SET_INL_H_ + +// Do not include this file directly. It is included by const-integer-set.h + + +namespace kaldi { + +template<class I> +void ConstIntegerSet<I>::InitInternal() { + KALDI_ASSERT_IS_INTEGER_TYPE(I); + quick_set_.clear(); // just in case we previously had data. + if (slow_set_.size() == 0) { + lowest_member_=(I) 1; + highest_member_=(I) 0; + contiguous_ = false; + quick_ = false; + } else { + lowest_member_ = slow_set_.front(); + highest_member_ = slow_set_.back(); + size_t range = highest_member_ + 1 - lowest_member_; + if (range == slow_set_.size()) { + contiguous_ = true; + quick_=false; + } else { + contiguous_ = false; + if (range < slow_set_.size() * 8 * sizeof(I)) { // If it would be more compact to store as bool + // (assuming 1 bit per element)... + quick_set_.resize(range, false); + for (size_t i = 0;i < slow_set_.size();i++) + quick_set_[slow_set_[i] - lowest_member_] = true; + quick_ = true; + } else { + quick_ = false; + } + } + } +} + +template<class I> +int ConstIntegerSet<I>::count(I i) const { + if (i < lowest_member_ || i > highest_member_) return 0; + else { + if (contiguous_) return true; + if (quick_) return (quick_set_[i-lowest_member_] ? 1 : 0); + else { + bool ans = std::binary_search(slow_set_.begin(), slow_set_.end(), i); + return (ans ? 1 : 0); + } + } +} + +template<class I> +void ConstIntegerSet<I>::Write(std::ostream &os, bool binary) const { + WriteIntegerVector(os, binary, slow_set_); +} + +template<class I> +void ConstIntegerSet<I>::Read(std::istream &is, bool binary) { + ReadIntegerVector(is, binary, &slow_set_); + InitInternal(); +} + + + +} // end namespace kaldi + +#endif diff --git a/kaldi_io/src/kaldi/util/const-integer-set.h b/kaldi_io/src/kaldi/util/const-integer-set.h new file mode 100644 index 0000000..ffdce4d --- /dev/null +++ b/kaldi_io/src/kaldi/util/const-integer-set.h @@ -0,0 +1,95 @@ +// util/const-integer-set.h + +// Copyright 2009-2011 Microsoft Corporation + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + + +#ifndef KALDI_UTIL_CONST_INTEGER_SET_H_ +#define KALDI_UTIL_CONST_INTEGER_SET_H_ +#include <vector> +#include <set> +#include <algorithm> +#include <limits> +#include <cassert> +#include "util/stl-utils.h" + + /* ConstIntegerSet is a way to efficiently test whether something is in a + supplied set of integers. It can be initialized from a vector or set, but + never changed after that. It either uses a sorted vector or an array of + bool, depending on the input. It behaves like a const version of an STL set, with + only a subset of the functionality, except all the member functions are + upper-case. + + Note that we could get rid of the member slow_set_, but we'd have to + do more work to implement an iterator type. This would save memory. + */ + +namespace kaldi { + +template<class I> class ConstIntegerSet { + public: + ConstIntegerSet(): lowest_member_(1), highest_member_(0) { } + + void Init(const std::vector<I> &input) { + slow_set_ = input; + SortAndUniq(&slow_set_); + InitInternal(); + } + + void Init(const std::set<I> &input) { + CopySetToVector(input, &slow_set_); + InitInternal(); + } + + explicit ConstIntegerSet(const std::vector<I> &input): slow_set_(input) { + SortAndUniq(&slow_set_); + InitInternal(); + } + explicit ConstIntegerSet(const std::set<I> &input) { + CopySetToVector(input, &slow_set_); + InitInternal(); + } + explicit ConstIntegerSet(const ConstIntegerSet<I> &other): slow_set_(other.slow_set_) { + InitInternal(); + } + + int count(I i) const; // returns 1 or 0. + + typedef typename std::vector<I>::const_iterator iterator; + iterator begin() const { return slow_set_.begin(); } + iterator end() const { return slow_set_.end(); } + size_t size() const { return slow_set_.size(); } + bool empty() const { return slow_set_.empty(); } + + void Write(std::ostream &os, bool binary) const; + void Read(std::istream &is, bool binary); + + private: + I lowest_member_; + I highest_member_; + bool contiguous_; + bool quick_; + std::vector<bool> quick_set_; + std::vector<I> slow_set_; + void InitInternal(); +}; + +} // end namespace kaldi + +#include "const-integer-set-inl.h" + +#endif diff --git a/kaldi_io/src/kaldi/util/edit-distance-inl.h b/kaldi_io/src/kaldi/util/edit-distance-inl.h new file mode 100644 index 0000000..ebbfb71 --- /dev/null +++ b/kaldi_io/src/kaldi/util/edit-distance-inl.h @@ -0,0 +1,189 @@ +// util/edit-distance-inl.h + +// Copyright 2009-2011 Microsoft Corporation; Haihua Xu; Yanmin Qian + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at + +// http://www.apache.org/licenses/LICENSE-2.0 + +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#ifndef KALDI_UTIL_EDIT_DISTANCE_INL_H_ +#define KALDI_UTIL_EDIT_DISTANCE_INL_H_ +#include "util/stl-utils.h" + + +namespace kaldi { + +template<class T> +int32 LevenshteinEditDistance(const std::vector<T> &a, + const std::vector<T> &b) { + // Algorithm: + // write A and B for the sequences, with elements a_0 .. + // let |A| = M and |B| = N be the lengths, and have + // elements a_0 ... a_{M-1} and b_0 ... b_{N-1}. + // We are computing the recursion + // E(m, n) = min( E(m-1, n-1) + (1-delta(a_{m-1}, b_{n-1})), + // E(m-1, n), + // E(m, n-1) ). + // where E(m, n) is defined for m = 0..M and n = 0..N and out-of- + // bounds quantities are considered to be infinity (i.e. the + // recursion does not visit them). + + // We do this computation using a vector e of size N+1. + // The outer iterations range over m = 0..M. + + int M = a.size(), N = b.size(); + std::vector<int32> e(N+1); + std::vector<int32> e_tmp(N+1); + // initialize e. + for (size_t i = 0; i < e.size(); i++) + e[i] = i; + for (int32 m = 1; m <= M; m++) { + // computing E(m, .) from E(m-1, .) + // handle special case n = 0: + e_tmp[0] = e[0] + 1; + + for (int32 n = 1; n <= N; n++) { + int32 term1 = e[n-1] + (a[m-1] == b[n-1] ? 0 : 1); + int32 term2 = e[n] + 1; + int32 term3 = e_tmp[n-1] + 1; + e_tmp[n] = std::min(term1, std::min(term2, term3)); + } + e = e_tmp; + } + return e.back(); +} +// +struct error_stats{ + int32 ins_num; + int32 del_num; + int32 sub_num; + int32 total_cost; // minimum total cost to the current alignment. +}; +// Note that both hyp and ref should not contain noise word in +// the following implementation. + +template<class T> +int32 LevenshteinEditDistance(const std::vector<T> &ref, + const std::vector<T> &hyp, + int32 *ins, int32 *del, int32 *sub) { + // temp sequence to remember error type and stats. + std::vector<error_stats> e(ref.size()+1); + std::vector<error_stats> cur_e(ref.size()+1); + // initialize the first hypothesis aligned to the reference at each + // position:[hyp_index =0][ref_index] + for (size_t i =0; i < e.size(); i ++) { + e[i].ins_num = 0; + e[i].sub_num = 0; + e[i].del_num = i; + e[i].total_cost = i; + } + + // for other alignments + for (size_t hyp_index = 1; hyp_index <= hyp.size(); hyp_index ++) { + cur_e[0] = e[0]; + cur_e[0].ins_num ++; + cur_e[0].total_cost ++; + for (size_t ref_index = 1; ref_index <= ref.size(); ref_index ++) { + + int32 ins_err = e[ref_index].total_cost + 1; + int32 del_err = cur_e[ref_index-1].total_cost + 1; + int32 sub_err = e[ref_index-1].total_cost; + if (hyp[hyp_index-1] != ref[ref_index-1]) + sub_err ++; + + if (sub_err < ins_err && sub_err < del_err) { + cur_e[ref_index] =e[ref_index-1]; + if (hyp[hyp_index-1] != ref[ref_index-1]) + cur_e[ref_index].sub_num ++; // substitution error should be increased + cur_e[ref_index].total_cost = sub_err; + }else if (del_err < ins_err ) { + cur_e[ref_index] = cur_e[ref_index-1]; + cur_e[ref_index].total_cost = del_err; + cur_e[ref_index].del_num ++; // deletion number is increased. + }else{ + cur_e[ref_index] = e[ref_index]; + cur_e[ref_index].total_cost = ins_err; + cur_e[ref_index].ins_num ++; // insertion number is increased. + } + } + e = cur_e; // alternate for the next recursion. + } + size_t ref_index = e.size()-1; + *ins = e[ref_index].ins_num, *del = e[ref_index].del_num, *sub = e[ref_index].sub_num; + return e[ref_index].total_cost; +} + +template<class T> +int32 LevenshteinAlignment(const std::vector<T> &a, + const std::vector<T> &b, + T eps_symbol, + std::vector<std::pair<T, T> > *output) { + // Check inputs: + { + KALDI_ASSERT(output != NULL); + for (size_t i = 0; i < a.size(); i++) KALDI_ASSERT(a[i] != eps_symbol); + for (size_t i = 0; i < b.size(); i++) KALDI_ASSERT(b[i] != eps_symbol); + } + output->clear(); + // This is very memory-inefficiently implemented using a vector of vectors. + size_t M = a.size(), N = b.size(); + size_t m, n; + std::vector<std::vector<int32> > e(M+1); + for (m = 0; m <=M; m++) e[m].resize(N+1); + for (n = 0; n <= N; n++) + e[0][n] = n; + for (m = 1; m <= M; m++) { + e[m][0] = e[m-1][0] + 1; + for (n = 1; n <= N; n++) { + int32 sub_or_ok = e[m-1][n-1] + (a[m-1] == b[n-1] ? 0 : 1); + int32 del = e[m-1][n] + 1; // assumes a == ref, b == hyp. + int32 ins = e[m][n-1] + 1; + e[m][n] = std::min(sub_or_ok, std::min(del, ins)); + } + } + // get time-reversed output first: trace back. + m = M; n = N; + while (m != 0 || n != 0) { + size_t last_m, last_n; + if (m == 0) { last_m = m; last_n = n-1; } + else if (n == 0) { last_m = m-1; last_n = n; } + else { + int32 sub_or_ok = e[m-1][n-1] + (a[m-1] == b[n-1] ? 0 : 1); + int32 del = e[m-1][n] + 1; // assumes a == ref, b == hyp. + int32 ins = e[m][n-1] + 1; + if (sub_or_ok <= std::min(del, ins)) { // choose sub_or_ok if all else equal. + last_m = m-1; last_n = n-1; + } else { + if (del <= ins) { // choose del over ins if equal. + last_m = m-1; last_n = n; + } else { + last_m = m; last_n = n-1; + } + } + } + T a_sym, b_sym; + a_sym = (last_m == m ? eps_symbol : a[last_m]); + b_sym = (last_n == n ? eps_symbol : b[last_n]); + output->push_back(std::make_pair(a_sym, b_sym)); + m = last_m; + n = last_n; + } + ReverseVector(output); + return e[M][N]; +} + + +} // end namespace kaldi + +#endif // KALDI_UTIL_EDIT_DISTANCE_INL_H_ diff --git a/kaldi_io/src/kaldi/util/edit-distance.h b/kaldi_io/src/kaldi/util/edit-distance.h new file mode 100644 index 0000000..6000622 --- /dev/null +++ b/kaldi_io/src/kaldi/util/edit-distance.h @@ -0,0 +1,63 @@ +// util/edit-distance.h + +// Copyright 2009-2011 Microsoft Corporation; Haihua Xu + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + + +#ifndef KALDI_UTIL_EDIT_DISTANCE_H_ +#define KALDI_UTIL_EDIT_DISTANCE_H_ +#include <vector> +#include <set> +#include <algorithm> +#include <limits> +#include <cassert> +#include "base/kaldi-types.h" + +namespace kaldi { + +// Compute the edit-distance between two strings. +template<class T> +int32 LevenshteinEditDistance(const std::vector<T> &a, + const std::vector<T> &b); + + +// edit distance calculation with conventional method. +// note: noise word must be filtered out from the hypothesis and reference sequence +// before the following procedure conducted. +template<class T> +int32 LevenshteinEditDistance(const std::vector<T> &ref, + const std::vector<T> &hyp, + int32 *ins, int32 *del, int32 *sub); + +// This version of the edit-distance computation outputs the alignment +// between the two. This is a vector of pairs of (symbol a, symbol b). +// The epsilon symbol (eps_symbol) must not occur in sequences a or b. +// Where one aligned to no symbol in the other (insertion or deletion), +// epsilon will be the corresponding member of the pair. +// It returns the edit-distance between the two strings. + +template<class T> +int32 LevenshteinAlignment(const std::vector<T> &a, + const std::vector<T> &b, + T eps_symbol, + std::vector<std::pair<T, T> > *output); + +} // end namespace kaldi + +#include "edit-distance-inl.h" + +#endif diff --git a/kaldi_io/src/kaldi/util/hash-list-inl.h b/kaldi_io/src/kaldi/util/hash-list-inl.h new file mode 100644 index 0000000..19c2bb6 --- /dev/null +++ b/kaldi_io/src/kaldi/util/hash-list-inl.h @@ -0,0 +1,183 @@ +// util/hash-list-inl.h + +// Copyright 2009-2011 Microsoft Corporation +// 2013 Johns Hopkins University (author: Daniel Povey) + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + + +#ifndef KALDI_UTIL_HASH_LIST_INL_H_ +#define KALDI_UTIL_HASH_LIST_INL_H_ + +// Do not include this file directly. It is included by fast-hash.h + + +namespace kaldi { + +template<class I, class T> HashList<I, T>::HashList() { + list_head_ = NULL; + bucket_list_tail_ = static_cast<size_t>(-1); // invalid. + hash_size_ = 0; + freed_head_ = NULL; +} + +template<class I, class T> void HashList<I, T>::SetSize(size_t size) { + hash_size_ = size; + KALDI_ASSERT(list_head_ == NULL && bucket_list_tail_ == static_cast<size_t>(-1)); // make sure empty. + if (size > buckets_.size()) + buckets_.resize(size, HashBucket(0, NULL)); +} + +template<class I, class T> +typename HashList<I, T>::Elem* HashList<I, T>::Clear() { + // Clears the hashtable and gives ownership of the currently contained list to the + // user. + for (size_t cur_bucket = bucket_list_tail_; + cur_bucket != static_cast<size_t>(-1); + cur_bucket = buckets_[cur_bucket].prev_bucket) { + buckets_[cur_bucket].last_elem = NULL; // this is how we indicate "empty". + } + bucket_list_tail_ = static_cast<size_t>(-1); + Elem *ans = list_head_; + list_head_ = NULL; + return ans; +} + +template<class I, class T> +const typename HashList<I, T>::Elem* HashList<I, T>::GetList() const { + return list_head_; +} + +template<class I, class T> +inline void HashList<I, T>::Delete(Elem *e) { + e->tail = freed_head_; + freed_head_ = e; +} + +template<class I, class T> +inline typename HashList<I, T>::Elem* HashList<I, T>::Find(I key) { + size_t index = (static_cast<size_t>(key) % hash_size_); + HashBucket &bucket = buckets_[index]; + if (bucket.last_elem == NULL) { + return NULL; // empty bucket. + } else { + Elem *head = (bucket.prev_bucket == static_cast<size_t>(-1) ? + list_head_ : + buckets_[bucket.prev_bucket].last_elem->tail), + *tail = bucket.last_elem->tail; + for (Elem *e = head; e != tail; e = e->tail) + if (e->key == key) return e; + return NULL; // Not found. + } +} + +template<class I, class T> +inline typename HashList<I, T>::Elem* HashList<I, T>::New() { + if (freed_head_) { + Elem *ans = freed_head_; + freed_head_ = freed_head_->tail; + return ans; + } else { + Elem *tmp = new Elem[allocate_block_size_]; + for (size_t i = 0; i+1 < allocate_block_size_; i++) + tmp[i].tail = tmp+i+1; + tmp[allocate_block_size_-1].tail = NULL; + freed_head_ = tmp; + allocated_.push_back(tmp); + return this->New(); + } +} + +template<class I, class T> +HashList<I, T>::~HashList() { + // First test whether we had any memory leak within the + // HashList, i.e. things for which the user did not call Delete(). + size_t num_in_list = 0, num_allocated = 0; + for (Elem *e = freed_head_; e != NULL; e = e->tail) + num_in_list++; + for (size_t i = 0; i < allocated_.size(); i++) { + num_allocated += allocate_block_size_; + delete[] allocated_[i]; + } + if (num_in_list != num_allocated) { + KALDI_WARN << "Possible memory leak: " << num_in_list + << " != " << num_allocated + << ": you might have forgotten to call Delete on " + << "some Elems"; + } +} + + +template<class I, class T> +void HashList<I, T>::Insert(I key, T val) { + size_t index = (static_cast<size_t>(key) % hash_size_); + HashBucket &bucket = buckets_[index]; + Elem *elem = New(); + elem->key = key; + elem->val = val; + + if (bucket.last_elem == NULL) { // Unoccupied bucket. Insert at + // head of bucket list (which is tail of regular list, they go in + // opposite directions). + if (bucket_list_tail_ == static_cast<size_t>(-1)) { + // list was empty so this is the first elem. + KALDI_ASSERT(list_head_ == NULL); + list_head_ = elem; + } else { + // link in to the chain of Elems + buckets_[bucket_list_tail_].last_elem->tail = elem; + } + elem->tail = NULL; + bucket.last_elem = elem; + bucket.prev_bucket = bucket_list_tail_; + bucket_list_tail_ = index; + } else { + // Already-occupied bucket. Insert at tail of list of elements within + // the bucket. + elem->tail = bucket.last_elem->tail; + bucket.last_elem->tail = elem; + bucket.last_elem = elem; + } +} + +template<class I, class T> +void HashList<I, T>::InsertMore(I key, T val) { + size_t index = (static_cast<size_t>(key) % hash_size_); + HashBucket &bucket = buckets_[index]; + Elem *elem = New(); + elem->key = key; + elem->val = val; + + KALDI_ASSERT(bucket.last_elem != NULL); // we assume there is already one element + if (bucket.last_elem->key == key) { // standard behavior: add as last element + elem->tail = bucket.last_elem->tail; + bucket.last_elem->tail = elem; + bucket.last_elem = elem; + return; + } + Elem *e = (bucket.prev_bucket == static_cast<size_t>(-1) ? + list_head_ : buckets_[bucket.prev_bucket].last_elem->tail); + // find place to insert in linked list + while (e != bucket.last_elem->tail && e->key != key) e = e->tail; + KALDI_ASSERT(e->key == key); // not found? - should not happen + elem->tail = e->tail; + e->tail = elem; +} + + +} // end namespace kaldi + +#endif diff --git a/kaldi_io/src/kaldi/util/hash-list.h b/kaldi_io/src/kaldi/util/hash-list.h new file mode 100644 index 0000000..4524759 --- /dev/null +++ b/kaldi_io/src/kaldi/util/hash-list.h @@ -0,0 +1,140 @@ +// util/hash-list.h + +// Copyright 2009-2011 Microsoft Corporation +// 2013 Johns Hopkins University (author: Daniel Povey) + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + + +#ifndef KALDI_UTIL_HASH_LIST_H_ +#define KALDI_UTIL_HASH_LIST_H_ +#include <vector> +#include <set> +#include <algorithm> +#include <limits> +#include <cassert> +#include "util/stl-utils.h" + + +/* This header provides utilities for a structure that's used in a decoder (but + is quite generic in nature so we implement and test it separately). + Basically it's a singly-linked list, but implemented in such a way that we + can quickly search for elements in the list. We give it a slightly richer + interface than just a hash and a list. The idea is that we want to separate + the hash part and the list part: basically, in the decoder, we want to have a + single hash for the current frame and the next frame, because by the time we + need to access the hash for the next frame we no longer need the hash for the + previous frame. So we have an operation that clears the hash but leaves the + list structure intact. We also control memory management inside this object, + to avoid repeated new's/deletes. + + See hash-list-test.cc for an example of how to use this object. +*/ + + +namespace kaldi { + +template<class I, class T> class HashList { + + public: + struct Elem { + I key; + T val; + Elem *tail; + }; + + /// Constructor takes no arguments. Call SetSize to inform it of the likely size. + HashList(); + + /// Clears the hash and gives the head of the current list to the user; + /// ownership is transferred to the user (the user must call Delete() + /// for each element in the list, at his/her leisure). + Elem *Clear(); + + /// Gives the head of the current list to the user. Ownership retained in the + /// class. Caution: in December 2013 the return type was changed to const Elem* + /// and this function was made const. You may need to change some types of + /// local Elem* variables to const if this produces compilation errors. + const Elem *GetList() const; + + /// Think of this like delete(). It is to be called for each Elem in turn + /// after you "obtained ownership" by doing Clear(). This is not the opposite of + /// Insert, it is the opposite of New. It's really a memory operation. + inline void Delete(Elem *e); + + /// This should probably not be needed to be called directly by the user. Think of it as opposite + /// to Delete(); + inline Elem *New(); + + /// Find tries to find this element in the current list using the hashtable. + /// It returns NULL if not present. The Elem it returns is not owned by the user, + /// it is part of the internal list owned by this object, but the user is + /// free to modify the "val" element. + inline Elem *Find(I key); + + /// Insert inserts a new element into the hashtable/stored list. By calling this, + /// the user asserts that it is not already present (e.g. Find was called and + /// returned NULL). With current code, calling this if an element already exists will + /// result in duplicate elements in the structure, and Find() will find the + /// first one that was added. [but we don't guarantee this behavior]. + inline void Insert(I key, T val); + + /// Insert inserts another element with same key into the hashtable/stored list. + /// By calling this, the user asserts that one element with that key is already present. + /// We insert it that way, that all elements with the same key follow each other. + /// Find() will return the first one of the elements with the same key. + inline void InsertMore(I key, T val); + + /// SetSize tells the object how many hash buckets to allocate (should typically be + /// at least twice the number of objects we expect to go in the structure, for fastest + /// performance). It must be called while the hash is empty (e.g. after Clear() or + /// after initializing the object, but before adding anything to the hash. + void SetSize(size_t sz); + + /// Returns current number of hash buckets. + inline size_t Size() { return hash_size_; } + + ~HashList(); + private: + + struct HashBucket { + size_t prev_bucket; // index to next bucket (-1 if list tail). Note: list of buckets + // goes in opposite direction to list of Elems. + Elem *last_elem; // pointer to last element in this bucket (NULL if empty) + inline HashBucket(size_t i, Elem *e): prev_bucket(i), last_elem(e) {} + }; + + Elem *list_head_; // head of currently stored list. + size_t bucket_list_tail_; // tail of list of active hash buckets. + + size_t hash_size_; // number of hash buckets. + + std::vector<HashBucket> buckets_; + + Elem *freed_head_; // head of list of currently freed elements. [ready for allocation] + + std::vector<Elem*> allocated_; // list of allocated blocks. + + static const size_t allocate_block_size_ = 1024; // Number of Elements to allocate in one block. Must be + // largish so storing allocated_ doesn't become a problem. +}; + + +} // end namespace kaldi + +#include "hash-list-inl.h" + +#endif diff --git a/kaldi_io/src/kaldi/util/kaldi-holder-inl.h b/kaldi_io/src/kaldi/util/kaldi-holder-inl.h new file mode 100644 index 0000000..6a66e61 --- /dev/null +++ b/kaldi_io/src/kaldi/util/kaldi-holder-inl.h @@ -0,0 +1,800 @@ +// util/kaldi-holder-inl.h + +// Copyright 2009-2011 Microsoft Corporation + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + + +#ifndef KALDI_UTIL_KALDI_HOLDER_INL_H_ +#define KALDI_UTIL_KALDI_HOLDER_INL_H_ + +#include <algorithm> +#include "util/kaldi-io.h" +#include "util/text-utils.h" +#include "matrix/kaldi-matrix.h" + +namespace kaldi { + +/// \addtogroup holders +/// @{ + + +// KaldiObjectHolder is valid only for Kaldi objects with +// copy constructors, default constructors, and "normal" +// Kaldi Write and Read functions. E.g. it works for +// Matrix and Vector. +template<class KaldiType> class KaldiObjectHolder { + public: + typedef KaldiType T; + + KaldiObjectHolder(): t_(NULL) { } + + static bool Write(std::ostream &os, bool binary, const T &t) { + InitKaldiOutputStream(os, binary); // Puts binary header if binary mode. + try { + t.Write(os, binary); + return os.good(); + } catch (const std::exception &e) { + KALDI_WARN << "Exception caught writing Table object: " << e.what(); + if (!IsKaldiError(e.what())) { std::cerr << e.what(); } + return false; // Write failure. + } + } + + void Clear() { + if (t_) { + delete t_; + t_ = NULL; + } + } + + // Reads into the holder. + bool Read(std::istream &is) { + if (t_) delete t_; + t_ = new T; + // Don't want any existing state to complicate the read functioN: get new object. + bool is_binary; + if (!InitKaldiInputStream(is, &is_binary)) { + KALDI_WARN << "Reading Table object, failed reading binary header\n"; + return false; + } + try { + t_->Read(is, is_binary); + return true; + } catch (std::exception &e) { + KALDI_WARN << "Exception caught reading Table object "; + if (!IsKaldiError(e.what())) { std::cerr << e.what(); } + delete t_; + t_ = NULL; + return false; + } + } + + // Kaldi objects always have the stream open in binary mode for + // reading. + static bool IsReadInBinary() { return true; } + + const T &Value() const { + // code error if !t_. + if (!t_) KALDI_ERR << "KaldiObjectHolder::Value() called wrongly."; + return *t_; + } + + ~KaldiObjectHolder() { if (t_) delete t_; } + private: + KALDI_DISALLOW_COPY_AND_ASSIGN(KaldiObjectHolder); + T *t_; +}; + + +// BasicHolder is valid for float, double, bool, and integer +// types. There will be a compile time error otherwise, because +// we make sure that the {Write, Read}BasicType functions do not +// get instantiated for other types. + +template<class BasicType> class BasicHolder { + public: + typedef BasicType T; + + BasicHolder(): t_(static_cast<T>(-1)) { } + + static bool Write(std::ostream &os, bool binary, const T &t) { + InitKaldiOutputStream(os, binary); // Puts binary header if binary mode. + try { + WriteBasicType(os, binary, t); + if (!binary) os << '\n'; // Makes output format more readable and + // easier to manipulate. + return os.good(); + } catch (const std::exception &e) { + KALDI_WARN << "Exception caught writing Table object: " << e.what(); + if (!IsKaldiError(e.what())) { std::cerr << e.what(); } + return false; // Write failure. + } + } + + void Clear() { } + + // Reads into the holder. + bool Read(std::istream &is) { + bool is_binary; + if (!InitKaldiInputStream(is, &is_binary)) { + KALDI_WARN << "Reading Table object [integer type], failed reading binary header\n"; + return false; + } + try { + int c; + if (!is_binary) { // This is to catch errors, the class would work without it.. + // Eat up any whitespace and make sure it's not newline. + while (isspace((c = is.peek())) && c != static_cast<int>('\n')) is.get(); + if (is.peek() == '\n') { + KALDI_WARN << "Found newline but expected basic type."; + return false; // This is just to catch a more- + // likely-than average type of error (empty line before the token), since + // ReadBasicType will eat it up. + } + } + + ReadBasicType(is, is_binary, &t_); + + if (!is_binary) { // This is to catch errors, the class would work without it.. + // make sure there is a newline. + while (isspace((c = is.peek())) && c != static_cast<int>('\n')) is.get(); + if (is.peek() != '\n') { + KALDI_WARN << "BasicHolder::Read, expected newline, got " + << CharToString(is.peek()) << ", position " << is.tellg(); + return false; + } + is.get(); // Consume the newline. + } + return true; + } catch (std::exception &e) { + KALDI_WARN << "Exception caught reading Table object"; + if (!IsKaldiError(e.what())) { std::cerr << e.what(); } + return false; + } + } + + // Objects read/written with the Kaldi I/O functions always have the stream + // open in binary mode for reading. + static bool IsReadInBinary() { return true; } + + const T &Value() const { + return t_; + } + + ~BasicHolder() { } + private: + KALDI_DISALLOW_COPY_AND_ASSIGN(BasicHolder); + + T t_; +}; + + +/// A Holder for a vector of basic types, e.g. +/// std::vector<int32>, std::vector<float>, and so on. +/// Note: a basic type is defined as a type for which ReadBasicType +/// and WriteBasicType are implemented, i.e. integer and floating +/// types, and bool. +template<class BasicType> class BasicVectorHolder { + public: + typedef std::vector<BasicType> T; + + BasicVectorHolder() { } + + static bool Write(std::ostream &os, bool binary, const T &t) { + InitKaldiOutputStream(os, binary); // Puts binary header if binary mode. + try { + if (binary) { // need to write the size, in binary mode. + KALDI_ASSERT(static_cast<size_t>(static_cast<int32>(t.size())) == t.size()); + // Or this Write routine cannot handle such a large vector. + // use int32 because it's fixed size regardless of compilation. + // change to int64 (plus in Read function) if this becomes a problem. + WriteBasicType(os, binary, static_cast<int32>(t.size())); + for (typename std::vector<BasicType>::const_iterator iter = t.begin(); + iter != t.end(); ++iter) + WriteBasicType(os, binary, *iter); + + } else { + for (typename std::vector<BasicType>::const_iterator iter = t.begin(); + iter != t.end(); ++iter) + WriteBasicType(os, binary, *iter); + os << '\n'; // Makes output format more readable and + // easier to manipulate. In text mode, this function writes something like + // "1 2 3\n". + } + return os.good(); + } catch (const std::exception &e) { + KALDI_WARN << "Exception caught writing Table object (BasicVector). "; + if (!IsKaldiError(e.what())) { std::cerr << e.what(); } + return false; // Write failure. + } + } + + void Clear() { t_.clear(); } + + // Reads into the holder. + bool Read(std::istream &is) { + t_.clear(); + bool is_binary; + if (!InitKaldiInputStream(is, &is_binary)) { + KALDI_WARN << "Reading Table object [integer type], failed reading binary header\n"; + return false; + } + if (!is_binary) { + // In text mode, we terminate with newline. + std::string line; + getline(is, line); // this will discard the \n, if present. + if (is.fail()) { + KALDI_WARN << "BasicVectorHolder::Read, error reading line " << (is.eof() ? "[eof]" : ""); + return false; // probably eof. fail in any case. + } + std::istringstream line_is(line); + try { + while (1) { + line_is >> std::ws; // eat up whitespace. + if (line_is.eof()) break; + BasicType bt; + ReadBasicType(line_is, false, &bt); + t_.push_back(bt); + } + return true; + } catch(std::exception &e) { + KALDI_WARN << "BasicVectorHolder::Read, could not interpret line: " << line; + if (!IsKaldiError(e.what())) { std::cerr << e.what(); } + return false; + } + } else { // binary mode. + size_t filepos = is.tellg(); + try { + int32 size; + ReadBasicType(is, true, &size); + t_.resize(size); + for (typename std::vector<BasicType>::iterator iter = t_.begin(); + iter != t_.end(); + ++iter) { + ReadBasicType(is, true, &(*iter)); + } + return true; + } catch (...) { + KALDI_WARN << "BasicVectorHolder::Read, read error or unexpected data at archive entry beginning at file position " << filepos; + return false; + } + } + } + + // Objects read/written with the Kaldi I/O functions always have the stream + // open in binary mode for reading. + static bool IsReadInBinary() { return true; } + + const T &Value() const { return t_; } + + ~BasicVectorHolder() { } + private: + KALDI_DISALLOW_COPY_AND_ASSIGN(BasicVectorHolder); + T t_; +}; + + +/// BasicVectorVectorHolder is a Holder for a vector of vector of +/// a basic type, e.g. std::vector<std::vector<int32> >. +/// Note: a basic type is defined as a type for which ReadBasicType +/// and WriteBasicType are implemented, i.e. integer and floating +/// types, and bool. +template<class BasicType> class BasicVectorVectorHolder { + public: + typedef std::vector<std::vector<BasicType> > T; + + BasicVectorVectorHolder() { } + + static bool Write(std::ostream &os, bool binary, const T &t) { + InitKaldiOutputStream(os, binary); // Puts binary header if binary mode. + try { + if (binary) { // need to write the size, in binary mode. + KALDI_ASSERT(static_cast<size_t>(static_cast<int32>(t.size())) == t.size()); + // Or this Write routine cannot handle such a large vector. + // use int32 because it's fixed size regardless of compilation. + // change to int64 (plus in Read function) if this becomes a problem. + WriteBasicType(os, binary, static_cast<int32>(t.size())); + for (typename std::vector<std::vector<BasicType> >::const_iterator iter = t.begin(); + iter != t.end(); ++iter) { + KALDI_ASSERT(static_cast<size_t>(static_cast<int32>(iter->size())) == iter->size()); + WriteBasicType(os, binary, static_cast<int32>(iter->size())); + for (typename std::vector<BasicType>::const_iterator iter2=iter->begin(); + iter2 != iter->end(); ++iter2) { + WriteBasicType(os, binary, *iter2); + } + } + } else { // text mode... + // In text mode, we write out something like (for integers): + // "1 2 3 ; 4 5 ; 6 ; ; 7 8 9 ;\n" + // where the semicolon is a terminator, not a separator + // (a separator would cause ambiguity between an + // empty list, and a list containing a single empty list). + for (typename std::vector<std::vector<BasicType> >::const_iterator iter = t.begin(); + iter != t.end(); + ++iter) { + for (typename std::vector<BasicType>::const_iterator iter2=iter->begin(); + iter2 != iter->end(); ++iter2) + WriteBasicType(os, binary, *iter2); + os << "; "; + } + os << '\n'; + } + return os.good(); + } catch (const std::exception &e) { + KALDI_WARN << "Exception caught writing Table object. "; + if (!IsKaldiError(e.what())) { std::cerr << e.what(); } + return false; // Write failure. + } + } + + void Clear() { t_.clear(); } + + // Reads into the holder. + bool Read(std::istream &is) { + t_.clear(); + bool is_binary; + if (!InitKaldiInputStream(is, &is_binary)) { + KALDI_WARN << "Failed reading binary header\n"; + return false; + } + if (!is_binary) { + // In text mode, we terminate with newline. + try { // catching errors from ReadBasicType.. + std::vector<BasicType> v; // temporary vector + while (1) { + int i = is.peek(); + if (i == -1) { + KALDI_WARN << "Unexpected EOF"; + return false; + } else if (static_cast<char>(i) == '\n') { + if (!v.empty()) { + KALDI_WARN << "No semicolon before newline (wrong format)"; + return false; + } else { is.get(); return true; } + } else if (std::isspace(i)) { + is.get(); + } else if (static_cast<char>(i) == ';') { + t_.push_back(v); + v.clear(); + is.get(); + } else { // some object we want to read... + BasicType b; + ReadBasicType(is, false, &b); // throws on error. + v.push_back(b); + } + } + } catch(std::exception &e) { + KALDI_WARN << "BasicVectorVectorHolder::Read, read error"; + if (!IsKaldiError(e.what())) { std::cerr << e.what(); } + return false; + } + } else { // binary mode. + size_t filepos = is.tellg(); + try { + int32 size; + ReadBasicType(is, true, &size); + t_.resize(size); + for (typename std::vector<std::vector<BasicType> >::iterator iter = t_.begin(); + iter != t_.end(); + ++iter) { + int32 size2; + ReadBasicType(is, true, &size2); + iter->resize(size2); + for (typename std::vector<BasicType>::iterator iter2 = iter->begin(); + iter2 != iter->end(); + ++iter2) + ReadBasicType(is, true, &(*iter2)); + } + return true; + } catch (...) { + KALDI_WARN << "Read error or unexpected data at archive entry beginning at file position " << filepos; + return false; + } + } + } + + // Objects read/written with the Kaldi I/O functions always have the stream + // open in binary mode for reading. + static bool IsReadInBinary() { return true; } + + const T &Value() const { return t_; } + + ~BasicVectorVectorHolder() { } + private: + KALDI_DISALLOW_COPY_AND_ASSIGN(BasicVectorVectorHolder); + T t_; +}; + + +/// BasicPairVectorHolder is a Holder for a vector of pairs of +/// a basic type, e.g. std::vector<std::pair<int32> >. +/// Note: a basic type is defined as a type for which ReadBasicType +/// and WriteBasicType are implemented, i.e. integer and floating +/// types, and bool. +template<class BasicType> class BasicPairVectorHolder { + public: + typedef std::vector<std::pair<BasicType, BasicType> > T; + + BasicPairVectorHolder() { } + + static bool Write(std::ostream &os, bool binary, const T &t) { + InitKaldiOutputStream(os, binary); // Puts binary header if binary mode. + try { + if (binary) { // need to write the size, in binary mode. + KALDI_ASSERT(static_cast<size_t>(static_cast<int32>(t.size())) == t.size()); + // Or this Write routine cannot handle such a large vector. + // use int32 because it's fixed size regardless of compilation. + // change to int64 (plus in Read function) if this becomes a problem. + WriteBasicType(os, binary, static_cast<int32>(t.size())); + for (typename T::const_iterator iter = t.begin(); + iter != t.end(); ++iter) { + WriteBasicType(os, binary, iter->first); + WriteBasicType(os, binary, iter->second); + } + } else { // text mode... + // In text mode, we write out something like (for integers): + // "1 2 ; 4 5 ; 6 7 ; 8 9 \n" + // where the semicolon is a separator, not a terminator. + for (typename T::const_iterator iter = t.begin(); + iter != t.end();) { + WriteBasicType(os, binary, iter->first); + WriteBasicType(os, binary, iter->second); + ++iter; + if (iter != t.end()) + os << "; "; + } + os << '\n'; + } + return os.good(); + } catch (const std::exception &e) { + KALDI_WARN << "Exception caught writing Table object. "; + if (!IsKaldiError(e.what())) { std::cerr << e.what(); } + return false; // Write failure. + } + } + + void Clear() { t_.clear(); } + + // Reads into the holder. + bool Read(std::istream &is) { + t_.clear(); + bool is_binary; + if (!InitKaldiInputStream(is, &is_binary)) { + KALDI_WARN << "Reading Table object [integer type], failed reading binary header\n"; + return false; + } + if (!is_binary) { + // In text mode, we terminate with newline. + try { // catching errors from ReadBasicType.. + std::vector<BasicType> v; // temporary vector + while (1) { + int i = is.peek(); + if (i == -1) { + KALDI_WARN << "Unexpected EOF"; + return false; + } else if (static_cast<char>(i) == '\n') { + if (t_.empty() && v.empty()) { + is.get(); + return true; + } else if (v.size() == 2) { + t_.push_back(std::make_pair(v[0], v[1])); + is.get(); + return true; + } else { + KALDI_WARN << "Unexpected newline, reading vector<pair<?> >; got " + << v.size() << " elements, expected 2."; + return false; + } + } else if (std::isspace(i)) { + is.get(); + } else if (static_cast<char>(i) == ';') { + if (v.size() != 2) { + KALDI_WARN << "Wrong input format, reading vector<pair<?> >; got " + << v.size() << " elements, expected 2."; + return false; + } + t_.push_back(std::make_pair(v[0], v[1])); + v.clear(); + is.get(); + } else { // some object we want to read... + BasicType b; + ReadBasicType(is, false, &b); // throws on error. + v.push_back(b); + } + } + } catch(std::exception &e) { + KALDI_WARN << "BasicPairVectorHolder::Read, read error"; + if (!IsKaldiError(e.what())) { std::cerr << e.what(); } + return false; + } + } else { // binary mode. + size_t filepos = is.tellg(); + try { + int32 size; + ReadBasicType(is, true, &size); + t_.resize(size); + for (typename T::iterator iter = t_.begin(); + iter != t_.end(); + ++iter) { + ReadBasicType(is, true, &(iter->first)); + ReadBasicType(is, true, &(iter->second)); + } + return true; + } catch (...) { + KALDI_WARN << "BasicVectorHolder::Read, read error or unexpected data at archive entry beginning at file position " << filepos; + return false; + } + } + } + + // Objects read/written with the Kaldi I/O functions always have the stream + // open in binary mode for reading. + static bool IsReadInBinary() { return true; } + + const T &Value() const { return t_; } + + ~BasicPairVectorHolder() { } + private: + KALDI_DISALLOW_COPY_AND_ASSIGN(BasicPairVectorHolder); + T t_; +}; + + + + +// We define a Token as a nonempty, printable, whitespace-free std::string. +// The binary and text formats here are the same (newline-terminated) +// and as such we don't bother with the binary-mode headers. +class TokenHolder { + public: + typedef std::string T; + + TokenHolder() {} + + static bool Write(std::ostream &os, bool, const T &t) { // ignore binary-mode. + KALDI_ASSERT(IsToken(t)); + os << t << '\n'; + return os.good(); + } + + void Clear() { t_.clear(); } + + // Reads into the holder. + bool Read(std::istream &is) { + is >> t_; + if (is.fail()) return false; + char c; + while (isspace(c = is.peek()) && c!= '\n') is.get(); + if (is.peek() != '\n') { + KALDI_ERR << "TokenHolder::Read, expected newline, got char " << CharToString(is.peek()) + << ", at stream pos " << is.tellg(); + return false; + } + is.get(); // get '\n' + return true; + } + + + // Since this is fundamentally a text format, read in text mode (would work + // fine either way, but doing it this way will exercise more of the code). + static bool IsReadInBinary() { return false; } + + const T &Value() const { return t_; } + + ~TokenHolder() { } + private: + KALDI_DISALLOW_COPY_AND_ASSIGN(TokenHolder); + T t_; +}; + +// A Token is a nonempty, whitespace-free std::string. +// Class TokenVectorHolder is a Holder class for vectors of these. +class TokenVectorHolder { + public: + typedef std::vector<std::string> T; + + TokenVectorHolder() { } + + static bool Write(std::ostream &os, bool, const T &t) { // ignore binary-mode. + for (std::vector<std::string>::const_iterator iter = t.begin(); + iter != t.end(); + ++iter) { + KALDI_ASSERT(IsToken(*iter)); // make sure it's whitespace-free, printable and nonempty. + os << *iter << ' '; + } + os << '\n'; + return os.good(); + } + + void Clear() { t_.clear(); } + + + // Reads into the holder. + bool Read(std::istream &is) { + t_.clear(); + + // there is no binary/non-binary mode. + + std::string line; + getline(is, line); // this will discard the \n, if present. + if (is.fail()) { + KALDI_WARN << "BasicVectorHolder::Read, error reading line " << (is.eof() ? "[eof]" : ""); + return false; // probably eof. fail in any case. + } + const char *white_chars = " \t\n\r\f\v"; + SplitStringToVector(line, white_chars, true, &t_); // true== omit empty strings e.g. + // between spaces. + return true; + } + + // Read in text format since it's basically a text-mode thing.. doesn't really matter, + // it would work either way since we ignore the extra '\r'. + static bool IsReadInBinary() { return false; } + + const T &Value() const { return t_; } + + private: + KALDI_DISALLOW_COPY_AND_ASSIGN(TokenVectorHolder); + T t_; +}; + + +class HtkMatrixHolder { + public: + typedef std::pair<Matrix<BaseFloat>, HtkHeader> T; + + HtkMatrixHolder() {} + + static bool Write(std::ostream &os, bool binary, const T &t) { + if (!binary) + KALDI_ERR << "Non-binary HTK-format write not supported."; + bool ans = WriteHtk(os, t.first, t.second); + if (!ans) + KALDI_WARN << "Error detected writing HTK-format matrix."; + return ans; + } + + void Clear() { t_.first.Resize(0, 0); } + + // Reads into the holder. + bool Read(std::istream &is) { + bool ans = ReadHtk(is, &t_.first, &t_.second); + if (!ans) { + KALDI_WARN << "Error detected reading HTK-format matrix."; + return false; + } + return ans; + } + + // HTK-format matrices only read in binary. + static bool IsReadInBinary() { return true; } + + const T &Value() const { return t_; } + + + // No destructor. + private: + KALDI_DISALLOW_COPY_AND_ASSIGN(HtkMatrixHolder); + T t_; +}; + +// SphinxMatrixHolder can be used to read and write feature files in +// CMU Sphinx format. 13-dimensional big-endian features are assumed. +// The ultimate reference is SphinxBase's source code (for example see +// feat_s2mfc_read() in src/libsphinxbase/feat/feat.c). +// We can't fully automate the detection of machine/feature file endianess +// mismatch here, because for this Sphinx relies on comparing the feature +// file's size with the number recorded in its header. We are working with +// streams, however(what happens if this is a Kaldi archive?). This should +// be no problem, because the usage help of Sphinx' "wave2feat" for example +// says that Sphinx features are always big endian. +// Note: the kFeatDim defaults to 13, see forward declaration in kaldi-holder.h +template<int kFeatDim> class SphinxMatrixHolder { + public: + typedef Matrix<BaseFloat> T; + + SphinxMatrixHolder() {} + + void Clear() { feats_.Resize(0, 0); } + + // Writes Sphinx-format features + static bool Write(std::ostream &os, bool binary, const T &m) { + if (!binary) { + KALDI_WARN << "SphinxMatrixHolder can't write Sphinx features in text "; + return false; + } + + int32 size = m.NumRows() * m.NumCols(); + if (MachineIsLittleEndian()) + KALDI_SWAP4(size); + os.write((char*) &size, sizeof(size)); // write the header + + for (MatrixIndexT i = 0; i < m.NumRows(); i++) { + float32 tmp[m.NumCols()]; + for (MatrixIndexT j = 0; j < m.NumCols(); j++) { + tmp[j] = static_cast<float32>(m(i, j)); + if (MachineIsLittleEndian()) + KALDI_SWAP4(tmp[j]); + } + os.write((char*) tmp, sizeof(tmp)); + } + + return true; + } + + // Reads the features into a Kaldi Matrix + bool Read(std::istream &is) { + int32 nmfcc; + + is.read((char*) &nmfcc, sizeof(nmfcc)); + if (MachineIsLittleEndian()) + KALDI_SWAP4(nmfcc); + KALDI_VLOG(2) << "#feats: " << nmfcc; + int32 nfvec = nmfcc / kFeatDim; + if ((nmfcc % kFeatDim) != 0) { + KALDI_WARN << "Sphinx feature count is inconsistent with vector length "; + return false; + } + + feats_.Resize(nfvec, kFeatDim); + for (MatrixIndexT i = 0; i < feats_.NumRows(); i++) { + if (sizeof(BaseFloat) == sizeof(float32)) { + is.read((char*) feats_.RowData(i), kFeatDim * sizeof(float32)); + if (!is.good()) { + KALDI_WARN << "Unexpected error/EOF while reading Sphinx features "; + return false; + } + if (MachineIsLittleEndian()) { + for (MatrixIndexT j=0; j < kFeatDim; j++) + KALDI_SWAP4(feats_(i, j)); + } + } else { // KALDI_DOUBLEPRECISION=1 + float32 tmp[kFeatDim]; + is.read((char*) tmp, sizeof(tmp)); + if (!is.good()) { + KALDI_WARN << "Unexpected error/EOF while reading Sphinx features "; + return false; + } + for (MatrixIndexT j=0; j < kFeatDim; j++) { + if (MachineIsLittleEndian()) + KALDI_SWAP4(tmp[j]); + feats_(i, j) = static_cast<BaseFloat>(tmp[j]); + } + } + } + + return true; + } + + // Only read in binary + static bool IsReadInBinary() { return true; } + + const T &Value() const { return feats_; } + + private: + KALDI_DISALLOW_COPY_AND_ASSIGN(SphinxMatrixHolder); + T feats_; +}; + + +/// @} end "addtogroup holders" + +} // end namespace kaldi + + + +#endif diff --git a/kaldi_io/src/kaldi/util/kaldi-holder.h b/kaldi_io/src/kaldi/util/kaldi-holder.h new file mode 100644 index 0000000..95f1183 --- /dev/null +++ b/kaldi_io/src/kaldi/util/kaldi-holder.h @@ -0,0 +1,207 @@ +// util/kaldi-holder.h + +// Copyright 2009-2011 Microsoft Corporation + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + + +#ifndef KALDI_UTIL_KALDI_HOLDER_H_ +#define KALDI_UTIL_KALDI_HOLDER_H_ + +#include <algorithm> +#include "util/kaldi-io.h" +#include "util/text-utils.h" +#include "matrix/kaldi-vector.h" + +namespace kaldi { + + +// The Table class uses a Holder class to wrap objects, and make them behave +// in a "normalized" way w.r.t. reading and writing, so the Table class can +// be template-ized without too much trouble. Look below this +// comment (search for GenericHolder) to see what it looks like. +// +// Requirements of the holder class: +// +// They can only contain objects that can be read/written without external +// information; other objects cannot be stored in this type of archive. +// +// In terms of what functions it should have, see GenericHolder below. +// It is just for documentation. +// +// (1) Requirements of the Read and Write functions +// +// The Read and Write functions should have the property that in a longer +// file, if the Read function is started from where the Write function started +// writing, it should go to where the Write function stopped writing, in either +// text or binary mode (but it's OK if it doesn't eat up trailing space). +// +// [Desirable property: when writing in text mode the output should contain +// exactly one newline, at the end of the output; this makes it easier to manipulate] +// +// [Desirable property for classes: the output should just be a binary-mode +// header (if in binary mode and it's a Kaldi object, or no header +// othewise), and then the output of Object.Write(). This means that when +// written to individual files with the scp: type of wspecifier, we can read +// the individual files in the "normal" Kaldi way by reading the binary +// header and then the object.] +// +// +// The Write function takes a 'binary' argument. In general, each object will +// have two formats: text and binary. However, it's permitted to throw() if +// asked to read in the text format if there is none. The file will be open, if +// the file system has binary/text modes, in the corresponding mode. However, +// the object should have a file-mode in which it can read either text or binary +// output. It announces this via the static IsReadInBinary() function. This +// will generally be the binary mode and it means that where necessary, in text +// formats, we must ignore \r characters. +// +// Memory requirements: if it allocates memory, the destructor should +// free that memory. Copying and assignment of Holder objects may be +// disallowed as the Table code never does this. + + +/// GenericHolder serves to document the requirements of the Holder interface; +/// it's not intended to be used. +template<class SomeType> class GenericHolder { + public: + typedef SomeType T; + + /// Must have a constructor that takes no arguments. + GenericHolder() { } + + /// Write writes this object of type T. Possibly also writes a binary-mode + /// header so that the Read function knows which mode to read in (since the + /// Read function does not get this information). It's a static member so we + /// can write those not inside this class (can use this function with Value() + /// to write from this class). The Write method may throw if it cannot write + /// the object in the given (binary/non-binary) mode. The holder object can + /// assume the stream has been opened in the given mode (where relevant). The + /// object can write the data how it likes. + static bool Write(std::ostream &os, bool binary, const T &t); + + /// Reads into the holder. Must work out from the stream (which will be opened + /// on Windows in binary mode if the IsReadInBinary() function of this class + /// returns true, and text mode otherwise) whether the actual data is binary or + /// not (usually via reading the Kaldi binary-mode header). We put the + /// responsibility for reading the Kaldi binary-mode header in the Read + /// function (rather than making the binary mode an argument to this function), + /// so that for non-Kaldi binary files we don't have to write the header, which + /// would prevent the file being read by non-Kaldi programs (e.g. if we write + /// to individual files using an scp). + /// + /// Read must deallocate any existing data we have here, if applicable (must + /// not assume the object was newly constructed). + /// + /// Returns true on success. + bool Read(std::istream &is); + + /// IsReadInBinary() will return true if the object wants the file to be + /// opened in binary for reading (if the file system has binary/text modes), + /// and false otherwise. Static function. Kaldi objects always return true + /// as they always read in binary mode. Note that we must be able to read, in + /// this mode, objects written in both text and binary mode by Write (which + /// may mean ignoring "\r" characters). I doubt we will ever want this + /// function to return false. + static bool IsReadInBinary() { return true; } + + /// Returns the value of the object held here. Will only + /// ever be called if Read() has been previously called and it returned + /// true (so OK to throw exception if no object was read). + const T &Value() const { return t_; } // if t is a pointer, would return *t_; + + /// The Clear() function doesn't have to do anything. Its purpose is to + /// allow the object to free resources if they're no longer needed. + void Clear() { } + + /// If the object held pointers, the destructor would free them. + ~GenericHolder() { } + + private: + KALDI_DISALLOW_COPY_AND_ASSIGN(GenericHolder); + T t_; // t_ may alternatively be of type T*. +}; + + +// See kaldi-holder-inl.h for examples of some actual Holder +// classes and templates. + + +// The following two typedefs should probably be in their own file, but they're +// here until there are enough of them to warrant their own header. + + +/// \addtogroup holders +/// @{ + +/// KaldiObjectHolder works for Kaldi objects that have the "standard" Read and Write +/// functions, and a copy constructor. +template<class KaldiType> class KaldiObjectHolder; + +/// BasicHolder is valid for float, double, bool, and integer +/// types. There will be a compile time error otherwise, because +/// we make sure that the {Write, Read}BasicType functions do not +/// get instantiated for other types. +template<class BasicType> class BasicHolder; + + +// A Holder for a vector of basic types, e.g. +// std::vector<int32>, std::vector<float>, and so on. +// Note: a basic type is defined as a type for which ReadBasicType +// and WriteBasicType are implemented, i.e. integer and floating +// types, and bool. +template<class BasicType> class BasicVectorHolder; + + +// A holder for vectors of vectors of basic types, e.g. +// std::vector<std::vector<int32> >, and so on. +// Note: a basic type is defined as a type for which ReadBasicType +// and WriteBasicType are implemented, i.e. integer and floating +// types, and bool. +template<class BasicType> class BasicVectorVectorHolder; + +// A holder for vectors of pairsof basic types, e.g. +// std::vector<std::vector<int32> >, and so on. +// Note: a basic type is defined as a type for which ReadBasicType +// and WriteBasicType are implemented, i.e. integer and floating +// types, and bool. Text format is (e.g. for integers), +// "1 12 ; 43 61 ; 17 8 \n" +template<class BasicType> class BasicPairVectorHolder; + +/// We define a Token (not a typedef, just a word) as a nonempty, printable, +/// whitespace-free std::string. The binary and text formats here are the same +/// (newline-terminated) and as such we don't bother with the binary-mode headers. +class TokenHolder; + +/// Class TokenVectorHolder is a Holder class for vectors of Tokens (T == std::string). +class TokenVectorHolder; + +/// A class for reading/writing HTK-format matrices. +/// T == std::pair<Matrix<BaseFloat>, HtkHeader> +class HtkMatrixHolder; + +/// A class for reading/writing Sphinx format matrices. +template<int kFeatDim=13> class SphinxMatrixHolder; + + +/// @} end "addtogroup holders" + + +} // end namespace kaldi + +#include "kaldi-holder-inl.h" + +#endif diff --git a/kaldi_io/src/kaldi/util/kaldi-io-inl.h b/kaldi_io/src/kaldi/util/kaldi-io-inl.h new file mode 100644 index 0000000..7df7505 --- /dev/null +++ b/kaldi_io/src/kaldi/util/kaldi-io-inl.h @@ -0,0 +1,45 @@ +// util/kaldi-io-inl.h + +// Copyright 2009-2011 Microsoft Corporation + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at + +// http://www.apache.org/licenses/LICENSE-2.0 + +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. +#ifndef KALDI_UTIL_KALDI_IO_INL_H_ +#define KALDI_UTIL_KALDI_IO_INL_H_ + + +namespace kaldi { + +bool Input::Open(const std::string &rxfilename, bool *binary) { + return OpenInternal(rxfilename, true, binary); +} + +bool Input::OpenTextMode(const std::string &rxfilename) { + return OpenInternal(rxfilename, false, NULL); +} + +bool Input::IsOpen() { + return impl_ != NULL; +} + +bool Output::IsOpen() { + return impl_ != NULL; +} + + +} // end namespace kaldi. + + +#endif diff --git a/kaldi_io/src/kaldi/util/kaldi-io.h b/kaldi_io/src/kaldi/util/kaldi-io.h new file mode 100644 index 0000000..f2c7563 --- /dev/null +++ b/kaldi_io/src/kaldi/util/kaldi-io.h @@ -0,0 +1,264 @@ +// util/kaldi-io.h + +// Copyright 2009-2011 Microsoft Corporation; Jan Silovsky + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at + +// http://www.apache.org/licenses/LICENSE-2.0 + +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. +#ifndef KALDI_UTIL_KALDI_IO_H_ +#define KALDI_UTIL_KALDI_IO_H_ + +#include <cctype> // For isspace. +#include <limits> +#include <string> +#include "base/kaldi-common.h" +#ifdef _MSC_VER +# include <fcntl.h> +# include <io.h> +#endif + + + +namespace kaldi { + +class OutputImplBase; // Forward decl; defined in a .cc file +class InputImplBase; // Forward decl; defined in a .cc file + +/// \addtogroup io_group +/// @{ + +// The Output and Input classes handle stream-opening for "extended" filenames +// that include actual files, standard-input/standard-output, pipes, and +// offsets into actual files. They also handle reading and writing the +// binary-mode headers for Kaldi files, where applicable. The classes have +// versions of the Open routines that throw and do not throw, depending whether +// the calling code wants to catch the errors or not; there are also versions +// that write (or do not write) the Kaldi binary-mode header that says if it's +// binary mode. Generally files that contain Kaldi objects will have the header +// on, so we know upon reading them whether they have the header. So you would +// use the OpenWithHeader routines for these (or the constructor); but other +// types of objects (e.g. FSTs) would have files without a header so you would +// use OpenNoHeader. + +// We now document the types of extended filenames that we use. +// +// A "wxfilename" is an extended filename for writing. It can take three forms: +// (1) Filename: e.g. "/some/filename", "./a/b/c", "c:\Users\dpovey\My Documents\\boo" +// (whatever the actual file-system interprets) +// (2) Standard output: "" or "-" +// (3) A pipe: e.g. "gunzip -c /tmp/abc.gz |" +// +// +// A "rxfilename" is an extended filename for reading. It can take four forms: +// (1) An actual filename, whatever the file-system can read, e.g. "/my/file". +// (2) Standard input: "" or "-" +// (3) A pipe: e.g. "| gzip -c > /tmp/abc.gz" +// (4) An offset into a file, e.g.: "/mnt/blah/data/1.ark:24871" +// [these are created by the Table and TableWriter classes; I may also write +// a program that creates them for arbitrary files] +// + + +// Typical usage: +// ... +// bool binary; +// MyObject.Write(Output(some_filename, binary).Stream(), binary); +// +// ... more extensive example: +// { +// Output ko(some_filename, binary); +// MyObject1.Write(ko.Stream(), binary); +// MyObject2.Write(ko.Stream(), binary); +// } + + + +enum OutputType { + kNoOutput, + kFileOutput, + kStandardOutput, + kPipeOutput +}; + +/// ClassifyWxfilename interprets filenames as follows: +/// - kNoOutput: invalid filenames (leading or trailing space, things that look +/// like wspecifiers and rspecifiers or like pipes to read from with leading |. +/// - kFileOutput: Normal filenames +/// - kStandardOutput: The empty string or "-", interpreted as standard output +/// - kPipeOutput: pipes, e.g. "gunzip -c some_file.gz |" +OutputType ClassifyWxfilename(const std::string &wxfilename); + +enum InputType { + kNoInput, + kFileInput, + kStandardInput, + kOffsetFileInput, + kPipeInput +}; + +/// ClassifyRxfilenames interprets filenames for reading as follows: +/// - kNoInput: invalid filenames (leading or trailing space, things that +/// look like wspecifiers and rspecifiers or pipes to write to +/// with trailing |. +/// - kFileInput: normal filenames +/// - kStandardInput: the empty string or "-" +/// - kPipeInput: e.g. "| gzip -c > blah.gz" +/// - kOffsetFileInput: offsets into files, e.g. /some/filename:12970 +InputType ClassifyRxfilename(const std::string &rxfilename); + + +class Output { + public: + // The normal constructor, provided for convenience. + // Equivalent to calling with default constructor then Open() + // with these arguments. + Output(const std::string &filename, bool binary, bool write_header = true); + + Output(): impl_(NULL) {}; + + /// This opens the stream, with the given mode (binary or text). It returns + /// true on success and false on failure. However, it will throw if something + /// was already open and could not be closed (to avoid this, call Close() + /// first. if write_header == true and binary == true, it writes the Kaldi + /// binary-mode header ('\0' then 'B'). You may call Open even if it is + /// already open; it will close the existing stream and reopen (however if + /// closing the old stream failed it will throw). + bool Open(const std::string &wxfilename, bool binary, bool write_header); + + inline bool IsOpen(); // return true if we have an open stream. Does not imply + // stream is good for writing. + + std::ostream &Stream(); // will throw if not open; else returns stream. + + // Close closes the stream. Calling Close is never necessary unless you + // want to avoid exceptions being thrown. There are times when calling + // Close will hurt efficiency (basically, when using offsets into files, + // and using the same Input object), + // but most of the time the user won't be doing this directly, it will + // be done in kaldi-table.{h, cc}, so you don't have to worry about it. + bool Close(); + + // This will throw if stream could not be closed (to check error status, + // call Close()). + ~Output(); + + private: + OutputImplBase *impl_; // non-NULL if open. + std::string filename_; + KALDI_DISALLOW_COPY_AND_ASSIGN(Output); +}; + + +// bool binary_in; +// Input ki(some_filename, &binary_in); +// MyObject.Read(ki, binary_in); +// +// ... more extensive example: +// +// { +// bool binary_in; +// Input ki(some_filename, &binary_in); +// MyObject1.Read(ki.Stream(), &binary_in); +// MyObject2.Write(ki.Stream(), &binary_in); +// } +// Note that to catch errors you need to use try.. catch. +// Input communicates errors by throwing exceptions. + + +// Input interprets four kinds of filenames: +// (1) Normal filenames +// (2) The empty string or "-", interpreted as standard output +// (3) Pipes, e.g. "| gzip -c > some_file.gz" +// (4) Offsets into [real] files, e.g. "/my/filename:12049" +// The last one has no correspondence in Output. + + +class Input { + public: + /// The normal constructor. Opens the stream in binary mode. + /// Equivalent to calling the default constructor followed by Open(); then, if + /// binary != NULL, it calls ReadHeader(), putting the output in "binary"; it + /// throws on error. + Input(const std::string &rxfilename, bool *contents_binary = NULL); + + Input(): impl_(NULL) {} + + // Open opens the stream for reading (the mode, where relevant, is binary; use + // OpenTextMode for text-mode, we made this a separate function rather than a + // boolean argument, to avoid confusion with Kaldi's text/binary distinction, + // since reading in the file system's text mode is unusual.) If + // contents_binary != NULL, it reads the binary-mode header and puts it in the + // "binary" variable. Returns true on success. If it returns false it will + // not be open. You may call Open even if it is already open; it will close + // the existing stream and reopen (however if closing the old stream failed it + // will throw). + inline bool Open(const std::string &rxfilename, bool *contents_binary = NULL); + + // As Open but (if the file system has text/binary modes) opens in text mode; + // you shouldn't ever have to use this as in Kaldi we read even text files in + // binary mode (and ignore the \r). + inline bool OpenTextMode(const std::string &rxfilename); + + // Return true if currently open for reading and Stream() will + // succeed. Does not guarantee that the stream is good. + inline bool IsOpen(); + + // It is never necessary or helpful to call Close, except if + // you are concerned about to many filehandles being open. + // Close does not throw. + void Close(); + + // Returns the underlying stream. Throws if !IsOpen() + std::istream &Stream(); + + // Destructor does not throw: input streams may legitimately fail so we + // don't worry about the status when we close them. + ~Input(); + private: + bool OpenInternal(const std::string &rxfilename, bool file_binary, bool *contents_binary); + InputImplBase *impl_; + KALDI_DISALLOW_COPY_AND_ASSIGN(Input); +}; + +template <class C> inline void ReadKaldiObject(const std::string &filename, + C *c) { + bool binary_in; + Input ki(filename, &binary_in); + c->Read(ki.Stream(), binary_in); +} + +template <class C> inline void WriteKaldiObject(const C &c, + const std::string &filename, + bool binary) { + Output ko(filename, binary); + c.Write(ko.Stream(), binary); +} + +/// PrintableRxfilename turns the rxfilename into a more human-readable +/// form for error reporting, i.e. it does quoting and escaping and +/// replaces "" or "-" with "standard input". +std::string PrintableRxfilename(std::string rxfilename); + +/// PrintableWxfilename turns the filename into a more human-readable +/// form for error reporting, i.e. it does quoting and escaping and +/// replaces "" or "-" with "standard output". +std::string PrintableWxfilename(std::string wxfilename); + +/// @} + +} // end namespace kaldi. + +#include "kaldi-io-inl.h" + +#endif diff --git a/kaldi_io/src/kaldi/util/kaldi-pipebuf.h b/kaldi_io/src/kaldi/util/kaldi-pipebuf.h new file mode 100644 index 0000000..43e5a2e --- /dev/null +++ b/kaldi_io/src/kaldi/util/kaldi-pipebuf.h @@ -0,0 +1,90 @@ +// util/kaldi-pipebuf.h + +// Copyright 2009-2011 Ondrej Glembek + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at + +// http://www.apache.org/licenses/LICENSE-2.0 + +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + + +/** @file kaldi-pipebuf.h + * This is an Kaldi C++ Library header. + */ + +#ifndef KALDI_UTIL_KALDI_PIPEBUF_H_ +#define KALDI_UTIL_KALDI_PIPEBUF_H_ + +#if defined(_LIBCPP_VERSION) // libc++ +#include "basic-filebuf.h" +#else +#include <fstream> +#endif + +namespace kaldi +{ +// This class provides a way to initialize a filebuf with a FILE* pointer +// directly; it will not close the file pointer when it is deleted. +// The C++ standard does not allow implementations of C++ to provide +// this constructor within basic_filebuf, which makes it hard to deal +// with pipes using completely native C++. This is a workaround + +#ifdef _MSC_VER +#elif defined(_LIBCPP_VERSION) // libc++ +template<class CharType, class Traits = std::char_traits<CharType> > +class basic_pipebuf : public basic_filebuf<CharType, Traits> +{ + public: + typedef basic_pipebuf<CharType, Traits> ThisType; + + public: + basic_pipebuf(FILE *fptr, std::ios_base::openmode mode) + : basic_filebuf<CharType, Traits>() { + this->open(fptr, mode); + if (!this->is_open()) { + KALDI_WARN << "Error initializing pipebuf"; // probably indicates + // code error, if the fptr was good. + return; + } + } +}; // class basic_pipebuf +#else +template<class CharType, class Traits = std::char_traits<CharType> > +class basic_pipebuf : public std::basic_filebuf<CharType, Traits> +{ + public: + typedef basic_pipebuf<CharType, Traits> ThisType; + + public: + basic_pipebuf(FILE *fptr, std::ios_base::openmode mode) + : std::basic_filebuf<CharType, Traits>() { + this->_M_file.sys_open(fptr, mode); + if (!this->is_open()) { + KALDI_WARN << "Error initializing pipebuf"; // probably indicates + // code error, if the fptr was good. + return; + } + this->_M_mode = mode; + this->_M_buf_size = BUFSIZ; + this->_M_allocate_internal_buffer(); + this->_M_reading = false; + this->_M_writing = false; + this->_M_set_buffer(-1); + } +}; // class basic_pipebuf +#endif // _MSC_VER + +}; // namespace kaldi + +#endif // KALDI_UTIL_KALDI_PIPEBUF_H_ + diff --git a/kaldi_io/src/kaldi/util/kaldi-table-inl.h b/kaldi_io/src/kaldi/util/kaldi-table-inl.h new file mode 100644 index 0000000..6b73c88 --- /dev/null +++ b/kaldi_io/src/kaldi/util/kaldi-table-inl.h @@ -0,0 +1,2246 @@ +// util/kaldi-table-inl.h + +// Copyright 2009-2011 Microsoft Corporation +// 2013 Johns Hopkins University (author: Daniel Povey) + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + + +#ifndef KALDI_UTIL_KALDI_TABLE_INL_H_ +#define KALDI_UTIL_KALDI_TABLE_INL_H_ + +#include <algorithm> +#include "util/kaldi-io.h" +#include "util/text-utils.h" +#include "util/stl-utils.h" // for StringHasher. + + +namespace kaldi { + +/// \addtogroup table_impl_types +/// @{ + +template<class Holder> class SequentialTableReaderImplBase { + public: + typedef typename Holder::T T; + // note that Open takes rxfilename not rspecifier. + virtual bool Open(const std::string &rxfilename) = 0; + virtual bool Done() const = 0; + virtual bool IsOpen() const = 0; + virtual std::string Key() = 0; + virtual const T &Value() = 0; + virtual void FreeCurrent() = 0; + virtual void Next() = 0; + virtual bool Close() = 0; + SequentialTableReaderImplBase() { } + virtual ~SequentialTableReaderImplBase() { } + private: + KALDI_DISALLOW_COPY_AND_ASSIGN(SequentialTableReaderImplBase); +}; + + +// This is the implementation for SequentialTableReader +// when it's actually a script file. +template<class Holder> class SequentialTableReaderScriptImpl: + public SequentialTableReaderImplBase<Holder> { + public: + typedef typename Holder::T T; + + SequentialTableReaderScriptImpl(): state_(kUninitialized) { } + + virtual bool Open(const std::string &rspecifier) { + if (state_ != kUninitialized) + if (! Close()) // call Close() yourself to suppress this exception. + KALDI_ERR << "TableReader::Open, error closing previous input: " + << "rspecifier was " << rspecifier_; + bool binary; + rspecifier_ = rspecifier; + RspecifierType rs = ClassifyRspecifier(rspecifier, &script_rxfilename_, + &opts_); + KALDI_ASSERT(rs == kScriptRspecifier); + if (!script_input_.Open(script_rxfilename_, &binary)) { // Failure on Open + KALDI_WARN << "Failed to open script file " + << PrintableRxfilename(script_rxfilename_); + state_ = kUninitialized; + return false; + } else { // Open succeeded. + if (binary) { // script file should not be binary file.. + state_ = kError; // bad script file. + script_input_.Close(); + return false; + } else { + state_ = kFileStart; + Next(); + if (state_ == kError) { + script_input_.Close(); + return false; + } + if (opts_.permissive) { // Next() will have preloaded. + KALDI_ASSERT(state_ == kLoadSucceeded || state_ == kEof); + } else { + KALDI_ASSERT(state_ == kHaveScpLine || state_ == kEof); + } + return true; // Success. + } + } + } + + virtual bool IsOpen() const { + switch (state_) { + case kEof: case kError: case kHaveScpLine: case kLoadSucceeded: case kLoadFailed: return true; + case kUninitialized: return false; + default: KALDI_ERR << "IsOpen() called on invalid object."; // kFileStart is not valid + // state for user to call something on. + return false; + } + } + + virtual bool Done() const { + switch (state_) { + case kHaveScpLine: return false; + case kLoadSucceeded: case kLoadFailed: return false; + // These cases are because we want LoadCurrent() + // to be callable after Next() and to not change the Done() status [only Next() should change + // the Done() status]. + case kEof: case kError: return true; // Error condition, like Eof, counts as Done(); the destructor + // or Close() will inform the user of the error. + default: KALDI_ERR << "Done() called on TableReader object at the wrong time."; + return false; + } + } + + virtual std::string Key() { + // Valid to call this whenever Done() returns false. + switch (state_) { + case kHaveScpLine: case kLoadSucceeded: case kLoadFailed: break; + default: + // coding error. + KALDI_ERR << "Key() called on TableReader object at the wrong time."; + } + return key_; + } + const T &Value() { + StateType orig_state = state_; + if (state_ == kHaveScpLine) LoadCurrent(); // Takes + // state_ to kLoadSucceeded or kLoadFailed. + if (state_ == kLoadFailed) { // this can happen due to + // a file listed in an scp file not existing, or + // read failure, failure of a command, etc. + if (orig_state == kHaveScpLine) + KALDI_ERR << "TableReader: failed to load object from " + << PrintableRxfilename(data_rxfilename_) + << " (to suppress this error, add the permissive " + << "(p, ) option to the rspecifier."; + + else // orig_state_ was kLoadFailed, which only could have happened + // if the user called FreeCurrent(). + KALDI_ERR << "TableReader: you called Value() after FreeCurrent()."; + } else if (state_ != kLoadSucceeded) { + // This would be a coding error. + KALDI_ERR << "TableReader: Value() called at the wrong time."; + } + return holder_.Value(); + } + void FreeCurrent() { + if (state_ == kLoadSucceeded) { + holder_.Clear(); + state_ = kLoadFailed; + } else { + KALDI_WARN << "TableReader: FreeCurrent called at the wrong time."; + } + } + void Next() { + while (1) { + NextScpLine(); + if (Done()) return; + if (opts_.permissive) { + // Permissive mode means, when reading scp files, we treat keys whose scp entry + // cannot be read as nonexistent. This means trying to read. + if (LoadCurrent()) return; // Success. + // else try the next scp line. + } else { + return; // We go the next key; Value() will crash if we can't + // read the scp line. + } + } + } + + virtual bool Close() { + // Close() will succeed if the stream was not in an error + // state. To clean up, it also closes the Input objects if + // they're open. + if (script_input_.IsOpen()) + script_input_.Close(); + if (data_input_.IsOpen()) + data_input_.Close(); + if (state_ == kLoadSucceeded) + holder_.Clear(); + if (!this->IsOpen()) + KALDI_ERR << "Close() called on input that was not open."; + StateType old_state = state_; + state_ = kUninitialized; + if (old_state == kError) { + if (opts_.permissive) { + KALDI_WARN << "Close() called on scp file with read error, ignoring the " + "error because permissive mode specified."; + return true; + } else return false; // User will do something with the error status. + } else return true; + } + + virtual ~SequentialTableReaderScriptImpl() { + if (state_ == kError) + KALDI_ERR << "TableReader: reading script file failed: from scp " + << PrintableRxfilename(script_rxfilename_); + // If you don't want this exception to be thrown you can + // call Close() and check the status. + if (state_ == kLoadSucceeded) + holder_.Clear(); + } + private: + bool LoadCurrent() { + // Attempts to load object whose rxfilename is on the current scp line. + if (state_ != kHaveScpLine) + KALDI_ERR << "TableReader: LoadCurrent() called at the wrong time."; + bool ans; + // note, NULL means it doesn't read the binary-mode header + if (Holder::IsReadInBinary()) ans = data_input_.Open(data_rxfilename_, NULL); + else ans = data_input_.OpenTextMode(data_rxfilename_); + if (!ans) { + // May want to make this warning a VLOG at some point + KALDI_WARN << "TableReader: failed to open file " + << PrintableRxfilename(data_rxfilename_); + state_ = kLoadFailed; + return false; + } else { + if (holder_.Read(data_input_.Stream())) { + state_ = kLoadSucceeded; + return true; + } else { // holder_ will not contain data. + KALDI_WARN << "TableReader: failed to load object from " + << PrintableRxfilename(data_rxfilename_); + state_ = kLoadFailed; + return false; + } + } + } + + // Reads the next line in the script file. + void NextScpLine() { + switch (state_) { + case kLoadSucceeded: holder_.Clear(); break; + case kHaveScpLine: case kLoadFailed: case kFileStart: break; + default: + // No other states are valid to call Next() from. + KALDI_ERR << "Reading script file: Next called wrongly."; + } + std::string line; + if (getline(script_input_.Stream(), line)) { + SplitStringOnFirstSpace(line, &key_, &data_rxfilename_); + if (!key_.empty() && !data_rxfilename_.empty()) { + // Got a valid line. + state_ = kHaveScpLine; + } else { + // Got an invalid line. + state_ = kError; // we can't make sense of this + // scp file and will now die. + } + } else { + state_ = kEof; // nothing more in the scp file. + // Might as well close the input streams as don't need them. + script_input_.Close(); + if (data_input_.IsOpen()) + data_input_.Close(); + } + } + + + Input script_input_; // Input object for the .scp file + Input data_input_; // Input object for the entries in + // the script file. + Holder holder_; // Holds the object. + bool binary_; // Binary-mode archive. + std::string key_; + std::string rspecifier_; + std::string script_rxfilename_; // of the script file. + RspecifierOptions opts_; // options. + std::string data_rxfilename_; // of the file we're reading. + enum StateType { + // [The state of the reading process] [does holder_ [is script_inp_ + // have object] open] + kUninitialized, // Uninitialized or closed. no no + kEof, // We did Next() and found eof in script file. no no + kError, // Some other error no yes + kHaveScpLine, // Just called Open() or Next() and have a no yes + // line of the script file but no data. + kLoadSucceeded, // Called LoadCurrent() and it succeeded. yes yes + kLoadFailed, // Called LoadCurrent() and it failed, no yes + // or the user called FreeCurrent().. note, + // if when called by user we are in this state, + // it means the user called FreeCurrent(). + kFileStart, // [state we only use internally] no yes + } state_; + private: +}; + + +// This is the implementation for SequentialTableReader +// when it's an archive. Note that the archive format is: +// key1 [space] object1 key2 [space] +// object2 ... eof. +// "object1" is the output of the Holder::Write function and will +// typically contain a binary header (in binary mode) and then +// the output of object.Write(os, binary). +// The archive itself does not care whether it is in binary +// or text mode, for reading purposes. + +template<class Holder> class SequentialTableReaderArchiveImpl: + public SequentialTableReaderImplBase<Holder> { + public: + typedef typename Holder::T T; + + SequentialTableReaderArchiveImpl(): state_(kUninitialized) { } + + virtual bool Open(const std::string &rspecifier) { + if (state_ != kUninitialized) { + if (! Close()) { // call Close() yourself to suppress this exception. + if (opts_.permissive) + KALDI_WARN << "TableReader::Open, error closing previous input " + "(only warning, since permissive mode)."; + else + KALDI_ERR << "TableReader::Open, error closing previous input."; + } + } + rspecifier_ = rspecifier; + RspecifierType rs = ClassifyRspecifier(rspecifier, + &archive_rxfilename_, + &opts_); + KALDI_ASSERT(rs == kArchiveRspecifier); + + bool ans; + // NULL means don't expect binary-mode header + if (Holder::IsReadInBinary()) + ans = input_.Open(archive_rxfilename_, NULL); + else + ans = input_.OpenTextMode(archive_rxfilename_); + if (!ans) { // header. + KALDI_WARN << "TableReader: failed to open stream " + << PrintableRxfilename(archive_rxfilename_); + state_ = kUninitialized; // Failure on Open + return false; // User should print the error message. + } + state_ = kFileStart; + Next(); + if (state_ == kError) { + KALDI_WARN << "Error beginning to read archive file (wrong filename?): " + << PrintableRxfilename(archive_rxfilename_); + input_.Close(); + state_ = kUninitialized; + return false; + } + KALDI_ASSERT(state_ == kHaveObject || state_ == kEof); + return true; + } + + virtual void Next() { + switch (state_) { + case kHaveObject: + holder_.Clear(); break; + case kFileStart: case kFreedObject: + break; + default: + KALDI_ERR << "TableReader: Next() called wrongly."; + } + std::istream &is = input_.Stream(); + is.clear(); // Clear any fail bits that may have been set... just in case + // this happened in the Read function. + is >> key_; // This eats up any leading whitespace and gets the string. + if (is.eof()) { + state_ = kEof; + return; + } + if (is.fail()) { // This shouldn't really happen, barring file-system errors. + KALDI_WARN << "Error reading archive " + << PrintableRxfilename(archive_rxfilename_); + state_ = kError; + return; + } + int c; + if ((c = is.peek()) != ' ' && c != '\t' && c != '\n') { // We expect a space ' ' after the key. + // We also allow tab [which is consumed] and newline [which is not], just + // so we can read archives generated by scripts that may not be fully + // aware of how this format works. + KALDI_WARN << "Invalid archive file format: expected space after key " + << key_ << ", got character " + << CharToString(static_cast<char>(is.peek())) << ", reading " + << PrintableRxfilename(archive_rxfilename_); + state_ = kError; + return; + } + if (c != '\n') is.get(); // Consume the space or tab. + if (holder_.Read(is)) { + state_ = kHaveObject; + return; + } else { + KALDI_WARN << "Object read failed, reading archive " + << PrintableRxfilename(archive_rxfilename_); + state_ = kError; + return; + } + } + + virtual bool IsOpen() const { + switch (state_) { + case kEof: case kError: case kHaveObject: case kFreedObject: return true; + case kUninitialized: return false; + default: KALDI_ERR << "IsOpen() called on invalid object."; // kFileStart is not valid + // state for user to call something on. + return false; + } + } + + virtual bool Done() const { + switch (state_) { + case kHaveObject: + return false; + case kEof: case kError: + return true; // Error-state counts as Done(), but destructor + // will fail (unless you check the status with Close()). + default: + KALDI_ERR << "Done() called on TableReader object at the wrong time."; + return false; + } + } + + virtual std::string Key() { + // Valid to call this whenever Done() returns false + switch (state_) { + case kHaveObject: break; // only valid case. + default: + // coding error. + KALDI_ERR << "Key() called on TableReader object at the wrong time."; + } + return key_; + } + const T &Value() { + switch (state_) { + case kHaveObject: + break; // only valid case. + default: + // coding error. + KALDI_ERR << "Value() called on TableReader object at the wrong time."; + } + return holder_.Value(); + } + virtual void FreeCurrent() { + if (state_ == kHaveObject) { + holder_.Clear(); + state_ = kFreedObject; + } else + KALDI_WARN << "TableReader: FreeCurernt called at the wrong time."; + } + + virtual bool Close() { + if (! this->IsOpen()) + KALDI_ERR << "Close() called on TableReader twice or otherwise wrongly."; + if (input_.IsOpen()) + input_.Close(); + if (state_ == kHaveObject) + holder_.Clear(); + bool ans; + if (opts_.permissive) { + ans = true; // always return success. + if (state_ == kError) + KALDI_WARN << "Error detected closing TableReader for archive " + << PrintableRxfilename(archive_rxfilename_) << " but ignoring " + << "it as permissive mode specified."; + } else + ans = (state_ != kError); // If error state, user should detect it. + state_ = kUninitialized; + return ans; + } + + virtual ~SequentialTableReaderArchiveImpl() { + if (state_ == kError) { + if (opts_.permissive) + KALDI_WARN << "Error detected closing TableReader for archive " + << PrintableRxfilename(archive_rxfilename_) << " but ignoring " + << "it as permissive mode specified."; + else + KALDI_ERR << "TableReader: error detected closing archive " + << PrintableRxfilename(archive_rxfilename_); + } + // If you don't want this exception to be thrown you can + // call Close() and check the status. + if (state_ == kHaveObject) + holder_.Clear(); + } + private: + Input input_; // Input object for the archive + Holder holder_; // Holds the object. + std::string key_; + std::string rspecifier_; + std::string archive_rxfilename_; + RspecifierOptions opts_; + enum { // [The state of the reading process] [does holder_ [is input_ + // have object] open] + kUninitialized, // Uninitialized or closed. no no + kFileStart, // [state we use internally: just opened.] no yes + kEof, // We did Next() and found eof in archive no no + kError, // Some other error no no + kHaveObject, // We read the key and the object after it. yes yes + kFreedObject, // The user called FreeCurrent(). no yes + } state_; +}; + + +template<class Holder> +SequentialTableReader<Holder>::SequentialTableReader(const std::string &rspecifier): impl_(NULL) { + if (rspecifier != "" && !Open(rspecifier)) + KALDI_ERR << "Error constructing TableReader: rspecifier is " << rspecifier; +} + +template<class Holder> +bool SequentialTableReader<Holder>::Open(const std::string &rspecifier) { + if (IsOpen()) + if (!Close()) + KALDI_ERR << "Could not close previously open object."; + // now impl_ will be NULL. + + RspecifierType wt = ClassifyRspecifier(rspecifier, NULL, NULL); + switch (wt) { + case kArchiveRspecifier: + impl_ = new SequentialTableReaderArchiveImpl<Holder>(); + break; + case kScriptRspecifier: + impl_ = new SequentialTableReaderScriptImpl<Holder>(); + break; + case kNoRspecifier: default: + KALDI_WARN << "Invalid rspecifier " << rspecifier; + return false; + } + if (!impl_->Open(rspecifier)) { + delete impl_; + impl_ = NULL; + return false; // sub-object will have printed warnings. + } + else return true; +} + +template<class Holder> +bool SequentialTableReader<Holder>::Close() { + CheckImpl(); + bool ans = impl_->Close(); + delete impl_; // We don't keep around empty impl_ objects. + impl_ = NULL; + return ans; +} + + +template<class Holder> +bool SequentialTableReader<Holder>::IsOpen() const { + return (impl_ != NULL); // Because we delete the object whenever + // that object is not open. Thus, the IsOpen functions of the + // Impl objects are not really needed. +} + +template<class Holder> +std::string SequentialTableReader<Holder>::Key() { + CheckImpl(); + return impl_->Key(); // this call may throw if called wrongly in other ways, + // e.g. eof. +} + + +template<class Holder> +void SequentialTableReader<Holder>::FreeCurrent() { + CheckImpl(); + impl_->FreeCurrent(); +} + + +template<class Holder> +const typename SequentialTableReader<Holder>::T & +SequentialTableReader<Holder>::Value() { + CheckImpl(); + return impl_->Value(); // This may throw (if LoadCurrent() returned false you are safe.). +} + + +template<class Holder> +void SequentialTableReader<Holder>::Next() { + CheckImpl(); + impl_->Next(); +} + +template<class Holder> +bool SequentialTableReader<Holder>::Done() { + CheckImpl(); + return impl_->Done(); +} + + +template<class Holder> +SequentialTableReader<Holder>::~SequentialTableReader() { + if (impl_) delete impl_; + // Destructor of impl_ may throw. +} + + + +template<class Holder> class TableWriterImplBase { + public: + typedef typename Holder::T T; + + virtual bool Open(const std::string &wspecifier) = 0; + + // Write returns true on success, false on failure, but + // some errors may not be detected until we call Close(). + // It throws (via KALDI_ERR) if called wrongly. We could + // have just thrown on all errors, since this is what + // TableWriter does; it was designed this way because originally + // TableWriter::Write returned an exit status. + virtual bool Write(const std::string &key, const T &value) = 0; + + // Flush will flush any archive; it does not return error status, + // any errors will be reported on the next Write or Close. + virtual void Flush() = 0; + + virtual bool Close() = 0; + + virtual bool IsOpen() const = 0; + + // May throw on write error if Close was not called. + virtual ~TableWriterImplBase() { } + + TableWriterImplBase() { } + private: + KALDI_DISALLOW_COPY_AND_ASSIGN(TableWriterImplBase); +}; + + +// The implementation of TableWriter we use when writing directly +// to an archive with no associated scp. +template<class Holder> +class TableWriterArchiveImpl: public TableWriterImplBase<Holder> { + public: + typedef typename Holder::T T; + + virtual bool Open(const std::string &wspecifier) { + switch (state_) { + case kUninitialized: + break; + case kWriteError: + KALDI_ERR << "TableWriter: opening stream, already open with write error."; + case kOpen: default: + if (!Close()) // throw because this error may not have been previously + // detected by the user. + KALDI_ERR << "TableWriter: opening stream, error closing previously open stream."; + } + wspecifier_ = wspecifier; + WspecifierType ws = ClassifyWspecifier(wspecifier, + &archive_wxfilename_, + NULL, + &opts_); + KALDI_ASSERT(ws == kArchiveWspecifier); // or wrongly called. + + if (output_.Open(archive_wxfilename_, opts_.binary, false)) { // false means no binary header. + state_ = kOpen; + return true; + } else { + // stream will not be open. User will report this error + // (we return bool), so don't bother printing anything. + state_ = kUninitialized; + return false; + } + } + + virtual bool IsOpen() const { + switch (state_) { + case kUninitialized: return false; + case kOpen: case kWriteError: return true; + default: KALDI_ERR << "IsOpen() called on TableWriter in invalid state."; + } + return false; + } + + // Write returns true on success, false on failure, but + // some errors may not be detected till we call Close(). + virtual bool Write(const std::string &key, const T &value) { + switch (state_) { + case kOpen: break; + case kWriteError: + // user should have known from the last + // call to Write that there was a problem. + KALDI_WARN << "TableWriter: attempting to write to invalid stream."; + return false; + case kUninitialized: default: + KALDI_ERR << "TableWriter: Write called on invalid stream"; + + } + // state is now kOpen or kWriteError. + if (!IsToken(key)) // e.g. empty string or has spaces... + KALDI_ERR << "TableWriter: using invalid key " << key; + output_.Stream() << key << ' '; + if (!Holder::Write(output_.Stream(), opts_.binary, value)) { + KALDI_WARN << "TableWriter: write failure to " + << PrintableWxfilename(archive_wxfilename_); + state_ = kWriteError; + return false; + } + if (state_ == kWriteError) return false; // Even if this Write seems to have + // succeeded, we fail because a previous Write failed and the archive may be + // corrupted and unreadable. + + if (opts_.flush) + Flush(); + return true; + } + + // Flush will flush any archive; it does not return error status, + // any errors will be reported on the next Write or Close. + virtual void Flush() { + switch (state_) { + case kWriteError: case kOpen: + output_.Stream().flush(); // Don't check error status. + return; + default: + KALDI_WARN << "TableWriter: Flush called on not-open writer."; + } + } + + virtual bool Close() { + if (!this->IsOpen() || !output_.IsOpen()) + KALDI_ERR << "TableWriter: Close called on a stream that was not open." << this->IsOpen() << ", " << output_.IsOpen(); + bool close_success = output_.Close(); + if (!close_success) { + KALDI_WARN << "TableWriter: error closing stream: wspecifier is " + << wspecifier_; + state_ = kUninitialized; + return false; + } + if (state_ == kWriteError) { + KALDI_WARN << "TableWriter: closing writer in error state: wspecifier is " + << wspecifier_; + state_ = kUninitialized; + return false; + } + state_ = kUninitialized; + return true; + } + + TableWriterArchiveImpl(): state_(kUninitialized) {} + + // May throw on write error if Close was not called. + virtual ~TableWriterArchiveImpl() { + if (!IsOpen()) return; + else if (!Close()) + KALDI_ERR << "At TableWriter destructor: Write failed or stream close " + << "failed: wspecifier is "<< wspecifier_; + } + + private: + Output output_; + WspecifierOptions opts_; + std::string wspecifier_; + std::string archive_wxfilename_; + enum { // is stream open? + kUninitialized, // no + kOpen, // yes + kWriteError, // yes + } state_; +}; + + + + +// The implementation of TableWriter we use when writing to +// individual files (more generally, wxfilenames) specified +// in an scp file that we read. + +// Note: the code for this class is similar to RandomAccessTableReaderScriptImpl; +// try to keep them in sync. + +template<class Holder> +class TableWriterScriptImpl: public TableWriterImplBase<Holder> { + public: + typedef typename Holder::T T; + + TableWriterScriptImpl(): last_found_(0), state_(kUninitialized) {} + + virtual bool Open(const std::string &wspecifier) { + switch (state_) { + case kReadScript: + KALDI_ERR << " Opening already open TableWriter: call Close first."; + case kUninitialized: case kNotReadScript: + break; + } + wspecifier_ = wspecifier; + WspecifierType ws = ClassifyWspecifier(wspecifier, + NULL, + &script_rxfilename_, + &opts_); + KALDI_ASSERT(ws == kScriptWspecifier); // or wrongly called. + KALDI_ASSERT(script_.empty()); // no way it could be nonempty at this point. + + if (! ReadScriptFile(script_rxfilename_, + true, // print any warnings + &script_)) { // error reading script file or invalid format + state_ = kNotReadScript; + return false; // no need to print further warnings. user gets the error. + } + std::sort(script_.begin(), script_.end()); + for (size_t i = 0; i+1 < script_.size(); i++) { + if (script_[i].first.compare(script_[i+1].first) >= 0) { + // script[i] not < script[i+1] in lexical order... + KALDI_WARN << "Script file " << PrintableRxfilename(script_rxfilename_) + << " contains duplicate key " << script_[i].first; + state_ = kNotReadScript; + return false; + } + } + state_ = kReadScript; + return true; + } + + virtual bool IsOpen() const { return (state_ == kReadScript); } + + virtual bool Close() { + if (!IsOpen()) + KALDI_ERR << "Close() called on TableWriter that was not open."; + state_ = kUninitialized; + last_found_ = 0; + script_.clear(); + return true; + } + + // Write returns true on success, false on failure, but + // some errors may not be detected till we call Close(). + virtual bool Write(const std::string &key, const T &value) { + if (!IsOpen()) + KALDI_ERR << "TableWriter: Write called on invalid stream"; + + if (!IsToken(key)) // e.g. empty string or has spaces... + KALDI_ERR << "TableWriter: using invalid key " << key; + + std::string wxfilename; + if (!LookupFilename(key, &wxfilename)) { + if (opts_.permissive) { + return true; // In permissive mode, it's as if we're writing to /dev/null + // for missing keys. + } else { + KALDI_WARN << "TableWriter: script file " + << PrintableRxfilename(script_rxfilename_) + << " has no entry for key "<<key; + return false; + } + } + Output output; + if (!output.Open(wxfilename, opts_.binary, false)) { + // Open in the text/binary mode (on Windows) given by member var. "binary" + // (obtained from wspecifier), but do not put the binary-mode header (it + // will be written, if needed, by the Holder::Write function.) + KALDI_WARN << "TableWriter: failed to open stream: " + << PrintableWxfilename(wxfilename); + return false; + } + if (!Holder::Write(output.Stream(), opts_.binary, value) + || !output.Close()) { + KALDI_WARN << "TableWriter: failed to write data to " + << PrintableWxfilename(wxfilename); + return false; + } + return true; + } + + // Flush does nothing in this implementation, there is nothing to flush. + virtual void Flush() { } + + + virtual ~TableWriterScriptImpl() { + // Nothing to do in destructor. + } + + private: + // Note: this function is almost the same as in RandomAccessTableReaderScriptImpl. + bool LookupFilename(const std::string &key, std::string *wxfilename) { + // First, an optimization: if we're going consecutively, this will + // make the lookup very fast. + last_found_++; + if (last_found_ < script_.size() && script_[last_found_].first == key) { + *wxfilename = script_[last_found_].second; + return true; + } + std::pair<std::string, std::string> pr(key, ""); // Important that "" + // compares less than or equal to any string, so lower_bound points to the + // element that has the same key. + typedef typename std::vector<std::pair<std::string, std::string> >::const_iterator + IterType; + IterType iter = std::lower_bound(script_.begin(), script_.end(), pr); + if (iter != script_.end() && iter->first == key) { + last_found_ = iter - script_.begin(); + *wxfilename = iter->second; + return true; + } else { + return false; + } + } + + + WspecifierOptions opts_; + std::string wspecifier_; + std::string script_rxfilename_; + + // the script_ variable contains pairs of (key, filename), sorted using + // std::sort. This can be used with binary_search to look up filenames for + // writing. If this becomes inefficient we can use std::unordered_map (but I + // suspect this wouldn't be significantly faster & would use more memory). + // If memory becomes a problem here, the user should probably be passing + // only the relevant part of the scp file rather than expecting us to get too + // clever in the code. + std::vector<std::pair<std::string, std::string> > script_; + size_t last_found_; // This is for an optimization used in LookupFilename. + + enum { + kUninitialized, + kReadScript, + kNotReadScript, // read of script failed. + } state_; +}; + + +// The implementation of TableWriter we use when writing directly +// to an archive plus an associated scp. +template<class Holder> +class TableWriterBothImpl: public TableWriterImplBase<Holder> { + public: + typedef typename Holder::T T; + + virtual bool Open(const std::string &wspecifier) { + switch (state_) { + case kUninitialized: + break; + case kWriteError: + KALDI_ERR << "TableWriter: opening stream, already open with write error."; + case kOpen: default: + if (!Close()) // throw because this error may not have been previously detected by user. + KALDI_ERR << "TableWriter: opening stream, error closing previously open stream."; + } + wspecifier_ = wspecifier; + WspecifierType ws = ClassifyWspecifier(wspecifier, + &archive_wxfilename_, + &script_wxfilename_, + &opts_); + KALDI_ASSERT(ws == kBothWspecifier); // or wrongly called. + if (ClassifyWxfilename(archive_wxfilename_) != kFileOutput) + KALDI_WARN << "When writing to both archive and script, the script file " + "will generally not be interpreted correctly unless the archive is " + "an actual file: wspecifier = " << wspecifier; + + if (!archive_output_.Open(archive_wxfilename_, opts_.binary, false)) { // false means no binary header. + state_ = kUninitialized; + return false; + } + if (!script_output_.Open(script_wxfilename_, false, false)) { // first false means text mode: + // script files always text-mode. second false means don't write header (doesn't matter + // for text mode). + archive_output_.Close(); // Don't care about status: error anyway. + state_ = kUninitialized; + return false; + } + state_ = kOpen; + return true; + } + + virtual bool IsOpen() const { + switch (state_) { + case kUninitialized: return false; + case kOpen: case kWriteError: return true; + default: KALDI_ERR << "IsOpen() called on TableWriter in invalid state."; + } + return false; + } + + void MakeFilename(typename std::ostream::pos_type streampos, std::string *output) const { + std::ostringstream ss; + ss << ':' << streampos; + KALDI_ASSERT(ss.str() != ":-1"); + *output = archive_wxfilename_ + ss.str(); + + // e.g. /some/file:12302. + // Note that we warned if archive_wxfilename_ is not an actual filename; + // the philosophy is we give the user rope and if they want to hang + // themselves, with it, fine. + } + + // Write returns true on success, false on failure, but + // some errors may not be detected till we call Close(). + virtual bool Write(const std::string &key, const T &value) { + switch (state_) { + case kOpen: break; + case kWriteError: + // user should have known from the last + // call to Write that there was a problem. Warn about it. + KALDI_WARN << "TableWriter: writing to non-open TableWriter object."; + return false; + case kUninitialized: default: + KALDI_ERR << "TableWriter: Write called on invalid stream"; + } + // state is now kOpen or kWriteError. + if (!IsToken(key)) // e.g. empty string or has spaces... + KALDI_ERR << "TableWriter: using invalid key " << key; + std::ostream &archive_os = archive_output_.Stream(); + archive_os << key << ' '; + typename std::ostream::pos_type archive_os_pos = archive_os.tellp(); + // position at start of Write() to archive. We will record this in the script file. + std::string offset_rxfilename; // rxfilename with offset into the archive, + // e.g. some_archive_name.ark:431541423 + MakeFilename(archive_os_pos, &offset_rxfilename); + + // Write to the script file first. + // The idea is that we want to get all the information possible into the + // script file, to make it easier to unwind errors later. + std::ostream &script_os = script_output_.Stream(); + script_output_.Stream() << key << ' ' << offset_rxfilename << '\n'; + + if (!Holder::Write(archive_output_.Stream(), opts_.binary, value)) { + KALDI_WARN << "TableWriter: write failure to" + << PrintableWxfilename(archive_wxfilename_); + state_ = kWriteError; + return false; + } + + if (script_os.fail()) { + KALDI_WARN << "TableWriter: write failure to script file detected: " + << PrintableWxfilename(script_wxfilename_); + state_ = kWriteError; + return false; + } + + if (archive_os.fail()) { + KALDI_WARN << "TableWriter: write failure to archive file detected: " + << PrintableWxfilename(archive_wxfilename_); + state_ = kWriteError; + return false; + } + + if (state_ == kWriteError) return false; // Even if this Write seems to have + // succeeded, we fail because a previous Write failed and the archive may be + // corrupted and unreadable. + + if (opts_.flush) + Flush(); + return true; + } + + // Flush will flush any archive; it does not return error status, + // any errors will be reported on the next Write or Close. + virtual void Flush() { + switch (state_) { + case kWriteError: case kOpen: + archive_output_.Stream().flush(); // Don't check error status. + script_output_.Stream().flush(); // Don't check error status. + return; + default: + KALDI_WARN << "TableWriter: Flush called on not-open writer."; + } + } + + virtual bool Close() { + if (!this->IsOpen()) + KALDI_ERR << "TableWriter: Close called on a stream that was not open."; + bool close_success = true; + if (archive_output_.IsOpen()) + if (!archive_output_.Close()) close_success = false; + if (script_output_.IsOpen()) + if (!script_output_.Close()) close_success = false; + bool ans = close_success && (state_ != kWriteError); + state_ = kUninitialized; + return ans; + } + + TableWriterBothImpl(): state_(kUninitialized) {} + + // May throw on write error if Close() was not called. + // User can get the error status by calling Close(). + virtual ~TableWriterBothImpl() { + if (!IsOpen()) return; + else if (!Close()) + KALDI_ERR << "At TableWriter destructor: Write failed or stream close failed: " + << wspecifier_; + } + + private: + Output archive_output_; + Output script_output_; + WspecifierOptions opts_; + std::string archive_wxfilename_; + std::string script_wxfilename_; + std::string wspecifier_; + enum { // is stream open? + kUninitialized, // no + kOpen, // yes + kWriteError, // yes + } state_; +}; + + +template<class Holder> +TableWriter<Holder>::TableWriter(const std::string &wspecifier): impl_(NULL) { + if (wspecifier != "" && !Open(wspecifier)) { + KALDI_ERR << "TableWriter: failed to write to " + << wspecifier; + } +} + +template<class Holder> +bool TableWriter<Holder>::IsOpen() const { + return (impl_ != NULL); +} + + +template<class Holder> +bool TableWriter<Holder>::Open(const std::string &wspecifier) { + + if (IsOpen()) { + if (!Close()) // call Close() yourself to suppress this exception. + KALDI_ERR << "TableWriter::Open, failed to close previously open writer."; + } + KALDI_ASSERT(impl_ == NULL); + WspecifierType wtype = ClassifyWspecifier(wspecifier, NULL, NULL, NULL); + switch (wtype) { + case kBothWspecifier: + impl_ = new TableWriterBothImpl<Holder>(); + break; + case kArchiveWspecifier: + impl_ = new TableWriterArchiveImpl<Holder>(); + break; + case kScriptWspecifier: + impl_ = new TableWriterScriptImpl<Holder>(); + break; + case kNoWspecifier: default: + KALDI_WARN << "ClassifyWspecifier: invalid wspecifier " << wspecifier; + return false; + } + if (impl_->Open(wspecifier)) return true; + else { // The class will have printed a more specific warning. + delete impl_; + impl_ = NULL; + return false; + } +} + +template<class Holder> +void TableWriter<Holder>::Write(const std::string &key, + const T &value) const { + CheckImpl(); + if (!impl_->Write(key, value)) + KALDI_ERR << "Error in TableWriter::Write"; + // More specific warning will have + // been printed in the Write function. +} + +template<class Holder> +void TableWriter<Holder>::Flush() { + CheckImpl(); + impl_->Flush(); +} + +template<class Holder> +bool TableWriter<Holder>::Close() { + CheckImpl(); + bool ans = impl_->Close(); + delete impl_; // We don't keep around non-open impl_ objects [c.f. definition of IsOpen()] + impl_ = NULL; + return ans; +} + +template<class Holder> +TableWriter<Holder>::~TableWriter() { + if (IsOpen() && !Close()) { + KALDI_ERR << "Error closing TableWriter [in destructor]."; + } +} + + +// Types of RandomAccessTableReader: +// In principle, we would like to have four types of RandomAccessTableReader: +// the 4 combinations [scp, archive], [seekable, not-seekable], +// where if something is seekable we only store a file offset. However, +// it seems sufficient for now to only implement two of these, in both +// cases assuming it's not seekable so we never store file offsets and always +// store either the scp line or the data in the archive. The reasons are: +// (1) +// For scp files, storing the actual entry is not that much more expensive +// than storing the file offsets (since the entries are just filenames), and +// avoids a lot of fseek operations that might be expensive. +// (2) +// For archive files, there is no real reason, if you have the archive file +// on disk somewhere, why you wouldn't access it via its associated scp. +// [i.e. write it as ark, scp]. The main reason to read archives directly +// is if they are part of a pipe, and in this case it's not seekable, so +// we implement only this case. +// +// Note that we will rarely in practice have to keep in memory everything in +// the archive, as long as things are only read once from the archive (the +// "o, " or "once" option) and as long as we keep our keys in sorted order; to take +// advantage of this we need the "s, " (sorted) option, so we would read archives +// as e.g. "s, o, ark:-" (this is the rspecifier we would use if it was the +// standard input and these conditions held). + +template<class Holder> class RandomAccessTableReaderImplBase { + public: + typedef typename Holder::T T; + + virtual bool Open(const std::string &rspecifier) = 0; + + virtual bool HasKey(const std::string &key) = 0; + + virtual const T &Value(const std::string &key) = 0; + + virtual bool Close() = 0; + + virtual ~RandomAccessTableReaderImplBase() {} +}; + + +// Implementation of RandomAccessTableReader for a script file; for simplicity we +// just read it in all in one go, as it's unlikely someone would generate this +// from a pipe. In principle we could read it on-demand as for the archives, but +// this would probably be overkill. + +// Note: the code for this this class is similar to TableWriterScriptImpl: +// try to keep them in sync. +template<class Holder> +class RandomAccessTableReaderScriptImpl: + public RandomAccessTableReaderImplBase<Holder> { + + public: + typedef typename Holder::T T; + + RandomAccessTableReaderScriptImpl(): last_found_(0), state_(kUninitialized) {} + + virtual bool Open(const std::string &rspecifier) { + switch (state_) { + case kNotHaveObject: case kHaveObject: case kGaveObject: + KALDI_ERR << " Opening already open RandomAccessTableReader: call Close first."; + case kUninitialized: case kNotReadScript: + break; + } + rspecifier_ = rspecifier; + RspecifierType rs = ClassifyRspecifier(rspecifier, + &script_rxfilename_, + &opts_); + KALDI_ASSERT(rs == kScriptRspecifier); // or wrongly called. + KALDI_ASSERT(script_.empty()); // no way it could be nonempty at this point. + + if (! ReadScriptFile(script_rxfilename_, + true, // print any warnings + &script_)) { // error reading script file or invalid format + state_ = kNotReadScript; + return false; // no need to print further warnings. user gets the error. + } + + rspecifier_ = rspecifier; + // If opts_.sorted, the user has asserted that the keys are already sorted. + // Although we could easily sort them, we want to let the user know of this + // mistake. This same mistake could have serious effects if used with an + // archive rather than a script. + if (!opts_.sorted) + std::sort(script_.begin(), script_.end()); + for (size_t i = 0; i+1 < script_.size(); i++) { + if (script_[i].first.compare(script_[i+1].first) >= 0) { + // script[i] not < script[i+1] in lexical order... + bool same = (script_[i].first == script_[i+1].first); + KALDI_WARN << "Script file " << PrintableRxfilename(script_rxfilename_) + << (same ? " contains duplicate key: " : + " is not sorted (remove s, option or add ns, option): key is ") + << script_[i].first; + state_ = kNotReadScript; + return false; + } + } + state_ = kNotHaveObject; + return true; + } + + virtual bool IsOpen() const { + return (state_ == kNotHaveObject || state_ == kHaveObject || + state_ == kGaveObject); + } + + virtual bool Close() { + if (!IsOpen()) + KALDI_ERR << "Close() called on RandomAccessTableReader that was not open."; + holder_.Clear(); + state_ = kUninitialized; + last_found_ = 0; + script_.clear(); + current_key_ = ""; + // This one cannot fail because any errors of a "global" + // nature would have been detected when we did Open(). + // With archives it's different. + return true; + } + + virtual bool HasKey(const std::string &key) { + bool preload = opts_.permissive; + // In permissive mode, we have to check that we can read + // the scp entry before we assert that the key is there. + return HasKeyInternal(key, preload); + } + + + // Write returns true on success, false on failure, but + // some errors may not be detected till we call Close(). + virtual const T& Value(const std::string &key) { + + if (!IsOpen()) + KALDI_ERR << "Value() called on non-open object."; + + if (!((state_ == kHaveObject || state_ == kGaveObject) + && key == current_key_)) { // Not already stored... + bool has_key = HasKeyInternal(key, true); // preload. + if (!has_key) + KALDI_ERR << "Could not get item for key " << key + << ", rspecifier is " << rspecifier_ << "[to ignore this, " + << "add the p, (permissive) option to the rspecifier."; + KALDI_ASSERT(state_ == kHaveObject && key == current_key_); + } + + if (state_ == kHaveObject) { + state_ = kGaveObject; + if (opts_.once) MakeTombstone(key); // make sure that future lookups fail. + return holder_.Value(); + } else { // state_ == kGaveObject + if (opts_.once) + KALDI_ERR << "Value called twice for the same key and ,o (once) option " + << "is used: rspecifier is " << rspecifier_; + return holder_.Value(); + } + } + + virtual ~RandomAccessTableReaderScriptImpl() { + if (state_ == kHaveObject || state_ == kGaveObject) + holder_.Clear(); + } + + private: + // HasKeyInternal when called with preload == false just tells us whether the + // key is in the scp. With preload == true, which happens when the ,p + // (permissive) option is given in the rspecifier, it will also check that we + // can preload the object from disk (loading from the rxfilename in the scp), + // and only return true if we can. This function is called both from HasKey + // and from Value(). + virtual bool HasKeyInternal(const std::string &key, bool preload) { + switch (state_) { + case kUninitialized: case kNotReadScript: + KALDI_ERR << "HasKey called on RandomAccessTableReader object that is not open."; + case kHaveObject: case kGaveObject: + if (key == current_key_) + return true; + break; + default: break; + } + KALDI_ASSERT(IsToken(key)); + size_t key_pos = 0; // set to zero to suppress warning + bool ans = LookupKey(key, &key_pos); + if (!ans) return false; + else { + // First do a check regarding the "once" option. + if (opts_.once && script_[key_pos].second == "") { // A "tombstone"; user is asking about + // already-read key. + KALDI_ERR << "HasKey called on key whose value was already read, and " + " you specified the \"once\" option (o, ): try removing o, or adding no, :" + " rspecifier is " << rspecifier_; + } + if (!preload) + return true; // we have the key. + else { // preload specified, so we have to pre-load the object before returning true. + if (!input_.Open(script_[key_pos].second)) { + KALDI_WARN << "Error opening stream " + << PrintableRxfilename(script_[key_pos].second); + return false; + } else { + // Make sure holder empty. + if (state_ == kHaveObject || state_ == kGaveObject) + holder_.Clear(); + if (holder_.Read(input_.Stream())) { + state_ = kHaveObject; + current_key_ = key; + return true; + } else { + KALDI_WARN << "Error reading object from " + "stream " << PrintableRxfilename(script_[key_pos].second); + state_ = kNotHaveObject; + return false; + } + } + } + } + } + void MakeTombstone(const std::string &key) { + size_t offset; + if (!LookupKey(key, &offset)) + KALDI_ERR << "RandomAccessTableReader object in inconsistent state."; + else + script_[offset].second = ""; + } + bool LookupKey(const std::string &key, size_t *script_offset) { + // First, an optimization: if we're going consecutively, this will + // make the lookup very fast. Since we may call HasKey and then + // Value(), which both may look up the key, we test if either the + // current or next position are correct. + if (last_found_ < script_.size() && script_[last_found_].first == key) { + *script_offset = last_found_; + return true; + } + last_found_++; + if (last_found_ < script_.size() && script_[last_found_].first == key) { + *script_offset = last_found_; + return true; + } + std::pair<std::string, std::string> pr(key, ""); // Important that "" + // compares less than or equal to any string, so lower_bound points to the + // element that has the same key. + typedef typename std::vector<std::pair<std::string, std::string> >::const_iterator + IterType; + IterType iter = std::lower_bound(script_.begin(), script_.end(), pr); + if (iter != script_.end() && iter->first == key) { + last_found_ = *script_offset = iter - script_.begin(); + return true; + } else { + return false; + } + } + + + Input input_; // Use the same input_ object for reading each file, in case + // the scp specifies offsets in an archive (so we can keep the same file open). + RspecifierOptions opts_; + std::string rspecifier_; // rspecifier used to open it; used in debug messages + std::string script_rxfilename_; // filename of script. + + std::string current_key_; // Key of object in holder_ + Holder holder_; + + // the script_ variable contains pairs of (key, filename), sorted using + // std::sort. This can be used with binary_search to look up filenames for + // writing. If this becomes inefficient we can use std::unordered_map (but I + // suspect this wouldn't be significantly faster & would use more memory). + // If memory becomes a problem here, the user should probably be passing + // only the relevant part of the scp file rather than expecting us to get too + // clever in the code. + std::vector<std::pair<std::string, std::string> > script_; + size_t last_found_; // This is for an optimization used in FindFilename. + + enum { // [Do we have [Does holder_ + // script_ set up?] contain object?] + kUninitialized, // no no + kNotReadScript, // no no + kNotHaveObject, // yes no + kHaveObject, // yes yes + kGaveObject, // yes yes + // [kGaveObject is as kHaveObject but we note that the + // user has already read it; this is for checking that + // if "once" is specified, the user actually only reads + // it once. + } state_; + +}; + + + + +// This is the base-class (with some implemented functions) for the +// implementations of RandomAccessTableReader when it's an archive. This +// base-class handles opening the files, storing the state of the reading +// process, and loading objects. This is the only case in which we have +// an intermediate class in the hierarchy between the virtual ImplBase +// class and the actual Impl classes. +// The child classes vary in the assumptions regarding sorting, etc. + +template<class Holder> class RandomAccessTableReaderArchiveImplBase: + public RandomAccessTableReaderImplBase<Holder> { + public: + typedef typename Holder::T T; + + RandomAccessTableReaderArchiveImplBase(): holder_(NULL), state_(kUninitialized) { } + + virtual bool Open(const std::string &rspecifier) { + if (state_ != kUninitialized) { + if (! this->Close()) // call Close() yourself to suppress this exception. + KALDI_ERR << "TableReader::Open, error closing previous input."; + } + rspecifier_ = rspecifier; + RspecifierType rs = ClassifyRspecifier(rspecifier, &archive_rxfilename_, + &opts_); + KALDI_ASSERT(rs == kArchiveRspecifier); + + // NULL means don't expect binary-mode header + bool ans; + if (Holder::IsReadInBinary()) + ans = input_.Open(archive_rxfilename_, NULL); + else + ans = input_.OpenTextMode(archive_rxfilename_); + if (!ans) { // header. + KALDI_WARN << "TableReader: failed to open stream " + << PrintableRxfilename(archive_rxfilename_); + state_ = kUninitialized; // Failure on Open + return false; // User should print the error message. + } else { + state_ = kNoObject; + } + return true; + } + + // ReadNextObject() requires that the state be kNoObject, + // and it will try read the next object. If it succeeds, + // it sets the state to kHaveObject, and + // cur_key_ and holder_ have the key and value. If it fails, + // it sets the state to kError or kEof. + void ReadNextObject() { + if (state_ != kNoObject) + KALDI_ERR << "TableReader: ReadNextObject() called from wrong state."; // Code error + // somewhere in this class or a child class. + std::istream &is = input_.Stream(); + is.clear(); // Clear any fail bits that may have been set... just in case + // this happened in the Read function. + is >> cur_key_; // This eats up any leading whitespace and gets the string. + if (is.eof()) { + state_ = kEof; + return; + } + if (is.fail()) { // This shouldn't really happen, barring file-system errors. + KALDI_WARN << "Error reading archive: rspecifier is " << rspecifier_; + state_ = kError; + return; + } + int c; + if ((c = is.peek()) != ' ' && c != '\t' && c != '\n') { // We expect a space ' ' after the key. + // We also allow tab, just so we can read archives generated by scripts that may + // not be fully aware of how this format works. + KALDI_WARN << "Invalid archive file format: expected space after key " <<cur_key_ + <<", got character " + << CharToString(static_cast<char>(is.peek())) << ", reading archive " + << PrintableRxfilename(archive_rxfilename_); + state_ = kError; + return; + } + if (c != '\n') is.get(); // Consume the space or tab. + holder_ = new Holder; + if (holder_->Read(is)) { + state_ = kHaveObject; + return; + } else { + KALDI_WARN << "Object read failed, reading archive " + << PrintableRxfilename(archive_rxfilename_); + state_ = kError; + delete holder_; + holder_ = NULL; + return; + } + } + + virtual bool IsOpen() const { + switch (state_) { + case kEof: case kError: case kHaveObject: case kNoObject: return true; + case kUninitialized: return false; + default: KALDI_ERR << "IsOpen() called on invalid object."; + return false; + } + } + + // Called by the child-class virutal Close() functions; does the + // shared parts of the cleanup. + bool CloseInternal() { + if (! this->IsOpen()) + KALDI_ERR << "Close() called on TableReader twice or otherwise wrongly."; + if (input_.IsOpen()) + input_.Close(); + if (state_ == kHaveObject) { + KALDI_ASSERT(holder_ != NULL); + delete holder_; + holder_ = NULL; + } else KALDI_ASSERT(holder_ == NULL); + bool ans = (state_ != kError); + state_ = kUninitialized; + if (!ans && opts_.permissive) { + KALDI_WARN << "Error state detected closing reader. " + << "Ignoring it because you specified permissive mode."; + return true; + } + return ans; + } + + ~RandomAccessTableReaderArchiveImplBase() { + // The child class has the responsibility to call CloseInternal(). + KALDI_ASSERT(state_ == kUninitialized && holder_ == NULL); + } + private: + Input input_; // Input object for the archive + protected: + // The variables below are accessed by child classes. + + std::string cur_key_; // current key (if state == kHaveObject). + Holder *holder_; // Holds the object we just read (if state == kHaveObject) + + std::string rspecifier_; + std::string archive_rxfilename_; + RspecifierOptions opts_; + + enum { // [The state of the reading process] [does holder_ [is input_ + // have object] open] + kUninitialized, // Uninitialized or closed no no + kNoObject, // Do not have object in holder_ no yes + kHaveObject, // Have object in holder_ yes yes + kEof, // End of file no yes + kError, // Some kind of error-state in the reading. no yes + } state_; + +}; + + +// RandomAccessTableReaderDSortedArchiveImpl (DSorted for "doubly sorted") is the +// implementation for random-access reading of archives when both the archive, +// and the calling code, are in sorted order (i.e. we ask for the keys in sorted +// order). This is when the s and cs options are both given. It only ever has +// to keep one object in memory. It inherits from +// RandomAccessTableReaderArchiveImplBase which implements the common parts of +// RandomAccessTableReader that are used when it's an archive we're reading from. + +template<class Holder> class RandomAccessTableReaderDSortedArchiveImpl: + public RandomAccessTableReaderArchiveImplBase<Holder> { + using RandomAccessTableReaderArchiveImplBase<Holder>::kUninitialized; + using RandomAccessTableReaderArchiveImplBase<Holder>::kHaveObject; + using RandomAccessTableReaderArchiveImplBase<Holder>::kNoObject; + using RandomAccessTableReaderArchiveImplBase<Holder>::kEof; + using RandomAccessTableReaderArchiveImplBase<Holder>::kError; + using RandomAccessTableReaderArchiveImplBase<Holder>::state_; + using RandomAccessTableReaderArchiveImplBase<Holder>::opts_; + using RandomAccessTableReaderArchiveImplBase<Holder>::cur_key_; + using RandomAccessTableReaderArchiveImplBase<Holder>::holder_; + using RandomAccessTableReaderArchiveImplBase<Holder>::rspecifier_; + using RandomAccessTableReaderArchiveImplBase<Holder>::archive_rxfilename_; + using RandomAccessTableReaderArchiveImplBase<Holder>::ReadNextObject; + public: + typedef typename Holder::T T; + + RandomAccessTableReaderDSortedArchiveImpl() { } + + virtual bool Close() { + // We don't have anything additional to clean up, so just + // call generic base-class one. + return this->CloseInternal(); + } + + virtual bool HasKey(const std::string &key) { + return FindKeyInternal(key); + } + virtual const T & Value(const std::string &key) { + if (FindKeyInternal(key)) { + KALDI_ASSERT(this->state_ == kHaveObject && key == this->cur_key_ + && holder_ != NULL); + return this->holder_->Value(); + } else { + KALDI_ERR << "Value() called but no such key " << key + << " in archive " << PrintableRxfilename(archive_rxfilename_); + return *(const T*)NULL; // keep compiler happy. + } + } + + virtual ~RandomAccessTableReaderDSortedArchiveImpl() { + if (this->IsOpen()) + if (!Close()) // more specific warning will already have been printed. + // we are in some kind of error state & user did not find out by + // calling Close(). + KALDI_ERR << "Error closing RandomAccessTableReader: rspecifier is " + << rspecifier_; + } + private: + // FindKeyInternal tries to find the key by calling "ReadNextObject()" + // as many times as necessary till we get to it. It is called from + // both FindKey and Value(). + bool FindKeyInternal(const std::string &key) { + // First check that the user is calling us right: should be + // in sorted order. If not, error. + if (!last_requested_key_.empty()) { + if (key.compare(last_requested_key_) < 0) { // key < last_requested_key_ + KALDI_ERR << "You provided the \"cs\" option " + << "but are not calling with keys in sorted order: " + << key << " < " << last_requested_key_ << ": rspecifier is " + << rspecifier_; + } + } + // last_requested_key_ is just for debugging of order of calling. + last_requested_key_ = key; + + if (state_ == kNoObject) + ReadNextObject(); // This can only happen + // once, the first time someone calls HasKey() or Value(). We don't + // do it in the initializer to stop the program hanging too soon, + // if reading from a pipe. + + if (state_ == kEof || state_ == kError) return false; + + if (state_ == kUninitialized) + KALDI_ERR << "Trying to access a RandomAccessTableReader object that is not open."; + + std::string last_key_; // To check that + // the archive we're reading is in sorted order. + while (1) { + KALDI_ASSERT(state_ == kHaveObject); + int compare = key.compare(cur_key_); + if (compare == 0) { // key == key_ + return true; // we got it.. + } else if (compare < 0) { // key < cur_key_, so we already read past the + // place where we want to be. This implies that we will never find it + // [due to the sorting etc., this means it just isn't in the archive]. + return false; + } else { // compare > 0, key > cur_key_. We need to read further ahead. + last_key_ = cur_key_; + // read next object.. we have to set state to kNoObject first. + KALDI_ASSERT(holder_ != NULL); + delete holder_; + holder_ = NULL; + state_ = kNoObject; + ReadNextObject(); + if (state_ != kHaveObject) + return false; // eof or read error. + if (cur_key_.compare(last_key_) <= 0) { + KALDI_ERR << "You provided the \"s\" option " + << " (sorted order), but keys are out of order or duplicated: " + << last_key_ << " is followed by " << cur_key_ + << ": rspecifier is " << rspecifier_; + } + } + } + } + + /// Last string provided to HasKey() or Value(); + std::string last_requested_key_; + + +}; + +// RandomAccessTableReaderSortedArchiveImpl is for random-access reading of +// archives when the user specified the sorted (s) option but not the +// called-sorted (cs) options. +template<class Holder> class RandomAccessTableReaderSortedArchiveImpl: + public RandomAccessTableReaderArchiveImplBase<Holder> { + using RandomAccessTableReaderArchiveImplBase<Holder>::kUninitialized; + using RandomAccessTableReaderArchiveImplBase<Holder>::kHaveObject; + using RandomAccessTableReaderArchiveImplBase<Holder>::kNoObject; + using RandomAccessTableReaderArchiveImplBase<Holder>::kEof; + using RandomAccessTableReaderArchiveImplBase<Holder>::kError; + using RandomAccessTableReaderArchiveImplBase<Holder>::state_; + using RandomAccessTableReaderArchiveImplBase<Holder>::opts_; + using RandomAccessTableReaderArchiveImplBase<Holder>::cur_key_; + using RandomAccessTableReaderArchiveImplBase<Holder>::holder_; + using RandomAccessTableReaderArchiveImplBase<Holder>::rspecifier_; + using RandomAccessTableReaderArchiveImplBase<Holder>::archive_rxfilename_; + using RandomAccessTableReaderArchiveImplBase<Holder>::ReadNextObject; + + public: + typedef typename Holder::T T; + + RandomAccessTableReaderSortedArchiveImpl(): + last_found_index_(static_cast<size_t>(-1)), + pending_delete_(static_cast<size_t>(-1)) { } + + virtual bool Close() { + for (size_t i = 0; i < seen_pairs_.size(); i++) + if (seen_pairs_[i].second) + delete seen_pairs_[i].second; + seen_pairs_.clear(); + + pending_delete_ = static_cast<size_t>(-1); + last_found_index_ = static_cast<size_t>(-1); + + return this->CloseInternal(); + } + virtual bool HasKey(const std::string &key) { + HandlePendingDelete(); + size_t index; + bool ans = FindKeyInternal(key, &index); + if (ans && opts_.once && seen_pairs_[index].second == NULL) { + // Just do a check RE the once option. "&&opts_.once" is for + // efficiency since this can only happen in that case. + KALDI_ERR << "Error: HasKey called after Value() already called for " + << " that key, and once (o) option specified: rspecifier is " + << rspecifier_; + } + return ans; + } + virtual const T & Value(const std::string &key) { + HandlePendingDelete(); + size_t index; + if (FindKeyInternal(key, &index)) { + if (seen_pairs_[index].second == NULL) { // can happen if opts.once_ + KALDI_ERR << "Error: Value() called more than once for key " + << key << " and once (o) option specified: rspecifier is " + << rspecifier_; + } + if (opts_.once) + pending_delete_ = index; // mark this index to be deleted on next call. + return seen_pairs_[index].second->Value(); + } else { + KALDI_ERR << "Value() called but no such key " << key + << " in archive " << PrintableRxfilename(archive_rxfilename_); + return *(const T*)NULL; // keep compiler happy. + } + } + virtual ~RandomAccessTableReaderSortedArchiveImpl() { + if (this->IsOpen()) + if (!Close()) // more specific warning will already have been printed. + // we are in some kind of error state & user did not find out by + // calling Close(). + KALDI_ERR << "Error closing RandomAccessTableReader: rspecifier is " + << rspecifier_; + } + private: + void HandlePendingDelete() { + const size_t npos = static_cast<size_t>(-1); + if (pending_delete_ != npos) { + KALDI_ASSERT(pending_delete_ < seen_pairs_.size()); + KALDI_ASSERT(seen_pairs_[pending_delete_].second != NULL); + delete seen_pairs_[pending_delete_].second; + seen_pairs_[pending_delete_].second = NULL; + pending_delete_ = npos; + } + } + + // FindKeyInternal tries to find the key in the array "seen_pairs_". + // If it is not already there, it reads ahead as far as necessary + // to determine whether we have the key or not. On success it returns + // true and puts the index into the array seen_pairs_, into "index"; + // on failure it returns false. + // It will leave the state as either kNoObject, kEof or kError. + // FindKeyInternal does not do any checking about whether you are asking + // about a key that has been already given (with the "once" option). + // That is the user's responsibility. + + bool FindKeyInternal(const std::string &key, size_t *index) { + // First, an optimization in case the previous call was for the + // same key, and we found it. + if (last_found_index_ < seen_pairs_.size() + && seen_pairs_[last_found_index_].first == key) { + *index = last_found_index_; + return true; + } + + if (state_ == kUninitialized) + KALDI_ERR << "Trying to access a RandomAccessTableReader object that is not open."; + + // Step one is to see whether we have to read ahead for the object.. + // Note, the possible states right now are kNoObject, kEof or kError. + // We are never in the state kHaveObject except just after calling + // ReadNextObject(). + bool looped = false; + while (state_ == kNoObject && + (seen_pairs_.empty() || key.compare(seen_pairs_.back().first) > 0)) { + looped = true; + // Read this as: + // while ( the stream is potentially good for reading && + // ([got no keys] || key > most_recent_key) ) { ... + // Try to read a new object. + // Note that the keys in seen_pairs_ are ordered from least to greatest. + ReadNextObject(); + if (state_ == kHaveObject) { // Successfully read object. + if (!seen_pairs_.empty() && // This is just a check. + cur_key_.compare(seen_pairs_.back().first) <= 0) { + // read the expression above as: !( cur_key_ > previous_key). + // it means we are not in sorted order [the user specified that we + // are, or we would not be using this implementation]. + KALDI_ERR << "You provided the sorted (s) option but keys in archive " + << PrintableRxfilename(archive_rxfilename_) << " are not " + << "in sorted order: " << seen_pairs_.back().first + << " is followed by " << cur_key_; + } + KALDI_ASSERT(holder_ != NULL); + seen_pairs_.push_back(std::make_pair(cur_key_, holder_)); + holder_ = NULL; + state_ = kNoObject; + } + } + if (looped) { // We only need to check the last element of the seen_pairs_ array, + // since we would not have read more after getting "key". + if (!seen_pairs_.empty() && seen_pairs_.back().first == key) { + last_found_index_ = *index = seen_pairs_.size() - 1; + return true; + } else return false; + } + // Now we have do an actual binary search in the seen_pairs_ array. + std::pair<std::string, Holder*> pr(key, static_cast<Holder*>(NULL)); + typename std::vector<std::pair<std::string, Holder*> >::iterator + iter = std::lower_bound(seen_pairs_.begin(), seen_pairs_.end(), + pr, PairCompare()); + if (iter != seen_pairs_.end() && + key == iter->first) { + last_found_index_ = *index = (iter - seen_pairs_.begin()); + return true; + } else return false; + } + + // These are the pairs of (key, object) we have read. We keep all the keys we + // have read but the actual objects (if they are stored with pointers inside + // the Holder object) may be deallocated if once == true, and the Holder + // pointer set to NULL. + std::vector<std::pair<std::string, Holder*> > seen_pairs_; + size_t last_found_index_; // An optimization s.t. if FindKeyInternal called twice with + // same key (as it often will), it doesn't have to do the key search twice. + size_t pending_delete_; // If opts_.once == true, this is the index of + // element of seen_pairs_ that is pending deletion. + struct PairCompare { + // PairCompare is the Less-than operator for the pairs of(key, Holder). + // compares the keys. + inline bool operator() (const std::pair<std::string, Holder*> &pr1, + const std::pair<std::string, Holder*> &pr2) { + return (pr1.first.compare(pr2.first) < 0); + } + }; +}; + + + +// RandomAccessTableReaderUnsortedArchiveImpl is for random-access reading of +// archives when the user does not specify the sorted (s) option (in this case +// the called-sorted, or "cs" option, is ignored). This is the least efficient +// of the random access archive readers, in general, but it can be as efficient +// as the others, in speed, memory and latency, if the "once" option is specified +// and it happens that the keys of the archive are the same as the keys the code +// is called with (to HasKey() and Value()), and in the same order. However, if +// you ask it for a key that's not present it will have to read the archive till +// the end and store it all in memory. + +template<class Holder> class RandomAccessTableReaderUnsortedArchiveImpl: + public RandomAccessTableReaderArchiveImplBase<Holder> { + using RandomAccessTableReaderArchiveImplBase<Holder>::kUninitialized; + using RandomAccessTableReaderArchiveImplBase<Holder>::kHaveObject; + using RandomAccessTableReaderArchiveImplBase<Holder>::kNoObject; + using RandomAccessTableReaderArchiveImplBase<Holder>::kEof; + using RandomAccessTableReaderArchiveImplBase<Holder>::kError; + using RandomAccessTableReaderArchiveImplBase<Holder>::state_; + using RandomAccessTableReaderArchiveImplBase<Holder>::opts_; + using RandomAccessTableReaderArchiveImplBase<Holder>::cur_key_; + using RandomAccessTableReaderArchiveImplBase<Holder>::holder_; + using RandomAccessTableReaderArchiveImplBase<Holder>::rspecifier_; + using RandomAccessTableReaderArchiveImplBase<Holder>::archive_rxfilename_; + using RandomAccessTableReaderArchiveImplBase<Holder>::ReadNextObject; + + typedef typename Holder::T T; + + public: + RandomAccessTableReaderUnsortedArchiveImpl(): to_delete_iter_(map_.end()), + to_delete_iter_valid_(false) + { + map_.max_load_factor(0.5); // make it quite empty -> quite efficient. + // default seems to be 1. + } + + virtual bool Close() { + for (typename MapType::iterator iter = map_.begin(); + iter != map_.end(); + ++iter) { + if (iter->second) + delete iter->second; + } + map_.clear(); + first_deleted_string_ = ""; + to_delete_iter_valid_ = false; + return this->CloseInternal(); + } + + virtual bool HasKey(const std::string &key) { + HandlePendingDelete(); + return FindKeyInternal(key, NULL); + } + virtual const T & Value(const std::string &key) { + HandlePendingDelete(); + const T *ans_ptr = NULL; + if (FindKeyInternal(key, &ans_ptr)) + return *ans_ptr; + else + KALDI_ERR << "Value() called but no such key " << key + << " in archive " << PrintableRxfilename(archive_rxfilename_); + return *(const T*)NULL; // keep compiler happy. + } + virtual ~RandomAccessTableReaderUnsortedArchiveImpl() { + if (this->IsOpen()) + if (!Close()) // more specific warning will already have been printed. + // we are in some kind of error state & user did not find out by + // calling Close(). + KALDI_ERR << "Error closing RandomAccessTableReader: rspecifier is " + << rspecifier_; + } + private: + void HandlePendingDelete() { + if (to_delete_iter_valid_) { + to_delete_iter_valid_ = false; + delete to_delete_iter_->second; // Delete Holder object. + if (first_deleted_string_.length() == 0) + first_deleted_string_ = to_delete_iter_->first; + map_.erase(to_delete_iter_); // delete that element. + } + } + + // FindKeyInternal tries to find the key in the map "map_" + // If it is not already there, it reads ahead either until it finds the + // key, or until end of file. If called with value_ptr == NULL, + // it assumes it's called from HasKey() and just returns true or false + // and doesn't otherwise have side effects. If called with value_ptr != + // NULL, it assumes it's called from Value(). Thus, it will crash + // if it cannot find the key. If it can find it it puts its address in + // *value_ptr, and if opts_once == true it will mark that element of the + // map to be deleted. + + bool FindKeyInternal(const std::string &key, const T **value_ptr = NULL) { + typename MapType::iterator iter = map_.find(key); + if (iter != map_.end()) { // Found in the map... + if (value_ptr == NULL) { // called from HasKey + return true; // this is all we have to do. + } else { + *value_ptr = &(iter->second->Value()); + if (opts_.once) { // value won't be needed again, so mark + // for deletion. + to_delete_iter_ = iter; // pending delete. + KALDI_ASSERT(!to_delete_iter_valid_); + to_delete_iter_valid_ = true; + } + return true; + } + } + while (state_ == kNoObject) { + ReadNextObject(); + if (state_ == kHaveObject) { // Successfully read object. + state_ = kNoObject; // we are about to transfer ownership + // of the object in holder_ to map_. + // Insert it into map_. + std::pair<typename MapType::iterator, bool> pr = + map_.insert(typename MapType::value_type(cur_key_, holder_)); + + if (!pr.second) { // Was not inserted-- previous element w/ same key + delete holder_; // map was not changed, no ownership transferred. + holder_ = NULL; + KALDI_ERR << "Error in RandomAccessTableReader: duplicate key " + << cur_key_ << " in archive " << archive_rxfilename_; + } + holder_ = NULL; // ownership transferred to map_. + if (cur_key_ == key) { // the one we wanted.. + if (value_ptr == NULL) { // called from HasKey + return true; + } else { // called from Value() + *value_ptr = &(pr.first->second->Value()); // this gives us the + // Value() from the Holder in the map. + if (opts_.once) { // mark for deletion, as won't be needed again. + to_delete_iter_ = pr.first; + KALDI_ASSERT(!to_delete_iter_valid_); + to_delete_iter_valid_ = true; + } + return true; + } + } + } + } + if (opts_.once && key == first_deleted_string_) { + KALDI_ERR << "You specified the once (o) option but " + << "you are calling using key " << key + << " more than once: rspecifier is " << rspecifier_; + } + return false; // We read the entire archive (or got to error state) and didn't + // find it. + } + + typedef unordered_map<std::string, Holder*, StringHasher> MapType; + MapType map_; + + typename MapType::iterator to_delete_iter_; + bool to_delete_iter_valid_; + + std::string first_deleted_string_; // keep the first string we deleted + // from map_ (if opts_.once == true). It's for an inexact spot-check that the + // "once" option isn't being used incorrectly. + +}; + + + + + +template<class Holder> +RandomAccessTableReader<Holder>::RandomAccessTableReader(const std::string &rspecifier): + impl_(NULL) { + if (rspecifier != "" && !Open(rspecifier)) + KALDI_ERR << "Error opening RandomAccessTableReader object " + " (rspecifier is: " << rspecifier << ")"; +} + +template<class Holder> +bool RandomAccessTableReader<Holder>::Open(const std::string &rspecifier) { + if (IsOpen()) + KALDI_ERR << "Already open."; + RspecifierOptions opts; + RspecifierType rs = ClassifyRspecifier(rspecifier, NULL, &opts); + switch (rs) { + case kScriptRspecifier: + impl_ = new RandomAccessTableReaderScriptImpl<Holder>(); + break; + case kArchiveRspecifier: + if (opts.sorted) { + if (opts.called_sorted) // "doubly" sorted case. + impl_ = new RandomAccessTableReaderDSortedArchiveImpl<Holder>(); + else + impl_ = new RandomAccessTableReaderSortedArchiveImpl<Holder>(); + } else impl_ = new RandomAccessTableReaderUnsortedArchiveImpl<Holder>(); + break; + case kNoRspecifier: default: + KALDI_WARN << "Invalid rspecifier: " + << rspecifier; + return false; + } + if (impl_->Open(rspecifier)) + return true; + else { + // Warning will already have been printed. + delete impl_; + impl_ = NULL; + return false; + } +} + +template<class Holder> +bool RandomAccessTableReader<Holder>::HasKey(const std::string &key) { + CheckImpl(); + if (!IsToken(key)) + KALDI_ERR << "Invalid key \"" << key << '"'; + return impl_->HasKey(key); +} + + +template<class Holder> +const typename RandomAccessTableReader<Holder>::T& +RandomAccessTableReader<Holder>::Value(const std::string &key) { + CheckImpl(); + return impl_->Value(key); +} + +template<class Holder> +bool RandomAccessTableReader<Holder>::Close() { + CheckImpl(); + bool ans =impl_->Close(); + delete impl_; + impl_ = NULL; + return ans; +} + +template<class Holder> +RandomAccessTableReader<Holder>::~RandomAccessTableReader() { + if (IsOpen() && !Close()) // call Close() yourself to stop this being thrown. + KALDI_ERR << "failure detected in destructor."; +} + +template<class Holder> +void SequentialTableReader<Holder>::CheckImpl() const { + if (!impl_) { + KALDI_ERR << "Trying to use empty SequentialTableReader (perhaps you " + << "passed the empty string as an argument to a program?)"; + } +} + +template<class Holder> +void RandomAccessTableReader<Holder>::CheckImpl() const { + if (!impl_) { + KALDI_ERR << "Trying to use empty RandomAccessTableReader (perhaps you " + << "passed the empty string as an argument to a program?)"; + } +} + +template<class Holder> +void TableWriter<Holder>::CheckImpl() const { + if (!impl_) { + KALDI_ERR << "Trying to use empty TableWriter (perhaps you " + << "passed the empty string as an argument to a program?)"; + } +} + +template<class Holder> +RandomAccessTableReaderMapped<Holder>::RandomAccessTableReaderMapped( + const std::string &table_rxfilename, + const std::string &utt2spk_rxfilename): + reader_(table_rxfilename), token_reader_(table_rxfilename.empty() ? "" : + utt2spk_rxfilename), + utt2spk_rxfilename_(utt2spk_rxfilename) { } + +template<class Holder> +bool RandomAccessTableReaderMapped<Holder>::Open( + const std::string &table_rxfilename, + const std::string &utt2spk_rxfilename) { + if (reader_.IsOpen()) reader_.Close(); + if (token_reader_.IsOpen()) token_reader_.Close(); + KALDI_ASSERT(!table_rxfilename.empty()); + if (!reader_.Open(table_rxfilename)) return false; // will have printed + // warning internally, probably. + if (!utt2spk_rxfilename.empty()) { + if (!token_reader_.Open(utt2spk_rxfilename)) { + reader_.Close(); + return false; + } + } + return true; +} + + +template<class Holder> +bool RandomAccessTableReaderMapped<Holder>::HasKey(const std::string &utt) { + // We don't check IsOpen, we let the call go through to the member variable + // (reader_), which will crash with a more informative error message than + // we can give here, as we don't any longer know the rxfilename. + if (token_reader_.IsOpen()) { // We need to map the key from utt to spk. + if (!token_reader_.HasKey(utt)) + KALDI_ERR << "Attempting to read key " << utt << ", which is not present " + << "in utt2spk map or similar map being read from " + << PrintableRxfilename(utt2spk_rxfilename_); + const std::string &spk = token_reader_.Value(utt); + return reader_.HasKey(spk); + } else { + return reader_.HasKey(utt); + } +} + +template<class Holder> +const typename Holder::T& RandomAccessTableReaderMapped<Holder>::Value( + const std::string &utt) { + if (token_reader_.IsOpen()) { // We need to map the key from utt to spk. + if (!token_reader_.HasKey(utt)) + KALDI_ERR << "Attempting to read key " << utt << ", which is not present " + << "in utt2spk map or similar map being read from " + << PrintableRxfilename(utt2spk_rxfilename_); + const std::string &spk = token_reader_.Value(utt); + return reader_.Value(spk); + } else { + return reader_.Value(utt); + } +} + + + +/// @} + +} // end namespace kaldi + + + +#endif diff --git a/kaldi_io/src/kaldi/util/kaldi-table.h b/kaldi_io/src/kaldi/util/kaldi-table.h new file mode 100644 index 0000000..6f6cb98 --- /dev/null +++ b/kaldi_io/src/kaldi/util/kaldi-table.h @@ -0,0 +1,459 @@ +// util/kaldi-table.h + +// Copyright 2009-2011 Microsoft Corporation +// 2013 Johns Hopkins University (author: Daniel Povey) + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#ifndef KALDI_UTIL_KALDI_TABLE_H_ +#define KALDI_UTIL_KALDI_TABLE_H_ + +#include <string> +#include <vector> +#include <utility> + +#include "base/kaldi-common.h" +#include "util/kaldi-holder.h" + +namespace kaldi { + +// Forward declarations +template<class Holder> class RandomAccessTableReaderImplBase; +template<class Holder> class SequentialTableReaderImplBase; +template<class Holder> class TableWriterImplBase; + +/// \addtogroup table_group +/// @{ + +// This header defines the Table classes (RandomAccessTableReader, +// SequentialTableReader and TableWriter) and explains what the Holder classes, +// which the Table class requires as a template argument, are like. It also +// explains the "rspecifier" and "wspecifier" concepts (these are strings that +// explain how to read/write objects via archives or scp files. A table is +// conceptually a collection of objects of a particular type T indexed by keys +// of type std::string (these Keys additionally have an order within each table). +// The Table classes are templated on a type (call it Holder) such that Holder::T +// is a typedef equal to T. + +// see kaldi-holder.h for detail on the Holder classes. + +typedef std::vector<std::string> KeyList; + +// Documentation for "wspecifier" +// "wspecifier" describes how we write a set of objects indexed by keys. +// The basic, unadorned wspecifiers are as follows: +// +// ark:wxfilename +// scp:rxfilename +// ark,scp:filename,wxfilename +// ark,scp:filename,wxfilename +// +// +// We also allow the following modifiers: +// t means text mode. +// b means binary mode. +// f means flush the stream after writing each entry. +// (nf means don't flush, and isn't very useful as the default is to flush). +// p means permissive mode, when writing to an "scp" file only: will ignore +// missing scp entries, i.e. won't write anything for those files but will +// return success status). +// +// So the following are valid wspecifiers: +// ark,b,f:foo +// "ark,b,b:| gzip -c > foo" +// "ark,scp,t,nf:foo.ark,|gzip -c > foo.scp.gz" +// ark,b:- +// +// The meanings of rxfilename and wxfilename are as described in +// kaldi-stream.h (they are filenames but include pipes, stdin/stdout +// and so on; filename is a regular filename. +// + +// The ark:wxfilename type of wspecifier instructs the class to +// write directly to an archive. For small objects (e.g. lists of ints), +// the text archive format will generally be human readable with one line +// per entry in the archive. +// +// The type "scp:xfilename" refers to an scp file which should +// already exist on disk, and tells us where to write the data for +// each key (usually an actual file); each line of the scp file +// would be: +// key xfilename +// +// The type ark,scp:filename,wxfilename means +// we write both an archive and an scp file that specifies offsets into the +// archive, with lines like: +// key filename:12407 +// where the number is the byte offset into the file. +// In this case we restrict the archive-filename to be an actual filename, +// as we can't see a situtation where an extended filename would make sense +// for this (we can't fseek() in pipes). + +enum WspecifierType { + kNoWspecifier, + kArchiveWspecifier, + kScriptWspecifier, + kBothWspecifier +}; + +struct WspecifierOptions { + bool binary; + bool flush; + bool permissive; // will ignore absent scp entries. + WspecifierOptions(): binary(true), flush(false), permissive(false) { } +}; + +// ClassifyWspecifier returns the type of the wspecifier string, +// and (if pointers are non-NULL) outputs the extra information +// about the options, and the script and archive +// filenames. +WspecifierType ClassifyWspecifier(const std::string &wspecifier, + std::string *archive_wxfilename, + std::string *script_wxfilename, + WspecifierOptions *opts); + +// ReadScriptFile reads an .scp file in its entirety, and appends it +// (in order as it was in the scp file) in script_out_, which contains +// pairs of (key, xfilename). The .scp +// file format is: on each line, key xfilename +// where xfilename means rxfilename or wxfilename, and may contain internal spaces +// (we trim away any leading or trailing space). The key is space-free. +// ReadScriptFile returns true if the format was valid (empty files +// are valid). +// If 'print_warnings', it will print out warning messages that explain what kind +// of error there was. +bool ReadScriptFile(const std::string &rxfilename, + bool print_warnings, + std::vector<std::pair<std::string, std::string> > *script_out); + +// This version of ReadScriptFile works from an istream. +bool ReadScriptFile(std::istream &is, + bool print_warnings, + std::vector<std::pair<std::string, std::string> > *script_out); + +// Writes, for each entry in script, the first element, then ' ', then the second +// element then '\n'. Checks that the keys (first elements of pairs) are valid +// tokens (nonempty, no whitespace), and the values (second elements of pairs) +// are newline-free and contain no leading or trailing space. Returns true on +// success. +bool WriteScriptFile(const std::string &wxfilename, + const std::vector<std::pair<std::string, std::string> > &script); + +// This version writes to an ostream. +bool WriteScriptFile(std::ostream &os, + const std::vector<std::pair<std::string, std::string> > &script); + +// Documentation for "rspecifier" +// "rspecifier" describes how we read a set of objects indexed by keys. +// The possibilities are: +// +// ark:rxfilename +// scp:rxfilename +// +// We also allow various modifiers: +// o means the program will only ask for each key once, which enables +// the reader to discard already-asked-for values. +// s means the keys are sorted on input (means we don't have to read till +// eof if someone asked for a key that wasn't there). +// cs means that it is called in sorted order (we are generally asserting this +// based on knowledge of how the program works). +// p means "permissive", and causes it to skip over keys whose corresponding +// scp-file entries cannot be read. [and to ignore errors in archives and +// script files, and just consider the "good" entries]. +// We allow the negation of the options above, as in no, ns, np, +// but these aren't currently very useful (just equivalent to omitting the +// corresponding option). +// [any of the above options can be prefixed by n to negate them, e.g. no, ns, +// ncs, np; but these aren't currently useful as you could just omit the option]. +// +// b is ignored [for scripting convenience] +// t is ignored [for scripting convenience] +// +// +// So for instance the following would be a valid rspecifier: +// +// "o, s, p, ark:gunzip -c foo.gz|" + +struct RspecifierOptions { + // These options only make a difference for the RandomAccessTableReader class. + bool once; // we assert that the program will only ask for each key once. + bool sorted; // we assert that the keys are sorted. + bool called_sorted; // we assert that the (HasKey(), Value() functions will + // also be called in sorted order. [this implies "once" but not vice versa]. + bool permissive; // If "permissive", when reading from scp files it treats + // scp files that can't be read as if the corresponding key were not there. + // For archive files it will suppress errors getting thrown if the archive + + // is corrupted and can't be read to the end. + + RspecifierOptions(): once(false), sorted(false), + called_sorted(false), permissive(false) { } +}; + +enum RspecifierType { + kNoRspecifier, + kArchiveRspecifier, + kScriptRspecifier +}; + +RspecifierType ClassifyRspecifier(const std::string &rspecifier, std::string *rxfilename, + RspecifierOptions *opts); + +// Class Table<Holder> is useful when you want the entire set of +// objects in memory. NOT IMPLEMENTED YET. +// It is the least scalable way of accessing data in Tables. +// The *TableReader and TableWriter classes are more scalable. + + +/// Allows random access to a collection +/// of objects in an archive or script file; see \ref io_sec_tables. +template<class Holder> +class RandomAccessTableReader { + public: + typedef typename Holder::T T; + + RandomAccessTableReader(): impl_(NULL) { } + + // This constructor equivalent to default constructor + "open", but + // throws on error. + RandomAccessTableReader(const std::string &rspecifier); + + // Opens the table. + bool Open(const std::string &rspecifier); + + // Returns true if table is open. + bool IsOpen() const { return (impl_ != NULL); } + + // Close() will close the table [throws if it was not open], + // and returns true on success (false if we were reading an + // archive and we discovered an error in the archive). + bool Close(); + + // Says if it has this key. + // If you are using the "permissive" (p) read option, + // it will return false for keys whose corresponding entry + // in the scp file cannot be read. + + bool HasKey(const std::string &key); + + // Value() may throw if you are reading an scp file, you + // do not have the "permissive" (p) option, and an entry + // in the scp file cannot be read. Typically you won't + // want to catch this error. + const T &Value(const std::string &key); + + ~RandomAccessTableReader(); + + // Allow copy-constructor only for non-opened readers (needed for inclusion in + // stl vector) + RandomAccessTableReader(const RandomAccessTableReader<Holder> &other): + impl_(NULL) { KALDI_ASSERT(other.impl_ == NULL); } + private: + // Disallow assignment. + RandomAccessTableReader &operator=(const RandomAccessTableReader<Holder>&); + void CheckImpl() const; // Checks that impl_ is non-NULL; prints an error + // message and dies (with KALDI_ERR) if NULL. + RandomAccessTableReaderImplBase<Holder> *impl_; +}; + + + +/// A templated class for reading objects sequentially from an archive or script +/// file; see \ref io_sec_tables. +template<class Holder> +class SequentialTableReader { + public: + typedef typename Holder::T T; + + SequentialTableReader(): impl_(NULL) { } + + // This constructor equivalent to default constructor + "open", but + // throws on error. + SequentialTableReader(const std::string &rspecifier); + + // Opens the table. Returns exit status; but does throw if previously + // open stream was in error state. Call Close to stop this [anyway, + // calling Open more than once is not recommended.] + bool Open(const std::string &rspecifier); + + // Returns true if we're done. It will also return true if there's some kind + // of error and we can't read any more; in this case, you can detect the + // error by calling Close and checking the return status; otherwise + // the destructor will throw. + inline bool Done(); + + // Only valid to call Key() if Done() returned false. + inline std::string Key(); + + // FreeCurrent() is provided as an optimization to save memory, for large + // objects. It instructs the class to deallocate the current value. The + // reference Value() will/ be invalidated by this. + + void FreeCurrent(); + + // Return reference to the current value. + // The reference is valid till next call to this object. + // If will throw if you are reading an scp file, did not + // specify the "permissive" (p) option and the file cannot + // be read. [The permissive option makes it behave as if that + // key does not even exist, if the corresponding file cannot be + // read.] You probably wouldn't want to catch this exception; + // the user can just specify the p option in the rspecifier. + const T &Value(); + + // Next goes to the next key. It will not throw; any error will + // result in Done() returning true, and then the destructor will + // throw unless you call Close(). + void Next(); + + // Returns true if table is open for reading (does not imply + // stream is in good state). + bool IsOpen() const; + + // Close() will return false (failure) if Done() became true + // because of an error/ condition rather than because we are + // really done [e.g. because of an error or early termination + // in the archive]. + // If there is an error and you don't call Close(), the destructor + // will fail. + // Close() + bool Close(); + + // The destructor may throw. This is the desired behaviour, as it's the way we + // signal the error to the user (to detect it, call Close(). The issue is that + // otherwise the user has no way to tell whether Done() returned true because + // we reached the end of the archive or script, or because there was an error + // that prevented further reading. + ~SequentialTableReader(); + + // Allow copy-constructor only for non-opened readers (needed for inclusion in + // stl vector) + SequentialTableReader(const SequentialTableReader<Holder> &other): + impl_(NULL) { KALDI_ASSERT(other.impl_ == NULL); } + private: + // Disallow assignment. + SequentialTableReader &operator = (const SequentialTableReader<Holder>&); + void CheckImpl() const; // Checks that impl_ is non-NULL; prints an error + // message and dies (with KALDI_ERR) if NULL. + SequentialTableReaderImplBase<Holder> *impl_; +}; + + +/// A templated class for writing objects to an +/// archive or script file; see \ref io_sec_tables. +template<class Holder> +class TableWriter { + public: + typedef typename Holder::T T; + + TableWriter(): impl_(NULL) { } + + // This constructor equivalent to default constructor + // + "open", but throws on error. See docs for + // wspecifier above. + TableWriter(const std::string &wspecifier); + + // Opens the table. See docs for wspecifier above. + // If it returns true, it is open. + bool Open(const std::string &wspecifier); + + // Returns true if open for writing. + bool IsOpen() const; + + // Write the object. Throws std::runtime_error on error (via the + // KALDI_ERR macro) + inline void Write(const std::string &key, const T &value) const; + + + // Flush will flush any archive; it does not return error status + // or throw, any errors will be reported on the next Write or Close. + // Useful if we may be writing to a command in a pipe and want + // to ensure good CPU utilization. + void Flush(); + + // Close() is not necessary to call, as the destructor + // closes it; it's mainly useful if you want to handle + // error states because the destructor will throw on + // error if you do not call Close(). + bool Close(); + + ~TableWriter(); + + // Allow copy-constructor only for non-opened writers (needed for inclusion in + // stl vector) + TableWriter(const TableWriter &other): impl_(NULL) { + KALDI_ASSERT(other.impl_ == NULL); + } + private: + TableWriter &operator = (const TableWriter&); // Disallow assignment. + void CheckImpl() const; // Checks that impl_ is non-NULL; prints an error + // message and dies (with KALDI_ERR) if NULL. + TableWriterImplBase<Holder> *impl_; +}; + + +/// This class is for when you are reading something in random access, but +/// it may actually be stored per-speaker (or something similar) but the +/// keys you're using are per utterance. So you also provide an "rxfilename" +/// for a file containing lines like +/// utt1 spk1 +/// utt2 spk1 +/// utt3 spk1 +/// and so on. Note: this is optional; if it is an empty string, we just won't +/// do the mapping. Also, "table_rxfilename" may be the empty string (as for +/// a regular table), in which case the table just won't be opened. +/// We provide only the most frequently used of the functions of RandomAccessTableReader. + +template<class Holder> +class RandomAccessTableReaderMapped { + public: + typedef typename Holder::T T; + /// Note: "utt2spk_rxfilename" will in the normal case be an rxfilename + /// for an utterance to speaker map, but this code is general; it accepts + /// a generic map. + RandomAccessTableReaderMapped(const std::string &table_rxfilename, + const std::string &utt2spk_rxfilename); + + RandomAccessTableReaderMapped() {}; + + /// Note: when calling Open, utt2spk_rxfilename may be empty. + bool Open(const std::string &table_rxfilename, + const std::string &utt2spk_rxfilename); + + bool HasKey(const std::string &key); + const T &Value(const std::string &key); + inline bool IsOpen() const { return reader_.IsOpen(); } + inline bool Close() { return reader_.Close(); } + + + + // The default copy-constructor will do what we want: it will crash + // for already-opened readers, by calling the member-variable copy-constructors. + private: + // Disallow assignment. + RandomAccessTableReaderMapped &operator=(const RandomAccessTableReaderMapped<Holder>&); + RandomAccessTableReader<Holder> reader_; + RandomAccessTableReader<TokenHolder> token_reader_; + std::string utt2spk_rxfilename_; // Used only in diagnostic messages. +}; + + +/// @} end "addtogroup table_group" +} // end namespace kaldi + +#include "kaldi-table-inl.h" + +#endif // KALDI_UTIL_KALDI_TABLE_H_ diff --git a/kaldi_io/src/kaldi/util/parse-options.h b/kaldi_io/src/kaldi/util/parse-options.h new file mode 100644 index 0000000..f563b54 --- /dev/null +++ b/kaldi_io/src/kaldi/util/parse-options.h @@ -0,0 +1,264 @@ +// util/parse-options.h + +// Copyright 2009-2011 Karel Vesely; Microsoft Corporation; +// Saarland University (Author: Arnab Ghoshal); +// Copyright 2012-2013 Frantisek Skala; Arnab Ghoshal + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#ifndef KALDI_UTIL_PARSE_OPTIONS_H_ +#define KALDI_UTIL_PARSE_OPTIONS_H_ + +#include <map> +#include <string> +#include <vector> + +#include "base/kaldi-common.h" +#include "itf/options-itf.h" + +namespace kaldi { + +/// The class ParseOptions is for parsing command-line options; see +/// \ref parse_options for more documentation. +class ParseOptions : public OptionsItf { + public: + explicit ParseOptions(const char *usage) : + print_args_(true), help_(false), usage_(usage), argc_(0), argv_(NULL), + prefix_(""), other_parser_(NULL) { +#ifndef _MSC_VER // This is just a convenient place to set the stderr to line + setlinebuf(stderr); // buffering mode, since it's called at program start. +#endif // This helps ensure different programs' output is not mixed up. + RegisterStandard("config", &config_, "Configuration file to read (this " + "option may be repeated)"); + RegisterStandard("print-args", &print_args_, + "Print the command line arguments (to stderr)"); + RegisterStandard("help", &help_, "Print out usage message"); + RegisterStandard("verbose", &g_kaldi_verbose_level, + "Verbose level (higher->more logging)"); + } + + /** + This is a constructor for the special case where some options are + registered with a prefix to avoid conflicts. The object thus created will + only be used temporarily to register an options class with the original + options parser (which is passed as the *other pointer) using the given + prefix. It should not be used for any other purpose, and the prefix must + not be the empty string. It seems to be the least bad way of implementing + options with prefixes at this point. + Example of usage is: + ParseOptions po; // original ParseOptions object + ParseOptions po_mfcc("mfcc", &po); // object with prefix. + MfccOptions mfcc_opts; + mfcc_opts.Register(&po_mfcc); + The options will now get registered as, e.g., --mfcc.frame-shift=10.0 + instead of just --frame-shift=10.0 + */ + ParseOptions(const std::string &prefix, OptionsItf *other); + + ~ParseOptions() {} + + // Methods from the interface + void Register(const std::string &name, + bool *ptr, const std::string &doc); + void Register(const std::string &name, + int32 *ptr, const std::string &doc); + void Register(const std::string &name, + uint32 *ptr, const std::string &doc); + void Register(const std::string &name, + float *ptr, const std::string &doc); + void Register(const std::string &name, + double *ptr, const std::string &doc); + void Register(const std::string &name, + std::string *ptr, const std::string &doc); + + /// If called after registering an option and before calling + /// Read(), disables that option from being used. Will crash + /// at runtime if that option had not been registered. + void DisableOption(const std::string &name); + + /// This one is used for registering standard parameters of all the programs + template<typename T> + void RegisterStandard(const std::string &name, + T *ptr, const std::string &doc); + + /** + Parses the command line options and fills the ParseOptions-registered + variables. This must be called after all the variables were registered!!! + + Initially the variables have implicit values, + then the config file values are set-up, + finally the command line vaues given. + Returns the first position in argv that was not used. + [typically not useful: use NumParams() and GetParam(). ] + */ + int Read(int argc, const char *const *argv); + + /// Prints the usage documentation [provided in the constructor]. + void PrintUsage(bool print_command_line = false); + /// Prints the actual configuration of all the registered variables + void PrintConfig(std::ostream &os); + + /// Reads the options values from a config file. Must be called after + /// registering all options. This is usually used internally after the + /// standard --config option is used, but it may also be called from a + /// program. + void ReadConfigFile(const std::string &filename); + + /// Number of positional parameters (c.f. argc-1). + int NumArgs() const; + + /// Returns one of the positional parameters; 1-based indexing for argc/argv + /// compatibility. Will crash if param is not >=1 and <=NumArgs(). + std::string GetArg(int param) const; + + std::string GetOptArg(int param) const { + return (param <= NumArgs() ? GetArg(param) : ""); + } + + /// The following function will return a possibly quoted and escaped + /// version of "str", according to the current shell. Currently + /// this is just hardwired to bash. It's useful for debug output. + static std::string Escape(const std::string &str); + + private: + /// Template to register various variable types, + /// used for program-specific parameters + template<typename T> + void RegisterTmpl(const std::string &name, T *ptr, const std::string &doc); + + // Following functions do just the datatype-specific part of the job + /// Register boolean variable + void RegisterSpecific(const std::string &name, const std::string &idx, + bool *b, const std::string &doc, bool is_standard); + /// Register int32 variable + void RegisterSpecific(const std::string &name, const std::string &idx, + int32 *i, const std::string &doc, bool is_standard); + /// Register unsinged int32 variable + void RegisterSpecific(const std::string &name, const std::string &idx, + uint32 *u, + const std::string &doc, bool is_standard); + /// Register float variable + void RegisterSpecific(const std::string &name, const std::string &idx, + float *f, const std::string &doc, bool is_standard); + /// Register double variable [useful as we change BaseFloat type]. + void RegisterSpecific(const std::string &name, const std::string &idx, + double *f, const std::string &doc, bool is_standard); + /// Register string variable + void RegisterSpecific(const std::string &name, const std::string &idx, + std::string *s, const std::string &doc, + bool is_standard); + + /// Does the actual job for both kinds of parameters + /// Does the common part of the job for all datatypes, + /// then calls RegisterSpecific + template<typename T> + void RegisterCommon(const std::string &name, + T *ptr, const std::string &doc, bool is_standard); + + /// SplitLongArg parses an argument of the form --a=b, --a=, or --a, + /// and sets "has_equal_sign" to true if an equals-sign was parsed.. + /// this is needed in order to correctly allow --x for a boolean option + /// x, and --y= for a string option y, and to disallow --x= and --y. + void SplitLongArg(std::string in, std::string *key, std::string *value, + bool *has_equal_sign); + + void NormalizeArgName(std::string *str); + + /// Set option with name "key" to "value"; will crash if can't do it. + /// "has_equal_sign" is used to allow --x for a boolean option x, + /// and --y=, for a string option y. + bool SetOption(const std::string &key, const std::string &value, + bool has_equal_sign); + + bool ToBool(std::string str); + int32 ToInt(std::string str); + uint32 ToUInt(std::string str); + float ToFloat(std::string str); + double ToDouble(std::string str); + + // maps for option variables + std::map<std::string, bool*> bool_map_; + std::map<std::string, int32*> int_map_; + std::map<std::string, uint32*> uint_map_; + std::map<std::string, float*> float_map_; + std::map<std::string, double*> double_map_; + std::map<std::string, std::string*> string_map_; + + /** + Structure for options' documentation + */ + struct DocInfo { + DocInfo() {} + DocInfo(const std::string &name, const std::string &usemsg) + : name_(name), use_msg_(usemsg), is_standard_(false) {} + DocInfo(const std::string &name, const std::string &usemsg, + bool is_standard) + : name_(name), use_msg_(usemsg), is_standard_(is_standard) {} + + std::string name_; + std::string use_msg_; + bool is_standard_; + }; + typedef std::map<std::string, DocInfo> DocMapType; + DocMapType doc_map_; ///< map for the documentation + + bool print_args_; ///< variable for the implicit --print-args parameter + bool help_; ///< variable for the implicit --help parameter + std::string config_; ///< variable for the implicit --config parameter + std::vector<std::string> positional_args_; + const char *usage_; + int argc_; + const char *const *argv_; + + /// These members are not normally used. They are only used when the object + /// is constructed with a prefix + std::string prefix_; + OptionsItf *other_parser_; +}; + +/// This template is provided for convenience in reading config classes from +/// files; this is not the standard way to read configuration options, but may +/// occasionally be needed. This function assumes the config has a function +/// "void Register(OptionsItf *po)" which it can call to register the +/// ParseOptions object. +template<class C> void ReadConfigFromFile(const std::string config_filename, + C *c) { + std::ostringstream usage_str; + usage_str << "Parsing config from " + << "from '" << config_filename << "'"; + ParseOptions po(usage_str.str().c_str()); + c->Register(&po); + po.ReadConfigFile(config_filename); +} + +/// This variant of the template ReadConfigFromFile is for if you need to read +/// two config classes from the same file. +template<class C1, class C2> void ReadConfigsFromFile(const std::string config_filename, + C1 *c1, C2 *c2) { + std::ostringstream usage_str; + usage_str << "Parsing config from " + << "from '" << config_filename << "'"; + ParseOptions po(usage_str.str().c_str()); + c1->Register(&po); + c2->Register(&po); + po.ReadConfigFile(config_filename); +} + + + +} // namespace kaldi + +#endif // KALDI_UTIL_PARSE_OPTIONS_H_ diff --git a/kaldi_io/src/kaldi/util/simple-io-funcs.h b/kaldi_io/src/kaldi/util/simple-io-funcs.h new file mode 100644 index 0000000..56573e4 --- /dev/null +++ b/kaldi_io/src/kaldi/util/simple-io-funcs.h @@ -0,0 +1,56 @@ +// util/simple-io-funcs.h + +// Copyright 2009-2011 Microsoft Corporation; Jan Silovsky + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at + +// http://www.apache.org/licenses/LICENSE-2.0 + +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. +#ifndef KALDI_UTIL_SIMPLE_IO_FUNCS_H_ +#define KALDI_UTIL_SIMPLE_IO_FUNCS_H_ + +#include "kaldi-io.h" + +// This header contains some utilities for reading some common, simple text formats: +// integers in files, one per line, and integers in files, possibly multiple per line. +// these are not really fully native Kaldi formats; they are mostly for small files that +// might be generated by scripts, and can be read all at one time. +// for longer files of this type, we would probably use the Table code. + +namespace kaldi { + +/// WriteToList attempts to write this list of integers, one per line, +/// to the given file, in text format. +/// returns true if succeeded. +bool WriteIntegerVectorSimple(std::string wxfilename, const std::vector<int32> &v); + +/// ReadFromList attempts to read this list of integers, one per line, +/// from the given file, in text format. +/// returns true if succeeded. +bool ReadIntegerVectorSimple(std::string rxfilename, std::vector<int32> *v); + +// This is a file format like: +// 1 2 +// 3 +// +// 4 5 6 +// etc. +bool WriteIntegerVectorVectorSimple(std::string wxfilename, const std::vector<std::vector<int32> > &v); + +bool ReadIntegerVectorVectorSimple(std::string rxfilename, std::vector<std::vector<int32> > *v); + + +} // end namespace kaldi. + + +#endif diff --git a/kaldi_io/src/kaldi/util/simple-options.h b/kaldi_io/src/kaldi/util/simple-options.h new file mode 100644 index 0000000..58816af --- /dev/null +++ b/kaldi_io/src/kaldi/util/simple-options.h @@ -0,0 +1,112 @@ +// util/simple-options.hh + +// Copyright 2013 Tanel Alumae, Tallinn University of Technology + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#ifndef KALDI_UTIL_SIMPLE_OPTIONS_H_ +#define KALDI_UTIL_SIMPLE_OPTIONS_H_ + +#include <map> +#include <string> +#include <vector> + +#include "base/kaldi-common.h" +#include "itf/options-itf.h" + +namespace kaldi { + + +/// The class SimpleOptions is an implementation of OptionsItf that allows +/// setting and getting option values programmatically, i.e., via getter +/// and setter methods. It doesn't provide any command line parsing functionality. +/// The class ParseOptions should be used for command-line options. +class SimpleOptions : public OptionsItf { + public: + SimpleOptions() { + } + + virtual ~SimpleOptions() { + } + + // Methods from the interface + void Register(const std::string &name, bool *ptr, const std::string &doc); + void Register(const std::string &name, int32 *ptr, const std::string &doc); + void Register(const std::string &name, uint32 *ptr, const std::string &doc); + void Register(const std::string &name, float *ptr, const std::string &doc); + void Register(const std::string &name, double *ptr, const std::string &doc); + void Register(const std::string &name, std::string *ptr, + const std::string &doc); + + // set option with the specified key, return true if successful + bool SetOption(const std::string &key, const bool &value); + bool SetOption(const std::string &key, const int32 &value); + bool SetOption(const std::string &key, const uint32 &value); + bool SetOption(const std::string &key, const float &value); + bool SetOption(const std::string &key, const double &value); + bool SetOption(const std::string &key, const std::string &value); + bool SetOption(const std::string &key, const char* value); + + // get option with the specified key and put to 'value', + // return true if successful + bool GetOption(const std::string &key, bool *value); + bool GetOption(const std::string &key, int32 *value); + bool GetOption(const std::string &key, uint32 *value); + bool GetOption(const std::string &key, float *value); + bool GetOption(const std::string &key, double *value); + bool GetOption(const std::string &key, std::string *value); + + enum OptionType { + kBool, + kInt32, + kUint32, + kFloat, + kDouble, + kString + }; + + struct OptionInfo { + OptionInfo(const std::string &doc, OptionType type) : + doc(doc), type(type) { + } + std::string doc; + OptionType type; + }; + + std::vector<std::pair<std::string, OptionInfo> > GetOptionInfoList(); + + /* + * Puts the type of the option with name 'key' in the argument 'type'. + * Return true if such option is found, false otherwise. + */ + bool GetOptionType(const std::string &key, OptionType *type); + + private: + + std::vector<std::pair<std::string, OptionInfo> > option_info_list_; + + // maps for option variables + std::map<std::string, bool*> bool_map_; + std::map<std::string, int32*> int_map_; + std::map<std::string, uint32*> uint_map_; + std::map<std::string, float*> float_map_; + std::map<std::string, double*> double_map_; + std::map<std::string, std::string*> string_map_; +}; + +} // namespace kaldi + +#endif // KALDI_UTIL_SIMPLE_OPTIONS_H_ diff --git a/kaldi_io/src/kaldi/util/stl-utils.h b/kaldi_io/src/kaldi/util/stl-utils.h new file mode 100644 index 0000000..12526ff --- /dev/null +++ b/kaldi_io/src/kaldi/util/stl-utils.h @@ -0,0 +1,327 @@ +// util/stl-utils.h + +// Copyright 2009-2011 Microsoft Corporation; Saarland University + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#ifndef KALDI_UTIL_STL_UTILS_H_ +#define KALDI_UTIL_STL_UTILS_H_ + +#include <algorithm> +#include <map> +#include <set> +#include <string> +#include <vector> +#include "base/kaldi-common.h" + +#ifdef _MSC_VER +#include <unordered_map> +#include <unordered_set> +using std::unordered_map; +using std::unordered_set; +#elif __cplusplus > 199711L || defined(__GXX_EXPERIMENTAL_CXX0X__) +#include <unordered_map> +#include <unordered_set> +using std::unordered_map; +using std::unordered_set; +#else +#include <tr1/unordered_map> +#include <tr1/unordered_set> +using std::tr1::unordered_map; +using std::tr1::unordered_set; +#endif + + +namespace kaldi { + +/// Sorts and uniq's (removes duplicates) from a vector. +template<typename T> +inline void SortAndUniq(std::vector<T> *vec) { + std::sort(vec->begin(), vec->end()); + vec->erase(std::unique(vec->begin(), vec->end()), vec->end()); +} + + +/// Returns true if the vector is sorted. +template<typename T> +inline bool IsSorted(const std::vector<T> &vec) { + typename std::vector<T>::const_iterator iter = vec.begin(), end = vec.end(); + if (iter == end) return true; + while (1) { + typename std::vector<T>::const_iterator next_iter = iter; + ++next_iter; + if (next_iter == end) return true; // end of loop and nothing out of order + if (*next_iter < *iter) return false; + iter = next_iter; + } +} + + +/// Returns true if the vector is sorted and contains each element +/// only once. +template<typename T> +inline bool IsSortedAndUniq(const std::vector<T> &vec) { + typename std::vector<T>::const_iterator iter = vec.begin(), end = vec.end(); + if (iter == end) return true; + while (1) { + typename std::vector<T>::const_iterator next_iter = iter; + ++next_iter; + if (next_iter == end) return true; // end of loop and nothing out of order + if (*next_iter <= *iter) return false; + iter = next_iter; + } +} + + +/// Removes duplicate elements from a sorted list. +template<typename T> +inline void Uniq(std::vector<T> *vec) { // must be already sorted. + KALDI_PARANOID_ASSERT(IsSorted(*vec)); + KALDI_ASSERT(vec); + vec->erase(std::unique(vec->begin(), vec->end()), vec->end()); +} + +/// Copies the elements of a set to a vector. +template<class T> +void CopySetToVector(const std::set<T> &s, std::vector<T> *v) { + // adds members of s to v, in sorted order from lowest to highest + // (because the set was in sorted order). + KALDI_ASSERT(v != NULL); + v->resize(s.size()); + typename std::set<T>::const_iterator siter = s.begin(), send = s.end(); + typename std::vector<T>::iterator viter = v->begin(); + for (; siter != send; ++siter, ++viter) { + *viter = *siter; + } +} + +template<class T> +void CopySetToVector(const unordered_set<T> &s, std::vector<T> *v) { + // adds members of s to v, in sorted order from lowest to highest + // (because the set was in sorted order). + KALDI_ASSERT(v != NULL); + v->resize(s.size()); + typename unordered_set<T>::const_iterator siter = s.begin(), send = s.end(); + typename std::vector<T>::iterator viter = v->begin(); + for (; siter != send; ++siter, ++viter) { + *viter = *siter; + } +} + + +/// Copies the (key, value) pairs in a map to a vector of pairs. +template<class A, class B> +void CopyMapToVector(const std::map<A, B> &m, + std::vector<std::pair<A, B> > *v) { + KALDI_ASSERT(v != NULL); + v->resize(m.size()); + typename std::map<A, B>::const_iterator miter = m.begin(), mend = m.end(); + typename std::vector<std::pair<A, B> >::iterator viter = v->begin(); + for (; miter != mend; ++miter, ++viter) { + *viter = std::make_pair(miter->first, miter->second); + // do it like this because of const casting. + } +} + +/// Copies the keys in a map to a vector. +template<class A, class B> +void CopyMapKeysToVector(const std::map<A, B> &m, std::vector<A> *v) { + KALDI_ASSERT(v != NULL); + v->resize(m.size()); + typename std::map<A, B>::const_iterator miter = m.begin(), mend = m.end(); + typename std::vector<A>::iterator viter = v->begin(); + for (; miter != mend; ++miter, ++viter) { + *viter = miter->first; + } +} + +/// Copies the values in a map to a vector. +template<class A, class B> +void CopyMapValuesToVector(const std::map<A, B> &m, std::vector<B> *v) { + KALDI_ASSERT(v != NULL); + v->resize(m.size()); + typename std::map<A, B>::const_iterator miter = m.begin(), mend = m.end(); + typename std::vector<B>::iterator viter = v->begin(); + for (; miter != mend; ++miter, ++viter) { + *viter = miter->second; + } +} + +/// Copies the keys in a map to a set. +template<class A, class B> +void CopyMapKeysToSet(const std::map<A, B> &m, std::set<A> *s) { + KALDI_ASSERT(s != NULL); + s->clear(); + typename std::map<A, B>::const_iterator miter = m.begin(), mend = m.end(); + for (; miter != mend; ++miter) { + s->insert(s->end(), miter->first); + } +} + +/// Copies the values in a map to a set. +template<class A, class B> +void CopyMapValuesToSet(const std::map<A, B> &m, std::set<B> *s) { + KALDI_ASSERT(s != NULL); + s->clear(); + typename std::map<A, B>::const_iterator miter = m.begin(), mend = m.end(); + for (; miter != mend; ++miter) + s->insert(s->end(), miter->second); +} + + +/// Copies the contents of a vector to a set. +template<class A> +void CopyVectorToSet(const std::vector<A> &v, std::set<A> *s) { + KALDI_ASSERT(s != NULL); + s->clear(); + typename std::vector<A>::const_iterator iter = v.begin(), end = v.end(); + for (; iter != end; ++iter) + s->insert(s->end(), *iter); + // s->end() is a hint in case v was sorted. will work regardless. +} + +/// Deletes any non-NULL pointers in the vector v, and sets +/// the corresponding entries of v to NULL +template<class A> +void DeletePointers(std::vector<A*> *v) { + KALDI_ASSERT(v != NULL); + typename std::vector<A*>::iterator iter = v->begin(), end = v->end(); + for (; iter != end; ++iter) { + if (*iter != NULL) { + delete *iter; + *iter = NULL; // set to NULL for extra safety. + } + } +} + +/// Returns true if the vector of pointers contains NULL pointers. +template<class A> +bool ContainsNullPointers(const std::vector<A*> &v) { + typename std::vector<A*>::const_iterator iter = v.begin(), end = v.end(); + for (; iter != end; ++iter) + if (*iter == static_cast<A*> (NULL)) return true; + return false; +} + +/// Copies the contents a vector of one type to a vector +/// of another type. +template<typename A, typename B> +void CopyVectorToVector(const std::vector<A> &vec_in, std::vector<B> *vec_out) { + KALDI_ASSERT(vec_out != NULL); + vec_out->resize(vec_in.size()); + for (size_t i = 0; i < vec_in.size(); i++) + (*vec_out)[i] = static_cast<B> (vec_in[i]); +} + +/// A hashing function-object for vectors. +template<typename Int> +struct VectorHasher { // hashing function for vector<Int>. + size_t operator()(const std::vector<Int> &x) const { + size_t ans = 0; + typename std::vector<Int>::const_iterator iter = x.begin(), end = x.end(); + for (; iter != end; ++iter) { + ans *= kPrime; + ans += *iter; + } + return ans; + } + VectorHasher() { // Check we're instantiated with an integer type. + KALDI_ASSERT_IS_INTEGER_TYPE(Int); + } + private: + static const int kPrime = 7853; +}; + +/// A hashing function-object for pairs of ints +template<typename Int> +struct PairHasher { // hashing function for pair<int> + size_t operator()(const std::pair<Int,Int> &x) const { + return x.first + x.second * kPrime; + } + PairHasher() { // Check we're instantiated with an integer type. + KALDI_ASSERT_IS_INTEGER_TYPE(Int); + } + private: + static const int kPrime = 7853; +}; + + +/// A hashing function object for strings. +struct StringHasher { // hashing function for std::string + size_t operator()(const std::string &str) const { + size_t ans = 0, len = str.length(); + const char *c = str.c_str(), *end = c + len; + for (; c != end; c++) { + ans *= kPrime; + ans += *c; + } + return ans; + } + private: + static const int kPrime = 7853; +}; + +/// Reverses the contents of a vector. +template<typename T> +inline void ReverseVector(std::vector<T> *vec) { + KALDI_ASSERT(vec != NULL); + size_t sz = vec->size(); + for (size_t i = 0; i < sz/2; i++) + std::swap( (*vec)[i], (*vec)[sz-1-i]); +} + + +/// Comparator object for pairs that compares only the first pair. +template<class A, class B> +struct CompareFirstMemberOfPair { + inline bool operator() (const std::pair<A, B> &p1, + const std::pair<A, B> &p2) { + return p1.first < p2.first; + } +}; + +/// For a vector of pair<I, F> where I is an integer and F a floating-point or +/// integer type, this function sorts a vector of type vector<pair<I, F> > on +/// the I value and then merges elements with equal I values, summing these over +/// the F component and then removing any F component with zero value. This +/// is for where the vector of pairs represents a map from the integer to float +/// component, with an "adding" type of semantics for combining the elements. +template<typename I, typename F> +inline void MergePairVectorSumming(std::vector<std::pair<I, F> > *vec) { + KALDI_ASSERT_IS_INTEGER_TYPE(I); + CompareFirstMemberOfPair<I, F> c; + std::sort(vec->begin(), vec->end(), c); // sort on 1st element. + typename std::vector<std::pair<I, F> >::iterator out = vec->begin(), + in = vec->begin(), end = vec->end(); + while (in < end) { + // We reach this point only at the first element of + // each stretch of identical .first elements. + *out = *in; + ++in; + while (in < end && in->first == out->first) { + out->second += in->second; // this is the merge operation. + ++in; + } + if (out->second != static_cast<F>(0)) // Don't keep zero elements. + out++; + } + vec->erase(out, end); +} + +} // namespace kaldi + +#endif // KALDI_UTIL_STL_UTILS_H_ + diff --git a/kaldi_io/src/kaldi/util/table-types.h b/kaldi_io/src/kaldi/util/table-types.h new file mode 100644 index 0000000..313d1aa --- /dev/null +++ b/kaldi_io/src/kaldi/util/table-types.h @@ -0,0 +1,137 @@ +// util/table-types.h + +// Copyright 2009-2011 Microsoft Corporation + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + + +#ifndef KALDI_UTIL_TABLE_TYPES_H_ +#define KALDI_UTIL_TABLE_TYPES_H_ +#include "base/kaldi-common.h" +#include "util/kaldi-table.h" +#include "util/kaldi-holder.h" +#include "matrix/matrix-lib.h" + +namespace kaldi { + +// This header defines typedefs that are specific instantiations of +// the Table types. + +/// \addtogroup table_types +/// @{ + +typedef TableWriter<KaldiObjectHolder<Matrix<BaseFloat> > > BaseFloatMatrixWriter; +typedef SequentialTableReader<KaldiObjectHolder<Matrix<BaseFloat> > > SequentialBaseFloatMatrixReader; +typedef RandomAccessTableReader<KaldiObjectHolder<Matrix<BaseFloat> > > RandomAccessBaseFloatMatrixReader; +typedef RandomAccessTableReaderMapped<KaldiObjectHolder<Matrix<BaseFloat> > > RandomAccessBaseFloatMatrixReaderMapped; + +typedef TableWriter<KaldiObjectHolder<Matrix<double> > > DoubleMatrixWriter; +typedef SequentialTableReader<KaldiObjectHolder<Matrix<double> > > SequentialDoubleMatrixReader; +typedef RandomAccessTableReader<KaldiObjectHolder<Matrix<double> > > RandomAccessDoubleMatrixReader; +typedef RandomAccessTableReaderMapped<KaldiObjectHolder<Matrix<double> > > RandomAccessDoubleMatrixReaderMapped; + +typedef TableWriter<KaldiObjectHolder<CompressedMatrix> > CompressedMatrixWriter; + +typedef TableWriter<KaldiObjectHolder<Vector<BaseFloat> > > BaseFloatVectorWriter; +typedef SequentialTableReader<KaldiObjectHolder<Vector<BaseFloat> > > SequentialBaseFloatVectorReader; +typedef RandomAccessTableReader<KaldiObjectHolder<Vector<BaseFloat> > > RandomAccessBaseFloatVectorReader; +typedef RandomAccessTableReaderMapped<KaldiObjectHolder<Vector<BaseFloat> > > RandomAccessBaseFloatVectorReaderMapped; + +typedef TableWriter<KaldiObjectHolder<Vector<double> > > DoubleVectorWriter; +typedef SequentialTableReader<KaldiObjectHolder<Vector<double> > > SequentialDoubleVectorReader; +typedef RandomAccessTableReader<KaldiObjectHolder<Vector<double> > > RandomAccessDoubleVectorReader; + +typedef TableWriter<KaldiObjectHolder<CuMatrix<BaseFloat> > > BaseFloatCuMatrixWriter; +typedef SequentialTableReader<KaldiObjectHolder<CuMatrix<BaseFloat> > > SequentialBaseFloatCuMatrixReader; +typedef RandomAccessTableReader<KaldiObjectHolder<CuMatrix<BaseFloat> > > RandomAccessBaseFloatCuMatrixReader; +typedef RandomAccessTableReaderMapped<KaldiObjectHolder<CuMatrix<BaseFloat> > > RandomAccessBaseFloatCuMatrixReaderMapped; + +typedef TableWriter<KaldiObjectHolder<CuMatrix<double> > > DoubleCuMatrixWriter; +typedef SequentialTableReader<KaldiObjectHolder<CuMatrix<double> > > SequentialDoubleCuMatrixReader; +typedef RandomAccessTableReader<KaldiObjectHolder<CuMatrix<double> > > RandomAccessDoubleCuMatrixReader; +typedef RandomAccessTableReaderMapped<KaldiObjectHolder<CuMatrix<double> > > RandomAccessDoubleCuMatrixReaderMapped; + +typedef TableWriter<KaldiObjectHolder<CuVector<BaseFloat> > > BaseFloatCuVectorWriter; +typedef SequentialTableReader<KaldiObjectHolder<CuVector<BaseFloat> > > SequentialBaseFloatCuVectorReader; +typedef RandomAccessTableReader<KaldiObjectHolder<CuVector<BaseFloat> > > RandomAccessBaseFloatCuVectorReader; +typedef RandomAccessTableReaderMapped<KaldiObjectHolder<CuVector<BaseFloat> > > RandomAccessBaseFloatCuVectorReaderMapped; + +typedef TableWriter<KaldiObjectHolder<CuVector<double> > > DoubleCuVectorWriter; +typedef SequentialTableReader<KaldiObjectHolder<CuVector<double> > > SequentialDoubleCuVectorReader; +typedef RandomAccessTableReader<KaldiObjectHolder<CuVector<double> > > RandomAccessDoubleCuVectorReader; + + +typedef TableWriter<BasicHolder<int32> > Int32Writer; +typedef SequentialTableReader<BasicHolder<int32> > SequentialInt32Reader; +typedef RandomAccessTableReader<BasicHolder<int32> > RandomAccessInt32Reader; + +typedef TableWriter<BasicVectorHolder<int32> > Int32VectorWriter; +typedef SequentialTableReader<BasicVectorHolder<int32> > SequentialInt32VectorReader; +typedef RandomAccessTableReader<BasicVectorHolder<int32> > RandomAccessInt32VectorReader; + +typedef TableWriter<BasicVectorVectorHolder<int32> > Int32VectorVectorWriter; +typedef SequentialTableReader<BasicVectorVectorHolder<int32> > SequentialInt32VectorVectorReader; +typedef RandomAccessTableReader<BasicVectorVectorHolder<int32> > RandomAccessInt32VectorVectorReader; + +typedef TableWriter<BasicPairVectorHolder<int32> > Int32PairVectorWriter; +typedef SequentialTableReader<BasicPairVectorHolder<int32> > SequentialInt32PairVectorReader; +typedef RandomAccessTableReader<BasicPairVectorHolder<int32> > RandomAccessInt32PairVectorReader; + +typedef TableWriter<BasicPairVectorHolder<BaseFloat> > BaseFloatPairVectorWriter; +typedef SequentialTableReader<BasicPairVectorHolder<BaseFloat> > SequentialBaseFloatPairVectorReader; +typedef RandomAccessTableReader<BasicPairVectorHolder<BaseFloat> > RandomAccessBaseFloatPairVectorReader; + +typedef TableWriter<BasicHolder<BaseFloat> > BaseFloatWriter; +typedef SequentialTableReader<BasicHolder<BaseFloat> > SequentialBaseFloatReader; +typedef RandomAccessTableReader<BasicHolder<BaseFloat> > RandomAccessBaseFloatReader; +typedef RandomAccessTableReaderMapped<BasicHolder<BaseFloat> > RandomAccessBaseFloatReaderMapped; + +typedef TableWriter<BasicHolder<double> > DoubleWriter; +typedef SequentialTableReader<BasicHolder<double> > SequentialDoubleReader; +typedef RandomAccessTableReader<BasicHolder<double> > RandomAccessDoubleReader; + +typedef TableWriter<BasicHolder<bool> > BoolWriter; +typedef SequentialTableReader<BasicHolder<bool> > SequentialBoolReader; +typedef RandomAccessTableReader<BasicHolder<bool> > RandomAccessBoolReader; + + + +/// TokenWriter is a writer specialized for std::string where the strings +/// are nonempty and whitespace-free. T == std::string +typedef TableWriter<TokenHolder> TokenWriter; +typedef SequentialTableReader<TokenHolder> SequentialTokenReader; +typedef RandomAccessTableReader<TokenHolder> RandomAccessTokenReader; + + +/// TokenVectorWriter is a writer specialized for sequences of +/// std::string where the strings are nonempty and whitespace-free. +/// T == std::vector<std::string> +typedef TableWriter<TokenVectorHolder> TokenVectorWriter; +// Ditto for SequentialTokenVectorReader. +typedef SequentialTableReader<TokenVectorHolder> SequentialTokenVectorReader; +typedef RandomAccessTableReader<TokenVectorHolder> RandomAccessTokenVectorReader; + + +/// @} + +// Note: for FST reader/writer, see ../fstext/fstext-utils.h +// [not done yet]. + +} // end namespace kaldi + + + +#endif diff --git a/kaldi_io/src/kaldi/util/text-utils.h b/kaldi_io/src/kaldi/util/text-utils.h new file mode 100644 index 0000000..1d85c47 --- /dev/null +++ b/kaldi_io/src/kaldi/util/text-utils.h @@ -0,0 +1,169 @@ +// util/text-utils.h + +// Copyright 2009-2011 Saarland University; Microsoft Corporation + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#ifndef KALDI_UTIL_TEXT_UTILS_H_ +#define KALDI_UTIL_TEXT_UTILS_H_ + +#include <algorithm> +#include <map> +#include <set> +#include <string> +#include <vector> +#include <errno.h> + +#include "base/kaldi-common.h" + +namespace kaldi { + +/// Split a string using any of the single character delimiters. +/// If omit_empty_strings == true, the output will contain any +/// nonempty strings after splitting on any of the +/// characters in the delimiter. If omit_empty_strings == false, +/// the output will contain n+1 strings if there are n characters +/// in the set "delim" within the input string. In this case +/// the empty string is split to a single empty string. +void SplitStringToVector(const std::string &full, const char *delim, + bool omit_empty_strings, + std::vector<std::string> *out); + +/// Joins the elements of a vector of strings into a single string using +/// "delim" as the delimiter. If omit_empty_strings == true, any empty strings +/// in the vector are skipped. A vector of empty strings results in an empty +/// string on the output. +void JoinVectorToString(const std::vector<std::string> &vec_in, + const char *delim, bool omit_empty_strings, + std::string *str_out); + + +/// Split a string (e.g. 1:2:3) into a vector of integers. +/// The delimiting char may be any character in "delim". +/// returns true on success, false on failure. +/// If omit_empty_strings == true, 1::2:3: will become +/// { 1, 2, 3 }. Otherwise it would be rejected. +/// Regardless of the value of omit_empty_strings, +/// the empty string is successfully parsed as an empty +/// vector of integers +template<class I> +bool SplitStringToIntegers(const std::string &full, + const char *delim, + bool omit_empty_strings, // typically false [but + // should probably be true + // if "delim" is spaces]. + std::vector<I> *out) { + KALDI_ASSERT(out != NULL); + KALDI_ASSERT_IS_INTEGER_TYPE(I); + if ( *(full.c_str()) == '\0') { + out->clear(); + return true; + } + std::vector<std::string> split; + SplitStringToVector(full, delim, omit_empty_strings, &split); + out->resize(split.size()); + for (size_t i = 0; i < split.size(); i++) { + const char *this_str = split[i].c_str(); + char *end = NULL; + long long int j = 0; + j = KALDI_STRTOLL(this_str, &end); + if (end == this_str || *end != '\0') { + out->clear(); + return false; + } else { + I jI = static_cast<I>(j); + if (static_cast<long long int>(jI) != j) { + // output type cannot fit this integer. + out->clear(); + return false; + } + (*out)[i] = jI; + } + } + return true; +} + +// This is defined for F = float and double. +template<class F> +bool SplitStringToFloats(const std::string &full, + const char *delim, + bool omit_empty_strings, // typically false + std::vector<F> *out); + + +/// Converts a string into an integer via strtoll and returns false if there was +/// any kind of problem (i.e. the string was not an integer or contained extra +/// non-whitespace junk, or the integer was too large to fit into the type it is +/// being converted into). Only sets *out if everything was OK and it returns +/// true. +template<class Int> +bool ConvertStringToInteger(const std::string &str, + Int *out) { + KALDI_ASSERT_IS_INTEGER_TYPE(Int); + const char *this_str = str.c_str(); + char *end = NULL; + errno = 0; + long long int i = KALDI_STRTOLL(this_str, &end); + if (end != this_str) + while (isspace(*end)) end++; + if (end == this_str || *end != '\0' || errno != 0) + return false; + Int iInt = static_cast<Int>(i); + if (static_cast<long long int>(iInt) != i || (i<0 && !std::numeric_limits<Int>::is_signed)) { + return false; + } + *out = iInt; + return true; +} + + +/// ConvertStringToReal converts a string into either float or double via strtod, +/// and returns false if there was any kind of problem (i.e. the string was not a +/// floating point number or contained extra non-whitespace junk. +/// Be careful- this function will successfully read inf's or nan's. +bool ConvertStringToReal(const std::string &str, + double *out); +bool ConvertStringToReal(const std::string &str, + float *out); + + +/// Removes the beginning and trailing whitespaces from a string +void Trim(std::string *str); + + +/// Removes leading and trailing white space from the string, then splits on the +/// first section of whitespace found (if present), putting the part before the +/// whitespace in "first" and the rest in "rest". If there is no such space, +/// everything that remains after removing leading and trailing whitespace goes +/// in "first". +void SplitStringOnFirstSpace(const std::string &line, + std::string *first, + std::string *rest); + + +/// Returns true if "token" is nonempty, and all characters are +/// printable and whitespace-free. +bool IsToken(const std::string &token); + + +/// Returns true if "line" is free of \n characters and unprintable +/// characters, and does not contain leading or trailing whitespace. +bool IsLine(const std::string &line); + + +} // namespace kaldi + +#endif // KALDI_UTIL_TEXT_UTILS_H_ diff --git a/kaldi_io/src/kaldi/util/timer.h b/kaldi_io/src/kaldi/util/timer.h new file mode 100644 index 0000000..e3ee8d5 --- /dev/null +++ b/kaldi_io/src/kaldi/util/timer.h @@ -0,0 +1,27 @@ +// util/timer.h + +// Copyright 2014 Johns Hopkins University (author: Daniel Povey) + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at + +// http://www.apache.org/licenses/LICENSE-2.0 + +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +// We are temporarily leaving this file to forward #includes to +// base-timer.h. Its use is deprecated; you should directrly +// #include base/timer.h +#ifndef KALDI_UTIL_TIMER_H_ +#define KALDI_UTIL_TIMER_H_ +#pragma message warning: please do not include util/timer.h, include base/timer.h (it has been moved) +#include "base/timer.h" +#endif diff --git a/kaldi_io/src/test.c b/kaldi_io/src/test.c new file mode 100644 index 0000000..e92b4c9 --- /dev/null +++ b/kaldi_io/src/test.c @@ -0,0 +1,48 @@ +/********************************************************************************* +* File Name : test.c +* Created By : YIMMON, [email protected] +* Creation Date : [2015-08-05 17:39] +* Last Modified : [2015-08-06 14:28] +* Description : +**********************************************************************************/ + +#include "cwrapper_kaldi.h" +#include <stdio.h> + +char feature_rspecifier[] = {"ark:/slfs6/users/ymz09/kaldi/src/featbin/copy-feats scp:/slfs6/users/ymz09/swb_ivec/train_bp.scp ark:- |"}; + +void print_nerv_matrix(Matrix *mat) { + int n = mat->nrow; + int m = mat->ncol; + int i, j; + size_t stride = mat->stride; + for (i = 0; i < n; i++) + { + float *nerv_row = (float *)((char *)mat->data.f + i * stride); + for (j = 0; j < m; j++) + printf("%.8f ", nerv_row[j]); + puts(""); + } +} + +int main(int argc, char *argv[]) +{ + Matrix *mat; + KaldiFeatureRepo *repo = kaldi_feature_repo_new(feature_rspecifier); + + mat = kaldi_feature_repo_read_utterance(repo, NULL, 1); + printf("1st uttrance\n"); + print_nerv_matrix(mat); + + kaldi_feature_repo_next(repo); + + mat = kaldi_feature_repo_read_utterance(repo, NULL, 1); + printf("2nd uttrance\n"); + print_nerv_matrix(mat); + + printf("is end: %d\n", kaldi_feature_repo_is_end(repo)); + + kaldi_feature_repo_destroy(repo); + + return 0; +} diff --git a/kaldi_io/src/tools/ATLAS/include/atlas_altivec.h b/kaldi_io/src/tools/ATLAS/include/atlas_altivec.h new file mode 100644 index 0000000..a772448 --- /dev/null +++ b/kaldi_io/src/tools/ATLAS/include/atlas_altivec.h @@ -0,0 +1,27 @@ +#ifndef ATLAS_ALTIVEC_H + #define ATLAS_ALTIVEC_H + +#ifdef ATL_AltiVec + #ifdef ATL_AVgcc + #include <altivec.h> + + #define VECTOR_INIT(v0_,v1_,v2_,v3_) (vector float) {v0_,v1_,v2_,v3_} + #define VECTOR_INITI(v0_,v1_,v2_,v3_) (vector int) {v0_,v1_,v2_,v3_} + #else + #define VECTOR_INIT(v0_,v1_,v2_,v3_) (vector float)(v0_,v1_,v2_,v3_) + #define VECTOR_INITI(v0_,v1_,v2_,v3_) (vector int)(v0_,v1_,v2_,v3_) + #define VECTOR_INITL(v0_,v1_,v2_,v3_) (vector long)(v0_,v1_,v2_,v3_) + #endif + #define ATL_GetCtrl(stride, count, size) \ + (int)((stride) | ((count)<<16) | ((size)<<24)) + #define ATL_pfavR(ptr, cwrd, stream) \ + vec_dst((vector float *)(ptr), (cwrd), (stream)) + #define ATL_pfavW(ptr, cwrd, stream) \ + vec_dstst((vector float *)(ptr), (cwrd), (stream)) +#else + #define ATL_GetCtrl(stride, count, size) + #define ATL_pfavR(ptr, cwrd, stream) + #define ATL_pfavW(ptr, cwrd, stream) +#endif + +#endif diff --git a/kaldi_io/src/tools/ATLAS/include/atlas_asm.h b/kaldi_io/src/tools/ATLAS/include/atlas_asm.h new file mode 100644 index 0000000..4c4fa86 --- /dev/null +++ b/kaldi_io/src/tools/ATLAS/include/atlas_asm.h @@ -0,0 +1,411 @@ +#ifndef ATLAS_ASM_H + #define ATLAS_ASM_H + +#ifndef Mjoin + #define Mjoin(pre, nam) my_join(pre, nam) + #define my_join(pre, nam) pre ## nam +#endif + +#if defined(ATL_OS_WinNT) || defined(ATL_OS_Win9x) || defined(ATL_OS_OSX) + #define ATL_asmdecor(nam) Mjoin(_,nam) +#elif defined(ATL_OS_AIX) && defined(ATL_GAS_PPC) + #define ATL_asmdecor(nam) Mjoin(.,nam) +#elif !defined(ATL_OS_OSX) && defined(ATL_GAS_PPC) && defined(ATL_USE64BITS) + #define ATL_asmdecor(nam) Mjoin(.,nam) +#else + #define ATL_asmdecor(nam) nam +#endif + +#ifdef ATL_GAS_PARISC + #ifdef ATL_OS_HPUX + #define ATL_HPUX_PARISC + #else + #define ATL_LINUX_PARISC + #endif +#endif + +#ifdef ATL_GAS_PPC + #ifdef ATL_OS_OSX + #define ATL_AS_OSX_PPC + #elif defined(ATL_OS_AIX) + #define ATL_AS_AIX_PPC + #else + #define ATL_GAS_LINUX_PPC + #endif +#endif + +#if defined(ATL_GAS_LINUX_PPC) || defined(ATL_AS_AIX_PPC) + + #define r0 0 + #define f0 0 + #define r1 1 + #define f1 1 + #define r2 2 + #define f2 2 + #define r3 3 + #define f3 3 + #define r4 4 + #define f4 4 + #define r5 5 + #define f5 5 + #define r6 6 + #define f6 6 + #define r7 7 + #define f7 7 + #define r8 8 + #define f8 8 + #define r9 9 + #define f9 9 + #define r10 10 + #define f10 10 + #define r11 11 + #define f11 11 + #define r12 12 + #define f12 12 + #define r13 13 + #define f13 13 + #define r14 14 + #define f14 14 + #define r15 15 + #define f15 15 + #define r16 16 + #define f16 16 + #define r17 17 + #define f17 17 + #define r18 18 + #define f18 18 + #define r19 19 + #define f19 19 + #define r20 20 + #define f20 20 + #define r21 21 + #define f21 21 + #define r22 22 + #define f22 22 + #define r23 23 + #define f23 23 + #define r24 24 + #define f24 24 + #define r25 25 + #define f25 25 + #define r26 26 + #define f26 26 + #define r27 27 + #define f27 27 + #define r28 28 + #define f28 28 + #define r29 29 + #define f29 29 + #define r30 30 + #define f30 30 + #define r31 31 + #define f31 31 + #define cr0 0 + #define cr1 1 + #define cr2 2 + #define cr3 3 + #define cr4 4 + #define cr5 5 + #define cr6 6 + #define cr7 7 + +#endif + +#ifdef ATL_OS_OSX + #define ALIGN2 .align 1 + #define ALIGN4 .align 2 + #define ALIGN8 .align 3 + #define ALIGN16 .align 4 + #define ALIGN32 .align 5 + #define ALIGN64 .align 6 + #define ALIGN128 .align 7 + #define global globl +#else + #define ALIGN2 .align 2 + #define ALIGN4 .align 4 + #define ALIGN8 .align 8 + #define ALIGN16 .align 16 + #define ALIGN32 .align 32 + #define ALIGN64 .align 64 + #define ALIGN128 .align 128 +#endif + +#if defined(ATL_SSE1) && !defined(ATL_3DNow) + #define prefetchw prefetchnta +#endif +/* + * Solaris doesn't allow division in integer expressions in assembly, but + * many x86 kernels need to do $MB/mu; we work around this insanity with + * this kludge + */ +#if defined(ATL_DIV_NUM) && defined(ATL_DIV_DEN) + #if (ATL_DIV_NUM/ATL_DIV_DEN) == 0 + #define ATL_DivAns 0 + #elif (ATL_DIV_NUM/ATL_DIV_DEN) == 1 + #define ATL_DivAns 1 + #elif (ATL_DIV_NUM/ATL_DIV_DEN) == 2 + #define ATL_DivAns 2 + #elif (ATL_DIV_NUM/ATL_DIV_DEN) == 3 + #define ATL_DivAns 3 + #elif (ATL_DIV_NUM/ATL_DIV_DEN) == 4 + #define ATL_DivAns 4 + #elif (ATL_DIV_NUM/ATL_DIV_DEN) == 5 + #define ATL_DivAns 5 + #elif (ATL_DIV_NUM/ATL_DIV_DEN) == 6 + #define ATL_DivAns 6 + #elif (ATL_DIV_NUM/ATL_DIV_DEN) == 7 + #define ATL_DivAns 7 + #elif (ATL_DIV_NUM/ATL_DIV_DEN) == 8 + #define ATL_DivAns 8 + #elif (ATL_DIV_NUM/ATL_DIV_DEN) == 9 + #define ATL_DivAns 9 + #elif (ATL_DIV_NUM/ATL_DIV_DEN) == 10 + #define ATL_DivAns 10 + #elif (ATL_DIV_NUM/ATL_DIV_DEN) == 11 + #define ATL_DivAns 11 + #elif (ATL_DIV_NUM/ATL_DIV_DEN) == 12 + #define ATL_DivAns 12 + #elif (ATL_DIV_NUM/ATL_DIV_DEN) == 13 + #define ATL_DivAns 13 + #elif (ATL_DIV_NUM/ATL_DIV_DEN) == 14 + #define ATL_DivAns 14 + #elif (ATL_DIV_NUM/ATL_DIV_DEN) == 15 + #define ATL_DivAns 15 + #elif (ATL_DIV_NUM/ATL_DIV_DEN) == 16 + #define ATL_DivAns 16 + #elif (ATL_DIV_NUM/ATL_DIV_DEN) == 17 + #define ATL_DivAns 17 + #elif (ATL_DIV_NUM/ATL_DIV_DEN) == 18 + #define ATL_DivAns 18 + #elif (ATL_DIV_NUM/ATL_DIV_DEN) == 19 + #define ATL_DivAns 19 + #elif (ATL_DIV_NUM/ATL_DIV_DEN) == 20 + #define ATL_DivAns 20 + #elif (ATL_DIV_NUM/ATL_DIV_DEN) == 21 + #define ATL_DivAns 21 + #elif (ATL_DIV_NUM/ATL_DIV_DEN) == 22 + #define ATL_DivAns 22 + #elif (ATL_DIV_NUM/ATL_DIV_DEN) == 23 + #define ATL_DivAns 23 + #elif (ATL_DIV_NUM/ATL_DIV_DEN) == 24 + #define ATL_DivAns 24 + #elif (ATL_DIV_NUM/ATL_DIV_DEN) == 25 + #define ATL_DivAns 25 + #elif (ATL_DIV_NUM/ATL_DIV_DEN) == 26 + #define ATL_DivAns 26 + #elif (ATL_DIV_NUM/ATL_DIV_DEN) == 27 + #define ATL_DivAns 27 + #elif (ATL_DIV_NUM/ATL_DIV_DEN) == 28 + #define ATL_DivAns 28 + #elif (ATL_DIV_NUM/ATL_DIV_DEN) == 29 + #define ATL_DivAns 29 + #elif (ATL_DIV_NUM/ATL_DIV_DEN) == 30 + #define ATL_DivAns 30 + #elif (ATL_DIV_NUM/ATL_DIV_DEN) == 31 + #define ATL_DivAns 31 + #elif (ATL_DIV_NUM/ATL_DIV_DEN) == 32 + #define ATL_DivAns 32 + #elif (ATL_DIV_NUM/ATL_DIV_DEN) == 33 + #define ATL_DivAns 33 + #elif (ATL_DIV_NUM/ATL_DIV_DEN) == 34 + #define ATL_DivAns 34 + #elif (ATL_DIV_NUM/ATL_DIV_DEN) == 35 + #define ATL_DivAns 35 + #elif (ATL_DIV_NUM/ATL_DIV_DEN) == 36 + #define ATL_DivAns 36 + #elif (ATL_DIV_NUM/ATL_DIV_DEN) == 37 + #define ATL_DivAns 37 + #elif (ATL_DIV_NUM/ATL_DIV_DEN) == 38 + #define ATL_DivAns 38 + #elif (ATL_DIV_NUM/ATL_DIV_DEN) == 39 + #define ATL_DivAns 39 + #elif (ATL_DIV_NUM/ATL_DIV_DEN) == 40 + #define ATL_DivAns 40 + #elif (ATL_DIV_NUM/ATL_DIV_DEN) == 41 + #define ATL_DivAns 41 + #elif (ATL_DIV_NUM/ATL_DIV_DEN) == 42 + #define ATL_DivAns 42 + #elif (ATL_DIV_NUM/ATL_DIV_DEN) == 43 + #define ATL_DivAns 43 + #elif (ATL_DIV_NUM/ATL_DIV_DEN) == 44 + #define ATL_DivAns 44 + #elif (ATL_DIV_NUM/ATL_DIV_DEN) == 45 + #define ATL_DivAns 45 + #elif (ATL_DIV_NUM/ATL_DIV_DEN) == 46 + #define ATL_DivAns 46 + #elif (ATL_DIV_NUM/ATL_DIV_DEN) == 47 + #define ATL_DivAns 47 + #elif (ATL_DIV_NUM/ATL_DIV_DEN) == 48 + #define ATL_DivAns 48 + #elif (ATL_DIV_NUM/ATL_DIV_DEN) == 49 + #define ATL_DivAns 49 + #elif (ATL_DIV_NUM/ATL_DIV_DEN) == 50 + #define ATL_DivAns 50 + #elif (ATL_DIV_NUM/ATL_DIV_DEN) == 51 + #define ATL_DivAns 51 + #elif (ATL_DIV_NUM/ATL_DIV_DEN) == 52 + #define ATL_DivAns 52 + #elif (ATL_DIV_NUM/ATL_DIV_DEN) == 53 + #define ATL_DivAns 53 + #elif (ATL_DIV_NUM/ATL_DIV_DEN) == 54 + #define ATL_DivAns 54 + #elif (ATL_DIV_NUM/ATL_DIV_DEN) == 55 + #define ATL_DivAns 55 + #elif (ATL_DIV_NUM/ATL_DIV_DEN) == 56 + #define ATL_DivAns 56 + #elif (ATL_DIV_NUM/ATL_DIV_DEN) == 57 + #define ATL_DivAns 57 + #elif (ATL_DIV_NUM/ATL_DIV_DEN) == 58 + #define ATL_DivAns 58 + #elif (ATL_DIV_NUM/ATL_DIV_DEN) == 59 + #define ATL_DivAns 59 + #elif (ATL_DIV_NUM/ATL_DIV_DEN) == 60 + #define ATL_DivAns 60 + #elif (ATL_DIV_NUM/ATL_DIV_DEN) == 61 + #define ATL_DivAns 61 + #elif (ATL_DIV_NUM/ATL_DIV_DEN) == 62 + #define ATL_DivAns 62 + #elif (ATL_DIV_NUM/ATL_DIV_DEN) == 63 + #define ATL_DivAns 63 + #elif (ATL_DIV_NUM/ATL_DIV_DEN) == 64 + #define ATL_DivAns 64 + #elif (ATL_DIV_NUM/ATL_DIV_DEN) == 65 + #define ATL_DivAns 65 + #elif (ATL_DIV_NUM/ATL_DIV_DEN) == 66 + #define ATL_DivAns 66 + #elif (ATL_DIV_NUM/ATL_DIV_DEN) == 67 + #define ATL_DivAns 67 + #elif (ATL_DIV_NUM/ATL_DIV_DEN) == 68 + #define ATL_DivAns 68 + #elif (ATL_DIV_NUM/ATL_DIV_DEN) == 69 + #define ATL_DivAns 69 + #elif (ATL_DIV_NUM/ATL_DIV_DEN) == 70 + #define ATL_DivAns 70 + #elif (ATL_DIV_NUM/ATL_DIV_DEN) == 71 + #define ATL_DivAns 71 + #elif (ATL_DIV_NUM/ATL_DIV_DEN) == 72 + #define ATL_DivAns 72 + #elif (ATL_DIV_NUM/ATL_DIV_DEN) == 73 + #define ATL_DivAns 73 + #elif (ATL_DIV_NUM/ATL_DIV_DEN) == 74 + #define ATL_DivAns 74 + #elif (ATL_DIV_NUM/ATL_DIV_DEN) == 75 + #define ATL_DivAns 75 + #elif (ATL_DIV_NUM/ATL_DIV_DEN) == 76 + #define ATL_DivAns 76 + #elif (ATL_DIV_NUM/ATL_DIV_DEN) == 77 + #define ATL_DivAns 77 + #elif (ATL_DIV_NUM/ATL_DIV_DEN) == 78 + #define ATL_DivAns 78 + #elif (ATL_DIV_NUM/ATL_DIV_DEN) == 79 + #define ATL_DivAns 79 + #elif (ATL_DIV_NUM/ATL_DIV_DEN) == 80 + #define ATL_DivAns 80 + #elif (ATL_DIV_NUM/ATL_DIV_DEN) == 81 + #define ATL_DivAns 81 + #elif (ATL_DIV_NUM/ATL_DIV_DEN) == 82 + #define ATL_DivAns 82 + #elif (ATL_DIV_NUM/ATL_DIV_DEN) == 83 + #define ATL_DivAns 83 + #elif (ATL_DIV_NUM/ATL_DIV_DEN) == 84 + #define ATL_DivAns 84 + #elif (ATL_DIV_NUM/ATL_DIV_DEN) == 85 + #define ATL_DivAns 85 + #elif (ATL_DIV_NUM/ATL_DIV_DEN) == 86 + #define ATL_DivAns 86 + #elif (ATL_DIV_NUM/ATL_DIV_DEN) == 87 + #define ATL_DivAns 87 + #elif (ATL_DIV_NUM/ATL_DIV_DEN) == 88 + #define ATL_DivAns 88 + #elif (ATL_DIV_NUM/ATL_DIV_DEN) == 89 + #define ATL_DivAns 89 + #elif (ATL_DIV_NUM/ATL_DIV_DEN) == 90 + #define ATL_DivAns 90 + #elif (ATL_DIV_NUM/ATL_DIV_DEN) == 91 + #define ATL_DivAns 91 + #elif (ATL_DIV_NUM/ATL_DIV_DEN) == 92 + #define ATL_DivAns 92 + #elif (ATL_DIV_NUM/ATL_DIV_DEN) == 93 + #define ATL_DivAns 93 + #elif (ATL_DIV_NUM/ATL_DIV_DEN) == 94 + #define ATL_DivAns 94 + #elif (ATL_DIV_NUM/ATL_DIV_DEN) == 95 + #define ATL_DivAns 95 + #elif (ATL_DIV_NUM/ATL_DIV_DEN) == 96 + #define ATL_DivAns 96 + #elif (ATL_DIV_NUM/ATL_DIV_DEN) == 97 + #define ATL_DivAns 97 + #elif (ATL_DIV_NUM/ATL_DIV_DEN) == 98 + #define ATL_DivAns 98 + #elif (ATL_DIV_NUM/ATL_DIV_DEN) == 99 + #define ATL_DivAns 99 + #elif (ATL_DIV_NUM/ATL_DIV_DEN) == 100 + #define ATL_DivAns 100 + #elif (ATL_DIV_NUM/ATL_DIV_DEN) == 101 + #define ATL_DivAns 101 + #elif (ATL_DIV_NUM/ATL_DIV_DEN) == 102 + #define ATL_DivAns 102 + #elif (ATL_DIV_NUM/ATL_DIV_DEN) == 103 + #define ATL_DivAns 103 + #elif (ATL_DIV_NUM/ATL_DIV_DEN) == 104 + #define ATL_DivAns 104 + #elif (ATL_DIV_NUM/ATL_DIV_DEN) == 105 + #define ATL_DivAns 105 + #elif (ATL_DIV_NUM/ATL_DIV_DEN) == 106 + #define ATL_DivAns 106 + #elif (ATL_DIV_NUM/ATL_DIV_DEN) == 107 + #define ATL_DivAns 107 + #elif (ATL_DIV_NUM/ATL_DIV_DEN) == 108 + #define ATL_DivAns 108 + #elif (ATL_DIV_NUM/ATL_DIV_DEN) == 109 + #define ATL_DivAns 109 + #elif (ATL_DIV_NUM/ATL_DIV_DEN) == 110 + #define ATL_DivAns 110 + #elif (ATL_DIV_NUM/ATL_DIV_DEN) == 111 + #define ATL_DivAns 111 + #elif (ATL_DIV_NUM/ATL_DIV_DEN) == 112 + #define ATL_DivAns 112 + #elif (ATL_DIV_NUM/ATL_DIV_DEN) == 113 + #define ATL_DivAns 113 + #elif (ATL_DIV_NUM/ATL_DIV_DEN) == 114 + #define ATL_DivAns 114 + #elif (ATL_DIV_NUM/ATL_DIV_DEN) == 115 + #define ATL_DivAns 115 + #elif (ATL_DIV_NUM/ATL_DIV_DEN) == 116 + #define ATL_DivAns 116 + #elif (ATL_DIV_NUM/ATL_DIV_DEN) == 117 + #define ATL_DivAns 117 + #elif (ATL_DIV_NUM/ATL_DIV_DEN) == 118 + #define ATL_DivAns 118 + #elif (ATL_DIV_NUM/ATL_DIV_DEN) == 119 + #define ATL_DivAns 119 + #elif (ATL_DIV_NUM/ATL_DIV_DEN) == 120 + #define ATL_DivAns 120 + #elif (ATL_DIV_NUM/ATL_DIV_DEN) == 121 + #define ATL_DivAns 121 + #elif (ATL_DIV_NUM/ATL_DIV_DEN) == 122 + #define ATL_DivAns 122 + #elif (ATL_DIV_NUM/ATL_DIV_DEN) == 123 + #define ATL_DivAns 123 + #elif (ATL_DIV_NUM/ATL_DIV_DEN) == 124 + #define ATL_DivAns 124 + #elif (ATL_DIV_NUM/ATL_DIV_DEN) == 125 + #define ATL_DivAns 125 + #elif (ATL_DIV_NUM/ATL_DIV_DEN) == 126 + #define ATL_DivAns 126 + #elif (ATL_DIV_NUM/ATL_DIV_DEN) == 127 + #define ATL_DivAns 127 + #elif (ATL_DIV_NUM/ATL_DIV_DEN) == 128 + #define ATL_DivAns 128 + #endif +#endif + +/* + * For GNU/Linux, set no-execute bit for all ATLAS assembly + */ +#if defined(ATL_OS_Linux) && defined(__ELF__) && defined(__GNUC__) && \ + defined(ATL_SSE1) +.section .note.GNU-stack,"",%progbits +#endif + +#endif diff --git a/kaldi_io/src/tools/ATLAS/include/atlas_aux.h b/kaldi_io/src/tools/ATLAS/include/atlas_aux.h new file mode 100644 index 0000000..ce31eee --- /dev/null +++ b/kaldi_io/src/tools/ATLAS/include/atlas_aux.h @@ -0,0 +1,785 @@ +/* + * Automatically Tuned Linear Algebra Software v3.8.3 + * (C) Copyright 1999 R. Clint Whaley + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions + * are met: + * 1. Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * 2. Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions, and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * 3. The name of the ATLAS group or the names of its contributers may + * not be used to endorse or promote products derived from this + * software without specific written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + * ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED + * TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR + * PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE ATLAS GROUP OR ITS CONTRIBUTORS + * BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + * POSSIBILITY OF SUCH DAMAGE. + * + */ +/* + * Header file for ATLAS's auxiliary routines + */ +#ifndef ATLAS_AUX_H +#define ATLAS_AUX_H +#include "atlas_misc.h" + +void ATL_xerbla(int p, char *rout, char *form, ...); +int ATL_lcm(const int M, const int N); +double ATL_walltime(); +double ATL_cputime(); + +/* + * Auxiliary routines that come in all four types + */ +void ATL_sgeadd(const int M, const int N, const float alpha, + const float *A, const int lda, const float beta, + float *C, const int ldc); +void ATL_sgemove(const int M, const int N, const float alpha, + const float *A, const int lda, float *C, const int ldc); +void ATL_sgemoveT(const int N, const int M, const float alpha, + const float *A, const int lda, float *C, const int ldc); +void ATL_ssyreflect(const enum ATLAS_UPLO Uplo, const int N, + float *C, const int ldc); +void ATL_sgecopy(const int M, const int N, const float *A, const int lda, + float *C, const int ldc); + +void ATL_sgescal(const int M, const int N, const float beta, + float *C, const int ldc); +void ATL_strscal + (const enum ATLAS_UPLO Uplo, const int M, const int N, const float alpha, + float *A, const int lda); +void ATL_shescal + (const enum ATLAS_UPLO Uplo, const int M, const int N, const float alpha, + float *A, const int lda); + +void ATL_sgezero(const int M, const int N, float *C, const int ldc); + +void ATL_szero(const int N, float *X, const int incX); +void ATL_sset(const int N, const float alpha, float *X, const int incX); +void ATL_sscal(const int N, const float alpha, float *X, const int incX); +void ATL_scopy(const int N, const float *X, const int incX, + float *Y, const int incY); +void ATL_scpsc(const int N, const float alpha, const float *X, + const int incX, float *Y, const int incY); +void ATL_saxpy(const int N, const float alpha, const float *X, + const int incX, float *Y, const int incY); +void ATL_saxpy_x1_y1(const int N, const float alpha, const float *X, + const int incX, float *Y, const int incY); +void ATL_saxpby(const int N, const float alpha, const float *X, + const int incX, const float beta, float *Y, const int incY); + +void ATL_sgeadd_a1_b1 + (const int M, const int N, const float alpha, const float *A, + const int lda, const float beta, float *C, const int ldc); +void ATL_saxpby_a1_b1 + (const int N, const float alpha, const float *X, const int incX, + const float beta, float *Y, const int incY); +void ATL_sgeadd_a0_b1 + (const int M, const int N, const float alpha, const float *A, + const int lda, const float beta, float *C, const int ldc); +void ATL_saxpby_a0_b1 + (const int N, const float alpha, const float *X, const int incX, + const float beta, float *Y, const int incY); +void ATL_sgeadd_aX_b1 + (const int M, const int N, const float alpha, const float *A, + const int lda, const float beta, float *C, const int ldc); +void ATL_saxpby_aX_b1 + (const int N, const float alpha, const float *X, const int incX, + const float beta, float *Y, const int incY); +void ATL_sgeadd_a1_b0 + (const int M, const int N, const float alpha, const float *A, + const int lda, const float beta, float *C, const int ldc); +void ATL_saxpby_a1_b0 + (const int N, const float alpha, const float *X, const int incX, + const float beta, float *Y, const int incY); +void ATL_sgeadd_a0_b0 + (const int M, const int N, const float alpha, const float *A, + const int lda, const float beta, float *C, const int ldc); +void ATL_saxpby_a0_b0 + (const int N, const float alpha, const float *X, const int incX, + const float beta, float *Y, const int incY); +void ATL_sgeadd_aX_b0 + (const int M, const int N, const float alpha, const float *A, + const int lda, const float beta, float *C, const int ldc); +void ATL_saxpby_aX_b0 + (const int N, const float alpha, const float *X, const int incX, + const float beta, float *Y, const int incY); +void ATL_sgeadd_a1_bX + (const int M, const int N, const float alpha, const float *A, + const int lda, const float beta, float *C, const int ldc); +void ATL_saxpby_a1_bX + (const int N, const float alpha, const float *X, const int incX, + const float beta, float *Y, const int incY); +void ATL_sgeadd_a0_bX + (const int M, const int N, const float alpha, const float *A, + const int lda, const float beta, float *C, const int ldc); +void ATL_saxpby_a0_bX + (const int N, const float alpha, const float *X, const int incX, + const float beta, float *Y, const int incY); +void ATL_sgeadd_aX_bX + (const int M, const int N, const float alpha, const float *A, + const int lda, const float beta, float *C, const int ldc); +void ATL_saxpby_aX_bX + (const int N, const float alpha, const float *X, const int incX, + const float beta, float *Y, const int incY); + +void ATL_sgemove_a1 + (const int M, const int N, const float alpha, const float *A, + const int lda, float *C, const int ldc); +void ATL_sgemove_a0 + (const int M, const int N, const float alpha, const float *A, + const int lda, float *C, const int ldc); +void ATL_sgemove_aX + (const int M, const int N, const float alpha, const float *A, + const int lda, float *C, const int ldc); + +void ATL_sgescal_b1 + (const int M, const int N, const float beta, float *C, const int ldc); +void ATL_sgescal_b0 + (const int M, const int N, const float beta, float *C, const int ldc); +void ATL_sgescal_bX + (const int M, const int N, const float beta, float *C, const int ldc); + +void ATL_dgeadd(const int M, const int N, const double alpha, + const double *A, const int lda, const double beta, + double *C, const int ldc); +void ATL_dgemove(const int M, const int N, const double alpha, + const double *A, const int lda, double *C, const int ldc); +void ATL_dgemoveT(const int N, const int M, const double alpha, + const double *A, const int lda, double *C, const int ldc); +void ATL_dsyreflect(const enum ATLAS_UPLO Uplo, const int N, + double *C, const int ldc); +void ATL_dgecopy(const int M, const int N, const double *A, const int lda, + double *C, const int ldc); + +void ATL_dgescal(const int M, const int N, const double beta, + double *C, const int ldc); +void ATL_dtrscal + (const enum ATLAS_UPLO Uplo, const int M, const int N, const double alpha, + double *A, const int lda); +void ATL_dhescal + (const enum ATLAS_UPLO Uplo, const int M, const int N, const double alpha, + double *A, const int lda); + +void ATL_dgezero(const int M, const int N, double *C, const int ldc); + +void ATL_dzero(const int N, double *X, const int incX); +void ATL_dset(const int N, const double alpha, double *X, const int incX); +void ATL_dscal(const int N, const double alpha, double *X, const int incX); +void ATL_dcopy(const int N, const double *X, const int incX, + double *Y, const int incY); +void ATL_dcpsc(const int N, const double alpha, const double *X, + const int incX, double *Y, const int incY); +void ATL_daxpy(const int N, const double alpha, const double *X, + const int incX, double *Y, const int incY); +void ATL_daxpy_x1_y1(const int N, const double alpha, const double *X, + const int incX, double *Y, const int incY); +void ATL_daxpby(const int N, const double alpha, const double *X, + const int incX, const double beta, double *Y, const int incY); + +void ATL_dgeadd_a1_b1 + (const int M, const int N, const double alpha, const double *A, + const int lda, const double beta, double *C, const int ldc); +void ATL_daxpby_a1_b1 + (const int N, const double alpha, const double *X, const int incX, + const double beta, double *Y, const int incY); +void ATL_dgeadd_a0_b1 + (const int M, const int N, const double alpha, const double *A, + const int lda, const double beta, double *C, const int ldc); +void ATL_daxpby_a0_b1 + (const int N, const double alpha, const double *X, const int incX, + const double beta, double *Y, const int incY); +void ATL_dgeadd_aX_b1 + (const int M, const int N, const double alpha, const double *A, + const int lda, const double beta, double *C, const int ldc); +void ATL_daxpby_aX_b1 + (const int N, const double alpha, const double *X, const int incX, + const double beta, double *Y, const int incY); +void ATL_dgeadd_a1_b0 + (const int M, const int N, const double alpha, const double *A, + const int lda, const double beta, double *C, const int ldc); +void ATL_daxpby_a1_b0 + (const int N, const double alpha, const double *X, const int incX, + const double beta, double *Y, const int incY); +void ATL_dgeadd_a0_b0 + (const int M, const int N, const double alpha, const double *A, + const int lda, const double beta, double *C, const int ldc); +void ATL_daxpby_a0_b0 + (const int N, const double alpha, const double *X, const int incX, + const double beta, double *Y, const int incY); +void ATL_dgeadd_aX_b0 + (const int M, const int N, const double alpha, const double *A, + const int lda, const double beta, double *C, const int ldc); +void ATL_daxpby_aX_b0 + (const int N, const double alpha, const double *X, const int incX, + const double beta, double *Y, const int incY); +void ATL_dgeadd_a1_bX + (const int M, const int N, const double alpha, const double *A, + const int lda, const double beta, double *C, const int ldc); +void ATL_daxpby_a1_bX + (const int N, const double alpha, const double *X, const int incX, + const double beta, double *Y, const int incY); +void ATL_dgeadd_a0_bX + (const int M, const int N, const double alpha, const double *A, + const int lda, const double beta, double *C, const int ldc); +void ATL_daxpby_a0_bX + (const int N, const double alpha, const double *X, const int incX, + const double beta, double *Y, const int incY); +void ATL_dgeadd_aX_bX + (const int M, const int N, const double alpha, const double *A, + const int lda, const double beta, double *C, const int ldc); +void ATL_daxpby_aX_bX + (const int N, const double alpha, const double *X, const int incX, + const double beta, double *Y, const int incY); + +void ATL_dgemove_a1 + (const int M, const int N, const double alpha, const double *A, + const int lda, double *C, const int ldc); +void ATL_dgemove_a0 + (const int M, const int N, const double alpha, const double *A, + const int lda, double *C, const int ldc); +void ATL_dgemove_aX + (const int M, const int N, const double alpha, const double *A, + const int lda, double *C, const int ldc); + +void ATL_dgescal_b1 + (const int M, const int N, const double beta, double *C, const int ldc); +void ATL_dgescal_b0 + (const int M, const int N, const double beta, double *C, const int ldc); +void ATL_dgescal_bX + (const int M, const int N, const double beta, double *C, const int ldc); + +void ATL_cgeadd(const int M, const int N, const float *alpha, + const float *A, const int lda, const float *beta, + float *C, const int ldc); +void ATL_cgemove(const int M, const int N, const float *alpha, + const float *A, const int lda, float *C, const int ldc); +void ATL_cgemoveT(const int N, const int M, const float *alpha, + const float *A, const int lda, float *C, const int ldc); +void ATL_csyreflect(const enum ATLAS_UPLO Uplo, const int N, + float *C, const int ldc); +void ATL_cgecopy(const int M, const int N, const float *A, const int lda, + float *C, const int ldc); + +void ATL_cgescal(const int M, const int N, const float *beta, + float *C, const int ldc); +void ATL_ctrscal + (const enum ATLAS_UPLO Uplo, const int M, const int N, const float *alpha, + float *A, const int lda); +void ATL_chescal + (const enum ATLAS_UPLO Uplo, const int M, const int N, const float alpha, + float *A, const int lda); + +void ATL_cgezero(const int M, const int N, float *C, const int ldc); + +void ATL_czero(const int N, float *X, const int incX); +void ATL_cset(const int N, const float *alpha, float *X, const int incX); +void ATL_cscal(const int N, const float *alpha, float *X, const int incX); +void ATL_ccopy(const int N, const float *X, const int incX, + float *Y, const int incY); +void ATL_ccpsc(const int N, const float *alpha, const float *X, + const int incX, float *Y, const int incY); +void ATL_caxpy(const int N, const float *alpha, const float *X, + const int incX, float *Y, const int incY); +void ATL_caxpy_x1_y1(const int N, const float *alpha, const float *X, + const int incX, float *Y, const int incY); +void ATL_caxpby(const int N, const float *alpha, const float *X, + const int incX, const float *beta, float *Y, const int incY); + +void ATL_cgeadd_a1_b1 + (const int M, const int N, const float *alpha, const float *A, + const int lda, const float *beta, float *C, const int ldc); +void ATL_caxpby_a1_b1 + (const int N, const float *alpha, const float *X, const int incX, + const float *beta, float *Y, const int incY); +void ATL_cgeadd_a0_b1 + (const int M, const int N, const float *alpha, const float *A, + const int lda, const float *beta, float *C, const int ldc); +void ATL_caxpby_a0_b1 + (const int N, const float *alpha, const float *X, const int incX, + const float *beta, float *Y, const int incY); +void ATL_cgeadd_aX_b1 + (const int M, const int N, const float *alpha, const float *A, + const int lda, const float *beta, float *C, const int ldc); +void ATL_caxpby_aX_b1 + (const int N, const float *alpha, const float *X, const int incX, + const float *beta, float *Y, const int incY); +void ATL_cgeadd_a1_b0 + (const int M, const int N, const float *alpha, const float *A, + const int lda, const float *beta, float *C, const int ldc); +void ATL_caxpby_a1_b0 + (const int N, const float *alpha, const float *X, const int incX, + const float *beta, float *Y, const int incY); +void ATL_cgeadd_a0_b0 + (const int M, const int N, const float *alpha, const float *A, + const int lda, const float *beta, float *C, const int ldc); +void ATL_caxpby_a0_b0 + (const int N, const float *alpha, const float *X, const int incX, + const float *beta, float *Y, const int incY); +void ATL_cgeadd_aX_b0 + (const int M, const int N, const float *alpha, const float *A, + const int lda, const float *beta, float *C, const int ldc); +void ATL_caxpby_aX_b0 + (const int N, const float *alpha, const float *X, const int incX, + const float *beta, float *Y, const int incY); +void ATL_cgeadd_a1_bX + (const int M, const int N, const float *alpha, const float *A, + const int lda, const float *beta, float *C, const int ldc); +void ATL_caxpby_a1_bX + (const int N, const float *alpha, const float *X, const int incX, + const float *beta, float *Y, const int incY); +void ATL_cgeadd_a0_bX + (const int M, const int N, const float *alpha, const float *A, + const int lda, const float *beta, float *C, const int ldc); +void ATL_caxpby_a0_bX + (const int N, const float *alpha, const float *X, const int incX, + const float *beta, float *Y, const int incY); +void ATL_cgeadd_aX_bX + (const int M, const int N, const float *alpha, const float *A, + const int lda, const float *beta, float *C, const int ldc); +void ATL_caxpby_aX_bX + (const int N, const float *alpha, const float *X, const int incX, + const float *beta, float *Y, const int incY); + +void ATL_cgemove_a1 + (const int M, const int N, const float *alpha, const float *A, + const int lda, float *C, const int ldc); +void ATL_cgemove_a0 + (const int M, const int N, const float *alpha, const float *A, + const int lda, float *C, const int ldc); +void ATL_cgemove_aX + (const int M, const int N, const float *alpha, const float *A, + const int lda, float *C, const int ldc); + +void ATL_cgescal_b1 + (const int M, const int N, const float *beta, float *C, const int ldc); +void ATL_cgescal_b0 + (const int M, const int N, const float *beta, float *C, const int ldc); +void ATL_cgescal_bX + (const int M, const int N, const float *beta, float *C, const int ldc); + +void ATL_zgeadd(const int M, const int N, const double *alpha, + const double *A, const int lda, const double *beta, + double *C, const int ldc); +void ATL_zgemove(const int M, const int N, const double *alpha, + const double *A, const int lda, double *C, const int ldc); +void ATL_zgemoveT(const int N, const int M, const double *alpha, + const double *A, const int lda, double *C, const int ldc); +void ATL_zsyreflect(const enum ATLAS_UPLO Uplo, const int N, + double *C, const int ldc); +void ATL_zgecopy(const int M, const int N, const double *A, const int lda, + double *C, const int ldc); + +void ATL_zgescal(const int M, const int N, const double *beta, + double *C, const int ldc); +void ATL_ztrscal + (const enum ATLAS_UPLO Uplo, const int M, const int N, const double *alpha, + double *A, const int lda); +void ATL_zhescal + (const enum ATLAS_UPLO Uplo, const int M, const int N, const double alpha, + double *A, const int lda); + +void ATL_zgezero(const int M, const int N, double *C, const int ldc); + +void ATL_zzero(const int N, double *X, const int incX); +void ATL_zset(const int N, const double *alpha, double *X, const int incX); +void ATL_zscal(const int N, const double *alpha, double *X, const int incX); +void ATL_zcopy(const int N, const double *X, const int incX, + double *Y, const int incY); +void ATL_zcpsc(const int N, const double *alpha, const double *X, + const int incX, double *Y, const int incY); +void ATL_zaxpy(const int N, const double *alpha, const double *X, + const int incX, double *Y, const int incY); +void ATL_zaxpy_x1_y1(const int N, const double *alpha, const double *X, + const int incX, double *Y, const int incY); +void ATL_zaxpby(const int N, const double *alpha, const double *X, + const int incX, const double *beta, double *Y, const int incY); + +void ATL_zgeadd_a1_b1 + (const int M, const int N, const double *alpha, const double *A, + const int lda, const double *beta, double *C, const int ldc); +void ATL_zaxpby_a1_b1 + (const int N, const double *alpha, const double *X, const int incX, + const double *beta, double *Y, const int incY); +void ATL_zgeadd_a0_b1 + (const int M, const int N, const double *alpha, const double *A, + const int lda, const double *beta, double *C, const int ldc); +void ATL_zaxpby_a0_b1 + (const int N, const double *alpha, const double *X, const int incX, + const double *beta, double *Y, const int incY); +void ATL_zgeadd_aX_b1 + (const int M, const int N, const double *alpha, const double *A, + const int lda, const double *beta, double *C, const int ldc); +void ATL_zaxpby_aX_b1 + (const int N, const double *alpha, const double *X, const int incX, + const double *beta, double *Y, const int incY); +void ATL_zgeadd_a1_b0 + (const int M, const int N, const double *alpha, const double *A, + const int lda, const double *beta, double *C, const int ldc); +void ATL_zaxpby_a1_b0 + (const int N, const double *alpha, const double *X, const int incX, + const double *beta, double *Y, const int incY); +void ATL_zgeadd_a0_b0 + (const int M, const int N, const double *alpha, const double *A, + const int lda, const double *beta, double *C, const int ldc); +void ATL_zaxpby_a0_b0 + (const int N, const double *alpha, const double *X, const int incX, + const double *beta, double *Y, const int incY); +void ATL_zgeadd_aX_b0 + (const int M, const int N, const double *alpha, const double *A, + const int lda, const double *beta, double *C, const int ldc); +void ATL_zaxpby_aX_b0 + (const int N, const double *alpha, const double *X, const int incX, + const double *beta, double *Y, const int incY); +void ATL_zgeadd_a1_bX + (const int M, const int N, const double *alpha, const double *A, + const int lda, const double *beta, double *C, const int ldc); +void ATL_zaxpby_a1_bX + (const int N, const double *alpha, const double *X, const int incX, + const double *beta, double *Y, const int incY); +void ATL_zgeadd_a0_bX + (const int M, const int N, const double *alpha, const double *A, + const int lda, const double *beta, double *C, const int ldc); +void ATL_zaxpby_a0_bX + (const int N, const double *alpha, const double *X, const int incX, + const double *beta, double *Y, const int incY); +void ATL_zgeadd_aX_bX + (const int M, const int N, const double *alpha, const double *A, + const int lda, const double *beta, double *C, const int ldc); +void ATL_zaxpby_aX_bX + (const int N, const double *alpha, const double *X, const int incX, + const double *beta, double *Y, const int incY); + +void ATL_zgemove_a1 + (const int M, const int N, const double *alpha, const double *A, + const int lda, double *C, const int ldc); +void ATL_zgemove_a0 + (const int M, const int N, const double *alpha, const double *A, + const int lda, double *C, const int ldc); +void ATL_zgemove_aX + (const int M, const int N, const double *alpha, const double *A, + const int lda, double *C, const int ldc); + +void ATL_zgescal_b1 + (const int M, const int N, const double *beta, double *C, const int ldc); +void ATL_zgescal_b0 + (const int M, const int N, const double *beta, double *C, const int ldc); +void ATL_zgescal_bX + (const int M, const int N, const double *beta, double *C, const int ldc); + +/* + * Specialized complex auxiliary routines + */ + +void ATL_ccplxinvert + (const int N, float *X, const int incX, float *Y, const int incY); + +void ATL_chereflect(const enum ATLAS_UPLO Uplo, const int N, + float *C, const int ldc); +void ATL_cscalConj + (const int N, const float *alpha, float *X, const int incX); +void ATL_ccopyConj + (const int N, const float *X, const int incX, float *Y, const int incY); +void ATL_cmoveConj + (const int N, const float *alpha, const float *X, const int incX, + float *Y, const int incY); +void ATL_caxpyConj + (const int N, const float *alpha, const float *X, const int incX, + float *Y, const int incY); +void ATL_caxpyConj_x1_y1(const int N, const float *alpha, const float *X, + const int incX, float *Y, const int incY); +void ATL_caxpbyConj + (const int N, const float *alpha, const float *X, const int incX, + const float *beta, float *Y, const int incY); +void ATL_cgemoveC(const int N, const int M, const float *alpha, + const float *A, const int lda, float *C, const int ldc); + +void ATL_cgeaddConj_aXi0_b1 + (const int M, const int N, const float *alpha, const float *A, + const int lda, const float *beta, float *C, const int ldc); +void ATL_cgeaddConj_a1_b1 + (const int M, const int N, const float *alpha, const float *A, + const int lda, const float *beta, float *C, const int ldc); +void ATL_cgeaddConj_a0_b1 + (const int M, const int N, const float *alpha, const float *A, + const int lda, const float *beta, float *C, const int ldc); +void ATL_cgeaddConj_aXi0_b1 + (const int M, const int N, const float *alpha, const float *A, + const int lda, const float *beta, float *C, const int ldc); +void ATL_cgeaddConj_aX_b1 + (const int M, const int N, const float *alpha, const float *A, + const int lda, const float *beta, float *C, const int ldc); +void ATL_cgeaddConj_aXi0_b0 + (const int M, const int N, const float *alpha, const float *A, + const int lda, const float *beta, float *C, const int ldc); +void ATL_cgeaddConj_a1_b0 + (const int M, const int N, const float *alpha, const float *A, + const int lda, const float *beta, float *C, const int ldc); +void ATL_cgeaddConj_a0_b0 + (const int M, const int N, const float *alpha, const float *A, + const int lda, const float *beta, float *C, const int ldc); +void ATL_cgeaddConj_aXi0_b0 + (const int M, const int N, const float *alpha, const float *A, + const int lda, const float *beta, float *C, const int ldc); +void ATL_cgeaddConj_aX_b0 + (const int M, const int N, const float *alpha, const float *A, + const int lda, const float *beta, float *C, const int ldc); +void ATL_cgeaddConj_aXi0_bXi0 + (const int M, const int N, const float *alpha, const float *A, + const int lda, const float *beta, float *C, const int ldc); +void ATL_cgeaddConj_a1_bXi0 + (const int M, const int N, const float *alpha, const float *A, + const int lda, const float *beta, float *C, const int ldc); +void ATL_cgeaddConj_a0_bXi0 + (const int M, const int N, const float *alpha, const float *A, + const int lda, const float *beta, float *C, const int ldc); +void ATL_cgeaddConj_aXi0_bXi0 + (const int M, const int N, const float *alpha, const float *A, + const int lda, const float *beta, float *C, const int ldc); +void ATL_cgeaddConj_aX_bXi0 + (const int M, const int N, const float *alpha, const float *A, + const int lda, const float *beta, float *C, const int ldc); +void ATL_cgeaddConj_aXi0_bX + (const int M, const int N, const float *alpha, const float *A, + const int lda, const float *beta, float *C, const int ldc); +void ATL_cgeaddConj_a1_bX + (const int M, const int N, const float *alpha, const float *A, + const int lda, const float *beta, float *C, const int ldc); +void ATL_cgeaddConj_a0_bX + (const int M, const int N, const float *alpha, const float *A, + const int lda, const float *beta, float *C, const int ldc); +void ATL_cgeaddConj_aXi0_bX + (const int M, const int N, const float *alpha, const float *A, + const int lda, const float *beta, float *C, const int ldc); +void ATL_cgeaddConj_aX_bX + (const int M, const int N, const float *alpha, const float *A, + const int lda, const float *beta, float *C, const int ldc); +void ATL_caxpby_aXi0_b1 + (const int N, const float *alpha, const float *X, const int incX, + const float *beta, float *Y, const int incY); +void ATL_caxpby_aXi0_b1 + (const int N, const float *alpha, const float *X, const int incX, + const float *beta, float *Y, const int incY); +void ATL_cgeadd_aXi0_b1 + (const int M, const int N, const float *alpha, const float *A, + const int lda, const float *beta, float *C, const int ldc); +void ATL_caxpby_aXi0_b0 + (const int N, const float *alpha, const float *X, const int incX, + const float *beta, float *Y, const int incY); +void ATL_caxpby_aXi0_b0 + (const int N, const float *alpha, const float *X, const int incX, + const float *beta, float *Y, const int incY); +void ATL_cgeadd_aXi0_b0 + (const int M, const int N, const float *alpha, const float *A, + const int lda, const float *beta, float *C, const int ldc); +void ATL_caxpby_aXi0_bXi0 + (const int N, const float *alpha, const float *X, const int incX, + const float *beta, float *Y, const int incY); +void ATL_caxpby_aXi0_bXi0 + (const int N, const float *alpha, const float *X, const int incX, + const float *beta, float *Y, const int incY); +void ATL_cgeadd_aXi0_bXi0 + (const int M, const int N, const float *alpha, const float *A, + const int lda, const float *beta, float *C, const int ldc); +void ATL_caxpby_aXi0_bX + (const int N, const float *alpha, const float *X, const int incX, + const float *beta, float *Y, const int incY); +void ATL_caxpby_aXi0_bX + (const int N, const float *alpha, const float *X, const int incX, + const float *beta, float *Y, const int incY); +void ATL_cgeadd_aXi0_bX + (const int M, const int N, const float *alpha, const float *A, + const int lda, const float *beta, float *C, const int ldc); +void ATL_caxpby_a1_bXi0 + (const int N, const float *alpha, const float *X, const int incX, + const float *beta, float *Y, const int incY); +void ATL_cgeadd_a1_bXi0 + (const int M, const int N, const float *alpha, const float *A, + const int lda, const float *beta, float *C, const int ldc); +void ATL_caxpby_a0_bXi0 + (const int N, const float *alpha, const float *X, const int incX, + const float *beta, float *Y, const int incY); +void ATL_cgeadd_a0_bXi0 + (const int M, const int N, const float *alpha, const float *A, + const int lda, const float *beta, float *C, const int ldc); +void ATL_caxpby_aX_bXi0 + (const int N, const float *alpha, const float *X, const int incX, + const float *beta, float *Y, const int incY); +void ATL_cgeadd_aX_bXi0 + (const int M, const int N, const float *alpha, const float *A, + const int lda, const float *beta, float *C, const int ldc); + +void ATL_cgemove_aXi0 + (const int M, const int N, const float *alpha0, const float *A, + const int lda, float *C, const int ldc); + +void ATL_cgescal_bXi0 + (const int M, const int N, const float *beta, float *C, const int ldc); + +void ATL_zcplxinvert + (const int N, double *X, const int incX, double *Y, const int incY); + +void ATL_zhereflect(const enum ATLAS_UPLO Uplo, const int N, + double *C, const int ldc); +void ATL_zscalConj + (const int N, const double *alpha, double *X, const int incX); +void ATL_zcopyConj + (const int N, const double *X, const int incX, double *Y, const int incY); +void ATL_zmoveConj + (const int N, const double *alpha, const double *X, const int incX, + double *Y, const int incY); +void ATL_zaxpyConj + (const int N, const double *alpha, const double *X, const int incX, + double *Y, const int incY); +void ATL_zaxpyConj_x1_y1(const int N, const double *alpha, const double *X, + const int incX, double *Y, const int incY); +void ATL_zaxpbyConj + (const int N, const double *alpha, const double *X, const int incX, + const double *beta, double *Y, const int incY); +void ATL_zgemoveC(const int N, const int M, const double *alpha, + const double *A, const int lda, double *C, const int ldc); + +void ATL_zgeaddConj_aXi0_b1 + (const int M, const int N, const double *alpha, const double *A, + const int lda, const double *beta, double *C, const int ldc); +void ATL_zgeaddConj_a1_b1 + (const int M, const int N, const double *alpha, const double *A, + const int lda, const double *beta, double *C, const int ldc); +void ATL_zgeaddConj_a0_b1 + (const int M, const int N, const double *alpha, const double *A, + const int lda, const double *beta, double *C, const int ldc); +void ATL_zgeaddConj_aXi0_b1 + (const int M, const int N, const double *alpha, const double *A, + const int lda, const double *beta, double *C, const int ldc); +void ATL_zgeaddConj_aX_b1 + (const int M, const int N, const double *alpha, const double *A, + const int lda, const double *beta, double *C, const int ldc); +void ATL_zgeaddConj_aXi0_b0 + (const int M, const int N, const double *alpha, const double *A, + const int lda, const double *beta, double *C, const int ldc); +void ATL_zgeaddConj_a1_b0 + (const int M, const int N, const double *alpha, const double *A, + const int lda, const double *beta, double *C, const int ldc); +void ATL_zgeaddConj_a0_b0 + (const int M, const int N, const double *alpha, const double *A, + const int lda, const double *beta, double *C, const int ldc); +void ATL_zgeaddConj_aXi0_b0 + (const int M, const int N, const double *alpha, const double *A, + const int lda, const double *beta, double *C, const int ldc); +void ATL_zgeaddConj_aX_b0 + (const int M, const int N, const double *alpha, const double *A, + const int lda, const double *beta, double *C, const int ldc); +void ATL_zgeaddConj_aXi0_bXi0 + (const int M, const int N, const double *alpha, const double *A, + const int lda, const double *beta, double *C, const int ldc); +void ATL_zgeaddConj_a1_bXi0 + (const int M, const int N, const double *alpha, const double *A, + const int lda, const double *beta, double *C, const int ldc); +void ATL_zgeaddConj_a0_bXi0 + (const int M, const int N, const double *alpha, const double *A, + const int lda, const double *beta, double *C, const int ldc); +void ATL_zgeaddConj_aXi0_bXi0 + (const int M, const int N, const double *alpha, const double *A, + const int lda, const double *beta, double *C, const int ldc); +void ATL_zgeaddConj_aX_bXi0 + (const int M, const int N, const double *alpha, const double *A, + const int lda, const double *beta, double *C, const int ldc); +void ATL_zgeaddConj_aXi0_bX + (const int M, const int N, const double *alpha, const double *A, + const int lda, const double *beta, double *C, const int ldc); +void ATL_zgeaddConj_a1_bX + (const int M, const int N, const double *alpha, const double *A, + const int lda, const double *beta, double *C, const int ldc); +void ATL_zgeaddConj_a0_bX + (const int M, const int N, const double *alpha, const double *A, + const int lda, const double *beta, double *C, const int ldc); +void ATL_zgeaddConj_aXi0_bX + (const int M, const int N, const double *alpha, const double *A, + const int lda, const double *beta, double *C, const int ldc); +void ATL_zgeaddConj_aX_bX + (const int M, const int N, const double *alpha, const double *A, + const int lda, const double *beta, double *C, const int ldc); +void ATL_zaxpby_aXi0_b1 + (const int N, const double *alpha, const double *X, const int incX, + const double *beta, double *Y, const int incY); +void ATL_zaxpby_aXi0_b1 + (const int N, const double *alpha, const double *X, const int incX, + const double *beta, double *Y, const int incY); +void ATL_zgeadd_aXi0_b1 + (const int M, const int N, const double *alpha, const double *A, + const int lda, const double *beta, double *C, const int ldc); +void ATL_zaxpby_aXi0_b0 + (const int N, const double *alpha, const double *X, const int incX, + const double *beta, double *Y, const int incY); +void ATL_zaxpby_aXi0_b0 + (const int N, const double *alpha, const double *X, const int incX, + const double *beta, double *Y, const int incY); +void ATL_zgeadd_aXi0_b0 + (const int M, const int N, const double *alpha, const double *A, + const int lda, const double *beta, double *C, const int ldc); +void ATL_zaxpby_aXi0_bXi0 + (const int N, const double *alpha, const double *X, const int incX, + const double *beta, double *Y, const int incY); +void ATL_zaxpby_aXi0_bXi0 + (const int N, const double *alpha, const double *X, const int incX, + const double *beta, double *Y, const int incY); +void ATL_zgeadd_aXi0_bXi0 + (const int M, const int N, const double *alpha, const double *A, + const int lda, const double *beta, double *C, const int ldc); +void ATL_zaxpby_aXi0_bX + (const int N, const double *alpha, const double *X, const int incX, + const double *beta, double *Y, const int incY); +void ATL_zaxpby_aXi0_bX + (const int N, const double *alpha, const double *X, const int incX, + const double *beta, double *Y, const int incY); +void ATL_zgeadd_aXi0_bX + (const int M, const int N, const double *alpha, const double *A, + const int lda, const double *beta, double *C, const int ldc); +void ATL_zaxpby_a1_bXi0 + (const int N, const double *alpha, const double *X, const int incX, + const double *beta, double *Y, const int incY); +void ATL_zgeadd_a1_bXi0 + (const int M, const int N, const double *alpha, const double *A, + const int lda, const double *beta, double *C, const int ldc); +void ATL_zaxpby_a0_bXi0 + (const int N, const double *alpha, const double *X, const int incX, + const double *beta, double *Y, const int incY); +void ATL_zgeadd_a0_bXi0 + (const int M, const int N, const double *alpha, const double *A, + const int lda, const double *beta, double *C, const int ldc); +void ATL_zaxpby_aX_bXi0 + (const int N, const double *alpha, const double *X, const int incX, + const double *beta, double *Y, const int incY); +void ATL_zgeadd_aX_bXi0 + (const int M, const int N, const double *alpha, const double *A, + const int lda, const double *beta, double *C, const int ldc); + +void ATL_zgemove_aXi0 + (const int M, const int N, const double *alpha0, const double *A, + const int lda, double *C, const int ldc); + +void ATL_zgescal_bXi0 + (const int M, const int N, const double *beta, double *C, const int ldc); + + +#if defined(ATL_USEPTHREADS) && !defined(ATL_flushcache) + #include "atlas_pthreads.h" + #define ATL_flushcache ATL_ptflushcache + #define ATL_PTCACHEMUL * ATL_NTHREADS +#else + #define ATL_PTCACHEMUL +#endif +double ATL_flushcache(int size); + +#endif diff --git a/kaldi_io/src/tools/ATLAS/include/atlas_cblascalias.h b/kaldi_io/src/tools/ATLAS/include/atlas_cblascalias.h new file mode 100644 index 0000000..267b176 --- /dev/null +++ b/kaldi_io/src/tools/ATLAS/include/atlas_cblascalias.h @@ -0,0 +1,40 @@ +#ifndef ATLAS_CBLASCALIAS_H + #define ATLAS_CBLASCALIAS_H + +#define cblas_dotc_sub cblas_cdotc_sub +#define cblas_dotu_sub cblas_cdotu_sub +#define cblas_axpy cblas_caxpy +#define cblas_copy cblas_ccopy +#define cblas_scal cblas_cscal +#define cblas_swap cblas_cswap +#define cblas_hpr2 cblas_chpr2 +#define cblas_her2 cblas_cher2 +#define cblas_hpr cblas_chpr +#define cblas_her cblas_cher +#define cblas_gerc cblas_cgerc +#define cblas_geru cblas_cgeru +#define cblas_tpsv cblas_ctpsv +#define cblas_tbsv cblas_ctbsv +#define cblas_trsv cblas_ctrsv +#define cblas_tpmv cblas_ctpmv +#define cblas_tbmv cblas_ctbmv +#define cblas_trmv cblas_ctrmv +#define cblas_hpmv cblas_chpmv +#define cblas_hbmv cblas_chbmv +#define cblas_hemv cblas_chemv +#define cblas_gbmv cblas_cgbmv +#define cblas_gemv cblas_cgemv +#define cblas_trsm cblas_ctrsm +#define cblas_trmm cblas_ctrmm +#define cblas_her2k cblas_cher2k +#define cblas_syr2k cblas_csyr2k +#define cblas_herk cblas_cherk +#define cblas_syrk cblas_csyrk +#define cblas_hemm cblas_chemm +#define cblas_symm cblas_csymm +#define cblas_gemm cblas_cgemm +#define cblas_iamax cblas_icamax +#define cblas_nrm2 cblas_scnrm2 +#define cblas_asum cblas_scasum + +#endif diff --git a/kaldi_io/src/tools/ATLAS/include/atlas_cblasdalias.h b/kaldi_io/src/tools/ATLAS/include/atlas_cblasdalias.h new file mode 100644 index 0000000..cfc6d10 --- /dev/null +++ b/kaldi_io/src/tools/ATLAS/include/atlas_cblasdalias.h @@ -0,0 +1,39 @@ +#ifndef ATLAS_CBLASDALIAS_H + #define ATLAS_CBLASDALIAS_H + +#define cblas_asum cblas_dasum +#define cblas_nrm2 cblas_dnrm2 +#define cblas_dot cblas_ddot +#define cblas_axpy cblas_daxpy +#define cblas_copy cblas_dcopy +#define cblas_scal cblas_dscal +#define cblas_swap cblas_dswap +#define cblas_rotm cblas_drotm +#define cblas_rot cblas_drot +#define cblas_rotmg cblas_drotmg +#define cblas_rotg cblas_drotg +#define cblas_spr2 cblas_dspr2 +#define cblas_syr2 cblas_dsyr2 +#define cblas_spr cblas_dspr +#define cblas_syr cblas_dsyr +#define cblas_ger cblas_dger +#define cblas_tpsv cblas_dtpsv +#define cblas_tbsv cblas_dtbsv +#define cblas_trsv cblas_dtrsv +#define cblas_tpmv cblas_dtpmv +#define cblas_tbmv cblas_dtbmv +#define cblas_trmv cblas_dtrmv +#define cblas_spmv cblas_dspmv +#define cblas_sbmv cblas_dsbmv +#define cblas_symv cblas_dsymv +#define cblas_gbmv cblas_dgbmv +#define cblas_gemv cblas_dgemv +#define cblas_trsm cblas_dtrsm +#define cblas_trmm cblas_dtrmm +#define cblas_syr2k cblas_dsyr2k +#define cblas_syrk cblas_dsyrk +#define cblas_symm cblas_dsymm +#define cblas_gemm cblas_dgemm +#define cblas_iamax cblas_idamax + +#endif diff --git a/kaldi_io/src/tools/ATLAS/include/atlas_cblassalias.h b/kaldi_io/src/tools/ATLAS/include/atlas_cblassalias.h new file mode 100644 index 0000000..090f9de --- /dev/null +++ b/kaldi_io/src/tools/ATLAS/include/atlas_cblassalias.h @@ -0,0 +1,39 @@ +#ifndef ATLAS_CBLASSALIAS_H + #define ATLAS_CBLASSALIAS_H + +#define cblas_asum cblas_sasum +#define cblas_nrm2 cblas_snrm2 +#define cblas_dot cblas_sdot +#define cblas_axpy cblas_saxpy +#define cblas_copy cblas_scopy +#define cblas_scal cblas_sscal +#define cblas_swap cblas_sswap +#define cblas_rotm cblas_srotm +#define cblas_rot cblas_srot +#define cblas_rotmg cblas_srotmg +#define cblas_rotg cblas_srotg +#define cblas_spr2 cblas_sspr2 +#define cblas_syr2 cblas_ssyr2 +#define cblas_spr cblas_sspr +#define cblas_syr cblas_ssyr +#define cblas_ger cblas_sger +#define cblas_tpsv cblas_stpsv +#define cblas_tbsv cblas_stbsv +#define cblas_trsv cblas_strsv +#define cblas_tpmv cblas_stpmv +#define cblas_tbmv cblas_stbmv +#define cblas_trmv cblas_strmv +#define cblas_spmv cblas_sspmv +#define cblas_sbmv cblas_ssbmv +#define cblas_symv cblas_ssymv +#define cblas_gbmv cblas_sgbmv +#define cblas_gemv cblas_sgemv +#define cblas_trsm cblas_strsm +#define cblas_trmm cblas_strmm +#define cblas_syr2k cblas_ssyr2k +#define cblas_syrk cblas_ssyrk +#define cblas_symm cblas_ssymm +#define cblas_gemm cblas_sgemm +#define cblas_iamax cblas_isamax + +#endif diff --git a/kaldi_io/src/tools/ATLAS/include/atlas_cblastypealias.h b/kaldi_io/src/tools/ATLAS/include/atlas_cblastypealias.h new file mode 100644 index 0000000..0c3e82f --- /dev/null +++ b/kaldi_io/src/tools/ATLAS/include/atlas_cblastypealias.h @@ -0,0 +1,9 @@ +#ifdef SREAL + #include "atlas_cblassalias.h" +#elif defined(DREAL) + #include "atlas_cblasdalias.h" +#elif defined(SCPLX) + #include "atlas_cblascalias.h" +#elif defined(DCPLX) + #include "atlas_cblaszalias.h" +#endif diff --git a/kaldi_io/src/tools/ATLAS/include/atlas_cblaszalias.h b/kaldi_io/src/tools/ATLAS/include/atlas_cblaszalias.h new file mode 100644 index 0000000..ac01436 --- /dev/null +++ b/kaldi_io/src/tools/ATLAS/include/atlas_cblaszalias.h @@ -0,0 +1,40 @@ +#ifndef ATLAS_CBLASZALIAS_H + #define ATLAS_CBLASZALIAS_H + +#define cblas_dotc_sub cblas_zdotc_sub +#define cblas_dotu_sub cblas_zdotu_sub +#define cblas_axpy cblas_zaxpy +#define cblas_copy cblas_zcopy +#define cblas_scal cblas_zscal +#define cblas_swap cblas_zswap +#define cblas_hpr2 cblas_zhpr2 +#define cblas_her2 cblas_zher2 +#define cblas_hpr cblas_zhpr +#define cblas_her cblas_zher +#define cblas_gerc cblas_zgerc +#define cblas_geru cblas_zgeru +#define cblas_tpsv cblas_ztpsv +#define cblas_tbsv cblas_ztbsv +#define cblas_trsv cblas_ztrsv +#define cblas_tpmv cblas_ztpmv +#define cblas_tbmv cblas_ztbmv +#define cblas_trmv cblas_ztrmv +#define cblas_hpmv cblas_zhpmv +#define cblas_hbmv cblas_zhbmv +#define cblas_hemv cblas_zhemv +#define cblas_gbmv cblas_zgbmv +#define cblas_gemv cblas_zgemv +#define cblas_trsm cblas_ztrsm +#define cblas_trmm cblas_ztrmm +#define cblas_her2k cblas_zher2k +#define cblas_syr2k cblas_zsyr2k +#define cblas_herk cblas_zherk +#define cblas_syrk cblas_zsyrk +#define cblas_hemm cblas_zhemm +#define cblas_symm cblas_zsymm +#define cblas_gemm cblas_zgemm +#define cblas_iamax cblas_izamax +#define cblas_nrm2 cblas_dznrm2 +#define cblas_asum cblas_dzasum + +#endif diff --git a/kaldi_io/src/tools/ATLAS/include/atlas_enum.h b/kaldi_io/src/tools/ATLAS/include/atlas_enum.h new file mode 100644 index 0000000..3d638be --- /dev/null +++ b/kaldi_io/src/tools/ATLAS/include/atlas_enum.h @@ -0,0 +1,55 @@ +/* + * Automatically Tuned Linear Algebra Software v3.8.3 + * (C) Copyright 1997 R. Clint Whaley + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions + * are met: + * 1. Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * 2. Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions, and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * 3. The name of the ATLAS group or the names of its contributers may + * not be used to endorse or promote products derived from this + * software without specific written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + * ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED + * TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR + * PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE ATLAS GROUP OR ITS CONTRIBUTORS + * BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + * POSSIBILITY OF SUCH DAMAGE. + * + */ +#ifndef ATLAS_ENUM_H + #define ATLAS_ENUM_H + + #define CBLAS_ENUM_ONLY + #include "cblas.h" + #undef CBLAS_ENUM_ONLY + + #define ATLAS_ORDER CBLAS_ORDER + #define AtlasRowMajor CblasRowMajor + #define AtlasColMajor CblasColMajor + #define ATLAS_TRANS CBLAS_TRANSPOSE + #define AtlasNoTrans CblasNoTrans + #define AtlasTrans CblasTrans + #define AtlasConjTrans CblasConjTrans + #define ATLAS_UPLO CBLAS_UPLO + #define AtlasUpper CblasUpper + #define AtlasLower CblasLower + #define ATLAS_DIAG CBLAS_DIAG + #define AtlasNonUnit CblasNonUnit + #define AtlasUnit CblasUnit + #define ATLAS_SIDE CBLAS_SIDE + #define AtlasLeft CblasLeft + #define AtlasRight CblasRight + +#endif + diff --git a/kaldi_io/src/tools/ATLAS/include/atlas_f77.h b/kaldi_io/src/tools/ATLAS/include/atlas_f77.h new file mode 100644 index 0000000..1586fba --- /dev/null +++ b/kaldi_io/src/tools/ATLAS/include/atlas_f77.h @@ -0,0 +1,83 @@ +/* + * Automatically Tuned Linear Algebra Software v3.8.3 + * (C) Copyright 1997 R. Clint Whaley + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions + * are met: + * 1. Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * 2. Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions, and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * 3. The name of the ATLAS group or the names of its contributers may + * not be used to endorse or promote products derived from this + * software without specific written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + * ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED + * TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR + * PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE ATLAS GROUP OR ITS CONTRIBUTORS + * BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + * POSSIBILITY OF SUCH DAMAGE. + * + */ +#ifndef ATLAS_F77_H +#define ATLAS_F77_H + + #ifndef ATL_F77_SUBROUTINE + #define ATL_F77_SUBROUTINE void + #endif + #ifndef F77_INTEGER + #define F77_INTEGER int + #else + #define ATL_FunkyInts + #endif + #if defined(CRAY) + #define UseTransChar 1 + #include <fortran.h> + #define F77_CHAR _fcd + #define ATL_F2C_TransChar(c) (*(_fcdtocp(c) )) + #define ATL_C2F_TransChar(c) (_cptofcd(&(c), 1)) + #elif defined(StringStructVal) + typedef struct {char *cp; F77_INTEGER len;} F77_CHAR; + #define ATL_F2C_TransChar(c) (*(c.cp)) + #define UseTransChar 2 + #elif defined(StringStructPtr) + typedef struct {char *cp; F77_INTEGER len;} F77_CHAR; + #define ATL_F2C_TransChar(c) (*(c->cp)) + #define UseTransChar 3 + #else + #define ATL_DeclareSlens + #define F77_CHAR char * + #define ATL_F2C_TransChar(c) (*(c)) + #define ATL_C2F_TransChar(c) (&(c)) + #define ATL_STRLEN_1 ,F77_INTEGER ATL_Slen1 + #define ATL_STRLEN_2 ,F77_INTEGER ATL_Slen1, F77_INTEGER ATL_Slen2 + #define ATL_STRLEN_3 ,F77_INTEGER ATL_Slen1, F77_INTEGER ATL_Slen2, \ + F77_INTEGER ATL_Slen3 + #define ATL_STRLEN_4 ,F77_INTEGER ATL_Slen1, F77_INTEGER ATL_Slen2, \ + F77_INTEGER ATL_Slen3, F77_INTEGER ATL_Slen4 + #define ATL_STRLEN_1_para ,ATL_Slen1 + #define ATL_STRLEN_2_para ,ATL_Slen1, ATL_Slen2 + #define ATL_STRLEN_3_para ,ATL_Slen1, ATL_Slen2, ATL_Slen3 + #define ATL_STRLEN_4_para ,ATL_Slen1, ATL_Slen2, ATL_Slen3, ATL_Slen4 + #endif + + #ifndef ATL_STRLEN_1 + #define ATL_STRLEN_1 + #define ATL_STRLEN_2 + #define ATL_STRLEN_3 + #define ATL_STRLEN_4 + #define ATL_STRLEN_1_para + #define ATL_STRLEN_2_para + #define ATL_STRLEN_3_para + #define ATL_STRLEN_4_para + #endif + +#endif diff --git a/kaldi_io/src/tools/ATLAS/include/atlas_f77blas.h b/kaldi_io/src/tools/ATLAS/include/atlas_f77blas.h new file mode 100644 index 0000000..a7c109d --- /dev/null +++ b/kaldi_io/src/tools/ATLAS/include/atlas_f77blas.h @@ -0,0 +1,849 @@ +#ifndef ATLAS_F77_LVLS +#define ATLAS_F77_LVLS + +#include "atlas_f77.h" + +#if defined( StringSunStyle ) +#define F77_CHAR_DECL F77_CHAR /* input character*1 */ +#define F77_1_CHAR , F77_INTEGER +#define F77_2_CHAR F77_1_CHAR F77_1_CHAR +#define F77_3_CHAR F77_2_CHAR F77_1_CHAR +#define F77_4_CHAR F77_3_CHAR F77_1_CHAR +#elif defined( StringCrayStyle ) +#define F77_CHAR_DECL F77_CHAR /* input character*1 */ +#elif defined( StringStructVal ) +#define F77_CHAR_DECL F77_CHAR /* input character*1 */ +#elif defined( StringStructPtr ) +#define F77_CHAR_DECL F77_CHAR * /* input character*1 */ +#endif + +#ifndef F77_1_CHAR +#define F77_1_CHAR +#define F77_2_CHAR +#define F77_3_CHAR +#define F77_4_CHAR +#endif + +#ifndef F77_CHAR_DECL + #define F77_CHAR_DECL F77_CHAR * /* input character*1 */ +#endif + +#define F77_INT_DECL const F77_INTEGER * /* input integer */ + +#ifdef TREAL +#define F77_SIN_DECL const TYPE * /* input scalar */ +#define F77_SINOUT_DECL TYPE * /* input/output scalar */ +#define F77_RIN_DECL const TYPE * /* input real scalar */ +#define F77_RINOUT_DECL TYPE * /* input/output real scalar */ +#else +#define F77_SIN_DECL const TYPE * /* input scalar */ +#define F77_SINOUT_DECL TYPE * /* input/output scalar */ +#define F77_RIN_DECL const TYPE * /* input real scalar */ +#define F77_RINOUT_DECL TYPE * /* input/output real scalar */ +#endif + +#define F77_VIN_DECL const TYPE * /* input vector */ +#define F77_VINOUT_DECL TYPE * /* input/output matrix */ + +#define F77_MIN_DECL const TYPE * /* input matrix */ +#define F77_MINOUT_DECL TYPE * /* input/output matrix */ + +#if defined( CRAY ) +#define F77_VOID_FUN extern fortran void /* subroutine */ +#define F77_INT_FUN extern fortran int /* integer function */ +#define F77_TYPE_FUN extern fortran TYPE /* real function */ +#define F77_DBLE_FUN extern fortran double /* dble function */ +#else +#define F77_VOID_FUN extern void /* subroutine */ +#define F77_INT_FUN extern int /* integer function */ +#define F77_TYPE_FUN extern TYPE /* real function */ +#define F77_DBLE_FUN extern double /* dble function */ +#endif + +#if defined( NoChange ) +/* + * These defines set up the naming scheme required to have a FORTRAN + * routine called by a C routine with the following FORTRAN to C inter- + * face: + * + * FORTRAN DECLARATION C CALL + * SUBROUTINE DGEMM(...) dgemm(...) + */ +#if defined( SREAL ) + +#define F77rotg srotg +#define F77rotmg srotmg +#define F77nrm2 swrapnrm2 +#define F77asum swrapasum +#define F77amax isamax +#define F77scal sscal +#define F77axpy saxpy +#define F77axpby fatlas_saxpby +#define F77set fatlas_sset +#define F77copy scopy +#define F77swap sswap +#define F77rot srot +#define F77rotm srotm +#define F77dot swrapdot +#define F77dsdot dswrapdot +#define F77sdsdot sdswrapdot + +#define F77gemv sgemv +#define F77gbmv sgbmv +#define F77sbmv ssbmv +#define F77spmv sspmv +#define F77symv ssymv +#define F77tbmv stbmv +#define F77tpmv stpmv +#define F77trmv strmv +#define F77tbsv stbsv +#define F77tpsv stpsv +#define F77trsv strsv +#define F77ger sger +#define F77spr sspr +#define F77syr ssyr +#define F77spr2 sspr2 +#define F77syr2 ssyr2 + +#define F77gemm sgemm +#define F77symm ssymm +#define F77syrk ssyrk +#define F77syr2k ssyr2k +#define F77trmm strmm +#define F77trsm strsm + +#elif defined( DREAL ) + +#define F77rotg drotg +#define F77rotmg drotmg +#define F77nrm2 dwrapnrm2 +#define F77asum dwrapasum +#define F77amax idamax +#define F77scal dscal +#define F77axpy daxpy +#define F77axpby fatlas_daxpby +#define F77set fatlas_dset +#define F77copy dcopy +#define F77swap dswap +#define F77rot drot +#define F77rotm drotm +#define F77dot dwrapdot + +#define F77gemv dgemv +#define F77gbmv dgbmv +#define F77sbmv dsbmv +#define F77spmv dspmv +#define F77symv dsymv +#define F77tbmv dtbmv +#define F77tpmv dtpmv +#define F77trmv dtrmv +#define F77tbsv dtbsv +#define F77tpsv dtpsv +#define F77trsv dtrsv +#define F77ger dger +#define F77spr dspr +#define F77syr dsyr +#define F77spr2 dspr2 +#define F77syr2 dsyr2 + +#define F77gemm dgemm +#define F77symm dsymm +#define F77syrk dsyrk +#define F77syr2k dsyr2k +#define F77trmm dtrmm +#define F77trsm dtrsm + +#elif defined( SCPLX ) + +#define F77rotg crotg +#define F77nrm2 scwrapnrm2 +#define F77asum scwrapasum +#define F77amax icamax +#define F77scal cscal +#define F77rscal csscal +#define F77axpy caxpy +#define F77axpby fatlas_caxpby +#define F77set fatlas_cset +#define F77copy ccopy +#define F77swap cswap +#define F77rot csrot +#define F77dotc cwrapdotc +#define F77dotu cwrapdotu + +#define F77gbmv cgbmv +#define F77gemv cgemv +#define F77hbmv chbmv +#define F77hpmv chpmv +#define F77hemv chemv +#define F77tbmv ctbmv +#define F77tpmv ctpmv +#define F77trmv ctrmv +#define F77tbsv ctbsv +#define F77tpsv ctpsv +#define F77trsv ctrsv +#define F77gerc cgerc +#define F77geru cgeru +#define F77hpr chpr +#define F77her cher +#define F77hpr2 chpr2 +#define F77her2 cher2 + +#define F77gemm cgemm +#define F77hemm chemm +#define F77herk cherk +#define F77her2k cher2k +#define F77symm csymm +#define F77syrk csyrk +#define F77syr2k csyr2k +#define F77trmm ctrmm +#define F77trsm ctrsm + +#elif defined( DCPLX ) + +#define F77rotg zrotg +#define F77nrm2 dzwrapnrm2 +#define F77asum dzwrapasum +#define F77amax izamax +#define F77scal zscal +#define F77rscal zdscal +#define F77axpy zaxpy +#define F77axpby fatlas_zaxpby +#define F77set fatlas_zset +#define F77copy zcopy +#define F77swap zswap +#define F77rot zdrot +#define F77dotc zwrapdotc +#define F77dotu zwrapdotu + +#define F77gbmv zgbmv +#define F77gemv zgemv +#define F77hbmv zhbmv +#define F77hpmv zhpmv +#define F77hemv zhemv +#define F77tbmv ztbmv +#define F77tpmv ztpmv +#define F77trmv ztrmv +#define F77tbsv ztbsv +#define F77tpsv ztpsv +#define F77trsv ztrsv +#define F77gerc zgerc +#define F77geru zgeru +#define F77hpr zhpr +#define F77her zher +#define F77hpr2 zhpr2 +#define F77her2 zher2 + +#define F77gemm zgemm +#define F77hemm zhemm +#define F77herk zherk +#define F77her2k zher2k +#define F77symm zsymm +#define F77syrk zsyrk +#define F77syr2k zsyr2k +#define F77trmm ztrmm +#define F77trsm ztrsm + +#endif + +#elif defined( UpCase ) +/* + * These defines set up the naming scheme required to have a FORTRAN + * routine called by a C routine with the following FORTRAN to C inter- + * face: + * + * FORTRAN DECLARATION C CALL + * SUBROUTINE DGEMM(...) DGEMM(...) + */ +#if defined( SREAL ) + +#define F77rotg SROTG +#define F77rotmg SROTMG +#define F77nrm2 SWRAPNRM2 +#define F77asum SWRAPASUM +#define F77amax ISAMAX +#define F77scal SSCAL +#define F77axpy SAXPY +#define F77axpby FATLAS_SAXPBY +#define F77set FATLAS_SSET +#define F77copy SCOPY +#define F77swap SSWAP +#define F77rot SROT +#define F77rotm SROTM +#define F77dot SWRAPDOT +#define F77dsdot DSWRAPDOT +#define F77sdsdot SDSWRAPDOT + +#define F77gemv SGEMV +#define F77gbmv SGBMV +#define F77sbmv SSBMV +#define F77spmv SSPMV +#define F77symv SSYMV +#define F77tbmv STBMV +#define F77tpmv STPMV +#define F77trmv STRMV +#define F77tbsv STBSV +#define F77tpsv STPSV +#define F77trsv STRSV +#define F77ger SGER +#define F77spr SSPR +#define F77syr SSYR +#define F77spr2 SSPR2 +#define F77syr2 SSYR2 + +#define F77gemm SGEMM +#define F77symm SSYMM +#define F77syrk SSYRK +#define F77syr2k SSYR2K +#define F77trmm STRMM +#define F77trsm STRSM + +#elif defined( DREAL ) + +#define F77rotg DROTG +#define F77rotmg DROTMG +#define F77nrm2 DWRAPNRM2 +#define F77asum DWRAPASUM +#define F77amax IDAMAX +#define F77scal DSCAL +#define F77axpy DAXPY +#define F77axpby FATLAS_DAXPBY +#define F77set FATLAS_DSET +#define F77copy DCOPY +#define F77swap DSWAP +#define F77rot DROT +#define F77rotm DROTM +#define F77dot DWRAPDOT + +#define F77gemv DGEMV +#define F77gbmv DGBMV +#define F77sbmv DSBMV +#define F77spmv DSPMV +#define F77symv DSYMV +#define F77tbmv DTBMV +#define F77tpmv DTPMV +#define F77trmv DTRMV +#define F77tbsv DTBSV +#define F77tpsv DTPSV +#define F77trsv DTRSV +#define F77ger DGER +#define F77spr DSPR +#define F77syr DSYR +#define F77spr2 DSPR2 +#define F77syr2 DSYR2 + +#define F77gemm DGEMM +#define F77symm DSYMM +#define F77syrk DSYRK +#define F77syr2k DSYR2K +#define F77trmm DTRMM +#define F77trsm DTRSM + +#elif defined( SCPLX ) + +#define F77rotg CROTG +#define F77nrm2 SCWRAPNRM2 +#define F77asum SCWRAPASUM +#define F77amax ICAMAX +#define F77scal CSCAL +#define F77rscal CSSCAL +#define F77axpy CAXPY +#define F77axpby FATLAS_CAXPBY +#define F77set FATLAS_CSET +#define F77copy CCOPY +#define F77swap CSWAP +#define F77rot CSROT +#define F77dotc CWRAPDOTC +#define F77dotu CWRAPDOTU + +#define F77gbmv CGBMV +#define F77gemv CGEMV +#define F77hbmv CHBMV +#define F77hpmv CHPMV +#define F77hemv CHEMV +#define F77tbmv CTBMV +#define F77tpmv CTPMV +#define F77trmv CTRMV +#define F77tbsv CTBSV +#define F77tpsv CTPSV +#define F77trsv CTRSV +#define F77gerc CGERC +#define F77geru CGERU +#define F77hpr CHPR +#define F77her CHER +#define F77hpr2 CHPR2 +#define F77her2 CHER2 + +#define F77gemm CGEMM +#define F77hemm CHEMM +#define F77herk CHERK +#define F77her2k CHER2K +#define F77symm CSYMM +#define F77syrk CSYRK +#define F77syr2k CSYR2K +#define F77trmm CTRMM +#define F77trsm CTRSM + +#elif defined( DCPLX ) + +#define F77rotg ZROTG +#define F77nrm2 DZWRAPNRM2 +#define F77asum DZWRAPASUM +#define F77amax IZAMAX +#define F77scal ZSCAL +#define F77rscal ZDSCAL +#define F77axpy ZAXPY +#define F77axpby FATLAS_ZAXPBY +#define F77set FATLAS_ZSET +#define F77copy ZCOPY +#define F77swap ZSWAP +#define F77rot ZDROT +#define F77dotc ZWRAPDOTC +#define F77dotu ZWRAPDOTU + +#define F77gbmv ZGBMV +#define F77gemv ZGEMV +#define F77hbmv ZHBMV +#define F77hpmv ZHPMV +#define F77hemv ZHEMV +#define F77tbmv ZTBMV +#define F77tpmv ZTPMV +#define F77trmv ZTRMV +#define F77tbsv ZTBSV +#define F77tpsv ZTPSV +#define F77trsv ZTRSV +#define F77gerc ZGERC +#define F77geru ZGERU +#define F77hpr ZHPR +#define F77her ZHER +#define F77hpr2 ZHPR2 +#define F77her2 ZHER2 + +#define F77gemm ZGEMM +#define F77hemm ZHEMM +#define F77herk ZHERK +#define F77her2k ZHER2K +#define F77symm ZSYMM +#define F77syrk ZSYRK +#define F77syr2k ZSYR2K +#define F77trmm ZTRMM +#define F77trsm ZTRSM + +#endif + +#elif defined( Add_ ) || defined( Add__ ) +/* + * These defines set up the naming scheme required to have a FORTRAN + * routine called by a C routine with the following FORTRAN to C inter- + * face: + * + * FORTRAN DECLARATION C CALL + * SUBROUTINE DGEMM(...) dgemm_(...) + */ +#if defined( SREAL ) + +#define F77rotg srotg_ +#define F77rotmg srotmg_ +#define F77nrm2 swrapnrm2_ +#define F77asum swrapasum_ +#define F77amax isamax_ +#define F77scal sscal_ +#define F77axpy saxpy_ +#ifdef Add_ + #define F77axpby fatlas_saxpby_ + #define F77set fatlas_sset_ +#else + #define F77axpby fatlas_saxpby__ + #define F77set fatlas_sset__ +#endif +#define F77copy scopy_ +#define F77swap sswap_ +#define F77rot srot_ +#define F77rotm srotm_ +#define F77dot swrapdot_ +#define F77dsdot dswrapdot_ +#define F77sdsdot sdswrapdot_ + +#define F77gemv sgemv_ +#define F77gbmv sgbmv_ +#define F77sbmv ssbmv_ +#define F77spmv sspmv_ +#define F77symv ssymv_ +#define F77tbmv stbmv_ +#define F77tpmv stpmv_ +#define F77trmv strmv_ +#define F77tbsv stbsv_ +#define F77tpsv stpsv_ +#define F77trsv strsv_ +#define F77ger sger_ +#define F77spr sspr_ +#define F77syr ssyr_ +#define F77spr2 sspr2_ +#define F77syr2 ssyr2_ + +#define F77gemm sgemm_ +#define F77symm ssymm_ +#define F77syrk ssyrk_ +#define F77syr2k ssyr2k_ +#define F77trmm strmm_ +#define F77trsm strsm_ + +#elif defined( DREAL ) + +#define F77rotg drotg_ +#define F77rotmg drotmg_ +#define F77nrm2 dwrapnrm2_ +#define F77asum dwrapasum_ +#define F77amax idamax_ +#define F77scal dscal_ +#define F77axpy daxpy_ +#ifdef Add_ + #define F77axpby fatlas_daxpby_ + #define F77set fatlas_dset_ +#else + #define F77axpby fatlas_daxpby__ + #define F77set fatlas_dset__ +#endif +#define F77copy dcopy_ +#define F77swap dswap_ +#define F77rot drot_ +#define F77rotm drotm_ +#define F77dot dwrapdot_ + +#define F77gemv dgemv_ +#define F77gbmv dgbmv_ +#define F77sbmv dsbmv_ +#define F77spmv dspmv_ +#define F77symv dsymv_ +#define F77tbmv dtbmv_ +#define F77tpmv dtpmv_ +#define F77trmv dtrmv_ +#define F77tbsv dtbsv_ +#define F77tpsv dtpsv_ +#define F77trsv dtrsv_ +#define F77ger dger_ +#define F77spr dspr_ +#define F77syr dsyr_ +#define F77spr2 dspr2_ +#define F77syr2 dsyr2_ + +#define F77gemm dgemm_ +#define F77symm dsymm_ +#define F77syrk dsyrk_ +#define F77syr2k dsyr2k_ +#define F77trmm dtrmm_ +#define F77trsm dtrsm_ + +#elif defined( SCPLX ) + +#define F77rotg crotg_ +#define F77nrm2 scwrapnrm2_ +#define F77asum scwrapasum_ +#define F77amax icamax_ +#define F77scal cscal_ +#define F77rscal csscal_ +#define F77axpy caxpy_ +#ifdef Add_ + #define F77axpby fatlas_caxpby_ + #define F77set fatlas_cset_ +#else + #define F77axpby fatlas_caxpby__ + #define F77set fatlas_cset__ +#endif +#define F77copy ccopy_ +#define F77swap cswap_ +#define F77rot csrot_ +#define F77dotc cwrapdotc_ +#define F77dotu cwrapdotu_ + +#define F77gbmv cgbmv_ +#define F77gemv cgemv_ +#define F77hbmv chbmv_ +#define F77hpmv chpmv_ +#define F77hemv chemv_ +#define F77tbmv ctbmv_ +#define F77tpmv ctpmv_ +#define F77trmv ctrmv_ +#define F77tbsv ctbsv_ +#define F77tpsv ctpsv_ +#define F77trsv ctrsv_ +#define F77gerc cgerc_ +#define F77geru cgeru_ +#define F77hpr chpr_ +#define F77her cher_ +#define F77hpr2 chpr2_ +#define F77her2 cher2_ + +#define F77gemm cgemm_ +#define F77hemm chemm_ +#define F77herk cherk_ +#define F77her2k cher2k_ +#define F77symm csymm_ +#define F77syrk csyrk_ +#define F77syr2k csyr2k_ +#define F77trmm ctrmm_ +#define F77trsm ctrsm_ + +#elif defined( DCPLX ) + +#define F77rotg zrotg_ +#define F77nrm2 dzwrapnrm2_ +#define F77asum dzwrapasum_ +#define F77amax izamax_ +#define F77scal zscal_ +#define F77rscal zdscal_ +#define F77axpy zaxpy_ +#ifdef Add_ + #define F77axpby fatlas_zaxpby_ + #define F77set fatlas_zset_ +#else + #define F77axpby fatlas_zaxpby__ + #define F77set fatlas_zset__ +#endif +#define F77copy zcopy_ +#define F77swap zswap_ +#define F77rot zdrot_ +#define F77dotc zwrapdotc_ +#define F77dotu zwrapdotu_ + +#define F77gbmv zgbmv_ +#define F77gemv zgemv_ +#define F77hbmv zhbmv_ +#define F77hpmv zhpmv_ +#define F77hemv zhemv_ +#define F77tbmv ztbmv_ +#define F77tpmv ztpmv_ +#define F77trmv ztrmv_ +#define F77tbsv ztbsv_ +#define F77tpsv ztpsv_ +#define F77trsv ztrsv_ +#define F77gerc zgerc_ +#define F77geru zgeru_ +#define F77hpr zhpr_ +#define F77her zher_ +#define F77hpr2 zhpr2_ +#define F77her2 zher2_ + +#define F77gemm zgemm_ +#define F77hemm zhemm_ +#define F77herk zherk_ +#define F77her2k zher2k_ +#define F77symm zsymm_ +#define F77syrk zsyrk_ +#define F77syr2k zsyr2k_ +#define F77trmm ztrmm_ +#define F77trsm ztrsm_ + +#endif + +#endif + +#ifdef TREAL +F77_VOID_FUN F77rotg +( F77_SINOUT_DECL, F77_SINOUT_DECL, F77_SINOUT_DECL, F77_SINOUT_DECL ); +F77_VOID_FUN F77rotmg +( F77_SINOUT_DECL, F77_SINOUT_DECL, F77_SINOUT_DECL, F77_SIN_DECL, + F77_VINOUT_DECL ); +#else +F77_VOID_FUN F77rotg +( F77_SINOUT_DECL, F77_SIN_DECL, F77_SINOUT_DECL, F77_SINOUT_DECL ); +#endif +F77_VOID_FUN F77nrm2 +( F77_INT_DECL, F77_VIN_DECL, F77_INT_DECL, F77_RINOUT_DECL ); +F77_VOID_FUN F77asum +( F77_INT_DECL, F77_VIN_DECL, F77_INT_DECL, F77_RINOUT_DECL ); +F77_INT_FUN F77amax +( F77_INT_DECL, F77_VIN_DECL, F77_INT_DECL ); +F77_VOID_FUN F77scal +( F77_INT_DECL, F77_SIN_DECL, F77_VINOUT_DECL, F77_INT_DECL ); +#ifdef TCPLX +F77_VOID_FUN F77rscal +( F77_INT_DECL, F77_RIN_DECL, F77_VINOUT_DECL, F77_INT_DECL ); +#endif +void F77set +( F77_INT_DECL, F77_SIN_DECL, F77_VINOUT_DECL, F77_INT_DECL ); +void F77axpby +( F77_INT_DECL, F77_SIN_DECL, F77_VIN_DECL, F77_INT_DECL, + F77_SIN_DECL, F77_VINOUT_DECL, F77_INT_DECL ); +F77_VOID_FUN F77axpy +( F77_INT_DECL, F77_SIN_DECL, F77_VIN_DECL, F77_INT_DECL, + F77_VINOUT_DECL, F77_INT_DECL ); +F77_VOID_FUN F77copy +( F77_INT_DECL, F77_VIN_DECL, F77_INT_DECL, F77_VINOUT_DECL, + F77_INT_DECL ); +F77_VOID_FUN F77swap +( F77_INT_DECL, F77_VINOUT_DECL, F77_INT_DECL, F77_VINOUT_DECL, + F77_INT_DECL ); +F77_VOID_FUN F77rot +( F77_INT_DECL, F77_VINOUT_DECL, F77_INT_DECL, F77_VINOUT_DECL, + F77_INT_DECL, F77_SIN_DECL, F77_SIN_DECL ); +#ifdef TREAL +F77_VOID_FUN F77rotm +( F77_INT_DECL, F77_VINOUT_DECL, F77_INT_DECL, F77_VINOUT_DECL, + F77_INT_DECL, F77_VIN_DECL ); +#endif +#ifdef TREAL +F77_VOID_FUN F77dot +( F77_INT_DECL, F77_VIN_DECL, F77_INT_DECL, F77_VIN_DECL, + F77_INT_DECL, F77_SINOUT_DECL ); +#ifdef SREAL +F77_VOID_FUN F77dsdot +( F77_INT_DECL, F77_VIN_DECL, F77_INT_DECL, F77_VIN_DECL, + F77_INT_DECL, double * ); +F77_VOID_FUN F77sdsdot +( F77_INT_DECL, F77_SIN_DECL, F77_VIN_DECL, F77_INT_DECL, + F77_VIN_DECL, F77_INT_DECL, F77_SINOUT_DECL ); +#endif +#else +F77_VOID_FUN F77dotc +( F77_INT_DECL, F77_VIN_DECL, F77_INT_DECL, F77_VIN_DECL, + F77_INT_DECL, F77_SINOUT_DECL ); +F77_VOID_FUN F77dotu +( F77_INT_DECL, F77_VIN_DECL, F77_INT_DECL, F77_VIN_DECL, + F77_INT_DECL, F77_SINOUT_DECL ); +#endif + +F77_VOID_FUN F77gbmv +( F77_CHAR_DECL, F77_INT_DECL, F77_INT_DECL, F77_INT_DECL, + F77_INT_DECL, F77_SIN_DECL, F77_MIN_DECL, F77_INT_DECL, + F77_VIN_DECL, F77_INT_DECL, F77_SIN_DECL, F77_VINOUT_DECL, + F77_INT_DECL F77_1_CHAR ); +F77_VOID_FUN F77gemv +( F77_CHAR_DECL, F77_INT_DECL, F77_INT_DECL, + F77_SIN_DECL, F77_MIN_DECL, F77_INT_DECL, + F77_VIN_DECL, F77_INT_DECL, F77_SIN_DECL, F77_VINOUT_DECL, + F77_INT_DECL F77_1_CHAR ); +#ifdef TREAL +F77_VOID_FUN F77ger +( F77_INT_DECL, F77_INT_DECL, F77_SIN_DECL, F77_VIN_DECL, + F77_INT_DECL, F77_VIN_DECL, F77_INT_DECL, F77_MINOUT_DECL, + F77_INT_DECL ); +F77_VOID_FUN F77sbmv +( F77_CHAR_DECL, F77_INT_DECL, F77_INT_DECL, F77_SIN_DECL, + F77_MIN_DECL, F77_INT_DECL, F77_VIN_DECL, F77_INT_DECL, + F77_SIN_DECL, F77_VINOUT_DECL, F77_INT_DECL F77_1_CHAR ); +F77_VOID_FUN F77spmv +( F77_CHAR_DECL, F77_INT_DECL, F77_SIN_DECL, + F77_MIN_DECL, F77_VIN_DECL, F77_INT_DECL, + F77_SIN_DECL, F77_VINOUT_DECL, F77_INT_DECL F77_1_CHAR ); +F77_VOID_FUN F77symv +( F77_CHAR_DECL, F77_INT_DECL, F77_SIN_DECL, + F77_MIN_DECL, F77_INT_DECL, F77_VIN_DECL, F77_INT_DECL, + F77_SIN_DECL, F77_VINOUT_DECL, F77_INT_DECL F77_1_CHAR ); +F77_VOID_FUN F77spr +( F77_CHAR_DECL, F77_INT_DECL, F77_SIN_DECL, F77_VIN_DECL, + F77_INT_DECL, F77_MINOUT_DECL F77_1_CHAR ); +F77_VOID_FUN F77syr +( F77_CHAR_DECL, F77_INT_DECL, F77_SIN_DECL, F77_VIN_DECL, + F77_INT_DECL, F77_MINOUT_DECL, F77_INT_DECL F77_1_CHAR ); +F77_VOID_FUN F77spr2 +( F77_CHAR_DECL, F77_INT_DECL, F77_SIN_DECL, F77_VIN_DECL, + F77_INT_DECL, F77_VIN_DECL, F77_INT_DECL, F77_MINOUT_DECL + F77_1_CHAR ); +F77_VOID_FUN F77syr2 +( F77_CHAR_DECL, F77_INT_DECL, F77_SIN_DECL, F77_VIN_DECL, + F77_INT_DECL, F77_VIN_DECL, F77_INT_DECL, F77_MINOUT_DECL, + F77_INT_DECL F77_1_CHAR ); +#else +F77_VOID_FUN F77gerc +( F77_INT_DECL, F77_INT_DECL, F77_SIN_DECL, F77_VIN_DECL, + F77_INT_DECL, F77_VIN_DECL, F77_INT_DECL, F77_MINOUT_DECL, + F77_INT_DECL ); +F77_VOID_FUN F77geru +( F77_INT_DECL, F77_INT_DECL, F77_SIN_DECL, F77_VIN_DECL, + F77_INT_DECL, F77_VIN_DECL, F77_INT_DECL, F77_MINOUT_DECL, + F77_INT_DECL ); +F77_VOID_FUN F77hbmv +( F77_CHAR_DECL, F77_INT_DECL, F77_INT_DECL, F77_SIN_DECL, + F77_MIN_DECL, F77_INT_DECL, F77_VIN_DECL, F77_INT_DECL, + F77_SIN_DECL, F77_VINOUT_DECL, F77_INT_DECL F77_1_CHAR ); +F77_VOID_FUN F77hpmv +( F77_CHAR_DECL, F77_INT_DECL, F77_SIN_DECL, + F77_MIN_DECL, F77_VIN_DECL, F77_INT_DECL, + F77_SIN_DECL, F77_VINOUT_DECL, F77_INT_DECL F77_1_CHAR ); +F77_VOID_FUN F77hemv +( F77_CHAR_DECL, F77_INT_DECL, F77_SIN_DECL, + F77_MIN_DECL, F77_INT_DECL, F77_VIN_DECL, F77_INT_DECL, + F77_SIN_DECL, F77_VINOUT_DECL, F77_INT_DECL F77_1_CHAR ); +F77_VOID_FUN F77hpr +( F77_CHAR_DECL, F77_INT_DECL, F77_RIN_DECL, F77_VIN_DECL, + F77_INT_DECL, F77_MINOUT_DECL F77_1_CHAR ); +F77_VOID_FUN F77her +( F77_CHAR_DECL, F77_INT_DECL, F77_RIN_DECL, F77_VIN_DECL, + F77_INT_DECL, F77_MINOUT_DECL, F77_INT_DECL F77_1_CHAR ); +F77_VOID_FUN F77hpr2 +( F77_CHAR_DECL, F77_INT_DECL, F77_SIN_DECL, F77_VIN_DECL, + F77_INT_DECL, F77_VIN_DECL, F77_INT_DECL, F77_MINOUT_DECL + F77_1_CHAR ); +F77_VOID_FUN F77her2 +( F77_CHAR_DECL, F77_INT_DECL, F77_SIN_DECL, F77_VIN_DECL, + F77_INT_DECL, F77_VIN_DECL, F77_INT_DECL, F77_MINOUT_DECL, + F77_INT_DECL F77_1_CHAR ); +#endif +F77_VOID_FUN F77tbmv +( F77_CHAR_DECL, F77_CHAR_DECL, F77_CHAR_DECL, F77_INT_DECL, + F77_INT_DECL, F77_MIN_DECL, F77_INT_DECL, F77_VINOUT_DECL, + F77_INT_DECL F77_3_CHAR ); +F77_VOID_FUN F77tpmv +( F77_CHAR_DECL, F77_CHAR_DECL, F77_CHAR_DECL, F77_INT_DECL, + F77_MIN_DECL, F77_VINOUT_DECL, + F77_INT_DECL F77_3_CHAR ); +F77_VOID_FUN F77trmv +( F77_CHAR_DECL, F77_CHAR_DECL, F77_CHAR_DECL, F77_INT_DECL, + F77_MIN_DECL, F77_INT_DECL, F77_VINOUT_DECL, + F77_INT_DECL F77_3_CHAR ); +F77_VOID_FUN F77tbsv +( F77_CHAR_DECL, F77_CHAR_DECL, F77_CHAR_DECL, F77_INT_DECL, + F77_INT_DECL, F77_MIN_DECL, F77_INT_DECL, F77_VINOUT_DECL, + F77_INT_DECL F77_3_CHAR ); +F77_VOID_FUN F77tpsv +( F77_CHAR_DECL, F77_CHAR_DECL, F77_CHAR_DECL, F77_INT_DECL, + F77_MIN_DECL, F77_VINOUT_DECL, + F77_INT_DECL F77_3_CHAR ); +F77_VOID_FUN F77trsv +( F77_CHAR_DECL, F77_CHAR_DECL, F77_CHAR_DECL, F77_INT_DECL, + F77_MIN_DECL, F77_INT_DECL, F77_VINOUT_DECL, + F77_INT_DECL F77_3_CHAR ); + +F77_VOID_FUN F77gemm +( F77_CHAR_DECL, F77_CHAR_DECL, F77_INT_DECL, F77_INT_DECL, + F77_INT_DECL, F77_SIN_DECL, F77_MIN_DECL, F77_INT_DECL, + F77_MIN_DECL, F77_INT_DECL, F77_SIN_DECL, F77_MINOUT_DECL, + F77_INT_DECL F77_2_CHAR ); +F77_VOID_FUN F77hemm +( F77_CHAR_DECL, F77_CHAR_DECL, F77_INT_DECL, F77_INT_DECL, + F77_SIN_DECL, F77_MIN_DECL, F77_INT_DECL, F77_MIN_DECL, + F77_INT_DECL, F77_SIN_DECL, F77_MINOUT_DECL, F77_INT_DECL + F77_2_CHAR ); +F77_VOID_FUN F77her2k +( F77_CHAR_DECL, F77_CHAR_DECL, F77_INT_DECL, F77_INT_DECL, + F77_SIN_DECL, F77_MIN_DECL, F77_INT_DECL, F77_MIN_DECL, + F77_INT_DECL, F77_RIN_DECL, F77_MINOUT_DECL, F77_INT_DECL + F77_2_CHAR ); +F77_VOID_FUN F77herk +( F77_CHAR_DECL, F77_CHAR_DECL, F77_INT_DECL, F77_INT_DECL, + F77_RIN_DECL, F77_MIN_DECL, F77_INT_DECL, F77_RIN_DECL, + F77_MINOUT_DECL, F77_INT_DECL F77_2_CHAR ); +F77_VOID_FUN F77symm +( F77_CHAR_DECL, F77_CHAR_DECL, F77_INT_DECL, F77_INT_DECL, + F77_SIN_DECL, F77_MIN_DECL, F77_INT_DECL, F77_MIN_DECL, + F77_INT_DECL, F77_SIN_DECL, F77_MINOUT_DECL, F77_INT_DECL + F77_2_CHAR ); +F77_VOID_FUN F77syr2k +( F77_CHAR_DECL, F77_CHAR_DECL, F77_INT_DECL, F77_INT_DECL, + F77_SIN_DECL, F77_MIN_DECL, F77_INT_DECL, F77_MIN_DECL, + F77_INT_DECL, F77_SIN_DECL, F77_MINOUT_DECL, F77_INT_DECL + F77_2_CHAR ); +F77_VOID_FUN F77syrk +( F77_CHAR_DECL, F77_CHAR_DECL, F77_INT_DECL, F77_INT_DECL, + F77_SIN_DECL, F77_MIN_DECL, F77_INT_DECL, F77_SIN_DECL, + F77_MINOUT_DECL, F77_INT_DECL F77_2_CHAR ); +F77_VOID_FUN F77trmm +( F77_CHAR_DECL, F77_CHAR_DECL, F77_CHAR_DECL, F77_CHAR_DECL, + F77_INT_DECL, F77_INT_DECL, F77_SIN_DECL, F77_MIN_DECL, + F77_INT_DECL, F77_MINOUT_DECL, F77_INT_DECL F77_4_CHAR ); +F77_VOID_FUN F77trsm +( F77_CHAR_DECL, F77_CHAR_DECL, F77_CHAR_DECL, F77_CHAR_DECL, + F77_INT_DECL, F77_INT_DECL, F77_SIN_DECL, F77_MIN_DECL, + F77_INT_DECL, F77_MINOUT_DECL, F77_INT_DECL F77_4_CHAR ); + +#endif diff --git a/kaldi_io/src/tools/ATLAS/include/atlas_f77wrap.h b/kaldi_io/src/tools/ATLAS/include/atlas_f77wrap.h new file mode 100644 index 0000000..db6099c --- /dev/null +++ b/kaldi_io/src/tools/ATLAS/include/atlas_f77wrap.h @@ -0,0 +1,1088 @@ +/* --------------------------------------------------------------------- + * + * -- Automatically Tuned Linear Algebra Software (ATLAS) + * (C) Copyright 2000 All Rights Reserved + * + * -- ATLAS routine -- Version 3.2 -- December 25, 2000 + * + * Author : Antoine P. Petitet + * Originally developed at the University of Tennessee, + * Innovative Computing Laboratory, Knoxville TN, 37996-1301, USA. + * + * --------------------------------------------------------------------- + * + * -- Copyright notice and Licensing terms: + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions + * are met: + * + * 1. Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * 2. Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions, and the following disclaimer in + * the documentation and/or other materials provided with the distri- + * bution. + * 3. The name of the University, the ATLAS group, or the names of its + * contributors may not be used to endorse or promote products deri- + * ved from this software without specific written permission. + * + * -- Disclaimer: + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + * ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE UNIVERSITY + * OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPE- + * CIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED + * TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, + * OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEO- + * RY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (IN- + * CLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF + * THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + * --------------------------------------------------------------------- + */ +#ifndef ATLAS_F77WRAP_H +#define ATLAS_F77WRAP_H +/* + * ===================================================================== + * Include Files + * ===================================================================== + */ +#include "atlas_misc.h" +#include "atlas_f77.h" +/* + * ===================================================================== + * Multi-threaded/reference implementation function names re-definition + * ===================================================================== + * + * Uncomment the following definition macros to call the multi-threaded + * implementation or define those macros at compile time. + * + * #define USE_L1_PTHREADS + * #define USE_L2_PTHREADS + * #define USE_L3_PTHREADS + * + * Uncomment the following definition macros to call the reference im- + * plementation or define those macros at compile time. + * + * #define USE_L1_REFERENCE + * #define USE_L2_REFERENCE + * #define USE_L3_REFERENCE + * + * ===================================================================== + */ + +#ifdef ATL_USEPTHREADS +#define USE_L3_PTHREADS +#endif + +/* + * ===================================================================== + * ATLAS Levels 1, 2 and 3 Prototypes + * ===================================================================== + */ +#if defined( USE_L1_PTHREADS ) +#include "atlas_ptalias1.h" +#include "atlas_ptlevel1.h" +#elif defined( USE_L1_REFERENCE ) +#include "atlas_refalias1.h" +#include "atlas_reflevel1.h" +#else +#include "atlas_level1.h" +#endif + +#if defined( USE_L2_PTHREADS ) +#include "atlas_ptalias2.h" +#include "atlas_ptlevel2.h" +#elif defined( USE_L2_REFERENCE ) +#include "atlas_refalias2.h" +#include "atlas_reflevel2.h" +#else +#include "atlas_level2.h" +#endif + +#if defined( USE_L3_PTHREADS ) +#include "atlas_ptalias3.h" +#include "atlas_ptlevel3.h" +#elif defined( USE_L3_REFERENCE ) +#include "atlas_refalias3.h" +#include "atlas_reflevel3.h" +#else +#include "atlas_level3.h" +#endif +/* + * ===================================================================== + * #define macro constants + * ===================================================================== + */ +#define PATLF77WRAP Mjoin( ATL_F77wrap_, PRE ) + +#ifdef TREAL +#define ATLPUF77WRAP Mjoin( ATL_F77wrap_, PRE ) +#define ATLUPF77WRAP Mjoin( ATL_F77wrap_, PRE ) +#else +#define ATLPUF77WRAP Mjoin( Mjoin( ATL_F77wrap_, PRE ), UPR ) +#define ATLUPF77WRAP Mjoin( Mjoin( ATL_F77wrap_, UPR ), PRE ) +#endif + +#define F77_INOTRAN 111 +#define F77_ITRAN 112 +#define F77_ICOTRAN 113 + +#define F77_IUPPER 121 +#define F77_ILOWER 122 + +#define F77_INONUNIT 131 +#define F77_IUNIT 132 + +#define F77_ILEFT 141 +#define F77_IRIGHT 142 +/* + * ===================================================================== + * #define macro functions + * ===================================================================== + */ +#ifdef TREAL +#define V1N( n_, x_, incx_ ) \ + ( (*n_) > 0 ? (x_)+(1-(*n_))*(*incx_) : (x_) ) +#define VN1( n_, x_, incx_ ) \ + ( (*n_) > 0 ? (x_)+((*n_)-1)*(*incx_) : (x_) ) +#define W1N( n_, x_, incx_ ) \ + ( (*incx_) < 0 ? V1N( n_, x_, incx_ ) : (x_) ) +#else +#define V1N( n_, x_, incx_ ) \ + ( (*n_) > 0 ? (x_)+( ( (1-(*n_))*(*incx_) ) << 1 ) : (x_) ) +#define VN1( n_, x_, incx_ ) \ + ( (*n_) > 0 ? (x_)+( ( ((*n_)-1)*(*incx_) ) << 1 ) : (x_) ) +#define W1N( n_, x_, incx_ ) \ + ( (*incx_) < 0 ? V1N( n_, x_, incx_ ) : (x_) ) +#endif +/* + * ===================================================================== + * FORTRAN <-> C interface + * ===================================================================== + * + * These macros identifies how these wrappers will be called as follows: + * + * Add_: the FORTRAN compiler expects the name of C functions to be + * in all lower case and to have an underscore postfixed it (Suns, Intel + * compilers expect this). + * + * NoChange: the FORTRAN compiler expects the name of C functions to be + * in all lower case (IBM RS6K compilers do this). + * + * UpCase: the FORTRAN compiler expects the name of C functions to be + * in all upcase. (Cray compilers expect this). + * + * Add__: the FORTRAN compiler in use is f2c, a FORTRAN to C conver- + * ter. + */ +#if defined( Add_ ) +/* + * These defines set up the naming scheme required to have a FORTRAN + * routine calling a C routine. + * + * FORTRAN CALL C declaration + * CALL ATL_F77WRAP_SGEMM(...) void atl_f77wrap_sgemm_(...) + * + * This is the default. + */ +#if defined( SREAL ) + +#define ATL_F77wrap_srotg atl_f77wrap_srotg_ +#define ATL_F77wrap_srotmg atl_f77wrap_srotmg_ +#define ATL_F77wrap_snrm2 atl_f77wrap_snrm2_ +#define ATL_F77wrap_sasum atl_f77wrap_sasum_ +#define ATL_F77wrap_sscal atl_f77wrap_sscal_ +#define ATL_F77wrap_isamax atl_f77wrap_isamax_ +#define ATL_F77wrap_saxpy atl_f77wrap_saxpy_ +#define ATL_F77wrap_scopy atl_f77wrap_scopy_ +#define ATL_F77wrap_sswap atl_f77wrap_sswap_ +#define ATL_F77wrap_srot atl_f77wrap_srot_ +#define ATL_F77wrap_srotm atl_f77wrap_srotm_ +#define ATL_F77wrap_sdot atl_f77wrap_sdot_ +#define ATL_F77wrap_dsdot atl_f77wrap_dsdot_ +#define ATL_F77wrap_sdsdot atl_f77wrap_sdsdot_ + +#define ATL_F77wrap_sgbmv atl_f77wrap_sgbmv_ +#define ATL_F77wrap_sgemv atl_f77wrap_sgemv_ +#define ATL_F77wrap_sger atl_f77wrap_sger_ +#define ATL_F77wrap_ssbmv atl_f77wrap_ssbmv_ +#define ATL_F77wrap_sspmv atl_f77wrap_sspmv_ +#define ATL_F77wrap_ssymv atl_f77wrap_ssymv_ +#define ATL_F77wrap_sspr atl_f77wrap_sspr_ +#define ATL_F77wrap_ssyr atl_f77wrap_ssyr_ +#define ATL_F77wrap_sspr2 atl_f77wrap_sspr2_ +#define ATL_F77wrap_ssyr2 atl_f77wrap_ssyr2_ +#define ATL_F77wrap_stbmv atl_f77wrap_stbmv_ +#define ATL_F77wrap_stpmv atl_f77wrap_stpmv_ +#define ATL_F77wrap_strmv atl_f77wrap_strmv_ +#define ATL_F77wrap_stbsv atl_f77wrap_stbsv_ +#define ATL_F77wrap_stpsv atl_f77wrap_stpsv_ +#define ATL_F77wrap_strsv atl_f77wrap_strsv_ + +#define ATL_F77wrap_sgemm atl_f77wrap_sgemm_ +#define ATL_F77wrap_ssymm atl_f77wrap_ssymm_ +#define ATL_F77wrap_ssyrk atl_f77wrap_ssyrk_ +#define ATL_F77wrap_ssyr2k atl_f77wrap_ssyr2k_ +#define ATL_F77wrap_strmm atl_f77wrap_strmm_ +#define ATL_F77wrap_strsm atl_f77wrap_strsm_ + +#elif defined( DREAL ) + +#define ATL_F77wrap_drotg atl_f77wrap_drotg_ +#define ATL_F77wrap_drotmg atl_f77wrap_drotmg_ +#define ATL_F77wrap_dnrm2 atl_f77wrap_dnrm2_ +#define ATL_F77wrap_dasum atl_f77wrap_dasum_ +#define ATL_F77wrap_dscal atl_f77wrap_dscal_ +#define ATL_F77wrap_idamax atl_f77wrap_idamax_ +#define ATL_F77wrap_daxpy atl_f77wrap_daxpy_ +#define ATL_F77wrap_dcopy atl_f77wrap_dcopy_ +#define ATL_F77wrap_dswap atl_f77wrap_dswap_ +#define ATL_F77wrap_drot atl_f77wrap_drot_ +#define ATL_F77wrap_drotm atl_f77wrap_drotm_ +#define ATL_F77wrap_ddot atl_f77wrap_ddot_ + +#define ATL_F77wrap_dgbmv atl_f77wrap_dgbmv_ +#define ATL_F77wrap_dgemv atl_f77wrap_dgemv_ +#define ATL_F77wrap_dger atl_f77wrap_dger_ +#define ATL_F77wrap_dsbmv atl_f77wrap_dsbmv_ +#define ATL_F77wrap_dspmv atl_f77wrap_dspmv_ +#define ATL_F77wrap_dsymv atl_f77wrap_dsymv_ +#define ATL_F77wrap_dspr atl_f77wrap_dspr_ +#define ATL_F77wrap_dsyr atl_f77wrap_dsyr_ +#define ATL_F77wrap_dspr2 atl_f77wrap_dspr2_ +#define ATL_F77wrap_dsyr2 atl_f77wrap_dsyr2_ +#define ATL_F77wrap_dtbmv atl_f77wrap_dtbmv_ +#define ATL_F77wrap_dtpmv atl_f77wrap_dtpmv_ +#define ATL_F77wrap_dtrmv atl_f77wrap_dtrmv_ +#define ATL_F77wrap_dtbsv atl_f77wrap_dtbsv_ +#define ATL_F77wrap_dtpsv atl_f77wrap_dtpsv_ +#define ATL_F77wrap_dtrsv atl_f77wrap_dtrsv_ + +#define ATL_F77wrap_dgemm atl_f77wrap_dgemm_ +#define ATL_F77wrap_dsymm atl_f77wrap_dsymm_ +#define ATL_F77wrap_dsyrk atl_f77wrap_dsyrk_ +#define ATL_F77wrap_dsyr2k atl_f77wrap_dsyr2k_ +#define ATL_F77wrap_dtrmm atl_f77wrap_dtrmm_ +#define ATL_F77wrap_dtrsm atl_f77wrap_dtrsm_ + +#elif defined( SCPLX ) + +#define ATL_F77wrap_crotg atl_f77wrap_crotg_ +#define ATL_F77wrap_scnrm2 atl_f77wrap_scnrm2_ +#define ATL_F77wrap_scasum atl_f77wrap_scasum_ +#define ATL_F77wrap_cscal atl_f77wrap_cscal_ +#define ATL_F77wrap_csscal atl_f77wrap_csscal_ +#define ATL_F77wrap_icamax atl_f77wrap_icamax_ +#define ATL_F77wrap_caxpy atl_f77wrap_caxpy_ +#define ATL_F77wrap_ccopy atl_f77wrap_ccopy_ +#define ATL_F77wrap_cswap atl_f77wrap_cswap_ +#define ATL_F77wrap_csrot atl_f77wrap_csrot_ +#define ATL_F77wrap_cdotc atl_f77wrap_cdotc_ +#define ATL_F77wrap_cdotu atl_f77wrap_cdotu_ + +#define ATL_F77wrap_cgbmv atl_f77wrap_cgbmv_ +#define ATL_F77wrap_cgemv atl_f77wrap_cgemv_ +#define ATL_F77wrap_cgerc atl_f77wrap_cgerc_ +#define ATL_F77wrap_cgeru atl_f77wrap_cgeru_ +#define ATL_F77wrap_chbmv atl_f77wrap_chbmv_ +#define ATL_F77wrap_chpmv atl_f77wrap_chpmv_ +#define ATL_F77wrap_chemv atl_f77wrap_chemv_ +#define ATL_F77wrap_chpr atl_f77wrap_chpr_ +#define ATL_F77wrap_cher atl_f77wrap_cher_ +#define ATL_F77wrap_chpr2 atl_f77wrap_chpr2_ +#define ATL_F77wrap_cher2 atl_f77wrap_cher2_ +#define ATL_F77wrap_ctbmv atl_f77wrap_ctbmv_ +#define ATL_F77wrap_ctpmv atl_f77wrap_ctpmv_ +#define ATL_F77wrap_ctrmv atl_f77wrap_ctrmv_ +#define ATL_F77wrap_ctbsv atl_f77wrap_ctbsv_ +#define ATL_F77wrap_ctpsv atl_f77wrap_ctpsv_ +#define ATL_F77wrap_ctrsv atl_f77wrap_ctrsv_ + +#define ATL_F77wrap_cgemm atl_f77wrap_cgemm_ +#define ATL_F77wrap_chemm atl_f77wrap_chemm_ +#define ATL_F77wrap_cherk atl_f77wrap_cherk_ +#define ATL_F77wrap_cher2k atl_f77wrap_cher2k_ +#define ATL_F77wrap_csymm atl_f77wrap_csymm_ +#define ATL_F77wrap_csyrk atl_f77wrap_csyrk_ +#define ATL_F77wrap_csyr2k atl_f77wrap_csyr2k_ +#define ATL_F77wrap_ctrmm atl_f77wrap_ctrmm_ +#define ATL_F77wrap_ctrsm atl_f77wrap_ctrsm_ + +#elif defined( DCPLX ) + +#define ATL_F77wrap_zrotg atl_f77wrap_zrotg_ +#define ATL_F77wrap_dznrm2 atl_f77wrap_dznrm2_ +#define ATL_F77wrap_dzasum atl_f77wrap_dzasum_ +#define ATL_F77wrap_zscal atl_f77wrap_zscal_ +#define ATL_F77wrap_zdscal atl_f77wrap_zdscal_ +#define ATL_F77wrap_izamax atl_f77wrap_izamax_ +#define ATL_F77wrap_zaxpy atl_f77wrap_zaxpy_ +#define ATL_F77wrap_zcopy atl_f77wrap_zcopy_ +#define ATL_F77wrap_zswap atl_f77wrap_zswap_ +#define ATL_F77wrap_zdrot atl_f77wrap_zdrot_ +#define ATL_F77wrap_zdotc atl_f77wrap_zdotc_ +#define ATL_F77wrap_zdotu atl_f77wrap_zdotu_ + +#define ATL_F77wrap_zgbmv atl_f77wrap_zgbmv_ +#define ATL_F77wrap_zgemv atl_f77wrap_zgemv_ +#define ATL_F77wrap_zgerc atl_f77wrap_zgerc_ +#define ATL_F77wrap_zgeru atl_f77wrap_zgeru_ +#define ATL_F77wrap_zhbmv atl_f77wrap_zhbmv_ +#define ATL_F77wrap_zhpmv atl_f77wrap_zhpmv_ +#define ATL_F77wrap_zhemv atl_f77wrap_zhemv_ +#define ATL_F77wrap_zhpr atl_f77wrap_zhpr_ +#define ATL_F77wrap_zher atl_f77wrap_zher_ +#define ATL_F77wrap_zhpr2 atl_f77wrap_zhpr2_ +#define ATL_F77wrap_zher2 atl_f77wrap_zher2_ +#define ATL_F77wrap_ztbmv atl_f77wrap_ztbmv_ +#define ATL_F77wrap_ztpmv atl_f77wrap_ztpmv_ +#define ATL_F77wrap_ztrmv atl_f77wrap_ztrmv_ +#define ATL_F77wrap_ztbsv atl_f77wrap_ztbsv_ +#define ATL_F77wrap_ztpsv atl_f77wrap_ztpsv_ +#define ATL_F77wrap_ztrsv atl_f77wrap_ztrsv_ + +#define ATL_F77wrap_zgemm atl_f77wrap_zgemm_ +#define ATL_F77wrap_zhemm atl_f77wrap_zhemm_ +#define ATL_F77wrap_zherk atl_f77wrap_zherk_ +#define ATL_F77wrap_zher2k atl_f77wrap_zher2k_ +#define ATL_F77wrap_zsymm atl_f77wrap_zsymm_ +#define ATL_F77wrap_zsyrk atl_f77wrap_zsyrk_ +#define ATL_F77wrap_zsyr2k atl_f77wrap_zsyr2k_ +#define ATL_F77wrap_ztrmm atl_f77wrap_ztrmm_ +#define ATL_F77wrap_ztrsm atl_f77wrap_ztrsm_ + +#endif + +#elif defined( UpCase ) +/* + * These defines set up the naming scheme required to have a FORTRAN + * routine calling a C routine with the following interface: + * + * FORTRAN CALL C declaration + * CALL ATL_F77WRAP_SGEMM(...) void ATL_F77WRAP_SGEMM(...) + * + */ +#if defined( SREAL ) + +#define ATL_F77wrap_srotg ATL_F77WRAP_SROTG +#define ATL_F77wrap_srotmg ATL_F77WRAP_SROTMG +#define ATL_F77wrap_snrm2 ATL_F77WRAP_SNRM2 +#define ATL_F77wrap_sasum ATL_F77WRAP_SASUM +#define ATL_F77wrap_sscal ATL_F77WRAP_SSCAL +#define ATL_F77wrap_isamax ATL_F77WRAP_ISAMAX +#define ATL_F77wrap_saxpy ATL_F77WRAP_SAXPY +#define ATL_F77wrap_scopy ATL_F77WRAP_SCOPY +#define ATL_F77wrap_sswap ATL_F77WRAP_SSWAP +#define ATL_F77wrap_srot ATL_F77WRAP_SROT +#define ATL_F77wrap_srotm ATL_F77WRAP_SROTM +#define ATL_F77wrap_sdot ATL_F77WRAP_SDOT +#define ATL_F77wrap_dsdot ATL_F77WRAP_DSDOT +#define ATL_F77wrap_sdsdot ATL_F77WRAP_SDSDOT + +#define ATL_F77wrap_sgbmv ATL_F77WRAP_SGBMV +#define ATL_F77wrap_sgemv ATL_F77WRAP_SGEMV +#define ATL_F77wrap_sger ATL_F77WRAP_SGER +#define ATL_F77wrap_ssbmv ATL_F77WRAP_SSBMV +#define ATL_F77wrap_sspmv ATL_F77WRAP_SSPMV +#define ATL_F77wrap_ssymv ATL_F77WRAP_SSYMV +#define ATL_F77wrap_sspr ATL_F77WRAP_SSPR +#define ATL_F77wrap_ssyr ATL_F77WRAP_SSYR +#define ATL_F77wrap_sspr2 ATL_F77WRAP_SSPR2 +#define ATL_F77wrap_ssyr2 ATL_F77WRAP_SSYR2 +#define ATL_F77wrap_stbmv ATL_F77WRAP_STBMV +#define ATL_F77wrap_stpmv ATL_F77WRAP_STPMV +#define ATL_F77wrap_strmv ATL_F77WRAP_STRMV +#define ATL_F77wrap_stbsv ATL_F77WRAP_STBSV +#define ATL_F77wrap_stpsv ATL_F77WRAP_STPSV +#define ATL_F77wrap_strsv ATL_F77WRAP_STRSV + +#define ATL_F77wrap_sgemm ATL_F77WRAP_SGEMM +#define ATL_F77wrap_ssymm ATL_F77WRAP_SSYMM +#define ATL_F77wrap_ssyrk ATL_F77WRAP_SSYRK +#define ATL_F77wrap_ssyr2k ATL_F77WRAP_SSYR2K +#define ATL_F77wrap_strmm ATL_F77WRAP_STRMM +#define ATL_F77wrap_strsm ATL_F77WRAP_STRSM + +#elif defined( DREAL ) + +#define ATL_F77wrap_drotg ATL_F77WRAP_DROTG +#define ATL_F77wrap_drotmg ATL_F77WRAP_DROTMG +#define ATL_F77wrap_dnrm2 ATL_F77WRAP_DNRM2 +#define ATL_F77wrap_dasum ATL_F77WRAP_DASUM +#define ATL_F77wrap_dscal ATL_F77WRAP_DSCAL +#define ATL_F77wrap_idamax ATL_F77WRAP_IDAMAX +#define ATL_F77wrap_daxpy ATL_F77WRAP_DAXPY +#define ATL_F77wrap_dcopy ATL_F77WRAP_DCOPY +#define ATL_F77wrap_dswap ATL_F77WRAP_DSWAP +#define ATL_F77wrap_drot ATL_F77WRAP_DROT +#define ATL_F77wrap_drotm ATL_F77WRAP_DROTM +#define ATL_F77wrap_ddot ATL_F77WRAP_DDOT + +#define ATL_F77wrap_dgbmv ATL_F77WRAP_DGBMV +#define ATL_F77wrap_dgemv ATL_F77WRAP_DGEMV +#define ATL_F77wrap_dger ATL_F77WRAP_DGER +#define ATL_F77wrap_dsbmv ATL_F77WRAP_DSBMV +#define ATL_F77wrap_dspmv ATL_F77WRAP_DSPMV +#define ATL_F77wrap_dsymv ATL_F77WRAP_DSYMV +#define ATL_F77wrap_dspr ATL_F77WRAP_DSPR +#define ATL_F77wrap_dsyr ATL_F77WRAP_DSYR +#define ATL_F77wrap_dspr2 ATL_F77WRAP_DSPR2 +#define ATL_F77wrap_dsyr2 ATL_F77WRAP_DSYR2 +#define ATL_F77wrap_dtbmv ATL_F77WRAP_DTBMV +#define ATL_F77wrap_dtpmv ATL_F77WRAP_DTPMV +#define ATL_F77wrap_dtrmv ATL_F77WRAP_DTRMV +#define ATL_F77wrap_dtbsv ATL_F77WRAP_DTBSV +#define ATL_F77wrap_dtpsv ATL_F77WRAP_DTPSV +#define ATL_F77wrap_dtrsv ATL_F77WRAP_DTRSV + +#define ATL_F77wrap_dgemm ATL_F77WRAP_DGEMM +#define ATL_F77wrap_dsymm ATL_F77WRAP_DSYMM +#define ATL_F77wrap_dsyrk ATL_F77WRAP_DSYRK +#define ATL_F77wrap_dsyr2k ATL_F77WRAP_DSYR2K +#define ATL_F77wrap_dtrmm ATL_F77WRAP_DTRMM +#define ATL_F77wrap_dtrsm ATL_F77WRAP_DTRSM + +#elif defined( SCPLX ) + +#define ATL_F77wrap_crotg ATL_F77WRAP_CROTG +#define ATL_F77wrap_scnrm2 ATL_F77WRAP_SCNRM2 +#define ATL_F77wrap_scasum ATL_F77WRAP_SCASUM +#define ATL_F77wrap_cscal ATL_F77WRAP_CSCAL +#define ATL_F77wrap_csscal ATL_F77WRAP_CSSCAL +#define ATL_F77wrap_icamax ATL_F77WRAP_ICAMAX +#define ATL_F77wrap_caxpy ATL_F77WRAP_CAXPY +#define ATL_F77wrap_ccopy ATL_F77WRAP_CCOPY +#define ATL_F77wrap_cswap ATL_F77WRAP_CSWAP +#define ATL_F77wrap_csrot ATL_F77WRAP_CSROT +#define ATL_F77wrap_cdotc ATL_F77WRAP_CDOTC +#define ATL_F77wrap_cdotu ATL_F77WRAP_CDOTU + +#define ATL_F77wrap_cgbmv ATL_F77WRAP_CGBMV +#define ATL_F77wrap_cgemv ATL_F77WRAP_CGEMV +#define ATL_F77wrap_cgerc ATL_F77WRAP_CGERC +#define ATL_F77wrap_cgeru ATL_F77WRAP_CGERU +#define ATL_F77wrap_chbmv ATL_F77WRAP_CHBMV +#define ATL_F77wrap_chpmv ATL_F77WRAP_CHPMV +#define ATL_F77wrap_chemv ATL_F77WRAP_CHEMV +#define ATL_F77wrap_chpr ATL_F77WRAP_CHPR +#define ATL_F77wrap_cher ATL_F77WRAP_CHER +#define ATL_F77wrap_chpr2 ATL_F77WRAP_CHPR2 +#define ATL_F77wrap_cher2 ATL_F77WRAP_CHER2 +#define ATL_F77wrap_ctbmv ATL_F77WRAP_CTBMV +#define ATL_F77wrap_ctpmv ATL_F77WRAP_CTPMV +#define ATL_F77wrap_ctrmv ATL_F77WRAP_CTRMV +#define ATL_F77wrap_ctbsv ATL_F77WRAP_CTBSV +#define ATL_F77wrap_ctpsv ATL_F77WRAP_CTPSV +#define ATL_F77wrap_ctrsv ATL_F77WRAP_CTRSV + +#define ATL_F77wrap_cgemm ATL_F77WRAP_CGEMM +#define ATL_F77wrap_chemm ATL_F77WRAP_CHEMM +#define ATL_F77wrap_cherk ATL_F77WRAP_CHERK +#define ATL_F77wrap_cher2k ATL_F77WRAP_CHER2K +#define ATL_F77wrap_csymm ATL_F77WRAP_CSYMM +#define ATL_F77wrap_csyrk ATL_F77WRAP_CSYRK +#define ATL_F77wrap_csyr2k ATL_F77WRAP_CSYR2K +#define ATL_F77wrap_ctrmm ATL_F77WRAP_CTRMM +#define ATL_F77wrap_ctrsm ATL_F77WRAP_CTRSM + +#elif defined( DCPLX ) + +#define ATL_F77wrap_zrotg ATL_F77WRAP_ZROTG +#define ATL_F77wrap_dznrm2 ATL_F77WRAP_DZNRM2 +#define ATL_F77wrap_dzasum ATL_F77WRAP_DZASUM +#define ATL_F77wrap_zscal ATL_F77WRAP_ZSCAL +#define ATL_F77wrap_zdscal ATL_F77WRAP_ZDSCAL +#define ATL_F77wrap_izamax ATL_F77WRAP_IZAMAX +#define ATL_F77wrap_zaxpy ATL_F77WRAP_ZAXPY +#define ATL_F77wrap_zcopy ATL_F77WRAP_ZCOPY +#define ATL_F77wrap_zswap ATL_F77WRAP_ZSWAP +#define ATL_F77wrap_zdrot ATL_F77WRAP_ZDROT +#define ATL_F77wrap_zdotc ATL_F77WRAP_ZDOTC +#define ATL_F77wrap_zdotu ATL_F77WRAP_ZDOTU + +#define ATL_F77wrap_zgbmv ATL_F77WRAP_ZGBMV +#define ATL_F77wrap_zgemv ATL_F77WRAP_ZGEMV +#define ATL_F77wrap_zgerc ATL_F77WRAP_ZGERC +#define ATL_F77wrap_zgeru ATL_F77WRAP_ZGERU +#define ATL_F77wrap_zhbmv ATL_F77WRAP_ZHBMV +#define ATL_F77wrap_zhpmv ATL_F77WRAP_ZHPMV +#define ATL_F77wrap_zhemv ATL_F77WRAP_ZHEMV +#define ATL_F77wrap_zhpr ATL_F77WRAP_ZHPR +#define ATL_F77wrap_zher ATL_F77WRAP_ZHER +#define ATL_F77wrap_zhpr2 ATL_F77WRAP_ZHPR2 +#define ATL_F77wrap_zher2 ATL_F77WRAP_ZHER2 +#define ATL_F77wrap_ztbmv ATL_F77WRAP_ZTBMV +#define ATL_F77wrap_ztpmv ATL_F77WRAP_ZTPMV +#define ATL_F77wrap_ztrmv ATL_F77WRAP_ZTRMV +#define ATL_F77wrap_ztbsv ATL_F77WRAP_ZTBSV +#define ATL_F77wrap_ztpsv ATL_F77WRAP_ZTPSV +#define ATL_F77wrap_ztrsv ATL_F77WRAP_ZTRSV + +#define ATL_F77wrap_zgemm ATL_F77WRAP_ZGEMM +#define ATL_F77wrap_zhemm ATL_F77WRAP_ZHEMM +#define ATL_F77wrap_zherk ATL_F77WRAP_ZHERK +#define ATL_F77wrap_zher2k ATL_F77WRAP_ZHER2K +#define ATL_F77wrap_zsymm ATL_F77WRAP_ZSYMM +#define ATL_F77wrap_zsyrk ATL_F77WRAP_ZSYRK +#define ATL_F77wrap_zsyr2k ATL_F77WRAP_ZSYR2K +#define ATL_F77wrap_ztrmm ATL_F77WRAP_ZTRMM +#define ATL_F77wrap_ztrsm ATL_F77WRAP_ZTRSM + +#endif + +#elif defined( NoChange ) +/* + * These defines set up the naming scheme required to have a FORTRAN + * routine calling a C routine with the following interface: + * + * FORTRAN CALL C declaration + * CALL ATL_F77WRAP_SGEMM(...) void atl_f77wrap_sgemm(...) + */ +#if defined( SREAL ) + +#define ATL_F77wrap_srotg atl_f77wrap_srotg +#define ATL_F77wrap_srotmg atl_f77wrap_srotmg +#define ATL_F77wrap_snrm2 atl_f77wrap_snrm2 +#define ATL_F77wrap_sasum atl_f77wrap_sasum +#define ATL_F77wrap_sscal atl_f77wrap_sscal +#define ATL_F77wrap_isamax atl_f77wrap_isamax +#define ATL_F77wrap_saxpy atl_f77wrap_saxpy +#define ATL_F77wrap_scopy atl_f77wrap_scopy +#define ATL_F77wrap_sswap atl_f77wrap_sswap +#define ATL_F77wrap_srot atl_f77wrap_srot +#define ATL_F77wrap_srotm atl_f77wrap_srotm +#define ATL_F77wrap_sdot atl_f77wrap_sdot +#define ATL_F77wrap_dsdot atl_f77wrap_dsdot +#define ATL_F77wrap_sdsdot atl_f77wrap_sdsdot + +#define ATL_F77wrap_sgbmv atl_f77wrap_sgbmv +#define ATL_F77wrap_sgemv atl_f77wrap_sgemv +#define ATL_F77wrap_sger atl_f77wrap_sger +#define ATL_F77wrap_ssbmv atl_f77wrap_ssbmv +#define ATL_F77wrap_sspmv atl_f77wrap_sspmv +#define ATL_F77wrap_ssymv atl_f77wrap_ssymv +#define ATL_F77wrap_sspr atl_f77wrap_sspr +#define ATL_F77wrap_ssyr atl_f77wrap_ssyr +#define ATL_F77wrap_sspr2 atl_f77wrap_sspr2 +#define ATL_F77wrap_ssyr2 atl_f77wrap_ssyr2 +#define ATL_F77wrap_stbmv atl_f77wrap_stbmv +#define ATL_F77wrap_stpmv atl_f77wrap_stpmv +#define ATL_F77wrap_strmv atl_f77wrap_strmv +#define ATL_F77wrap_stbsv atl_f77wrap_stbsv +#define ATL_F77wrap_stpsv atl_f77wrap_stpsv +#define ATL_F77wrap_strsv atl_f77wrap_strsv + +#define ATL_F77wrap_sgemm atl_f77wrap_sgemm +#define ATL_F77wrap_ssymm atl_f77wrap_ssymm +#define ATL_F77wrap_ssyrk atl_f77wrap_ssyrk +#define ATL_F77wrap_ssyr2k atl_f77wrap_ssyr2k +#define ATL_F77wrap_strmm atl_f77wrap_strmm +#define ATL_F77wrap_strsm atl_f77wrap_strsm + +#elif defined( DREAL ) + +#define ATL_F77wrap_drotg atl_f77wrap_drotg +#define ATL_F77wrap_drotmg atl_f77wrap_drotmg +#define ATL_F77wrap_dnrm2 atl_f77wrap_dnrm2 +#define ATL_F77wrap_dasum atl_f77wrap_dasum +#define ATL_F77wrap_dscal atl_f77wrap_dscal +#define ATL_F77wrap_idamax atl_f77wrap_idamax +#define ATL_F77wrap_daxpy atl_f77wrap_daxpy +#define ATL_F77wrap_dcopy atl_f77wrap_dcopy +#define ATL_F77wrap_dswap atl_f77wrap_dswap +#define ATL_F77wrap_drot atl_f77wrap_drot +#define ATL_F77wrap_drotm atl_f77wrap_drotm +#define ATL_F77wrap_ddot atl_f77wrap_ddot + +#define ATL_F77wrap_dgbmv atl_f77wrap_dgbmv +#define ATL_F77wrap_dgemv atl_f77wrap_dgemv +#define ATL_F77wrap_dger atl_f77wrap_dger +#define ATL_F77wrap_dsbmv atl_f77wrap_dsbmv +#define ATL_F77wrap_dspmv atl_f77wrap_dspmv +#define ATL_F77wrap_dsymv atl_f77wrap_dsymv +#define ATL_F77wrap_dspr atl_f77wrap_dspr +#define ATL_F77wrap_dsyr atl_f77wrap_dsyr +#define ATL_F77wrap_dspr2 atl_f77wrap_dspr2 +#define ATL_F77wrap_dsyr2 atl_f77wrap_dsyr2 +#define ATL_F77wrap_dtbmv atl_f77wrap_dtbmv +#define ATL_F77wrap_dtpmv atl_f77wrap_dtpmv +#define ATL_F77wrap_dtrmv atl_f77wrap_dtrmv +#define ATL_F77wrap_dtbsv atl_f77wrap_dtbsv +#define ATL_F77wrap_dtpsv atl_f77wrap_dtpsv +#define ATL_F77wrap_dtrsv atl_f77wrap_dtrsv + +#define ATL_F77wrap_dgemm atl_f77wrap_dgemm +#define ATL_F77wrap_dsymm atl_f77wrap_dsymm +#define ATL_F77wrap_dsyrk atl_f77wrap_dsyrk +#define ATL_F77wrap_dsyr2k atl_f77wrap_dsyr2k +#define ATL_F77wrap_dtrmm atl_f77wrap_dtrmm +#define ATL_F77wrap_dtrsm atl_f77wrap_dtrsm + +#elif defined( SCPLX ) + +#define ATL_F77wrap_crotg atl_f77wrap_crotg +#define ATL_F77wrap_scnrm2 atl_f77wrap_scnrm2 +#define ATL_F77wrap_scasum atl_f77wrap_scasum +#define ATL_F77wrap_cscal atl_f77wrap_cscal +#define ATL_F77wrap_csscal atl_f77wrap_csscal +#define ATL_F77wrap_icamax atl_f77wrap_icamax +#define ATL_F77wrap_caxpy atl_f77wrap_caxpy +#define ATL_F77wrap_ccopy atl_f77wrap_ccopy +#define ATL_F77wrap_cswap atl_f77wrap_cswap +#define ATL_F77wrap_csrot atl_f77wrap_csrot +#define ATL_F77wrap_cdotc atl_f77wrap_cdotc +#define ATL_F77wrap_cdotu atl_f77wrap_cdotu + +#define ATL_F77wrap_cgbmv atl_f77wrap_cgbmv +#define ATL_F77wrap_cgemv atl_f77wrap_cgemv +#define ATL_F77wrap_cgerc atl_f77wrap_cgerc +#define ATL_F77wrap_cgeru atl_f77wrap_cgeru +#define ATL_F77wrap_chbmv atl_f77wrap_chbmv +#define ATL_F77wrap_chpmv atl_f77wrap_chpmv +#define ATL_F77wrap_chemv atl_f77wrap_chemv +#define ATL_F77wrap_chpr atl_f77wrap_chpr +#define ATL_F77wrap_cher atl_f77wrap_cher +#define ATL_F77wrap_chpr2 atl_f77wrap_chpr2 +#define ATL_F77wrap_cher2 atl_f77wrap_cher2 +#define ATL_F77wrap_ctbmv atl_f77wrap_ctbmv +#define ATL_F77wrap_ctpmv atl_f77wrap_ctpmv +#define ATL_F77wrap_ctrmv atl_f77wrap_ctrmv +#define ATL_F77wrap_ctbsv atl_f77wrap_ctbsv +#define ATL_F77wrap_ctpsv atl_f77wrap_ctpsv +#define ATL_F77wrap_ctrsv atl_f77wrap_ctrsv + +#define ATL_F77wrap_cgemm atl_f77wrap_cgemm +#define ATL_F77wrap_chemm atl_f77wrap_chemm +#define ATL_F77wrap_cherk atl_f77wrap_cherk +#define ATL_F77wrap_cher2k atl_f77wrap_cher2k +#define ATL_F77wrap_csymm atl_f77wrap_csymm +#define ATL_F77wrap_csyrk atl_f77wrap_csyrk +#define ATL_F77wrap_csyr2k atl_f77wrap_csyr2k +#define ATL_F77wrap_ctrmm atl_f77wrap_ctrmm +#define ATL_F77wrap_ctrsm atl_f77wrap_ctrsm + +#elif defined( DCPLX ) + +#define ATL_F77wrap_zrotg atl_f77wrap_zrotg +#define ATL_F77wrap_dznrm2 atl_f77wrap_dznrm2 +#define ATL_F77wrap_dzasum atl_f77wrap_dzasum +#define ATL_F77wrap_zscal atl_f77wrap_zscal +#define ATL_F77wrap_zdscal atl_f77wrap_zdscal +#define ATL_F77wrap_izamax atl_f77wrap_izamax +#define ATL_F77wrap_zaxpy atl_f77wrap_zaxpy +#define ATL_F77wrap_zcopy atl_f77wrap_zcopy +#define ATL_F77wrap_zswap atl_f77wrap_zswap +#define ATL_F77wrap_zdrot atl_f77wrap_zdrot +#define ATL_F77wrap_zdotc atl_f77wrap_zdotc +#define ATL_F77wrap_zdotu atl_f77wrap_zdotu + +#define ATL_F77wrap_zgbmv atl_f77wrap_zgbmv +#define ATL_F77wrap_zgemv atl_f77wrap_zgemv +#define ATL_F77wrap_zgerc atl_f77wrap_zgerc +#define ATL_F77wrap_zgeru atl_f77wrap_zgeru +#define ATL_F77wrap_zhbmv atl_f77wrap_zhbmv +#define ATL_F77wrap_zhpmv atl_f77wrap_zhpmv +#define ATL_F77wrap_zhemv atl_f77wrap_zhemv +#define ATL_F77wrap_zhpr atl_f77wrap_zhpr +#define ATL_F77wrap_zher atl_f77wrap_zher +#define ATL_F77wrap_zhpr2 atl_f77wrap_zhpr2 +#define ATL_F77wrap_zher2 atl_f77wrap_zher2 +#define ATL_F77wrap_ztbmv atl_f77wrap_ztbmv +#define ATL_F77wrap_ztpmv atl_f77wrap_ztpmv +#define ATL_F77wrap_ztrmv atl_f77wrap_ztrmv +#define ATL_F77wrap_ztbsv atl_f77wrap_ztbsv +#define ATL_F77wrap_ztpsv atl_f77wrap_ztpsv +#define ATL_F77wrap_ztrsv atl_f77wrap_ztrsv + +#define ATL_F77wrap_zgemm atl_f77wrap_zgemm +#define ATL_F77wrap_zhemm atl_f77wrap_zhemm +#define ATL_F77wrap_zherk atl_f77wrap_zherk +#define ATL_F77wrap_zher2k atl_f77wrap_zher2k +#define ATL_F77wrap_zsymm atl_f77wrap_zsymm +#define ATL_F77wrap_zsyrk atl_f77wrap_zsyrk +#define ATL_F77wrap_zsyr2k atl_f77wrap_zsyr2k +#define ATL_F77wrap_ztrmm atl_f77wrap_ztrmm +#define ATL_F77wrap_ztrsm atl_f77wrap_ztrsm + +#endif + +#elif defined( Add__ ) +/* + * These defines set up the naming scheme required to have a FORTRAN + * routine calling a C routine with the following interface: + * + * FORTRAN CALL C declaration + * CALL ATL_F77WRAP_SGEMM(...) void atl_f77wrap_sgemm__(...) + */ +#if defined( SREAL ) + +#define ATL_F77wrap_srotg atl_f77wrap_srotg__ +#define ATL_F77wrap_srotmg atl_f77wrap_srotmg__ +#define ATL_F77wrap_snrm2 atl_f77wrap_snrm2__ +#define ATL_F77wrap_sasum atl_f77wrap_sasum__ +#define ATL_F77wrap_sscal atl_f77wrap_sscal__ +#define ATL_F77wrap_isamax atl_f77wrap_isamax__ +#define ATL_F77wrap_saxpy atl_f77wrap_saxpy__ +#define ATL_F77wrap_scopy atl_f77wrap_scopy__ +#define ATL_F77wrap_sswap atl_f77wrap_sswap__ +#define ATL_F77wrap_srot atl_f77wrap_srot__ +#define ATL_F77wrap_srotm atl_f77wrap_srotm__ +#define ATL_F77wrap_sdot atl_f77wrap_sdot__ +#define ATL_F77wrap_dsdot atl_f77wrap_dsdot__ +#define ATL_F77wrap_sdsdot atl_f77wrap_sdsdot__ + +#define ATL_F77wrap_sgbmv atl_f77wrap_sgbmv__ +#define ATL_F77wrap_sgemv atl_f77wrap_sgemv__ +#define ATL_F77wrap_sger atl_f77wrap_sger__ +#define ATL_F77wrap_ssbmv atl_f77wrap_ssbmv__ +#define ATL_F77wrap_sspmv atl_f77wrap_sspmv__ +#define ATL_F77wrap_ssymv atl_f77wrap_ssymv__ +#define ATL_F77wrap_sspr atl_f77wrap_sspr__ +#define ATL_F77wrap_ssyr atl_f77wrap_ssyr__ +#define ATL_F77wrap_sspr2 atl_f77wrap_sspr2__ +#define ATL_F77wrap_ssyr2 atl_f77wrap_ssyr2__ +#define ATL_F77wrap_stbmv atl_f77wrap_stbmv__ +#define ATL_F77wrap_stpmv atl_f77wrap_stpmv__ +#define ATL_F77wrap_strmv atl_f77wrap_strmv__ +#define ATL_F77wrap_stbsv atl_f77wrap_stbsv__ +#define ATL_F77wrap_stpsv atl_f77wrap_stpsv__ +#define ATL_F77wrap_strsv atl_f77wrap_strsv__ + +#define ATL_F77wrap_sgemm atl_f77wrap_sgemm__ +#define ATL_F77wrap_ssymm atl_f77wrap_ssymm__ +#define ATL_F77wrap_ssyrk atl_f77wrap_ssyrk__ +#define ATL_F77wrap_ssyr2k atl_f77wrap_ssyr2k__ +#define ATL_F77wrap_strmm atl_f77wrap_strmm__ +#define ATL_F77wrap_strsm atl_f77wrap_strsm__ + +#elif defined( DREAL ) + +#define ATL_F77wrap_drotg atl_f77wrap_drotg__ +#define ATL_F77wrap_drotmg atl_f77wrap_drotmg__ +#define ATL_F77wrap_dnrm2 atl_f77wrap_dnrm2__ +#define ATL_F77wrap_dasum atl_f77wrap_dasum__ +#define ATL_F77wrap_dscal atl_f77wrap_dscal__ +#define ATL_F77wrap_idamax atl_f77wrap_idamax__ +#define ATL_F77wrap_daxpy atl_f77wrap_daxpy__ +#define ATL_F77wrap_dcopy atl_f77wrap_dcopy__ +#define ATL_F77wrap_dswap atl_f77wrap_dswap__ +#define ATL_F77wrap_drot atl_f77wrap_drot__ +#define ATL_F77wrap_drotm atl_f77wrap_drotm__ +#define ATL_F77wrap_ddot atl_f77wrap_ddot__ + +#define ATL_F77wrap_dgbmv atl_f77wrap_dgbmv__ +#define ATL_F77wrap_dgemv atl_f77wrap_dgemv__ +#define ATL_F77wrap_dger atl_f77wrap_dger__ +#define ATL_F77wrap_dsbmv atl_f77wrap_dsbmv__ +#define ATL_F77wrap_dspmv atl_f77wrap_dspmv__ +#define ATL_F77wrap_dsymv atl_f77wrap_dsymv__ +#define ATL_F77wrap_dspr atl_f77wrap_dspr__ +#define ATL_F77wrap_dsyr atl_f77wrap_dsyr__ +#define ATL_F77wrap_dspr2 atl_f77wrap_dspr2__ +#define ATL_F77wrap_dsyr2 atl_f77wrap_dsyr2__ +#define ATL_F77wrap_dtbmv atl_f77wrap_dtbmv__ +#define ATL_F77wrap_dtpmv atl_f77wrap_dtpmv__ +#define ATL_F77wrap_dtrmv atl_f77wrap_dtrmv__ +#define ATL_F77wrap_dtbsv atl_f77wrap_dtbsv__ +#define ATL_F77wrap_dtpsv atl_f77wrap_dtpsv__ +#define ATL_F77wrap_dtrsv atl_f77wrap_dtrsv__ + +#define ATL_F77wrap_dgemm atl_f77wrap_dgemm__ +#define ATL_F77wrap_dsymm atl_f77wrap_dsymm__ +#define ATL_F77wrap_dsyrk atl_f77wrap_dsyrk__ +#define ATL_F77wrap_dsyr2k atl_f77wrap_dsyr2k__ +#define ATL_F77wrap_dtrmm atl_f77wrap_dtrmm__ +#define ATL_F77wrap_dtrsm atl_f77wrap_dtrsm__ + +#elif defined( SCPLX ) + +#define ATL_F77wrap_crotg atl_f77wrap_crotg__ +#define ATL_F77wrap_scnrm2 atl_f77wrap_scnrm2__ +#define ATL_F77wrap_scasum atl_f77wrap_scasum__ +#define ATL_F77wrap_cscal atl_f77wrap_cscal__ +#define ATL_F77wrap_csscal atl_f77wrap_csscal__ +#define ATL_F77wrap_icamax atl_f77wrap_icamax__ +#define ATL_F77wrap_caxpy atl_f77wrap_caxpy__ +#define ATL_F77wrap_ccopy atl_f77wrap_ccopy__ +#define ATL_F77wrap_cswap atl_f77wrap_cswap__ +#define ATL_F77wrap_csrot atl_f77wrap_csrot__ +#define ATL_F77wrap_cdotc atl_f77wrap_cdotc__ +#define ATL_F77wrap_cdotu atl_f77wrap_cdotu__ + +#define ATL_F77wrap_cgbmv atl_f77wrap_cgbmv__ +#define ATL_F77wrap_cgemv atl_f77wrap_cgemv__ +#define ATL_F77wrap_cgerc atl_f77wrap_cgerc__ +#define ATL_F77wrap_cgeru atl_f77wrap_cgeru__ +#define ATL_F77wrap_chbmv atl_f77wrap_chbmv__ +#define ATL_F77wrap_chpmv atl_f77wrap_chpmv__ +#define ATL_F77wrap_chemv atl_f77wrap_chemv__ +#define ATL_F77wrap_chpr atl_f77wrap_chpr__ +#define ATL_F77wrap_cher atl_f77wrap_cher__ +#define ATL_F77wrap_chpr2 atl_f77wrap_chpr2__ +#define ATL_F77wrap_cher2 atl_f77wrap_cher2__ +#define ATL_F77wrap_ctbmv atl_f77wrap_ctbmv__ +#define ATL_F77wrap_ctpmv atl_f77wrap_ctpmv__ +#define ATL_F77wrap_ctrmv atl_f77wrap_ctrmv__ +#define ATL_F77wrap_ctbsv atl_f77wrap_ctbsv__ +#define ATL_F77wrap_ctpsv atl_f77wrap_ctpsv__ +#define ATL_F77wrap_ctrsv atl_f77wrap_ctrsv__ + +#define ATL_F77wrap_cgemm atl_f77wrap_cgemm__ +#define ATL_F77wrap_chemm atl_f77wrap_chemm__ +#define ATL_F77wrap_cherk atl_f77wrap_cherk__ +#define ATL_F77wrap_cher2k atl_f77wrap_cher2k__ +#define ATL_F77wrap_csymm atl_f77wrap_csymm__ +#define ATL_F77wrap_csyrk atl_f77wrap_csyrk__ +#define ATL_F77wrap_csyr2k atl_f77wrap_csyr2k__ +#define ATL_F77wrap_ctrmm atl_f77wrap_ctrmm__ +#define ATL_F77wrap_ctrsm atl_f77wrap_ctrsm__ + +#elif defined( DCPLX ) + +#define ATL_F77wrap_zrotg atl_f77wrap_zrotg__ +#define ATL_F77wrap_dznrm2 atl_f77wrap_dznrm2__ +#define ATL_F77wrap_dzasum atl_f77wrap_dzasum__ +#define ATL_F77wrap_zscal atl_f77wrap_zscal__ +#define ATL_F77wrap_zdscal atl_f77wrap_zdscal__ +#define ATL_F77wrap_izamax atl_f77wrap_izamax__ +#define ATL_F77wrap_zaxpy atl_f77wrap_zaxpy__ +#define ATL_F77wrap_zcopy atl_f77wrap_zcopy__ +#define ATL_F77wrap_zswap atl_f77wrap_zswap__ +#define ATL_F77wrap_zdrot atl_f77wrap_zdrot__ +#define ATL_F77wrap_zdotc atl_f77wrap_zdotc__ +#define ATL_F77wrap_zdotu atl_f77wrap_zdotu__ + +#define ATL_F77wrap_zgbmv atl_f77wrap_zgbmv__ +#define ATL_F77wrap_zgemv atl_f77wrap_zgemv__ +#define ATL_F77wrap_zgerc atl_f77wrap_zgerc__ +#define ATL_F77wrap_zgeru atl_f77wrap_zgeru__ +#define ATL_F77wrap_zhbmv atl_f77wrap_zhbmv__ +#define ATL_F77wrap_zhpmv atl_f77wrap_zhpmv__ +#define ATL_F77wrap_zhemv atl_f77wrap_zhemv__ +#define ATL_F77wrap_zhpr atl_f77wrap_zhpr__ +#define ATL_F77wrap_zher atl_f77wrap_zher__ +#define ATL_F77wrap_zhpr2 atl_f77wrap_zhpr2__ +#define ATL_F77wrap_zher2 atl_f77wrap_zher2__ +#define ATL_F77wrap_ztbmv atl_f77wrap_ztbmv__ +#define ATL_F77wrap_ztpmv atl_f77wrap_ztpmv__ +#define ATL_F77wrap_ztrmv atl_f77wrap_ztrmv__ +#define ATL_F77wrap_ztbsv atl_f77wrap_ztbsv__ +#define ATL_F77wrap_ztpsv atl_f77wrap_ztpsv__ +#define ATL_F77wrap_ztrsv atl_f77wrap_ztrsv__ + +#define ATL_F77wrap_zgemm atl_f77wrap_zgemm__ +#define ATL_F77wrap_zhemm atl_f77wrap_zhemm__ +#define ATL_F77wrap_zherk atl_f77wrap_zherk__ +#define ATL_F77wrap_zher2k atl_f77wrap_zher2k__ +#define ATL_F77wrap_zsymm atl_f77wrap_zsymm__ +#define ATL_F77wrap_zsyrk atl_f77wrap_zsyrk__ +#define ATL_F77wrap_zsyr2k atl_f77wrap_zsyr2k__ +#define ATL_F77wrap_ztrmm atl_f77wrap_ztrmm__ +#define ATL_F77wrap_ztrsm atl_f77wrap_ztrsm__ + +#endif + +#endif +/* + * ===================================================================== + * Prototypes for F77 interface wrappers ATLAS BLAS routines + * ===================================================================== + */ +void Mjoin( PATLF77WRAP, rotg ) +( TYPE *, TYPE *, TYPE *, TYPE * ); +#ifdef TREAL +void Mjoin( PATLF77WRAP, rotmg ) +( TYPE *, TYPE *, TYPE *, TYPE *, + TYPE * ); +#endif +void Mjoin( ATLUPF77WRAP, nrm2 ) +( F77_INTEGER *, TYPE *, F77_INTEGER *, TYPE * ); +void Mjoin( ATLUPF77WRAP, asum ) +( F77_INTEGER *, TYPE *, F77_INTEGER *, TYPE * ); +void Mjoin( PATLF77WRAP, scal ) +( F77_INTEGER *, TYPE *, TYPE *, F77_INTEGER * ); +#ifdef TCPLX +void Mjoin( ATLPUF77WRAP, scal ) +( F77_INTEGER *, TYPE *, TYPE *, F77_INTEGER * ); +#endif +void Mjoin( Mjoin( ATL_F77wrap_i, PRE ), amax ) +( F77_INTEGER *, TYPE *, F77_INTEGER *, F77_INTEGER * ); +void Mjoin( PATLF77WRAP, axpy ) +( F77_INTEGER *, TYPE *, TYPE *, F77_INTEGER *, + TYPE *, F77_INTEGER * ); +void Mjoin( PATLF77WRAP, copy ) +( F77_INTEGER *, TYPE *, F77_INTEGER *, TYPE *, + F77_INTEGER * ); +void Mjoin( PATLF77WRAP, swap ) +( F77_INTEGER *, TYPE *, F77_INTEGER *, TYPE *, + F77_INTEGER * ); +void Mjoin( ATLPUF77WRAP, rot ) +( F77_INTEGER *, TYPE *, F77_INTEGER *, TYPE *, + F77_INTEGER *, TYPE *, TYPE * ); +#ifdef TREAL +void Mjoin( PATLF77WRAP, rotm ) +( F77_INTEGER *, TYPE *, F77_INTEGER *, TYPE *, + F77_INTEGER *, TYPE * ); +#endif +#ifdef TREAL +void Mjoin( PATLF77WRAP, dot ) +( F77_INTEGER *, TYPE *, F77_INTEGER *, TYPE *, + F77_INTEGER *, TYPE * ); +#else +void Mjoin( PATLF77WRAP, dotc ) +( F77_INTEGER *, TYPE *, F77_INTEGER *, TYPE *, + F77_INTEGER *, TYPE * ); +void Mjoin( PATLF77WRAP, dotu ) +( F77_INTEGER *, TYPE *, F77_INTEGER *, TYPE *, + F77_INTEGER *, TYPE * ); +#endif +void ATL_F77wrap_dsdot +( F77_INTEGER *, float *, F77_INTEGER *, float *, + F77_INTEGER *, double * ); +void ATL_F77wrap_sdsdot +( F77_INTEGER *, float *, float *, F77_INTEGER *, + float *, F77_INTEGER *, float * ); + +void Mjoin( PATLF77WRAP, gbmv ) +( F77_INTEGER *, F77_INTEGER *, F77_INTEGER *, F77_INTEGER *, + F77_INTEGER *, TYPE *, TYPE *, F77_INTEGER *, + TYPE *, F77_INTEGER *, TYPE *, TYPE *, + F77_INTEGER * ); +void Mjoin( PATLF77WRAP, gemv ) +( F77_INTEGER *, F77_INTEGER *, F77_INTEGER *, TYPE *, + TYPE *, F77_INTEGER *, TYPE *, F77_INTEGER *, + TYPE *, TYPE *, F77_INTEGER * ); +#ifdef TREAL +void Mjoin( PATLF77WRAP, ger ) +( F77_INTEGER *, F77_INTEGER *, TYPE *, TYPE *, + F77_INTEGER *, TYPE *, F77_INTEGER *, TYPE *, + F77_INTEGER * ); +void Mjoin( PATLF77WRAP, sbmv ) +( F77_INTEGER *, F77_INTEGER *, F77_INTEGER *, TYPE *, + TYPE *, F77_INTEGER *, TYPE *, F77_INTEGER *, + TYPE *, TYPE *, F77_INTEGER * ); +void Mjoin( PATLF77WRAP, spmv ) +( F77_INTEGER *, F77_INTEGER *, TYPE *, TYPE *, + TYPE *, F77_INTEGER *, TYPE *, TYPE *, + F77_INTEGER * ); +void Mjoin( PATLF77WRAP, symv ) +( F77_INTEGER *, F77_INTEGER *, TYPE *, TYPE *, + F77_INTEGER *, TYPE *, F77_INTEGER *, TYPE *, + TYPE *, F77_INTEGER * ); +void Mjoin( PATLF77WRAP, spr ) +( F77_INTEGER *, F77_INTEGER *, TYPE *, TYPE *, + F77_INTEGER *, TYPE * ); +void Mjoin( PATLF77WRAP, syr ) +( F77_INTEGER *, F77_INTEGER *, TYPE *, TYPE *, + F77_INTEGER *, TYPE *, F77_INTEGER * ); +void Mjoin( PATLF77WRAP, spr2 ) +( F77_INTEGER *, F77_INTEGER *, TYPE *, TYPE *, + F77_INTEGER *, TYPE *, F77_INTEGER *, TYPE * ); +void Mjoin( PATLF77WRAP, syr2 ) +( F77_INTEGER *, F77_INTEGER *, TYPE *, TYPE *, + F77_INTEGER *, TYPE *, F77_INTEGER *, TYPE *, + F77_INTEGER * ); +#else +void Mjoin( PATLF77WRAP, gerc ) +( F77_INTEGER *, F77_INTEGER *, TYPE *, TYPE *, + F77_INTEGER *, TYPE *, F77_INTEGER *, TYPE *, + F77_INTEGER * ); +void Mjoin( PATLF77WRAP, geru ) +( F77_INTEGER *, F77_INTEGER *, TYPE *, TYPE *, + F77_INTEGER *, TYPE *, F77_INTEGER *, TYPE *, + F77_INTEGER * ); +void Mjoin( PATLF77WRAP, hbmv ) +( F77_INTEGER *, F77_INTEGER *, F77_INTEGER *, TYPE *, + TYPE *, F77_INTEGER *, TYPE *, F77_INTEGER *, + TYPE *, TYPE *, F77_INTEGER * ); +void Mjoin( PATLF77WRAP, hpmv ) +( F77_INTEGER *, F77_INTEGER *, TYPE *, TYPE *, + TYPE *, F77_INTEGER *, TYPE *, TYPE *, + F77_INTEGER * ); +void Mjoin( PATLF77WRAP, hemv ) +( F77_INTEGER *, F77_INTEGER *, TYPE *, TYPE *, + F77_INTEGER *, TYPE *, F77_INTEGER *, TYPE *, + TYPE *, F77_INTEGER * ); +void Mjoin( PATLF77WRAP, hpr ) +( F77_INTEGER *, F77_INTEGER *, TYPE *, TYPE *, + F77_INTEGER *, TYPE * ); +void Mjoin( PATLF77WRAP, her ) +( F77_INTEGER *, F77_INTEGER *, TYPE *, TYPE *, + F77_INTEGER *, TYPE *, F77_INTEGER * ); +void Mjoin( PATLF77WRAP, hpr2 ) +( F77_INTEGER *, F77_INTEGER *, TYPE *, TYPE *, + F77_INTEGER *, TYPE *, F77_INTEGER *, TYPE * ); +void Mjoin( PATLF77WRAP, her2 ) +( F77_INTEGER *, F77_INTEGER *, TYPE *, TYPE *, + F77_INTEGER *, TYPE *, F77_INTEGER *, TYPE *, + F77_INTEGER * ); +#endif +void Mjoin( PATLF77WRAP, tbmv ) +( F77_INTEGER *, F77_INTEGER *, F77_INTEGER *, F77_INTEGER *, + F77_INTEGER *, TYPE *, F77_INTEGER *, TYPE *, + F77_INTEGER * ); +void Mjoin( PATLF77WRAP, tpmv ) +( F77_INTEGER *, F77_INTEGER *, F77_INTEGER *, F77_INTEGER *, + TYPE *, TYPE *, F77_INTEGER * ); +void Mjoin( PATLF77WRAP, trmv ) +( F77_INTEGER *, F77_INTEGER *, F77_INTEGER *, F77_INTEGER *, + TYPE *, F77_INTEGER *, TYPE *, F77_INTEGER * ); +void Mjoin( PATLF77WRAP, tbsv ) +( F77_INTEGER *, F77_INTEGER *, F77_INTEGER *, F77_INTEGER *, + F77_INTEGER *, TYPE *, F77_INTEGER *, TYPE *, + F77_INTEGER * ); +void Mjoin( PATLF77WRAP, tpsv ) +( F77_INTEGER *, F77_INTEGER *, F77_INTEGER *, F77_INTEGER *, + TYPE *, TYPE *, F77_INTEGER * ); +void Mjoin( PATLF77WRAP, trsv ) +( F77_INTEGER *, F77_INTEGER *, F77_INTEGER *, F77_INTEGER *, + TYPE *, F77_INTEGER *, TYPE *, F77_INTEGER * ); + +void Mjoin( PATLF77WRAP, gemm ) +( F77_INTEGER *, F77_INTEGER *, F77_INTEGER *, F77_INTEGER *, + F77_INTEGER *, TYPE *, TYPE *, F77_INTEGER *, + TYPE *, F77_INTEGER *, TYPE *, TYPE *, + F77_INTEGER * ); +#ifdef TCPLX +void Mjoin( PATLF77WRAP, hemm ) +( F77_INTEGER *, F77_INTEGER *, F77_INTEGER *, F77_INTEGER *, + TYPE *, TYPE *, F77_INTEGER *, TYPE *, + F77_INTEGER *, TYPE *, TYPE *, F77_INTEGER * ); +void Mjoin( PATLF77WRAP, herk ) +( F77_INTEGER *, F77_INTEGER *, F77_INTEGER *, F77_INTEGER *, + TYPE *, TYPE *, F77_INTEGER *, TYPE *, + TYPE *, F77_INTEGER * ); +void Mjoin( PATLF77WRAP, her2k ) +( F77_INTEGER *, F77_INTEGER *, F77_INTEGER *, F77_INTEGER *, + TYPE *, TYPE *, F77_INTEGER *, TYPE *, + F77_INTEGER *, TYPE *, TYPE *, F77_INTEGER * ); +#endif +void Mjoin( PATLF77WRAP, symm ) +( F77_INTEGER *, F77_INTEGER *, F77_INTEGER *, F77_INTEGER *, + TYPE *, TYPE *, F77_INTEGER *, TYPE *, + F77_INTEGER *, TYPE *, TYPE *, F77_INTEGER * ); +void Mjoin( PATLF77WRAP, syrk ) +( F77_INTEGER *, F77_INTEGER *, F77_INTEGER *, F77_INTEGER *, + TYPE *, TYPE *, F77_INTEGER *, TYPE *, + TYPE *, F77_INTEGER * ); +void Mjoin( PATLF77WRAP, syr2k ) +( F77_INTEGER *, F77_INTEGER *, F77_INTEGER *, F77_INTEGER *, + TYPE *, TYPE *, F77_INTEGER *, TYPE *, + F77_INTEGER *, TYPE *, TYPE *, F77_INTEGER * ); +void Mjoin( PATLF77WRAP, trmm ) +( F77_INTEGER *, F77_INTEGER *, F77_INTEGER *, F77_INTEGER *, + F77_INTEGER *, F77_INTEGER *, TYPE *, TYPE *, + F77_INTEGER *, TYPE *, F77_INTEGER * ); +void Mjoin( PATLF77WRAP, trsm ) +( F77_INTEGER *, F77_INTEGER *, F77_INTEGER *, F77_INTEGER *, + F77_INTEGER *, F77_INTEGER *, TYPE *, TYPE *, + F77_INTEGER *, TYPE *, F77_INTEGER * ); + +#endif +/* + * End of atlas_f77wrap.h + */ diff --git a/kaldi_io/src/tools/ATLAS/include/atlas_fopen.h b/kaldi_io/src/tools/ATLAS/include/atlas_fopen.h new file mode 100644 index 0000000..aaed713 --- /dev/null +++ b/kaldi_io/src/tools/ATLAS/include/atlas_fopen.h @@ -0,0 +1,40 @@ +#ifndef ATLAS_FOPEN_H +#define ATLAS_FOPEN_H + +static int FileExists(const char *path) +{ + FILE *fp; + int iret=0; + fp = fopen(path, "r"); + if (fp) + { + fclose(fp); + iret = 1; + } + return(iret); +} + +#ifdef ATL_FOPENDELAY +static FILE *ATL_fopen(const char *path, const char *mode) +/* + * Overload fopen so it waits for NFS propogation upon first read failure + */ +{ + FILE *fp; + char ln[256]; + + fp = fopen(path, mode); + if (fp == NULL) + { + if (*mode == 'r') /* give NFS time to produce file */ + { + sprintf(ln, "make waitfile waitfile=%s\n", path); + if (system(ln) == 0) fp = fopen(path, mode); + } + } + return(fp); +} +#define fopen ATL_fopen +#endif + +#endif diff --git a/kaldi_io/src/tools/ATLAS/include/atlas_kern3.h b/kaldi_io/src/tools/ATLAS/include/atlas_kern3.h new file mode 100644 index 0000000..97e8bcc --- /dev/null +++ b/kaldi_io/src/tools/ATLAS/include/atlas_kern3.h @@ -0,0 +1,110 @@ +/* + * Automatically Tuned Linear Algebra Software v3.8.3 + * (C) Copyright 1999 R. Clint Whaley + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions + * are met: + * 1. Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * 2. Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions, and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * 3. The name of the ATLAS group or the names of its contributers may + * not be used to endorse or promote products derived from this + * software without specific written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + * ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED + * TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR + * PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE ATLAS GROUP OR ITS CONTRIBUTORS + * BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + * POSSIBILITY OF SUCH DAMAGE. + * + */ +#ifndef ATLAS_KERN3_H +#define ATLAS_KERN3_H + +#include "atlas_misc.h" +#include Mstr(Mjoin(Mjoin(atlas_,PRE),NCmm.h)) +#include "atlas_lvl3.h" +#include "atlas_kernel3.h" +#include "atlas_reflevel3.h" +/* + * Gemm entry points + */ +#define CgemmNN Mjoin(PATL,gemmNN) +#define CgemmNT Mjoin(PATL,gemmNT) +#define CgemmTN Mjoin(PATL,gemmTN) +#define CgemmNC Mjoin(PATL,gemmNC) +#define CgemmCN Mjoin(PATL,gemmCN) + +#define CAgemmNN Mjoin(PATL,aliased_gemmNN) +#define CAgemmTN Mjoin(PATL,aliased_gemmTN) + +#ifdef Left_ + #define Side_ AtlasLeft + #define SideNM L +#elif defined(Right_) + #define Side_ AtlasRight + #define SideNM R +#endif + +#ifdef Upper_ + #define Uplo_ AtlasUpper + #define UploNM U +#elif defined(Lower_) + #define Uplo_ AtlasLower + #define UploNM L +#endif + +#ifdef UnitDiag_ + #define Unit_ AtlasUnit + #define UnitNM U +#elif defined(NonUnitDiag_) + #define Unit_ AtlasNonUnit + #define UnitNM N +#endif + +#ifdef Transpose_ + #define Trans_ AtlasTrans + #define TransNM T +#elif defined(Notranspose_) + #define Trans_ AtlasNoTrans + #define TransNM N +#elif defined(ConjTrans_) + #define Trans_ AtlasConjTrans + #define TransNM C +#endif + +#ifndef TRSM_Xover + #define TRSM_Xover NB +#endif +#ifndef TRMM_Xover + #define TRMM_Xover NB +#endif +#ifndef HER2K_Xover + #define HER2K_Xover NB +#endif +#ifndef SYR2K_Xover + #define SYR2K_Xover NB +#endif +#ifndef HERK_Xover + #define HERK_Xover NB +#endif +#ifndef SYRK_Xover + #define SYRK_Xover NB +#endif +#ifndef HEMM_Xover + #define HEMM_Xover NB +#endif +#ifndef SYMM_Xover + #define SYMM_Xover NB +#endif + +#endif diff --git a/kaldi_io/src/tools/ATLAS/include/atlas_kernel2.h b/kaldi_io/src/tools/ATLAS/include/atlas_kernel2.h new file mode 100644 index 0000000..4663def --- /dev/null +++ b/kaldi_io/src/tools/ATLAS/include/atlas_kernel2.h @@ -0,0 +1,5408 @@ +/* --------------------------------------------------------------------- + * + * -- Automatically Tuned Linear Algebra Software (ATLAS) + * (C) Copyright 2000 All Rights Reserved + * + * -- ATLAS routine -- Version 3.2 -- December 25, 2000 + * + * Author : Antoine P. Petitet + * Contributor(s) : R. Clint Whaley + * Originally developed at the University of Tennessee, + * Innovative Computing Laboratory, Knoxville TN, 37996-1301, USA. + * + * --------------------------------------------------------------------- + * + * -- Copyright notice and Licensing terms: + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions + * are met: + * + * 1. Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * 2. Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions, and the following disclaimer in + * the documentation and/or other materials provided with the distri- + * bution. + * 3. The name of the University, the ATLAS group, or the names of its + * contributors may not be used to endorse or promote products deri- + * ved from this software without specific written permission. + * + * -- Disclaimer: + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + * ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE UNIVERSITY + * OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPE- + * CIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED + * TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, + * OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEO- + * RY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (IN- + * CLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF + * THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + * --------------------------------------------------------------------- + */ +#ifndef ATLAS_KERNEL2_H +#define ATLAS_KERNEL2_H +/* + * ===================================================================== + * Macro function definitions + * ===================================================================== + */ +#define ATL_GetPartSBMV ATL_GetPartSYMV +#define ATL_GetPartSPMV ATL_GetPartSYMV +#define ATL_GetPartP1 ATL_GetPartR1 + +#define MLpprev( n_, a_, lda_ ) \ + { a_ -= ( (((n_) * (lda_)) + (((n_)*((n_)+1)) >> 1)) SHIFT ); lda_ += (n_); } +#define MUpprev( n_, a_, lda_ ) \ + { a_ -= ( (((n_) * (lda_)) - (((n_)*((n_)-1)) >> 1)) SHIFT ); lda_ -= (n_); } +#define MLpnext( n_, a_, lda_ ) \ + { a_ += ( (((n_) * (lda_)) - (((n_)*((n_)-1)) >> 1)) SHIFT ); lda_ -= (n_); } +#define MUpnext( n_, a_, lda_ ) \ + { a_ += ( (((n_) * (lda_)) + (((n_)*((n_)+1)) >> 1)) SHIFT ); lda_ += (n_); } + +#define MLrprev( n_, a_, lda_ ) \ + { a_ -= ( ((n_) * ((lda_)+1)) SHIFT ); } +#define MUrprev( n_, a_, lda_ ) \ + { a_ -= ( ((n_) * ((lda_)+1)) SHIFT ); } +#define MLrnext( n_, a_, lda_ ) \ + { a_ += ( ((n_) * ((lda_)+1)) SHIFT ); } +#define MUrnext( n_, a_, lda_ ) \ + { a_ += ( ((n_) * ((lda_)+1)) SHIFT ); } +/* + * ===================================================================== + * Recursive Level 2 BLAS function prototypes + * ===================================================================== + */ +void ATL_strsvLTU +( + const int, + const float *, const int, + float * +); + +void ATL_strsvLNU +( + const int, + const float *, const int, + float * +); + +void ATL_strsvLTN +( + const int, + const float *, const int, + float * +); + +void ATL_strsvLNN +( + const int, + const float *, const int, + float * +); + +void ATL_strsvUTU +( + const int, + const float *, const int, + float * +); + +void ATL_strsvUNU +( + const int, + const float *, const int, + float * +); + +void ATL_strsvUTN +( + const int, + const float *, const int, + float * +); + +void ATL_strsvUNN +( + const int, + const float *, const int, + float * +); + +void ATL_strsvLT +( + const enum ATLAS_DIAG, + const int, + const float *, const int, + float * +); + +void ATL_strsvLN +( + const enum ATLAS_DIAG, + const int, + const float *, const int, + float * +); + +void ATL_strsvUT +( + const enum ATLAS_DIAG, + const int, + const float *, const int, + float * +); + +void ATL_strsvUN +( + const enum ATLAS_DIAG, + const int, + const float *, const int, + float * +); + +void ATL_stpsvLTU +( + const int, + const float *, const int, + float * +); + +void ATL_stpsvLNU +( + const int, + const float *, const int, + float * +); + +void ATL_stpsvLTN +( + const int, + const float *, const int, + float * +); + +void ATL_stpsvLNN +( + const int, + const float *, const int, + float * +); + +void ATL_stpsvUTU +( + const int, + const float *, const int, + float * +); + +void ATL_stpsvUNU +( + const int, + const float *, const int, + float * +); + +void ATL_stpsvUTN +( + const int, + const float *, const int, + float * +); + +void ATL_stpsvUNN +( + const int, + const float *, const int, + float * +); + +void ATL_stpsvLT +( + const enum ATLAS_DIAG, + const int, + const float *, const int, + float * +); + +void ATL_stpsvLN +( + const enum ATLAS_DIAG, + const int, + const float *, const int, + float * +); + +void ATL_stpsvUT +( + const enum ATLAS_DIAG, + const int, + const float *, const int, + float * +); + +void ATL_stpsvUN +( + const enum ATLAS_DIAG, + const int, + const float *, const int, + float * +); + +void ATL_stbsvLTU +( + const int, const int, + const float *, const int, + float * +); + +void ATL_stbsvLNU +( + const int, const int, + const float *, const int, + float * +); + +void ATL_stbsvLTN +( + const int, const int, + const float *, const int, + float * +); + +void ATL_stbsvLNN +( + const int, const int, + const float *, const int, + float * +); + +void ATL_stbsvUTU +( + const int, const int, + const float *, const int, + float * +); + +void ATL_stbsvUNU +( + const int, const int, + const float *, const int, + float * +); + +void ATL_stbsvUTN +( + const int, const int, + const float *, const int, + float * +); + +void ATL_stbsvUNN +( + const int, const int, + const float *, const int, + float * +); + +void ATL_stbsvLT +( + const enum ATLAS_DIAG, + const int, const int, + const float *, const int, + float * +); + +void ATL_stbsvLN +( + const enum ATLAS_DIAG, + const int, const int, + const float *, const int, + float * +); + +void ATL_stbsvUT +( + const enum ATLAS_DIAG, + const int, const int, + const float *, const int, + float * +); + +void ATL_stbsvUN +( + const enum ATLAS_DIAG, + const int, const int, + const float *, const int, + float * +); + +void ATL_strmvLTU +( + const int, + const float *, const int, + float * +); + +void ATL_strmvLNU +( + const int, + const float *, const int, + float * +); + +void ATL_strmvLTN +( + const int, + const float *, const int, + float * +); + +void ATL_strmvLNN +( + const int, + const float *, const int, + float * +); + +void ATL_strmvUTU +( + const int, + const float *, const int, + float * +); + +void ATL_strmvUNU +( + const int, + const float *, const int, + float * +); + +void ATL_strmvUTN +( + const int, + const float *, const int, + float * +); + +void ATL_strmvUNN +( + const int, + const float *, const int, + float * +); + +void ATL_strmvLT +( + const enum ATLAS_DIAG, + const int, + const float *, const int, + float * +); + +void ATL_strmvLN +( + const enum ATLAS_DIAG, + const int, + const float *, const int, + float * +); + +void ATL_strmvUT +( + const enum ATLAS_DIAG, + const int, + const float *, const int, + float * +); + +void ATL_strmvUN +( + const enum ATLAS_DIAG, + const int, + const float *, const int, + float * +); + +void ATL_stpmvLTU +( + const int, + const float *, const int, + float * +); + +void ATL_stpmvLNU +( + const int, + const float *, const int, + float * +); + +void ATL_stpmvLTN +( + const int, + const float *, const int, + float * +); + +void ATL_stpmvLNN +( + const int, + const float *, const int, + float * +); + +void ATL_stpmvUTU +( + const int, + const float *, const int, + float * +); + +void ATL_stpmvUNU +( + const int, + const float *, const int, + float * +); + +void ATL_stpmvUTN +( + const int, + const float *, const int, + float * +); + +void ATL_stpmvUNN +( + const int, + const float *, const int, + float * +); + +void ATL_stpmvLT +( + const enum ATLAS_DIAG, + const int, + const float *, const int, + float * +); + +void ATL_stpmvLN +( + const enum ATLAS_DIAG, + const int, + const float *, const int, + float * +); + +void ATL_stpmvUT +( + const enum ATLAS_DIAG, + const int, + const float *, const int, + float * +); + +void ATL_stpmvUN +( + const enum ATLAS_DIAG, + const int, + const float *, const int, + float * +); + +void ATL_stbmvLTU +( + const int, const int, + const float *, const int, + float * +); + +void ATL_stbmvLNU +( + const int, const int, + const float *, const int, + float * +); + +void ATL_stbmvLTN +( + const int, const int, + const float *, const int, + float * +); + +void ATL_stbmvLNN +( + const int, const int, + const float *, const int, + float * +); + +void ATL_stbmvUTU +( + const int, const int, + const float *, const int, + float * +); + +void ATL_stbmvUNU +( + const int, const int, + const float *, const int, + float * +); + +void ATL_stbmvUTN +( + const int, const int, + const float *, const int, + float * +); + +void ATL_stbmvUNN +( + const int, const int, + const float *, const int, + float * +); + +void ATL_stbmvLT +( + const enum ATLAS_DIAG, + const int, const int, + const float *, const int, + float * +); + +void ATL_stbmvLN +( + const enum ATLAS_DIAG, + const int, const int, + const float *, const int, + float * +); + +void ATL_stbmvUT +( + const enum ATLAS_DIAG, + const int, const int, + const float *, const int, + float * +); + +void ATL_stbmvUN +( + const enum ATLAS_DIAG, + const int, const int, + const float *, const int, + float * +); + +void ATL_ssyr2U +( + const int, + const float *, + const float *, + float *, const int +); + +void ATL_ssyr2L +( + const int, + const float *, + const float *, + float *, const int +); + +void ATL_sspr2U +( + const int, + const float *, + const float *, + float *, const int +); + +void ATL_sspr2L +( + const int, + const float *, + const float *, + float *, const int +); + +void ATL_ssyrU +( + const int, + const float *, + const float *, const int, + float *, const int +); + +void ATL_ssyrL +( + const int, + const float *, + const float *, const int, + float *, const int +); + +void ATL_ssprU +( + const int, + const float *, + const float *, const int, + float *, const int +); + +void ATL_ssprL +( + const int, + const float *, + const float *, const int, + float *, const int +); + +void ATL_ssymvU +( + const int, + const float *, const int, + const float *, + const float, + float * +); + +void ATL_ssymvL +( + const int, + const float *, const int, + const float *, + const float, + float * +); + +void ATL_sspmvU +( + const int, + const float *, const int, + const float *, + const float, + float * +); + +void ATL_sspmvL +( + const int, + const float *, const int, + const float *, + const float, + float * +); + +void ATL_ssbmvU +( + const int, const int, + const float *, const int, + const float *, + const float, + float * +); + +void ATL_ssbmvL +( + const int, const int, + const float *, const int, + const float *, + const float, + float * +); + +void ATL_sgpmv +( + const enum ATLAS_UPLO, const enum ATLAS_TRANS, + const int, const int, + const float, + const float *, const int, + const float *, const int, + const float, + float *, const int +); + +void ATL_sgprU +( + const int, const int, + const float, + const float *, const int, + const float *, const int, + float *, const int +); + +void ATL_sgprL +( + const int, const int, + const float, + const float *, const int, + const float *, const int, + float *, const int +); + +void ATL_sgpr +( + const enum ATLAS_UPLO, + const int, const int, + const float, + const float *, const int, + const float *, const int, + float *, const int +); + +void ATL_sgpr1U_a1_x1_yX +( + const int, const int, + const float, + const float *, const int, + const float *, const int, + float *, const int +); + +void ATL_sgpr1L_a1_x1_yX +( + const int, const int, + const float, + const float *, const int, + const float *, const int, + float *, const int +); + +void ATL_sgpmvUT_a1_x1_bX_y1 +( + const int, const int, + const float, + const float *, const int, + const float *, const int, + const float, + float *, const int +); + +void ATL_sgpmvUN_a1_x1_bX_y1 +( + const int, const int, + const float, + const float *, const int, + const float *, const int, + const float, + float *, const int +); + +void ATL_sgpmvUT_a1_x1_b1_y1 +( + const int, const int, + const float, + const float *, const int, + const float *, const int, + const float, + float *, const int +); + +void ATL_sgpmvUN_a1_x1_b1_y1 +( + const int, const int, + const float, + const float *, const int, + const float *, const int, + const float, + float *, const int +); + +void ATL_sgpmvUT_a1_x1_b0_y1 +( + const int, const int, + const float, + const float *, const int, + const float *, const int, + const float, + float *, const int +); + +void ATL_sgpmvUN_a1_x1_b0_y1 +( + const int, const int, + const float, + const float *, const int, + const float *, const int, + const float, + float *, const int +); + +void ATL_sgpmvLT_a1_x1_bX_y1 +( + const int, const int, + const float, + const float *, const int, + const float *, const int, + const float, + float *, const int +); + +void ATL_sgpmvLN_a1_x1_bX_y1 +( + const int, const int, + const float, + const float *, const int, + const float *, const int, + const float, + float *, const int +); + +void ATL_sgpmvLT_a1_x1_b1_y1 +( + const int, const int, + const float, + const float *, const int, + const float *, const int, + const float, + float *, const int +); + +void ATL_sgpmvLN_a1_x1_b1_y1 +( + const int, const int, + const float, + const float *, const int, + const float *, const int, + const float, + float *, const int +); + +void ATL_sgpmvLT_a1_x1_b0_y1 +( + const int, const int, + const float, + const float *, const int, + const float *, const int, + const float, + float *, const int +); + +void ATL_sgpmvLN_a1_x1_b0_y1 +( + const int, const int, + const float, + const float *, const int, + const float *, const int, + const float, + float *, const int +); + +void ATL_sgbmvT_a1_x1_bX_y1 +( + const int, const int, + const int, const int, + const float, + const float *, const int, + const float *, const int, + const float, + float *, const int +); + +void ATL_sgbmvN_a1_x1_bX_y1 +( + const int, const int, + const int, const int, + const float, + const float *, const int, + const float *, const int, + const float, + float *, const int +); + +void ATL_sgbmvT_a1_x1_b1_y1 +( + const int, const int, + const int, const int, + const float, + const float *, const int, + const float *, const int, + const float, + float *, const int +); + +void ATL_sgbmvN_a1_x1_b1_y1 +( + const int, const int, + const int, const int, + const float, + const float *, const int, + const float *, const int, + const float, + float *, const int +); + +void ATL_sgbmvT_a1_x1_b0_y1 +( + const int, const int, + const int, const int, + const float, + const float *, const int, + const float *, const int, + const float, + float *, const int +); + +void ATL_sgbmvN_a1_x1_b0_y1 +( + const int, const int, + const int, const int, + const float, + const float *, const int, + const float *, const int, + const float, + float *, const int +); + +void ATL_dtrsvLTU +( + const int, + const double *, const int, + double * +); + +void ATL_dtrsvLNU +( + const int, + const double *, const int, + double * +); + +void ATL_dtrsvLTN +( + const int, + const double *, const int, + double * +); + +void ATL_dtrsvLNN +( + const int, + const double *, const int, + double * +); + +void ATL_dtrsvUTU +( + const int, + const double *, const int, + double * +); + +void ATL_dtrsvUNU +( + const int, + const double *, const int, + double * +); + +void ATL_dtrsvUTN +( + const int, + const double *, const int, + double * +); + +void ATL_dtrsvUNN +( + const int, + const double *, const int, + double * +); + +void ATL_dtrsvLT +( + const enum ATLAS_DIAG, + const int, + const double *, const int, + double * +); + +void ATL_dtrsvLN +( + const enum ATLAS_DIAG, + const int, + const double *, const int, + double * +); + +void ATL_dtrsvUT +( + const enum ATLAS_DIAG, + const int, + const double *, const int, + double * +); + +void ATL_dtrsvUN +( + const enum ATLAS_DIAG, + const int, + const double *, const int, + double * +); + +void ATL_dtpsvLTU +( + const int, + const double *, const int, + double * +); + +void ATL_dtpsvLNU +( + const int, + const double *, const int, + double * +); + +void ATL_dtpsvLTN +( + const int, + const double *, const int, + double * +); + +void ATL_dtpsvLNN +( + const int, + const double *, const int, + double * +); + +void ATL_dtpsvUTU +( + const int, + const double *, const int, + double * +); + +void ATL_dtpsvUNU +( + const int, + const double *, const int, + double * +); + +void ATL_dtpsvUTN +( + const int, + const double *, const int, + double * +); + +void ATL_dtpsvUNN +( + const int, + const double *, const int, + double * +); + +void ATL_dtpsvLT +( + const enum ATLAS_DIAG, + const int, + const double *, const int, + double * +); + +void ATL_dtpsvLN +( + const enum ATLAS_DIAG, + const int, + const double *, const int, + double * +); + +void ATL_dtpsvUT +( + const enum ATLAS_DIAG, + const int, + const double *, const int, + double * +); + +void ATL_dtpsvUN +( + const enum ATLAS_DIAG, + const int, + const double *, const int, + double * +); + +void ATL_dtbsvLTU +( + const int, const int, + const double *, const int, + double * +); + +void ATL_dtbsvLNU +( + const int, const int, + const double *, const int, + double * +); + +void ATL_dtbsvLTN +( + const int, const int, + const double *, const int, + double * +); + +void ATL_dtbsvLNN +( + const int, const int, + const double *, const int, + double * +); + +void ATL_dtbsvUTU +( + const int, const int, + const double *, const int, + double * +); + +void ATL_dtbsvUNU +( + const int, const int, + const double *, const int, + double * +); + +void ATL_dtbsvUTN +( + const int, const int, + const double *, const int, + double * +); + +void ATL_dtbsvUNN +( + const int, const int, + const double *, const int, + double * +); + +void ATL_dtbsvLT +( + const enum ATLAS_DIAG, + const int, const int, + const double *, const int, + double * +); + +void ATL_dtbsvLN +( + const enum ATLAS_DIAG, + const int, const int, + const double *, const int, + double * +); + +void ATL_dtbsvUT +( + const enum ATLAS_DIAG, + const int, const int, + const double *, const int, + double * +); + +void ATL_dtbsvUN +( + const enum ATLAS_DIAG, + const int, const int, + const double *, const int, + double * +); + +void ATL_dtrmvLTU +( + const int, + const double *, const int, + double * +); + +void ATL_dtrmvLNU +( + const int, + const double *, const int, + double * +); + +void ATL_dtrmvLTN +( + const int, + const double *, const int, + double * +); + +void ATL_dtrmvLNN +( + const int, + const double *, const int, + double * +); + +void ATL_dtrmvUTU +( + const int, + const double *, const int, + double * +); + +void ATL_dtrmvUNU +( + const int, + const double *, const int, + double * +); + +void ATL_dtrmvUTN +( + const int, + const double *, const int, + double * +); + +void ATL_dtrmvUNN +( + const int, + const double *, const int, + double * +); + +void ATL_dtrmvLT +( + const enum ATLAS_DIAG, + const int, + const double *, const int, + double * +); + +void ATL_dtrmvLN +( + const enum ATLAS_DIAG, + const int, + const double *, const int, + double * +); + +void ATL_dtrmvUT +( + const enum ATLAS_DIAG, + const int, + const double *, const int, + double * +); + +void ATL_dtrmvUN +( + const enum ATLAS_DIAG, + const int, + const double *, const int, + double * +); + +void ATL_dtpmvLTU +( + const int, + const double *, const int, + double * +); + +void ATL_dtpmvLNU +( + const int, + const double *, const int, + double * +); + +void ATL_dtpmvLTN +( + const int, + const double *, const int, + double * +); + +void ATL_dtpmvLNN +( + const int, + const double *, const int, + double * +); + +void ATL_dtpmvUTU +( + const int, + const double *, const int, + double * +); + +void ATL_dtpmvUNU +( + const int, + const double *, const int, + double * +); + +void ATL_dtpmvUTN +( + const int, + const double *, const int, + double * +); + +void ATL_dtpmvUNN +( + const int, + const double *, const int, + double * +); + +void ATL_dtpmvLT +( + const enum ATLAS_DIAG, + const int, + const double *, const int, + double * +); + +void ATL_dtpmvLN +( + const enum ATLAS_DIAG, + const int, + const double *, const int, + double * +); + +void ATL_dtpmvUT +( + const enum ATLAS_DIAG, + const int, + const double *, const int, + double * +); + +void ATL_dtpmvUN +( + const enum ATLAS_DIAG, + const int, + const double *, const int, + double * +); + +void ATL_dtbmvLTU +( + const int, const int, + const double *, const int, + double * +); + +void ATL_dtbmvLNU +( + const int, const int, + const double *, const int, + double * +); + +void ATL_dtbmvLTN +( + const int, const int, + const double *, const int, + double * +); + +void ATL_dtbmvLNN +( + const int, const int, + const double *, const int, + double * +); + +void ATL_dtbmvUTU +( + const int, const int, + const double *, const int, + double * +); + +void ATL_dtbmvUNU +( + const int, const int, + const double *, const int, + double * +); + +void ATL_dtbmvUTN +( + const int, const int, + const double *, const int, + double * +); + +void ATL_dtbmvUNN +( + const int, const int, + const double *, const int, + double * +); + +void ATL_dtbmvLT +( + const enum ATLAS_DIAG, + const int, const int, + const double *, const int, + double * +); + +void ATL_dtbmvLN +( + const enum ATLAS_DIAG, + const int, const int, + const double *, const int, + double * +); + +void ATL_dtbmvUT +( + const enum ATLAS_DIAG, + const int, const int, + const double *, const int, + double * +); + +void ATL_dtbmvUN +( + const enum ATLAS_DIAG, + const int, const int, + const double *, const int, + double * +); + +void ATL_dsyr2U +( + const int, + const double *, + const double *, + double *, const int +); + +void ATL_dsyr2L +( + const int, + const double *, + const double *, + double *, const int +); + +void ATL_dspr2U +( + const int, + const double *, + const double *, + double *, const int +); + +void ATL_dspr2L +( + const int, + const double *, + const double *, + double *, const int +); + +void ATL_dsyrU +( + const int, + const double *, + const double *, const int, + double *, const int +); + +void ATL_dsyrL +( + const int, + const double *, + const double *, const int, + double *, const int +); + +void ATL_dsprU +( + const int, + const double *, + const double *, const int, + double *, const int +); + +void ATL_dsprL +( + const int, + const double *, + const double *, const int, + double *, const int +); + +void ATL_dsymvU +( + const int, + const double *, const int, + const double *, + const double, + double * +); + +void ATL_dsymvL +( + const int, + const double *, const int, + const double *, + const double, + double * +); + +void ATL_dspmvU +( + const int, + const double *, const int, + const double *, + const double, + double * +); + +void ATL_dspmvL +( + const int, + const double *, const int, + const double *, + const double, + double * +); + +void ATL_dsbmvU +( + const int, const int, + const double *, const int, + const double *, + const double, + double * +); + +void ATL_dsbmvL +( + const int, const int, + const double *, const int, + const double *, + const double, + double * +); + +void ATL_dgpmv +( + const enum ATLAS_UPLO, const enum ATLAS_TRANS, + const int, const int, + const double, + const double *, const int, + const double *, const int, + const double, + double *, const int +); + +void ATL_dgprU +( + const int, const int, + const double, + const double *, const int, + const double *, const int, + double *, const int +); + +void ATL_dgprL +( + const int, const int, + const double, + const double *, const int, + const double *, const int, + double *, const int +); + +void ATL_dgpr +( + const enum ATLAS_UPLO, + const int, const int, + const double, + const double *, const int, + const double *, const int, + double *, const int +); + +void ATL_dgpr1U_a1_x1_yX +( + const int, const int, + const double, + const double *, const int, + const double *, const int, + double *, const int +); + +void ATL_dgpr1L_a1_x1_yX +( + const int, const int, + const double, + const double *, const int, + const double *, const int, + double *, const int +); + +void ATL_dgpmvUT_a1_x1_bX_y1 +( + const int, const int, + const double, + const double *, const int, + const double *, const int, + const double, + double *, const int +); + +void ATL_dgpmvUN_a1_x1_bX_y1 +( + const int, const int, + const double, + const double *, const int, + const double *, const int, + const double, + double *, const int +); + +void ATL_dgpmvUT_a1_x1_b1_y1 +( + const int, const int, + const double, + const double *, const int, + const double *, const int, + const double, + double *, const int +); + +void ATL_dgpmvUN_a1_x1_b1_y1 +( + const int, const int, + const double, + const double *, const int, + const double *, const int, + const double, + double *, const int +); + +void ATL_dgpmvUT_a1_x1_b0_y1 +( + const int, const int, + const double, + const double *, const int, + const double *, const int, + const double, + double *, const int +); + +void ATL_dgpmvUN_a1_x1_b0_y1 +( + const int, const int, + const double, + const double *, const int, + const double *, const int, + const double, + double *, const int +); + +void ATL_dgpmvLT_a1_x1_bX_y1 +( + const int, const int, + const double, + const double *, const int, + const double *, const int, + const double, + double *, const int +); + +void ATL_dgpmvLN_a1_x1_bX_y1 +( + const int, const int, + const double, + const double *, const int, + const double *, const int, + const double, + double *, const int +); + +void ATL_dgpmvLT_a1_x1_b1_y1 +( + const int, const int, + const double, + const double *, const int, + const double *, const int, + const double, + double *, const int +); + +void ATL_dgpmvLN_a1_x1_b1_y1 +( + const int, const int, + const double, + const double *, const int, + const double *, const int, + const double, + double *, const int +); + +void ATL_dgpmvLT_a1_x1_b0_y1 +( + const int, const int, + const double, + const double *, const int, + const double *, const int, + const double, + double *, const int +); + +void ATL_dgpmvLN_a1_x1_b0_y1 +( + const int, const int, + const double, + const double *, const int, + const double *, const int, + const double, + double *, const int +); + +void ATL_dgbmvT_a1_x1_bX_y1 +( + const int, const int, + const int, const int, + const double, + const double *, const int, + const double *, const int, + const double, + double *, const int +); + +void ATL_dgbmvN_a1_x1_bX_y1 +( + const int, const int, + const int, const int, + const double, + const double *, const int, + const double *, const int, + const double, + double *, const int +); + +void ATL_dgbmvT_a1_x1_b1_y1 +( + const int, const int, + const int, const int, + const double, + const double *, const int, + const double *, const int, + const double, + double *, const int +); + +void ATL_dgbmvN_a1_x1_b1_y1 +( + const int, const int, + const int, const int, + const double, + const double *, const int, + const double *, const int, + const double, + double *, const int +); + +void ATL_dgbmvT_a1_x1_b0_y1 +( + const int, const int, + const int, const int, + const double, + const double *, const int, + const double *, const int, + const double, + double *, const int +); + +void ATL_dgbmvN_a1_x1_b0_y1 +( + const int, const int, + const int, const int, + const double, + const double *, const int, + const double *, const int, + const double, + double *, const int +); + +void ATL_ctrsvLHU +( + const int, + const float *, const int, + float * +); + +void ATL_ctrsvLCU +( + const int, + const float *, const int, + float * +); + +void ATL_ctrsvLTU +( + const int, + const float *, const int, + float * +); + +void ATL_ctrsvLNU +( + const int, + const float *, const int, + float * +); + +void ATL_ctrsvLHN +( + const int, + const float *, const int, + float * +); + +void ATL_ctrsvLCN +( + const int, + const float *, const int, + float * +); + +void ATL_ctrsvLTN +( + const int, + const float *, const int, + float * +); + +void ATL_ctrsvLNN +( + const int, + const float *, const int, + float * +); + +void ATL_ctrsvUHU +( + const int, + const float *, const int, + float * +); + +void ATL_ctrsvUCU +( + const int, + const float *, const int, + float * +); + +void ATL_ctrsvUTU +( + const int, + const float *, const int, + float * +); + +void ATL_ctrsvUNU +( + const int, + const float *, const int, + float * +); + +void ATL_ctrsvUHN +( + const int, + const float *, const int, + float * +); + +void ATL_ctrsvUCN +( + const int, + const float *, const int, + float * +); + +void ATL_ctrsvUTN +( + const int, + const float *, const int, + float * +); + +void ATL_ctrsvUNN +( + const int, + const float *, const int, + float * +); + +void ATL_ctrsvLH +( + const enum ATLAS_DIAG, + const int, + const float *, const int, + float * +); + +void ATL_ctrsvLC +( + const enum ATLAS_DIAG, + const int, + const float *, const int, + float * +); + +void ATL_ctrsvLT +( + const enum ATLAS_DIAG, + const int, + const float *, const int, + float * +); + +void ATL_ctrsvLN +( + const enum ATLAS_DIAG, + const int, + const float *, const int, + float * +); + +void ATL_ctrsvUH +( + const enum ATLAS_DIAG, + const int, + const float *, const int, + float * +); + +void ATL_ctrsvUC +( + const enum ATLAS_DIAG, + const int, + const float *, const int, + float * +); + +void ATL_ctrsvUT +( + const enum ATLAS_DIAG, + const int, + const float *, const int, + float * +); + +void ATL_ctrsvUN +( + const enum ATLAS_DIAG, + const int, + const float *, const int, + float * +); + +void ATL_ctpsvLHU +( + const int, + const float *, const int, + float * +); + +void ATL_ctpsvLCU +( + const int, + const float *, const int, + float * +); + +void ATL_ctpsvLTU +( + const int, + const float *, const int, + float * +); + +void ATL_ctpsvLNU +( + const int, + const float *, const int, + float * +); + +void ATL_ctpsvLHN +( + const int, + const float *, const int, + float * +); + +void ATL_ctpsvLCN +( + const int, + const float *, const int, + float * +); + +void ATL_ctpsvLTN +( + const int, + const float *, const int, + float * +); + +void ATL_ctpsvLNN +( + const int, + const float *, const int, + float * +); + +void ATL_ctpsvUHU +( + const int, + const float *, const int, + float * +); + +void ATL_ctpsvUCU +( + const int, + const float *, const int, + float * +); + +void ATL_ctpsvUTU +( + const int, + const float *, const int, + float * +); + +void ATL_ctpsvUNU +( + const int, + const float *, const int, + float * +); + +void ATL_ctpsvUHN +( + const int, + const float *, const int, + float * +); + +void ATL_ctpsvUCN +( + const int, + const float *, const int, + float * +); + +void ATL_ctpsvUTN +( + const int, + const float *, const int, + float * +); + +void ATL_ctpsvUNN +( + const int, + const float *, const int, + float * +); + +void ATL_ctpsvLH +( + const enum ATLAS_DIAG, + const int, + const float *, const int, + float * +); + +void ATL_ctpsvLC +( + const enum ATLAS_DIAG, + const int, + const float *, const int, + float * +); + +void ATL_ctpsvLT +( + const enum ATLAS_DIAG, + const int, + const float *, const int, + float * +); + +void ATL_ctpsvLN +( + const enum ATLAS_DIAG, + const int, + const float *, const int, + float * +); + +void ATL_ctpsvUH +( + const enum ATLAS_DIAG, + const int, + const float *, const int, + float * +); + +void ATL_ctpsvUC +( + const enum ATLAS_DIAG, + const int, + const float *, const int, + float * +); + +void ATL_ctpsvUT +( + const enum ATLAS_DIAG, + const int, + const float *, const int, + float * +); + +void ATL_ctpsvUN +( + const enum ATLAS_DIAG, + const int, + const float *, const int, + float * +); + +void ATL_ctbsvLHU +( + const int, const int, + const float *, const int, + float * +); + +void ATL_ctbsvLCU +( + const int, const int, + const float *, const int, + float * +); + +void ATL_ctbsvLTU +( + const int, const int, + const float *, const int, + float * +); + +void ATL_ctbsvLNU +( + const int, const int, + const float *, const int, + float * +); + +void ATL_ctbsvLHN +( + const int, const int, + const float *, const int, + float * +); + +void ATL_ctbsvLCN +( + const int, const int, + const float *, const int, + float * +); + +void ATL_ctbsvLTN +( + const int, const int, + const float *, const int, + float * +); + +void ATL_ctbsvLNN +( + const int, const int, + const float *, const int, + float * +); + +void ATL_ctbsvUHU +( + const int, const int, + const float *, const int, + float * +); + +void ATL_ctbsvUCU +( + const int, const int, + const float *, const int, + float * +); + +void ATL_ctbsvUTU +( + const int, const int, + const float *, const int, + float * +); + +void ATL_ctbsvUNU +( + const int, const int, + const float *, const int, + float * +); + +void ATL_ctbsvUHN +( + const int, const int, + const float *, const int, + float * +); + +void ATL_ctbsvUCN +( + const int, const int, + const float *, const int, + float * +); + +void ATL_ctbsvUTN +( + const int, const int, + const float *, const int, + float * +); + +void ATL_ctbsvUNN +( + const int, const int, + const float *, const int, + float * +); + +void ATL_ctbsvLH +( + const enum ATLAS_DIAG, + const int, const int, + const float *, const int, + float * +); + +void ATL_ctbsvLC +( + const enum ATLAS_DIAG, + const int, const int, + const float *, const int, + float * +); + +void ATL_ctbsvLT +( + const enum ATLAS_DIAG, + const int, const int, + const float *, const int, + float * +); + +void ATL_ctbsvLN +( + const enum ATLAS_DIAG, + const int, const int, + const float *, const int, + float * +); + +void ATL_ctbsvUH +( + const enum ATLAS_DIAG, + const int, const int, + const float *, const int, + float * +); + +void ATL_ctbsvUC +( + const enum ATLAS_DIAG, + const int, const int, + const float *, const int, + float * +); + +void ATL_ctbsvUT +( + const enum ATLAS_DIAG, + const int, const int, + const float *, const int, + float * +); + +void ATL_ctbsvUN +( + const enum ATLAS_DIAG, + const int, const int, + const float *, const int, + float * +); + +void ATL_ctrmvLHU +( + const int, + const float *, const int, + float * +); + +void ATL_ctrmvLCU +( + const int, + const float *, const int, + float * +); + +void ATL_ctrmvLTU +( + const int, + const float *, const int, + float * +); + +void ATL_ctrmvLNU +( + const int, + const float *, const int, + float * +); + +void ATL_ctrmvLHN +( + const int, + const float *, const int, + float * +); + +void ATL_ctrmvLCN +( + const int, + const float *, const int, + float * +); + +void ATL_ctrmvLTN +( + const int, + const float *, const int, + float * +); + +void ATL_ctrmvLNN +( + const int, + const float *, const int, + float * +); + +void ATL_ctrmvUHU +( + const int, + const float *, const int, + float * +); + +void ATL_ctrmvUCU +( + const int, + const float *, const int, + float * +); + +void ATL_ctrmvUTU +( + const int, + const float *, const int, + float * +); + +void ATL_ctrmvUNU +( + const int, + const float *, const int, + float * +); + +void ATL_ctrmvUHN +( + const int, + const float *, const int, + float * +); + +void ATL_ctrmvUCN +( + const int, + const float *, const int, + float * +); + +void ATL_ctrmvUTN +( + const int, + const float *, const int, + float * +); + +void ATL_ctrmvUNN +( + const int, + const float *, const int, + float * +); + +void ATL_ctrmvLH +( + const enum ATLAS_DIAG, + const int, + const float *, const int, + float * +); + +void ATL_ctrmvLC +( + const enum ATLAS_DIAG, + const int, + const float *, const int, + float * +); + +void ATL_ctrmvLT +( + const enum ATLAS_DIAG, + const int, + const float *, const int, + float * +); + +void ATL_ctrmvLN +( + const enum ATLAS_DIAG, + const int, + const float *, const int, + float * +); + +void ATL_ctrmvUH +( + const enum ATLAS_DIAG, + const int, + const float *, const int, + float * +); + +void ATL_ctrmvUC +( + const enum ATLAS_DIAG, + const int, + const float *, const int, + float * +); + +void ATL_ctrmvUT +( + const enum ATLAS_DIAG, + const int, + const float *, const int, + float * +); + +void ATL_ctrmvUN +( + const enum ATLAS_DIAG, + const int, + const float *, const int, + float * +); + +void ATL_ctpmvLHU +( + const int, + const float *, const int, + float * +); + +void ATL_ctpmvLCU +( + const int, + const float *, const int, + float * +); + +void ATL_ctpmvLTU +( + const int, + const float *, const int, + float * +); + +void ATL_ctpmvLNU +( + const int, + const float *, const int, + float * +); + +void ATL_ctpmvLHN +( + const int, + const float *, const int, + float * +); + +void ATL_ctpmvLCN +( + const int, + const float *, const int, + float * +); + +void ATL_ctpmvLTN +( + const int, + const float *, const int, + float * +); + +void ATL_ctpmvLNN +( + const int, + const float *, const int, + float * +); + +void ATL_ctpmvUHU +( + const int, + const float *, const int, + float * +); + +void ATL_ctpmvUCU +( + const int, + const float *, const int, + float * +); + +void ATL_ctpmvUTU +( + const int, + const float *, const int, + float * +); + +void ATL_ctpmvUNU +( + const int, + const float *, const int, + float * +); + +void ATL_ctpmvUHN +( + const int, + const float *, const int, + float * +); + +void ATL_ctpmvUCN +( + const int, + const float *, const int, + float * +); + +void ATL_ctpmvUTN +( + const int, + const float *, const int, + float * +); + +void ATL_ctpmvUNN +( + const int, + const float *, const int, + float * +); + +void ATL_ctpmvLH +( + const enum ATLAS_DIAG, + const int, + const float *, const int, + float * +); + +void ATL_ctpmvLC +( + const enum ATLAS_DIAG, + const int, + const float *, const int, + float * +); + +void ATL_ctpmvLT +( + const enum ATLAS_DIAG, + const int, + const float *, const int, + float * +); + +void ATL_ctpmvLN +( + const enum ATLAS_DIAG, + const int, + const float *, const int, + float * +); + +void ATL_ctpmvUH +( + const enum ATLAS_DIAG, + const int, + const float *, const int, + float * +); + +void ATL_ctpmvUC +( + const enum ATLAS_DIAG, + const int, + const float *, const int, + float * +); + +void ATL_ctpmvUT +( + const enum ATLAS_DIAG, + const int, + const float *, const int, + float * +); + +void ATL_ctpmvUN +( + const enum ATLAS_DIAG, + const int, + const float *, const int, + float * +); + +void ATL_ctbmvLHU +( + const int, const int, + const float *, const int, + float * +); + +void ATL_ctbmvLCU +( + const int, const int, + const float *, const int, + float * +); + +void ATL_ctbmvLTU +( + const int, const int, + const float *, const int, + float * +); + +void ATL_ctbmvLNU +( + const int, const int, + const float *, const int, + float * +); + +void ATL_ctbmvLHN +( + const int, const int, + const float *, const int, + float * +); + +void ATL_ctbmvLCN +( + const int, const int, + const float *, const int, + float * +); + +void ATL_ctbmvLTN +( + const int, const int, + const float *, const int, + float * +); + +void ATL_ctbmvLNN +( + const int, const int, + const float *, const int, + float * +); + +void ATL_ctbmvUHU +( + const int, const int, + const float *, const int, + float * +); + +void ATL_ctbmvUCU +( + const int, const int, + const float *, const int, + float * +); + +void ATL_ctbmvUTU +( + const int, const int, + const float *, const int, + float * +); + +void ATL_ctbmvUNU +( + const int, const int, + const float *, const int, + float * +); + +void ATL_ctbmvUHN +( + const int, const int, + const float *, const int, + float * +); + +void ATL_ctbmvUCN +( + const int, const int, + const float *, const int, + float * +); + +void ATL_ctbmvUTN +( + const int, const int, + const float *, const int, + float * +); + +void ATL_ctbmvUNN +( + const int, const int, + const float *, const int, + float * +); + +void ATL_ctbmvLH +( + const enum ATLAS_DIAG, + const int, const int, + const float *, const int, + float * +); + +void ATL_ctbmvLC +( + const enum ATLAS_DIAG, + const int, const int, + const float *, const int, + float * +); + +void ATL_ctbmvLT +( + const enum ATLAS_DIAG, + const int, const int, + const float *, const int, + float * +); + +void ATL_ctbmvLN +( + const enum ATLAS_DIAG, + const int, const int, + const float *, const int, + float * +); + +void ATL_ctbmvUH +( + const enum ATLAS_DIAG, + const int, const int, + const float *, const int, + float * +); + +void ATL_ctbmvUC +( + const enum ATLAS_DIAG, + const int, const int, + const float *, const int, + float * +); + +void ATL_ctbmvUT +( + const enum ATLAS_DIAG, + const int, const int, + const float *, const int, + float * +); + +void ATL_ctbmvUN +( + const enum ATLAS_DIAG, + const int, const int, + const float *, const int, + float * +); + +void ATL_cher2U +( + const int, + const float *, + const float *, + float *, const int +); + +void ATL_cher2L +( + const int, + const float *, + const float *, + float *, const int +); + +void ATL_chpr2U +( + const int, + const float *, + const float *, + float *, const int +); + +void ATL_chpr2L +( + const int, + const float *, + const float *, + float *, const int +); + +void ATL_cherU +( + const int, + const float *, + const float *, const int, + float *, const int +); + +void ATL_cherL +( + const int, + const float *, + const float *, const int, + float *, const int +); + +void ATL_chprU +( + const int, + const float *, + const float *, const int, + float *, const int +); + +void ATL_chprL +( + const int, + const float *, + const float *, const int, + float *, const int +); + +void ATL_chemvU +( + const int, + const float *, const int, + const float *, + const float *, + float * +); + +void ATL_chemvL +( + const int, + const float *, const int, + const float *, + const float *, + float * +); + +void ATL_chpmvU +( + const int, + const float *, const int, + const float *, + const float *, + float * +); + +void ATL_chpmvL +( + const int, + const float *, const int, + const float *, + const float *, + float * +); + +void ATL_chbmvU +( + const int, const int, + const float *, const int, + const float *, + const float *, + float * +); + +void ATL_chbmvL +( + const int, const int, + const float *, const int, + const float *, + const float *, + float * +); + +void ATL_cgpmv +( + const enum ATLAS_UPLO, const enum ATLAS_TRANS, + const int, const int, + const float *, + const float *, const int, + const float *, const int, + const float *, + float *, const int +); + +void ATL_cgpruU +( + const int, const int, + const float *, + const float *, const int, + const float *, const int, + float *, const int +); + +void ATL_cgpruL +( + const int, const int, + const float *, + const float *, const int, + const float *, const int, + float *, const int +); + +void ATL_cgpru +( + const enum ATLAS_UPLO, + const int, const int, + const float *, + const float *, const int, + const float *, const int, + float *, const int +); + +void ATL_cgprcU +( + const int, const int, + const float *, + const float *, const int, + const float *, const int, + float *, const int +); + +void ATL_cgprcL +( + const int, const int, + const float *, + const float *, const int, + const float *, const int, + float *, const int +); + +void ATL_cgprc +( + const enum ATLAS_UPLO, + const int, const int, + const float *, + const float *, const int, + const float *, const int, + float *, const int +); + +void ATL_cgpr1uU_a1_x1_yX +( + const int, const int, + const float *, + const float *, const int, + const float *, const int, + float *, const int +); + +void ATL_cgpr1uL_a1_x1_yX +( + const int, const int, + const float *, + const float *, const int, + const float *, const int, + float *, const int +); + +void ATL_cgpr1cU_a1_x1_yX +( + const int, const int, + const float *, + const float *, const int, + const float *, const int, + float *, const int +); + +void ATL_cgpr1cL_a1_x1_yX +( + const int, const int, + const float *, + const float *, const int, + const float *, const int, + float *, const int +); + +void ATL_cgpmvUNc_a1_x1_bX_y1 +( + const int, const int, + const float *, + const float *, const int, + const float *, const int, + const float *, + float *, const int +); + +void ATL_cgpmvUC_a1_x1_bX_y1 +( + const int, const int, + const float *, + const float *, const int, + const float *, const int, + const float *, + float *, const int +); + +void ATL_cgpmvUT_a1_x1_bX_y1 +( + const int, const int, + const float *, + const float *, const int, + const float *, const int, + const float *, + float *, const int +); + +void ATL_cgpmvUN_a1_x1_bX_y1 +( + const int, const int, + const float *, + const float *, const int, + const float *, const int, + const float *, + float *, const int +); + +void ATL_cgpmvUNc_a1_x1_b1_y1 +( + const int, const int, + const float *, + const float *, const int, + const float *, const int, + const float *, + float *, const int +); + +void ATL_cgpmvUC_a1_x1_b1_y1 +( + const int, const int, + const float *, + const float *, const int, + const float *, const int, + const float *, + float *, const int +); + +void ATL_cgpmvUT_a1_x1_b1_y1 +( + const int, const int, + const float *, + const float *, const int, + const float *, const int, + const float *, + float *, const int +); + +void ATL_cgpmvUN_a1_x1_b1_y1 +( + const int, const int, + const float *, + const float *, const int, + const float *, const int, + const float *, + float *, const int +); + +void ATL_cgpmvUNc_a1_x1_bXi0_y1 +( + const int, const int, + const float *, + const float *, const int, + const float *, const int, + const float *, + float *, const int +); + +void ATL_cgpmvUC_a1_x1_bXi0_y1 +( + const int, const int, + const float *, + const float *, const int, + const float *, const int, + const float *, + float *, const int +); + +void ATL_cgpmvUT_a1_x1_bXi0_y1 +( + const int, const int, + const float *, + const float *, const int, + const float *, const int, + const float *, + float *, const int +); + +void ATL_cgpmvUN_a1_x1_bXi0_y1 +( + const int, const int, + const float *, + const float *, const int, + const float *, const int, + const float *, + float *, const int +); + +void ATL_cgpmvUNc_a1_x1_b0_y1 +( + const int, const int, + const float *, + const float *, const int, + const float *, const int, + const float *, + float *, const int +); + +void ATL_cgpmvUC_a1_x1_b0_y1 +( + const int, const int, + const float *, + const float *, const int, + const float *, const int, + const float *, + float *, const int +); + +void ATL_cgpmvUT_a1_x1_b0_y1 +( + const int, const int, + const float *, + const float *, const int, + const float *, const int, + const float *, + float *, const int +); + +void ATL_cgpmvUN_a1_x1_b0_y1 +( + const int, const int, + const float *, + const float *, const int, + const float *, const int, + const float *, + float *, const int +); + +void ATL_cgpmvLNc_a1_x1_bX_y1 +( + const int, const int, + const float *, + const float *, const int, + const float *, const int, + const float *, + float *, const int +); + +void ATL_cgpmvLC_a1_x1_bX_y1 +( + const int, const int, + const float *, + const float *, const int, + const float *, const int, + const float *, + float *, const int +); + +void ATL_cgpmvLT_a1_x1_bX_y1 +( + const int, const int, + const float *, + const float *, const int, + const float *, const int, + const float *, + float *, const int +); + +void ATL_cgpmvLN_a1_x1_bX_y1 +( + const int, const int, + const float *, + const float *, const int, + const float *, const int, + const float *, + float *, const int +); + +void ATL_cgpmvLNc_a1_x1_b1_y1 +( + const int, const int, + const float *, + const float *, const int, + const float *, const int, + const float *, + float *, const int +); + +void ATL_cgpmvLC_a1_x1_b1_y1 +( + const int, const int, + const float *, + const float *, const int, + const float *, const int, + const float *, + float *, const int +); + +void ATL_cgpmvLT_a1_x1_b1_y1 +( + const int, const int, + const float *, + const float *, const int, + const float *, const int, + const float *, + float *, const int +); + +void ATL_cgpmvLN_a1_x1_b1_y1 +( + const int, const int, + const float *, + const float *, const int, + const float *, const int, + const float *, + float *, const int +); + +void ATL_cgpmvLNc_a1_x1_bXi0_y1 +( + const int, const int, + const float *, + const float *, const int, + const float *, const int, + const float *, + float *, const int +); + +void ATL_cgpmvLC_a1_x1_bXi0_y1 +( + const int, const int, + const float *, + const float *, const int, + const float *, const int, + const float *, + float *, const int +); + +void ATL_cgpmvLT_a1_x1_bXi0_y1 +( + const int, const int, + const float *, + const float *, const int, + const float *, const int, + const float *, + float *, const int +); + +void ATL_cgpmvLN_a1_x1_bXi0_y1 +( + const int, const int, + const float *, + const float *, const int, + const float *, const int, + const float *, + float *, const int +); + +void ATL_cgpmvLNc_a1_x1_b0_y1 +( + const int, const int, + const float *, + const float *, const int, + const float *, const int, + const float *, + float *, const int +); + +void ATL_cgpmvLC_a1_x1_b0_y1 +( + const int, const int, + const float *, + const float *, const int, + const float *, const int, + const float *, + float *, const int +); + +void ATL_cgpmvLT_a1_x1_b0_y1 +( + const int, const int, + const float *, + const float *, const int, + const float *, const int, + const float *, + float *, const int +); + +void ATL_cgpmvLN_a1_x1_b0_y1 +( + const int, const int, + const float *, + const float *, const int, + const float *, const int, + const float *, + float *, const int +); + +void ATL_cgbmvNc_a1_x1_bX_y1 +( + const int, const int, + const int, const int, + const float *, + const float *, const int, + const float *, const int, + const float *, + float *, const int +); + +void ATL_cgbmvC_a1_x1_bX_y1 +( + const int, const int, + const int, const int, + const float *, + const float *, const int, + const float *, const int, + const float *, + float *, const int +); + +void ATL_cgbmvT_a1_x1_bX_y1 +( + const int, const int, + const int, const int, + const float *, + const float *, const int, + const float *, const int, + const float *, + float *, const int +); + +void ATL_cgbmvN_a1_x1_bX_y1 +( + const int, const int, + const int, const int, + const float *, + const float *, const int, + const float *, const int, + const float *, + float *, const int +); + +void ATL_cgbmvNc_a1_x1_b1_y1 +( + const int, const int, + const int, const int, + const float *, + const float *, const int, + const float *, const int, + const float *, + float *, const int +); + +void ATL_cgbmvC_a1_x1_b1_y1 +( + const int, const int, + const int, const int, + const float *, + const float *, const int, + const float *, const int, + const float *, + float *, const int +); + +void ATL_cgbmvT_a1_x1_b1_y1 +( + const int, const int, + const int, const int, + const float *, + const float *, const int, + const float *, const int, + const float *, + float *, const int +); + +void ATL_cgbmvN_a1_x1_b1_y1 +( + const int, const int, + const int, const int, + const float *, + const float *, const int, + const float *, const int, + const float *, + float *, const int +); + +void ATL_cgbmvNc_a1_x1_bXi0_y1 +( + const int, const int, + const int, const int, + const float *, + const float *, const int, + const float *, const int, + const float *, + float *, const int +); + +void ATL_cgbmvC_a1_x1_bXi0_y1 +( + const int, const int, + const int, const int, + const float *, + const float *, const int, + const float *, const int, + const float *, + float *, const int +); + +void ATL_cgbmvT_a1_x1_bXi0_y1 +( + const int, const int, + const int, const int, + const float *, + const float *, const int, + const float *, const int, + const float *, + float *, const int +); + +void ATL_cgbmvN_a1_x1_bXi0_y1 +( + const int, const int, + const int, const int, + const float *, + const float *, const int, + const float *, const int, + const float *, + float *, const int +); + +void ATL_cgbmvNc_a1_x1_b0_y1 +( + const int, const int, + const int, const int, + const float *, + const float *, const int, + const float *, const int, + const float *, + float *, const int +); + +void ATL_cgbmvC_a1_x1_b0_y1 +( + const int, const int, + const int, const int, + const float *, + const float *, const int, + const float *, const int, + const float *, + float *, const int +); + +void ATL_cgbmvT_a1_x1_b0_y1 +( + const int, const int, + const int, const int, + const float *, + const float *, const int, + const float *, const int, + const float *, + float *, const int +); + +void ATL_cgbmvN_a1_x1_b0_y1 +( + const int, const int, + const int, const int, + const float *, + const float *, const int, + const float *, const int, + const float *, + float *, const int +); + +void ATL_ztrsvLHU +( + const int, + const double *, const int, + double * +); + +void ATL_ztrsvLCU +( + const int, + const double *, const int, + double * +); + +void ATL_ztrsvLTU +( + const int, + const double *, const int, + double * +); + +void ATL_ztrsvLNU +( + const int, + const double *, const int, + double * +); + +void ATL_ztrsvLHN +( + const int, + const double *, const int, + double * +); + +void ATL_ztrsvLCN +( + const int, + const double *, const int, + double * +); + +void ATL_ztrsvLTN +( + const int, + const double *, const int, + double * +); + +void ATL_ztrsvLNN +( + const int, + const double *, const int, + double * +); + +void ATL_ztrsvUHU +( + const int, + const double *, const int, + double * +); + +void ATL_ztrsvUCU +( + const int, + const double *, const int, + double * +); + +void ATL_ztrsvUTU +( + const int, + const double *, const int, + double * +); + +void ATL_ztrsvUNU +( + const int, + const double *, const int, + double * +); + +void ATL_ztrsvUHN +( + const int, + const double *, const int, + double * +); + +void ATL_ztrsvUCN +( + const int, + const double *, const int, + double * +); + +void ATL_ztrsvUTN +( + const int, + const double *, const int, + double * +); + +void ATL_ztrsvUNN +( + const int, + const double *, const int, + double * +); + +void ATL_ztrsvLH +( + const enum ATLAS_DIAG, + const int, + const double *, const int, + double * +); + +void ATL_ztrsvLC +( + const enum ATLAS_DIAG, + const int, + const double *, const int, + double * +); + +void ATL_ztrsvLT +( + const enum ATLAS_DIAG, + const int, + const double *, const int, + double * +); + +void ATL_ztrsvLN +( + const enum ATLAS_DIAG, + const int, + const double *, const int, + double * +); + +void ATL_ztrsvUH +( + const enum ATLAS_DIAG, + const int, + const double *, const int, + double * +); + +void ATL_ztrsvUC +( + const enum ATLAS_DIAG, + const int, + const double *, const int, + double * +); + +void ATL_ztrsvUT +( + const enum ATLAS_DIAG, + const int, + const double *, const int, + double * +); + +void ATL_ztrsvUN +( + const enum ATLAS_DIAG, + const int, + const double *, const int, + double * +); + +void ATL_ztpsvLHU +( + const int, + const double *, const int, + double * +); + +void ATL_ztpsvLCU +( + const int, + const double *, const int, + double * +); + +void ATL_ztpsvLTU +( + const int, + const double *, const int, + double * +); + +void ATL_ztpsvLNU +( + const int, + const double *, const int, + double * +); + +void ATL_ztpsvLHN +( + const int, + const double *, const int, + double * +); + +void ATL_ztpsvLCN +( + const int, + const double *, const int, + double * +); + +void ATL_ztpsvLTN +( + const int, + const double *, const int, + double * +); + +void ATL_ztpsvLNN +( + const int, + const double *, const int, + double * +); + +void ATL_ztpsvUHU +( + const int, + const double *, const int, + double * +); + +void ATL_ztpsvUCU +( + const int, + const double *, const int, + double * +); + +void ATL_ztpsvUTU +( + const int, + const double *, const int, + double * +); + +void ATL_ztpsvUNU +( + const int, + const double *, const int, + double * +); + +void ATL_ztpsvUHN +( + const int, + const double *, const int, + double * +); + +void ATL_ztpsvUCN +( + const int, + const double *, const int, + double * +); + +void ATL_ztpsvUTN +( + const int, + const double *, const int, + double * +); + +void ATL_ztpsvUNN +( + const int, + const double *, const int, + double * +); + +void ATL_ztpsvLH +( + const enum ATLAS_DIAG, + const int, + const double *, const int, + double * +); + +void ATL_ztpsvLC +( + const enum ATLAS_DIAG, + const int, + const double *, const int, + double * +); + +void ATL_ztpsvLT +( + const enum ATLAS_DIAG, + const int, + const double *, const int, + double * +); + +void ATL_ztpsvLN +( + const enum ATLAS_DIAG, + const int, + const double *, const int, + double * +); + +void ATL_ztpsvUH +( + const enum ATLAS_DIAG, + const int, + const double *, const int, + double * +); + +void ATL_ztpsvUC +( + const enum ATLAS_DIAG, + const int, + const double *, const int, + double * +); + +void ATL_ztpsvUT +( + const enum ATLAS_DIAG, + const int, + const double *, const int, + double * +); + +void ATL_ztpsvUN +( + const enum ATLAS_DIAG, + const int, + const double *, const int, + double * +); + +void ATL_ztbsvLHU +( + const int, const int, + const double *, const int, + double * +); + +void ATL_ztbsvLCU +( + const int, const int, + const double *, const int, + double * +); + +void ATL_ztbsvLTU +( + const int, const int, + const double *, const int, + double * +); + +void ATL_ztbsvLNU +( + const int, const int, + const double *, const int, + double * +); + +void ATL_ztbsvLHN +( + const int, const int, + const double *, const int, + double * +); + +void ATL_ztbsvLCN +( + const int, const int, + const double *, const int, + double * +); + +void ATL_ztbsvLTN +( + const int, const int, + const double *, const int, + double * +); + +void ATL_ztbsvLNN +( + const int, const int, + const double *, const int, + double * +); + +void ATL_ztbsvUHU +( + const int, const int, + const double *, const int, + double * +); + +void ATL_ztbsvUCU +( + const int, const int, + const double *, const int, + double * +); + +void ATL_ztbsvUTU +( + const int, const int, + const double *, const int, + double * +); + +void ATL_ztbsvUNU +( + const int, const int, + const double *, const int, + double * +); + +void ATL_ztbsvUHN +( + const int, const int, + const double *, const int, + double * +); + +void ATL_ztbsvUCN +( + const int, const int, + const double *, const int, + double * +); + +void ATL_ztbsvUTN +( + const int, const int, + const double *, const int, + double * +); + +void ATL_ztbsvUNN +( + const int, const int, + const double *, const int, + double * +); + +void ATL_ztbsvLH +( + const enum ATLAS_DIAG, + const int, const int, + const double *, const int, + double * +); + +void ATL_ztbsvLC +( + const enum ATLAS_DIAG, + const int, const int, + const double *, const int, + double * +); + +void ATL_ztbsvLT +( + const enum ATLAS_DIAG, + const int, const int, + const double *, const int, + double * +); + +void ATL_ztbsvLN +( + const enum ATLAS_DIAG, + const int, const int, + const double *, const int, + double * +); + +void ATL_ztbsvUH +( + const enum ATLAS_DIAG, + const int, const int, + const double *, const int, + double * +); + +void ATL_ztbsvUC +( + const enum ATLAS_DIAG, + const int, const int, + const double *, const int, + double * +); + +void ATL_ztbsvUT +( + const enum ATLAS_DIAG, + const int, const int, + const double *, const int, + double * +); + +void ATL_ztbsvUN +( + const enum ATLAS_DIAG, + const int, const int, + const double *, const int, + double * +); + +void ATL_ztrmvLHU +( + const int, + const double *, const int, + double * +); + +void ATL_ztrmvLCU +( + const int, + const double *, const int, + double * +); + +void ATL_ztrmvLTU +( + const int, + const double *, const int, + double * +); + +void ATL_ztrmvLNU +( + const int, + const double *, const int, + double * +); + +void ATL_ztrmvLHN +( + const int, + const double *, const int, + double * +); + +void ATL_ztrmvLCN +( + const int, + const double *, const int, + double * +); + +void ATL_ztrmvLTN +( + const int, + const double *, const int, + double * +); + +void ATL_ztrmvLNN +( + const int, + const double *, const int, + double * +); + +void ATL_ztrmvUHU +( + const int, + const double *, const int, + double * +); + +void ATL_ztrmvUCU +( + const int, + const double *, const int, + double * +); + +void ATL_ztrmvUTU +( + const int, + const double *, const int, + double * +); + +void ATL_ztrmvUNU +( + const int, + const double *, const int, + double * +); + +void ATL_ztrmvUHN +( + const int, + const double *, const int, + double * +); + +void ATL_ztrmvUCN +( + const int, + const double *, const int, + double * +); + +void ATL_ztrmvUTN +( + const int, + const double *, const int, + double * +); + +void ATL_ztrmvUNN +( + const int, + const double *, const int, + double * +); + +void ATL_ztrmvLH +( + const enum ATLAS_DIAG, + const int, + const double *, const int, + double * +); + +void ATL_ztrmvLC +( + const enum ATLAS_DIAG, + const int, + const double *, const int, + double * +); + +void ATL_ztrmvLT +( + const enum ATLAS_DIAG, + const int, + const double *, const int, + double * +); + +void ATL_ztrmvLN +( + const enum ATLAS_DIAG, + const int, + const double *, const int, + double * +); + +void ATL_ztrmvUH +( + const enum ATLAS_DIAG, + const int, + const double *, const int, + double * +); + +void ATL_ztrmvUC +( + const enum ATLAS_DIAG, + const int, + const double *, const int, + double * +); + +void ATL_ztrmvUT +( + const enum ATLAS_DIAG, + const int, + const double *, const int, + double * +); + +void ATL_ztrmvUN +( + const enum ATLAS_DIAG, + const int, + const double *, const int, + double * +); + +void ATL_ztpmvLHU +( + const int, + const double *, const int, + double * +); + +void ATL_ztpmvLCU +( + const int, + const double *, const int, + double * +); + +void ATL_ztpmvLTU +( + const int, + const double *, const int, + double * +); + +void ATL_ztpmvLNU +( + const int, + const double *, const int, + double * +); + +void ATL_ztpmvLHN +( + const int, + const double *, const int, + double * +); + +void ATL_ztpmvLCN +( + const int, + const double *, const int, + double * +); + +void ATL_ztpmvLTN +( + const int, + const double *, const int, + double * +); + +void ATL_ztpmvLNN +( + const int, + const double *, const int, + double * +); + +void ATL_ztpmvUHU +( + const int, + const double *, const int, + double * +); + +void ATL_ztpmvUCU +( + const int, + const double *, const int, + double * +); + +void ATL_ztpmvUTU +( + const int, + const double *, const int, + double * +); + +void ATL_ztpmvUNU +( + const int, + const double *, const int, + double * +); + +void ATL_ztpmvUHN +( + const int, + const double *, const int, + double * +); + +void ATL_ztpmvUCN +( + const int, + const double *, const int, + double * +); + +void ATL_ztpmvUTN +( + const int, + const double *, const int, + double * +); + +void ATL_ztpmvUNN +( + const int, + const double *, const int, + double * +); + +void ATL_ztpmvLH +( + const enum ATLAS_DIAG, + const int, + const double *, const int, + double * +); + +void ATL_ztpmvLC +( + const enum ATLAS_DIAG, + const int, + const double *, const int, + double * +); + +void ATL_ztpmvLT +( + const enum ATLAS_DIAG, + const int, + const double *, const int, + double * +); + +void ATL_ztpmvLN +( + const enum ATLAS_DIAG, + const int, + const double *, const int, + double * +); + +void ATL_ztpmvUH +( + const enum ATLAS_DIAG, + const int, + const double *, const int, + double * +); + +void ATL_ztpmvUC +( + const enum ATLAS_DIAG, + const int, + const double *, const int, + double * +); + +void ATL_ztpmvUT +( + const enum ATLAS_DIAG, + const int, + const double *, const int, + double * +); + +void ATL_ztpmvUN +( + const enum ATLAS_DIAG, + const int, + const double *, const int, + double * +); + +void ATL_ztbmvLHU +( + const int, const int, + const double *, const int, + double * +); + +void ATL_ztbmvLCU +( + const int, const int, + const double *, const int, + double * +); + +void ATL_ztbmvLTU +( + const int, const int, + const double *, const int, + double * +); + +void ATL_ztbmvLNU +( + const int, const int, + const double *, const int, + double * +); + +void ATL_ztbmvLHN +( + const int, const int, + const double *, const int, + double * +); + +void ATL_ztbmvLCN +( + const int, const int, + const double *, const int, + double * +); + +void ATL_ztbmvLTN +( + const int, const int, + const double *, const int, + double * +); + +void ATL_ztbmvLNN +( + const int, const int, + const double *, const int, + double * +); + +void ATL_ztbmvUHU +( + const int, const int, + const double *, const int, + double * +); + +void ATL_ztbmvUCU +( + const int, const int, + const double *, const int, + double * +); + +void ATL_ztbmvUTU +( + const int, const int, + const double *, const int, + double * +); + +void ATL_ztbmvUNU +( + const int, const int, + const double *, const int, + double * +); + +void ATL_ztbmvUHN +( + const int, const int, + const double *, const int, + double * +); + +void ATL_ztbmvUCN +( + const int, const int, + const double *, const int, + double * +); + +void ATL_ztbmvUTN +( + const int, const int, + const double *, const int, + double * +); + +void ATL_ztbmvUNN +( + const int, const int, + const double *, const int, + double * +); + +void ATL_ztbmvLH +( + const enum ATLAS_DIAG, + const int, const int, + const double *, const int, + double * +); + +void ATL_ztbmvLC +( + const enum ATLAS_DIAG, + const int, const int, + const double *, const int, + double * +); + +void ATL_ztbmvLT +( + const enum ATLAS_DIAG, + const int, const int, + const double *, const int, + double * +); + +void ATL_ztbmvLN +( + const enum ATLAS_DIAG, + const int, const int, + const double *, const int, + double * +); + +void ATL_ztbmvUH +( + const enum ATLAS_DIAG, + const int, const int, + const double *, const int, + double * +); + +void ATL_ztbmvUC +( + const enum ATLAS_DIAG, + const int, const int, + const double *, const int, + double * +); + +void ATL_ztbmvUT +( + const enum ATLAS_DIAG, + const int, const int, + const double *, const int, + double * +); + +void ATL_ztbmvUN +( + const enum ATLAS_DIAG, + const int, const int, + const double *, const int, + double * +); + +void ATL_zher2U +( + const int, + const double *, + const double *, + double *, const int +); + +void ATL_zher2L +( + const int, + const double *, + const double *, + double *, const int +); + +void ATL_zhpr2U +( + const int, + const double *, + const double *, + double *, const int +); + +void ATL_zhpr2L +( + const int, + const double *, + const double *, + double *, const int +); + +void ATL_zherU +( + const int, + const double *, + const double *, const int, + double *, const int +); + +void ATL_zherL +( + const int, + const double *, + const double *, const int, + double *, const int +); + +void ATL_zhprU +( + const int, + const double *, + const double *, const int, + double *, const int +); + +void ATL_zhprL +( + const int, + const double *, + const double *, const int, + double *, const int +); + +void ATL_zhemvU +( + const int, + const double *, const int, + const double *, + const double *, + double * +); + +void ATL_zhemvL +( + const int, + const double *, const int, + const double *, + const double *, + double * +); + +void ATL_zhpmvU +( + const int, + const double *, const int, + const double *, + const double *, + double * +); + +void ATL_zhpmvL +( + const int, + const double *, const int, + const double *, + const double *, + double * +); + +void ATL_zhbmvU +( + const int, const int, + const double *, const int, + const double *, + const double *, + double * +); + +void ATL_zhbmvL +( + const int, const int, + const double *, const int, + const double *, + const double *, + double * +); + +void ATL_zgpmv +( + const enum ATLAS_UPLO, const enum ATLAS_TRANS, + const int, const int, + const double *, + const double *, const int, + const double *, const int, + const double *, + double *, const int +); + +void ATL_zgpruU +( + const int, const int, + const double *, + const double *, const int, + const double *, const int, + double *, const int +); + +void ATL_zgpruL +( + const int, const int, + const double *, + const double *, const int, + const double *, const int, + double *, const int +); + +void ATL_zgpru +( + const enum ATLAS_UPLO, + const int, const int, + const double *, + const double *, const int, + const double *, const int, + double *, const int +); + +void ATL_zgprcU +( + const int, const int, + const double *, + const double *, const int, + const double *, const int, + double *, const int +); + +void ATL_zgprcL +( + const int, const int, + const double *, + const double *, const int, + const double *, const int, + double *, const int +); + +void ATL_zgprc +( + const enum ATLAS_UPLO, + const int, const int, + const double *, + const double *, const int, + const double *, const int, + double *, const int +); + +void ATL_zgpr1uU_a1_x1_yX +( + const int, const int, + const double *, + const double *, const int, + const double *, const int, + double *, const int +); + +void ATL_zgpr1uL_a1_x1_yX +( + const int, const int, + const double *, + const double *, const int, + const double *, const int, + double *, const int +); + +void ATL_zgpr1cU_a1_x1_yX +( + const int, const int, + const double *, + const double *, const int, + const double *, const int, + double *, const int +); + +void ATL_zgpr1cL_a1_x1_yX +( + const int, const int, + const double *, + const double *, const int, + const double *, const int, + double *, const int +); + +void ATL_zgpmvUNc_a1_x1_bX_y1 +( + const int, const int, + const double *, + const double *, const int, + const double *, const int, + const double *, + double *, const int +); + +void ATL_zgpmvUC_a1_x1_bX_y1 +( + const int, const int, + const double *, + const double *, const int, + const double *, const int, + const double *, + double *, const int +); + +void ATL_zgpmvUT_a1_x1_bX_y1 +( + const int, const int, + const double *, + const double *, const int, + const double *, const int, + const double *, + double *, const int +); + +void ATL_zgpmvUN_a1_x1_bX_y1 +( + const int, const int, + const double *, + const double *, const int, + const double *, const int, + const double *, + double *, const int +); + +void ATL_zgpmvUNc_a1_x1_b1_y1 +( + const int, const int, + const double *, + const double *, const int, + const double *, const int, + const double *, + double *, const int +); + +void ATL_zgpmvUC_a1_x1_b1_y1 +( + const int, const int, + const double *, + const double *, const int, + const double *, const int, + const double *, + double *, const int +); + +void ATL_zgpmvUT_a1_x1_b1_y1 +( + const int, const int, + const double *, + const double *, const int, + const double *, const int, + const double *, + double *, const int +); + +void ATL_zgpmvUN_a1_x1_b1_y1 +( + const int, const int, + const double *, + const double *, const int, + const double *, const int, + const double *, + double *, const int +); + +void ATL_zgpmvUNc_a1_x1_bXi0_y1 +( + const int, const int, + const double *, + const double *, const int, + const double *, const int, + const double *, + double *, const int +); + +void ATL_zgpmvUC_a1_x1_bXi0_y1 +( + const int, const int, + const double *, + const double *, const int, + const double *, const int, + const double *, + double *, const int +); + +void ATL_zgpmvUT_a1_x1_bXi0_y1 +( + const int, const int, + const double *, + const double *, const int, + const double *, const int, + const double *, + double *, const int +); + +void ATL_zgpmvUN_a1_x1_bXi0_y1 +( + const int, const int, + const double *, + const double *, const int, + const double *, const int, + const double *, + double *, const int +); + +void ATL_zgpmvUNc_a1_x1_b0_y1 +( + const int, const int, + const double *, + const double *, const int, + const double *, const int, + const double *, + double *, const int +); + +void ATL_zgpmvUC_a1_x1_b0_y1 +( + const int, const int, + const double *, + const double *, const int, + const double *, const int, + const double *, + double *, const int +); + +void ATL_zgpmvUT_a1_x1_b0_y1 +( + const int, const int, + const double *, + const double *, const int, + const double *, const int, + const double *, + double *, const int +); + +void ATL_zgpmvUN_a1_x1_b0_y1 +( + const int, const int, + const double *, + const double *, const int, + const double *, const int, + const double *, + double *, const int +); + +void ATL_zgpmvLNc_a1_x1_bX_y1 +( + const int, const int, + const double *, + const double *, const int, + const double *, const int, + const double *, + double *, const int +); + +void ATL_zgpmvLC_a1_x1_bX_y1 +( + const int, const int, + const double *, + const double *, const int, + const double *, const int, + const double *, + double *, const int +); + +void ATL_zgpmvLT_a1_x1_bX_y1 +( + const int, const int, + const double *, + const double *, const int, + const double *, const int, + const double *, + double *, const int +); + +void ATL_zgpmvLN_a1_x1_bX_y1 +( + const int, const int, + const double *, + const double *, const int, + const double *, const int, + const double *, + double *, const int +); + +void ATL_zgpmvLNc_a1_x1_b1_y1 +( + const int, const int, + const double *, + const double *, const int, + const double *, const int, + const double *, + double *, const int +); + +void ATL_zgpmvLC_a1_x1_b1_y1 +( + const int, const int, + const double *, + const double *, const int, + const double *, const int, + const double *, + double *, const int +); + +void ATL_zgpmvLT_a1_x1_b1_y1 +( + const int, const int, + const double *, + const double *, const int, + const double *, const int, + const double *, + double *, const int +); + +void ATL_zgpmvLN_a1_x1_b1_y1 +( + const int, const int, + const double *, + const double *, const int, + const double *, const int, + const double *, + double *, const int +); + +void ATL_zgpmvLNc_a1_x1_bXi0_y1 +( + const int, const int, + const double *, + const double *, const int, + const double *, const int, + const double *, + double *, const int +); + +void ATL_zgpmvLC_a1_x1_bXi0_y1 +( + const int, const int, + const double *, + const double *, const int, + const double *, const int, + const double *, + double *, const int +); + +void ATL_zgpmvLT_a1_x1_bXi0_y1 +( + const int, const int, + const double *, + const double *, const int, + const double *, const int, + const double *, + double *, const int +); + +void ATL_zgpmvLN_a1_x1_bXi0_y1 +( + const int, const int, + const double *, + const double *, const int, + const double *, const int, + const double *, + double *, const int +); + +void ATL_zgpmvLNc_a1_x1_b0_y1 +( + const int, const int, + const double *, + const double *, const int, + const double *, const int, + const double *, + double *, const int +); + +void ATL_zgpmvLC_a1_x1_b0_y1 +( + const int, const int, + const double *, + const double *, const int, + const double *, const int, + const double *, + double *, const int +); + +void ATL_zgpmvLT_a1_x1_b0_y1 +( + const int, const int, + const double *, + const double *, const int, + const double *, const int, + const double *, + double *, const int +); + +void ATL_zgpmvLN_a1_x1_b0_y1 +( + const int, const int, + const double *, + const double *, const int, + const double *, const int, + const double *, + double *, const int +); + +void ATL_zgbmvNc_a1_x1_bX_y1 +( + const int, const int, + const int, const int, + const double *, + const double *, const int, + const double *, const int, + const double *, + double *, const int +); + +void ATL_zgbmvC_a1_x1_bX_y1 +( + const int, const int, + const int, const int, + const double *, + const double *, const int, + const double *, const int, + const double *, + double *, const int +); + +void ATL_zgbmvT_a1_x1_bX_y1 +( + const int, const int, + const int, const int, + const double *, + const double *, const int, + const double *, const int, + const double *, + double *, const int +); + +void ATL_zgbmvN_a1_x1_bX_y1 +( + const int, const int, + const int, const int, + const double *, + const double *, const int, + const double *, const int, + const double *, + double *, const int +); + +void ATL_zgbmvNc_a1_x1_b1_y1 +( + const int, const int, + const int, const int, + const double *, + const double *, const int, + const double *, const int, + const double *, + double *, const int +); + +void ATL_zgbmvC_a1_x1_b1_y1 +( + const int, const int, + const int, const int, + const double *, + const double *, const int, + const double *, const int, + const double *, + double *, const int +); + +void ATL_zgbmvT_a1_x1_b1_y1 +( + const int, const int, + const int, const int, + const double *, + const double *, const int, + const double *, const int, + const double *, + double *, const int +); + +void ATL_zgbmvN_a1_x1_b1_y1 +( + const int, const int, + const int, const int, + const double *, + const double *, const int, + const double *, const int, + const double *, + double *, const int +); + +void ATL_zgbmvNc_a1_x1_bXi0_y1 +( + const int, const int, + const int, const int, + const double *, + const double *, const int, + const double *, const int, + const double *, + double *, const int +); + +void ATL_zgbmvC_a1_x1_bXi0_y1 +( + const int, const int, + const int, const int, + const double *, + const double *, const int, + const double *, const int, + const double *, + double *, const int +); + +void ATL_zgbmvT_a1_x1_bXi0_y1 +( + const int, const int, + const int, const int, + const double *, + const double *, const int, + const double *, const int, + const double *, + double *, const int +); + +void ATL_zgbmvN_a1_x1_bXi0_y1 +( + const int, const int, + const int, const int, + const double *, + const double *, const int, + const double *, const int, + const double *, + double *, const int +); + +void ATL_zgbmvNc_a1_x1_b0_y1 +( + const int, const int, + const int, const int, + const double *, + const double *, const int, + const double *, const int, + const double *, + double *, const int +); + +void ATL_zgbmvC_a1_x1_b0_y1 +( + const int, const int, + const int, const int, + const double *, + const double *, const int, + const double *, const int, + const double *, + double *, const int +); + +void ATL_zgbmvT_a1_x1_b0_y1 +( + const int, const int, + const int, const int, + const double *, + const double *, const int, + const double *, const int, + const double *, + double *, const int +); + +void ATL_zgbmvN_a1_x1_b0_y1 +( + const int, const int, + const int, const int, + const double *, + const double *, const int, + const double *, const int, + const double *, + double *, const int +); + + +#endif +/* + * End of atlas_kernel2.h + */ diff --git a/kaldi_io/src/tools/ATLAS/include/atlas_kernel3.h b/kaldi_io/src/tools/ATLAS/include/atlas_kernel3.h new file mode 100644 index 0000000..a929c2d --- /dev/null +++ b/kaldi_io/src/tools/ATLAS/include/atlas_kernel3.h @@ -0,0 +1,1393 @@ +/* + * Automatically Tuned Linear Algebra Software v3.8.3 + * (C) Copyright 1999 R. Clint Whaley + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions + * are met: + * 1. Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * 2. Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions, and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * 3. The name of the ATLAS group or the names of its contributers may + * not be used to endorse or promote products derived from this + * software without specific written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + * ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED + * TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR + * PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE ATLAS GROUP OR ITS CONTRIBUTORS + * BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + * POSSIBILITY OF SUCH DAMAGE. + * + */ +#ifndef ATLAS_KERNEL3_H +#define ATLAS_KERNEL3_H + +/* + * Real level 3 kernels + */ +void ATL_ssymmRU + (const int M, const int N, const void *alpha, const void *A, const int lda, + const void *B, const int ldb, const void *beta, void *C, const int ldc); +void ATL_ssymmLU + (const int M, const int N, const void *alpha, const void *A, const int lda, + const void *B, const int ldb, const void *beta, void *C, const int ldc); +void ATL_ssymmRL + (const int M, const int N, const void *alpha, const void *A, const int lda, + const void *B, const int ldb, const void *beta, void *C, const int ldc); +void ATL_ssymmLL + (const int M, const int N, const void *alpha, const void *A, const int lda, + const void *B, const int ldb, const void *beta, void *C, const int ldc); +void ATL_strsmLLTN + (const int M, const int N, const void *valpha, const void *A, const int lda, + void *C, const int ldc); +void ATL_strmmLLTN + (const int M, const int N, const void *valpha, const void *A, const int lda, + void *C, const int ldc); +void ATL_strsmLLTU + (const int M, const int N, const void *valpha, const void *A, const int lda, + void *C, const int ldc); +void ATL_strmmLLTU + (const int M, const int N, const void *valpha, const void *A, const int lda, + void *C, const int ldc); +void ATL_strsmLLNN + (const int M, const int N, const void *valpha, const void *A, const int lda, + void *C, const int ldc); +void ATL_strmmLLNN + (const int M, const int N, const void *valpha, const void *A, const int lda, + void *C, const int ldc); +void ATL_strsmLLNU + (const int M, const int N, const void *valpha, const void *A, const int lda, + void *C, const int ldc); +void ATL_strmmLLNU + (const int M, const int N, const void *valpha, const void *A, const int lda, + void *C, const int ldc); +void ATL_strsmLUTN + (const int M, const int N, const void *valpha, const void *A, const int lda, + void *C, const int ldc); +void ATL_strmmLUTN + (const int M, const int N, const void *valpha, const void *A, const int lda, + void *C, const int ldc); +void ATL_strsmLUTU + (const int M, const int N, const void *valpha, const void *A, const int lda, + void *C, const int ldc); +void ATL_strmmLUTU + (const int M, const int N, const void *valpha, const void *A, const int lda, + void *C, const int ldc); +void ATL_strsmLUNN + (const int M, const int N, const void *valpha, const void *A, const int lda, + void *C, const int ldc); +void ATL_strmmLUNN + (const int M, const int N, const void *valpha, const void *A, const int lda, + void *C, const int ldc); +void ATL_strsmLUNU + (const int M, const int N, const void *valpha, const void *A, const int lda, + void *C, const int ldc); +void ATL_strmmLUNU + (const int M, const int N, const void *valpha, const void *A, const int lda, + void *C, const int ldc); +void ATL_strsmRLTN + (const int M, const int N, const void *valpha, const void *A, const int lda, + void *C, const int ldc); +void ATL_strmmRLTN + (const int M, const int N, const void *valpha, const void *A, const int lda, + void *C, const int ldc); +void ATL_strsmRLTU + (const int M, const int N, const void *valpha, const void *A, const int lda, + void *C, const int ldc); +void ATL_strmmRLTU + (const int M, const int N, const void *valpha, const void *A, const int lda, + void *C, const int ldc); +void ATL_strsmRLNN + (const int M, const int N, const void *valpha, const void *A, const int lda, + void *C, const int ldc); +void ATL_strmmRLNN + (const int M, const int N, const void *valpha, const void *A, const int lda, + void *C, const int ldc); +void ATL_strsmRLNU + (const int M, const int N, const void *valpha, const void *A, const int lda, + void *C, const int ldc); +void ATL_strmmRLNU + (const int M, const int N, const void *valpha, const void *A, const int lda, + void *C, const int ldc); +void ATL_strsmRUTN + (const int M, const int N, const void *valpha, const void *A, const int lda, + void *C, const int ldc); +void ATL_strmmRUTN + (const int M, const int N, const void *valpha, const void *A, const int lda, + void *C, const int ldc); +void ATL_strsmRUTU + (const int M, const int N, const void *valpha, const void *A, const int lda, + void *C, const int ldc); +void ATL_strmmRUTU + (const int M, const int N, const void *valpha, const void *A, const int lda, + void *C, const int ldc); +void ATL_strsmRUNN + (const int M, const int N, const void *valpha, const void *A, const int lda, + void *C, const int ldc); +void ATL_strmmRUNN + (const int M, const int N, const void *valpha, const void *A, const int lda, + void *C, const int ldc); +void ATL_strsmRUNU + (const int M, const int N, const void *valpha, const void *A, const int lda, + void *C, const int ldc); +void ATL_strmmRUNU + (const int M, const int N, const void *valpha, const void *A, const int lda, + void *C, const int ldc); +void ATL_ssyrkLT + (const int N, const int K, const void *valpha, const void *A, const int lda, + const void *vbeta, void *C, const int ldc); +void ATL_ssyrkUT + (const int N, const int K, const void *valpha, const void *A, const int lda, + const void *vbeta, void *C, const int ldc); +void ATL_ssyrkLN + (const int N, const int K, const void *valpha, const void *A, const int lda, + const void *vbeta, void *C, const int ldc); +void ATL_ssyrkUN + (const int N, const int K, const void *valpha, const void *A, const int lda, + const void *vbeta, void *C, const int ldc); +int ATL_ssyr2kLT + (const int N, const int K, const void *valpha, const void *A, const int lda, + const void *B, const int ldb, const void *vbeta, void *C, const int ldc); +int ATL_ssyr2kUT + (const int N, const int K, const void *valpha, const void *A, const int lda, + const void *B, const int ldb, const void *vbeta, void *C, const int ldc); +int ATL_ssyr2kLN + (const int N, const int K, const void *valpha, const void *A, const int lda, + const void *B, const int ldb, const void *vbeta, void *C, const int ldc); +int ATL_ssyr2kUN + (const int N, const int K, const void *valpha, const void *A, const int lda, + const void *B, const int ldb, const void *vbeta, void *C, const int ldc); +void ATL_dsymmRU + (const int M, const int N, const void *alpha, const void *A, const int lda, + const void *B, const int ldb, const void *beta, void *C, const int ldc); +void ATL_dsymmLU + (const int M, const int N, const void *alpha, const void *A, const int lda, + const void *B, const int ldb, const void *beta, void *C, const int ldc); +void ATL_dsymmRL + (const int M, const int N, const void *alpha, const void *A, const int lda, + const void *B, const int ldb, const void *beta, void *C, const int ldc); +void ATL_dsymmLL + (const int M, const int N, const void *alpha, const void *A, const int lda, + const void *B, const int ldb, const void *beta, void *C, const int ldc); +void ATL_dtrsmLLTN + (const int M, const int N, const void *valpha, const void *A, const int lda, + void *C, const int ldc); +void ATL_dtrmmLLTN + (const int M, const int N, const void *valpha, const void *A, const int lda, + void *C, const int ldc); +void ATL_dtrsmLLTU + (const int M, const int N, const void *valpha, const void *A, const int lda, + void *C, const int ldc); +void ATL_dtrmmLLTU + (const int M, const int N, const void *valpha, const void *A, const int lda, + void *C, const int ldc); +void ATL_dtrsmLLNN + (const int M, const int N, const void *valpha, const void *A, const int lda, + void *C, const int ldc); +void ATL_dtrmmLLNN + (const int M, const int N, const void *valpha, const void *A, const int lda, + void *C, const int ldc); +void ATL_dtrsmLLNU + (const int M, const int N, const void *valpha, const void *A, const int lda, + void *C, const int ldc); +void ATL_dtrmmLLNU + (const int M, const int N, const void *valpha, const void *A, const int lda, + void *C, const int ldc); +void ATL_dtrsmLUTN + (const int M, const int N, const void *valpha, const void *A, const int lda, + void *C, const int ldc); +void ATL_dtrmmLUTN + (const int M, const int N, const void *valpha, const void *A, const int lda, + void *C, const int ldc); +void ATL_dtrsmLUTU + (const int M, const int N, const void *valpha, const void *A, const int lda, + void *C, const int ldc); +void ATL_dtrmmLUTU + (const int M, const int N, const void *valpha, const void *A, const int lda, + void *C, const int ldc); +void ATL_dtrsmLUNN + (const int M, const int N, const void *valpha, const void *A, const int lda, + void *C, const int ldc); +void ATL_dtrmmLUNN + (const int M, const int N, const void *valpha, const void *A, const int lda, + void *C, const int ldc); +void ATL_dtrsmLUNU + (const int M, const int N, const void *valpha, const void *A, const int lda, + void *C, const int ldc); +void ATL_dtrmmLUNU + (const int M, const int N, const void *valpha, const void *A, const int lda, + void *C, const int ldc); +void ATL_dtrsmRLTN + (const int M, const int N, const void *valpha, const void *A, const int lda, + void *C, const int ldc); +void ATL_dtrmmRLTN + (const int M, const int N, const void *valpha, const void *A, const int lda, + void *C, const int ldc); +void ATL_dtrsmRLTU + (const int M, const int N, const void *valpha, const void *A, const int lda, + void *C, const int ldc); +void ATL_dtrmmRLTU + (const int M, const int N, const void *valpha, const void *A, const int lda, + void *C, const int ldc); +void ATL_dtrsmRLNN + (const int M, const int N, const void *valpha, const void *A, const int lda, + void *C, const int ldc); +void ATL_dtrmmRLNN + (const int M, const int N, const void *valpha, const void *A, const int lda, + void *C, const int ldc); +void ATL_dtrsmRLNU + (const int M, const int N, const void *valpha, const void *A, const int lda, + void *C, const int ldc); +void ATL_dtrmmRLNU + (const int M, const int N, const void *valpha, const void *A, const int lda, + void *C, const int ldc); +void ATL_dtrsmRUTN + (const int M, const int N, const void *valpha, const void *A, const int lda, + void *C, const int ldc); +void ATL_dtrmmRUTN + (const int M, const int N, const void *valpha, const void *A, const int lda, + void *C, const int ldc); +void ATL_dtrsmRUTU + (const int M, const int N, const void *valpha, const void *A, const int lda, + void *C, const int ldc); +void ATL_dtrmmRUTU + (const int M, const int N, const void *valpha, const void *A, const int lda, + void *C, const int ldc); +void ATL_dtrsmRUNN + (const int M, const int N, const void *valpha, const void *A, const int lda, + void *C, const int ldc); +void ATL_dtrmmRUNN + (const int M, const int N, const void *valpha, const void *A, const int lda, + void *C, const int ldc); +void ATL_dtrsmRUNU + (const int M, const int N, const void *valpha, const void *A, const int lda, + void *C, const int ldc); +void ATL_dtrmmRUNU + (const int M, const int N, const void *valpha, const void *A, const int lda, + void *C, const int ldc); +void ATL_dsyrkLT + (const int N, const int K, const void *valpha, const void *A, const int lda, + const void *vbeta, void *C, const int ldc); +void ATL_dsyrkUT + (const int N, const int K, const void *valpha, const void *A, const int lda, + const void *vbeta, void *C, const int ldc); +void ATL_dsyrkLN + (const int N, const int K, const void *valpha, const void *A, const int lda, + const void *vbeta, void *C, const int ldc); +void ATL_dsyrkUN + (const int N, const int K, const void *valpha, const void *A, const int lda, + const void *vbeta, void *C, const int ldc); +int ATL_dsyr2kLT + (const int N, const int K, const void *valpha, const void *A, const int lda, + const void *B, const int ldb, const void *vbeta, void *C, const int ldc); +int ATL_dsyr2kUT + (const int N, const int K, const void *valpha, const void *A, const int lda, + const void *B, const int ldb, const void *vbeta, void *C, const int ldc); +int ATL_dsyr2kLN + (const int N, const int K, const void *valpha, const void *A, const int lda, + const void *B, const int ldb, const void *vbeta, void *C, const int ldc); +int ATL_dsyr2kUN + (const int N, const int K, const void *valpha, const void *A, const int lda, + const void *B, const int ldb, const void *vbeta, void *C, const int ldc); + +/* + * Complex level 3 kernels + */ +void ATL_chemmRU + (const int M, const int N, const void *alpha, const void *A, const int lda, + const void *B, const int ldb, const void *beta, void *C, const int ldc); +void ATL_chemmLU + (const int M, const int N, const void *alpha, const void *A, const int lda, + const void *B, const int ldb, const void *beta, void *C, const int ldc); +void ATL_chemmRL + (const int M, const int N, const void *alpha, const void *A, const int lda, + const void *B, const int ldb, const void *beta, void *C, const int ldc); +void ATL_chemmLL + (const int M, const int N, const void *alpha, const void *A, const int lda, + const void *B, const int ldb, const void *beta, void *C, const int ldc); +void ATL_csymmRU + (const int M, const int N, const void *alpha, const void *A, const int lda, + const void *B, const int ldb, const void *beta, void *C, const int ldc); +void ATL_csymmLU + (const int M, const int N, const void *alpha, const void *A, const int lda, + const void *B, const int ldb, const void *beta, void *C, const int ldc); +void ATL_csymmRL + (const int M, const int N, const void *alpha, const void *A, const int lda, + const void *B, const int ldb, const void *beta, void *C, const int ldc); +void ATL_csymmLL + (const int M, const int N, const void *alpha, const void *A, const int lda, + const void *B, const int ldb, const void *beta, void *C, const int ldc); +void ATL_ctrsmLLTN + (const int M, const int N, const void *valpha, const void *A, const int lda, + void *C, const int ldc); +void ATL_ctrmmLLTN + (const int M, const int N, const void *valpha, const void *A, const int lda, + void *C, const int ldc); +void ATL_ctrsmLLTU + (const int M, const int N, const void *valpha, const void *A, const int lda, + void *C, const int ldc); +void ATL_ctrmmLLTU + (const int M, const int N, const void *valpha, const void *A, const int lda, + void *C, const int ldc); +void ATL_ctrsmLLNN + (const int M, const int N, const void *valpha, const void *A, const int lda, + void *C, const int ldc); +void ATL_ctrmmLLNN + (const int M, const int N, const void *valpha, const void *A, const int lda, + void *C, const int ldc); +void ATL_ctrsmLLNU + (const int M, const int N, const void *valpha, const void *A, const int lda, + void *C, const int ldc); +void ATL_ctrmmLLNU + (const int M, const int N, const void *valpha, const void *A, const int lda, + void *C, const int ldc); +void ATL_ctrsmLLCN + (const int M, const int N, const void *valpha, const void *A, const int lda, + void *C, const int ldc); +void ATL_ctrmmLLCN + (const int M, const int N, const void *valpha, const void *A, const int lda, + void *C, const int ldc); +void ATL_ctrsmLLCU + (const int M, const int N, const void *valpha, const void *A, const int lda, + void *C, const int ldc); +void ATL_ctrmmLLCU + (const int M, const int N, const void *valpha, const void *A, const int lda, + void *C, const int ldc); +void ATL_ctrsmLUTN + (const int M, const int N, const void *valpha, const void *A, const int lda, + void *C, const int ldc); +void ATL_ctrmmLUTN + (const int M, const int N, const void *valpha, const void *A, const int lda, + void *C, const int ldc); +void ATL_ctrsmLUTU + (const int M, const int N, const void *valpha, const void *A, const int lda, + void *C, const int ldc); +void ATL_ctrmmLUTU + (const int M, const int N, const void *valpha, const void *A, const int lda, + void *C, const int ldc); +void ATL_ctrsmLUNN + (const int M, const int N, const void *valpha, const void *A, const int lda, + void *C, const int ldc); +void ATL_ctrmmLUNN + (const int M, const int N, const void *valpha, const void *A, const int lda, + void *C, const int ldc); +void ATL_ctrsmLUNU + (const int M, const int N, const void *valpha, const void *A, const int lda, + void *C, const int ldc); +void ATL_ctrmmLUNU + (const int M, const int N, const void *valpha, const void *A, const int lda, + void *C, const int ldc); +void ATL_ctrsmLUCN + (const int M, const int N, const void *valpha, const void *A, const int lda, + void *C, const int ldc); +void ATL_ctrmmLUCN + (const int M, const int N, const void *valpha, const void *A, const int lda, + void *C, const int ldc); +void ATL_ctrsmLUCU + (const int M, const int N, const void *valpha, const void *A, const int lda, + void *C, const int ldc); +void ATL_ctrmmLUCU + (const int M, const int N, const void *valpha, const void *A, const int lda, + void *C, const int ldc); +void ATL_ctrsmRLTN + (const int M, const int N, const void *valpha, const void *A, const int lda, + void *C, const int ldc); +void ATL_ctrmmRLTN + (const int M, const int N, const void *valpha, const void *A, const int lda, + void *C, const int ldc); +void ATL_ctrsmRLTU + (const int M, const int N, const void *valpha, const void *A, const int lda, + void *C, const int ldc); +void ATL_ctrmmRLTU + (const int M, const int N, const void *valpha, const void *A, const int lda, + void *C, const int ldc); +void ATL_ctrsmRLNN + (const int M, const int N, const void *valpha, const void *A, const int lda, + void *C, const int ldc); +void ATL_ctrmmRLNN + (const int M, const int N, const void *valpha, const void *A, const int lda, + void *C, const int ldc); +void ATL_ctrsmRLNU + (const int M, const int N, const void *valpha, const void *A, const int lda, + void *C, const int ldc); +void ATL_ctrmmRLNU + (const int M, const int N, const void *valpha, const void *A, const int lda, + void *C, const int ldc); +void ATL_ctrsmRLCN + (const int M, const int N, const void *valpha, const void *A, const int lda, + void *C, const int ldc); +void ATL_ctrmmRLCN + (const int M, const int N, const void *valpha, const void *A, const int lda, + void *C, const int ldc); +void ATL_ctrsmRLCU + (const int M, const int N, const void *valpha, const void *A, const int lda, + void *C, const int ldc); +void ATL_ctrmmRLCU + (const int M, const int N, const void *valpha, const void *A, const int lda, + void *C, const int ldc); +void ATL_ctrsmRUTN + (const int M, const int N, const void *valpha, const void *A, const int lda, + void *C, const int ldc); +void ATL_ctrmmRUTN + (const int M, const int N, const void *valpha, const void *A, const int lda, + void *C, const int ldc); +void ATL_ctrsmRUTU + (const int M, const int N, const void *valpha, const void *A, const int lda, + void *C, const int ldc); +void ATL_ctrmmRUTU + (const int M, const int N, const void *valpha, const void *A, const int lda, + void *C, const int ldc); +void ATL_ctrsmRUNN + (const int M, const int N, const void *valpha, const void *A, const int lda, + void *C, const int ldc); +void ATL_ctrmmRUNN + (const int M, const int N, const void *valpha, const void *A, const int lda, + void *C, const int ldc); +void ATL_ctrsmRUNU + (const int M, const int N, const void *valpha, const void *A, const int lda, + void *C, const int ldc); +void ATL_ctrmmRUNU + (const int M, const int N, const void *valpha, const void *A, const int lda, + void *C, const int ldc); +void ATL_ctrsmRUCN + (const int M, const int N, const void *valpha, const void *A, const int lda, + void *C, const int ldc); +void ATL_ctrmmRUCN + (const int M, const int N, const void *valpha, const void *A, const int lda, + void *C, const int ldc); +void ATL_ctrsmRUCU + (const int M, const int N, const void *valpha, const void *A, const int lda, + void *C, const int ldc); +void ATL_ctrmmRUCU + (const int M, const int N, const void *valpha, const void *A, const int lda, + void *C, const int ldc); +void ATL_cherkLC + (const int N, const int K, const void *valpha, const void *A, const int lda, + const void *vbeta, void *C, const int ldc); +void ATL_cherkUC + (const int N, const int K, const void *valpha, const void *A, const int lda, + const void *vbeta, void *C, const int ldc); +void ATL_cherkLN + (const int N, const int K, const void *valpha, const void *A, const int lda, + const void *vbeta, void *C, const int ldc); +void ATL_cherkUN + (const int N, const int K, const void *valpha, const void *A, const int lda, + const void *vbeta, void *C, const int ldc); +void ATL_csyrkLT + (const int N, const int K, const void *valpha, const void *A, const int lda, + const void *vbeta, void *C, const int ldc); +void ATL_csyrkUT + (const int N, const int K, const void *valpha, const void *A, const int lda, + const void *vbeta, void *C, const int ldc); +void ATL_csyrkLN + (const int N, const int K, const void *valpha, const void *A, const int lda, + const void *vbeta, void *C, const int ldc); +void ATL_csyrkUN + (const int N, const int K, const void *valpha, const void *A, const int lda, + const void *vbeta, void *C, const int ldc); +int ATL_cher2kLC + (const int N, const int K, const void *valpha, const void *A, const int lda, + const void *B, const int ldb, const void *vbeta, void *C, const int ldc); +int ATL_cher2kUC + (const int N, const int K, const void *valpha, const void *A, const int lda, + const void *B, const int ldb, const void *vbeta, void *C, const int ldc); +int ATL_cher2kLN + (const int N, const int K, const void *valpha, const void *A, const int lda, + const void *B, const int ldb, const void *vbeta, void *C, const int ldc); +int ATL_cher2kUN + (const int N, const int K, const void *valpha, const void *A, const int lda, + const void *B, const int ldb, const void *vbeta, void *C, const int ldc); +int ATL_csyr2kLT + (const int N, const int K, const void *valpha, const void *A, const int lda, + const void *B, const int ldb, const void *vbeta, void *C, const int ldc); +int ATL_csyr2kUT + (const int N, const int K, const void *valpha, const void *A, const int lda, + const void *B, const int ldb, const void *vbeta, void *C, const int ldc); +int ATL_csyr2kLN + (const int N, const int K, const void *valpha, const void *A, const int lda, + const void *B, const int ldb, const void *vbeta, void *C, const int ldc); +int ATL_csyr2kUN + (const int N, const int K, const void *valpha, const void *A, const int lda, + const void *B, const int ldb, const void *vbeta, void *C, const int ldc); +void ATL_zhemmRU + (const int M, const int N, const void *alpha, const void *A, const int lda, + const void *B, const int ldb, const void *beta, void *C, const int ldc); +void ATL_zhemmLU + (const int M, const int N, const void *alpha, const void *A, const int lda, + const void *B, const int ldb, const void *beta, void *C, const int ldc); +void ATL_zhemmRL + (const int M, const int N, const void *alpha, const void *A, const int lda, + const void *B, const int ldb, const void *beta, void *C, const int ldc); +void ATL_zhemmLL + (const int M, const int N, const void *alpha, const void *A, const int lda, + const void *B, const int ldb, const void *beta, void *C, const int ldc); +void ATL_zsymmRU + (const int M, const int N, const void *alpha, const void *A, const int lda, + const void *B, const int ldb, const void *beta, void *C, const int ldc); +void ATL_zsymmLU + (const int M, const int N, const void *alpha, const void *A, const int lda, + const void *B, const int ldb, const void *beta, void *C, const int ldc); +void ATL_zsymmRL + (const int M, const int N, const void *alpha, const void *A, const int lda, + const void *B, const int ldb, const void *beta, void *C, const int ldc); +void ATL_zsymmLL + (const int M, const int N, const void *alpha, const void *A, const int lda, + const void *B, const int ldb, const void *beta, void *C, const int ldc); +void ATL_ztrsmLLTN + (const int M, const int N, const void *valpha, const void *A, const int lda, + void *C, const int ldc); +void ATL_ztrmmLLTN + (const int M, const int N, const void *valpha, const void *A, const int lda, + void *C, const int ldc); +void ATL_ztrsmLLTU + (const int M, const int N, const void *valpha, const void *A, const int lda, + void *C, const int ldc); +void ATL_ztrmmLLTU + (const int M, const int N, const void *valpha, const void *A, const int lda, + void *C, const int ldc); +void ATL_ztrsmLLNN + (const int M, const int N, const void *valpha, const void *A, const int lda, + void *C, const int ldc); +void ATL_ztrmmLLNN + (const int M, const int N, const void *valpha, const void *A, const int lda, + void *C, const int ldc); +void ATL_ztrsmLLNU + (const int M, const int N, const void *valpha, const void *A, const int lda, + void *C, const int ldc); +void ATL_ztrmmLLNU + (const int M, const int N, const void *valpha, const void *A, const int lda, + void *C, const int ldc); +void ATL_ztrsmLLCN + (const int M, const int N, const void *valpha, const void *A, const int lda, + void *C, const int ldc); +void ATL_ztrmmLLCN + (const int M, const int N, const void *valpha, const void *A, const int lda, + void *C, const int ldc); +void ATL_ztrsmLLCU + (const int M, const int N, const void *valpha, const void *A, const int lda, + void *C, const int ldc); +void ATL_ztrmmLLCU + (const int M, const int N, const void *valpha, const void *A, const int lda, + void *C, const int ldc); +void ATL_ztrsmLUTN + (const int M, const int N, const void *valpha, const void *A, const int lda, + void *C, const int ldc); +void ATL_ztrmmLUTN + (const int M, const int N, const void *valpha, const void *A, const int lda, + void *C, const int ldc); +void ATL_ztrsmLUTU + (const int M, const int N, const void *valpha, const void *A, const int lda, + void *C, const int ldc); +void ATL_ztrmmLUTU + (const int M, const int N, const void *valpha, const void *A, const int lda, + void *C, const int ldc); +void ATL_ztrsmLUNN + (const int M, const int N, const void *valpha, const void *A, const int lda, + void *C, const int ldc); +void ATL_ztrmmLUNN + (const int M, const int N, const void *valpha, const void *A, const int lda, + void *C, const int ldc); +void ATL_ztrsmLUNU + (const int M, const int N, const void *valpha, const void *A, const int lda, + void *C, const int ldc); +void ATL_ztrmmLUNU + (const int M, const int N, const void *valpha, const void *A, const int lda, + void *C, const int ldc); +void ATL_ztrsmLUCN + (const int M, const int N, const void *valpha, const void *A, const int lda, + void *C, const int ldc); +void ATL_ztrmmLUCN + (const int M, const int N, const void *valpha, const void *A, const int lda, + void *C, const int ldc); +void ATL_ztrsmLUCU + (const int M, const int N, const void *valpha, const void *A, const int lda, + void *C, const int ldc); +void ATL_ztrmmLUCU + (const int M, const int N, const void *valpha, const void *A, const int lda, + void *C, const int ldc); +void ATL_ztrsmRLTN + (const int M, const int N, const void *valpha, const void *A, const int lda, + void *C, const int ldc); +void ATL_ztrmmRLTN + (const int M, const int N, const void *valpha, const void *A, const int lda, + void *C, const int ldc); +void ATL_ztrsmRLTU + (const int M, const int N, const void *valpha, const void *A, const int lda, + void *C, const int ldc); +void ATL_ztrmmRLTU + (const int M, const int N, const void *valpha, const void *A, const int lda, + void *C, const int ldc); +void ATL_ztrsmRLNN + (const int M, const int N, const void *valpha, const void *A, const int lda, + void *C, const int ldc); +void ATL_ztrmmRLNN + (const int M, const int N, const void *valpha, const void *A, const int lda, + void *C, const int ldc); +void ATL_ztrsmRLNU + (const int M, const int N, const void *valpha, const void *A, const int lda, + void *C, const int ldc); +void ATL_ztrmmRLNU + (const int M, const int N, const void *valpha, const void *A, const int lda, + void *C, const int ldc); +void ATL_ztrsmRLCN + (const int M, const int N, const void *valpha, const void *A, const int lda, + void *C, const int ldc); +void ATL_ztrmmRLCN + (const int M, const int N, const void *valpha, const void *A, const int lda, + void *C, const int ldc); +void ATL_ztrsmRLCU + (const int M, const int N, const void *valpha, const void *A, const int lda, + void *C, const int ldc); +void ATL_ztrmmRLCU + (const int M, const int N, const void *valpha, const void *A, const int lda, + void *C, const int ldc); +void ATL_ztrsmRUTN + (const int M, const int N, const void *valpha, const void *A, const int lda, + void *C, const int ldc); +void ATL_ztrmmRUTN + (const int M, const int N, const void *valpha, const void *A, const int lda, + void *C, const int ldc); +void ATL_ztrsmRUTU + (const int M, const int N, const void *valpha, const void *A, const int lda, + void *C, const int ldc); +void ATL_ztrmmRUTU + (const int M, const int N, const void *valpha, const void *A, const int lda, + void *C, const int ldc); +void ATL_ztrsmRUNN + (const int M, const int N, const void *valpha, const void *A, const int lda, + void *C, const int ldc); +void ATL_ztrmmRUNN + (const int M, const int N, const void *valpha, const void *A, const int lda, + void *C, const int ldc); +void ATL_ztrsmRUNU + (const int M, const int N, const void *valpha, const void *A, const int lda, + void *C, const int ldc); +void ATL_ztrmmRUNU + (const int M, const int N, const void *valpha, const void *A, const int lda, + void *C, const int ldc); +void ATL_ztrsmRUCN + (const int M, const int N, const void *valpha, const void *A, const int lda, + void *C, const int ldc); +void ATL_ztrmmRUCN + (const int M, const int N, const void *valpha, const void *A, const int lda, + void *C, const int ldc); +void ATL_ztrsmRUCU + (const int M, const int N, const void *valpha, const void *A, const int lda, + void *C, const int ldc); +void ATL_ztrmmRUCU + (const int M, const int N, const void *valpha, const void *A, const int lda, + void *C, const int ldc); +void ATL_zherkLC + (const int N, const int K, const void *valpha, const void *A, const int lda, + const void *vbeta, void *C, const int ldc); +void ATL_zherkUC + (const int N, const int K, const void *valpha, const void *A, const int lda, + const void *vbeta, void *C, const int ldc); +void ATL_zherkLN + (const int N, const int K, const void *valpha, const void *A, const int lda, + const void *vbeta, void *C, const int ldc); +void ATL_zherkUN + (const int N, const int K, const void *valpha, const void *A, const int lda, + const void *vbeta, void *C, const int ldc); +void ATL_zsyrkLT + (const int N, const int K, const void *valpha, const void *A, const int lda, + const void *vbeta, void *C, const int ldc); +void ATL_zsyrkUT + (const int N, const int K, const void *valpha, const void *A, const int lda, + const void *vbeta, void *C, const int ldc); +void ATL_zsyrkLN + (const int N, const int K, const void *valpha, const void *A, const int lda, + const void *vbeta, void *C, const int ldc); +void ATL_zsyrkUN + (const int N, const int K, const void *valpha, const void *A, const int lda, + const void *vbeta, void *C, const int ldc); +int ATL_zher2kLC + (const int N, const int K, const void *valpha, const void *A, const int lda, + const void *B, const int ldb, const void *vbeta, void *C, const int ldc); +int ATL_zher2kUC + (const int N, const int K, const void *valpha, const void *A, const int lda, + const void *B, const int ldb, const void *vbeta, void *C, const int ldc); +int ATL_zher2kLN + (const int N, const int K, const void *valpha, const void *A, const int lda, + const void *B, const int ldb, const void *vbeta, void *C, const int ldc); +int ATL_zher2kUN + (const int N, const int K, const void *valpha, const void *A, const int lda, + const void *B, const int ldb, const void *vbeta, void *C, const int ldc); +int ATL_zsyr2kLT + (const int N, const int K, const void *valpha, const void *A, const int lda, + const void *B, const int ldb, const void *vbeta, void *C, const int ldc); +int ATL_zsyr2kUT + (const int N, const int K, const void *valpha, const void *A, const int lda, + const void *B, const int ldb, const void *vbeta, void *C, const int ldc); +int ATL_zsyr2kLN + (const int N, const int K, const void *valpha, const void *A, const int lda, + const void *B, const int ldb, const void *vbeta, void *C, const int ldc); +int ATL_zsyr2kUN + (const int N, const int K, const void *valpha, const void *A, const int lda, + const void *B, const int ldb, const void *vbeta, void *C, const int ldc); + +/* + * Real level 3 kernel auxiliaries + */ +void ATL_ssycopyU_a0 + (const int N, const float alpha, const float *A, const int lda, float *C); +void ATL_ssycopyL_a0 + (const int N, const float alpha, const float *A, const int lda, float *C); +void ATL_strcopyU2L_N_a0 + (const int N, const float alpha, const float *A, const int lda, float *C); +void ATL_strcopyU2L_U_a0 + (const int N, const float alpha, const float *A, const int lda, float *C); +void ATL_strcopyU2U_N_a0 + (const int N, const float alpha, const float *A, const int lda, float *C); +void ATL_strcopyU2U_U_a0 + (const int N, const float alpha, const float *A, const int lda, float *C); +void ATL_strcopyL2L_N_a0 + (const int N, const float alpha, const float *A, const int lda, float *C); +void ATL_strcopyL2L_U_a0 + (const int N, const float alpha, const float *A, const int lda, float *C); +void ATL_strcopyL2U_N_a0 + (const int N, const float alpha, const float *A, const int lda, float *C); +void ATL_strcopyL2U_U_a0 + (const int N, const float alpha, const float *A, const int lda, float *C); +void ATL_ssycopyU_a1 + (const int N, const float alpha, const float *A, const int lda, float *C); +void ATL_ssycopyL_a1 + (const int N, const float alpha, const float *A, const int lda, float *C); +void ATL_strcopyU2L_N_a1 + (const int N, const float alpha, const float *A, const int lda, float *C); +void ATL_strcopyU2L_U_a1 + (const int N, const float alpha, const float *A, const int lda, float *C); +void ATL_strcopyU2U_N_a1 + (const int N, const float alpha, const float *A, const int lda, float *C); +void ATL_strcopyU2U_U_a1 + (const int N, const float alpha, const float *A, const int lda, float *C); +void ATL_strcopyL2L_N_a1 + (const int N, const float alpha, const float *A, const int lda, float *C); +void ATL_strcopyL2L_U_a1 + (const int N, const float alpha, const float *A, const int lda, float *C); +void ATL_strcopyL2U_N_a1 + (const int N, const float alpha, const float *A, const int lda, float *C); +void ATL_strcopyL2U_U_a1 + (const int N, const float alpha, const float *A, const int lda, float *C); +void ATL_ssycopyU_aX + (const int N, const float alpha, const float *A, const int lda, float *C); +void ATL_ssycopyL_aX + (const int N, const float alpha, const float *A, const int lda, float *C); +void ATL_strcopyU2L_N_aX + (const int N, const float alpha, const float *A, const int lda, float *C); +void ATL_strcopyU2L_U_aX + (const int N, const float alpha, const float *A, const int lda, float *C); +void ATL_strcopyU2U_N_aX + (const int N, const float alpha, const float *A, const int lda, float *C); +void ATL_strcopyU2U_U_aX + (const int N, const float alpha, const float *A, const int lda, float *C); +void ATL_strcopyL2L_N_aX + (const int N, const float alpha, const float *A, const int lda, float *C); +void ATL_strcopyL2L_U_aX + (const int N, const float alpha, const float *A, const int lda, float *C); +void ATL_strcopyL2U_N_aX + (const int N, const float alpha, const float *A, const int lda, float *C); +void ATL_strcopyL2U_U_aX + (const int N, const float alpha, const float *A, const int lda, float *C); +void ATL_strinvertUU(const int N, float *A, const int lda); +void ATL_strinvertLU(const int N, float *A, const int lda); +void ATL_strinvertUN(const int N, float *A, const int lda); +void ATL_strinvertLN(const int N, float *A, const int lda); +void ATL_ssyr2k_putU_bX + (const int N, const float *v, const float beta, float *A, const int lda); +void ATL_ssyr2k_putL_bX + (const int N, const float *v, const float beta, float *A, const int lda); +void ATL_strputU_bX + (const int N, const float *v, const float beta, float *A, const int lda); +void ATL_strputL_bX + (const int N, const float *v, const float beta, float *A, const int lda); +void ATL_ssyr2k_putU_b1 + (const int N, const float *v, const float beta, float *A, const int lda); +void ATL_ssyr2k_putL_b1 + (const int N, const float *v, const float beta, float *A, const int lda); +void ATL_strputU_b1 + (const int N, const float *v, const float beta, float *A, const int lda); +void ATL_strputL_b1 + (const int N, const float *v, const float beta, float *A, const int lda); +void ATL_ssyr2k_putU_b0 + (const int N, const float *v, const float beta, float *A, const int lda); +void ATL_ssyr2k_putL_b0 + (const int N, const float *v, const float beta, float *A, const int lda); +void ATL_strputU_b0 + (const int N, const float *v, const float beta, float *A, const int lda); +void ATL_strputL_b0 + (const int N, const float *v, const float beta, float *A, const int lda); +void ATL_strsmKLLTN + (const int M, const int N, const float alpha, const float *A, + const int lda, float *C, const int ldc); +void ATL_strsmKLLTU + (const int M, const int N, const float alpha, const float *A, + const int lda, float *C, const int ldc); +void ATL_strsmKLLNN + (const int M, const int N, const float alpha, const float *A, + const int lda, float *C, const int ldc); +void ATL_strsmKLLNU + (const int M, const int N, const float alpha, const float *A, + const int lda, float *C, const int ldc); +void ATL_strsmKLUTN + (const int M, const int N, const float alpha, const float *A, + const int lda, float *C, const int ldc); +void ATL_strsmKLUTU + (const int M, const int N, const float alpha, const float *A, + const int lda, float *C, const int ldc); +void ATL_strsmKLUNN + (const int M, const int N, const float alpha, const float *A, + const int lda, float *C, const int ldc); +void ATL_strsmKLUNU + (const int M, const int N, const float alpha, const float *A, + const int lda, float *C, const int ldc); +void ATL_strsmKRLTN + (const int M, const int N, const float alpha, const float *A, + const int lda, float *C, const int ldc); +void ATL_strsmKRLTU + (const int M, const int N, const float alpha, const float *A, + const int lda, float *C, const int ldc); +void ATL_strsmKRLNN + (const int M, const int N, const float alpha, const float *A, + const int lda, float *C, const int ldc); +void ATL_strsmKRLNU + (const int M, const int N, const float alpha, const float *A, + const int lda, float *C, const int ldc); +void ATL_strsmKRUTN + (const int M, const int N, const float alpha, const float *A, + const int lda, float *C, const int ldc); +void ATL_strsmKRUTU + (const int M, const int N, const float alpha, const float *A, + const int lda, float *C, const int ldc); +void ATL_strsmKRUNN + (const int M, const int N, const float alpha, const float *A, + const int lda, float *C, const int ldc); +void ATL_strsmKRUNU + (const int M, const int N, const float alpha, const float *A, + const int lda, float *C, const int ldc); +void ATL_dsycopyU_a0 + (const int N, const double alpha, const double *A, const int lda, double *C); +void ATL_dsycopyL_a0 + (const int N, const double alpha, const double *A, const int lda, double *C); +void ATL_dtrcopyU2L_N_a0 + (const int N, const double alpha, const double *A, const int lda, double *C); +void ATL_dtrcopyU2L_U_a0 + (const int N, const double alpha, const double *A, const int lda, double *C); +void ATL_dtrcopyU2U_N_a0 + (const int N, const double alpha, const double *A, const int lda, double *C); +void ATL_dtrcopyU2U_U_a0 + (const int N, const double alpha, const double *A, const int lda, double *C); +void ATL_dtrcopyL2L_N_a0 + (const int N, const double alpha, const double *A, const int lda, double *C); +void ATL_dtrcopyL2L_U_a0 + (const int N, const double alpha, const double *A, const int lda, double *C); +void ATL_dtrcopyL2U_N_a0 + (const int N, const double alpha, const double *A, const int lda, double *C); +void ATL_dtrcopyL2U_U_a0 + (const int N, const double alpha, const double *A, const int lda, double *C); +void ATL_dsycopyU_a1 + (const int N, const double alpha, const double *A, const int lda, double *C); +void ATL_dsycopyL_a1 + (const int N, const double alpha, const double *A, const int lda, double *C); +void ATL_dtrcopyU2L_N_a1 + (const int N, const double alpha, const double *A, const int lda, double *C); +void ATL_dtrcopyU2L_U_a1 + (const int N, const double alpha, const double *A, const int lda, double *C); +void ATL_dtrcopyU2U_N_a1 + (const int N, const double alpha, const double *A, const int lda, double *C); +void ATL_dtrcopyU2U_U_a1 + (const int N, const double alpha, const double *A, const int lda, double *C); +void ATL_dtrcopyL2L_N_a1 + (const int N, const double alpha, const double *A, const int lda, double *C); +void ATL_dtrcopyL2L_U_a1 + (const int N, const double alpha, const double *A, const int lda, double *C); +void ATL_dtrcopyL2U_N_a1 + (const int N, const double alpha, const double *A, const int lda, double *C); +void ATL_dtrcopyL2U_U_a1 + (const int N, const double alpha, const double *A, const int lda, double *C); +void ATL_dsycopyU_aX + (const int N, const double alpha, const double *A, const int lda, double *C); +void ATL_dsycopyL_aX + (const int N, const double alpha, const double *A, const int lda, double *C); +void ATL_dtrcopyU2L_N_aX + (const int N, const double alpha, const double *A, const int lda, double *C); +void ATL_dtrcopyU2L_U_aX + (const int N, const double alpha, const double *A, const int lda, double *C); +void ATL_dtrcopyU2U_N_aX + (const int N, const double alpha, const double *A, const int lda, double *C); +void ATL_dtrcopyU2U_U_aX + (const int N, const double alpha, const double *A, const int lda, double *C); +void ATL_dtrcopyL2L_N_aX + (const int N, const double alpha, const double *A, const int lda, double *C); +void ATL_dtrcopyL2L_U_aX + (const int N, const double alpha, const double *A, const int lda, double *C); +void ATL_dtrcopyL2U_N_aX + (const int N, const double alpha, const double *A, const int lda, double *C); +void ATL_dtrcopyL2U_U_aX + (const int N, const double alpha, const double *A, const int lda, double *C); +void ATL_dtrinvertUU(const int N, double *A, const int lda); +void ATL_dtrinvertLU(const int N, double *A, const int lda); +void ATL_dtrinvertUN(const int N, double *A, const int lda); +void ATL_dtrinvertLN(const int N, double *A, const int lda); +void ATL_dsyr2k_putU_bX + (const int N, const double *v, const double beta, double *A, const int lda); +void ATL_dsyr2k_putL_bX + (const int N, const double *v, const double beta, double *A, const int lda); +void ATL_dtrputU_bX + (const int N, const double *v, const double beta, double *A, const int lda); +void ATL_dtrputL_bX + (const int N, const double *v, const double beta, double *A, const int lda); +void ATL_dsyr2k_putU_b1 + (const int N, const double *v, const double beta, double *A, const int lda); +void ATL_dsyr2k_putL_b1 + (const int N, const double *v, const double beta, double *A, const int lda); +void ATL_dtrputU_b1 + (const int N, const double *v, const double beta, double *A, const int lda); +void ATL_dtrputL_b1 + (const int N, const double *v, const double beta, double *A, const int lda); +void ATL_dsyr2k_putU_b0 + (const int N, const double *v, const double beta, double *A, const int lda); +void ATL_dsyr2k_putL_b0 + (const int N, const double *v, const double beta, double *A, const int lda); +void ATL_dtrputU_b0 + (const int N, const double *v, const double beta, double *A, const int lda); +void ATL_dtrputL_b0 + (const int N, const double *v, const double beta, double *A, const int lda); +void ATL_dtrsmKLLTN + (const int M, const int N, const double alpha, const double *A, + const int lda, double *C, const int ldc); +void ATL_dtrsmKLLTU + (const int M, const int N, const double alpha, const double *A, + const int lda, double *C, const int ldc); +void ATL_dtrsmKLLNN + (const int M, const int N, const double alpha, const double *A, + const int lda, double *C, const int ldc); +void ATL_dtrsmKLLNU + (const int M, const int N, const double alpha, const double *A, + const int lda, double *C, const int ldc); +void ATL_dtrsmKLUTN + (const int M, const int N, const double alpha, const double *A, + const int lda, double *C, const int ldc); +void ATL_dtrsmKLUTU + (const int M, const int N, const double alpha, const double *A, + const int lda, double *C, const int ldc); +void ATL_dtrsmKLUNN + (const int M, const int N, const double alpha, const double *A, + const int lda, double *C, const int ldc); +void ATL_dtrsmKLUNU + (const int M, const int N, const double alpha, const double *A, + const int lda, double *C, const int ldc); +void ATL_dtrsmKRLTN + (const int M, const int N, const double alpha, const double *A, + const int lda, double *C, const int ldc); +void ATL_dtrsmKRLTU + (const int M, const int N, const double alpha, const double *A, + const int lda, double *C, const int ldc); +void ATL_dtrsmKRLNN + (const int M, const int N, const double alpha, const double *A, + const int lda, double *C, const int ldc); +void ATL_dtrsmKRLNU + (const int M, const int N, const double alpha, const double *A, + const int lda, double *C, const int ldc); +void ATL_dtrsmKRUTN + (const int M, const int N, const double alpha, const double *A, + const int lda, double *C, const int ldc); +void ATL_dtrsmKRUTU + (const int M, const int N, const double alpha, const double *A, + const int lda, double *C, const int ldc); +void ATL_dtrsmKRUNN + (const int M, const int N, const double alpha, const double *A, + const int lda, double *C, const int ldc); +void ATL_dtrsmKRUNU + (const int M, const int N, const double alpha, const double *A, + const int lda, double *C, const int ldc); + +/* + * Complex level 3 kernel auxiliaries + */ +void ATL_cCtrsmKL + (enum ATLAS_UPLO Uplo, enum ATLAS_TRANS Trans, enum ATLAS_DIAG Diag, + const int M, const int N, const float *alpha, const float *A, + const int lda, float *B, const int ldb); +void ATL_checopy + (const int N, const float *A, const int lda, float *C); +void ATL_csycopy + (const int N, const float *A, const int lda, float *C); +void ATL_ctrcopyU2L_N + (const int N, const float *A, const int lda, float *C); +void ATL_ctrcopyU2Lc_N + (const int N, const float *A, const int lda, float *C); +void ATL_ctrcopyU2L_U + (const int N, const float *A, const int lda, float *C); +void ATL_ctrcopyU2Lc_U + (const int N, const float *A, const int lda, float *C); +void ATL_ctrcopyU2U_N + (const int N, const float *A, const int lda, float *C); +void ATL_ctrcopyU2Uc_N + (const int N, const float *A, const int lda, float *C); +void ATL_ctrcopyU2U_U + (const int N, const float *A, const int lda, float *C); +void ATL_ctrcopyU2Uc_U + (const int N, const float *A, const int lda, float *C); +void ATL_ctrcopyL2L_N + (const int N, const float *A, const int lda, float *C); +void ATL_ctrcopyL2Lc_N + (const int N, const float *A, const int lda, float *C); +void ATL_ctrcopyL2L_U + (const int N, const float *A, const int lda, float *C); +void ATL_ctrcopyL2Lc_U + (const int N, const float *A, const int lda, float *C); +void ATL_ctrcopyL2U_N + (const int N, const float *A, const int lda, float *C); +void ATL_ctrcopyL2Uc_N + (const int N, const float *A, const int lda, float *C); +void ATL_ctrcopyL2U_U + (const int N, const float *A, const int lda, float *C); +void ATL_ctrcopyL2Uc_U + (const int N, const float *A, const int lda, float *C); +void ATL_ctrmv_scalLNU_an1 + (const int N, const float *alpha, const float *A, const int lda, float *X); +void ATL_ctrmv_scalLNN_aX + (const int N, const float *alpha, const float *A, const int lda, float *X); +void ATL_ctrmv_scalUNU_an1 + (const int N, const float *alpha, const float *A, const int lda, float *X); +void ATL_ctrmv_scalUNN_aX + (const int N, const float *alpha, const float *A, const int lda, float *X); +void ATL_ctrinvertUU(const int N, float *A, const int lda); +void ATL_ctrinvertLU(const int N, float *A, const int lda); +void ATL_ctrinvertUN(const int N, float *A, const int lda); +void ATL_ctrinvertLN(const int N, float *A, const int lda); +void ATL_ctrputU_b0 + (const int N, const float *v, const float *beta, float *A, const int lda); +void ATL_ctrputL_b0 + (const int N, const float *v, const float *beta, float *A, const int lda); +void ATL_csyr2k_putU_b0 + (const int N, const float *v, const float *beta, float *A, const int lda); +void ATL_csyr2k_putL_b0 + (const int N, const float *v, const float *beta, float *A, const int lda); +void ATL_ctrputU_b1 + (const int N, const float *v, const float *beta, float *A, const int lda); +void ATL_ctrputL_b1 + (const int N, const float *v, const float *beta, float *A, const int lda); +void ATL_csyr2k_putU_b1 + (const int N, const float *v, const float *beta, float *A, const int lda); +void ATL_csyr2k_putL_b1 + (const int N, const float *v, const float *beta, float *A, const int lda); +void ATL_ctrputU_bX + (const int N, const float *v, const float *beta, float *A, const int lda); +void ATL_ctrputL_bX + (const int N, const float *v, const float *beta, float *A, const int lda); +void ATL_csyr2k_putU_bX + (const int N, const float *v, const float *beta, float *A, const int lda); +void ATL_csyr2k_putL_bX + (const int N, const float *v, const float *beta, float *A, const int lda); +void ATL_ctrputU_bXi0 + (const int N, const float *v, const float *beta, float *A, const int lda); +void ATL_ctrputL_bXi0 + (const int N, const float *v, const float *beta, float *A, const int lda); +void ATL_csyr2k_putU_bXi0 + (const int N, const float *v, const float *beta, float *A, const int lda); +void ATL_csyr2k_putL_bXi0 + (const int N, const float *v, const float *beta, float *A, const int lda); +void ATL_ctrputU_bn1 + (const int N, const float *v, const float *beta, float *A, const int lda); +void ATL_ctrputL_bn1 + (const int N, const float *v, const float *beta, float *A, const int lda); +void ATL_csyr2k_putU_bn1 + (const int N, const float *v, const float *beta, float *A, const int lda); +void ATL_csyr2k_putL_bn1 + (const int N, const float *v, const float *beta, float *A, const int lda); +void ATL_cher2k_putU_b0 + (const int N, const float *v, const float *beta, float *A, const int lda); +void ATL_cher2k_putL_b0 + (const int N, const float *v, const float *beta, float *A, const int lda); +void ATL_cheputU_b0 + (const int N, const float *v, const float *beta, float *A, const int lda); +void ATL_cheputL_b0 + (const int N, const float *v, const float *beta, float *A, const int lda); +void ATL_cher2k_putU_b1 + (const int N, const float *v, const float *beta, float *A, const int lda); +void ATL_cher2k_putL_b1 + (const int N, const float *v, const float *beta, float *A, const int lda); +void ATL_cheputU_b1 + (const int N, const float *v, const float *beta, float *A, const int lda); +void ATL_cheputL_b1 + (const int N, const float *v, const float *beta, float *A, const int lda); +void ATL_cher2k_putU_bXi0 + (const int N, const float *v, const float *beta, float *A, const int lda); +void ATL_cher2k_putL_bXi0 + (const int N, const float *v, const float *beta, float *A, const int lda); +void ATL_cheputU_bXi0 + (const int N, const float *v, const float *beta, float *A, const int lda); +void ATL_cheputL_bXi0 + (const int N, const float *v, const float *beta, float *A, const int lda); +void ATL_ctrsm0LLTN + (const int M, const int N, const float *alpha, const float *A, + const int lda, float *C, const int ldc); +void ATL_ctrsm0LLTU + (const int M, const int N, const float *alpha, const float *A, + const int lda, float *C, const int ldc); +void ATL_ctrsm0LLNN + (const int M, const int N, const float *alpha, const float *A, + const int lda, float *C, const int ldc); +void ATL_ctrsm0LLNU + (const int M, const int N, const float *alpha, const float *A, + const int lda, float *C, const int ldc); +void ATL_ctrsm0LLCN + (const int M, const int N, const float *alpha, const float *A, + const int lda, float *C, const int ldc); +void ATL_ctrsm0LLCU + (const int M, const int N, const float *alpha, const float *A, + const int lda, float *C, const int ldc); +void ATL_ctrsm0LUTN + (const int M, const int N, const float *alpha, const float *A, + const int lda, float *C, const int ldc); +void ATL_ctrsm0LUTU + (const int M, const int N, const float *alpha, const float *A, + const int lda, float *C, const int ldc); +void ATL_ctrsm0LUNN + (const int M, const int N, const float *alpha, const float *A, + const int lda, float *C, const int ldc); +void ATL_ctrsm0LUNU + (const int M, const int N, const float *alpha, const float *A, + const int lda, float *C, const int ldc); +void ATL_ctrsm0LUCN + (const int M, const int N, const float *alpha, const float *A, + const int lda, float *C, const int ldc); +void ATL_ctrsm0LUCU + (const int M, const int N, const float *alpha, const float *A, + const int lda, float *C, const int ldc); +void ATL_ctrsm0RLTN + (const int M, const int N, const float *alpha, const float *A, + const int lda, float *C, const int ldc); +void ATL_ctrsm0RLTU + (const int M, const int N, const float *alpha, const float *A, + const int lda, float *C, const int ldc); +void ATL_ctrsm0RLNN + (const int M, const int N, const float *alpha, const float *A, + const int lda, float *C, const int ldc); +void ATL_ctrsm0RLNU + (const int M, const int N, const float *alpha, const float *A, + const int lda, float *C, const int ldc); +void ATL_ctrsm0RLCN + (const int M, const int N, const float *alpha, const float *A, + const int lda, float *C, const int ldc); +void ATL_ctrsm0RLCU + (const int M, const int N, const float *alpha, const float *A, + const int lda, float *C, const int ldc); +void ATL_ctrsm0RUTN + (const int M, const int N, const float *alpha, const float *A, + const int lda, float *C, const int ldc); +void ATL_ctrsm0RUTU + (const int M, const int N, const float *alpha, const float *A, + const int lda, float *C, const int ldc); +void ATL_ctrsm0RUNN + (const int M, const int N, const float *alpha, const float *A, + const int lda, float *C, const int ldc); +void ATL_ctrsm0RUNU + (const int M, const int N, const float *alpha, const float *A, + const int lda, float *C, const int ldc); +void ATL_ctrsm0RUCN + (const int M, const int N, const float *alpha, const float *A, + const int lda, float *C, const int ldc); +void ATL_ctrsm0RUCU + (const int M, const int N, const float *alpha, const float *A, + const int lda, float *C, const int ldc); +void ATL_zCtrsmKL + (enum ATLAS_UPLO Uplo, enum ATLAS_TRANS Trans, enum ATLAS_DIAG Diag, + const int M, const int N, const double *alpha, const double *A, + const int lda, double *B, const int ldb); +void ATL_zhecopy + (const int N, const double *A, const int lda, double *C); +void ATL_zsycopy + (const int N, const double *A, const int lda, double *C); +void ATL_ztrcopyU2L_N + (const int N, const double *A, const int lda, double *C); +void ATL_ztrcopyU2Lc_N + (const int N, const double *A, const int lda, double *C); +void ATL_ztrcopyU2L_U + (const int N, const double *A, const int lda, double *C); +void ATL_ztrcopyU2Lc_U + (const int N, const double *A, const int lda, double *C); +void ATL_ztrcopyU2U_N + (const int N, const double *A, const int lda, double *C); +void ATL_ztrcopyU2Uc_N + (const int N, const double *A, const int lda, double *C); +void ATL_ztrcopyU2U_U + (const int N, const double *A, const int lda, double *C); +void ATL_ztrcopyU2Uc_U + (const int N, const double *A, const int lda, double *C); +void ATL_ztrcopyL2L_N + (const int N, const double *A, const int lda, double *C); +void ATL_ztrcopyL2Lc_N + (const int N, const double *A, const int lda, double *C); +void ATL_ztrcopyL2L_U + (const int N, const double *A, const int lda, double *C); +void ATL_ztrcopyL2Lc_U + (const int N, const double *A, const int lda, double *C); +void ATL_ztrcopyL2U_N + (const int N, const double *A, const int lda, double *C); +void ATL_ztrcopyL2Uc_N + (const int N, const double *A, const int lda, double *C); +void ATL_ztrcopyL2U_U + (const int N, const double *A, const int lda, double *C); +void ATL_ztrcopyL2Uc_U + (const int N, const double *A, const int lda, double *C); +void ATL_ztrmv_scalLNU_an1 + (const int N, const double *alpha, const double *A, const int lda, double *X); +void ATL_ztrmv_scalLNN_aX + (const int N, const double *alpha, const double *A, const int lda, double *X); +void ATL_ztrmv_scalUNU_an1 + (const int N, const double *alpha, const double *A, const int lda, double *X); +void ATL_ztrmv_scalUNN_aX + (const int N, const double *alpha, const double *A, const int lda, double *X); +void ATL_ztrinvertUU(const int N, double *A, const int lda); +void ATL_ztrinvertLU(const int N, double *A, const int lda); +void ATL_ztrinvertUN(const int N, double *A, const int lda); +void ATL_ztrinvertLN(const int N, double *A, const int lda); +void ATL_ztrputU_b0 + (const int N, const double *v, const double *beta, double *A, const int lda); +void ATL_ztrputL_b0 + (const int N, const double *v, const double *beta, double *A, const int lda); +void ATL_zsyr2k_putU_b0 + (const int N, const double *v, const double *beta, double *A, const int lda); +void ATL_zsyr2k_putL_b0 + (const int N, const double *v, const double *beta, double *A, const int lda); +void ATL_ztrputU_b1 + (const int N, const double *v, const double *beta, double *A, const int lda); +void ATL_ztrputL_b1 + (const int N, const double *v, const double *beta, double *A, const int lda); +void ATL_zsyr2k_putU_b1 + (const int N, const double *v, const double *beta, double *A, const int lda); +void ATL_zsyr2k_putL_b1 + (const int N, const double *v, const double *beta, double *A, const int lda); +void ATL_ztrputU_bX + (const int N, const double *v, const double *beta, double *A, const int lda); +void ATL_ztrputL_bX + (const int N, const double *v, const double *beta, double *A, const int lda); +void ATL_zsyr2k_putU_bX + (const int N, const double *v, const double *beta, double *A, const int lda); +void ATL_zsyr2k_putL_bX + (const int N, const double *v, const double *beta, double *A, const int lda); +void ATL_ztrputU_bXi0 + (const int N, const double *v, const double *beta, double *A, const int lda); +void ATL_ztrputL_bXi0 + (const int N, const double *v, const double *beta, double *A, const int lda); +void ATL_zsyr2k_putU_bXi0 + (const int N, const double *v, const double *beta, double *A, const int lda); +void ATL_zsyr2k_putL_bXi0 + (const int N, const double *v, const double *beta, double *A, const int lda); +void ATL_ztrputU_bn1 + (const int N, const double *v, const double *beta, double *A, const int lda); +void ATL_ztrputL_bn1 + (const int N, const double *v, const double *beta, double *A, const int lda); +void ATL_zsyr2k_putU_bn1 + (const int N, const double *v, const double *beta, double *A, const int lda); +void ATL_zsyr2k_putL_bn1 + (const int N, const double *v, const double *beta, double *A, const int lda); +void ATL_zher2k_putU_b0 + (const int N, const double *v, const double *beta, double *A, const int lda); +void ATL_zher2k_putL_b0 + (const int N, const double *v, const double *beta, double *A, const int lda); +void ATL_zheputU_b0 + (const int N, const double *v, const double *beta, double *A, const int lda); +void ATL_zheputL_b0 + (const int N, const double *v, const double *beta, double *A, const int lda); +void ATL_zher2k_putU_b1 + (const int N, const double *v, const double *beta, double *A, const int lda); +void ATL_zher2k_putL_b1 + (const int N, const double *v, const double *beta, double *A, const int lda); +void ATL_zheputU_b1 + (const int N, const double *v, const double *beta, double *A, const int lda); +void ATL_zheputL_b1 + (const int N, const double *v, const double *beta, double *A, const int lda); +void ATL_zher2k_putU_bXi0 + (const int N, const double *v, const double *beta, double *A, const int lda); +void ATL_zher2k_putL_bXi0 + (const int N, const double *v, const double *beta, double *A, const int lda); +void ATL_zheputU_bXi0 + (const int N, const double *v, const double *beta, double *A, const int lda); +void ATL_zheputL_bXi0 + (const int N, const double *v, const double *beta, double *A, const int lda); +void ATL_ztrsm0LLTN + (const int M, const int N, const double *alpha, const double *A, + const int lda, double *C, const int ldc); +void ATL_ztrsm0LLTU + (const int M, const int N, const double *alpha, const double *A, + const int lda, double *C, const int ldc); +void ATL_ztrsm0LLNN + (const int M, const int N, const double *alpha, const double *A, + const int lda, double *C, const int ldc); +void ATL_ztrsm0LLNU + (const int M, const int N, const double *alpha, const double *A, + const int lda, double *C, const int ldc); +void ATL_ztrsm0LLCN + (const int M, const int N, const double *alpha, const double *A, + const int lda, double *C, const int ldc); +void ATL_ztrsm0LLCU + (const int M, const int N, const double *alpha, const double *A, + const int lda, double *C, const int ldc); +void ATL_ztrsm0LUTN + (const int M, const int N, const double *alpha, const double *A, + const int lda, double *C, const int ldc); +void ATL_ztrsm0LUTU + (const int M, const int N, const double *alpha, const double *A, + const int lda, double *C, const int ldc); +void ATL_ztrsm0LUNN + (const int M, const int N, const double *alpha, const double *A, + const int lda, double *C, const int ldc); +void ATL_ztrsm0LUNU + (const int M, const int N, const double *alpha, const double *A, + const int lda, double *C, const int ldc); +void ATL_ztrsm0LUCN + (const int M, const int N, const double *alpha, const double *A, + const int lda, double *C, const int ldc); +void ATL_ztrsm0LUCU + (const int M, const int N, const double *alpha, const double *A, + const int lda, double *C, const int ldc); +void ATL_ztrsm0RLTN + (const int M, const int N, const double *alpha, const double *A, + const int lda, double *C, const int ldc); +void ATL_ztrsm0RLTU + (const int M, const int N, const double *alpha, const double *A, + const int lda, double *C, const int ldc); +void ATL_ztrsm0RLNN + (const int M, const int N, const double *alpha, const double *A, + const int lda, double *C, const int ldc); +void ATL_ztrsm0RLNU + (const int M, const int N, const double *alpha, const double *A, + const int lda, double *C, const int ldc); +void ATL_ztrsm0RLCN + (const int M, const int N, const double *alpha, const double *A, + const int lda, double *C, const int ldc); +void ATL_ztrsm0RLCU + (const int M, const int N, const double *alpha, const double *A, + const int lda, double *C, const int ldc); +void ATL_ztrsm0RUTN + (const int M, const int N, const double *alpha, const double *A, + const int lda, double *C, const int ldc); +void ATL_ztrsm0RUTU + (const int M, const int N, const double *alpha, const double *A, + const int lda, double *C, const int ldc); +void ATL_ztrsm0RUNN + (const int M, const int N, const double *alpha, const double *A, + const int lda, double *C, const int ldc); +void ATL_ztrsm0RUNU + (const int M, const int N, const double *alpha, const double *A, + const int lda, double *C, const int ldc); +void ATL_ztrsm0RUCN + (const int M, const int N, const double *alpha, const double *A, + const int lda, double *C, const int ldc); +void ATL_ztrsm0RUCU + (const int M, const int N, const double *alpha, const double *A, + const int lda, double *C, const int ldc); + +#endif diff --git a/kaldi_io/src/tools/ATLAS/include/atlas_lapack.h b/kaldi_io/src/tools/ATLAS/include/atlas_lapack.h new file mode 100644 index 0000000..4b370b8 --- /dev/null +++ b/kaldi_io/src/tools/ATLAS/include/atlas_lapack.h @@ -0,0 +1,239 @@ +/* + * Automatically Tuned Linear Algebra Software v3.8.3 + * (C) Copyright 1999 R. Clint Whaley + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions + * are met: + * 1. Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * 2. Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions, and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * 3. The name of the ATLAS group or the names of its contributers may + * not be used to endorse or promote products derived from this + * software without specific written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + * ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED + * TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR + * PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE ATLAS GROUP OR ITS CONTRIBUTORS + * BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + * POSSIBILITY OF SUCH DAMAGE. + * + */ +#ifndef ATLAS_LAPACK_H + #define ATLAS_LAPACK_H + +#include "atlas_misc.h" +#include "cblas.h" + +#ifdef PATL + +#include "atlas_cblastypealias.h" +/* + * predefined type macro names + */ +#define ATL_getriR Mjoin(PATL,getriR) +#define ATL_getriC Mjoin(PATL,getriC) +#define ATL_getri Mjoin(PATL,getri) +#define ATL_lauumRL Mjoin(PATL,lauumRL) +#define ATL_lauumRU Mjoin(PATL,lauumRU) +#define ATL_lauumCL Mjoin(PATL,lauumCL) +#define ATL_lauumCU Mjoin(PATL,lauumCU) +#define ATL_lauum Mjoin(PATL,lauum) +#define ATL_trtriRL Mjoin(PATL,trtriRL) +#define ATL_trtriRU Mjoin(PATL,trtriRU) +#define ATL_trtriCL Mjoin(PATL,trtriCL) +#define ATL_trtriCU Mjoin(PATL,trtriCU) +#define ATL_trtri Mjoin(PATL,trtri) +#define ATL_potrfU Mjoin(PATL,potrfU) +#define ATL_potrfL Mjoin(PATL,potrfL) +#define ATL_potrs Mjoin(PATL,potrs) +#define ATL_potrf Mjoin(PATL,potrf) +#define ATL_getrfR Mjoin(PATL,getrfR) +#define ATL_getrfC Mjoin(PATL,getrfC) +#define ATL_getrs Mjoin(PATL,getrs) +#define ATL_getrf Mjoin(PATL,getrf) +#define ATL_laswp Mjoin(PATL,laswp) + +#endif + +int ATL_sgetri(const enum CBLAS_ORDER Order, const int N, TYPE *A, const int lda, + const int *ipiv, TYPE *wrk, int *lwrk); +int ATL_sgetriR(const int N, TYPE *A, const int lda, const int *ipiv, + TYPE *wrk, const int lwrk); +int ATL_sgetriC(const int N, TYPE *A, const int lda, const int *ipiv, + TYPE *wrk, const int lwrk); +void ATL_slauum(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, + const int N, float *A, const int lda); +int ATL_spotrf(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, + const int N, float *A, const int lda); +void ATL_spotrs(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, + const int N, const int NRHS, const float *A, const int lda, + float *B, const int ldb); +int ATL_sgetrf(const enum CBLAS_ORDER Order, const int M, const int N, + float *A, const int lda, int *ipiv); +void ATL_sgetrs(const enum CBLAS_ORDER Order, const enum CBLAS_TRANSPOSE Trans, + const int N, const int NRHS, const float *A, const int lda, + const int *ipiv, float *B, const int ldb); +void ATL_slaswp(const int N, float *A, const int lda0, const int K1, + const int K2, const int *ipiv, const int inci); +int ATL_sgetrfC(const int M, const int N, float *A, const int lda, + int *ipiv); +int ATL_sgetrfR(const int M, const int N, float *A, const int lda, + int *ipiv); +void ATL_slauumRU(const int N, float *A, const int lda); +void ATL_slauumRL(const int N, float *A, const int lda); +void ATL_slauumCU(const int N, float *A, const int lda); +void ATL_slauumCL(const int N, float *A, const int lda); +int ATL_spotrfU(const int N, float *A, const int lda); +int ATL_spotrfL(const int N, float *A, const int lda); +int ATL_strtri(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, + const enum CBLAS_DIAG Diag, const int N, + float *A, const int lda); +int ATL_strtriRU(const enum CBLAS_DIAG Diag, const int N, float *A, + const int lda); +int ATL_strtriRL(const enum CBLAS_DIAG Diag, const int N, float *A, + const int lda); +int ATL_strtriCU(const enum CBLAS_DIAG Diag, const int N, float *A, + const int lda); +int ATL_strtriCL(const enum CBLAS_DIAG Diag, const int N, float *A, + const int lda); + +int ATL_dgetri(const enum CBLAS_ORDER Order, const int N, TYPE *A, const int lda, + const int *ipiv, TYPE *wrk, int *lwrk); +int ATL_dgetriR(const int N, TYPE *A, const int lda, const int *ipiv, + TYPE *wrk, const int lwrk); +int ATL_dgetriC(const int N, TYPE *A, const int lda, const int *ipiv, + TYPE *wrk, const int lwrk); +void ATL_dlauum(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, + const int N, double *A, const int lda); +int ATL_dpotrf(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, + const int N, double *A, const int lda); +void ATL_dpotrs(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, + const int N, const int NRHS, const double *A, const int lda, + double *B, const int ldb); +int ATL_dgetrf(const enum CBLAS_ORDER Order, const int M, const int N, + double *A, const int lda, int *ipiv); +void ATL_dgetrs(const enum CBLAS_ORDER Order, const enum CBLAS_TRANSPOSE Trans, + const int N, const int NRHS, const double *A, const int lda, + const int *ipiv, double *B, const int ldb); +void ATL_dlaswp(const int N, double *A, const int lda0, const int K1, + const int K2, const int *ipiv, const int inci); +int ATL_dgetrfC(const int M, const int N, double *A, const int lda, + int *ipiv); +int ATL_dgetrfR(const int M, const int N, double *A, const int lda, + int *ipiv); +void ATL_dlauumRU(const int N, double *A, const int lda); +void ATL_dlauumRL(const int N, double *A, const int lda); +void ATL_dlauumCU(const int N, double *A, const int lda); +void ATL_dlauumCL(const int N, double *A, const int lda); +int ATL_dpotrfU(const int N, double *A, const int lda); +int ATL_dpotrfL(const int N, double *A, const int lda); +int ATL_dtrtri(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, + const enum CBLAS_DIAG Diag, const int N, + double *A, const int lda); +int ATL_dtrtriRU(const enum CBLAS_DIAG Diag, const int N, double *A, + const int lda); +int ATL_dtrtriRL(const enum CBLAS_DIAG Diag, const int N, double *A, + const int lda); +int ATL_dtrtriCU(const enum CBLAS_DIAG Diag, const int N, double *A, + const int lda); +int ATL_dtrtriCL(const enum CBLAS_DIAG Diag, const int N, double *A, + const int lda); + +int ATL_cgetri(const enum CBLAS_ORDER Order, const int N, TYPE *A, const int lda, + const int *ipiv, TYPE *wrk, int *lwrk); +int ATL_cgetriR(const int N, TYPE *A, const int lda, const int *ipiv, + TYPE *wrk, const int lwrk); +int ATL_cgetriC(const int N, TYPE *A, const int lda, const int *ipiv, + TYPE *wrk, const int lwrk); +void ATL_clauum(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, + const int N, float *A, const int lda); +int ATL_cpotrf(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, + const int N, float *A, const int lda); +void ATL_cpotrs(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, + const int N, const int NRHS, const float *A, const int lda, + float *B, const int ldb); +int ATL_cgetrf(const enum CBLAS_ORDER Order, const int M, const int N, + float *A, const int lda, int *ipiv); +void ATL_cgetrs(const enum CBLAS_ORDER Order, const enum CBLAS_TRANSPOSE Trans, + const int N, const int NRHS, const float *A, const int lda, + const int *ipiv, float *B, const int ldb); +void ATL_claswp(const int N, float *A, const int lda0, const int K1, + const int K2, const int *ipiv, const int inci); +int ATL_cgetrfC(const int M, const int N, float *A, const int lda, + int *ipiv); +int ATL_cgetrfR(const int M, const int N, float *A, const int lda, + int *ipiv); +void ATL_clauumRU(const int N, float *A, const int lda); +void ATL_clauumRL(const int N, float *A, const int lda); +void ATL_clauumCU(const int N, float *A, const int lda); +void ATL_clauumCL(const int N, float *A, const int lda); +int ATL_cpotrfRU(const int N, float *A, const int lda); +int ATL_cpotrfRL(const int N, float *A, const int lda); +int ATL_cpotrfU(const int N, float *A, const int lda); +int ATL_cpotrfL(const int N, float *A, const int lda); +int ATL_ctrtri(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, + const enum CBLAS_DIAG Diag, const int N, + float *A, const int lda); +int ATL_ctrtriRU(const enum CBLAS_DIAG Diag, const int N, float *A, + const int lda); +int ATL_ctrtriRL(const enum CBLAS_DIAG Diag, const int N, float *A, + const int lda); +int ATL_ctrtriCU(const enum CBLAS_DIAG Diag, const int N, float *A, + const int lda); +int ATL_ctrtriCL(const enum CBLAS_DIAG Diag, const int N, float *A, + const int lda); + +int ATL_zgetri(const enum CBLAS_ORDER Order, const int N, TYPE *A, const int lda, + const int *ipiv, TYPE *wrk, int *lwrk); +int ATL_zgetriR(const int N, TYPE *A, const int lda, const int *ipiv, + TYPE *wrk, const int lwrk); +int ATL_zgetriC(const int N, TYPE *A, const int lda, const int *ipiv, + TYPE *wrk, const int lwrk); +void ATL_zlauum(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, + const int N, double *A, const int lda); +int ATL_zpotrf(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, + const int N, double *A, const int lda); +void ATL_zpotrs(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, + const int N, const int NRHS, const double *A, const int lda, + double *B, const int ldb); +int ATL_zgetrf(const enum CBLAS_ORDER Order, const int M, const int N, + double *A, const int lda, int *ipiv); +void ATL_zgetrs(const enum CBLAS_ORDER Order, const enum CBLAS_TRANSPOSE Trans, + const int N, const int NRHS, const double *A, const int lda, + const int *ipiv, double *B, const int ldb); +void ATL_zlaswp(const int N, double *A, const int lda0, const int K1, + const int K2, const int *ipiv, const int inci); +int ATL_zgetrfC(const int M, const int N, double *A, const int lda, + int *ipiv); +int ATL_zgetrfR(const int M, const int N, double *A, const int lda, + int *ipiv); +void ATL_zlauumRU(const int N, double *A, const int lda); +void ATL_zlauumRL(const int N, double *A, const int lda); +void ATL_zlauumCU(const int N, double *A, const int lda); +void ATL_zlauumCL(const int N, double *A, const int lda); +int ATL_zpotrfRU(const int N, double *A, const int lda); +int ATL_zpotrfRL(const int N, double *A, const int lda); +int ATL_zpotrfU(const int N, double *A, const int lda); +int ATL_zpotrfL(const int N, double *A, const int lda); +int ATL_ztrtri(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, + const enum CBLAS_DIAG Diag, const int N, + double *A, const int lda); +int ATL_ztrtriRU(const enum CBLAS_DIAG Diag, const int N, double *A, + const int lda); +int ATL_ztrtriRL(const enum CBLAS_DIAG Diag, const int N, double *A, + const int lda); +int ATL_ztrtriCU(const enum CBLAS_DIAG Diag, const int N, double *A, + const int lda); +int ATL_ztrtriCL(const enum CBLAS_DIAG Diag, const int N, double *A, + const int lda); + +#endif diff --git a/kaldi_io/src/tools/ATLAS/include/atlas_level1.h b/kaldi_io/src/tools/ATLAS/include/atlas_level1.h new file mode 100644 index 0000000..d4d61d8 --- /dev/null +++ b/kaldi_io/src/tools/ATLAS/include/atlas_level1.h @@ -0,0 +1,127 @@ +/* + * Automatically Tuned Linear Algebra Software v3.8.3 + * (C) Copyright 1999 R. Clint Whaley + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions + * are met: + * 1. Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * 2. Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions, and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * 3. The name of the ATLAS group or the names of its contributers may + * not be used to endorse or promote products derived from this + * software without specific written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + * ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED + * TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR + * PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE ATLAS GROUP OR ITS CONTRIBUTORS + * BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + * POSSIBILITY OF SUCH DAMAGE. + * + */ + +/* + * Prototypes ATLAS Level 1 functions not defined in atlas_aux.h + */ +#ifndef ATLAS_LEVEL1_H +#define ATLAS_LEVEL1_H + +/* + * Many level one blas routines actually taken care of by atlas auxiliary + */ +#include "atlas_aux.h" + +float ATL_sdsdot(const int N, const float alpha, const float *X, + const int incX, const float *Y, const int incY); +double ATL_dsdot(const int N, const float *X, const int incX, + const float *Y, const int incY); +/* + * Routines with all four types + */ +void ATL_sswap(const int N, float *X, const int incX, + float *Y, const int incY); +int ATL_isamax(const int N, const float *X, const int incX); + +void ATL_dswap(const int N, double *X, const int incX, + double *Y, const int incY); +int ATL_idamax(const int N, const double *X, const int incX); + +void ATL_cswap(const int N, float *X, const int incX, + float *Y, const int incY); +int ATL_icamax(const int N, const float *X, const int incX); + +void ATL_zswap(const int N, double *X, const int incX, + double *Y, const int incY); +int ATL_izamax(const int N, const double *X, const int incX); + +/* + * Routines with real types + */ +void ATL_srotg(float *a, float *b, float *c, float *s); +void ATL_srotmg(float *d1, float *d2, float *b1, const float b2, float *P); +void ATL_srot(const int N, float *X, const int incX, + float *Y, const int incY, const float c, const float s); +void ATL_srotm(const int N, float *X, const int incX, + float *Y, const int incY, const float *P); +float ATL_sdot(const int N, const float *X, const int incX, + const float *Y, const int incY); +void ATL_sssq(const int N, const float *X, const int incX, + float *scal0, float *ssq0); +float ATL_snrm2(const int N, const float *X, const int incX); +float ATL_sasum(const int N, const float *X, const int incX); + +void ATL_drotg(double *a, double *b, double *c, double *s); +void ATL_drotmg(double *d1, double *d2, double *b1, const double b2, double *P); +void ATL_drot(const int N, double *X, const int incX, + double *Y, const int incY, const double c, const double s); +void ATL_drotm(const int N, double *X, const int incX, + double *Y, const int incY, const double *P); +double ATL_ddot(const int N, const double *X, const int incX, + const double *Y, const int incY); +void ATL_dssq(const int N, const double *X, const int incX, + double *scal0, double *ssq0); +double ATL_dnrm2(const int N, const double *X, const int incX); +double ATL_dasum(const int N, const double *X, const int incX); + +/* + * Routines with complex types + */ +void ATL_csrot(const int N, float *X, const int incX, + float *Y, const int incY, const float c, const float s); +void ATL_crotg(float *a, const float *b, float *c, float *s); +void ATL_cdotu_sub(const int N, const float *X, const int incX, + const float *Y, const int incY, float *dot); +void ATL_cdotc_sub(const int N, const float *X, const int incX, + const float *Y, const int incY, float *dot); +void ATL_cssq(const int N, const float *X, const int incX, + float *scal0, float *ssq0); +float ATL_scnrm2(const int N, const float *X, const int incX); +float ATL_scasum(const int N, const float *X, const int incX); + +void ATL_zdrot(const int N, double *X, const int incX, + double *Y, const int incY, const double c, const double s); +void ATL_zrotg(double *a, const double *b, double *c, double *s); +void ATL_zdotu_sub(const int N, const double *X, const int incX, + const double *Y, const int incY, double *dot); +void ATL_zdotc_sub(const int N, const double *X, const int incX, + const double *Y, const int incY, double *dot); +void ATL_zssq(const int N, const double *X, const int incX, + double *scal0, double *ssq0); +double ATL_dznrm2(const int N, const double *X, const int incX); +double ATL_dzasum(const int N, const double *X, const int incX); + + +#define ATL_casum ATL_scasum +#define ATL_zasum ATL_dzasum +#define ATL_cnrm2 ATL_scnrm2 +#define ATL_znrm2 ATL_dznrm2 + +#endif diff --git a/kaldi_io/src/tools/ATLAS/include/atlas_level2.h b/kaldi_io/src/tools/ATLAS/include/atlas_level2.h new file mode 100644 index 0000000..d05f6d5 --- /dev/null +++ b/kaldi_io/src/tools/ATLAS/include/atlas_level2.h @@ -0,0 +1,267 @@ +/* + * Automatically Tuned Linear Algebra Software v3.8.3 + * (C) Copyright 1999 R. Clint Whaley + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions + * are met: + * 1. Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * 2. Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions, and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * 3. The name of the ATLAS group or the names of its contributers may + * not be used to endorse or promote products derived from this + * software without specific written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + * ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED + * TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR + * PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE ATLAS GROUP OR ITS CONTRIBUTORS + * BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + * POSSIBILITY OF SUCH DAMAGE. + * + */ + +/* + * =========================================================================== + * Prototypes for level 2 BLAS + * =========================================================================== + */ +#ifndef ATLAS_LEVEL2_H +#define ATLAS_LEVEL2_H + +/* + * Routines with standard 4 prefixes (S, D, C, Z) + */ +void ATL_sgemv(const enum ATLAS_TRANS TransA, const int M, const int N, + const float alpha, const float *A, const int lda, + const float *X, const int incX, const float beta, + float *Y, const int incY); +void ATL_sgbmv(const enum ATLAS_TRANS TransA, const int M, const int N, + const int KL, const int KU, const float alpha, + const float *A, const int lda, const float *X, + const int incX, const float beta, float *Y, const int incY); +void ATL_strmv(const enum ATLAS_UPLO Uplo, const enum ATLAS_TRANS TransA, + const enum ATLAS_DIAG Diag, const int N, + const float *A, const int lda, float *X, const int incX); +void ATL_stbmv(const enum ATLAS_UPLO Uplo, const enum ATLAS_TRANS TransA, + const enum ATLAS_DIAG Diag, const int N, const int K, + const float *A, const int lda, float *X, const int incX); +void ATL_stpmv(const enum ATLAS_UPLO Uplo, const enum ATLAS_TRANS TransA, + const enum ATLAS_DIAG Diag, const int N, const float *Ap, + float *X, const int incX); +void ATL_strsv(const enum ATLAS_UPLO Uplo, const enum ATLAS_TRANS TransA, + const enum ATLAS_DIAG Diag, const int N, + const float *A, const int lda, float *X, const int incX); +void ATL_stbsv(const enum ATLAS_UPLO Uplo, const enum ATLAS_TRANS TransA, + const enum ATLAS_DIAG Diag, const int N, const int K, + const float *A, const int lda, float *X, const int incX); +void ATL_stpsv(const enum ATLAS_UPLO Uplo, const enum ATLAS_TRANS TransA, + const enum ATLAS_DIAG Diag, const int N, + const float *Ap, float *X, const int incX); + +void ATL_dgemv(const enum ATLAS_TRANS TransA, const int M, const int N, + const double alpha, const double *A, const int lda, + const double *X, const int incX, const double beta, + double *Y, const int incY); +void ATL_dgbmv(const enum ATLAS_TRANS TransA, const int M, const int N, + const int KL, const int KU, const double alpha, + const double *A, const int lda, const double *X, + const int incX, const double beta, double *Y, const int incY); +void ATL_dtrmv(const enum ATLAS_UPLO Uplo, const enum ATLAS_TRANS TransA, + const enum ATLAS_DIAG Diag, const int N, + const double *A, const int lda, double *X, const int incX); +void ATL_dtbmv(const enum ATLAS_UPLO Uplo, const enum ATLAS_TRANS TransA, + const enum ATLAS_DIAG Diag, const int N, const int K, + const double *A, const int lda, double *X, const int incX); +void ATL_dtpmv(const enum ATLAS_UPLO Uplo, const enum ATLAS_TRANS TransA, + const enum ATLAS_DIAG Diag, const int N, const double *Ap, + double *X, const int incX); +void ATL_dtrsv(const enum ATLAS_UPLO Uplo, const enum ATLAS_TRANS TransA, + const enum ATLAS_DIAG Diag, const int N, + const double *A, const int lda, double *X, const int incX); +void ATL_dtbsv(const enum ATLAS_UPLO Uplo, const enum ATLAS_TRANS TransA, + const enum ATLAS_DIAG Diag, const int N, const int K, + const double *A, const int lda, double *X, const int incX); +void ATL_dtpsv(const enum ATLAS_UPLO Uplo, const enum ATLAS_TRANS TransA, + const enum ATLAS_DIAG Diag, const int N, + const double *Ap, double *X, const int incX); + +void ATL_cgemv(const enum ATLAS_TRANS TransA, const int M, const int N, + const float *alpha, const float *A, const int lda, + const float *X, const int incX, const float *beta, + float *Y, const int incY); +void ATL_cgbmv(const enum ATLAS_TRANS TransA, const int M, const int N, + const int KL, const int KU, const float *alpha, + const float *A, const int lda, const float *X, + const int incX, const float *beta, float *Y, const int incY); +void ATL_ctrmv(const enum ATLAS_UPLO Uplo, const enum ATLAS_TRANS TransA, + const enum ATLAS_DIAG Diag, const int N, + const float *A, const int lda, float *X, const int incX); +void ATL_ctbmv(const enum ATLAS_UPLO Uplo, const enum ATLAS_TRANS TransA, + const enum ATLAS_DIAG Diag, const int N, const int K, + const float *A, const int lda, float *X, const int incX); +void ATL_ctpmv(const enum ATLAS_UPLO Uplo, const enum ATLAS_TRANS TransA, + const enum ATLAS_DIAG Diag, const int N, const float *Ap, + float *X, const int incX); +void ATL_ctrsv(const enum ATLAS_UPLO Uplo, const enum ATLAS_TRANS TransA, + const enum ATLAS_DIAG Diag, const int N, + const float *A, const int lda, float *X, const int incX); +void ATL_ctbsv(const enum ATLAS_UPLO Uplo, const enum ATLAS_TRANS TransA, + const enum ATLAS_DIAG Diag, const int N, const int K, + const float *A, const int lda, float *X, const int incX); +void ATL_ctpsv(const enum ATLAS_UPLO Uplo, const enum ATLAS_TRANS TransA, + const enum ATLAS_DIAG Diag, const int N, + const float *Ap, float *X, const int incX); + +void ATL_zgemv(const enum ATLAS_TRANS TransA, const int M, const int N, + const double *alpha, const double *A, const int lda, + const double *X, const int incX, const double *beta, + double *Y, const int incY); +void ATL_zgbmv(const enum ATLAS_TRANS TransA, const int M, const int N, + const int KL, const int KU, const double *alpha, + const double *A, const int lda, const double *X, + const int incX, const double *beta, double *Y, const int incY); +void ATL_ztrmv(const enum ATLAS_UPLO Uplo, const enum ATLAS_TRANS TransA, + const enum ATLAS_DIAG Diag, const int N, + const double *A, const int lda, double *X, const int incX); +void ATL_ztbmv(const enum ATLAS_UPLO Uplo, const enum ATLAS_TRANS TransA, + const enum ATLAS_DIAG Diag, const int N, const int K, + const double *A, const int lda, double *X, const int incX); +void ATL_ztpmv(const enum ATLAS_UPLO Uplo, const enum ATLAS_TRANS TransA, + const enum ATLAS_DIAG Diag, const int N, const double *Ap, + double *X, const int incX); +void ATL_ztrsv(const enum ATLAS_UPLO Uplo, const enum ATLAS_TRANS TransA, + const enum ATLAS_DIAG Diag, const int N, + const double *A, const int lda, double *X, const int incX); +void ATL_ztbsv(const enum ATLAS_UPLO Uplo, const enum ATLAS_TRANS TransA, + const enum ATLAS_DIAG Diag, const int N, const int K, + const double *A, const int lda, double *X, const int incX); +void ATL_ztpsv(const enum ATLAS_UPLO Uplo, const enum ATLAS_TRANS TransA, + const enum ATLAS_DIAG Diag, const int N, + const double *Ap, double *X, const int incX); + + +/* + * Routines with S and D prefixes only + */ +void ATL_ssymv(const enum ATLAS_UPLO Uplo, const int N, + const float alpha, const float *A, const int lda, + const float *X, const int incX, const float beta, + float *Y, const int incY); +void ATL_ssbmv(const enum ATLAS_UPLO Uplo, const int N, const int K, + const float alpha, const float *A, const int lda, + const float *X, const int incX, const float beta, + float *Y, const int incY); +void ATL_sspmv(const enum ATLAS_UPLO Uplo, const int N, const float alpha, + const float *Ap, const float *X, const int incX, + const float beta, float *Y, const int incY); +void ATL_sger(const int M, const int N, const float alpha, + const float *X, const int incX, const float *Y, const int incY, + float *A, const int lda); +void ATL_ssyr(const enum ATLAS_UPLO Uplo, const int N, const float alpha, + const float *X, const int incX, float *A, const int lda); +void ATL_sspr(const enum ATLAS_UPLO Uplo, const int N, const float alpha, + const float *X, const int incX, float *Ap); +void ATL_ssyr2(const enum ATLAS_UPLO Uplo, const int N, const float alpha, + const float *X, const int incX, const float *Y, const int incY, + float *A, const int lda); +void ATL_sspr2(const enum ATLAS_UPLO Uplo, const int N, const float alpha, + const float *X, const int incX, const float *Y, const int incY, + float *A); + +void ATL_dsymv(const enum ATLAS_UPLO Uplo, const int N, + const double alpha, const double *A, const int lda, + const double *X, const int incX, const double beta, + double *Y, const int incY); +void ATL_dsbmv(const enum ATLAS_UPLO Uplo, const int N, const int K, + const double alpha, const double *A, const int lda, + const double *X, const int incX, const double beta, + double *Y, const int incY); +void ATL_dspmv(const enum ATLAS_UPLO Uplo, const int N, const double alpha, + const double *Ap, const double *X, const int incX, + const double beta, double *Y, const int incY); +void ATL_dger(const int M, const int N, const double alpha, + const double *X, const int incX, const double *Y, const int incY, + double *A, const int lda); +void ATL_dsyr(const enum ATLAS_UPLO Uplo, const int N, const double alpha, + const double *X, const int incX, double *A, const int lda); +void ATL_dspr(const enum ATLAS_UPLO Uplo, const int N, const double alpha, + const double *X, const int incX, double *Ap); +void ATL_dsyr2(const enum ATLAS_UPLO Uplo, const int N, const double alpha, + const double *X, const int incX, const double *Y, const int incY, + double *A, const int lda); +void ATL_dspr2(const enum ATLAS_UPLO Uplo, const int N, const double alpha, + const double *X, const int incX, const double *Y, const int incY, + double *A); + + +/* + * Routines with C and Z prefixes only + */ +void ATL_chemv(const enum ATLAS_UPLO Uplo, const int N, + const float *alpha, const float *A, const int lda, + const float *X, const int incX, const float *beta, + float *Y, const int incY); +void ATL_chbmv(const enum ATLAS_UPLO Uplo, const int N, const int K, + const float *alpha, const float *A, const int lda, + const float *X, const int incX, const float *beta, + float *Y, const int incY); +void ATL_chpmv(const enum ATLAS_UPLO Uplo, const int N, + const float *alpha, const float *Ap, + const float *X, const int incX, const float *beta, + float *Y, const int incY); +void ATL_cgeru(const int M, const int N, const float *alpha, + const float *X, const int incX, const float *Y, const int incY, + float *A, const int lda); +void ATL_cgerc(const int M, const int N, const float *alpha, + const float *X, const int incX, const float *Y, const int incY, + float *A, const int lda); +void ATL_cher(const enum ATLAS_UPLO Uplo, const int N, const float alpha, + const float *X, const int incX, float *A, const int lda); +void ATL_chpr(const enum ATLAS_UPLO Uplo, const int N, const float alpha, + const float *X, const int incX, float *A); +void ATL_cher2(const enum ATLAS_UPLO Uplo, const int N, + const float *alpha, const float *X, const int incX, + const float *Y, const int incY, float *A, const int lda); +void ATL_chpr2(const enum ATLAS_UPLO Uplo, const int N, + const float *alpha, const float *X, const int incX, + const float *Y, const int incY, float *Ap); + +void ATL_zhemv(const enum ATLAS_UPLO Uplo, const int N, + const double *alpha, const double *A, const int lda, + const double *X, const int incX, const double *beta, + double *Y, const int incY); +void ATL_zhbmv(const enum ATLAS_UPLO Uplo, const int N, const int K, + const double *alpha, const double *A, const int lda, + const double *X, const int incX, const double *beta, + double *Y, const int incY); +void ATL_zhpmv(const enum ATLAS_UPLO Uplo, const int N, + const double *alpha, const double *Ap, + const double *X, const int incX, const double *beta, + double *Y, const int incY); +void ATL_zgeru(const int M, const int N, const double *alpha, + const double *X, const int incX, const double *Y, const int incY, + double *A, const int lda); +void ATL_zgerc(const int M, const int N, const double *alpha, + const double *X, const int incX, const double *Y, const int incY, + double *A, const int lda); +void ATL_zher(const enum ATLAS_UPLO Uplo, const int N, const double alpha, + const double *X, const int incX, double *A, const int lda); +void ATL_zhpr(const enum ATLAS_UPLO Uplo, const int N, const double alpha, + const double *X, const int incX, double *A); +void ATL_zher2(const enum ATLAS_UPLO Uplo, const int N, + const double *alpha, const double *X, const int incX, + const double *Y, const int incY, double *A, const int lda); +void ATL_zhpr2(const enum ATLAS_UPLO Uplo, const int N, + const double *alpha, const double *X, const int incX, + const double *Y, const int incY, double *Ap); + + +#endif diff --git a/kaldi_io/src/tools/ATLAS/include/atlas_level3.h b/kaldi_io/src/tools/ATLAS/include/atlas_level3.h new file mode 100644 index 0000000..023c63c --- /dev/null +++ b/kaldi_io/src/tools/ATLAS/include/atlas_level3.h @@ -0,0 +1,181 @@ +/* + * Automatically Tuned Linear Algebra Software v3.8.3 + * (C) Copyright 1997 R. Clint Whaley + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions + * are met: + * 1. Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * 2. Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions, and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * 3. The name of the ATLAS group or the names of its contributers may + * not be used to endorse or promote products derived from this + * software without specific written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + * ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED + * TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR + * PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE ATLAS GROUP OR ITS CONTRIBUTORS + * BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + * POSSIBILITY OF SUCH DAMAGE. + * + */ +/* + * =========================================================================== + * Prototypes for level 3 BLAS + * =========================================================================== + */ +#ifndef ATLAS_LEVEL3_H +#define ATLAS_LEVEL3_H + + +/* + * Routines with standard 4 prefixes (S, D, C, Z) + */ +int ATL_sGetNB(void); +int ATL_sGetNCNB(void); +void ATL_sgemm(const enum ATLAS_TRANS TransA, const enum ATLAS_TRANS TransB, + const int M, const int N, const int K, const float alpha, + const float *A, const int lda, const float *B, const int ldb, + const float beta, float *C, const int ldc); +void ATL_ssymm(const enum ATLAS_SIDE Side, const enum ATLAS_UPLO Uplo, + const int M, const int N, const float alpha, + const float *A, const int lda, const float *B, const int ldb, + const float beta, float *C, const int ldc); +void ATL_ssyrk(const enum ATLAS_UPLO Uplo, const enum ATLAS_TRANS Trans, + const int N, const int K, const float alpha, + const float *A, const int lda, const float beta, + float *C, const int ldc); +void ATL_ssyr2k(const enum ATLAS_UPLO Uplo, const enum ATLAS_TRANS Trans, + const int N, const int K, const float alpha, + const float *A, const int lda, const float *B, const int ldb, + const float beta, float *C, const int ldc); +void ATL_strmm(const enum ATLAS_SIDE Side, const enum ATLAS_UPLO Uplo, + const enum ATLAS_TRANS TransA, const enum ATLAS_DIAG Diag, + const int M, const int N, const float alpha, + const float *A, const int lda, float *B, const int ldb); +void ATL_strsm(const enum ATLAS_SIDE Side, const enum ATLAS_UPLO Uplo, + const enum ATLAS_TRANS TransA, const enum ATLAS_DIAG Diag, + const int M, const int N, const float alpha, + const float *A, const int lda, float *B, const int ldb); + +int ATL_dGetNB(void); +int ATL_dGetNCNB(void); +void ATL_dgemm(const enum ATLAS_TRANS TransA, const enum ATLAS_TRANS TransB, + const int M, const int N, const int K, const double alpha, + const double *A, const int lda, const double *B, const int ldb, + const double beta, double *C, const int ldc); +void ATL_dsymm(const enum ATLAS_SIDE Side, const enum ATLAS_UPLO Uplo, + const int M, const int N, const double alpha, + const double *A, const int lda, const double *B, const int ldb, + const double beta, double *C, const int ldc); +void ATL_dsyrk(const enum ATLAS_UPLO Uplo, const enum ATLAS_TRANS Trans, + const int N, const int K, const double alpha, + const double *A, const int lda, const double beta, + double *C, const int ldc); +void ATL_dsyr2k(const enum ATLAS_UPLO Uplo, const enum ATLAS_TRANS Trans, + const int N, const int K, const double alpha, + const double *A, const int lda, const double *B, const int ldb, + const double beta, double *C, const int ldc); +void ATL_dtrmm(const enum ATLAS_SIDE Side, const enum ATLAS_UPLO Uplo, + const enum ATLAS_TRANS TransA, const enum ATLAS_DIAG Diag, + const int M, const int N, const double alpha, + const double *A, const int lda, double *B, const int ldb); +void ATL_dtrsm(const enum ATLAS_SIDE Side, const enum ATLAS_UPLO Uplo, + const enum ATLAS_TRANS TransA, const enum ATLAS_DIAG Diag, + const int M, const int N, const double alpha, + const double *A, const int lda, double *B, const int ldb); + +int ATL_cGetNB(void); +int ATL_cGetNCNB(void); +void ATL_cgemm(const enum ATLAS_TRANS TransA, const enum ATLAS_TRANS TransB, + const int M, const int N, const int K, const float *alpha, + const float *A, const int lda, const float *B, const int ldb, + const float *beta, float *C, const int ldc); +void ATL_csymm(const enum ATLAS_SIDE Side, const enum ATLAS_UPLO Uplo, + const int M, const int N, const float *alpha, + const float *A, const int lda, const float *B, const int ldb, + const float *beta, float *C, const int ldc); +void ATL_csyrk(const enum ATLAS_UPLO Uplo, const enum ATLAS_TRANS Trans, + const int N, const int K, const float *alpha, + const float *A, const int lda, const float *beta, + float *C, const int ldc); +void ATL_csyr2k(const enum ATLAS_UPLO Uplo, const enum ATLAS_TRANS Trans, + const int N, const int K, const float *alpha, + const float *A, const int lda, const float *B, const int ldb, + const float *beta, float *C, const int ldc); +void ATL_ctrmm(const enum ATLAS_SIDE Side, const enum ATLAS_UPLO Uplo, + const enum ATLAS_TRANS TransA, const enum ATLAS_DIAG Diag, + const int M, const int N, const float *alpha, + const float *A, const int lda, float *B, const int ldb); +void ATL_ctrsm(const enum ATLAS_SIDE Side, const enum ATLAS_UPLO Uplo, + const enum ATLAS_TRANS TransA, const enum ATLAS_DIAG Diag, + const int M, const int N, const float *alpha, + const float *A, const int lda, float *B, const int ldb); + +int ATL_zGetNB(void); +int ATL_zGetNCNB(void); +void ATL_zgemm(const enum ATLAS_TRANS TransA, const enum ATLAS_TRANS TransB, + const int M, const int N, const int K, const double *alpha, + const double *A, const int lda, const double *B, const int ldb, + const double *beta, double *C, const int ldc); +void ATL_zsymm(const enum ATLAS_SIDE Side, const enum ATLAS_UPLO Uplo, + const int M, const int N, const double *alpha, + const double *A, const int lda, const double *B, const int ldb, + const double *beta, double *C, const int ldc); +void ATL_zsyrk(const enum ATLAS_UPLO Uplo, const enum ATLAS_TRANS Trans, + const int N, const int K, const double *alpha, + const double *A, const int lda, const double *beta, + double *C, const int ldc); +void ATL_zsyr2k(const enum ATLAS_UPLO Uplo, const enum ATLAS_TRANS Trans, + const int N, const int K, const double *alpha, + const double *A, const int lda, const double *B, const int ldb, + const double *beta, double *C, const int ldc); +void ATL_ztrmm(const enum ATLAS_SIDE Side, const enum ATLAS_UPLO Uplo, + const enum ATLAS_TRANS TransA, const enum ATLAS_DIAG Diag, + const int M, const int N, const double *alpha, + const double *A, const int lda, double *B, const int ldb); +void ATL_ztrsm(const enum ATLAS_SIDE Side, const enum ATLAS_UPLO Uplo, + const enum ATLAS_TRANS TransA, const enum ATLAS_DIAG Diag, + const int M, const int N, const double *alpha, + const double *A, const int lda, double *B, const int ldb); + + +/* + * Routines with prefixes C and Z only + */ +void ATL_chemm(const enum ATLAS_SIDE Side, const enum ATLAS_UPLO Uplo, + const int M, const int N, const float *alpha, + const float *A, const int lda, const float *B, const int ldb, + const float *beta, float *C, const int ldc); +void ATL_cherk(const enum ATLAS_UPLO Uplo, const enum ATLAS_TRANS Trans, + const int N, const int K, const float alpha, + const float *A, const int lda, const float beta, + float *C, const int ldc); +void ATL_cher2k(const enum ATLAS_UPLO Uplo, const enum ATLAS_TRANS Trans, + const int N, const int K, const float *alpha, + const float *A, const int lda, const float *B, const int ldb, + const float beta, float *C, const int ldc); + +void ATL_zhemm(const enum ATLAS_SIDE Side, const enum ATLAS_UPLO Uplo, + const int M, const int N, const double *alpha, + const double *A, const int lda, const double *B, const int ldb, + const double *beta, double *C, const int ldc); +void ATL_zherk(const enum ATLAS_UPLO Uplo, const enum ATLAS_TRANS Trans, + const int N, const int K, const double alpha, + const double *A, const int lda, const double beta, + double *C, const int ldc); +void ATL_zher2k(const enum ATLAS_UPLO Uplo, const enum ATLAS_TRANS Trans, + const int N, const int K, const double *alpha, + const double *A, const int lda, const double *B, const int ldb, + const double beta, double *C, const int ldc); + + +#endif diff --git a/kaldi_io/src/tools/ATLAS/include/atlas_lvl2.h b/kaldi_io/src/tools/ATLAS/include/atlas_lvl2.h new file mode 100644 index 0000000..b09a021 --- /dev/null +++ b/kaldi_io/src/tools/ATLAS/include/atlas_lvl2.h @@ -0,0 +1,294 @@ +/* + * Automatically Tuned Linear Algebra Software v3.8.3 + * (C) Copyright 1999 R. Clint Whaley + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions + * are met: + * 1. Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * 2. Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions, and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * 3. The name of the ATLAS group or the names of its contributers may + * not be used to endorse or promote products derived from this + * software without specific written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + * ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED + * TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR + * PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE ATLAS GROUP OR ITS CONTRIBUTORS + * BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + * POSSIBILITY OF SUCH DAMAGE. + * + */ + +#include "atlas_level2.h" +#include "atlas_kernel2.h" +#ifndef ATLAS_LVL2_H +#define ATLAS_LVL2_H + +/* + * Real kernels + */ +void ATL_sger1_a1_x1_yX + (const int M, const int N, const float alpha, const float *X, + const int incX, const float *Y, const int incY, float *A, const int lda); +void ATL_sgemvS_a1_x1_bX_y1 + (const int M, const int N, const float alpha, const float *A, + const int lda, const float *X, const int incX, const float beta, + float *Y, const int incY); +void ATL_sgemvT_a1_x1_bX_y1 + (const int M, const int N, const float alpha, const float *A, + const int lda, const float *X, const int incX, const float beta, + float *Y, const int incY); +void ATL_sgemvN_a1_x1_bX_y1 + (const int M, const int N, const float alpha, const float *A, + const int lda, const float *X, const int incX, const float beta, + float *Y, const int incY); +void ATL_sgemvS_a1_x1_b1_y1 + (const int M, const int N, const float alpha, const float *A, + const int lda, const float *X, const int incX, const float beta, + float *Y, const int incY); +void ATL_sgemvT_a1_x1_b1_y1 + (const int M, const int N, const float alpha, const float *A, + const int lda, const float *X, const int incX, const float beta, + float *Y, const int incY); +void ATL_sgemvN_a1_x1_b1_y1 + (const int M, const int N, const float alpha, const float *A, + const int lda, const float *X, const int incX, const float beta, + float *Y, const int incY); +void ATL_sgemvS_a1_x1_b0_y1 + (const int M, const int N, const float alpha, const float *A, + const int lda, const float *X, const int incX, const float beta, + float *Y, const int incY); +void ATL_sgemvT_a1_x1_b0_y1 + (const int M, const int N, const float alpha, const float *A, + const int lda, const float *X, const int incX, const float beta, + float *Y, const int incY); +void ATL_sgemvN_a1_x1_b0_y1 + (const int M, const int N, const float alpha, const float *A, + const int lda, const float *X, const int incX, const float beta, + float *Y, const int incY); +void ATL_dger1_a1_x1_yX + (const int M, const int N, const double alpha, const double *X, + const int incX, const double *Y, const int incY, double *A, const int lda); +void ATL_dgemvS_a1_x1_bX_y1 + (const int M, const int N, const double alpha, const double *A, + const int lda, const double *X, const int incX, const double beta, + double *Y, const int incY); +void ATL_dgemvT_a1_x1_bX_y1 + (const int M, const int N, const double alpha, const double *A, + const int lda, const double *X, const int incX, const double beta, + double *Y, const int incY); +void ATL_dgemvN_a1_x1_bX_y1 + (const int M, const int N, const double alpha, const double *A, + const int lda, const double *X, const int incX, const double beta, + double *Y, const int incY); +void ATL_dgemvS_a1_x1_b1_y1 + (const int M, const int N, const double alpha, const double *A, + const int lda, const double *X, const int incX, const double beta, + double *Y, const int incY); +void ATL_dgemvT_a1_x1_b1_y1 + (const int M, const int N, const double alpha, const double *A, + const int lda, const double *X, const int incX, const double beta, + double *Y, const int incY); +void ATL_dgemvN_a1_x1_b1_y1 + (const int M, const int N, const double alpha, const double *A, + const int lda, const double *X, const int incX, const double beta, + double *Y, const int incY); +void ATL_dgemvS_a1_x1_b0_y1 + (const int M, const int N, const double alpha, const double *A, + const int lda, const double *X, const int incX, const double beta, + double *Y, const int incY); +void ATL_dgemvT_a1_x1_b0_y1 + (const int M, const int N, const double alpha, const double *A, + const int lda, const double *X, const int incX, const double beta, + double *Y, const int incY); +void ATL_dgemvN_a1_x1_b0_y1 + (const int M, const int N, const double alpha, const double *A, + const int lda, const double *X, const int incX, const double beta, + double *Y, const int incY); + +/* + * Complex kernels + */ +void ATL_cger1u_a1_x1_yX + (const int M, const int N, const float *alpha, const float *X, + const int incX, const float *Y, const int incY, float *A, const int lda); +void ATL_cger1c_a1_x1_yX + (const int M, const int N, const float *alpha, const float *X, + const int incX, const float *Y, const int incY, float *A, const int lda); +void ATL_cgemvS_a1_x1_bXi0_y1 + (const int M, const int N, const float *alpha, const float *A, + const int lda, const float *X, const int incX, const float *beta, + float *Y, const int incY); +void ATL_cgemvC_a1_x1_bXi0_y1 + (const int M, const int N, const float *alpha, const float *A, + const int lda, const float *X, const int incX, const float *beta, + float *Y, const int incY); +void ATL_cgemvNc_a1_x1_bXi0_y1 + (const int M, const int N, const float *alpha, const float *A, + const int lda, const float *X, const int incX, const float *beta, + float *Y, const int incY); +void ATL_cgemvT_a1_x1_bXi0_y1 + (const int M, const int N, const float *alpha, const float *A, + const int lda, const float *X, const int incX, const float *beta, + float *Y, const int incY); +void ATL_cgemvN_a1_x1_bXi0_y1 + (const int M, const int N, const float *alpha, const float *A, + const int lda, const float *X, const int incX, const float *beta, + float *Y, const int incY); +void ATL_cgemvS_a1_x1_bX_y1 + (const int M, const int N, const float *alpha, const float *A, + const int lda, const float *X, const int incX, const float *beta, + float *Y, const int incY); +void ATL_cgemvC_a1_x1_bX_y1 + (const int M, const int N, const float *alpha, const float *A, + const int lda, const float *X, const int incX, const float *beta, + float *Y, const int incY); +void ATL_cgemvNc_a1_x1_bX_y1 + (const int M, const int N, const float *alpha, const float *A, + const int lda, const float *X, const int incX, const float *beta, + float *Y, const int incY); +void ATL_cgemvT_a1_x1_bX_y1 + (const int M, const int N, const float *alpha, const float *A, + const int lda, const float *X, const int incX, const float *beta, + float *Y, const int incY); +void ATL_cgemvN_a1_x1_bX_y1 + (const int M, const int N, const float *alpha, const float *A, + const int lda, const float *X, const int incX, const float *beta, + float *Y, const int incY); +void ATL_cgemvS_a1_x1_b1_y1 + (const int M, const int N, const float *alpha, const float *A, + const int lda, const float *X, const int incX, const float *beta, + float *Y, const int incY); +void ATL_cgemvC_a1_x1_b1_y1 + (const int M, const int N, const float *alpha, const float *A, + const int lda, const float *X, const int incX, const float *beta, + float *Y, const int incY); +void ATL_cgemvNc_a1_x1_b1_y1 + (const int M, const int N, const float *alpha, const float *A, + const int lda, const float *X, const int incX, const float *beta, + float *Y, const int incY); +void ATL_cgemvT_a1_x1_b1_y1 + (const int M, const int N, const float *alpha, const float *A, + const int lda, const float *X, const int incX, const float *beta, + float *Y, const int incY); +void ATL_cgemvN_a1_x1_b1_y1 + (const int M, const int N, const float *alpha, const float *A, + const int lda, const float *X, const int incX, const float *beta, + float *Y, const int incY); +void ATL_cgemvS_a1_x1_b0_y1 + (const int M, const int N, const float *alpha, const float *A, + const int lda, const float *X, const int incX, const float *beta, + float *Y, const int incY); +void ATL_cgemvC_a1_x1_b0_y1 + (const int M, const int N, const float *alpha, const float *A, + const int lda, const float *X, const int incX, const float *beta, + float *Y, const int incY); +void ATL_cgemvNc_a1_x1_b0_y1 + (const int M, const int N, const float *alpha, const float *A, + const int lda, const float *X, const int incX, const float *beta, + float *Y, const int incY); +void ATL_cgemvT_a1_x1_b0_y1 + (const int M, const int N, const float *alpha, const float *A, + const int lda, const float *X, const int incX, const float *beta, + float *Y, const int incY); +void ATL_cgemvN_a1_x1_b0_y1 + (const int M, const int N, const float *alpha, const float *A, + const int lda, const float *X, const int incX, const float *beta, + float *Y, const int incY); +void ATL_zger1u_a1_x1_yX + (const int M, const int N, const double *alpha, const double *X, + const int incX, const double *Y, const int incY, double *A, const int lda); +void ATL_zger1c_a1_x1_yX + (const int M, const int N, const double *alpha, const double *X, + const int incX, const double *Y, const int incY, double *A, const int lda); +void ATL_zgemvS_a1_x1_bXi0_y1 + (const int M, const int N, const double *alpha, const double *A, + const int lda, const double *X, const int incX, const double *beta, + double *Y, const int incY); +void ATL_zgemvC_a1_x1_bXi0_y1 + (const int M, const int N, const double *alpha, const double *A, + const int lda, const double *X, const int incX, const double *beta, + double *Y, const int incY); +void ATL_zgemvNc_a1_x1_bXi0_y1 + (const int M, const int N, const double *alpha, const double *A, + const int lda, const double *X, const int incX, const double *beta, + double *Y, const int incY); +void ATL_zgemvT_a1_x1_bXi0_y1 + (const int M, const int N, const double *alpha, const double *A, + const int lda, const double *X, const int incX, const double *beta, + double *Y, const int incY); +void ATL_zgemvN_a1_x1_bXi0_y1 + (const int M, const int N, const double *alpha, const double *A, + const int lda, const double *X, const int incX, const double *beta, + double *Y, const int incY); +void ATL_zgemvS_a1_x1_bX_y1 + (const int M, const int N, const double *alpha, const double *A, + const int lda, const double *X, const int incX, const double *beta, + double *Y, const int incY); +void ATL_zgemvC_a1_x1_bX_y1 + (const int M, const int N, const double *alpha, const double *A, + const int lda, const double *X, const int incX, const double *beta, + double *Y, const int incY); +void ATL_zgemvNc_a1_x1_bX_y1 + (const int M, const int N, const double *alpha, const double *A, + const int lda, const double *X, const int incX, const double *beta, + double *Y, const int incY); +void ATL_zgemvT_a1_x1_bX_y1 + (const int M, const int N, const double *alpha, const double *A, + const int lda, const double *X, const int incX, const double *beta, + double *Y, const int incY); +void ATL_zgemvN_a1_x1_bX_y1 + (const int M, const int N, const double *alpha, const double *A, + const int lda, const double *X, const int incX, const double *beta, + double *Y, const int incY); +void ATL_zgemvS_a1_x1_b1_y1 + (const int M, const int N, const double *alpha, const double *A, + const int lda, const double *X, const int incX, const double *beta, + double *Y, const int incY); +void ATL_zgemvC_a1_x1_b1_y1 + (const int M, const int N, const double *alpha, const double *A, + const int lda, const double *X, const int incX, const double *beta, + double *Y, const int incY); +void ATL_zgemvNc_a1_x1_b1_y1 + (const int M, const int N, const double *alpha, const double *A, + const int lda, const double *X, const int incX, const double *beta, + double *Y, const int incY); +void ATL_zgemvT_a1_x1_b1_y1 + (const int M, const int N, const double *alpha, const double *A, + const int lda, const double *X, const int incX, const double *beta, + double *Y, const int incY); +void ATL_zgemvN_a1_x1_b1_y1 + (const int M, const int N, const double *alpha, const double *A, + const int lda, const double *X, const int incX, const double *beta, + double *Y, const int incY); +void ATL_zgemvS_a1_x1_b0_y1 + (const int M, const int N, const double *alpha, const double *A, + const int lda, const double *X, const int incX, const double *beta, + double *Y, const int incY); +void ATL_zgemvC_a1_x1_b0_y1 + (const int M, const int N, const double *alpha, const double *A, + const int lda, const double *X, const int incX, const double *beta, + double *Y, const int incY); +void ATL_zgemvNc_a1_x1_b0_y1 + (const int M, const int N, const double *alpha, const double *A, + const int lda, const double *X, const int incX, const double *beta, + double *Y, const int incY); +void ATL_zgemvT_a1_x1_b0_y1 + (const int M, const int N, const double *alpha, const double *A, + const int lda, const double *X, const int incX, const double *beta, + double *Y, const int incY); +void ATL_zgemvN_a1_x1_b0_y1 + (const int M, const int N, const double *alpha, const double *A, + const int lda, const double *X, const int incX, const double *beta, + double *Y, const int incY); + +#endif diff --git a/kaldi_io/src/tools/ATLAS/include/atlas_lvl3.h b/kaldi_io/src/tools/ATLAS/include/atlas_lvl3.h new file mode 100644 index 0000000..eab93c0 --- /dev/null +++ b/kaldi_io/src/tools/ATLAS/include/atlas_lvl3.h @@ -0,0 +1,512 @@ +/* + * Automatically Tuned Linear Algebra Software v3.8.3 + * (C) Copyright 1997 R. Clint Whaley + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions + * are met: + * 1. Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * 2. Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions, and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * 3. The name of the ATLAS group or the names of its contributers may + * not be used to endorse or promote products derived from this + * software without specific written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + * ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED + * TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR + * PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE ATLAS GROUP OR ITS CONTRIBUTORS + * BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + * POSSIBILITY OF SUCH DAMAGE. + * + */ + +#ifndef ATLAS_LVL3_H +#define ATLAS_LVL3_H + +#include "atlas_misc.h" +#include "atlas_f77.h" +#include "atlas_level3.h" +#if defined(SREAL) + #include "smm.h" + #include "sXover.h" +#elif defined(DREAL) + #include "dmm.h" + #include "dXover.h" +#elif defined(QREAL) + #include "qmm.h" + #include "qXover.h" +#elif defined(SCPLX) + #ifdef ATL_NCMM + #include "atlas_cNCmm.h" + #else + #include "cmm.h" + #endif + #include "cXover.h" +#elif defined(DCPLX) + #ifdef ATL_NCMM + #include "atlas_zNCmm.h" + #else + #include "zmm.h" + #endif + #include "zmm.h" + #include "zXover.h" +#endif +#ifndef ATL_3NB + #define ATL_3NB 3*NB + + #define NN_MNK_M NBNB*NB + #define NN_MNK_N NBNB*NB + #define NN_MNK_K NBNB*NB + #define NN_MNK_MN NBNB*NB + #define NN_MNK_GE NBNB*NB + + #define NT_MNK_M NBNB*NB + #define NT_MNK_N NBNB*NB + #define NT_MNK_K NBNB*NB + #define NT_MNK_MN NBNB*NB + #define NT_MNK_GE NBNB*NB + + #define TN_MNK_M NBNB*NB + #define TN_MNK_N NBNB*NB + #define TN_MNK_K NBNB*NB + #define TN_MNK_MN NBNB*NB + #define TN_MNK_GE NBNB*NB + + #define TT_MNK_M NBNB*NB + #define TT_MNK_N NBNB*NB + #define TT_MNK_K NBNB*NB + #define TT_MNK_MN NBNB*NB + #define TT_MNK_GE NBNB*NB +#endif + +#ifndef CN_MNK_M + #define CN_MNK_M TN_MNK_M + #define CN_MNK_N TN_MNK_N + #define CN_MNK_K TN_MNK_K + #define CN_MNK_MN TN_MNK_MN + #define CN_MNK_GE TN_MNK_GE +#endif +#ifndef NC_MNK_M + #define NC_MNK_M NT_MNK_M + #define NC_MNK_N NT_MNK_N + #define NC_MNK_K NT_MNK_K + #define NC_MNK_MN NT_MNK_MN + #define NC_MNK_GE NT_MNK_GE +#endif +#ifndef CT_MNK_M + #define CT_MNK_M TT_MNK_M + #define CT_MNK_N TT_MNK_N + #define CT_MNK_K TT_MNK_K + #define CT_MNK_MN TT_MNK_MN + #define CT_MNK_GE TT_MNK_GE +#endif +#ifndef TC_MNK_M + #define TC_MNK_M TT_MNK_M + #define TC_MNK_N TT_MNK_N + #define TC_MNK_K TT_MNK_K + #define TC_MNK_MN TT_MNK_MN + #define TC_MNK_GE TT_MNK_GE +#endif +#ifndef CC_MNK_M + #define CC_MNK_M TT_MNK_M + #define CC_MNK_N TT_MNK_N + #define CC_MNK_K TT_MNK_K + #define CC_MNK_MN TT_MNK_MN + #define CC_MNK_GE TT_MNK_GE +#endif + +#define CPAT Mjoin(C_ATL_, PRE); + +#ifndef ATL_MaxMalloc + #define ATL_MaxMalloc 67108864 +#endif + +typedef void (*MAT2BLK)(int, int, const TYPE*, int, TYPE*, const SCALAR); +typedef void (*MAT2BLK2)(const int, const int, const SCALAR, const TYPE*, + const int, TYPE*, const int); +typedef void (*MATSCAL)(const int, const int, const SCALAR, TYPE*, const int); +typedef void (*PUTBLK)(int, int, TYPE*, TYPE*, int, const SCALAR); +typedef void (*NBCLEANUP)(const TYPE*, const TYPE*, TYPE*, const int); +typedef int (*MMINTR)(const enum ATLAS_TRANS, const enum ATLAS_TRANS, + const int, const int, const int, const SCALAR, + const TYPE *, const int, const TYPE *, const int, + const SCALAR, TYPE *, const int); +typedef void (*NBMM0)(const int, const int, const int, const TYPE, + const TYPE*, const int, const TYPE*, const int, + const TYPE, TYPE*, const int); + +void ATL_xerbla(int p, char *rout, char *form, ...); +int Mjoin(PATL,GetNB)(void); +int Mjoin(PATL,GetNCNB)(void); + +void Mjoin(PATL, gescal_bX)(const int, const int, const SCALAR, TYPE*, + const int); +void Mjoin(PATL, gescal_bn1)(const int, const int, const SCALAR, TYPE*, + const int); +void Mjoin(PATL, gescal_b0)(const int, const int, const SCALAR, TYPE*, + const int); + +void Mjoin(PATL,pKBmm_bX)(const int M, const int N, const int K, + const TYPE alpha, const TYPE *A, const int lda, + const TYPE *B, const int ldb, const TYPE beta, + TYPE *C, const int ldc); +void Mjoin(PATL,pNBmm_bX)(const int M, const int N, const int K, + const TYPE alpha, const TYPE *A, const int lda, + const TYPE *B, const int ldb, const TYPE beta, + TYPE *C, const int ldc); +void Mjoin(PATL,pMBmm_bX)(const int M, const int N, const int K, + const TYPE alpha, const TYPE *A, const int lda, + const TYPE *B, const int ldb, const TYPE beta, + TYPE *C, const int ldc); +void Mjoin(PATL,pKBmm_b1)(const int M, const int N, const int K, + const TYPE alpha, const TYPE *A, const int lda, + const TYPE *B, const int ldb, const TYPE beta, + TYPE *C, const int ldc); +void Mjoin(PATL,pNBmm_b1)(const int M, const int N, const int K, + const TYPE alpha, const TYPE *A, const int lda, + const TYPE *B, const int ldb, const TYPE beta, + TYPE *C, const int ldc); +void Mjoin(PATL,pMBmm_b1)(const int M, const int N, const int K, + const TYPE alpha, const TYPE *A, const int lda, + const TYPE *B, const int ldb, const TYPE beta, + TYPE *C, const int ldc); +void Mjoin(PATL,pKBmm_b0)(const int M, const int N, const int K, + const TYPE alpha, const TYPE *A, const int lda, + const TYPE *B, const int ldb, const TYPE beta, + TYPE *C, const int ldc); +void Mjoin(PATL,pNBmm_b0)(const int M, const int N, const int K, + const TYPE alpha, const TYPE *A, const int lda, + const TYPE *B, const int ldb, const TYPE beta, + TYPE *C, const int ldc); +void Mjoin(PATL,pMBmm_b0)(const int M, const int N, const int K, + const TYPE alpha, const TYPE *A, const int lda, + const TYPE *B, const int ldb, const TYPE beta, + TYPE *C, const int ldc); +void Mjoin(PATL,pKBmm)(const int M, const int N, const int K, + const TYPE alpha, const TYPE *A, const int lda, + const TYPE *B, const int ldb, const TYPE beta, + TYPE *C, const int ldc); + +void Mjoin(PATL,MBJBmm)(const int N, const int K, const TYPE *A, const TYPE *B, + const TYPE beta, TYPE *C, const int ldc); +void Mjoin(PATL,IBJBmm)(int IB, int JB, int K, const TYPE *A, const TYPE *B, + const TYPE beta, TYPE *C, const int ldc); +void Mjoin(PATL,IBNBmm)(const int M, const int K, const TYPE *A, const TYPE *B, + const TYPE beta, TYPE *C, const int ldc); +#ifdef TCPLX + +void Mjoin(PATL,CNBmm_b0)(const int M, const int N, const int K, + const TYPE alpha, const TYPE *A, const int lda, + const TYPE *B, const int ldb, const TYPE beta, + TYPE *C, const int ldc); +void Mjoin(PATL,CNBmm_b1)(const int M, const int N, const int K, + const TYPE alpha, const TYPE *A, const int lda, + const TYPE *B, const int ldb, const TYPE beta, + TYPE *C, const int ldc); +void Mjoin(PATL,CNBmm_bX)(const int M, const int N, const int K, + const TYPE alpha, const TYPE *A, const int lda, + const TYPE *B, const int ldb, const TYPE beta, + TYPE *C, const int ldc); +void Mjoin(PATL, gescal_bXi0)(const int, const int, const SCALAR, TYPE*, + const int); + +void Mjoin(PATL,row2blkT_aXi0) + (const int, const int, const TYPE*, const int, TYPE*, const SCALAR); +void Mjoin(PATL,row2blkT2_aXi0) + (const int, const int, const TYPE*, const int, TYPE*, const SCALAR); +void Mjoin(PATL,col2blk_aXi0) + (const int, const int, const TYPE*, const int, TYPE*, const SCALAR); +void Mjoin(PATL,col2blk2_aXi0) + (const int, const int, const TYPE*, const int, TYPE*, const SCALAR); + +void Mjoin(PATL,row2blkC_aX) + (const int, const int, const TYPE*, const int, TYPE*, const SCALAR); +void Mjoin(PATL,row2blkC2_aX) + (const int, const int, const TYPE*, const int, TYPE*, const SCALAR); +void Mjoin(PATL,col2blkConj_aX) + (const int, const int, const TYPE*, const int, TYPE*, const SCALAR); +void Mjoin(PATL,col2blkConj2_aX) + (const int, const int, const TYPE*, const int, TYPE*, const SCALAR); +void Mjoin(PATL,row2blkC_a1) + (const int, const int, const TYPE*, const int, TYPE*, const SCALAR); +void Mjoin(PATL,row2blkC2_a1) + (const int, const int, const TYPE*, const int, TYPE*, const SCALAR); +void Mjoin(PATL,col2blkConj_a1) + (const int, const int, const TYPE*, const int, TYPE*, const SCALAR); +void Mjoin(PATL,col2blkConj2_a1) + (const int, const int, const TYPE*, const int, TYPE*, const SCALAR); +void Mjoin(PATL,row2blkC_aXi0) + (const int, const int, const TYPE*, const int, TYPE*, const SCALAR); +void Mjoin(PATL,row2blkC2_aXi0) + (const int, const int, const TYPE*, const int, TYPE*, const SCALAR); +void Mjoin(PATL,col2blkConj_aXi0) + (const int, const int, const TYPE*, const int, TYPE*, const SCALAR); +void Mjoin(PATL,col2blkConj2_aXi0) + (const int, const int, const TYPE*, const int, TYPE*, const SCALAR); + +void Mjoin(PATL,mmJIK2) + (int K, int nMb, int nNb, int nKb, int ib, int jb, int kb, + const SCALAR alpha, const TYPE *pA0, const TYPE *B, int ldb, TYPE *pB0, + int incB, MAT2BLK B2blk, const SCALAR beta, TYPE *C, int ldc, + MATSCAL gescal, NBMM0 NBmm0); + +void Mjoin(PATL,mmIJK2) + (int K, int nMb, int nNb, int nKb, int ib, int jb, int kb, + const SCALAR alpha, const TYPE *A, const int lda, TYPE *pA0, const int incA, + MAT2BLK A2blk, TYPE *pB0, const SCALAR beta, TYPE *C, int ldc, + MATSCAL gescal, NBMM0 NBmm0); + +#else /* real */ + +void Mjoin(PATL,putblk_bX)(int M, int N, TYPE *V, TYPE *C, int ldc, const SCALAR beta); +void Mjoin(PATL,putblk_bn1)(int M, int N, TYPE *V, TYPE *C, int ldc, const SCALAR beta); +void Mjoin(PATL,putblk_b1)(int M, int N, TYPE *V, TYPE *C, int ldc, const SCALAR beta); +void Mjoin(PATL,putblk_b0)(int M, int N, TYPE *V, TYPE *C, int ldc, const SCALAR beta); +void ATL_gereal2cplx(const int M, const int N, TYPE *alpha, TYPE *R, int ldr, + TYPE *I, int ldi, TYPE *beta, TYPE *C, int ldc); + +void NBmm_b1(const int M, const int N, const int K, const TYPE alpha, + const TYPE *A, const int lda, const TYPE *B, const int ldb, + const TYPE beta, TYPE *C, const int ldc); +void NBmm_b0(const int M, const int N, const int K, const TYPE alpha, + const TYPE *A, const int lda, const TYPE *B, const int ldb, + const TYPE beta, TYPE *C, const int ldc); +void NBmm_bX(const int M, const int N, const int K, const TYPE alpha, + const TYPE *A, const int lda, const TYPE *B, const int ldb, + const TYPE beta, TYPE *C, const int ldc); +void Mjoin(PATL,mmJIK2)(int K, int nMb, int nNb, int nKb, int ib, int jb, + int kb, const SCALAR alpha, const TYPE *pA0, + const TYPE *B, int ldb, TYPE *pB0, int incB, + MAT2BLK B2blk, const SCALAR beta, TYPE *C, int ldc, + TYPE *pC, PUTBLK putblk, NBMM0 NBmm0); + +void Mjoin(PATL,mmIJK2)(int K, int nMb, int nNb, int nKb, int ib, int jb, + int kb, const SCALAR alpha, const TYPE *A, int lda, + TYPE *pA0, int incA, MAT2BLK A2blk, const TYPE *pB0, + const SCALAR beta, TYPE *C, int ldc, TYPE *pC, + PUTBLK putblk, NBMM0 NBmm0); + + +void Mjoin(PATL,aliased_gemm) + (const enum ATLAS_TRANS TA, const enum ATLAS_TRANS TB, + const int M, const int N, const int K, const SCALAR alpha, + const TYPE *A, const int lda, const TYPE *B, const int ldb, + const SCALAR beta, TYPE *C, const int ldc); +void Mjoin(PATL,f77gemm) + (const enum ATLAS_TRANS TA, const enum ATLAS_TRANS TB, + const int M, const int N, const int K, const SCALAR alpha, + const TYPE *A, const int lda, const TYPE *B, const int ldb, + const SCALAR beta, TYPE *C, const int ldc); +void Mjoin(PATL,gemm) + (const enum ATLAS_TRANS TA, const enum ATLAS_TRANS TB, + const int M, const int N, const int K, const SCALAR alpha, + const TYPE *A, const int lda, const TYPE *B, const int ldb, + const SCALAR beta, TYPE *C, const int ldc); +void Mjoin(PATL,small_mm) + (const enum ATLAS_TRANS TA, const enum ATLAS_TRANS TB, + const int M, const int N, const int K, const SCALAR alpha, + const TYPE *A, const int lda, const TYPE *B, const int ldb, + const SCALAR beta, TYPE *C, const int ldc); +void Mjoin(PATL,big_mm) + (const enum ATLAS_TRANS TA, const enum ATLAS_TRANS TB, + const int M, const int N, const int K, const SCALAR alpha, + const TYPE *A, const int lda, const TYPE *B, const int ldb, + const SCALAR beta, TYPE *C, const int ldc); +#endif + +#ifdef USERGEMM +int Mjoin(PATU,usergemm)(const enum ATLAS_TRANS, const enum ATLAS_TRANS, + const int, const int, const int, const SCALAR, + const TYPE*, const int, const TYPE*, + const int, const SCALAR, TYPE*, const int); +#endif +int Mjoin(PATL,NCmmJIK)(const enum ATLAS_TRANS, const enum ATLAS_TRANS, + const int, const int, const int, const SCALAR, + const TYPE*, const int, const TYPE*, + const int, const SCALAR, TYPE*, const int); +int Mjoin(PATL,NCmmIJK)(const enum ATLAS_TRANS, const enum ATLAS_TRANS, + const int, const int, const int, const SCALAR, + const TYPE*, const int, const TYPE*, + const int, const SCALAR, TYPE*, const int); +int Mjoin(PATL,NCmmJIK_c)(const enum ATLAS_TRANS, const enum ATLAS_TRANS, + const int, const int, const int, const SCALAR, + const TYPE*, const int, const TYPE*, + const int, const SCALAR, TYPE*, const int); +int Mjoin(PATL,NCmmIJK_c)(const enum ATLAS_TRANS, const enum ATLAS_TRANS, + const int, const int, const int, const SCALAR, + const TYPE*, const int, const TYPE*, + const int, const SCALAR, TYPE*, const int); + +void Mjoin(PATL,row2blkT2_aX)(int, int, const TYPE*, int, TYPE*, const SCALAR); +void Mjoin(PATL,row2blkT_aX)(int, int, const TYPE*, int, TYPE*, const SCALAR); +void Mjoin(PATL,col2blk2_aX)(int, int, const TYPE*, int, TYPE*, const SCALAR); +void Mjoin(PATL,col2blk_aX)(int, int, const TYPE*, int, TYPE*, const SCALAR); +void Mjoin(PATL,row2blkT2_an1)(int, int, const TYPE*, int, TYPE*, + const SCALAR); +void Mjoin(PATL,row2blkT_an1)(int, int, const TYPE*, int, TYPE*, const SCALAR); +void Mjoin(PATL,col2blk2_an1)(int, int, const TYPE*, int, TYPE*, const SCALAR); +void Mjoin(PATL,col2blk_an1)(int, int, const TYPE*, int, TYPE*, const SCALAR); +void Mjoin(PATL,row2blkT2_a1)(int, int, const TYPE*, int, TYPE*, const SCALAR); +void Mjoin(PATL,row2blkT_a1)(int, int, const TYPE*, int, TYPE*, const SCALAR); +void Mjoin(PATL,col2blk2_a1)(int, int, const TYPE*, int, TYPE*, const SCALAR); +void Mjoin(PATL,col2blk_a1)(int, int, const TYPE*, int, TYPE*, const SCALAR); + +int Mjoin(PATL,mmJITcp)(const enum ATLAS_TRANS TA, const enum ATLAS_TRANS TB, + const int M, const int N, const int K, + const SCALAR alpha, const TYPE *A, const int lda, + const TYPE *B, const int ldb, const SCALAR beta, + TYPE *C, const int ldc); +int Mjoin(PATL,mmJIK)(const enum ATLAS_TRANS TA, const enum ATLAS_TRANS TB, + const int M, const int N, const int K, + const SCALAR alpha, const TYPE *A, const int lda, + const TYPE *B, const int ldb, const SCALAR beta, + TYPE *C, const int ldc); +int Mjoin(PATL,mmIJK)(const enum ATLAS_TRANS TA, const enum ATLAS_TRANS TB, + const int M, const int N, const int K, + const SCALAR alpha, const TYPE *A, const int lda, + const TYPE *B, const int ldb, const SCALAR beta, + TYPE *C, const int ldc); +int Mjoin(PATL,mmJKI)(const enum ATLAS_TRANS TA, const enum ATLAS_TRANS TB, + const int M, const int N, const int K, + const SCALAR alpha, const TYPE *A, const int lda, + const TYPE *B, const int ldb, const SCALAR beta, + TYPE *C, const int ldc); + +void Mjoin(PATL,mmK) + (int M, int m, int N, int n, int nblk, int kr, int KR, const SCALAR alphaA, + const SCALAR alphaB, const SCALAR beta, const TYPE *A, const int lda, + const int incA, TYPE *pA, const int incAW, const TYPE *B, const int ldb, + const int incB, TYPE *pB, const int incBW, TYPE *C, const int ldc, + MAT2BLK2 A2blk, MAT2BLK2 B2blk, NBMM0 NBmm0, NBMM0 NBmm1); + +int Mjoin(PATL,mmBPP)(const enum ATLAS_TRANS TA, const enum ATLAS_TRANS TB, + const int M, const int N, const int K, + const SCALAR alpha, const TYPE *A, const int lda, + const TYPE *B, const int ldb, const SCALAR beta, + TYPE *C, const int ldc); + + +void Mjoin(PATL,gemmTT) + (const int M, const int N, const int K, const SCALAR alpha, + const TYPE *A, const int lda, const TYPE *B, const int ldb, + const SCALAR beta, TYPE *C, const int ldc); +void Mjoin(PATL,aliased_gemmTT) + (const int M, const int N, const int K, const SCALAR alpha, + const TYPE *A, const int lda, const TYPE *B, const int ldb, + const SCALAR beta, TYPE *C, const int ldc); +void Mjoin(PATL,gemmTN) + (const int M, const int N, const int K, const SCALAR alpha, + const TYPE *A, const int lda, const TYPE *B, const int ldb, + const SCALAR beta, TYPE *C, const int ldc); +void Mjoin(PATL,aliased_gemmTN) + (const int M, const int N, const int K, const SCALAR alpha, + const TYPE *A, const int lda, const TYPE *B, const int ldb, + const SCALAR beta, TYPE *C, const int ldc); +void Mjoin(PATL,gemmNT) + (const int M, const int N, const int K, const SCALAR alpha, + const TYPE *A, const int lda, const TYPE *B, const int ldb, + const SCALAR beta, TYPE *C, const int ldc); +void Mjoin(PATL,aliased_gemmNT) + (const int M, const int N, const int K, const SCALAR alpha, + const TYPE *A, const int lda, const TYPE *B, const int ldb, + const SCALAR beta, TYPE *C, const int ldc); +void Mjoin(PATL,gemmNN) + (const int M, const int N, const int K, const SCALAR alpha, + const TYPE *A, const int lda, const TYPE *B, const int ldb, + const SCALAR beta, TYPE *C, const int ldc); +void Mjoin(PATL,aliased_gemmNN) + (const int M, const int N, const int K, const SCALAR alpha, + const TYPE *A, const int lda, const TYPE *B, const int ldb, + const SCALAR beta, TYPE *C, const int ldc); + + +void NCmmNNIJK_c + (const enum ATLAS_TRANS TA, const enum ATLAS_TRANS TB, + const int M, const int N, const int K, const SCALAR alpha, + const TYPE *A, const int lda, const TYPE *B, const int ldb, + const SCALAR beta, TYPE *C, const int ldc); +void NCmmNTIJK_c + (const enum ATLAS_TRANS TA, const enum ATLAS_TRANS TB, + const int M, const int N, const int K, const SCALAR alpha, + const TYPE *A, const int lda, const TYPE *B, const int ldb, + const SCALAR beta, TYPE *C, const int ldc); +void NCmmTNIJK_c + (const enum ATLAS_TRANS TA, const enum ATLAS_TRANS TB, + const int M, const int N, const int K, const SCALAR alpha, + const TYPE *A, const int lda, const TYPE *B, const int ldb, + const SCALAR beta, TYPE *C, const int ldc); +void NCmmTTIJK_c + (const enum ATLAS_TRANS TA, const enum ATLAS_TRANS TB, + const int M, const int N, const int K, const SCALAR alpha, + const TYPE *A, const int lda, const TYPE *B, const int ldb, + const SCALAR beta, TYPE *C, const int ldc); +void NCmmNNIJK + (const enum ATLAS_TRANS TA, const enum ATLAS_TRANS TB, + const int M, const int N, const int K, const SCALAR alpha, + const TYPE *A, const int lda, const TYPE *B, const int ldb, + const SCALAR beta, TYPE *C, const int ldc); +void NCmmNTIJK + (const enum ATLAS_TRANS TA, const enum ATLAS_TRANS TB, + const int M, const int N, const int K, const SCALAR alpha, + const TYPE *A, const int lda, const TYPE *B, const int ldb, + const SCALAR beta, TYPE *C, const int ldc); +void NCmmTNIJK + (const enum ATLAS_TRANS TA, const enum ATLAS_TRANS TB, + const int M, const int N, const int K, const SCALAR alpha, + const TYPE *A, const int lda, const TYPE *B, const int ldb, + const SCALAR beta, TYPE *C, const int ldc); +void NCmmTTIJK + (const enum ATLAS_TRANS TA, const enum ATLAS_TRANS TB, + const int M, const int N, const int K, const SCALAR alpha, + const TYPE *A, const int lda, const TYPE *B, const int ldb, + const SCALAR beta, TYPE *C, const int ldc); +void NCmmNNJIK_c + (const enum ATLAS_TRANS TA, const enum ATLAS_TRANS TB, + const int M, const int N, const int K, const SCALAR alpha, + const TYPE *A, const int lda, const TYPE *B, const int ldb, + const SCALAR beta, TYPE *C, const int ldc); +void NCmmNTJIK_c + (const enum ATLAS_TRANS TA, const enum ATLAS_TRANS TB, + const int M, const int N, const int K, const SCALAR alpha, + const TYPE *A, const int lda, const TYPE *B, const int ldb, + const SCALAR beta, TYPE *C, const int ldc); +void NCmmTNJIK_c + (const enum ATLAS_TRANS TA, const enum ATLAS_TRANS TB, + const int M, const int N, const int K, const SCALAR alpha, + const TYPE *A, const int lda, const TYPE *B, const int ldb, + const SCALAR beta, TYPE *C, const int ldc); +void NCmmTTJIK_c + (const enum ATLAS_TRANS TA, const enum ATLAS_TRANS TB, + const int M, const int N, const int K, const SCALAR alpha, + const TYPE *A, const int lda, const TYPE *B, const int ldb, + const SCALAR beta, TYPE *C, const int ldc); +void NCmmNNJIK + (const enum ATLAS_TRANS TA, const enum ATLAS_TRANS TB, + const int M, const int N, const int K, const SCALAR alpha, + const TYPE *A, const int lda, const TYPE *B, const int ldb, + const SCALAR beta, TYPE *C, const int ldc); +void NCmmNTJIK + (const enum ATLAS_TRANS TA, const enum ATLAS_TRANS TB, + const int M, const int N, const int K, const SCALAR alpha, + const TYPE *A, const int lda, const TYPE *B, const int ldb, + const SCALAR beta, TYPE *C, const int ldc); +void NCmmTNJIK + (const enum ATLAS_TRANS TA, const enum ATLAS_TRANS TB, + const int M, const int N, const int K, const SCALAR alpha, + const TYPE *A, const int lda, const TYPE *B, const int ldb, + const SCALAR beta, TYPE *C, const int ldc); +void NCmmTTJIK + (const enum ATLAS_TRANS TA, const enum ATLAS_TRANS TB, + const int M, const int N, const int K, const SCALAR alpha, + const TYPE *A, const int lda, const TYPE *B, const int ldb, + const SCALAR beta, TYPE *C, const int ldc); + +#endif diff --git a/kaldi_io/src/tools/ATLAS/include/atlas_misc.h b/kaldi_io/src/tools/ATLAS/include/atlas_misc.h new file mode 100644 index 0000000..88f754d --- /dev/null +++ b/kaldi_io/src/tools/ATLAS/include/atlas_misc.h @@ -0,0 +1,416 @@ +/* + * Automatically Tuned Linear Algebra Software v3.8.3 + * (C) Copyright 1997 R. Clint Whaley + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions + * are met: + * 1. Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * 2. Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions, and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * 3. The name of the ATLAS group or the names of its contributers may + * not be used to endorse or promote products derived from this + * software without specific written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + * ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED + * TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR + * PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE ATLAS GROUP OR ITS CONTRIBUTORS + * BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + * POSSIBILITY OF SUCH DAMAGE. + * + */ + +#include <stdio.h> +#include <stdlib.h> +#include "atlas_enum.h" + +#ifndef ATLAS_MISC_H +#define ATLAS_MISC_H +#include "atlas_type.h" +#ifdef ATL_PROFILE + extern int ATL_ProfGemmCameFrom; +#endif +/* + * Some useful macro functions + */ +#if (defined(PentiumCPS) || defined(ATL_USEPTHREADS)) && !defined(WALL) + #define WALL +#endif +#ifndef time00 + #if defined(WALL) + #define time00 ATL_walltime + #else + #define time00 ATL_cputime + #endif +#endif +#define Mabs(x) ( (x) >= 0 ? (x) : -(x) ) +#define Mmax(x, y) ( (x) > (y) ? (x) : (y) ) +#define Mmin(x, y) ( (x) > (y) ? (y) : (x) ) +#define Mlowcase(C) ( ((C) > 64 && (C) < 91) ? (C) | 32 : (C) ) +#define Mupcase(C) ( ((C) > 96 && (C) < 123) ? (C) & 0xDF : (C) ) +/* + * packed indexing functions (upper & lower) + */ + +#define Mjoin(pre, nam) my_join(pre, nam) +#define my_join(pre, nam) pre ## nam +#define Mstr2(m) # m +#define Mstr(m) Mstr2(m) + +#define ATL_assert(n_) \ +{ \ + if (!(n_)) \ + { \ + ATL_xerbla(0, __FILE__, "assertion %s failed, line %d of file %s\n", \ + Mstr(n_), __LINE__, __FILE__); \ + } \ +} + +/* + * Define some C99 features that we use when we know the compiler supports them + */ +#if defined(__STDC_VERSION__) && (__STDC_VERSION__/100 >= 1999) + #define INLINE inline + #define RESTRICT restrict +#else + #define INLINE + #define RESTRICT +#endif + +#if defined(SREAL) + #define EPS 5.0e-7 + #define TYPE float + #define PRE s + #define UPR s + #define PREU S + #define PATL ATL_s + #define PATU ATLU_s + #define UATL ATLU_s + #define CBLA cblas_s + #define PATLU ATL_s + #define ATL_rone 1.0f + #define ATL_rnone -1.0f + #define ATL_rzero 0.0f + #define ATL_typify(m_) Mjoin(m_,f) + #include "atlas_ssysinfo.h" +#elif defined(DREAL) + #define EPS 1.0e-15 + #define TYPE double + #define PRE d + #define UPR d + #define PREU D + #define PATL ATL_d + #define PATU ATLU_d + #define UATL ATLU_d + #define CBLA cblas_d + #define PATLU ATL_d + #define ATL_rone 1.0 + #define ATL_rnone -1.0 + #define ATL_rzero 0.0 + #define ATL_typify(m_) m_ + #include "atlas_dsysinfo.h" +#elif defined (QREAL) + #define EPS 1.9259299443872358530559779425849273E-34L + #define TYPE long double + #define PRE q + #define UPR q + #define PREU Q + #define PATL ATL_q + #define PATU ATLU_q + #define CBLA cblas_q +#elif defined(SCPLX) + #define EPS 5.0e-7 + #define TYPE float + #define PRE c + #define UPR s + #define PREU C + #define PATL ATL_c + #define PATLU ATL_s + #define PATU ATLU_c + #define UATL ATLU_s + #define ATL_rone 1.0f + #define ATL_rnone -1.0f + #define ATL_rzero 0.0f + #define ATL_typify(m_) Mjoin(m_,f) + #define CBLA cblas_c + #include "atlas_csysinfo.h" +#elif defined(DCPLX) + #define TYPE double + #define PRE z + #define UPR d + #define PREU Z + #define PATL ATL_z + #define PATLU ATL_d + #define PATU ATLU_z + #define UATL ATLU_d + #define EPS 1.0e-15 + #define ATL_rone 1.0 + #define ATL_rnone -1.0 + #define ATL_rzero 0.0 + #define ATL_typify(m_) m_ + #define CBLA cblas_z + #include "atlas_zsysinfo.h" +#endif + +#if defined (SREAL) || defined (DREAL) || defined (SCPLX) || defined (DCPLX) + #define ATL_sizeof Mjoin(PATL,size) + #define ATL_MulBySize Mjoin(PATL,MulBySize) + #define ATL_DivBySize Mjoin(PATL,DivBySize) +#endif + +#if ( defined(SREAL) || defined(DREAL) || defined(QREAL) ) + #define TREAL + #define SHIFT + #define SCALAR TYPE + #define SADD & + #define SVAL + #define SVVAL * + #define SCALAR_IS_ONE(M_scalar) ((M_scalar) == ATL_rone) + #define SCALAR_IS_NONE(M_scalar) ((M_scalar) == ATL_rnone) + #define SCALAR_IS_ZERO(M_scalar) ((M_scalar) == ATL_rzero) +#elif defined(SCPLX) || defined(DCPLX) + #define TCPLX +/* + * c = b*c + v; + */ + #define CMULT2(v, a, b, tmp) \ + { \ + tmp = *(a) * *(b) - *(a+1) * *(b+1); \ + *(b+1) = *(a) * *(b+1) + *(a+1) * *(b) + *(v+1); \ + *(b) = tmp + *v; \ + } + #define SHIFT << 1 + #define SCALAR TYPE * + #define SADD + #define SVAL * + #define SVVAL + #define SCALAR_IS_ONE(M_scalar) \ + ( (*(M_scalar) == ATL_rone) && ((M_scalar)[1] == ATL_rzero) ) + #define SCALAR_IS_NONE(M_scalar) \ + ( (*(M_scalar) == ATL_rnone) && ((M_scalar)[1] == ATL_rzero) ) + #define SCALAR_IS_ZERO(M_scalar) \ + ( (*(M_scalar) == ATL_rzero) && ((M_scalar)[1] == ATL_rzero) ) +#endif + +#if defined(ALPHA1) + #define ATL_MulByALPHA(x_) (x_) + #define NM _a1 +#elif defined (ALPHA0) + #define ATL_MulByALPHA(x_) ATL_rzero + #define NM _a0 +#elif defined (ALPHAN1) + #define ATL_MulByALPHA(x_) (-(x_)) + #define NM _an1 +#elif defined (ALPHAXI0) + #define ATL_MulByALPHA(x_) (ralpha*(x_)) + #define NM _aXi0 +#elif defined (ALPHA1C) + #define NM _a1c +#elif defined (ALPHAN1C) + #define NM _an1c +#elif defined (ALPHAXI0C) + #define NM _aXi0c +#elif defined (ALPHAXC) + #define NM _aXc +#elif defined (ALPHAX) + #define ATL_MulByALPHA(x_) (alpha*(x_)) + #define NM _aX +#endif + +#if defined(BETA1) + #define ATL_MulByBETA(x_) (x_) + #define MSTAT A[i] += v[i] + #define BNM _b1 +#elif defined(BETA1C) + #define BNM _b1c +#elif defined(BETAN1) + #define ATL_MulByBETA(x_) (-(x_)) + #define MSTAT A[i] = v[i] - A[i] + #define BNM _bn1 +#elif defined(BETAN1C) + #define BNM _bn1c +#elif defined(BETA0) + #define ATL_MulByBETA(x_) ATL_rzero + #define MSTAT A[i] = v[i] + #define BNM _b0 +#elif defined (BETAXI0) + #define BNM _bXi0 + #define ATL_MulByBETA(x_) (rbeta*(x_)) +#elif defined (BETAXI0C) + #define BNM _bXi0c +#elif defined (BETAX) + #define ATL_MulByBETA(x_) (beta*(x_)) + #define MSTAT A[i] = beta*A[i] + v[i] + #define BNM _bX +#elif defined (BETAXC) + #define BNM _bXc +#endif + +/* any alignment below this forces data copy in gemm */ +#ifndef ATL_MinMMAlign + #define ATL_MinMMAlign 16 +#endif +#if (ATL_MinMMAlign == 1 || ATL_MinMMAlign == 0) + #define ATL_DataIsMinAligned(ptr) 1 +#elif (ATL_MinMMAlign == 2) + #define ATL_DataIsMinAligned(ptr) \ + ( (((size_t) (ptr))>>1)<<1 == (size_t) (ptr) ) +#elif (ATL_MinMMAlign == 4) + #define ATL_DataIsMinAligned(ptr) \ + ( (((size_t) (ptr))>>2)<<2 == (size_t) (ptr) ) +#elif (ATL_MinMMAlign == 8) + #define ATL_DataIsMinAligned(ptr) \ + ( (((size_t) (ptr))>>3)<<3 == (size_t) (ptr) ) +#elif (ATL_MinMMAlign == 16) + #define ATL_DataIsMinAligned(ptr) \ + ( (((size_t) (ptr))>>4)<<4 == (size_t) (ptr) ) +#elif (ATL_MinMMAlign == 32) + #define ATL_DataIsMinAligned(ptr) \ + ( (((size_t) (ptr))>>5)<<5 == (size_t) (ptr) ) +#elif (ATL_MinMMAlign == 64) + #define ATL_DataIsMinAligned(ptr) \ + ( (((size_t) (ptr))>>6)<<6 == (size_t) (ptr) ) +#elif (ATL_MinMMAlign == 128) + #define ATL_DataIsMinAligned(ptr) \ + ( (((size_t) (ptr))>>7)<<7 == (size_t) (ptr) ) +#else + #define ATL_DataIsMinAligned(ptr) \ + ( (((size_t) (ptr))/ATL_MinMMAlign)*ATL_MinMMAlign == (size_t) (ptr) ) +#endif + +#define ATL_Cachelen 32 +#if (ATL_Cachelen == 4) + #define ATL_MulByCachelen(N_) ( (N_) << 2 ) + #define ATL_DivByCachelen(N_) ( (N_) >> 2 ) +#elif (ATL_Cachelen == 8) + #define ATL_MulByCachelen(N_) ( (N_) << 3 ) + #define ATL_DivByCachelen(N_) ( (N_) >> 3 ) +#elif (ATL_Cachelen == 16) + #define ATL_MulByCachelen(N_) ( (N_) << 4 ) + #define ATL_DivByCachelen(N_) ( (N_) >> 4 ) +#elif (ATL_Cachelen == 32) + #define ATL_MulByCachelen(N_) ( (N_) << 5 ) + #define ATL_DivByCachelen(N_) ( (N_) >> 5 ) +#elif (ATL_Cachelen == 64) + #define ATL_MulByCachelen(N_) ( (N_) << 6 ) + #define ATL_DivByCachelen(N_) ( (N_) >> 6 ) +#elif (ATL_Cachelen == 128) + #define ATL_MulByCachelen(N_) ( (N_) << 7 ) + #define ATL_DivByCachelen(N_) ( (N_) >> 7 ) +#elif (ATL_Cachelen == 256) + #define ATL_MulByCachelen(N_) ( (N_) << 8 ) + #define ATL_DivByCachelen(N_) ( (N_) >> 8 ) +#else + #define ATL_MulByCachelen(N_) ( (N_) * ATL_Cachelen ) + #define ATL_DivByCachelen(N_) ( (N_) / ATL_Cachelen ) +#endif + +#if (ATL_Cachelen < ATL_MinMMAlign) + Force a compilation error if our required alignment is at least the + minimum!!@^ +#endif + +#define ATL_AlignPtr(vp) \ + (void*) (ATL_Cachelen + ATL_MulByCachelen(ATL_DivByCachelen((size_t) (vp)))) + +#define ATL_FindPtrAdjust(vp, iadj_) \ +{ \ + (iadj_) = ((size_t)(vp))-ATL_MulByCachelen(ATL_DivByCachelen((size_t)(vp)));\ + if (iadj_) \ + { \ + if ( (iadj_) == ATL_MulBySize(ATL_DivBySize(iadj_)) ) \ + (iadj_) = ATL_DivBySize(iadj_); \ + else (iadj_) = 0; \ + }\ +} +#define ATL_FindMatAdjust(vp_, lda_, iadj_) \ +{ \ + if (ATL_MulByCachelen(ATL_DivByCachelen(ATL_MulBySize(lda_))) \ + == ATL_MulBySize(lda_)) \ + { \ + ATL_FindPtrAdjust(vp_, iadj_); \ + } \ + else (iadj_) = 0; \ +} + +#define ATL_sqrtLL(x, res) \ + asm ("fsqrt" : "=t" (res) : "0" (x)); + +/* + * Find N necessary for alignment. Written as function for optimization, + * declared static to encourage inlining + */ +static int ATL_AlignOffset +(const int N, /* max return value */ + const void *vp, /* pointer to be aligned */ + const int inc, /* size of each elt, in bytes */ + const int align) /* required alignment, in bytes */ +{ + const int p = align/inc; + const size_t k=(size_t)vp, j=k/inc; + int iret; + if (k == (j)*inc && p*inc == align) + { + iret = ((j+p-1) / p)*p - j; + if (iret <= N) return(iret); + } + return(N); +} + +/* + * Gcc links in crap that MSVC++ and DVF can't handle if you use stdout + * or stderr, so use this beautiful kludge to avoid this problem -- RCW + */ +#ifdef GCCWIN + +#include <stdarg.h> +static int WINFPRINTF(FILE *fpout, char *form, ...) +{ + int ierr=0; + va_list argptr; + + va_start(argptr, form); + if (fpout == NULL) ierr = vprintf(form, argptr); + else ierr = vfprintf(fpout, form, argptr); + va_end(argptr); + + return(ierr); +} + +#ifdef stdout + #undef stdout +#endif +#ifdef stderr + #undef stderr +#endif +#ifdef assert + #undef assert +#endif + +#define stdout NULL +#define stderr NULL +#define fprintf WINFPRINTF +#define assert WINASSERT +#define WINASSERT(n_) \ +{ \ + if (!(n_)) \ + { \ + printf("assertion %s failed, line %d of file %s\n", \ + Mstr(n_), __LINE__, __FILE__); \ + exit(1); \ + } \ +} + +#endif + +#include "atlas_aux.h" + +#endif diff --git a/kaldi_io/src/tools/ATLAS/include/atlas_mv.h b/kaldi_io/src/tools/ATLAS/include/atlas_mv.h new file mode 100644 index 0000000..f26da5f --- /dev/null +++ b/kaldi_io/src/tools/ATLAS/include/atlas_mv.h @@ -0,0 +1,45 @@ +/* + * Automatically Tuned Linear Algebra Software v3.8.3 + * (C) Copyright 1999 R. Clint Whaley + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions + * are met: + * 1. Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * 2. Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions, and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * 3. The name of the ATLAS group or the names of its contributers may + * not be used to endorse or promote products derived from this + * software without specific written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + * ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED + * TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR + * PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE ATLAS GROUP OR ITS CONTRIBUTORS + * BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + * POSSIBILITY OF SUCH DAMAGE. + * + */ + +#ifndef ATLAS_MV_H + #define ATLAS_MV_H + +#include "atlas_misc.h" +#if defined(SREAL) + #include "atlas_smv.h" +#elif defined(DREAL) + #include "atlas_dmv.h" +#elif defined(SCPLX) + #include "atlas_cmv.h" +#elif defined(DCPLX) + #include "atlas_zmv.h" +#endif + +#endif diff --git a/kaldi_io/src/tools/ATLAS/include/atlas_pkblas.h b/kaldi_io/src/tools/ATLAS/include/atlas_pkblas.h new file mode 100644 index 0000000..b9c7d82 --- /dev/null +++ b/kaldi_io/src/tools/ATLAS/include/atlas_pkblas.h @@ -0,0 +1,569 @@ +/* + * Automatically Tuned Linear Algebra Software v3.8.3 + * (C) Copyright 2003 R. Clint Whaley + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions + * are met: + * 1. Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * 2. Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions, and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * 3. The name of the ATLAS group or the names of its contributers may + * not be used to endorse or promote products derived from this + * software without specific written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + * ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED + * TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR + * PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE ATLAS GROUP OR ITS CONTRIBUTORS + * BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + * POSSIBILITY OF SUCH DAMAGE. + * + */ +#ifndef ATLAS_PKBLAS_H +#define ATLAS_PKBLAS_H + +#include "atlas_misc.h" +#ifndef ATL_NOL3 +#include "atlas_lvl3.h" +#endif + +#define CBLAS_ENUM_ONLY +#include "cblas.h" +#undef CBLAS_ENUM_ONLY + +enum PACK_UPLO {PackUpper=121, PackLower=122, PackGen=123}; + +#define PACK_ORDER CBLAS_ORDER + #define PackRowMajor CblasRowMajor + #define PackColMajor CblasColMajor +#define PACK_TRANS CBLAS_TRANSPOSE + #define PackNoTrans CblasNoTrans + #define PackTrans CblasTrans + #define PackConjTrans CblasConjTrans + #define PackConj AtlasConj +#define PACK_DIAG CBLAS_DIAG + #define PackNonUnit CblasNonUnit + #define PackUnit CblasUnit +#define PACK_SIDE CBLAS_SIDE + #define PackLeft CblasLeft + #define PackRight CblasRight + +#ifndef ATL_pkMaxMalloc + #define ATL_pkMaxMalloc ATL_MaxMalloc +#endif + +#ifdef TCPLX + #define MindexPL(I_,J_,lda_) ( (((J_)*((lda_)+(lda_)-(J_)-1))) + (I_)+(I_) ) + #define MindexPU(I_,J_,lda_) ( ((((lda_)+(lda_)+(J_)-1)*(J_))) + (I_)+(I_) ) +#else + #define MindexPL(I_,J_,lda_) ( (((J_)*((lda_)+(lda_)-(J_)-1))>>1) + (I_) ) + #define MindexPU(I_,J_,lda_) ( ((((lda_)+(lda_)+(J_)-1)*(J_))>>1) + (I_) ) +#endif +#define MindexP(uplo_,I_,J_,lda_) \ + ( (uplo_) == PackUpper ? MindexPU(I_,J_,lda_) : \ + ( (uplo_) == PackLower ? MindexPL(I_,J_,lda_) : \ + (((J_)*(lda_)+(I_))SHIFT) ) ) +#define Mpld(uplo_,J_,lda_) (uplo_) == PackUpper ? (lda_)+(J_) : \ + ( (uplo_) == PackLower ? (lda_)-(J_) : (lda_) ) + + +void ATL_sgpmm(const enum PACK_UPLO UA, const enum PACK_TRANS TA, + const enum PACK_UPLO UB, const enum PACK_TRANS TB, + const enum PACK_UPLO UC, + const int M, const int N, const int K, const float alpha, + const float *A, const int IA, const int JA, const int lda, + const float *B, const int IB, const int JB, const int ldb, + const float beta, float *C, const int IC, const int JC, + const int ldc); +void ATL_sprankK(const enum PACK_UPLO UA, const enum ATLAS_TRANS TA, + const enum PACK_UPLO UB, const enum ATLAS_TRANS TB, + const int M, const int N, const int K, int R, + const SCALAR alpha, const TYPE *A, int lda, + const TYPE *B, int ldb, const SCALAR beta, + const enum PACK_UPLO UC, TYPE *C, int ldc); +int ATL_spmmJIKF(const enum PACK_UPLO UA, const enum ATLAS_TRANS TA, + const enum PACK_UPLO UB, const enum ATLAS_TRANS TB, + const int M, const int N, const int K, const SCALAR alpha, + const TYPE *A, const int lda, const TYPE *B, const int ldb, + const SCALAR beta, const enum PACK_UPLO UC, + TYPE *C, const int ldc); +int ATL_spmmJIK(const enum PACK_UPLO UA, const enum ATLAS_TRANS TA, + const enum PACK_UPLO UB, const enum ATLAS_TRANS TB, + const int M, const int N, const int K, const float alpha, + const float *A, const int lda, const float *B, const int ldb, + const float beta, const enum PACK_UPLO UC, + float *C, const int ldc); +void ATL_spcol2blkF(const int M, const int N, const float alpha, + const float *A, int lda, const int ldainc, float *V); +void ATL_sprow2blkTF(const int M, const int N, const float alpha, + const float *A, int lda, const int ldainc, float *V); +void ATL_spcol2blk_a1(const int M, const int N, const float alpha, + const float *A, int lda, const int ldainc, float *V); +void ATL_spcol2blk_aX(const int M, const int N, const float alpha, + const float *A, int lda, const int ldainc, float *V); +void ATL_sprow2blkT_a1(const int M, const int N, const float alpha, + const float *A, int lda, const int ldainc, float *V); +void ATL_sprow2blkT_aX(const int M, const int N, const float alpha, + const float *A, int lda, const int ldainc, float *V); +void ATL_spputblk(const int M, const int N, const TYPE *V, TYPE *C, + int ldc, int ldcinc, const SCALAR beta); +void ATL_spputblk_diag + (const int M, const int N, const float *V, const enum ATLAS_UPLO UC, + float *C, int ldc, int ldcinc, const float alpha, const float beta); +void ATL_spputblk_aX + (const int M, const int N, const float *V, float *C, int ldc, int ldcinc, + const float alpha, const float beta); +void ATL_ssprk(const enum PACK_UPLO UA, const enum PACK_TRANS TA, + const enum ATLAS_UPLO UC, const int CP, + const int N, const int K, const float alpha, + const float *A, const int IA, const int JA, const int lda, + const float beta, + float *C, const int IC, const int JC, const int ldc); +void ATL_shprk(const enum PACK_UPLO UA, const enum PACK_TRANS TA, + const enum ATLAS_UPLO UC, const int CP, + const int N, const int K, const float alpha, + const float *A, const int IA, const int JA, const int lda, + const float beta, + float *C, const int IC, const int JC, const int ldc); +void ATL_shprk_rK(const enum PACK_UPLO UA, const enum PACK_TRANS TA, + const enum ATLAS_UPLO UC, const int CP, + const int N, const int K, int R, const float alpha, + const float *A, int lda, const float beta, + float *C, const int ldc); +int ATL_sphk_kmm(const enum ATLAS_UPLO UC, const enum PACK_UPLO UA, + const enum ATLAS_TRANS TA, const int N, const int K, + const float alpha, const float *A, const int lda, + const float beta, const int CP, float *C, const int ldc); +void ATL_ssprk_rK(const enum PACK_UPLO UA, const enum PACK_TRANS TA, + const enum ATLAS_UPLO UC, const int CP, + const int N, const int K, int R, const float alpha, + const float *A, int lda, const float beta, + float *C, const int ldc); +int ATL_sprk_kmm(const enum ATLAS_UPLO UC, const enum PACK_UPLO UA, + const enum ATLAS_TRANS TA, const int N, const int K, + const float alpha, const float *A, const int lda, + const float beta, const int CP, float *C, const int ldc); + +void ATL_dgpmm(const enum PACK_UPLO UA, const enum PACK_TRANS TA, + const enum PACK_UPLO UB, const enum PACK_TRANS TB, + const enum PACK_UPLO UC, + const int M, const int N, const int K, const double alpha, + const double *A, const int IA, const int JA, const int lda, + const double *B, const int IB, const int JB, const int ldb, + const double beta, double *C, const int IC, const int JC, + const int ldc); +void ATL_dprankK(const enum PACK_UPLO UA, const enum ATLAS_TRANS TA, + const enum PACK_UPLO UB, const enum ATLAS_TRANS TB, + const int M, const int N, const int K, int R, + const SCALAR alpha, const TYPE *A, int lda, + const TYPE *B, int ldb, const SCALAR beta, + const enum PACK_UPLO UC, TYPE *C, int ldc); +int ATL_dpmmJIKF(const enum PACK_UPLO UA, const enum ATLAS_TRANS TA, + const enum PACK_UPLO UB, const enum ATLAS_TRANS TB, + const int M, const int N, const int K, const SCALAR alpha, + const TYPE *A, const int lda, const TYPE *B, const int ldb, + const SCALAR beta, const enum PACK_UPLO UC, + TYPE *C, const int ldc); +int ATL_dpmmJIK(const enum PACK_UPLO UA, const enum ATLAS_TRANS TA, + const enum PACK_UPLO UB, const enum ATLAS_TRANS TB, + const int M, const int N, const int K, const double alpha, + const double *A, const int lda, const double *B, const int ldb, + const double beta, const enum PACK_UPLO UC, + double *C, const int ldc); +void ATL_dpcol2blkF(const int M, const int N, const double alpha, + const double *A, int lda, const int ldainc, double *V); +void ATL_dprow2blkTF(const int M, const int N, const double alpha, + const double *A, int lda, const int ldainc, double *V); +void ATL_dpcol2blk_a1(const int M, const int N, const double alpha, + const double *A, int lda, const int ldainc, double *V); +void ATL_dpcol2blk_aX(const int M, const int N, const double alpha, + const double *A, int lda, const int ldainc, double *V); +void ATL_dprow2blkT_a1(const int M, const int N, const double alpha, + const double *A, int lda, const int ldainc, double *V); +void ATL_dprow2blkT_aX(const int M, const int N, const double alpha, + const double *A, int lda, const int ldainc, double *V); +void ATL_dpputblk(const int M, const int N, const TYPE *V, TYPE *C, + int ldc, int ldcinc, const SCALAR beta); +void ATL_dpputblk_diag + (const int M, const int N, const double *V, const enum ATLAS_UPLO UC, + double *C, int ldc, int ldcinc, const double alpha, const double beta); +void ATL_dpputblk_aX + (const int M, const int N, const double *V, double *C, int ldc, int ldcinc, + const double alpha, const double beta); +void ATL_dsprk(const enum PACK_UPLO UA, const enum PACK_TRANS TA, + const enum ATLAS_UPLO UC, const int CP, + const int N, const int K, const double alpha, + const double *A, const int IA, const int JA, const int lda, + const double beta, + double *C, const int IC, const int JC, const int ldc); +void ATL_dhprk(const enum PACK_UPLO UA, const enum PACK_TRANS TA, + const enum ATLAS_UPLO UC, const int CP, + const int N, const int K, const double alpha, + const double *A, const int IA, const int JA, const int lda, + const double beta, + double *C, const int IC, const int JC, const int ldc); +void ATL_dhprk_rK(const enum PACK_UPLO UA, const enum PACK_TRANS TA, + const enum ATLAS_UPLO UC, const int CP, + const int N, const int K, int R, const double alpha, + const double *A, int lda, const double beta, + double *C, const int ldc); +int ATL_dphk_kmm(const enum ATLAS_UPLO UC, const enum PACK_UPLO UA, + const enum ATLAS_TRANS TA, const int N, const int K, + const double alpha, const double *A, const int lda, + const double beta, const int CP, double *C, const int ldc); +void ATL_dsprk_rK(const enum PACK_UPLO UA, const enum PACK_TRANS TA, + const enum ATLAS_UPLO UC, const int CP, + const int N, const int K, int R, const double alpha, + const double *A, int lda, const double beta, + double *C, const int ldc); +int ATL_dprk_kmm(const enum ATLAS_UPLO UC, const enum PACK_UPLO UA, + const enum ATLAS_TRANS TA, const int N, const int K, + const double alpha, const double *A, const int lda, + const double beta, const int CP, double *C, const int ldc); + +void ATL_cgpmm(const enum PACK_UPLO UA, const enum PACK_TRANS TA, + const enum PACK_UPLO UB, const enum PACK_TRANS TB, + const enum PACK_UPLO UC, + const int M, const int N, const int K, const float* alpha, + const float *A, const int IA, const int JA, const int lda, + const float *B, const int IB, const int JB, const int ldb, + const float* beta, float *C, const int IC, const int JC, + const int ldc); +void ATL_cprankK(const enum PACK_UPLO UA, const enum ATLAS_TRANS TA, + const enum PACK_UPLO UB, const enum ATLAS_TRANS TB, + const int M, const int N, const int K, int R, + const SCALAR alpha, const TYPE *A, int lda, + const TYPE *B, int ldb, const SCALAR beta, + const enum PACK_UPLO UC, TYPE *C, int ldc); +int ATL_cpmmJIKF(const enum PACK_UPLO UA, const enum ATLAS_TRANS TA, + const enum PACK_UPLO UB, const enum ATLAS_TRANS TB, + const int M, const int N, const int K, const SCALAR alpha, + const TYPE *A, const int lda, const TYPE *B, const int ldb, + const SCALAR beta, const enum PACK_UPLO UC, + TYPE *C, const int ldc); +int ATL_cpmmJIK(const enum PACK_UPLO UA, const enum ATLAS_TRANS TA, + const enum PACK_UPLO UB, const enum ATLAS_TRANS TB, + const int M, const int N, const int K, const float* alpha, + const float *A, const int lda, const float *B, const int ldb, + const float* beta, const enum PACK_UPLO UC, + float *C, const int ldc); +void ATL_cpcol2blkF(const int M, const int N, const float* alpha, + const float *A, int lda, const int ldainc, float *V); +void ATL_cprow2blkTF(const int M, const int N, const float* alpha, + const float *A, int lda, const int ldainc, float *V); +void ATL_cpcol2blk_a1(const int M, const int N, const float* alpha, + const float *A, int lda, const int ldainc, float *V); +void ATL_cpcol2blk_aX(const int M, const int N, const float* alpha, + const float *A, int lda, const int ldainc, float *V); +void ATL_cprow2blkT_a1(const int M, const int N, const float* alpha, + const float *A, int lda, const int ldainc, float *V); +void ATL_cprow2blkT_aX(const int M, const int N, const float* alpha, + const float *A, int lda, const int ldainc, float *V); +void ATL_cpputblk(const int M, const int N, const TYPE *V, TYPE *C, + int ldc, int ldcinc, const SCALAR beta); +void ATL_cpputblk_diag + (const int M, const int N, const float *V, const enum ATLAS_UPLO UC, + float *C, int ldc, int ldcinc, const float* alpha, const float* beta); +void ATL_cpputblk_aX + (const int M, const int N, const float *V, float *C, int ldc, int ldcinc, + const float* alpha, const float* beta); +void ATL_csprk(const enum PACK_UPLO UA, const enum PACK_TRANS TA, + const enum ATLAS_UPLO UC, const int CP, + const int N, const int K, const float* alpha, + const float *A, const int IA, const int JA, const int lda, + const float* beta, + float *C, const int IC, const int JC, const int ldc); +void ATL_chprk(const enum PACK_UPLO UA, const enum PACK_TRANS TA, + const enum ATLAS_UPLO UC, const int CP, + const int N, const int K, const float alpha, + const float *A, const int IA, const int JA, const int lda, + const float beta, + float *C, const int IC, const int JC, const int ldc); +void ATL_chprk_rK(const enum PACK_UPLO UA, const enum PACK_TRANS TA, + const enum ATLAS_UPLO UC, const int CP, + const int N, const int K, int R, const float* alpha, + const float *A, int lda, const float* beta, + float *C, const int ldc); +int ATL_cphk_kmm(const enum ATLAS_UPLO UC, const enum PACK_UPLO UA, + const enum ATLAS_TRANS TA, const int N, const int K, + const float* alpha, const float *A, const int lda, + const float* beta, const int CP, float *C, const int ldc); +void ATL_csprk_rK(const enum PACK_UPLO UA, const enum PACK_TRANS TA, + const enum ATLAS_UPLO UC, const int CP, + const int N, const int K, int R, const float* alpha, + const float *A, int lda, const float* beta, + float *C, const int ldc); +int ATL_cprk_kmm(const enum ATLAS_UPLO UC, const enum PACK_UPLO UA, + const enum ATLAS_TRANS TA, const int N, const int K, + const float* alpha, const float *A, const int lda, + const float* beta, const int CP, float *C, const int ldc); + +void ATL_zgpmm(const enum PACK_UPLO UA, const enum PACK_TRANS TA, + const enum PACK_UPLO UB, const enum PACK_TRANS TB, + const enum PACK_UPLO UC, + const int M, const int N, const int K, const double* alpha, + const double *A, const int IA, const int JA, const int lda, + const double *B, const int IB, const int JB, const int ldb, + const double* beta, double *C, const int IC, const int JC, + const int ldc); +void ATL_zprankK(const enum PACK_UPLO UA, const enum ATLAS_TRANS TA, + const enum PACK_UPLO UB, const enum ATLAS_TRANS TB, + const int M, const int N, const int K, int R, + const SCALAR alpha, const TYPE *A, int lda, + const TYPE *B, int ldb, const SCALAR beta, + const enum PACK_UPLO UC, TYPE *C, int ldc); +int ATL_zpmmJIKF(const enum PACK_UPLO UA, const enum ATLAS_TRANS TA, + const enum PACK_UPLO UB, const enum ATLAS_TRANS TB, + const int M, const int N, const int K, const SCALAR alpha, + const TYPE *A, const int lda, const TYPE *B, const int ldb, + const SCALAR beta, const enum PACK_UPLO UC, + TYPE *C, const int ldc); +int ATL_zpmmJIK(const enum PACK_UPLO UA, const enum ATLAS_TRANS TA, + const enum PACK_UPLO UB, const enum ATLAS_TRANS TB, + const int M, const int N, const int K, const double* alpha, + const double *A, const int lda, const double *B, const int ldb, + const double* beta, const enum PACK_UPLO UC, + double *C, const int ldc); +void ATL_zpcol2blkF(const int M, const int N, const double* alpha, + const double *A, int lda, const int ldainc, double *V); +void ATL_zprow2blkTF(const int M, const int N, const double* alpha, + const double *A, int lda, const int ldainc, double *V); +void ATL_zpcol2blk_a1(const int M, const int N, const double* alpha, + const double *A, int lda, const int ldainc, double *V); +void ATL_zpcol2blk_aX(const int M, const int N, const double* alpha, + const double *A, int lda, const int ldainc, double *V); +void ATL_zprow2blkT_a1(const int M, const int N, const double* alpha, + const double *A, int lda, const int ldainc, double *V); +void ATL_zprow2blkT_aX(const int M, const int N, const double* alpha, + const double *A, int lda, const int ldainc, double *V); +void ATL_zpputblk(const int M, const int N, const TYPE *V, TYPE *C, + int ldc, int ldcinc, const SCALAR beta); +void ATL_zpputblk_diag + (const int M, const int N, const double *V, const enum ATLAS_UPLO UC, + double *C, int ldc, int ldcinc, const double* alpha, const double* beta); +void ATL_zpputblk_aX + (const int M, const int N, const double *V, double *C, int ldc, int ldcinc, + const double* alpha, const double* beta); +void ATL_zsprk(const enum PACK_UPLO UA, const enum PACK_TRANS TA, + const enum ATLAS_UPLO UC, const int CP, + const int N, const int K, const double* alpha, + const double *A, const int IA, const int JA, const int lda, + const double* beta, + double *C, const int IC, const int JC, const int ldc); +void ATL_zhprk(const enum PACK_UPLO UA, const enum PACK_TRANS TA, + const enum ATLAS_UPLO UC, const int CP, + const int N, const int K, const double alpha, + const double *A, const int IA, const int JA, const int lda, + const double beta, + double *C, const int IC, const int JC, const int ldc); +void ATL_zhprk_rK(const enum PACK_UPLO UA, const enum PACK_TRANS TA, + const enum ATLAS_UPLO UC, const int CP, + const int N, const int K, int R, const double* alpha, + const double *A, int lda, const double* beta, + double *C, const int ldc); +int ATL_zphk_kmm(const enum ATLAS_UPLO UC, const enum PACK_UPLO UA, + const enum ATLAS_TRANS TA, const int N, const int K, + const double* alpha, const double *A, const int lda, + const double* beta, const int CP, double *C, const int ldc); +void ATL_zsprk_rK(const enum PACK_UPLO UA, const enum PACK_TRANS TA, + const enum ATLAS_UPLO UC, const int CP, + const int N, const int K, int R, const double* alpha, + const double *A, int lda, const double* beta, + double *C, const int ldc); +int ATL_zprk_kmm(const enum ATLAS_UPLO UC, const enum PACK_UPLO UA, + const enum ATLAS_TRANS TA, const int N, const int K, + const double* alpha, const double *A, const int lda, + const double* beta, const int CP, double *C, const int ldc); + +void ATL_cpcol2blk_aX_blk + (const int blk, const int M, const int N, const float* alpha, + const float *A, int lda, const int ldainc, float *V); +void ATL_cprow2blkT_aX_blk + (const int blk, const int M, const int N, const float* alpha, + const float *A, int lda, const int ldainc, float *V); +void ATL_cpcol2blk_a1_blk + (const int blk, const int M, const int N, const float* alpha, + const float *A, int lda, const int ldainc, float *V); +void ATL_cprow2blkT_a1_blk + (const int blk, const int M, const int N, const float* alpha, + const float *A, int lda, const int ldainc, float *V); +void ATL_cpcol2blkConjF + (const int M, const int N, const float* alpha, + const float *A, int lda, const int ldainc, float *V); +void ATL_cpcol2blkConj_a1 + (const int M, const int N, const float* alpha, + const float *A, int lda, const int ldainc, float *V); +void ATL_cpcol2blkConj_aX + (const int M, const int N, const float* alpha, + const float *A, int lda, const int ldainc, float *V); +void ATL_cpcol2blk_aXi0 + (const int M, const int N, const float* alpha, + const float *A, int lda, const int ldainc, float *V); +void ATL_cpcol2blkConj_aXi0 + (const int M, const int N, const float* alpha, + const float *A, int lda, const int ldainc,float*V); +void ATL_cprow2blkHF + (const int M, const int N, const float* alpha, + const float *A, int lda, const int ldainc, float *V); +void ATL_cprow2blkH_a1 + (const int M, const int N, const float* alpha, + const float *A, int lda, const int ldainc, float *V); +void ATL_cprow2blkH_aX + (const int M, const int N, const float* alpha, + const float *A, int lda, const int ldainc, float *V); +void ATL_cprow2blkH_aXi0 + (const int M, const int N, const float* alpha, + const float *A, int lda, const int ldainc, float *V); +void ATL_cprow2blkT_aXi0 + (const int M, const int N, const float* alpha, + const float *A, int lda, const int ldainc, float *V); +void ATL_cpcol2blkConjF_blk + (const int blk, const int M, const int N, const float* alpha, + const float *A, int lda, const int ldainc, float *V); +void ATL_cpcol2blkConj_a1_blk + (const int blk, const int M, const int N, const float* alpha, + const float *A, int lda, const int ldainc, float *V); +void ATL_cpcol2blkConj_aX_blk + (const int blk, const int M, const int N, const float* alpha, + const float *A, int lda, const int ldainc, float *V); +void ATL_cpcol2blk_aXi0_blk + (const int blk, const int M, const int N, const float* alpha, + const float *A, int lda, const int ldainc, float *V); +void ATL_cpcol2blkConj_aXi0_blk + (const int blk, const int M, const int N, const float* alpha, + const float *A, int lda, const int ldainc,float*V); +void ATL_cprow2blkHF_blk + (const int blk, const int M, const int N, const float* alpha, + const float *A, int lda, const int ldainc, float *V); +void ATL_cprow2blkH_a1_blk + (const int blk, const int M, const int N, const float* alpha, + const float *A, int lda, const int ldainc, float *V); +void ATL_cprow2blkH_aX_blk + (const int blk, const int M, const int N, const float* alpha, + const float *A, int lda, const int ldainc, float *V); +void ATL_cprow2blkH_aXi0_blk + (const int blk, const int M, const int N, const float* alpha, + const float *A, int lda, const int ldainc, float *V); +void ATL_cprow2blkT_aXi0_blk + (const int blk, const int M, const int N, const float* alpha, + const float *A, int lda, const int ldainc, float *V); + +void ATL_cprow2blkT_KB_aXi0 + (const int mb, const int nb, const SCALAR alpha, const TYPE *A, int lda, + const int ldainc, TYPE *V); +void ATL_cprow2blkT_KB_aX + (const int mb, const int nb, const SCALAR alpha, const TYPE *A, int lda, + const int ldainc, TYPE *V); +void ATL_cprow2blkT_KB_a1 + (const int mb, const int nb, const SCALAR alpha, const TYPE *A, int lda, + const int ldainc, TYPE *V); +void ATL_cprow2blkH_KB_aXi0 + (const int mb, const int nb, const SCALAR alpha, const TYPE *A, int lda, + const int ldainc, TYPE *V); +void ATL_cprow2blkH_KB_aX + (const int mb, const int nb, const SCALAR alpha, const TYPE *A, int lda, + const int ldainc, TYPE *V); +void ATL_cprow2blkH_KB_a1 + (const int mb, const int nb, const SCALAR alpha, const TYPE *A, int lda, + const int ldainc, TYPE *V); +void ATL_zpcol2blk_aX_blk + (const int blk, const int M, const int N, const double* alpha, + const double *A, int lda, const int ldainc, double *V); +void ATL_zprow2blkT_aX_blk + (const int blk, const int M, const int N, const double* alpha, + const double *A, int lda, const int ldainc, double *V); +void ATL_zpcol2blk_a1_blk + (const int blk, const int M, const int N, const double* alpha, + const double *A, int lda, const int ldainc, double *V); +void ATL_zprow2blkT_a1_blk + (const int blk, const int M, const int N, const double* alpha, + const double *A, int lda, const int ldainc, double *V); +void ATL_zpcol2blkConjF + (const int M, const int N, const double* alpha, + const double *A, int lda, const int ldainc, double *V); +void ATL_zpcol2blkConj_a1 + (const int M, const int N, const double* alpha, + const double *A, int lda, const int ldainc, double *V); +void ATL_zpcol2blkConj_aX + (const int M, const int N, const double* alpha, + const double *A, int lda, const int ldainc, double *V); +void ATL_zpcol2blk_aXi0 + (const int M, const int N, const double* alpha, + const double *A, int lda, const int ldainc, double *V); +void ATL_zpcol2blkConj_aXi0 + (const int M, const int N, const double* alpha, + const double *A, int lda, const int ldainc,double*V); +void ATL_zprow2blkHF + (const int M, const int N, const double* alpha, + const double *A, int lda, const int ldainc, double *V); +void ATL_zprow2blkH_a1 + (const int M, const int N, const double* alpha, + const double *A, int lda, const int ldainc, double *V); +void ATL_zprow2blkH_aX + (const int M, const int N, const double* alpha, + const double *A, int lda, const int ldainc, double *V); +void ATL_zprow2blkH_aXi0 + (const int M, const int N, const double* alpha, + const double *A, int lda, const int ldainc, double *V); +void ATL_zprow2blkT_aXi0 + (const int M, const int N, const double* alpha, + const double *A, int lda, const int ldainc, double *V); +void ATL_zpcol2blkConjF_blk + (const int blk, const int M, const int N, const double* alpha, + const double *A, int lda, const int ldainc, double *V); +void ATL_zpcol2blkConj_a1_blk + (const int blk, const int M, const int N, const double* alpha, + const double *A, int lda, const int ldainc, double *V); +void ATL_zpcol2blkConj_aX_blk + (const int blk, const int M, const int N, const double* alpha, + const double *A, int lda, const int ldainc, double *V); +void ATL_zpcol2blk_aXi0_blk + (const int blk, const int M, const int N, const double* alpha, + const double *A, int lda, const int ldainc, double *V); +void ATL_zpcol2blkConj_aXi0_blk + (const int blk, const int M, const int N, const double* alpha, + const double *A, int lda, const int ldainc,double*V); +void ATL_zprow2blkHF_blk + (const int blk, const int M, const int N, const double* alpha, + const double *A, int lda, const int ldainc, double *V); +void ATL_zprow2blkH_a1_blk + (const int blk, const int M, const int N, const double* alpha, + const double *A, int lda, const int ldainc, double *V); +void ATL_zprow2blkH_aX_blk + (const int blk, const int M, const int N, const double* alpha, + const double *A, int lda, const int ldainc, double *V); +void ATL_zprow2blkH_aXi0_blk + (const int blk, const int M, const int N, const double* alpha, + const double *A, int lda, const int ldainc, double *V); +void ATL_zprow2blkT_aXi0_blk + (const int blk, const int M, const int N, const double* alpha, + const double *A, int lda, const int ldainc, double *V); + +void ATL_zprow2blkT_KB_aXi0 + (const int mb, const int nb, const SCALAR alpha, const TYPE *A, int lda, + const int ldainc, TYPE *V); +void ATL_zprow2blkT_KB_aX + (const int mb, const int nb, const SCALAR alpha, const TYPE *A, int lda, + const int ldainc, TYPE *V); +void ATL_zprow2blkT_KB_a1 + (const int mb, const int nb, const SCALAR alpha, const TYPE *A, int lda, + const int ldainc, TYPE *V); +void ATL_zprow2blkH_KB_aXi0 + (const int mb, const int nb, const SCALAR alpha, const TYPE *A, int lda, + const int ldainc, TYPE *V); +void ATL_zprow2blkH_KB_aX + (const int mb, const int nb, const SCALAR alpha, const TYPE *A, int lda, + const int ldainc, TYPE *V); +void ATL_zprow2blkH_KB_a1 + (const int mb, const int nb, const SCALAR alpha, const TYPE *A, int lda, + const int ldainc, TYPE *V); + +#endif diff --git a/kaldi_io/src/tools/ATLAS/include/atlas_prefetch.h b/kaldi_io/src/tools/ATLAS/include/atlas_prefetch.h new file mode 100644 index 0000000..83ee2df --- /dev/null +++ b/kaldi_io/src/tools/ATLAS/include/atlas_prefetch.h @@ -0,0 +1,197 @@ +#ifndef ATLAS_PREFETCH_H +#define ATLAS_PREFETCH_H +/* + * Altivec prefetch model not well utilized by SSE-like prefetch, so have + * special commands for it. + */ +#if defined(ATL_AltiVec) + #include "atlas_altivec.h" +#endif +/* + * + * ATL_pfl1R(mem) : fetch location mem to L1, with intent to read *only* + * ATL_pfl1W(mem) : fetch location mem to L1, with intent to read/write + * ATL_pfl1WO(mem) : fetch location mem to L1, with intent to write ONLY + */ + +#if defined(ATL_3DNow) + #ifdef __GNUC__ + #define ATL_pfl1R(mem) \ + __asm__ __volatile__ ("prefetch %0" : : "m" (*((char *)(mem)))) + #define ATL_pfl1W(mem) \ + __asm__ __volatile__ ("prefetchw %0" : : "m" (*((char *)(mem)))) + #define ATL_pfl1WO ATL_pfl1W + #define ATL_GOT_L1PREFETCH + #ifdef ATL_SSE1 + #define ATL_pfl2R(mem) \ + __asm__ __volatile__ ("prefetcht1 %0" : : "m" (*((char *)(mem)))) + #define ATL_pfl2W(mem) \ + __asm__ __volatile__ ("prefetcht1 %0" : : "m" (*((char *)(mem)))) + #define ATL_pfl2WO ATL_pfl2W + #define ATL_GOT_L2PREFETCH + #endif + #endif +#elif defined(ATL_SSE1) || defined (ATL_SSE2) /* SSE prefetch is available */ + #ifdef __GNUC__ + #define ATL_pfl1R(mem) \ + __asm__ __volatile__ ("prefetchnta %0" : : "m" (*((char *)(mem)))) + #define ATL_pfl1W(mem) \ + __asm__ __volatile__ ("prefetchnta %0" : : "m" (*((char *)(mem)))) + #define ATL_pfl1WO ATL_pfl1W + #define ATL_GOT_L1PREFETCH + + #define ATL_pfl2R(mem) \ + __asm__ __volatile__ ("prefetcht1 %0" : : "m" (*((char *)(mem)))) + #define ATL_pfl2W(mem) \ + __asm__ __volatile__ ("prefetcht1 %0" : : "m" (*((char *)(mem)))) + #define ATL_pfl2WO ATL_pfl2W + #define ATL_GOT_L2PREFETCH + #endif +#elif defined(__SUNPRO_C) && defined(__sparc) /* && __SUNPRO_CC > 0x600 */ + #include <sun_prefetch.h> + #define ATL_pfl1R(mem) sparc_prefetch_read_many((void*)(mem)) + #define ATL_pfl1W(mem) sparc_prefetch_write_many((void*)(mem)) + #define ATL_GOT_L1PREFETCH + #define ATL_pfl2R(mem) sparc_prefetch_read_many((void*)(mem)) + #define ATL_pfl2W(mem) sparc_prefetch_write_many((void*)(mem)) + #define ATL_GOT_L2PREFETCH +#elif defined(ATL_ARCH_21264) + #ifdef __GNUC__ + #define ATL_pfl1R(mem) \ + __asm__ __volatile__ ("ldt $f31, %0" : : "m" (*((char *)(mem)))) + #define ATL_pfl1W(mem) \ + __asm__ __volatile__ ("lds $f31, %0" : : "m" (*((char *)(mem)))) + #define ATL_pfl1WO(mem) \ + __asm__ __volatile__ ("wh64 %0" : : "m" (*((char *)(mem)))) + #define ATL_GOT_L1PREFETCH + #elif defined(__DECC) + #include "c_asm.h" + #define ATL_pfl1R(mem) asm ("ldt %f31,(%a0) ;", mem) + #define ATL_pfl1W(mem) asm ("lds %f31,(%a0) ;", mem) + #define ATL_pfl1WO(mem) asm ("wh64 (%a0) ;", mem) + #define ATL_GOT_L1PREFETCH + #endif +/* + * Note: SunUS5/10 seems to get no benefit from prefetch, so don't enable + */ +#elif defined(ATL_ARCH_USIV) || defined(ATL_ARCH_SunUSIII) || \ + defined(ATL_ARCH_SunUSII) || defined(ATL_ARCH_SunUSI) + #ifdef __GNUC__ + #define ATL_pfl1R(mem) \ + __asm__ __volatile__ ("prefetch %0,0" : : "m" (*((char *)(mem)))) + #define ATL_pfl1W(mem) \ + __asm__ __volatile__ ("prefetch %0,2" : : "m" (*((char *)(mem)))) + #define ATL_GOT_L1PREFETCH + #define ATL_pfl2R(mem) \ + __asm__ __volatile__ ("prefetch %0,3" : : "m" (*((char *)(mem)))) + #define ATL_pfl2W(mem) \ + __asm__ __volatile__ ("prefetch %0,2" : : "m" (*((char *)(mem)))) + #define ATL_GOT_L2PREFETCH + #endif +/* + * Gives gigantic slowdown on POWER4, so don't enable there, just use gcc + * builtin + */ +#elif defined(ATL_ARCH_PPCG5) || defined(ATL_ARCH_PPCG5) || \ + defined(ATL_ARCH_POWER5) + #if defined(__GNUC__) || defined(__IBM_GCC_ASM) + #define ATL_pfl1R(mem) \ + __asm__ __volatile__ ("dcbt 0, %0, 0" : : "r" ((mem))) + #define ATL_pfl1W(mem) \ + __asm__ __volatile__ ("dcbtst 0, %0" : : "r" ((mem))) + #define ATL_pfST(mem) \ + __asm__ __volatile__ ("dcbt 0, %0, 1" : : "r" ((mem))) + #define ATL_pfl1STi(mem, str) \ + __asm__ __volatile__ ("rlwinm %0, %0, 0, 0, 24\n\t" \ + "ori %0, %0, 96+%2\n\t" \ + "dcbt 0, %0, 8" \ + : "=r" (mem) \ + : "0" (mem), "i" (str)) + + #define ATL_GOT_L1PREFETCH + #define ATL_L1LS 128 + #endif +#elif defined(ATL_ARCH_IA64Itan) || defined(ATL_ARCH_IA64Itan2) +/* + * Have to use nt2, 'cause fpu ignored L1. + * NOTE: just let icc to prefetch, keep inst here for reference + */ + #if defined(__ECC) && 0 + #include "ia64intrin.h" + #define ATL_pfl1R(mem) __lfetch(2, (mem)) + #define ATL_pfl1W(mem) __lfetch_excl(2, (mem)) + #define ATL_GOT_L1PREFETCH + #elif defined(__GNUC__) && !defined(__ECC) + #define ATL_pfl1R(mem) \ + __asm__ (" lfetch.nt2 [%0]": : "r"((void *)(mem))) + #define ATL_pfl1W(mem) \ + __asm__ (" lfetch.excl [%0]": : "r"((void *)(mem))) + #define ATL_GOT_L1PREFETCH + #endif +#elif defined(ATL_ARCH_HPPA20) && defined(__GNUC__) + #define ATL_pfl1R(mem) \ + __asm__ __volatile__ ("ldw %0, %%r0" : : "m" (*((char *)(mem)))) + #define ATL_pfl1W(mem) \ + __asm__ __volatile__ ("ldd %0, %%r0" : : "m" (*((char *)(mem)))) + #define ATL_GOT_L1PREFETCH +#elif defined(ATL_AltiVec) && !defined(ATL_pfl1R) + #ifndef ATL_NoFakePF + /* 33619968 is ATL_GetCtrl(0, 1, 2), or fetch 1 32-byte block */ + #define ATL_pfl1R(mem) ATL_pfavR(mem, 33619968, 3) + #define ATL_pfl1W(mem) ATL_pfavW(mem, 33619968, 2) + #define ATL_GOT_L1PREFETCH + #endif +#elif defined(ATL_ARCH_MIPSICE9) && defined(__GNUC__) + #define ATL_pfl1R(mem) \ + __asm__ __volatile__ ("pref 6,%0" : : "m" (*((char *)(mem)))) + #define ATL_pfl1W(mem) \ + __asm__ __volatile__ ("pref 7,%0" : : "m" (*((char *)(mem)))) + #define ATL_GOT_L1PREFETCH + #define ATL_L1LS 32 + #define ATL_L2LS 64 +#elif defined(__GNUC__) /* last ditch, use gcc predefined func */ + #define ATL_pfl1R(mem) __builtin_prefetch(mem, 0, 3) + #define ATL_pfl1W(mem) __builtin_prefetch(mem, 1, 3) + #define ATL_GOT_L1PREFETCH +#endif +#if defined(ATL_pfl1W) && !defined(ATL_pfl1WO) + #define ATL_pfl1WO ATL_pfl1W +#endif + +#ifdef ATL_NOL1PREFETCH + #ifdef ATL_GOT_L1PREFETCH + #undef ATL_pfl1R + #undef ATL_pfl1W + #undef ATL_pfl1WO + #undef ATL_GOT_L1PREFETCH + #endif +#endif +#ifdef ATL_NOL2PREFETCH + #ifdef ATL_GOT_L2PREFETCH + #undef ATL_pfl2R + #undef ATL_pfl2W + #undef ATL_pfl2WO + #undef ATL_GOT_L2PREFETCH + #endif +#endif +#ifndef ATL_GOT_L1PREFETCH /* dummy calls cpp takes out of code */ + #define ATL_pfl1R(mem) + #define ATL_pfl1W(mem) + #define ATL_pfl1WO(mem) +#endif +#ifndef ATL_GOT_L2PREFETCH /* dummy calls cpp takes out of code */ + #define ATL_pfl2R(mem) + #define ATL_pfl2W(mem) +#endif + +/* + * Define Cache line sizes for L1 and L2 + */ +#ifndef ATL_L1LS + #define ATL_L1LS 64 +#endif +#ifndef ATL_L2LS + #define ATL_L2LS ATL_L1LS +#endif + +#endif diff --git a/kaldi_io/src/tools/ATLAS/include/atlas_ptalias1.h b/kaldi_io/src/tools/ATLAS/include/atlas_ptalias1.h new file mode 100644 index 0000000..2a45eda --- /dev/null +++ b/kaldi_io/src/tools/ATLAS/include/atlas_ptalias1.h @@ -0,0 +1,60 @@ +#define ATLAS_PTALIAS1_H /* no threaded routs for Level 1 and 2 yet */ +#ifndef ATLAS_PTALIAS1_H +#define ATLAS_PTALIAS1_H +/* + * Real BLAS + */ + #define ATL_dsdot ATL_dsptdot + #define ATL_sdsdot ATL_sdsptdot + #define ATL_sasum ATL_sptasum + #define ATL_snrm2 ATL_sptnrm2 + #define ATL_sdot ATL_sptdot + #define ATL_saxpy ATL_sptaxpy + #define ATL_scopy ATL_sptcopy + #define ATL_sscal ATL_sptscal + #define ATL_sswap ATL_sptswap + #define ATL_srotm ATL_sptrotm + #define ATL_srot ATL_sptrot + #define ATL_srotmg ATL_sptrotmg + #define ATL_srotg ATL_sptrotg + #define ATL_isamax ATL_isptamax + + #define ATL_dasum ATL_dptasum + #define ATL_dnrm2 ATL_dptnrm2 + #define ATL_ddot ATL_dptdot + #define ATL_daxpy ATL_dptaxpy + #define ATL_dcopy ATL_dptcopy + #define ATL_dscal ATL_dptscal + #define ATL_dswap ATL_dptswap + #define ATL_drotm ATL_dptrotm + #define ATL_drot ATL_dptrot + #define ATL_drotmg ATL_dptrotmg + #define ATL_drotg ATL_dptrotg + #define ATL_idamax ATL_idptamax + +/* + * Complex BLAS + */ + #define ATL_cdotc_sub ATL_cptdotc_sub + #define ATL_cdotu_sub ATL_cptdotu_sub + #define ATL_caxpy ATL_cptaxpy + #define ATL_ccopy ATL_cptcopy + #define ATL_cscal ATL_cptscal + #define ATL_cswap ATL_cptswap + #define ATL_icamax ATL_icptamax + #define ATL_csscal ATL_csptscal + #define ATL_scnrm2 ATL_scptnrm2 + #define ATL_scasum ATL_scptasum + + #define ATL_zdotc_sub ATL_zptdotc_sub + #define ATL_zdotu_sub ATL_zptdotu_sub + #define ATL_zaxpy ATL_zptaxpy + #define ATL_zcopy ATL_zptcopy + #define ATL_zscal ATL_zptscal + #define ATL_zswap ATL_zptswap + #define ATL_izamax ATL_izptamax + #define ATL_zdscal ATL_zdptscal + #define ATL_dznrm2 ATL_dzptnrm2 + #define ATL_dzasum ATL_dzptasum + +#endif diff --git a/kaldi_io/src/tools/ATLAS/include/atlas_ptalias2.h b/kaldi_io/src/tools/ATLAS/include/atlas_ptalias2.h new file mode 100644 index 0000000..66b1e0e --- /dev/null +++ b/kaldi_io/src/tools/ATLAS/include/atlas_ptalias2.h @@ -0,0 +1,80 @@ +#define ATLAS_PTALIAS2_H /* no threaded routs for Level 1 and 2 yet */ +#ifndef ATLAS_PTALIAS2_H +#define ATLAS_PTALIAS2_H +/* + * Real BLAS + */ + #define ATL_sspr2 ATL_sptspr2 + #define ATL_ssyr2 ATL_sptsyr2 + #define ATL_sspr ATL_sptspr + #define ATL_ssyr ATL_sptsyr + #define ATL_sger ATL_sptger + #define ATL_stpsv ATL_spttpsv + #define ATL_stbsv ATL_spttbsv + #define ATL_strsv ATL_spttrsv + #define ATL_stpmv ATL_spttpmv + #define ATL_stbmv ATL_spttbmv + #define ATL_strmv ATL_spttrmv + #define ATL_sspmv ATL_sptspmv + #define ATL_ssbmv ATL_sptsbmv + #define ATL_ssymv ATL_sptsymv + #define ATL_sgbmv ATL_sptgbmv + #define ATL_sgemv ATL_sptgemv + + #define ATL_dspr2 ATL_dptspr2 + #define ATL_dsyr2 ATL_dptsyr2 + #define ATL_dspr ATL_dptspr + #define ATL_dsyr ATL_dptsyr + #define ATL_dger ATL_dptger + #define ATL_dtpsv ATL_dpttpsv + #define ATL_dtbsv ATL_dpttbsv + #define ATL_dtrsv ATL_dpttrsv + #define ATL_dtpmv ATL_dpttpmv + #define ATL_dtbmv ATL_dpttbmv + #define ATL_dtrmv ATL_dpttrmv + #define ATL_dspmv ATL_dptspmv + #define ATL_dsbmv ATL_dptsbmv + #define ATL_dsymv ATL_dptsymv + #define ATL_dgbmv ATL_dptgbmv + #define ATL_dgemv ATL_dptgemv + +/* + * Complex BLAS + */ + #define ATL_chpr2 ATL_cpthpr2 + #define ATL_cher2 ATL_cpther2 + #define ATL_chpr ATL_cpthpr + #define ATL_cher ATL_cpther + #define ATL_cgerc ATL_cptgerc + #define ATL_cgeru ATL_cptgeru + #define ATL_ctpsv ATL_cpttpsv + #define ATL_ctbsv ATL_cpttbsv + #define ATL_ctrsv ATL_cpttrsv + #define ATL_ctpmv ATL_cpttpmv + #define ATL_ctbmv ATL_cpttbmv + #define ATL_ctrmv ATL_cpttrmv + #define ATL_chpmv ATL_cpthpmv + #define ATL_chbmv ATL_cpthbmv + #define ATL_chemv ATL_cpthemv + #define ATL_cgbmv ATL_cptgbmv + #define ATL_cgemv ATL_cptgemv + + #define ATL_zhpr2 ATL_zpthpr2 + #define ATL_zher2 ATL_zpther2 + #define ATL_zhpr ATL_zpthpr + #define ATL_zher ATL_zpther + #define ATL_zgerc ATL_zptgerc + #define ATL_zgeru ATL_zptgeru + #define ATL_ztpsv ATL_zpttpsv + #define ATL_ztbsv ATL_zpttbsv + #define ATL_ztrsv ATL_zpttrsv + #define ATL_ztpmv ATL_zpttpmv + #define ATL_ztbmv ATL_zpttbmv + #define ATL_ztrmv ATL_zpttrmv + #define ATL_zhpmv ATL_zpthpmv + #define ATL_zhbmv ATL_zpthbmv + #define ATL_zhemv ATL_zpthemv + #define ATL_zgbmv ATL_zptgbmv + #define ATL_zgemv ATL_zptgemv + +#endif diff --git a/kaldi_io/src/tools/ATLAS/include/atlas_ptalias3.h b/kaldi_io/src/tools/ATLAS/include/atlas_ptalias3.h new file mode 100644 index 0000000..2a25d23 --- /dev/null +++ b/kaldi_io/src/tools/ATLAS/include/atlas_ptalias3.h @@ -0,0 +1,43 @@ +#ifndef ATLAS_PTALIAS3_H +#define ATLAS_PTALIAS3_H +/* + * Real BLAS + */ + #define ATL_strsm ATL_spttrsm + #define ATL_strmm ATL_spttrmm + #define ATL_ssyr2k ATL_sptsyr2k + #define ATL_ssyrk ATL_sptsyrk + #define ATL_ssymm ATL_sptsymm + #define ATL_sgemm ATL_sptgemm + + #define ATL_dtrsm ATL_dpttrsm + #define ATL_dtrmm ATL_dpttrmm + #define ATL_dsyr2k ATL_dptsyr2k + #define ATL_dsyrk ATL_dptsyrk + #define ATL_dsymm ATL_dptsymm + #define ATL_dgemm ATL_dptgemm + +/* + * Complex BLAS + */ + #define ATL_ctrsm ATL_cpttrsm + #define ATL_ctrmm ATL_cpttrmm + #define ATL_cher2k ATL_cpther2k + #define ATL_csyr2k ATL_cptsyr2k + #define ATL_cherk ATL_cptherk + #define ATL_csyrk ATL_cptsyrk + #define ATL_chemm ATL_cpthemm + #define ATL_csymm ATL_cptsymm + #define ATL_cgemm ATL_cptgemm + + #define ATL_ztrsm ATL_zpttrsm + #define ATL_ztrmm ATL_zpttrmm + #define ATL_zher2k ATL_zpther2k + #define ATL_zsyr2k ATL_zptsyr2k + #define ATL_zherk ATL_zptherk + #define ATL_zsyrk ATL_zptsyrk + #define ATL_zhemm ATL_zpthemm + #define ATL_zsymm ATL_zptsymm + #define ATL_zgemm ATL_zptgemm + +#endif diff --git a/kaldi_io/src/tools/ATLAS/include/atlas_ptlevel3.h b/kaldi_io/src/tools/ATLAS/include/atlas_ptlevel3.h new file mode 100644 index 0000000..d1bded3 --- /dev/null +++ b/kaldi_io/src/tools/ATLAS/include/atlas_ptlevel3.h @@ -0,0 +1,284 @@ + +/* --------------------------------------------------------------------- + * + * -- Automatically Tuned Linear Algebra Software (ATLAS) + * (C) Copyright 2000 All Rights Reserved + * + * -- ATLAS routine -- Version 3.2 -- December 25, 2000 + * + * Author : Antoine P. Petitet + * Originally developed at the University of Tennessee, + * Innovative Computing Laboratory, Knoxville TN, 37996-1301, USA. + * + * --------------------------------------------------------------------- + * + * -- Copyright notice and Licensing terms: + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions + * are met: + * + * 1. Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * 2. Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions, and the following disclaimer in + * the documentation and/or other materials provided with the distri- + * bution. + * 3. The name of the University, the ATLAS group, or the names of its + * contributors may not be used to endorse or promote products deri- + * ved from this software without specific written permission. + * + * -- Disclaimer: + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + * ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE UNIVERSITY + * OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPE- + * CIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED + * TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, + * OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEO- + * RY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (IN- + * CLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF + * THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + * --------------------------------------------------------------------- + */ +#ifndef ATLAS_PTLEVEL3_H +#define ATLAS_PTLEVEL3_H +/* + * ===================================================================== + * Include files + * ===================================================================== + */ +#include "atlas_enum.h" +#include "atlas_pthreads.h" +/* + * ===================================================================== + * Prototypes for single precision real Level 3 multi-threaded ATLAS + * BLAS routines. + * ===================================================================== + */ +void ATL_sptgeadd +( const int, const int, const float, const float *, + const int, const float, float *, const int ); +void ATL_sptgezero +( const int, const int, float *, const int ); +void ATL_sptgescal +( const int, const int, const float, float *, + const int ); +void ATL_spttrscal +( const enum ATLAS_UPLO, const int, const int, + const float, float *, const int ); + +void ATL_sptgemm +( const enum ATLAS_TRANS, const enum ATLAS_TRANS, + const int, const int, const int, const float, + const float *, const int, const float *, const int, + const float, float *, const int ); +void ATL_sptsymm +( const enum ATLAS_SIDE, const enum ATLAS_UPLO, + const int, const int, const float, const float *, + const int, const float *, const int, const float, + float *, const int ); +void ATL_sptsyrk +( const enum ATLAS_UPLO, const enum ATLAS_TRANS, + const int, const int, const float, const float *, + const int, const float, float *, const int ); +void ATL_sptsyr2k +( const enum ATLAS_UPLO, const enum ATLAS_TRANS, + const int, const int, const float, const float *, + const int, const float *, const int, const float, + float *, const int ); +void ATL_spttrmm +( const enum ATLAS_SIDE, const enum ATLAS_UPLO, + const enum ATLAS_TRANS, const enum ATLAS_DIAG, + const int, const int, const float, const float *, + const int, float *, const int ); +void ATL_spttrsm +( const enum ATLAS_SIDE, const enum ATLAS_UPLO, + const enum ATLAS_TRANS, const enum ATLAS_DIAG, + const int, const int, const float, const float *, + const int, float *, const int ); +/* + * ===================================================================== + * Prototypes for double precision real Level 3 multi-threaded ATLAS + * BLAS routines. + * ===================================================================== + */ +void ATL_dptgeadd +( const int, const int, const double, const double *, + const int, const double, double *, const int ); +void ATL_dptgezero +( const int, const int, double *, const int ); +void ATL_dptgescal +( const int, const int, const double, double *, + const int ); +void ATL_dpttrscal +( const enum ATLAS_UPLO, const int, const int, + const double, double *, const int ); + +void ATL_dptgemm +( const enum ATLAS_TRANS, const enum ATLAS_TRANS, + const int, const int, const int, const double, + const double *, const int, const double *, const int, + const double, double *, const int ); +void ATL_dptsymm +( const enum ATLAS_SIDE, const enum ATLAS_UPLO, + const int, const int, const double, const double *, + const int, const double *, const int, const double, + double *, const int ); +void ATL_dptsyrk +( const enum ATLAS_UPLO, const enum ATLAS_TRANS, + const int, const int, const double, const double *, + const int, const double, double *, const int ); +void ATL_dptsyr2k +( const enum ATLAS_UPLO, const enum ATLAS_TRANS, + const int, const int, const double, const double *, + const int, const double *, const int, const double, + double *, const int ); +void ATL_dpttrmm +( const enum ATLAS_SIDE, const enum ATLAS_UPLO, + const enum ATLAS_TRANS, const enum ATLAS_DIAG, + const int, const int, const double, const double *, + const int, double *, const int ); +void ATL_dpttrsm +( const enum ATLAS_SIDE, const enum ATLAS_UPLO, + const enum ATLAS_TRANS, const enum ATLAS_DIAG, + const int, const int, const double, const double *, + const int, double *, const int ); +/* + * ===================================================================== + * Prototypes for single precision complex Level 3 multi-threaded ATLAS + * BLAS routines. + * ===================================================================== + */ +void ATL_cptgeadd +( const int, const int, const float *, const float *, + const int, const float *, float *, const int ); +void ATL_cptgezero +( const int, const int, float *, const int ); +void ATL_cptgescal +( const int, const int, const float *, float *, + const int ); +void ATL_cpttrscal +( const enum ATLAS_UPLO, const int, const int, + const float *, float *, const int ); +void ATL_cpthescal +( const enum ATLAS_UPLO, const int, const int, + const float, float *, const int ); + +void ATL_cptgemm +( const enum ATLAS_TRANS, const enum ATLAS_TRANS, + const int, const int, const int, const float *, + const float *, const int, const float *, const int, + const float *, float *, const int ); +void ATL_cptsymm +( const enum ATLAS_SIDE, const enum ATLAS_UPLO, + const int, const int, const float *, const float *, + const int, const float *, const int, const float *, + float *, const int ); +void ATL_cptsyrk +( const enum ATLAS_UPLO, const enum ATLAS_TRANS, + const int, const int, const float *, const float *, + const int, const float *, float *, const int ); +void ATL_cptsyr2k +( const enum ATLAS_UPLO, const enum ATLAS_TRANS, + const int, const int, const float *, const float *, + const int, const float *, const int, const float *, + float *, const int ); +void ATL_cpttrmm +( const enum ATLAS_SIDE, const enum ATLAS_UPLO, + const enum ATLAS_TRANS, const enum ATLAS_DIAG, + const int, const int, const float *, const float *, + const int, float *, const int ); +void ATL_cpttrsm +( const enum ATLAS_SIDE, const enum ATLAS_UPLO, + const enum ATLAS_TRANS, const enum ATLAS_DIAG, + const int, const int, const float *, const float *, + const int, float *, const int ); +/* + * ===================================================================== + * Prototypes for double precision complex Level 3 multi-threaded ATLAS + * BLAS routines. + * ===================================================================== + */ +void ATL_zptgeadd +( const int, const int, const double *, const double *, + const int, const double *, double *, const int ); +void ATL_zptgezero +( const int, const int, double *, const int ); +void ATL_zptgescal +( const int, const int, const double *, double *, + const int ); +void ATL_zpttrscal +( const enum ATLAS_UPLO, const int, const int, + const double *, double *, const int ); +void ATL_zpthescal +( const enum ATLAS_UPLO, const int, const int, + const double, double *, const int ); + +void ATL_zptgemm +( const enum ATLAS_TRANS, const enum ATLAS_TRANS, + const int, const int, const int, const double *, + const double *, const int, const double *, const int, + const double *, double *, const int ); +void ATL_zptsymm +( const enum ATLAS_SIDE, const enum ATLAS_UPLO, + const int, const int, const double *, const double *, + const int, const double *, const int, const double *, + double *, const int ); +void ATL_zptsyrk +( const enum ATLAS_UPLO, const enum ATLAS_TRANS, + const int, const int, const double *, const double *, + const int, const double *, double *, const int ); +void ATL_zptsyr2k +( const enum ATLAS_UPLO, const enum ATLAS_TRANS, + const int, const int, const double *, const double *, + const int, const double *, const int, const double *, + double *, const int ); +void ATL_zpttrmm +( const enum ATLAS_SIDE, const enum ATLAS_UPLO, + const enum ATLAS_TRANS, const enum ATLAS_DIAG, + const int, const int, const double *, const double *, + const int, double *, const int ); +void ATL_zpttrsm +( const enum ATLAS_SIDE, const enum ATLAS_UPLO, + const enum ATLAS_TRANS, const enum ATLAS_DIAG, + const int, const int, const double *, const double *, + const int, double *, const int ); + +void ATL_cpthemm +( const enum ATLAS_SIDE, const enum ATLAS_UPLO, + const int, const int, const float *, const float *, + const int, const float *, const int, const float *, + float *, const int ); +void ATL_cptherk +( const enum ATLAS_UPLO, const enum ATLAS_TRANS, + const int, const int, const float, const float *, + const int, const float, float *, const int ); +void ATL_cpther2k +( const enum ATLAS_UPLO, const enum ATLAS_TRANS, + const int, const int, const float *, const float *, + const int, const float *, const int, const float, + float *, const int ); + +void ATL_zpthemm +( const enum ATLAS_SIDE, const enum ATLAS_UPLO, + const int, const int, const double *, const double *, + const int, const double *, const int, const double *, + double *, const int ); +void ATL_zptherk +( const enum ATLAS_UPLO, const enum ATLAS_TRANS, + const int, const int, const double, const double *, + const int, const double, double *, const int ); +void ATL_zpther2k +( const enum ATLAS_UPLO, const enum ATLAS_TRANS, + const int, const int, const double *, const double *, + const int, const double *, const int, const double, + double *, const int ); + +#endif +/* + * End of atlas_ptlevel3.h + */ diff --git a/kaldi_io/src/tools/ATLAS/include/atlas_ptlvl3.h b/kaldi_io/src/tools/ATLAS/include/atlas_ptlvl3.h new file mode 100644 index 0000000..916afd0 --- /dev/null +++ b/kaldi_io/src/tools/ATLAS/include/atlas_ptlvl3.h @@ -0,0 +1,389 @@ + +/* --------------------------------------------------------------------- + * + * -- Automatically Tuned Linear Algebra Software (ATLAS) + * (C) Copyright 2000 All Rights Reserved + * + * -- ATLAS routine -- Version 3.2 -- December 25, 2000 + * + * Author : Antoine P. Petitet + * Originally developed at the University of Tennessee, + * Innovative Computing Laboratory, Knoxville TN, 37996-1301, USA. + * + * --------------------------------------------------------------------- + * + * -- Copyright notice and Licensing terms: + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions + * are met: + * + * 1. Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * 2. Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions, and the following disclaimer in + * the documentation and/or other materials provided with the distri- + * bution. + * 3. The name of the University, the ATLAS group, or the names of its + * contributors may not be used to endorse or promote products deri- + * ved from this software without specific written permission. + * + * -- Disclaimer: + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + * ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE UNIVERSITY + * OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPE- + * CIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED + * TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, + * OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEO- + * RY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (IN- + * CLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF + * THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + * --------------------------------------------------------------------- + */ +#ifndef ATLAS_PTLVL3_H +#define ATLAS_PTLVL3_H +/* + * ===================================================================== + * Include files + * ===================================================================== + */ +#include "atlas_ptmisc.h" +#include "atlas_level3.h" +#include "atlas_rblas3.h" +/* + * ===================================================================== + * macro constants + * ===================================================================== + */ +#ifdef TREAL +#define ATL_XOVER_L3_DEFAULT 8 /* number of NB x NB blocks */ +#else +#define ATL_XOVER_L3_DEFAULT 4 +#endif +/* + * ===================================================================== + * macro functions + * ===================================================================== + */ +#define Mpt3( a_, i_, siz_ ) ( ( (char*)(a_) + ( (i_) * (siz_) ) ) ) +#define Mvpt3( a_, i_, siz_ ) ( (void *)(Mpt3( (a_), (i_), (siz_) ))) +/* + * ===================================================================== + * typedef definitions + * ===================================================================== + */ +typedef PT_TREE_T (*PT_GEMM_FUN_T) +( + const unsigned int, pthread_attr_t *, + const enum ATLAS_TRANS, const enum ATLAS_TRANS, + const int, const int, const int, const void *, + const void *, const int, const void *, const int, + const void *, void *, const int +); + +typedef PT_TREE_T (*PT_TRMM_FUN_T) +( + const unsigned int, pthread_attr_t *, + const enum ATLAS_SIDE, const enum ATLAS_UPLO, + const enum ATLAS_TRANS, const enum ATLAS_DIAG, + const int, const int, const void *, const void *, + const int, void *, const int +); + +typedef int (*PT_SYR2K_FUN_T) +( + const unsigned int, pthread_attr_t *, + const enum ATLAS_UPLO, const enum ATLAS_TRANS, + const enum ATLAS_TRANS, const int, const int, + const void *, const void *, const int, const void *, + const int, const void *, void *, const int +); + + +typedef struct +{ + size_t size; + void * negone, * one, * zero; + PT_FUN_T geadd0, gemm0, symm0, hemm0, syrk0, syr2k0, + herk0, her2k0, trmm0, trsm0; + PT_GEMM_FUN_T ptgemm; + PT_TRMM_FUN_T pttrmm; + PT_SYR2K_FUN_T ptsyr2k0, pther2k0; +} PT_LVL3_TYPE_T; + +typedef struct +{ + const void * a, * al, * b, * be; + void * c; + enum ATLAS_TRANS ta, tb; + int k, la, lb, lc, m, n; +} PT_GEMM_ARGS_T; + +typedef struct +{ + const void * a, * al, * b, * be; + void * c; + enum ATLAS_SIDE si; + enum ATLAS_UPLO up; + int la, lb, lc, m, n; +} PT_SYMM_ARGS_T; + +typedef struct +{ + const void * a, * al, * be; + void * c; + enum ATLAS_UPLO up; + enum ATLAS_TRANS tr; + int l, la, lc, m, n, k; +} PT_SYRK_ARGS_T; + +typedef struct +{ + const void * a, * al, * ac, * b, * be; + void * c; + enum ATLAS_UPLO up; + enum ATLAS_TRANS tr; + int l, la, lb, lc, m, n, k; +} PT_SYR2K_ARGS_T; + +typedef struct +{ + const void * a, * al; + void * b; + enum ATLAS_SIDE si; + enum ATLAS_UPLO up; + enum ATLAS_TRANS tr; + enum ATLAS_DIAG di; + int la, lb, m, n; +} PT_TRMM_ARGS_T; + +typedef struct +{ + const void * a, * al; + void * b; + enum ATLAS_SIDE si; + enum ATLAS_UPLO up; + enum ATLAS_TRANS tr; + enum ATLAS_DIAG di; + int la, lb, m, n; +} PT_TRSM_ARGS_T; + +/* + * ===================================================================== + * Function prototypes + * ===================================================================== + */ +PT_TREE_T ATL_Sgemm +( const PT_LVL3_TYPE_T *, const unsigned int, + const unsigned int, pthread_attr_t *, const int, + const enum ATLAS_TRANS, const enum ATLAS_TRANS, + const int, const int, const int, const void *, + const void *, const int, const void *, const int, + const void *, void *, const int ); +PT_TREE_T ATL_Ssymm +( const PT_LVL3_TYPE_T *, const unsigned int, + const unsigned int, pthread_attr_t *, const int, + const enum ATLAS_TRANS, const enum ATLAS_SIDE, + const enum ATLAS_UPLO, const int, const int, + const void *, const void *, const int, const void *, + const int, const void *, void *, const int ); +PT_TREE_T ATL_Ssyrk +( const PT_LVL3_TYPE_T *, const unsigned int, + const unsigned int, pthread_attr_t *, const int, + const enum ATLAS_UPLO, const enum ATLAS_TRANS, + const enum ATLAS_TRANS, const int, const int, + const int, const int, const void *, const void *, + const int, const void *, void *, const int ); +PT_TREE_T ATL_Ssyr2k +( const PT_LVL3_TYPE_T *, const unsigned int, + const unsigned int, pthread_attr_t *, const int, + const enum ATLAS_UPLO, const enum ATLAS_TRANS, + const enum ATLAS_TRANS, const int, const int, + const int, const int, const void *, const void *, + const void *, const int, const void *, const int, + const void *, void *, const int ); +PT_TREE_T ATL_Strmm +( const PT_LVL3_TYPE_T *, const unsigned int, + const unsigned int, pthread_attr_t *, const int, + const enum ATLAS_SIDE, const enum ATLAS_UPLO, + const enum ATLAS_TRANS, const enum ATLAS_DIAG, + const int, const int, const void *, const void *, + const int, void *, const int ); +PT_TREE_T ATL_Strsm +( const PT_LVL3_TYPE_T *, const unsigned int, + const unsigned int, pthread_attr_t *, const int, + const enum ATLAS_SIDE, const enum ATLAS_UPLO, + const enum ATLAS_TRANS, const enum ATLAS_DIAG, + const int, const int, const void *, const void *, + const int, void *, const int ); + +#if defined( TREAL ) || defined( TCPLX ) + +int Mjoin( PATL, GetNB ) ( void ); + +void Mjoin( PATL, ptl3settype ) ( PT_LVL3_TYPE_T * ); + +void Mjoin( PATL, gemmNN ) +( const int, const int, const int, const SCALAR, + const TYPE *, const int, const TYPE *, const int, + const SCALAR, TYPE *, const int ); +void Mjoin( PATL, gemmNT ) +( const int, const int, const int, const SCALAR, + const TYPE *, const int, const TYPE *, const int, + const SCALAR, TYPE *, const int ); +void Mjoin( PATL, gemmTN ) +( const int, const int, const int, const SCALAR, + const TYPE *, const int, const TYPE *, const int, + const SCALAR, TYPE *, const int ); + +#if defined( TCPLX ) +void Mjoin( PATL, gemmNC ) +( const int, const int, const int, const SCALAR, + const TYPE *, const int, const TYPE *, const int, + const SCALAR, TYPE *, const int ); +void Mjoin( PATL, gemmCN ) +( const int, const int, const int, const SCALAR, + const TYPE *, const int, const TYPE *, const int, + const SCALAR, TYPE *, const int ); +#endif + +PT_FUN_ARG_T Mjoin( PATL, ptgemm0 ) ( PT_FUN_ARG_T ); +PT_FUN_ARG_T Mjoin( PATL, ptsymm0 ) ( PT_FUN_ARG_T ); +PT_FUN_ARG_T Mjoin( PATL, ptsyr2k0 ) ( PT_FUN_ARG_T ); +PT_FUN_ARG_T Mjoin( PATL, ptsyrk0 ) ( PT_FUN_ARG_T ); +PT_FUN_ARG_T Mjoin( PATL, pttrmm0 ) ( PT_FUN_ARG_T ); +PT_FUN_ARG_T Mjoin( PATL, pttrsm0 ) ( PT_FUN_ARG_T ); + +#if defined( TCPLX ) +PT_FUN_ARG_T Mjoin( PATL, pthemm0 ) ( PT_FUN_ARG_T ); +PT_FUN_ARG_T Mjoin( PATL, pther2k0 ) ( PT_FUN_ARG_T ); +PT_FUN_ARG_T Mjoin( PATL, ptherk0 ) ( PT_FUN_ARG_T ); +#endif +/* + * ===================================================================== + * Prototypes for the Level 3 multi-threaded ATLAS BLAS routines + * ===================================================================== + */ +PT_TREE_T Mjoin( PATL, ptgemm_nt ) +( const unsigned int, pthread_attr_t *, + const enum ATLAS_TRANS, const enum ATLAS_TRANS, + const int, const int, const int, const void *, + const void *, const int, const void *, const int, + const void *, void *, const int ); +PT_TREE_T Mjoin( PATL, ptsymm_nt ) +( const unsigned int, pthread_attr_t *, + const enum ATLAS_SIDE, const enum ATLAS_UPLO, + const int, const int, const void *, const void *, + const int, const void *, const int, const void *, + void *, const int ); +PT_TREE_T Mjoin( PATL, ptsyr2k_nt ) +( const unsigned int, pthread_attr_t *, + const enum ATLAS_UPLO, const enum ATLAS_TRANS, + const int, const int, const void *, const void *, + const void *, const int, const void *, const int, + const void *, void *, const int ); +int Mjoin( PATL, ptsyr2k0_nt ) +( const unsigned int, pthread_attr_t *, + const enum ATLAS_UPLO, const enum ATLAS_TRANS, + const enum ATLAS_TRANS, const int, const int, + const void *, const void *, const int, const void *, + const int, const void *, void *, const int ); +PT_TREE_T Mjoin( PATL, ptsyrk_nt ) +( const unsigned int, pthread_attr_t *, + const enum ATLAS_UPLO, const enum ATLAS_TRANS, + const int, const int, const void *, const void *, + const int, const void *, void *, const int ); +PT_TREE_T Mjoin( PATL, pttrmm_nt ) +( const unsigned int, pthread_attr_t *, + const enum ATLAS_SIDE, const enum ATLAS_UPLO, + const enum ATLAS_TRANS, const enum ATLAS_DIAG, + const int, const int, const void *, const void *, + const int, void *, const int ); +PT_TREE_T Mjoin( PATL, pttrsm_nt ) +( const unsigned int, pthread_attr_t *, + const enum ATLAS_SIDE, const enum ATLAS_UPLO, + const enum ATLAS_TRANS, const enum ATLAS_DIAG, + const int, const int, const void *, const void *, + const int, void *, const int ); + +void Mjoin( PATL, ptgemm ) +( const enum ATLAS_TRANS, const enum ATLAS_TRANS, + const int, const int, const int, const SCALAR, + const TYPE *, const int, const TYPE *, const int, + const SCALAR, TYPE *, const int ); +void Mjoin( PATL, ptsymm ) +( const enum ATLAS_SIDE, const enum ATLAS_UPLO, + const int, const int, const SCALAR, const TYPE *, + const int, const TYPE *, const int, const SCALAR, + TYPE *, const int ); +void Mjoin( PATL, ptsyr2k ) +( const enum ATLAS_UPLO, const enum ATLAS_TRANS, + const int, const int, const SCALAR, const TYPE *, + const int, const TYPE *, const int, const SCALAR, + TYPE *, const int ); +void Mjoin( PATL, ptsyrk ) +( const enum ATLAS_UPLO, const enum ATLAS_TRANS, + const int, const int, const SCALAR, const TYPE *, + const int, const SCALAR, TYPE *, const int ); +void Mjoin( PATL, pttrmm ) +( const enum ATLAS_SIDE, const enum ATLAS_UPLO, + const enum ATLAS_TRANS, const enum ATLAS_DIAG, + const int, const int, const SCALAR, const TYPE *, + const int, TYPE *, const int ); +void Mjoin( PATL, pttrsm ) +( const enum ATLAS_SIDE, const enum ATLAS_UPLO, + const enum ATLAS_TRANS, const enum ATLAS_DIAG, + const int, const int, const SCALAR, const TYPE *, + const int, TYPE *, const int ); + +#if defined( TCPLX ) +PT_TREE_T Mjoin( PATL, pthemm_nt ) +( const unsigned int, pthread_attr_t *, + const enum ATLAS_SIDE, const enum ATLAS_UPLO, + const int, const int, const void *, const void *, + const int, const void *, const int, const void *, + void *, const int ); +PT_TREE_T Mjoin( PATL, pther2k_nt ) +( const unsigned int, pthread_attr_t *, + const enum ATLAS_UPLO, const enum ATLAS_TRANS, + const int, const int, const void *, const void *, + const void *, const int, const void *, const int, + const void *, void *, const int ); +int Mjoin( PATL, pther2k0_nt ) +( const unsigned int, pthread_attr_t *, + const enum ATLAS_UPLO, const enum ATLAS_TRANS, + const enum ATLAS_TRANS, const int, const int, + const void *, const void *, const int, const void *, + const int, const void *, void *, const int ); +PT_TREE_T Mjoin( PATL, ptherk_nt ) +( const unsigned int, pthread_attr_t *, + const enum ATLAS_UPLO, const enum ATLAS_TRANS, + const int, const int, const void *, const void *, + const int, const void *, void *, const int ); + +void Mjoin( PATL, pthemm ) +( const enum ATLAS_SIDE, const enum ATLAS_UPLO, + const int, const int, const SCALAR, const TYPE *, + const int, const TYPE *, const int, const SCALAR, + TYPE *, const int ); +void Mjoin( PATL, pther2k ) +( const enum ATLAS_UPLO, const enum ATLAS_TRANS, + const int, const int, const SCALAR, const TYPE *, + const int, const TYPE *, const int, const TYPE, + TYPE *, const int ); +void Mjoin( PATL, ptherk ) +( const enum ATLAS_UPLO, const enum ATLAS_TRANS, + const int, const int, const TYPE, const TYPE *, + const int, const TYPE, TYPE *, const int ); +#endif + +#endif + +#endif +/* + * End of atlas_ptlvl3.h + */ diff --git a/kaldi_io/src/tools/ATLAS/include/atlas_ptmisc.h b/kaldi_io/src/tools/ATLAS/include/atlas_ptmisc.h new file mode 100644 index 0000000..4c3db23 --- /dev/null +++ b/kaldi_io/src/tools/ATLAS/include/atlas_ptmisc.h @@ -0,0 +1,410 @@ +/* --------------------------------------------------------------------- + * + * -- Automatically Tuned Linear Algebra Software (ATLAS) + * (C) Copyright 2000 All Rights Reserved + * + * -- ATLAS routine -- Version 3.2 -- December 25, 2000 + * + * Author : Antoine P. Petitet + * Originally developed at the University of Tennessee, + * Innovative Computing Laboratory, Knoxville TN, 37996-1301, USA. + * + * --------------------------------------------------------------------- + * + * -- Copyright notice and Licensing terms: + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions + * are met: + * + * 1. Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * 2. Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions, and the following disclaimer in + * the documentation and/or other materials provided with the distri- + * bution. + * 3. The name of the University, the ATLAS group, or the names of its + * contributors may not be used to endorse or promote products deri- + * ved from this software without specific written permission. + * + * -- Disclaimer: + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + * ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE UNIVERSITY + * OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPE- + * CIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED + * TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, + * OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEO- + * RY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (IN- + * CLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF + * THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + * --------------------------------------------------------------------- + */ +#ifndef ATLAS_PTMISC_H +#define ATLAS_PTMISC_H +/* + * ===================================================================== + * Include Files + * ===================================================================== + */ +#include <math.h> +#include <pthread.h> + +#include "atlas_misc.h" +#include "atlas_pthreads.h" +/* + * ===================================================================== + * #define macro constants + * ===================================================================== + * + * ATL_XOVER_MI_DEFAULT is the smallest number of NB-by-NB blocks for + * which threading is enabled, where NB is the value returned by the + * ATLAS function Mjoin( PATL, GetNB ). + */ +#ifdef TREAL +#define ATL_XOVER_MI_DEFAULT 8 /* number of NB x NB blocks */ +#else +#define ATL_XOVER_MI_DEFAULT 4 +#endif + +#define NOSPLIT 0 /* For convenience */ +#define SPLIT_M 1 +#define SPLIT_N 2 +#define SPLIT_K 3 + +/* + * ===================================================================== + * macro functions + * ===================================================================== + */ +#define Mptm( a_, i_, siz_ ) ( ( (char*)(a_) + ( (i_) * (siz_) ) ) ) +#define Mvptm( a_, i_, siz_ ) ( (void *)(Mptm( (a_), (i_), (siz_) ))) +/* + * ===================================================================== + * typedef definitions + * ===================================================================== + * + * Definition of the Binary (recursive) task tree: Each node of the tree + * mainly consist a node number, a reference counter to enforce depen- + * dencies, a argument structure and a function to be applied. + */ +typedef void * PT_DATA_T; +typedef void * PT_FUN_VAL_T; +typedef void * PT_FUN_ARG_T; +typedef PT_FUN_VAL_T (*PT_FUN_T) ( PT_FUN_ARG_T ); + +typedef struct PT_node_T +{ + pthread_t pid; + pthread_mutex_t mutex; + pthread_cond_t cond; + struct PT_node_T * left; + struct PT_node_T * right; + PT_DATA_T data; + PT_FUN_VAL_T * val; + PT_FUN_T fun; + PT_FUN_ARG_T arg; + unsigned int node; + unsigned int count; +} PT_NODE_T; + +typedef PT_NODE_T * PT_TREE_T; +typedef void (*PT_APPLY_FUN_T)( PT_TREE_T ); + +enum DIM_1DSPLIT_E +{ + Atlas1dSplit = 100, + Atlas1dNoSplit = 199 +}; + +enum DIM_TZSPLIT_E +{ + AtlasTzSplitMrow = 200, + AtlasTzSplitKrow = 201, + AtlasTzSplitKcol = 202, + AtlasTzSplitNcol = 203, + AtlasTzNoSplit = 299 +}; + +typedef enum DIM_1DSPLIT_E DIM_1DSPLIT_T; +typedef enum DIM_TZSPLIT_E DIM_TZSPLIT_T; + +/* + * Type definitions for some auxiliaries that have been multi-threaded + * as well. + */ +typedef struct +{ + size_t size; + PT_FUN_T fun; +} PT_MISC_TYPE_T; + +typedef struct +{ + const void * al, * be; + const void * a; + void * c; + int la, lc, m, n; +} PT_GEADD_ARGS_T; + +typedef struct +{ + void * a; + int la, m, n; +} PT_GEZERO_ARGS_T; + +typedef struct +{ + const void * al; + void * a; + int la, m, n; +} PT_GESCAL_ARGS_T; + +typedef struct +{ + enum ATLAS_UPLO up; + const void * al; + void * a; + int k, la, m, n; +} PT_TZSCAL_ARGS_T; + +/* + * ===================================================================== + * Function prototypes + * ===================================================================== + */ +int ATL_sGetNB ( void ); +int ATL_dGetNB ( void ); +int ATL_cGetNB ( void ); +int ATL_zGetNB ( void ); + +DIM_1DSPLIT_T ATL_1dsplit +( + const unsigned int, + const int, + const int, + unsigned int *, + unsigned int *, + int *, + int *, + double * +); + +DIM_TZSPLIT_T ATL_tzsplit +( + const enum ATLAS_UPLO, + const unsigned int, + const int, + const int, + const int, + const int, + unsigned int *, + unsigned int *, + int *, + int * +); +/* + * Task tree management + */ +PT_TREE_T ATL_init_node +( unsigned int, PT_TREE_T, PT_TREE_T, PT_DATA_T, + PT_FUN_VAL_T *, PT_FUN_T, PT_FUN_ARG_T ); + +void ATL_traverse_tree ( PT_TREE_T ); +void ATL_apply_tree ( PT_TREE_T, PT_APPLY_FUN_T ); +void ATL_free_tree ( PT_TREE_T ); +void ATL_free_node ( PT_TREE_T ); +void ATL_print_node_id ( PT_TREE_T ); + +void ATL_thread_init ( pthread_attr_t * ); +void ATL_thread_exit ( pthread_attr_t * ); +void ATL_wait_tree ( PT_TREE_T ); +void ATL_signal_tree ( PT_TREE_T ); +void ATL_thread_tree ( PT_TREE_T, pthread_attr_t * ); +void ATL_join_tree ( PT_TREE_T ); + +PT_TREE_T ATL_create_tree +( unsigned int *, const int, const int ); +/* + * Typeless auxiliary functions + */ +PT_TREE_T ATL_Sgeadd +( const PT_MISC_TYPE_T *, const unsigned int, + const unsigned int, pthread_attr_t *, const int, + const int, const int, const void *, const void *, + const int, const void *, void *, const int ); +PT_TREE_T ATL_Sgescal +( const PT_MISC_TYPE_T *, const unsigned int, + const unsigned int, pthread_attr_t *, const int, + const int, const int, const void *, void *, + const int ); +PT_TREE_T ATL_Sgezero +( const PT_MISC_TYPE_T *, const unsigned int, + const unsigned int, pthread_attr_t *, const int, + const int, const int, void *, const int ); +PT_TREE_T ATL_Stzscal +( const PT_MISC_TYPE_T *, const unsigned int, + const unsigned int, pthread_attr_t *, const int, + const enum ATLAS_UPLO, const int, const int, + const int, const void *, void *, const int ); +/* + * Single precision real auxiliary functions + */ +PT_FUN_ARG_T ATL_sptgeadd0 ( PT_FUN_ARG_T ); +PT_FUN_ARG_T ATL_sptgescal0 ( PT_FUN_ARG_T ); +PT_FUN_ARG_T ATL_sptgezero0 ( PT_FUN_ARG_T ); +PT_FUN_ARG_T ATL_spttzscal0 ( PT_FUN_ARG_T ); + +PT_TREE_T ATL_sptgeadd_nt +( const unsigned int, pthread_attr_t *, const int, + const int, const void *, const void *, const int, + const void *, void *, const int ); +PT_TREE_T ATL_sptgescal_nt +( const unsigned int, pthread_attr_t *, const int, + const int, const void *, void *, const int ); +PT_TREE_T ATL_sptgezero_nt +( const unsigned int, pthread_attr_t *, const int, + const int, void *, const int ); +PT_TREE_T ATL_spttrscal_nt +( const unsigned int, pthread_attr_t *, + const enum ATLAS_UPLO, const int, const int, + const void *, void *, const int ); + +void ATL_sptgeadd +( const int, const int, const float, const float *, + const int, const float, float *, const int ); +void ATL_sptgescal +( const int, const int, const float, float *, + const int ); +void ATL_sptgezero +( const int, const int, float *, const int ); +void ATL_spttrscal +( const enum ATLAS_UPLO, const int, const int, + const float, float *, const int ); + +/* + * Double precision real auxiliary functions + */ +PT_FUN_ARG_T ATL_dptgeadd0 ( PT_FUN_ARG_T ); +PT_FUN_ARG_T ATL_dptgescal0 ( PT_FUN_ARG_T ); +PT_FUN_ARG_T ATL_dptgezero0 ( PT_FUN_ARG_T ); +PT_FUN_ARG_T ATL_dpttzscal0 ( PT_FUN_ARG_T ); + +PT_TREE_T ATL_dptgeadd_nt +( const unsigned int, pthread_attr_t *, const int, + const int, const void *, const void *, const int, + const void *, void *, const int ); +PT_TREE_T ATL_dptgescal_nt +( const unsigned int, pthread_attr_t *, const int, + const int, const void *, void *, const int ); +PT_TREE_T ATL_dptgezero_nt +( const unsigned int, pthread_attr_t *, const int, + const int, void *, const int ); +PT_TREE_T ATL_dpttrscal_nt +( const unsigned int, pthread_attr_t *, + const enum ATLAS_UPLO, const int, const int, + const void *, void *, const int ); + +void ATL_dptgeadd +( const int, const int, const double, const double *, + const int, const double, double *, const int ); +void ATL_dptgescal +( const int, const int, const double, double *, + const int ); +void ATL_dptgezero +( const int, const int, double *, const int ); +void ATL_dpttrscal +( const enum ATLAS_UPLO, const int, const int, + const double, double *, const int ); +/* + * Single precision complex auxiliary functions + */ +PT_FUN_ARG_T ATL_cptgeadd0 ( PT_FUN_ARG_T ); +PT_FUN_ARG_T ATL_cptgescal0 ( PT_FUN_ARG_T ); +PT_FUN_ARG_T ATL_cptgezero0 ( PT_FUN_ARG_T ); +PT_FUN_ARG_T ATL_cpthescal0 ( PT_FUN_ARG_T ); +PT_FUN_ARG_T ATL_cpttzscal0 ( PT_FUN_ARG_T ); + +PT_TREE_T ATL_cptgeadd_nt +( const unsigned int, pthread_attr_t *, const int, + const int, const void *, const void *, const int, + const void *, void *, const int ); +PT_TREE_T ATL_cptgescal_nt +( const unsigned int, pthread_attr_t *, const int, + const int, const void *, void *, const int ); +PT_TREE_T ATL_cptgezero_nt +( const unsigned int, pthread_attr_t *, const int, + const int, void *, const int ); +PT_TREE_T ATL_cpttrscal_nt +( const unsigned int, pthread_attr_t *, + const enum ATLAS_UPLO, const int, const int, + const void *, void *, const int ); +PT_TREE_T ATL_cpthescal_nt +( const unsigned int, pthread_attr_t *, + const enum ATLAS_UPLO, const int, const int, + const void *, void *, const int ); + +void ATL_cptgeadd +( const int, const int, const float *, const float *, + const int, const float *, float *, const int ); +void ATL_cptgezero +( const int, const int, float *, const int ); +void ATL_cptgescal +( const int, const int, const float *, float *, + const int ); +void ATL_cpttrscal +( const enum ATLAS_UPLO, const int, const int, + const float *, float *, const int ); +void ATL_cpthescal +( const enum ATLAS_UPLO, const int, const int, + const float, float *, const int ); +/* + * Double precision complex auxiliary functions + */ +PT_FUN_ARG_T ATL_zptgeadd0 ( PT_FUN_ARG_T ); +PT_FUN_ARG_T ATL_zptgescal0 ( PT_FUN_ARG_T ); +PT_FUN_ARG_T ATL_zptgezero0 ( PT_FUN_ARG_T ); +PT_FUN_ARG_T ATL_zpthescal0 ( PT_FUN_ARG_T ); +PT_FUN_ARG_T ATL_zpttzscal0 ( PT_FUN_ARG_T ); + +PT_TREE_T ATL_zptgeadd_nt +( const unsigned int, pthread_attr_t *, const int, + const int, const void *, const void *, const int, + const void *, void *, const int ); +PT_TREE_T ATL_zptgescal_nt +( const unsigned int, pthread_attr_t *, const int, + const int, const void *, void *, const int ); +PT_TREE_T ATL_zptgezero_nt +( const unsigned int, pthread_attr_t *, const int, + const int, void *, const int ); +PT_TREE_T ATL_zpttrscal_nt +( const unsigned int, pthread_attr_t *, + const enum ATLAS_UPLO, const int, const int, + const void *, void *, const int ); +PT_TREE_T ATL_zpthescal_nt +( const unsigned int, pthread_attr_t *, + const enum ATLAS_UPLO, const int, const int, + const void *, void *, const int ); + +void ATL_zptgeadd +( const int, const int, const double *, const double *, + const int, const double *, double *, const int ); +void ATL_zptgezero +( const int, const int, double *, const int ); +void ATL_zptgescal +( const int, const int, const double *, double *, + const int ); +void ATL_zpttrscal +( const enum ATLAS_UPLO, const int, const int, + const double *, double *, const int ); +void ATL_zpthescal +( const enum ATLAS_UPLO, const int, const int, + const double, double *, const int ); + +#endif +/* + * End of atlas_ptmisc.h + */ diff --git a/kaldi_io/src/tools/ATLAS/include/atlas_r1.h b/kaldi_io/src/tools/ATLAS/include/atlas_r1.h new file mode 100644 index 0000000..dc49fe2 --- /dev/null +++ b/kaldi_io/src/tools/ATLAS/include/atlas_r1.h @@ -0,0 +1,39 @@ +/* + * Automatically Tuned Linear Algebra Software v3.8.3 + * (C) Copyright 1999 R. Clint Whaley + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions + * are met: + * 1. Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * 2. Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions, and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * 3. The name of the ATLAS group or the names of its contributers may + * not be used to endorse or promote products derived from this + * software without specific written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + * ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED + * TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR + * PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE ATLAS GROUP OR ITS CONTRIBUTORS + * BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + * POSSIBILITY OF SUCH DAMAGE. + * + */ + +#ifdef SREAL + #include "atlas_sr1.h" +#elif defined(DREAL) + #include "atlas_dr1.h" +#elif defined(SCPLX) + #include "atlas_cr1.h" +#elif defined(DCPLX) + #include "atlas_zr1.h" +#endif diff --git a/kaldi_io/src/tools/ATLAS/include/atlas_rblas3.h b/kaldi_io/src/tools/ATLAS/include/atlas_rblas3.h new file mode 100644 index 0000000..9ad27e7 --- /dev/null +++ b/kaldi_io/src/tools/ATLAS/include/atlas_rblas3.h @@ -0,0 +1,474 @@ +/* --------------------------------------------------------------------- + * + * -- Automatically Tuned Linear Algebra Software (ATLAS) + * (C) Copyright 2000 All Rights Reserved + * + * -- ATLAS routine -- Version 3.2 -- December 25, 2000 + * + * Author : Antoine P. Petitet + * Contributor(s) : R. Clint Whaley + * Originally developed at the University of Tennessee, + * Innovative Computing Laboratory, Knoxville TN, 37996-1301, USA. + * + * --------------------------------------------------------------------- + * + * -- Copyright notice and Licensing terms: + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions + * are met: + * + * 1. Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * 2. Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions, and the following disclaimer in + * the documentation and/or other materials provided with the distri- + * bution. + * 3. The name of the University, the ATLAS group, or the names of its + * contributors may not be used to endorse or promote products deri- + * ved from this software without specific written permission. + * + * -- Disclaimer: + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + * ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE UNIVERSITY + * OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPE- + * CIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED + * TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, + * OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEO- + * RY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (IN- + * CLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF + * THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + * --------------------------------------------------------------------- + */ +#ifndef ATLAS_RBLAS3_H +#define ATLAS_RBLAS3_H +/* + * ===================================================================== + * Include files + * ===================================================================== + */ +#include "atlas_misc.h" +/* + * ===================================================================== + * #define macros definitions + * ===================================================================== + */ +#define Mrc3( a_, i_, j_, lda_, siz_ ) \ + ( (void*) ( (char*)(a_) + ( ( (i_)+(j_)*(lda_) )*(siz_) ) ) ) +/* + * ===================================================================== + * #typedef definitions + * ===================================================================== + */ +typedef void (*KR3_FUN_GEMM_T) +( const int, const int, const int, const void *, + const void *, const int, const void *, const int, + const void *, void *, const int ); +typedef void (*KR3_FUN_HEMM_T) +( const int, const int, const void *, const void *, + const int, const void *, const int, const void *, + void *, const int ); +typedef int (*KR3_FUN_HER2K_T) +( const int, const int, const void *, const void *, + const int, const void *, const int, const void *, + void *, const int ); +typedef void (*KR3_FUN_HERK_T) +( const int, const int, const void *, const void *, + const int, const void *, void *, const int ); +typedef void (*KR3_FUN_SYMM_T) +( const int, const int, const void *, const void *, + const int, const void *, const int, const void *, + void *, const int ); +typedef int (*KR3_FUN_SYR2K_T) +( const int, const int, const void *, const void *, + const int, const void *, const int, const void *, + void *, const int ); +typedef void (*KR3_FUN_SYRK_T) +( const int, const int, const void *, const void *, + const int, const void *, void *, const int ); +typedef void (*KR3_FUN_TRMM_T) +( const int, const int, const void *, const void *, + const int, void *, const int ); +typedef void (*KR3_FUN_TRSM_T) +( const int, const int, const void *, const void *, + const int, void *, const int ); + +typedef struct +{ + size_t size; + void * one; + KR3_FUN_GEMM_T TgemmNN; + KR3_FUN_GEMM_T Tgemm; + KR3_FUN_SYMM_T Tsymm; +} RC3_SYMM_T; + +typedef struct +{ + size_t size; + void * one; + KR3_FUN_GEMM_T TgemmNN; + KR3_FUN_GEMM_T Tgemm; + KR3_FUN_HEMM_T Themm; +} RC3_HEMM_T; + +typedef struct +{ + size_t size; + KR3_FUN_GEMM_T Tgemm; + KR3_FUN_SYRK_T Tsyrk; +} RC3_SYRK_T; + +typedef struct +{ + size_t size; + KR3_FUN_GEMM_T Tgemm; + KR3_FUN_HERK_T Therk; +} RC3_HERK_T; + +typedef struct +{ + size_t size; + void * one; + KR3_FUN_GEMM_T Tgemm; + KR3_FUN_SYR2K_T Tsyr2k; +} RC3_SYR2K_T; + +typedef struct +{ + size_t size; + void * one; + KR3_FUN_GEMM_T Tgemm; + KR3_FUN_HER2K_T Ther2k; +} RC3_HER2K_T; + +typedef struct +{ + size_t size; + void * one; + KR3_FUN_GEMM_T Tgemm; + KR3_FUN_TRMM_T Ttrmm; +} RC3_TRMM_T; + +typedef struct +{ + size_t size; + void * one, * negone; + KR3_FUN_GEMM_T Tgemm; + KR3_FUN_TRSM_T Ttrsm; +} RC3_TRSM_T; + +typedef void (*RC3_FUN_HEMM_T) +( RC3_HEMM_T *, const int, const int, const void *, + const void *, const int, const void *, const int, + const void *, void *, const int, const int ); +typedef void (*RC3_FUN_HER2K_T) +( RC3_HER2K_T *, const int, const int, const void *, + const void *, const void *, const int, const void *, + const int, const void *, void *, const int, + const int ); +typedef void (*RC3_FUN_HERK_T) +( RC3_HERK_T *, const int, const int, const void *, + const void *, const int, const void *, void *, + const int, const int ); +typedef void (*RC3_FUN_SYMM_T) +( RC3_SYMM_T *, const int, const int, const void *, + const void *, const int, const void *, const int, + const void *, void *, const int, const int ); +typedef void (*RC3_FUN_SYR2K_T) +( RC3_SYR2K_T *, const int, const int, const void *, + const void *, const int, const void *, const int, + const void *, void *, const int, const int ); +typedef void (*RC3_FUN_SYRK_T) +( RC3_SYRK_T *, const int, const int, const void *, + const void *, const int, const void *, void *, + const int, const int ); +typedef void (*RC3_FUN_TRMM_T) +( RC3_TRMM_T *, const int, const int, const void *, + const void *, const int, void *, const int, + const int ); +typedef void (*RC3_FUN_TRSM_T) +( RC3_TRSM_T *, const int, const int, const void *, + const void *, const int, void *, const int, + const int ); +/* + * ===================================================================== + * Level 3 recursive BLAS internal function prototypes + * ===================================================================== + */ +void ATL_sgemmTN_RB +( const int, const int, const int, const void *, + const void *, const int, const void *, const int, + const void *, void *, const int ); +void ATL_sgemmNT_RB +( const int, const int, const int, const void *, + const void *, const int, const void *, const int, + const void *, void *, const int ); +void ATL_sgemmNN_RB +( const int, const int, const int, const void *, + const void *, const int, const void *, const int, + const void *, void *, const int ); +void ATL_dgemmTN_RB +( const int, const int, const int, const void *, + const void *, const int, const void *, const int, + const void *, void *, const int ); +void ATL_dgemmNT_RB +( const int, const int, const int, const void *, + const void *, const int, const void *, const int, + const void *, void *, const int ); +void ATL_dgemmNN_RB +( const int, const int, const int, const void *, + const void *, const int, const void *, const int, + const void *, void *, const int ); +void ATL_cgemmCN_RB +( const int, const int, const int, const void *, + const void *, const int, const void *, const int, + const void *, void *, const int ); +void ATL_cgemmNC_RB +( const int, const int, const int, const void *, + const void *, const int, const void *, const int, + const void *, void *, const int ); +void ATL_cgemmTN_RB +( const int, const int, const int, const void *, + const void *, const int, const void *, const int, + const void *, void *, const int ); +void ATL_cgemmNT_RB +( const int, const int, const int, const void *, + const void *, const int, const void *, const int, + const void *, void *, const int ); +void ATL_cgemmNN_RB +( const int, const int, const int, const void *, + const void *, const int, const void *, const int, + const void *, void *, const int ); +void ATL_zgemmCN_RB +( const int, const int, const int, const void *, + const void *, const int, const void *, const int, + const void *, void *, const int ); +void ATL_zgemmNC_RB +( const int, const int, const int, const void *, + const void *, const int, const void *, const int, + const void *, void *, const int ); +void ATL_zgemmTN_RB +( const int, const int, const int, const void *, + const void *, const int, const void *, const int, + const void *, void *, const int ); +void ATL_zgemmNT_RB +( const int, const int, const int, const void *, + const void *, const int, const void *, const int, + const void *, void *, const int ); +void ATL_zgemmNN_RB +( const int, const int, const int, const void *, + const void *, const int, const void *, const int, + const void *, void *, const int ); +/* + * ===================================================================== + * Recursive BLAS function prototypes + * ===================================================================== + */ +void ATL_rsymmRU +( RC3_SYMM_T *, const int, const int, const void *, + const void *, const int, const void *, const int, + const void *, void *, const int, const int ); +void ATL_rhemmRU +( RC3_HEMM_T *, const int, const int, const void *, + const void *, const int, const void *, const int, + const void *, void *, const int, const int ); +void ATL_rsymmRL +( RC3_SYMM_T *, const int, const int, const void *, + const void *, const int, const void *, const int, + const void *, void *, const int, const int ); +void ATL_rhemmRL +( RC3_HEMM_T *, const int, const int, const void *, + const void *, const int, const void *, const int, + const void *, void *, const int, const int ); +void ATL_rsymmLU +( RC3_SYMM_T *, const int, const int, const void *, + const void *, const int, const void *, const int, + const void *, void *, const int, const int ); +void ATL_rhemmLU +( RC3_HEMM_T *, const int, const int, const void *, + const void *, const int, const void *, const int, + const void *, void *, const int, const int ); +void ATL_rsymmLL +( RC3_SYMM_T *, const int, const int, const void *, + const void *, const int, const void *, const int, + const void *, void *, const int, const int ); +void ATL_rhemmLL +( RC3_HEMM_T *, const int, const int, const void *, + const void *, const int, const void *, const int, + const void *, void *, const int, const int ); + +void ATL_rsyrkUT +( RC3_SYRK_T *, const int, const int, const void *, + const void *, const int, const void *, void *, + const int, const int ); +void ATL_rsyr2kUT +( RC3_SYR2K_T *, const int, const int, const void *, + const void *, const int, const void *, const int, + const void *, void *, const int, const int ); +void ATL_rsyrkUN +( RC3_SYRK_T *, const int, const int, const void *, + const void *, const int, const void *, void *, + const int, const int ); +void ATL_rsyr2kUN +( RC3_SYR2K_T *, const int, const int, const void *, + const void *, const int, const void *, const int, + const void *, void *, const int, const int ); +void ATL_rsyrkLT +( RC3_SYRK_T *, const int, const int, const void *, + const void *, const int, const void *, void *, + const int, const int ); +void ATL_rsyr2kLT +( RC3_SYR2K_T *, const int, const int, const void *, + const void *, const int, const void *, const int, + const void *, void *, const int, const int ); +void ATL_rsyrkLN +( RC3_SYRK_T *, const int, const int, const void *, + const void *, const int, const void *, void *, + const int, const int ); +void ATL_rsyr2kLN +( RC3_SYR2K_T *, const int, const int, const void *, + const void *, const int, const void *, const int, + const void *, void *, const int, const int ); + +void ATL_rherkUC +( RC3_HERK_T *, const int, const int, const void *, + const void *, const int, const void *, void *, + const int, const int ); +void ATL_rher2kUC +( RC3_HER2K_T *, const int, const int, const void *, + const void *, const void *, const int, const void *, + const int, const void *, void *, const int, + const int ); +void ATL_rherkUN +( RC3_HERK_T *, const int, const int, const void *, + const void *, const int, const void *, void *, + const int, const int ); +void ATL_rher2kUN +( RC3_HER2K_T *, const int, const int, const void *, + const void *, const void *, const int, const void *, + const int, const void *, void *, const int, + const int ); +void ATL_rherkLC +( RC3_HERK_T *, const int, const int, const void *, + const void *, const int, const void *, void *, + const int, const int ); +void ATL_rher2kLC +( RC3_HER2K_T *, const int, const int, const void *, + const void *, const void *, const int, const void *, + const int, const void *, void *, const int, + const int ); +void ATL_rherkLN +( RC3_HERK_T *, const int, const int, const void *, + const void *, const int, const void *, void *, + const int, const int ); +void ATL_rher2kLN +( RC3_HER2K_T *, const int, const int, const void *, + const void *, const void *, const int, const void *, + const int, const void *, void *, const int, + const int ); + +void ATL_rtrmmRUC +( RC3_TRMM_T *, const int, const int, const void *, + const void *, const int, void *, const int, + const int ); +void ATL_rtrsmRUC +( RC3_TRSM_T *, const int, const int, const void *, + const void *, const int, void *, const int, + const int ); +void ATL_rtrmmRLC +( RC3_TRMM_T *, const int, const int, const void *, + const void *, const int, void *, const int, + const int ); +void ATL_rtrsmRLC +( RC3_TRSM_T *, const int, const int, const void *, + const void *, const int, void *, const int, + const int ); +void ATL_rtrmmRUT +( RC3_TRMM_T *, const int, const int, const void *, + const void *, const int, void *, const int, + const int ); +void ATL_rtrsmRUT +( RC3_TRSM_T *, const int, const int, const void *, + const void *, const int, void *, const int, + const int ); +void ATL_rtrmmRLT +( RC3_TRMM_T *, const int, const int, const void *, + const void *, const int, void *, const int, + const int ); +void ATL_rtrsmRLT +( RC3_TRSM_T *, const int, const int, const void *, + const void *, const int, void *, const int, + const int ); +void ATL_rtrmmRUN +( RC3_TRMM_T *, const int, const int, const void *, + const void *, const int, void *, const int, + const int ); +void ATL_rtrsmRUN +( RC3_TRSM_T *, const int, const int, const void *, + const void *, const int, void *, const int, + const int ); +void ATL_rtrmmRLN +( RC3_TRMM_T *, const int, const int, const void *, + const void *, const int, void *, const int, + const int ); +void ATL_rtrsmRLN +( RC3_TRSM_T *, const int, const int, const void *, + const void *, const int, void *, const int, + const int ); +void ATL_rtrmmLUC +( RC3_TRMM_T *, const int, const int, const void *, + const void *, const int, void *, const int, + const int ); +void ATL_rtrsmLUC +( RC3_TRSM_T *, const int, const int, const void *, + const void *, const int, void *, const int, + const int ); +void ATL_rtrmmLLC +( RC3_TRMM_T *, const int, const int, const void *, + const void *, const int, void *, const int, + const int ); +void ATL_rtrsmLLC +( RC3_TRSM_T *, const int, const int, const void *, + const void *, const int, void *, const int, + const int ); +void ATL_rtrmmLUT +( RC3_TRMM_T *, const int, const int, const void *, + const void *, const int, void *, const int, + const int ); +void ATL_rtrsmLUT +( RC3_TRSM_T *, const int, const int, const void *, + const void *, const int, void *, const int, + const int ); +void ATL_rtrmmLLT +( RC3_TRMM_T *, const int, const int, const void *, + const void *, const int, void *, const int, + const int ); +void ATL_rtrsmLLT +( RC3_TRSM_T *, const int, const int, const void *, + const void *, const int, void *, const int, + const int ); +void ATL_rtrmmLUN +( RC3_TRMM_T *, const int, const int, const void *, + const void *, const int, void *, const int, + const int ); +void ATL_rtrsmLUN +( RC3_TRSM_T *, const int, const int, const void *, + const void *, const int, void *, const int, + const int ); +void ATL_rtrmmLLN +( RC3_TRMM_T *, const int, const int, const void *, + const void *, const int, void *, const int, + const int ); +void ATL_rtrsmLLN +( RC3_TRSM_T *, const int, const int, const void *, + const void *, const int, void *, const int, + const int ); + +#endif +/* + * End of atlas_rblas3.h + */ diff --git a/kaldi_io/src/tools/ATLAS/include/atlas_refalias1.h b/kaldi_io/src/tools/ATLAS/include/atlas_refalias1.h new file mode 100644 index 0000000..7dcac8a --- /dev/null +++ b/kaldi_io/src/tools/ATLAS/include/atlas_refalias1.h @@ -0,0 +1,59 @@ +#ifndef ATLAS_REFALIAS1_H +#define ATLAS_REFALIAS1_H +/* + * Real BLAS + */ + #define ATL_dsdot ATL_dsrefdot + #define ATL_sdsdot ATL_sdsrefdot + #define ATL_sasum ATL_srefasum + #define ATL_snrm2 ATL_srefnrm2 + #define ATL_sdot ATL_srefdot + #define ATL_saxpy ATL_srefaxpy + #define ATL_scopy ATL_srefcopy + #define ATL_sscal ATL_srefscal + #define ATL_sswap ATL_srefswap + #define ATL_srotm ATL_srefrotm + #define ATL_srot ATL_srefrot + #define ATL_srotmg ATL_srefrotmg + #define ATL_srotg ATL_srefrotg + #define ATL_isamax ATL_isrefamax + + #define ATL_dasum ATL_drefasum + #define ATL_dnrm2 ATL_drefnrm2 + #define ATL_ddot ATL_drefdot + #define ATL_daxpy ATL_drefaxpy + #define ATL_dcopy ATL_drefcopy + #define ATL_dscal ATL_drefscal + #define ATL_dswap ATL_drefswap + #define ATL_drotm ATL_drefrotm + #define ATL_drot ATL_drefrot + #define ATL_drotmg ATL_drefrotmg + #define ATL_drotg ATL_drefrotg + #define ATL_idamax ATL_idrefamax + +/* + * Complex BLAS + */ + #define ATL_cdotc_sub ATL_crefdotc_sub + #define ATL_cdotu_sub ATL_crefdotu_sub + #define ATL_caxpy ATL_crefaxpy + #define ATL_ccopy ATL_crefcopy + #define ATL_cscal ATL_crefscal + #define ATL_cswap ATL_crefswap + #define ATL_icamax ATL_icrefamax + #define ATL_csscal ATL_csrefscal + #define ATL_scnrm2 ATL_screfnrm2 + #define ATL_scasum ATL_screfasum + + #define ATL_zdotc_sub ATL_zrefdotc_sub + #define ATL_zdotu_sub ATL_zrefdotu_sub + #define ATL_zaxpy ATL_zrefaxpy + #define ATL_zcopy ATL_zrefcopy + #define ATL_zscal ATL_zrefscal + #define ATL_zswap ATL_zrefswap + #define ATL_izamax ATL_izrefamax + #define ATL_zdscal ATL_zdrefscal + #define ATL_dznrm2 ATL_dzrefnrm2 + #define ATL_dzasum ATL_dzrefasum + +#endif diff --git a/kaldi_io/src/tools/ATLAS/include/atlas_refalias2.h b/kaldi_io/src/tools/ATLAS/include/atlas_refalias2.h new file mode 100644 index 0000000..5871491 --- /dev/null +++ b/kaldi_io/src/tools/ATLAS/include/atlas_refalias2.h @@ -0,0 +1,79 @@ +#ifndef ATLAS_REFALIAS2_H +#define ATLAS_REFALIAS2_H +/* + * Real BLAS + */ + #define ATL_sspr2 ATL_srefspr2 + #define ATL_ssyr2 ATL_srefsyr2 + #define ATL_sspr ATL_srefspr + #define ATL_ssyr ATL_srefsyr + #define ATL_sger ATL_srefger + #define ATL_stpsv ATL_sreftpsv + #define ATL_stbsv ATL_sreftbsv + #define ATL_strsv ATL_sreftrsv + #define ATL_stpmv ATL_sreftpmv + #define ATL_stbmv ATL_sreftbmv + #define ATL_strmv ATL_sreftrmv + #define ATL_sspmv ATL_srefspmv + #define ATL_ssbmv ATL_srefsbmv + #define ATL_ssymv ATL_srefsymv + #define ATL_sgbmv ATL_srefgbmv + #define ATL_sgemv ATL_srefgemv + + #define ATL_dspr2 ATL_drefspr2 + #define ATL_dsyr2 ATL_drefsyr2 + #define ATL_dspr ATL_drefspr + #define ATL_dsyr ATL_drefsyr + #define ATL_dger ATL_drefger + #define ATL_dtpsv ATL_dreftpsv + #define ATL_dtbsv ATL_dreftbsv + #define ATL_dtrsv ATL_dreftrsv + #define ATL_dtpmv ATL_dreftpmv + #define ATL_dtbmv ATL_dreftbmv + #define ATL_dtrmv ATL_dreftrmv + #define ATL_dspmv ATL_drefspmv + #define ATL_dsbmv ATL_drefsbmv + #define ATL_dsymv ATL_drefsymv + #define ATL_dgbmv ATL_drefgbmv + #define ATL_dgemv ATL_drefgemv + +/* + * Complex BLAS + */ + #define ATL_chpr2 ATL_crefhpr2 + #define ATL_cher2 ATL_crefher2 + #define ATL_chpr ATL_crefhpr + #define ATL_cher ATL_crefher + #define ATL_cgerc ATL_crefgerc + #define ATL_cgeru ATL_crefgeru + #define ATL_ctpsv ATL_creftpsv + #define ATL_ctbsv ATL_creftbsv + #define ATL_ctrsv ATL_creftrsv + #define ATL_ctpmv ATL_creftpmv + #define ATL_ctbmv ATL_creftbmv + #define ATL_ctrmv ATL_creftrmv + #define ATL_chpmv ATL_crefhpmv + #define ATL_chbmv ATL_crefhbmv + #define ATL_chemv ATL_crefhemv + #define ATL_cgbmv ATL_crefgbmv + #define ATL_cgemv ATL_crefgemv + + #define ATL_zhpr2 ATL_zrefhpr2 + #define ATL_zher2 ATL_zrefher2 + #define ATL_zhpr ATL_zrefhpr + #define ATL_zher ATL_zrefher + #define ATL_zgerc ATL_zrefgerc + #define ATL_zgeru ATL_zrefgeru + #define ATL_ztpsv ATL_zreftpsv + #define ATL_ztbsv ATL_zreftbsv + #define ATL_ztrsv ATL_zreftrsv + #define ATL_ztpmv ATL_zreftpmv + #define ATL_ztbmv ATL_zreftbmv + #define ATL_ztrmv ATL_zreftrmv + #define ATL_zhpmv ATL_zrefhpmv + #define ATL_zhbmv ATL_zrefhbmv + #define ATL_zhemv ATL_zrefhemv + #define ATL_zgbmv ATL_zrefgbmv + #define ATL_zgemv ATL_zrefgemv + +#endif diff --git a/kaldi_io/src/tools/ATLAS/include/atlas_refalias3.h b/kaldi_io/src/tools/ATLAS/include/atlas_refalias3.h new file mode 100644 index 0000000..f10e65c --- /dev/null +++ b/kaldi_io/src/tools/ATLAS/include/atlas_refalias3.h @@ -0,0 +1,43 @@ +#ifndef ATLAS_REFALIAS3_H +#define ATLAS_REFALIAS3_H +/* + * Real BLAS + */ + #define ATL_strsm ATL_sreftrsm + #define ATL_strmm ATL_sreftrmm + #define ATL_ssyr2k ATL_srefsyr2k + #define ATL_ssyrk ATL_srefsyrk + #define ATL_ssymm ATL_srefsymm + #define ATL_sgemm ATL_srefgemm + + #define ATL_dtrsm ATL_dreftrsm + #define ATL_dtrmm ATL_dreftrmm + #define ATL_dsyr2k ATL_drefsyr2k + #define ATL_dsyrk ATL_drefsyrk + #define ATL_dsymm ATL_drefsymm + #define ATL_dgemm ATL_drefgemm + +/* + * Complex BLAS + */ + #define ATL_ctrsm ATL_creftrsm + #define ATL_ctrmm ATL_creftrmm + #define ATL_cher2k ATL_crefher2k + #define ATL_csyr2k ATL_crefsyr2k + #define ATL_cherk ATL_crefherk + #define ATL_csyrk ATL_crefsyrk + #define ATL_chemm ATL_crefhemm + #define ATL_csymm ATL_crefsymm + #define ATL_cgemm ATL_crefgemm + + #define ATL_ztrsm ATL_zreftrsm + #define ATL_ztrmm ATL_zreftrmm + #define ATL_zher2k ATL_zrefher2k + #define ATL_zsyr2k ATL_zrefsyr2k + #define ATL_zherk ATL_zrefherk + #define ATL_zsyrk ATL_zrefsyrk + #define ATL_zhemm ATL_zrefhemm + #define ATL_zsymm ATL_zrefsymm + #define ATL_zgemm ATL_zrefgemm + +#endif diff --git a/kaldi_io/src/tools/ATLAS/include/atlas_reflevel1.h b/kaldi_io/src/tools/ATLAS/include/atlas_reflevel1.h new file mode 100644 index 0000000..2f79ac8 --- /dev/null +++ b/kaldi_io/src/tools/ATLAS/include/atlas_reflevel1.h @@ -0,0 +1,421 @@ +/* --------------------------------------------------------------------- + * + * -- Automatically Tuned Linear Algebra Software (ATLAS) + * (C) Copyright 2000 All Rights Reserved + * + * -- ATLAS routine -- Version 3.2 -- December 25, 2000 + * + * Author : Antoine P. Petitet + * Originally developed at the University of Tennessee, + * Innovative Computing Laboratory, Knoxville TN, 37996-1301, USA. + * + * --------------------------------------------------------------------- + * + * -- Copyright notice and Licensing terms: + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions + * are met: + * + * 1. Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * 2. Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions, and the following disclaimer in + * the documentation and/or other materials provided with the distri- + * bution. + * 3. The name of the University, the ATLAS group, or the names of its + * contributors may not be used to endorse or promote products deri- + * ved from this software without specific written permission. + * + * -- Disclaimer: + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + * ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE UNIVERSITY + * OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPE- + * CIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED + * TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, + * OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEO- + * RY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (IN- + * CLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF + * THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + * --------------------------------------------------------------------- + */ +#ifndef ATLAS_REFLEVEL1_H +#define ATLAS_REFLEVEL1_H +/* + * ===================================================================== + * Prototypes for Level 1 Reference ATLAS BLAS routines + * ===================================================================== + */ +void ATL_srefrotg +( + float *, + float *, + float *, + float * +); + +void ATL_srefrotmg +( + float *, + float *, + float *, + const float, + float * +); + +float ATL_srefnrm2 +( + const int, + const float *, const int +); + +float ATL_srefasum +( + const int, + const float *, const int +); + +int ATL_isrefamax +( + const int, + const float *, const int +); + +void ATL_srefscal +( + const int, + const float, + float *, const int +); + +void ATL_srefswap +( + const int, + float *, const int, + float *, const int +); + +void ATL_srefcopy +( + const int, + const float *, const int, + float *, const int +); + +void ATL_srefaxpy +( + const int, + const float, + const float *, const int, + float *, const int +); + +void ATL_srefrot +( + const int, + float *, const int, + float *, const int, + const float, + const float +); + +void ATL_srefrotm +( + const int, + float *, const int, + float *, const int, + const float * +); + +float ATL_srefdot +( + const int, + const float *, const int, + const float *, const int +); + +float ATL_sdsrefdot +( + const int, + const float, + const float *, const int, + const float *, const int +); + +double ATL_dsrefdot +( + const int, + const float *, const int, + const float *, const int +); + +void ATL_drefrotg +( + double *, + double *, + double *, + double * +); + +void ATL_drefrotmg +( + double *, + double *, + double *, + const double, + double * +); + +double ATL_drefnrm2 +( + const int, + const double *, const int +); + +double ATL_drefasum +( + const int, + const double *, const int +); + +int ATL_idrefamax +( + const int, + const double *, const int +); + +void ATL_drefscal +( + const int, + const double, + double *, const int +); + +void ATL_drefswap +( + const int, + double *, const int, + double *, const int +); + +void ATL_drefcopy +( + const int, + const double *, const int, + double *, const int +); + +void ATL_drefaxpy +( + const int, + const double, + const double *, const int, + double *, const int +); + +void ATL_drefrot +( + const int, + double *, const int, + double *, const int, + const double, + const double +); + +void ATL_drefrotm +( + const int, + double *, const int, + double *, const int, + const double * +); + +double ATL_drefdot +( + const int, + const double *, const int, + const double *, const int +); + +void ATL_crefrotg +( + float *, + const float *, + float *, + float * +); + +float ATL_screfnrm2 +( + const int, + const float *, const int +); + +float ATL_screfasum +( + const int, + const float *, const int +); + +int ATL_icrefamax +( + const int, + const float *, const int +); + +void ATL_crefscal +( + const int, + const float *, + float *, const int +); + +void ATL_csrefscal +( + const int, + const float, + float *, const int +); + +void ATL_crefswap +( + const int, + float *, const int, + float *, const int +); + +void ATL_crefcopy +( + const int, + const float *, const int, + float *, const int +); + +void ATL_crefaxpy +( + const int, + const float *, + const float *, const int, + float *, const int +); + +void ATL_csrefrot +( + const int, + float *, const int, + float *, const int, + const float, + const float +); + +void ATL_crefdotc_sub +( + const int, + const float *, const int, + const float *, const int, + float * +); + +void ATL_crefdotu_sub +( + const int, + const float *, const int, + const float *, const int, + float * +); + +void ATL_zrefrotg +( + double *, + const double *, + double *, + double * +); + +double ATL_dzrefnrm2 +( + const int, + const double *, const int +); + +double ATL_dzrefasum +( + const int, + const double *, const int +); + +int ATL_izrefamax +( + const int, + const double *, const int +); + +void ATL_zrefscal +( + const int, + const double *, + double *, const int +); + +void ATL_zdrefscal +( + const int, + const double, + double *, const int +); + +void ATL_zrefswap +( + const int, + double *, const int, + double *, const int +); + +void ATL_zrefcopy +( + const int, + const double *, const int, + double *, const int +); + +void ATL_zrefaxpy +( + const int, + const double *, + const double *, const int, + double *, const int +); + +void ATL_zdrefrot +( + const int, + double *, const int, + double *, const int, + const double, + const double +); + +void ATL_zrefdotc_sub +( + const int, + const double *, const int, + const double *, const int, + double * +); + +void ATL_zrefdotu_sub +( + const int, + const double *, const int, + const double *, const int, + double * +); + +#endif +/* + * End of atlas_reflevel1.h + */ diff --git a/kaldi_io/src/tools/ATLAS/include/atlas_reflevel2.h b/kaldi_io/src/tools/ATLAS/include/atlas_reflevel2.h new file mode 100644 index 0000000..6158d17 --- /dev/null +++ b/kaldi_io/src/tools/ATLAS/include/atlas_reflevel2.h @@ -0,0 +1,788 @@ +/* --------------------------------------------------------------------- + * + * -- Automatically Tuned Linear Algebra Software (ATLAS) + * (C) Copyright 2000 All Rights Reserved + * + * -- ATLAS routine -- Version 3.2 -- December 25, 2000 + * + * Author : Antoine P. Petitet + * Originally developed at the University of Tennessee, + * Innovative Computing Laboratory, Knoxville TN, 37996-1301, USA. + * + * --------------------------------------------------------------------- + * + * -- Copyright notice and Licensing terms: + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions + * are met: + * + * 1. Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * 2. Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions, and the following disclaimer in + * the documentation and/or other materials provided with the distri- + * bution. + * 3. The name of the University, the ATLAS group, or the names of its + * contributors may not be used to endorse or promote products deri- + * ved from this software without specific written permission. + * + * -- Disclaimer: + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + * ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE UNIVERSITY + * OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPE- + * CIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED + * TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, + * OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEO- + * RY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (IN- + * CLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF + * THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + * --------------------------------------------------------------------- + */ +#ifndef ATLAS_REFLEVEL2_H +#define ATLAS_REFLEVEL2_H + +#include "atlas_enum.h" +/* + * ===================================================================== + * Prototypes for Level 2 Reference ATLAS BLAS routines + * ===================================================================== + */ +void ATL_srefgbmv +( + const enum ATLAS_TRANS, + const int, const int, + const int, const int, + const float, + const float *, const int, + const float *, const int, + const float, + float *, const int +); + +void ATL_srefgpmv +( + const enum ATLAS_UPLO, + const enum ATLAS_TRANS, + const int, const int, + const float, + const float *, const int, + const float *, const int, + const float, + float *, const int +); + +void ATL_srefgemv +( + const enum ATLAS_TRANS, + const int, const int, + const float, + const float *, const int, + const float *, const int, + const float, + float *, const int +); + +void ATL_srefgpr +( + const enum ATLAS_UPLO, + const int, const int, + const float, + const float *, const int, + const float *, const int, + float *, const int +); + +void ATL_srefger +( + const int, const int, + const float, + const float *, const int, + const float *, const int, + float *, const int +); + +void ATL_srefsbmv +( + const enum ATLAS_UPLO, + const int, const int, + const float, + const float *, const int, + const float *, const int, + const float, + float *, const int +); + +void ATL_srefspmv +( + const enum ATLAS_UPLO, + const int, + const float, + const float *, + const float *, const int, + const float, + float *, const int +); + +void ATL_srefspr +( + const enum ATLAS_UPLO, + const int, + const float, + const float *, const int, + float * +); + +void ATL_srefspr2 +( + const enum ATLAS_UPLO, + const int, + const float, + const float *, const int, + const float *, const int, + float * +); + +void ATL_srefsymv +( + const enum ATLAS_UPLO, + const int, + const float, + const float *, const int, + const float *, const int, + const float, + float *, const int +); + +void ATL_srefsyr +( + const enum ATLAS_UPLO, + const int, + const float, + const float *, const int, + float *, const int +); + +void ATL_srefsyr2 +( + const enum ATLAS_UPLO, + const int, + const float, + const float *, const int, + const float *, const int, + float *, const int +); + +void ATL_sreftbmv +( + const enum ATLAS_UPLO, const enum ATLAS_TRANS, const enum ATLAS_DIAG, + const int, const int, + const float *, const int, + float *, const int +); + +void ATL_sreftbsv +( + const enum ATLAS_UPLO, const enum ATLAS_TRANS, const enum ATLAS_DIAG, + const int, const int, + const float *, const int, + float *, const int +); + +void ATL_sreftpmv +( + const enum ATLAS_UPLO, const enum ATLAS_TRANS, const enum ATLAS_DIAG, + const int, + const float *, + float *, const int +); + +void ATL_sreftpsv +( + const enum ATLAS_UPLO, const enum ATLAS_TRANS, const enum ATLAS_DIAG, + const int, + const float *, + float *, const int +); + +void ATL_sreftrmv +( + const enum ATLAS_UPLO, const enum ATLAS_TRANS, const enum ATLAS_DIAG, + const int, + const float *, const int, + float *, const int +); + +void ATL_sreftrsv +( + const enum ATLAS_UPLO, const enum ATLAS_TRANS, const enum ATLAS_DIAG, + const int, + const float *, const int, + float *, const int +); + +void ATL_drefgbmv +( + const enum ATLAS_TRANS, + const int, const int, + const int, const int, + const double, + const double *, const int, + const double *, const int, + const double, + double *, const int +); + +void ATL_drefgpmv +( + const enum ATLAS_UPLO, + const enum ATLAS_TRANS, + const int, const int, + const double, + const double *, const int, + const double *, const int, + const double, + double *, const int +); + +void ATL_drefgemv +( + const enum ATLAS_TRANS, + const int, const int, + const double, + const double *, const int, + const double *, const int, + const double, + double *, const int +); + +void ATL_drefgpr +( + const enum ATLAS_UPLO, + const int, const int, + const double, + const double *, const int, + const double *, const int, + double *, const int +); + +void ATL_drefger +( + const int, const int, + const double, + const double *, const int, + const double *, const int, + double *, const int +); + +void ATL_drefsbmv +( + const enum ATLAS_UPLO, + const int, const int, + const double, + const double *, const int, + const double *, const int, + const double, + double *, const int +); + +void ATL_drefspmv +( + const enum ATLAS_UPLO, + const int, + const double, + const double *, + const double *, const int, + const double, + double *, const int +); + +void ATL_drefspr +( + const enum ATLAS_UPLO, + const int, + const double, + const double *, const int, + double * +); + +void ATL_drefspr2 +( + const enum ATLAS_UPLO, + const int, + const double, + const double *, const int, + const double *, const int, + double * +); + +void ATL_drefsymv +( + const enum ATLAS_UPLO, + const int, + const double, + const double *, const int, + const double *, const int, + const double, + double *, const int +); + +void ATL_drefsyr +( + const enum ATLAS_UPLO, + const int, + const double, + const double *, const int, + double *, const int +); + +void ATL_drefsyr2 +( + const enum ATLAS_UPLO, + const int, + const double, + const double *, const int, + const double *, const int, + double *, const int +); + +void ATL_dreftbmv +( + const enum ATLAS_UPLO, const enum ATLAS_TRANS, const enum ATLAS_DIAG, + const int, const int, + const double *, const int, + double *, const int +); + +void ATL_dreftbsv +( + const enum ATLAS_UPLO, const enum ATLAS_TRANS, const enum ATLAS_DIAG, + const int, const int, + const double *, const int, + double *, const int +); + +void ATL_dreftpmv +( + const enum ATLAS_UPLO, const enum ATLAS_TRANS, const enum ATLAS_DIAG, + const int, + const double *, + double *, const int +); + +void ATL_dreftpsv +( + const enum ATLAS_UPLO, const enum ATLAS_TRANS, const enum ATLAS_DIAG, + const int, + const double *, + double *, const int +); + +void ATL_dreftrmv +( + const enum ATLAS_UPLO, const enum ATLAS_TRANS, const enum ATLAS_DIAG, + const int, + const double *, const int, + double *, const int +); + +void ATL_dreftrsv +( + const enum ATLAS_UPLO, const enum ATLAS_TRANS, const enum ATLAS_DIAG, + const int, + const double *, const int, + double *, const int +); + +void ATL_crefgbmv +( + const enum ATLAS_TRANS, + const int, const int, + const int, const int, + const float *, + const float *, const int, + const float *, const int, + const float *, + float *, const int +); + +void ATL_crefgpmv +( + const enum ATLAS_UPLO, + const enum ATLAS_TRANS, + const int, const int, + const float *, + const float *, const int, + const float *, const int, + const float *, + float *, const int +); + +void ATL_crefgemv +( + const enum ATLAS_TRANS, + const int, const int, + const float *, + const float *, const int, + const float *, const int, + const float *, + float *, const int +); + +void ATL_crefgprc +( + const enum ATLAS_UPLO, + const int, const int, + const float *, + const float *, const int, + const float *, const int, + float *, const int +); + +void ATL_crefgpru +( + const enum ATLAS_UPLO, + const int, const int, + const float *, + const float *, const int, + const float *, const int, + float *, const int +); + +void ATL_crefgerc +( + const int, const int, + const float *, + const float *, const int, + const float *, const int, + float *, const int +); + +void ATL_crefgeru +( + const int, const int, + const float *, + const float *, const int, + const float *, const int, + float *, const int +); + +void ATL_crefhbmv +( + const enum ATLAS_UPLO, + const int, const int, + const float *, + const float *, const int, + const float *, const int, + const float *, + float *, const int +); + +void ATL_crefhpmv +( + const enum ATLAS_UPLO, + const int, + const float *, + const float *, + const float *, const int, + const float *, + float *, const int +); + +void ATL_crefhpr +( + const enum ATLAS_UPLO, + const int, + const float, + const float *, const int, + float * +); + +void ATL_crefhpr2 +( + const enum ATLAS_UPLO, + const int, + const float *, + const float *, const int, + const float *, const int, + float * +); + +void ATL_crefhemv +( + const enum ATLAS_UPLO, + const int, + const float *, + const float *, const int, + const float *, const int, + const float *, + float *, const int +); + +void ATL_crefher +( + const enum ATLAS_UPLO, + const int, + const float, + const float *, const int, + float *, const int +); + +void ATL_crefher2 +( + const enum ATLAS_UPLO, + const int, + const float *, + const float *, const int, + const float *, const int, + float *, const int +); + +void ATL_creftbmv +( + const enum ATLAS_UPLO, const enum ATLAS_TRANS, const enum ATLAS_DIAG, + const int, const int, + const float *, const int, + float *, const int +); + +void ATL_creftbsv +( + const enum ATLAS_UPLO, const enum ATLAS_TRANS, const enum ATLAS_DIAG, + const int, const int, + const float *, const int, + float *, const int +); + +void ATL_creftpmv +( + const enum ATLAS_UPLO, const enum ATLAS_TRANS, const enum ATLAS_DIAG, + const int, + const float *, + float *, const int +); + +void ATL_creftpsv +( + const enum ATLAS_UPLO, const enum ATLAS_TRANS, const enum ATLAS_DIAG, + const int, + const float *, + float *, const int +); + +void ATL_creftrmv +( + const enum ATLAS_UPLO, const enum ATLAS_TRANS, const enum ATLAS_DIAG, + const int, + const float *, const int, + float *, const int +); + +void ATL_creftrsv +( + const enum ATLAS_UPLO, const enum ATLAS_TRANS, const enum ATLAS_DIAG, + const int, + const float *, const int, + float *, const int +); + +void ATL_zrefgbmv +( + const enum ATLAS_TRANS, + const int, const int, + const int, const int, + const double *, + const double *, const int, + const double *, const int, + const double *, + double *, const int +); + +void ATL_zrefgpmv +( + const enum ATLAS_UPLO, + const enum ATLAS_TRANS, + const int, const int, + const double *, + const double *, const int, + const double *, const int, + const double *, + double *, const int +); + +void ATL_zrefgemv +( + const enum ATLAS_TRANS, + const int, const int, + const double *, + const double *, const int, + const double *, const int, + const double *, + double *, const int +); + +void ATL_zrefgprc +( + const enum ATLAS_UPLO, + const int, const int, + const double *, + const double *, const int, + const double *, const int, + double *, const int +); + +void ATL_zrefgpru +( + const enum ATLAS_UPLO, + const int, const int, + const double *, + const double *, const int, + const double *, const int, + double *, const int +); + +void ATL_zrefgerc +( + const int, const int, + const double *, + const double *, const int, + const double *, const int, + double *, const int +); + +void ATL_zrefgeru +( + const int, const int, + const double *, + const double *, const int, + const double *, const int, + double *, const int +); + +void ATL_zrefhbmv +( + const enum ATLAS_UPLO, + const int, const int, + const double *, + const double *, const int, + const double *, const int, + const double *, + double *, const int +); + +void ATL_zrefhpmv +( + const enum ATLAS_UPLO, + const int, + const double *, + const double *, + const double *, const int, + const double *, + double *, const int +); + +void ATL_zrefhpr +( + const enum ATLAS_UPLO, + const int, + const double, + const double *, const int, + double * +); + +void ATL_zrefhpr2 +( + const enum ATLAS_UPLO, + const int, + const double *, + const double *, const int, + const double *, const int, + double * +); + +void ATL_zrefhemv +( + const enum ATLAS_UPLO, + const int, + const double *, + const double *, const int, + const double *, const int, + const double *, + double *, const int +); + +void ATL_zrefher +( + const enum ATLAS_UPLO, + const int, + const double, + const double *, const int, + double *, const int +); + +void ATL_zrefher2 +( + const enum ATLAS_UPLO, + const int, + const double *, + const double *, const int, + const double *, const int, + double *, const int +); + +void ATL_zreftbmv +( + const enum ATLAS_UPLO, const enum ATLAS_TRANS, const enum ATLAS_DIAG, + const int, const int, + const double *, const int, + double *, const int +); + +void ATL_zreftbsv +( + const enum ATLAS_UPLO, const enum ATLAS_TRANS, const enum ATLAS_DIAG, + const int, const int, + const double *, const int, + double *, const int +); + +void ATL_zreftpmv +( + const enum ATLAS_UPLO, const enum ATLAS_TRANS, const enum ATLAS_DIAG, + const int, + const double *, + double *, const int +); + +void ATL_zreftpsv +( + const enum ATLAS_UPLO, const enum ATLAS_TRANS, const enum ATLAS_DIAG, + const int, + const double *, + double *, const int +); + +void ATL_zreftrmv +( + const enum ATLAS_UPLO, const enum ATLAS_TRANS, const enum ATLAS_DIAG, + const int, + const double *, const int, + double *, const int +); + +void ATL_zreftrsv +( + const enum ATLAS_UPLO, const enum ATLAS_TRANS, const enum ATLAS_DIAG, + const int, + const double *, const int, + double *, const int +); + +#endif +/* + * End of atlas_reflevel2.h + */ diff --git a/kaldi_io/src/tools/ATLAS/include/atlas_reflevel3.h b/kaldi_io/src/tools/ATLAS/include/atlas_reflevel3.h new file mode 100644 index 0000000..eba976b --- /dev/null +++ b/kaldi_io/src/tools/ATLAS/include/atlas_reflevel3.h @@ -0,0 +1,374 @@ +/* --------------------------------------------------------------------- + * + * -- Automatically Tuned Linear Algebra Software (ATLAS) + * (C) Copyright 2000 All Rights Reserved + * + * -- ATLAS routine -- Version 3.2 -- December 25, 2000 + * + * Author : Antoine P. Petitet + * Originally developed at the University of Tennessee, + * Innovative Computing Laboratory, Knoxville TN, 37996-1301, USA. + * + * --------------------------------------------------------------------- + * + * -- Copyright notice and Licensing terms: + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions + * are met: + * + * 1. Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * 2. Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions, and the following disclaimer in + * the documentation and/or other materials provided with the distri- + * bution. + * 3. The name of the University, the ATLAS group, or the names of its + * contributors may not be used to endorse or promote products deri- + * ved from this software without specific written permission. + * + * -- Disclaimer: + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + * ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE UNIVERSITY + * OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPE- + * CIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED + * TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, + * OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEO- + * RY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (IN- + * CLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF + * THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + * --------------------------------------------------------------------- + */ +#ifndef ATLAS_REFLEVEL3_H +#define ATLAS_REFLEVEL3_H + +#include "atlas_enum.h" +/* + * ===================================================================== + * Prototypes for Level 3 Reference ATLAS BLAS routines + * ===================================================================== + */ +void ATL_srefgemm +( + const enum ATLAS_TRANS, const enum ATLAS_TRANS, + const int, const int, const int, + const float, + const float *, const int, + const float *, const int, + const float, + float *, const int +); + +void ATL_srefsymm +( + const enum ATLAS_SIDE, const enum ATLAS_UPLO, + const int, const int, + const float, + const float *, const int, + const float *, const int, + const float, + float *, const int +); + +void ATL_srefsyrk +( + const enum ATLAS_UPLO, const enum ATLAS_TRANS, + const int, const int, + const float, + const float *, const int, + const float, + float *, const int +); + +void ATL_srefsyr2k +( + const enum ATLAS_UPLO, const enum ATLAS_TRANS, + const int, const int, + const float, + const float *, const int, + const float *, const int, + const float, + float *, const int +); + +void ATL_sreftrmm +( + const enum ATLAS_SIDE, const enum ATLAS_UPLO, + const enum ATLAS_TRANS, const enum ATLAS_DIAG, + const int, const int, + const float, + const float *, const int, + float *, const int +); + +void ATL_sreftrsm +( + const enum ATLAS_SIDE, const enum ATLAS_UPLO, + const enum ATLAS_TRANS, const enum ATLAS_DIAG, + const int, const int, + const float, + const float *, const int, + float *, const int +); + +void ATL_drefgemm +( + const enum ATLAS_TRANS, const enum ATLAS_TRANS, + const int, const int, const int, + const double, + const double *, const int, + const double *, const int, + const double, + double *, const int +); + +void ATL_drefsymm +( + const enum ATLAS_SIDE, const enum ATLAS_UPLO, + const int, const int, + const double, + const double *, const int, + const double *, const int, + const double, + double *, const int +); + +void ATL_drefsyrk +( + const enum ATLAS_UPLO, const enum ATLAS_TRANS, + const int, const int, + const double, + const double *, const int, + const double, + double *, const int +); + +void ATL_drefsyr2k +( + const enum ATLAS_UPLO, const enum ATLAS_TRANS, + const int, const int, + const double, + const double *, const int, + const double *, const int, + const double, + double *, const int +); + +void ATL_dreftrmm +( + const enum ATLAS_SIDE, const enum ATLAS_UPLO, + const enum ATLAS_TRANS, const enum ATLAS_DIAG, + const int, const int, + const double, + const double *, const int, + double *, const int +); + +void ATL_dreftrsm +( + const enum ATLAS_SIDE, const enum ATLAS_UPLO, + const enum ATLAS_TRANS, const enum ATLAS_DIAG, + const int, const int, + const double, + const double *, const int, + double *, const int +); + +void ATL_crefgemm +( + const enum ATLAS_TRANS, const enum ATLAS_TRANS, + const int, const int, const int, + const float *, + const float *, const int, + const float *, const int, + const float *, + float *, const int +); + +void ATL_crefhemm +( + const enum ATLAS_SIDE, const enum ATLAS_UPLO, + const int, const int, + const float *, + const float *, const int, + const float *, const int, + const float *, + float *, const int +); + +void ATL_crefherk +( + const enum ATLAS_UPLO, const enum ATLAS_TRANS, + const int, const int, + const float, + const float *, const int, + const float, + float *, const int +); + +void ATL_crefher2k +( + const enum ATLAS_UPLO, const enum ATLAS_TRANS, + const int, const int, + const float *, + const float *, const int, + const float *, const int, + const float, + float *, const int +); + +void ATL_crefsymm +( + const enum ATLAS_SIDE, const enum ATLAS_UPLO, + const int, const int, + const float *, + const float *, const int, + const float *, const int, + const float *, + float *, const int +); + +void ATL_crefsyrk +( + const enum ATLAS_UPLO, const enum ATLAS_TRANS, + const int, const int, + const float *, + const float *, const int, + const float *, + float *, const int +); + +void ATL_crefsyr2k +( + const enum ATLAS_UPLO, const enum ATLAS_TRANS, + const int, const int, + const float *, + const float *, const int, + const float *, const int, + const float *, + float *, const int +); + +void ATL_creftrmm +( + const enum ATLAS_SIDE, const enum ATLAS_UPLO, + const enum ATLAS_TRANS, const enum ATLAS_DIAG, + const int, const int, + const float *, + const float *, const int, + float *, const int +); + +void ATL_creftrsm +( + const enum ATLAS_SIDE, const enum ATLAS_UPLO, + const enum ATLAS_TRANS, const enum ATLAS_DIAG, + const int, const int, + const float *, + const float *, const int, + float *, const int +); + +void ATL_zrefgemm +( + const enum ATLAS_TRANS, const enum ATLAS_TRANS, + const int, const int, const int, + const double *, + const double *, const int, + const double *, const int, + const double *, + double *, const int +); + +void ATL_zrefhemm +( + const enum ATLAS_SIDE, const enum ATLAS_UPLO, + const int, const int, + const double *, + const double *, const int, + const double *, const int, + const double *, + double *, const int +); + +void ATL_zrefherk +( + const enum ATLAS_UPLO, const enum ATLAS_TRANS, + const int, const int, + const double, + const double *, const int, + const double, + double *, const int +); + +void ATL_zrefher2k +( + const enum ATLAS_UPLO, const enum ATLAS_TRANS, + const int, const int, + const double *, + const double *, const int, + const double *, const int, + const double, + double *, const int +); + +void ATL_zrefsymm +( + const enum ATLAS_SIDE, const enum ATLAS_UPLO, + const int, const int, + const double *, + const double *, const int, + const double *, const int, + const double *, + double *, const int +); + +void ATL_zrefsyrk +( + const enum ATLAS_UPLO, const enum ATLAS_TRANS, + const int, const int, + const double *, + const double *, const int, + const double *, + double *, const int +); + +void ATL_zrefsyr2k +( + const enum ATLAS_UPLO, const enum ATLAS_TRANS, + const int, const int, + const double *, + const double *, const int, + const double *, const int, + const double *, + double *, const int +); + +void ATL_zreftrmm +( + const enum ATLAS_SIDE, const enum ATLAS_UPLO, + const enum ATLAS_TRANS, const enum ATLAS_DIAG, + const int, const int, + const double *, + const double *, const int, + double *, const int +); + +void ATL_zreftrsm +( + const enum ATLAS_SIDE, const enum ATLAS_UPLO, + const enum ATLAS_TRANS, const enum ATLAS_DIAG, + const int, const int, + const double *, + const double *, const int, + double *, const int +); + +#endif +/* + * End of atlas_reflevel3.h + */ diff --git a/kaldi_io/src/tools/ATLAS/include/atlas_reflvl2.h b/kaldi_io/src/tools/ATLAS/include/atlas_reflvl2.h new file mode 100644 index 0000000..c557f04 --- /dev/null +++ b/kaldi_io/src/tools/ATLAS/include/atlas_reflvl2.h @@ -0,0 +1,3184 @@ +/* --------------------------------------------------------------------- + * + * -- Automatically Tuned Linear Algebra Software (ATLAS) + * (C) Copyright 2000 All Rights Reserved + * + * -- ATLAS routine -- Version 3.2 -- December 25, 2000 + * + * Author : Antoine P. Petitet + * Originally developed at the University of Tennessee, + * Innovative Computing Laboratory, Knoxville TN, 37996-1301, USA. + * + * --------------------------------------------------------------------- + * + * -- Copyright notice and Licensing terms: + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions + * are met: + * + * 1. Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * 2. Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions, and the following disclaimer in + * the documentation and/or other materials provided with the distri- + * bution. + * 3. The name of the University, the ATLAS group, or the names of its + * contributors may not be used to endorse or promote products deri- + * ved from this software without specific written permission. + * + * -- Disclaimer: + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + * ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE UNIVERSITY + * OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPE- + * CIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED + * TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, + * OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEO- + * RY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (IN- + * CLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF + * THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + * --------------------------------------------------------------------- + */ +#ifndef ATLAS_REFLVL2_H +#define ATLAS_REFLVL2_H +/* + * ===================================================================== + * Prototypes for Level 2 Reference Internal ATLAS BLAS routines + * ===================================================================== + */ +void ATL_srefgbmvN +( + const int, const int, + const int, const int, + const float, + const float *, const int, + const float *, const int, + const float, + float *, const int +); + +void ATL_srefgbmvT +( + const int, const int, + const int, const int, + const float, + const float *, const int, + const float *, const int, + const float, + float *, const int +); + +void ATL_srefgpmvUN +( + const int, const int, + const float, + const float *, const int, + const float *, const int, + const float, + float *, const int +); + +void ATL_srefgpmvUT +( + const int, const int, + const float, + const float *, const int, + const float *, const int, + const float, + float *, const int +); + +void ATL_srefgpmvLN +( + const int, const int, + const float, + const float *, const int, + const float *, const int, + const float, + float *, const int +); + +void ATL_srefgpmvLT +( + const int, const int, + const float, + const float *, const int, + const float *, const int, + const float, + float *, const int +); + +void ATL_srefgemvN +( + const int, const int, + const float, + const float *, const int, + const float *, const int, + const float, + float *, const int +); + +void ATL_srefgemvT +( + const int, const int, + const float, + const float *, const int, + const float *, const int, + const float, + float *, const int +); + +void ATL_srefgprL +( + const int, const int, + const float, + const float *, const int, + const float *, const int, + float *, const int +); + +void ATL_srefgprU +( + const int, const int, + const float, + const float *, const int, + const float *, const int, + float *, const int +); + +void ATL_srefsbmvL +( + const int, const int, + const float, + const float *, const int, + const float *, const int, + const float, + float *, const int +); + +void ATL_srefsbmvU +( + const int, const int, + const float, + const float *, const int, + const float *, const int, + const float, + float *, const int +); + +void ATL_srefspmvL +( + const int, + const float, + const float *, const int, + const float *, const int, + const float, + float *, const int +); + +void ATL_srefspmvU +( + const int, + const float, + const float *, const int, + const float *, const int, + const float, + float *, const int +); + +void ATL_srefsprL +( + const int, + const float, + const float *, const int, + float *, const int +); + +void ATL_srefsprU +( + const int, + const float, + const float *, const int, + float *, const int +); + +void ATL_srefspr2L +( + const int, + const float, + const float *, const int, + const float *, const int, + float *, const int +); + +void ATL_srefspr2U +( + const int, + const float, + const float *, const int, + const float *, const int, + float *, const int +); + +void ATL_srefsymvL +( + const int, + const float, + const float *, const int, + const float *, const int, + const float, + float *, const int +); + +void ATL_srefsymvU +( + const int, + const float, + const float *, const int, + const float *, const int, + const float, + float *, const int +); + +void ATL_srefsyrL +( + const int, + const float, + const float *, const int, + float *, const int +); + +void ATL_srefsyrU +( + const int, + const float, + const float *, const int, + float *, const int +); + +void ATL_srefsyr2L +( + const int, + const float, + const float *, const int, + const float *, const int, + float *, const int +); + +void ATL_srefsyr2U +( + const int, + const float, + const float *, const int, + const float *, const int, + float *, const int +); + +void ATL_sreftbmvLNN +( + const int, const int, + const float *, const int, + float *, const int +); + +void ATL_sreftbmvLNU +( + const int, const int, + const float *, const int, + float *, const int +); + +void ATL_sreftbmvLTN +( + const int, const int, + const float *, const int, + float *, const int +); + +void ATL_sreftbmvLTU +( + const int, const int, + const float *, const int, + float *, const int +); + +void ATL_sreftbmvUNN +( + const int, const int, + const float *, const int, + float *, const int +); + +void ATL_sreftbmvUNU +( + const int, const int, + const float *, const int, + float *, const int +); + +void ATL_sreftbmvUTN +( + const int, const int, + const float *, const int, + float *, const int +); + +void ATL_sreftbmvUTU +( + const int, const int, + const float *, const int, + float *, const int +); + +void ATL_sreftpmvLNN +( + const int, + const float *, const int, + float *, const int +); + +void ATL_sreftpmvLNU +( + const int, + const float *, const int, + float *, const int +); + +void ATL_sreftpmvLTN +( + const int, + const float *, const int, + float *, const int +); + +void ATL_sreftpmvLTU +( + const int, + const float *, const int, + float *, const int +); + +void ATL_sreftpmvUNN +( + const int, + const float *, const int, + float *, const int +); + +void ATL_sreftpmvUNU +( + const int, + const float *, const int, + float *, const int +); + +void ATL_sreftpmvUTN +( + const int, + const float *, const int, + float *, const int +); + +void ATL_sreftpmvUTU +( + const int, + const float *, const int, + float *, const int +); + +void ATL_sreftrmvLNN +( + const int, + const float *, const int, + float *, const int +); + +void ATL_sreftrmvLNU +( + const int, + const float *, const int, + float *, const int +); + +void ATL_sreftrmvLTN +( + const int, + const float *, const int, + float *, const int +); + +void ATL_sreftrmvLTU +( + const int, + const float *, const int, + float *, const int +); + +void ATL_sreftrmvUNN +( + const int, + const float *, const int, + float *, const int +); + +void ATL_sreftrmvUNU +( + const int, + const float *, const int, + float *, const int +); + +void ATL_sreftrmvUTN +( + const int, + const float *, const int, + float *, const int +); + +void ATL_sreftrmvUTU +( + const int, + const float *, const int, + float *, const int +); + +void ATL_sreftbsvLNN +( + const int, const int, + const float *, const int, + float *, const int +); + +void ATL_sreftbsvLNU +( + const int, const int, + const float *, const int, + float *, const int +); + +void ATL_sreftbsvLTN +( + const int, const int, + const float *, const int, + float *, const int +); + +void ATL_sreftbsvLTU +( + const int, const int, + const float *, const int, + float *, const int +); + +void ATL_sreftbsvUNN +( + const int, const int, + const float *, const int, + float *, const int +); + +void ATL_sreftbsvUNU +( + const int, const int, + const float *, const int, + float *, const int +); + +void ATL_sreftbsvUTN +( + const int, const int, + const float *, const int, + float *, const int +); + +void ATL_sreftbsvUTU +( + const int, const int, + const float *, const int, + float *, const int +); + +void ATL_sreftpsvLNN +( + const int, + const float *, const int, + float *, const int +); + +void ATL_sreftpsvLNU +( + const int, + const float *, const int, + float *, const int +); + +void ATL_sreftpsvLTN +( + const int, + const float *, const int, + float *, const int +); + +void ATL_sreftpsvLTU +( + const int, + const float *, const int, + float *, const int +); + +void ATL_sreftpsvUNN +( + const int, + const float *, const int, + float *, const int +); + +void ATL_sreftpsvUNU +( + const int, + const float *, const int, + float *, const int +); + +void ATL_sreftpsvUTN +( + const int, + const float *, const int, + float *, const int +); + +void ATL_sreftpsvUTU +( + const int, + const float *, const int, + float *, const int +); + +void ATL_sreftrsvLNN +( + const int, + const float *, const int, + float *, const int +); + +void ATL_sreftrsvLNU +( + const int, + const float *, const int, + float *, const int +); + +void ATL_sreftrsvLTN +( + const int, + const float *, const int, + float *, const int +); + +void ATL_sreftrsvLTU +( + const int, + const float *, const int, + float *, const int +); + +void ATL_sreftrsvUNN +( + const int, + const float *, const int, + float *, const int +); + +void ATL_sreftrsvUNU +( + const int, + const float *, const int, + float *, const int +); + +void ATL_sreftrsvUTN +( + const int, + const float *, const int, + float *, const int +); + +void ATL_sreftrsvUTU +( + const int, + const float *, const int, + float *, const int +); + +void ATL_drefgbmvN +( + const int, const int, + const int, const int, + const double, + const double *, const int, + const double *, const int, + const double, + double *, const int +); + +void ATL_drefgbmvT +( + const int, const int, + const int, const int, + const double, + const double *, const int, + const double *, const int, + const double, + double *, const int +); + +void ATL_drefgpmvUN +( + const int, const int, + const double, + const double *, const int, + const double *, const int, + const double, + double *, const int +); + +void ATL_drefgpmvUT +( + const int, const int, + const double, + const double *, const int, + const double *, const int, + const double, + double *, const int +); + +void ATL_drefgpmvLN +( + const int, const int, + const double, + const double *, const int, + const double *, const int, + const double, + double *, const int +); + +void ATL_drefgpmvLT +( + const int, const int, + const double, + const double *, const int, + const double *, const int, + const double, + double *, const int +); + +void ATL_drefgemvN +( + const int, const int, + const double, + const double *, const int, + const double *, const int, + const double, + double *, const int +); + +void ATL_drefgemvT +( + const int, const int, + const double, + const double *, const int, + const double *, const int, + const double, + double *, const int +); + +void ATL_drefgprL +( + const int, const int, + const double, + const double *, const int, + const double *, const int, + double *, const int +); + +void ATL_drefgprU +( + const int, const int, + const double, + const double *, const int, + const double *, const int, + double *, const int +); + +void ATL_drefsbmvL +( + const int, const int, + const double, + const double *, const int, + const double *, const int, + const double, + double *, const int +); + +void ATL_drefsbmvU +( + const int, const int, + const double, + const double *, const int, + const double *, const int, + const double, + double *, const int +); + +void ATL_drefspmvL +( + const int, + const double, + const double *, const int, + const double *, const int, + const double, + double *, const int +); + +void ATL_drefspmvU +( + const int, + const double, + const double *, const int, + const double *, const int, + const double, + double *, const int +); + +void ATL_drefsprL +( + const int, + const double, + const double *, const int, + double *, const int +); + +void ATL_drefsprU +( + const int, + const double, + const double *, const int, + double *, const int +); + +void ATL_drefspr2L +( + const int, + const double, + const double *, const int, + const double *, const int, + double *, const int +); + +void ATL_drefspr2U +( + const int, + const double, + const double *, const int, + const double *, const int, + double *, const int +); + +void ATL_drefsymvL +( + const int, + const double, + const double *, const int, + const double *, const int, + const double, + double *, const int +); + +void ATL_drefsymvU +( + const int, + const double, + const double *, const int, + const double *, const int, + const double, + double *, const int +); + +void ATL_drefsyrL +( + const int, + const double, + const double *, const int, + double *, const int +); + +void ATL_drefsyrU +( + const int, + const double, + const double *, const int, + double *, const int +); + +void ATL_drefsyr2L +( + const int, + const double, + const double *, const int, + const double *, const int, + double *, const int +); + +void ATL_drefsyr2U +( + const int, + const double, + const double *, const int, + const double *, const int, + double *, const int +); + +void ATL_dreftbmvLNN +( + const int, const int, + const double *, const int, + double *, const int +); + +void ATL_dreftbmvLNU +( + const int, const int, + const double *, const int, + double *, const int +); + +void ATL_dreftbmvLTN +( + const int, const int, + const double *, const int, + double *, const int +); + +void ATL_dreftbmvLTU +( + const int, const int, + const double *, const int, + double *, const int +); + +void ATL_dreftbmvUNN +( + const int, const int, + const double *, const int, + double *, const int +); + +void ATL_dreftbmvUNU +( + const int, const int, + const double *, const int, + double *, const int +); + +void ATL_dreftbmvUTN +( + const int, const int, + const double *, const int, + double *, const int +); + +void ATL_dreftbmvUTU +( + const int, const int, + const double *, const int, + double *, const int +); + +void ATL_dreftpmvLNN +( + const int, + const double *, const int, + double *, const int +); + +void ATL_dreftpmvLNU +( + const int, + const double *, const int, + double *, const int +); + +void ATL_dreftpmvLTN +( + const int, + const double *, const int, + double *, const int +); + +void ATL_dreftpmvLTU +( + const int, + const double *, const int, + double *, const int +); + +void ATL_dreftpmvUNN +( + const int, + const double *, const int, + double *, const int +); + +void ATL_dreftpmvUNU +( + const int, + const double *, const int, + double *, const int +); + +void ATL_dreftpmvUTN +( + const int, + const double *, const int, + double *, const int +); + +void ATL_dreftpmvUTU +( + const int, + const double *, const int, + double *, const int +); + +void ATL_dreftrmvLNN +( + const int, + const double *, const int, + double *, const int +); + +void ATL_dreftrmvLNU +( + const int, + const double *, const int, + double *, const int +); + +void ATL_dreftrmvLTN +( + const int, + const double *, const int, + double *, const int +); + +void ATL_dreftrmvLTU +( + const int, + const double *, const int, + double *, const int +); + +void ATL_dreftrmvUNN +( + const int, + const double *, const int, + double *, const int +); + +void ATL_dreftrmvUNU +( + const int, + const double *, const int, + double *, const int +); + +void ATL_dreftrmvUTN +( + const int, + const double *, const int, + double *, const int +); + +void ATL_dreftrmvUTU +( + const int, + const double *, const int, + double *, const int +); + +void ATL_dreftbsvLNN +( + const int, const int, + const double *, const int, + double *, const int +); + +void ATL_dreftbsvLNU +( + const int, const int, + const double *, const int, + double *, const int +); + +void ATL_dreftbsvLTN +( + const int, const int, + const double *, const int, + double *, const int +); + +void ATL_dreftbsvLTU +( + const int, const int, + const double *, const int, + double *, const int +); + +void ATL_dreftbsvUNN +( + const int, const int, + const double *, const int, + double *, const int +); + +void ATL_dreftbsvUNU +( + const int, const int, + const double *, const int, + double *, const int +); + +void ATL_dreftbsvUTN +( + const int, const int, + const double *, const int, + double *, const int +); + +void ATL_dreftbsvUTU +( + const int, const int, + const double *, const int, + double *, const int +); + +void ATL_dreftpsvLNN +( + const int, + const double *, const int, + double *, const int +); + +void ATL_dreftpsvLNU +( + const int, + const double *, const int, + double *, const int +); + +void ATL_dreftpsvLTN +( + const int, + const double *, const int, + double *, const int +); + +void ATL_dreftpsvLTU +( + const int, + const double *, const int, + double *, const int +); + +void ATL_dreftpsvUNN +( + const int, + const double *, const int, + double *, const int +); + +void ATL_dreftpsvUNU +( + const int, + const double *, const int, + double *, const int +); + +void ATL_dreftpsvUTN +( + const int, + const double *, const int, + double *, const int +); + +void ATL_dreftpsvUTU +( + const int, + const double *, const int, + double *, const int +); + +void ATL_dreftrsvLNN +( + const int, + const double *, const int, + double *, const int +); + +void ATL_dreftrsvLNU +( + const int, + const double *, const int, + double *, const int +); + +void ATL_dreftrsvLTN +( + const int, + const double *, const int, + double *, const int +); + +void ATL_dreftrsvLTU +( + const int, + const double *, const int, + double *, const int +); + +void ATL_dreftrsvUNN +( + const int, + const double *, const int, + double *, const int +); + +void ATL_dreftrsvUNU +( + const int, + const double *, const int, + double *, const int +); + +void ATL_dreftrsvUTN +( + const int, + const double *, const int, + double *, const int +); + +void ATL_dreftrsvUTU +( + const int, + const double *, const int, + double *, const int +); + +void ATL_crefgbmvN +( + const int, const int, + const int, const int, + const float *, + const float *, const int, + const float *, const int, + const float *, + float *, const int +); + +void ATL_crefgbmvT +( + const int, const int, + const int, const int, + const float *, + const float *, const int, + const float *, const int, + const float *, + float *, const int +); + +void ATL_crefgbmvC +( + const int, const int, + const int, const int, + const float *, + const float *, const int, + const float *, const int, + const float *, + float *, const int +); + +void ATL_crefgbmvH +( + const int, const int, + const int, const int, + const float *, + const float *, const int, + const float *, const int, + const float *, + float *, const int +); + +void ATL_crefgpmvUN +( + const int, const int, + const float *, + const float *, const int, + const float *, const int, + const float *, + float *, const int +); + +void ATL_crefgpmvUT +( + const int, const int, + const float *, + const float *, const int, + const float *, const int, + const float *, + float *, const int +); + +void ATL_crefgpmvUC +( + const int, const int, + const float *, + const float *, const int, + const float *, const int, + const float *, + float *, const int +); + +void ATL_crefgpmvUH +( + const int, const int, + const float *, + const float *, const int, + const float *, const int, + const float *, + float *, const int +); + +void ATL_crefgpmvLN +( + const int, const int, + const float *, + const float *, const int, + const float *, const int, + const float *, + float *, const int +); + +void ATL_crefgpmvLT +( + const int, const int, + const float *, + const float *, const int, + const float *, const int, + const float *, + float *, const int +); + +void ATL_crefgpmvLC +( + const int, const int, + const float *, + const float *, const int, + const float *, const int, + const float *, + float *, const int +); + +void ATL_crefgpmvLH +( + const int, const int, + const float *, + const float *, const int, + const float *, const int, + const float *, + float *, const int +); + +void ATL_crefgemvN +( + const int, const int, + const float *, + const float *, const int, + const float *, const int, + const float *, + float *, const int +); + +void ATL_crefgemvT +( + const int, const int, + const float *, + const float *, const int, + const float *, const int, + const float *, + float *, const int +); + +void ATL_crefgemvC +( + const int, const int, + const float *, + const float *, const int, + const float *, const int, + const float *, + float *, const int +); + +void ATL_crefgemvH +( + const int, const int, + const float *, + const float *, const int, + const float *, const int, + const float *, + float *, const int +); + +void ATL_crefgprcL +( + const int, const int, + const float *, + const float *, const int, + const float *, const int, + float *, const int +); + +void ATL_crefgprcU +( + const int, const int, + const float *, + const float *, const int, + const float *, const int, + float *, const int +); + +void ATL_crefgpruL +( + const int, const int, + const float *, + const float *, const int, + const float *, const int, + float *, const int +); + +void ATL_crefgpruU +( + const int, const int, + const float *, + const float *, const int, + const float *, const int, + float *, const int +); + +void ATL_crefhbmvL +( + const int, const int, + const float *, + const float *, const int, + const float *, const int, + const float *, + float *, const int +); + +void ATL_crefhbmvU +( + const int, const int, + const float *, + const float *, const int, + const float *, const int, + const float *, + float *, const int +); + +void ATL_crefhpmvL +( + const int, + const float *, + const float *, const int, + const float *, const int, + const float *, + float *, const int +); + +void ATL_crefhpmvU +( + const int, + const float *, + const float *, const int, + const float *, const int, + const float *, + float *, const int +); + +void ATL_crefhprL +( + const int, + const float, + const float *, const int, + float *, const int +); + +void ATL_crefhprU +( + const int, + const float, + const float *, const int, + float *, const int +); + +void ATL_crefhpr2L +( + const int, + const float *, + const float *, const int, + const float *, const int, + float *, const int +); + +void ATL_crefhpr2U +( + const int, + const float *, + const float *, const int, + const float *, const int, + float *, const int +); + +void ATL_crefhemvL +( + const int, + const float *, + const float *, const int, + const float *, const int, + const float *, + float *, const int +); + +void ATL_crefhemvU +( + const int, + const float *, + const float *, const int, + const float *, const int, + const float *, + float *, const int +); + +void ATL_crefherL +( + const int, + const float, + const float *, const int, + float *, const int +); + +void ATL_crefherU +( + const int, + const float, + const float *, const int, + float *, const int +); + +void ATL_crefher2L +( + const int, + const float *, + const float *, const int, + const float *, const int, + float *, const int +); + +void ATL_crefher2U +( + const int, + const float *, + const float *, const int, + const float *, const int, + float *, const int +); + +void ATL_creftbmvLNN +( + const int, const int, + const float *, const int, + float *, const int +); + +void ATL_creftbmvLNU +( + const int, const int, + const float *, const int, + float *, const int +); + +void ATL_creftbmvLTN +( + const int, const int, + const float *, const int, + float *, const int +); + +void ATL_creftbmvLTU +( + const int, const int, + const float *, const int, + float *, const int +); + +void ATL_creftbmvLCN +( + const int, const int, + const float *, const int, + float *, const int +); + +void ATL_creftbmvLCU +( + const int, const int, + const float *, const int, + float *, const int +); + +void ATL_creftbmvLHN +( + const int, const int, + const float *, const int, + float *, const int +); + +void ATL_creftbmvLHU +( + const int, const int, + const float *, const int, + float *, const int +); + +void ATL_creftbmvUNN +( + const int, const int, + const float *, const int, + float *, const int +); + +void ATL_creftbmvUNU +( + const int, const int, + const float *, const int, + float *, const int +); + +void ATL_creftbmvUTN +( + const int, const int, + const float *, const int, + float *, const int +); + +void ATL_creftbmvUTU +( + const int, const int, + const float *, const int, + float *, const int +); + +void ATL_creftbmvUCN +( + const int, const int, + const float *, const int, + float *, const int +); + +void ATL_creftbmvUCU +( + const int, const int, + const float *, const int, + float *, const int +); + +void ATL_creftbmvUHN +( + const int, const int, + const float *, const int, + float *, const int +); + +void ATL_creftbmvUHU +( + const int, const int, + const float *, const int, + float *, const int +); + +void ATL_creftpmvLNN +( + const int, + const float *, const int, + float *, const int +); + +void ATL_creftpmvLNU +( + const int, + const float *, const int, + float *, const int +); + +void ATL_creftpmvLTN +( + const int, + const float *, const int, + float *, const int +); + +void ATL_creftpmvLTU +( + const int, + const float *, const int, + float *, const int +); + +void ATL_creftpmvLCN +( + const int, + const float *, const int, + float *, const int +); + +void ATL_creftpmvLCU +( + const int, + const float *, const int, + float *, const int +); + +void ATL_creftpmvLHN +( + const int, + const float *, const int, + float *, const int +); + +void ATL_creftpmvLHU +( + const int, + const float *, const int, + float *, const int +); + +void ATL_creftpmvUNN +( + const int, + const float *, const int, + float *, const int +); + +void ATL_creftpmvUNU +( + const int, + const float *, const int, + float *, const int +); + +void ATL_creftpmvUTN +( + const int, + const float *, const int, + float *, const int +); + +void ATL_creftpmvUTU +( + const int, + const float *, const int, + float *, const int +); + +void ATL_creftpmvUCN +( + const int, + const float *, const int, + float *, const int +); + +void ATL_creftpmvUCU +( + const int, + const float *, const int, + float *, const int +); + +void ATL_creftpmvUHN +( + const int, + const float *, const int, + float *, const int +); + +void ATL_creftpmvUHU +( + const int, + const float *, const int, + float *, const int +); + +void ATL_creftrmvLNN +( + const int, + const float *, const int, + float *, const int +); + +void ATL_creftrmvLNU +( + const int, + const float *, const int, + float *, const int +); + +void ATL_creftrmvLTN +( + const int, + const float *, const int, + float *, const int +); + +void ATL_creftrmvLTU +( + const int, + const float *, const int, + float *, const int +); + +void ATL_creftrmvLCN +( + const int, + const float *, const int, + float *, const int +); + +void ATL_creftrmvLCU +( + const int, + const float *, const int, + float *, const int +); + +void ATL_creftrmvLHN +( + const int, + const float *, const int, + float *, const int +); + +void ATL_creftrmvLHU +( + const int, + const float *, const int, + float *, const int +); + +void ATL_creftrmvUNN +( + const int, + const float *, const int, + float *, const int +); + +void ATL_creftrmvUNU +( + const int, + const float *, const int, + float *, const int +); + +void ATL_creftrmvUTN +( + const int, + const float *, const int, + float *, const int +); + +void ATL_creftrmvUTU +( + const int, + const float *, const int, + float *, const int +); + +void ATL_creftrmvUCN +( + const int, + const float *, const int, + float *, const int +); + +void ATL_creftrmvUCU +( + const int, + const float *, const int, + float *, const int +); + +void ATL_creftrmvUHN +( + const int, + const float *, const int, + float *, const int +); + +void ATL_creftrmvUHU +( + const int, + const float *, const int, + float *, const int +); + +void ATL_creftbsvLNN +( + const int, const int, + const float *, const int, + float *, const int +); + +void ATL_creftbsvLNU +( + const int, const int, + const float *, const int, + float *, const int +); + +void ATL_creftbsvLTN +( + const int, const int, + const float *, const int, + float *, const int +); + +void ATL_creftbsvLTU +( + const int, const int, + const float *, const int, + float *, const int +); + +void ATL_creftbsvLCN +( + const int, const int, + const float *, const int, + float *, const int +); + +void ATL_creftbsvLCU +( + const int, const int, + const float *, const int, + float *, const int +); + +void ATL_creftbsvLHN +( + const int, const int, + const float *, const int, + float *, const int +); + +void ATL_creftbsvLHU +( + const int, const int, + const float *, const int, + float *, const int +); + +void ATL_creftbsvUNN +( + const int, const int, + const float *, const int, + float *, const int +); + +void ATL_creftbsvUNU +( + const int, const int, + const float *, const int, + float *, const int +); + +void ATL_creftbsvUTN +( + const int, const int, + const float *, const int, + float *, const int +); + +void ATL_creftbsvUTU +( + const int, const int, + const float *, const int, + float *, const int +); + +void ATL_creftbsvUCN +( + const int, const int, + const float *, const int, + float *, const int +); + +void ATL_creftbsvUCU +( + const int, const int, + const float *, const int, + float *, const int +); + +void ATL_creftbsvUHN +( + const int, const int, + const float *, const int, + float *, const int +); + +void ATL_creftbsvUHU +( + const int, const int, + const float *, const int, + float *, const int +); + +void ATL_creftpsvLNN +( + const int, + const float *, const int, + float *, const int +); + +void ATL_creftpsvLNU +( + const int, + const float *, const int, + float *, const int +); + +void ATL_creftpsvLTN +( + const int, + const float *, const int, + float *, const int +); + +void ATL_creftpsvLTU +( + const int, + const float *, const int, + float *, const int +); + +void ATL_creftpsvLCN +( + const int, + const float *, const int, + float *, const int +); + +void ATL_creftpsvLCU +( + const int, + const float *, const int, + float *, const int +); + +void ATL_creftpsvLHN +( + const int, + const float *, const int, + float *, const int +); + +void ATL_creftpsvLHU +( + const int, + const float *, const int, + float *, const int +); + +void ATL_creftpsvUNN +( + const int, + const float *, const int, + float *, const int +); + +void ATL_creftpsvUNU +( + const int, + const float *, const int, + float *, const int +); + +void ATL_creftpsvUTN +( + const int, + const float *, const int, + float *, const int +); + +void ATL_creftpsvUTU +( + const int, + const float *, const int, + float *, const int +); + +void ATL_creftpsvUCN +( + const int, + const float *, const int, + float *, const int +); + +void ATL_creftpsvUCU +( + const int, + const float *, const int, + float *, const int +); + +void ATL_creftpsvUHN +( + const int, + const float *, const int, + float *, const int +); + +void ATL_creftpsvUHU +( + const int, + const float *, const int, + float *, const int +); + +void ATL_creftrsvLNN +( + const int, + const float *, const int, + float *, const int +); + +void ATL_creftrsvLNU +( + const int, + const float *, const int, + float *, const int +); + +void ATL_creftrsvLTN +( + const int, + const float *, const int, + float *, const int +); + +void ATL_creftrsvLTU +( + const int, + const float *, const int, + float *, const int +); + +void ATL_creftrsvLCN +( + const int, + const float *, const int, + float *, const int +); + +void ATL_creftrsvLCU +( + const int, + const float *, const int, + float *, const int +); + +void ATL_creftrsvLHN +( + const int, + const float *, const int, + float *, const int +); + +void ATL_creftrsvLHU +( + const int, + const float *, const int, + float *, const int +); + +void ATL_creftrsvUNN +( + const int, + const float *, const int, + float *, const int +); + +void ATL_creftrsvUNU +( + const int, + const float *, const int, + float *, const int +); + +void ATL_creftrsvUTN +( + const int, + const float *, const int, + float *, const int +); + +void ATL_creftrsvUTU +( + const int, + const float *, const int, + float *, const int +); + +void ATL_creftrsvUCN +( + const int, + const float *, const int, + float *, const int +); + +void ATL_creftrsvUCU +( + const int, + const float *, const int, + float *, const int +); + +void ATL_creftrsvUHN +( + const int, + const float *, const int, + float *, const int +); + +void ATL_creftrsvUHU +( + const int, + const float *, const int, + float *, const int +); + +void ATL_zrefgbmvN +( + const int, const int, + const int, const int, + const double *, + const double *, const int, + const double *, const int, + const double *, + double *, const int +); + +void ATL_zrefgbmvT +( + const int, const int, + const int, const int, + const double *, + const double *, const int, + const double *, const int, + const double *, + double *, const int +); + +void ATL_zrefgbmvC +( + const int, const int, + const int, const int, + const double *, + const double *, const int, + const double *, const int, + const double *, + double *, const int +); + +void ATL_zrefgbmvH +( + const int, const int, + const int, const int, + const double *, + const double *, const int, + const double *, const int, + const double *, + double *, const int +); + +void ATL_zrefgpmvUN +( + const int, const int, + const double *, + const double *, const int, + const double *, const int, + const double *, + double *, const int +); + +void ATL_zrefgpmvUT +( + const int, const int, + const double *, + const double *, const int, + const double *, const int, + const double *, + double *, const int +); + +void ATL_zrefgpmvUC +( + const int, const int, + const double *, + const double *, const int, + const double *, const int, + const double *, + double *, const int +); + +void ATL_zrefgpmvUH +( + const int, const int, + const double *, + const double *, const int, + const double *, const int, + const double *, + double *, const int +); + +void ATL_zrefgpmvLN +( + const int, const int, + const double *, + const double *, const int, + const double *, const int, + const double *, + double *, const int +); + +void ATL_zrefgpmvLT +( + const int, const int, + const double *, + const double *, const int, + const double *, const int, + const double *, + double *, const int +); + +void ATL_zrefgpmvLC +( + const int, const int, + const double *, + const double *, const int, + const double *, const int, + const double *, + double *, const int +); + +void ATL_zrefgpmvLH +( + const int, const int, + const double *, + const double *, const int, + const double *, const int, + const double *, + double *, const int +); + +void ATL_zrefgemvN +( + const int, const int, + const double *, + const double *, const int, + const double *, const int, + const double *, + double *, const int +); + +void ATL_zrefgemvT +( + const int, const int, + const double *, + const double *, const int, + const double *, const int, + const double *, + double *, const int +); + +void ATL_zrefgemvC +( + const int, const int, + const double *, + const double *, const int, + const double *, const int, + const double *, + double *, const int +); + +void ATL_zrefgemvH +( + const int, const int, + const double *, + const double *, const int, + const double *, const int, + const double *, + double *, const int +); + +void ATL_zrefgprcL +( + const int, const int, + const double *, + const double *, const int, + const double *, const int, + double *, const int +); + +void ATL_zrefgprcU +( + const int, const int, + const double *, + const double *, const int, + const double *, const int, + double *, const int +); + +void ATL_zrefgpruL +( + const int, const int, + const double *, + const double *, const int, + const double *, const int, + double *, const int +); + +void ATL_zrefgpruU +( + const int, const int, + const double *, + const double *, const int, + const double *, const int, + double *, const int +); + +void ATL_zrefhbmvL +( + const int, const int, + const double *, + const double *, const int, + const double *, const int, + const double *, + double *, const int +); + +void ATL_zrefhbmvU +( + const int, const int, + const double *, + const double *, const int, + const double *, const int, + const double *, + double *, const int +); + +void ATL_zrefhpmvL +( + const int, + const double *, + const double *, const int, + const double *, const int, + const double *, + double *, const int +); + +void ATL_zrefhpmvU +( + const int, + const double *, + const double *, const int, + const double *, const int, + const double *, + double *, const int +); + +void ATL_zrefhprL +( + const int, + const double, + const double *, const int, + double *, const int +); + +void ATL_zrefhprU +( + const int, + const double, + const double *, const int, + double *, const int +); + +void ATL_zrefhpr2L +( + const int, + const double *, + const double *, const int, + const double *, const int, + double *, const int +); + +void ATL_zrefhpr2U +( + const int, + const double *, + const double *, const int, + const double *, const int, + double *, const int +); + +void ATL_zrefhemvL +( + const int, + const double *, + const double *, const int, + const double *, const int, + const double *, + double *, const int +); + +void ATL_zrefhemvU +( + const int, + const double *, + const double *, const int, + const double *, const int, + const double *, + double *, const int +); + +void ATL_zrefherL +( + const int, + const double, + const double *, const int, + double *, const int +); + +void ATL_zrefherU +( + const int, + const double, + const double *, const int, + double *, const int +); + +void ATL_zrefher2L +( + const int, + const double *, + const double *, const int, + const double *, const int, + double *, const int +); + +void ATL_zrefher2U +( + const int, + const double *, + const double *, const int, + const double *, const int, + double *, const int +); + +void ATL_zreftbmvLNN +( + const int, const int, + const double *, const int, + double *, const int +); + +void ATL_zreftbmvLNU +( + const int, const int, + const double *, const int, + double *, const int +); + +void ATL_zreftbmvLTN +( + const int, const int, + const double *, const int, + double *, const int +); + +void ATL_zreftbmvLTU +( + const int, const int, + const double *, const int, + double *, const int +); + +void ATL_zreftbmvLCN +( + const int, const int, + const double *, const int, + double *, const int +); + +void ATL_zreftbmvLCU +( + const int, const int, + const double *, const int, + double *, const int +); + +void ATL_zreftbmvLHN +( + const int, const int, + const double *, const int, + double *, const int +); + +void ATL_zreftbmvLHU +( + const int, const int, + const double *, const int, + double *, const int +); + +void ATL_zreftbmvUNN +( + const int, const int, + const double *, const int, + double *, const int +); + +void ATL_zreftbmvUNU +( + const int, const int, + const double *, const int, + double *, const int +); + +void ATL_zreftbmvUTN +( + const int, const int, + const double *, const int, + double *, const int +); + +void ATL_zreftbmvUTU +( + const int, const int, + const double *, const int, + double *, const int +); + +void ATL_zreftbmvUCN +( + const int, const int, + const double *, const int, + double *, const int +); + +void ATL_zreftbmvUCU +( + const int, const int, + const double *, const int, + double *, const int +); + +void ATL_zreftbmvUHN +( + const int, const int, + const double *, const int, + double *, const int +); + +void ATL_zreftbmvUHU +( + const int, const int, + const double *, const int, + double *, const int +); + +void ATL_zreftpmvLNN +( + const int, + const double *, const int, + double *, const int +); + +void ATL_zreftpmvLNU +( + const int, + const double *, const int, + double *, const int +); + +void ATL_zreftpmvLTN +( + const int, + const double *, const int, + double *, const int +); + +void ATL_zreftpmvLTU +( + const int, + const double *, const int, + double *, const int +); + +void ATL_zreftpmvLCN +( + const int, + const double *, const int, + double *, const int +); + +void ATL_zreftpmvLCU +( + const int, + const double *, const int, + double *, const int +); + +void ATL_zreftpmvLHN +( + const int, + const double *, const int, + double *, const int +); + +void ATL_zreftpmvLHU +( + const int, + const double *, const int, + double *, const int +); + +void ATL_zreftpmvUNN +( + const int, + const double *, const int, + double *, const int +); + +void ATL_zreftpmvUNU +( + const int, + const double *, const int, + double *, const int +); + +void ATL_zreftpmvUTN +( + const int, + const double *, const int, + double *, const int +); + +void ATL_zreftpmvUTU +( + const int, + const double *, const int, + double *, const int +); + +void ATL_zreftpmvUCN +( + const int, + const double *, const int, + double *, const int +); + +void ATL_zreftpmvUCU +( + const int, + const double *, const int, + double *, const int +); + +void ATL_zreftpmvUHN +( + const int, + const double *, const int, + double *, const int +); + +void ATL_zreftpmvUHU +( + const int, + const double *, const int, + double *, const int +); + +void ATL_zreftrmvLNN +( + const int, + const double *, const int, + double *, const int +); + +void ATL_zreftrmvLNU +( + const int, + const double *, const int, + double *, const int +); + +void ATL_zreftrmvLTN +( + const int, + const double *, const int, + double *, const int +); + +void ATL_zreftrmvLTU +( + const int, + const double *, const int, + double *, const int +); + +void ATL_zreftrmvLCN +( + const int, + const double *, const int, + double *, const int +); + +void ATL_zreftrmvLCU +( + const int, + const double *, const int, + double *, const int +); + +void ATL_zreftrmvLHN +( + const int, + const double *, const int, + double *, const int +); + +void ATL_zreftrmvLHU +( + const int, + const double *, const int, + double *, const int +); + +void ATL_zreftrmvUNN +( + const int, + const double *, const int, + double *, const int +); + +void ATL_zreftrmvUNU +( + const int, + const double *, const int, + double *, const int +); + +void ATL_zreftrmvUTN +( + const int, + const double *, const int, + double *, const int +); + +void ATL_zreftrmvUTU +( + const int, + const double *, const int, + double *, const int +); + +void ATL_zreftrmvUCN +( + const int, + const double *, const int, + double *, const int +); + +void ATL_zreftrmvUCU +( + const int, + const double *, const int, + double *, const int +); + +void ATL_zreftrmvUHN +( + const int, + const double *, const int, + double *, const int +); + +void ATL_zreftrmvUHU +( + const int, + const double *, const int, + double *, const int +); + +void ATL_zreftbsvLNN +( + const int, const int, + const double *, const int, + double *, const int +); + +void ATL_zreftbsvLNU +( + const int, const int, + const double *, const int, + double *, const int +); + +void ATL_zreftbsvLTN +( + const int, const int, + const double *, const int, + double *, const int +); + +void ATL_zreftbsvLTU +( + const int, const int, + const double *, const int, + double *, const int +); + +void ATL_zreftbsvLCN +( + const int, const int, + const double *, const int, + double *, const int +); + +void ATL_zreftbsvLCU +( + const int, const int, + const double *, const int, + double *, const int +); + +void ATL_zreftbsvLHN +( + const int, const int, + const double *, const int, + double *, const int +); + +void ATL_zreftbsvLHU +( + const int, const int, + const double *, const int, + double *, const int +); + +void ATL_zreftbsvUNN +( + const int, const int, + const double *, const int, + double *, const int +); + +void ATL_zreftbsvUNU +( + const int, const int, + const double *, const int, + double *, const int +); + +void ATL_zreftbsvUTN +( + const int, const int, + const double *, const int, + double *, const int +); + +void ATL_zreftbsvUTU +( + const int, const int, + const double *, const int, + double *, const int +); + +void ATL_zreftbsvUCN +( + const int, const int, + const double *, const int, + double *, const int +); + +void ATL_zreftbsvUCU +( + const int, const int, + const double *, const int, + double *, const int +); + +void ATL_zreftbsvUHN +( + const int, const int, + const double *, const int, + double *, const int +); + +void ATL_zreftbsvUHU +( + const int, const int, + const double *, const int, + double *, const int +); + +void ATL_zreftpsvLNN +( + const int, + const double *, const int, + double *, const int +); + +void ATL_zreftpsvLNU +( + const int, + const double *, const int, + double *, const int +); + +void ATL_zreftpsvLTN +( + const int, + const double *, const int, + double *, const int +); + +void ATL_zreftpsvLTU +( + const int, + const double *, const int, + double *, const int +); + +void ATL_zreftpsvLCN +( + const int, + const double *, const int, + double *, const int +); + +void ATL_zreftpsvLCU +( + const int, + const double *, const int, + double *, const int +); + +void ATL_zreftpsvLHN +( + const int, + const double *, const int, + double *, const int +); + +void ATL_zreftpsvLHU +( + const int, + const double *, const int, + double *, const int +); + +void ATL_zreftpsvUNN +( + const int, + const double *, const int, + double *, const int +); + +void ATL_zreftpsvUNU +( + const int, + const double *, const int, + double *, const int +); + +void ATL_zreftpsvUTN +( + const int, + const double *, const int, + double *, const int +); + +void ATL_zreftpsvUTU +( + const int, + const double *, const int, + double *, const int +); + +void ATL_zreftpsvUCN +( + const int, + const double *, const int, + double *, const int +); + +void ATL_zreftpsvUCU +( + const int, + const double *, const int, + double *, const int +); + +void ATL_zreftpsvUHN +( + const int, + const double *, const int, + double *, const int +); + +void ATL_zreftpsvUHU +( + const int, + const double *, const int, + double *, const int +); + +void ATL_zreftrsvLNN +( + const int, + const double *, const int, + double *, const int +); + +void ATL_zreftrsvLNU +( + const int, + const double *, const int, + double *, const int +); + +void ATL_zreftrsvLTN +( + const int, + const double *, const int, + double *, const int +); + +void ATL_zreftrsvLTU +( + const int, + const double *, const int, + double *, const int +); + +void ATL_zreftrsvLCN +( + const int, + const double *, const int, + double *, const int +); + +void ATL_zreftrsvLCU +( + const int, + const double *, const int, + double *, const int +); + +void ATL_zreftrsvLHN +( + const int, + const double *, const int, + double *, const int +); + +void ATL_zreftrsvLHU +( + const int, + const double *, const int, + double *, const int +); + +void ATL_zreftrsvUNN +( + const int, + const double *, const int, + double *, const int +); + +void ATL_zreftrsvUNU +( + const int, + const double *, const int, + double *, const int +); + +void ATL_zreftrsvUTN +( + const int, + const double *, const int, + double *, const int +); + +void ATL_zreftrsvUTU +( + const int, + const double *, const int, + double *, const int +); + +void ATL_zreftrsvUCN +( + const int, + const double *, const int, + double *, const int +); + +void ATL_zreftrsvUCU +( + const int, + const double *, const int, + double *, const int +); + +void ATL_zreftrsvUHN +( + const int, + const double *, const int, + double *, const int +); + +void ATL_zreftrsvUHU +( + const int, + const double *, const int, + double *, const int +); + +#endif +/* + * End of atlas_reflvl2.h + */ diff --git a/kaldi_io/src/tools/ATLAS/include/atlas_reflvl3.h b/kaldi_io/src/tools/ATLAS/include/atlas_reflvl3.h new file mode 100644 index 0000000..0451ff9 --- /dev/null +++ b/kaldi_io/src/tools/ATLAS/include/atlas_reflvl3.h @@ -0,0 +1,2292 @@ +/* --------------------------------------------------------------------- + * + * -- Automatically Tuned Linear Algebra Software (ATLAS) + * (C) Copyright 2000 All Rights Reserved + * + * -- ATLAS routine -- Version 3.2 -- December 25, 2000 + * + * Author : Antoine P. Petitet + * Originally developed at the University of Tennessee, + * Innovative Computing Laboratory, Knoxville TN, 37996-1301, USA. + * + * --------------------------------------------------------------------- + * + * -- Copyright notice and Licensing terms: + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions + * are met: + * + * 1. Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * 2. Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions, and the following disclaimer in + * the documentation and/or other materials provided with the distri- + * bution. + * 3. The name of the University, the ATLAS group, or the names of its + * contributors may not be used to endorse or promote products deri- + * ved from this software without specific written permission. + * + * -- Disclaimer: + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + * ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE UNIVERSITY + * OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPE- + * CIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED + * TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, + * OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEO- + * RY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (IN- + * CLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF + * THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + * --------------------------------------------------------------------- + */ +#ifndef ATLAS_REFLVL3_H +#define ATLAS_REFLVL3_H +/* + * ===================================================================== + * Prototypes for Level 3 Reference Internal ATLAS BLAS routines + * ===================================================================== + */ +void ATL_srefgemmNN +( + const int, const int, const int, + const float, + const float *, const int, + const float *, const int, + const float, + float *, const int +); + +void ATL_srefgemmNT +( + const int, const int, const int, + const float, + const float *, const int, + const float *, const int, + const float, + float *, const int +); + +void ATL_srefgemmTN +( + const int, const int, const int, + const float, + const float *, const int, + const float *, const int, + const float, + float *, const int +); + +void ATL_srefgemmTT +( + const int, const int, const int, + const float, + const float *, const int, + const float *, const int, + const float, + float *, const int +); + +void ATL_srefsymmLL +( + const int, const int, + const float, + const float *, const int, + const float *, const int, + const float, + float *, const int +); + +void ATL_srefsymmLU +( + const int, const int, + const float, + const float *, const int, + const float *, const int, + const float, + float *, const int +); + +void ATL_srefsymmRL +( + const int, const int, + const float, + const float *, const int, + const float *, const int, + const float, + float *, const int +); + +void ATL_srefsymmRU +( + const int, const int, + const float, + const float *, const int, + const float *, const int, + const float, + float *, const int +); + +void ATL_srefsyrkLN +( + const int, const int, + const float, + const float *, const int, + const float, + float *, const int +); + +void ATL_srefsyrkLT +( + const int, const int, + const float, + const float *, const int, + const float, + float *, const int +); + +void ATL_srefsyrkUN +( + const int, const int, + const float, + const float *, const int, + const float, + float *, const int +); + +void ATL_srefsyrkUT +( + const int, const int, + const float, + const float *, const int, + const float, + float *, const int +); + +void ATL_srefsyr2kLN +( + const int, const int, + const float, + const float *, const int, + const float *, const int, + const float, + float *, const int +); + +void ATL_srefsyr2kLT +( + const int, const int, + const float, + const float *, const int, + const float *, const int, + const float, + float *, const int +); + +void ATL_srefsyr2kUN +( + const int, const int, + const float, + const float *, const int, + const float *, const int, + const float, + float *, const int +); + +void ATL_srefsyr2kUT +( + const int, const int, + const float, + const float *, const int, + const float *, const int, + const float, + float *, const int +); + +void ATL_sreftrmmLLNN +( + const int, const int, + const float, + const float *, const int, + float *, const int +); + +void ATL_sreftrmmLLNU +( + const int, const int, + const float, + const float *, const int, + float *, const int +); + +void ATL_sreftrmmLLTN +( + const int, const int, + const float, + const float *, const int, + float *, const int +); + +void ATL_sreftrmmLLTU +( + const int, const int, + const float, + const float *, const int, + float *, const int +); + +void ATL_sreftrmmLUNN +( + const int, const int, + const float, + const float *, const int, + float *, const int +); + +void ATL_sreftrmmLUNU +( + const int, const int, + const float, + const float *, const int, + float *, const int +); + +void ATL_sreftrmmLUTN +( + const int, const int, + const float, + const float *, const int, + float *, const int +); + +void ATL_sreftrmmLUTU +( + const int, const int, + const float, + const float *, const int, + float *, const int +); + +void ATL_sreftrmmRLNN +( + const int, const int, + const float, + const float *, const int, + float *, const int +); + +void ATL_sreftrmmRLNU +( + const int, const int, + const float, + const float *, const int, + float *, const int +); + +void ATL_sreftrmmRLTN +( + const int, const int, + const float, + const float *, const int, + float *, const int +); + +void ATL_sreftrmmRLTU +( + const int, const int, + const float, + const float *, const int, + float *, const int +); + +void ATL_sreftrmmRUNN +( + const int, const int, + const float, + const float *, const int, + float *, const int +); + +void ATL_sreftrmmRUNU +( + const int, const int, + const float, + const float *, const int, + float *, const int +); + +void ATL_sreftrmmRUTN +( + const int, const int, + const float, + const float *, const int, + float *, const int +); + +void ATL_sreftrmmRUTU +( + const int, const int, + const float, + const float *, const int, + float *, const int +); + +void ATL_sreftrsmLLNN +( + const int, const int, + const float, + const float *, const int, + float *, const int +); + +void ATL_sreftrsmLLNU +( + const int, const int, + const float, + const float *, const int, + float *, const int +); + +void ATL_sreftrsmLLTN +( + const int, const int, + const float, + const float *, const int, + float *, const int +); + +void ATL_sreftrsmLLTU +( + const int, const int, + const float, + const float *, const int, + float *, const int +); + +void ATL_sreftrsmLUNN +( + const int, const int, + const float, + const float *, const int, + float *, const int +); + +void ATL_sreftrsmLUNU +( + const int, const int, + const float, + const float *, const int, + float *, const int +); + +void ATL_sreftrsmLUTN +( + const int, const int, + const float, + const float *, const int, + float *, const int +); + +void ATL_sreftrsmLUTU +( + const int, const int, + const float, + const float *, const int, + float *, const int +); + +void ATL_sreftrsmRLNN +( + const int, const int, + const float, + const float *, const int, + float *, const int +); + +void ATL_sreftrsmRLNU +( + const int, const int, + const float, + const float *, const int, + float *, const int +); + +void ATL_sreftrsmRLTN +( + const int, const int, + const float, + const float *, const int, + float *, const int +); + +void ATL_sreftrsmRLTU +( + const int, const int, + const float, + const float *, const int, + float *, const int +); + +void ATL_sreftrsmRUNN +( + const int, const int, + const float, + const float *, const int, + float *, const int +); + +void ATL_sreftrsmRUNU +( + const int, const int, + const float, + const float *, const int, + float *, const int +); + +void ATL_sreftrsmRUTN +( + const int, const int, + const float, + const float *, const int, + float *, const int +); + +void ATL_sreftrsmRUTU +( + const int, const int, + const float, + const float *, const int, + float *, const int +); + +void ATL_drefgemmNN +( + const int, const int, const int, + const double, + const double *, const int, + const double *, const int, + const double, + double *, const int +); + +void ATL_drefgemmNT +( + const int, const int, const int, + const double, + const double *, const int, + const double *, const int, + const double, + double *, const int +); + +void ATL_drefgemmTN +( + const int, const int, const int, + const double, + const double *, const int, + const double *, const int, + const double, + double *, const int +); + +void ATL_drefgemmTT +( + const int, const int, const int, + const double, + const double *, const int, + const double *, const int, + const double, + double *, const int +); + +void ATL_drefsymmLL +( + const int, const int, + const double, + const double *, const int, + const double *, const int, + const double, + double *, const int +); + +void ATL_drefsymmLU +( + const int, const int, + const double, + const double *, const int, + const double *, const int, + const double, + double *, const int +); + +void ATL_drefsymmRL +( + const int, const int, + const double, + const double *, const int, + const double *, const int, + const double, + double *, const int +); + +void ATL_drefsymmRU +( + const int, const int, + const double, + const double *, const int, + const double *, const int, + const double, + double *, const int +); + +void ATL_drefsyrkLN +( + const int, const int, + const double, + const double *, const int, + const double, + double *, const int +); + +void ATL_drefsyrkLT +( + const int, const int, + const double, + const double *, const int, + const double, + double *, const int +); + +void ATL_drefsyrkUN +( + const int, const int, + const double, + const double *, const int, + const double, + double *, const int +); + +void ATL_drefsyrkUT +( + const int, const int, + const double, + const double *, const int, + const double, + double *, const int +); + +void ATL_drefsyr2kLN +( + const int, const int, + const double, + const double *, const int, + const double *, const int, + const double, + double *, const int +); + +void ATL_drefsyr2kLT +( + const int, const int, + const double, + const double *, const int, + const double *, const int, + const double, + double *, const int +); + +void ATL_drefsyr2kUN +( + const int, const int, + const double, + const double *, const int, + const double *, const int, + const double, + double *, const int +); + +void ATL_drefsyr2kUT +( + const int, const int, + const double, + const double *, const int, + const double *, const int, + const double, + double *, const int +); + +void ATL_dreftrmmLLNN +( + const int, const int, + const double, + const double *, const int, + double *, const int +); + +void ATL_dreftrmmLLNU +( + const int, const int, + const double, + const double *, const int, + double *, const int +); + +void ATL_dreftrmmLLTN +( + const int, const int, + const double, + const double *, const int, + double *, const int +); + +void ATL_dreftrmmLLTU +( + const int, const int, + const double, + const double *, const int, + double *, const int +); + +void ATL_dreftrmmLUNN +( + const int, const int, + const double, + const double *, const int, + double *, const int +); + +void ATL_dreftrmmLUNU +( + const int, const int, + const double, + const double *, const int, + double *, const int +); + +void ATL_dreftrmmLUTN +( + const int, const int, + const double, + const double *, const int, + double *, const int +); + +void ATL_dreftrmmLUTU +( + const int, const int, + const double, + const double *, const int, + double *, const int +); + +void ATL_dreftrmmRLNN +( + const int, const int, + const double, + const double *, const int, + double *, const int +); + +void ATL_dreftrmmRLNU +( + const int, const int, + const double, + const double *, const int, + double *, const int +); + +void ATL_dreftrmmRLTN +( + const int, const int, + const double, + const double *, const int, + double *, const int +); + +void ATL_dreftrmmRLTU +( + const int, const int, + const double, + const double *, const int, + double *, const int +); + +void ATL_dreftrmmRUNN +( + const int, const int, + const double, + const double *, const int, + double *, const int +); + +void ATL_dreftrmmRUNU +( + const int, const int, + const double, + const double *, const int, + double *, const int +); + +void ATL_dreftrmmRUTN +( + const int, const int, + const double, + const double *, const int, + double *, const int +); + +void ATL_dreftrmmRUTU +( + const int, const int, + const double, + const double *, const int, + double *, const int +); + +void ATL_dreftrsmLLNN +( + const int, const int, + const double, + const double *, const int, + double *, const int +); + +void ATL_dreftrsmLLNU +( + const int, const int, + const double, + const double *, const int, + double *, const int +); + +void ATL_dreftrsmLLTN +( + const int, const int, + const double, + const double *, const int, + double *, const int +); + +void ATL_dreftrsmLLTU +( + const int, const int, + const double, + const double *, const int, + double *, const int +); + +void ATL_dreftrsmLUNN +( + const int, const int, + const double, + const double *, const int, + double *, const int +); + +void ATL_dreftrsmLUNU +( + const int, const int, + const double, + const double *, const int, + double *, const int +); + +void ATL_dreftrsmLUTN +( + const int, const int, + const double, + const double *, const int, + double *, const int +); + +void ATL_dreftrsmLUTU +( + const int, const int, + const double, + const double *, const int, + double *, const int +); + +void ATL_dreftrsmRLNN +( + const int, const int, + const double, + const double *, const int, + double *, const int +); + +void ATL_dreftrsmRLNU +( + const int, const int, + const double, + const double *, const int, + double *, const int +); + +void ATL_dreftrsmRLTN +( + const int, const int, + const double, + const double *, const int, + double *, const int +); + +void ATL_dreftrsmRLTU +( + const int, const int, + const double, + const double *, const int, + double *, const int +); + +void ATL_dreftrsmRUNN +( + const int, const int, + const double, + const double *, const int, + double *, const int +); + +void ATL_dreftrsmRUNU +( + const int, const int, + const double, + const double *, const int, + double *, const int +); + +void ATL_dreftrsmRUTN +( + const int, const int, + const double, + const double *, const int, + double *, const int +); + +void ATL_dreftrsmRUTU +( + const int, const int, + const double, + const double *, const int, + double *, const int +); + +void ATL_crefgemmNN +( + const int, const int, const int, + const float *, + const float *, const int, + const float *, const int, + const float *, + float *, const int +); + +void ATL_crefgemmNT +( + const int, const int, const int, + const float *, + const float *, const int, + const float *, const int, + const float *, + float *, const int +); + +void ATL_crefgemmNC +( + const int, const int, const int, + const float *, + const float *, const int, + const float *, const int, + const float *, + float *, const int +); + +void ATL_crefgemmTN +( + const int, const int, const int, + const float *, + const float *, const int, + const float *, const int, + const float *, + float *, const int +); + +void ATL_crefgemmTT +( + const int, const int, const int, + const float *, + const float *, const int, + const float *, const int, + const float *, + float *, const int +); + +void ATL_crefgemmTC +( + const int, const int, const int, + const float *, + const float *, const int, + const float *, const int, + const float *, + float *, const int +); + +void ATL_crefgemmCN +( + const int, const int, const int, + const float *, + const float *, const int, + const float *, const int, + const float *, + float *, const int +); + +void ATL_crefgemmCT +( + const int, const int, const int, + const float *, + const float *, const int, + const float *, const int, + const float *, + float *, const int +); + +void ATL_crefgemmCC +( + const int, const int, const int, + const float *, + const float *, const int, + const float *, const int, + const float *, + float *, const int +); + +void ATL_crefhemmLL +( + const int, const int, + const float *, + const float *, const int, + const float *, const int, + const float *, + float *, const int +); + +void ATL_crefhemmLU +( + const int, const int, + const float *, + const float *, const int, + const float *, const int, + const float *, + float *, const int +); + +void ATL_crefhemmRL +( + const int, const int, + const float *, + const float *, const int, + const float *, const int, + const float *, + float *, const int +); + +void ATL_crefhemmRU +( + const int, const int, + const float *, + const float *, const int, + const float *, const int, + const float *, + float *, const int +); + +void ATL_crefherkLN +( + const int, const int, + const float, + const float *, const int, + const float, + float *, const int +); + +void ATL_crefherkLC +( + const int, const int, + const float, + const float *, const int, + const float, + float *, const int +); + +void ATL_crefherkUN +( + const int, const int, + const float, + const float *, const int, + const float, + float *, const int +); + +void ATL_crefherkUC +( + const int, const int, + const float, + const float *, const int, + const float, + float *, const int +); + +void ATL_crefher2kLN +( + const int, const int, + const float *, + const float *, const int, + const float *, const int, + const float, + float *, const int +); + +void ATL_crefher2kLC +( + const int, const int, + const float *, + const float *, const int, + const float *, const int, + const float, + float *, const int +); + +void ATL_crefher2kUN +( + const int, const int, + const float *, + const float *, const int, + const float *, const int, + const float, + float *, const int +); + +void ATL_crefher2kUC +( + const int, const int, + const float *, + const float *, const int, + const float *, const int, + const float, + float *, const int +); + +void ATL_crefsymmLL +( + const int, const int, + const float *, + const float *, const int, + const float *, const int, + const float *, + float *, const int +); + +void ATL_crefsymmLU +( + const int, const int, + const float *, + const float *, const int, + const float *, const int, + const float *, + float *, const int +); + +void ATL_crefsymmRL +( + const int, const int, + const float *, + const float *, const int, + const float *, const int, + const float *, + float *, const int +); + +void ATL_crefsymmRU +( + const int, const int, + const float *, + const float *, const int, + const float *, const int, + const float *, + float *, const int +); + +void ATL_crefsyrkLN +( + const int, const int, + const float *, + const float *, const int, + const float *, + float *, const int +); + +void ATL_crefsyrkLT +( + const int, const int, + const float *, + const float *, const int, + const float *, + float *, const int +); + +void ATL_crefsyrkUN +( + const int, const int, + const float *, + const float *, const int, + const float *, + float *, const int +); + +void ATL_crefsyrkUT +( + const int, const int, + const float *, + const float *, const int, + const float *, + float *, const int +); + +void ATL_crefsyr2kLN +( + const int, const int, + const float *, + const float *, const int, + const float *, const int, + const float *, + float *, const int +); + +void ATL_crefsyr2kLT +( + const int, const int, + const float *, + const float *, const int, + const float *, const int, + const float *, + float *, const int +); + +void ATL_crefsyr2kUN +( + const int, const int, + const float *, + const float *, const int, + const float *, const int, + const float *, + float *, const int +); + +void ATL_crefsyr2kUT +( + const int, const int, + const float *, + const float *, const int, + const float *, const int, + const float *, + float *, const int +); + +void ATL_creftrmmLLNN +( + const int, const int, + const float *, + const float *, const int, + float *, const int +); + +void ATL_creftrmmLLNU +( + const int, const int, + const float *, + const float *, const int, + float *, const int +); + +void ATL_creftrmmLLTN +( + const int, const int, + const float *, + const float *, const int, + float *, const int +); + +void ATL_creftrmmLLTU +( + const int, const int, + const float *, + const float *, const int, + float *, const int +); + +void ATL_creftrmmLLCN +( + const int, const int, + const float *, + const float *, const int, + float *, const int +); + +void ATL_creftrmmLLCU +( + const int, const int, + const float *, + const float *, const int, + float *, const int +); + +void ATL_creftrmmLUNN +( + const int, const int, + const float *, + const float *, const int, + float *, const int +); + +void ATL_creftrmmLUNU +( + const int, const int, + const float *, + const float *, const int, + float *, const int +); + +void ATL_creftrmmLUTN +( + const int, const int, + const float *, + const float *, const int, + float *, const int +); + +void ATL_creftrmmLUTU +( + const int, const int, + const float *, + const float *, const int, + float *, const int +); + +void ATL_creftrmmLUCN +( + const int, const int, + const float *, + const float *, const int, + float *, const int +); + +void ATL_creftrmmLUCU +( + const int, const int, + const float *, + const float *, const int, + float *, const int +); + +void ATL_creftrmmRLNN +( + const int, const int, + const float *, + const float *, const int, + float *, const int +); + +void ATL_creftrmmRLNU +( + const int, const int, + const float *, + const float *, const int, + float *, const int +); + +void ATL_creftrmmRLTN +( + const int, const int, + const float *, + const float *, const int, + float *, const int +); + +void ATL_creftrmmRLTU +( + const int, const int, + const float *, + const float *, const int, + float *, const int +); + +void ATL_creftrmmRLCN +( + const int, const int, + const float *, + const float *, const int, + float *, const int +); + +void ATL_creftrmmRLCU +( + const int, const int, + const float *, + const float *, const int, + float *, const int +); + +void ATL_creftrmmRUNN +( + const int, const int, + const float *, + const float *, const int, + float *, const int +); + +void ATL_creftrmmRUNU +( + const int, const int, + const float *, + const float *, const int, + float *, const int +); + +void ATL_creftrmmRUTN +( + const int, const int, + const float *, + const float *, const int, + float *, const int +); + +void ATL_creftrmmRUTU +( + const int, const int, + const float *, + const float *, const int, + float *, const int +); + +void ATL_creftrmmRUCN +( + const int, const int, + const float *, + const float *, const int, + float *, const int +); + +void ATL_creftrmmRUCU +( + const int, const int, + const float *, + const float *, const int, + float *, const int +); + +void ATL_creftrsmLLNN +( + const int, const int, + const float *, + const float *, const int, + float *, const int +); + +void ATL_creftrsmLLNU +( + const int, const int, + const float *, + const float *, const int, + float *, const int +); + +void ATL_creftrsmLLTN +( + const int, const int, + const float *, + const float *, const int, + float *, const int +); + +void ATL_creftrsmLLTU +( + const int, const int, + const float *, + const float *, const int, + float *, const int +); + +void ATL_creftrsmLLCN +( + const int, const int, + const float *, + const float *, const int, + float *, const int +); + +void ATL_creftrsmLLCU +( + const int, const int, + const float *, + const float *, const int, + float *, const int +); + +void ATL_creftrsmLUNN +( + const int, const int, + const float *, + const float *, const int, + float *, const int +); + +void ATL_creftrsmLUNU +( + const int, const int, + const float *, + const float *, const int, + float *, const int +); + +void ATL_creftrsmLUTN +( + const int, const int, + const float *, + const float *, const int, + float *, const int +); + +void ATL_creftrsmLUTU +( + const int, const int, + const float *, + const float *, const int, + float *, const int +); + +void ATL_creftrsmLUCN +( + const int, const int, + const float *, + const float *, const int, + float *, const int +); + +void ATL_creftrsmLUCU +( + const int, const int, + const float *, + const float *, const int, + float *, const int +); + +void ATL_creftrsmRLNN +( + const int, const int, + const float *, + const float *, const int, + float *, const int +); + +void ATL_creftrsmRLNU +( + const int, const int, + const float *, + const float *, const int, + float *, const int +); + +void ATL_creftrsmRLTN +( + const int, const int, + const float *, + const float *, const int, + float *, const int +); + +void ATL_creftrsmRLTU +( + const int, const int, + const float *, + const float *, const int, + float *, const int +); + +void ATL_creftrsmRLCN +( + const int, const int, + const float *, + const float *, const int, + float *, const int +); + +void ATL_creftrsmRLCU +( + const int, const int, + const float *, + const float *, const int, + float *, const int +); + +void ATL_creftrsmRUNN +( + const int, const int, + const float *, + const float *, const int, + float *, const int +); + +void ATL_creftrsmRUNU +( + const int, const int, + const float *, + const float *, const int, + float *, const int +); + +void ATL_creftrsmRUTN +( + const int, const int, + const float *, + const float *, const int, + float *, const int +); + +void ATL_creftrsmRUTU +( + const int, const int, + const float *, + const float *, const int, + float *, const int +); + +void ATL_creftrsmRUCN +( + const int, const int, + const float *, + const float *, const int, + float *, const int +); + +void ATL_creftrsmRUCU +( + const int, const int, + const float *, + const float *, const int, + float *, const int +); + +void ATL_zrefgemmNN +( + const int, const int, const int, + const double *, + const double *, const int, + const double *, const int, + const double *, + double *, const int +); + +void ATL_zrefgemmNT +( + const int, const int, const int, + const double *, + const double *, const int, + const double *, const int, + const double *, + double *, const int +); + +void ATL_zrefgemmNC +( + const int, const int, const int, + const double *, + const double *, const int, + const double *, const int, + const double *, + double *, const int +); + +void ATL_zrefgemmTN +( + const int, const int, const int, + const double *, + const double *, const int, + const double *, const int, + const double *, + double *, const int +); + +void ATL_zrefgemmTT +( + const int, const int, const int, + const double *, + const double *, const int, + const double *, const int, + const double *, + double *, const int +); + +void ATL_zrefgemmTC +( + const int, const int, const int, + const double *, + const double *, const int, + const double *, const int, + const double *, + double *, const int +); + +void ATL_zrefgemmCN +( + const int, const int, const int, + const double *, + const double *, const int, + const double *, const int, + const double *, + double *, const int +); + +void ATL_zrefgemmCT +( + const int, const int, const int, + const double *, + const double *, const int, + const double *, const int, + const double *, + double *, const int +); + +void ATL_zrefgemmCC +( + const int, const int, const int, + const double *, + const double *, const int, + const double *, const int, + const double *, + double *, const int +); + +void ATL_zrefhemmLL +( + const int, const int, + const double *, + const double *, const int, + const double *, const int, + const double *, + double *, const int +); + +void ATL_zrefhemmLU +( + const int, const int, + const double *, + const double *, const int, + const double *, const int, + const double *, + double *, const int +); + +void ATL_zrefhemmRL +( + const int, const int, + const double *, + const double *, const int, + const double *, const int, + const double *, + double *, const int +); + +void ATL_zrefhemmRU +( + const int, const int, + const double *, + const double *, const int, + const double *, const int, + const double *, + double *, const int +); + +void ATL_zrefherkLN +( + const int, const int, + const double, + const double *, const int, + const double, + double *, const int +); + +void ATL_zrefherkLC +( + const int, const int, + const double, + const double *, const int, + const double, + double *, const int +); + +void ATL_zrefherkUN +( + const int, const int, + const double, + const double *, const int, + const double, + double *, const int +); + +void ATL_zrefherkUC +( + const int, const int, + const double, + const double *, const int, + const double, + double *, const int +); + +void ATL_zrefher2kLN +( + const int, const int, + const double *, + const double *, const int, + const double *, const int, + const double, + double *, const int +); + +void ATL_zrefher2kLC +( + const int, const int, + const double *, + const double *, const int, + const double *, const int, + const double, + double *, const int +); + +void ATL_zrefher2kUN +( + const int, const int, + const double *, + const double *, const int, + const double *, const int, + const double, + double *, const int +); + +void ATL_zrefher2kUC +( + const int, const int, + const double *, + const double *, const int, + const double *, const int, + const double, + double *, const int +); + +void ATL_zrefsymmLL +( + const int, const int, + const double *, + const double *, const int, + const double *, const int, + const double *, + double *, const int +); + +void ATL_zrefsymmLU +( + const int, const int, + const double *, + const double *, const int, + const double *, const int, + const double *, + double *, const int +); + +void ATL_zrefsymmRL +( + const int, const int, + const double *, + const double *, const int, + const double *, const int, + const double *, + double *, const int +); + +void ATL_zrefsymmRU +( + const int, const int, + const double *, + const double *, const int, + const double *, const int, + const double *, + double *, const int +); + +void ATL_zrefsyrkLN +( + const int, const int, + const double *, + const double *, const int, + const double *, + double *, const int +); + +void ATL_zrefsyrkLT +( + const int, const int, + const double *, + const double *, const int, + const double *, + double *, const int +); + +void ATL_zrefsyrkUN +( + const int, const int, + const double *, + const double *, const int, + const double *, + double *, const int +); + +void ATL_zrefsyrkUT +( + const int, const int, + const double *, + const double *, const int, + const double *, + double *, const int +); + +void ATL_zrefsyr2kLN +( + const int, const int, + const double *, + const double *, const int, + const double *, const int, + const double *, + double *, const int +); + +void ATL_zrefsyr2kLT +( + const int, const int, + const double *, + const double *, const int, + const double *, const int, + const double *, + double *, const int +); + +void ATL_zrefsyr2kUN +( + const int, const int, + const double *, + const double *, const int, + const double *, const int, + const double *, + double *, const int +); + +void ATL_zrefsyr2kUT +( + const int, const int, + const double *, + const double *, const int, + const double *, const int, + const double *, + double *, const int +); + +void ATL_zreftrmmLLNN +( + const int, const int, + const double *, + const double *, const int, + double *, const int +); + +void ATL_zreftrmmLLNU +( + const int, const int, + const double *, + const double *, const int, + double *, const int +); + +void ATL_zreftrmmLLTN +( + const int, const int, + const double *, + const double *, const int, + double *, const int +); + +void ATL_zreftrmmLLTU +( + const int, const int, + const double *, + const double *, const int, + double *, const int +); + +void ATL_zreftrmmLLCN +( + const int, const int, + const double *, + const double *, const int, + double *, const int +); + +void ATL_zreftrmmLLCU +( + const int, const int, + const double *, + const double *, const int, + double *, const int +); + +void ATL_zreftrmmLUNN +( + const int, const int, + const double *, + const double *, const int, + double *, const int +); + +void ATL_zreftrmmLUNU +( + const int, const int, + const double *, + const double *, const int, + double *, const int +); + +void ATL_zreftrmmLUTN +( + const int, const int, + const double *, + const double *, const int, + double *, const int +); + +void ATL_zreftrmmLUTU +( + const int, const int, + const double *, + const double *, const int, + double *, const int +); + +void ATL_zreftrmmLUCN +( + const int, const int, + const double *, + const double *, const int, + double *, const int +); + +void ATL_zreftrmmLUCU +( + const int, const int, + const double *, + const double *, const int, + double *, const int +); + +void ATL_zreftrmmRLNN +( + const int, const int, + const double *, + const double *, const int, + double *, const int +); + +void ATL_zreftrmmRLNU +( + const int, const int, + const double *, + const double *, const int, + double *, const int +); + +void ATL_zreftrmmRLTN +( + const int, const int, + const double *, + const double *, const int, + double *, const int +); + +void ATL_zreftrmmRLTU +( + const int, const int, + const double *, + const double *, const int, + double *, const int +); + +void ATL_zreftrmmRLCN +( + const int, const int, + const double *, + const double *, const int, + double *, const int +); + +void ATL_zreftrmmRLCU +( + const int, const int, + const double *, + const double *, const int, + double *, const int +); + +void ATL_zreftrmmRUNN +( + const int, const int, + const double *, + const double *, const int, + double *, const int +); + +void ATL_zreftrmmRUNU +( + const int, const int, + const double *, + const double *, const int, + double *, const int +); + +void ATL_zreftrmmRUTN +( + const int, const int, + const double *, + const double *, const int, + double *, const int +); + +void ATL_zreftrmmRUTU +( + const int, const int, + const double *, + const double *, const int, + double *, const int +); + +void ATL_zreftrmmRUCN +( + const int, const int, + const double *, + const double *, const int, + double *, const int +); + +void ATL_zreftrmmRUCU +( + const int, const int, + const double *, + const double *, const int, + double *, const int +); + +void ATL_zreftrsmLLNN +( + const int, const int, + const double *, + const double *, const int, + double *, const int +); + +void ATL_zreftrsmLLNU +( + const int, const int, + const double *, + const double *, const int, + double *, const int +); + +void ATL_zreftrsmLLTN +( + const int, const int, + const double *, + const double *, const int, + double *, const int +); + +void ATL_zreftrsmLLTU +( + const int, const int, + const double *, + const double *, const int, + double *, const int +); + +void ATL_zreftrsmLLCN +( + const int, const int, + const double *, + const double *, const int, + double *, const int +); + +void ATL_zreftrsmLLCU +( + const int, const int, + const double *, + const double *, const int, + double *, const int +); + +void ATL_zreftrsmLUNN +( + const int, const int, + const double *, + const double *, const int, + double *, const int +); + +void ATL_zreftrsmLUNU +( + const int, const int, + const double *, + const double *, const int, + double *, const int +); + +void ATL_zreftrsmLUTN +( + const int, const int, + const double *, + const double *, const int, + double *, const int +); + +void ATL_zreftrsmLUTU +( + const int, const int, + const double *, + const double *, const int, + double *, const int +); + +void ATL_zreftrsmLUCN +( + const int, const int, + const double *, + const double *, const int, + double *, const int +); + +void ATL_zreftrsmLUCU +( + const int, const int, + const double *, + const double *, const int, + double *, const int +); + +void ATL_zreftrsmRLNN +( + const int, const int, + const double *, + const double *, const int, + double *, const int +); + +void ATL_zreftrsmRLNU +( + const int, const int, + const double *, + const double *, const int, + double *, const int +); + +void ATL_zreftrsmRLTN +( + const int, const int, + const double *, + const double *, const int, + double *, const int +); + +void ATL_zreftrsmRLTU +( + const int, const int, + const double *, + const double *, const int, + double *, const int +); + +void ATL_zreftrsmRLCN +( + const int, const int, + const double *, + const double *, const int, + double *, const int +); + +void ATL_zreftrsmRLCU +( + const int, const int, + const double *, + const double *, const int, + double *, const int +); + +void ATL_zreftrsmRUNN +( + const int, const int, + const double *, + const double *, const int, + double *, const int +); + +void ATL_zreftrsmRUNU +( + const int, const int, + const double *, + const double *, const int, + double *, const int +); + +void ATL_zreftrsmRUTN +( + const int, const int, + const double *, + const double *, const int, + double *, const int +); + +void ATL_zreftrsmRUTU +( + const int, const int, + const double *, + const double *, const int, + double *, const int +); + +void ATL_zreftrsmRUCN +( + const int, const int, + const double *, + const double *, const int, + double *, const int +); + +void ATL_zreftrsmRUCU +( + const int, const int, + const double *, + const double *, const int, + double *, const int +); + +#endif +/* + * End of atlas_reflvl3.h + */ diff --git a/kaldi_io/src/tools/ATLAS/include/atlas_refmisc.h b/kaldi_io/src/tools/ATLAS/include/atlas_refmisc.h new file mode 100644 index 0000000..d8b600e --- /dev/null +++ b/kaldi_io/src/tools/ATLAS/include/atlas_refmisc.h @@ -0,0 +1,367 @@ +/* --------------------------------------------------------------------- + * + * -- Automatically Tuned Linear Algebra Software (ATLAS) + * (C) Copyright 2000 All Rights Reserved + * + * -- ATLAS routine -- Version 3.2 -- December 25, 2000 + * + * Author : Antoine P. Petitet + * Originally developed at the University of Tennessee, + * Innovative Computing Laboratory, Knoxville TN, 37996-1301, USA. + * + * --------------------------------------------------------------------- + * + * -- Copyright notice and Licensing terms: + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions + * are met: + * + * 1. Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * 2. Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions, and the following disclaimer in + * the documentation and/or other materials provided with the distri- + * bution. + * 3. The name of the University, the ATLAS group, or the names of its + * contributors may not be used to endorse or promote products deri- + * ved from this software without specific written permission. + * + * -- Disclaimer: + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + * ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE UNIVERSITY + * OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPE- + * CIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED + * TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, + * OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEO- + * RY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (IN- + * CLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF + * THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + * --------------------------------------------------------------------- + */ +#ifndef ATL_REFMISC_H +#define ATL_REFMISC_H +/* + * ===================================================================== + * Include files + * ===================================================================== + */ +#include <math.h> +#include "atlas_enum.h" +/* + * ===================================================================== + * #define macro constants + * ===================================================================== + */ +#define ATL_sNONE (-1.0f) +#define ATL_sNTWO (-2.0f) +#define ATL_sONE ( 1.0f) +#define ATL_sZERO ( 0.0f) + +#define ATL_dNONE (-1.0) +#define ATL_dNTWO (-2.0) +#define ATL_dONE ( 1.0) +#define ATL_dZERO ( 0.0) +/* + * ===================================================================== + * # macro functions + * ===================================================================== + */ +#define Msabs( a_ ) ( ( (a_) < ATL_sZERO ) ? -(a_) : (a_) ) + +#define Mszero( a_r_, a_i_ ) \ + ( ( (a_r_) == ATL_sZERO ) && ( (a_i_) == ATL_sZERO ) ) + +#define Msone( a_r_, a_i_ ) \ + ( ( (a_r_) == ATL_sONE ) && ( (a_i_) == ATL_sZERO ) ) + +#define Msscl( a_r_, a_i_, c_r_, c_i_ ) \ + { \ + register float tmp_r_, tmp_i_; \ + tmp_r_ = (a_r_) * c_r_ - (a_i_) * c_i_; \ + tmp_i_ = (a_r_) * c_i_ + (a_i_) * c_r_; \ + c_r_ = tmp_r_; \ + c_i_ = tmp_i_; \ + } +/* + * Msdiv performs complex division in real arithmetic + * a_r_ + i * a_i_ = ( a_r_ + i * a_i_ ) / ( b_r_ + i * b_i_ ); + * The algorithm is due to Robert L. Smith and can be found in D. Knuth, + * The art of Computer Programming, Vol.2, p.195 + */ +#define Msdiv( b_r_, b_i_, a_r_, a_i_ ) \ + { \ + register float c_i_, c_r_, tmp1_, tmp2_; \ + if( Msabs( b_i_ ) < Msabs( b_r_ ) ) \ + { \ + tmp1_ = (b_i_) / (b_r_); \ + tmp2_ = (b_r_) + (b_i_) * tmp1_; \ + c_r_ = ( (a_r_) + (a_i_) * tmp1_ ) / tmp2_; \ + c_i_ = ( (a_i_) - (a_r_) * tmp1_ ) / tmp2_; \ + } \ + else \ + { \ + tmp1_ = (b_r_) / (b_i_); \ + tmp2_ = (b_i_) + (b_r_) * tmp1_; \ + c_r_ = ( (a_i_) + (a_r_) * tmp1_ ) / tmp2_; \ + c_i_ = ( -(a_r_) + (a_i_) * tmp1_ ) / tmp2_; \ + } \ + a_r_ = c_r_; \ + a_i_ = c_i_; \ + } + +#define Mdabs( a_ ) ( ( (a_) < ATL_dZERO ) ? -(a_) : (a_) ) + +#define Mdzero( a_r_, a_i_ ) \ + ( ( (a_r_) == ATL_dZERO ) && ( (a_i_) == ATL_dZERO ) ) + +#define Mdone( a_r_, a_i_ ) \ + ( ( (a_r_) == ATL_dONE ) && ( (a_i_) == ATL_dZERO ) ) + +#define Mdscl( a_r_, a_i_, c_r_, c_i_ ) \ + { \ + register double tmp_r_, tmp_i_; \ + tmp_r_ = (a_r_) * c_r_ - (a_i_) * c_i_; \ + tmp_i_ = (a_r_) * c_i_ + (a_i_) * c_r_; \ + c_r_ = tmp_r_; \ + c_i_ = tmp_i_; \ + } +/* + * Mddiv performs complex division in real arithmetic + * a_r_ + i * a_i_ = ( a_r_ + i * a_i_ ) / ( b_r_ + i * b_i_ ); + * The algorithm is due to Robert L. Smith and can be found in D. Knuth, + * The art of Computer Programming, Vol.2, p.195 + */ +#define Mddiv( b_r_, b_i_, a_r_, a_i_ ) \ + { \ + register double c_i_, c_r_, tmp1_, tmp2_; \ + if( Mdabs( b_i_ ) < Mdabs( b_r_ ) ) \ + { \ + tmp1_ = (b_i_) / (b_r_); \ + tmp2_ = (b_r_) + (b_i_) * tmp1_; \ + c_r_ = ( (a_r_) + (a_i_) * tmp1_ ) / tmp2_; \ + c_i_ = ( (a_i_) - (a_r_) * tmp1_ ) / tmp2_; \ + } \ + else \ + { \ + tmp1_ = (b_r_) / (b_i_); \ + tmp2_ = (b_i_) + (b_r_) * tmp1_; \ + c_r_ = ( (a_i_) + (a_r_) * tmp1_ ) / tmp2_; \ + c_i_ = ( -(a_r_) + (a_i_) * tmp1_ ) / tmp2_; \ + } \ + a_r_ = c_r_; \ + a_i_ = c_i_; \ + } + +#define Mmin( a_, b_ ) ( ( (a_) < (b_) ) ? (a_) : (b_) ) + +#define Mmax( a_, b_ ) ( ( (a_) > (b_) ) ? (a_) : (b_) ) + +#define Mmul( a_r_, a_i_, b_r_, b_i_, c_r_, c_i_ ) \ + { \ + c_r_ = (a_r_) * (b_r_) - (a_i_) * (b_i_); \ + c_i_ = (a_r_) * (b_i_) + (a_i_) * (b_r_); \ + } + +#define Mmla( a_r_, a_i_, b_r_, b_i_, c_r_, c_i_ ) \ + { \ + c_r_ += (a_r_) * (b_r_) - (a_i_) * (b_i_); \ + c_i_ += (a_r_) * (b_i_) + (a_i_) * (b_r_); \ + } + +#define Mmls( a_r_, a_i_, b_r_, b_i_, c_r_, c_i_ ) \ + { \ + c_r_ -= (a_r_) * (b_r_) - (a_i_) * (b_i_); \ + c_i_ -= (a_r_) * (b_i_) + (a_i_) * (b_r_); \ + } + +#define Mset( a_r_, a_i_, b_r_, b_i_ ) \ + { \ + b_r_ = (a_r_); \ + b_i_ = (a_i_); \ + } + +#define Mselscal( al_, a_ ) \ + { \ + if( (al_) == ATL_sZERO ) { (a_) = ATL_sZERO; } \ + else if( (al_) != ATL_sONE ) { (a_) *= (al_); } \ + } + +#define Mdelscal( al_, a_ ) \ + { \ + if( (al_) == ATL_dZERO ) { (a_) = ATL_dZERO; } \ + else if( (al_) != ATL_dONE ) { (a_) *= (al_); } \ + } + +#define Mcelscal( al_r_, al_i_, a_r_, a_i_ ) \ + { \ + if( Mszero( (al_r_), (al_i_) ) ) \ + { (a_r_) = (a_i_) = ATL_sZERO; } \ + else if( ! Msone( (al_r_), (al_i_) ) ) \ + { Msscl( (al_r_), (al_i_), (a_r_), (a_i_) ); } \ + } + +#define Mzelscal( al_r_, al_i_, a_r_, a_i_ ) \ + { \ + if( Mdzero( (al_r_), (al_i_) ) ) \ + { (a_r_) = (a_i_) = ATL_dZERO; } \ + else if( ! Mdone( (al_r_), (al_i_) ) ) \ + { Mdscl( (al_r_), (al_i_), (a_r_), (a_i_) ); } \ + } + +#define Msvscal( n_, al_, x_, incx_ ) \ + { \ + int i_, ix_; \ + if( (al_) == ATL_sZERO ) \ + { \ + for( i_ = 0, ix_ = 0; i_ < (n_); i_++, ix_ += (incx_) ) \ + { (x_)[ix_] = ATL_sZERO; } \ + } \ + else if( (al_) != ATL_sONE ) \ + { \ + for( i_ = 0, ix_ = 0; i_ < (n_); i_++, ix_ += (incx_) ) \ + { (x_)[ix_] *= (al_); } \ + } \ + } + +#define Mdvscal( n_, al_, x_, incx_ ) \ + { \ + int i_, ix_; \ + if( (al_) == ATL_dZERO ) \ + { \ + for( i_ = 0, ix_ = 0; i_ < (n_); i_++, ix_ += (incx_) ) \ + { (x_)[ix_] = ATL_dZERO; } \ + } \ + else if( (al_) != ATL_dONE ) \ + { \ + for( i_ = 0, ix_ = 0; i_ < (n_); i_++, ix_ += (incx_) ) \ + { (x_)[ix_] *= (al_); } \ + } \ + } + +#define Mcvscal( n_, al_, x_, incx_ ) \ + { \ + int i_, ix_, incx2_ = ( 2 * (incx_) ); \ + if( Mszero( (al_)[0], (al_)[1] ) ) \ + { \ + for( i_ = 0, ix_ = 0; i_ < (n_); i_++, ix_ += (incx2_) ) \ + { (x_)[ix_] = (x_)[ix_+1] = ATL_sZERO; } \ + } \ + else if( ! Msone( (al_)[0], (al_)[1] ) ) \ + { \ + for( i_ = 0, ix_ = 0; i_ < (n_); i_++, ix_ += (incx2_) ) \ + { Msscl( (al_)[0], (al_)[1], (x_)[ix_], (x_)[ix_+1] ); } \ + } \ + } + +#define Mzvscal( n_, al_, x_, incx_ ) \ + { \ + int i_, ix_, incx2_ = ( 2 * (incx_) ); \ + if( Mdzero( (al_)[0], (al_)[1] ) ) \ + { \ + for( i_ = 0, ix_ = 0; i_ < (n_); i_++, ix_ += (incx2_) ) \ + { (x_)[ix_] = (x_)[ix_+1] = ATL_dZERO; } \ + } \ + else if( ! Mdone( (al_)[0], (al_)[1] ) ) \ + { \ + for( i_ = 0, ix_ = 0; i_ < (n_); i_++, ix_ += (incx2_) ) \ + { Mdscl( (al_)[0], (al_)[1], (x_)[ix_], (x_)[ix_+1] ); } \ + } \ + } + +#define Msgescal( m_, n_, al_, a_, lda_ ) \ + { \ + int i_, iaij_, j_, jaj_; \ + if( (al_) == ATL_sZERO ) \ + { \ + for( j_ = 0, jaj_ = 0; j_ < (n_); j_++, jaj_ += (lda_) ) \ + { \ + for( i_ = 0, iaij_ = jaj_; i_ < (m_); i_++, iaij_ += 1 ) \ + { (a_)[iaij_] = ATL_sZERO; } \ + } \ + } \ + else if( (al_) != ATL_sONE ) \ + { \ + for( j_ = 0, jaj_ = 0; j_ < (n_); j_++, jaj_ += (lda_) ) \ + { \ + for( i_ = 0, iaij_ = jaj_; i_ < (m_); i_++, iaij_ += 1 ) \ + { (a_)[iaij_] *= (al_); } \ + } \ + } \ + } + +#define Mdgescal( m_, n_, al_, a_, lda_ ) \ + { \ + int i_, iaij_, j_, jaj_; \ + if( (al_) == ATL_dZERO ) \ + { \ + for( j_ = 0, jaj_ = 0; j_ < (n_); j_++, jaj_ += (lda_) ) \ + { \ + for( i_ = 0, iaij_ = jaj_; i_ < (m_); i_++, iaij_ += 1 ) \ + { (a_)[iaij_] = ATL_dZERO; } \ + } \ + } \ + else if( (al_) != ATL_dONE ) \ + { \ + for( j_ = 0, jaj_ = 0; j_ < (n_); j_++, jaj_ += (lda_) ) \ + { \ + for( i_ = 0, iaij_ = jaj_; i_ < (m_); i_++, iaij_ += 1 ) \ + { (a_)[iaij_] *= (al_); } \ + } \ + } \ + } + +#define Mcgescal( m_, n_, al_, a_, lda_ ) \ + { \ + int i_, iaij_, j_, jaj_, lda2_ = ( (lda_) << 1 ); \ + if( Mszero( (al_)[0], (al_)[1] ) ) \ + { \ + for( j_ = 0, jaj_ = 0; j_ < (n_); j_++, jaj_ += lda2_ ) \ + { \ + for( i_ = 0, iaij_ = jaj_; i_ < (m_); i_++, iaij_ += 2 ) \ + { (a_)[iaij_] = (a_)[iaij_+1] = ATL_sZERO; } \ + } \ + } \ + else if( ! Msone( (al_)[0], (al_)[1] ) ) \ + { \ + for( j_ = 0, jaj_ = 0; j_ < (n_); j_++, jaj_ += lda2_ ) \ + { \ + for( i_ = 0, iaij_ = jaj_; i_ < (m_); i_++, iaij_ += 2 ) \ + { \ + Msscl( (al_)[0], (al_)[1], (a_)[iaij_], (a_)[iaij_+1] ); \ + } \ + } \ + } \ + } + +#define Mzgescal( m_, n_, al_, a_, lda_ ) \ + { \ + int i_, iaij_, j_, jaj_, lda2_ = ( (lda_) << 1 ); \ + if( Mdzero( (al_)[0], (al_)[1] ) ) \ + { \ + for( j_ = 0, jaj_ = 0; j_ < (n_); j_++, jaj_ += lda2_ ) \ + { \ + for( i_ = 0, iaij_ = jaj_; i_ < (m_); i_++, iaij_ += 2 ) \ + { (a_)[iaij_] = (a_)[iaij_+1] = ATL_dZERO; } \ + } \ + } \ + else if( ! Mdone( (al_)[0], (al_)[1] ) ) \ + { \ + for( j_ = 0, jaj_ = 0; j_ < (n_); j_++, jaj_ += lda2_ ) \ + { \ + for( i_ = 0, iaij_ = jaj_; i_ < (m_); i_++, iaij_ += 2 ) \ + { \ + Mdscl( (al_)[0], (al_)[1], (a_)[iaij_], (a_)[iaij_+1] ); \ + } \ + } \ + } \ + } + +#endif +/* + * End of atlas_refmisc.h + */ diff --git a/kaldi_io/src/tools/ATLAS/include/atlas_tst.h b/kaldi_io/src/tools/ATLAS/include/atlas_tst.h new file mode 100644 index 0000000..1ea5f5e --- /dev/null +++ b/kaldi_io/src/tools/ATLAS/include/atlas_tst.h @@ -0,0 +1,909 @@ +/* + * Automatically Tuned Linear Algebra Software v3.8.3 + * (C) Copyright 1999 R. Clint Whaley + * + * Code contributers : R. Clint Whaley, Antoine P. Petitet + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions + * are met: + * 1. Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * 2. Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions, and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * 3. The name of the ATLAS group or the names of its contributers may + * not be used to endorse or promote products derived from this + * software without specific written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + * ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED + * TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR + * PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE ATLAS GROUP OR ITS CONTRIBUTORS + * BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + * POSSIBILITY OF SUCH DAMAGE. + * + */ + +#ifndef ATLAS_TST_H + #define ATLAS_TST_H + +#include "atlas_enum.h" + +double time00(); +#ifndef UseCRand + void ATL_srand(int iseed); + int ATL_rand(void); + #define dumb_seed(iseed_) ATL_srand(iseed_) + #define dumb_rand() ( 0.5 - ((double)ATL_rand())/(2147483648.0) ) +#else + #define dumb_seed(iseed_) srand(iseed_) + #ifndef RAND_MAX /* rather dangerous non-ansi workaround */ + #define RAND_MAX ((unsigned long)(1<<30)) + #endif + #define dumb_rand() ( 0.5 - ((double)rand())/((double)RAND_MAX) ) +#endif + +void ATL_ststsqtran(const int N, float *A, const int lda); +void ATL_sgeprint + (char *mat, const int M, const int N, const float *A, const int lda); + +float ATL_sgediffnrm1 + (const int M, const int N, const float *A, const int lda, + const float *B, const int ldb); +float ATL_shediffnrm + (const enum ATLAS_ORDER Order, const enum ATLAS_UPLO Uplo, const int N, + const float *A0, const int ld0, const float *A1, const int ld1); +float ATL_sinfnrm(const int N, const float *X, const int incX); +float ATL_sgenrm1 + (const int M, const int N, const float *A, const int lda); +float ATL_strnrm1 + (const enum ATLAS_UPLO Upper, const enum ATLAS_DIAG Diag, const int N, + const float *A, const int lda); +float ATL_sgbnrm1 + (const int M, const int N, const int KL, const int KU, + const float *A, const int lda); +float ATL_stpnrm1 + (const enum ATLAS_UPLO UPLO, const enum ATLAS_DIAG DIAG, const int N, + const float *A); +float ATL_stbnrm1 + (const enum ATLAS_UPLO UPLO, const enum ATLAS_DIAG DIAG, + const int N, const int K, const float *A, const int LDA); +float ATL_ssynrm + (const enum ATLAS_UPLO UPLO, const int N, const float *A, const int LDA); +float ATL_shenrm + (const enum ATLAS_UPLO UPLO, const int N, const float *A, const int LDA); +float ATL_sspnrm + (const enum ATLAS_UPLO UPLO, const int N, const float *A); +float ATL_shpnrm + (const enum ATLAS_UPLO UPLO, const int N, const float *A); +float ATL_ssbnrm + (const enum ATLAS_UPLO UPLO, const int N, const int K, + const float *A, const int LDA); +float ATL_shbnrm + (const enum ATLAS_UPLO UPLO, const int N, const int K, + const float *A, const int LDA); + +void ATL_sgefillgap(const int M, const int N, float *A, const int lda0); +int ATL_sgechkgap(const int M0, const int N, float *A, const int lda0); +void ATL_strgen(const enum ATLAS_UPLO Uplo, const enum ATLAS_DIAG Diag, + const int N, float *A, const int lda, const int seed); +void ATL_sgegen(const int M0, const int N, float *A, const int lda, + const int seed); +float ATL_sepsilon(void); +void ATL_svdiff(const int N, const float *X, const int incX, + const float *Y, const int incY, float *Z, const int incZ); +void ATL_sgediff(const int M, const int N, const float *A, const int lda, + const float *B, const int ldb, float *C, const int ldc); +void ATL_dtstsqtran(const int N, double *A, const int lda); +void ATL_dgeprint + (char *mat, const int M, const int N, const double *A, const int lda); + +double ATL_dgediffnrm1 + (const int M, const int N, const double *A, const int lda, + const double *B, const int ldb); +double ATL_dhediffnrm + (const enum ATLAS_ORDER Order, const enum ATLAS_UPLO Uplo, const int N, + const double *A0, const int ld0, const double *A1, const int ld1); +double ATL_dinfnrm(const int N, const double *X, const int incX); +double ATL_dgenrm1 + (const int M, const int N, const double *A, const int lda); +double ATL_dtrnrm1 + (const enum ATLAS_UPLO Upper, const enum ATLAS_DIAG Diag, const int N, + const double *A, const int lda); +double ATL_dgbnrm1 + (const int M, const int N, const int KL, const int KU, + const double *A, const int lda); +double ATL_dtpnrm1 + (const enum ATLAS_UPLO UPLO, const enum ATLAS_DIAG DIAG, const int N, + const double *A); +double ATL_dtbnrm1 + (const enum ATLAS_UPLO UPLO, const enum ATLAS_DIAG DIAG, + const int N, const int K, const double *A, const int LDA); +double ATL_dsynrm + (const enum ATLAS_UPLO UPLO, const int N, const double *A, const int LDA); +double ATL_dhenrm + (const enum ATLAS_UPLO UPLO, const int N, const double *A, const int LDA); +double ATL_dspnrm + (const enum ATLAS_UPLO UPLO, const int N, const double *A); +double ATL_dhpnrm + (const enum ATLAS_UPLO UPLO, const int N, const double *A); +double ATL_dsbnrm + (const enum ATLAS_UPLO UPLO, const int N, const int K, + const double *A, const int LDA); +double ATL_dhbnrm + (const enum ATLAS_UPLO UPLO, const int N, const int K, + const double *A, const int LDA); + +void ATL_dgefillgap(const int M, const int N, double *A, const int lda0); +int ATL_dgechkgap(const int M0, const int N, double *A, const int lda0); +void ATL_dtrgen(const enum ATLAS_UPLO Uplo, const enum ATLAS_DIAG Diag, + const int N, double *A, const int lda, const int seed); +void ATL_dgegen(const int M0, const int N, double *A, const int lda, + const int seed); +double ATL_depsilon(void); +void ATL_dvdiff(const int N, const double *X, const int incX, + const double *Y, const int incY, double *Z, const int incZ); +void ATL_dgediff(const int M, const int N, const double *A, const int lda, + const double *B, const int ldb, double *C, const int ldc); +void ATL_ctstsqtran(const int N, float *A, const int lda); +void ATL_cgeprint + (char *mat, const int M, const int N, const float *A, const int lda); + +float ATL_cgediffnrm1 + (const int M, const int N, const float *A, const int lda, + const float *B, const int ldb); +float ATL_chediffnrm + (const enum ATLAS_ORDER Order, const enum ATLAS_UPLO Uplo, const int N, + const float *A0, const int ld0, const float *A1, const int ld1); +float ATL_cinfnrm(const int N, const float *X, const int incX); +float ATL_cgenrm1 + (const int M, const int N, const float *A, const int lda); +float ATL_ctrnrm1 + (const enum ATLAS_UPLO Upper, const enum ATLAS_DIAG Diag, const int N, + const float *A, const int lda); +float ATL_cgbnrm1 + (const int M, const int N, const int KL, const int KU, + const float *A, const int lda); +float ATL_ctpnrm1 + (const enum ATLAS_UPLO UPLO, const enum ATLAS_DIAG DIAG, const int N, + const float *A); +float ATL_ctbnrm1 + (const enum ATLAS_UPLO UPLO, const enum ATLAS_DIAG DIAG, + const int N, const int K, const float *A, const int LDA); +float ATL_csynrm + (const enum ATLAS_UPLO UPLO, const int N, const float *A, const int LDA); +float ATL_chenrm + (const enum ATLAS_UPLO UPLO, const int N, const float *A, const int LDA); +float ATL_cspnrm + (const enum ATLAS_UPLO UPLO, const int N, const float *A); +float ATL_chpnrm + (const enum ATLAS_UPLO UPLO, const int N, const float *A); +float ATL_csbnrm + (const enum ATLAS_UPLO UPLO, const int N, const int K, + const float *A, const int LDA); +float ATL_chbnrm + (const enum ATLAS_UPLO UPLO, const int N, const int K, + const float *A, const int LDA); + +void ATL_cgefillgap(const int M, const int N, float *A, const int lda0); +int ATL_cgechkgap(const int M0, const int N, float *A, const int lda0); +void ATL_ctrgen(const enum ATLAS_UPLO Uplo, const enum ATLAS_DIAG Diag, + const int N, float *A, const int lda, const int seed); +void ATL_cgegen(const int M0, const int N, float *A, const int lda, + const int seed); +float ATL_cepsilon(void); +void ATL_cvdiff(const int N, const float *X, const int incX, + const float *Y, const int incY, float *Z, const int incZ); +void ATL_cgediff(const int M, const int N, const float *A, const int lda, + const float *B, const int ldb, float *C, const int ldc); +void ATL_ztstsqtran(const int N, double *A, const int lda); +void ATL_zgeprint + (char *mat, const int M, const int N, const double *A, const int lda); + +double ATL_zgediffnrm1 + (const int M, const int N, const double *A, const int lda, + const double *B, const int ldb); +double ATL_zhediffnrm + (const enum ATLAS_ORDER Order, const enum ATLAS_UPLO Uplo, const int N, + const double *A0, const int ld0, const double *A1, const int ld1); +double ATL_zinfnrm(const int N, const double *X, const int incX); +double ATL_zgenrm1 + (const int M, const int N, const double *A, const int lda); +double ATL_ztrnrm1 + (const enum ATLAS_UPLO Upper, const enum ATLAS_DIAG Diag, const int N, + const double *A, const int lda); +double ATL_zgbnrm1 + (const int M, const int N, const int KL, const int KU, + const double *A, const int lda); +double ATL_ztpnrm1 + (const enum ATLAS_UPLO UPLO, const enum ATLAS_DIAG DIAG, const int N, + const double *A); +double ATL_ztbnrm1 + (const enum ATLAS_UPLO UPLO, const enum ATLAS_DIAG DIAG, + const int N, const int K, const double *A, const int LDA); +double ATL_zsynrm + (const enum ATLAS_UPLO UPLO, const int N, const double *A, const int LDA); +double ATL_zhenrm + (const enum ATLAS_UPLO UPLO, const int N, const double *A, const int LDA); +double ATL_zspnrm + (const enum ATLAS_UPLO UPLO, const int N, const double *A); +double ATL_zhpnrm + (const enum ATLAS_UPLO UPLO, const int N, const double *A); +double ATL_zsbnrm + (const enum ATLAS_UPLO UPLO, const int N, const int K, + const double *A, const int LDA); +double ATL_zhbnrm + (const enum ATLAS_UPLO UPLO, const int N, const int K, + const double *A, const int LDA); + +void ATL_zgefillgap(const int M, const int N, double *A, const int lda0); +int ATL_zgechkgap(const int M0, const int N, double *A, const int lda0); +void ATL_ztrgen(const enum ATLAS_UPLO Uplo, const enum ATLAS_DIAG Diag, + const int N, double *A, const int lda, const int seed); +void ATL_zgegen(const int M0, const int N, double *A, const int lda, + const int seed); +double ATL_zepsilon(void); +void ATL_zvdiff(const int N, const double *X, const int incX, + const double *Y, const int incY, double *Z, const int incZ); +void ATL_zgediff(const int M, const int N, const double *A, const int lda, + const double *B, const int ldb, double *C, const int ldc); + +/* + * Wrappers so that C can call F77 LAPACK + */ +int ATL_sf77getri + (const enum ATLAS_ORDER, const int, float*, const int, int*, + float*, int*); +int ATL_sf77getrf + (const enum ATLAS_ORDER, const int, const int, float*, const int, int*); +int ATL_sf77potrf(const enum ATLAS_UPLO, const int, float*, const int); +int ATL_sf77lauum(const enum ATLAS_UPLO, const int, float*, const int); +int ATL_sf77trtri(const enum ATLAS_UPLO, const enum ATLAS_DIAG, const int, + float*, const int); +int ATL_sf77posv(const enum ATLAS_UPLO, const int, const int, float*, const int, float*, const int); +int ATL_sf77gesv(const int, const int, float*, const int, int*, float*, const int); +int ATL_sf77gels(const enum ATLAS_TRANS, const int, const int, const int, float*, const int, float*, const int); +int ATL_df77getri + (const enum ATLAS_ORDER, const int, double*, const int, int*, + double*, int*); +int ATL_df77getrf + (const enum ATLAS_ORDER, const int, const int, double*, const int, int*); +int ATL_df77potrf(const enum ATLAS_UPLO, const int, double*, const int); +int ATL_df77lauum(const enum ATLAS_UPLO, const int, double*, const int); +int ATL_df77trtri(const enum ATLAS_UPLO, const enum ATLAS_DIAG, const int, + double*, const int); +int ATL_df77posv(const enum ATLAS_UPLO, const int, const int, double*, const int, double*, const int); +int ATL_df77gesv(const int, const int, double*, const int, int*, double*, const int); +int ATL_df77gels(const enum ATLAS_TRANS, const int, const int, const int, double*, const int, double*, const int); +int ATL_cf77getri + (const enum ATLAS_ORDER, const int, float*, const int, int*, + float*, int*); +int ATL_cf77getrf + (const enum ATLAS_ORDER, const int, const int, float*, const int, int*); +int ATL_cf77potrf(const enum ATLAS_UPLO, const int, float*, const int); +int ATL_cf77lauum(const enum ATLAS_UPLO, const int, float*, const int); +int ATL_cf77trtri(const enum ATLAS_UPLO, const enum ATLAS_DIAG, const int, + float*, const int); +int ATL_cf77posv(const enum ATLAS_UPLO, const int, const int, float*, const int, float*, const int); +int ATL_cf77gesv(const int, const int, float*, const int, int*, float*, const int); +int ATL_cf77gels(const enum ATLAS_TRANS, const int, const int, const int, float*, const int, float*, const int); +int ATL_zf77getri + (const enum ATLAS_ORDER, const int, double*, const int, int*, + double*, int*); +int ATL_zf77getrf + (const enum ATLAS_ORDER, const int, const int, double*, const int, int*); +int ATL_zf77potrf(const enum ATLAS_UPLO, const int, double*, const int); +int ATL_zf77lauum(const enum ATLAS_UPLO, const int, double*, const int); +int ATL_zf77trtri(const enum ATLAS_UPLO, const enum ATLAS_DIAG, const int, + double*, const int); +int ATL_zf77posv(const enum ATLAS_UPLO, const int, const int, double*, const int, double*, const int); +int ATL_zf77gesv(const int, const int, double*, const int, int*, double*, const int); +int ATL_zf77gels(const enum ATLAS_TRANS, const int, const int, const int, double*, const int, double*, const int); +/* + * ===================================================================== + * Prototypes for C-callable F77 interface to the Level 1 BLAS routines + * ===================================================================== + */ +void ATL_sf77rotg +( float *, float *, float *, float * ); +void ATL_df77rotg +( double *, double *, double *, double * ); +void ATL_cf77rotg +( float *, const float *, float *, float * ); +void ATL_zf77rotg +( double *, const double *, double *, double * ); + +void ATL_sf77rotmg +( float *, float *, float *, const float, + float * ); +void ATL_df77rotmg +( double *, double *, double *, const double, + double * ); + +float ATL_sf77nrm2 +( const int, const float *, const int ); +double ATL_df77nrm2 +( const int, const double *, const int ); +float ATL_scf77nrm2 +( const int, const float *, const int ); +double ATL_dzf77nrm2 +( const int, const double *, const int ); + +float ATL_sf77asum +( const int, const float *, const int ); +double ATL_df77asum +( const int, const double *, const int ); +float ATL_scf77asum +( const int, const float *, const int ); +double ATL_dzf77asum +( const int, const double *, const int ); + +int ATL_isf77amax +( const int, const float *, const int ); +int ATL_idf77amax +( const int, const double *, const int ); +int ATL_icf77amax +( const int, const float *, const int ); +int ATL_izf77amax +( const int, const double *, const int ); + +void ATL_sf77scal +( const int, const float, float *, const int ); +void ATL_df77scal +( const int, const double, double *, const int ); +void ATL_cf77scal +( const int, const float *, float *, const int ); +void ATL_zf77scal +( const int, const double *, double *, const int ); +void ATL_csf77scal +( const int, const float, float *, const int ); +void ATL_zdf77scal +( const int, const double, double *, const int ); + +void ATL_sf77set(const int, const float, float*, const int); +void ATL_df77set(const int, const double, double*, const int); +void ATL_cf77set(const int, const float*, float*, const int); +void ATL_zf77set(const int, const double*, double*, const int); +void ATL_sf77axpby + (const int, const float, const float*, const int, const float, + float*, const int); +void ATL_df77axpby + (const int, const double, const double*, const int, const double, + double*, const int); +void ATL_cf77axpby + (const int, const float*, const float*, const int, const float*, + float*, const int); +void ATL_zf77axpby + (const int, const double*, const double*, const int, const double*, + double*, const int); + +void ATL_sf77axpy +( const int, const float, const float *, const int, + float *, const int ); +void ATL_df77axpy +( const int, const double, const double *, const int, + double *, const int ); +void ATL_cf77axpy +( const int, const float *, const float *, const int, + float *, const int ); +void ATL_zf77axpy +( const int, const double *, const double *, const int, + double *, const int ); + +void ATL_sf77copy +( const int, const float *, const int, float *, + const int ); +void ATL_df77copy +( const int, const double *, const int, double *, + const int ); +void ATL_cf77copy +( const int, const float *, const int, float *, + const int ); +void ATL_zf77copy +( const int, const double *, const int, double *, + const int ); + +void ATL_sf77swap +( const int, float *, const int, float *, + const int ); +void ATL_df77swap +( const int, double *, const int, double *, + const int ); +void ATL_cf77swap +( const int, float *, const int, float *, + const int ); +void ATL_zf77swap +( const int, double *, const int, double *, + const int ); + +void ATL_sf77rot +( const int, float *, const int, float *, + const int, const float, const float ); +void ATL_df77rot +( const int, double *, const int, double *, + const int, const double, const double ); +void ATL_csf77rot +( const int, float *, const int, float *, + const int, const float, const float ); +void ATL_zdf77rot +( const int, double *, const int, double *, + const int, const double, const double ); + +void ATL_sf77rotm +( const int, float *, const int, float *, + const int, const float * ); +void ATL_df77rotm +( const int, double *, const int, double *, + const int, const double * ); + +float ATL_sf77dot +( const int, const float *, const int, const float *, + const int ); +double ATL_df77dot +( const int, const double *, const int, const double *, + const int ); +void ATL_cf77dotu_sub +( const int, const float *, const int, const float *, + const int, float * ); +void ATL_cf77dotc_sub +( const int, const float *, const int, const float *, + const int, float * ); +void ATL_zf77dotu_sub +( const int, const double *, const int, const double *, + const int, double * ); +void ATL_zf77dotc_sub +( const int, const double *, const int, const double *, + const int, double * ); + +float ATL_sdsf77dot +( const int, const float, const float *, const int, + const float *, const int ); +double ATL_dsf77dot +( const int, const float *, const int, const float *, + const int ); +/* + * ===================================================================== + * Prototypes for C-callable F77 interface to the Level 2 BLAS routines + * ===================================================================== + */ +void ATL_sf77gemv +( const enum ATLAS_TRANS, const int, const int, + const float, const float *, const int, const float *, + const int, const float, float *, const int ); +void ATL_df77gemv +( const enum ATLAS_TRANS, const int, const int, + const double, const double *, const int, const double *, + const int, const double, double *, const int ); +void ATL_cf77gemv +( const enum ATLAS_TRANS, const int, const int, + const float *, const float *, const int, const float *, + const int, const float *, float *, const int ); +void ATL_zf77gemv +( const enum ATLAS_TRANS, const int, const int, + const double *, const double *, const int, const double *, + const int, const double *, double *, const int ); + +void ATL_sf77gbmv +( const enum ATLAS_TRANS, const int, const int, + const int, const int, const float, const float *, + const int, const float *, const int, const float, + float *, const int ); +void ATL_df77gbmv +( const enum ATLAS_TRANS, const int, const int, + const int, const int, const double, const double *, + const int, const double *, const int, const double, + double *, const int ); +void ATL_cf77gbmv +( const enum ATLAS_TRANS, const int, const int, + const int, const int, const float *, const float *, + const int, const float *, const int, const float *, + float *, const int ); +void ATL_zf77gbmv +( const enum ATLAS_TRANS, const int, const int, + const int, const int, const double *, const double *, + const int, const double *, const int, const double *, + double *, const int ); + +void ATL_sf77trmv +( const enum ATLAS_UPLO, const enum ATLAS_TRANS, + const enum ATLAS_DIAG, const int, const float *, + const int, float *, const int ); +void ATL_df77trmv +( const enum ATLAS_UPLO, const enum ATLAS_TRANS, + const enum ATLAS_DIAG, const int, const double *, + const int, double *, const int ); +void ATL_cf77trmv +( const enum ATLAS_UPLO, const enum ATLAS_TRANS, + const enum ATLAS_DIAG, const int, const float *, + const int, float *, const int ); +void ATL_zf77trmv +( const enum ATLAS_UPLO, const enum ATLAS_TRANS, + const enum ATLAS_DIAG, const int, const double *, + const int, double *, const int ); + +void ATL_sf77tbmv +( const enum ATLAS_UPLO, const enum ATLAS_TRANS, + const enum ATLAS_DIAG, const int, const int, + const float *, const int, float *, const int ); +void ATL_df77tbmv +( const enum ATLAS_UPLO, const enum ATLAS_TRANS, + const enum ATLAS_DIAG, const int, const int, + const double *, const int, double *, const int ); +void ATL_cf77tbmv +( const enum ATLAS_UPLO, const enum ATLAS_TRANS, + const enum ATLAS_DIAG, const int, const int, + const float *, const int, float *, const int ); +void ATL_zf77tbmv +( const enum ATLAS_UPLO, const enum ATLAS_TRANS, + const enum ATLAS_DIAG, const int, const int, + const double *, const int, double *, const int ); + +void ATL_sf77tpmv +( const enum ATLAS_UPLO, const enum ATLAS_TRANS, + const enum ATLAS_DIAG, const int, const float *, + float *, const int ); +void ATL_df77tpmv +( const enum ATLAS_UPLO, const enum ATLAS_TRANS, + const enum ATLAS_DIAG, const int, const double *, + double *, const int ); +void ATL_cf77tpmv +( const enum ATLAS_UPLO, const enum ATLAS_TRANS, + const enum ATLAS_DIAG, const int, const float *, + float *, const int ); +void ATL_zf77tpmv +( const enum ATLAS_UPLO, const enum ATLAS_TRANS, + const enum ATLAS_DIAG, const int, const double *, + double *, const int ); + +void ATL_sf77trsv +( const enum ATLAS_UPLO, const enum ATLAS_TRANS, + const enum ATLAS_DIAG, const int, const float *, + const int, float *, const int ); +void ATL_df77trsv +( const enum ATLAS_UPLO, const enum ATLAS_TRANS, + const enum ATLAS_DIAG, const int, const double *, + const int, double *, const int ); +void ATL_cf77trsv +( const enum ATLAS_UPLO, const enum ATLAS_TRANS, + const enum ATLAS_DIAG, const int, const float *, + const int, float *, const int ); +void ATL_zf77trsv +( const enum ATLAS_UPLO, const enum ATLAS_TRANS, + const enum ATLAS_DIAG, const int, const double *, + const int, double *, const int ); + +void ATL_sf77tbsv +( const enum ATLAS_UPLO, const enum ATLAS_TRANS, + const enum ATLAS_DIAG, const int, const int, + const float *, const int, float *, const int ); +void ATL_df77tbsv +( const enum ATLAS_UPLO, const enum ATLAS_TRANS, + const enum ATLAS_DIAG, const int, const int, + const double *, const int, double *, const int ); +void ATL_cf77tbsv +( const enum ATLAS_UPLO, const enum ATLAS_TRANS, + const enum ATLAS_DIAG, const int, const int, + const float *, const int, float *, const int ); +void ATL_zf77tbsv +( const enum ATLAS_UPLO, const enum ATLAS_TRANS, + const enum ATLAS_DIAG, const int, const int, + const double *, const int, double *, const int ); + +void ATL_sf77tpsv +( const enum ATLAS_UPLO, const enum ATLAS_TRANS, + const enum ATLAS_DIAG, const int, const float *, + float *, const int ); +void ATL_df77tpsv +( const enum ATLAS_UPLO, const enum ATLAS_TRANS, + const enum ATLAS_DIAG, const int, const double *, + double *, const int ); +void ATL_cf77tpsv +( const enum ATLAS_UPLO, const enum ATLAS_TRANS, + const enum ATLAS_DIAG, const int, const float *, + float *, const int ); +void ATL_zf77tpsv +( const enum ATLAS_UPLO, const enum ATLAS_TRANS, + const enum ATLAS_DIAG, const int, const double *, + double *, const int ); + +void ATL_sf77symv +( const enum ATLAS_UPLO, const int, const float, + const float *, const int, const float *, const int, + const float, float *, const int ); +void ATL_df77symv +( const enum ATLAS_UPLO, const int, const double, + const double *, const int, const double *, const int, + const double, double *, const int ); + +void ATL_cf77hemv +( const enum ATLAS_UPLO, const int, const float *, + const float *, const int, const float *, const int, + const float *, float *, const int ); +void ATL_zf77hemv +( const enum ATLAS_UPLO, const int, const double *, + const double *, const int, const double *, const int, + const double *, double *, const int ); + +void ATL_sf77sbmv +( const enum ATLAS_UPLO, const int, const int, + const float, const float *, const int, const float *, + const int, const float, float *, const int ); +void ATL_df77sbmv +( const enum ATLAS_UPLO, const int, const int, + const double, const double *, const int, const double *, + const int, const double, double *, const int ); +void ATL_cf77hbmv +( const enum ATLAS_UPLO, const int, const int, + const float *, const float *, const int, const float *, + const int, const float *, float *, const int ); +void ATL_zf77hbmv +( const enum ATLAS_UPLO, const int, const int, + const double *, const double *, const int, const double *, + const int, const double *, double *, const int ); + +void ATL_sf77spmv +( const enum ATLAS_UPLO, const int, const float, + const float *, const float *, const int, const float, + float *, const int ); +void ATL_df77spmv +( const enum ATLAS_UPLO, const int, const double, + const double *, const double *, const int, const double, + double *, const int ); +void ATL_cf77hpmv +( const enum ATLAS_UPLO, const int, const float *, + const float *, const float *, const int, const float *, + float *, const int ); +void ATL_zf77hpmv +( const enum ATLAS_UPLO, const int, const double *, + const double *, const double *, const int, const double *, + double *, const int ); + +void ATL_sf77ger +( const int, const int, const float, const float *, + const int, const float *, const int, float *, + const int ); +void ATL_df77ger +( const int, const int, const double, const double *, + const int, const double *, const int, double *, + const int ); +void ATL_cf77gerc +( const int, const int, const float *, const float *, + const int, const float *, const int, float *, + const int ); +void ATL_cf77geru +( const int, const int, const float *, const float *, + const int, const float *, const int, float *, + const int ); +void ATL_zf77gerc +( const int, const int, const double *, const double *, + const int, const double *, const int, double *, + const int ); +void ATL_zf77geru +( const int, const int, const double *, const double *, + const int, const double *, const int, double *, + const int ); + +void ATL_sf77syr +( const enum ATLAS_UPLO, const int, const float, + const float *, const int, float *, const int ); +void ATL_df77syr +( const enum ATLAS_UPLO, const int, const double, + const double *, const int, double *, const int ); +void ATL_cf77her +( const enum ATLAS_UPLO, const int, const float, + const float *, const int, float *, const int ); +void ATL_zf77her +( const enum ATLAS_UPLO, const int, const double, + const double *, const int, double *, const int ); + +void ATL_sf77spr +( const enum ATLAS_UPLO, const int, const float, + const float *, const int, float * ); +void ATL_df77spr +( const enum ATLAS_UPLO, const int, const double, + const double *, const int, double * ); +void ATL_cf77hpr +( const enum ATLAS_UPLO, const int, const float, + const float *, const int, float * ); +void ATL_zf77hpr +( const enum ATLAS_UPLO, const int, const double, + const double *, const int, double * ); + +void ATL_sf77syr2 +( const enum ATLAS_UPLO, const int, const float, + const float *, const int, const float *, const int, + float *, const int ); +void ATL_df77syr2 +( const enum ATLAS_UPLO, const int, const double, + const double *, const int, const double *, const int, + double *, const int ); +void ATL_cf77her2 +( const enum ATLAS_UPLO, const int, const float *, + const float *, const int, const float *, const int, + float *, const int ); +void ATL_zf77her2 +( const enum ATLAS_UPLO, const int, const double *, + const double *, const int, const double *, const int, + double *, const int ); + +void ATL_sf77spr2 +( const enum ATLAS_UPLO, const int, const float, + const float *, const int, const float *, const int, + float * ); +void ATL_df77spr2 +( const enum ATLAS_UPLO, const int, const double, + const double *, const int, const double *, const int, + double * ); +void ATL_cf77hpr2 +( const enum ATLAS_UPLO, const int, const float *, + const float *, const int, const float *, const int, + float * ); +void ATL_zf77hpr2 +( const enum ATLAS_UPLO, const int, const double *, + const double *, const int, const double *, const int, + double * ); +/* + * ===================================================================== + * Prototypes for C-callable F77 interface to the Level 3 BLAS routines + * ===================================================================== + */ +void ATL_sf77gemm +( const enum ATLAS_TRANS, const enum ATLAS_TRANS, + const int, const int, const int, const float, + const float *, const int, const float *, const int, + const float, float *, const int ); +void ATL_df77gemm +( const enum ATLAS_TRANS, const enum ATLAS_TRANS, + const int, const int, const int, const double, + const double *, const int, const double *, const int, + const double, double *, const int ); +void ATL_cf77gemm +( const enum ATLAS_TRANS, const enum ATLAS_TRANS, + const int, const int, const int, const float *, + const float *, const int, const float *, const int, + const float *, float *, const int ); +void ATL_zf77gemm +( const enum ATLAS_TRANS, const enum ATLAS_TRANS, + const int, const int, const int, const double *, + const double *, const int, const double *, const int, + const double *, double *, const int ); + +void ATL_cf77hemm +( const enum ATLAS_SIDE, const enum ATLAS_UPLO, + const int, const int, const float *, const float *, + const int, const float *, const int, const float *, + float *, const int ); +void ATL_zf77hemm +( const enum ATLAS_SIDE, const enum ATLAS_UPLO, + const int, const int, const double *, const double *, + const int, const double *, const int, const double *, + double *, const int ); + +void ATL_cf77herk +( const enum ATLAS_UPLO, const enum ATLAS_TRANS, + const int, const int, const float, const float *, + const int, const float, float *, const int ); +void ATL_zf77herk +( const enum ATLAS_UPLO, const enum ATLAS_TRANS, + const int, const int, const double, const double *, + const int, const double, double *, const int ); + +void ATL_cf77her2k +( const enum ATLAS_UPLO, const enum ATLAS_TRANS, + const int, const int, const float *, const float *, + const int, const float *, const int, const float, + float *, const int ); +void ATL_zf77her2k +( const enum ATLAS_UPLO, const enum ATLAS_TRANS, + const int, const int, const double *, const double *, + const int, const double *, const int, const double, + double *, const int ); + +void ATL_sf77symm +( const enum ATLAS_SIDE, const enum ATLAS_UPLO, + const int, const int, const float, const float *, + const int, const float *, const int, const float, + float *, const int ); +void ATL_df77symm +( const enum ATLAS_SIDE, const enum ATLAS_UPLO, + const int, const int, const double, const double *, + const int, const double *, const int, const double, + double *, const int ); +void ATL_cf77symm +( const enum ATLAS_SIDE, const enum ATLAS_UPLO, + const int, const int, const float *, const float *, + const int, const float *, const int, const float *, + float *, const int ); +void ATL_zf77symm +( const enum ATLAS_SIDE, const enum ATLAS_UPLO, + const int, const int, const double *, const double *, + const int, const double *, const int, const double *, + double *, const int ); + +void ATL_sf77syrk +( const enum ATLAS_UPLO, const enum ATLAS_TRANS, + const int, const int, const float, const float *, + const int, const float, float *, const int ); +void ATL_df77syrk +( const enum ATLAS_UPLO, const enum ATLAS_TRANS, + const int, const int, const double, const double *, + const int, const double, double *, const int ); +void ATL_cf77syrk +( const enum ATLAS_UPLO, const enum ATLAS_TRANS, + const int, const int, const float *, const float *, + const int, const float *, float *, const int ); +void ATL_zf77syrk +( const enum ATLAS_UPLO, const enum ATLAS_TRANS, + const int, const int, const double *, const double *, + const int, const double *, double *, const int ); + +void ATL_sf77syr2k +( const enum ATLAS_UPLO, const enum ATLAS_TRANS, + const int, const int, const float, const float *, + const int, const float *, const int, const float, + float *, const int ); +void ATL_df77syr2k +( const enum ATLAS_UPLO, const enum ATLAS_TRANS, + const int, const int, const double, const double *, + const int, const double *, const int, const double, + double *, const int ); +void ATL_cf77syr2k +( const enum ATLAS_UPLO, const enum ATLAS_TRANS, + const int, const int, const float *, const float *, + const int, const float *, const int, const float *, + float *, const int ); +void ATL_zf77syr2k +( const enum ATLAS_UPLO, const enum ATLAS_TRANS, + const int, const int, const double *, const double *, + const int, const double *, const int, const double *, + double *, const int ); + +void ATL_sf77trmm +( const enum ATLAS_SIDE, const enum ATLAS_UPLO, + const enum ATLAS_TRANS, const enum ATLAS_DIAG, + const int, const int, const float, const float *, + const int, float *, const int ); +void ATL_df77trmm +( const enum ATLAS_SIDE, const enum ATLAS_UPLO, + const enum ATLAS_TRANS, const enum ATLAS_DIAG, + const int, const int, const double, const double *, + const int, double *, const int ); +void ATL_cf77trmm +( const enum ATLAS_SIDE, const enum ATLAS_UPLO, + const enum ATLAS_TRANS, const enum ATLAS_DIAG, + const int, const int, const float *, const float *, + const int, float *, const int ); +void ATL_zf77trmm +( const enum ATLAS_SIDE, const enum ATLAS_UPLO, + const enum ATLAS_TRANS, const enum ATLAS_DIAG, + const int, const int, const double *, const double *, + const int, double *, const int ); + +void ATL_sf77trsm +( const enum ATLAS_SIDE, const enum ATLAS_UPLO, + const enum ATLAS_TRANS, const enum ATLAS_DIAG, + const int, const int, const float, const float *, + const int, float *, const int ); +void ATL_df77trsm +( const enum ATLAS_SIDE, const enum ATLAS_UPLO, + const enum ATLAS_TRANS, const enum ATLAS_DIAG, + const int, const int, const double, const double *, + const int, double *, const int ); +void ATL_cf77trsm +( const enum ATLAS_SIDE, const enum ATLAS_UPLO, + const enum ATLAS_TRANS, const enum ATLAS_DIAG, + const int, const int, const float *, const float *, + const int, float *, const int ); +void ATL_zf77trsm +( const enum ATLAS_SIDE, const enum ATLAS_UPLO, + const enum ATLAS_TRANS, const enum ATLAS_DIAG, + const int, const int, const double *, const double *, + const int, double *, const int ); + +#endif diff --git a/kaldi_io/src/tools/ATLAS/include/cblas.h b/kaldi_io/src/tools/ATLAS/include/cblas.h new file mode 100644 index 0000000..4087ffb --- /dev/null +++ b/kaldi_io/src/tools/ATLAS/include/cblas.h @@ -0,0 +1,596 @@ +#ifndef CBLAS_H + +#ifndef CBLAS_ENUM_DEFINED_H + #define CBLAS_ENUM_DEFINED_H + enum CBLAS_ORDER {CblasRowMajor=101, CblasColMajor=102 }; + enum CBLAS_TRANSPOSE {CblasNoTrans=111, CblasTrans=112, CblasConjTrans=113, + AtlasConj=114}; + enum CBLAS_UPLO {CblasUpper=121, CblasLower=122}; + enum CBLAS_DIAG {CblasNonUnit=131, CblasUnit=132}; + enum CBLAS_SIDE {CblasLeft=141, CblasRight=142}; +#endif + +#ifndef CBLAS_ENUM_ONLY +#define CBLAS_H +#define CBLAS_INDEX int + +int cblas_errprn(int ierr, int info, char *form, ...); + +/* + * =========================================================================== + * Prototypes for level 1 BLAS functions (complex are recast as routines) + * =========================================================================== + */ +float cblas_sdsdot(const int N, const float alpha, const float *X, + const int incX, const float *Y, const int incY); +double cblas_dsdot(const int N, const float *X, const int incX, const float *Y, + const int incY); +float cblas_sdot(const int N, const float *X, const int incX, + const float *Y, const int incY); +double cblas_ddot(const int N, const double *X, const int incX, + const double *Y, const int incY); +/* + * Functions having prefixes Z and C only + */ +void cblas_cdotu_sub(const int N, const void *X, const int incX, + const void *Y, const int incY, void *dotu); +void cblas_cdotc_sub(const int N, const void *X, const int incX, + const void *Y, const int incY, void *dotc); + +void cblas_zdotu_sub(const int N, const void *X, const int incX, + const void *Y, const int incY, void *dotu); +void cblas_zdotc_sub(const int N, const void *X, const int incX, + const void *Y, const int incY, void *dotc); + + +/* + * Functions having prefixes S D SC DZ + */ +float cblas_snrm2(const int N, const float *X, const int incX); +float cblas_sasum(const int N, const float *X, const int incX); + +double cblas_dnrm2(const int N, const double *X, const int incX); +double cblas_dasum(const int N, const double *X, const int incX); + +float cblas_scnrm2(const int N, const void *X, const int incX); +float cblas_scasum(const int N, const void *X, const int incX); + +double cblas_dznrm2(const int N, const void *X, const int incX); +double cblas_dzasum(const int N, const void *X, const int incX); + + +/* + * Functions having standard 4 prefixes (S D C Z) + */ +CBLAS_INDEX cblas_isamax(const int N, const float *X, const int incX); +CBLAS_INDEX cblas_idamax(const int N, const double *X, const int incX); +CBLAS_INDEX cblas_icamax(const int N, const void *X, const int incX); +CBLAS_INDEX cblas_izamax(const int N, const void *X, const int incX); + +/* + * =========================================================================== + * Prototypes for level 1 BLAS routines + * =========================================================================== + */ + +/* + * Routines with standard 4 prefixes (s, d, c, z) + */ +void cblas_sswap(const int N, float *X, const int incX, + float *Y, const int incY); +void cblas_scopy(const int N, const float *X, const int incX, + float *Y, const int incY); +void cblas_saxpy(const int N, const float alpha, const float *X, + const int incX, float *Y, const int incY); +void catlas_saxpby(const int N, const float alpha, const float *X, + const int incX, const float beta, float *Y, const int incY); +void catlas_sset + (const int N, const float alpha, float *X, const int incX); + +void cblas_dswap(const int N, double *X, const int incX, + double *Y, const int incY); +void cblas_dcopy(const int N, const double *X, const int incX, + double *Y, const int incY); +void cblas_daxpy(const int N, const double alpha, const double *X, + const int incX, double *Y, const int incY); +void catlas_daxpby(const int N, const double alpha, const double *X, + const int incX, const double beta, double *Y, const int incY); +void catlas_dset + (const int N, const double alpha, double *X, const int incX); + +void cblas_cswap(const int N, void *X, const int incX, + void *Y, const int incY); +void cblas_ccopy(const int N, const void *X, const int incX, + void *Y, const int incY); +void cblas_caxpy(const int N, const void *alpha, const void *X, + const int incX, void *Y, const int incY); +void catlas_caxpby(const int N, const void *alpha, const void *X, + const int incX, const void *beta, void *Y, const int incY); +void catlas_cset + (const int N, const void *alpha, void *X, const int incX); + +void cblas_zswap(const int N, void *X, const int incX, + void *Y, const int incY); +void cblas_zcopy(const int N, const void *X, const int incX, + void *Y, const int incY); +void cblas_zaxpy(const int N, const void *alpha, const void *X, + const int incX, void *Y, const int incY); +void catlas_zaxpby(const int N, const void *alpha, const void *X, + const int incX, const void *beta, void *Y, const int incY); +void catlas_zset + (const int N, const void *alpha, void *X, const int incX); + + +/* + * Routines with S and D prefix only + */ +void cblas_srotg(float *a, float *b, float *c, float *s); +void cblas_srotmg(float *d1, float *d2, float *b1, const float b2, float *P); +void cblas_srot(const int N, float *X, const int incX, + float *Y, const int incY, const float c, const float s); +void cblas_srotm(const int N, float *X, const int incX, + float *Y, const int incY, const float *P); + +void cblas_drotg(double *a, double *b, double *c, double *s); +void cblas_drotmg(double *d1, double *d2, double *b1, const double b2, double *P); +void cblas_drot(const int N, double *X, const int incX, + double *Y, const int incY, const double c, const double s); +void cblas_drotm(const int N, double *X, const int incX, + double *Y, const int incY, const double *P); + + +/* + * Routines with S D C Z CS and ZD prefixes + */ +void cblas_sscal(const int N, const float alpha, float *X, const int incX); +void cblas_dscal(const int N, const double alpha, double *X, const int incX); +void cblas_cscal(const int N, const void *alpha, void *X, const int incX); +void cblas_zscal(const int N, const void *alpha, void *X, const int incX); +void cblas_csscal(const int N, const float alpha, void *X, const int incX); +void cblas_zdscal(const int N, const double alpha, void *X, const int incX); + +/* + * Extra reference routines provided by ATLAS, but not mandated by the standard + */ +void cblas_crotg(void *a, void *b, void *c, void *s); +void cblas_zrotg(void *a, void *b, void *c, void *s); +void cblas_csrot(const int N, void *X, const int incX, void *Y, const int incY, + const float c, const float s); +void cblas_zdrot(const int N, void *X, const int incX, void *Y, const int incY, + const double c, const double s); + +/* + * =========================================================================== + * Prototypes for level 2 BLAS + * =========================================================================== + */ + +/* + * Routines with standard 4 prefixes (S, D, C, Z) + */ +void cblas_sgemv(const enum CBLAS_ORDER Order, + const enum CBLAS_TRANSPOSE TransA, const int M, const int N, + const float alpha, const float *A, const int lda, + const float *X, const int incX, const float beta, + float *Y, const int incY); +void cblas_sgbmv(const enum CBLAS_ORDER Order, + const enum CBLAS_TRANSPOSE TransA, const int M, const int N, + const int KL, const int KU, const float alpha, + const float *A, const int lda, const float *X, + const int incX, const float beta, float *Y, const int incY); +void cblas_strmv(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, + const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag, + const int N, const float *A, const int lda, + float *X, const int incX); +void cblas_stbmv(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, + const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag, + const int N, const int K, const float *A, const int lda, + float *X, const int incX); +void cblas_stpmv(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, + const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag, + const int N, const float *Ap, float *X, const int incX); +void cblas_strsv(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, + const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag, + const int N, const float *A, const int lda, float *X, + const int incX); +void cblas_stbsv(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, + const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag, + const int N, const int K, const float *A, const int lda, + float *X, const int incX); +void cblas_stpsv(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, + const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag, + const int N, const float *Ap, float *X, const int incX); + +void cblas_dgemv(const enum CBLAS_ORDER Order, + const enum CBLAS_TRANSPOSE TransA, const int M, const int N, + const double alpha, const double *A, const int lda, + const double *X, const int incX, const double beta, + double *Y, const int incY); +void cblas_dgbmv(const enum CBLAS_ORDER Order, + const enum CBLAS_TRANSPOSE TransA, const int M, const int N, + const int KL, const int KU, const double alpha, + const double *A, const int lda, const double *X, + const int incX, const double beta, double *Y, const int incY); +void cblas_dtrmv(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, + const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag, + const int N, const double *A, const int lda, + double *X, const int incX); +void cblas_dtbmv(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, + const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag, + const int N, const int K, const double *A, const int lda, + double *X, const int incX); +void cblas_dtpmv(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, + const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag, + const int N, const double *Ap, double *X, const int incX); +void cblas_dtrsv(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, + const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag, + const int N, const double *A, const int lda, double *X, + const int incX); +void cblas_dtbsv(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, + const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag, + const int N, const int K, const double *A, const int lda, + double *X, const int incX); +void cblas_dtpsv(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, + const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag, + const int N, const double *Ap, double *X, const int incX); + +void cblas_cgemv(const enum CBLAS_ORDER Order, + const enum CBLAS_TRANSPOSE TransA, const int M, const int N, + const void *alpha, const void *A, const int lda, + const void *X, const int incX, const void *beta, + void *Y, const int incY); +void cblas_cgbmv(const enum CBLAS_ORDER Order, + const enum CBLAS_TRANSPOSE TransA, const int M, const int N, + const int KL, const int KU, const void *alpha, + const void *A, const int lda, const void *X, + const int incX, const void *beta, void *Y, const int incY); +void cblas_ctrmv(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, + const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag, + const int N, const void *A, const int lda, + void *X, const int incX); +void cblas_ctbmv(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, + const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag, + const int N, const int K, const void *A, const int lda, + void *X, const int incX); +void cblas_ctpmv(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, + const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag, + const int N, const void *Ap, void *X, const int incX); +void cblas_ctrsv(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, + const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag, + const int N, const void *A, const int lda, void *X, + const int incX); +void cblas_ctbsv(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, + const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag, + const int N, const int K, const void *A, const int lda, + void *X, const int incX); +void cblas_ctpsv(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, + const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag, + const int N, const void *Ap, void *X, const int incX); + +void cblas_zgemv(const enum CBLAS_ORDER Order, + const enum CBLAS_TRANSPOSE TransA, const int M, const int N, + const void *alpha, const void *A, const int lda, + const void *X, const int incX, const void *beta, + void *Y, const int incY); +void cblas_zgbmv(const enum CBLAS_ORDER Order, + const enum CBLAS_TRANSPOSE TransA, const int M, const int N, + const int KL, const int KU, const void *alpha, + const void *A, const int lda, const void *X, + const int incX, const void *beta, void *Y, const int incY); +void cblas_ztrmv(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, + const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag, + const int N, const void *A, const int lda, + void *X, const int incX); +void cblas_ztbmv(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, + const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag, + const int N, const int K, const void *A, const int lda, + void *X, const int incX); +void cblas_ztpmv(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, + const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag, + const int N, const void *Ap, void *X, const int incX); +void cblas_ztrsv(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, + const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag, + const int N, const void *A, const int lda, void *X, + const int incX); +void cblas_ztbsv(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, + const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag, + const int N, const int K, const void *A, const int lda, + void *X, const int incX); +void cblas_ztpsv(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, + const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag, + const int N, const void *Ap, void *X, const int incX); + + +/* + * Routines with S and D prefixes only + */ +void cblas_ssymv(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, + const int N, const float alpha, const float *A, + const int lda, const float *X, const int incX, + const float beta, float *Y, const int incY); +void cblas_ssbmv(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, + const int N, const int K, const float alpha, const float *A, + const int lda, const float *X, const int incX, + const float beta, float *Y, const int incY); +void cblas_sspmv(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, + const int N, const float alpha, const float *Ap, + const float *X, const int incX, + const float beta, float *Y, const int incY); +void cblas_sger(const enum CBLAS_ORDER Order, const int M, const int N, + const float alpha, const float *X, const int incX, + const float *Y, const int incY, float *A, const int lda); +void cblas_ssyr(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, + const int N, const float alpha, const float *X, + const int incX, float *A, const int lda); +void cblas_sspr(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, + const int N, const float alpha, const float *X, + const int incX, float *Ap); +void cblas_ssyr2(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, + const int N, const float alpha, const float *X, + const int incX, const float *Y, const int incY, float *A, + const int lda); +void cblas_sspr2(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, + const int N, const float alpha, const float *X, + const int incX, const float *Y, const int incY, float *A); + +void cblas_dsymv(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, + const int N, const double alpha, const double *A, + const int lda, const double *X, const int incX, + const double beta, double *Y, const int incY); +void cblas_dsbmv(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, + const int N, const int K, const double alpha, const double *A, + const int lda, const double *X, const int incX, + const double beta, double *Y, const int incY); +void cblas_dspmv(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, + const int N, const double alpha, const double *Ap, + const double *X, const int incX, + const double beta, double *Y, const int incY); +void cblas_dger(const enum CBLAS_ORDER Order, const int M, const int N, + const double alpha, const double *X, const int incX, + const double *Y, const int incY, double *A, const int lda); +void cblas_dsyr(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, + const int N, const double alpha, const double *X, + const int incX, double *A, const int lda); +void cblas_dspr(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, + const int N, const double alpha, const double *X, + const int incX, double *Ap); +void cblas_dsyr2(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, + const int N, const double alpha, const double *X, + const int incX, const double *Y, const int incY, double *A, + const int lda); +void cblas_dspr2(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, + const int N, const double alpha, const double *X, + const int incX, const double *Y, const int incY, double *A); + + +/* + * Routines with C and Z prefixes only + */ +void cblas_chemv(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, + const int N, const void *alpha, const void *A, + const int lda, const void *X, const int incX, + const void *beta, void *Y, const int incY); +void cblas_chbmv(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, + const int N, const int K, const void *alpha, const void *A, + const int lda, const void *X, const int incX, + const void *beta, void *Y, const int incY); +void cblas_chpmv(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, + const int N, const void *alpha, const void *Ap, + const void *X, const int incX, + const void *beta, void *Y, const int incY); +void cblas_cgeru(const enum CBLAS_ORDER Order, const int M, const int N, + const void *alpha, const void *X, const int incX, + const void *Y, const int incY, void *A, const int lda); +void cblas_cgerc(const enum CBLAS_ORDER Order, const int M, const int N, + const void *alpha, const void *X, const int incX, + const void *Y, const int incY, void *A, const int lda); +void cblas_cher(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, + const int N, const float alpha, const void *X, const int incX, + void *A, const int lda); +void cblas_chpr(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, + const int N, const float alpha, const void *X, + const int incX, void *A); +void cblas_cher2(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, const int N, + const void *alpha, const void *X, const int incX, + const void *Y, const int incY, void *A, const int lda); +void cblas_chpr2(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, const int N, + const void *alpha, const void *X, const int incX, + const void *Y, const int incY, void *Ap); + +void cblas_zhemv(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, + const int N, const void *alpha, const void *A, + const int lda, const void *X, const int incX, + const void *beta, void *Y, const int incY); +void cblas_zhbmv(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, + const int N, const int K, const void *alpha, const void *A, + const int lda, const void *X, const int incX, + const void *beta, void *Y, const int incY); +void cblas_zhpmv(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, + const int N, const void *alpha, const void *Ap, + const void *X, const int incX, + const void *beta, void *Y, const int incY); +void cblas_zgeru(const enum CBLAS_ORDER Order, const int M, const int N, + const void *alpha, const void *X, const int incX, + const void *Y, const int incY, void *A, const int lda); +void cblas_zgerc(const enum CBLAS_ORDER Order, const int M, const int N, + const void *alpha, const void *X, const int incX, + const void *Y, const int incY, void *A, const int lda); +void cblas_zher(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, + const int N, const double alpha, const void *X, const int incX, + void *A, const int lda); +void cblas_zhpr(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, + const int N, const double alpha, const void *X, + const int incX, void *A); +void cblas_zher2(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, const int N, + const void *alpha, const void *X, const int incX, + const void *Y, const int incY, void *A, const int lda); +void cblas_zhpr2(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, const int N, + const void *alpha, const void *X, const int incX, + const void *Y, const int incY, void *Ap); + +/* + * =========================================================================== + * Prototypes for level 3 BLAS + * =========================================================================== + */ + +/* + * Routines with standard 4 prefixes (S, D, C, Z) + */ +void cblas_sgemm(const enum CBLAS_ORDER Order, const enum CBLAS_TRANSPOSE TransA, + const enum CBLAS_TRANSPOSE TransB, const int M, const int N, + const int K, const float alpha, const float *A, + const int lda, const float *B, const int ldb, + const float beta, float *C, const int ldc); +void cblas_ssymm(const enum CBLAS_ORDER Order, const enum CBLAS_SIDE Side, + const enum CBLAS_UPLO Uplo, const int M, const int N, + const float alpha, const float *A, const int lda, + const float *B, const int ldb, const float beta, + float *C, const int ldc); +void cblas_ssyrk(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, + const enum CBLAS_TRANSPOSE Trans, const int N, const int K, + const float alpha, const float *A, const int lda, + const float beta, float *C, const int ldc); +void cblas_ssyr2k(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, + const enum CBLAS_TRANSPOSE Trans, const int N, const int K, + const float alpha, const float *A, const int lda, + const float *B, const int ldb, const float beta, + float *C, const int ldc); +void cblas_strmm(const enum CBLAS_ORDER Order, const enum CBLAS_SIDE Side, + const enum CBLAS_UPLO Uplo, const enum CBLAS_TRANSPOSE TransA, + const enum CBLAS_DIAG Diag, const int M, const int N, + const float alpha, const float *A, const int lda, + float *B, const int ldb); +void cblas_strsm(const enum CBLAS_ORDER Order, const enum CBLAS_SIDE Side, + const enum CBLAS_UPLO Uplo, const enum CBLAS_TRANSPOSE TransA, + const enum CBLAS_DIAG Diag, const int M, const int N, + const float alpha, const float *A, const int lda, + float *B, const int ldb); + +void cblas_dgemm(const enum CBLAS_ORDER Order, const enum CBLAS_TRANSPOSE TransA, + const enum CBLAS_TRANSPOSE TransB, const int M, const int N, + const int K, const double alpha, const double *A, + const int lda, const double *B, const int ldb, + const double beta, double *C, const int ldc); +void cblas_dsymm(const enum CBLAS_ORDER Order, const enum CBLAS_SIDE Side, + const enum CBLAS_UPLO Uplo, const int M, const int N, + const double alpha, const double *A, const int lda, + const double *B, const int ldb, const double beta, + double *C, const int ldc); +void cblas_dsyrk(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, + const enum CBLAS_TRANSPOSE Trans, const int N, const int K, + const double alpha, const double *A, const int lda, + const double beta, double *C, const int ldc); +void cblas_dsyr2k(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, + const enum CBLAS_TRANSPOSE Trans, const int N, const int K, + const double alpha, const double *A, const int lda, + const double *B, const int ldb, const double beta, + double *C, const int ldc); +void cblas_dtrmm(const enum CBLAS_ORDER Order, const enum CBLAS_SIDE Side, + const enum CBLAS_UPLO Uplo, const enum CBLAS_TRANSPOSE TransA, + const enum CBLAS_DIAG Diag, const int M, const int N, + const double alpha, const double *A, const int lda, + double *B, const int ldb); +void cblas_dtrsm(const enum CBLAS_ORDER Order, const enum CBLAS_SIDE Side, + const enum CBLAS_UPLO Uplo, const enum CBLAS_TRANSPOSE TransA, + const enum CBLAS_DIAG Diag, const int M, const int N, + const double alpha, const double *A, const int lda, + double *B, const int ldb); + +void cblas_cgemm(const enum CBLAS_ORDER Order, const enum CBLAS_TRANSPOSE TransA, + const enum CBLAS_TRANSPOSE TransB, const int M, const int N, + const int K, const void *alpha, const void *A, + const int lda, const void *B, const int ldb, + const void *beta, void *C, const int ldc); +void cblas_csymm(const enum CBLAS_ORDER Order, const enum CBLAS_SIDE Side, + const enum CBLAS_UPLO Uplo, const int M, const int N, + const void *alpha, const void *A, const int lda, + const void *B, const int ldb, const void *beta, + void *C, const int ldc); +void cblas_csyrk(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, + const enum CBLAS_TRANSPOSE Trans, const int N, const int K, + const void *alpha, const void *A, const int lda, + const void *beta, void *C, const int ldc); +void cblas_csyr2k(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, + const enum CBLAS_TRANSPOSE Trans, const int N, const int K, + const void *alpha, const void *A, const int lda, + const void *B, const int ldb, const void *beta, + void *C, const int ldc); +void cblas_ctrmm(const enum CBLAS_ORDER Order, const enum CBLAS_SIDE Side, + const enum CBLAS_UPLO Uplo, const enum CBLAS_TRANSPOSE TransA, + const enum CBLAS_DIAG Diag, const int M, const int N, + const void *alpha, const void *A, const int lda, + void *B, const int ldb); +void cblas_ctrsm(const enum CBLAS_ORDER Order, const enum CBLAS_SIDE Side, + const enum CBLAS_UPLO Uplo, const enum CBLAS_TRANSPOSE TransA, + const enum CBLAS_DIAG Diag, const int M, const int N, + const void *alpha, const void *A, const int lda, + void *B, const int ldb); + +void cblas_zgemm(const enum CBLAS_ORDER Order, const enum CBLAS_TRANSPOSE TransA, + const enum CBLAS_TRANSPOSE TransB, const int M, const int N, + const int K, const void *alpha, const void *A, + const int lda, const void *B, const int ldb, + const void *beta, void *C, const int ldc); +void cblas_zsymm(const enum CBLAS_ORDER Order, const enum CBLAS_SIDE Side, + const enum CBLAS_UPLO Uplo, const int M, const int N, + const void *alpha, const void *A, const int lda, + const void *B, const int ldb, const void *beta, + void *C, const int ldc); +void cblas_zsyrk(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, + const enum CBLAS_TRANSPOSE Trans, const int N, const int K, + const void *alpha, const void *A, const int lda, + const void *beta, void *C, const int ldc); +void cblas_zsyr2k(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, + const enum CBLAS_TRANSPOSE Trans, const int N, const int K, + const void *alpha, const void *A, const int lda, + const void *B, const int ldb, const void *beta, + void *C, const int ldc); +void cblas_ztrmm(const enum CBLAS_ORDER Order, const enum CBLAS_SIDE Side, + const enum CBLAS_UPLO Uplo, const enum CBLAS_TRANSPOSE TransA, + const enum CBLAS_DIAG Diag, const int M, const int N, + const void *alpha, const void *A, const int lda, + void *B, const int ldb); +void cblas_ztrsm(const enum CBLAS_ORDER Order, const enum CBLAS_SIDE Side, + const enum CBLAS_UPLO Uplo, const enum CBLAS_TRANSPOSE TransA, + const enum CBLAS_DIAG Diag, const int M, const int N, + const void *alpha, const void *A, const int lda, + void *B, const int ldb); + + +/* + * Routines with prefixes C and Z only + */ +void cblas_chemm(const enum CBLAS_ORDER Order, const enum CBLAS_SIDE Side, + const enum CBLAS_UPLO Uplo, const int M, const int N, + const void *alpha, const void *A, const int lda, + const void *B, const int ldb, const void *beta, + void *C, const int ldc); +void cblas_cherk(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, + const enum CBLAS_TRANSPOSE Trans, const int N, const int K, + const float alpha, const void *A, const int lda, + const float beta, void *C, const int ldc); +void cblas_cher2k(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, + const enum CBLAS_TRANSPOSE Trans, const int N, const int K, + const void *alpha, const void *A, const int lda, + const void *B, const int ldb, const float beta, + void *C, const int ldc); +void cblas_zhemm(const enum CBLAS_ORDER Order, const enum CBLAS_SIDE Side, + const enum CBLAS_UPLO Uplo, const int M, const int N, + const void *alpha, const void *A, const int lda, + const void *B, const int ldb, const void *beta, + void *C, const int ldc); +void cblas_zherk(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, + const enum CBLAS_TRANSPOSE Trans, const int N, const int K, + const double alpha, const void *A, const int lda, + const double beta, void *C, const int ldc); +void cblas_zher2k(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, + const enum CBLAS_TRANSPOSE Trans, const int N, const int K, + const void *alpha, const void *A, const int lda, + const void *B, const int ldb, const double beta, + void *C, const int ldc); + +int cblas_errprn(int ierr, int info, char *form, ...); + +#endif /* end #ifdef CBLAS_ENUM_ONLY */ +#endif diff --git a/kaldi_io/src/tools/ATLAS/include/cblas_test.h b/kaldi_io/src/tools/ATLAS/include/cblas_test.h new file mode 100644 index 0000000..b871a47 --- /dev/null +++ b/kaldi_io/src/tools/ATLAS/include/cblas_test.h @@ -0,0 +1,542 @@ +/* + * Added by R. Clint Whaley to make compatible with ATLAS + */ +#if defined(Add_) || defined(Add__) + #define ADD_ +#elif defined(NoChange) + #define NOCHANGE +#elif defined(UpCase) + #define UPCASE +#endif + +#ifdef ADD_ + #define F77_crotg crotgtest_ + #define F77_zrotg zrotgtest_ + #define F77_csrot csrottest_ + #define F77_zdrot zdrottest_ +#elif defined NOCHANGE + #define F77_crotg crotgtest + #define F77_zrotg zrotgtest + #define F77_csrot csrottest + #define F77_zdrot zdrottest +#elif defined UPCASE + #define F77_crotg CROTGTEST + #define F77_zrotg ZROTGTEST + #define F77_csrot CSROTTEST + #define F77_zdrot ZDROTTEST +#endif + + +/* + * cblas_test.h + * Written by Keita Teranishi + */ +#ifndef CBLAS_TEST_H +#define CBLAS_TEST_H +#include "cblas.h" + +#define TRUE 1 +#define PASSED 1 +#define TEST_ROW_MJR 1 + +#define FALSE 0 +#define FAILED 0 +#define TEST_COL_MJR 0 + +#define INVALID -1 +#define UNDEFINED -1 + +typedef struct { float real; float imag; } CBLAS_TEST_COMPLEX; +typedef struct { double real; double imag; } CBLAS_TEST_ZOMPLEX; + +#if defined(ADD_) + #define F77_xerbla xerbla_ +/* + * Level 1 BLAS + */ + #define F77_srotg srotgtest_ + #define F77_srotmg srotmgtest_ + #define F77_srot srottest_ + #define F77_srotm srotmtest_ + #define F77_drotg drotgtest_ + #define F77_drotmg drotmgtest_ + #define F77_drot drottest_ + #define F77_drotm drotmtest_ + #define F77_sswap sswaptest_ + #define F77_scopy scopytest_ + #define F77_saxpy saxpytest_ + #define F77_isamax isamaxtest_ + #define F77_dswap dswaptest_ + #define F77_dcopy dcopytest_ + #define F77_daxpy daxpytest_ + #define F77_idamax idamaxtest_ + #define F77_cswap cswaptest_ + #define F77_ccopy ccopytest_ + #define F77_caxpy caxpytest_ + #define F77_icamax icamaxtest_ + #define F77_zswap zswaptest_ + #define F77_zcopy zcopytest_ + #define F77_zaxpy zaxpytest_ + #define F77_izamax izamaxtest_ + #define F77_sdot sdottestsub_ + #define F77_ddot ddottestsub_ + #define F77_dsdot dsdottest_ + #define F77_sscal sscaltest_ + #define F77_dscal dscaltest_ + #define F77_cscal cscaltest_ + #define F77_zscal zscaltest_ + #define F77_csscal csscaltest_ + #define F77_zdscal zdscaltest_ + #define F77_cdotu cdotutest_ + #define F77_cdotc cdotctest_ + #define F77_zdotu zdotutest_ + #define F77_zdotc zdotctest_ + #define F77_snrm2 snrm2testsub_ + #define F77_sasum sasumtestsub_ + #define F77_dnrm2 dnrm2testsub_ + #define F77_dasum dasumtestsub_ + #define F77_scnrm2 scnrm2testsub_ + #define F77_scasum scasumtestsub_ + #define F77_dznrm2 dznrm2testsub_ + #define F77_dzasum dzasumtestsub_ + #define F77_sdsdot sdsdottest_ +/* + * Level 2 BLAS + */ + #define F77_s2chke cs2chke_ + #define F77_d2chke cd2chke_ + #define F77_c2chke cc2chke_ + #define F77_z2chke cz2chke_ + #define F77_ssymv cssymv_ + #define F77_ssbmv cssbmv_ + #define F77_sspmv csspmv_ + #define F77_sger csger_ + #define F77_ssyr cssyr_ + #define F77_sspr csspr_ + #define F77_ssyr2 cssyr2_ + #define F77_sspr2 csspr2_ + #define F77_dsymv cdsymv_ + #define F77_dsbmv cdsbmv_ + #define F77_dspmv cdspmv_ + #define F77_dger cdger_ + #define F77_dsyr cdsyr_ + #define F77_dspr cdspr_ + #define F77_dsyr2 cdsyr2_ + #define F77_dspr2 cdspr2_ + #define F77_chemv cchemv_ + #define F77_chbmv cchbmv_ + #define F77_chpmv cchpmv_ + #define F77_cgeru ccgeru_ + #define F77_cgerc ccgerc_ + #define F77_cher ccher_ + #define F77_chpr cchpr_ + #define F77_cher2 ccher2_ + #define F77_chpr2 cchpr2_ + #define F77_zhemv czhemv_ + #define F77_zhbmv czhbmv_ + #define F77_zhpmv czhpmv_ + #define F77_zgeru czgeru_ + #define F77_zgerc czgerc_ + #define F77_zher czher_ + #define F77_zhpr czhpr_ + #define F77_zher2 czher2_ + #define F77_zhpr2 czhpr2_ + #define F77_sgemv csgemv_ + #define F77_sgbmv csgbmv_ + #define F77_strmv cstrmv_ + #define F77_stbmv cstbmv_ + #define F77_stpmv cstpmv_ + #define F77_strsv cstrsv_ + #define F77_stbsv cstbsv_ + #define F77_stpsv cstpsv_ + #define F77_dgemv cdgemv_ + #define F77_dgbmv cdgbmv_ + #define F77_dtrmv cdtrmv_ + #define F77_dtbmv cdtbmv_ + #define F77_dtpmv cdtpmv_ + #define F77_dtrsv cdtrsv_ + #define F77_dtbsv cdtbsv_ + #define F77_dtpsv cdtpsv_ + #define F77_cgemv ccgemv_ + #define F77_cgbmv ccgbmv_ + #define F77_ctrmv cctrmv_ + #define F77_ctbmv cctbmv_ + #define F77_ctpmv cctpmv_ + #define F77_ctrsv cctrsv_ + #define F77_ctbsv cctbsv_ + #define F77_ctpsv cctpsv_ + #define F77_zgemv czgemv_ + #define F77_zgbmv czgbmv_ + #define F77_ztrmv cztrmv_ + #define F77_ztbmv cztbmv_ + #define F77_ztpmv cztpmv_ + #define F77_ztrsv cztrsv_ + #define F77_ztbsv cztbsv_ + #define F77_ztpsv cztpsv_ +/* + * Level 3 BLAS + */ + #define F77_s3chke cs3chke_ + #define F77_d3chke cd3chke_ + #define F77_c3chke cc3chke_ + #define F77_z3chke cz3chke_ + #define F77_chemm cchemm_ + #define F77_cherk ccherk_ + #define F77_cher2k ccher2k_ + #define F77_zhemm czhemm_ + #define F77_zherk czherk_ + #define F77_zher2k czher2k_ + #define F77_sgemm csgemm_ + #define F77_ssymm cssymm_ + #define F77_ssyrk cssyrk_ + #define F77_ssyr2k cssyr2k_ + #define F77_strmm cstrmm_ + #define F77_strsm cstrsm_ + #define F77_dgemm cdgemm_ + #define F77_dsymm cdsymm_ + #define F77_dsyrk cdsyrk_ + #define F77_dsyr2k cdsyr2k_ + #define F77_dtrmm cdtrmm_ + #define F77_dtrsm cdtrsm_ + #define F77_cgemm ccgemm_ + #define F77_csymm ccsymm_ + #define F77_csyrk ccsyrk_ + #define F77_csyr2k ccsyr2k_ + #define F77_ctrmm cctrmm_ + #define F77_ctrsm cctrsm_ + #define F77_zgemm czgemm_ + #define F77_zsymm czsymm_ + #define F77_zsyrk czsyrk_ + #define F77_zsyr2k czsyr2k_ + #define F77_ztrmm cztrmm_ + #define F77_ztrsm cztrsm_ +#elif defined(UPCASE) + #define F77_xerbla XERBLA +/* + * Level 1 BLAS + */ + #define F77_srotg SROTGTEST + #define F77_srotmg SROTMGTEST + #define F77_srot SROTTEST + #define F77_srotm SROTMTEST + #define F77_drotg DROTGTEST + #define F77_drotmg DROTMGTEST + #define F77_drot DROTTEST + #define F77_drotm DROTMTEST + #define F77_sswap SSWAPTEST + #define F77_scopy SCOPYTEST + #define F77_saxpy SAXPYTEST + #define F77_isamax ISAMAXTEST + #define F77_dswap DSWAPTEST + #define F77_dcopy DCOPYTEST + #define F77_daxpy DAXPYTEST + #define F77_idamax IDAMAXTEST + #define F77_cswap CSWAPTEST + #define F77_ccopy CCOPYTEST + #define F77_caxpy CAXPYTEST + #define F77_icamax ICAMAXTEST + #define F77_zswap ZSWAPTEST + #define F77_zcopy ZCOPYTEST + #define F77_zaxpy ZAXPYTEST + #define F77_izamax IZAMAXTEST + #define F77_sdot SDOTTESTSUB + #define F77_ddot DDOTTESTSUB + #define F77_dsdot DSDOTTEST + #define F77_sscal SSCALTEST + #define F77_dscal DSCALTEST + #define F77_cscal CSCALTEST + #define F77_zscal ZSCALTEST + #define F77_csscal CSSCALTEST + #define F77_zdscal ZDSCALTEST + #define F77_cdotu CDOTUTEST + #define F77_cdotc CDOTCTEST + #define F77_zdotu ZDOTUTEST + #define F77_zdotc ZDOTCTEST + #define F77_snrm2 SNRM2TESTSUB + #define F77_sasum SASUMTESTSUB + #define F77_dnrm2 DNRM2TESTSUB + #define F77_dasum DASUMTESTSUB + #define F77_scnrm2 SCNRM2TESTSUB + #define F77_scasum SCASUMTESTSUB + #define F77_dznrm2 DZNRM2TESTSUB + #define F77_dzasum DZASUMTESTSUB + #define F77_sdsdot SDSDOTTEST +/* + * Level 2 BLAS + */ + #define F77_s2chke CS2CHKE + #define F77_d2chke CD2CHKE + #define F77_c2chke CC2CHKE + #define F77_z2chke CZ2CHKE + #define F77_ssymv CSSYMV + #define F77_ssbmv CSSBMV + #define F77_sspmv CSSPMV + #define F77_sger CSGER + #define F77_ssyr CSSYR + #define F77_sspr CSSPR + #define F77_ssyr2 CSSYR2 + #define F77_sspr2 CSSPR2 + #define F77_dsymv CDSYMV + #define F77_dsbmv CDSBMV + #define F77_dspmv CDSPMV + #define F77_dger CDGER + #define F77_dsyr CDSYR + #define F77_dspr CDSPR + #define F77_dsyr2 CDSYR2 + #define F77_dspr2 CDSPR2 + #define F77_chemv CCHEMV + #define F77_chbmv CCHBMV + #define F77_chpmv CCHPMV + #define F77_cgeru CCGERU + #define F77_cgerc CCGERC + #define F77_cher CCHER + #define F77_chpr CCHPR + #define F77_cher2 CCHER2 + #define F77_chpr2 CCHPR2 + #define F77_zhemv CZHEMV + #define F77_zhbmv CZHBMV + #define F77_zhpmv CZHPMV + #define F77_zgeru CZGERU + #define F77_zgerc CZGERC + #define F77_zher CZHER + #define F77_zhpr CZHPR + #define F77_zher2 CZHER2 + #define F77_zhpr2 CZHPR2 + #define F77_sgemv CSGEMV + #define F77_sgbmv CSGBMV + #define F77_strmv CSTRMV + #define F77_stbmv CSTBMV + #define F77_stpmv CSTPMV + #define F77_strsv CSTRSV + #define F77_stbsv CSTBSV + #define F77_stpsv CSTPSV + #define F77_dgemv CDGEMV + #define F77_dgbmv CDGBMV + #define F77_dtrmv CDTRMV + #define F77_dtbmv CDTBMV + #define F77_dtpmv CDTPMV + #define F77_dtrsv CDTRSV + #define F77_dtbsv CDTBSV + #define F77_dtpsv CDTPSV + #define F77_cgemv CCGEMV + #define F77_cgbmv CCGBMV + #define F77_ctrmv CCTRMV + #define F77_ctbmv CCTBMV + #define F77_ctpmv CCTPMV + #define F77_ctrsv CCTRSV + #define F77_ctbsv CCTBSV + #define F77_ctpsv CCTPSV + #define F77_zgemv CZGEMV + #define F77_zgbmv CZGBMV + #define F77_ztrmv CZTRMV + #define F77_ztbmv CZTBMV + #define F77_ztpmv CZTPMV + #define F77_ztrsv CZTRSV + #define F77_ztbsv CZTBSV + #define F77_ztpsv CZTPSV +/* + * Level 3 BLAS + */ + #define F77_s3chke CS3CHKE + #define F77_d3chke CD3CHKE + #define F77_c3chke CC3CHKE + #define F77_z3chke CZ3CHKE + #define F77_chemm CCHEMM + #define F77_cherk CCHERK + #define F77_cher2k CCHER2K + #define F77_zhemm CZHEMM + #define F77_zherk CZHERK + #define F77_zher2k CZHER2K + #define F77_sgemm CSGEMM + #define F77_ssymm CSSYMM + #define F77_ssyrk CSSYRK + #define F77_ssyr2k CSSYR2K + #define F77_strmm CSTRMM + #define F77_strsm CSTRSM + #define F77_dgemm CDGEMM + #define F77_dsymm CDSYMM + #define F77_dsyrk CDSYRK + #define F77_dsyr2k CDSYR2K + #define F77_dtrmm CDTRMM + #define F77_dtrsm CDTRSM + #define F77_cgemm CCGEMM + #define F77_csymm CCSYMM + #define F77_csyrk CCSYRK + #define F77_csyr2k CCSYR2K + #define F77_ctrmm CCTRMM + #define F77_ctrsm CCTRSM + #define F77_zgemm CZGEMM + #define F77_zsymm CZSYMM + #define F77_zsyrk CZSYRK + #define F77_zsyr2k CZSYR2K + #define F77_ztrmm CZTRMM + #define F77_ztrsm CZTRSM +#elif defined(NOCHANGE) + #define F77_xerbla xerbla +/* + * Level 1 BLAS + */ + #define F77_srotg srotgtest + #define F77_srotmg srotmgtest + #define F77_srot srottest + #define F77_srotm srotmtest + #define F77_drotg drotgtest + #define F77_drotmg drotmgtest + #define F77_drot drottest + #define F77_drotm drotmtest + #define F77_sswap sswaptest + #define F77_scopy scopytest + #define F77_saxpy saxpytest + #define F77_isamax isamaxtest + #define F77_dswap dswaptest + #define F77_dcopy dcopytest + #define F77_daxpy daxpytest + #define F77_idamax idamaxtest + #define F77_cswap cswaptest + #define F77_ccopy ccopytest + #define F77_caxpy caxpytest + #define F77_icamax icamaxtest + #define F77_zswap zswaptest + #define F77_zcopy zcopytest + #define F77_zaxpy zaxpytest + #define F77_izamax izamaxtest + #define F77_sdot sdottestsub + #define F77_ddot ddottestsub + #define F77_dsdot dsdottest + #define F77_sscal sscaltest + #define F77_dscal dscaltest + #define F77_cscal cscaltest + #define F77_zscal zscaltest + #define F77_csscal csscaltest + #define F77_zdscal zdscaltest + #define F77_cdotu cdotutest + #define F77_cdotc cdotctest + #define F77_zdotu zdotutest + #define F77_zdotc zdotctest + #define F77_snrm2 snrm2testsub + #define F77_sasum sasumtestsub + #define F77_dnrm2 dnrm2testsub + #define F77_dasum dasumtestsub + #define F77_scnrm2 scnrm2testsub + #define F77_scasum scasumtestsub + #define F77_dznrm2 dznrm2testsub + #define F77_dzasum dzasumtestsub + #define F77_sdsdot sdsdottest +/* + * Level 2 BLAS + */ + #define F77_s2chke cs2chke + #define F77_d2chke cd2chke + #define F77_c2chke cc2chke + #define F77_z2chke cz2chke + #define F77_ssymv cssymv + #define F77_ssbmv cssbmv + #define F77_sspmv csspmv + #define F77_sger csger + #define F77_ssyr cssyr + #define F77_sspr csspr + #define F77_ssyr2 cssyr2 + #define F77_sspr2 csspr2 + #define F77_dsymv cdsymv + #define F77_dsbmv cdsbmv + #define F77_dspmv cdspmv + #define F77_dger cdger + #define F77_dsyr cdsyr + #define F77_dspr cdspr + #define F77_dsyr2 cdsyr2 + #define F77_dspr2 cdspr2 + #define F77_chemv cchemv + #define F77_chbmv cchbmv + #define F77_chpmv cchpmv + #define F77_cgeru ccgeru + #define F77_cgerc ccgerc + #define F77_cher ccher + #define F77_chpr cchpr + #define F77_cher2 ccher2 + #define F77_chpr2 cchpr2 + #define F77_zhemv czhemv + #define F77_zhbmv czhbmv + #define F77_zhpmv czhpmv + #define F77_zgeru czgeru + #define F77_zgerc czgerc + #define F77_zher czher + #define F77_zhpr czhpr + #define F77_zher2 czher2 + #define F77_zhpr2 czhpr2 + #define F77_sgemv csgemv + #define F77_sgbmv csgbmv + #define F77_strmv cstrmv + #define F77_stbmv cstbmv + #define F77_stpmv cstpmv + #define F77_strsv cstrsv + #define F77_stbsv cstbsv + #define F77_stpsv cstpsv + #define F77_dgemv cdgemv + #define F77_dgbmv cdgbmv + #define F77_dtrmv cdtrmv + #define F77_dtbmv cdtbmv + #define F77_dtpmv cdtpmv + #define F77_dtrsv cdtrsv + #define F77_dtbsv cdtbsv + #define F77_dtpsv cdtpsv + #define F77_cgemv ccgemv + #define F77_cgbmv ccgbmv + #define F77_ctrmv cctrmv + #define F77_ctbmv cctbmv + #define F77_ctpmv cctpmv + #define F77_ctrsv cctrsv + #define F77_ctbsv cctbsv + #define F77_ctpsv cctpsv + #define F77_zgemv czgemv + #define F77_zgbmv czgbmv + #define F77_ztrmv cztrmv + #define F77_ztbmv cztbmv + #define F77_ztpmv cztpmv + #define F77_ztrsv cztrsv + #define F77_ztbsv cztbsv + #define F77_ztpsv cztpsv +/* + * Level 3 BLAS + */ + #define F77_s3chke cs3chke + #define F77_d3chke cd3chke + #define F77_c3chke cc3chke + #define F77_z3chke cz3chke + #define F77_chemm cchemm + #define F77_cherk ccherk + #define F77_cher2k ccher2k + #define F77_zhemm czhemm + #define F77_zherk czherk + #define F77_zher2k czher2k + #define F77_sgemm csgemm + #define F77_ssymm cssymm + #define F77_ssyrk cssyrk + #define F77_ssyr2k cssyr2k + #define F77_strmm cstrmm + #define F77_strsm cstrsm + #define F77_dgemm cdgemm + #define F77_dsymm cdsymm + #define F77_dsyrk cdsyrk + #define F77_dsyr2k cdsyr2k + #define F77_dtrmm cdtrmm + #define F77_dtrsm cdtrsm + #define F77_cgemm ccgemm + #define F77_csymm ccsymm + #define F77_csyrk ccsyrk + #define F77_csyr2k ccsyr2k + #define F77_ctrmm cctrmm + #define F77_ctrsm cctrsm + #define F77_zgemm czgemm + #define F77_zsymm czsymm + #define F77_zsyrk czsyrk + #define F77_zsyr2k czsyr2k + #define F77_ztrmm cztrmm + #define F77_ztrsm cztrsm +#endif + +void get_transpose_type(char *type, enum CBLAS_TRANSPOSE *trans); +void get_uplo_type(char *type, enum CBLAS_UPLO *uplo); +void get_diag_type(char *type, enum CBLAS_DIAG *diag); +void get_side_type(char *type, enum CBLAS_SIDE *side); + +#endif /* CBLAS_TEST_H */ diff --git a/kaldi_io/src/tools/ATLAS/include/clapack.h b/kaldi_io/src/tools/ATLAS/include/clapack.h new file mode 100644 index 0000000..c5dde3f --- /dev/null +++ b/kaldi_io/src/tools/ATLAS/include/clapack.h @@ -0,0 +1,149 @@ +/* + * Automatically Tuned Linear Algebra Software v3.8.3 + * (C) Copyright 1999 R. Clint Whaley + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions + * are met: + * 1. Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * 2. Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions, and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * 3. The name of the ATLAS group or the names of its contributers may + * not be used to endorse or promote products derived from this + * software without specific written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + * ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED + * TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR + * PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE ATLAS GROUP OR ITS CONTRIBUTORS + * BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + * POSSIBILITY OF SUCH DAMAGE. + * + */ + +#ifndef CLAPACK_H + +#define CLAPACK_H +#include "cblas.h" + +#ifndef ATLAS_ORDER + #define ATLAS_ORDER CBLAS_ORDER +#endif +#ifndef ATLAS_UPLO + #define ATLAS_UPLO CBLAS_UPLO +#endif +#ifndef ATLAS_DIAG + #define ATLAS_DIAG CBLAS_DIAG +#endif +int clapack_sgesv(const enum CBLAS_ORDER Order, const int N, const int NRHS, + float *A, const int lda, int *ipiv, + float *B, const int ldb); +int clapack_sgetrf(const enum CBLAS_ORDER Order, const int M, const int N, + float *A, const int lda, int *ipiv); +int clapack_sgetrs + (const enum CBLAS_ORDER Order, const enum CBLAS_TRANSPOSE Trans, + const int N, const int NRHS, const float *A, const int lda, + const int *ipiv, float *B, const int ldb); +int clapack_sgetri(const enum CBLAS_ORDER Order, const int N, float *A, + const int lda, const int *ipiv); +int clapack_sposv(const enum ATLAS_ORDER Order, const enum ATLAS_UPLO Uplo, + const int N, const int NRHS, float *A, const int lda, + float *B, const int ldb); +int clapack_spotrf(const enum ATLAS_ORDER Order, const enum ATLAS_UPLO Uplo, + const int N, float *A, const int lda); +int clapack_spotrs(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, + const int N, const int NRHS, const float *A, const int lda, + float *B, const int ldb); +int clapack_spotri(const enum ATLAS_ORDER Order, const enum ATLAS_UPLO Uplo, + const int N, float *A, const int lda); +int clapack_slauum(const enum ATLAS_ORDER Order, const enum ATLAS_UPLO Uplo, + const int N, float *A, const int lda); +int clapack_strtri(const enum ATLAS_ORDER Order,const enum ATLAS_UPLO Uplo, + const enum ATLAS_DIAG Diag,const int N, float *A, const int lda); + +int clapack_dgesv(const enum CBLAS_ORDER Order, const int N, const int NRHS, + double *A, const int lda, int *ipiv, + double *B, const int ldb); +int clapack_dgetrf(const enum CBLAS_ORDER Order, const int M, const int N, + double *A, const int lda, int *ipiv); +int clapack_dgetrs + (const enum CBLAS_ORDER Order, const enum CBLAS_TRANSPOSE Trans, + const int N, const int NRHS, const double *A, const int lda, + const int *ipiv, double *B, const int ldb); +int clapack_dgetri(const enum CBLAS_ORDER Order, const int N, double *A, + const int lda, const int *ipiv); +int clapack_dposv(const enum ATLAS_ORDER Order, const enum ATLAS_UPLO Uplo, + const int N, const int NRHS, double *A, const int lda, + double *B, const int ldb); +int clapack_dpotrf(const enum ATLAS_ORDER Order, const enum ATLAS_UPLO Uplo, + const int N, double *A, const int lda); +int clapack_dpotrs(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, + const int N, const int NRHS, const double *A, const int lda, + double *B, const int ldb); +int clapack_dpotri(const enum ATLAS_ORDER Order, const enum ATLAS_UPLO Uplo, + const int N, double *A, const int lda); +int clapack_dlauum(const enum ATLAS_ORDER Order, const enum ATLAS_UPLO Uplo, + const int N, double *A, const int lda); +int clapack_dtrtri(const enum ATLAS_ORDER Order,const enum ATLAS_UPLO Uplo, + const enum ATLAS_DIAG Diag,const int N, double *A, const int lda); + +int clapack_cgesv(const enum CBLAS_ORDER Order, const int N, const int NRHS, + void *A, const int lda, int *ipiv, + void *B, const int ldb); +int clapack_cgetrf(const enum CBLAS_ORDER Order, const int M, const int N, + void *A, const int lda, int *ipiv); +int clapack_cgetrs + (const enum CBLAS_ORDER Order, const enum CBLAS_TRANSPOSE Trans, + const int N, const int NRHS, const void *A, const int lda, + const int *ipiv, void *B, const int ldb); +int clapack_cgetri(const enum CBLAS_ORDER Order, const int N, void *A, + const int lda, const int *ipiv); +int clapack_cposv(const enum ATLAS_ORDER Order, const enum ATLAS_UPLO Uplo, + const int N, const int NRHS, void *A, const int lda, + void *B, const int ldb); +int clapack_cpotrf(const enum ATLAS_ORDER Order, const enum ATLAS_UPLO Uplo, + const int N, void *A, const int lda); +int clapack_cpotrs(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, + const int N, const int NRHS, const void *A, const int lda, + void *B, const int ldb); +int clapack_cpotri(const enum ATLAS_ORDER Order, const enum ATLAS_UPLO Uplo, + const int N, void *A, const int lda); +int clapack_clauum(const enum ATLAS_ORDER Order, const enum ATLAS_UPLO Uplo, + const int N, void *A, const int lda); +int clapack_ctrtri(const enum ATLAS_ORDER Order,const enum ATLAS_UPLO Uplo, + const enum ATLAS_DIAG Diag,const int N, void *A, const int lda); + +int clapack_zgesv(const enum CBLAS_ORDER Order, const int N, const int NRHS, + void *A, const int lda, int *ipiv, + void *B, const int ldb); +int clapack_zgetrf(const enum CBLAS_ORDER Order, const int M, const int N, + void *A, const int lda, int *ipiv); +int clapack_zgetrs + (const enum CBLAS_ORDER Order, const enum CBLAS_TRANSPOSE Trans, + const int N, const int NRHS, const void *A, const int lda, + const int *ipiv, void *B, const int ldb); +int clapack_zgetri(const enum CBLAS_ORDER Order, const int N, void *A, + const int lda, const int *ipiv); +int clapack_zposv(const enum ATLAS_ORDER Order, const enum ATLAS_UPLO Uplo, + const int N, const int NRHS, void *A, const int lda, + void *B, const int ldb); +int clapack_zpotrf(const enum ATLAS_ORDER Order, const enum ATLAS_UPLO Uplo, + const int N, void *A, const int lda); +int clapack_zpotrs(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, + const int N, const int NRHS, const void *A, const int lda, + void *B, const int ldb); +int clapack_zpotri(const enum ATLAS_ORDER Order, const enum ATLAS_UPLO Uplo, + const int N, void *A, const int lda); +int clapack_zlauum(const enum ATLAS_ORDER Order, const enum ATLAS_UPLO Uplo, + const int N, void *A, const int lda); +int clapack_ztrtri(const enum ATLAS_ORDER Order,const enum ATLAS_UPLO Uplo, + const enum ATLAS_DIAG Diag,const int N, void *A, const int lda); + +#endif diff --git a/kaldi_io/src/tools/ATLAS/include/contrib/ATL_gemv_ger_SSE.h b/kaldi_io/src/tools/ATLAS/include/contrib/ATL_gemv_ger_SSE.h new file mode 100644 index 0000000..118d3de --- /dev/null +++ b/kaldi_io/src/tools/ATLAS/include/contrib/ATL_gemv_ger_SSE.h @@ -0,0 +1,188 @@ +#ifdef GER +#undef NO_TRANSPOSE +#define NO_TRANSPOSE +#endif + + +#if NDPM > 4 +#error Max NDPM is 4 +#endif + +#if !defined(ATL_SSE1) && ( defined(SREAL) || defined(SCPLX) ) +#error This routine needs ATL_SSE1 defined +#endif + +#if !defined(ATL_SSE2) && ( defined(DREAL) || defined(DCPLX) ) +#error This routine needs ATL_SSE2 defined +#endif + +#include <stdio.h> +#include <stdlib.h> + +#include "camm_util.h" + +#ifndef GER +#if defined(BETAX) || defined(BETAXI0) +#include "camm_scale.h" +#endif +#endif + +#if NDPM >= 4 +#define EXT4 Mjoin(4dp,BLC) +#undef NDP +#define NDP 4 +#undef EXT +#define EXT EXT4 +#include "camm_dpa.h" +#endif + +#if NDPM >= 3 +#define EXT3 Mjoin(3dp,BLC) +#undef NDP +#define NDP 3 +#undef EXT +#define EXT EXT3 +#include "camm_dpa.h" +#endif + +#if NDPM >= 2 +#define EXT2 Mjoin(2dp,BLC) +#undef NDP +#define NDP 2 +#undef EXT +#define EXT EXT2 +#include "camm_dpa.h" +#endif + +#define EXT1 Mjoin(1dp,BLC) +#undef NDP +#define NDP 1 +#undef EXT +#define EXT EXT1 +#include "camm_dpa.h" + +#undef NDP +#define NDP NDPM +#undef EXT +#define EXT Mjoin(Mjoin(NDP,Mjoin(dp,BLC)),m) +#include "camm_dpa.h" + +#ifdef GER +#if defined(SCPLX) || defined(DCPLX) +#ifdef Conj_ +#define IM 1c +#else +#define IM 1u +#endif +#else +#define IM 1 +#endif + + +#define FN Mjoin(Mjoin(Mjoin(ATL_,PREC),Mjoin(ger,IM)),_a1_x1_yX) + +#undef MY_FUNCTION +#define MY_FUNCTION FN + +void +MY_FUNCTION(int m,int n, const SCALAR alpha,const TYPE *c, + int cinc,const TYPE *b,int binc, + TYPE *a,int lda) { + +#else + + +#define FN Mjoin(Mjoin(Mjoin(ATL_,PREC),gemv),Mjoin(FEXT,Mjoin(_a1_x1_,Mjoin(BL,_y1)))) + +#undef MY_FUNCTION +#define MY_FUNCTION FN + +void +MY_FUNCTION(int m,int n, const SCALAR alpha,const TYPE *a, + int lda,const TYPE *b,int binc, + const SCALAR beta,TYPE *c,int cinc) { + +#endif + + int i,mm,nn; + const TYPE *ae; +#ifdef NO_TRANSPOSE + int len=m,w=n; +#define zz b +#else + int len=n,w=m; +#define zz c +#endif + +#ifdef GER +#define zzinc binc +#else +#define zzinc 1 + + +#if defined(NO_TRANSPOSE) && defined(BETA0) + memset(c,0,m*sizeof(*c)); +#endif + +#if defined(BETAX) || defined(BETAXI0) +#if defined(SCPLX) || defined(DCPLX) + SCALE(beta,c,m); +#endif +#if defined(SREAL) || defined(DREAL) + SCALE(&beta,c,m); +#endif +#endif + +#endif + + ae=a+w*lda; + nn=STRIDE*lda; + + +#if NDPM == 1 + for (;a<ae;a+=lda,zz+=zzinc) + Mjoin(dp,EXT)(a,nn,b,c,STRIDE*zzinc,len); + +#else + + while (a+NDPM*nn<=ae) { + for (i=0;i<STRIDE;i++,a+=lda,zz+=zzinc) + Mjoin(dp,EXT)(a,nn,b,c,STRIDE*zzinc,len); + + a+=(NDPM-1)*nn; + zz+=(NDPM-1)*STRIDE*zzinc; + } + + for (i=0;a<ae && i<STRIDE;i++,a+=lda,zz+=zzinc) { + + mm=(ae-a)/nn; +#if STRIDE > 1 + if (((ae-a)/lda)%STRIDE) + mm++; +#endif + + if (mm == 1) + Mjoin(dp,EXT1)(a,nn,b,c,STRIDE*zzinc,len); + +#if ( NDPM == 2 && STRIDE > 1 ) || NDPM > 2 + else if (mm == 2) + Mjoin(dp,EXT2)(a,nn,b,c,STRIDE*zzinc,len); +#endif + +#if ( NDPM == 3 && STRIDE > 1 ) || NDPM > 3 + else if (mm == 3) + Mjoin(dp,EXT3)(a,nn,b,c,STRIDE*zzinc,len); +#endif + +#if ( NDPM == 4 && STRIDE > 1 ) || NDPM > 4 + else if (mm == 4) + Mjoin(dp,EXT4)(a,nn,b,c,STRIDE*zzinc,len); +#endif + + + } + +#endif + +} + diff --git a/kaldi_io/src/tools/ATLAS/include/contrib/Make.ext b/kaldi_io/src/tools/ATLAS/include/contrib/Make.ext new file mode 100644 index 0000000..f7f9a0a --- /dev/null +++ b/kaldi_io/src/tools/ATLAS/include/contrib/Make.ext @@ -0,0 +1,39 @@ + +topd = /home/whaley/atlas3.8/AtlasBase +incs = -def topd /home/whaley/atlas3.8/AtlasBase \ + -def incd /home/whaley/atlas3.8/AtlasBase/Clint \ + -def BASEdir /home/whaley/atlas3.8/AtlasBase/Antoine/ \ + -def basd /home/whaley/atlas3.8/AtlasBase/Clint +ext = extract +extF = $(ext) -langF -lnlen71 -Remtblank -llwarn2 -LAPACK1 $(incs) +extC = $(ext) -langC -lnlen79 -Remtblank -llwarn2 $(incs) +extM = $(ext) -langM -lnlen79 -llwarn2 $(incs) + +default: all +force_build: +basd = /home/whaley/atlas3.8/AtlasBase/Clint +basdRCW = /home/whaley/atlas3.8/AtlasBase/Clint +basdAPP = /home/whaley/atlas3.8/AtlasBase/Antoine +incf = /home/whaley/atlas3.8/AtlasBase/gen.inc + +files = ATL_gemv_ger_SSE.h SSE3Dnow.h camm_dpa.h camm_pipe3.h camm_scale.h \ + camm_strat1.h camm_tpipe.h camm_util.h + +all : $(files) + +camm_strat1.h : $(topd)/kernel/CammMaguire/camm_strat1.h + cp $(topd)/kernel/CammMaguire/camm_strat1.h . +camm_tpipe.h : $(topd)/kernel/CammMaguire/camm_tpipe.h + cp $(topd)/kernel/CammMaguire/camm_tpipe.h . +camm_pipe3.h : $(topd)/kernel/CammMaguire/camm_pipe3.h + cp $(topd)/kernel/CammMaguire/camm_pipe3.h . +ATL_gemv_ger_SSE.h : $(topd)/kernel/CammMaguire/ATL_gemv_ger_SSE.h + cp $(topd)/kernel/CammMaguire/ATL_gemv_ger_SSE.h . +camm_util.h : $(topd)/kernel/CammMaguire/camm_util.h + cp $(topd)/kernel/CammMaguire/camm_util.h . +camm_scale.h : $(topd)/kernel/CammMaguire/camm_scale.h + cp $(topd)/kernel/CammMaguire/camm_scale.h . +camm_dpa.h : $(topd)/kernel/CammMaguire/camm_dpa.h + cp $(topd)/kernel/CammMaguire/camm_dpa.h . +SSE3Dnow.h : $(topd)/kernel/PeterSoendergaard/SSE3Dnow.h + cp $(topd)/kernel/PeterSoendergaard/SSE3Dnow.h . diff --git a/kaldi_io/src/tools/ATLAS/include/contrib/SSE3Dnow.h b/kaldi_io/src/tools/ATLAS/include/contrib/SSE3Dnow.h new file mode 100644 index 0000000..a783749 --- /dev/null +++ b/kaldi_io/src/tools/ATLAS/include/contrib/SSE3Dnow.h @@ -0,0 +1,709 @@ +#if !defined(ATL_GAS_x8632) && !defined(ATL_GAS_x8664) + #error "This kernel requires gas x86 assembler!" +#endif +#ifndef Mstr /* Added by RCW to make multiline macros work */ + #define Mstr2(m) # m + #define Mstr(m) Mstr2(m) +#endif +/* The mening of the defined macros is as follows: + * VECLEN: The length of a singleprecision vector register + * vec_add: Add to single precision vectors. + * vec_mul: Multiply to single precision vectors. + * vec_mov: Moves data around + * vec_mov1: Load one element in a vector and zero all other entries! + * vec_splat: Load one element relpicated in all positions in the vector. + * vec_load_apart: Load elements from different memory positions into a register. + * vec_sum: Sums a register. + * vec_store_one: Stores lowest element in vector to memory, no zero-extend! + * Meaning of suffixes is as follows: + * mr means memory to register + * rr means register to register + * rm means register to memory + * a means that instruction needs aligned data + * 1 means that the instructions only operates on the lowest element of the + * vector. + * + * The _1 instructions work under one important assumption: That you never mix + * them with regular instructions, e.g. loading into a register with a normal + * mov, and then using add_rr_1 will not work under 3dnow! since it is in + * reality a normal add. However, if using a mov_1 first, the upper part of + * the register will be zeroed, and it will therefore work. The _1 system is + * more robust under SSE, but other architectures might be implemented the + * same way as 3dnow! + * + * RCW: I added the following functionality for SSE only (note that vw may + * be overwritten with intermediate results, but is not used as input, + * and that all input array may be overwritten wt intermediate results. + * VL : vector length -1): + * vec_red(vd, vw) : vd[0] = sum(vd[0:VL]) + * vec_red2(v1, v2, vw) : v1[0] = sum(v1[0:VL]); v1[1] = sum(v2[0:VL]) + * vec_red4(v0, v1, v2, v3 vw1, vw2) : + * v0[0] = sum(v0[0:VL]); v0[1] = sum(v1[0:VL]) + * if type = double: + * v2[0] = sum(v2[0:VL]); v2[1] = sum(v3[0:VL]) + * else + * v0[2] = sum(v2[0:VL]); v0[3] = sum(v3[0:VL]) + * vec_zero(vd) : vd[0:VL] = 0.0 + */ + + +/* Things to try: + * Non-temporal stores + * Sequences of instructions instead of movups + * + * + * + * + */ + + + +#define gen_vec_rr(op,reg1,reg2) \ + __asm__ __volatile__ (#op " " #reg1 ", " #reg2 \ + : /* nothing */ \ + : /* nothing */) + + +#define w(p) p + +#define nop() __asm__ __volatile__ ("nop") + +#define rep() __asm__ __volatile__ ("rep") + +#define align() __asm__ __volatile__ (".align 16") + + +#ifdef x87double + +#define st0 %%st(0) +#define st1 %%st(1) +#define st2 %%st(2) +#define st3 %%st(3) +#define st4 %%st(4) +#define st5 %%st(5) +#define st6 %%st(6) +#define st7 %%st(7) + + +#define gen_stack_rt(op,reg) \ + __asm__ __volatile__ (#op " " #reg \ + : /* nothing */ \ + : /* nothing */) + +#define gen_stack_tr(op,reg) \ + __asm__ __volatile__ (#op " %%st(0)," #reg \ + : \ + : ) + + +#define gen_stack_rr(op,reg1,reg2) \ + __asm__ __volatile__ (#op " " #reg1 ", " #reg2 \ + : /* nothing */ \ + : /* nothing */) + +#define gen_stack_t(op) \ + __asm__ __volatile__ (#op \ + : /* nothing */ \ + : /* nothing */) + + +#define gen_stack_tm(op,mem) \ + __asm__ __volatile__ (#op " %0" \ + : "=m" (((mem)[0])) \ + : ) + +#define gen_stack_mt(op,mem) \ + __asm__ __volatile__ (#op " %0" \ + : \ + : "m" (((mem)[0]))) + + +#define stack_mov_mt_push(mem) gen_stack_mt(fldl,mem) + +#define stack_add_tr_pop(reg) gen_stack_tr(faddp,reg) +#define stack_add_mt(mem) gen_stack_mt(faddl,mem) + +#define stack_mul_tr(reg) gen_stack_tr(fmul,reg) +#define stack_mul_tr_pop(reg) gen_stack_tr(fmulp,reg) +#define stack_mul_mt(mem) gen_stack_mt(fmul,mem) + +#define stack_mov_tm_pop(mem) gen_stack_tm(fstpl,mem) + +#define stack_zero_push() gen_stack_t(fldz) + +#endif /* x87double */ + +#ifdef SSE + +/* Peculiarities of SSE: Alignment is good, but not mandatory. It is possible to + * load/store from misaligned adresses using movups at a cost of some cycles. Loading + * using mul/add must always be aligned. Alignment is 16 bytes. + * No muladd. + */ + + + +#define gen_vec_mr(op,mem,reg) \ + __asm__ __volatile__ (#op " %0, " #reg \ + : /* nothing */ \ + : "m" (((mem)[0])), "m" (((mem)[1])), "m" (((mem)[2])), "m" (((mem)[3]))) + + +#define gen_vec_rm(op,reg,mem) \ + __asm__ __volatile__ (#op " " #reg ", %0" \ + : "=m" (((mem)[0])), "=m" (((mem)[1])), "=m" (((mem)[2])), "=m" (((mem)[3])) \ + : /* nothing */ ) + + + + +#define VECLEN 4 + +#define reg0 %%xmm0 +#define reg1 %%xmm1 +#define reg2 %%xmm2 +#define reg3 %%xmm3 +#define reg4 %%xmm4 +#define reg5 %%xmm5 +#define reg6 %%xmm6 +#define reg7 %%xmm7 +#ifdef ATL_GAS_x8664 + #define reg8 %%xmm8 + #define reg9 %%xmm9 + #define reg10 %%xmm10 + #define reg11 %%xmm11 + #define reg12 %%xmm12 + #define reg13 %%xmm13 + #define reg14 %%xmm14 + #define reg15 %%xmm15 +#endif + +#define vec_mov_mr(mem,reg) gen_vec_mr(movups,mem,reg) +#define vec_mov_rm(reg,mem) gen_vec_rm(movups,reg,mem) +#define vec_mov_mr_a(mem,reg) gen_vec_mr(movaps,mem,reg) +#define vec_mov_rm_a(reg,mem) gen_vec_rm(movaps,reg,mem) +#define vec_mov_rr(reg1,reg2) gen_vec_rr(movaps,reg1,reg2) + +#define vec_add_mr_a(mem,reg) gen_vec_mr(addps,mem,reg) +#define vec_mul_mr_a(mem,reg) gen_vec_mr(mulps,mem,reg) + +#define vec_add_rr(mem,reg) gen_vec_rr(addps,mem,reg) +#define vec_mul_rr(mem,reg) gen_vec_rr(mulps,mem,reg) + +#define vec_mov_mr_1(mem,reg) gen_vec_mr(movss,mem,reg) +#define vec_mov_rm_1(reg,mem) gen_vec_rm(movss,reg,mem) +#define vec_mov_rr_1(reg1,reg2) gen_vec_rr(movss,reg1,reg2) + +#define vec_add_mr_1(mem,reg) gen_vec_mr(addss,mem,reg) +#define vec_add_rr_1(reg1,reg2) gen_vec_rr(addss,reg1,reg2) + +#define vec_mul_mr_1(mem,reg) gen_vec_mr(mulss,mem,reg) +#define vec_mul_rr_1(reg1,reg2) gen_vec_rr(mulss,reg1,reg2) + +#define vec_unpack_low(reg1,reg2) gen_vec_rr(unpcklps,reg1,reg2) +#define vec_unpack_high(reg1,reg2) gen_vec_rr(unpckhps,reg1,reg2) +#define vec_shuffle(mode,reg1,reg2) vec_shuffle_wrap(mode,reg1,reg2) +#define vec_shuffle_wrap(mode,reg1,reg2) \ + __asm__ __volatile__ ("shufps " #mode ", " #reg1 ", " #reg2 \ + : /* nothing */\ + : /* nothing */) + +/* Hack! */ +/* To use this instruction be sure that register 7 is not in use!!! */ +/* It must be possible to reduce this sequence to only four instructions. + * please tell me how! */ +#define vec_sum(reg) vec_sum_wrap(reg) +#define vec_sum_wrap(reg) \ + __asm__ __volatile__ ("movhlps " #reg ", %%xmm7\n"\ + "addps " #reg ", %%xmm7\n"\ + "movaps %%xmm7, " #reg "\n"\ + "shufps $1, " #reg ", %%xmm7\n"\ + "addss %%xmm7, " #reg "\n"\ + : /* nothing */\ + : /* nothing */) + +/* RCW: added to safely replace vec_sum (vec reduce), and use SSE3 when avail */ +#define vec_zero(vd) __asm__ __volatile__("xorps " Mstr(vd) ", " Mstr(vd) ::) +#ifdef ATL_SSE3 + #define vec_red(vr, vwrk) \ + __asm__ __volatile__("haddps " Mstr(vr) ", " Mstr(vr) "\n"\ + "haddps " Mstr(vr) ", " Mstr(vr) "\n" ::) +/* + * haddps v1 v0 # v0 = {v1cd, v1ab, v0cd, v0ab} + * haddps v0 v0 # v0 = {v1abcd, v0abcd, v1abcd, v0abcd} + */ + #define vec_red2(v0, v1, vwork) \ + __asm__ __volatile__("haddps " Mstr(v1) ", " Mstr(v0) "\n"\ + "haddps " Mstr(v0) ", " Mstr(v0) "\n" ::) +/* + * haddps v1, v0 # v0 = {v1cd,v1ab,v0cd,v0ab} + * haddps v3, v2 # v2 = {v3cd,v3ab,v2cd,v2ab} + * haddps v2, v0 # v0 = {v3abcd,v2abcd,v1abcd, v0abcd} + */ + #define vec_red4(v0, v1, v2, v3, w0, w1) \ + __asm__ __volatile__("haddps " Mstr(v1) ", " Mstr(v0) "\n"\ + "haddps " Mstr(v3) ", " Mstr(v2) "\n"\ + "haddps " Mstr(v2) ", " Mstr(v0) "\n" ::) +#elif defined(ATL_SSE2) + #define vec_red(vr, vwrk) \ + __asm__ __volatile__ ("pshufd $0xEE, " Mstr(vr) ", " Mstr(vwrk) "\n"\ + "addps " Mstr(vwrk) ", " Mstr(vr) "\n"\ + "pshufd $0xE5, " Mstr(vr) ", " Mstr(vwrk) "\n"\ + "addss " Mstr(vwrk) ", " Mstr(vr) "\n"\ + ::) +#else + #define vec_red(vr, vwrk) \ + __asm__ __volatile__ ("movhlps " Mstr(vr) ", " Mstr(vwrk) "\n"\ + "addps " Mstr(vwrk) ", " Mstr(vr) "\n"\ + "movaps " Mstr(vr) ", " Mstr(vwrk) "\n"\ + "shufps $0xE5, " Mstr(vr) ", " Mstr(vr) "\n"\ + "addss " Mstr(vwrk) ", " Mstr(vr) "\n"\ + ::) +#endif +#ifndef ATL_SSE3 /* codes that are the same for SSE2 and SSE1 */ +/* + # v0 = {v0d,v0c,v0b,v0a} + # v1 = {v1d,v1c,v1b,v1a} + movaps v0, vw # vw = {v0d,v0c,v0b,v0a} + unpacklps v1, v0 # v0 = {v1b,v0b,v1a,v0a} + unpackhps v1, vw # vw = {v1d,v0d,v1c,v0c} + addps vw, v0 # v0 = {v1bd,v0bd,v1ac,v0ac} + movhlps v0, vw # vw = {X , X,v1bd,v0bd} + addps vw, v0 # v0 = {X , X,v1abcd,v0abcd} +*/ + #define vec_red2(v0, v1, vw) \ + __asm__ __volatile__ ("movaps " Mstr(v0) ", " Mstr(vw) "\n"\ + "unpcklps " Mstr(v1) ", " Mstr(v0) "\n"\ + "unpckhps " Mstr(v1) ", " Mstr(vw) "\n"\ + "addps " Mstr(vw) ", " Mstr(v0) "\n"\ + "movhlps " Mstr(v0) ", " Mstr(vw) "\n"\ + "addps " Mstr(vw) ", " Mstr(v0) "\n"\ + ::) +/* + * movaps v0, w0 # w0 = {v0d, v0c, v0b, v0a} + * unpcklps v1, v0 # v0 = {v1b, v0b, v1a, v0a} + * movaps v2, w1 # w1 = {v2d, v2c, v2b, v2a} + * unpckhps v1, w0 # w0 = {v1d, v0d, v1c, v0c} + * unpcklps v3, v2 # v2 = {v3b, v2b, v3a, v2a} + * addps w0, v0 # v0 = {v1bd, v0bd, v1ac, v0ac} + * unpckhps v3, w1 # w1 = {v3d, v2d, v3c, v2c} + * movaps v0, w0 # w0 = {v1bd, v0bd, v1ac, v0ac} + * addps w1, v2 # v2 = {v3bd, v2bd, v3ac, v2ac} + * shufps $0x44,v2,v0 # v0 = {v3ac, v2ac, v1ac, v0ac} + * shufps $0xEE,v2,w0 # w0 = {v3bd, v2bd, v1bd, v0bd} + * addps w0, v0 # v0 = {v3abcd, v2abcd, v1abcd, v0abcd} + */ + #define vec_red4(v0, v1, v2, v3, w0, w1) \ + __asm__ __volatile__ ("movaps " Mstr(v0) ", " Mstr(w0) "\n"\ + "unpcklps " Mstr(v1) ", " Mstr(v0) "\n"\ + "movaps " Mstr(v2) ", " Mstr(w1) "\n"\ + "unpckhps " Mstr(v1) ", " Mstr(w0) "\n"\ + "unpcklps " Mstr(v3) ", " Mstr(v2) "\n"\ + "addps " Mstr(w0) ", " Mstr(v0) "\n"\ + "unpckhps " Mstr(v3) ", " Mstr(w1) "\n"\ + "movaps " Mstr(v0) ", " Mstr(w0) "\n"\ + "addps " Mstr(w1) ", " Mstr(v2) "\n"\ + "shufps $0x44, " Mstr(v2) ", " Mstr(v0) "\n"\ + "shufps $0xEE, " Mstr(v2) ", " Mstr(w0) "\n"\ + "addps " Mstr(w0) ", " Mstr(v0) "\n"\ + ::) +#endif + +#define vec_splat(mem,reg) vec_splat_wrap(mem,reg) +#define vec_splat_wrap(mem,reg) \ + __asm__ __volatile__ ("movss %0, " #reg "\n"\ + "unpcklps " #reg ", " #reg "\n"\ + "movlhps " #reg ", " #reg "\n"\ + : /* nothing */ \ + : "m" ((mem)[0])) + + +/* This instruction sequence appears courtesy of Camm Maguire. */ +#define vec_sum_full(reg0,reg1,reg2,reg3,regout,empty0,empty1) vec_sum_full_wrap(reg0,reg1,reg2,reg3,regout,empty0,empty1) +#define vec_sum_full_wrap(reg0,reg1,reg2,reg3,regout,empty0,empty1) \ + __asm__ __volatile__ ("movaps " #reg0 "," #empty0 "\n"\ + "unpcklps " #reg1 "," #reg0 "\n"\ + "movaps " #reg2 "," #empty1 "\n"\ + "unpckhps " #reg1 "," #empty0 "\n"\ + "unpcklps " #reg3 "," #reg2 "\n"\ + "addps " #empty0 "," #reg0 "\n"\ + "unpckhps " #reg3 "," #empty1 "\n"\ + "movaps " #reg0 "," #regout "\n"\ + "addps " #empty1 "," #reg2 "\n"\ + "shufps $0x44," #reg2 "," #reg0 "\n"\ + "shufps $0xee," #reg2 "," #regout "\n"\ + "addps " #reg0 "," #regout "\n"\ + : /* nothing */ \ + : /* nothing */) + + + +typedef float vector[VECLEN]; + +#endif /* end ifdef SSE */ + + +#ifdef SSE2 + +/* Peculiarities of SSE: Alignment is good, but not mandatory. It is possible to + * load/store from misaligned adresses using movups at a cost of some cycles. Loading + * using mul/add must always be aligned. Alignment is 16 bytes. + * No muladd. + */ + + + +#define gen_vec_mr(op,mem,reg) \ + __asm__ __volatile__ (#op " %0, " #reg \ + : /* nothing */ \ + : "m" (((mem)[0])), "m" (((mem)[1]))) + + +#define gen_vec_rm(op,reg,mem) \ + __asm__ __volatile__ (#op " " #reg ", %0" \ + : "=m" (((mem)[0])), "=m" (((mem)[1])) \ + : /* nothing */ ) + + + + +#define VECLEN 2 + +#define reg0 %%xmm0 +#define reg1 %%xmm1 +#define reg2 %%xmm2 +#define reg3 %%xmm3 +#define reg4 %%xmm4 +#define reg5 %%xmm5 +#define reg6 %%xmm6 +#define reg7 %%xmm7 +#ifdef ATL_GAS_x8664 + #define reg8 %%xmm8 + #define reg9 %%xmm9 + #define reg10 %%xmm10 + #define reg11 %%xmm11 + #define reg12 %%xmm12 + #define reg13 %%xmm13 + #define reg14 %%xmm14 + #define reg15 %%xmm15 +#endif + + +#define vec_mov_mr(mem,reg) gen_vec_mr(movupd,mem,reg) +#define vec_mov_rm(reg,mem) gen_vec_rm(movupd,reg,mem) +#define vec_mov_mr_a(mem,reg) gen_vec_mr(movapd,mem,reg) +#define vec_mov_rm_a(reg,mem) gen_vec_rm(movapd,reg,mem) +#define vec_mov_rr(reg1,reg2) gen_vec_rr(movapd,reg1,reg2) + +#define vec_add_mr_a(mem,reg) gen_vec_mr(addpd,mem,reg) +#define vec_mul_mr_a(mem,reg) gen_vec_mr(mulpd,mem,reg) + +#define vec_add_rr(mem,reg) gen_vec_rr(addpd,mem,reg) +#define vec_mul_rr(mem,reg) gen_vec_rr(mulpd,mem,reg) + +#define vec_mov_mr_1(mem,reg) gen_vec_mr(movsd,mem,reg) +#define vec_mov_rm_1(reg,mem) gen_vec_rm(movsd,reg,mem) +#define vec_mov_rr_1(reg1,reg2) gen_vec_rr(movsd,reg1,reg2) + +#define vec_add_mr_1(mem,reg) gen_vec_mr(addsd,mem,reg) +#define vec_add_rr_1(reg1,reg2) gen_vec_rr(addsd,reg1,reg2) + +#define vec_mul_mr_1(mem,reg) gen_vec_mr(mulsd,mem,reg) +#define vec_mul_rr_1(reg1,reg2) gen_vec_rr(mulsd,reg1,reg2) + +#define vec_splat(mem,reg) vec_splat_wrap(mem,reg) +#define vec_splat_wrap(mem,reg) \ + __asm__ __volatile__ ("movsd %0, " #reg "\n"\ + "unpcklpd " #reg ", " #reg \ + : /* nothing */ \ + : "m" ((mem)[0])) + +/* Hack! */ +/* To use this instruction be sure that register 7 is not in use!!! */ +#define vec_sum(reg) vec_sum_wrap(reg) +#define vec_sum_wrap(reg) \ + __asm__ __volatile__ ("movhlps " #reg ", %%xmm7\n"\ + "addpd %%xmm7, " #reg "\n"\ + : /* nothing */\ + : /* nothing */) +/* + * Added by RCW to improve performance and avoid xmm7 hack (replace vec_sum) + */ +#define vec_zero(vd) __asm__ __volatile__("xorps " Mstr(vd) ", " Mstr(vd) ::) +#ifdef ATL_SSE3 + #define vec_red(vr, vwrk) \ + __asm__ __volatile__("haddpd " Mstr(vr) ", " Mstr(vr) "\n" ::) + #define vec_red2(v0, v1, vw) \ + __asm__ __volatile__("haddpd " Mstr(v1) ", " Mstr(v0) "\n" ::) + #define vec_red4(v0, v1, v2, v3, w0, w1) \ + __asm__ __volatile__("haddpd " Mstr(v1) ", " Mstr(v0) "\n"\ + "haddpd " Mstr(v3) ", " Mstr(v2) "\n"\ + ::) +#else + #define vec_red(vr, vwrk) \ + __asm__ __volatile__ ("pshufd $0xEE, " Mstr(vr) ", " Mstr(vwrk) "\n"\ + "addsd " Mstr(vwrk) ", " Mstr(vr) "\n" ::) +/* + * movapd v0, vw # vw = {v0b, v0a} + * unpcklpd v1,v0 # v0 = {v1a, v0a} + * unpckhpd v1, vw # vw = {v1b, v0b} + * addpd vw, v0 # v0 = {v1ab,v0ab} + */ + #define vec_red2(v0, v1, vw) \ + __asm__ __volatile__("movapd " Mstr(v0) ", " Mstr(vw) "\n"\ + "unpcklpd " Mstr(v1) ", " Mstr(v0) "\n"\ + "unpckhpd " Mstr(v1) ", " Mstr(vw) "\n"\ + "addpd " Mstr(vw) ", " Mstr(v0) "\n"\ + ::) +/* + * movapd v0, w0 # w0 = {v0b, v0a} + * movapd v2, w1 # w1 = {v2b, v2a} + * unpcklpd v1, v0 # v0 = {v1a, v0a} + * unpcklpd v3, v2 # v2 = {v3a, v2a} + * unpckhpd v1, w0 # w0 = {v1b, v0b} + * unpckhpd v3, w1 # w1 = {v3b, v2b} + * addpd w0, v0 # v0 = {v1ab, v0ab} + * addpd w1, v2 # v2 = {v3ab, v2ab} + */ + #define vec_red4(v0, v1, v2, v3, w0, w1) \ + __asm__ __volatile__("movapd " Mstr(v0) ", " Mstr(w0) "\n"\ + "movapd " Mstr(v2) ", " Mstr(w1) "\n"\ + "unpcklpd " Mstr(v1) ", " Mstr(v0) "\n"\ + "unpcklpd " Mstr(v3) ", " Mstr(v2) "\n"\ + "unpckhpd " Mstr(v1) ", " Mstr(w0) "\n"\ + "unpckhpd " Mstr(v3) ", " Mstr(w1) "\n"\ + "addpd " Mstr(w0) ", " Mstr(v0) "\n"\ + "addpd " Mstr(w1) ", " Mstr(v2) "\n"\ + ::) +#endif + +#define vec_sum_full(reg1,reg2,empty1) vec_sum_full_wrap(reg1,reg2,empty1) +#define vec_sum_full_wrap(reg1,reg2,empty1) \ + __asm__ __volatile__ ("movhlps " #reg2 ", " #empty1 "\n"\ + "movlhps " #reg2 ", " #empty1 "\n"\ + "addpd " #empty1 ", " #reg1 "\n"\ + : /* nothing */\ + : /* nothing */) + + +typedef double vector[VECLEN]; + +#endif /* end ifdef SSE2 */ + + +#ifdef THREEDNOW + +/* Peculiarities of 3DNOW. Alignment is not an issue, + * all alignments are legal, however alignment gives a speed increase. + * The vec_acc instruction can be used to sum to registers at once more efficiently + * than a series of vec_sum and vec_store_one + * No muladd. + */ + + +#define gen_vec_mr(op,mem,reg) \ + __asm__ __volatile__ (#op " %0, " #reg \ + : /* nothing */ \ + : "m" (((mem)[0])), "m" (((mem)[1]))) + +#define gen_vec_rm(op,reg,mem) \ + __asm__ __volatile__ (#op " " #reg ", %0" \ + : "=m" (((mem)[0])), "=m" (((mem)[1])) \ + : /* nothing */ ) + + + + +#define VECLEN 2 + +#define reg0 %%mm0 +#define reg1 %%mm1 +#define reg2 %%mm2 +#define reg3 %%mm3 +#define reg4 %%mm4 +#define reg5 %%mm5 +#define reg6 %%mm6 +#define reg7 %%mm7 + +#define vec_add_mr(mem,reg) gen_vec_mr(pfadd,mem,reg) +#define vec_mul_mr(mem,reg) gen_vec_mr(pfmul,mem,reg) +#define vec_mov_mr(mem,reg) gen_vec_mr(movq,mem,reg) +#define vec_mov_rm(reg,mem) gen_vec_rm(movq,reg,mem) +#define vec_add_rr(reg1,reg2) gen_vec_rr(pfadd,reg1,reg2) +#define vec_mul_rr(reg1,reg2) gen_vec_rr(pfmul,reg1,reg2) +#define vec_acc_rr(reg1,reg2) gen_vec_rr(pfacc,reg1,reg2) +#define vec_mov_rr(reg1,reg2) gen_vec_rr(movq,reg1,reg2) + +#define vec_sum(reg) gen_vec_rr(pfacc,reg,reg) +#define vec_sum_full(reg1,reg2) gen_vec_rr(pfacc,reg1,reg2) + +#define vec_mov_mr_1(mem,reg) gen_vec_mr(movd,mem,reg) +#define vec_mov_rm_1(reg,mem) gen_vec_rm(movd,reg,mem) +#define vec_mov_rr_1(reg1,reg2) gen_vec_rr(movd,reg1,reg2) + +#define vec_add_rr_1(reg1,reg2) gen_vec_rr(pfadd,reg1,reg2) +#define vec_mul_rr_1(reg1,reg2) gen_vec_rr(pfmul,reg1,reg2) + + +#define vec_splat(mem,reg) vec_splat_wrap(mem,reg) +#define vec_splat_wrap(mem,reg) \ + __asm__ __volatile__ ("movd %0, " #reg "\n"\ + "punpckldq " #reg ", " #reg \ + : /* nothing */ \ + : "m" ((mem)[0])) + + +#define vec_load_apart(mem1,mem2,reg) vec_load_apart_wrap(mem1,mem2,reg) +#define vec_load_apart_wrap(mem1,mem2,reg) \ + __asm__ __volatile__ ("movd %0, " #reg "\n"\ + "punpckldq %1, " #reg \ + : /* nothing */ \ + : "m" ((mem1)[0]), "m" (((mem2)[0]))) + + +#define vec_zero(reg) gen_vec_rr(pxor,reg,reg) + +#define vec_enter() __asm__ __volatile__ ("femms") +#define vec_exit() __asm__ __volatile__ ("femms") + +#define align() __asm__ __volatile__ (".align 16") + + +typedef float vector[VECLEN]; + +#endif + + + + + +#ifdef ALTIVEC + +#define VECLEN 4 + +#define reg0 %%vr0 +#define reg1 %%vr1 +#define reg2 %%vr2 +#define reg3 %%vr3 +#define reg4 %%vr4 +#define reg5 %%vr5 +#define reg6 %%vr6 +#define reg7 %%vr7 +#define reg8 %%vr8 +#define reg9 %%vr9 +#define reg10 %%vr10 +#define reg11 %%vr11 +#define reg12 %%vr12 +#define reg13 %%vr13 +#define reg14 %%vr14 +#define reg15 %%vr15 +#define reg16 %%vr16 +#define reg17 %%vr17 +#define reg18 %%vr18 +#define reg19 %%vr19 +#define reg20 %%vr20 +#define reg21 %%vr21 +#define reg22 %%vr22 +#define reg23 %%vr23 +#define reg24 %%vr24 +#define reg25 %%vr25 +#define reg26 %%vr26 +#define reg27 %%vr27 +#define reg28 %%vr28 +#define reg29 %%vr29 +#define reg30 %%vr30 +#define reg31 %%vr31 + +#define gen_vec_mr(op,mem,reg) \ + __asm__ __volatile__ (#op " %0, " #reg \ + : /* nothing */ \ + : "m" (((mem)[0])), "m" (((mem)[1])), "m" (((mem)[2])), "m" (((mem)[3]))) + + +#define gen_vec_rm(op,reg,mem) \ + __asm__ __volatile__ (#op " " #reg ", %0" \ + : "=m" (((mem)[0])), "=m" (((mem)[1])), "=m" (((mem)[2])), "=m" (((mem)[3])) \ + : /* nothing */ ) + + +#define gen_alti3(op,reg1,reg2,regout) \ + __asm__ __volatile__ (#op " " #reg1 ", " #reg2 ", " #regout \ + : /* nothing */ \ + : /* nothing */) + +#define gen_alti_muladd(op,reg1,reg2,regout) \ + __asm__ __volatile__ (#op " " #reg1 ", " #reg2 ", " #regout ", " #regout \ + : /* nothing */ \ + : /* nothing */) + + + +#define vec_mov_mr_a(mem,reg) gen_vec_mr(lvx,mem,reg) +#define vec_mov_rm_a(reg,mem) gen_vec_rm(svx,reg,mem) +#define vec_muladd(reg1,reg2,regout) gen_alti3(vmaddfp,reg1,reg2,regout) + +#define vec_zero(reg) gen_alti3(vxor,reg,reg,reg) + + +typedef float vector[VECLEN]; + +#endif + + +#ifdef ALTIVEC_C + +/* These macros have been written by, or greatly inspired by, + * Nicholas A. Coult . Thanks. + */ + +/* assumes that last four registers are not in use! */ +#define transpose(x0,x1,x2,x3) \ +reg28 = vec_mergeh(x0,x2); \ +reg29 = vec_mergeh(x1,x3); \ +reg30 = vec_mergel(x0,x2); \ +reg31 = vec_mergel(x1,x3); \ +x0 = vec_mergeh(reg28,reg29); \ +x1 = vec_mergel(reg28,reg29); \ +x2 = vec_mergeh(reg30,reg31); \ +x3 = vec_mergel(reg30,reg31) + +#define vec_mov_rm(v, where) \ +low = vec_ld(0, (where)); \ +high = vec_ld(16, (where)); \ +p_vector = vec_lvsr(0, (int *)(where)); \ +mask = vec_perm((vector unsigned char)(0), (vector unsigned char)(-1), p_vector); \ +v = vec_perm(v, v, p_vector); \ +low = vec_sel(low, v, mask); \ +high = vec_sel(v, high, mask); \ +vec_st(low, 0, (where)); \ +vec_st(high, 16, (where)) + +#define vec_mov_mr_a(mem,reg) reg = vec_ld(0, mem) + +#define vec_mov_mr(u,v) \ +p_vector = (vector unsigned char)vec_lvsl(0, (int*)(v)); \ +low = (vector unsigned char)vec_ld(0, (v)); \ +high = (vector unsigned char)vec_ld(16, (v)); \ +u=(vector float)vec_perm(low, high, p_vector) + +#define vec_muladd(reg1,reg2,regout) regout = vec_madd(reg1,reg2,regout) +#define vec_add_rr(reg1,reg2) reg2 = vec_add(reg1,reg2) + +#define vec_zero(reg) reg = vec_xor(reg,reg) + +#define vec_sum_full(reg0,reg1,reg2,reg3,regout,empty0,empty1) \ +transpose(reg0, reg1,reg2,reg3,regout,empty0,empty1); \ +empty0 = vec_add(reg0,reg1); \ +empty1 = vec_add(reg2,reg3); \ +regout = vec_add(empty0,empty1) + + +#endif /* ALTIVEC_C */ + + + + + + + + diff --git a/kaldi_io/src/tools/ATLAS/include/contrib/camm_dpa.h b/kaldi_io/src/tools/ATLAS/include/contrib/camm_dpa.h new file mode 100644 index 0000000..af9c6b1 --- /dev/null +++ b/kaldi_io/src/tools/ATLAS/include/contrib/camm_dpa.h @@ -0,0 +1,1626 @@ +#include <stdlib.h> +#include <sys/time.h> +#include <stdio.h> + +#include "camm_util.h" + + +#if defined(ALIGN) +#if( defined(SCPLX) || defined(DCPLX)) +#error Cannot align complex routines +#endif +#if defined(SREAL) && ( NDPM != 1 ) && ( STRIDE % 4 != 0) +#error Can only align SREAL with NDPM 1 or STRIDE % 4 = 0 +#endif +#if defined(DREAL) && ( NDPM != 1 ) && ( STRIDE % 2 != 0) +#error Can only align DREAL with NDPM 1 or STRIDE % 2 = 0 +#endif +#endif + +/****************************************************************************** + * Single Precision Complex Macros + ******************************************************************************/ + +#ifdef SCPLX + +#ifdef NO_TRANSPOSE + +#if NDPM > 3 +#error Max NDPM is 3 for SCPLX NO_TRANSPOSE +#endif + +#undef plax +#define plax + +#undef R1 +#define R1 2 +#undef R2 +#define R2 4 +#undef R3 +#define R3 6 +#undef R4 +#define R4 6 + +#undef TREG +#define TREG 1 +#undef SREG +#define SREG 0 +#undef CREG +#define CREG 0 + +#ifdef GER +#undef AREG +#define AREG 0 +#undef targ +#define targ(a_) AREG +#undef wb +#define wb(a_,b_) pu(AREG,a_,b_) +#undef wbd +#define wbd(a_,b_) pud(AREG,a_,b_) +#undef w +#define w(a_) +#undef w1_2 +#define w1_2(a_) +#else +#undef AREG +#define AREG TREG +#undef targ +#define targ(a_) CREG +#undef wb +#define wb(a_,b_) +#undef wbd +#define wbd(a_,b_) +#undef w +#define w(a_) pu(CREG,a_ ## 0,si) +#undef w1_2 +#define w1_2(a_) pud(CREG,a_ ## 0,si) +#endif + +#undef src +#define src(a_) a_ +#undef mpx +#define mpx(a_) pls(0,si,a_) ps(0,a_,a_) pls(4,si,P(a_,1)) \ + ps(0,P(a_,1),P(a_,1)) sign(a_) +#undef madd +#define madd(a_,b_,c_) pas(a_,b_,c_) +#undef ulfa +#define ulfa(a_) + +#else + +#undef R1 +#define R1 4 +#undef R2 +#define R2 5 +#undef R3 +#define R3 6 +#undef R4 +#define R4 7 + +#undef TREG +#define TREG 3 +#undef SREG +#define SREG 2 +#undef CREG +#define CREG 0 +#undef targ +#define targ(a_) a_ +#undef src +#define src(a_) 0 +#undef w +#define w(a_) +#undef w1_2 +#define w1_2(a_) +#undef mpx +#define mpx(a_) px(a_) +#ifdef BETA0 +#undef ulfa +#define ulfa(a_) phl(a_,0) pa(0,a_) pud(a_,0,si) +#else +#undef ulfa +#define ulfa(a_) pld(0,si,TREG) phl(a_,0) pa(0,a_) pa(TREG,a_) pud(a_,0,si) +#endif +#undef AREG +#define AREG TREG +#undef wb +#define wb(a_,b_) +#undef wbd +#define wbd(a_,b_) +#undef wbs +#define wbs(a_,b_) + + +#undef plax +#define plax pc(CREG,1) ps(160,CREG,CREG) ps(245,1,1) sign(CREG) + + + +#endif + +#if defined(Conj_) && ! defined(GER) +#undef sign +#define sign(a_) pm(SREG,a_) +#else +#undef sign +#define sign(a_) pm(SREG,P(a_,1)) +#endif + + + +#undef plb +#define plb(a_,b_) pl(a_,b_,AREG) +#undef plbd +#define plbd(a_,b_) px(AREG) pld(a_,b_,AREG) + +#undef dpr +#define dpr(a_) pm(src(a_),TREG) pa(TREG,targ(a_)) +#undef dprp +#define dprp(a_,b_,c_) pf(b_,c_) pm(src(a_),TREG) pa(TREG,targ(a_)) +#undef dpi +#define dpi(a_) pm(P(src(a_),1),TREG) ps(177,TREG,TREG) pa(TREG,targ(a_)) + +#ifndef GER + + +#undef plaa +#define plaa(a_) pl(a_ ## 0,si,CREG) plax +#undef wa +#define wa(a_) w(a_) +#undef dp +#define dp(a_,b_,c_) plb(a_ ## 0,b_) dpr(c_) plb(a_ ## 0,b_) dpi(c_) +#undef dpp +#define dpp(a_,b_,c_,d_,e_) plb(a_ ## 0,b_) dprp(c_,d_,e_) plb(a_ ## 0,b_) dpi(c_) +#undef ddp +#define ddp(a_,b_,c_) dp(a_,b_,c_) +#undef ddpp +#define ddpp(a_,b_,c_,d_,e_) dpp(a_,b_,c_,d_,e_) + +#undef plaa1_2 +#define plaa1_2(a_) px(CREG) pld(a_ ## 0,si,CREG) plax +#undef wa1_2 +#define wa1_2(a_) w1_2(a_) +#undef dp1_2 +#define dp1_2(a_,b_,c_) plbd(a_ ## 0,b_) dpr(c_) plbd(a_ ## 0,b_) dpi(c_) +#undef dpp1_2 +#define dpp1_2(a_,b_,c_,d_,e_) plbd(a_ ## 0,b_) dprp(c_,d_,e_) plbd(a_ ## 0,b_) dpi(c_) +#undef ddp1_2 +#define ddp1_2(a_,b_,c_) dp1_2(a_,b_,c_) +#undef ddpp1_2 +#define ddpp1_2(a_,b_,c_,d_,e_) dpp1_2(a_,b_,c_,d_,e_) + + +#else + +#undef lqc +#define lqc(a_) pl(a_ ## 0,si,TREG) +#undef lqc1 +#define lqc1_2(a_) px(TREG) pld(a_ ## 0,si,TREG) + + +#undef plaa +#define plaa(a_) +#undef wa +#define wa(a_) +#undef dp +#define dp(a_,b_,c_) lqc(a_) plb(a_ ## 0,b_) dpr(c_) \ + lqc(a_) dpi(c_) wb(a_ ## 0,b_) +#undef dpp +#define dpp(a_,b_,c_,d_,e_) lqc(a_) plb(a_ ## 0,b_) dpr(c_) pf(d_,e_) \ + lqc(a_) dpi(c_) wb(a_ ## 0,b_) +#undef ddp +#define ddp(a_,b_,c_) dp(a_,b_,c_) +#undef ddpp +#define ddpp(a_,b_,c_,d_,e_) dpp(a_,b_,c_,d_,e_) + +#undef plaa1_2 +#define plaa1_2(a_) +#undef wa1_2 +#define wa1_2(a_) +#undef dp1_2 +#define dp1_2(a_,b_,c_) lqc1_2(a_) plbd(a_ ## 0,b_) dpr(c_) \ + lqc1_2(a_) dpi(c_) wbd(a_ ## 0,b_) +#undef dpp1_2 +#define dpp1_2(a_,b_,c_,d_,e_) lqc1_2(a_) plbd(a_ ## 0,b_) dpr(c_) pf(d_,e_) \ + lqc1_2(a_) dpi(c_) wbd(a_ ## 0,b_) +#undef ddp1_2 +#define ddp1_2(a_,b_,c_) dp1_2(a_,b_,c_) +#undef ddpp1_2 +#define ddpp1_2(a_,b_,c_,d_,e_) dpp1_2(a_,b_,c_,d_,e_) + +#endif + +#endif + +/****************************************************************************** + * Single Precision Real Macros + ******************************************************************************/ + +#ifdef SREAL + +#ifdef NO_TRANSPOSE + +#undef mpx +#define mpx(a_) pls(0,si,a_) ps(0,a_,a_) +#undef madd +#define madd(a_,b_,c_) pas(a_,b_,c_) +#undef TREG +#define TREG 1 +#undef targ +#define targ(a_) 0 +#undef src +#define src(a_) a_ +#undef ulfa +#define ulfa(a_) + +#ifdef GER +#undef w +#define w(a_) +#undef w1_2 +#define w1_2(a_) +#undef w1_4 +#define w1_4(a_) +#undef CREG +#define CREG 2 +#undef AREG +#define AREG 0 +#undef cp +#define cp pc(CREG,TREG) +#undef wb +#define wb(a_,b_) pu(AREG,a_,b_) +#undef wbd +#define wbd(a_,b_) pud(AREG,a_,b_) +#undef wbs +#define wbs(a_,b_) pus(AREG,a_,b_) +#else +#undef CREG +#define CREG 0 +#undef AREG +#define AREG TREG +#undef cp +#define cp +#undef wb +#define wb(a_,b_) +#undef wbd +#define wbd(a_,b_) +#undef wbs +#define wbs(a_,b_) +#undef w +#define w(a_) pu(CREG,a_ ## 0,si) +#undef w1_2 +#define w1_2(a_) pud(CREG,a_ ## 0,si) +#undef w1_4 +#define w1_4(a_) pus(CREG,a_ ## 0,si) +#endif + +#else + +#undef mpx +#define mpx(a_) px(a_) +#ifdef BETA0 +#undef madd +#define madd(a_,b_,c_) +#else +#undef madd +#define madd(a_,b_,c_) pas(a_,b_,c_) +#endif +#undef TREG +#define TREG 3 +#undef targ +#define targ(a_) a_ +#undef src +#define src(a_) 0 +#undef w +#define w(a_) +#undef w1_2 +#define w1_2(a_) +#undef w1_4 +#define w1_4(a_) +#undef ulfa +#undef ulfa +#define ulfa(a_) phl(a_,0) pa(0,a_) pc(a_,0) ps(1,0,0) pa(0,a_) \ + madd(0,si,a_) pus(a_,0,si) + +#undef CREG +#define CREG 0 +#undef AREG +#define AREG TREG +#undef cp +#define cp +#undef wb +#define wb(a_,b_) +#undef wbd +#define wbd(a_,b_) +#undef wbs +#define wbs(a_,b_) + +#endif + +#if defined(ALIGN) +#undef plb +#define plb(a_,b_) pla(a_,b_,AREG) +#else +#undef plb +#define plb(a_,b_) pl(a_,b_,AREG) +#endif +#undef plbd +#define plbd(a_,b_) px(AREG) pld(a_,b_,AREG) +#undef plbs +#define plbs(a_,b_) pls(a_,b_,AREG) +#undef dpr +#define dpr(a_) pm(src(a_),TREG) pa(TREG,targ(a_)) +#undef dprp +#define dprp(a_,b_,c_) pf(b_,c_) pm(src(a_),TREG) pa(TREG,targ(a_)) +#undef dprs +#define dprs(a_) pmsr(src(a_),TREG) pasr(TREG,targ(a_)) +#undef dprps +#define dprps(a_,b_,c_) pf(b_,c_) pmsr(src(a_),TREG) pasr(TREG,targ(a_)) + +#undef plaa +#define plaa(a_) pl(a_ ## 0,si,CREG) +#undef wa +#define wa(a_) w(a_) +#undef dp +#define dp(a_,b_,c_) cp plb(a_ ## 0,b_) dpr(c_) wb(a_ ## 0,b_) +#undef dpp +#define dpp(a_,b_,c_,d_,e_) cp plb(a_ ## 0,b_) dprp(c_,d_,e_) wb(a_ ## 0,b_) +#undef ddp +#define ddp(a_,b_,c_) dp(a_,b_,c_) +#undef ddpp +#define ddpp(a_,b_,c_,d_,e_) dpp(a_,b_,c_,d_,e_) + +#undef plaa1_2 +#define plaa1_2(a_) px(CREG) pld(a_ ## 0,si,CREG) +#undef wa1_2 +#define wa1_2(a_) w1_2(a_) +#undef dp1_2 +#define dp1_2(a_,b_,c_) cp plbd(a_ ## 0,b_) dpr(c_) wbd(a_ ## 0,b_) +#undef dpp1_2 +#define dpp1_2(a_,b_,c_,d_,e_) cp plbd(a_ ## 0,b_) dprp(c_,d_,e_) wbd(a_ ## 0,b_) +#undef ddp1_2 +#define ddp1_2(a_,b_,c_) dp1_2(a_,b_,c_) +#undef ddpp1_2 +#define ddpp1_2(a_,b_,c_,d_,e_) dpp1_2(a_,b_,c_,d_,e_) + +#undef plaa1_4 +#define plaa1_4(a_) pls(a_ ## 0,si,CREG) +#undef wa1_4 +#define wa1_4(a_) w1_4(a_) +#undef dp1_4 +#define dp1_4(a_,b_,c_) cp plbs(a_ ## 0,b_) dprs(c_) wbs(a_ ## 0,b_) +#undef dpp1_4 +#define dpp1_4(a_,b_,c_,d_,e_) cp plbs(a_ ## 0,b_) dprps(c_,d_,e_) wbs(a_ ## 0,b_) +#undef ddp1_4 +#define ddp1_4(a_,b_,c_) dp1_4(a_,b_,c_) +#undef ddpp1_4 +#define ddpp1_4(a_,b_,c_,d_,e_) dpp1_4(a_,b_,c_,d_,e_) + + + +#undef R1 +#define R1 4 +#undef R2 +#define R2 5 +#undef R3 +#define R3 6 +#undef R4 +#define R4 7 + +#endif + +/****************************************************************************** + * Double Precision Real Macros + ******************************************************************************/ + +#ifdef DREAL + +#ifdef ATL_SSE2 + +#ifdef NO_TRANSPOSE + +#undef mpx +#define mpx(a_) pls(0,si,a_) ps(0,a_,a_) +#undef madd +#define madd(a_,b_,c_) pas(a_,b_,c_) +#undef TREG +#define TREG 1 +#undef targ +#define targ(a_) 0 +#undef src +#define src(a_) a_ +#undef ulfa +#define ulfa(a_) + +#ifdef GER +#undef w +#define w(a_) +#undef w1_2 +#define w1_2(a_) +#undef w1_4 +#define w1_4(a_) +#undef CREG +#define CREG 2 +#undef AREG +#define AREG 0 +#undef cp +#define cp pc(CREG,TREG) +#undef wb +#define wb(a_,b_) pu(AREG,a_,b_) +#undef wbd +#define wbd(a_,b_) pus(AREG,a_,b_) +#undef wbs +/* #define wbs(a_,b_) pus(AREG,a_,b_) */ +#else +#undef CREG +#define CREG 0 +#undef AREG +#define AREG TREG +#undef cp +#define cp +#undef wb +#define wb(a_,b_) +#undef wbd +#define wbd(a_,b_) +#undef wbs +/* #define wbs(a_,b_) */ +#undef w +#define w(a_) pu(CREG,a_ ## 0,si) +#undef w1_2 +#define w1_2(a_) pus(CREG,a_ ## 0,si) +#undef w1_4 +/* #define w1_4(a_) pus(CREG,a_ ## 0,si) */ +#endif + +#else + +#undef mpx +#define mpx(a_) px(a_) +#ifdef BETA0 +#undef madd +#define madd(a_,b_,c_) +#else +#undef madd +#define madd(a_,b_,c_) pas(a_,b_,c_) +#endif +#undef TREG +#define TREG 3 +#undef targ +#define targ(a_) a_ +#undef src +#define src(a_) 0 +#undef w +#define w(a_) +#undef w1_2 +#define w1_2(a_) +#undef w1_4 +#define w1_4(a_) +#undef ulfa +#undef ulfa +#define ulfa(a_) /* phl(a_,0) pa(0,a_) */ pc(a_,0) ps(1,0,0) pa(0,a_) \ + madd(0,si,a_) pus(a_,0,si) + +#undef CREG +#define CREG 0 +#undef AREG +#define AREG TREG +#undef cp +#define cp +#undef wb +#define wb(a_,b_) +#undef wbd +#define wbd(a_,b_) +#undef wbs +#define wbs(a_,b_) + +#endif + +#if defined(ALIGN) +#undef plb +#define plb(a_,b_) pla(a_,b_,AREG) +#else +#undef plb +#define plb(a_,b_) pl(a_,b_,AREG) +#endif +#undef plbd +#define plbd(a_,b_) /* px(AREG) */pls(a_,b_,AREG) +#undef plbs +/* #define plbs(a_,b_) pls(a_,b_,AREG) */ +#undef dpr +#define dpr(a_) pm(src(a_),TREG) pa(TREG,targ(a_)) +#undef dprp +#define dprp(a_,b_,c_) pf(b_,c_) pm(src(a_),TREG) pa(TREG,targ(a_)) +#undef dprs +#define dprs(a_) pmsr(src(a_),TREG) pasr(TREG,targ(a_)) +#undef dprps +#define dprps(a_,b_,c_) pf(b_,c_) pmsr(src(a_),TREG) pasr(TREG,targ(a_)) + +#undef plaa +#define plaa(a_) pl(a_ ## 0,si,CREG) +#undef wa +#define wa(a_) w(a_) +#undef dp +#define dp(a_,b_,c_) cp plb(a_ ## 0,b_) dpr(c_) wb(a_ ## 0,b_) +#undef dpp +#define dpp(a_,b_,c_,d_,e_) cp plb(a_ ## 0,b_) dprp(c_,d_,e_) wb(a_ ## 0,b_) +#undef ddp +#define ddp(a_,b_,c_) dp(a_,b_,c_) +#undef ddpp +#define ddpp(a_,b_,c_,d_,e_) dpp(a_,b_,c_,d_,e_) + +#undef plaa1_2 +#define plaa1_2(a_) /* px(CREG) */pls(a_ ## 0,si,CREG) +#undef wa1_2 +#define wa1_2(a_) w1_2(a_) +#undef dp1_2 +#define dp1_2(a_,b_,c_) cp plbd(a_ ## 0,b_) dprs(c_) wbd(a_ ## 0,b_) +#undef dpp1_2 +#define dpp1_2(a_,b_,c_,d_,e_) cp plbd(a_ ## 0,b_) dprps(c_,d_,e_) wbd(a_ ## 0,b_) +#undef ddp1_2 +#define ddp1_2(a_,b_,c_) dp1_2(a_,b_,c_) +#undef ddpp1_2 +#define ddpp1_2(a_,b_,c_,d_,e_) dpp1_2(a_,b_,c_,d_,e_) + +#undef plaa1_4 +/* #define plaa1_4(a_) pls(a_ ## 0,si,CREG) */ +#undef wa1_4 +/* #define wa1_4(a_) w1_4(a_) */ +#undef dp1_4 +/* #define dp1_4(a_,b_,c_) cp plbs(a_ ## 0,b_) dprs(c_) wbs(a_ ## 0,b_) */ +#undef dpp1_4 +/* #define dpp1_4(a_,b_,c_,d_,e_) cp plbs(a_ ## 0,b_) dprps(c_,d_,e_) wbs(a_ ## 0,b_) */ +#undef ddp1_4 +/* #define ddp1_4(a_,b_,c_) dp1_4(a_,b_,c_) */ +#undef ddpp1_4 +/* #define ddpp1_4(a_,b_,c_,d_,e_) dpp1_4(a_,b_,c_,d_,e_) */ + + + +#undef R1 +#define R1 4 +#undef R2 +#define R2 5 +#undef R3 +#define R3 6 +#undef R4 +#define R4 7 + +#else + +#ifdef NO_TRANSPOSE + +#undef t0 +#define t0(a_) 1 +#undef s0 +#define s0(a_) a_ +#undef t8 +#define t8(a_) 2 +#undef s8 +#define s8(a_) a_ +#undef w +#define w(a_) fp(a_ ## 0,si) fp(a_ ## 8,si) +#undef w1_2 +#define w1_2(a_) fp(a_ ## 0,si) +#undef mpx +#define mpx(a_) fl(0,si) fc(M(a_,2)) +#undef madd +#define madd(a_,b_,c_) faa(a_,b_) +#undef ulfa +#define ulfa(a_) fc(0) + +#else + +#undef t0 +#define t0(a_) a_ +#undef s0 +#define s0(a_) 1 +#undef t8 +#define t8(a_) a_ +#undef s8 +#define s8(a_) 2 +#undef w +#define w(a_) +#undef w1_2 +#define w1_2(a_) +#undef mpx +#define mpx(a_) fz +#ifdef BETA0 +#undef madd +#define madd(a_,b_,c_) +#else +#undef madd +#define madd(a_,b_,c_) faa(a_,b_) +#endif +#undef ulfa +#define ulfa(a_) madd(0,si,a_) fp(0,si) + +#endif + + +#ifndef GER + +#undef plaa1_2 +#define plaa1_2(a_) fl(a_ ## 0,si) +#undef wa1_2 +#define wa1_2(a_) w1_2(a_) +#ifdef NO_TRANSPOSE +#undef ddp1_2 +#define ddp1_2(a_,b_,c_) fl(a_ ## 0,b_) fm(M(s0(c_),1),0) fap(0,t0(c_)) +#undef dp1_2 +#define dp1_2(a_,b_,c_) ddp1_2(a_,b_,c_) +#else +#undef ddp1_2 +#define ddp1_2(a_,b_,c_) fl(a_ ## 0,b_) fm(s0(c_),0) fap(0,M(t0(c_),1)) +#undef dp1_2 +#define dp1_2(a_,b_,c_) fl(a_ ## 0,b_) fmp(0,s0(c_)) fap(0,M(t0(c_),2)) +#endif + +#else + +#undef plaa1_2 +#define plaa1_2(a_) fl(a_ ## 0,si) +#undef wa1_2 +#define wa1_2(a_) +#undef ddp1_2 +#define ddp1_2(a_,b_,c_) fd(M(s0(c_),2)) fm(t0(c_),0) faa(a_ ## 0,b_) fp(a_ ## 0,b_) +#undef dp1_2 +#define dp1_2(a_,b_,c_) fm(M(s0(c_),2),0) faa(a_ ## 0,b_) fp(a_ ## 0,b_) + +#endif + + + +#undef plaa +#define plaa(a_) fl(a_ ## 0,si) fl(a_ ## 8,si) fx1 + +#ifndef GER + + +#undef wa +#define wa(a_) w(a_) + + +#undef ddp +#define ddp(a_,b_,c_) fl(a_ ## 0,b_) fm(s0(c_),0) fl(a_ ## 8,b_) \ + fm(P(s8(c_),1),0) fx1 fap(0,P(t0(c_),1)) \ + fap(0,t8(c_)) +#undef ddpp +#define ddpp(a_,b_,c_,d_,e_) fl(a_ ## 0,b_) fm(s0(c_),0) fl(a_ ## 8,b_) \ + fm(P(s8(c_),1),0) pf(d_,e_) fx1 fap(0,P(t0(c_),1)) \ + fap(0,t8(c_)) + +/* #define ddp(a_,b_,c_) fd(M(s0(c_),1)) fma(a_ ## 0,b_) fap(0,t0(c_)) \ */ +/* fd(M(s8(c_),1)) fma(a_ ## 8,b_) fap(0,t8(c_)) */ +/* #define ddpp(a_,b_,c_,d_,e_) fd(M(s0(c_),1)) fma(a_ ## 0,b_) fap(0,t0(c_)) \ */ +/* \ */ +/* fd(M(s8(c_),1)) fma(a_ ## 8,b_) fap(0,t8(c_)) pf(d_,e_) */ + +#ifdef NO_TRANSPOSE + +#undef dp +#define dp(a_,b_,c_) ddp(a_,b_,c_) +#undef dpp +#define dpp(a_,b_,c_,d_,e_) ddpp(a_,b_,c_,d_,e_) + +#else + +#undef dp +#define dp(a_,b_,c_) fl(a_ ## 0,b_) fmp(0,s0(c_)) fl(a_ ## 8,b_) \ + fmp(0,s8(c_)) fap(0,M(t0(c_),1)) fap(0,M(t8(c_),2)) +#undef dpp +#define dpp(a_,b_,c_,d_,e_) fl(a_ ## 0,b_) pf(d_ ,e_) fmp(0,s0(c_)) fl(a_ ## 8,b_) \ + fmp(0,s8(c_)) fap(0,M(t0(c_),1)) fap(0,M(t8(c_),2)) + +/* #define dp(a_,b_,c_) fma(a_ ## 0,b_) fap(0,M(t0(c_),1)) \ */ +/* fma(a_ ## 8,b_) fap(0,M(t8(c_),2)) */ +/* #define dpp(a_,b_,c_,d_,e_) fma(a_ ## 0,b_) fap(0,M(t0(c_),1)) \ */ +/* \ */ +/* fma(a_ ## 8,b_) fap(0,M(t8(c_),2)) pf(d_,e_) */ + +#endif + + +#else + +#undef wa +#define wa(a_) +#undef ddp +#define ddp(a_,b_,c_) fd(M(s0(c_),1)) fm(t0(c_),0) faa(a_ ## 0,b_) fp(a_ ## 0,b_) \ + fd(M(s8(c_),1)) fm(t8(c_),0) faa(a_ ## 8,b_) fp(a_ ## 8,b_) +#undef ddpp +#define ddpp(a_,b_,c_,d_,e_) fd(M(s0(c_),1)) fm(t0(c_),0) faa(a_ ## 0,b_) fp(a_ ## 0,b_) \ + fd(M(s8(c_),1)) fm(t8(c_),0) faa(a_ ## 8,b_) fp(a_ ## 8,b_) pf(d_,e_) + +#undef dp +#define dp(a_,b_,c_) fm(M(s0(c_),1),0) faa(a_ ## 0,b_) fp(a_ ## 0,b_) \ + fm(M(s8(c_),2),0) faa(a_ ## 8,b_) fp(a_ ## 8,b_) +#undef dpp +#define dpp(a_,b_,c_,d_,e_) fm(M(s0(c_),1),0) faa(a_ ## 0,b_) fp(a_ ## 0,b_) \ + fm(M(s8(c_),2),0) faa(a_ ## 8,b_) fp(a_ ## 8,b_) pf(d_,e_) + +#endif + + +#undef R1 +#define R1 3 +#undef R2 +#define R2 4 +#undef R3 +#define R3 5 +#undef R4 +#define R4 6 + +#endif + +#endif + +/****************************************************************************** + * Double Precision Complex Macros + ******************************************************************************/ + +#ifdef DCPLX + +#ifdef ATL_SSE2 +#ifdef NO_TRANSPOSE + +#if NDPM > 3 +#error Max NDPM is 3 for DCPLX NO_TRANSPOSE +#endif + +#undef plax +#define plax + +#undef R1 +#define R1 2 +#undef R2 +#define R2 4 +#undef R3 +#define R3 6 +#undef R4 +#define R4 6 + +#undef TREG +#define TREG 1 +#undef SREG +#define SREG 0 +#undef CREG +#define CREG 0 + +#ifdef GER +#undef AREG +#define AREG 0 +#undef targ +#define targ(a_) AREG +#undef wb +#define wb(a_,b_) pu(AREG,a_,b_) +#undef wbd +/* #define wbd(a_,b_) pud(AREG,a_,b_) */ +#undef w +#define w(a_) +#undef w1_2 +/* #define w1_2(a_) */ +#else +#undef AREG +#define AREG TREG +#undef targ +#define targ(a_) CREG +#undef wb +#define wb(a_,b_) +#undef wbd +/* #define wbd(a_,b_) */ +#undef w +#define w(a_) pu(CREG,a_ ## 0,si) +#undef w1_2 +/* #define w1_2(a_) pud(CREG,a_ ## 0,si) */ +#endif + +#undef src +#define src(a_) a_ +#undef mpx +#define mpx(a_) pls(0,si,a_) ps(0,a_,a_) pls(8,si,P(a_,1)) \ + ps(0,P(a_,1),P(a_,1)) sign(a_) +#undef madd +#define madd(a_,b_,c_) pas(a_,b_,c_) +#undef ulfa +#define ulfa(a_) + +#else + +#undef R1 +#define R1 4 +#undef R2 +#define R2 5 +#undef R3 +#define R3 6 +#undef R4 +#define R4 7 + +#undef TREG +#define TREG 3 +#undef SREG +#define SREG 2 +#undef CREG +#define CREG 0 +#undef targ +#define targ(a_) a_ +#undef src +#define src(a_) 0 +#undef w +#define w(a_) +#undef w1_2 +#define w1_2(a_) +#undef mpx +#define mpx(a_) px(a_) +#ifdef BETA0 +#undef ulfa +#define ulfa(a_) /* phl(a_,0) pa(0,a_) */pu(a_,0,si) +#else +#undef ulfa +#define ulfa(a_) pl(0,si,TREG) /* phl(a_,0) pa(0,a_) */ pa(TREG,a_) pu(a_,0,si) +#endif +#undef AREG +#define AREG TREG +#undef wb +#define wb(a_,b_) +#undef wbd +#define wbd(a_,b_) +#undef wbs +#define wbs(a_,b_) + + +#undef plax +#define plax pc(CREG,1) ps(0,CREG,CREG) ps(3,1,1) sign(CREG) + + + +#endif + +#if defined(Conj_) && ! defined(GER) +#undef sign +#define sign(a_) pm(SREG,a_) +#else +#undef sign +#define sign(a_) pm(SREG,P(a_,1)) +#endif + + + +#undef plb +#define plb(a_,b_) pl(a_,b_,AREG) +#undef plbd +/* #define plbd(a_,b_) px(AREG) pld(a_,b_,AREG) */ + +#undef dpr +#define dpr(a_) pm(src(a_),TREG) pa(TREG,targ(a_)) +#undef dprp +#define dprp(a_,b_,c_) pf(b_,c_) pm(src(a_),TREG) pa(TREG,targ(a_)) +#undef dpi +#define dpi(a_) pm(P(src(a_),1),TREG) ps(1,TREG,TREG) pa(TREG,targ(a_)) + +#ifndef GER + +#undef plaa +#define plaa(a_) pl(a_ ## 0,si,CREG) plax +#undef wa +#define wa(a_) w(a_) +#undef dp +#define dp(a_,b_,c_) plb(a_ ## 0,b_) dpr(c_) plb(a_ ## 0,b_) dpi(c_) +#undef dpp +#define dpp(a_,b_,c_,d_,e_) plb(a_ ## 0,b_) dprp(c_,d_,e_) plb(a_ ## 0,b_) dpi(c_) +#undef ddp +#define ddp(a_,b_,c_) dp(a_,b_,c_) +#undef ddpp +#define ddpp(a_,b_,c_,d_,e_) dpp(a_,b_,c_,d_,e_) + +#undef plaa1_2 +/* #define plaa1_2(a_) px(CREG) pld(a_ ## 0,si,CREG) plax */ +#undef wa1_2 +/* #define wa1_2(a_) w1_2(a_) */ +#undef dp1_2 +/* #define dp1_2(a_,b_,c_) plbd(a_ ## 0,b_) dpr(c_) plbd(a_ ## 0,b_) dpi(c_) */ +#undef dpp1_2 +/* #define dpp1_2(a_,b_,c_,d_,e_) plbd(a_ ## 0,b_) dprp(c_,d_,e_) plbd(a_ ## 0,b_) dpi(c_) */ +#undef ddp1_2 +/* #define ddp1_2(a_,b_,c_) dp1_2(a_,b_,c_) */ +#undef ddpp1_2 +/* #define ddpp1_2(a_,b_,c_,d_,e_) dpp1_2(a_,b_,c_,d_,e_) */ + + +#else + +#undef lqc +#define lqc(a_) pl(a_ ## 0,si,TREG) +#undef lqc1 +/* #define lqc1_2(a_) px(TREG) pld(a_ ## 0,si,TREG) */ + + +#undef plaa +#define plaa(a_) +#undef wa +#define wa(a_) +#undef dp +#define dp(a_,b_,c_) lqc(a_) plb(a_ ## 0,b_) dpr(c_) \ + lqc(a_) dpi(c_) wb(a_ ## 0,b_) +#undef dpp +#define dpp(a_,b_,c_,d_,e_) lqc(a_) plb(a_ ## 0,b_) dpr(c_) pf(d_,e_) \ + lqc(a_) dpi(c_) wb(a_ ## 0,b_) +#undef ddp +#define ddp(a_,b_,c_) dp(a_,b_,c_) +#undef ddpp +#define ddpp(a_,b_,c_,d_,e_) dpp(a_,b_,c_,d_,e_) + +#undef plaa1_2 +/* #define plaa1_2(a_) */ +#undef wa1_2 +/* #define wa1_2(a_) */ +#undef dp1_2 +/* #define dp1_2(a_,b_,c_) lqc1_2(a_) plbd(a_ ## 0,b_) dpr(c_) \ */ +/* lqc1_2(a_) dpi(c_) wbd(a_ ## 0,b_) */ +#undef dpp1_2 +/* #define dpp1_2(a_,b_,c_,d_,e_) lqc1_2(a_) plbd(a_ ## 0,b_) dpr(c_) pf(d_,e_) \ */ +/* lqc1_2(a_) dpi(c_) wbd(a_ ## 0,b_) */ +#undef ddp1_2 +/* #define ddp1_2(a_,b_,c_) dp1_2(a_,b_,c_) */ +#undef ddpp1_2 +/* #define ddpp1_2(a_,b_,c_,d_,e_) dpp1_2(a_,b_,c_,d_,e_) */ + +#endif + +#else + +#if NDPM > 2 +#error Max NDPM is 2 for DCPLX +#endif + +#undef TREG +#define TREG 2 + +#ifdef NO_TRANSPOSE + +#undef w +#define w(a_) fp(a_ ## 0,si) fp(a_ ## 8,si) +#undef plax +#define plax fx1 +#undef srr +#define srr(a_) a_ +#undef sri +#define sri(a_) a_ +#undef sir +#define sir(a_) a_ +#undef sii +#define sii(a_) a_ +#undef trr +#define trr(a_) P(TREG,1) +#undef tri +#define tri(a_) M(TREG,1) +#undef tir +#define tir(a_) TREG +#undef tii +#define tii(a_) TREG +#undef mpx +#define mpx(a_) fl(0,si) fl(8,si) fc(M(a_,2)) fc(M(a_,2)) +#undef madd +#define madd(a_,b_,c_) faa(a_,b_) +#undef ulfa +#define ulfa(a_) fc(0) fc(0) + +#else + +#undef srr +#define srr(a_) P(TREG,1) +#undef sri +#define sri(a_) M(TREG,1) +#undef sir +#define sir(a_) TREG +#undef sii +#define sii(a_) TREG +#undef trr +#define trr(a_) a_ +#undef tri +#define tri(a_) a_ +#undef tir +#define tir(a_) a_ +#undef tii +#define tii(a_) a_ +#undef w +#define w(a_) +#undef plax +#define plax +#undef mpx +#define mpx(a_) fz fz +#ifdef BETA0 +#undef madd +#define madd(a_,b_,c_) +#else +#undef madd +#define madd(a_,b_,c_) faa(a_,b_) +#endif +#undef ulfa +#define ulfa(a_) madd(0,si,a_) fp(0,si) madd(8,si,a_) fp(8,si) + +#endif + + + +#ifdef Conj_ +#undef fapi +#define fapi(a_,b_) fsp(b_) +#undef fspi +#define fspi(a_,b_) fap(a_,b_) +#else +#undef fapi +#define fapi(a_,b_) fap(a_,b_) +#undef fspi +#define fspi(a_,b_) fsp(b_) +#endif + +#ifndef GER + + +#undef plaa +#define plaa(a_) fl(a_ ## 0,si) fl(a_ ## 8,si) plax +#undef wa +#define wa(a_) w(a_) +#undef ddp +#define ddp(a_,b_,c_) fl(a_ ## 0,b_) fd(0) fm(srr(c_),0) fap(0,trr(c_)) \ + fm(sri(c_),0) fap(0,tri(c_))\ + fl(a_ ## 8,b_) fd(0) fm(sir(c_),0) fspi(0,tir(c_)) \ + fm(sii(c_),0) fapi(0,tii(c_)) +#undef ddpp +#define ddpp(a_,b_,c_,d_,e_) fl(a_ ## 0,b_) fd(0) fm(srr(c_),0) fap(0,trr(c_)) \ + fm(sri(c_),0) fap(0,tri(c_))\ + fl(a_ ## 8,b_) fd(0) pf(d_,e_) fm(sir(c_),0) fspi(0,tir(c_))\ + fm(sii(c_),0) fapi(0,tii(c_)) + + + +#ifdef NO_TRANSPOSE + + + +#undef dp +#define dp(a_,b_,c_) ddp(a_,b_,c_) +#undef dpp +#define dpp(a_,b_,c_,d_,e_) ddpp(a_,b_,c_,d_,e_) + + + +#else + +#undef dp +#define dp(a_,b_,c_) fl(a_ ## 0,b_) fd(0) fm(srr(c_),0) fap(0,trr(c_)) \ + fm(sri(c_),0) fap(0,tri(c_))\ + fl(a_ ## 8,b_) fm(0,sir(c_)) fmp(0,M(sir(c_),1)) \ + fspi(0,M(tir(c_),2)) fapi(0,M(tii(c_),2)) + +#undef dpp +#define dpp(a_,b_,c_,d_,e_) fl(a_ ## 0,b_) fd(0) fm(srr(c_),0) fap(0,trr(c_)) \ + pf(d_,e_) fm(sri(c_),0) fap(0,tri(c_))\ + fl(a_ ## 8,b_) fm(0,sir(c_)) fmp(0,M(sir(c_),1)) \ + fspi(0,M(tir(c_),2)) fapi(0,M(tii(c_),2)) + + +#endif + +#else + +#undef plaa +#define plaa(a_) fl(a_ ## 0,si) fl(a_ ## 8,si) plax +#undef wa +#define wa(a_) + +#undef ddprr +#define ddprr(a_,b_,c_) fl(a_ ## 0,b_) \ + fd(tri(c_)) fm(P(sri(c_),1),0) fap(0,1) \ + fd(M(trr(c_),1)) fm(srr(c_),0) fspi(0,1) \ + fp(a_ ## 0,b_) +#undef ddpri +#define ddpri(a_,b_,c_) fl(a_ ## 8,b_) \ + fd(tii(c_)) fm(P(sii(c_),1),0) fap(0,1) \ + fd(M(tir(c_),1)) fm(sir(c_),0) fapi(0,1) \ + fp(a_ ## 8,b_) +#undef dpri +#define dpri(a_,b_,c_) fl(a_ ## 8,b_) \ + fx(2) fm(sir(c_),0) fap(0,2) \ + fm(M(sii(c_),2),0) fapi(0,1) \ + fp(a_ ## 8,b_) + + +#undef ddpp +#define ddpp(a_,b_,c_,d_,e_) ddprr(a_,b_,c_) pf(d_,e_) ddpri(a_,b_,c_) +#undef ddp +#define ddp(a_,b_,c_) ddprr(a_,b_,c_) ddpri(a_,b_,c_) +#undef dpp +#define dpp(a_,b_,c_,d_,e_) ddprr(a_,b_,c_) pf(d_,e_) dpri(a_,b_,c_) +#undef dp +#define dp(a_,b_,c_) ddprr(a_,b_,c_) dpri(a_,b_,c_) + +#endif + + +#undef R1 +#define R1 4 +#undef R2 +#define R2 6 +#undef R3 +#define R3 6 +#undef R4 +#define R4 6 + +#endif + +#endif + + +/****************************************************************************** + * General Macros + ******************************************************************************/ + + + + +#undef bla1 +#define bla1(a_,b_) plaa(a_) dpp(a_,ax,R1,b_,si) wa(a_) +#undef blb1 +#define blb1(a_,b_) plaa(a_) dpp(a_,ax,R1,b_,ax) wa(a_) + +#undef bla2 +#undef bla2 +#define bla2(a_,b_) pf(b_,si) plaa(a_) ddp(a_,ax,R1) pf(b_,ax) dp(a_,bx,R2) wa(a_) +#undef blb2 +#undef blb2 +#define blb2(a_,b_) plaa(a_) ddpp(a_,ax,R1,b_,bx) dp(a_,bx,R2) wa(a_) + +#undef bla3 +#define bla3(a_,b_) plaa(a_) ddpp(a_,ax,R1,b_,si) ddp(a_,bx,R2) \ + dpp(a_,cx,R3,b_,ax) wa(a_) +#undef blb3 +#define blb3(a_,b_) plaa(a_) ddpp(a_,ax,R1,b_,bx) ddp(a_,bx,R2) \ + dpp(a_,cx,R3,b_,cx) wa(a_) + +#undef bla4 +#define bla4(a_,b_) plaa(a_) ddpp(a_,ax,R1,b_,si) ddpp(a_,bx,R2,b_,ax) \ + ddp(a_,cx,R3) dpp(a_,dx,R4,b_,bx) wa(a_) +#undef blb4 +#define blb4(a_,b_) plaa(a_) ddp(a_,ax,R1) ddpp(a_,bx,R2,b_,cx) \ + ddp(a_,cx,R3) dpp(a_,dx,R4,b_,dx) wa(a_) + +#undef bla +#define bla(a_,b_) Mjoin(bla,NDP)(a_,b_) +#undef blb +#define blb(a_,b_) Mjoin(blb,NDP)(a_,b_) + + + +#undef bla11_2 +#define bla11_2(a_) plaa1_2(a_) dp1_2(a_,ax,R1) wa1_2(a_) +#undef bla21_2 +#define bla21_2(a_) plaa1_2(a_) ddp1_2(a_,ax,R1) dp1_2(a_,bx,R2) wa1_2(a_) +#undef bla31_2 +#define bla31_2(a_) plaa1_2(a_) ddp1_2(a_,ax,R1) ddp1_2(a_,bx,R2) \ + dp1_2(a_,cx,R3) wa1_2(a_) +#undef bla41_2 +#define bla41_2(a_) plaa1_2(a_) ddp1_2(a_,ax,R1) ddp1_2(a_,bx,R2) \ + ddp1_2(a_,cx,R3) dp1_2(a_,dx,R4) wa1_2(a_) + +#undef bla1_2 +#define bla1_2(a_) Mjoin(Mjoin(bla,NDP),1_2)(a_) + + + +#undef bla11_4 +#define bla11_4(a_) plaa1_4(a_) dp1_4(a_,ax,R1) wa1_4(a_) +#undef bla21_4 +#define bla21_4(a_) plaa1_4(a_) ddp1_4(a_,ax,R1) dp1_4(a_,bx,R2) wa1_4(a_) +#undef bla31_4 +#define bla31_4(a_) plaa1_4(a_) ddp1_4(a_,ax,R1) ddp1_4(a_,bx,R2) \ + dp1_4(a_,cx,R3) wa1_4(a_) +#undef bla41_4 +#define bla41_4(a_) plaa1_4(a_) ddp1_4(a_,ax,R1) ddp1_4(a_,bx,R2) \ + ddp1_4(a_,cx,R3) dp1_4(a_,dx,R4) wa1_4(a_) + +#undef bla1_4 +#define bla1_4(a_) Mjoin(Mjoin(bla,NDP),1_4)(a_) + + + +#undef inc1 +#define inc1(a_) a(a_,si) a(a_,ax) +#undef inc2 +#define inc2(a_) inc1(a_) a(a_,bx) +#undef inc3 +#define inc3(a_) inc2(a_) a(a_,cx) +#undef inc4 +#define inc4(a_) inc3(a_) a(a_,dx) + +#undef inc +#define inc(a_) Mjoin(inc,NDP)(a_) + + +#ifdef PREFETCH +/* #include "camm_arith.h" */ +#undef S +#define S(a_,b_) (a_) + (b_) +#undef PF1 +#define PF1 PREFETCH +#undef PF2 +#define PF2 S(PF1,32) +#undef PF3 +#define PF3 S(PF1,64) +#undef PF4 +#define PF4 S(PF1,96) +#undef PF5 +#define PF5 S(PF1,128) +#undef PF6 +#define PF6 S(PF1,160) +#undef PF7 +#define PF7 S(PF1,192) +#undef PF8 +#define PF8 S(PF1,224) +#else +#undef PF1 +#define PF1 64 +#undef PF2 +#define PF2 96 +#undef PF3 +#define PF3 128 +#undef PF4 +#define PF4 160 +#undef PF5 +#define PF5 192 +#undef PF6 +#define PF6 224 +#undef PF7 +#define PF7 256 +#undef PF8 +#define PF8 288 +#endif + + +#if defined(NO_TRANSPOSE) && !defined(SREAL) && !defined(GER) +#undef pf +#define pf(a_,b_) f(t0,a_,b_) +#else +#undef pf +#define pf(a_,b_) f(nta,a_,b_) +#endif + +#undef bl1 +#define bl1 bla1_4(0x0) inc(4) +#undef bl2 +#define bl2 bla1_2(0x0) inc(8) +#undef bl4 +#define bl4 bla(0x0,PF1) inc(16) +#undef bl8 +#define bl8 bla(0x0,PF1) blb(0x1,PF1) inc(32) +#undef bl16 +#define bl16 bla(0x0,PF1) blb(0x1,PF1) bla(0x2,PF2) blb(0x3,PF2) inc(64) +#undef bl32 +#define bl32 bla(0x0,PF1) blb(0x1,PF1) bla(0x2,PF2) blb(0x3,PF2) \ + bla(0x4,PF3) blb(0x5,PF3) bla(0x6,PF4) blb(0x7,PF4) inc(128) +#undef bl64 +#define bl64 bla(0x0,PF1) blb(0x1,PF1) bla(0x2,PF2) blb(0x3,PF2) \ + bla(0x4,PF3) blb(0x5,PF3) bla(0x6,PF4) blb(0x7,PF4) \ + bla(0x8,PF5) blb(0x9,PF5) bla(0xa,PF6) blb(0xb,PF6) \ + bla(0xc,PF7) blb(0xd,PF7) bla(0xe,PF8) blb(0xf,PF8) inc(256) + +/* #define in2 inc(8) */ +/* #define in4 inc(16) */ +/* #define in8 inc(32) */ +/* #define in16 inc(64) */ + +#undef in2 +#define in2 +#undef in4 +#define in4 +#undef in8 +#define in8 +#undef in16 +#define in16 + +#ifdef NO_TRANSPOSE +#undef incf +#define incf ra(di,si) +#else +#undef incf +#define incf +#endif + +#undef lf1 +#define lf1 mpx(R1) +#undef lf2 +#define lf2 lf1 incf mpx(R2) +#undef lf3 +#define lf3 lf2 incf mpx(R3) +#undef lf4 +#define lf4 lf3 incf mpx(R4) + +#undef lf +#define lf Mjoin(lf,NDP) + + +#undef ulf1 +#define ulf1 ulfa(R1) +#undef ulf2 +#define ulf2 ulf1 ra(di,si) ulfa(R2) +#undef ulf3 +#define ulf3 ulf2 ra(di,si) ulfa(R3) +#undef ulf4 +#define ulf4 ulf3 ra(di,si) ulfa(R4) + +#undef ulf +#define ulf Mjoin(ulf,NDP) + +#undef lpba +#define lpba(a_) "movl %%esi,%%e" #a_ "\n\t" + +#undef lpb1 +#define lpb1 lpba(ax) +#undef lpb2 +#define lpb2 lpb1 ra(di,si) lpba(bx) +#undef lpb3 +#define lpb3 lpb2 ra(di,si) lpba(cx) +#undef lpb4 +#define lpb4 lpb3 ra(di,si) lpba(dx) + +#undef lpb +#define lpb Mjoin(lpb,NDP) + +#undef ipf1 +#define ipf1(a_) pf(a_,si) pf(a_,ax) +#undef ipf2 +#define ipf2(a_) ipf1(a_) pf(a_,bx) +#undef ipf3 +#define ipf3(a_) ipf2(a_) pf(a_,cx) +#undef ipf4 +#define ipf4(a_) ipf3(a_) pf(a_,dx) + +#undef ipf +#define ipf(a_) Mjoin(ipf,NDP)(a_) + +#ifdef LUNROLL +#undef UNROLL +#ifdef SREAL +#undef UNROLL +#define UNROLL LUNROLL +#elif defined(DREAL) || defined(SCPLX) +#undef UNROLL +#define UNROLL LUNROLL*2 +#elif defined(DCPLX) +#undef UNROLL +#define UNROLL LUNROLL*4 +#endif +#else +#undef UNROLL +#define UNROLL 16 +#endif + +#undef UNROLL1_2 +#if UNROLL == 64 +#undef blUNROLL +#define blUNROLL bl64 +#undef UNROLL1_2 +#define UNROLL1_2 32 +#elif UNROLL == 32 +#undef blUNROLL +#define blUNROLL bl32 +#undef UNROLL1_2 +#define UNROLL1_2 16 +#elif UNROLL == 16 +#undef blUNROLL +#define blUNROLL bl16 +#undef UNROLL1_2 +#define UNROLL1_2 8 +#elif UNROLL == 8 +#undef blUNROLL +#define blUNROLL bl8 +#undef UNROLL1_2 +#define UNROLL1_2 4 +#elif UNROLL == 4 +#undef blUNROLL +#define blUNROLL bl4 +#undef UNROLL1_2 +#define UNROLL1_2 2 +#elif UNROLL == 2 +#undef blUNROLL +#define blUNROLL bl2 +#undef UNROLL1_2 +#define UNROLL1_2 1 +#elif UNROLL == 1 +#undef blUNROLL +#define blUNROLL bl1 +#undef UNROLL1_2 +#define UNROLL1_2 stop +#endif +#ifndef UNROLL1_2 +#error UNROLL must be set to power of 2 < 128 +#endif + + +#ifdef GER +#undef aconst +#define aconst +#undef cconst +#define cconst const +#else +#undef aconst +#define aconst const +#undef cconst +#define cconst +#endif + +#undef MY_FUNCTION +#define MY_FUNCTION Mjoin(dp,EXT) + +static void +MY_FUNCTION(aconst TYPE *a,int lda, + const TYPE *b, + cconst TYPE *c,int stride,int len) { + +#ifdef SCPLX +#if defined(GER) && defined(Conj_) + const TYPE w1[2]={{-1.0,1.0},{-1.0,1.0}},*w=w1; +#else + const TYPE w1[2]={{1.0,-1.0},{1.0,-1.0}},*w=w1; +#endif +#endif + +#if defined(DCPLX) && defined(ATL_SSE2) +#if defined(GER) && defined(Conj_) + const TYPE w1[1]={{-1.0,1.0}},*w=w1; +#else + const TYPE w1[1]={{1.0,-1.0}},*w=w1; +#endif +#endif + +#ifdef NO_TRANSPOSE +#undef movm +#define movm c +#undef fixm +#define fixm b +#else +#undef movm +#define movm b +#undef fixm +#define fixm c +#endif + NO_INLINE + unsigned u1=stride*sizeof(*fixm),u2=lda*sizeof(*a),u3=len*sizeof(*movm)/sizeof(float); + + ASM ( + + "pushl %%ebx\n\t" + a(4,sp) + +#if defined(SCPLX) || (defined(DCPLX) && defined(ATL_SSE2)) + "movl %6,%%esi\n\t" + pl(0,si,SREG) +#endif + +#ifdef NO_TRANSPOSE + "movl %1,%%esi\n\t" /* fixm */ + "movl %2,%%edi\n\t" /* fixm2fixm */ +#endif + + lf + + "movl %3,%%esi\n\t" /* a */ + "movl %4,%%edi\n\t" /* a2a */ + + lpb + + ipf(0) + + "movl %0,%%esi\n\t" /* movm */ + "movl %5,%%edi\n\t" /* len */ + +#if defined(ALIGN) + +#if defined(SREAL) + + test(4,ax) + je(Mjoin(a1,EXT)) + test(-1,di) + je(Mjoin(a1,EXT)) + sub(1,di) + bl1 + + lab(Mjoin(a1,EXT)) + +#endif + +#if defined(DREAL) || defined(SREAL) + + test(8,ax) + je(Mjoin(as,EXT)) + test(-2,di) + je(Mjoin(as,EXT)) + sub(2,di) + bl2 + + lab(Mjoin(as,EXT)) + +#endif + +#endif + + + ipf(32) + + lab(Mjoin(loop,EXT)) + + test(-UNROLL,di) + je(Mjoin(UNROLL1_2,EXT)) + sub(UNROLL,di) + + blUNROLL + + jmp(Mjoin(loop,EXT)) + +#if UNROLL > 32 + lab(Mjoin(32,EXT)) + test(32,di) + je(Mjoin(16,EXT)) + bl32 +#endif + +#if UNROLL > 16 + lab(Mjoin(16,EXT)) + test(16,di) + je(Mjoin(8,EXT)) + bl16 +#endif + +#if UNROLL > 8 + lab(Mjoin(8,EXT)) + test(8,di) + je(Mjoin(4,EXT)) + bl8 +#endif + +#if UNROLL > 4 + lab(Mjoin(4,EXT)) + test(4,di) + je(Mjoin(2,EXT)) + bl4 +#endif + +#if UNROLL > 2 + lab(Mjoin(2,EXT)) +#ifndef DCPLX + test(2,di) + je(Mjoin(1,EXT)) + bl2 +#endif +#endif + +#if UNROLL > 1 + lab(Mjoin(1,EXT)) +#ifdef SREAL + test(1,di) + je(Mjoin(stop,EXT)) + bl1 +#endif +#endif + + lab(Mjoin(stop,EXT)) + +#ifndef NO_TRANSPOSE + "movl %1,%%esi\n\t" /* fixm */ + "movl %2,%%edi\n\t" /* fixm2fixm */ +#endif + + ulf + + a(-4,sp) + "popl %%ebx\n\t" + + + ::"m" (movm),"m" (fixm),"m" (u1),"m" (a),"m" (u2),"m" (u3) + +#if defined(SCPLX) || (defined(DCPLX) && defined(ATL_SSE2)) + ,"m" (w) +#endif + :"ax","bx","cx","dx","si","di"); + + +} + diff --git a/kaldi_io/src/tools/ATLAS/include/contrib/camm_pipe3.h b/kaldi_io/src/tools/ATLAS/include/contrib/camm_pipe3.h new file mode 100644 index 0000000..7fd1404 --- /dev/null +++ b/kaldi_io/src/tools/ATLAS/include/contrib/camm_pipe3.h @@ -0,0 +1,295 @@ +#include "camm_util.h" + +#ifndef N +#error N must be defined in camm_pipe3.h +#endif +#ifndef KB +#error KB must be defined in camm_pipe3.h +#endif + +#undef p1 +#define p1(a_) Mjoin(p1_4_,N)(a_) +#undef p2 +#define p2(a_) Mjoin(p1_2_,N)(a_) +#undef p4 +#define p4(a_) Mjoin(p1_,N)(a_) +#undef load_pipe +#define load_pipe(a_) Mjoin(lp,N)(a_) +#undef drain_pipe +#define drain_pipe(a_) Mjoin(dp,N)(a_) +#undef pipe_len +#define pipe_len Mjoin(pl,N) + +#undef p8 +#if pipe_len > 4 +#define p8(a_) Mjoin(p2_,N)(a_) +#else +#define p8(a_) p4(a_) p4(SS(a_,16)) +#endif + +#undef p16 +#if pipe_len > 8 +#define p16(a_) Mjoin(p4_,N)(a_) +#else +#define p16(a_) p8(a_) p8(SS(a_,32)) +#endif + +#undef p32 +#if pipe_len > 16 +#define p32(a_) Mjoin(p8_,N)(a_) +#else +#define p32(a_) p16(a_) p16(SS(a_,64)) +#endif + +#undef p64 +#if pipe_len > 32 +#define p64(a_) Mjoin(p16_,N)(a_) +#else +#define p64(a_) p32(a_) p32(SS(a_,128)) +#endif + +#undef p128 +#if pipe_len > 64 +#define p128(a_) Mjoin(p32_,N)(a_) +#else +#define p128(a_) p64(a_) p64(SS(a_,256)) +#endif + +#undef p256 +#if pipe_len > 128 +#define p256(a_) Mjoin(p64_,N)(a_) +#else +#define p256(a_) p128(a_) p128(SS(a_,512)) +#endif + +#if KB < pipe_len +#undef pipe_len +#define pipe_len 0 +#undef load_pipe +#define load_pipe(a_) +#undef drain_pipe +#define drain_pipe(a_) +#endif + + +#undef MKB +/* #ifdef SREAL */ +#define MKB KB +/* #elif defined (DCPLX) */ +/* #define MKB ( KB * 4 ) */ +/* #else */ +/* #define MKB ( KB * 2 ) */ +/* #endif */ + +#if MKB >= 512 +#error MKB must be less than 512 +#endif + +#undef x0 +#undef o0 +#define x0 load_pipe(0) +#define o0 0 + +#undef MKBB +#define MKBB ( MKB - pipe_len ) + +#undef xx1 +#undef oo1 +#if MKBB >= 256 +#define xx1 x0 p256(o0) +#define oo1 SS(1024,o0) +#else +#define xx1 x0 +#define oo1 o0 +#endif + +#undef xx1a +#undef oo1a +#if pipe_len == 256 +#define xx1a xx1 drain_pipe(oo1) +#define oo1a SS(1024,oo1) +#undef MKBB +#define MKBB MKB +#else +#define xx1a xx1 +#define oo1a oo1 +#endif + +#undef x1 +#undef o1 +#if ( MKBB / 128 ) % 2 +#define x1 xx1a p128(oo1a) +#define o1 SS(512,oo1a) +#else +#define x1 xx1a +#define o1 oo1a +#endif + +#undef x1a +#undef o1a +#if pipe_len == 128 +#define x1a x1 drain_pipe(o1) +#define o1a SS(512,o1) +#undef MKBB +#define MKBB MKB +#else +#define x1a x1 +#define o1a o1 +#endif + +#undef x2 +#undef o2 +#if ( MKBB / 64 ) % 2 +#define x2 x1a p64(o1a) +#define o2 SS(256,o1a) +#else +#define x2 x1a +#define o2 o1a +#endif + +#undef x2a +#undef o2a +#if pipe_len == 64 +#define x2a x2 drain_pipe(o2) +#define o2a SS(256,o2) +#undef MKBB +#define MKBB MKB +#else +#define x2a x2 +#define o2a o2 +#endif + +#undef x3 +#undef o3 +#if ( MKBB / 32 ) % 2 +#define x3 x2a p32(o2a) +#define o3 SS(128,o2a) +#else +#define x3 x2a +#define o3 o2a +#endif + +#undef x3a +#undef o3a +#if pipe_len == 32 +#define x3a x3 drain_pipe(o3) +#define o3a SS(128,o3) +#undef MKBB +#define MKBB MKB +#else +#define x3a x3 +#define o3a o3 +#endif + +#undef x4 +#undef o4 +#if ( MKBB / 16 ) % 2 +#define x4 x3a p16(o3a) +#define o4 SS(64,o3a) +#else +#define x4 x3a +#define o4 o3a +#endif + +#undef x4a +#undef o4a +#if pipe_len == 16 +#define x4a x4 drain_pipe(o4) +#define o4a SS(64,o4) +#undef MKBB +#define MKBB MKB +#else +#define x4a x4 +#define o4a o4 +#endif + +#undef x5 +#undef o5 +#if ( MKBB / 8 ) % 2 +#define x5 x4a p8(o4a) +#define o5 SS(32,o4a) +#else +#define x5 x4a +#define o5 o4a +#endif + +#undef x5a +#undef o5a +#if pipe_len == 8 +#define x5a x5 drain_pipe(o5) +#define o5a SS(32,o5) +#undef MKBB +#define MKBB MKB +#else +#define x5a x5 +#define o5a o5 +#endif + +#undef x6 +#undef o6 +#if ( MKBB / 4 ) % 2 +#define x6 x5a p4(o5a) +#define o6 SS(16,o5a) +#else +#define x6 x5a +#define o6 o5a +#endif + +#undef x6a +#undef o6a +#if pipe_len == 4 +#define x6a x6 drain_pipe(o6) +#define o6a SS(16,o6) +#undef MKBB +#define MKBB MKB +#else +#define x6a x6 +#define o6a o6 +#endif + +#undef x7 +#undef o7 +#if ( MKB / 2 ) % 2 +#define x7 x6a p2(o6a) +#define o7 SS(8,o6a) +#else +#define x7 x6a +#define o7 o6a +#endif + +#undef x7a +#undef o7a +#if pipe_len == 2 +#define x7a x7 drain_pipe(o7) +#define o7a SS(8,o7) +#undef MKBB +#define MKBB MKB +#else +#define x7a x7 +#define o7a o7 +#endif + +#undef x8 +#undef o8 +#if ( MKB / 1 ) % 2 +#define x8 x7a p1(o7a) +#define o8 SS(4,o7a) +#else +#define x8 x7a +#define o8 o7a +#endif + +#undef x8a +#undef o8a +#if pipe_len == 1 +#define x8a x8 drain_pipe(o8) +#define o8a SS(4,o8) +#undef MKBB +#define MKBB MKB +#else +#define x8a x8 +#define o8a o8 +#endif + +#undef KB_block +#define KB_block x8a diff --git a/kaldi_io/src/tools/ATLAS/include/contrib/camm_scale.h b/kaldi_io/src/tools/ATLAS/include/contrib/camm_scale.h new file mode 100644 index 0000000..35e9e59 --- /dev/null +++ b/kaldi_io/src/tools/ATLAS/include/contrib/camm_scale.h @@ -0,0 +1,215 @@ +#ifndef CAMM_SCALE_H +#define CAMM_SCALE_H /*+ To stop multiple inclusions. +*/ + +#include "camm_util.h" + +#undef spf +#define spf(a_,b_) f(t0,a_,b_) + +#ifdef SCPLX +#ifdef BETAX +#undef SSREG +#define SSREG 2 +#undef lbx +#define lbx pls(4,ax,1) ps(0,1,1) pm(SSREG,1) +#undef cxx +#define cxx pm(1,3) ps(177,3,3) pa(3,2) +#undef pcx +#define pcx pc(2,3) +#else +#undef lbx +#define lbx +#undef cxx +#define cxx +#undef pcx +#define pcx +#endif +#undef lb +#define lb pls(0,ax,0) ps(0,0,0) lbx +#undef c +#define c(a_) pl(a_ ## 0,si,2) pcx pm(0,2) cxx pu(2,a_ ## 0,si) +#undef cp +#define cp(a_,b_) pl(a_ ## 0,si,2) pcx pm(0,2) spf(b_,si) cxx pu(2,a_ ## 0,si) +#undef c1_2 +#define c1_2(a_) px(2) pld(a_ ## 0,si,2) pcx pm(0,2) cxx pud(2,a_ ## 0,si) +#undef ub +#define ub +#endif + +#ifdef SREAL +#undef lb +#define lb pls(0,ax,0) ps(0,0,0) +#undef c +#define c(a_) pl(a_ ## 0,si,2) pm(0,2) pu(2,a_ ## 0,si) +#undef cp +#define cp(a_,b_) pl(a_ ## 0,si,2) spf(b_,si) pm(0,2) pu(2,a_ ## 0,si) +#undef c1_2 +#define c1_2(a_) px(2) pld(a_ ## 0,si,2) pm(0,2) pud(2,a_ ## 0,si) +#undef c1_4 +#define c1_4(a_) pls(a_ ## 0,si,2) pm(0,2) pus(2,a_ ## 0,si) +#undef ub +#define ub +#endif + +#ifdef DREAL +#undef lb +#define lb fl(0,ax) +#undef c +#define c(a_) fl(a_ ## 0,si) fm(1,0) fl(a_ ## 8,si) fm(2,0) fx1 \ + fp(a_ ## 0,si) fp(a_ ## 8,si) +#undef cp +#define cp(a_,b_) fl(a_ ## 0,si) fm(1,0) fl(a_ ## 8,si) spf(b_,si) fm(2,0) fx1 \ + fp(a_ ## 0,si) fp(a_ ## 8,si) +#undef c1_2 +#define c1_2(a_) fl(a_ ## 0,si) fm(1,0) fp(a_ ## 0,si) +#undef ub +#define ub fc(0) +#endif + +#ifdef DCPLX +#undef lb +#define lb fl(0,ax) fl(8,ax) +#undef c +#define c(a_) fl(a_ ## 0,si) fl(a_ ## 8,si) fd(3) fm(2,0) fd(3) \ + fm(2,0) fx(3) fm(4,0) fx(2) fm(5,0) fap(0,2) fx(2) fsp(2) fx1 \ + fp(a_ ## 0,si) fp(a_ ## 8,si) +#undef cp +#define cp(a_,b_) fl(a_ ## 0,si) fl(a_ ## 8,si) fd(3) fm(2,0) fd(3) \ + fm(2,0) fx(3) spf(b_,si) fm(4,0) fx(2) fm(5,0) fap(0,2) fx(2) \ + fsp(2) fx1 fp(a_ ## 0,si) fp(a_ ## 8,si) +#undef ub +#define ub fc(0) fc(0) +#endif + +#undef sbl1 +#define sbl1 c1_4(0x0) +#undef sbl2 +#define sbl2 c1_2(0x0) +#undef sbl4 +#define sbl4 cp(0x0,0x40) +#undef sbl8 +#define sbl8 sbl4 c(0x1) +#undef sbl16 +#define sbl16 sbl8 cp(0x2,0x60) c(0x3) + +#undef sinc16 +#define sinc16 a(0x40,si) +#undef sinc8 +#define sinc8 a(0x20,si) +#undef sinc4 +#define sinc4 a(0x10,si) +#undef sinc2 +#define sinc2 a(0x8,si) +#undef sinc1 +#define sinc1 a(0x4,si) + +#undef SCALE +#define SCALE Mjoin(Mjoin(PREC,Mjoin(scale,BLC)),FEXT) + +#undef MY_FUNCTION +#define MY_FUNCTION SCALE + +static void +MY_FUNCTION(const TYPE *b,TYPE *c,int len) { + + const TYPE *ce=c+len; +#if defined(BETAX) && defined(SCPLX) + const TYPE z1[2]={{1.0,-1.0},{1.0,-1.0}},*z=z1; +#endif + NO_INLINE + +#ifndef SREAL + len+=len; +#endif +#ifdef DCPLX + len+=len; +#endif + + + ASM( + + "pushl %%ebx\n\t" + a(4,sp) + + + "movl %0,%%esi\n\t" + + spf(0x00,si) + spf(0x20,si) + + "movl %1,%%eax\n\t" + "movl %2,%%edi\n\t" + +#if defined(BETAX) && defined(SCPLX) + "movl %3,%%ebx\n\t" + pl(0,bx,SSREG) +#endif + + lb + + lab(loop) + + test(-16,di) + je(8) + sub(16,di) + align + + sbl16 + sinc16 + + jmp(loop) + align + + lab(8) + + test(8,di) + je(4) + + sbl8 + sinc8 + + lab(4) + + test(4,di) + je(2) + + sbl4 + sinc4 + + lab(2) + +#ifndef DCPLX + test(2,di) + je(1) + + sbl2 + sinc2 + + lab(1) + +#ifdef SREAL + test(1,di) + je(stop) + + sbl1 + sinc1 + + lab(stop) +#endif +#endif + + ub + + a(-4,sp) + "popl %%ebx\n\t" + + + ::"m" (c),"m" (b), "m" (len) +#if defined(BETAX) && defined(SCPLX) + ,"m" (z) +#endif + : "si","ax","di"); + + +} +#endif /* CAMM_SCALE_H */ diff --git a/kaldi_io/src/tools/ATLAS/include/contrib/camm_strat1.h b/kaldi_io/src/tools/ATLAS/include/contrib/camm_strat1.h new file mode 100644 index 0000000..4a92006 --- /dev/null +++ b/kaldi_io/src/tools/ATLAS/include/contrib/camm_strat1.h @@ -0,0 +1,2982 @@ +#include "camm_util.h" + +#undef p1_4_swap_1 +#define p1_4_swap_1(a_) \ + pls(a_,ax,1) \ + pls(a_,cx,0) \ + pus(0,a_,ax) \ + pus(1,a_,cx) +#undef p1_2_swap_1 +#define p1_2_swap_1(a_) \ + px(1) \ + pld(a_,ax,1) \ + px(0) \ + pld(a_,cx,0) \ + pud(0,a_,ax) \ + pud(1,a_,cx) +#undef p1_swap_1 +#define p1_swap_1(a_) \ + plq(a_,ax,1) \ + pl(a_,cx,0) \ + puq(0,a_,ax) \ + pu(1,a_,cx) +#undef p2_swap_1 +#define p2_swap_1(a_) \ + plq(SS(a_,RS4),ax,3) \ + pl(SS(a_,RS4),cx,2) \ + puq(0,a_,ax) \ + pu(1,a_,cx) \ + f(nta,SS(a_,MM((SS(2,CL)),RS4)),ax) \ + plq(SS(a_,MM(2,RS4)),ax,1) \ + f(nta,SS(a_,MM((SS(2,CL)),RS4)),cx) \ + pl(SS(a_,MM(2,RS4)),cx,0) \ + puq(2,SS(a_,RS4),ax) \ + pu(3,SS(a_,RS4),cx) +#undef lpswap_1 +#define lpswap_1(a_) \ + f(nta,SS(a_,MM((SS(0,CL)),RS4)),ax) \ + plq(SS(a_,MM(0,RS4)),ax,1) \ + f(nta,SS(a_,MM((SS(0,CL)),RS4)),cx) \ + pl(SS(a_,MM(0,RS4)),cx,0) +#undef dpswap_1 +#define dpswap_1(a_) \ + plq(SS(a_,RS4),ax,3) \ + pl(SS(a_,RS4),cx,2) \ + puq(0,a_,ax) \ + pu(1,a_,cx) \ + puq(2,SS(a_,RS4),ax) \ + pu(3,SS(a_,RS4),cx) +#undef plswap_1 +#define plswap_1 8 + + +#undef p1_4_scal_3 +#define p1_4_scal_3(a_) \ + pls(a_,ax,0) \ + pmsr(6,0) \ + pus(0,a_,ax) +#undef p1_2_scal_3 +#define p1_2_scal_3(a_) \ + pld(a_,ax,0) \ + pm(6,0) \ + pud(0,a_,ax) +#undef p1_scal_3 +#define p1_scal_3(a_) \ + plq(a_,ax,0) \ + pm(6,0) \ + puq(0,a_,ax) +#undef p2_scal_3 +#define p2_scal_3(a_) \ + plq(a_,ax,0) \ + plq(SS(a_,RS4),ax,1) \ + pm(6,0) \ + pm(6,1) \ + puq(0,a_,ax) \ + puq(1,SS(a_,RS4),ax) +#undef p4_scal_3 +#define p4_scal_3(a_) \ + plq(SS(a_,MM(3,RS4)),ax,3) \ + pm(6,2) \ + puq(0,a_,ax) \ + f(nta,SS(a_,MM((SS(2,CL)),RS4)),ax) \ + plq(SS(a_,MM(4,RS4)),ax,0) \ + pm(6,3) \ + puq(1,SS(a_,RS4),ax) \ + f(nta,SS(a_,MM((SS(2,CL)),RS4)),ax) \ + plq(SS(a_,MM(5,RS4)),ax,1) \ + pm(6,0) \ + puq(2,SS(a_,MM(2,RS4)),ax) \ + f(nta,SS(a_,MM((SS(4,CL)),RS4)),ax) \ + plq(SS(a_,MM(6,RS4)),ax,2) \ + pm(6,1) \ + puq(3,SS(a_,MM(3,RS4)),ax) \ + f(nta,SS(a_,MM((SS(4,CL)),RS4)),ax) +#undef lpscal_3 +#define lpscal_3(a_) \ + f(nta,SS(a_,MM((SS(0,CL)),RS4)),ax) \ + plq(SS(a_,MM(0,RS4)),ax,0) \ + plq(SS(a_,MM(1,RS4)),ax,1) \ + pm(6,0) \ + f(nta,SS(a_,MM((SS(0,CL)),RS4)),ax) \ + plq(SS(a_,MM(2,RS4)),ax,2) \ + pm(6,1) +#undef dpscal_3 +#define dpscal_3(a_) \ + plq(SS(a_,MM(3,RS4)),ax,3) \ + pm(6,2) \ + puq(0,a_,ax) \ + pm(6,3) \ + puq(1,SS(a_,RS4),ax) \ + puq(2,SS(a_,MM(2,RS4)),ax) \ + puq(3,SS(a_,MM(3,RS4)),ax) +#undef plscal_3 +#define plscal_3 16 + +#undef p1_4_scal_3c +#define p1_4_scal_3c(a_) +#undef p1_2_scal_3c +#define p1_2_scal_3c(a_) \ + pld(a_,ax,0) \ + pc(0,1) \ + pm(6,0) \ + ps(CSHUF,1,1) \ + pm(7,1) \ + pa(1,0) \ + pud(0,a_,ax) +#undef p1_scal_3c +#define p1_scal_3c(a_) \ + plq(a_,ax,0) \ + pc(0,1) \ + pm(6,0) \ + ps(CSHUF,1,1) \ + pm(7,1) \ + pa(1,0) \ + puq(0,a_,ax) +#undef p2_scal_3c +#define p2_scal_3c(a_) \ + plq(a_,ax,0) \ + plq(SS(a_,RS4),ax,1) \ + pc(0,2) \ + pm(6,0) \ + ps(CSHUF,2,2) \ + pm(7,2) \ + pa(2,0) \ + puq(0,a_,ax) \ + pc(1,3) \ + pm(6,1) \ + ps(CSHUF,3,3) \ + pm(7,3) \ + pa(3,1) \ + puq(1,SS(a_,RS4),ax) +#undef p4_scal_3c +#define p4_scal_3c(a_) \ + pm(7,5) \ + pa(5,1) \ + puq(0,a_,ax) \ + ps(CSHUF,4,4) \ + f(nta,SS(a_,MM((SS(2,CL)),RS4)),ax) \ + plq(SS(a_,MM(4,RS4)),ax,0) \ + pc(3,5) \ + pm(6,3) \ + pm(7,4) \ + pa(4,2) \ + puq(1,SS(a_,RS4),ax) \ + ps(CSHUF,5,5) \ + plq(SS(a_,MM(5,RS4)),ax,1) \ + pc(0,4) \ + pm(6,0) \ + pm(7,5) \ + pa(5,3) \ + puq(2,SS(a_,MM(2,RS4)),ax) \ + ps(CSHUF,4,4) \ + plq(SS(a_,MM(6,RS4)),ax,2) \ + pc(1,5) \ + pm(6,1) \ + pm(7,4) \ + pa(4,0) \ + puq(3,SS(a_,MM(3,RS4)),ax) \ + ps(CSHUF,5,5) \ + plq(SS(a_,MM(7,RS4)),ax,3) \ + pc(2,4) \ + pm(6,2) +#undef lpscal_3c +#define lpscal_3c(a_) \ + f(nta,SS(a_,MM((SS(0,CL)),RS4)),ax) \ + plq(SS(a_,MM(0,RS4)),ax,0) \ + plq(SS(a_,MM(1,RS4)),ax,1) \ + pc(0,4) \ + pm(6,0) \ + ps(CSHUF,4,4) \ + plq(SS(a_,MM(2,RS4)),ax,2) \ + pc(1,5) \ + pm(6,1) \ + pm(7,4) \ + pa(4,0) \ + ps(CSHUF,5,5) \ + plq(SS(a_,MM(3,RS4)),ax,3) \ + pc(2,4) \ + pm(6,2) +#undef dpscal_3c +#define dpscal_3c(a_) \ + pm(7,5) \ + pa(5,1) \ + ps(CSHUF,4,4) \ + puq(0,a_,ax) \ + pm(7,4) \ + pa(4,2) \ + pc(3,5) \ + pm(6,3) \ + puq(1,SS(a_,RS4),ax) \ + ps(CSHUF,5,5) \ + puq(2,SS(a_,MM(2,RS4)),ax) \ + pm(7,5) \ + pa(5,3) \ + puq(3,SS(a_,MM(3,RS4)),ax) +#undef plscal_3c +#define plscal_3c 16 + +#undef p1_4_scal_4 +#define p1_4_scal_4(a_) \ + pls(SS(a_,MM(0,RS4)),ax,0) \ + pmsr(6,0) \ + pus(0,a_,ax) +#undef p1_2_scal_4 +#define p1_2_scal_4(a_) \ + pld(SS(a_,MM(0,RS4)),ax,0) \ + pm(6,0) \ + pud(0,a_,ax) +#undef p1_scal_4 +#define p1_scal_4(a_) \ + plq(SS(a_,MM(0,RS4)),ax,0) \ + pm(6,0) \ + puq(0,a_,ax) +#undef p2_scal_4 +#define p2_scal_4(a_) \ + plq(SS(a_,MM(0,RS4)),ax,0) \ + plq(SS(a_,MM(1,RS4)),ax,1) \ + pm(6,0) \ + pm(6,1) \ + puq(0,a_,ax) \ + puq(1,SS(a_,RS4),ax) +#undef p4_scal_4 +#define p4_scal_4(a_) \ + f(nta,SS(a_,MM((SS(0,CL)),RS4)),ax) \ + plq(SS(a_,MM(0,RS4)),ax,0) \ + plq(SS(a_,MM(1,RS4)),ax,1) \ + plq(SS(a_,MM(2,RS4)),ax,2) \ + plq(SS(a_,MM(3,RS4)),ax,3) \ + pm(6,0) \ + pm(6,1) \ + pm(6,2) \ + pm(6,3) \ + puq(0,a_,ax) \ + puq(1,SS(a_,RS4),ax) \ + puq(2,SS(a_,MM(2,RS4)),ax) \ + puq(3,SS(a_,MM(3,RS4)),ax) +#undef p8_scal_4 +#define p8_scal_4(a_) \ + f(nta,SS(a_,MM((SS(2,CL)),RS4)),ax) \ + plq(SS(a_,MM(0,RS4)),ax,0) \ + plq(SS(a_,MM(1,RS4)),ax,1) \ + plq(SS(a_,MM(2,RS4)),ax,2) \ + plq(SS(a_,MM(3,RS4)),ax,3) \ + plq(SS(a_,MM(4,RS4)),ax,4) \ + plq(SS(a_,MM(5,RS4)),ax,5) \ + plq(SS(a_,MM(6,RS4)),ax,7) \ + pm(6,0) \ + pm(6,1) \ + pm(6,2) \ + puq(0,a_,ax) \ + pm(6,3) \ + pm(6,4) \ + pm(6,5) \ + plq(SS(a_,MM(7,RS4)),ax,0) \ + pm(6,7) \ + pm(6,0) \ + puq(1,SS(a_,RS4),ax) \ + puq(2,SS(a_,MM(2,RS4)),ax) \ + puq(3,SS(a_,MM(3,RS4)),ax) \ + puq(4,SS(a_,MM(4,RS4)),ax) \ + puq(5,SS(a_,MM(5,RS4)),ax) \ + puq(7,SS(a_,MM(6,RS4)),ax) \ + puq(0,SS(a_,MM(7,RS4)),ax) +#undef lpscal_4 +#define lpscal_4(a_) +#undef dpscal_4 +#define dpscal_4(a_) p4_scal_4(a_) +#undef plscal_4 +#define plscal_4 16 + +#undef p1_4_scal_4c +#define p1_4_scal_4c(a_) +#undef p1_2_scal_4c +#define p1_2_scal_4c(a_) \ + pld(a_,ax,0) \ + pc(0,1) \ + pm(6,0) \ + ps(CSHUF,1,1) \ + pm(7,1) \ + pa(1,0) \ + pud(0,a_,ax) +#undef p1_scal_4c +#define p1_scal_4c(a_) \ + plq(a_,ax,0) \ + pc(0,1) \ + pm(6,0) \ + ps(CSHUF,1,1) \ + pm(7,1) \ + pa(1,0) \ + puq(0,a_,ax) +#undef p2_scal_4c +#define p2_scal_4c(a_) \ + plq(SS(a_,MM(0,RS4)),ax,0) \ + plq(SS(a_,MM(1,RS4)),ax,1) \ + pc(0,4) \ + pc(1,5) \ + pm(6,0) \ + pm(6,1) \ + ps(CSHUF,4,4) \ + ps(CSHUF,5,5) \ + pm(7,4) \ + pa(4,0) \ + pm(7,5) \ + pa(5,1) \ + puq(0,a_,ax) \ + puq(1,SS(a_,RS4),ax) +#undef p4_scal_4c +#define p4_scal_4c(a_) \ + f(nta,SS(a_,MM((SS(2,CL)),RS4)),ax) \ + plq(SS(a_,MM(0,RS4)),ax,0) \ + plq(SS(a_,MM(1,RS4)),ax,1) \ + plq(SS(a_,MM(2,RS4)),ax,2) \ + plq(SS(a_,MM(3,RS4)),ax,3) \ + pc(0,4) \ + pc(1,5) \ + pm(6,0) \ + pm(6,1) \ + ps(CSHUF,4,4) \ + ps(CSHUF,5,5) \ + pm(7,4) \ + pa(4,0) \ + pc(2,4) \ + pm(7,5) \ + pa(5,1) \ + pc(3,5) \ + pm(6,2) \ + pm(6,3) \ + ps(CSHUF,4,4) \ + ps(CSHUF,5,5) \ + pm(7,4) \ + pa(4,2) \ + pm(7,5) \ + pa(5,3) \ + puq(0,a_,ax) \ + puq(1,SS(a_,RS4),ax) \ + puq(2,SS(a_,MM(2,RS4)),ax) \ + puq(3,SS(a_,MM(3,RS4)),ax) +#undef lpscal_4c +#define lpscal_4c(a_) +#undef dpscal_4c +#define dpscal_4c(a_) p4_scal_4c(a_) +#undef plscal_4c +#define plscal_4c 16 + +#undef p1_4_scal_1 +#define p1_4_scal_1(a_) \ + pls(a_,ax,1) \ + pmsr(0,1) \ + pus(1,a_,ax) +#undef p1_2_scal_1 +#define p1_2_scal_1(a_) \ + px(1) \ + pld(a_,ax,1) \ + pm(0,1) \ + pud(1,a_,ax) +#undef p1_scal_1 +#define p1_scal_1(a_) \ + plq(a_,ax,1) \ + pm(0,1) \ + puq(1,a_,ax) +#undef p2_scal_1 +#define p2_scal_1(a_) \ + plq(a_,ax,1) \ + plq(SS(a_,RS4),ax,2) \ + pm(0,1) \ + pm(0,2) \ + puq(1,a_,ax) \ + puq(2,SS(a_,RS4),ax) +#undef p4_scal_1 +#define p4_scal_1(a_) \ + f(nta,SS(a_,MM((SS(2,CL)),RS4)),ax) \ + plq(SS(a_,MM(2,RS4)),ax,1) \ + pm(0,3) \ + puq(7,a_,ax) \ + plq(SS(a_,MM(3,RS4)),ax,2) \ + pm(0,1) \ + puq(3,SS(a_,MM(1,RS4)),ax) \ + f(nta,SS(a_,MM((SS(4,CL)),RS4)),ax) \ + plq(SS(a_,MM(4,RS4)),ax,7) \ + pm(0,2) \ + puq(1,SS(a_,MM(2,RS4)),ax) \ + plq(SS(a_,MM(5,RS4)),ax,3) \ + pm(0,7) \ + puq(2,SS(a_,MM(3,RS4)),ax) +#undef lpscal_1 +#define lpscal_1(a_) \ + plq(a_,ax,7) \ + plq(SS(a_,MM(1,RS4)),ax,3) \ + pm(0,7) +#undef dpscal_1 +#define dpscal_1(a_) \ + plq(SS(a_,MM(2,RS4)),ax,1) \ + pm(0,3) \ + puq(7,a_,ax) \ + plq(SS(a_,MM(3,RS4)),ax,2) \ + pm(0,1) \ + puq(3,SS(a_,MM(1,RS4)),ax) \ + pm(0,2) \ + puq(1,SS(a_,MM(2,RS4)),ax) \ + puq(2,SS(a_,MM(3,RS4)),ax) +#undef plscal_1 +#define plscal_1 RS4 + + +#undef p1_4_set_1 +#define p1_4_set_1(a_) \ + pls(a_,ax,1) \ + pcs(0,1) \ + pus(1,a_,ax) +#undef p1_2_set_1 +#define p1_2_set_1(a_) \ + px(1) \ + pld(a_,ax,1) \ + pc(0,1) \ + pud(1,a_,ax) +#undef p1_set_1 +#define p1_set_1(a_) \ + plq(a_,ax,1) \ + pc(0,1) \ + puq(1,a_,ax) +#undef p2_set_1 +#define p2_set_1(a_) \ + plq(a_,ax,1) \ + plq(SS(a_,RS4),ax,2) \ + pc(0,1) \ + pc(0,2) \ + puq(1,a_,ax) \ + puq(2,SS(a_,RS4),ax) +#undef p4_set_1 +#define p4_set_1(a_) \ + f(nta,SS(a_,MM((SS(2,CL)),RS4)),ax) \ + plq(SS(a_,MM(2,RS4)),ax,1) \ + pc(0,3) \ + puq(7,a_,ax) \ + plq(SS(a_,MM(3,RS4)),ax,2) \ + pc(0,1) \ + puq(3,SS(a_,MM(1,RS4)),ax) \ + f(nta,SS(a_,MM((SS(4,CL)),RS4)),ax) \ + plq(SS(a_,MM(4,RS4)),ax,7) \ + pc(0,2) \ + puq(1,SS(a_,MM(2,RS4)),ax) \ + plq(SS(a_,MM(5,RS4)),ax,3) \ + pc(0,7) \ + puq(2,SS(a_,MM(3,RS4)),ax) +#undef lpset_1 +#define lpset_1(a_) \ + plq(a_,ax,7) \ + plq(SS(a_,MM(1,RS4)),ax,3) \ + pc(0,7) +#undef dpset_1 +#define dpset_1(a_) \ + plq(SS(a_,MM(2,RS4)),ax,1) \ + pc(0,3) \ + puq(7,a_,ax) \ + plq(SS(a_,MM(3,RS4)),ax,2) \ + pc(0,1) \ + puq(3,SS(a_,MM(1,RS4)),ax) \ + pc(0,2) \ + puq(1,SS(a_,MM(2,RS4)),ax) \ + puq(2,SS(a_,MM(3,RS4)),ax) +#undef plset_1 +#define plset_1 RS4 + + +#undef p1_4_set_2 +#define p1_4_set_2(a_) \ + pus(0,a_,ax) +#undef p1_2_set_2 +#define p1_2_set_2(a_) \ + pud(0,a_,ax) +#undef p1_set_2 +#define p1_set_2(a_) \ + puq(0,a_,ax) +#undef p2_set_2 +#define p2_set_2(a_) \ + puq(0,a_,ax) \ + puq(0,SS(a_,RS4),ax) +#undef p4_set_2 +#define p4_set_2(a_) \ + f(nta,SS(a_,MM((SS(2,CL)),RS4)),ax) \ + puq(0,a_,ax) \ + puq(0,SS(a_,MM(1,RS4)),ax) \ + f(nta,SS(a_,MM((SS(4,CL)),RS4)),ax) \ + puq(0,SS(a_,MM(2,RS4)),ax) \ + puq(0,SS(a_,MM(3,RS4)),ax) +#undef lpset_2 +#define lpset_2(a_) +#undef dpset_2 +#define dpset_2(a_) \ + puq(0,a_,ax) \ + puq(0,SS(a_,MM(1,RS4)),ax) \ + puq(0,SS(a_,MM(2,RS4)),ax) \ + puq(0,SS(a_,MM(3,RS4)),ax) +#undef plset_2 +#define plset_2 RS4 + + +#undef p1_4_set_3 +#define p1_4_set_3(a_) \ + pus(0,a_,ax) +#undef p1_2_set_3 +#define p1_2_set_3(a_) \ + pud(0,a_,ax) +#undef p1_set_3 +#define p1_set_3(a_) \ + puq(0,SS(a_,MM(0,RS4)),ax) +#undef p2_set_3 +#define p2_set_3(a_) \ + puq(0,SS(a_,MM(0,RS4)),ax) \ + puq(0,SS(a_,MM(1,RS4)),ax) +#undef p4_set_3 +#define p4_set_3(a_) \ + puq(0,SS(a_,MM(0,RS4)),ax) \ + puq(0,SS(a_,MM(1,RS4)),ax) \ + puq(0,SS(a_,MM(2,RS4)),ax) \ + puq(0,SS(a_,MM(3,RS4)),ax) +#undef p8_set_3 +#define p8_set_3(a_) \ + puq(0,SS(a_,MM(0,RS4)),ax) \ + puq(0,SS(a_,MM(1,RS4)),ax) \ + puq(0,SS(a_,MM(2,RS4)),ax) \ + puq(0,SS(a_,MM(3,RS4)),ax) \ + puq(0,SS(a_,MM(4,RS4)),ax) \ + puq(0,SS(a_,MM(5,RS4)),ax) \ + puq(0,SS(a_,MM(6,RS4)),ax) \ + puq(0,SS(a_,MM(7,RS4)),ax) +#undef lpset_3 +#define lpset_3(a_) +#undef dpset_3 +#define dpset_3(a_) p8_set_3(a_) +#undef plset_3 +#define plset_3 32 + + +#undef p1_4_0x1_nrm2_1 +#define p1_4_0x1_nrm2_1(a_) \ + pls(a_,ax,1) \ + pmsr(1,1) \ + pasr(1,0) +#undef p1_2_0x1_nrm2_1 +#define p1_2_0x1_nrm2_1(a_) \ + px(1) \ + pld(a_,ax,1) \ + pm(1,1) \ + pa(1,0) +#undef p1_0x1_nrm2_1 +#define p1_0x1_nrm2_1(a_) \ + plq(a_,ax,1) \ + pm(1,1) \ + pa(1,0) +#undef p2_0x1_nrm2_1 +#define p2_0x1_nrm2_1(a_) \ + plq(a_,ax,1) \ + plq(SS(a_,RS4),ax,2) \ + pm(1,1) \ + pm(2,2) \ + pa(1,0) \ + pm(2,0) +#undef p4_0x1_nrm2_1 +#define p4_0x1_nrm2_1(a_) \ + f(nta,SS(a_,MM((SS(2,CL)),RS4)),ax) \ + plq(SS(a_,MM(2,RS4)),ax,1) \ + pm(3,3) \ + pa(7,0) \ + plq(SS(a_,MM(3,RS4)),ax,2) \ + pm(1,1) \ + pa(3,0) \ + f(nta,SS(a_,MM((SS(4,CL)),RS4)),ax) \ + plq(SS(a_,MM(4,RS4)),ax,7) \ + pm(2,2) \ + pa(1,0) \ + plq(SS(a_,MM(5,RS4)),ax,3) \ + pm(7,7) \ + pa(2,0) +#undef lp0x1_nrm2_1 +#define lp0x1_nrm2_1(a_) \ + plq(a_,ax,7) \ + plq(SS(a_,MM(1,RS4)),ax,3) \ + pm(7,7) +#undef dp0x1_nrm2_1 +#define dp0x1_nrm2_1(a_) \ + plq(SS(a_,MM(2,RS4)),ax,1) \ + pm(3,3) \ + pa(7,0) \ + plq(SS(a_,MM(3,RS4)),ax,2) \ + pm(1,1) \ + pa(3,0) \ + pm(2,2) \ + pa(1,0) \ + pa(2,0) +#undef pl0x1_nrm2_1 +#define pl0x1_nrm2_1 RS4 + + +#undef p1_4_nrm2_2 +#define p1_4_nrm2_2(a_) \ + pls(a_,ax,1) dbg(1) \ + pan(4,1) dbg(1) \ + pcs(5,6) dbg(6) \ + pcs(5,7) dbg(7) \ + paxs(1,5) dbg(5) \ + prps(5,2) dbg(2) \ + px(3) \ + pcms(0,2,3) dbg(3) \ + pan(3,7) dbg(7) \ + pann(5,3) dbg(3) \ + pasr(3,7) dbg(7) \ + pcs(7,5) dbg(5) \ + pdsr(5,6) dbg(6) \ + pdsr(5,1) dbg(1) \ + pmsr(6,6) dbg(6) \ + pmsr(1,1) dbg(1) \ + pm(6,0) dbg(0) \ + pasr(1,0) dbg(0) +#undef p1_2_nrm2_2 +#define p1_2_nrm2_2(a_) \ + px(1) pld(a_,ax,1) dbg(1) \ + pan(4,1) dbg(1) \ + pc(5,6) dbg(6) \ + pc(5,7) dbg(7) \ + pax(1,5) dbg(5) \ + prp(5,2) dbg(2) \ + px(3) \ + pcm(0,2,3)dbg(3) \ + pan(3,7) dbg(7) \ + pann(5,3) dbg(3) \ + pa(3,7) dbg(7) \ + pc(7,5) dbg(5) \ + pd(5,6) dbg(6) \ + pd(5,1) dbg(1) \ + pm(6,6) dbg(6) \ + pm(1,1) dbg(1) \ + pm(6,0) dbg(0) \ + pa(1,0) dbg(0) +#undef p1_nrm2_2 +#define p1_nrm2_2(a_) \ + plq(a_,ax,1) dbg(1) \ + pan(4,1) dbg(1) \ + pc(5,6) dbg(6) \ + pc(5,7) dbg(7) \ + pax(1,5) dbg(5) \ + prp(5,2) dbg(2) \ + px(3) \ + pcm(0,2,3)dbg(3) \ + pan(3,7) dbg(7) \ + pann(5,3) dbg(3) \ + pa(3,7) dbg(7) \ + pc(7,5) dbg(5) \ + pd(5,6) dbg(6) \ + pd(5,1) dbg(1) \ + pm(6,6) dbg(6) \ + pm(1,1) dbg(1) \ + pm(6,0) dbg(0) \ + pa(1,0) dbg(0) +#define p2_nrm2_2(a_) \ + plq(SS(a_,RS4),ax,1) dbg(1) \ + pan(4,1) dbg(1) \ + pc(5,6) dbg(6) \ + pc(5,7) dbg(7) \ + pax(1,5) dbg(5) \ + prp(5,2) dbg(2) \ + px(3) \ + pcm(0,2,3)dbg(3) \ + pan(3,7) dbg(7) \ + pann(5,3) dbg(3) \ + pa(3,7) dbg(7) \ + pc(7,5) dbg(5) \ + pd(5,6) dbg(6) \ + pd(5,1) dbg(1) \ + pm(6,6) dbg(6) \ + pm(1,1) dbg(1) \ + pm(6,0) dbg(0) \ + pa(1,0) dbg(0) \ + f(nta,SS(a_,MM((SS(2,CL)),RS4)),ax) \ + plq(SS(a_,MM(2,RS4)),ax,1) dbg(1) \ + pan(4,1) dbg(1) \ + pc(5,6) dbg(6) \ + pc(5,7) dbg(7) \ + pax(1,5) dbg(5) \ + prp(5,2) dbg(2) \ + px(3) \ + pcm(0,2,3)dbg(3) \ + pan(3,7) dbg(7) \ + pann(5,3) dbg(3) \ + pa(3,7) dbg(7) \ + pc(7,5) dbg(5) \ + pd(5,6) dbg(6) \ + pd(5,1) dbg(1) \ + pm(6,6) dbg(6) \ + pm(1,1) dbg(1) \ + pm(6,0) dbg(0) \ + pa(1,0) dbg(0) +#undef lpnrm2_2 +#define lpnrm2_2(a_) \ + f(nta,SS(a_,MM((SS(0,CL)),RS4)),ax) \ + plq(SS(a_,MM(0,RS4)),ax,1) dbg(1) \ + pan(4,1) dbg(1) \ + pc(5,6) dbg(6) \ + pc(5,7) dbg(7) \ + pax(1,5) dbg(5) \ + prp(5,2) dbg(2) \ + px(3) \ + pcm(0,2,3)dbg(3) \ + pan(3,7) dbg(7) \ + pann(5,3) dbg(3) \ + pa(3,7) dbg(7) \ + pc(7,5) dbg(5) \ + pd(5,6) dbg(6) \ + pd(5,1) dbg(1) \ + pm(6,6) dbg(6) \ + pm(1,1) dbg(1) \ + pm(6,0) dbg(0) \ + pa(1,0) dbg(0) +#undef dpnrm2_2 +#define dpnrm2_2(a_) \ + plq(SS(a_,RS4),ax,1) dbg(1) \ + pan(4,1) dbg(1) \ + pc(5,6) dbg(6) \ + pc(5,7) dbg(7) \ + pax(1,5) dbg(5) \ + prp(5,2) dbg(2) \ + px(3) \ + pcm(0,2,3)dbg(3) \ + pan(3,7) dbg(7) \ + pann(5,3) dbg(3) \ + pa(3,7) dbg(7) \ + pc(7,5) dbg(5) \ + pd(5,6) dbg(6) \ + pd(5,1) dbg(1) \ + pm(6,6) dbg(6) \ + pm(1,1) dbg(1) \ + pm(6,0) dbg(0) \ + pa(1,0) dbg(0) +#undef plnrm2_2 +#define plnrm2_2 8 + + +#undef p1_4_nrm2_3 +#define p1_4_nrm2_3(a_) \ + pls(a_,ax,1) dbg(1) \ + pcs(5,6) dbg(6) \ + pan(4,1) dbg(1) \ + paxs(1,5) dbg(5) \ + pdsr(5,6) dbg(6) \ + pdsr(5,1) dbg(1) \ + pmsr(6,6) dbg(6) \ + pmsr(1,1) dbg(1) \ + pm(6,0) dbg(0) \ + pasr(1,0) dbg(0) +#undef p1_2_nrm2_3 +#define p1_2_nrm2_3(a_) \ + px(1) pld(a_,ax,1) dbg(1) \ + pc(5,6) dbg(6) \ + pan(4,1) dbg(1) \ + pax(1,5) dbg(5) \ + pd(5,6) dbg(6) \ + pd(5,1) dbg(1) \ + pm(6,6) dbg(6) \ + pm(1,1) dbg(1) \ + pm(6,0) dbg(0) \ + pa(1,0) dbg(0) +#undef p1_nrm2_3 +#define p1_nrm2_3(a_) \ + plq(a_,ax,1) dbg(1) \ + pc(5,6) dbg(6) \ + pan(4,1) dbg(1) \ + pax(1,5) dbg(5) \ + pd(5,6) dbg(6) \ + pd(5,1) dbg(1) \ + pm(6,6) dbg(6) \ + pm(1,1) dbg(1) \ + pm(6,0) dbg(0) \ + pa(1,0) dbg(0) +#define p2_nrm2_3(a_) \ + plq(SS(a_,RS4),ax,1) dbg(1) \ + pc(5,6) dbg(6) \ + pan(4,1) dbg(1) \ + pax(1,5) dbg(5) \ + pd(5,6) dbg(6) \ + pd(5,1) dbg(1) \ + pm(6,6) dbg(6) \ + pm(1,1) dbg(1) \ + pm(6,0) dbg(0) \ + pa(1,0) dbg(0) \ + f(nta,SS(a_,MM((SS(2,CL)),RS4)),ax) \ + plq(SS(a_,MM(2,RS4)),ax,1) dbg(1) \ + pc(5,6) dbg(6) \ + pan(4,1) dbg(1) \ + pax(1,5) dbg(5) \ + pd(5,6) dbg(6) \ + pd(5,1) dbg(1) \ + pm(6,6) dbg(6) \ + pm(1,1) dbg(1) \ + pm(6,0) dbg(0) \ + pa(1,0) dbg(0) +#undef lpnrm2_3 +#define lpnrm2_3(a_) \ + f(nta,SS(a_,MM((SS(0,CL)),RS4)),ax) \ + plq(SS(a_,MM(0,RS4)),ax,1) dbg(1) \ + pc(5,6) dbg(6) \ + pan(4,1) dbg(1) \ + pax(1,5) dbg(5) \ + pd(5,6) dbg(6) \ + pd(5,1) dbg(1) \ + pm(6,6) dbg(6) \ + pm(1,1) dbg(1) \ + pm(6,0) dbg(0) \ + pa(1,0) dbg(0) +#undef dpnrm2_3 +#define dpnrm2_3(a_) \ + plq(SS(a_,RS4),ax,1) dbg(1) \ + pc(5,6) dbg(6) \ + pan(4,1) dbg(1) \ + pax(1,5) dbg(5) \ + pd(5,6) dbg(6) \ + pd(5,1) dbg(1) \ + pm(6,6) dbg(6) \ + pm(1,1) dbg(1) \ + pm(6,0) dbg(0) \ + pa(1,0) dbg(0) +#undef plnrm2_3 +#define plnrm2_3 8 + +#define block_nrm2_4(a_,b_) \ + Mjoin(pc,a_)(5,6) dbg(6) \ + pan(4,1) dbg(1) \ + Mjoin(pax,a_)(1,5) dbg(5) \ + Mjoin(pc,a_)(2,7) dbg(7) \ + Mjoin(pd,b_)(5,7) dbg(7) \ + Mjoin(pm,b_)(7,6) dbg(6) \ + Mjoin(pm,b_)(7,1) dbg(1) \ + Mjoin(pm,b_)(6,6) dbg(6) \ + Mjoin(pm,b_)(6,0) dbg(0) \ + Mjoin(pm,b_)(1,1) dbg(1) \ + Mjoin(pa,b_)(1,0) dbg(0) + + +/* #undef p1_4_nrm2_4 */ +/* #define p1_4_nrm2_4(a_) \ */ +/* pls(a_,ax,1) dbg(1) \ */ +/* pcs(5,6) dbg(6) \ */ +/* pan(4,1) dbg(1) \ */ +/* paxs(1,5) dbg(5) \ */ +/* pcs(2,7) dbg(7) \ */ +/* pdsr(5,7) dbg(7) \ */ +/* pmsr(7,6) dbg(6) \ */ +/* pmsr(7,1) dbg(1) \ */ +/* pmsr(6,6) dbg(6) \ */ +/* pmsr(6,0) dbg(0) \ */ +/* pmsr(1,1) dbg(1) \ */ +/* pasr(1,0) dbg(0) */ +#undef p1_4_nrm2_4 +#define p1_4_nrm2_4(a_) \ + pls(a_,ax,1) dbg(1) \ + block_nrm2_4(s,sr) +#undef p1_2_nrm2_4 +#define p1_2_nrm2_4(a_) \ + px(1) pld(a_,ax,1) dbg(1) \ + block_nrm2_4(,) +#undef p1_nrm2_4 +#define p1_nrm2_4(a_) \ + plq(a_,ax,1) dbg(1) \ + block_nrm2_4(,) +#define p2_nrm2_4(a_) \ + plq(SS(a_,RS4),ax,1) dbg(1) \ + block_nrm2_4(,) \ + plq(SS(a_,MM(2,RS4)),ax,1) dbg(1) \ + f(nta,SS(a_,MM((SS(2,CL)),RS4)),ax) \ + block_nrm2_4(,) +#undef lpnrm2_4 +#define lpnrm2_4(a_) \ + plq(SS(a_,MM(0,RS4)),ax,1) dbg(1) \ + f(nta,SS(a_,MM((SS(0,CL)),RS4)),ax) \ + block_nrm2_4(,) +#undef dpnrm2_4 +#define dpnrm2_4(a_) \ + plq(SS(a_,RS4),ax,1) dbg(1) \ + block_nrm2_4(,) +#undef plnrm2_4 +#define plnrm2_4 8 + + +#undef p1_4_1x1_1 +#define p1_4_1x1_1(a_) \ + pls(a_,ax,1) \ + pls(a_,bx,0) \ + pm(0,1) \ + pa(1,6) +#undef p1_2_1x1_1 +#define p1_2_1x1_1(a_) \ + pld(a_,ax,1) \ + pld(a_,bx,0) \ + pm(0,1) \ + pa(1,6) +#undef p1_1x1_1 +#define p1_1x1_1(a_) \ + plq(a_,ax,1) \ + plq(a_,bx,0) \ + pm(0,1) \ + pa(0,6) +#undef p2_1x1_1 +#define p2_1x1_1(a_) \ + plq(a_,ax,1) \ + plq(a_,bx,0) \ + plq(SS(a_,RS4),ax,2) \ + plq(SS(a_,RS4),bx,3) \ + pm(0,1) \ + pm(2,3) \ + pa(1,6) \ + pa(3,6) +#undef p4_1x1_1 +#define p4_1x1_1(a_) \ + f(nta,SS(a_,MM(4,RS4)),ax) \ + plq(SS(a_,MM(2,RS4)),ax,1) \ + plq(SS(a_,MM(2,RS4)),ax,1) \ + pm(0,3) \ + puq(7,a_,ax) \ + plq(SS(a_,MM(3,RS4)),ax,2) \ + pm(0,1) \ + puq(3,SS(a_,RS4),ax) \ + f(nta,SS(a_,MM(6,RS4)),ax) \ + plq(SS(a_,MM(4,RS4)),ax,7) \ + pm(0,2) \ + puq(1,SS(a_,MM(2,RS4)),ax) \ + plq(SS(a_,MM(5,RS4)),ax,3) \ + pm(0,7) \ + puq(2,SS(a_,MM(3,RS4)),ax) +#undef lp1x1_1 +#define lp1x1_1(a_) \ + plq(a_,ax,7) \ + plq(SS(a_,RS4),ax,3) \ + pm(0,7) +#undef dp1x1_1 +#define dp1x1_1(a_) \ + plq(SS(,a_,MM(2,RS4)),ax,1) \ + pm(0,3) \ + puq(7,a_,ax) \ + plq(SS(a_,MM(3,RS4)),ax,2) \ + pm(0,1) \ + puq(3,SS(a_,RS4),ax) \ + pm(0,2) \ + puq(1,SS(a_,MM(2,RS4)),ax) \ + puq(2,SS(a_,MM(3,RS4)),ax) +#undef pl1x1_1 +#define pl1x1_1 RS4 + + +#undef p1_4_0x1_asum_1 +#define p1_4_0x1_asum_1(a_) \ + pls(a_,ax,1) \ + pan(4,1) \ + pasr(1,0) +#undef p1_2_0x1_asum_1 +#define p1_2_0x1_asum_1(a_) \ + px(1) \ + pld(a_,ax,1) \ + pan(4,1) \ + pa(1,0) +#undef p1_0x1_asum_1 +#define p1_0x1_asum_1(a_) \ + plq(a_,ax,1) \ + pan(4,1) \ + pa(1,0) +#undef p2_0x1_asum_1 +#define p2_0x1_asum_1(a_) \ + plq(a_,ax,1) \ + plq(SS(a_,RS4),ax,2) \ + pan(4,1) \ + pan(4,2) \ + pa(1,0) \ + pa(2,0) +#undef p4_0x1_asum_1 +#define p4_0x1_asum_1(a_) \ + f(nta,SS(a_,MM((SS(2,CL)),RS4)),ax) \ + plq(SS(a_,MM(2,RS4)),ax,1) \ + pan(4,3) \ + pa(7,0) \ + plq(SS(a_,MM(3,RS4)),ax,2) \ + pan(4,1) \ + pa(3,0) \ + f(nta,SS(a_,MM((SS(4,CL)),RS4)),ax) \ + plq(SS(a_,MM(4,RS4)),ax,7) \ + pan(4,2) \ + pa(1,0) \ + plq(SS(a_,MM(5,RS4)),ax,3) \ + pan(4,7) \ + pa(2,0) +#undef lp0x1_asum_1 +#define lp0x1_asum_1(a_) \ + plq(a_,ax,7) \ + plq(SS(a_,MM(1,RS4)),ax,3) \ + pan(4,7) +#undef dp0x1_asum_1 +#define dp0x1_asum_1(a_) \ + plq(SS(a_,MM(2,RS4)),ax,1) \ + pan(4,3) \ + pa(7,0) \ + plq(SS(a_,MM(3,RS4)),ax,2) \ + pan(4,1) \ + pa(3,0) \ + pan(4,2) \ + pa(1,0) \ + pa(2,0) +#undef pl0x1_asum_1 +#define pl0x1_asum_1 RS4 + + +#undef p1_4_sum_1 +#define p1_4_sum_1(a_) \ + pls(a_,ax,1) \ + pasr(1,0) +#undef p1_2_sum_1 +#define p1_2_sum_1(a_) \ + px(1) \ + pld(a_,ax,1) \ + pa(1,0) +#undef p1_sum_1 +#define p1_sum_1(a_) \ + plq(a_,ax,1) \ + pa(1,0) +#undef p2_sum_1 +#define p2_sum_1(a_) \ + plq(a_,ax,1) \ + plq(SS(a_,RS4),ax,2) \ + pa(1,0) \ + pa(2,0) +#undef p4_sum_1 +#define p4_sum_1(a_) \ + f(nta,SS(a_,MM((SS(2,CL)),RS4)),ax) \ + plq(SS(a_,MM(2,RS4)),ax,1) \ + pa(7,0) \ + plq(SS(a_,MM(3,RS4)),ax,2) \ + pa(3,0) \ + f(nta,SS(a_,MM((SS(4,CL)),RS4)),ax) \ + plq(SS(a_,MM(4,RS4)),ax,7) \ + pa(1,0) \ + plq(SS(a_,MM(5,RS4)),ax,3) \ + pa(2,0) +#undef lpsum_1 +#define lpsum_1(a_) \ + plq(a_,ax,7) \ + plq(SS(a_,MM(1,RS4)),ax,3) +#undef dpsum_1 +#define dpsum_1(a_) \ + plq(SS(a_,MM(2,RS4)),ax,1) \ + pa(7,0) \ + plq(SS(a_,MM(3,RS4)),ax,2) \ + pa(3,0) \ + pa(1,0) \ + pa(2,0) +#undef plsum_1 +#define plsum_1 RS4 + + +#undef p1_4_dot_1 +#define p1_4_dot_1(a_) \ + pls(a_,ax,1) \ + pls(a_,cx,2) \ + pmsr(2,1) \ + pasr(1,0) +#undef p1_2_dot_1 +#define p1_2_dot_1(a_) \ + px(1) \ + pld(a_,ax,1) \ + px(2) \ + pld(a_,cx,2) \ + pm(2,1) \ + pa(1,0) +#undef p1_dot_1 +#define p1_dot_1(a_) \ + plq(a_,ax,1) \ + pl(a_,cx,2) \ + pm(2,1) \ + pa(1,0) +#undef p2_dot_1 +#define p2_dot_1(a_) \ + plq(SS(a_,MM(1,RS4)),ax,1) \ + pl(SS(a_,MM(1,RS4)),cx,2) \ + pm(4,3) \ + pa(3,0) \ + f(nta,SS(a_,MM((SS(2,CL)),RS4)),ax) \ + plq(SS(a_,MM(2,RS4)),ax,3) \ + f(nta,SS(a_,MM((SS(2,CL)),RS4)),cx) \ + pl(SS(a_,MM(2,RS4)),cx,4) \ + pm(2,1) \ + pa(1,0) +#undef lpdot_1 +#define lpdot_1(a_) \ + f(nta,SS(a_,MM((SS(0,CL)),RS4)),ax) \ + plq(a_,ax,3) \ + f(nta,SS(a_,MM((SS(0,CL)),RS4)),cx) \ + pl(a_,cx,4) +#undef dpdot_1 +#define dpdot_1(a_) \ + plq(SS(a_,MM(1,RS4)),ax,1) \ + pl(SS(a_,MM(1,RS4)),cx,2) \ + pm(4,3) \ + pa(3,0) \ + pm(2,1) \ + pa(1,0) +#undef pldot_1 +#define pldot_1 8 + +#undef p1_4_dot_1c +#define p1_4_dot_1c(a_) +#undef p1_2_dot_1c +#define p1_2_dot_1c(a_) \ + px(1) \ + pld(a_,ax,1) \ + px(2) \ + pld(a_,cx,2) \ + pc(1,3) \ + ps(HSHUF,1,1) \ + ps(LSHUF,3,3) \ + pm(7,1) \ + pm(2,3) \ + pa(3,0) \ + pm(2,1) \ + pa(1,6) +#undef p1_dot_1c +#define p1_dot_1c(a_) \ + plq(a_,ax,1) \ + pl(a_,cx,2) \ + pc(1,3) \ + ps(HSHUF,1,1) \ + ps(LSHUF,3,3) \ + pm(7,1) \ + pm(2,3) \ + pa(3,0) \ + pm(2,1) \ + pa(1,6) +#undef p2_dot_1c +#define p2_dot_1c(a_) \ + f(nta,SS(a_,MM((SS(2,CL)),RS4)),ax) \ + plq(SS(a_,MM(1,RS4)),ax,1) \ + pl(SS(a_,MM(1,RS4)),cx,2) \ + pc(3,5) \ + ps(HSHUF,3,3) \ + ps(LSHUF,5,5) \ + pm(7,3) \ + pm(4,5) \ + pa(5,0) \ + pm(4,3) \ + pa(3,6) \ + f(nta,SS(a_,MM((SS(2,CL)),RS4)),cx) \ + pl(SS(a_,MM(2,RS4)),cx,4) \ + plq(SS(a_,MM(2,RS4)),ax,3) \ + pc(1,5) \ + ps(HSHUF,1,1) \ + ps(LSHUF,5,5) \ + pm(7,1) \ + pm(2,5) \ + pa(5,0) \ + pm(2,1) \ + pa(1,6) +#undef lpdot_1c +#define lpdot_1c(a_) \ + f(nta,SS(a_,MM((SS(0,CL)),RS4)),ax) \ + plq(a_,ax,3) \ + f(nta,SS(a_,MM((SS(0,CL)),RS4)),cx) \ + pl(a_,cx,4) +#undef dpdot_1c +#define dpdot_1c(a_) \ + plq(SS(a_,MM(1,RS4)),ax,1) \ + pl(SS(a_,MM(1,RS4)),cx,2) \ + pc(3,5) \ + ps(HSHUF,3,3) \ + ps(LSHUF,5,5) \ + pm(7,3) \ + pm(4,5) \ + pa(5,0) \ + pm(4,3) \ + pa(3,6) \ + pc(1,5) \ + ps(HSHUF,1,1) \ + ps(LSHUF,5,5) \ + pm(7,1) \ + pm(2,5) \ + pa(5,0) \ + pm(2,1) \ + pa(1,6) +#undef pldot_1c +#define pldot_1c 8 + +#undef p1_4_dot_2c +#define p1_4_dot_2c(a_) +#undef p1_2_dot_2c +#define p1_2_dot_2c(a_) \ + px(1) \ + pld(a_,ax,1) \ + px(2) \ + pld(a_,cx,2) \ + pc(1,3) \ + ps(CSHUF,1,1) \ + pm(2,3) \ + pa(3,0) \ + pm(2,1) \ + pa(1,6) +#undef p1_dot_2c +#define p1_dot_2c(a_) \ + plq(a_,ax,1) \ + pl(a_,cx,2) \ + pc(1,3) \ + ps(CSHUF,1,1) \ + pm(2,3) \ + pa(3,0) \ + pm(2,1) \ + pa(1,6) +#undef p2_dot_2c +#define p2_dot_2c(a_) \ + f(nta,SS(a_,MM((SS(2,CL)),RS4)),ax) \ + plq(SS(a_,MM(1,RS4)),ax,1) \ + pl(SS(a_,MM(1,RS4)),cx,2) \ + pc(3,5) \ + ps(CSHUF,3,3) \ + pm(4,5) \ + pa(5,0) \ + pm(4,3) \ + pa(3,6) \ + f(nta,SS(a_,MM((SS(2,CL)),RS4)),cx) \ + pl(SS(a_,MM(2,RS4)),cx,4) \ + plq(SS(a_,MM(2,RS4)),ax,3) \ + pc(1,5) \ + ps(CSHUF,1,1) \ + pm(2,5) \ + pa(5,0) \ + pm(2,1) \ + pa(1,6) +#undef lpdot_2c +#define lpdot_2c(a_) \ + f(nta,SS(a_,MM((SS(0,CL)),RS4)),ax) \ + plq(a_,ax,3) \ + f(nta,SS(a_,MM((SS(0,CL)),RS4)),cx) \ + pl(a_,cx,4) +#undef dpdot_2c +#define dpdot_2c(a_) \ + plq(SS(a_,MM(1,RS4)),ax,1) \ + pl(SS(a_,MM(1,RS4)),cx,2) \ + pc(3,5) \ + ps(CSHUF,3,3) \ + pm(4,5) \ + pa(5,0) \ + pm(4,3) \ + pa(3,6) \ + pc(1,5) \ + ps(CSHUF,1,1) \ + pm(2,5) \ + pa(5,0) \ + pm(2,1) \ + pa(1,6) +#undef pldot_2c +#define pldot_2c 8 + +#undef p1_4_axpby_3 +#define p1_4_axpby_3(a_) \ + pls(a_,ax,0) \ + pls(a_,cx,3) \ + pmsr(5,0) \ + pmsr(6,3) \ + pasr(3,0) \ + pus(0,a_,ax) +#undef p1_2_axpby_3 +#define p1_2_axpby_3(a_) \ + pld(a_,ax,0) \ + pld(a_,cx,3) \ + pm(5,0) \ + pm(6,3) \ + pa(3,0) \ + pud(0,a_,ax) +#undef p1_axpby_3 +#define p1_axpby_3(a_) \ + plq(a_,ax,0) \ + pl(a_,cx,3) \ + pm(5,0) \ + pm(6,3) \ + pa(3,0) \ + punt(0,a_,ax) +#undef p2_axpby_3 +#define p2_axpby_3(a_) \ + plq(a_,ax,0) \ + pl(a_,cx,3) \ + plq(SS(a_,RS4),ax,1) \ + pm(5,0) \ + pm(6,3) \ + pa(3,0) \ + pl(SS(a_,RS4),cx,3) \ + punt(0,a_,ax) \ + pm(5,1) \ + pm(6,3) \ + pa(3,1) \ + punt(1,SS(a_,RS4),ax) +#undef p4_axpby_3 +#define p4_axpby_3(a_) \ + plq(SS(a_,MM(3,RS4)),ax,3) \ + pm(5,2) \ + pl(SS(a_,MM(3,RS4)),cx,7) \ + pm(6,4) \ + pa(4,2) \ + punt(0,a_,ax) \ + f(nta,SS(a_,MM((SS(2,CL)),RS4)),cx) \ + pl(SS(a_,MM(4,RS4)),cx,4) \ + pm(5,3) \ + plq(SS(a_,MM(4,RS4)),ax,0) \ + pm(6,7) \ + pa(7,3) \ + punt(1,SS(a_,RS4),ax) \ + f(nta,SS(a_,MM((SS(2,CL)),RS4)),ax) \ + plq(SS(a_,MM(5,RS4)),ax,1) \ + pm(5,0) \ + pl(SS(a_,MM(5,RS4)),cx,7) \ + pm(6,4) \ + pa(4,0) \ + punt(2,SS(a_,MM(2,RS4)),ax) \ + f(nta,SS(a_,MM((SS(4,CL)),RS4)),cx) \ + pl(SS(a_,MM(6,RS4)),cx,4) \ + pm(5,1) \ + plq(SS(a_,MM(6,RS4)),ax,2) \ + pm(6,7) \ + pa(7,1) \ + punt(3,SS(a_,MM(3,RS4)),ax) \ + f(nta,SS(a_,MM((SS(4,CL)),RS4)),ax) +#undef lpaxpby_3 +#define lpaxpby_3(a_) \ + f(nta,SS(a_,MM((SS(0,CL)),RS4)),cx) \ + pl(SS(a_,MM(0,RS4)),cx,4) \ + plq(SS(a_,MM(0,RS4)),ax,0) \ + pl(SS(a_,MM(1,RS4)),cx,7) \ + pm(5,0) \ + plq(SS(a_,MM(1,RS4)),ax,1) \ + pm(6,4) \ + pa(4,0) \ + f(nta,SS(a_,MM((SS(0,CL)),RS4)),ax) \ + plq(SS(a_,MM(2,RS4)),ax,2) \ + pm(5,1) \ + pl(SS(a_,MM(2,RS4)),cx,4) \ + pm(6,7) \ + pa(7,1) +#undef dpaxpby_3 +#define dpaxpby_3(a_) \ + pl(SS(a_,MM(3,RS4)),cx,7) \ + pm(5,2) \ + plq(SS(a_,MM(3,RS4)),ax,3) \ + pm(6,4) \ + pa(4,2) \ + pm(5,3) \ + punt(0,a_,ax) \ + pm(6,7) \ + pa(7,3) \ + punt(1,SS(a_,RS4),ax) \ + punt(2,SS(a_,MM(2,RS4)),ax) \ + punt(3,SS(a_,MM(3,RS4)),ax) +#undef plaxpby_3 +#define plaxpby_3 16 + +#undef p1_4_axpby_3c +#define p1_4_axpby_3c(a_) +#undef p1_2_axpby_3c +#define p1_2_axpby_3c(a_) \ + pld(a_,ax,0) \ + pld(a_,cx,2) \ + pc(0,3) \ + pm(5,0) \ + ps(CSHUF,3,3) \ + pm(4,3) \ + pa(3,0) \ + pc(2,3) \ + pm(6,2) \ + pa(2,0) \ + ps(CSHUF,3,3) \ + pm(7,3) \ + pa(3,0) \ + pud(0,a_,ax) +#undef p1_axpby_3c +#define p1_axpby_3c(a_) \ + plq(a_,ax,0) \ + pl(a_,cx,2) \ + pc(0,3) \ + pm(5,0) \ + ps(CSHUF,3,3) \ + pm(4,3) \ + pa(3,0) \ + pc(2,3) \ + pm(6,2) \ + pa(2,0) \ + ps(CSHUF,3,3) \ + pm(7,3) \ + pa(3,0) \ + puq(0,a_,ax) +#undef p2_axpby_3c +#define p2_axpby_3c(a_) \ + f(nta,SS(a_,MM((SS(2,CL)),RS4)),ax) \ + plq(SS(a_,MM(1,RS4)),ax,1) \ + pl(SS(a_,MM(1,RS4)),cx,3) \ + pc(1,2) \ + pm(5,1) \ + ps(CSHUF,2,2) \ + pm(4,2) \ + pa(2,1) \ + pc(3,2) \ + pm(6,3) \ + pa(3,1) \ + ps(CSHUF,2,2) \ + pm(7,2) \ + pa(2,1) \ + puq(0,a_,ax) \ + plq(SS(a_,MM(2,RS4)),ax,0) \ + f(nta,SS(a_,MM((SS(2,CL)),RS4)),cx) \ + pl(SS(a_,MM(2,RS4)),cx,2) \ + pc(0,3) \ + pm(5,0) \ + ps(CSHUF,3,3) \ + pm(4,3) \ + pa(3,0) \ + pc(2,3) \ + pm(6,2) \ + pa(2,0) \ + ps(CSHUF,3,3) \ + pm(7,3) \ + pa(3,0) \ + puq(1,SS(a_,RS4),ax) +#undef lpaxpby_3c +#define lpaxpby_3c(a_) \ + f(nta,SS(a_,MM((SS(0,CL)),RS4)),ax) \ + plq(SS(a_,MM(0,RS4)),ax,0) \ + f(nta,SS(a_,MM((SS(0,CL)),RS4)),cx) \ + pl(SS(a_,MM(0,RS4)),cx,2) \ + pc(0,3) \ + pm(5,0) \ + ps(CSHUF,3,3) \ + pm(4,3) \ + pa(3,0) \ + pc(2,3) \ + pm(6,2) \ + pa(2,0) \ + ps(CSHUF,3,3) \ + pm(7,3) \ + pa(3,0) +#undef dpaxpby_3c +#define dpaxpby_3c(a_) \ + plq(SS(a_,MM(1,RS4)),ax,1) \ + pl(SS(a_,MM(1,RS4)),cx,3) \ + pc(1,2) \ + pm(5,1) \ + ps(CSHUF,2,2) \ + pm(4,2) \ + pa(2,1) \ + pc(3,2) \ + pm(6,3) \ + pa(3,1) \ + ps(CSHUF,2,2) \ + pm(7,2) \ + pa(2,1) \ + puq(0,a_,ax) \ + puq(1,SS(a_,RS4),ax) +#undef plaxpby_3c +#define plaxpby_3c 8 + +#undef p1_4_axpby_2 +#define p1_4_axpby_2(a_) \ + pls(a_,cx,5) \ + pls(a_,ax,0) \ + pmsr(6,5) \ + pasr(5,0) \ + pus(0,a_,ax) +#undef p1_2_axpby_2 +#define p1_2_axpby_2(a_) \ + pld(a_,cx,5) \ + pld(a_,ax,0) \ + pm(6,5) \ + pa(5,0) \ + pud(0,a_,ax) +#undef p1_axpby_2 +#define p1_axpby_2(a_) \ + pl(a_,cx,5) \ + plq(a_,ax,0) \ + pm(6,5) \ + pa(5,0) \ + puq(0,a_,ax) +#undef p2_axpby_2 +#define p2_axpby_2(a_) \ + pl(a_,cx,5) \ + plq(a_,ax,0) \ + pl(SS(a_,RS4),cx,4) \ + pm(6,5) \ + pa(5,0) \ + plq(SS(a_,RS4),ax,1) \ + puq(0,a_,ax) \ + pm(6,4) \ + pa(4,1) \ + puq(1,SS(a_,RS4),ax) +#undef p4_axpby_2 +#define p4_axpby_2(a_) \ + plq(SS(a_,MM(3,RS4)),ax,3) \ + pl(SS(a_,MM(3,RS4)),cx,5) \ + pm(6,4) \ + pa(4,2) \ + puq(0,a_,ax) \ + f(nta,SS(a_,MM((SS(2,CL)),RS4)),cx) \ + pl(SS(a_,MM(4,RS4)),cx,4) \ + plq(SS(a_,MM(4,RS4)),ax,0) \ + pm(6,5) \ + pa(5,3) \ + puq(1,SS(a_,RS4),ax) \ + f(nta,SS(a_,MM((SS(2,CL)),RS4)),ax) \ + plq(SS(a_,MM(5,RS4)),ax,1) \ + pl(SS(a_,MM(5,RS4)),cx,5) \ + pm(6,4) \ + pa(4,0) \ + puq(2,SS(a_,MM(2,RS4)),ax) \ + f(nta,SS(a_,MM((SS(4,CL)),RS4)),cx) \ + pl(SS(a_,MM(6,RS4)),cx,4) \ + plq(SS(a_,MM(6,RS4)),ax,2) \ + pm(6,5) \ + pa(5,1) \ + puq(3,SS(a_,MM(3,RS4)),ax) \ + f(nta,SS(a_,MM((SS(4,CL)),RS4)),ax) +#undef lpaxpby_2 +#define lpaxpby_2(a_) \ + f(nta,SS(a_,MM((SS(0,CL)),RS4)),cx) \ + pl(SS(a_,MM(0,RS4)),cx,4) \ + plq(SS(a_,MM(0,RS4)),ax,0) \ + pl(SS(a_,MM(1,RS4)),cx,5) \ + plq(SS(a_,MM(1,RS4)),ax,1) \ + pm(6,4) \ + pa(4,0) \ + f(nta,SS(a_,MM((SS(0,CL)),RS4)),ax) \ + plq(SS(a_,MM(2,RS4)),ax,2) \ + pl(SS(a_,MM(2,RS4)),cx,4) \ + pm(6,5) \ + pa(5,1) +#undef dpaxpby_2 +#define dpaxpby_2(a_) \ + pl(SS(a_,MM(3,RS4)),cx,5) \ + plq(SS(a_,MM(3,RS4)),ax,3) \ + pm(6,4) \ + pa(4,2) \ + puq(0,a_,ax) \ + pm(6,5) \ + pa(5,3) \ + puq(1,SS(a_,RS4),ax) \ + puq(2,SS(a_,MM(2,RS4)),ax) \ + puq(3,SS(a_,MM(3,RS4)),ax) +#undef plaxpby_2 +#define plaxpby_2 16 + +#undef p1_4_axpby_2c +#define p1_4_axpby_2c(a_) +#undef p1_2_axpby_2c +#define p1_2_axpby_2c(a_) \ + pld(a_,cx,5) \ + pld(a_,ax,0) \ + pc(5,1) \ + pm(6,5) \ + pa(5,0) \ + ps(CSHUF,1,1) \ + pm(7,1) \ + pa(1,0) \ + pud(0,a_,ax) +#undef p1_axpby_2c +#define p1_axpby_2c(a_) \ + pl(a_,cx,5) \ + plq(a_,ax,0) \ + pc(5,1) \ + pm(6,5) \ + pa(5,0) \ + ps(CSHUF,1,1) \ + pm(7,1) \ + pa(1,0) \ + puq(0,a_,ax) +#undef p2_axpby_2c +#define p2_axpby_2c(a_) \ + pl(a_,cx,5) \ + plq(a_,ax,0) \ + pl(SS(a_,RS4),cx,4) \ + pc(5,1) \ + pm(6,5) \ + pa(5,0) \ + ps(CSHUF,2,2) \ + pm(7,2) \ + pa(2,0) \ + plq(SS(a_,RS4),ax,1) \ + puq(0,a_,ax) \ + pc(4,3) \ + pm(6,4) \ + pa(4,1) \ + ps(CSHUF,3,3) \ + pm(7,3) \ + pa(3,1) \ + puq(1,SS(a_,RS4),ax) +#undef p4_axpby_2c +#define p4_axpby_2c(a_) \ + plq(SS(a_,MM(3,RS4)),ax,3) \ + puq(0,a_,ax) \ + pc(4,0) \ + pm(6,4) \ + pa(4,2) \ + ps(CSHUF,0,0) \ + f(nta,SS(a_,MM((SS(2,CL)),RS4)),cx) \ + pl(SS(a_,MM(4,RS4)),cx,4) \ + pm(7,0) \ + pa(0,2) \ + plq(SS(a_,MM(4,RS4)),ax,0) \ + puq(1,SS(a_,RS4),ax) \ + pc(5,1) \ + pm(6,5) \ + pa(5,3) \ + ps(CSHUF,1,1) \ + pl(SS(a_,MM(5,RS4)),cx,5) \ + pm(7,1) \ + pa(1,3) \ + f(nta,SS(a_,MM((SS(2,CL)),RS4)),ax) \ + plq(SS(a_,MM(5,RS4)),ax,1) \ + puq(2,SS(a_,MM(2,RS4)),ax) \ + pc(4,2) \ + pm(6,4) \ + pa(4,0) \ + ps(CSHUF,2,2) \ + f(nta,SS(a_,MM((SS(4,CL)),RS4)),cx) \ + pl(SS(a_,MM(6,RS4)),cx,4) \ + pm(7,2) \ + pa(2,0) \ + plq(SS(a_,MM(6,RS4)),ax,2) \ + puq(3,SS(a_,MM(3,RS4)),ax) \ + pc(5,3) \ + pm(6,5) \ + pa(5,1) \ + ps(CSHUF,3,3) \ + pl(SS(a_,MM(7,RS4)),cx,5) \ + pm(7,3) \ + pa(3,1) \ + f(nta,SS(a_,MM((SS(4,CL)),RS4)),ax) +#undef lpaxpby_2c +#define lpaxpby_2c(a_) \ + f(nta,SS(a_,MM((SS(0,CL)),RS4)),cx) \ + pl(SS(a_,MM(0,RS4)),cx,4) \ + plq(SS(a_,MM(0,RS4)),ax,0) \ + pl(SS(a_,MM(1,RS4)),cx,5) \ + plq(SS(a_,MM(1,RS4)),ax,1) \ + pc(4,2) \ + pm(6,4) \ + pa(4,0) \ + ps(CSHUF,2,2) \ + pl(SS(a_,MM(2,RS4)),cx,4) \ + pm(7,2) \ + pa(2,0) \ + f(nta,SS(a_,MM((SS(0,CL)),RS4)),ax) \ + plq(SS(a_,MM(2,RS4)),ax,2) \ + pc(5,3) \ + pm(6,5) \ + pa(5,1) \ + ps(CSHUF,3,3) \ + pl(SS(a_,MM(3,RS4)),cx,5) \ + pm(7,3) \ + pa(3,1) +#undef dpaxpby_2c +#define dpaxpby_2c(a_) \ + plq(SS(a_,MM(3,RS4)),ax,3) \ + puq(0,a_,ax) \ + pc(4,0) \ + pm(6,4) \ + pa(4,2) \ + ps(CSHUF,0,0) \ + puq(1,SS(a_,RS4),ax) \ + pm(7,0) \ + pa(0,2) \ + pc(5,1) \ + pm(6,5) \ + pa(5,3) \ + ps(CSHUF,1,1) \ + puq(2,SS(a_,MM(2,RS4)),ax) \ + pm(7,1) \ + pa(1,3) \ + puq(3,SS(a_,MM(3,RS4)),ax) +#undef plaxpby_2c +#define plaxpby_2c 16 + +#undef p1_4_axpby_1 +#define p1_4_axpby_1(a_) \ + pls(a_,ax,1) \ + pls(a_,cx,2) \ + pmsr(5,1) \ + pmsr(6,2) \ + pasr(2,1) \ + pus(1,a_,ax) +#undef p1_2_axpby_1 +#define p1_2_axpby_1(a_) \ + pld(a_,ax,1) \ + pld(a_,cx,2) \ + pm(5,1) \ + pm(6,2) \ + pa(2,1) \ + pud(1,a_,ax) +#undef p1_axpby_1 +#define p1_axpby_1(a_) \ + plq(a_,ax,1) \ + pl(a_,cx,2) \ + pm(5,1) \ + pm(6,2) \ + pa(2,1) \ + puq(1,a_,ax) +#undef p2_axpby_1 +#define p2_axpby_1(a_) \ + plq(SS(a_,RS4),ax,3) \ + pl(SS(a_,RS4),cx,4) \ + pm(5,1) \ + pm(6,2) \ + pa(2,1) \ + puq(1,a_,ax) \ + f(nta,SS(a_,MM((SS(2,CL)),RS4)),ax) \ + plq(SS(a_,MM(2,RS4)),ax,1) \ + f(nta,SS(a_,MM((SS(2,CL)),RS4)),cx) \ + pl(SS(a_,MM(2,RS4)),cx,2) \ + pm(5,3) \ + pm(6,4) \ + pa(4,3) \ + puq(3,SS(a_,RS4),ax) +#undef lpaxpby_1 +#define lpaxpby_1(a_) \ + f(nta,SS(a_,MM((SS(0,CL)),RS4)),ax) \ + plq(SS(a_,MM(0,RS4)),ax,1) \ + f(nta,SS(a_,MM((SS(0,CL)),RS4)),cx) \ + pl(SS(a_,MM(0,RS4)),cx,2) +#undef dpaxpby_1 +#define dpaxpby_1(a_) \ + plq(SS(a_,RS4),ax,3) \ + pl(SS(a_,RS4),cx,4) \ + pm(5,1) \ + pm(6,2) \ + pa(2,1) \ + puq(1,a_,ax) \ + pm(5,3) \ + pm(6,4) \ + pa(4,3) \ + puq(3,SS(a_,RS4),ax) +#undef plaxpby_1 +#define plaxpby_1 8 + +#undef p1_4_axpy_0 +#define p1_4_axpy_0(a_) \ + pls(a_,cx,2) \ + pls(a_,ax,1) \ + pmsr(6,2) \ + pasr(2,1) \ + pus(1,a_,ax) +#undef p1_2_axpy_0 +#define p1_2_axpy_0(a_) \ + pld(a_,cx,2) \ + pld(a_,ax,1) \ + pm(6,2) \ + pa(2,1) \ + pud(1,a_,ax) +#undef p1_axpy_0 +#define p1_axpy_0(a_) \ + pl(a_,cx,2) \ + plq(a_,ax,1) \ + pm(6,2) \ + pa(2,1) \ + puq(1,a_,ax) +#undef p2_axpy_0 +#define p2_axpy_0(a_) \ + pl(SS(a_,RS4),cx,4) \ + pm(6,2) \ + pa(2,1) \ + plq(SS(a_,RS4),ax,3) \ + f(nta,SS(a_,MM((SS(2,CL)),RS4)),cx) \ + pl(SS(a_,MM(2,RS4)),cx,2) \ + puq(1,a_,ax) \ + pm(6,4) \ + pa(4,3) \ + f(nta,SS(a_,MM((SS(2,CL)),RS4)),ax) \ + plq(SS(a_,MM(2,RS4)),ax,1) \ + puq(3,SS(a_,RS4),ax) +#undef lpaxpy_0 +#define lpaxpy_0(a_) \ + f(nta,SS(a_,MM((SS(0,CL)),RS4)),cx) \ + pl(SS(a_,MM(0,RS4)),cx,2) \ + f(nta,SS(a_,MM((SS(0,CL)),RS4)),ax) \ + plq(SS(a_,MM(0,RS4)),ax,1) +#undef dpaxpy_0 +#define dpaxpy_0(a_) \ + pl(SS(a_,RS4),cx,4) \ + pm(6,2) \ + pa(2,1) \ + plq(SS(a_,RS4),ax,3) \ + puq(1,a_,ax) \ + pm(6,4) \ + pa(4,3) \ + puq(3,SS(a_,RS4),ax) +#undef plaxpy_0 +#define plaxpy_0 8 + +#undef p1_4_axpy_1 +#define p1_4_axpy_1(a_) \ + pls(a_,cx,2) \ + pls(a_,ax,1) \ + pmsr(6,2) \ + pasr(2,1) \ + pus(1,a_,ax) +#undef p1_2_axpy_1 +#define p1_2_axpy_1(a_) \ + pld(a_,cx,2) \ + pld(a_,ax,1) \ + pm(6,2) \ + pa(2,1) \ + pud(1,a_,ax) +#undef p1_axpy_1 +#define p1_axpy_1(a_) \ + pl(a_,cx,2) \ + pm(6,2) \ + pam(a_,ax,2) \ + puq(2,a_,ax) +#undef p2_axpy_1 +#define p2_axpy_1(a_) \ + pl(a_,cx,2) \ + pm(6,2) \ + pl(SS(a_,RS4),cx,4) \ + pam(a_,ax,2) \ + pm(6,4) \ + puq(2,a_,ax) \ + pam(SS(a_,RS4),ax,4) \ + puq(4,SS(a_,RS4),ax) +#undef p4_axpy_1 +#define p4_axpy_1(a_) \ + f(nta,SS(a_,MM((SS(2,CL)),RS4)),cx) \ + pl(SS(a_,MM(3,RS4)),cx,3) \ + pm(6,2) \ + pam(SS(a_,MM(2,RS4)),ax,2) \ + puq(0,a_,ax) \ + f(nta,SS(a_,MM((SS(2,CL)),RS4)),ax) \ + pl(SS(a_,MM(4,RS4)),cx,0) \ + pm(6,3) \ + pam(SS(a_,MM(3,RS4)),ax,3) \ + puq(1,SS(a_,RS4),ax) \ + f(nta,SS(a_,MM((SS(4,CL)),RS4)),cx) \ + pl(SS(a_,MM(5,RS4)),cx,1) \ + pm(6,0) \ + pam(SS(a_,MM(4,RS4)),ax,0) \ + puq(2,SS(a_,MM(2,RS4)),ax) \ + f(nta,SS(a_,MM((SS(4,CL)),RS4)),ax) \ + pl(SS(a_,MM(6,RS4)),cx,2) \ + pm(6,1) \ + pam(SS(a_,MM(5,RS4)),ax,1) \ + puq(3,SS(a_,MM(3,RS4)),ax) +#undef lpaxpy_1 +#define lpaxpy_1(a_) \ + f(nta,SS(a_,MM((SS(0,CL)),RS4)),cx) \ + pl(a_,cx,0) \ + f(nta,SS(a_,MM((SS(0,CL)),RS4)),ax) \ + pl(SS(a_,RS4),cx,1) \ + f(nta,SS(a_,MM((SS(2,CL)),RS4)),cx) \ + pm(6,0) \ + pam(a_,ax,0) \ + pl(SS(a_,MM(2,RS4)),cx,2) \ + f(nta,SS(a_,MM((SS(2,CL)),RS4)),ax) \ + pm(6,1) \ + pam(SS(a_,RS4),ax,1) +#undef dpaxpy_1 +#define dpaxpy_1(a_) \ + pl(SS(a_,MM(3,RS4)),cx,3) \ + pm(6,2) \ + pam(SS(a_,MM(2,RS4)),ax,2) \ + puq(0,a_,ax) \ + pm(6,3) \ + pam(SS(a_,MM(3,RS4)),ax,3) \ + puq(1,SS(a_,RS4),ax) \ + puq(2,SS(a_,MM(2,RS4)),ax) \ + puq(3,SS(a_,MM(3,RS4)),ax) +#undef plaxpy_1 +#define plaxpy_1 16 + +#undef p1_4_axpy_2 +#define p1_4_axpy_2(a_) \ + pls(a_,cx,5) \ + pls(a_,ax,0) \ + pmsr(6,5) \ + pasr(5,0) \ + pus(0,a_,ax) +#undef p1_2_axpy_2 +#define p1_2_axpy_2(a_) \ + pld(a_,cx,5) \ + pld(a_,ax,0) \ + pm(6,5) \ + pa(5,0) \ + pud(0,a_,ax) +#undef p1_axpy_2 +#define p1_axpy_2(a_) \ + pl(a_,cx,5) \ + plq(a_,ax,0) \ + pm(6,5) \ + pa(5,0) \ + puq(0,a_,ax) +#undef p2_axpy_2 +#define p2_axpy_2(a_) \ + pl(a_,cx,5) \ + plq(a_,ax,0) \ + pl(SS(a_,RS4),cx,4) \ + pm(6,5) \ + pa(5,0) \ + plq(SS(a_,RS4),ax,1) \ + puq(0,a_,ax) \ + pm(6,4) \ + pa(4,1) \ + puq(1,SS(a_,RS4),ax) +#undef p4_axpy_2 +#define p4_axpy_2(a_) \ + plq(SS(a_,MM(3,RS4)),ax,3) \ + pl(SS(a_,MM(3,RS4)),cx,5) \ + pm(6,4) \ + pa(4,2) \ + puq(0,a_,ax) \ + f(nta,SS(a_,MM((SS(2,CL)),RS4)),cx) \ + pl(SS(a_,MM(4,RS4)),cx,4) \ + plq(SS(a_,MM(4,RS4)),ax,0) \ + pm(6,5) \ + pa(5,3) \ + puq(1,SS(a_,RS4),ax) \ + f(nta,SS(a_,MM((SS(2,CL)),RS4)),ax) \ + plq(SS(a_,MM(5,RS4)),ax,1) \ + pl(SS(a_,MM(5,RS4)),cx,5) \ + pm(6,4) \ + pa(4,0) \ + puq(2,SS(a_,MM(2,RS4)),ax) \ + f(nta,SS(a_,MM((SS(4,CL)),RS4)),cx) \ + pl(SS(a_,MM(6,RS4)),cx,4) \ + plq(SS(a_,MM(6,RS4)),ax,2) \ + pm(6,5) \ + pa(5,1) \ + puq(3,SS(a_,MM(3,RS4)),ax) \ + f(nta,SS(a_,MM((SS(4,CL)),RS4)),ax) +#undef lpaxpy_2 +#define lpaxpy_2(a_) \ + f(nta,SS(a_,MM((SS(0,CL)),RS4)),cx) \ + pl(SS(a_,MM(0,RS4)),cx,4) \ + plq(SS(a_,MM(0,RS4)),ax,0) \ + pl(SS(a_,MM(1,RS4)),cx,5) \ + plq(SS(a_,MM(1,RS4)),ax,1) \ + pm(6,4) \ + pa(4,0) \ + f(nta,SS(a_,MM((SS(0,CL)),RS4)),ax) \ + plq(SS(a_,MM(2,RS4)),ax,2) \ + pl(SS(a_,MM(2,RS4)),cx,4) \ + pm(6,5) \ + pa(5,1) +#undef dpaxpy_2 +#define dpaxpy_2(a_) \ + pl(SS(a_,MM(3,RS4)),cx,5) \ + plq(SS(a_,MM(3,RS4)),ax,3) \ + pm(6,4) \ + pa(4,2) \ + puq(0,a_,ax) \ + pm(6,5) \ + pa(5,3) \ + puq(1,SS(a_,RS4),ax) \ + puq(2,SS(a_,MM(2,RS4)),ax) \ + puq(3,SS(a_,MM(3,RS4)),ax) +#undef plaxpy_2 +#define plaxpy_2 16 + +#undef p1_4_axpy_2c +#define p1_4_axpy_2c(a_) +#undef p1_2_axpy_2c +#define p1_2_axpy_2c(a_) \ + pld(a_,cx,4) \ + pld(a_,ax,0) \ + pc(4,2) \ + pm(6,4) \ + pa(4,0) \ + ps(CSHUF,2,2) \ + pm(7,2) \ + pa(2,0) \ + pud(0,a_,ax) +#undef p1_axpy_2c +#define p1_axpy_2c(a_) \ + pl(a_,cx,4) \ + plq(a_,ax,0) \ + pc(4,2) \ + pm(6,4) \ + pa(4,0) \ + ps(CSHUF,2,2) \ + pm(7,2) \ + pa(2,0) \ + puq(0,a_,ax) +#undef p2_axpy_2c +#define p2_axpy_2c(a_) \ + pl(a_,cx,4) \ + plq(a_,ax,0) \ + pl(SS(a_,RS4),cx,5) \ + pc(4,2) \ + pm(6,4) \ + pa(4,0) \ + ps(CSHUF,2,2) \ + pm(7,2) \ + pa(2,0) \ + plq(SS(a_,RS4),ax,1) \ + puq(0,a_,ax) \ + pc(5,3) \ + pm(6,5) \ + pa(5,1) \ + ps(CSHUF,3,3) \ + pm(7,3) \ + pa(3,1) \ + puq(1,SS(a_,RS4),ax) +#undef p4_axpy_2c +#define p4_axpy_2c(a_) \ + plq(SS(a_,MM(3,RS4)),ax,3) \ + puq(0,a_,ax) \ + pc(4,0) \ + pm(6,4) \ + pa(4,2) \ + ps(CSHUF,0,0) \ + f(nta,SS(a_,MM((SS(2,CL)),RS4)),cx) \ + pl(SS(a_,MM(4,RS4)),cx,4) \ + pm(7,0) \ + pa(0,2) \ + plq(SS(a_,MM(4,RS4)),ax,0) \ + puq(1,SS(a_,RS4),ax) \ + pc(5,1) \ + pm(6,5) \ + pa(5,3) \ + ps(CSHUF,1,1) \ + pl(SS(a_,MM(5,RS4)),cx,5) \ + pm(7,1) \ + pa(1,3) \ + f(nta,SS(a_,MM((SS(2,CL)),RS4)),ax) \ + plq(SS(a_,MM(5,RS4)),ax,1) \ + puq(2,SS(a_,MM(2,RS4)),ax) \ + pc(4,2) \ + pm(6,4) \ + pa(4,0) \ + ps(CSHUF,2,2) \ + f(nta,SS(a_,MM((SS(4,CL)),RS4)),cx) \ + pl(SS(a_,MM(6,RS4)),cx,4) \ + pm(7,2) \ + pa(2,0) \ + plq(SS(a_,MM(6,RS4)),ax,2) \ + puq(3,SS(a_,MM(3,RS4)),ax) \ + pc(5,3) \ + pm(6,5) \ + pa(5,1) \ + ps(CSHUF,3,3) \ + pl(SS(a_,MM(7,RS4)),cx,5) \ + pm(7,3) \ + pa(3,1) \ + f(nta,SS(a_,MM((SS(4,CL)),RS4)),ax) +#undef lpaxpy_2c +#define lpaxpy_2c(a_) \ + f(nta,SS(a_,MM((SS(0,CL)),RS4)),cx) \ + pl(SS(a_,MM(0,RS4)),cx,4) \ + plq(SS(a_,MM(0,RS4)),ax,0) \ + pl(SS(a_,MM(1,RS4)),cx,5) \ + plq(SS(a_,MM(1,RS4)),ax,1) \ + pc(4,2) \ + pm(6,4) \ + pa(4,0) \ + ps(CSHUF,2,2) \ + pl(SS(a_,MM(2,RS4)),cx,4) \ + pm(7,2) \ + pa(2,0) \ + f(nta,SS(a_,MM((SS(0,CL)),RS4)),ax) \ + plq(SS(a_,MM(2,RS4)),ax,2) \ + pc(5,3) \ + pm(6,5) \ + pa(5,1) \ + ps(CSHUF,3,3) \ + pl(SS(a_,MM(3,RS4)),cx,5) \ + pm(7,3) \ + pa(3,1) +#undef dpaxpy_2c +#define dpaxpy_2c(a_) \ + plq(SS(a_,MM(3,RS4)),ax,3) \ + puq(0,a_,ax) \ + pc(4,0) \ + pm(6,4) \ + pa(4,2) \ + ps(CSHUF,0,0) \ + puq(1,SS(a_,RS4),ax) \ + pm(7,0) \ + pa(0,2) \ + pc(5,1) \ + pm(6,5) \ + pa(5,3) \ + ps(CSHUF,1,1) \ + puq(2,SS(a_,MM(2,RS4)),ax) \ + pm(7,1) \ + pa(1,3) \ + puq(3,SS(a_,MM(3,RS4)),ax) +#undef plaxpy_2c +#define plaxpy_2c 16 + +#undef p1_4_axpy_1c +#define p1_4_axpy_1c(a_) +#undef p1_2_axpy_1c +#define p1_2_axpy_1c(a_) \ + pld(a_,cx,2) \ + pc(2,0) \ + pld(a_,ax,1) \ + ps(CSHUF,0,0) \ + pm(6,2) \ + pa(2,1) \ + pm(7,0) \ + pa(0,1) \ + pud(1,a_,ax) +#undef p1_axpy_1c +#define p1_axpy_1c(a_) \ + pl(a_,cx,2) \ + pc(2,0) \ + plq(a_,ax,1) \ + ps(CSHUF,0,0) \ + pm(6,2) \ + pa(2,1) \ + pm(7,0) \ + pa(0,1) \ + puq(1,a_,ax) +#undef p2_axpy_1c +#define p2_axpy_1c(a_) \ + plq(SS(a_,RS4),ax,3) \ + ps(CSHUF,0,0) \ + pl(SS(a_,RS4),cx,4) \ + pm(6,2) \ + pa(2,1) \ + pm(7,0) \ + pa(0,1) \ + pc(4,0) \ + puq(1,a_,ax) \ + f(nta,SS(a_,MM((SS(2,CL)),RS4)),ax) \ + plq(SS(a_,MM(2,RS4)),ax,1) \ + ps(CSHUF,0,0) \ + f(nta,SS(a_,MM((SS(2,CL)),RS4)),cx) \ + pl(SS(a_,MM(2,RS4)),cx,2) \ + pm(6,4) \ + pa(4,3) \ + pm(7,0) \ + pa(0,3) \ + pc(2,0) \ + puq(3,SS(a_,RS4),ax) +#undef lpaxpy_1c +#define lpaxpy_1c(a_) \ + f(nta,SS(a_,MM((SS(0,CL)),RS4)),cx) \ + pl(SS(a_,MM(0,RS4)),cx,2) \ + f(nta,SS(a_,MM((SS(0,CL)),RS4)),ax) \ + plq(SS(a_,MM(0,RS4)),ax,1) \ + pc(2,0) +#undef dpaxpy_1c +#define dpaxpy_1c(a_) \ + plq(SS(a_,RS4),ax,3) \ + ps(CSHUF,0,0) \ + pl(SS(a_,RS4),cx,4) \ + pm(6,2) \ + pa(2,1) \ + pm(7,0) \ + pa(0,1) \ + pc(4,0) \ + puq(1,a_,ax) \ + ps(CSHUF,0,0) \ + pm(6,4) \ + pa(4,3) \ + pm(7,0) \ + pa(0,3) \ + puq(3,SS(a_,RS4),ax) +#undef plaxpy_1c +#define plaxpy_1c 8 + +#undef p1_4_copy_1 +#define p1_4_copy_1(a_) \ + pls(a_,cx,2) \ + pus(2,a_,ax) +#undef p1_2_copy_1 +#define p1_2_copy_1(a_) \ + pld(a_,cx,2) \ + pud(2,a_,ax) +#undef p1_copy_1 +#define p1_copy_1(a_) \ + pl(a_,cx,2) \ + puq(2,a_,ax) +#undef p2_copy_1 +#define p2_copy_1(a_) \ + pl(SS(a_,RS4),cx,4) \ + puq(2,a_,ax) \ + f(nta,SS(a_,MM((SS(2,CL)),RS4)),ax) \ + f(nta,SS(a_,MM((SS(2,CL)),RS4)),cx) \ + pl(SS(a_,MM(2,RS4)),cx,2) \ + puq(4,SS(a_,RS4),ax) +#undef lpcopy_1 +#define lpcopy_1(a_) \ + f(nta,SS(a_,MM((SS(0,CL)),RS4)),ax) \ + f(nta,SS(a_,MM((SS(0,CL)),RS4)),cx) \ + pl(SS(a_,MM(0,RS4)),cx,2) +#undef dpcopy_1 +#define dpcopy_1(a_) \ + pl(SS(a_,RS4),cx,4) \ + puq(2,a_,ax) \ + puq(4,SS(a_,RS4),ax) +#undef plcopy_1 +#define plcopy_1 8 + +#undef p1_4_copy_2 +#define p1_4_copy_2(a_) \ + pls(a_,ax,2) \ + pus(2,a_,cx) +#undef p1_2_copy_2 +#define p1_2_copy_2(a_) \ + pld(a_,ax,2) \ + pud(2,a_,cx) +#undef p1_copy_2 +#define p1_copy_2(a_) \ + plq(a_,ax,2) \ + pu(2,a_,cx) +#undef p2_copy_2 +#define p2_copy_2(a_) \ + plq(SS(a_,RS4),ax,4) \ + pu(2,a_,cx) \ + f(nta,SS(a_,MM((SS(2,CL)),RS4)),cx) \ + f(nta,SS(a_,MM((SS(2,CL)),RS4)),ax) \ + plq(SS(a_,MM(2,RS4)),ax,2) \ + pu(4,SS(a_,RS4),cx) +#undef lpcopy_2 +#define lpcopy_2(a_) \ + f(nta,SS(a_,MM((SS(0,CL)),RS4)),cx) \ + f(nta,SS(a_,MM((SS(0,CL)),RS4)),ax) \ + plq(SS(a_,MM(0,RS4)),ax,2) +#undef dpcopy_2 +#define dpcopy_2(a_) \ + plq(SS(a_,RS4),ax,4) \ + pu(2,a_,cx) \ + pu(4,SS(a_,RS4),cx) +#undef plcopy_2 +#define plcopy_2 8 + +#undef p1_4_copy_3 +#define p1_4_copy_3(a_) \ + pls(a_,cx,2) \ + pus(2,a_,ax) +#undef p1_2_copy_3 +#define p1_2_copy_3(a_) \ + pld(a_,cx,2) \ + pud(2,a_,ax) +#undef p1_copy_3 +#define p1_copy_3(a_) \ + pl(a_,cx,2) \ + punt(2,a_,ax) +#undef p2_copy_3 +#define p2_copy_3(a_) \ + pl(SS(a_,MM(0,RS4)),cx,0) \ + pl(SS(a_,MM(1,RS4)),cx,1) \ + punt(0,SS(a_,MM(0,RS4)),ax) \ + punt(1,SS(a_,MM(1,RS4)),ax) +#undef p4_copy_3 +#define p4_copy_3(a_) \ + pl(SS(a_,MM(0,RS4)),cx,0) \ + pl(SS(a_,MM(1,RS4)),cx,1) \ + pl(SS(a_,MM(2,RS4)),cx,2) \ + pl(SS(a_,MM(3,RS4)),cx,3) \ + punt(0,SS(a_,MM(0,RS4)),ax) \ + punt(1,SS(a_,MM(1,RS4)),ax) \ + punt(2,SS(a_,MM(2,RS4)),ax) \ + punt(3,SS(a_,MM(3,RS4)),ax) +#undef p8_copy_3 +#define p8_copy_3(a_) \ + f(nta,SS(a_,MM((SS(0,CL)),RS4)),cx) \ + pl(SS(a_,MM(0,RS4)),cx,0) \ + pl(SS(a_,MM(1,RS4)),cx,1) \ + pl(SS(a_,MM(2,RS4)),cx,2) \ + pl(SS(a_,MM(3,RS4)),cx,3) \ + pl(SS(a_,MM(4,RS4)),cx,4) \ + pl(SS(a_,MM(5,RS4)),cx,5) \ + pl(SS(a_,MM(6,RS4)),cx,6) \ + pl(SS(a_,MM(7,RS4)),cx,7) \ + punt(0,SS(a_,MM(0,RS4)),ax) \ + punt(1,SS(a_,MM(1,RS4)),ax) \ + punt(2,SS(a_,MM(2,RS4)),ax) \ + punt(3,SS(a_,MM(3,RS4)),ax) \ + punt(4,SS(a_,MM(4,RS4)),ax) \ + punt(5,SS(a_,MM(5,RS4)),ax) \ + punt(6,SS(a_,MM(6,RS4)),ax) \ + punt(7,SS(a_,MM(7,RS4)),ax) +#undef lpcopy_3 +#define lpcopy_3(a_) +#undef dpcopy_3 +#define dpcopy_3(a_) p8_copy_3(a_) +#undef plcopy_3 +#define plcopy_3 32 + +#undef p1_4_cpsc_3 +#define p1_4_cpsc_3(a_) \ + pls(a_,ax,0) \ + pmsr(6,0) \ + pus(0,a_,cx) +#undef p1_2_cpsc_3 +#define p1_2_cpsc_3(a_) \ + pld(a_,ax,0) \ + pm(6,0) \ + pud(0,a_,cx) +#undef p1_cpsc_3 +#define p1_cpsc_3(a_) \ + plq(a_,ax,0) \ + pm(6,0) \ + pu(0,a_,cx) +#undef p2_cpsc_3 +#define p2_cpsc_3(a_) \ + plq(a_,ax,0) \ + plq(SS(a_,RS4),ax,1) \ + pm(6,0) \ + pm(6,1) \ + pu(0,a_,cx) \ + pu(1,SS(a_,RS4),cx) +#undef p4_cpsc_3 +#define p4_cpsc_3(a_) \ + plq(SS(a_,MM(3,RS4)),ax,3) \ + pm(6,2) \ + pu(0,a_,cx) \ + f(nta,SS(a_,MM((SS(2,CL)),RS4)),ax) \ + plq(SS(a_,MM(4,RS4)),ax,0) \ + pm(6,3) \ + pu(1,SS(a_,RS4),cx) \ + f(nta,SS(a_,MM((SS(2,CL)),RS4)),cx) \ + plq(SS(a_,MM(5,RS4)),ax,1) \ + pm(6,0) \ + pu(2,SS(a_,MM(2,RS4)),cx) \ + f(nta,SS(a_,MM((SS(4,CL)),RS4)),ax) \ + plq(SS(a_,MM(6,RS4)),ax,2) \ + pm(6,1) \ + pu(3,SS(a_,MM(3,RS4)),cx) \ + f(nta,SS(a_,MM((SS(4,CL)),RS4)),cx) +#undef lpcpsc_3 +#define lpcpsc_3(a_) \ + f(nta,SS(a_,MM((SS(0,CL)),RS4)),ax) \ + plq(SS(a_,MM(0,RS4)),ax,0) \ + plq(SS(a_,MM(1,RS4)),ax,1) \ + pm(6,0) \ + f(nta,SS(a_,MM((SS(0,CL)),RS4)),cx) \ + plq(SS(a_,MM(2,RS4)),ax,2) \ + pm(6,1) +#undef dpcpsc_3 +#define dpcpsc_3(a_) \ + plq(SS(a_,MM(3,RS4)),ax,3) \ + pm(6,2) \ + pu(0,a_,cx) \ + pm(6,3) \ + pu(1,SS(a_,RS4),cx) \ + pu(2,SS(a_,MM(2,RS4)),cx) \ + pu(3,SS(a_,MM(3,RS4)),cx) +#undef plcpsc_3 +#define plcpsc_3 16 + +#undef p1_4_cpsc_3c +#define p1_4_cpsc_3c(a_) +#undef p1_2_cpsc_3c +#define p1_2_cpsc_3c(a_) \ + pld(a_,ax,0) \ + pc(0,1) \ + pm(6,0) \ + ps(CSHUF,1,1) \ + pm(7,1) \ + pa(1,0) \ + pud(0,a_,cx) +#undef p1_cpsc_3c +#define p1_cpsc_3c(a_) \ + plq(a_,ax,0) \ + pc(0,1) \ + pm(6,0) \ + ps(CSHUF,1,1) \ + pm(7,1) \ + pa(1,0) \ + pu(0,a_,cx) +#undef p2_cpsc_3c +#define p2_cpsc_3c(a_) \ + plq(a_,ax,0) \ + plq(SS(a_,RS4),ax,1) \ + pc(0,2) \ + pm(6,0) \ + ps(CSHUF,2,2) \ + pm(7,2) \ + pa(2,0) \ + pu(0,a_,cx) \ + pc(1,3) \ + pm(6,1) \ + ps(CSHUF,3,3) \ + pm(7,3) \ + pa(3,1) \ + pu(1,SS(a_,RS4),cx) +#undef p4_cpsc_3c +#define p4_cpsc_3c(a_) \ + pu(0,a_,cx) \ + pc(2,4) \ + pm(6,2) \ + ps(CSHUF,4,4) \ + f(nta,SS(a_,MM((SS(2,CL)),RS4)),ax) \ + plq(SS(a_,MM(4,RS4)),ax,0) \ + pm(7,4) \ + pa(4,2) \ + pu(1,SS(a_,RS4),cx) \ + pc(3,4) \ + pm(6,3) \ + ps(CSHUF,4,4) \ + plq(SS(a_,MM(5,RS4)),ax,1) \ + pm(7,4) \ + pa(4,3) \ + f(nta,SS(a_,MM((SS(2,CL)),RS4)),cx) \ + pu(2,SS(a_,MM(2,RS4)),cx) \ + pc(0,4) \ + pm(6,0) \ + ps(CSHUF,4,4) \ + f(nta,SS(a_,MM((SS(4,CL)),RS4)),ax) \ + plq(SS(a_,MM(6,RS4)),ax,2) \ + pm(7,4) \ + pa(4,0) \ + pu(3,SS(a_,MM(3,RS4)),cx) \ + pc(1,4) \ + pm(6,1) \ + ps(CSHUF,4,4) \ + plq(SS(a_,MM(7,RS4)),ax,3) \ + pm(7,4) \ + pa(4,1) \ + f(nta,SS(a_,MM((SS(4,CL)),RS4)),cx) +#undef lpcpsc_3c +#define lpcpsc_3c(a_) \ + f(nta,SS(a_,MM((SS(0,CL)),RS4)),ax) \ + plq(SS(a_,MM(0,RS4)),ax,0) \ + plq(SS(a_,MM(1,RS4)),ax,1) \ + pc(0,4) \ + pm(6,0) \ + ps(CSHUF,4,4) \ + plq(SS(a_,MM(2,RS4)),ax,2) \ + pm(7,4) \ + pa(4,0) \ + f(nta,SS(a_,MM((SS(0,CL)),RS4)),cx) \ + pc(1,4) \ + pm(6,1) \ + ps(CSHUF,4,4) \ + plq(SS(a_,MM(3,RS4)),ax,3) \ + pm(7,4) \ + pa(4,1) +#undef dpcpsc_3c +#define dpcpsc_3c(a_) \ + pu(0,a_,cx) \ + pc(2,4) \ + pm(6,2) \ + ps(CSHUF,4,4) \ + pu(1,SS(a_,RS4),cx) \ + pm(7,4) \ + pa(4,2) \ + pc(3,4) \ + pm(6,3) \ + ps(CSHUF,4,4) \ + pu(2,SS(a_,MM(2,RS4)),cx) \ + pm(7,4) \ + pa(4,3) \ + pu(3,SS(a_,MM(3,RS4)),cx) +#undef plcpsc_3c +#define plcpsc_3c 16 + +#undef p1_4_cpsc_4 +#define p1_4_cpsc_4(a_) \ + pls(a_,cx,0) \ + pmsr(6,0) \ + pus(0,a_,ax) +#undef p1_2_cpsc_4 +#define p1_2_cpsc_4(a_) \ + pld(a_,cx,0) \ + pm(6,0) \ + pud(0,a_,ax) +#undef p1_cpsc_4 +#define p1_cpsc_4(a_) \ + pl(a_,cx,0) \ + pm(6,0) \ + puq(0,a_,ax) +#undef p2_cpsc_4 +#define p2_cpsc_4(a_) \ + pl(a_,cx,0) \ + pl(SS(a_,RS4),cx,1) \ + pm(6,0) \ + pm(6,1) \ + puq(0,a_,ax) \ + puq(1,SS(a_,RS4),ax) +#undef p4_cpsc_4 +#define p4_cpsc_4(a_) \ + pl(SS(a_,MM(3,RS4)),cx,3) \ + pm(6,2) \ + puq(0,a_,ax) \ + f(nta,SS(a_,MM((SS(2,CL)),RS4)),cx) \ + pl(SS(a_,MM(4,RS4)),cx,0) \ + pm(6,3) \ + puq(1,SS(a_,RS4),ax) \ + pl(SS(a_,MM(5,RS4)),cx,1) \ + pm(6,0) \ + puq(2,SS(a_,MM(2,RS4)),ax) \ + f(nta,SS(a_,MM((SS(4,CL)),RS4)),cx) \ + pl(SS(a_,MM(6,RS4)),cx,2) \ + pm(6,1) \ + puq(3,SS(a_,MM(3,RS4)),ax) +#undef lpcpsc_4 +#define lpcpsc_4(a_) \ + f(nta,SS(a_,MM((SS(0,CL)),RS4)),cx) \ + pl(SS(a_,MM(0,RS4)),cx,0) \ + pl(SS(a_,MM(1,RS4)),cx,1) \ + pm(6,0) \ + pl(SS(a_,MM(2,RS4)),cx,2) \ + pm(6,1) +#undef dpcpsc_4 +#define dpcpsc_4(a_) \ + pl(SS(a_,MM(3,RS4)),cx,3) \ + pm(6,2) \ + puq(0,a_,ax) \ + pm(6,3) \ + puq(1,SS(a_,RS4),ax) \ + puq(2,SS(a_,MM(2,RS4)),ax) \ + puq(3,SS(a_,MM(3,RS4)),ax) +#undef plcpsc_4 +#define plcpsc_4 16 + +#undef p1_4_cpsc_5 +#define p1_4_cpsc_5(a_) \ + pls(a_,cx,0) \ + pmsr(6,0) \ + pus(0,a_,ax) +#undef p1_2_cpsc_5 +#define p1_2_cpsc_5(a_) \ + pld(a_,cx,0) \ + pm(6,0) \ + pud(0,a_,ax) +#undef p1_cpsc_5 +#define p1_cpsc_5(a_) \ + pl(a_,cx,0) \ + pm(6,0) \ + puq(0,a_,ax) +#undef p2_cpsc_5 +#define p2_cpsc_5(a_) \ + pl(a_,cx,0) \ + pl(SS(a_,RS4),cx,1) \ + pm(6,0) \ + pm(6,1) \ + puq(0,a_,ax) \ + puq(1,SS(a_,RS4),ax) +#undef p4_cpsc_5 +#define p4_cpsc_5(a_) \ + pl(SS(a_,MM(0,RS4)),cx,0) \ + pl(SS(a_,MM(1,RS4)),cx,1) \ + pl(SS(a_,MM(2,RS4)),cx,2) \ + pl(SS(a_,MM(3,RS4)),cx,3) \ + pm(6,0) \ + pm(6,1) \ + pm(6,2) \ + pm(6,3) \ + puq(0,a_,ax) \ + puq(1,SS(a_,RS4),ax) \ + puq(2,SS(a_,MM(2,RS4)),ax) \ + puq(3,SS(a_,MM(3,RS4)),ax) +#undef p8_cpsc_5 +#define p8_cpsc_5(a_) \ + f(nta,SS(a_,MM((SS(2,CL)),RS4)),cx) \ + pl(SS(a_,MM(0,RS4)),cx,0) \ + pl(SS(a_,MM(1,RS4)),cx,1) \ + pl(SS(a_,MM(2,RS4)),cx,2) \ + pl(SS(a_,MM(3,RS4)),cx,3) \ + pl(SS(a_,MM(4,RS4)),cx,4) \ + pl(SS(a_,MM(5,RS4)),cx,5) \ + pl(SS(a_,MM(6,RS4)),cx,7) \ + pm(6,0) \ + pm(6,1) \ + pm(6,2) \ + pm(6,3) \ + puq(0,a_,ax) \ + pl(SS(a_,MM(7,RS4)),cx,0) \ + pm(6,4) \ + pm(6,5) \ + pm(6,7) \ + pm(6,0) \ + puq(1,SS(a_,RS4),ax) \ + puq(2,SS(a_,MM(2,RS4)),ax) \ + puq(3,SS(a_,MM(3,RS4)),ax) \ + puq(4,SS(a_,MM(4,RS4)),ax) \ + puq(5,SS(a_,MM(5,RS4)),ax) \ + puq(7,SS(a_,MM(6,RS4)),ax) \ + puq(0,SS(a_,MM(7,RS4)),ax) +#undef lpcpsc_5 +#define lpcpsc_5(a_) +#undef dpcpsc_5 +#define dpcpsc_5(a_) p8_cpsc_5(a_) +#undef plcpsc_5 +#define plcpsc_5 32 + +#undef cpsc_cdp +#define cpsc_cdp(a_) pc(a_,5) pm(6,a_) ps(CSHUF,5,5) pm(7,5) pa(5,a_) +#undef p1_4_cpsc_5c +#define p1_4_cpsc_5c(a_) +#undef p1_2_cpsc_5c +#define p1_2_cpsc_5c(a_) \ + pld(a_,cx,0) \ + cpsc_cdp(0) \ + pud(0,a_,ax) +#undef p1_cpsc_5c +#define p1_cpsc_5c(a_) \ + pl(a_,cx,0) \ + cpsc_cdp(0) \ + puq(0,a_,ax) +#undef p2_cpsc_5c +#define p2_cpsc_5c(a_) \ + pl(a_,cx,0) \ + pl(SS(a_,RS4),cx,1) \ + cpsc_cdp(0) \ + cpsc_cdp(1) \ + puq(0,a_,ax) \ + puq(1,SS(a_,RS4),ax) +#undef p4_cpsc_5c +#define p4_cpsc_5c(a_) \ + pl(SS(a_,MM(0,RS4)),cx,0) \ + pl(SS(a_,MM(1,RS4)),cx,1) \ + pl(SS(a_,MM(2,RS4)),cx,2) \ + pl(SS(a_,MM(3,RS4)),cx,3) \ + cpsc_cdp(0) \ + cpsc_cdp(1) \ + cpsc_cdp(2) \ + cpsc_cdp(3) \ + puq(0,a_,ax) \ + puq(1,SS(a_,RS4),ax) \ + puq(2,SS(a_,MM(2,RS4)),ax) \ + puq(3,SS(a_,MM(3,RS4)),ax) +#undef p8_cpsc_5c +#define p8_cpsc_5c(a_) \ + f(nta,SS(a_,MM((SS(2,CL)),RS4)),cx) \ + pl(SS(a_,MM(0,RS4)),cx,0) \ + pl(SS(a_,MM(1,RS4)),cx,1) \ + pl(SS(a_,MM(2,RS4)),cx,2) \ + pl(SS(a_,MM(3,RS4)),cx,3) \ + pl(SS(a_,MM(4,RS4)),cx,4) \ + cpsc_cdp(0) \ + cpsc_cdp(1) \ + puq(0,a_,ax) \ + pl(SS(a_,MM(5,RS4)),cx,0) \ + cpsc_cdp(2) \ + cpsc_cdp(3) \ + puq(1,SS(a_,RS4),ax) \ + pl(SS(a_,MM(6,RS4)),cx,1) \ + cpsc_cdp(4) \ + cpsc_cdp(0) \ + puq(2,SS(a_,MM(2,RS4)),ax) \ + pl(SS(a_,MM(7,RS4)),cx,2) \ + cpsc_cdp(1) \ + cpsc_cdp(2) \ + puq(3,SS(a_,MM(3,RS4)),ax) \ + puq(4,SS(a_,MM(4,RS4)),ax) \ + puq(0,SS(a_,MM(5,RS4)),ax) \ + puq(1,SS(a_,MM(6,RS4)),ax) \ + puq(2,SS(a_,MM(7,RS4)),ax) +#undef lpcpsc_5c +#define lpcpsc_5c(a_) +#undef dpcpsc_5c +#define dpcpsc_5c(a_) p8_cpsc_5c(a_) +#undef plcpsc_5c +#define plcpsc_5c 32 + +#undef p1_4_cpsc_1 +#define p1_4_cpsc_1(a_) \ + pls(a_,ax,2) \ + pmsr(3,2) \ + pus(2,a_,cx) +#undef p1_2_cpsc_1 +#define p1_2_cpsc_1(a_) \ + pld(a_,ax,2) \ + pm(3,2) \ + pud(2,a_,cx) +#undef p1_cpsc_1 +#define p1_cpsc_1(a_) \ + plq(a_,ax,2) \ + pm(3,2) \ + pu(2,a_,cx) +#undef p2_cpsc_1 +#define p2_cpsc_1(a_) \ + plq(SS(a_,RS4),ax,4) \ + pm(3,2) \ + pu(2,a_,cx) \ + f(nta,SS(a_,MM((SS(2,CL)),RS4)),cx) \ + f(nta,SS(a_,MM((SS(2,CL)),RS4)),ax) \ + plq(SS(a_,MM(2,RS4)),ax,2) \ + pm(3,4) \ + pu(4,SS(a_,RS4),cx) +#undef lpcpsc_1 +#define lpcpsc_1(a_) \ + f(nta,SS(a_,MM((SS(0,CL)),RS4)),cx) \ + f(nta,SS(a_,MM((SS(0,CL)),RS4)),ax) \ + plq(SS(a_,MM(0,RS4)),ax,2) +#undef dpcpsc_1 +#define dpcpsc_1(a_) \ + plq(SS(a_,RS4),ax,4) \ + pm(3,2) \ + pu(2,a_,cx) \ + pm(3,4) \ + pu(4,SS(a_,RS4),cx) +#undef plcpsc_1 +#define plcpsc_1 8 + +#undef p1_4_cpsc_2 +#define p1_4_cpsc_2(a_) \ + pls(a_,ax,2) \ + pmsr(3,2) \ + pus(2,a_,cx) +#undef p1_2_cpsc_2 +#define p1_2_cpsc_2(a_) \ + pld(a_,ax,2) \ + pm(3,2) \ + pud(2,a_,cx) +#undef p1_cpsc_2 +#define p1_cpsc_2(a_) \ + plq(a_,ax,2) \ + pm(3,2) \ + pu(2,a_,cx) +#undef p2_cpsc_2 +#define p2_cpsc_2(a_) \ + plq(a_,ax,2) \ + plq(SS(a_,RS4),ax,4) \ + pm(3,2) \ + pm(3,4) \ + pu(2,a_,cx) \ + pu(4,SS(a_,RS4),cx) +#undef p4_cpsc_2 +#define p4_cpsc_2(a_) \ + f(nta,SS(a_,MM((SS(2,CL)),RS4)),cx) \ + f(nta,SS(a_,MM((SS(2,CL)),RS4)),ax) \ + plq(SS(a_,MM(2,RS4)),ax,7) \ + pm(3,6) \ + pu(4,a_,cx) \ + plq(SS(a_,MM(3,RS4)),ax,2) \ + pm(3,7) \ + pu(6,SS(a_,RS4),cx) \ + f(nta,SS(a_,MM((SS(4,CL)),RS4)),cx) \ + f(nta,SS(a_,MM((SS(4,CL)),RS4)),ax) \ + plq(SS(a_,MM(4,RS4)),ax,4) \ + pm(3,2) \ + pu(7,SS(a_,MM(2,RS4)),cx) \ + plq(SS(a_,MM(5,RS4)),ax,6) \ + pm(3,4) \ + pu(2,SS(a_,MM(3,RS4)),cx) +#undef lpcpsc_2 +#define lpcpsc_2(a_) \ + f(nta,SS(a_,MM((SS(0,CL)),RS4)),cx) \ + f(nta,SS(a_,MM((SS(0,CL)),RS4)),ax) \ + plq(SS(a_,MM(0,RS4)),ax,4) \ + plq(SS(a_,MM(1,RS4)),ax,6) \ + pm(3,4) +#undef dpcpsc_2 +#define dpcpsc_2(a_) \ + f(nta,SS(a_,MM((SS(2,CL)),RS4)),cx) \ + f(nta,SS(a_,MM((SS(2,CL)),RS4)),ax) \ + plq(SS(a_,MM(2,RS4)),ax,7) \ + pm(3,6) \ + pu(4,a_,cx) \ + plq(SS(a_,MM(3,RS4)),ax,2) \ + pm(3,7) \ + pu(6,SS(a_,RS4),cx) \ + pm(3,2) \ + pu(7,SS(a_,MM(2,RS4)),cx) \ + pu(2,SS(a_,MM(3,RS4)),cx) +#undef plcpsc_2 +#define plcpsc_2 RS4 + + +#undef p1_4_iamax_1 +#define p1_4_iamax_1(a_) \ + px(4) \ + pls(a_,ax,4) \ + pan(2,4) \ + pc(3,5) \ + pcm(6,4,5) \ + paxs(4,3) \ + pan(5,6) \ + pann(0,5) \ + pasr(5,6) \ + pasr(1,0) \ + ps(57,0,0) +#undef p1_2_iamax_1 +#define p1_2_iamax_1(a_) \ + px(4) \ + pld(a_,ax,4) \ + pan(2,4) \ + pc(3,5) \ + pcm(6,4,5) \ + pax(4,3) \ + pan(5,6) \ + pann(0,5) \ + pa(5,6) \ + pasr(1,0) \ + ps(57,0,0)\ + pasr(1,0) \ + ps(57,0,0) +#undef p1_iamax_1 +#define p1_iamax_1(a_) \ + plq(a_,ax,4) \ + pan(2,4) \ + pc(3,5) \ + pcm(6,4,5) \ + pax(4,3) \ + pan(5,6) \ + pann(0,5) \ + pa(5,6) \ + pa(1,0) +#define p2_iamax_1(a_) \ + plq(SS(a_,RS4),ax,4) \ + pan(2,4) \ + pc(3,5) \ + pcm(6,4,5) \ + pax(4,3) \ + pan(5,6) \ + pann(0,5) \ + pa(5,6) \ + pa(1,0) \ + f(nta,SS(a_,MM(SS(2,CL),RS4)),ax) \ + plq(SS(a_,MM(2,RS4)),ax,4) \ + pan(2,4) \ + pc(3,5) \ + pcm(6,4,5) \ + pax(4,3) \ + pan(5,6) \ + pann(0,5) \ + pa(5,6) \ + pa(1,0) +#undef lpiamax_1 +#define lpiamax_1(a_) \ + f(nta,SS(a_,MM(CL,RS4)),ax) \ + plq(a_,ax,4) \ + pan(2,4) \ + pc(3,5) \ + pcm(6,4,5) \ + pax(4,3) \ + pan(5,6) \ + pann(0,5) \ + pa(5,6) \ + pa(1,0) +#undef dpiamax_1 +#define dpiamax_1(a_) \ + plq(SS(a_,RS4),ax,4) \ + pan(2,4) \ + pc(3,5) \ + pcm(6,4,5) \ + pax(4,3) \ + pan(5,6) \ + pann(0,5) \ + pa(5,6) \ + pa(1,0) +#undef pliamax_1 +#define pliamax_1 8 + +#undef p1_4_iamax_1d +#define p1_4_iamax_1d(a_) +#undef p1_2_iamax_1d +#define p1_2_iamax_1d(a_) \ + px(4) \ + pld(a_,ax,4) \ + dbg(2) \ + pan(2,4) \ + dbg(4) \ + pc(3,5) \ + dbg(5) \ + pcm(6,4,5) \ + dbg(5) \ + pax(4,3) \ + dbg(3) \ + pan(5,6) \ + dbg(6) \ + pann(0,5) \ + dbg(5) \ + pa(5,6) \ + dbg(6) \ + pasr(1,0) \ + dbg(0) \ + ps(1,0,0) +#undef p1_iamax_1d +#define p1_iamax_1d(a_) \ + plq(a_,ax,4) \ + dbg(2) \ + pan(2,4) \ + dbg(4) \ + pc(3,5) \ + dbg(5) \ + pcm(6,4,5) \ + dbg(5) \ + pax(4,3) \ + dbg(3) \ + pan(5,6) \ + dbg(6) \ + pann(0,5) \ + dbg(5) \ + pa(5,6) \ + dbg(6) \ + pa(1,0) +#define p2_iamax_1d(a_) \ + plq(SS(a_,RS4),ax,4) \ + dbg(2) \ + pan(2,4) \ + dbg(4) \ + pc(3,5) \ + dbg(5) \ + pcm(6,4,5) \ + dbg(5) \ + pax(4,3) \ + dbg(3) \ + pan(5,6) \ + dbg(6) \ + pann(0,5) \ + dbg(5) \ + pa(5,6) \ + dbg(6) \ + pa(1,0) \ + dbg(0) \ + f(nta,SS(a_,MM(SS(2,CL),RS4)),ax) \ + plq(SS(a_,MM(2,RS4)),ax,4) \ + dbg(2) \ + pan(2,4) \ + dbg(4) \ + pc(3,5) \ + dbg(5) \ + pcm(6,4,5) \ + dbg(5) \ + pax(4,3) \ + dbg(3) \ + pan(5,6) \ + dbg(6) \ + pann(0,5) \ + dbg(5) \ + pa(5,6) \ + dbg(6) \ + pa(1,0) +#undef lpiamax_1d +#define lpiamax_1d(a_) \ + f(nta,SS(a_,MM(CL,RS4)),ax) \ + plq(a_,ax,4) \ + dbg(2) \ + pan(2,4) \ + dbg(4) \ + pc(3,5) \ + dbg(5) \ + pcm(6,4,5) \ + dbg(5) \ + pax(4,3) \ + dbg(3) \ + pan(5,6) \ + dbg(6) \ + pann(0,5) \ + dbg(5) \ + pa(5,6) \ + dbg(6) \ + pa(1,0) +#undef dpiamax_1d +#define dpiamax_1d(a_) \ + plq(SS(a_,RS4),ax,4) \ + dbg(2) \ + pan(2,4) \ + dbg(4) \ + pc(3,5) \ + dbg(5) \ + pcm(6,4,5) \ + dbg(5) \ + pax(4,3) \ + dbg(3) \ + pan(5,6) \ + dbg(6) \ + pann(0,5) \ + dbg(5) \ + pa(5,6) \ + dbg(6) \ + pa(1,0) +#undef pliamax_1d +#define pliamax_1d 8 + diff --git a/kaldi_io/src/tools/ATLAS/include/contrib/camm_tpipe.h b/kaldi_io/src/tools/ATLAS/include/contrib/camm_tpipe.h new file mode 100644 index 0000000..03486cf --- /dev/null +++ b/kaldi_io/src/tools/ATLAS/include/contrib/camm_tpipe.h @@ -0,0 +1,331 @@ +/*************************************** + $Header: /cvsroot/math-atlas/AtlasBase/kernel/CammMaguire/camm_tpipe.h,v 1.2 2003/10/18 18:13:30 yycamm Exp $ + + +***************************************/ + + +/* #ifndef CAMM_TPIPE_H */ +/* #define CAMM_TPIPE_H */ /*+ To stop multiple inclusions. +*/ + +#ifndef BITS +#error BITS must be defined in camm_tpipe.h +#endif +#ifndef DIV +#error DIV must be defined in camm_tpipe.h +#endif +#ifndef INC +#error INC(a_) must be defined in camm_tpipe.h +#endif +#ifndef LR +#error LR must be defined in camm_tpipe.h +#endif + +#ifdef ALIGN + +#if defined(SREAL) + + test(4,ax) + je(a2) + +#undef KB +#define KB ( 1 /* / DIV */ ) +#include "camm_pipe3.h" + + KB_block + INC(4) + sub(1,LR) + + lab(a2) + +#endif + +#if defined(SREAL) || defined(DREAL) + + test(8,ax) + je(a4) + test(-2,LR) + je(a4) + +#undef KB +#define KB ( 2 /* / DIV */ ) +#include "camm_pipe3.h" + KB_block + INC(8) + sub(2,LR) + + lab(a4) + +#endif +#endif + +/* "movl %%edx,%%edi\n\t" */ + push(LR) + shr(BITS,LR) + shl(BITS,LR) + m(4,LR) + ra(ax,LR) + +#if defined(ALIGN) && ( defined(SCPLX) || defined(DCPLX) ) + test(12,ax) + je(loopa) +#endif + +#if !defined(ALIGN) || defined(SCPLX) || defined(DCPLX) +#undef plq +#define plq(a_,b_,c_) pl(a_,b_,c_) +#undef puq +#define puq(a_,b_,c_) pu(a_,b_,c_) +#undef plqx +#define plqx(a_,b_,c_,d_,e_) plx(a_,b_,c_,d_,e_) +#undef puqx +#define puqx(a_,b_,c_,d_,e_) pux(a_,b_,c_,d_,e_) +#else +#undef plq +#define plq(a_,b_,c_) pla(a_,b_,c_) +#undef puq +#define puq(a_,b_,c_) punt(a_,b_,c_) +#undef plqx +#define plqx(a_,b_,c_,d_,e_) plax(a_,b_,c_,d_,e_) +#undef puqx +#define puqx(a_,b_,c_,d_,e_) puax(a_,b_,c_,d_,e_) +#endif + + align + lab(loop) + cmp(ax,LR) + je(stop) + +#undef KB +#define KB ( (1 << BITS) /* / DIV */ ) +#include "camm_pipe3.h" + KB_block + INC(4*KB/**DIV*/) + + jmp(loop) + + lab(stop) + pop(LR) + +#if ( 1 << BITS ) > 128 + test(128,LR) + je(64) +#undef KB +#define KB ( 128 /* / DIV */ ) +#include "camm_pipe3.h" + KB_block + INC(512) + + lab(64) +#endif + +#if ( 1 << BITS ) > 64 + test(64,LR) + je(32) +#undef KB +#define KB ( 64 /* / DIV */ ) +#include "camm_pipe3.h" + KB_block + INC(256) + + lab(32) +#endif + +#if ( 1 << BITS ) > 32 + test(32,LR) + je(16) +#undef KB +#define KB ( 32 /* / DIV */ ) +#include "camm_pipe3.h" + KB_block + INC(128) + + lab(16) +#endif + +#if ( 1 << BITS ) > 16 + test(16,LR) + je(8) +#undef KB +#define KB ( 16 /* / DIV */ ) +#include "camm_pipe3.h" + KB_block + INC(64) + + lab(8) +#endif + +#if ( 1 << BITS ) > 8 + test(8,LR) + je(4) +#undef KB +#define KB ( 8 /* / DIV */ ) +#include "camm_pipe3.h" + KB_block + INC(32) + + lab(4) +#endif + +#if ( 1 << BITS ) > 4 + test(4,LR) + je(2) +#undef KB +#define KB ( 4 /* / DIV */ ) +#include "camm_pipe3.h" + KB_block + INC(16) + + lab(2) +#endif + +#if DIV != 4 && ( 1 << BITS ) > 2 + test(2,LR) + je(1) +#undef KB +#define KB ( 2 /* / DIV */ ) +#include "camm_pipe3.h" + KB_block + INC(8) + + lab(1) +#endif + +#if DIV == 1 && ( 1 << BITS ) > 1 + test(1,LR) + je(end) +#undef KB +#define KB ( 1 /* / DIV */ ) +#include "camm_pipe3.h" + KB_block + lab(end) +#endif + +#if defined (ALIGN) && ( defined(SCPLX) || defined(DCPLX) ) + + jmp(tend) + +#undef plq +#define plq(a_,b_,c_) pla(a_,b_,c_) +#undef puq +#define puq(a_,b_,c_) punt(a_,b_,c_) +#undef plqx +#define plqx(a_,b_,c_,d_,e_) plax(a_,b_,c_,d_,e_) +#undef puqx +#define puqx(a_,b_,c_,d_,e_) puax(a_,b_,c_,d_,e_) + + align + lab(loopa) + cmp(ax,LR) + je(stopa) + +#undef KB +#define KB ( (1 << BITS) /* / DIV */ ) +#include "camm_pipe3.h" + KB_block + INC(4*KB/**DIV*/) + + jmp(loopa) + + lab(stopa) + pop(LR) + +#if ( 1 << BITS ) > 128 + test(128,LR) + je(64a) +#undef KB +#define KB ( 128 /* / DIV */ ) +#include "camm_pipe3.h" + KB_block + INC(512) + + lab(64a) +#endif + +#if ( 1 << BITS ) > 64 + test(64,LR) + je(32a) +#undef KB +#define KB ( 64 /* / DIV */ ) +#include "camm_pipe3.h" + KB_block + INC(256) + + lab(32a) +#endif + +#if ( 1 << BITS ) > 32 + test(32,LR) + je(16a) +#undef KB +#define KB ( 32 /* / DIV */ ) +#include "camm_pipe3.h" + KB_block + INC(128) + + lab(16a) +#endif + +#if ( 1 << BITS ) > 16 + test(16,LR) + je(8a) +#undef KB +#define KB ( 16 /* / DIV */ ) +#include "camm_pipe3.h" + KB_block + INC(64) + + lab(8a) +#endif + +#if ( 1 << BITS ) > 8 + test(8,LR) + je(4a) +#undef KB +#define KB ( 8 /* / DIV */ ) +#include "camm_pipe3.h" + KB_block + INC(32) + + lab(4a) +#endif + +#if ( 1 << BITS ) > 4 + test(4,LR) + je(2a) +#undef KB +#define KB ( 4 /* / DIV */ ) +#include "camm_pipe3.h" + KB_block + INC(16) + + lab(2a) +#endif + +#if DIV != 4 && ( 1 << BITS ) > 2 + test(2,LR) + je(1a) +#undef KB +#define KB ( 2 /* / DIV */ ) +#include "camm_pipe3.h" + KB_block + INC(8) + + lab(1a) +#endif + +#if DIV == 1 && ( 1 << BITS ) > 1 + test(1,LR) + je(enda) +#undef KB +#define KB ( 1 /* / DIV */ ) +#include "camm_pipe3.h" + KB_block + lab(enda) +#endif + + lab(tend) + +#endif + +/* #endif */ /* CAMM_TPIPE_H */ diff --git a/kaldi_io/src/tools/ATLAS/include/contrib/camm_util.h b/kaldi_io/src/tools/ATLAS/include/contrib/camm_util.h new file mode 100644 index 0000000..6b150d3 --- /dev/null +++ b/kaldi_io/src/tools/ATLAS/include/contrib/camm_util.h @@ -0,0 +1,508 @@ +#ifndef CAMM_UTIL_H +#define CAMM_UTIL_H /*+ To stop multiple inclusions. +*/ + +typedef struct { + float r,i; +} Complex; + +typedef struct { + double r,i; +} Dcomplex; + +#undef str +#define str(a_) xstr(a_) +#undef xstr +#define xstr(a_) #a_ + +#undef val +#define val(a_) xval(a_) +#undef xval +#define xval(a_) a_ + +#ifndef Mjoin +#define Mjoin(a,b) mjoin(a,b) +#ifdef mjoin + #undef mjoin +#endif +#define mjoin(a,b) a ## b +#endif + +#undef VOLATILE +#define VOLATILE __volatile__ +#undef ASM +#define ASM __asm__ VOLATILE + +#ifdef BETA0 +#undef BL +#define BL b0 +#endif +#ifdef BETA1 +#undef BL +#define BL b1 +#endif +#ifdef BETAX +#undef BL +#define BL bX +#endif +#ifdef BETAXI0 +#undef BL +#define BL bXi0 +#endif + +#ifdef NO_TRANSPOSE +#ifdef GER +#ifdef Conj_ +#undef FEXT +#define FEXT Gc +#else +#undef FEXT +#define FEXT Gu +#endif +#else +#ifdef Conj_ +#undef FEXT +#define FEXT Nc +#else +#undef FEXT +#define FEXT N +#endif +#endif +#else +#ifdef Conj_ +#undef FEXT +#define FEXT C +#else +#undef FEXT +#define FEXT T +#endif +#endif + +#undef BLC +#define BLC Mjoin(FEXT,BL) + +#ifdef __GNUC__ +#undef NO_INLINE +#define NO_INLINE double sq(double x) {return x*x;} +#else +#undef NO_INLINE +#define NO_INLINE +#endif + +#undef lab +#define lab(a_) "\n" str(MY_FUNCTION) "_" str(N) "_" str(a_) ":\n\t" +#undef jmp +#define jmp(a_) "jmp " str(MY_FUNCTION) "_" str(N) "_" str(a_) "\n\t" +#undef je +#define je(a_) "je " str(MY_FUNCTION) "_" str(N) "_" str(a_) "\n\t" +#undef jge +#define jge(a_) "jge " str(MY_FUNCTION) "_" str(N) "_" str(a_) "\n\t" +#undef jle +#define jle(a_) "jle " str(MY_FUNCTION) "_" str(N) "_" str(a_) "\n\t" +#undef jl +#define jl(a_) "jl " str(MY_FUNCTION) "_" str(N) "_" str(a_) "\n\t" +#undef jne +#define jne(a_) "jne " str(MY_FUNCTION) "_" str(N) "_" str(a_) "\n\t" +#undef align +#define align ".align 16\n\t" +#undef test +#define test(a_,b_) "testl $" str(a_) ",%%e" str(b_) "\n\t" +#undef and +#define and(a_,b_) "andl $" str(a_) ",%%e" str(b_) "\n\t" +#undef sub +#define sub(a_,b_) "subl $" str(a_) ",%%e" str(b_) "\n\t" +#undef SS +#define SS(a_,b_) a_ + b_ +#undef MM +#define MM(a_,b_) a_ * b_ +#undef E4 +#define E4(a_) (( a_ >> 2 ) << 2 ) + +#undef TYPE +#undef SCALAR +#undef PREC +#undef CSHUF +#undef LSHUF +#undef HSHUF +#undef ISHUF +#undef RSHUF +#undef SINGLE +#undef REAL +#undef DIV + +#ifdef SCPLX +#define TYPE Complex +#define SCALAR Complex * +#define PREC c +#define CSHUF 177 +#define LSHUF 160 +#define HSHUF 245 +#define ISHUF 13*17 +#define RSHUF 8*17 +#define SINGLE +#define DIV 2 +/* #ifdef Conj_ */ +/* static const TYPE signd[2]={{-1.0,1.0},{-1.0,1.0}}; */ +/* #else */ + static const TYPE signd[2]={{1.0,-1.0},{1.0,-1.0}}; +/* #endif */ +#endif + +#ifdef SREAL +#define TYPE float +#define SCALAR float +#define PREC s +#define SINGLE +#define REAL +#define DIV 1 +#endif + +#ifdef DREAL +#define TYPE double +#define SCALAR double +#define PREC d +#define REAL +#define DIV 2 +#endif + +#ifdef DCPLX +#define TYPE Dcomplex +#define SCALAR Dcomplex * +#define PREC z +#define CSHUF 1 +#define LSHUF 0 +#define HSHUF 3 +#define ISHUF 3 +#define RSHUF 0 +#define DIV 4 +/* #ifdef Conj_ */ +/* static const TYPE signd[1]={{-1.0,1.0}}; */ +/* #else */ + static const TYPE signd[1]={{1.0,-1.0}}; +/* #endif */ +#endif + +#undef M11 +#define M11 0 +#undef M12 +#define M12 1 +#undef M13 +#define M13 2 +#undef M14 +#define M14 3 +#undef M15 +#define M15 4 +#undef M16 +#define M16 5 +#undef M17 +#define M17 6 +#undef M18 +#define M18 7 + +#undef M23 +#define M23 1 +#undef M24 +#define M24 2 +#undef M25 +#define M25 3 +#undef M26 +#define M26 4 +#undef M27 +#define M27 5 +#undef M28 +#define M28 6 + +#undef M33 +#define M33 0 +#undef M34 +#define M34 1 +#undef M35 +#define M35 2 +#undef M36 +#define M36 3 +#undef M37 +#define M37 4 +#undef M38 +#define M38 5 + +#undef P10 +#define P10 1 +#undef P11 +#define P11 2 +#undef P12 +#define P12 3 +#undef P13 +#define P13 4 +#undef P14 +#define P14 5 +#undef P15 +#define P15 6 +#undef P16 +#define P16 7 + +#undef XM +#define XM(a_,b_) M ## b_ ## a_ +#undef M +#define M(a_,b_) XM(a_,b_) + +#undef XP +#define XP(a_,b_) P ## b_ ## a_ +#undef P +#define P(a_,b_) XP(a_,b_) + +#undef mex +#define mex(a_) str(%%e ## a_) +#undef msx +#define msx(a_) "%%st(" str(a_) ")" + +#undef cmp +#define cmp(a_,b_) "cmp " mex(a_) "," mex(b_) "\n\t" +#undef icmpr +#define icmpr(a_,b_) "cmp " mex(a_) ",(" mex(b_) ")\n\t" +#undef f +#define f(a_,b_,c_) "prefetch" str(a_) " " str(b_) "(%%e" #c_ ")\n\t" +#undef pfx +#define pfx(a_,b_,c_,d_,e_) "prefetch" str(a_) " " str(b_) "(%%e" #c_ ",%%e" #d_ "," str(e_) ")\n\t" +#undef a +#define a(a_,b_) "addl $" str(a_) "," mex(b_) "\n\t" +#undef m +#define m(a_,b_) "imul $" str(a_) "," mex(b_) "\n\t" +#undef pop +#define pop(a_) "popl %%e" str(a_) "\n\t" +#undef push +#define push(a_) "pushl %%e" str(a_) "\n\t" +#undef d +#define d(a_,b_) "idiv $" str(a_) "," mex(b_) "\n\t" +#undef shl +#define shl(a_,b_) "shl $" str(a_) "," mex(b_) "\n\t" +#undef shr +#define shr(a_,b_) "shr $" str(a_) "," mex(b_) "\n\t" +#undef mm +#define mm(a_,b_) "mov $" str(a_) "," mex(b_) "\n\t" +#undef ra +#define ra(a_,b_) "addl %%e" str(a_) "," mex(b_) "\n\t" +#undef rs +#define rs(a_,b_) "subl %%e" str(a_) "," mex(b_) "\n\t" + +#undef fl +#define fl(a_,b_) "fldl " str(a_) "(" mex(b_) ")\n\t" +#undef fp +#define fp(a_,b_) "fstpl " str(a_) "(" mex(b_) ")\n\t" +#undef fd +#define fd(a_) "fld " msx(a_) "\n\t" +#undef fap +#define fap(a_,b_) "faddp " msx(a_) "," msx(b_) "\n\t" +/* #define fsp(a_) fx(a_) "fsubp %%st," msx(a_) "\n\t" */ +#undef fsp +#define fsp(a_) "fsubrp %%st," msx(a_) "\n\t" +#undef fmp +#define fmp(a_,b_) "fmulp " msx(a_) "," msx(b_) "\n\t" +#undef fa +#define fa(a_,b_) "fadd " msx(a_) "," msx(b_) "\n\t" +#undef fm +#define fm(a_,b_) "fmul " msx(a_) "," msx(b_) "\n\t" +#undef faa +#define faa(a_,b_) "faddl " str(a_) "(" mex(b_) ")\n\t" +#undef fma +#define fma(a_,b_) "fmull " str(a_) "(" mex(b_) ")\n\t" +#undef fz +#define fz "fldz\n\t" +#undef fx +#define fx(a_) "fxch " msx(a_) "\n\t" +#undef fx1 +#define fx1 "fxch\n\t" +#undef fc +#define fc(a_) "fstp " msx(a_) "\n\t" + + +#ifndef ATHLON + + +#if defined(DREAL) || defined(DCPLX) +#undef SSESUF +#define SSESUF "d " +#undef RS4 +#define RS4 16 +#undef RS +#define RS 4 +#else +#undef SSESUF +#define SSESUF "s " +#undef RS4 +#define RS4 16 +#undef RS +#define RS 4 +#endif + +#undef mxx +#define mxx(a_) str(%%xmm ## a_) +#undef prp +#define prp(a_,b_) "rcpp" SSESUF mxx(a_) "," mxx(b_) "\n\t" +#undef prps +#define prps(a_,b_) "rcps" SSESUF mxx(a_) "," mxx(b_) "\n\t" +#undef pann +#define pann(a_,b_) "andnp" SSESUF mxx(a_) "," mxx(b_) "\n\t" +#undef psqs +#define psqs(a_,b_) "sqrts" SSESUF mxx(a_) "," mxx(b_) "\n\t" +#undef por +#define por(a_,b_) "orp" SSESUF mxx(a_) "," mxx(b_) "\n\t" +#undef pan +#define pan(a_,b_) "andp" SSESUF mxx(a_) "," mxx(b_) "\n\t" +#undef pcm +#define pcm(a_,b_,c_) "cmpp" SSESUF " $" str(a_) "," mxx(b_) "," mxx(c_) "\n\t" +#undef pcms +#define pcms(a_,b_,c_) "cmps" SSESUF " $" str(a_) "," mxx(b_) "," mxx(c_) "\n\t" +#undef pax +#define pax(a_,b_) "maxp" SSESUF mxx(a_) "," mxx(b_) "\n\t" +#undef paxs +#define paxs(a_,b_) "maxs" SSESUF mxx(a_) "," mxx(b_) "\n\t" +#undef pd +#define pd(a_,b_) "divp" SSESUF mxx(a_) "," mxx(b_) "\n\t" +#undef pdsr +#define pdsr(a_,b_) "divs" SSESUF mxx(a_) "," mxx(b_) "\n\t" +#undef pxx +#define pxx(a_,b_) "xorp" SSESUF mxx(a_) "," mxx(b_) "\n\t" +#undef px +#define px(a_) "xorp" SSESUF mxx(a_) "," mxx(a_) "\n\t" +#undef pm +#define pm(a_,b_) "mulp" SSESUF mxx(a_) "," mxx(b_) "\n\t" +#undef pa +#define pa(a_,b_) "addp" SSESUF mxx(a_) "," mxx(b_) "\n\t" +#undef pmm +#define pmm(a_,b_,c_) "mulp" SSESUF str(a_) "(" mex(b_) ")," mxx(c_) "\n\t" +#undef pam +#define pam(a_,b_,c_) "addp" SSESUF str(a_) "(" mex(b_) ")," mxx(c_) "\n\t" +#undef pl +#define pl(a_,b_,c_) "movup" SSESUF str(a_) "(" mex(b_) ")," mxx(c_) "\n\t" +#undef pla +#define pla(a_,b_,c_) "movap" SSESUF str(a_) "(" mex(b_) ")," mxx(c_) "\n\t" +#undef pu +#define pu(a_,b_,c_) "movup" SSESUF mxx(a_) "," str(b_) "(" mex(c_) ")\n\t" +#undef punt +#define punt(a_,b_,c_) "movntp" SSESUF mxx(a_) "," str(b_) "(" mex(c_) ")\n\t" +#undef pua +#define pua(a_,b_,c_) "movap" SSESUF mxx(a_) "," str(b_) "(" mex(c_) ")\n\t" +#undef pud +#define pud(a_,b_,c_) "movlp" SSESUF mxx(a_) "," str(b_) "(" mex(c_) ")\n\t" +#undef pudr +#define pudr(a_,b_) "movlp" SSESUF mxx(a_) "," mxx(b_) "\n\t" +#undef pc +#define pc(a_,b_) "movap" SSESUF mxx(a_) "," mxx(b_) "\n\t" +#undef ps +#define ps(a_,b_,c_) "shufp" SSESUF " $" str(a_) "," mxx(b_) "," mxx(c_) "\n\t" +#undef phl +#define phl(a_,b_) "movhlp" SSESUF mxx(a_) "," mxx(b_) "\n\t" +#undef pus +#define pus(a_,b_,c_) "movs" SSESUF mxx(a_) "," str(b_) "(" mex(c_) ")\n\t" +#undef pls +#define pls(a_,b_,c_) "movs" SSESUF str(a_) "(" mex(b_) ")," mxx(c_) "\n\t" +#undef pld +#define pld(a_,b_,c_) "movlp" SSESUF str(a_) "(" mex(b_) ")," mxx(c_) "\n\t" +#undef plh +#define plh(a_,b_) "movlhp" SSESUF mxx(a_) "," mxx(b_) "\n\t" +#undef pas +#define pas(a_,b_,c_) "adds" SSESUF str(a_) "(" mex(b_) ")," mxx(c_) "\n\t" +#undef pms +#define pms(a_,b_,c_) "muls" SSESUF str(a_) "(" mex(b_) ")," mxx(c_) "\n\t" +#undef pcs +#define pcs(a_,b_) "movs" SSESUF mxx(a_) "," mxx(b_) "\n\t" +#undef pasr +#define pasr(a_,b_) "adds" SSESUF mxx(a_) "," mxx(b_) "\n\t" +#undef pmsr +#define pmsr(a_,b_) "muls" SSESUF mxx(a_) "," mxx(b_) "\n\t" +#undef pul +#define pul(a_,b_) "unpcklp" SSESUF mxx(a_) "," mxx(b_) "\n\t" +#undef puh +#define puh(a_,b_) "unpckhp" SSESUF mxx(a_) "," mxx(b_) "\n\t" + +#undef plsx +#define plsx(a_,b_,c_,d_,e_) \ + "movs" SSESUF str(a_) "(" mex(b_) "," mex(c_) "," #d_ ")," mxx(e_) "\n\t" +#undef plx +#define plx(a_,b_,c_,d_,e_) \ + "movup" SSESUF str(a_) "(" mex(b_) "," mex(c_) "," #d_ ")," mxx(e_) "\n\t" +#undef plax +#define plax(a_,b_,c_,d_,e_) \ + "movap" SSESUF str(a_) "(" mex(b_) "," mex(c_) "," #d_ ")," mxx(e_) "\n\t" +#undef pasx +#define pasx(a_,b_,c_,d_,e_) \ + "adds" SSESUF str(a_) "(" mex(b_) "," mex(c_) "," #d_ ")," mxx(e_) "\n\t" +#undef pusx +#define pusx(a_,b_,c_,d_,e_) \ + "movs" SSESUF mxx(a_) "," str(b_) "(" mex(c_) "," mex(d_) "," #e_ ")\n\t" +#undef pux +#define pux(a_,b_,c_,d_,e_) \ + "movup" SSESUF mxx(a_) "," str(b_) "(" mex(c_) "," mex(d_) "," #e_ ")\n\t" +#undef puax +#define puax(a_,b_,c_,d_,e_) \ + "movap" SSESUF mxx(a_) "," str(b_) "(" mex(c_) "," mex(d_) "," #e_ ")\n\t" +#undef pudx +#define pudx(a_,b_,c_,d_,e_) \ + "movlp" SSESUF mxx(a_) "," str(b_) "(" mex(c_) "," mex(d_) "," #e_ ")\n\t" + +#undef pldx +#define pldx(a_,b_,c_,d_,e_) \ + "movlp" SSESUF str(a_) "(" mex(b_) "," mex(c_) "," #d_ ")," mxx(e_) "\n\t" + +#else + +#undef RS4 +#define RS4 8 +#undef RS +#define RS 2 + +#undef mxx +#define mxx(a_) str(%%mm ## a_) +#undef pul +#define pul(a_,b_) "punpckldq " mxx(a_) "," mxx(b_) "\n\t" +#undef puh +#define puh(a_,b_) "punpckhdq " mxx(a_) "," mxx(b_) "\n\t" + +#undef px +#define px(a_) "pxor " mxx(a_) "," mxx(a_) "\n\t" +#undef pm +#define pm(a_,b_) "pfmul " mxx(a_) "," mxx(b_) "\n\t" +#undef pa +#define pa(a_,b_) "pfadd " mxx(a_) "," mxx(b_) "\n\t" +#undef pac +#define pac(a_,b_) "pfacc " mxx(a_) "," mxx(b_) "\n\t" +#undef pmm +#define pmm(a_,b_,c_) "pfmul " str(a_) "(" mex(b_) ")," mxx(c_) "\n\t" +#undef pam +#define pam(a_,b_,c_) "pfadd " str(a_) "(" mex(b_) ")," mxx(c_) "\n\t" +#undef pl +#define pl(a_,b_,c_) "movq " str(a_) "(" mex(b_) ")," mxx(c_) "\n\t" +#undef pla +#define pla(a_,b_,c_) "movq " str(a_) "(" mex(b_) ")," mxx(c_) "\n\t" +#undef pu +#define pu(a_,b_,c_) "movq " mxx(a_) "," str(b_) "(" mex(c_) ")\n\t" +#undef pc +#define pc(a_,b_) "movq " mxx(a_) "," mxx(b_) "\n\t" +#undef ps +#define ps(a_,b_,c_) "pswapd " mxx(b_) "," mxx(c_) "\n\t" +#undef phl +#define phl(a_,b_) "punpckhdq " mxx(a_) "," mxx(b_) "\n\t" +#undef plh +#define plh(a_,b_) "punpckldq " mxx(a_) "," mxx(b_) "\n\t" +#undef pus +#define pus(a_,b_,c_) "movd " mxx(a_) "," str(b_) "(" mex(c_) ")\n\t" +#undef pls +#define pls(a_,b_,c_) "movd " str(a_) "(" mex(b_) ")," mxx(c_) "\n\t" + +#undef plsx +#define plsx(a_,b_,c_,d_,e_) \ + "movd " str(a_) "(" mex(b_) "," mex(c_) "," #d_ ")," mxx(e_) "\n\t" +#undef plx +#define plx(a_,b_,c_,d_,e_) \ + "movq " str(a_) "(" mex(b_) "," mex(c_) "," #d_ ")," mxx(e_) "\n\t" +#undef pasx +#define pasx(a_,b_,c_,d_,e_) \ + "addss " str(a_) "(" mex(b_) "," mex(c_) "," #d_ ")," mxx(e_) "\n\t" +#undef pusx +#define pusx(a_,b_,c_,d_,e_) \ + "movd " mxx(a_) "," str(b_) "(" mex(c_) "," mex(d_) "," #e_ ")\n\t" +#undef pux +#define pux(a_,b_,c_,d_,e_) \ + "movq " mxx(a_) "," str(b_) "(" mex(c_) "," mex(d_) "," #e_ ")\n\t" +#endif + +#endif /* CAMM_UTIL_H */ diff --git a/kaldi_io/src/tools/ATLAS/include/f77wrap_lapack.h b/kaldi_io/src/tools/ATLAS/include/f77wrap_lapack.h new file mode 100644 index 0000000..89417f7 --- /dev/null +++ b/kaldi_io/src/tools/ATLAS/include/f77wrap_lapack.h @@ -0,0 +1,91 @@ +/* + * Automatically Tuned Linear Algebra Software v3.8.3 + * (C) Copyright 1999 R. Clint Whaley + * + * Code contributers : R. Clint Whaley, Antoine P. Petitet + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions + * are met: + * 1. Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * 2. Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions, and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * 3. The name of the ATLAS group or the names of its contributers may + * not be used to endorse or promote products derived from this + * software without specific written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + * ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED + * TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR + * PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE ATLAS GROUP OR ITS CONTRIBUTORS + * BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + * POSSIBILITY OF SUCH DAMAGE. + * + */ + +#ifndef F77WRAP_LAPACK_H +#define F77WRAP_LAPACK_H + +#include "atlas_misc.h" +#include "atlas_f77.h" + +#ifdef UpCase + #define PFW Mjoin(ATL_F77WRAP_,PREU) +#else + #define PFW Mjoin(atl_f77wrap_,PRE) +#endif + +#ifdef Add_ + #define F77WRAP_GETRI Mjoin(PFW,getri_) + #define F77WRAP_LAUUM Mjoin(PFW,lauum_) + #define F77WRAP_TRTRI Mjoin(PFW,trtri_) + #define F77WRAP_GETNB Mjoin(PFW,getnb_) + #define F77WRAP_GETRS Mjoin(PFW,getrs_) + #define F77WRAP_GETRF Mjoin(PFW,getrf_) + #define F77WRAP_GESV Mjoin(PFW,gesv_) + #define F77WRAP_POTRS Mjoin(PFW,potrs_) + #define F77WRAP_POTRF Mjoin(PFW,potrf_) + #define F77WRAP_POSV Mjoin(PFW,posv_) +#elif defined(Add__) + #define F77WRAP_GETRI Mjoin(PFW,getri__) + #define F77WRAP_LAUUM Mjoin(PFW,lauum__) + #define F77WRAP_TRTRI Mjoin(PFW,trtri__) + #define F77WRAP_GETNB Mjoin(PFW,getnb__) + #define F77WRAP_GETRS Mjoin(PFW,getrs__) + #define F77WRAP_GETRF Mjoin(PFW,getrf__) + #define F77WRAP_GESV Mjoin(PFW,gesv__) + #define F77WRAP_POTRS Mjoin(PFW,potrs__) + #define F77WRAP_POTRF Mjoin(PFW,potrf__) + #define F77WRAP_POSV Mjoin(PFW,posv__) +#elif defined(NoChange) + #define F77WRAP_GETRI Mjoin(PFW,getri) + #define F77WRAP_LAUUM Mjoin(PFW,lauum) + #define F77WRAP_TRTRI Mjoin(PFW,trtri) + #define F77WRAP_GETNB Mjoin(PFW,getnb) + #define F77WRAP_GETRS Mjoin(PFW,getrs) + #define F77WRAP_GETRF Mjoin(PFW,getrf) + #define F77WRAP_GESV Mjoin(PFW,gesv) + #define F77WRAP_POTRS Mjoin(PFW,potrs) + #define F77WRAP_POTRF Mjoin(PFW,potrf) + #define F77WRAP_POSV Mjoin(PFW,posv) +#elif defined(UpCase) + #define F77WRAP_GETRI Mjoin(PFW,GETRI) + #define F77WRAP_LAUUM Mjoin(PFW,LAUUM) + #define F77WRAP_TRTRI Mjoin(PFW,TRTRI) + #define F77WRAP_GETNB Mjoin(PFW,GETNB) + #define F77WRAP_GETRS Mjoin(PFW,GETRS) + #define F77WRAP_GETRF Mjoin(PFW,GETRF) + #define F77WRAP_GESV Mjoin(PFW,GESV) + #define F77WRAP_POTRS Mjoin(PFW,POTRS) + #define F77WRAP_POTRF Mjoin(PFW,POTRF) + #define F77WRAP_POSV Mjoin(PFW,POSV) +#endif + +#endif diff --git a/kaldi_io/src/tools/openfst/include/fst/accumulator.h b/kaldi_io/src/tools/openfst/include/fst/accumulator.h new file mode 100644 index 0000000..81d1847 --- /dev/null +++ b/kaldi_io/src/tools/openfst/include/fst/accumulator.h @@ -0,0 +1,745 @@ +// accumulator.h + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Copyright 2005-2010 Google, Inc. +// Author: [email protected] (Michael Riley) +// +// \file +// Classes to accumulate arc weights. Useful for weight lookahead. + +#ifndef FST_LIB_ACCUMULATOR_H__ +#define FST_LIB_ACCUMULATOR_H__ + +#include <algorithm> +#include <functional> +#include <tr1/unordered_map> +using std::tr1::unordered_map; +using std::tr1::unordered_multimap; +#include <vector> +using std::vector; + +#include <fst/arcfilter.h> +#include <fst/arcsort.h> +#include <fst/dfs-visit.h> +#include <fst/expanded-fst.h> +#include <fst/replace.h> + +namespace fst { + +// This class accumulates arc weights using the semiring Plus(). +template <class A> +class DefaultAccumulator { + public: + typedef A Arc; + typedef typename A::StateId StateId; + typedef typename A::Weight Weight; + + DefaultAccumulator() {} + + DefaultAccumulator(const DefaultAccumulator<A> &acc) {} + + void Init(const Fst<A>& fst, bool copy = false) {} + + void SetState(StateId) {} + + Weight Sum(Weight w, Weight v) { + return Plus(w, v); + } + + template <class ArcIterator> + Weight Sum(Weight w, ArcIterator *aiter, ssize_t begin, + ssize_t end) { + Weight sum = w; + aiter->Seek(begin); + for (ssize_t pos = begin; pos < end; aiter->Next(), ++pos) + sum = Plus(sum, aiter->Value().weight); + return sum; + } + + bool Error() const { return false; } + + private: + void operator=(const DefaultAccumulator<A> &); // Disallow +}; + + +// This class accumulates arc weights using the log semiring Plus() +// assuming an arc weight has a WeightConvert specialization to +// and from log64 weights. +template <class A> +class LogAccumulator { + public: + typedef A Arc; + typedef typename A::StateId StateId; + typedef typename A::Weight Weight; + + LogAccumulator() {} + + LogAccumulator(const LogAccumulator<A> &acc) {} + + void Init(const Fst<A>& fst, bool copy = false) {} + + void SetState(StateId) {} + + Weight Sum(Weight w, Weight v) { + return LogPlus(w, v); + } + + template <class ArcIterator> + Weight Sum(Weight w, ArcIterator *aiter, ssize_t begin, + ssize_t end) { + Weight sum = w; + aiter->Seek(begin); + for (ssize_t pos = begin; pos < end; aiter->Next(), ++pos) + sum = LogPlus(sum, aiter->Value().weight); + return sum; + } + + bool Error() const { return false; } + + private: + double LogPosExp(double x) { return log(1.0F + exp(-x)); } + + Weight LogPlus(Weight w, Weight v) { + double f1 = to_log_weight_(w).Value(); + double f2 = to_log_weight_(v).Value(); + if (f1 > f2) + return to_weight_(f2 - LogPosExp(f1 - f2)); + else + return to_weight_(f1 - LogPosExp(f2 - f1)); + } + + WeightConvert<Weight, Log64Weight> to_log_weight_; + WeightConvert<Log64Weight, Weight> to_weight_; + + void operator=(const LogAccumulator<A> &); // Disallow +}; + + +// Stores shareable data for fast log accumulator copies. +class FastLogAccumulatorData { + public: + FastLogAccumulatorData() {} + + vector<double> *Weights() { return &weights_; } + vector<ssize_t> *WeightPositions() { return &weight_positions_; } + double *WeightEnd() { return &(weights_[weights_.size() - 1]); }; + int RefCount() const { return ref_count_.count(); } + int IncrRefCount() { return ref_count_.Incr(); } + int DecrRefCount() { return ref_count_.Decr(); } + + private: + // Cummulative weight per state for all states s.t. # of arcs > + // arc_limit_ with arcs in order. Special first element per state + // being Log64Weight::Zero(); + vector<double> weights_; + // Maps from state to corresponding beginning weight position in + // weights_. Position -1 means no pre-computed weights for that + // state. + vector<ssize_t> weight_positions_; + RefCounter ref_count_; // Reference count. + + DISALLOW_COPY_AND_ASSIGN(FastLogAccumulatorData); +}; + + +// This class accumulates arc weights using the log semiring Plus() +// assuming an arc weight has a WeightConvert specialization to and +// from log64 weights. The member function Init(fst) has to be called +// to setup pre-computed weight information. +template <class A> +class FastLogAccumulator { + public: + typedef A Arc; + typedef typename A::StateId StateId; + typedef typename A::Weight Weight; + + explicit FastLogAccumulator(ssize_t arc_limit = 20, ssize_t arc_period = 10) + : arc_limit_(arc_limit), + arc_period_(arc_period), + data_(new FastLogAccumulatorData()), + error_(false) {} + + FastLogAccumulator(const FastLogAccumulator<A> &acc) + : arc_limit_(acc.arc_limit_), + arc_period_(acc.arc_period_), + data_(acc.data_), + error_(acc.error_) { + data_->IncrRefCount(); + } + + ~FastLogAccumulator() { + if (!data_->DecrRefCount()) + delete data_; + } + + void SetState(StateId s) { + vector<double> &weights = *data_->Weights(); + vector<ssize_t> &weight_positions = *data_->WeightPositions(); + + if (weight_positions.size() <= s) { + FSTERROR() << "FastLogAccumulator::SetState: invalid state id."; + error_ = true; + return; + } + + ssize_t pos = weight_positions[s]; + if (pos >= 0) + state_weights_ = &(weights[pos]); + else + state_weights_ = 0; + } + + Weight Sum(Weight w, Weight v) { + return LogPlus(w, v); + } + + template <class ArcIterator> + Weight Sum(Weight w, ArcIterator *aiter, ssize_t begin, + ssize_t end) { + if (error_) return Weight::NoWeight(); + Weight sum = w; + // Finds begin and end of pre-stored weights + ssize_t index_begin = -1, index_end = -1; + ssize_t stored_begin = end, stored_end = end; + if (state_weights_ != 0) { + index_begin = begin > 0 ? (begin - 1)/ arc_period_ + 1 : 0; + index_end = end / arc_period_; + stored_begin = index_begin * arc_period_; + stored_end = index_end * arc_period_; + } + // Computes sum before pre-stored weights + if (begin < stored_begin) { + ssize_t pos_end = min(stored_begin, end); + aiter->Seek(begin); + for (ssize_t pos = begin; pos < pos_end; aiter->Next(), ++pos) + sum = LogPlus(sum, aiter->Value().weight); + } + // Computes sum between pre-stored weights + if (stored_begin < stored_end) { + sum = LogPlus(sum, LogMinus(state_weights_[index_end], + state_weights_[index_begin])); + } + // Computes sum after pre-stored weights + if (stored_end < end) { + ssize_t pos_start = max(stored_begin, stored_end); + aiter->Seek(pos_start); + for (ssize_t pos = pos_start; pos < end; aiter->Next(), ++pos) + sum = LogPlus(sum, aiter->Value().weight); + } + return sum; + } + + template <class F> + void Init(const F &fst, bool copy = false) { + if (copy) + return; + vector<double> &weights = *data_->Weights(); + vector<ssize_t> &weight_positions = *data_->WeightPositions(); + if (!weights.empty() || arc_limit_ < arc_period_) { + FSTERROR() << "FastLogAccumulator: initialization error."; + error_ = true; + return; + } + weight_positions.reserve(CountStates(fst)); + + ssize_t weight_position = 0; + for(StateIterator<F> siter(fst); !siter.Done(); siter.Next()) { + StateId s = siter.Value(); + if (fst.NumArcs(s) >= arc_limit_) { + double sum = FloatLimits<double>::PosInfinity(); + weight_positions.push_back(weight_position); + weights.push_back(sum); + ++weight_position; + ssize_t narcs = 0; + for(ArcIterator<F> aiter(fst, s); !aiter.Done(); aiter.Next()) { + const A &arc = aiter.Value(); + sum = LogPlus(sum, arc.weight); + // Stores cumulative weight distribution per arc_period_. + if (++narcs % arc_period_ == 0) { + weights.push_back(sum); + ++weight_position; + } + } + } else { + weight_positions.push_back(-1); + } + } + } + + bool Error() const { return error_; } + + private: + double LogPosExp(double x) { + return x == FloatLimits<double>::PosInfinity() ? + 0.0 : log(1.0F + exp(-x)); + } + + double LogMinusExp(double x) { + return x == FloatLimits<double>::PosInfinity() ? + 0.0 : log(1.0F - exp(-x)); + } + + Weight LogPlus(Weight w, Weight v) { + double f1 = to_log_weight_(w).Value(); + double f2 = to_log_weight_(v).Value(); + if (f1 > f2) + return to_weight_(f2 - LogPosExp(f1 - f2)); + else + return to_weight_(f1 - LogPosExp(f2 - f1)); + } + + double LogPlus(double f1, Weight v) { + double f2 = to_log_weight_(v).Value(); + if (f1 == FloatLimits<double>::PosInfinity()) + return f2; + else if (f1 > f2) + return f2 - LogPosExp(f1 - f2); + else + return f1 - LogPosExp(f2 - f1); + } + + Weight LogMinus(double f1, double f2) { + if (f1 >= f2) { + FSTERROR() << "FastLogAcumulator::LogMinus: f1 >= f2 with f1 = " << f1 + << " and f2 = " << f2; + error_ = true; + return Weight::NoWeight(); + } + if (f2 == FloatLimits<double>::PosInfinity()) + return to_weight_(f1); + else + return to_weight_(f1 - LogMinusExp(f2 - f1)); + } + + WeightConvert<Weight, Log64Weight> to_log_weight_; + WeightConvert<Log64Weight, Weight> to_weight_; + + ssize_t arc_limit_; // Minimum # of arcs to pre-compute state + ssize_t arc_period_; // Save cumulative weights per 'arc_period_'. + bool init_; // Cumulative weights initialized? + FastLogAccumulatorData *data_; + double *state_weights_; + bool error_; + + void operator=(const FastLogAccumulator<A> &); // Disallow +}; + + +// Stores shareable data for cache log accumulator copies. +// All copies share the same cache. +template <class A> +class CacheLogAccumulatorData { + public: + typedef A Arc; + typedef typename A::StateId StateId; + typedef typename A::Weight Weight; + + CacheLogAccumulatorData(bool gc, size_t gc_limit) + : cache_gc_(gc), cache_limit_(gc_limit), cache_size_(0) {} + + ~CacheLogAccumulatorData() { + for(typename unordered_map<StateId, CacheState>::iterator it = cache_.begin(); + it != cache_.end(); + ++it) + delete it->second.weights; + } + + bool CacheDisabled() const { return cache_gc_ && cache_limit_ == 0; } + + vector<double> *GetWeights(StateId s) { + typename unordered_map<StateId, CacheState>::iterator it = cache_.find(s); + if (it != cache_.end()) { + it->second.recent = true; + return it->second.weights; + } else { + return 0; + } + } + + void AddWeights(StateId s, vector<double> *weights) { + if (cache_gc_ && cache_size_ >= cache_limit_) + GC(false); + cache_.insert(make_pair(s, CacheState(weights, true))); + if (cache_gc_) + cache_size_ += weights->capacity() * sizeof(double); + } + + int RefCount() const { return ref_count_.count(); } + int IncrRefCount() { return ref_count_.Incr(); } + int DecrRefCount() { return ref_count_.Decr(); } + + private: + // Cached information for a given state. + struct CacheState { + vector<double>* weights; // Accumulated weights for this state. + bool recent; // Has this state been accessed since last GC? + + CacheState(vector<double> *w, bool r) : weights(w), recent(r) {} + }; + + // Garbage collect: Delete from cache states that have not been + // accessed since the last GC ('free_recent = false') until + // 'cache_size_' is 2/3 of 'cache_limit_'. If it does not free enough + // memory, start deleting recently accessed states. + void GC(bool free_recent) { + size_t cache_target = (2 * cache_limit_)/3 + 1; + typename unordered_map<StateId, CacheState>::iterator it = cache_.begin(); + while (it != cache_.end() && cache_size_ > cache_target) { + CacheState &cs = it->second; + if (free_recent || !cs.recent) { + cache_size_ -= cs.weights->capacity() * sizeof(double); + delete cs.weights; + cache_.erase(it++); + } else { + cs.recent = false; + ++it; + } + } + if (!free_recent && cache_size_ > cache_target) + GC(true); + } + + unordered_map<StateId, CacheState> cache_; // Cache + bool cache_gc_; // Enable garbage collection + size_t cache_limit_; // # of bytes cached + size_t cache_size_; // # of bytes allowed before GC + RefCounter ref_count_; + + DISALLOW_COPY_AND_ASSIGN(CacheLogAccumulatorData); +}; + +// This class accumulates arc weights using the log semiring Plus() +// has a WeightConvert specialization to and from log64 weights. It +// is similar to the FastLogAccumator. However here, the accumulated +// weights are pre-computed and stored only for the states that are +// visited. The member function Init(fst) has to be called to setup +// this accumulator. +template <class A> +class CacheLogAccumulator { + public: + typedef A Arc; + typedef typename A::StateId StateId; + typedef typename A::Weight Weight; + + explicit CacheLogAccumulator(ssize_t arc_limit = 10, bool gc = false, + size_t gc_limit = 10 * 1024 * 1024) + : arc_limit_(arc_limit), fst_(0), data_( + new CacheLogAccumulatorData<A>(gc, gc_limit)), s_(kNoStateId), + error_(false) {} + + CacheLogAccumulator(const CacheLogAccumulator<A> &acc) + : arc_limit_(acc.arc_limit_), fst_(acc.fst_ ? acc.fst_->Copy() : 0), + data_(acc.data_), s_(kNoStateId), error_(acc.error_) { + data_->IncrRefCount(); + } + + ~CacheLogAccumulator() { + if (fst_) + delete fst_; + if (!data_->DecrRefCount()) + delete data_; + } + + // Arg 'arc_limit' specifies minimum # of arcs to pre-compute state. + void Init(const Fst<A> &fst, bool copy = false) { + if (copy) { + delete fst_; + } else if (fst_) { + FSTERROR() << "CacheLogAccumulator: initialization error."; + error_ = true; + return; + } + fst_ = fst.Copy(); + } + + void SetState(StateId s, int depth = 0) { + if (s == s_) + return; + s_ = s; + + if (data_->CacheDisabled() || error_) { + weights_ = 0; + return; + } + + if (!fst_) { + FSTERROR() << "CacheLogAccumulator::SetState: incorrectly initialized."; + error_ = true; + weights_ = 0; + return; + } + + weights_ = data_->GetWeights(s); + if ((weights_ == 0) && (fst_->NumArcs(s) >= arc_limit_)) { + weights_ = new vector<double>; + weights_->reserve(fst_->NumArcs(s) + 1); + weights_->push_back(FloatLimits<double>::PosInfinity()); + data_->AddWeights(s, weights_); + } + } + + Weight Sum(Weight w, Weight v) { + return LogPlus(w, v); + } + + template <class Iterator> + Weight Sum(Weight w, Iterator *aiter, ssize_t begin, + ssize_t end) { + if (weights_ == 0) { + Weight sum = w; + aiter->Seek(begin); + for (ssize_t pos = begin; pos < end; aiter->Next(), ++pos) + sum = LogPlus(sum, aiter->Value().weight); + return sum; + } else { + if (weights_->size() <= end) + for (aiter->Seek(weights_->size() - 1); + weights_->size() <= end; + aiter->Next()) + weights_->push_back(LogPlus(weights_->back(), + aiter->Value().weight)); + return LogPlus(w, LogMinus((*weights_)[end], (*weights_)[begin])); + } + } + + template <class Iterator> + size_t LowerBound(double w, Iterator *aiter) { + if (weights_ != 0) { + return lower_bound(weights_->begin() + 1, + weights_->end(), + w, + std::greater<double>()) + - weights_->begin() - 1; + } else { + size_t n = 0; + double x = FloatLimits<double>::PosInfinity(); + for(aiter->Reset(); !aiter->Done(); aiter->Next(), ++n) { + x = LogPlus(x, aiter->Value().weight); + if (x < w) break; + } + return n; + } + } + + bool Error() const { return error_; } + + private: + double LogPosExp(double x) { + return x == FloatLimits<double>::PosInfinity() ? + 0.0 : log(1.0F + exp(-x)); + } + + double LogMinusExp(double x) { + return x == FloatLimits<double>::PosInfinity() ? + 0.0 : log(1.0F - exp(-x)); + } + + Weight LogPlus(Weight w, Weight v) { + double f1 = to_log_weight_(w).Value(); + double f2 = to_log_weight_(v).Value(); + if (f1 > f2) + return to_weight_(f2 - LogPosExp(f1 - f2)); + else + return to_weight_(f1 - LogPosExp(f2 - f1)); + } + + double LogPlus(double f1, Weight v) { + double f2 = to_log_weight_(v).Value(); + if (f1 == FloatLimits<double>::PosInfinity()) + return f2; + else if (f1 > f2) + return f2 - LogPosExp(f1 - f2); + else + return f1 - LogPosExp(f2 - f1); + } + + Weight LogMinus(double f1, double f2) { + if (f1 >= f2) { + FSTERROR() << "CacheLogAcumulator::LogMinus: f1 >= f2 with f1 = " << f1 + << " and f2 = " << f2; + error_ = true; + return Weight::NoWeight(); + } + if (f2 == FloatLimits<double>::PosInfinity()) + return to_weight_(f1); + else + return to_weight_(f1 - LogMinusExp(f2 - f1)); + } + + WeightConvert<Weight, Log64Weight> to_log_weight_; + WeightConvert<Log64Weight, Weight> to_weight_; + + ssize_t arc_limit_; // Minimum # of arcs to cache a state + vector<double> *weights_; // Accumulated weights for cur. state + const Fst<A>* fst_; // Input fst + CacheLogAccumulatorData<A> *data_; // Cache data + StateId s_; // Current state + bool error_; + + void operator=(const CacheLogAccumulator<A> &); // Disallow +}; + + +// Stores shareable data for replace accumulator copies. +template <class Accumulator, class T> +class ReplaceAccumulatorData { + public: + typedef typename Accumulator::Arc Arc; + typedef typename Arc::StateId StateId; + typedef typename Arc::Label Label; + typedef T StateTable; + typedef typename T::StateTuple StateTuple; + + ReplaceAccumulatorData() : state_table_(0) {} + + ReplaceAccumulatorData(const vector<Accumulator*> &accumulators) + : state_table_(0), accumulators_(accumulators) {} + + ~ReplaceAccumulatorData() { + for (size_t i = 0; i < fst_array_.size(); ++i) + delete fst_array_[i]; + for (size_t i = 0; i < accumulators_.size(); ++i) + delete accumulators_[i]; + } + + void Init(const vector<pair<Label, const Fst<Arc>*> > &fst_tuples, + const StateTable *state_table) { + state_table_ = state_table; + accumulators_.resize(fst_tuples.size()); + for (size_t i = 0; i < accumulators_.size(); ++i) { + if (!accumulators_[i]) + accumulators_[i] = new Accumulator; + accumulators_[i]->Init(*(fst_tuples[i].second)); + fst_array_.push_back(fst_tuples[i].second->Copy()); + } + } + + const StateTuple &GetTuple(StateId s) const { + return state_table_->Tuple(s); + } + + Accumulator *GetAccumulator(size_t i) { return accumulators_[i]; } + + const Fst<Arc> *GetFst(size_t i) const { return fst_array_[i]; } + + int RefCount() const { return ref_count_.count(); } + int IncrRefCount() { return ref_count_.Incr(); } + int DecrRefCount() { return ref_count_.Decr(); } + + private: + const T * state_table_; + vector<Accumulator*> accumulators_; + vector<const Fst<Arc>*> fst_array_; + RefCounter ref_count_; + + DISALLOW_COPY_AND_ASSIGN(ReplaceAccumulatorData); +}; + +// This class accumulates weights in a ReplaceFst. The 'Init' method +// takes as input the argument used to build the ReplaceFst and the +// ReplaceFst state table. It uses accumulators of type 'Accumulator' +// in the underlying FSTs. +template <class Accumulator, + class T = DefaultReplaceStateTable<typename Accumulator::Arc> > +class ReplaceAccumulator { + public: + typedef typename Accumulator::Arc Arc; + typedef typename Arc::StateId StateId; + typedef typename Arc::Label Label; + typedef typename Arc::Weight Weight; + typedef T StateTable; + typedef typename T::StateTuple StateTuple; + + ReplaceAccumulator() + : init_(false), data_(new ReplaceAccumulatorData<Accumulator, T>()), + error_(false) {} + + ReplaceAccumulator(const vector<Accumulator*> &accumulators) + : init_(false), + data_(new ReplaceAccumulatorData<Accumulator, T>(accumulators)), + error_(false) {} + + ReplaceAccumulator(const ReplaceAccumulator<Accumulator, T> &acc) + : init_(acc.init_), data_(acc.data_), error_(acc.error_) { + if (!init_) + FSTERROR() << "ReplaceAccumulator: can't copy unintialized accumulator"; + data_->IncrRefCount(); + } + + ~ReplaceAccumulator() { + if (!data_->DecrRefCount()) + delete data_; + } + + // Does not take ownership of the state table, the state table + // is own by the ReplaceFst + void Init(const vector<pair<Label, const Fst<Arc>*> > &fst_tuples, + const StateTable *state_table) { + init_ = true; + data_->Init(fst_tuples, state_table); + } + + void SetState(StateId s) { + if (!init_) { + FSTERROR() << "ReplaceAccumulator::SetState: incorrectly initialized."; + error_ = true; + return; + } + StateTuple tuple = data_->GetTuple(s); + fst_id_ = tuple.fst_id - 1; // Replace FST ID is 1-based + data_->GetAccumulator(fst_id_)->SetState(tuple.fst_state); + if ((tuple.prefix_id != 0) && + (data_->GetFst(fst_id_)->Final(tuple.fst_state) != Weight::Zero())) { + offset_ = 1; + offset_weight_ = data_->GetFst(fst_id_)->Final(tuple.fst_state); + } else { + offset_ = 0; + offset_weight_ = Weight::Zero(); + } + } + + Weight Sum(Weight w, Weight v) { + if (error_) return Weight::NoWeight(); + return data_->GetAccumulator(fst_id_)->Sum(w, v); + } + + template <class ArcIterator> + Weight Sum(Weight w, ArcIterator *aiter, ssize_t begin, + ssize_t end) { + if (error_) return Weight::NoWeight(); + Weight sum = begin == end ? Weight::Zero() + : data_->GetAccumulator(fst_id_)->Sum( + w, aiter, begin ? begin - offset_ : 0, end - offset_); + if (begin == 0 && end != 0 && offset_ > 0) + sum = Sum(offset_weight_, sum); + return sum; + } + + bool Error() const { return error_; } + + private: + bool init_; + ReplaceAccumulatorData<Accumulator, T> *data_; + Label fst_id_; + size_t offset_; + Weight offset_weight_; + bool error_; + + void operator=(const ReplaceAccumulator<Accumulator, T> &); // Disallow +}; + +} // namespace fst + +#endif // FST_LIB_ACCUMULATOR_H__ diff --git a/kaldi_io/src/tools/openfst/include/fst/add-on.h b/kaldi_io/src/tools/openfst/include/fst/add-on.h new file mode 100644 index 0000000..ee21a93 --- /dev/null +++ b/kaldi_io/src/tools/openfst/include/fst/add-on.h @@ -0,0 +1,306 @@ +// add-on.h + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Copyright 2005-2010 Google, Inc. +// Author: [email protected] (Michael Riley) +// +// \file +// Fst implementation class to attach an arbitrary object with a +// read/write method to an FST and its file rep. The FST is given a +// new type name. + +#ifndef FST_LIB_ADD_ON_FST_H__ +#define FST_LIB_ADD_ON_FST_H__ + +#include <stddef.h> +#include <string> + +#include <fst/fst.h> + + +namespace fst { + +// Identifies stream data as an add-on fst. +static const int32 kAddOnMagicNumber = 446681434; + + +// +// Some useful add-on objects. +// + +// Nothing to save. +class NullAddOn { + public: + NullAddOn() {} + + static NullAddOn *Read(istream &istrm) { + return new NullAddOn(); + }; + + bool Write(ostream &ostrm) const { return true; } + + int RefCount() const { return ref_count_.count(); } + int IncrRefCount() { return ref_count_.Incr(); } + int DecrRefCount() { return ref_count_.Decr(); } + + private: + RefCounter ref_count_; + + DISALLOW_COPY_AND_ASSIGN(NullAddOn); +}; + + +// Create a new add-on from a pair of add-ons. +template <class A1, class A2> +class AddOnPair { + public: + // Argument reference count incremented. + AddOnPair(A1 *a1, A2 *a2) + : a1_(a1), a2_(a2) { + if (a1_) + a1_->IncrRefCount(); + if (a2_) + a2_->IncrRefCount(); + } + + ~AddOnPair() { + if (a1_ && !a1_->DecrRefCount()) + delete a1_; + if (a2_ && !a2_->DecrRefCount()) + delete a2_; + } + + A1 *First() const { return a1_; } + A2 *Second() const { return a2_; } + + static AddOnPair<A1, A2> *Read(istream &istrm) { + A1 *a1 = 0; + bool have_addon1 = false; + ReadType(istrm, &have_addon1); + if (have_addon1) + a1 = A1::Read(istrm); + + A2 *a2 = 0; + bool have_addon2 = false; + ReadType(istrm, &have_addon2); + if (have_addon2) + a2 = A2::Read(istrm); + + AddOnPair<A1, A2> *a = new AddOnPair<A1, A2>(a1, a2); + if (a1) + a1->DecrRefCount(); + if (a2) + a2->DecrRefCount(); + return a; + }; + + bool Write(ostream &ostrm) const { + bool have_addon1 = a1_; + WriteType(ostrm, have_addon1); + if (have_addon1) + a1_->Write(ostrm); + bool have_addon2 = a2_; + WriteType(ostrm, have_addon2); + if (have_addon2) + a2_->Write(ostrm); + return true; + } + + int RefCount() const { return ref_count_.count(); } + + int IncrRefCount() { + return ref_count_.Incr(); + } + + int DecrRefCount() { + return ref_count_.Decr(); + } + + private: + A1 *a1_; + A2 *a2_; + RefCounter ref_count_; + + DISALLOW_COPY_AND_ASSIGN(AddOnPair); +}; + + +// Add to an Fst F a type T object. T must have a 'T* Read(istream &)', +// a 'bool Write(ostream &)' method, and 'int RecCount(), 'int IncrRefCount()' +// and 'int DecrRefCount()' methods (e.g. 'MatcherData' in matcher-fst.h). +// The result is a new Fst implemenation with type name 'type'. +template<class F, class T> +class AddOnImpl : public FstImpl<typename F::Arc> { + public: + typedef typename F::Arc Arc; + typedef typename Arc::Label Label; + typedef typename Arc::Weight Weight; + typedef typename Arc::StateId StateId; + + using FstImpl<Arc>::SetType; + using FstImpl<Arc>::SetProperties; + using FstImpl<Arc>::WriteHeader; + + // If 't' is non-zero, its reference count is incremented. + AddOnImpl(const F &fst, const string &type, T *t = 0) + : fst_(fst), t_(t) { + SetType(type); + SetProperties(fst_.Properties(kFstProperties, false)); + if (t_) + t_->IncrRefCount(); + } + + // If 't' is non-zero, its reference count is incremented. + AddOnImpl(const Fst<Arc> &fst, const string &type, T *t = 0) + : fst_(fst), t_(t) { + SetType(type); + SetProperties(fst_.Properties(kFstProperties, false)); + if (t_) + t_->IncrRefCount(); + } + + AddOnImpl(const AddOnImpl<F, T> &impl) + : fst_(impl.fst_), t_(impl.t_) { + SetType(impl.Type()); + SetProperties(fst_.Properties(kCopyProperties, false)); + if (t_) + t_->IncrRefCount(); + } + + ~AddOnImpl() { + if (t_ && !t_->DecrRefCount()) + delete t_; + } + + StateId Start() const { return fst_.Start(); } + Weight Final(StateId s) const { return fst_.Final(s); } + size_t NumArcs(StateId s) const { return fst_.NumArcs(s); } + + size_t NumInputEpsilons(StateId s) const { + return fst_.NumInputEpsilons(s); + } + + size_t NumOutputEpsilons(StateId s) const { + return fst_.NumOutputEpsilons(s); + } + + size_t NumStates() const { return fst_.NumStates(); } + + static AddOnImpl<F, T> *Read(istream &strm, const FstReadOptions &opts) { + FstReadOptions nopts(opts); + FstHeader hdr; + if (!nopts.header) { + hdr.Read(strm, nopts.source); + nopts.header = &hdr; + } + AddOnImpl<F, T> *impl = new AddOnImpl<F, T>(nopts.header->FstType()); + if (!impl->ReadHeader(strm, nopts, kMinFileVersion, &hdr)) + return 0; + delete impl; // Used here only for checking types. + + int32 magic_number = 0; + ReadType(strm, &magic_number); // Ensures this is an add-on Fst. + if (magic_number != kAddOnMagicNumber) { + LOG(ERROR) << "AddOnImpl::Read: Bad add-on header: " << nopts.source; + return 0; + } + + FstReadOptions fopts(opts); + fopts.header = 0; // Contained header was written out. + F *fst = F::Read(strm, fopts); + if (!fst) + return 0; + + T *t = 0; + bool have_addon = false; + ReadType(strm, &have_addon); + if (have_addon) { // Read add-on object if present. + t = T::Read(strm); + if (!t) + return 0; + } + impl = new AddOnImpl<F, T>(*fst, nopts.header->FstType(), t); + delete fst; + if (t) + t->DecrRefCount(); + return impl; + } + + bool Write(ostream &strm, const FstWriteOptions &opts) const { + FstHeader hdr; + FstWriteOptions nopts(opts); + nopts.write_isymbols = false; // Let contained FST hold any symbols. + nopts.write_osymbols = false; + WriteHeader(strm, nopts, kFileVersion, &hdr); + WriteType(strm, kAddOnMagicNumber); // Ensures this is an add-on Fst. + FstWriteOptions fopts(opts); + fopts.write_header = true; // Force writing contained header. + if (!fst_.Write(strm, fopts)) + return false; + bool have_addon = t_; + WriteType(strm, have_addon); + if (have_addon) // Write add-on object if present. + t_->Write(strm); + return true; + } + + void InitStateIterator(StateIteratorData<Arc> *data) const { + fst_.InitStateIterator(data); + } + + void InitArcIterator(StateId s, ArcIteratorData<Arc> *data) const { + fst_.InitArcIterator(s, data); + } + + F &GetFst() { return fst_; } + + const F &GetFst() const { return fst_; } + + T *GetAddOn() const { return t_; } + + // If 't' is non-zero, its reference count is incremented. + void SetAddOn(T *t) { + if (t == t_) + return; + if (t_ && !t_->DecrRefCount()) + delete t_; + t_ = t; + if (t_) + t_->IncrRefCount(); + } + + private: + explicit AddOnImpl(const string &type) : t_(0) { + SetType(type); + SetProperties(kExpanded); + } + + // Current file format version + static const int kFileVersion = 1; + // Minimum file format version supported + static const int kMinFileVersion = 1; + + F fst_; + T *t_; + + void operator=(const AddOnImpl<F, T> &fst); // Disallow +}; + +template <class F, class T> const int AddOnImpl<F, T>::kFileVersion; +template <class F, class T> const int AddOnImpl<F, T>::kMinFileVersion; + + +} // namespace fst + +#endif // FST_LIB_ADD_ON_FST_H__ diff --git a/kaldi_io/src/tools/openfst/include/fst/arc-map.h b/kaldi_io/src/tools/openfst/include/fst/arc-map.h new file mode 100644 index 0000000..914f81c --- /dev/null +++ b/kaldi_io/src/tools/openfst/include/fst/arc-map.h @@ -0,0 +1,1146 @@ +// arc-map.h + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Copyright 2005-2010 Google, Inc. +// Author: [email protected] (Michael Riley) +// +// \file +// Class to map over/transform arcs e.g., change semirings or +// implement project/invert. Consider using when operation does +// not change the number of arcs (except possibly superfinal arcs). + +#ifndef FST_LIB_ARC_MAP_H__ +#define FST_LIB_ARC_MAP_H__ + +#include <tr1/unordered_map> +using std::tr1::unordered_map; +using std::tr1::unordered_multimap; +#include <string> +#include <utility> +using std::pair; using std::make_pair; + +#include <fst/cache.h> +#include <fst/mutable-fst.h> + + +namespace fst { + +// This determines how final weights are mapped. +enum MapFinalAction { + // A final weight is mapped into a final weight. An error + // is raised if this is not possible. + MAP_NO_SUPERFINAL, + + // A final weight is mapped to an arc to the superfinal state + // when the result cannot be represented as a final weight. + // The superfinal state will be added only if it is needed. + MAP_ALLOW_SUPERFINAL, + + // A final weight is mapped to an arc to the superfinal state + // unless the result can be represented as a final weight of weight + // Zero(). The superfinal state is always added (if the input is + // not the empty Fst). + MAP_REQUIRE_SUPERFINAL +}; + +// This determines how symbol tables are mapped. +enum MapSymbolsAction { + // Symbols should be cleared in the result by the map. + MAP_CLEAR_SYMBOLS, + + // Symbols should be copied from the input FST by the map. + MAP_COPY_SYMBOLS, + + // Symbols should not be modified in the result by the map itself. + // (They may set by the mapper). + MAP_NOOP_SYMBOLS +}; + +// ArcMapper Interface - class determinies how arcs and final weights +// are mapped. Useful for implementing operations that do not change +// the number of arcs (expect possibly superfinal arcs). +// +// class ArcMapper { +// public: +// typedef A FromArc; +// typedef B ToArc; +// +// // Maps an arc type A to arc type B. +// B operator()(const A &arc); +// // Specifies final action the mapper requires (see above). +// // The mapper will be passed final weights as arcs of the +// // form A(0, 0, weight, kNoStateId). +// MapFinalAction FinalAction() const; +// // Specifies input symbol table action the mapper requires (see above). +// MapSymbolsAction InputSymbolsAction() const; +// // Specifies output symbol table action the mapper requires (see above). +// MapSymbolsAction OutputSymbolsAction() const; +// // This specifies the known properties of an Fst mapped by this +// // mapper. It takes as argument the input Fst's known properties. +// uint64 Properties(uint64 props) const; +// }; +// +// The ArcMap functions and classes below will use the FinalAction() +// method of the mapper to determine how to treat final weights, +// e.g. whether to add a superfinal state. They will use the Properties() +// method to set the result Fst properties. +// +// We include a various map versions below. One dimension of +// variation is whether the mapping mutates its input, writes to a +// new result Fst, or is an on-the-fly Fst. Another dimension is how +// we pass the mapper. We allow passing the mapper by pointer +// for cases that we need to change the state of the user's mapper. +// This is the case with the encode mapper, which is reused during +// decoding. We also include map versions that pass the mapper +// by value or const reference when this suffices. + + +// Maps an arc type A using a mapper function object C, passed +// by pointer. This version modifies its Fst input. +template<class A, class C> +void ArcMap(MutableFst<A> *fst, C* mapper) { + typedef typename A::StateId StateId; + typedef typename A::Weight Weight; + + if (mapper->InputSymbolsAction() == MAP_CLEAR_SYMBOLS) + fst->SetInputSymbols(0); + + if (mapper->OutputSymbolsAction() == MAP_CLEAR_SYMBOLS) + fst->SetOutputSymbols(0); + + if (fst->Start() == kNoStateId) + return; + + uint64 props = fst->Properties(kFstProperties, false); + + MapFinalAction final_action = mapper->FinalAction(); + StateId superfinal = kNoStateId; + if (final_action == MAP_REQUIRE_SUPERFINAL) { + superfinal = fst->AddState(); + fst->SetFinal(superfinal, Weight::One()); + } + + for (StateId s = 0; s < fst->NumStates(); ++s) { + for (MutableArcIterator< MutableFst<A> > aiter(fst, s); + !aiter.Done(); aiter.Next()) { + const A &arc = aiter.Value(); + aiter.SetValue((*mapper)(arc)); + } + + switch (final_action) { + case MAP_NO_SUPERFINAL: + default: { + A final_arc = (*mapper)(A(0, 0, fst->Final(s), kNoStateId)); + if (final_arc.ilabel != 0 || final_arc.olabel != 0) { + FSTERROR() << "ArcMap: non-zero arc labels for superfinal arc"; + fst->SetProperties(kError, kError); + } + + fst->SetFinal(s, final_arc.weight); + break; + } + case MAP_ALLOW_SUPERFINAL: { + if (s != superfinal) { + A final_arc = (*mapper)(A(0, 0, fst->Final(s), kNoStateId)); + if (final_arc.ilabel != 0 || final_arc.olabel != 0) { + // Add a superfinal state if not already done. + if (superfinal == kNoStateId) { + superfinal = fst->AddState(); + fst->SetFinal(superfinal, Weight::One()); + } + final_arc.nextstate = superfinal; + fst->AddArc(s, final_arc); + fst->SetFinal(s, Weight::Zero()); + } else { + fst->SetFinal(s, final_arc.weight); + } + break; + } + } + case MAP_REQUIRE_SUPERFINAL: { + if (s != superfinal) { + A final_arc = (*mapper)(A(0, 0, fst->Final(s), kNoStateId)); + if (final_arc.ilabel != 0 || final_arc.olabel != 0 || + final_arc.weight != Weight::Zero()) + fst->AddArc(s, A(final_arc.ilabel, final_arc.olabel, + final_arc.weight, superfinal)); + fst->SetFinal(s, Weight::Zero()); + } + break; + } + } + } + fst->SetProperties(mapper->Properties(props), kFstProperties); +} + + +// Maps an arc type A using a mapper function object C, passed +// by value. This version modifies its Fst input. +template<class A, class C> +void ArcMap(MutableFst<A> *fst, C mapper) { + ArcMap(fst, &mapper); +} + + +// Maps an arc type A to an arc type B using mapper function +// object C, passed by pointer. This version writes the mapped +// input Fst to an output MutableFst. +template<class A, class B, class C> +void ArcMap(const Fst<A> &ifst, MutableFst<B> *ofst, C* mapper) { + typedef typename A::StateId StateId; + typedef typename A::Weight Weight; + + ofst->DeleteStates(); + + if (mapper->InputSymbolsAction() == MAP_COPY_SYMBOLS) + ofst->SetInputSymbols(ifst.InputSymbols()); + else if (mapper->InputSymbolsAction() == MAP_CLEAR_SYMBOLS) + ofst->SetInputSymbols(0); + + if (mapper->OutputSymbolsAction() == MAP_COPY_SYMBOLS) + ofst->SetOutputSymbols(ifst.OutputSymbols()); + else if (mapper->OutputSymbolsAction() == MAP_CLEAR_SYMBOLS) + ofst->SetOutputSymbols(0); + + uint64 iprops = ifst.Properties(kCopyProperties, false); + + if (ifst.Start() == kNoStateId) { + if (iprops & kError) ofst->SetProperties(kError, kError); + return; + } + + MapFinalAction final_action = mapper->FinalAction(); + if (ifst.Properties(kExpanded, false)) { + ofst->ReserveStates(CountStates(ifst) + + final_action == MAP_NO_SUPERFINAL ? 0 : 1); + } + + // Add all states. + for (StateIterator< Fst<A> > siter(ifst); !siter.Done(); siter.Next()) + ofst->AddState(); + + StateId superfinal = kNoStateId; + if (final_action == MAP_REQUIRE_SUPERFINAL) { + superfinal = ofst->AddState(); + ofst->SetFinal(superfinal, B::Weight::One()); + } + for (StateIterator< Fst<A> > siter(ifst); !siter.Done(); siter.Next()) { + StateId s = siter.Value(); + if (s == ifst.Start()) + ofst->SetStart(s); + + ofst->ReserveArcs(s, ifst.NumArcs(s)); + for (ArcIterator< Fst<A> > aiter(ifst, s); !aiter.Done(); aiter.Next()) + ofst->AddArc(s, (*mapper)(aiter.Value())); + + switch (final_action) { + case MAP_NO_SUPERFINAL: + default: { + B final_arc = (*mapper)(A(0, 0, ifst.Final(s), kNoStateId)); + if (final_arc.ilabel != 0 || final_arc.olabel != 0) { + FSTERROR() << "ArcMap: non-zero arc labels for superfinal arc"; + ofst->SetProperties(kError, kError); + } + ofst->SetFinal(s, final_arc.weight); + break; + } + case MAP_ALLOW_SUPERFINAL: { + B final_arc = (*mapper)(A(0, 0, ifst.Final(s), kNoStateId)); + if (final_arc.ilabel != 0 || final_arc.olabel != 0) { + // Add a superfinal state if not already done. + if (superfinal == kNoStateId) { + superfinal = ofst->AddState(); + ofst->SetFinal(superfinal, B::Weight::One()); + } + final_arc.nextstate = superfinal; + ofst->AddArc(s, final_arc); + ofst->SetFinal(s, B::Weight::Zero()); + } else { + ofst->SetFinal(s, final_arc.weight); + } + break; + } + case MAP_REQUIRE_SUPERFINAL: { + B final_arc = (*mapper)(A(0, 0, ifst.Final(s), kNoStateId)); + if (final_arc.ilabel != 0 || final_arc.olabel != 0 || + final_arc.weight != B::Weight::Zero()) + ofst->AddArc(s, B(final_arc.ilabel, final_arc.olabel, + final_arc.weight, superfinal)); + ofst->SetFinal(s, B::Weight::Zero()); + break; + } + } + } + uint64 oprops = ofst->Properties(kFstProperties, false); + ofst->SetProperties(mapper->Properties(iprops) | oprops, kFstProperties); +} + +// Maps an arc type A to an arc type B using mapper function +// object C, passed by value. This version writes the mapped input +// Fst to an output MutableFst. +template<class A, class B, class C> +void ArcMap(const Fst<A> &ifst, MutableFst<B> *ofst, C mapper) { + ArcMap(ifst, ofst, &mapper); +} + + +struct ArcMapFstOptions : public CacheOptions { + // ArcMapFst default caching behaviour is to do no caching. Most + // mappers are cheap and therefore we save memory by not doing + // caching. + ArcMapFstOptions() : CacheOptions(true, 0) {} + ArcMapFstOptions(const CacheOptions& opts) : CacheOptions(opts) {} +}; + + +template <class A, class B, class C> class ArcMapFst; + +// Implementation of delayed ArcMapFst. +template <class A, class B, class C> +class ArcMapFstImpl : public CacheImpl<B> { + public: + using FstImpl<B>::SetType; + using FstImpl<B>::SetProperties; + using FstImpl<B>::SetInputSymbols; + using FstImpl<B>::SetOutputSymbols; + + using VectorFstBaseImpl<typename CacheImpl<B>::State>::NumStates; + + using CacheImpl<B>::PushArc; + using CacheImpl<B>::HasArcs; + using CacheImpl<B>::HasFinal; + using CacheImpl<B>::HasStart; + using CacheImpl<B>::SetArcs; + using CacheImpl<B>::SetFinal; + using CacheImpl<B>::SetStart; + + friend class StateIterator< ArcMapFst<A, B, C> >; + + typedef B Arc; + typedef typename B::Weight Weight; + typedef typename B::StateId StateId; + + ArcMapFstImpl(const Fst<A> &fst, const C &mapper, + const ArcMapFstOptions& opts) + : CacheImpl<B>(opts), + fst_(fst.Copy()), + mapper_(new C(mapper)), + own_mapper_(true), + superfinal_(kNoStateId), + nstates_(0) { + Init(); + } + + ArcMapFstImpl(const Fst<A> &fst, C *mapper, + const ArcMapFstOptions& opts) + : CacheImpl<B>(opts), + fst_(fst.Copy()), + mapper_(mapper), + own_mapper_(false), + superfinal_(kNoStateId), + nstates_(0) { + Init(); + } + + ArcMapFstImpl(const ArcMapFstImpl<A, B, C> &impl) + : CacheImpl<B>(impl), + fst_(impl.fst_->Copy(true)), + mapper_(new C(*impl.mapper_)), + own_mapper_(true), + superfinal_(kNoStateId), + nstates_(0) { + Init(); + } + + ~ArcMapFstImpl() { + delete fst_; + if (own_mapper_) delete mapper_; + } + + StateId Start() { + if (!HasStart()) + SetStart(FindOState(fst_->Start())); + return CacheImpl<B>::Start(); + } + + Weight Final(StateId s) { + if (!HasFinal(s)) { + switch (final_action_) { + case MAP_NO_SUPERFINAL: + default: { + B final_arc = (*mapper_)(A(0, 0, fst_->Final(FindIState(s)), + kNoStateId)); + if (final_arc.ilabel != 0 || final_arc.olabel != 0) { + FSTERROR() << "ArcMapFst: non-zero arc labels for superfinal arc"; + SetProperties(kError, kError); + } + SetFinal(s, final_arc.weight); + break; + } + case MAP_ALLOW_SUPERFINAL: { + if (s == superfinal_) { + SetFinal(s, Weight::One()); + } else { + B final_arc = (*mapper_)(A(0, 0, fst_->Final(FindIState(s)), + kNoStateId)); + if (final_arc.ilabel == 0 && final_arc.olabel == 0) + SetFinal(s, final_arc.weight); + else + SetFinal(s, Weight::Zero()); + } + break; + } + case MAP_REQUIRE_SUPERFINAL: { + SetFinal(s, s == superfinal_ ? Weight::One() : Weight::Zero()); + break; + } + } + } + return CacheImpl<B>::Final(s); + } + + size_t NumArcs(StateId s) { + if (!HasArcs(s)) + Expand(s); + return CacheImpl<B>::NumArcs(s); + } + + size_t NumInputEpsilons(StateId s) { + if (!HasArcs(s)) + Expand(s); + return CacheImpl<B>::NumInputEpsilons(s); + } + + size_t NumOutputEpsilons(StateId s) { + if (!HasArcs(s)) + Expand(s); + return CacheImpl<B>::NumOutputEpsilons(s); + } + + uint64 Properties() const { return Properties(kFstProperties); } + + // Set error if found; return FST impl properties. + uint64 Properties(uint64 mask) const { + if ((mask & kError) && (fst_->Properties(kError, false) || + (mapper_->Properties(0) & kError))) + SetProperties(kError, kError); + return FstImpl<Arc>::Properties(mask); + } + + void InitArcIterator(StateId s, ArcIteratorData<B> *data) { + if (!HasArcs(s)) + Expand(s); + CacheImpl<B>::InitArcIterator(s, data); + } + + void Expand(StateId s) { + // Add exiting arcs. + if (s == superfinal_) { SetArcs(s); return; } + + for (ArcIterator< Fst<A> > aiter(*fst_, FindIState(s)); + !aiter.Done(); aiter.Next()) { + A aarc(aiter.Value()); + aarc.nextstate = FindOState(aarc.nextstate); + const B& barc = (*mapper_)(aarc); + PushArc(s, barc); + } + + // Check for superfinal arcs. + if (!HasFinal(s) || Final(s) == Weight::Zero()) + switch (final_action_) { + case MAP_NO_SUPERFINAL: + default: + break; + case MAP_ALLOW_SUPERFINAL: { + B final_arc = (*mapper_)(A(0, 0, fst_->Final(FindIState(s)), + kNoStateId)); + if (final_arc.ilabel != 0 || final_arc.olabel != 0) { + if (superfinal_ == kNoStateId) + superfinal_ = nstates_++; + final_arc.nextstate = superfinal_; + PushArc(s, final_arc); + } + break; + } + case MAP_REQUIRE_SUPERFINAL: { + B final_arc = (*mapper_)(A(0, 0, fst_->Final(FindIState(s)), + kNoStateId)); + if (final_arc.ilabel != 0 || final_arc.olabel != 0 || + final_arc.weight != B::Weight::Zero()) + PushArc(s, B(final_arc.ilabel, final_arc.olabel, + final_arc.weight, superfinal_)); + break; + } + } + SetArcs(s); + } + + private: + void Init() { + SetType("map"); + + if (mapper_->InputSymbolsAction() == MAP_COPY_SYMBOLS) + SetInputSymbols(fst_->InputSymbols()); + else if (mapper_->InputSymbolsAction() == MAP_CLEAR_SYMBOLS) + SetInputSymbols(0); + + if (mapper_->OutputSymbolsAction() == MAP_COPY_SYMBOLS) + SetOutputSymbols(fst_->OutputSymbols()); + else if (mapper_->OutputSymbolsAction() == MAP_CLEAR_SYMBOLS) + SetOutputSymbols(0); + + if (fst_->Start() == kNoStateId) { + final_action_ = MAP_NO_SUPERFINAL; + SetProperties(kNullProperties); + } else { + final_action_ = mapper_->FinalAction(); + uint64 props = fst_->Properties(kCopyProperties, false); + SetProperties(mapper_->Properties(props)); + if (final_action_ == MAP_REQUIRE_SUPERFINAL) + superfinal_ = 0; + } + } + + // Maps from output state to input state. + StateId FindIState(StateId s) { + if (superfinal_ == kNoStateId || s < superfinal_) + return s; + else + return s - 1; + } + + // Maps from input state to output state. + StateId FindOState(StateId is) { + StateId os; + if (superfinal_ == kNoStateId || is < superfinal_) + os = is; + else + os = is + 1; + + if (os >= nstates_) + nstates_ = os + 1; + + return os; + } + + + const Fst<A> *fst_; + C* mapper_; + bool own_mapper_; + MapFinalAction final_action_; + + StateId superfinal_; + StateId nstates_; + + void operator=(const ArcMapFstImpl<A, B, C> &); // disallow +}; + + +// Maps an arc type A to an arc type B using Mapper function object +// C. This version is a delayed Fst. +template <class A, class B, class C> +class ArcMapFst : public ImplToFst< ArcMapFstImpl<A, B, C> > { + public: + friend class ArcIterator< ArcMapFst<A, B, C> >; + friend class StateIterator< ArcMapFst<A, B, C> >; + + typedef B Arc; + typedef typename B::Weight Weight; + typedef typename B::StateId StateId; + typedef CacheState<B> State; + typedef ArcMapFstImpl<A, B, C> Impl; + + ArcMapFst(const Fst<A> &fst, const C &mapper, const ArcMapFstOptions& opts) + : ImplToFst<Impl>(new Impl(fst, mapper, opts)) {} + + ArcMapFst(const Fst<A> &fst, C* mapper, const ArcMapFstOptions& opts) + : ImplToFst<Impl>(new Impl(fst, mapper, opts)) {} + + ArcMapFst(const Fst<A> &fst, const C &mapper) + : ImplToFst<Impl>(new Impl(fst, mapper, ArcMapFstOptions())) {} + + ArcMapFst(const Fst<A> &fst, C* mapper) + : ImplToFst<Impl>(new Impl(fst, mapper, ArcMapFstOptions())) {} + + // See Fst<>::Copy() for doc. + ArcMapFst(const ArcMapFst<A, B, C> &fst, bool safe = false) + : ImplToFst<Impl>(fst, safe) {} + + // Get a copy of this ArcMapFst. See Fst<>::Copy() for further doc. + virtual ArcMapFst<A, B, C> *Copy(bool safe = false) const { + return new ArcMapFst<A, B, C>(*this, safe); + } + + virtual inline void InitStateIterator(StateIteratorData<B> *data) const; + + virtual void InitArcIterator(StateId s, ArcIteratorData<B> *data) const { + GetImpl()->InitArcIterator(s, data); + } + + private: + // Makes visible to friends. + Impl *GetImpl() const { return ImplToFst<Impl>::GetImpl(); } + + void operator=(const ArcMapFst<A, B, C> &fst); // disallow +}; + + +// Specialization for ArcMapFst. +template<class A, class B, class C> +class StateIterator< ArcMapFst<A, B, C> > : public StateIteratorBase<B> { + public: + typedef typename B::StateId StateId; + + explicit StateIterator(const ArcMapFst<A, B, C> &fst) + : impl_(fst.GetImpl()), siter_(*impl_->fst_), s_(0), + superfinal_(impl_->final_action_ == MAP_REQUIRE_SUPERFINAL) + { CheckSuperfinal(); } + + bool Done() const { return siter_.Done() && !superfinal_; } + + StateId Value() const { return s_; } + + void Next() { + ++s_; + if (!siter_.Done()) { + siter_.Next(); + CheckSuperfinal(); + } + else if (superfinal_) + superfinal_ = false; + } + + void Reset() { + s_ = 0; + siter_.Reset(); + superfinal_ = impl_->final_action_ == MAP_REQUIRE_SUPERFINAL; + CheckSuperfinal(); + } + + private: + // This allows base-class virtual access to non-virtual derived- + // class members of the same name. It makes the derived class more + // efficient to use but unsafe to further derive. + bool Done_() const { return Done(); } + StateId Value_() const { return Value(); } + void Next_() { Next(); } + void Reset_() { Reset(); } + + void CheckSuperfinal() { + if (impl_->final_action_ != MAP_ALLOW_SUPERFINAL || superfinal_) + return; + if (!siter_.Done()) { + B final_arc = (*impl_->mapper_)(A(0, 0, impl_->fst_->Final(s_), + kNoStateId)); + if (final_arc.ilabel != 0 || final_arc.olabel != 0) + superfinal_ = true; + } + } + + const ArcMapFstImpl<A, B, C> *impl_; + StateIterator< Fst<A> > siter_; + StateId s_; + bool superfinal_; // true if there is a superfinal state and not done + + DISALLOW_COPY_AND_ASSIGN(StateIterator); +}; + + +// Specialization for ArcMapFst. +template <class A, class B, class C> +class ArcIterator< ArcMapFst<A, B, C> > + : public CacheArcIterator< ArcMapFst<A, B, C> > { + public: + typedef typename A::StateId StateId; + + ArcIterator(const ArcMapFst<A, B, C> &fst, StateId s) + : CacheArcIterator< ArcMapFst<A, B, C> >(fst.GetImpl(), s) { + if (!fst.GetImpl()->HasArcs(s)) + fst.GetImpl()->Expand(s); + } + + private: + DISALLOW_COPY_AND_ASSIGN(ArcIterator); +}; + +template <class A, class B, class C> inline +void ArcMapFst<A, B, C>::InitStateIterator(StateIteratorData<B> *data) + const { + data->base = new StateIterator< ArcMapFst<A, B, C> >(*this); +} + + +// +// Utility Mappers +// + +// Mapper that returns its input. +template <class A> +struct IdentityArcMapper { + typedef A FromArc; + typedef A ToArc; + + A operator()(const A &arc) const { return arc; } + + MapFinalAction FinalAction() const { return MAP_NO_SUPERFINAL; } + + MapSymbolsAction InputSymbolsAction() const { return MAP_COPY_SYMBOLS; } + + MapSymbolsAction OutputSymbolsAction() const { return MAP_COPY_SYMBOLS;} + + uint64 Properties(uint64 props) const { return props; } +}; + + +// Mapper that returns its input with final states redirected to +// a single super-final state. +template <class A> +struct SuperFinalMapper { + typedef A FromArc; + typedef A ToArc; + + A operator()(const A &arc) const { return arc; } + + MapFinalAction FinalAction() const { return MAP_REQUIRE_SUPERFINAL; } + + MapSymbolsAction InputSymbolsAction() const { return MAP_COPY_SYMBOLS; } + + MapSymbolsAction OutputSymbolsAction() const { return MAP_COPY_SYMBOLS;} + + uint64 Properties(uint64 props) const { + return props & kAddSuperFinalProperties; + } +}; + + +// Mapper that leaves labels and nextstate unchanged and constructs a new weight +// from the underlying value of the arc weight. Requires that there is a +// WeightConvert class specialization that converts the weights. +template <class A, class B> +class WeightConvertMapper { + public: + typedef A FromArc; + typedef B ToArc; + typedef typename FromArc::Weight FromWeight; + typedef typename ToArc::Weight ToWeight; + + ToArc operator()(const FromArc &arc) const { + return ToArc(arc.ilabel, arc.olabel, + convert_weight_(arc.weight), arc.nextstate); + } + + MapFinalAction FinalAction() const { return MAP_NO_SUPERFINAL; } + + MapSymbolsAction InputSymbolsAction() const { return MAP_COPY_SYMBOLS; } + + MapSymbolsAction OutputSymbolsAction() const { return MAP_COPY_SYMBOLS;} + + uint64 Properties(uint64 props) const { return props; } + + private: + WeightConvert<FromWeight, ToWeight> convert_weight_; +}; + +// Non-precision-changing weight conversions. +// Consider using more efficient Cast (fst.h) instead. +typedef WeightConvertMapper<StdArc, LogArc> StdToLogMapper; +typedef WeightConvertMapper<LogArc, StdArc> LogToStdMapper; + +// Precision-changing weight conversions. +typedef WeightConvertMapper<StdArc, Log64Arc> StdToLog64Mapper; +typedef WeightConvertMapper<LogArc, Log64Arc> LogToLog64Mapper; +typedef WeightConvertMapper<Log64Arc, StdArc> Log64ToStdMapper; +typedef WeightConvertMapper<Log64Arc, LogArc> Log64ToLogMapper; + +// Mapper from A to GallicArc<A>. +template <class A, StringType S = STRING_LEFT> +struct ToGallicMapper { + typedef A FromArc; + typedef GallicArc<A, S> ToArc; + + typedef StringWeight<typename A::Label, S> SW; + typedef typename A::Weight AW; + typedef typename GallicArc<A, S>::Weight GW; + + ToArc operator()(const A &arc) const { + // 'Super-final' arc. + if (arc.nextstate == kNoStateId && arc.weight != AW::Zero()) + return ToArc(0, 0, GW(SW::One(), arc.weight), kNoStateId); + // 'Super-non-final' arc. + else if (arc.nextstate == kNoStateId) + return ToArc(0, 0, GW(SW::Zero(), arc.weight), kNoStateId); + // Epsilon label. + else if (arc.olabel == 0) + return ToArc(arc.ilabel, arc.ilabel, + GW(SW::One(), arc.weight), arc.nextstate); + // Regular label. + else + return ToArc(arc.ilabel, arc.ilabel, + GW(SW(arc.olabel), arc.weight), arc.nextstate); + } + + MapFinalAction FinalAction() const { return MAP_NO_SUPERFINAL; } + + MapSymbolsAction InputSymbolsAction() const { return MAP_COPY_SYMBOLS; } + + MapSymbolsAction OutputSymbolsAction() const { return MAP_CLEAR_SYMBOLS;} + + uint64 Properties(uint64 props) const { + return ProjectProperties(props, true) & kWeightInvariantProperties; + } +}; + + +// Mapper from GallicArc<A> to A. +template <class A, StringType S = STRING_LEFT> +struct FromGallicMapper { + typedef GallicArc<A, S> FromArc; + typedef A ToArc; + + typedef typename A::Label Label; + typedef StringWeight<Label, S> SW; + typedef typename A::Weight AW; + typedef typename GallicArc<A, S>::Weight GW; + + FromGallicMapper(Label superfinal_label = 0) + : superfinal_label_(superfinal_label), error_(false) {} + + A operator()(const FromArc &arc) const { + // 'Super-non-final' arc. + if (arc.nextstate == kNoStateId && arc.weight == GW::Zero()) + return A(arc.ilabel, 0, AW::Zero(), kNoStateId); + + SW w1 = arc.weight.Value1(); + AW w2 = arc.weight.Value2(); + StringWeightIterator<Label, S> iter1(w1); + + Label l = w1.Size() == 1 ? iter1.Value() : 0; + + if (l == kStringInfinity || l == kStringBad || + arc.ilabel != arc.olabel || w1.Size() > 1) { + FSTERROR() << "FromGallicMapper: unrepesentable weight"; + error_ = true; + } + + if (arc.ilabel == 0 && l != 0 && arc.nextstate == kNoStateId) + return A(superfinal_label_, l, w2, arc.nextstate); + else + return A(arc.ilabel, l, w2, arc.nextstate); + } + + MapFinalAction FinalAction() const { return MAP_ALLOW_SUPERFINAL; } + + MapSymbolsAction InputSymbolsAction() const { return MAP_COPY_SYMBOLS; } + + MapSymbolsAction OutputSymbolsAction() const { return MAP_CLEAR_SYMBOLS;} + + uint64 Properties(uint64 inprops) const { + uint64 outprops = inprops & kOLabelInvariantProperties & + kWeightInvariantProperties & kAddSuperFinalProperties; + if (error_) + outprops |= kError; + return outprops; + } + + private: + Label superfinal_label_; + mutable bool error_; +}; + + +// Mapper from GallicArc<A> to A. +template <class A, StringType S = STRING_LEFT> +struct GallicToNewSymbolsMapper { + typedef GallicArc<A, S> FromArc; + typedef A ToArc; + + typedef typename A::StateId StateId; + typedef typename A::Label Label; + typedef StringWeight<Label, S> SW; + typedef typename A::Weight AW; + typedef typename GallicArc<A, S>::Weight GW; + + GallicToNewSymbolsMapper(MutableFst<ToArc> *fst) + : fst_(fst), lmax_(0), osymbols_(fst->OutputSymbols()), + isymbols_(0), error_(false) { + fst_->DeleteStates(); + state_ = fst_->AddState(); + fst_->SetStart(state_); + fst_->SetFinal(state_, AW::One()); + if (osymbols_) { + string name = osymbols_->Name() + "_from_gallic"; + fst_->SetInputSymbols(new SymbolTable(name)); + isymbols_ = fst_->MutableInputSymbols(); + isymbols_->AddSymbol(osymbols_->Find((int64) 0), 0); + } else { + fst_->SetInputSymbols(0); + } + } + + A operator()(const FromArc &arc) { + // 'Super-non-final' arc. + if (arc.nextstate == kNoStateId && arc.weight == GW::Zero()) + return A(arc.ilabel, 0, AW::Zero(), kNoStateId); + + SW w1 = arc.weight.Value1(); + AW w2 = arc.weight.Value2(); + Label l; + + if (w1.Size() == 0) { + l = 0; + } else { + typename Map::iterator miter = map_.find(w1); + if (miter != map_.end()) { + l = (*miter).second; + } else { + l = ++lmax_; + map_.insert(pair<const SW, Label>(w1, l)); + StringWeightIterator<Label, S> iter1(w1); + StateId n; + string s; + for(size_t i = 0, p = state_; + i < w1.Size(); + ++i, iter1.Next(), p = n) { + n = i == w1.Size() - 1 ? state_ : fst_->AddState(); + fst_->AddArc(p, ToArc(i ? 0 : l, iter1.Value(), AW::One(), n)); + if (isymbols_) { + if (i) s = s + "_"; + s = s + osymbols_->Find(iter1.Value()); + } + } + if (isymbols_) + isymbols_->AddSymbol(s, l); + } + } + + if (l == kStringInfinity || l == kStringBad || arc.ilabel != arc.olabel) { + FSTERROR() << "GallicToNewSymbolMapper: unrepesentable weight"; + error_ = true; + } + + return A(arc.ilabel, l, w2, arc.nextstate); + } + + MapFinalAction FinalAction() const { return MAP_ALLOW_SUPERFINAL; } + + MapSymbolsAction InputSymbolsAction() const { return MAP_COPY_SYMBOLS; } + + MapSymbolsAction OutputSymbolsAction() const { return MAP_CLEAR_SYMBOLS; } + + uint64 Properties(uint64 inprops) const { + uint64 outprops = inprops & kOLabelInvariantProperties & + kWeightInvariantProperties & kAddSuperFinalProperties; + if (error_) + outprops |= kError; + return outprops; + } + + private: + class StringKey { + public: + size_t operator()(const SW &x) const { + return x.Hash(); + } + }; + + typedef unordered_map<SW, Label, StringKey> Map; + + MutableFst<ToArc> *fst_; + Map map_; + Label lmax_; + StateId state_; + const SymbolTable *osymbols_; + SymbolTable *isymbols_; + mutable bool error_; + + DISALLOW_COPY_AND_ASSIGN(GallicToNewSymbolsMapper); +}; + + +// Mapper to add a constant to all weights. +template <class A> +struct PlusMapper { + typedef A FromArc; + typedef A ToArc; + typedef typename A::Weight Weight; + + explicit PlusMapper(Weight w) : weight_(w) {} + + A operator()(const A &arc) const { + if (arc.weight == Weight::Zero()) + return arc; + Weight w = Plus(arc.weight, weight_); + return A(arc.ilabel, arc.olabel, w, arc.nextstate); + } + + MapFinalAction FinalAction() const { return MAP_NO_SUPERFINAL; } + + MapSymbolsAction InputSymbolsAction() const { return MAP_COPY_SYMBOLS; } + + MapSymbolsAction OutputSymbolsAction() const { return MAP_COPY_SYMBOLS;} + + uint64 Properties(uint64 props) const { + return props & kWeightInvariantProperties; + } + + private: + + + + Weight weight_; +}; + + +// Mapper to (right) multiply a constant to all weights. +template <class A> +struct TimesMapper { + typedef A FromArc; + typedef A ToArc; + typedef typename A::Weight Weight; + + explicit TimesMapper(Weight w) : weight_(w) {} + + A operator()(const A &arc) const { + if (arc.weight == Weight::Zero()) + return arc; + Weight w = Times(arc.weight, weight_); + return A(arc.ilabel, arc.olabel, w, arc.nextstate); + } + + MapFinalAction FinalAction() const { return MAP_NO_SUPERFINAL; } + + MapSymbolsAction InputSymbolsAction() const { return MAP_COPY_SYMBOLS; } + + MapSymbolsAction OutputSymbolsAction() const { return MAP_COPY_SYMBOLS;} + + uint64 Properties(uint64 props) const { + return props & kWeightInvariantProperties; + } + + private: + Weight weight_; +}; + + +// Mapper to reciprocate all non-Zero() weights. +template <class A> +struct InvertWeightMapper { + typedef A FromArc; + typedef A ToArc; + typedef typename A::Weight Weight; + + A operator()(const A &arc) const { + if (arc.weight == Weight::Zero()) + return arc; + Weight w = Divide(Weight::One(), arc.weight); + return A(arc.ilabel, arc.olabel, w, arc.nextstate); + } + + MapFinalAction FinalAction() const { return MAP_NO_SUPERFINAL; } + + MapSymbolsAction InputSymbolsAction() const { return MAP_COPY_SYMBOLS; } + + MapSymbolsAction OutputSymbolsAction() const { return MAP_COPY_SYMBOLS;} + + uint64 Properties(uint64 props) const { + return props & kWeightInvariantProperties; + } +}; + + +// Mapper to map all non-Zero() weights to One(). +template <class A, class B = A> +struct RmWeightMapper { + typedef A FromArc; + typedef B ToArc; + typedef typename FromArc::Weight FromWeight; + typedef typename ToArc::Weight ToWeight; + + B operator()(const A &arc) const { + ToWeight w = arc.weight != FromWeight::Zero() ? + ToWeight::One() : ToWeight::Zero(); + return B(arc.ilabel, arc.olabel, w, arc.nextstate); + } + + MapFinalAction FinalAction() const { return MAP_NO_SUPERFINAL; } + + MapSymbolsAction InputSymbolsAction() const { return MAP_COPY_SYMBOLS; } + + MapSymbolsAction OutputSymbolsAction() const { return MAP_COPY_SYMBOLS;} + + uint64 Properties(uint64 props) const { + return (props & kWeightInvariantProperties) | kUnweighted; + } +}; + + +// Mapper to quantize all weights. +template <class A, class B = A> +struct QuantizeMapper { + typedef A FromArc; + typedef B ToArc; + typedef typename FromArc::Weight FromWeight; + typedef typename ToArc::Weight ToWeight; + + QuantizeMapper() : delta_(kDelta) {} + + explicit QuantizeMapper(float d) : delta_(d) {} + + B operator()(const A &arc) const { + ToWeight w = arc.weight.Quantize(delta_); + return B(arc.ilabel, arc.olabel, w, arc.nextstate); + } + + MapFinalAction FinalAction() const { return MAP_NO_SUPERFINAL; } + + MapSymbolsAction InputSymbolsAction() const { return MAP_COPY_SYMBOLS; } + + MapSymbolsAction OutputSymbolsAction() const { return MAP_COPY_SYMBOLS;} + + uint64 Properties(uint64 props) const { + return props & kWeightInvariantProperties; + } + + private: + float delta_; +}; + + +// Mapper from A to B under the assumption: +// B::Weight = A::Weight::ReverseWeight +// B::Label == A::Label +// B::StateId == A::StateId +// The weight is reversed, while the label and nextstate preserved +// in the mapping. +template <class A, class B> +struct ReverseWeightMapper { + typedef A FromArc; + typedef B ToArc; + + B operator()(const A &arc) const { + return B(arc.ilabel, arc.olabel, arc.weight.Reverse(), arc.nextstate); + } + + MapFinalAction FinalAction() const { return MAP_NO_SUPERFINAL; } + + MapSymbolsAction InputSymbolsAction() const { return MAP_COPY_SYMBOLS; } + + MapSymbolsAction OutputSymbolsAction() const { return MAP_COPY_SYMBOLS;} + + uint64 Properties(uint64 props) const { return props; } +}; + +} // namespace fst + +#endif // FST_LIB_ARC_MAP_H__ diff --git a/kaldi_io/src/tools/openfst/include/fst/arc.h b/kaldi_io/src/tools/openfst/include/fst/arc.h new file mode 100644 index 0000000..5f4014b --- /dev/null +++ b/kaldi_io/src/tools/openfst/include/fst/arc.h @@ -0,0 +1,307 @@ +// arc.h + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Copyright 2005-2010 Google, Inc. +// Author: [email protected] (Michael Riley) +// +// \file +// +// Commonly used Fst arc types. + +#ifndef FST_LIB_ARC_H__ +#define FST_LIB_ARC_H__ + +#include <string> + + +#include <fst/expectation-weight.h> +#include <fst/float-weight.h> +#include <fst/lexicographic-weight.h> +#include <fst/power-weight.h> +#include <fst/product-weight.h> +#include <fst/signed-log-weight.h> +#include <fst/sparse-power-weight.h> +#include <iostream> +#include <fstream> +#include <sstream> +#include <fst/string-weight.h> + + +namespace fst { + +template <class W> +class ArcTpl { + public: + typedef W Weight; + typedef int Label; + typedef int StateId; + + ArcTpl(Label i, Label o, const Weight& w, StateId s) + : ilabel(i), olabel(o), weight(w), nextstate(s) {} + + ArcTpl() {} + + static const string &Type(void) { + static const string type = + (Weight::Type() == "tropical") ? "standard" : Weight::Type(); + return type; + } + + Label ilabel; + Label olabel; + Weight weight; + StateId nextstate; +}; + +typedef ArcTpl<TropicalWeight> StdArc; +typedef ArcTpl<LogWeight> LogArc; +typedef ArcTpl<Log64Weight> Log64Arc; +typedef ArcTpl<SignedLogWeight> SignedLogArc; +typedef ArcTpl<SignedLog64Weight> SignedLog64Arc; +typedef ArcTpl<MinMaxWeight> MinMaxArc; + + +// Arc with integer labels and state Ids and string weights. +template <StringType S = STRING_LEFT> +class StringArc { + public: + typedef int Label; + typedef StringWeight<int, S> Weight; + typedef int StateId; + + StringArc(Label i, Label o, Weight w, StateId s) + : ilabel(i), olabel(o), weight(w), nextstate(s) {} + + StringArc() {} + + static const string &Type() { // Arc type name + static const string type = + S == STRING_LEFT ? "standard_string" : + (S == STRING_RIGHT ? "right_standard_string" : + (S == STRING_LEFT_RESTRICT ? "restricted_string" : + "right_restricted_string")); + return type; + } + + Label ilabel; // Transition input label + Label olabel; // Transition output label + Weight weight; // Transition weight + StateId nextstate; // Transition destination state +}; + + +// Arc with label and state Id type the same as template arg and with +// weights over the Gallic semiring w.r.t the output labels and weights of A. +template <class A, StringType S = STRING_LEFT> +struct GallicArc { + typedef A Arc; + typedef typename A::Label Label; + typedef typename A::StateId StateId; + typedef GallicWeight<Label, typename A::Weight, S> Weight; + + GallicArc() {} + + GallicArc(Label i, Label o, Weight w, StateId s) + : ilabel(i), olabel(o), weight(w), nextstate(s) {} + + GallicArc(const A &arc) + : ilabel(arc.ilabel), olabel(arc.ilabel), + weight(arc.olabel, arc.weight), nextstate(arc.nextstate) {} + + static const string &Type() { // Arc type name + static const string type = + (S == STRING_LEFT ? "gallic_" : + (S == STRING_RIGHT ? "right_gallic_" : + (S == STRING_LEFT_RESTRICT ? "restricted_gallic_" : + "right_restricted_gallic_"))) + A::Type(); + return type; + } + + Label ilabel; // Transition input label + Label olabel; // Transition output label + Weight weight; // Transition weight + StateId nextstate; // Transition destination state +}; + + +// Arc with the reverse of the weight found in its template arg. +template <class A> struct ReverseArc { + typedef A Arc; + typedef typename A::Label Label; + typedef typename A::Weight AWeight; + typedef typename AWeight::ReverseWeight Weight; + typedef typename A::StateId StateId; + + ReverseArc(Label i, Label o, Weight w, StateId s) + : ilabel(i), olabel(o), weight(w), nextstate(s) {} + + ReverseArc() {} + + static const string &Type() { // Arc type name + static const string type = "reverse_" + Arc::Type(); + return type; + } + + Label ilabel; // Transition input label + Label olabel; // Transition output label + Weight weight; // Transition weight + StateId nextstate; // Transition destination state +}; + + +// Arc with integer labels and state Ids and lexicographic weights. +template<class W1, class W2> +struct LexicographicArc { + typedef int Label; + typedef LexicographicWeight<W1, W2> Weight; + typedef int StateId; + + LexicographicArc(Label i, Label o, Weight w, StateId s) + : ilabel(i), olabel(o), weight(w), nextstate(s) {} + + LexicographicArc() {} + + static const string &Type() { // Arc type name + static const string type = Weight::Type(); + return type; + } + + Label ilabel; // Transition input label + Label olabel; // Transition output label + Weight weight; // Transition weight + StateId nextstate; // Transition destination state +}; + + +// Arc with integer labels and state Ids and product weights. +template<class W1, class W2> +struct ProductArc { + typedef int Label; + typedef ProductWeight<W1, W2> Weight; + typedef int StateId; + + ProductArc(Label i, Label o, Weight w, StateId s) + : ilabel(i), olabel(o), weight(w), nextstate(s) {} + + ProductArc() {} + + static const string &Type() { // Arc type name + static const string type = Weight::Type(); + return type; + } + + Label ilabel; // Transition input label + Label olabel; // Transition output label + Weight weight; // Transition weight + StateId nextstate; // Transition destination state +}; + + +// Arc with label and state Id type the same as first template arg and with +// weights over the n-th cartesian power of the weight type of the +// template arg. +template <class A, unsigned int n> +struct PowerArc { + typedef A Arc; + typedef typename A::Label Label; + typedef typename A::StateId StateId; + typedef PowerWeight<typename A::Weight, n> Weight; + + PowerArc() {} + + PowerArc(Label i, Label o, Weight w, StateId s) + : ilabel(i), olabel(o), weight(w), nextstate(s) {} + + static const string &Type() { // Arc type name + static string type; + if (type.empty()) { + string power; + Int64ToStr(n, &power); + type = A::Type() + "_^" + power; + } + return type; + } + + Label ilabel; // Transition input label + Label olabel; // Transition output label + Weight weight; // Transition weight + StateId nextstate; // Transition destination state +}; + + +// Arc with label and state Id type the same as first template arg and with +// weights over the arbitrary cartesian power of the weight type. +template <class A, class K = int> +struct SparsePowerArc { + typedef A Arc; + typedef typename A::Label Label; + typedef typename A::StateId StateId; + typedef SparsePowerWeight<typename A::Weight, K> Weight; + + SparsePowerArc() {} + + SparsePowerArc(Label i, Label o, Weight w, StateId s) + : ilabel(i), olabel(o), weight(w), nextstate(s) {} + + static const string &Type() { // Arc type name + static string type; + if (type.empty()) { type = A::Type() + "_^n"; } + if(sizeof(K) != sizeof(uint32)) { + string size; + Int64ToStr(8 * sizeof(K), &size); + type += "_" + size; + } + return type; + } + + Label ilabel; // Transition input label + Label olabel; // Transition output label + Weight weight; // Transition weight + StateId nextstate; // Transition destination state +}; + + +// Arc with label and state Id type the same as first template arg and with +// expectation weight over the first template arg weight type and the +// second template arg. +template <class A, class X2> +struct ExpectationArc { + typedef A Arc; + typedef typename A::Label Label; + typedef typename A::StateId StateId; + typedef typename A::Weight X1; + typedef ExpectationWeight<X1, X2> Weight; + + ExpectationArc() {} + + ExpectationArc(Label i, Label o, Weight w, StateId s) + : ilabel(i), olabel(o), weight(w), nextstate(s) {} + + static const string &Type() { // Arc type name + static string type; + if (type.empty()) { + type = "expectation_" + A::Type() + "_" + X2::Type(); + } + return type; + } + + Label ilabel; // Transition input label + Label olabel; // Transition output label + Weight weight; // Transition weight + StateId nextstate; // Transition destination state +}; + +} // namespace fst + +#endif // FST_LIB_ARC_H__ diff --git a/kaldi_io/src/tools/openfst/include/fst/arcfilter.h b/kaldi_io/src/tools/openfst/include/fst/arcfilter.h new file mode 100644 index 0000000..179dc2c --- /dev/null +++ b/kaldi_io/src/tools/openfst/include/fst/arcfilter.h @@ -0,0 +1,99 @@ +// arcfilter.h + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Copyright 2005-2010 Google, Inc. +// Author: [email protected] (Michael Riley) +// +// \file +// Function objects to restrict which arcs are traversed in an FST. + +#ifndef FST_LIB_ARCFILTER_H__ +#define FST_LIB_ARCFILTER_H__ + + +#include <fst/fst.h> +#include <fst/util.h> + + +namespace fst { + +// True for all arcs. +template <class A> +class AnyArcFilter { +public: + bool operator()(const A &arc) const { return true; } +}; + + +// True for (input/output) epsilon arcs. +template <class A> +class EpsilonArcFilter { +public: + bool operator()(const A &arc) const { + return arc.ilabel == 0 && arc.olabel == 0; + } +}; + + +// True for input epsilon arcs. +template <class A> +class InputEpsilonArcFilter { +public: + bool operator()(const A &arc) const { + return arc.ilabel == 0; + } +}; + + +// True for output epsilon arcs. +template <class A> +class OutputEpsilonArcFilter { +public: + bool operator()(const A &arc) const { + return arc.olabel == 0; + } +}; + + +// True if specified labels match (don't match) when keep_match is +// true (false). +template <class A> +class MultiLabelArcFilter { +public: + typedef typename A::Label Label; + + MultiLabelArcFilter(bool match_input = true, bool keep_match = true) + : match_input_(match_input), + keep_match_(keep_match) {} + + + bool operator()(const A &arc) const { + Label label = match_input_ ? arc.ilabel : arc.olabel; + bool match = labels_.Find(label) != labels_.End(); + return keep_match_ ? match : !match; + } + + void AddLabel(Label label) { + labels_.Insert(label); + } + +private: + CompactSet<Label, kNoLabel> labels_; + bool match_input_; + bool keep_match_; +}; + +} // namespace fst + +#endif // FST_LIB_ARCFILTER_H__ diff --git a/kaldi_io/src/tools/openfst/include/fst/arcsort.h b/kaldi_io/src/tools/openfst/include/fst/arcsort.h new file mode 100644 index 0000000..37a51dc --- /dev/null +++ b/kaldi_io/src/tools/openfst/include/fst/arcsort.h @@ -0,0 +1,217 @@ +// arcsort.h + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Copyright 2005-2010 Google, Inc. +// Author: [email protected] (Michael Riley) +// +// \file +// Functions and classes to sort arcs in an FST. + +#ifndef FST_LIB_ARCSORT_H__ +#define FST_LIB_ARCSORT_H__ + +#include <algorithm> +#include <string> +#include <vector> +using std::vector; + +#include <fst/cache.h> +#include <fst/state-map.h> +#include <fst/test-properties.h> + + +namespace fst { + +template <class Arc, class Compare> +class ArcSortMapper { + public: + typedef Arc FromArc; + typedef Arc ToArc; + + typedef typename Arc::StateId StateId; + typedef typename Arc::Weight Weight; + + ArcSortMapper(const Fst<Arc> &fst, const Compare &comp) + : fst_(fst), comp_(comp), i_(0) {} + + // Allows updating Fst argument; pass only if changed. + ArcSortMapper(const ArcSortMapper<Arc, Compare> &mapper, + const Fst<Arc> *fst = 0) + : fst_(fst ? *fst : mapper.fst_), comp_(mapper.comp_), i_(0) {} + + StateId Start() { return fst_.Start(); } + Weight Final(StateId s) const { return fst_.Final(s); } + + void SetState(StateId s) { + i_ = 0; + arcs_.clear(); + arcs_.reserve(fst_.NumArcs(s)); + for (ArcIterator< Fst<Arc> > aiter(fst_, s); !aiter.Done(); aiter.Next()) + arcs_.push_back(aiter.Value()); + sort(arcs_.begin(), arcs_.end(), comp_); + } + + bool Done() const { return i_ >= arcs_.size(); } + const Arc &Value() const { return arcs_[i_]; } + void Next() { ++i_; } + + MapSymbolsAction InputSymbolsAction() const { return MAP_COPY_SYMBOLS; } + MapSymbolsAction OutputSymbolsAction() const { return MAP_COPY_SYMBOLS; } + uint64 Properties(uint64 props) const { return comp_.Properties(props); } + + private: + const Fst<Arc> &fst_; + const Compare &comp_; + vector<Arc> arcs_; + ssize_t i_; // current arc position + + void operator=(const ArcSortMapper<Arc, Compare> &); // disallow +}; + + +// Sorts the arcs in an FST according to function object 'comp' of +// type Compare. This version modifies its input. Comparison function +// objects ILabelCompare and OLabelCompare are provived by the +// library. In general, Compare must meet the requirements for an STL +// sort comparision function object. It must also have a member +// Properties(uint64) that specifies the known properties of the +// sorted FST; it takes as argument the input FST's known properties +// before the sort. +// +// Complexity: +// - Time: O(V D log D) +// - Space: O(D) +// where V = # of states and D = maximum out-degree. +template<class Arc, class Compare> +void ArcSort(MutableFst<Arc> *fst, Compare comp) { + ArcSortMapper<Arc, Compare> mapper(*fst, comp); + StateMap(fst, mapper); +} + +typedef CacheOptions ArcSortFstOptions; + +// Sorts the arcs in an FST according to function object 'comp' of +// type Compare. This version is a delayed Fst. Comparsion function +// objects ILabelCompare and OLabelCompare are provided by the +// library. In general, Compare must meet the requirements for an STL +// comparision function object (e.g. as used for STL sort). It must +// also have a member Properties(uint64) that specifies the known +// properties of the sorted FST; it takes as argument the input FST's +// known properties. +// +// Complexity: +// - Time: O(v d log d) +// - Space: O(d) +// where v = # of states visited, d = maximum out-degree of states +// visited. Constant time and space to visit an input state is assumed +// and exclusive of caching. +template <class A, class C> +class ArcSortFst : public StateMapFst<A, A, ArcSortMapper<A, C> > { + using StateMapFst<A, A, ArcSortMapper<A, C> >::GetImpl; + public: + typedef A Arc; + typedef typename Arc::StateId StateId; + typedef ArcSortMapper<A, C> M; + + ArcSortFst(const Fst<A> &fst, const C &comp) + : StateMapFst<A, A, M>(fst, ArcSortMapper<A, C>(fst, comp)) {} + + ArcSortFst(const Fst<A> &fst, const C &comp, const ArcSortFstOptions &opts) + : StateMapFst<A, A, M>(fst, ArcSortMapper<A, C>(fst, comp), opts) {} + + // See Fst<>::Copy() for doc. + ArcSortFst(const ArcSortFst<A, C> &fst, bool safe = false) + : StateMapFst<A, A, M>(fst, safe) {} + + // Get a copy of this ArcSortFst. See Fst<>::Copy() for further doc. + virtual ArcSortFst<A, C> *Copy(bool safe = false) const { + return new ArcSortFst(*this, safe); + } + + virtual size_t NumArcs(StateId s) const { + return GetImpl()->GetFst().NumArcs(s); + } + + virtual size_t NumInputEpsilons(StateId s) const { + return GetImpl()->GetFst().NumInputEpsilons(s); + } + + virtual size_t NumOutputEpsilons(StateId s) const { + return GetImpl()->GetFst().NumOutputEpsilons(s); + } +}; + + +// Specialization for ArcSortFst. +template <class A, class C> +class StateIterator< ArcSortFst<A, C> > + : public StateIterator< StateMapFst<A, A, ArcSortMapper<A, C> > > { + public: + explicit StateIterator(const ArcSortFst<A, C> &fst) + : StateIterator< StateMapFst<A, A, ArcSortMapper<A, C> > >(fst) {} +}; + + +// Specialization for ArcSortFst. +template <class A, class C> +class ArcIterator< ArcSortFst<A, C> > + : public ArcIterator< StateMapFst<A, A, ArcSortMapper<A, C> > > { + public: + ArcIterator(const ArcSortFst<A, C> &fst, typename A::StateId s) + : ArcIterator< StateMapFst<A, A, ArcSortMapper<A, C> > >(fst, s) {} +}; + + +// Compare class for comparing input labels of arcs. +template<class A> class ILabelCompare { + public: + bool operator() (A arc1, A arc2) const { + return arc1.ilabel < arc2.ilabel; + } + + uint64 Properties(uint64 props) const { + return (props & kArcSortProperties) | kILabelSorted | + (props & kAcceptor ? kOLabelSorted : 0); + } +}; + + +// Compare class for comparing output labels of arcs. +template<class A> class OLabelCompare { + public: + bool operator() (const A &arc1, const A &arc2) const { + return arc1.olabel < arc2.olabel; + } + + uint64 Properties(uint64 props) const { + return (props & kArcSortProperties) | kOLabelSorted | + (props & kAcceptor ? kILabelSorted : 0); + } +}; + + +// Useful aliases when using StdArc. +template<class C> class StdArcSortFst : public ArcSortFst<StdArc, C> { + public: + typedef StdArc Arc; + typedef C Compare; +}; + +typedef ILabelCompare<StdArc> StdILabelCompare; + +typedef OLabelCompare<StdArc> StdOLabelCompare; + +} // namespace fst + +#endif // FST_LIB_ARCSORT_H__ diff --git a/kaldi_io/src/tools/openfst/include/fst/bi-table.h b/kaldi_io/src/tools/openfst/include/fst/bi-table.h new file mode 100644 index 0000000..d220ce4 --- /dev/null +++ b/kaldi_io/src/tools/openfst/include/fst/bi-table.h @@ -0,0 +1,532 @@ +// bi-table.h + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Copyright 2005-2010 Google, Inc. +// Author: [email protected] (Michael Riley) +// +// \file +// Classes for representing a bijective mapping between an arbitrary entry +// of type T and a signed integral ID. + +#ifndef FST_LIB_BI_TABLE_H__ +#define FST_LIB_BI_TABLE_H__ + +#include <deque> +using std::deque; +#include <functional> +#include <vector> +using std::vector; + +#include <tr1/unordered_set> +using std::tr1::unordered_set; +using std::tr1::unordered_multiset; + +namespace fst { + +// BI TABLES - these determine a bijective mapping between an +// arbitrary entry of type T and an signed integral ID of type I. The IDs are +// allocated starting from 0 in order. +// +// template <class I, class T> +// class BiTable { +// public: +// +// // Required constructors. +// BiTable(); +// +// // Lookup integer ID from entry. If it doesn't exist and 'insert' +// / is true, then add it. Otherwise return -1. +// I FindId(const T &entry, bool insert = true); +// // Lookup entry from integer ID. +// const T &FindEntry(I) const; +// // # of stored entries. +// I Size() const; +// }; + +// An implementation using a hash map for the entry to ID mapping. +// H is the hash function and E is the equality function. +// If passed to the constructor, ownership is given to this class. + +template <class I, class T, class H, class E = std::equal_to<T> > +class HashBiTable { + public: + // Reserves space for 'table_size' elements. + explicit HashBiTable(size_t table_size = 0, H *h = 0, E *e = 0) + : hash_func_(h), + hash_equal_(e), + entry2id_(table_size, (h ? *h : H()), (e ? *e : E())) { + if (table_size) + id2entry_.reserve(table_size); + } + + HashBiTable(const HashBiTable<I, T, H, E> &table) + : hash_func_(table.hash_func_ ? new H(*table.hash_func_) : 0), + hash_equal_(table.hash_equal_ ? new E(*table.hash_equal_) : 0), + entry2id_(table.entry2id_.begin(), table.entry2id_.end(), + table.entry2id_.size(), + (hash_func_ ? *hash_func_ : H()), + (hash_equal_ ? *hash_equal_ : E())), + id2entry_(table.id2entry_) { } + + ~HashBiTable() { + delete hash_func_; + delete hash_equal_; + } + + I FindId(const T &entry, bool insert = true) { + I &id_ref = entry2id_[entry]; + if (id_ref == 0) { // T not found + if (insert) { // store and assign it a new ID + id2entry_.push_back(entry); + id_ref = id2entry_.size(); + } else { + return -1; + } + } + return id_ref - 1; // NB: id_ref = ID + 1 + } + + const T &FindEntry(I s) const { + return id2entry_[s]; + } + + I Size() const { return id2entry_.size(); } + + private: + H *hash_func_; + E *hash_equal_; + unordered_map<T, I, H, E> entry2id_; + vector<T> id2entry_; + + void operator=(const HashBiTable<I, T, H, E> &table); // disallow +}; + + +// Enables alternative hash set representations below. +// typedef enum { HS_STL = 0, HS_DENSE = 1, HS_SPARSE = 2 } HSType; +typedef enum { HS_STL = 0, HS_DENSE = 1, HS_SPARSE = 2 } HSType; + +// Default hash set is STL hash_set +template<class K, class H, class E, HSType> +struct HashSet : public unordered_set<K, H, E> { + HashSet(size_t n = 0, const H &h = H(), const E &e = E()) + : unordered_set<K, H, E>(n, h, e) { } + + void rehash(size_t n) { } +}; + + +// An implementation using a hash set for the entry to ID mapping. +// The hash set holds 'keys' which are either the ID or kCurrentKey. +// These keys can be mapped to entrys either by looking up in the +// entry vector or, if kCurrentKey, in current_entry_ member. The hash +// and key equality functions map to entries first. H +// is the hash function and E is the equality function. If passed to +// the constructor, ownership is given to this class. +template <class I, class T, class H, + class E = std::equal_to<T>, HSType HS = HS_DENSE> +class CompactHashBiTable { + public: + friend class HashFunc; + friend class HashEqual; + + // Reserves space for 'table_size' elements. + explicit CompactHashBiTable(size_t table_size = 0, H *h = 0, E *e = 0) + : hash_func_(h), + hash_equal_(e), + compact_hash_func_(*this), + compact_hash_equal_(*this), + keys_(table_size, compact_hash_func_, compact_hash_equal_) { + if (table_size) + id2entry_.reserve(table_size); + } + + CompactHashBiTable(const CompactHashBiTable<I, T, H, E, HS> &table) + : hash_func_(table.hash_func_ ? new H(*table.hash_func_) : 0), + hash_equal_(table.hash_equal_ ? new E(*table.hash_equal_) : 0), + compact_hash_func_(*this), + compact_hash_equal_(*this), + keys_(table.keys_.size(), compact_hash_func_, compact_hash_equal_), + id2entry_(table.id2entry_) { + keys_.insert(table.keys_.begin(), table.keys_.end()); + } + + ~CompactHashBiTable() { + delete hash_func_; + delete hash_equal_; + } + + I FindId(const T &entry, bool insert = true) { + current_entry_ = &entry; + typename KeyHashSet::const_iterator it = keys_.find(kCurrentKey); + if (it == keys_.end()) { // T not found + if (insert) { // store and assign it a new ID + I key = id2entry_.size(); + id2entry_.push_back(entry); + keys_.insert(key); + return key; + } else { + return -1; + } + } else { + return *it; + } + } + + const T &FindEntry(I s) const { return id2entry_[s]; } + + I Size() const { return id2entry_.size(); } + + // Clear content. With argument, erases last n IDs. + void Clear(ssize_t n = -1) { + if (n < 0 || n > id2entry_.size()) + n = id2entry_.size(); + while (n-- > 0) { + I key = id2entry_.size() - 1; + keys_.erase(key); + id2entry_.pop_back(); + } + keys_.rehash(0); + } + + private: + static const I kCurrentKey; // -1 + static const I kEmptyKey; // -2 + static const I kDeletedKey; // -3 + + class HashFunc { + public: + HashFunc(const CompactHashBiTable &ht) : ht_(&ht) {} + + size_t operator()(I k) const { + if (k >= kCurrentKey) { + return (*ht_->hash_func_)(ht_->Key2Entry(k)); + } else { + return 0; + } + } + + private: + const CompactHashBiTable *ht_; + }; + + class HashEqual { + public: + HashEqual(const CompactHashBiTable &ht) : ht_(&ht) {} + + bool operator()(I k1, I k2) const { + if (k1 >= kCurrentKey && k2 >= kCurrentKey) { + return (*ht_->hash_equal_)(ht_->Key2Entry(k1), ht_->Key2Entry(k2)); + } else { + return k1 == k2; + } + } + private: + const CompactHashBiTable *ht_; + }; + + typedef HashSet<I, HashFunc, HashEqual, HS> KeyHashSet; + + const T &Key2Entry(I k) const { + if (k == kCurrentKey) + return *current_entry_; + else + return id2entry_[k]; + } + + H *hash_func_; + E *hash_equal_; + HashFunc compact_hash_func_; + HashEqual compact_hash_equal_; + KeyHashSet keys_; + vector<T> id2entry_; + const T *current_entry_; + + void operator=(const CompactHashBiTable<I, T, H, E, HS> &table); // disallow +}; + + +template <class I, class T, class H, class E, HSType HS> +const I CompactHashBiTable<I, T, H, E, HS>::kCurrentKey = -1; + +template <class I, class T, class H, class E, HSType HS> +const I CompactHashBiTable<I, T, H, E, HS>::kEmptyKey = -2; + +template <class I, class T, class H, class E, HSType HS> +const I CompactHashBiTable<I, T, H, E, HS>::kDeletedKey = -3; + + +// An implementation using a vector for the entry to ID mapping. +// It is passed a function object FP that should fingerprint entries +// uniquely to an integer that can used as a vector index. Normally, +// VectorBiTable constructs the FP object. The user can instead +// pass in this object; in that case, VectorBiTable takes its +// ownership. +template <class I, class T, class FP> +class VectorBiTable { + public: + // Reserves space for 'table_size' elements. + explicit VectorBiTable(FP *fp = 0, size_t table_size = 0) + : fp_(fp ? fp : new FP()) { + if (table_size) + id2entry_.reserve(table_size); + } + + VectorBiTable(const VectorBiTable<I, T, FP> &table) + : fp_(table.fp_ ? new FP(*table.fp_) : 0), + fp2id_(table.fp2id_), + id2entry_(table.id2entry_) { } + + ~VectorBiTable() { delete fp_; } + + I FindId(const T &entry, bool insert = true) { + ssize_t fp = (*fp_)(entry); + if (fp >= fp2id_.size()) + fp2id_.resize(fp + 1); + I &id_ref = fp2id_[fp]; + if (id_ref == 0) { // T not found + if (insert) { // store and assign it a new ID + id2entry_.push_back(entry); + id_ref = id2entry_.size(); + } else { + return -1; + } + } + return id_ref - 1; // NB: id_ref = ID + 1 + } + + const T &FindEntry(I s) const { return id2entry_[s]; } + + I Size() const { return id2entry_.size(); } + + const FP &Fingerprint() const { return *fp_; } + + private: + FP *fp_; + vector<I> fp2id_; + vector<T> id2entry_; + + void operator=(const VectorBiTable<I, T, FP> &table); // disallow +}; + + +// An implementation using a vector and a compact hash table. The +// selecting functor S returns true for entries to be hashed in the +// vector. The fingerprinting functor FP returns a unique fingerprint +// for each entry to be hashed in the vector (these need to be +// suitable for indexing in a vector). The hash functor H is used +// when hashing entry into the compact hash table. If passed to the +// constructor, ownership is given to this class. +template <class I, class T, class S, class FP, class H, HSType HS = HS_DENSE> +class VectorHashBiTable { + public: + friend class HashFunc; + friend class HashEqual; + + explicit VectorHashBiTable(S *s, FP *fp = 0, H *h = 0, + size_t vector_size = 0, + size_t entry_size = 0) + : selector_(s), + fp_(fp ? fp : new FP()), + h_(h ? h : new H()), + hash_func_(*this), + hash_equal_(*this), + keys_(0, hash_func_, hash_equal_) { + if (vector_size) + fp2id_.reserve(vector_size); + if (entry_size) + id2entry_.reserve(entry_size); + } + + VectorHashBiTable(const VectorHashBiTable<I, T, S, FP, H, HS> &table) + : selector_(new S(table.s_)), + fp_(table.fp_ ? new FP(*table.fp_) : 0), + h_(table.h_ ? new H(*table.h_) : 0), + id2entry_(table.id2entry_), + fp2id_(table.fp2id_), + hash_func_(*this), + hash_equal_(*this), + keys_(table.keys_.size(), hash_func_, hash_equal_) { + keys_.insert(table.keys_.begin(), table.keys_.end()); + } + + ~VectorHashBiTable() { + delete selector_; + delete fp_; + delete h_; + } + + I FindId(const T &entry, bool insert = true) { + if ((*selector_)(entry)) { // Use the vector if 'selector_(entry) == true' + uint64 fp = (*fp_)(entry); + if (fp2id_.size() <= fp) + fp2id_.resize(fp + 1, 0); + if (fp2id_[fp] == 0) { // T not found + if (insert) { // store and assign it a new ID + id2entry_.push_back(entry); + fp2id_[fp] = id2entry_.size(); + } else { + return -1; + } + } + return fp2id_[fp] - 1; // NB: assoc_value = ID + 1 + } else { // Use the hash table otherwise. + current_entry_ = &entry; + typename KeyHashSet::const_iterator it = keys_.find(kCurrentKey); + if (it == keys_.end()) { + if (insert) { + I key = id2entry_.size(); + id2entry_.push_back(entry); + keys_.insert(key); + return key; + } else { + return -1; + } + } else { + return *it; + } + } + } + + const T &FindEntry(I s) const { + return id2entry_[s]; + } + + I Size() const { return id2entry_.size(); } + + const S &Selector() const { return *selector_; } + + const FP &Fingerprint() const { return *fp_; } + + const H &Hash() const { return *h_; } + + private: + static const I kCurrentKey; // -1 + static const I kEmptyKey; // -2 + + class HashFunc { + public: + HashFunc(const VectorHashBiTable &ht) : ht_(&ht) {} + + size_t operator()(I k) const { + if (k >= kCurrentKey) { + return (*(ht_->h_))(ht_->Key2Entry(k)); + } else { + return 0; + } + } + private: + const VectorHashBiTable *ht_; + }; + + class HashEqual { + public: + HashEqual(const VectorHashBiTable &ht) : ht_(&ht) {} + + bool operator()(I k1, I k2) const { + if (k1 >= kCurrentKey && k2 >= kCurrentKey) { + return ht_->Key2Entry(k1) == ht_->Key2Entry(k2); + } else { + return k1 == k2; + } + } + private: + const VectorHashBiTable *ht_; + }; + + typedef HashSet<I, HashFunc, HashEqual, HS> KeyHashSet; + + const T &Key2Entry(I k) const { + if (k == kCurrentKey) + return *current_entry_; + else + return id2entry_[k]; + } + + S *selector_; // Returns true if entry hashed into vector + FP *fp_; // Fingerprint used when hashing entry into vector + H *h_; // Hash function used when hashing entry into hash_set + + vector<T> id2entry_; // Maps state IDs to entry + vector<I> fp2id_; // Maps entry fingerprints to IDs + + // Compact implementation of the hash table mapping entrys to + // state IDs using the hash function 'h_' + HashFunc hash_func_; + HashEqual hash_equal_; + KeyHashSet keys_; + const T *current_entry_; + + // disallow + void operator=(const VectorHashBiTable<I, T, S, FP, H, HS> &table); +}; + +template <class I, class T, class S, class FP, class H, HSType HS> +const I VectorHashBiTable<I, T, S, FP, H, HS>::kCurrentKey = -1; + +template <class I, class T, class S, class FP, class H, HSType HS> +const I VectorHashBiTable<I, T, S, FP, H, HS>::kEmptyKey = -3; + + +// An implementation using a hash map for the entry to ID +// mapping. This version permits erasing of arbitrary states. The +// entry T must have == defined and its default constructor must +// produce a entry that will never be seen. F is the hash function. +template <class I, class T, class F> +class ErasableBiTable { + public: + ErasableBiTable() : first_(0) {} + + I FindId(const T &entry, bool insert = true) { + I &id_ref = entry2id_[entry]; + if (id_ref == 0) { // T not found + if (insert) { // store and assign it a new ID + id2entry_.push_back(entry); + id_ref = id2entry_.size() + first_; + } else { + return -1; + } + } + return id_ref - 1; // NB: id_ref = ID + 1 + } + + const T &FindEntry(I s) const { return id2entry_[s - first_]; } + + I Size() const { return id2entry_.size(); } + + void Erase(I s) { + T &entry = id2entry_[s - first_]; + typename unordered_map<T, I, F>::iterator it = + entry2id_.find(entry); + entry2id_.erase(it); + id2entry_[s - first_] = empty_entry_; + while (!id2entry_.empty() && id2entry_.front() == empty_entry_) { + id2entry_.pop_front(); + ++first_; + } + } + + private: + unordered_map<T, I, F> entry2id_; + deque<T> id2entry_; + const T empty_entry_; + I first_; // I of first element in the deque; + + // disallow + void operator=(const ErasableBiTable<I, T, F> &table); //disallow +}; + +} // namespace fst + +#endif // FST_LIB_BI_TABLE_H__ diff --git a/kaldi_io/src/tools/openfst/include/fst/cache.h b/kaldi_io/src/tools/openfst/include/fst/cache.h new file mode 100644 index 0000000..7c96fe1 --- /dev/null +++ b/kaldi_io/src/tools/openfst/include/fst/cache.h @@ -0,0 +1,861 @@ +// cache.h + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Copyright 2005-2010 Google, Inc. +// Author: [email protected] (Michael Riley) +// +// \file +// An Fst implementation that caches FST elements of a delayed +// computation. + +#ifndef FST_LIB_CACHE_H__ +#define FST_LIB_CACHE_H__ + +#include <vector> +using std::vector; +#include <list> + +#include <fst/vector-fst.h> + + +DECLARE_bool(fst_default_cache_gc); +DECLARE_int64(fst_default_cache_gc_limit); + +namespace fst { + +struct CacheOptions { + bool gc; // enable GC + size_t gc_limit; // # of bytes allowed before GC + + CacheOptions(bool g, size_t l) : gc(g), gc_limit(l) {} + CacheOptions() + : gc(FLAGS_fst_default_cache_gc), + gc_limit(FLAGS_fst_default_cache_gc_limit) {} +}; + +// A CacheStateAllocator allocates and frees CacheStates +// template <class S> +// struct CacheStateAllocator { +// S *Allocate(StateId s); +// void Free(S *state, StateId s); +// }; +// + +// A simple allocator class, can be overridden as needed, +// maintains a single entry cache. +template <class S> +struct DefaultCacheStateAllocator { + typedef typename S::Arc::StateId StateId; + + DefaultCacheStateAllocator() : mru_(NULL) { } + + ~DefaultCacheStateAllocator() { + delete mru_; + } + + S *Allocate(StateId s) { + if (mru_) { + S *state = mru_; + mru_ = NULL; + state->Reset(); + return state; + } + return new S(); + } + + void Free(S *state, StateId s) { + if (mru_) { + delete mru_; + } + mru_ = state; + } + + private: + S *mru_; +}; + +// VectorState but additionally has a flags data member (see +// CacheState below). This class is used to cache FST elements with +// the flags used to indicate what has been cached. Use HasStart() +// HasFinal(), and HasArcs() to determine if cached and SetStart(), +// SetFinal(), AddArc(), (or PushArc() and SetArcs()) to cache. Note +// you must set the final weight even if the state is non-final to +// mark it as cached. If the 'gc' option is 'false', cached items have +// the extent of the FST - minimizing computation. If the 'gc' option +// is 'true', garbage collection of states (not in use in an arc +// iterator and not 'protected') is performed, in a rough +// approximation of LRU order, when 'gc_limit' bytes is reached - +// controlling memory use. When 'gc_limit' is 0, special optimizations +// apply - minimizing memory use. + +template <class S, class C = DefaultCacheStateAllocator<S> > +class CacheBaseImpl : public VectorFstBaseImpl<S> { + public: + typedef S State; + typedef C Allocator; + typedef typename State::Arc Arc; + typedef typename Arc::Weight Weight; + typedef typename Arc::StateId StateId; + + using FstImpl<Arc>::Type; + using FstImpl<Arc>::Properties; + using FstImpl<Arc>::SetProperties; + using VectorFstBaseImpl<State>::NumStates; + using VectorFstBaseImpl<State>::Start; + using VectorFstBaseImpl<State>::AddState; + using VectorFstBaseImpl<State>::SetState; + using VectorFstBaseImpl<State>::ReserveStates; + + explicit CacheBaseImpl(C *allocator = 0) + : cache_start_(false), nknown_states_(0), min_unexpanded_state_id_(0), + cache_first_state_id_(kNoStateId), cache_first_state_(0), + cache_gc_(FLAGS_fst_default_cache_gc), cache_size_(0), + cache_limit_(FLAGS_fst_default_cache_gc_limit > kMinCacheLimit || + FLAGS_fst_default_cache_gc_limit == 0 ? + FLAGS_fst_default_cache_gc_limit : kMinCacheLimit), + protect_(false) { + allocator_ = allocator ? allocator : new C(); + } + + explicit CacheBaseImpl(const CacheOptions &opts, C *allocator = 0) + : cache_start_(false), nknown_states_(0), + min_unexpanded_state_id_(0), cache_first_state_id_(kNoStateId), + cache_first_state_(0), cache_gc_(opts.gc), cache_size_(0), + cache_limit_(opts.gc_limit > kMinCacheLimit || opts.gc_limit == 0 ? + opts.gc_limit : kMinCacheLimit), + protect_(false) { + allocator_ = allocator ? allocator : new C(); + } + + // Preserve gc parameters. If preserve_cache true, also preserves + // cache data. + CacheBaseImpl(const CacheBaseImpl<S, C> &impl, bool preserve_cache = false) + : VectorFstBaseImpl<S>(), cache_start_(false), nknown_states_(0), + min_unexpanded_state_id_(0), cache_first_state_id_(kNoStateId), + cache_first_state_(0), cache_gc_(impl.cache_gc_), cache_size_(0), + cache_limit_(impl.cache_limit_), + protect_(impl.protect_) { + allocator_ = new C(); + if (preserve_cache) { + cache_start_ = impl.cache_start_; + nknown_states_ = impl.nknown_states_; + expanded_states_ = impl.expanded_states_; + min_unexpanded_state_id_ = impl.min_unexpanded_state_id_; + if (impl.cache_first_state_id_ != kNoStateId) { + cache_first_state_id_ = impl.cache_first_state_id_; + cache_first_state_ = allocator_->Allocate(cache_first_state_id_); + *cache_first_state_ = *impl.cache_first_state_; + } + cache_states_ = impl.cache_states_; + cache_size_ = impl.cache_size_; + ReserveStates(impl.NumStates()); + for (StateId s = 0; s < impl.NumStates(); ++s) { + const S *state = + static_cast<const VectorFstBaseImpl<S> &>(impl).GetState(s); + if (state) { + S *copied_state = allocator_->Allocate(s); + *copied_state = *state; + AddState(copied_state); + } else { + AddState(0); + } + } + VectorFstBaseImpl<S>::SetStart(impl.Start()); + } + } + + ~CacheBaseImpl() { + allocator_->Free(cache_first_state_, cache_first_state_id_); + delete allocator_; + } + + // Gets a state from its ID; state must exist. + const S *GetState(StateId s) const { + if (s == cache_first_state_id_) + return cache_first_state_; + else + return VectorFstBaseImpl<S>::GetState(s); + } + + // Gets a state from its ID; state must exist. + S *GetState(StateId s) { + if (s == cache_first_state_id_) + return cache_first_state_; + else + return VectorFstBaseImpl<S>::GetState(s); + } + + // Gets a state from its ID; return 0 if it doesn't exist. + const S *CheckState(StateId s) const { + if (s == cache_first_state_id_) + return cache_first_state_; + else if (s < NumStates()) + return VectorFstBaseImpl<S>::GetState(s); + else + return 0; + } + + // Gets a state from its ID; add it if necessary. + S *ExtendState(StateId s); + + void SetStart(StateId s) { + VectorFstBaseImpl<S>::SetStart(s); + cache_start_ = true; + if (s >= nknown_states_) + nknown_states_ = s + 1; + } + + void SetFinal(StateId s, Weight w) { + S *state = ExtendState(s); + state->final = w; + state->flags |= kCacheFinal | kCacheRecent | kCacheModified; + } + + // AddArc adds a single arc to state s and does incremental cache + // book-keeping. For efficiency, prefer PushArc and SetArcs below + // when possible. + void AddArc(StateId s, const Arc &arc) { + S *state = ExtendState(s); + state->arcs.push_back(arc); + if (arc.ilabel == 0) { + ++state->niepsilons; + } + if (arc.olabel == 0) { + ++state->noepsilons; + } + const Arc *parc = state->arcs.empty() ? 0 : &(state->arcs.back()); + SetProperties(AddArcProperties(Properties(), s, arc, parc)); + state->flags |= kCacheModified; + if (cache_gc_ && s != cache_first_state_id_ && + !(state->flags & kCacheProtect)) { + cache_size_ += sizeof(Arc); + if (cache_size_ > cache_limit_) + GC(s, false); + } + } + + // Adds a single arc to state s but delays cache book-keeping. + // SetArcs must be called when all PushArc calls at a state are + // complete. Do not mix with calls to AddArc. + void PushArc(StateId s, const Arc &arc) { + S *state = ExtendState(s); + state->arcs.push_back(arc); + } + + // Marks arcs of state s as cached and does cache book-keeping after all + // calls to PushArc have been completed. Do not mix with calls to AddArc. + void SetArcs(StateId s) { + S *state = ExtendState(s); + vector<Arc> &arcs = state->arcs; + state->niepsilons = state->noepsilons = 0; + for (size_t a = 0; a < arcs.size(); ++a) { + const Arc &arc = arcs[a]; + if (arc.nextstate >= nknown_states_) + nknown_states_ = arc.nextstate + 1; + if (arc.ilabel == 0) + ++state->niepsilons; + if (arc.olabel == 0) + ++state->noepsilons; + } + ExpandedState(s); + state->flags |= kCacheArcs | kCacheRecent | kCacheModified; + if (cache_gc_ && s != cache_first_state_id_ && + !(state->flags & kCacheProtect)) { + cache_size_ += arcs.capacity() * sizeof(Arc); + if (cache_size_ > cache_limit_) + GC(s, false); + } + }; + + void ReserveArcs(StateId s, size_t n) { + S *state = ExtendState(s); + state->arcs.reserve(n); + } + + void DeleteArcs(StateId s, size_t n) { + S *state = ExtendState(s); + const vector<Arc> &arcs = state->arcs; + for (size_t i = 0; i < n; ++i) { + size_t j = arcs.size() - i - 1; + if (arcs[j].ilabel == 0) + --state->niepsilons; + if (arcs[j].olabel == 0) + --state->noepsilons; + } + + state->arcs.resize(arcs.size() - n); + SetProperties(DeleteArcsProperties(Properties())); + state->flags |= kCacheModified; + if (cache_gc_ && s != cache_first_state_id_ && + !(state->flags & kCacheProtect)) { + cache_size_ -= n * sizeof(Arc); + } + } + + void DeleteArcs(StateId s) { + S *state = ExtendState(s); + size_t n = state->arcs.size(); + state->niepsilons = 0; + state->noepsilons = 0; + state->arcs.clear(); + SetProperties(DeleteArcsProperties(Properties())); + state->flags |= kCacheModified; + if (cache_gc_ && s != cache_first_state_id_ && + !(state->flags & kCacheProtect)) { + cache_size_ -= n * sizeof(Arc); + } + } + + void DeleteStates(const vector<StateId> &dstates) { + size_t old_num_states = NumStates(); + vector<StateId> newid(old_num_states, 0); + for (size_t i = 0; i < dstates.size(); ++i) + newid[dstates[i]] = kNoStateId; + StateId nstates = 0; + for (StateId s = 0; s < old_num_states; ++s) { + if (newid[s] != kNoStateId) { + newid[s] = nstates; + ++nstates; + } + } + // just for states_.resize(), does unnecessary walk. + VectorFstBaseImpl<S>::DeleteStates(dstates); + SetProperties(DeleteStatesProperties(Properties())); + // Update list of cached states. + typename list<StateId>::iterator siter = cache_states_.begin(); + while (siter != cache_states_.end()) { + if (newid[*siter] != kNoStateId) { + *siter = newid[*siter]; + ++siter; + } else { + cache_states_.erase(siter++); + } + } + } + + void DeleteStates() { + cache_states_.clear(); + allocator_->Free(cache_first_state_, cache_first_state_id_); + for (int s = 0; s < NumStates(); ++s) { + allocator_->Free(VectorFstBaseImpl<S>::GetState(s), s); + SetState(s, 0); + } + nknown_states_ = 0; + min_unexpanded_state_id_ = 0; + cache_first_state_id_ = kNoStateId; + cache_first_state_ = 0; + cache_size_ = 0; + cache_start_ = false; + VectorFstBaseImpl<State>::DeleteStates(); + SetProperties(DeleteAllStatesProperties(Properties(), + kExpanded | kMutable)); + } + + // Is the start state cached? + bool HasStart() const { + if (!cache_start_ && Properties(kError)) + cache_start_ = true; + return cache_start_; + } + + // Is the final weight of state s cached? + bool HasFinal(StateId s) const { + const S *state = CheckState(s); + if (state && state->flags & kCacheFinal) { + state->flags |= kCacheRecent; + return true; + } else { + return false; + } + } + + // Are arcs of state s cached? + bool HasArcs(StateId s) const { + const S *state = CheckState(s); + if (state && state->flags & kCacheArcs) { + state->flags |= kCacheRecent; + return true; + } else { + return false; + } + } + + Weight Final(StateId s) const { + const S *state = GetState(s); + return state->final; + } + + size_t NumArcs(StateId s) const { + const S *state = GetState(s); + return state->arcs.size(); + } + + size_t NumInputEpsilons(StateId s) const { + const S *state = GetState(s); + return state->niepsilons; + } + + size_t NumOutputEpsilons(StateId s) const { + const S *state = GetState(s); + return state->noepsilons; + } + + // Provides information needed for generic arc iterator. + void InitArcIterator(StateId s, ArcIteratorData<Arc> *data) const { + const S *state = GetState(s); + data->base = 0; + data->narcs = state->arcs.size(); + data->arcs = data->narcs > 0 ? &(state->arcs[0]) : 0; + data->ref_count = &(state->ref_count); + ++(*data->ref_count); + } + + // Number of known states. + StateId NumKnownStates() const { return nknown_states_; } + + // Update number of known states taking in account the existence of state s. + void UpdateNumKnownStates(StateId s) { + if (s >= nknown_states_) + nknown_states_ = s + 1; + } + + // Find the mininum never-expanded state Id + StateId MinUnexpandedState() const { + while (min_unexpanded_state_id_ < expanded_states_.size() && + expanded_states_[min_unexpanded_state_id_]) + ++min_unexpanded_state_id_; + return min_unexpanded_state_id_; + } + + // Removes from cache_states_ and uncaches (not referenced-counted + // or protected) states that have not been accessed since the last + // GC until at most cache_fraction * cache_limit_ bytes are cached. + // If that fails to free enough, recurs uncaching recently visited + // states as well. If still unable to free enough memory, then + // widens cache_limit_ to fulfill condition. + void GC(StateId current, bool free_recent, float cache_fraction = 0.666); + + // Setc/clears GC protection: if true, new states are protected + // from garbage collection. + void GCProtect(bool on) { protect_ = on; } + + void ExpandedState(StateId s) { + if (s < min_unexpanded_state_id_) + return; + while (expanded_states_.size() <= s) + expanded_states_.push_back(false); + expanded_states_[s] = true; + } + + C *GetAllocator() const { + return allocator_; + } + + // Caching on/off switch, limit and size accessors. + bool GetCacheGc() const { return cache_gc_; } + size_t GetCacheLimit() const { return cache_limit_; } + size_t GetCacheSize() const { return cache_size_; } + + private: + static const size_t kMinCacheLimit = 8096; // Minimum (non-zero) cache limit + + static const uint32 kCacheFinal = 0x0001; // Final weight has been cached + static const uint32 kCacheArcs = 0x0002; // Arcs have been cached + static const uint32 kCacheRecent = 0x0004; // Mark as visited since GC + static const uint32 kCacheProtect = 0x0008; // Mark state as GC protected + + public: + static const uint32 kCacheModified = 0x0010; // Mark state as modified + static const uint32 kCacheFlags = kCacheFinal | kCacheArcs | kCacheRecent + | kCacheProtect | kCacheModified; + + private: + C *allocator_; // used to allocate new states + mutable bool cache_start_; // Is the start state cached? + StateId nknown_states_; // # of known states + vector<bool> expanded_states_; // states that have been expanded + mutable StateId min_unexpanded_state_id_; // minimum never-expanded state Id + StateId cache_first_state_id_; // First cached state id + S *cache_first_state_; // First cached state + list<StateId> cache_states_; // list of currently cached states + bool cache_gc_; // enable GC + size_t cache_size_; // # of bytes cached + size_t cache_limit_; // # of bytes allowed before GC + bool protect_; // Protect new states from GC + + void operator=(const CacheBaseImpl<S, C> &impl); // disallow +}; + +// Gets a state from its ID; add it if necessary. +template <class S, class C> +S *CacheBaseImpl<S, C>::ExtendState(typename S::Arc::StateId s) { + // If 'protect_' true and a new state, protects from garbage collection. + if (s == cache_first_state_id_) { + return cache_first_state_; // Return 1st cached state + } else if (cache_limit_ == 0 && cache_first_state_id_ == kNoStateId) { + cache_first_state_id_ = s; // Remember 1st cached state + cache_first_state_ = allocator_->Allocate(s); + if (protect_) cache_first_state_->flags |= kCacheProtect; + return cache_first_state_; + } else if (cache_first_state_id_ != kNoStateId && + cache_first_state_->ref_count == 0 && + !(cache_first_state_->flags & kCacheProtect)) { + // With Default allocator, the Free and Allocate will reuse the same S*. + allocator_->Free(cache_first_state_, cache_first_state_id_); + cache_first_state_id_ = s; + cache_first_state_ = allocator_->Allocate(s); + if (protect_) cache_first_state_->flags |= kCacheProtect; + return cache_first_state_; // Return 1st cached state + } else { + while (NumStates() <= s) // Add state to main cache + AddState(0); + S *state = VectorFstBaseImpl<S>::GetState(s); + if (!state) { + state = allocator_->Allocate(s); + if (protect_) state->flags |= kCacheProtect; + SetState(s, state); + if (cache_first_state_id_ != kNoStateId) { // Forget 1st cached state + while (NumStates() <= cache_first_state_id_) + AddState(0); + SetState(cache_first_state_id_, cache_first_state_); + if (cache_gc_ && !(cache_first_state_->flags & kCacheProtect)) { + cache_states_.push_back(cache_first_state_id_); + cache_size_ += sizeof(S) + + cache_first_state_->arcs.capacity() * sizeof(Arc); + } + cache_limit_ = kMinCacheLimit; + cache_first_state_id_ = kNoStateId; + cache_first_state_ = 0; + } + if (cache_gc_ && !protect_) { + cache_states_.push_back(s); + cache_size_ += sizeof(S); + if (cache_size_ > cache_limit_) + GC(s, false); + } + } + return state; + } +} + +// Removes from cache_states_ and uncaches (not referenced-counted or +// protected) states that have not been accessed since the last GC +// until at most cache_fraction * cache_limit_ bytes are cached. If +// that fails to free enough, recurs uncaching recently visited states +// as well. If still unable to free enough memory, then widens cache_limit_ +// to fulfill condition. +template <class S, class C> +void CacheBaseImpl<S, C>::GC(typename S::Arc::StateId current, + bool free_recent, float cache_fraction) { + if (!cache_gc_) + return; + VLOG(2) << "CacheImpl: Enter GC: object = " << Type() << "(" << this + << "), free recently cached = " << free_recent + << ", cache size = " << cache_size_ + << ", cache frac = " << cache_fraction + << ", cache limit = " << cache_limit_ << "\n"; + typename list<StateId>::iterator siter = cache_states_.begin(); + + size_t cache_target = cache_fraction * cache_limit_; + while (siter != cache_states_.end()) { + StateId s = *siter; + S* state = VectorFstBaseImpl<S>::GetState(s); + if (cache_size_ > cache_target && state->ref_count == 0 && + (free_recent || !(state->flags & kCacheRecent)) && s != current) { + cache_size_ -= sizeof(S) + state->arcs.capacity() * sizeof(Arc); + allocator_->Free(state, s); + SetState(s, 0); + cache_states_.erase(siter++); + } else { + state->flags &= ~kCacheRecent; + ++siter; + } + } + if (!free_recent && cache_size_ > cache_target) { // recurses on recent + GC(current, true); + } else if (cache_target > 0) { // widens cache limit + while (cache_size_ > cache_target) { + cache_limit_ *= 2; + cache_target *= 2; + } + } else if (cache_size_ > 0) { + FSTERROR() << "CacheImpl:GC: Unable to free all cached states"; + } + VLOG(2) << "CacheImpl: Exit GC: object = " << Type() << "(" << this + << "), free recently cached = " << free_recent + << ", cache size = " << cache_size_ + << ", cache frac = " << cache_fraction + << ", cache limit = " << cache_limit_ << "\n"; +} + +template <class S, class C> const uint32 CacheBaseImpl<S, C>::kCacheFinal; +template <class S, class C> const uint32 CacheBaseImpl<S, C>::kCacheArcs; +template <class S, class C> const uint32 CacheBaseImpl<S, C>::kCacheRecent; +template <class S, class C> const uint32 CacheBaseImpl<S, C>::kCacheModified; +template <class S, class C> const size_t CacheBaseImpl<S, C>::kMinCacheLimit; + +// Arcs implemented by an STL vector per state. Similar to VectorState +// but adds flags and ref count to keep track of what has been cached. +template <class A> +struct CacheState { + typedef A Arc; + typedef typename A::Weight Weight; + typedef typename A::StateId StateId; + + CacheState() : final(Weight::Zero()), flags(0), ref_count(0) {} + + void Reset() { + flags = 0; + ref_count = 0; + arcs.resize(0); + } + + Weight final; // Final weight + vector<A> arcs; // Arcs represenation + size_t niepsilons; // # of input epsilons + size_t noepsilons; // # of output epsilons + mutable uint32 flags; + mutable int ref_count; +}; + +// A CacheBaseImpl with a commonly used CacheState. +template <class A> +class CacheImpl : public CacheBaseImpl< CacheState<A> > { + public: + typedef CacheState<A> State; + + CacheImpl() {} + + explicit CacheImpl(const CacheOptions &opts) + : CacheBaseImpl< CacheState<A> >(opts) {} + + CacheImpl(const CacheImpl<A> &impl, bool preserve_cache = false) + : CacheBaseImpl<State>(impl, preserve_cache) {} + + private: + void operator=(const CacheImpl<State> &impl); // disallow +}; + + +// Use this to make a state iterator for a CacheBaseImpl-derived Fst, +// which must have type 'State' defined. Note this iterator only +// returns those states reachable from the initial state, so consider +// implementing a class-specific one. +template <class F> +class CacheStateIterator : public StateIteratorBase<typename F::Arc> { + public: + typedef typename F::Arc Arc; + typedef typename Arc::StateId StateId; + typedef typename F::State State; + typedef CacheBaseImpl<State> Impl; + + CacheStateIterator(const F &fst, Impl *impl) + : fst_(fst), impl_(impl), s_(0) { + fst_.Start(); // force start state + } + + bool Done() const { + if (s_ < impl_->NumKnownStates()) + return false; + if (s_ < impl_->NumKnownStates()) + return false; + for (StateId u = impl_->MinUnexpandedState(); + u < impl_->NumKnownStates(); + u = impl_->MinUnexpandedState()) { + // force state expansion + ArcIterator<F> aiter(fst_, u); + aiter.SetFlags(kArcValueFlags, kArcValueFlags | kArcNoCache); + for (; !aiter.Done(); aiter.Next()) + impl_->UpdateNumKnownStates(aiter.Value().nextstate); + impl_->ExpandedState(u); + if (s_ < impl_->NumKnownStates()) + return false; + } + return true; + } + + StateId Value() const { return s_; } + + void Next() { ++s_; } + + void Reset() { s_ = 0; } + + private: + // This allows base class virtual access to non-virtual derived- + // class members of the same name. It makes the derived class more + // efficient to use but unsafe to further derive. + virtual bool Done_() const { return Done(); } + virtual StateId Value_() const { return Value(); } + virtual void Next_() { Next(); } + virtual void Reset_() { Reset(); } + + const F &fst_; + Impl *impl_; + StateId s_; +}; + + +// Use this to make an arc iterator for a CacheBaseImpl-derived Fst, +// which must have types 'Arc' and 'State' defined. +template <class F, + class C = DefaultCacheStateAllocator<CacheState<typename F::Arc> > > +class CacheArcIterator { + public: + typedef typename F::Arc Arc; + typedef typename F::State State; + typedef typename Arc::StateId StateId; + typedef CacheBaseImpl<State, C> Impl; + + CacheArcIterator(Impl *impl, StateId s) : i_(0) { + state_ = impl->ExtendState(s); + ++state_->ref_count; + } + + ~CacheArcIterator() { --state_->ref_count; } + + bool Done() const { return i_ >= state_->arcs.size(); } + + const Arc& Value() const { return state_->arcs[i_]; } + + void Next() { ++i_; } + + size_t Position() const { return i_; } + + void Reset() { i_ = 0; } + + void Seek(size_t a) { i_ = a; } + + uint32 Flags() const { + return kArcValueFlags; + } + + void SetFlags(uint32 flags, uint32 mask) {} + + private: + const State *state_; + size_t i_; + + DISALLOW_COPY_AND_ASSIGN(CacheArcIterator); +}; + +// Use this to make a mutable arc iterator for a CacheBaseImpl-derived Fst, +// which must have types 'Arc' and 'State' defined. +template <class F, + class C = DefaultCacheStateAllocator<CacheState<typename F::Arc> > > +class CacheMutableArcIterator + : public MutableArcIteratorBase<typename F::Arc> { + public: + typedef typename F::State State; + typedef typename F::Arc Arc; + typedef typename Arc::StateId StateId; + typedef typename Arc::Weight Weight; + typedef CacheBaseImpl<State, C> Impl; + + // You will need to call MutateCheck() in the constructor. + CacheMutableArcIterator(Impl *impl, StateId s) : i_(0), s_(s), impl_(impl) { + state_ = impl_->ExtendState(s_); + ++state_->ref_count; + }; + + ~CacheMutableArcIterator() { + --state_->ref_count; + } + + bool Done() const { return i_ >= state_->arcs.size(); } + + const Arc& Value() const { return state_->arcs[i_]; } + + void Next() { ++i_; } + + size_t Position() const { return i_; } + + void Reset() { i_ = 0; } + + void Seek(size_t a) { i_ = a; } + + void SetValue(const Arc& arc) { + state_->flags |= CacheBaseImpl<State, C>::kCacheModified; + uint64 properties = impl_->Properties(); + Arc& oarc = state_->arcs[i_]; + if (oarc.ilabel != oarc.olabel) + properties &= ~kNotAcceptor; + if (oarc.ilabel == 0) { + --state_->niepsilons; + properties &= ~kIEpsilons; + if (oarc.olabel == 0) + properties &= ~kEpsilons; + } + if (oarc.olabel == 0) { + --state_->noepsilons; + properties &= ~kOEpsilons; + } + if (oarc.weight != Weight::Zero() && oarc.weight != Weight::One()) + properties &= ~kWeighted; + oarc = arc; + if (arc.ilabel != arc.olabel) { + properties |= kNotAcceptor; + properties &= ~kAcceptor; + } + if (arc.ilabel == 0) { + ++state_->niepsilons; + properties |= kIEpsilons; + properties &= ~kNoIEpsilons; + if (arc.olabel == 0) { + properties |= kEpsilons; + properties &= ~kNoEpsilons; + } + } + if (arc.olabel == 0) { + ++state_->noepsilons; + properties |= kOEpsilons; + properties &= ~kNoOEpsilons; + } + if (arc.weight != Weight::Zero() && arc.weight != Weight::One()) { + properties |= kWeighted; + properties &= ~kUnweighted; + } + properties &= kSetArcProperties | kAcceptor | kNotAcceptor | + kEpsilons | kNoEpsilons | kIEpsilons | kNoIEpsilons | + kOEpsilons | kNoOEpsilons | kWeighted | kUnweighted; + impl_->SetProperties(properties); + } + + uint32 Flags() const { + return kArcValueFlags; + } + + void SetFlags(uint32 f, uint32 m) {} + + private: + virtual bool Done_() const { return Done(); } + virtual const Arc& Value_() const { return Value(); } + virtual void Next_() { Next(); } + virtual size_t Position_() const { return Position(); } + virtual void Reset_() { Reset(); } + virtual void Seek_(size_t a) { Seek(a); } + virtual void SetValue_(const Arc &a) { SetValue(a); } + uint32 Flags_() const { return Flags(); } + void SetFlags_(uint32 f, uint32 m) { SetFlags(f, m); } + + size_t i_; + StateId s_; + Impl *impl_; + State *state_; + + DISALLOW_COPY_AND_ASSIGN(CacheMutableArcIterator); +}; + +} // namespace fst + +#endif // FST_LIB_CACHE_H__ diff --git a/kaldi_io/src/tools/openfst/include/fst/closure.h b/kaldi_io/src/tools/openfst/include/fst/closure.h new file mode 100644 index 0000000..541562b --- /dev/null +++ b/kaldi_io/src/tools/openfst/include/fst/closure.h @@ -0,0 +1,155 @@ +// closure.h + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Copyright 2005-2010 Google, Inc. +// Author: [email protected] (Michael Riley) +// +// \file +// Functions and classes to compute the concatenative closure of an Fst. + +#ifndef FST_LIB_CLOSURE_H__ +#define FST_LIB_CLOSURE_H__ + +#include <vector> +using std::vector; +#include <algorithm> + +#include <fst/mutable-fst.h> +#include <fst/rational.h> + + +namespace fst { + +// Computes the concatenative closure. This version modifies its +// MutableFst input. If FST transduces string x to y with weight a, +// then the closure transduces x to y with weight a, xx to yy with +// weight Times(a, a), xxx to yyy with with Times(Times(a, a), a), +// etc. If closure_type == CLOSURE_STAR, then the empty string is +// transduced to itself with weight Weight::One() as well. +// +// Complexity: +// - Time: O(V) +// - Space: O(V) +// where V = # of states. +template<class Arc> +void Closure(MutableFst<Arc> *fst, ClosureType closure_type) { + typedef typename Arc::StateId StateId; + typedef typename Arc::Label Label; + typedef typename Arc::Weight Weight; + + uint64 props = fst->Properties(kFstProperties, false); + StateId start = fst->Start(); + for (StateIterator< MutableFst<Arc> > siter(*fst); + !siter.Done(); + siter.Next()) { + StateId s = siter.Value(); + Weight final = fst->Final(s); + if (final != Weight::Zero()) + fst->AddArc(s, Arc(0, 0, final, start)); + } + if (closure_type == CLOSURE_STAR) { + fst->ReserveStates(fst->NumStates() + 1); + StateId nstart = fst->AddState(); + fst->SetStart(nstart); + fst->SetFinal(nstart, Weight::One()); + if (start != kNoLabel) + fst->AddArc(nstart, Arc(0, 0, Weight::One(), start)); + } + fst->SetProperties(ClosureProperties(props, closure_type == CLOSURE_STAR), + kFstProperties); +} + +// Computes the concatenative closure. This version modifies its +// RationalFst input. +template<class Arc> +void Closure(RationalFst<Arc> *fst, ClosureType closure_type) { + fst->GetImpl()->AddClosure(closure_type); +} + + +struct ClosureFstOptions : RationalFstOptions { + ClosureType type; + + ClosureFstOptions(const RationalFstOptions &opts, ClosureType t) + : RationalFstOptions(opts), type(t) {} + explicit ClosureFstOptions(ClosureType t) : type(t) {} + ClosureFstOptions() : type(CLOSURE_STAR) {} +}; + + +// Computes the concatenative closure. This version is a delayed +// Fst. If FST transduces string x to y with weight a, then the +// closure transduces x to y with weight a, xx to yy with weight +// Times(a, a), xxx to yyy with weight Times(Times(a, a), a), etc. If +// closure_type == CLOSURE_STAR, then The empty string is transduced +// to itself with weight Weight::One() as well. +// +// Complexity: +// - Time: O(v) +// - Space: O(v) +// where v = # of states visited. Constant time and space to visit an +// input state or arc is assumed and exclusive of caching. +template <class A> +class ClosureFst : public RationalFst<A> { + public: + using ImplToFst< RationalFstImpl<A> >::GetImpl; + + typedef A Arc; + + ClosureFst(const Fst<A> &fst, ClosureType closure_type) { + GetImpl()->InitClosure(fst, closure_type); + } + + ClosureFst(const Fst<A> &fst, const ClosureFstOptions &opts) + : RationalFst<A>(opts) { + GetImpl()->InitClosure(fst, opts.type); + } + + // See Fst<>::Copy() for doc. + ClosureFst(const ClosureFst<A> &fst, bool safe = false) + : RationalFst<A>(fst, safe) {} + + // Get a copy of this ClosureFst. See Fst<>::Copy() for further doc. + virtual ClosureFst<A> *Copy(bool safe = false) const { + return new ClosureFst<A>(*this, safe); + } +}; + + +// Specialization for ClosureFst. +template <class A> +class StateIterator< ClosureFst<A> > : public StateIterator< RationalFst<A> > { + public: + explicit StateIterator(const ClosureFst<A> &fst) + : StateIterator< RationalFst<A> >(fst) {} +}; + + +// Specialization for ClosureFst. +template <class A> +class ArcIterator< ClosureFst<A> > : public ArcIterator< RationalFst<A> > { + public: + typedef typename A::StateId StateId; + + ArcIterator(const ClosureFst<A> &fst, StateId s) + : ArcIterator< RationalFst<A> >(fst, s) {} +}; + + +// Useful alias when using StdArc. +typedef ClosureFst<StdArc> StdClosureFst; + +} // namespace fst + +#endif // FST_LIB_CLOSURE_H__ diff --git a/kaldi_io/src/tools/openfst/include/fst/compact-fst.h b/kaldi_io/src/tools/openfst/include/fst/compact-fst.h new file mode 100644 index 0000000..6db3317 --- /dev/null +++ b/kaldi_io/src/tools/openfst/include/fst/compact-fst.h @@ -0,0 +1,1438 @@ +// compact-fst.h + + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Copyright 2005-2010 Google, Inc. +// Author: [email protected] (Cyril Allauzen) +// +// \file +// FST Class for memory-efficient representation of common types of +// FSTs: linear automata, acceptors, unweighted FSTs, ... + +#ifndef FST_LIB_COMPACT_FST_H__ +#define FST_LIB_COMPACT_FST_H__ + +#include <iterator> +#include <utility> +using std::pair; using std::make_pair; +#include <vector> +using std::vector; + +#include <fst/cache.h> +#include <fst/expanded-fst.h> +#include <fst/fst-decl.h> // For optional argument declarations +#include <fst/mapped-file.h> +#include <fst/matcher.h> +#include <fst/test-properties.h> +#include <fst/util.h> + + +namespace fst { + +struct CompactFstOptions : public CacheOptions { + // CompactFst default caching behaviour is to do no caching. Most + // compactors are cheap and therefore we save memory by not doing + // caching. + CompactFstOptions() : CacheOptions(true, 0) {} + CompactFstOptions(const CacheOptions &opts) : CacheOptions(opts) {} +}; + +// Compactor Interface - class determinies how arcs and final weights +// are compacted and expanded. +// +// Final weights are treated as transitions to the superfinal state, +// i.e. ilabel = olabel = kNoLabel and nextstate = kNoStateId. +// +// There are two types of compactors: +// +// * Fixed out-degree compactors: 'compactor.Size()' returns a +// positive integer 's'. An FST can be compacted by this compactor +// only if each state has exactly 's' outgoing transitions (counting a +// non-Zero() final weight as a transition). A typical example is a +// compactor for string FSTs, i.e. 's == 1'. +// +// * Variable out-degree compactors: 'compactor.Size() == -1'. There +// are no out-degree restrictions for these compactors. +// +// +// class Compactor { +// public: +// // Element is the type of the compacted transitions. +// typedef ... Element; +// // Return the compacted representation of a transition 'arc' +// // at a state 's'. +// Element Compact(StateId s, const Arc &arc); +// // Return the transition at state 's' represented by the compacted +// // transition 'e'. +// Arc Expand(StateId s, const Element &e); +// // Return -1 for variable out-degree compactors, and the mandatory +// // out-degree otherwise. +// ssize_t Size(); +// // Test whether 'fst' can be compacted by this compactor. +// bool Compatible(const Fst<A> &fst); +// // Return the properties that are always true for an fst +// // compacted using this compactor +// uint64 Properties(); +// // Return a string identifying the type of compactor. +// static const string &Type(); +// // Write a compactor to a file. +// bool Write(ostream &strm); +// // Read a compactor from a file. +// static Compactor *Read(istream &strm); +// // Default constructor (optional, see comment below). +// Compactor(); +// }; +// +// The default constructor is only required for FST_REGISTER to work +// (i.e. enabling Convert() and the command-line utilities to work +// with this new compactor). However, a default constructor always +// needs to be specify for this code to compile, but one can have it +// simply raised an error when called: +// +// Compactor::Compactor() { +// FSTERROR() << "Compactor: no default constructor"; +// } + + +// Implementation data for Compact Fst, which can shared between otherwise +// independent copies. +// +// The implementation contains two arrays: 'states_' and 'compacts_'. +// +// For fixed out-degree compactors, the 'states_' array is unallocated. +// The 'compacts_' contains the compacted transitions. Its size is +// 'ncompacts_'. The outgoing transitions at a given state are stored +// consecutively. For a given state 's', its 'compactor.Size()' outgoing +// transitions (including superfinal transition when 's' is final), are +// stored in position ['s*compactor.Size()', '(s+1)*compactor_.Size()'). +// +// For variable out-degree compactors, the states_ array has size +// 'nstates_ + 1' and contains pointers to positions into 'compacts_'. +// For a given state 's', the compacted transitions of 's' are +// stored in positions [ 'states_[s]', 'states_[s + 1]' ) in 'compacts_'. +// By convention, 'states_[nstates_] == ncompacts_'. +// +// In both cases, the superfinal transitons (when 's' is final, i.e. +// 'Final(s) != Weight::Zero()') is stored first. +// +// The unsigned type U is used to represent indices into the compacts_ +// array. +template <class E, class U> +class CompactFstData { + public: + typedef E CompactElement; + typedef U Unsigned; + + CompactFstData() + : states_region_(0), + compacts_region_(0), + states_(0), + compacts_(0), + nstates_(0), + ncompacts_(0), + narcs_(0), + start_(kNoStateId), + error_(false) {} + + template <class A, class Compactor> + CompactFstData(const Fst<A> &fst, const Compactor &compactor); + + template <class Iterator, class Compactor> + CompactFstData(const Iterator &begin, const Iterator &end, + const Compactor &compactor); + + ~CompactFstData() { + if (states_region_ == NULL) { + delete [] states_; + } + delete states_region_; + if (compacts_region_ == NULL) { + delete [] compacts_; + } + delete compacts_region_; + } + + template <class Compactor> + static CompactFstData<E, U> *Read(istream &strm, + const FstReadOptions &opts, + const FstHeader &hdr, + const Compactor &compactor); + + bool Write(ostream &strm, const FstWriteOptions &opts) const; + + Unsigned States(ssize_t i) const { return states_[i]; } + const CompactElement &Compacts(size_t i) const { return compacts_[i]; } + size_t NumStates() const { return nstates_; } + size_t NumCompacts() const { return ncompacts_; } + size_t NumArcs() const { return narcs_; } + ssize_t Start() const { return start_; } + + int RefCount() const { return ref_count_.count(); } + int IncrRefCount() { return ref_count_.Incr(); } + int DecrRefCount() { return ref_count_.Decr(); } + + bool Error() const { return error_; } + + private: + MappedFile *states_region_; + MappedFile *compacts_region_; + Unsigned *states_; + CompactElement *compacts_; + size_t nstates_; + size_t ncompacts_; + size_t narcs_; + ssize_t start_; + RefCounter ref_count_; + bool error_; +}; + +template <class E, class U> +template <class A, class C> +CompactFstData<E, U>::CompactFstData(const Fst<A> &fst, const C &compactor) + : states_region_(0), + compacts_region_(0), + states_(0), + compacts_(0), + nstates_(0), + ncompacts_(0), + narcs_(0), + start_(kNoStateId), + error_(false) { + typedef typename A::StateId StateId; + typedef typename A::Weight Weight; + start_ = fst.Start(); + // Count # of states and arcs. + StateId nfinals = 0; + for (StateIterator< Fst<A> > siter(fst); + !siter.Done(); + siter.Next()) { + ++nstates_; + StateId s = siter.Value(); + for (ArcIterator< Fst<A> > aiter(fst, s); + !aiter.Done(); + aiter.Next()) + ++narcs_; + if (fst.Final(s) != Weight::Zero()) ++nfinals; + } + if (compactor.Size() == -1) { + states_ = new Unsigned[nstates_ + 1]; + ncompacts_ = narcs_ + nfinals; + compacts_ = new CompactElement[ncompacts_]; + states_[nstates_] = ncompacts_; + } else { + states_ = 0; + ncompacts_ = nstates_ * compactor.Size(); + if ((narcs_ + nfinals) != ncompacts_) { + FSTERROR() << "CompactFstData: compactor incompatible with fst"; + error_ = true; + return; + } + compacts_ = new CompactElement[ncompacts_]; + } + size_t pos = 0, fpos = 0; + for (StateId s = 0; s < nstates_; ++s) { + fpos = pos; + if (compactor.Size() == -1) + states_[s] = pos; + if (fst.Final(s) != Weight::Zero()) + compacts_[pos++] = compactor.Compact(s, A(kNoLabel, kNoLabel, + fst.Final(s), kNoStateId)); + for (ArcIterator< Fst<A> > aiter(fst, s); + !aiter.Done(); + aiter.Next()) { + compacts_[pos++] = compactor.Compact(s, aiter.Value()); + } + if ((compactor.Size() != -1) && ((pos - fpos) != compactor.Size())) { + FSTERROR() << "CompactFstData: compactor incompatible with fst"; + error_ = true; + return; + } + } + if (pos != ncompacts_) { + FSTERROR() << "CompactFstData: compactor incompatible with fst"; + error_ = true; + return; + } +} + +template <class E, class U> +template <class Iterator, class C> +CompactFstData<E, U>::CompactFstData(const Iterator &begin, + const Iterator &end, + const C &compactor) + : states_region_(0), + compacts_region_(0), + states_(0), + compacts_(0), + nstates_(0), + ncompacts_(0), + narcs_(0), + start_(kNoStateId), + error_(false) { + typedef typename C::Arc Arc; + typedef typename Arc::Weight Weight; + if (compactor.Size() != -1) { + ncompacts_ = distance(begin, end); + if (compactor.Size() == 1) { + // For strings, allow implicit final weight. + // Empty input is the empty string. + if (ncompacts_ == 0) { + ++ncompacts_; + } else { + Arc arc = compactor.Expand(ncompacts_ - 1, + *(begin + (ncompacts_ - 1))); + if (arc.ilabel != kNoLabel) + ++ncompacts_; + } + } + if (ncompacts_ % compactor.Size()) { + FSTERROR() << "CompactFstData: size of input container incompatible" + << " with compactor"; + error_ = true; + return; + } + if (ncompacts_ == 0) + return; + start_ = 0; + nstates_ = ncompacts_ / compactor.Size(); + compacts_ = new CompactElement[ncompacts_]; + size_t i = 0; + Iterator it = begin; + for(; it != end; ++it, ++i){ + compacts_[i] = *it; + if (compactor.Expand(i, *it).ilabel != kNoLabel) + ++narcs_; + } + if (i < ncompacts_) + compacts_[i] = compactor.Compact(i, Arc(kNoLabel, kNoLabel, + Weight::One(), kNoStateId)); + } else { + if (distance(begin, end) == 0) + return; + // Count # of states, arcs and compacts. + Iterator it = begin; + for(size_t i = 0; it != end; ++it, ++i) { + Arc arc = compactor.Expand(i, *it); + if (arc.ilabel != kNoLabel) { + ++narcs_; + ++ncompacts_; + } else { + ++nstates_; + if (arc.weight != Weight::Zero()) + ++ncompacts_; + } + } + start_ = 0; + compacts_ = new CompactElement[ncompacts_]; + states_ = new Unsigned[nstates_ + 1]; + states_[nstates_] = ncompacts_; + size_t i = 0, s = 0; + for(it = begin; it != end; ++it) { + Arc arc = compactor.Expand(i, *it); + if (arc.ilabel != kNoLabel) { + compacts_[i++] = *it; + } else { + states_[s++] = i; + if (arc.weight != Weight::Zero()) + compacts_[i++] = *it; + } + } + if ((s != nstates_) || (i != ncompacts_)) { + FSTERROR() << "CompactFstData: ill-formed input container"; + error_ = true; + return; + } + } +} + +template <class E, class U> +template <class C> +CompactFstData<E, U> *CompactFstData<E, U>::Read( + istream &strm, + const FstReadOptions &opts, + const FstHeader &hdr, + const C &compactor) { + CompactFstData<E, U> *data = new CompactFstData<E, U>(); + data->start_ = hdr.Start(); + data->nstates_ = hdr.NumStates(); + data->narcs_ = hdr.NumArcs(); + + if (compactor.Size() == -1) { + if ((hdr.GetFlags() & FstHeader::IS_ALIGNED) && !AlignInput(strm)) { + LOG(ERROR) << "CompactFst::Read: Alignment failed: " << opts.source; + delete data; + return 0; + } + size_t b = (data->nstates_ + 1) * sizeof(Unsigned); + data->states_region_ = MappedFile::Map(&strm, opts, b); + if (!strm || data->states_region_ == NULL) { + LOG(ERROR) << "CompactFst::Read: Read failed: " << opts.source; + delete data; + return 0; + } + data->states_ = static_cast<Unsigned *>( + data->states_region_->mutable_data()); + } else { + data->states_ = 0; + } + data->ncompacts_ = compactor.Size() == -1 + ? data->states_[data->nstates_] + : data->nstates_ * compactor.Size(); + if ((hdr.GetFlags() & FstHeader::IS_ALIGNED) && !AlignInput(strm)) { + LOG(ERROR) << "CompactFst::Read: Alignment failed: " << opts.source; + delete data; + return 0; + } + size_t b = data->ncompacts_ * sizeof(CompactElement); + data->compacts_region_ = MappedFile::Map(&strm, opts, b); + if (!strm || data->compacts_region_ == NULL) { + LOG(ERROR) << "CompactFst::Read: Read failed: " << opts.source; + delete data; + return 0; + } + data->compacts_ = static_cast<CompactElement *>( + data->compacts_region_->mutable_data()); + return data; +} + +template<class E, class U> +bool CompactFstData<E, U>::Write(ostream &strm, + const FstWriteOptions &opts) const { + if (states_) { + if (opts.align && !AlignOutput(strm)) { + LOG(ERROR) << "CompactFst::Write: Alignment failed: " << opts.source; + return false; + } + strm.write(reinterpret_cast<char *>(states_), + (nstates_ + 1) * sizeof(Unsigned)); + } + if (opts.align && !AlignOutput(strm)) { + LOG(ERROR) << "CompactFst::Write: Alignment failed: " << opts.source; + return false; + } + strm.write(reinterpret_cast<char *>(compacts_), + ncompacts_ * sizeof(CompactElement)); + + strm.flush(); + if (!strm) { + LOG(ERROR) << "CompactFst::Write: Write failed: " << opts.source; + return false; + } + return true; +} + +template <class A, class C, class U> class CompactFst; +template <class F, class G> void Cast(const F &, G *); + +// Implementation class for CompactFst, which contains CompactFstData +// and Fst cache. +template <class A, class C, class U> +class CompactFstImpl : public CacheImpl<A> { + public: + using FstImpl<A>::SetType; + using FstImpl<A>::SetProperties; + using FstImpl<A>::Properties; + using FstImpl<A>::SetInputSymbols; + using FstImpl<A>::SetOutputSymbols; + using FstImpl<A>::WriteHeader; + + using CacheImpl<A>::PushArc; + using CacheImpl<A>::HasArcs; + using CacheImpl<A>::HasFinal; + using CacheImpl<A>::HasStart; + using CacheImpl<A>::SetArcs; + using CacheImpl<A>::SetFinal; + using CacheImpl<A>::SetStart; + + typedef A Arc; + typedef typename A::Weight Weight; + typedef typename A::StateId StateId; + typedef C Compactor; + typedef typename C::Element CompactElement; + typedef U Unsigned; + + CompactFstImpl() + : CacheImpl<A>(CompactFstOptions()), + compactor_(0), + own_compactor_(false), + data_(0) { + string type = "compact"; + if (sizeof(U) != sizeof(uint32)) { + string size; + Int64ToStr(8 * sizeof(U), &size); + type += size; + } + type += "_"; + type += C::Type(); + SetType(type); + SetProperties(kNullProperties | kStaticProperties); + } + + CompactFstImpl(const Fst<Arc> &fst, const C &compactor, + const CompactFstOptions &opts) + : CacheImpl<A>(opts), + compactor_(new C(compactor)), + own_compactor_(true), + data_(0) { + Init(fst); + } + + CompactFstImpl(const Fst<Arc> &fst, C *compactor, + const CompactFstOptions &opts) + : CacheImpl<A>(opts), + compactor_(compactor), + own_compactor_(false), + data_(0) { + Init(fst); + } + + template <class Iterator> + CompactFstImpl(const Iterator &b, const Iterator &e, const C &compactor, + const CompactFstOptions &opts) + : CacheImpl<A>(opts), + compactor_(new C(compactor)), + own_compactor_(true), + data_(0) { + Init(b, e); + } + + template <class Iterator> + CompactFstImpl(const Iterator &b, const Iterator &e, C *compactor, + const CompactFstOptions &opts) + : CacheImpl<A>(opts), + compactor_(compactor), + own_compactor_(false), + data_(0) { + Init(b, e); + } + + CompactFstImpl(const CompactFstImpl<A, C, U> &impl) + : CacheImpl<A>(impl), + compactor_(new C(*impl.compactor_)), + own_compactor_(true), + data_(impl.data_) { + if (data_) + data_->IncrRefCount(); + SetType(impl.Type()); + SetProperties(impl.Properties()); + SetInputSymbols(impl.InputSymbols()); + SetOutputSymbols(impl.OutputSymbols()); + } + + ~CompactFstImpl(){ + if (own_compactor_) + delete compactor_; + if (data_ && !data_->DecrRefCount()) + delete data_; + } + + StateId Start() { + if (!HasStart()) { + SetStart(data_->Start()); + } + return CacheImpl<A>::Start(); + } + + Weight Final(StateId s) { + if (HasFinal(s)) + return CacheImpl<A>::Final(s); + Arc arc(kNoLabel, kNoLabel, Weight::Zero(), kNoStateId); + if ((compactor_->Size() != -1) || + (data_->States(s) != data_->States(s + 1))) + arc = ComputeArc(s, + compactor_->Size() == -1 + ? data_->States(s) + : s * compactor_->Size()); + return arc.ilabel == kNoLabel ? arc.weight : Weight::Zero(); + } + + StateId NumStates() const { + if (Properties(kError)) return 0; + return data_->NumStates(); + } + + size_t NumArcs(StateId s) { + if (HasArcs(s)) + return CacheImpl<A>::NumArcs(s); + Unsigned i, num_arcs; + if (compactor_->Size() == -1) { + i = data_->States(s); + num_arcs = data_->States(s + 1) - i; + } else { + i = s * compactor_->Size(); + num_arcs = compactor_->Size(); + } + if (num_arcs > 0) { + const A &arc = ComputeArc(s, i, kArcILabelValue); + if (arc.ilabel == kNoStateId) { + --num_arcs; + } + } + return num_arcs; + } + + size_t NumInputEpsilons(StateId s) { + if (!HasArcs(s) && !Properties(kILabelSorted)) + Expand(s); + if (HasArcs(s)) + return CacheImpl<A>::NumInputEpsilons(s); + return CountEpsilons(s, false); + } + + size_t NumOutputEpsilons(StateId s) { + if (!HasArcs(s) && !Properties(kOLabelSorted)) + Expand(s); + if (HasArcs(s)) + return CacheImpl<A>::NumOutputEpsilons(s); + return CountEpsilons(s, true); + } + + size_t CountEpsilons(StateId s, bool output_epsilons) { + size_t begin = compactor_->Size() == -1 ? + data_->States(s) : s * compactor_->Size(); + size_t end = compactor_->Size() == -1 ? + data_->States(s + 1) : (s + 1) * compactor_->Size(); + size_t num_eps = 0; + for (size_t i = begin; i < end; ++i) { + const A &arc = ComputeArc( + s, i, output_epsilons ? kArcOLabelValue : kArcILabelValue); + const typename A::Label &label = + (output_epsilons ? arc.olabel : arc.ilabel); + if (label == kNoLabel) + continue; + else if (label > 0) + break; + ++num_eps; + } + return num_eps; + } + + static CompactFstImpl<A, C, U> *Read(istream &strm, + const FstReadOptions &opts) { + CompactFstImpl<A, C, U> *impl = new CompactFstImpl<A, C, U>(); + FstHeader hdr; + if (!impl->ReadHeader(strm, opts, kMinFileVersion, &hdr)) { + delete impl; + return 0; + } + + // Ensures compatibility + if (hdr.Version() == kAlignedFileVersion) + hdr.SetFlags(hdr.GetFlags() | FstHeader::IS_ALIGNED); + + impl->compactor_ = C::Read(strm); + if (!impl->compactor_) { + delete impl; + return 0; + } + impl->own_compactor_ = true; + impl->data_ = CompactFstData<CompactElement, U>::Read(strm, opts, hdr, + *impl->compactor_); + if (!impl->data_) { + delete impl; + return 0; + } + return impl; + } + + bool Write(ostream &strm, const FstWriteOptions &opts) const { + FstHeader hdr; + hdr.SetStart(data_->Start()); + hdr.SetNumStates(data_->NumStates()); + hdr.SetNumArcs(data_->NumArcs()); + + // Ensures compatibility + int file_version = opts.align ? kAlignedFileVersion : kFileVersion; + WriteHeader(strm, opts, file_version, &hdr); + compactor_->Write(strm); + return data_->Write(strm, opts); + } + + // Provide information needed for generic state iterator + void InitStateIterator(StateIteratorData<A> *data) const { + data->base = 0; + data->nstates = data_->NumStates(); + } + + void InitArcIterator(StateId s, ArcIteratorData<A> *data) { + if (!HasArcs(s)) + Expand(s); + CacheImpl<A>::InitArcIterator(s, data); + } + + Arc ComputeArc(StateId s, Unsigned i, uint32 f = kArcValueFlags) const { + return compactor_->Expand(s, data_->Compacts(i), f); + } + + void Expand(StateId s) { + size_t begin = compactor_->Size() == -1 ? + data_->States(s) : s * compactor_->Size(); + size_t end = compactor_->Size() == -1 ? + data_->States(s + 1) : (s + 1) * compactor_->Size(); + for (size_t i = begin; i < end; ++i) { + const Arc &arc = ComputeArc(s, i); + if (arc.ilabel == kNoLabel) + SetFinal(s, arc.weight); + else + PushArc(s, arc); + } + if (!HasFinal(s)) + SetFinal(s, Weight::Zero()); + SetArcs(s); + } + + template <class Iterator> + void SetCompactElements(const Iterator &b, const Iterator &e) { + if (data_ && !data_->DecrRefCount()) + delete data_; + data_ = new CompactFstData<CompactElement, U>(b, e, *compactor_); + } + + C *GetCompactor() const { return compactor_; } + CompactFstData<CompactElement, U> *Data() const { return data_; } + + // Properties always true of this Fst class + static const uint64 kStaticProperties = kExpanded; + + protected: + template <class B, class D> + explicit CompactFstImpl(const CompactFstImpl<B, D, U> &impl) + : CacheImpl<A>(CacheOptions(impl.GetCacheGc(), impl.GetCacheLimit())), + compactor_(new C(*impl.GetCompactor())), + own_compactor_(true), + data_(impl.Data()) { + if (data_) + data_->IncrRefCount(); + SetType(impl.Type()); + SetProperties(impl.Properties()); + SetInputSymbols(impl.InputSymbols()); + SetOutputSymbols(impl.OutputSymbols()); + } + + private: + friend class CompactFst<A, C, U>; // allow access during write. + + void Init(const Fst<Arc> &fst) { + string type = "compact"; + if (sizeof(U) != sizeof(uint32)) { + string size; + Int64ToStr(8 * sizeof(U), &size); + type += size; + } + type += "_"; + type += compactor_->Type(); + SetType(type); + SetInputSymbols(fst.InputSymbols()); + SetOutputSymbols(fst.OutputSymbols()); + data_ = new CompactFstData<CompactElement, U>(fst, *compactor_); + if (data_->Error()) + SetProperties(kError, kError); + uint64 copy_properties = fst.Properties(kCopyProperties, true); + if ((copy_properties & kError) || !compactor_->Compatible(fst)) { + FSTERROR() << "CompactFstImpl: input fst incompatible with compactor"; + SetProperties(kError, kError); + return; + } + SetProperties(copy_properties | kStaticProperties); + } + + template <class Iterator> + void Init(const Iterator &b, const Iterator &e) { + string type = "compact"; + if (sizeof(U) != sizeof(uint32)) { + string size; + Int64ToStr(8 * sizeof(U), &size); + type += size; + } + type += "_"; + type += compactor_->Type(); + SetType(type); + SetProperties(kStaticProperties | compactor_->Properties()); + data_ = new CompactFstData<CompactElement, U>(b, e, *compactor_); + if (data_->Error()) + SetProperties(kError, kError); + } + + // Current unaligned file format version + static const int kFileVersion = 2; + // Current aligned file format version + static const int kAlignedFileVersion = 1; + // Minimum file format version supported + static const int kMinFileVersion = 1; + + C *compactor_; + bool own_compactor_; + CompactFstData<CompactElement, U> *data_; +}; + +template <class A, class C, class U> +const uint64 CompactFstImpl<A, C, U>::kStaticProperties; +template <class A, class C, class U> +const int CompactFstImpl<A, C, U>::kFileVersion; +template <class A, class C, class U> +const int CompactFstImpl<A, C, U>::kAlignedFileVersion; +template <class A, class C, class U> +const int CompactFstImpl<A, C, U>::kMinFileVersion; + + +// CompactFst. This class attaches interface to implementation and +// handles reference counting, delegating most methods to +// ImplToExpandedFst. The unsigned type U is used to represent indices +// into the compact arc array (uint32 by default, declared in +// fst-decl.h). +template <class A, class C, class U> +class CompactFst : public ImplToExpandedFst< CompactFstImpl<A, C, U> > { + public: + friend class StateIterator< CompactFst<A, C, U> >; + friend class ArcIterator< CompactFst<A, C, U> >; + template <class F, class G> void friend Cast(const F &, G *); + + typedef A Arc; + typedef typename A::StateId StateId; + typedef CompactFstImpl<A, C, U> Impl; + typedef CacheState<A> State; + typedef U Unsigned; + + CompactFst() : ImplToExpandedFst<Impl>(new Impl()) {} + + explicit CompactFst(const Fst<A> &fst, const C &compactor = C(), + const CompactFstOptions &opts = CompactFstOptions()) + : ImplToExpandedFst<Impl>(new Impl(fst, compactor, opts)) {} + + CompactFst(const Fst<A> &fst, C *compactor, + const CompactFstOptions &opts = CompactFstOptions()) + : ImplToExpandedFst<Impl>(new Impl(fst, compactor, opts)) {} + + // The following 2 constructors take as input two iterators delimiting + // a set of (already) compacted transitions, starting with the + // transitions out of the initial state. The format of the input + // differs for fixed out-degree and variable out-degree compactors. + // + // - For fixed out-degree compactors, the final weight (encoded as a + // compacted transition) needs to be given only for final + // states. All strings (compactor of size 1) will be assume to be + // terminated by a final state even when the final state is not + // implicitely given. + // + // - For variable out-degree compactors, the final weight (encoded + // as a compacted transition) needs to be given for all states and + // must appeared first in the list (for state s, final weight of s, + // followed by outgoing transitons in s). + // + // These 2 constructors allows the direct construction of a CompactFst + // without first creating a more memory hungry 'regular' FST. This + // is useful when memory usage is severely constrained. + template <class Iterator> + explicit CompactFst(const Iterator &begin, const Iterator &end, + const C &compactor = C(), + const CompactFstOptions &opts = CompactFstOptions()) + : ImplToExpandedFst<Impl>(new Impl(begin, end, compactor, opts)) {} + + template <class Iterator> + CompactFst(const Iterator &begin, const Iterator &end, + C *compactor, const CompactFstOptions &opts = CompactFstOptions()) + : ImplToExpandedFst<Impl>(new Impl(begin, end, compactor, opts)) {} + + // See Fst<>::Copy() for doc. + CompactFst(const CompactFst<A, C, U> &fst, bool safe = false) + : ImplToExpandedFst<Impl>(fst, safe) {} + + // Get a copy of this CompactFst. See Fst<>::Copy() for further doc. + virtual CompactFst<A, C, U> *Copy(bool safe = false) const { + return new CompactFst<A, C, U>(*this, safe); + } + + // Read a CompactFst from an input stream; return NULL on error + static CompactFst<A, C, U> *Read(istream &strm, const FstReadOptions &opts) { + Impl* impl = Impl::Read(strm, opts); + return impl ? new CompactFst<A, C, U>(impl) : 0; + } + + // Read a CompactFst from a file; return NULL on error + // Empty filename reads from standard input + static CompactFst<A, C, U> *Read(const string &filename) { + Impl* impl = ImplToExpandedFst<Impl>::Read(filename); + return impl ? new CompactFst<A, C, U>(impl) : 0; + } + + virtual bool Write(ostream &strm, const FstWriteOptions &opts) const { + return GetImpl()->Write(strm, opts); + } + + virtual bool Write(const string &filename) const { + return Fst<A>::WriteFile(filename); + } + + template <class F> + static bool WriteFst(const F &fst, const C &compactor, ostream &strm, + const FstWriteOptions &opts); + + virtual void InitStateIterator(StateIteratorData<A> *data) const { + GetImpl()->InitStateIterator(data); + } + + virtual void InitArcIterator(StateId s, ArcIteratorData<A> *data) const { + GetImpl()->InitArcIterator(s, data); + } + + virtual MatcherBase<A> *InitMatcher(MatchType match_type) const { + return new SortedMatcher<CompactFst<A, C, U> >(*this, match_type); + } + + template <class Iterator> + void SetCompactElements(const Iterator &b, const Iterator &e) { + GetImpl()->SetCompactElements(b, e); + } + + private: + CompactFst(Impl *impl) : ImplToExpandedFst<Impl>(impl) {} + + // Makes visible to friends. + Impl *GetImpl() const { return ImplToFst<Impl, ExpandedFst<A> >::GetImpl(); } + + void SetImpl(Impl *impl, bool own_impl = false) { + ImplToFst< Impl, ExpandedFst<A> >::SetImpl(impl, own_impl); + } + + // Use overloading to extract the type of the argument. + static Impl* GetImplIfCompactFst(const CompactFst<A, C, U> &compact_fst) { + return compact_fst.GetImpl(); + } + + // This does not give privileged treatment to subclasses of CompactFst. + template<typename NonCompactFst> + static Impl* GetImplIfCompactFst(const NonCompactFst& fst) { + return NULL; + } + + void operator=(const CompactFst<A, C, U> &fst); // disallow +}; + +// Writes Fst in Compact format, potentially with a pass over the machine +// before writing to compute the number of states and arcs. +// +template <class A, class C, class U> +template <class F> +bool CompactFst<A, C, U>::WriteFst(const F &fst, + const C &compactor, + ostream &strm, + const FstWriteOptions &opts) { + typedef U Unsigned; + typedef typename C::Element CompactElement; + typedef typename A::Weight Weight; + int file_version = opts.align ? + CompactFstImpl<A, C, U>::kAlignedFileVersion : + CompactFstImpl<A, C, U>::kFileVersion; + size_t num_arcs = -1, num_states = -1, num_compacts = -1; + C first_pass_compactor = compactor; + if (Impl* impl = GetImplIfCompactFst(fst)) { + num_arcs = impl->Data()->NumArcs(); + num_states = impl->Data()->NumStates(); + num_compacts = impl->Data()->NumCompacts(); + first_pass_compactor = *impl->GetCompactor(); + } else { + // A first pass is needed to compute the state of the compactor, which + // is saved ahead of the rest of the data structures. This unfortunately + // means forcing a complete double compaction when writing in this format. + // TODO(allauzen): eliminate mutable state from compactors. + num_arcs = 0; + num_states = 0; + for (StateIterator<F> siter(fst); !siter.Done(); siter.Next()) { + const StateId s = siter.Value(); + ++num_states; + if (fst.Final(s) != Weight::Zero()) { + first_pass_compactor.Compact( + s, A(kNoLabel, kNoLabel, fst.Final(s), kNoStateId)); + } + for (ArcIterator<F> aiter(fst, s); !aiter.Done(); aiter.Next()) { + ++num_arcs; + first_pass_compactor.Compact(s, aiter.Value()); + } + } + } + FstHeader hdr; + hdr.SetStart(fst.Start()); + hdr.SetNumStates(num_states); + hdr.SetNumArcs(num_arcs); + string type = "compact"; + if (sizeof(U) != sizeof(uint32)) { + string size; + Int64ToStr(8 * sizeof(U), &size); + type += size; + } + type += "_"; + type += C::Type(); + uint64 copy_properties = fst.Properties(kCopyProperties, true); + if ((copy_properties & kError) || !compactor.Compatible(fst)) { + LOG(ERROR) << "fst incompatible with compactor"; + return false; + } + uint64 properties = copy_properties | + CompactFstImpl<A, C, U>::kStaticProperties; + FstImpl<A>::WriteFstHeader(fst, strm, opts, file_version, type, properties, + &hdr); + first_pass_compactor.Write(strm); + if (first_pass_compactor.Size() == -1) { + if (opts.align && !AlignOutput(strm)) { + LOG(ERROR) << "CompactFst::Write: Alignment failed: " << opts.source; + return false; + } + Unsigned compacts = 0; + for (StateIterator<F> siter(fst); !siter.Done(); siter.Next()) { + const StateId s = siter.Value(); + strm.write(reinterpret_cast<const char *>(&compacts), sizeof(compacts)); + if (fst.Final(s) != Weight::Zero()) { + ++compacts; + } + compacts += fst.NumArcs(s); + } + strm.write(reinterpret_cast<const char *>(&compacts), sizeof(compacts)); + } + if (opts.align && !AlignOutput(strm)) { + LOG(ERROR) << "Could not align file during write after writing states"; + } + C second_pass_compactor = compactor; + CompactElement element; + for (StateIterator<F> siter(fst); !siter.Done(); siter.Next()) { + const StateId s = siter.Value(); + if (fst.Final(s) != Weight::Zero()) { + element = second_pass_compactor.Compact( + s, A(kNoLabel, kNoLabel, fst.Final(s), kNoStateId)); + strm.write(reinterpret_cast<const char *>(&element), sizeof(element)); + } + for (ArcIterator<F> aiter(fst, s); !aiter.Done(); aiter.Next()) { + element = second_pass_compactor.Compact(s, aiter.Value()); + strm.write(reinterpret_cast<const char *>(&element), sizeof(element)); + } + } + strm.flush(); + if (!strm) { + LOG(ERROR) << "CompactFst write failed: " << opts.source; + return false; + } + return true; +} + + +// Specialization for CompactFst; see generic version in fst.h +// for sample usage (but use the CompactFst type!). This version +// should inline. +template <class A, class C, class U> +class StateIterator< CompactFst<A, C, U> > { + public: + typedef typename A::StateId StateId; + + explicit StateIterator(const CompactFst<A, C, U> &fst) + : nstates_(fst.GetImpl()->NumStates()), s_(0) {} + + bool Done() const { return s_ >= nstates_; } + + StateId Value() const { return s_; } + + void Next() { ++s_; } + + void Reset() { s_ = 0; } + + private: + StateId nstates_; + StateId s_; + + DISALLOW_COPY_AND_ASSIGN(StateIterator); +}; + +// Specialization for CompactFst. +// Never caches, always iterates over the underlying compact elements. +template <class A, class C, class U> +class ArcIterator< CompactFst<A, C, U> > { + public: + typedef typename A::StateId StateId; + typedef typename C::Element CompactElement; + + ArcIterator(const CompactFst<A, C, U> &fst, StateId s) + : compactor_(fst.GetImpl()->GetCompactor()), state_(s), compacts_(0), + pos_(0), flags_(kArcValueFlags) { + + const CompactFstData<CompactElement, U> *data = fst.GetImpl()->Data(); + size_t offset; + if (compactor_->Size() == -1) { // Variable out-degree compactor + offset = data->States(s); + num_arcs_ = data->States(s + 1) - offset; + } else { // Fixed out-degree compactor + offset = s * compactor_->Size(); + num_arcs_ = compactor_->Size(); + } + if (num_arcs_ > 0) { + compacts_ = &(data->Compacts(offset)); + arc_ = compactor_->Expand(s, *compacts_, kArcILabelValue); + if (arc_.ilabel == kNoStateId) { + ++compacts_; + --num_arcs_; + } + } + } + + ~ArcIterator() {} + + bool Done() const { return pos_ >= num_arcs_; } + + const A& Value() const { + arc_ = compactor_->Expand(state_, compacts_[pos_], flags_); + return arc_; + } + + void Next() { ++pos_; } + + size_t Position() const { return pos_; } + + void Reset() { pos_ = 0; } + + void Seek(size_t pos) { pos_ = pos; } + + uint32 Flags() const { return flags_; } + + void SetFlags(uint32 f, uint32 m) { + flags_ &= ~m; + flags_ |= (f & kArcValueFlags); + } + + private: + C *compactor_; + StateId state_; + const CompactElement *compacts_; + size_t pos_; + size_t num_arcs_; + mutable A arc_; + uint32 flags_; + + DISALLOW_COPY_AND_ASSIGN(ArcIterator); +}; + +// // Specialization for CompactFst. +// // This is an optionally caching arc iterator. +// // TODO(allauzen): implements the kArcValueFlags, the current +// // implementation only implements the kArcNoCache flag. +// template <class A, class C, class U> +// class ArcIterator< CompactFst<A, C, U> > { +// public: +// typedef typename A::StateId StateId; + +// ArcIterator(const CompactFst<A, C, U> &fst, StateId s) +// : fst_(fst), state_(s), pos_(0), num_arcs_(0), offset_(0), +// flags_(kArcValueFlags) { +// cache_data_.ref_count = 0; + +// if (fst_.GetImpl()->HasArcs(state_)) { +// fst_.GetImpl()->InitArcIterator(s, &cache_data_); +// num_arcs_ = cache_data_.narcs; +// return; +// } + +// const C *compactor = fst_.GetImpl()->GetCompactor(); +// const CompactFstData<A, C, U> *data = fst_.GetImpl()->Data(); +// if (compactor->Size() == -1) { // Variable out-degree compactor +// offset_ = data->States(s); +// num_arcs_ = data->States(s + 1) - offset_; +// } else { // Fixed out-degree compactor +// offset_ = s * compactor->Size(); +// num_arcs_ = compactor->Size(); +// } +// if (num_arcs_ > 0) { +// const A &arc = fst_.GetImpl()->ComputeArc(s, offset_); +// if (arc.ilabel == kNoStateId) { +// ++offset_; +// --num_arcs_; +// } +// } +// } + + +// ~ArcIterator() { +// if (cache_data_.ref_count) +// --(*cache_data_.ref_count); +// } + +// bool Done() const { return pos_ >= num_arcs_; } + +// const A& Value() const { +// if (cache_data_.ref_count == 0) { +// if (flags_ & kArcNoCache) { +// arc_ = fst_.GetImpl()->ComputeArc(state_, pos_ + offset_); +// return arc_; +// } else { +// fst_.GetImpl()->InitArcIterator(state_, &cache_data_); +// } +// } +// return cache_data_.arcs[pos_]; +// } + +// void Next() { ++pos_; } + +// size_t Position() const { return pos_; } + +// void Reset() { pos_ = 0; } + +// void Seek(size_t pos) { pos_ = pos; } + +// uint32 Flags() const { return flags_; } + +// void SetFlags(uint32 f, uint32 m) { +// flags_ &= ~m; +// flags_ |= f; + +// if (!(flags_ & kArcNoCache) && cache_data_.ref_count == 0) +// fst_.GetImpl()->InitArcIterator(state_, &cache_data_); +// } + +// private: +// mutable const CompactFst<A, C, U> &fst_; +// StateId state_; +// size_t pos_; +// size_t num_arcs_; +// size_t offset_; +// uint32 flags_; +// mutable A arc_; +// mutable ArcIteratorData<A> cache_data_; + +// DISALLOW_COPY_AND_ASSIGN(ArcIterator); +// }; + + +// +// Utility Compactors +// + +// Compactor for unweighted string FSTs +template <class A> +class StringCompactor { + public: + typedef A Arc; + typedef typename A::Label Element; + typedef typename A::Label Label; + typedef typename A::StateId StateId; + typedef typename A::Weight Weight; + + Element Compact(StateId s, const A &arc) const { return arc.ilabel; } + + Arc Expand(StateId s, const Element &p, uint32 f = kArcValueFlags) const { + return Arc(p, p, Weight::One(), p != kNoLabel ? s + 1 : kNoStateId); + } + + ssize_t Size() const { return 1; } + + uint64 Properties() const { + return kString | kAcceptor | kUnweighted; + } + + bool Compatible(const Fst<A> &fst) const { + uint64 props = Properties(); + return fst.Properties(props, true) == props; + } + + static const string &Type() { + static const string type = "string"; + return type; + } + + bool Write(ostream &strm) const { return true; } + + static StringCompactor *Read(istream &strm) { + return new StringCompactor; + } +}; + + +// Compactor for weighted string FSTs +template <class A> +class WeightedStringCompactor { + public: + typedef A Arc; + typedef typename A::Label Label; + typedef typename A::StateId StateId; + typedef typename A::Weight Weight; + typedef pair<Label, Weight> Element; + + Element Compact(StateId s, const A &arc) const { + return make_pair(arc.ilabel, arc.weight); + } + + Arc Expand(StateId s, const Element &p, uint32 f = kArcValueFlags) const { + return Arc(p.first, p.first, p.second, + p.first != kNoLabel ? s + 1 : kNoStateId); + } + + ssize_t Size() const { return 1;} + + uint64 Properties() const { + return kString | kAcceptor; + } + + bool Compatible(const Fst<A> &fst) const { + uint64 props = Properties(); + return fst.Properties(props, true) == props; + } + + static const string &Type() { + static const string type = "weighted_string"; + return type; + } + + bool Write(ostream &strm) const { return true; } + + static WeightedStringCompactor *Read(istream &strm) { + return new WeightedStringCompactor; + } +}; + + +// Compactor for unweighted acceptor FSTs +template <class A> +class UnweightedAcceptorCompactor { + public: + typedef A Arc; + typedef typename A::Label Label; + typedef typename A::StateId StateId; + typedef typename A::Weight Weight; + typedef pair<Label, StateId> Element; + + Element Compact(StateId s, const A &arc) const { + return make_pair(arc.ilabel, arc.nextstate); + } + + Arc Expand(StateId s, const Element &p, uint32 f = kArcValueFlags) const { + return Arc(p.first, p.first, Weight::One(), p.second); + } + + ssize_t Size() const { return -1;} + + uint64 Properties() const { + return kAcceptor | kUnweighted; + } + + bool Compatible(const Fst<A> &fst) const { + uint64 props = Properties(); + return fst.Properties(props, true) == props; + } + + static const string &Type() { + static const string type = "unweighted_acceptor"; + return type; + } + + bool Write(ostream &strm) const { return true; } + + static UnweightedAcceptorCompactor *Read(istream &istrm) { + return new UnweightedAcceptorCompactor; + } +}; + + +// Compactor for weighted acceptor FSTs +template <class A> +class AcceptorCompactor { + public: + typedef A Arc; + typedef typename A::Label Label; + typedef typename A::StateId StateId; + typedef typename A::Weight Weight; + typedef pair< pair<Label, Weight>, StateId > Element; + + Element Compact(StateId s, const A &arc) const { + return make_pair(make_pair(arc.ilabel, arc.weight), arc.nextstate); + } + + Arc Expand(StateId s, const Element &p, uint32 f = kArcValueFlags) const { + return Arc(p.first.first, p.first.first, p.first.second, p.second); + } + + ssize_t Size() const { return -1;} + + uint64 Properties() const { + return kAcceptor; + } + + bool Compatible(const Fst<A> &fst) const { + uint64 props = Properties(); + return fst.Properties(props, true) == props; + } + + static const string &Type() { + static const string type = "acceptor"; + return type; + } + + bool Write(ostream &strm) const { return true; } + + static AcceptorCompactor *Read(istream &strm) { + return new AcceptorCompactor; + } +}; + + +// Compactor for unweighted FSTs +template <class A> +class UnweightedCompactor { + public: + typedef A Arc; + typedef typename A::Label Label; + typedef typename A::StateId StateId; + typedef typename A::Weight Weight; + typedef pair< pair<Label, Label>, StateId > Element; + + Element Compact(StateId s, const A &arc) const { + return make_pair(make_pair(arc.ilabel, arc.olabel), arc.nextstate); + } + + Arc Expand(StateId s, const Element &p, uint32 f = kArcValueFlags) const { + return Arc(p.first.first, p.first.second, Weight::One(), p.second); + } + + ssize_t Size() const { return -1; } + + uint64 Properties() const { + return kUnweighted; + } + + bool Compatible(const Fst<A> &fst) const { + uint64 props = Properties(); + return fst.Properties(props, true) == props; + } + + static const string &Type() { + static const string type = "unweighted"; + return type; + } + + bool Write(ostream &strm) const { return true; } + + static UnweightedCompactor *Read(istream &strm) { + return new UnweightedCompactor; + } +}; + + +// Uselful aliases when using StdArc +typedef CompactFst< StdArc, StringCompactor<StdArc> > +StdCompactStringFst; +typedef CompactFst< StdArc, WeightedStringCompactor<StdArc> > +StdCompactWeightedStringFst; +typedef CompactFst<StdArc, AcceptorCompactor<StdArc> > +StdCompactAcceptorFst; +typedef CompactFst<StdArc, UnweightedCompactor<StdArc> > +StdCompactUnweightedFst; +typedef CompactFst<StdArc, UnweightedAcceptorCompactor<StdArc> > +StdCompactUnweightedAcceptorFst; + +} // namespace fst + +#endif // FST_LIB_COMPACT_FST_H__ diff --git a/kaldi_io/src/tools/openfst/include/fst/compat.h b/kaldi_io/src/tools/openfst/include/fst/compat.h new file mode 100644 index 0000000..3b5275d --- /dev/null +++ b/kaldi_io/src/tools/openfst/include/fst/compat.h @@ -0,0 +1,131 @@ +// compat.h +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Author: [email protected] (Michael Riley) +// +// \file +// Google compatibility declarations and inline definitions. + +#ifndef FST_LIB_COMPAT_H__ +#define FST_LIB_COMPAT_H__ + +#include <dlfcn.h> + +#include <climits> +#include <cstdlib> +#include <cstring> +#include <iostream> +#include <string> +#include <vector> + +// Makes copy constructor and operator= private +#define DISALLOW_COPY_AND_ASSIGN(type) \ + type(const type&); \ + void operator=(const type&) + +#include <fst/config.h> +#include <fst/types.h> +#include <fst/lock.h> +#include <fst/flags.h> +#include <fst/log.h> +#include <fst/icu.h> + +using std::cin; +using std::cout; +using std::cerr; +using std::endl; +using std::string; + +void FailedNewHandler(); + +namespace fst { + +using namespace std; + +void SplitToVector(char *line, const char *delim, + std::vector<char *> *vec, bool omit_empty_strings); + +// Downcasting +template<typename To, typename From> +inline To down_cast(From* f) { + return static_cast<To>(f); +} + +// Bitcasting +template <class Dest, class Source> +inline Dest bit_cast(const Source& source) { + // Compile time assertion: sizeof(Dest) == sizeof(Source) + // A compile error here means your Dest and Source have different sizes. + typedef char VerifySizesAreEqual [sizeof(Dest) == sizeof(Source) ? 1 : + -1]; + Dest dest; + memcpy(&dest, &source, sizeof(dest)); + return dest; +} + +// Check sums +class CheckSummer { + public: + CheckSummer() : count_(0) { + check_sum_.resize(kCheckSumLength, '\0'); + } + + void Reset() { + count_ = 0; + for (int i = 0; i < kCheckSumLength; ++i) + check_sum_[i] = '\0'; + } + + void Update(void const *data, int size) { + const char *p = reinterpret_cast<const char *>(data); + for (int i = 0; i < size; ++i) + check_sum_[(count_++) % kCheckSumLength] ^= p[i]; + } + + void Update(string const &data) { + for (int i = 0; i < data.size(); ++i) + check_sum_[(count_++) % kCheckSumLength] ^= data[i]; + } + + string Digest() { + return check_sum_; + } + + private: + static const int kCheckSumLength = 32; + int count_; + string check_sum_; + + DISALLOW_COPY_AND_ASSIGN(CheckSummer); +}; + +} // namespace fst + + +// Define missing hash functions if needed +#ifndef HAVE_STD__TR1__HASH_LONG_LONG_UNSIGNED_ +namespace std { +namespace tr1 { + +template <class T> class hash; + +template<> struct hash<uint64> { + size_t operator()(uint64 x) const { return x; } +}; + +} +} +#endif // HAVE_STD__TR1__HASH_LONG_LONG_UNSIGNED_ + +#endif // FST_LIB_COMPAT_H__ diff --git a/kaldi_io/src/tools/openfst/include/fst/complement.h b/kaldi_io/src/tools/openfst/include/fst/complement.h new file mode 100644 index 0000000..dacf396 --- /dev/null +++ b/kaldi_io/src/tools/openfst/include/fst/complement.h @@ -0,0 +1,338 @@ +// complement.h + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Copyright 2005-2010 Google, Inc. +// Author: [email protected] (Michael Riley) +// +// \file +// Class to complement an Fst. + +#ifndef FST_LIB_COMPLEMENT_H__ +#define FST_LIB_COMPLEMENT_H__ + +#include <algorithm> +#include <string> +#include <vector> +using std::vector; + +#include <fst/fst.h> +#include <fst/test-properties.h> + + +namespace fst { + +template <class A> class ComplementFst; + +// Implementation of delayed ComplementFst. The algorithm used +// completes the (deterministic) FSA and then exchanges final and +// non-final states. Completion, i.e. ensuring that all labels can be +// read from every state, is accomplished by using RHO labels, which +// match all labels that are otherwise not found leaving a state. The +// first state in the output is reserved to be a new state that is the +// destination of all RHO labels. Each remaining output state s +// corresponds to input state s - 1. The first arc in the output at +// these states is the rho label, the remaining arcs correspond to the +// input arcs. +template <class A> +class ComplementFstImpl : public FstImpl<A> { + public: + using FstImpl<A>::SetType; + using FstImpl<A>::SetProperties; + using FstImpl<A>::SetInputSymbols; + using FstImpl<A>::SetOutputSymbols; + + friend class StateIterator< ComplementFst<A> >; + friend class ArcIterator< ComplementFst<A> >; + + typedef A Arc; + typedef typename A::Label Label; + typedef typename A::Weight Weight; + typedef typename A::StateId StateId; + + explicit ComplementFstImpl(const Fst<A> &fst) : fst_(fst.Copy()) { + SetType("complement"); + uint64 props = fst.Properties(kILabelSorted, false); + SetProperties(ComplementProperties(props), kCopyProperties); + SetInputSymbols(fst.InputSymbols()); + SetOutputSymbols(fst.OutputSymbols()); + } + + ComplementFstImpl(const ComplementFstImpl<A> &impl) + : fst_(impl.fst_->Copy()) { + SetType("complement"); + SetProperties(impl.Properties(), kCopyProperties); + SetInputSymbols(impl.InputSymbols()); + SetOutputSymbols(impl.OutputSymbols()); + } + + ~ComplementFstImpl() { delete fst_; } + + StateId Start() const { + if (Properties(kError)) + return kNoStateId; + + StateId start = fst_->Start(); + if (start != kNoStateId) + return start + 1; + else + return 0; + } + + // Exchange final and non-final states; make rho destination state final. + Weight Final(StateId s) const { + if (s == 0 || fst_->Final(s - 1) == Weight::Zero()) + return Weight::One(); + else + return Weight::Zero(); + } + + size_t NumArcs(StateId s) const { + if (s == 0) + return 1; + else + return fst_->NumArcs(s - 1) + 1; + } + + size_t NumInputEpsilons(StateId s) const { + return s == 0 ? 0 : fst_->NumInputEpsilons(s - 1); + } + + size_t NumOutputEpsilons(StateId s) const { + return s == 0 ? 0 : fst_->NumOutputEpsilons(s - 1); + } + + + uint64 Properties() const { return Properties(kFstProperties); } + + // Set error if found; return FST impl properties. + uint64 Properties(uint64 mask) const { + if ((mask & kError) && fst_->Properties(kError, false)) + SetProperties(kError, kError); + return FstImpl<Arc>::Properties(mask); + } + + + private: + const Fst<A> *fst_; + + void operator=(const ComplementFstImpl<A> &fst); // Disallow +}; + + +// Complements an automaton. This is a library-internal operation that +// introduces a (negative) 'rho' label; use Difference/DifferenceFst in +// user code, which will not see this label. This version is a delayed Fst. +// +// This class attaches interface to implementation and handles +// reference counting, delegating most methods to ImplToFst. +template <class A> +class ComplementFst : public ImplToFst< ComplementFstImpl<A> > { + public: + friend class StateIterator< ComplementFst<A> >; + friend class ArcIterator< ComplementFst<A> >; + + using ImplToFst< ComplementFstImpl<A> >::GetImpl; + + typedef A Arc; + typedef typename A::StateId StateId; + typedef typename A::Label Label; + typedef ComplementFstImpl<A> Impl; + + explicit ComplementFst(const Fst<A> &fst) + : ImplToFst<Impl>(new Impl(fst)) { + uint64 props = kUnweighted | kNoEpsilons | kIDeterministic | kAcceptor; + if (fst.Properties(props, true) != props) { + FSTERROR() << "ComplementFst: argument not an unweighted " + << "epsilon-free deterministic acceptor"; + GetImpl()->SetProperties(kError, kError); + } + } + + // See Fst<>::Copy() for doc. + ComplementFst(const ComplementFst<A> &fst, bool safe = false) + : ImplToFst<Impl>(fst, safe) {} + + // Get a copy of this ComplementFst. See Fst<>::Copy() for further doc. + virtual ComplementFst<A> *Copy(bool safe = false) const { + return new ComplementFst<A>(*this, safe); + } + + virtual inline void InitStateIterator(StateIteratorData<A> *data) const; + + virtual inline void InitArcIterator(StateId s, + ArcIteratorData<A> *data) const; + + // Label that represents the rho transition. + // We use a negative value, which is thus private to the library and + // which will preserve FST label sort order. + static const Label kRhoLabel = -2; + private: + // Makes visible to friends. + Impl *GetImpl() const { return ImplToFst<Impl>::GetImpl(); } + + void operator=(const ComplementFst<A> &fst); // disallow +}; + +template <class A> const typename A::Label ComplementFst<A>::kRhoLabel; + + +// Specialization for ComplementFst. +template <class A> +class StateIterator< ComplementFst<A> > : public StateIteratorBase<A> { + public: + typedef typename A::StateId StateId; + typedef typename A::Label Label; + + explicit StateIterator(const ComplementFst<A> &fst) + : siter_(*fst.GetImpl()->fst_), s_(0) { + } + + bool Done() const { return s_ > 0 && siter_.Done(); } + + StateId Value() const { return s_; } + + void Next() { + if (s_ != 0) + siter_.Next(); + ++s_; + } + + void Reset() { + siter_.Reset(); + s_ = 0; + } + + private: + // This allows base class virtual access to non-virtual derived- + // class members of the same name. It makes the derived class more + // efficient to use but unsafe to further derive. + virtual bool Done_() const { return Done(); } + virtual StateId Value_() const { return Value(); } + virtual void Next_() { Next(); } + virtual void Reset_() { Reset(); } + + StateIterator< Fst<A> > siter_; + StateId s_; + + DISALLOW_COPY_AND_ASSIGN(StateIterator); +}; + + +// Specialization for ComplementFst. +template <class A> +class ArcIterator< ComplementFst<A> > : public ArcIteratorBase<A> { + public: + typedef typename A::StateId StateId; + typedef typename A::Label Label; + typedef typename A::Weight Weight; + + ArcIterator(const ComplementFst<A> &fst, StateId s) + : aiter_(0), s_(s), pos_(0) { + if (s_ != 0) + aiter_ = new ArcIterator< Fst<A> >(*fst.GetImpl()->fst_, s - 1); + } + + virtual ~ArcIterator() { delete aiter_; } + + bool Done() const { + if (s_ != 0) + return pos_ > 0 && aiter_->Done(); + else + return pos_ > 0; + } + + // Adds the rho label to the rho destination state. + const A& Value() const { + if (pos_ == 0) { + arc_.ilabel = arc_.olabel = ComplementFst<A>::kRhoLabel; + arc_.weight = Weight::One(); + arc_.nextstate = 0; + } else { + arc_ = aiter_->Value(); + ++arc_.nextstate; + } + return arc_; + } + + void Next() { + if (s_ != 0 && pos_ > 0) + aiter_->Next(); + ++pos_; + } + + size_t Position() const { + return pos_; + } + + void Reset() { + if (s_ != 0) + aiter_->Reset(); + pos_ = 0; + } + + void Seek(size_t a) { + if (s_ != 0) { + if (a == 0) { + aiter_->Reset(); + } else { + aiter_->Seek(a - 1); + } + } + pos_ = a; + } + + uint32 Flags() const { + return kArcValueFlags; + } + + void SetFlags(uint32 f, uint32 m) {} + + private: + // This allows base class virtual access to non-virtual derived- + // class members of the same name. It makes the derived class more + // efficient to use but unsafe to further derive. + virtual bool Done_() const { return Done(); } + virtual const A& Value_() const { return Value(); } + virtual void Next_() { Next(); } + virtual size_t Position_() const { return Position(); } + virtual void Reset_() { Reset(); } + virtual void Seek_(size_t a) { Seek(a); } + uint32 Flags_() const { return Flags(); } + void SetFlags_(uint32 f, uint32 m) { SetFlags(f, m); } + + ArcIterator< Fst<A> > *aiter_; + StateId s_; + size_t pos_; + mutable A arc_; + DISALLOW_COPY_AND_ASSIGN(ArcIterator); +}; + + +template <class A> inline void +ComplementFst<A>::InitStateIterator(StateIteratorData<A> *data) const { + data->base = new StateIterator< ComplementFst<A> >(*this); +} + +template <class A> inline void +ComplementFst<A>::InitArcIterator(StateId s, ArcIteratorData<A> *data) const { + data->base = new ArcIterator< ComplementFst<A> >(*this, s); +} + + +// Useful alias when using StdArc. +typedef ComplementFst<StdArc> StdComplementFst; + +} // namespace fst + +#endif // FST_LIB_COMPLEMENT_H__ diff --git a/kaldi_io/src/tools/openfst/include/fst/compose-filter.h b/kaldi_io/src/tools/openfst/include/fst/compose-filter.h new file mode 100644 index 0000000..6bf7736 --- /dev/null +++ b/kaldi_io/src/tools/openfst/include/fst/compose-filter.h @@ -0,0 +1,542 @@ +// compose-filter.h + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Copyright 2005-2010 Google, Inc. +// Author: [email protected] (Michael Riley) +// +// \file +// Classes for filtering the composition matches, e.g. for correct epsilon +// handling. + +#ifndef FST_LIB_COMPOSE_FILTER_H__ +#define FST_LIB_COMPOSE_FILTER_H__ + +#include <fst/fst.h> +#include <fst/fst-decl.h> // For optional argument declarations +#include <fst/matcher.h> + + +namespace fst { + + +// COMPOSITION FILTER STATE - this represents the state of +// the composition filter. It has the form: +// +// class FilterState { +// public: +// // Required constructors +// FilterState(); +// FilterState(const FilterState &f); +// // An invalid filter state. +// static const FilterState NoState(); +// // Maps state to integer for hashing. +// size_t Hash() const; +// // Equality of filter states. +// bool operator==(const FilterState &f) const; +// // Inequality of filter states. +// bool operator!=(const FilterState &f) const; +// // Assignment to filter states. +// FilterState& operator=(const FilterState& f); +// }; + + +// Filter state that is a signed integral type. +template <typename T> +class IntegerFilterState { + public: + IntegerFilterState() : state_(kNoStateId) {} + explicit IntegerFilterState(T s) : state_(s) {} + + static const IntegerFilterState NoState() { return IntegerFilterState(); } + + size_t Hash() const { return static_cast<size_t>(state_); } + + bool operator==(const IntegerFilterState &f) const { + return state_ == f.state_; + } + + bool operator!=(const IntegerFilterState &f) const { + return state_ != f.state_; + } + + T GetState() const { return state_; } + + void SetState(T state) { state_ = state; } + +private: + T state_; +}; + +typedef IntegerFilterState<signed char> CharFilterState; +typedef IntegerFilterState<short> ShortFilterState; +typedef IntegerFilterState<int> IntFilterState; + + +// Filter state that is a weight (class). +template <class W> +class WeightFilterState { + public: + WeightFilterState() : weight_(W::Zero()) {} + explicit WeightFilterState(W w) : weight_(w) {} + + static const WeightFilterState NoState() { return WeightFilterState(); } + + size_t Hash() const { return weight_.Hash(); } + + bool operator==(const WeightFilterState &f) const { + return weight_ == f.weight_; + } + + bool operator!=(const WeightFilterState &f) const { + return weight_ != f.weight_; + } + + W GetWeight() const { return weight_; } + + void SetWeight(W w) { weight_ = w; } + +private: + W weight_; +}; + + +// Filter state that is the combination of two filter states. +template <class F1, class F2> +class PairFilterState { + public: + PairFilterState() : f1_(F1::NoState()), f2_(F2::NoState()) {} + + PairFilterState(const F1 &f1, const F2 &f2) : f1_(f1), f2_(f2) {} + + static const PairFilterState NoState() { return PairFilterState(); } + + size_t Hash() const { + size_t h1 = f1_.Hash(); + size_t h2 = f2_.Hash(); + const int lshift = 5; + const int rshift = CHAR_BIT * sizeof(size_t) - 5; + return h1 << lshift ^ h1 >> rshift ^ h2; + } + + bool operator==(const PairFilterState &f) const { + return f1_ == f.f1_ && f2_ == f.f2_; + } + + bool operator!=(const PairFilterState &f) const { + return f1_ != f.f1_ || f2_ != f.f2_; + } + + const F1 &GetState1() const { return f1_; } + const F2 &GetState2() const { return f2_; } + + void SetState(const F1 &f1, const F2 &f2) { + f1_ = f1; + f2_ = f2; + } + +private: + F1 f1_; + F2 f2_; +}; + + +// COMPOSITION FILTERS - these determine which matches are allowed to +// proceed. The filter's state is represented by the type +// ComposeFilter::FilterState. The basic filters handle correct +// epsilon matching. Their interface is: +// +// template <class M1, class M2> +// class ComposeFilter { +// public: +// typedef typename M1::FST1 FST1; +// typedef typename M1::FST2 FST2; +// typedef typename FST1::Arc Arc; +// typedef ... FilterState; +// typedef ... Matcher1; +// typedef ... Matcher2; +// +// // Required constructors. +// ComposeFilter(const FST1 &fst1, const FST2 &fst2, +// // M1 *matcher1 = 0, M2 *matcher2 = 0); +// // If safe=true, the copy is thread-safe. See Fst<>::Copy() +// // for further doc. +// ComposeFilter(const ComposeFilter<M1, M2> &filter, +// // bool safe = false); +// // Return start state of filter. +// FilterState Start() const; +// // Specifies current composition state. +// void SetState(StateId s1, StateId s2, const FilterState &f); +// +// // Apply filter at current composition state to these transitions. +// // If an arc label to be matched is kNolabel, then that side +// // does not consume a symbol. Returns the new filter state or, +// // if disallowed, FilterState::NoState(). The filter is permitted to +// // modify its inputs, e.g. for optimizations. +// FilterState FilterArc(Arc *arc1, Arc *arc2) const; + +// // Apply filter at current composition state to these final weights +// // (cf. superfinal transitions). The filter may modify its inputs, +// // e.g. for optimizations. +// void FilterFinal(Weight *final1, Weight *final2) const; +// +// // Return resp matchers. Ownership stays with filter. These +// // methods allow the filter to access and possibly modify +// // the composition matchers (useful e.g. with lookahead). +// Matcher1 *GetMatcher1(); +// Matcher2 *GetMatcher2(); +// +// // This specifies how the filter affects the composition result +// // properties. It takes as argument the properties that would +// // apply with a trivial composition fitler. +// uint64 Properties(uint64 props) const; +// }; + +// This filter requires epsilons on FST1 to be read before epsilons on FST2. +template <class M1, class M2> +class SequenceComposeFilter { + public: + typedef typename M1::FST FST1; + typedef typename M2::FST FST2; + typedef typename FST1::Arc Arc; + typedef CharFilterState FilterState; + typedef M1 Matcher1; + typedef M2 Matcher2; + + typedef typename Arc::StateId StateId; + typedef typename Arc::Label Label; + typedef typename Arc::Weight Weight; + + SequenceComposeFilter(const FST1 &fst1, const FST2 &fst2, + M1 *matcher1 = 0, M2 *matcher2 = 0) + : matcher1_(matcher1 ? matcher1 : new M1(fst1, MATCH_OUTPUT)), + matcher2_(matcher2 ? matcher2 : new M2(fst2, MATCH_INPUT)), + fst1_(matcher1_->GetFst()), + s1_(kNoStateId), + s2_(kNoStateId), + f_(kNoStateId) {} + + SequenceComposeFilter(const SequenceComposeFilter<M1, M2> &filter, + bool safe = false) + : matcher1_(filter.matcher1_->Copy(safe)), + matcher2_(filter.matcher2_->Copy(safe)), + fst1_(matcher1_->GetFst()), + s1_(kNoStateId), + s2_(kNoStateId), + f_(kNoStateId) {} + + ~SequenceComposeFilter() { + delete matcher1_; + delete matcher2_; + } + + FilterState Start() const { return FilterState(0); } + + void SetState(StateId s1, StateId s2, const FilterState &f) { + if (s1_ == s1 && s2_ == s2 && f == f_) + return; + s1_ = s1; + s2_ = s2; + f_ = f; + size_t na1 = internal::NumArcs(fst1_, s1); + size_t ne1 = internal::NumOutputEpsilons(fst1_, s1); + bool fin1 = internal::Final(fst1_, s1) != Weight::Zero(); + alleps1_ = na1 == ne1 && !fin1; + noeps1_ = ne1 == 0; + } + + FilterState FilterArc(Arc *arc1, Arc *arc2) const { + if (arc1->olabel == kNoLabel) + return alleps1_ ? FilterState::NoState() : + noeps1_ ? FilterState(0) : FilterState(1); + else if (arc2->ilabel == kNoLabel) + return f_ != FilterState(0) ? FilterState::NoState() : FilterState(0); + else + return arc1->olabel == 0 ? FilterState::NoState() : FilterState(0); + } + + void FilterFinal(Weight *, Weight *) const {} + + // Return resp matchers. Ownership stays with filter. + Matcher1 *GetMatcher1() { return matcher1_; } + Matcher2 *GetMatcher2() { return matcher2_; } + + uint64 Properties(uint64 props) const { return props; } + + private: + Matcher1 *matcher1_; + Matcher2 *matcher2_; + const FST1 &fst1_; + StateId s1_; // Current fst1_ state; + StateId s2_; // Current fst2_ state; + FilterState f_; // Current filter state + bool alleps1_; // Only epsilons (and non-final) leaving s1_? + bool noeps1_; // No epsilons leaving s1_? + + void operator=(const SequenceComposeFilter<M1, M2> &); // disallow +}; + + +// This filter requires epsilons on FST2 to be read before epsilons on FST1. +template <class M1, class M2> +class AltSequenceComposeFilter { + public: + typedef typename M1::FST FST1; + typedef typename M2::FST FST2; + typedef typename FST1::Arc Arc; + typedef CharFilterState FilterState; + typedef M1 Matcher1; + typedef M2 Matcher2; + + typedef typename Arc::StateId StateId; + typedef typename Arc::Label Label; + typedef typename Arc::Weight Weight; + + AltSequenceComposeFilter(const FST1 &fst1, const FST2 &fst2, + M1 *matcher1 = 0, M2 *matcher2 = 0) + : matcher1_(matcher1 ? matcher1 : new M1(fst1, MATCH_OUTPUT)), + matcher2_(matcher2 ? matcher2 : new M2(fst2, MATCH_INPUT)), + fst2_(matcher2_->GetFst()), + s1_(kNoStateId), + s2_(kNoStateId), + f_(kNoStateId) {} + + AltSequenceComposeFilter(const AltSequenceComposeFilter<M1, M2> &filter, + bool safe = false) + : matcher1_(filter.matcher1_->Copy(safe)), + matcher2_(filter.matcher2_->Copy(safe)), + fst2_(matcher2_->GetFst()), + s1_(kNoStateId), + s2_(kNoStateId), + f_(kNoStateId) {} + + ~AltSequenceComposeFilter() { + delete matcher1_; + delete matcher2_; + } + + FilterState Start() const { return FilterState(0); } + + void SetState(StateId s1, StateId s2, const FilterState &f) { + if (s1_ == s1 && s2_ == s2 && f == f_) + return; + s1_ = s1; + s2_ = s2; + f_ = f; + size_t na2 = internal::NumArcs(fst2_, s2); + size_t ne2 = internal::NumInputEpsilons(fst2_, s2); + bool fin2 = internal::Final(fst2_, s2) != Weight::Zero(); + alleps2_ = na2 == ne2 && !fin2; + noeps2_ = ne2 == 0; + } + + FilterState FilterArc(Arc *arc1, Arc *arc2) const { + if (arc2->ilabel == kNoLabel) + return alleps2_ ? FilterState::NoState() : + noeps2_ ? FilterState(0) : FilterState(1); + else if (arc1->olabel == kNoLabel) + return f_ == FilterState(1) ? FilterState::NoState() : FilterState(0); + else + return arc1->olabel == 0 ? FilterState::NoState() : FilterState(0); + } + + void FilterFinal(Weight *, Weight *) const {} + + // Return resp matchers. Ownership stays with filter. + Matcher1 *GetMatcher1() { return matcher1_; } + Matcher2 *GetMatcher2() { return matcher2_; } + + uint64 Properties(uint64 props) const { return props; } + + private: + Matcher1 *matcher1_; + Matcher2 *matcher2_; + const FST2 &fst2_; + StateId s1_; // Current fst1_ state; + StateId s2_; // Current fst2_ state; + FilterState f_; // Current filter state + bool alleps2_; // Only epsilons (and non-final) leaving s2_? + bool noeps2_; // No epsilons leaving s2_? + +void operator=(const AltSequenceComposeFilter<M1, M2> &); // disallow +}; + + +// This filter requires epsilons on FST1 to be matched with epsilons on FST2 +// whenever possible. +template <class M1, class M2> +class MatchComposeFilter { + public: + typedef typename M1::FST FST1; + typedef typename M2::FST FST2; + typedef typename FST1::Arc Arc; + typedef CharFilterState FilterState; + typedef M1 Matcher1; + typedef M2 Matcher2; + + typedef typename Arc::StateId StateId; + typedef typename Arc::Label Label; + typedef typename Arc::Weight Weight; + + MatchComposeFilter(const FST1 &fst1, const FST2 &fst2, + M1 *matcher1 = 0, M2 *matcher2 = 0) + : matcher1_(matcher1 ? matcher1 : new M1(fst1, MATCH_OUTPUT)), + matcher2_(matcher2 ? matcher2 : new M2(fst2, MATCH_INPUT)), + fst1_(matcher1_->GetFst()), + fst2_(matcher2_->GetFst()), + s1_(kNoStateId), + s2_(kNoStateId), + f_(kNoStateId) {} + + MatchComposeFilter(const MatchComposeFilter<M1, M2> &filter, + bool safe = false) + : matcher1_(filter.matcher1_->Copy(safe)), + matcher2_(filter.matcher2_->Copy(safe)), + fst1_(matcher1_->GetFst()), + fst2_(matcher2_->GetFst()), + s1_(kNoStateId), + s2_(kNoStateId), + f_(kNoStateId) {} + + ~MatchComposeFilter() { + delete matcher1_; + delete matcher2_; + } + + FilterState Start() const { return FilterState(0); } + + void SetState(StateId s1, StateId s2, const FilterState &f) { + if (s1_ == s1 && s2_ == s2 && f == f_) + return; + s1_ = s1; + s2_ = s2; + f_ = f; + size_t na1 = internal::NumArcs(fst1_, s1); + size_t ne1 = internal::NumOutputEpsilons(fst1_, s1); + bool f1 = internal::Final(fst1_, s1) != Weight::Zero(); + alleps1_ = na1 == ne1 && !f1; + noeps1_ = ne1 == 0; + size_t na2 = internal::NumArcs(fst2_, s2); + size_t ne2 = internal::NumInputEpsilons(fst2_, s2); + bool f2 = internal::Final(fst2_, s2) != Weight::Zero(); + alleps2_ = na2 == ne2 && !f2; + noeps2_ = ne2 == 0; + } + + FilterState FilterArc(Arc *arc1, Arc *arc2) const { + if (arc2->ilabel == kNoLabel) // Epsilon on Fst1 + return f_ == FilterState(0) ? + (noeps2_ ? FilterState(0) : + (alleps2_ ? FilterState::NoState(): FilterState(1))) : + (f_ == FilterState(1) ? FilterState(1) : FilterState::NoState()); + else if (arc1->olabel == kNoLabel) // Epsilon on Fst2 + return f_ == FilterState(0) ? + (noeps1_ ? FilterState(0) : + (alleps1_ ? FilterState::NoState() : FilterState(2))) : + (f_ == FilterState(2) ? FilterState(2) : FilterState::NoState()); + else if (arc1->olabel == 0) // Epsilon on both + return f_ == FilterState(0) ? FilterState(0) : FilterState::NoState(); + else // Both are non-epsilons + return FilterState(0); + } + + void FilterFinal(Weight *, Weight *) const {} + + // Return resp matchers. Ownership stays with filter. + Matcher1 *GetMatcher1() { return matcher1_; } + Matcher2 *GetMatcher2() { return matcher2_; } + + uint64 Properties(uint64 props) const { return props; } + + private: + Matcher1 *matcher1_; + Matcher2 *matcher2_; + const FST1 &fst1_; + const FST2 &fst2_; + StateId s1_; // Current fst1_ state; + StateId s2_; // Current fst2_ state; + FilterState f_; // Current filter state ID + bool alleps1_, alleps2_; // Only epsilons (and non-final) leaving s1, s2? + bool noeps1_, noeps2_; // No epsilons leaving s1, s2? + + void operator=(const MatchComposeFilter<M1, M2> &); // disallow +}; + + +// This filter works with the MultiEpsMatcher to determine if +// 'multi-epsilons' are preserved in the composition output +// (rather than rewritten as 0) and ensures correct properties. +template <class F> +class MultiEpsFilter { + public: + typedef typename F::FST1 FST1; + typedef typename F::FST2 FST2; + typedef typename F::Arc Arc; + typedef typename F::Matcher1 Matcher1; + typedef typename F::Matcher2 Matcher2; + typedef typename F::FilterState FilterState; + typedef MultiEpsFilter<F> Filter; + + typedef typename Arc::StateId StateId; + typedef typename Arc::Label Label; + typedef typename Arc::Weight Weight; + + MultiEpsFilter(const FST1 &fst1, const FST2 &fst2, + Matcher1 *matcher1 = 0, Matcher2 *matcher2 = 0, + bool keep_multi_eps = false) + : filter_(fst1, fst2, matcher1, matcher2), + keep_multi_eps_(keep_multi_eps) {} + + MultiEpsFilter(const Filter &filter, bool safe = false) + : filter_(filter.filter_, safe), + keep_multi_eps_(filter.keep_multi_eps_) {} + + FilterState Start() const { return filter_.Start(); } + + void SetState(StateId s1, StateId s2, const FilterState &f) { + return filter_.SetState(s1, s2, f); + } + + FilterState FilterArc(Arc *arc1, Arc *arc2) const { + FilterState f = filter_.FilterArc(arc1, arc2); + if (keep_multi_eps_) { + if (arc1->olabel == kNoLabel) + arc1->ilabel = arc2->ilabel; + if (arc2->ilabel == kNoLabel) + arc2->olabel = arc1->olabel; + } + return f; + } + + void FilterFinal(Weight *w1, Weight *w2) const { + return filter_.FilterFinal(w1, w2); + } + + // Return resp matchers. Ownership stays with filter. + Matcher1 *GetMatcher1() { return filter_.GetMatcher1(); } + Matcher2 *GetMatcher2() { return filter_.GetMatcher2(); } + + uint64 Properties(uint64 iprops) const { + uint64 oprops = filter_.Properties(iprops); + return oprops & kILabelInvariantProperties & kOLabelInvariantProperties; + } + + private: + F filter_; + bool keep_multi_eps_; +}; + +} // namespace fst + + +#endif // FST_LIB_COMPOSE_FILTER_H__ diff --git a/kaldi_io/src/tools/openfst/include/fst/compose.h b/kaldi_io/src/tools/openfst/include/fst/compose.h new file mode 100644 index 0000000..db5ea3a --- /dev/null +++ b/kaldi_io/src/tools/openfst/include/fst/compose.h @@ -0,0 +1,728 @@ +// compose.h + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Copyright 2005-2010 Google, Inc. +// Author: [email protected] (Michael Riley) +// +// \file +// Class to compute the composition of two FSTs + +#ifndef FST_LIB_COMPOSE_H__ +#define FST_LIB_COMPOSE_H__ + +#include <algorithm> +#include <string> +#include <vector> +using std::vector; + +#include <fst/cache.h> +#include <fst/compose-filter.h> +#include <fst/lookahead-filter.h> +#include <fst/matcher.h> +#include <fst/state-table.h> +#include <fst/test-properties.h> + + +namespace fst { + +// Delayed composition options templated on the arc type, the matcher, +// the composition filter, and the composition state table. By +// default, the matchers, filter, and state table are constructed by +// composition. If set below, the user can instead pass in these +// objects; in that case, ComposeFst takes their ownership. This +// version controls composition implemented between generic Fst<Arc> +// types and a shared matcher type M for Fst<Arc>. This should be +// adequate for most applications, giving a reasonable tradeoff +// between efficiency and code sharing (but see ComposeFstImplOptions). +template <class A, + class M = Matcher<Fst<A> >, + class F = SequenceComposeFilter<M>, + class T = GenericComposeStateTable<A, typename F::FilterState> > +struct ComposeFstOptions : public CacheOptions { + M *matcher1; // FST1 matcher (see matcher.h) + M *matcher2; // FST2 matcher + F *filter; // Composition filter (see compose-filter.h) + T *state_table; // Composition state table (see compose-state-table.h) + + explicit ComposeFstOptions(const CacheOptions &opts, + M *mat1 = 0, M *mat2 = 0, + F *filt = 0, T *sttable= 0) + : CacheOptions(opts), matcher1(mat1), matcher2(mat2), + filter(filt), state_table(sttable) {} + + ComposeFstOptions() : matcher1(0), matcher2(0), filter(0), state_table(0) {} +}; + + +// Delayed composition options templated on the two matcher types, the +// composition filter, and the composition state table. By default, +// the matchers, filter, and state table are constructed by +// composition. If set below, the user can instead pass in these +// objects; in that case, ComposeFst takes their ownership. This +// version controls composition implemented using arbitrary matchers +// (of the same Arc type but otherwise arbitrary Fst type). The user +// must ensure the matchers are compatible. These options permit the +// most efficient use, but shares the least code. This is for advanced +// use only in the most demanding or specialized applications that can +// benefit from it (o.w. prefer ComposeFstOptions). +template <class M1, class M2, + class F = SequenceComposeFilter<M1, M2>, + class T = GenericComposeStateTable<typename M1::Arc, + typename F::FilterState> > +struct ComposeFstImplOptions : public CacheOptions { + M1 *matcher1; // FST1 matcher (see matcher.h) + M2 *matcher2; // FST2 matcher + F *filter; // Composition filter (see compose-filter.h) + T *state_table; // Composition state table (see compose-state-table.h) + + explicit ComposeFstImplOptions(const CacheOptions &opts, + M1 *mat1 = 0, M2 *mat2 = 0, + F *filt = 0, T *sttable= 0) + : CacheOptions(opts), matcher1(mat1), matcher2(mat2), + filter(filt), state_table(sttable) {} + + ComposeFstImplOptions() + : matcher1(0), matcher2(0), filter(0), state_table(0) {} +}; + + +// Implementation of delayed composition. This base class is +// common to the variants with different matchers, composition filters +// and state tables. +template <class A> +class ComposeFstImplBase : public CacheImpl<A> { + public: + using FstImpl<A>::SetType; + using FstImpl<A>::SetProperties; + using FstImpl<A>::Properties; + using FstImpl<A>::SetInputSymbols; + using FstImpl<A>::SetOutputSymbols; + + using CacheBaseImpl< CacheState<A> >::HasStart; + using CacheBaseImpl< CacheState<A> >::HasFinal; + using CacheBaseImpl< CacheState<A> >::HasArcs; + using CacheBaseImpl< CacheState<A> >::SetFinal; + using CacheBaseImpl< CacheState<A> >::SetStart; + + typedef typename A::Label Label; + typedef typename A::Weight Weight; + typedef typename A::StateId StateId; + typedef CacheState<A> State; + + ComposeFstImplBase(const Fst<A> &fst1, const Fst<A> &fst2, + const CacheOptions &opts) + : CacheImpl<A>(opts) { + VLOG(2) << "ComposeFst(" << this << "): Begin"; + SetType("compose"); + + if (!CompatSymbols(fst2.InputSymbols(), fst1.OutputSymbols())) { + FSTERROR() << "ComposeFst: output symbol table of 1st argument " + << "does not match input symbol table of 2nd argument"; + SetProperties(kError, kError); + } + + SetInputSymbols(fst1.InputSymbols()); + SetOutputSymbols(fst2.OutputSymbols()); + } + + ComposeFstImplBase(const ComposeFstImplBase<A> &impl) + : CacheImpl<A>(impl, true) { + SetProperties(impl.Properties(), kCopyProperties); + SetInputSymbols(impl.InputSymbols()); + SetOutputSymbols(impl.OutputSymbols()); + } + + virtual ComposeFstImplBase<A> *Copy() = 0; + + virtual ~ComposeFstImplBase() {} + + StateId Start() { + if (!HasStart()) { + StateId start = ComputeStart(); + if (start != kNoStateId) { + SetStart(start); + } + } + return CacheImpl<A>::Start(); + } + + Weight Final(StateId s) { + if (!HasFinal(s)) { + Weight final = ComputeFinal(s); + SetFinal(s, final); + } + return CacheImpl<A>::Final(s); + } + + virtual void Expand(StateId s) = 0; + + size_t NumArcs(StateId s) { + if (!HasArcs(s)) + Expand(s); + return CacheImpl<A>::NumArcs(s); + } + + size_t NumInputEpsilons(StateId s) { + if (!HasArcs(s)) + Expand(s); + return CacheImpl<A>::NumInputEpsilons(s); + } + + size_t NumOutputEpsilons(StateId s) { + if (!HasArcs(s)) + Expand(s); + return CacheImpl<A>::NumOutputEpsilons(s); + } + + void InitArcIterator(StateId s, ArcIteratorData<A> *data) { + if (!HasArcs(s)) + Expand(s); + CacheImpl<A>::InitArcIterator(s, data); + } + + protected: + virtual StateId ComputeStart() = 0; + virtual Weight ComputeFinal(StateId s) = 0; +}; + + +// Implementaion of delayed composition templated on the matchers (see +// matcher.h), composition filter (see compose-filter-inl.h) and +// the composition state table (see compose-state-table.h). +template <class M1, class M2, class F, class T> +class ComposeFstImpl : public ComposeFstImplBase<typename M1::Arc> { + typedef typename M1::FST FST1; + typedef typename M2::FST FST2; + typedef typename M1::Arc Arc; + typedef typename Arc::StateId StateId; + typedef typename Arc::Label Label; + typedef typename Arc::Weight Weight; + typedef typename F::FilterState FilterState; + typedef typename F::Matcher1 Matcher1; + typedef typename F::Matcher2 Matcher2; + + using CacheBaseImpl<CacheState<Arc> >::SetArcs; + using FstImpl<Arc>::SetType; + using FstImpl<Arc>::SetProperties; + + typedef ComposeStateTuple<StateId, FilterState> StateTuple; + + public: + ComposeFstImpl(const FST1 &fst1, const FST2 &fst2, + const ComposeFstImplOptions<M1, M2, F, T> &opts); + + ComposeFstImpl(const ComposeFstImpl<M1, M2, F, T> &impl) + : ComposeFstImplBase<Arc>(impl), + filter_(new F(*impl.filter_, true)), + matcher1_(filter_->GetMatcher1()), + matcher2_(filter_->GetMatcher2()), + fst1_(matcher1_->GetFst()), + fst2_(matcher2_->GetFst()), + state_table_(new T(*impl.state_table_)), + match_type_(impl.match_type_) {} + + ~ComposeFstImpl() { + VLOG(2) << "ComposeFst(" << this + << "): End: # of visited states: " << state_table_->Size(); + + delete filter_; + delete state_table_; + } + + virtual ComposeFstImpl<M1, M2, F, T> *Copy() { + return new ComposeFstImpl<M1, M2, F, T>(*this); + } + + uint64 Properties() const { return Properties(kFstProperties); } + + // Set error if found; return FST impl properties. + uint64 Properties(uint64 mask) const { + if ((mask & kError) && + (fst1_.Properties(kError, false) || + fst2_.Properties(kError, false) || + (matcher1_->Properties(0) & kError) || + (matcher2_->Properties(0) & kError) | + (filter_->Properties(0) & kError) || + state_table_->Error())) { + SetProperties(kError, kError); + } + return FstImpl<Arc>::Properties(mask); + } + + // Arranges it so that the first arg to OrderedExpand is the Fst + // that will be matched on. + void Expand(StateId s) { + const StateTuple &tuple = state_table_->Tuple(s); + StateId s1 = tuple.state_id1; + StateId s2 = tuple.state_id2; + filter_->SetState(s1, s2, tuple.filter_state); + if (match_type_ == MATCH_OUTPUT || + (match_type_ == MATCH_BOTH && + internal::NumArcs(fst1_, s1) > internal::NumArcs(fst2_, s2))) + OrderedExpand(s, fst1_, s1, fst2_, s2, matcher1_, false); + else + OrderedExpand(s, fst2_, s2, fst1_, s1, matcher2_, true); + } + + const FST1 &GetFst1() { return fst1_; } + const FST2 &GetFst2() { return fst2_; } + M1 *GetMatcher1() { return matcher1_; } + M2 *GetMatcher2() { return matcher2_; } + F *GetFilter() { return filter_; } + T *GetStateTable() { return state_table_; } + + private: + // This does that actual matching of labels in the composition. The + // arguments are ordered so matching is called on state 'sa' of + // 'fsta' for each arc leaving state 'sb' of 'fstb'. The 'match_input' arg + // determines whether the input or output label of arcs at 'sb' is + // the one to match on. + template <class FST, class Matcher> + void OrderedExpand(StateId s, const Fst<Arc> &, StateId sa, + const FST &fstb, StateId sb, + Matcher *matchera, bool match_input) { + matchera->SetState(sa); + + // First process non-consuming symbols (e.g., epsilons) on FSTA. + Arc loop(match_input ? 0 : kNoLabel, match_input ? kNoLabel : 0, + Weight::One(), sb); + MatchArc(s, matchera, loop, match_input); + + // Then process matches on FSTB. + for (ArcIterator<FST> iterb(fstb, sb); !iterb.Done(); iterb.Next()) + MatchArc(s, matchera, iterb.Value(), match_input); + + SetArcs(s); + } + + // Matches a single transition from 'fstb' against 'fata' at 's'. + template <class Matcher> + void MatchArc(StateId s, Matcher *matchera, + const Arc &arc, bool match_input) { + if (matchera->Find(match_input ? arc.olabel : arc.ilabel)) { + for (; !matchera->Done(); matchera->Next()) { + Arc arca = matchera->Value(); + Arc arcb = arc; + if (match_input) { + const FilterState &f = filter_->FilterArc(&arcb, &arca); + if (f != FilterState::NoState()) + AddArc(s, arcb, arca, f); + } else { + const FilterState &f = filter_->FilterArc(&arca, &arcb); + if (f != FilterState::NoState()) + AddArc(s, arca, arcb, f); + } + } + } + } + + // Add a matching transition at 's'. + void AddArc(StateId s, const Arc &arc1, const Arc &arc2, + const FilterState &f) { + StateTuple tuple(arc1.nextstate, arc2.nextstate, f); + Arc oarc(arc1.ilabel, arc2.olabel, Times(arc1.weight, arc2.weight), + state_table_->FindState(tuple)); + CacheImpl<Arc>::PushArc(s, oarc); + } + + StateId ComputeStart() { + StateId s1 = fst1_.Start(); + if (s1 == kNoStateId) + return kNoStateId; + + StateId s2 = fst2_.Start(); + if (s2 == kNoStateId) + return kNoStateId; + + const FilterState &f = filter_->Start(); + StateTuple tuple(s1, s2, f); + return state_table_->FindState(tuple); + } + + Weight ComputeFinal(StateId s) { + const StateTuple &tuple = state_table_->Tuple(s); + StateId s1 = tuple.state_id1; + Weight final1 = internal::Final(fst1_, s1); + if (final1 == Weight::Zero()) + return final1; + + StateId s2 = tuple.state_id2; + Weight final2 = internal::Final(fst2_, s2); + if (final2 == Weight::Zero()) + return final2; + + filter_->SetState(s1, s2, tuple.filter_state); + filter_->FilterFinal(&final1, &final2); + return Times(final1, final2); + } + + // Identifies and verifies the capabilities of the matcher to be used for + // composition. + void SetMatchType(); + + F *filter_; + Matcher1 *matcher1_; + Matcher2 *matcher2_; + const FST1 &fst1_; + const FST2 &fst2_; + T *state_table_; + + MatchType match_type_; + + void operator=(const ComposeFstImpl<M1, M2, F, T> &); // disallow +}; + +template <class M1, class M2, class F, class T> inline +ComposeFstImpl<M1, M2, F, T>::ComposeFstImpl( + const FST1 &fst1, const FST2 &fst2, + const ComposeFstImplOptions<M1, M2, F, T> &opts) + : ComposeFstImplBase<Arc>(fst1, fst2, opts), + filter_(opts.filter ? opts.filter : + new F(fst1, fst2, opts.matcher1, opts.matcher2)), + matcher1_(filter_->GetMatcher1()), + matcher2_(filter_->GetMatcher2()), + fst1_(matcher1_->GetFst()), + fst2_(matcher2_->GetFst()), + state_table_(opts.state_table ? opts.state_table : + new T(fst1_, fst2_)) { + SetMatchType(); + if (match_type_ == MATCH_NONE) + SetProperties(kError, kError); + VLOG(2) << "ComposeFst(" << this << "): Match type: " + << (match_type_ == MATCH_OUTPUT ? "output" : + (match_type_ == MATCH_INPUT ? "input" : + (match_type_ == MATCH_BOTH ? "both" : + (match_type_ == MATCH_NONE ? "none" : "unknown")))); + + uint64 fprops1 = fst1.Properties(kFstProperties, false); + uint64 fprops2 = fst2.Properties(kFstProperties, false); + uint64 mprops1 = matcher1_->Properties(fprops1); + uint64 mprops2 = matcher2_->Properties(fprops2); + uint64 cprops = ComposeProperties(mprops1, mprops2); + SetProperties(filter_->Properties(cprops), kCopyProperties); + if (state_table_->Error()) SetProperties(kError, kError); + VLOG(2) << "ComposeFst(" << this << "): Initialized"; +} + +template <class M1, class M2, class F, class T> +void ComposeFstImpl<M1, M2, F, T>::SetMatchType() { + MatchType type1 = matcher1_->Type(false); + MatchType type2 = matcher2_->Type(false); + uint32 flags1 = matcher1_->Flags(); + uint32 flags2 = matcher2_->Flags(); + if (flags1 & flags2 & kRequireMatch) { + FSTERROR() << "ComposeFst: only one argument can require matching."; + match_type_ = MATCH_NONE; + } else if (flags1 & kRequireMatch) { + if (matcher1_->Type(true) != MATCH_OUTPUT) { + FSTERROR() << "ComposeFst: 1st argument requires matching but cannot."; + match_type_ = MATCH_NONE; + } + match_type_ = MATCH_OUTPUT; + } else if (flags2 & kRequireMatch) { + if (matcher2_->Type(true) != MATCH_INPUT) { + FSTERROR() << "ComposeFst: 2nd argument requires matching but cannot."; + match_type_ = MATCH_NONE; + } + match_type_ = MATCH_INPUT; + } else if (flags1 & flags2 & kPreferMatch && + type1 == MATCH_OUTPUT && type2 == MATCH_INPUT) { + match_type_ = MATCH_BOTH; + } else if (flags1 & kPreferMatch && type1 == MATCH_OUTPUT) { + match_type_ = MATCH_OUTPUT; + } else if (flags2 & kPreferMatch && type2 == MATCH_INPUT) { + match_type_ = MATCH_INPUT; + } else if (type1 == MATCH_OUTPUT && type2 == MATCH_INPUT) { + match_type_ = MATCH_BOTH; + } else if (type1 == MATCH_OUTPUT) { + match_type_ = MATCH_OUTPUT; + } else if (type2 == MATCH_INPUT) { + match_type_ = MATCH_INPUT; + } else if (flags1 & kPreferMatch && matcher1_->Type(true) == MATCH_OUTPUT) { + match_type_ = MATCH_OUTPUT; + } else if (flags2 & kPreferMatch && matcher2_->Type(true) == MATCH_INPUT) { + match_type_ = MATCH_INPUT; + } else if (matcher1_->Type(true) == MATCH_OUTPUT) { + match_type_ = MATCH_OUTPUT; + } else if (matcher2_->Type(true) == MATCH_INPUT) { + match_type_ = MATCH_INPUT; + } else { + FSTERROR() << "ComposeFst: 1st argument cannot match on output labels " + << "and 2nd argument cannot match on input labels (sort?)."; + match_type_ = MATCH_NONE; + } +} + + +// Computes the composition of two transducers. This version is a +// delayed Fst. If FST1 transduces string x to y with weight a and FST2 +// transduces y to z with weight b, then their composition transduces +// string x to z with weight Times(x, z). +// +// The output labels of the first transducer or the input labels of +// the second transducer must be sorted (with the default matcher). +// The weights need to form a commutative semiring (valid for +// TropicalWeight and LogWeight). +// +// Complexity: +// Assuming the first FST is unsorted and the second is sorted: +// - Time: O(v1 v2 d1 (log d2 + m2)), +// - Space: O(v1 v2) +// where vi = # of states visited, di = maximum out-degree, and mi the +// maximum multiplicity of the states visited for the ith +// FST. Constant time and space to visit an input state or arc is +// assumed and exclusive of caching. +// +// Caveats: +// - ComposeFst does not trim its output (since it is a delayed operation). +// - The efficiency of composition can be strongly affected by several factors: +// - the choice of which tnansducer is sorted - prefer sorting the FST +// that has the greater average out-degree. +// - the amount of non-determinism +// - the presence and location of epsilon transitions - avoid epsilon +// transitions on the output side of the first transducer or +// the input side of the second transducer or prefer placing +// them later in a path since they delay matching and can +// introduce non-coaccessible states and transitions. +// +// This class attaches interface to implementation and handles +// reference counting, delegating most methods to ImplToFst. +template <class A> +class ComposeFst : public ImplToFst< ComposeFstImplBase<A> > { + public: + friend class ArcIterator< ComposeFst<A> >; + friend class StateIterator< ComposeFst<A> >; + + typedef A Arc; + typedef typename A::Weight Weight; + typedef typename A::StateId StateId; + typedef CacheState<A> State; + typedef ComposeFstImplBase<A> Impl; + + using ImplToFst<Impl>::SetImpl; + + // Compose specifying only caching options. + ComposeFst(const Fst<A> &fst1, const Fst<A> &fst2, + const CacheOptions &opts = CacheOptions()) + : ImplToFst<Impl>(CreateBase(fst1, fst2, opts)) {} + + // Compose specifying one shared matcher type M. Requires input + // Fsts and matcher FST type (M::FST) be Fst<A>. Recommended for + // best code-sharing and matcher compatiblity. + template <class M, class F, class T> + ComposeFst(const Fst<A> &fst1, const Fst<A> &fst2, + const ComposeFstOptions<A, M, F, T> &opts) + : ImplToFst<Impl>(CreateBase1(fst1, fst2, opts)) {} + + // Compose specifying two matcher types M1 and M2. Requires input + // Fsts (of the same Arc type but o.w. arbitrary) match the + // corresponding matcher FST types (M1::FST, M2::FST). Recommended + // only for advanced use in demanding or specialized applications + // due to potential code bloat and matcher incompatibilities. + template <class M1, class M2, class F, class T> + ComposeFst(const typename M1::FST &fst1, const typename M2::FST &fst2, + const ComposeFstImplOptions<M1, M2, F, T> &opts) + : ImplToFst<Impl>(CreateBase2(fst1, fst2, opts)) {} + + // See Fst<>::Copy() for doc. + ComposeFst(const ComposeFst<A> &fst, bool safe = false) { + if (safe) + SetImpl(fst.GetImpl()->Copy()); + else + SetImpl(fst.GetImpl(), false); + } + + // Get a copy of this ComposeFst. See Fst<>::Copy() for further doc. + virtual ComposeFst<A> *Copy(bool safe = false) const { + return new ComposeFst<A>(*this, safe); + } + + virtual inline void InitStateIterator(StateIteratorData<A> *data) const; + + virtual void InitArcIterator(StateId s, ArcIteratorData<A> *data) const { + GetImpl()->InitArcIterator(s, data); + } + + protected: + ComposeFst() {} + + // Create compose implementation specifying two matcher types. + template <class M1, class M2, class F, class T> + static Impl *CreateBase2( + const typename M1::FST &fst1, const typename M2::FST &fst2, + const ComposeFstImplOptions<M1, M2, F, T> &opts) { + Impl *impl = new ComposeFstImpl<M1, M2, F, T>(fst1, fst2, opts); + if (!(Weight::Properties() & kCommutative)) { + int64 props1 = fst1.Properties(kUnweighted, true); + int64 props2 = fst2.Properties(kUnweighted, true); + if (!(props1 & kUnweighted) && !(props2 & kUnweighted)) { + FSTERROR() << "ComposeFst: Weights must be a commutative semiring: " + << Weight::Type(); + impl->SetProperties(kError, kError); + } + } + return impl; + } + + // Create compose implementation specifying one matcher type. + // Requires input Fsts and matcher FST type (M::FST) be Fst<A> + template <class M, class F, class T> + static Impl *CreateBase1(const Fst<A> &fst1, const Fst<A> &fst2, + const ComposeFstOptions<A, M, F, T> &opts) { + ComposeFstImplOptions<M, M, F, T> nopts(opts, opts.matcher1, opts.matcher2, + opts.filter, opts.state_table); + return CreateBase2(fst1, fst2, nopts); + } + + // Create compose implementation specifying no matcher type. + static Impl *CreateBase(const Fst<A> &fst1, const Fst<A> &fst2, + const CacheOptions &opts) { + switch (LookAheadMatchType(fst1, fst2)) { // Check for lookahead matchers + default: + case MATCH_NONE: { // Default composition (no look-ahead) + VLOG(2) << "ComposeFst: Default composition (no look-ahead)"; + ComposeFstOptions<Arc> nopts(opts); + return CreateBase1(fst1, fst2, nopts); + } + case MATCH_OUTPUT: { // Lookahead on fst1 + VLOG(2) << "ComposeFst: Lookahead on fst1"; + typedef typename DefaultLookAhead<Arc, MATCH_OUTPUT>::FstMatcher M; + typedef typename DefaultLookAhead<Arc, MATCH_OUTPUT>::ComposeFilter F; + ComposeFstOptions<Arc, M, F> nopts(opts); + return CreateBase1(fst1, fst2, nopts); + } + case MATCH_INPUT: { // Lookahead on fst2 + VLOG(2) << "ComposeFst: Lookahead on fst2"; + typedef typename DefaultLookAhead<Arc, MATCH_INPUT>::FstMatcher M; + typedef typename DefaultLookAhead<Arc, MATCH_INPUT>::ComposeFilter F; + ComposeFstOptions<Arc, M, F> nopts(opts); + return CreateBase1(fst1, fst2, nopts); + } + } + } + + private: + // Makes visible to friends. + Impl *GetImpl() const { return ImplToFst<Impl>::GetImpl(); } + + void operator=(const ComposeFst<A> &fst); // disallow +}; + + +// Specialization for ComposeFst. +template<class A> +class StateIterator< ComposeFst<A> > + : public CacheStateIterator< ComposeFst<A> > { + public: + explicit StateIterator(const ComposeFst<A> &fst) + : CacheStateIterator< ComposeFst<A> >(fst, fst.GetImpl()) {} +}; + + +// Specialization for ComposeFst. +template <class A> +class ArcIterator< ComposeFst<A> > + : public CacheArcIterator< ComposeFst<A> > { + public: + typedef typename A::StateId StateId; + + ArcIterator(const ComposeFst<A> &fst, StateId s) + : CacheArcIterator< ComposeFst<A> >(fst.GetImpl(), s) { + if (!fst.GetImpl()->HasArcs(s)) + fst.GetImpl()->Expand(s); + } + + private: + DISALLOW_COPY_AND_ASSIGN(ArcIterator); +}; + +template <class A> inline +void ComposeFst<A>::InitStateIterator(StateIteratorData<A> *data) const { + data->base = new StateIterator< ComposeFst<A> >(*this); +} + +// Useful alias when using StdArc. +typedef ComposeFst<StdArc> StdComposeFst; + +enum ComposeFilter { AUTO_FILTER, SEQUENCE_FILTER, ALT_SEQUENCE_FILTER, + MATCH_FILTER }; + +struct ComposeOptions { + bool connect; // Connect output + ComposeFilter filter_type; // Which pre-defined filter to use + + ComposeOptions(bool c, ComposeFilter ft = AUTO_FILTER) + : connect(c), filter_type(ft) {} + ComposeOptions() : connect(true), filter_type(AUTO_FILTER) {} +}; + +// Computes the composition of two transducers. This version writes +// the composed FST into a MurableFst. If FST1 transduces string x to +// y with weight a and FST2 transduces y to z with weight b, then +// their composition transduces string x to z with weight +// Times(x, z). +// +// The output labels of the first transducer or the input labels of +// the second transducer must be sorted. The weights need to form a +// commutative semiring (valid for TropicalWeight and LogWeight). +// +// Complexity: +// Assuming the first FST is unsorted and the second is sorted: +// - Time: O(V1 V2 D1 (log D2 + M2)), +// - Space: O(V1 V2 D1 M2) +// where Vi = # of states, Di = maximum out-degree, and Mi is +// the maximum multiplicity for the ith FST. +// +// Caveats: +// - Compose trims its output. +// - The efficiency of composition can be strongly affected by several factors: +// - the choice of which tnansducer is sorted - prefer sorting the FST +// that has the greater average out-degree. +// - the amount of non-determinism +// - the presence and location of epsilon transitions - avoid epsilon +// transitions on the output side of the first transducer or +// the input side of the second transducer or prefer placing +// them later in a path since they delay matching and can +// introduce non-coaccessible states and transitions. +template<class Arc> +void Compose(const Fst<Arc> &ifst1, const Fst<Arc> &ifst2, + MutableFst<Arc> *ofst, + const ComposeOptions &opts = ComposeOptions()) { + typedef Matcher< Fst<Arc> > M; + + if (opts.filter_type == AUTO_FILTER) { + CacheOptions nopts; + nopts.gc_limit = 0; // Cache only the last state for fastest copy. + *ofst = ComposeFst<Arc>(ifst1, ifst2, nopts); + } else if (opts.filter_type == SEQUENCE_FILTER) { + ComposeFstOptions<Arc> copts; + copts.gc_limit = 0; // Cache only the last state for fastest copy. + *ofst = ComposeFst<Arc>(ifst1, ifst2, copts); + } else if (opts.filter_type == ALT_SEQUENCE_FILTER) { + ComposeFstOptions<Arc, M, AltSequenceComposeFilter<M> > copts; + copts.gc_limit = 0; // Cache only the last state for fastest copy. + *ofst = ComposeFst<Arc>(ifst1, ifst2, copts); + } else if (opts.filter_type == MATCH_FILTER) { + ComposeFstOptions<Arc, M, MatchComposeFilter<M> > copts; + copts.gc_limit = 0; // Cache only the last state for fastest copy. + *ofst = ComposeFst<Arc>(ifst1, ifst2, copts); + } + + if (opts.connect) + Connect(ofst); +} + +} // namespace fst + +#endif // FST_LIB_COMPOSE_H__ diff --git a/kaldi_io/src/tools/openfst/include/fst/concat.h b/kaldi_io/src/tools/openfst/include/fst/concat.h new file mode 100644 index 0000000..8500d50 --- /dev/null +++ b/kaldi_io/src/tools/openfst/include/fst/concat.h @@ -0,0 +1,246 @@ +// concat.h + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Copyright 2005-2010 Google, Inc. +// Author: [email protected] (Michael Riley) +// +// \file +// Functions and classes to compute the concat of two FSTs. + +#ifndef FST_LIB_CONCAT_H__ +#define FST_LIB_CONCAT_H__ + +#include <vector> +using std::vector; +#include <algorithm> + +#include <fst/mutable-fst.h> +#include <fst/rational.h> + + +namespace fst { + +// Computes the concatenation (product) of two FSTs. If FST1 +// transduces string x to y with weight a and FST2 transduces string w +// to v with weight b, then their concatenation transduces string xw +// to yv with Times(a, b). +// +// This version modifies its MutableFst argument (in first position). +// +// Complexity: +// - Time: O(V1 + V2 + E2) +// - Space: O(V1 + V2 + E2) +// where Vi = # of states and Ei = # of arcs of the ith FST. +// +template<class Arc> +void Concat(MutableFst<Arc> *fst1, const Fst<Arc> &fst2) { + typedef typename Arc::StateId StateId; + typedef typename Arc::Label Label; + typedef typename Arc::Weight Weight; + + // TODO(riley): restore when voice actions issues fixed + // Check that the symbol table are compatible + if (!CompatSymbols(fst1->InputSymbols(), fst2.InputSymbols()) || + !CompatSymbols(fst1->OutputSymbols(), fst2.OutputSymbols())) { + LOG(ERROR) << "Concat: input/output symbol tables of 1st argument " + << "do not match input/output symbol tables of 2nd argument"; + // fst1->SetProperties(kError, kError); + // return; + } + + uint64 props1 = fst1->Properties(kFstProperties, false); + uint64 props2 = fst2.Properties(kFstProperties, false); + + StateId start1 = fst1->Start(); + if (start1 == kNoStateId) { + if (props2 & kError) fst1->SetProperties(kError, kError); + return; + } + + StateId numstates1 = fst1->NumStates(); + if (fst2.Properties(kExpanded, false)) + fst1->ReserveStates(numstates1 + CountStates(fst2)); + + for (StateIterator< Fst<Arc> > siter2(fst2); + !siter2.Done(); + siter2.Next()) { + StateId s1 = fst1->AddState(); + StateId s2 = siter2.Value(); + fst1->SetFinal(s1, fst2.Final(s2)); + fst1->ReserveArcs(s1, fst2.NumArcs(s2)); + for (ArcIterator< Fst<Arc> > aiter(fst2, s2); + !aiter.Done(); + aiter.Next()) { + Arc arc = aiter.Value(); + arc.nextstate += numstates1; + fst1->AddArc(s1, arc); + } + } + + StateId start2 = fst2.Start(); + for (StateId s1 = 0; s1 < numstates1; ++s1) { + Weight final = fst1->Final(s1); + if (final != Weight::Zero()) { + fst1->SetFinal(s1, Weight::Zero()); + if (start2 != kNoStateId) + fst1->AddArc(s1, Arc(0, 0, final, start2 + numstates1)); + } + } + if (start2 != kNoStateId) + fst1->SetProperties(ConcatProperties(props1, props2), kFstProperties); +} + +// Computes the concatentation of two FSTs. This version modifies its +// MutableFst argument (in second position). +// +// Complexity: +// - Time: O(V1 + E1) +// - Space: O(V1 + E1) +// where Vi = # of states and Ei = # of arcs of the ith FST. +// +template<class Arc> +void Concat(const Fst<Arc> &fst1, MutableFst<Arc> *fst2) { + typedef typename Arc::StateId StateId; + typedef typename Arc::Label Label; + typedef typename Arc::Weight Weight; + + // Check that the symbol table are compatible + if (!CompatSymbols(fst1.InputSymbols(), fst2->InputSymbols()) || + !CompatSymbols(fst1.OutputSymbols(), fst2->OutputSymbols())) { + LOG(ERROR) << "Concat: input/output symbol tables of 1st argument " + << "do not match input/output symbol tables of 2nd argument"; + // fst2->SetProperties(kError, kError); + // return; + } + + uint64 props1 = fst1.Properties(kFstProperties, false); + uint64 props2 = fst2->Properties(kFstProperties, false); + + StateId start2 = fst2->Start(); + if (start2 == kNoStateId) { + if (props1 & kError) fst2->SetProperties(kError, kError); + return; + } + + StateId numstates2 = fst2->NumStates(); + if (fst1.Properties(kExpanded, false)) + fst2->ReserveStates(numstates2 + CountStates(fst1)); + + for (StateIterator< Fst<Arc> > siter(fst1); + !siter.Done(); + siter.Next()) { + StateId s1 = siter.Value(); + StateId s2 = fst2->AddState(); + Weight final = fst1.Final(s1); + fst2->ReserveArcs(s2, fst1.NumArcs(s1) + (final != Weight::Zero() ? 1 : 0)); + if (final != Weight::Zero()) + fst2->AddArc(s2, Arc(0, 0, final, start2)); + for (ArcIterator< Fst<Arc> > aiter(fst1, s1); + !aiter.Done(); + aiter.Next()) { + Arc arc = aiter.Value(); + arc.nextstate += numstates2; + fst2->AddArc(s2, arc); + } + } + StateId start1 = fst1.Start(); + fst2->SetStart(start1 == kNoStateId ? fst2->AddState() : start1 + numstates2); + if (start1 != kNoStateId) + fst2->SetProperties(ConcatProperties(props1, props2), kFstProperties); +} + + +// Computes the concatentation of two FSTs. This version modifies its +// RationalFst input (in first position). +template<class Arc> +void Concat(RationalFst<Arc> *fst1, const Fst<Arc> &fst2) { + fst1->GetImpl()->AddConcat(fst2, true); +} + +// Computes the concatentation of two FSTs. This version modifies its +// RationalFst input (in second position). +template<class Arc> +void Concat(const Fst<Arc> &fst1, RationalFst<Arc> *fst2) { + fst2->GetImpl()->AddConcat(fst1, false); +} + +typedef RationalFstOptions ConcatFstOptions; + + +// Computes the concatenation (product) of two FSTs; this version is a +// delayed Fst. If FST1 transduces string x to y with weight a and FST2 +// transduces string w to v with weight b, then their concatenation +// transduces string xw to yv with Times(a, b). +// +// Complexity: +// - Time: O(v1 + e1 + v2 + e2), +// - Space: O(v1 + v2) +// where vi = # of states visited and ei = # of arcs visited of the +// ith FST. Constant time and space to visit an input state or arc is +// assumed and exclusive of caching. +template <class A> +class ConcatFst : public RationalFst<A> { + public: + using ImplToFst< RationalFstImpl<A> >::GetImpl; + + typedef A Arc; + typedef typename A::Weight Weight; + typedef typename A::StateId StateId; + + ConcatFst(const Fst<A> &fst1, const Fst<A> &fst2) { + GetImpl()->InitConcat(fst1, fst2); + } + + ConcatFst(const Fst<A> &fst1, const Fst<A> &fst2, + const ConcatFstOptions &opts) : RationalFst<A>(opts) { + GetImpl()->InitConcat(fst1, fst2); + } + + // See Fst<>::Copy() for doc. + ConcatFst(const ConcatFst<A> &fst, bool safe = false) + : RationalFst<A>(fst, safe) {} + + // Get a copy of this ConcatFst. See Fst<>::Copy() for further doc. + virtual ConcatFst<A> *Copy(bool safe = false) const { + return new ConcatFst<A>(*this, safe); + } +}; + + +// Specialization for ConcatFst. +template <class A> +class StateIterator< ConcatFst<A> > : public StateIterator< RationalFst<A> > { + public: + explicit StateIterator(const ConcatFst<A> &fst) + : StateIterator< RationalFst<A> >(fst) {} +}; + + +// Specialization for ConcatFst. +template <class A> +class ArcIterator< ConcatFst<A> > : public ArcIterator< RationalFst<A> > { + public: + typedef typename A::StateId StateId; + + ArcIterator(const ConcatFst<A> &fst, StateId s) + : ArcIterator< RationalFst<A> >(fst, s) {} +}; + + +// Useful alias when using StdArc. +typedef ConcatFst<StdArc> StdConcatFst; + +} // namespace fst + +#endif // FST_LIB_CONCAT_H__ diff --git a/kaldi_io/src/tools/openfst/include/fst/config.h b/kaldi_io/src/tools/openfst/include/fst/config.h new file mode 100644 index 0000000..046b49c --- /dev/null +++ b/kaldi_io/src/tools/openfst/include/fst/config.h @@ -0,0 +1,12 @@ +/* src/include/fst/config.h. Generated from config.h.in by configure. */ +// OpenFst config file + +/* Define to 1 if you have the ICU library. */ +/* #undef HAVE_ICU */ + +/* Define to 1 if the system has the type `std::tr1::hash<long long + unsigned>'. */ +#define HAVE_STD__TR1__HASH_LONG_LONG_UNSIGNED_ 1 + +/* Define to 1 if the system has the type `__gnu_cxx::slist<int>'. */ +#define HAVE___GNU_CXX__SLIST_INT_ 1 diff --git a/kaldi_io/src/tools/openfst/include/fst/connect.h b/kaldi_io/src/tools/openfst/include/fst/connect.h new file mode 100644 index 0000000..427808c --- /dev/null +++ b/kaldi_io/src/tools/openfst/include/fst/connect.h @@ -0,0 +1,319 @@ +// connect.h + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Copyright 2005-2010 Google, Inc. +// Author: [email protected] (Michael Riley) +// +// \file +// Classes and functions to remove unsuccessful paths from an Fst. + +#ifndef FST_LIB_CONNECT_H__ +#define FST_LIB_CONNECT_H__ + +#include <vector> +using std::vector; + +#include <fst/dfs-visit.h> +#include <fst/union-find.h> +#include <fst/mutable-fst.h> + + +namespace fst { + +// Finds and returns connected components. Use with Visit(). +template <class A> +class CcVisitor { + public: + typedef A Arc; + typedef typename Arc::Weight Weight; + typedef typename A::StateId StateId; + + // cc[i]: connected component number for state i. + CcVisitor(vector<StateId> *cc) + : comps_(new UnionFind<StateId>(0, kNoStateId)), + cc_(cc), + nstates_(0) { } + + // comps: connected components equiv classes. + CcVisitor(UnionFind<StateId> *comps) + : comps_(comps), + cc_(0), + nstates_(0) { } + + ~CcVisitor() { + if (cc_) // own comps_? + delete comps_; + } + + void InitVisit(const Fst<A> &fst) { } + + bool InitState(StateId s, StateId root) { + ++nstates_; + if (comps_->FindSet(s) == kNoStateId) + comps_->MakeSet(s); + return true; + } + + bool WhiteArc(StateId s, const A &arc) { + comps_->MakeSet(arc.nextstate); + comps_->Union(s, arc.nextstate); + return true; + } + + bool GreyArc(StateId s, const A &arc) { + comps_->Union(s, arc.nextstate); + return true; + } + + bool BlackArc(StateId s, const A &arc) { + comps_->Union(s, arc.nextstate); + return true; + } + + void FinishState(StateId s) { } + + void FinishVisit() { + if (cc_) + GetCcVector(cc_); + } + + // cc[i]: connected component number for state i. + // Returns number of components. + int GetCcVector(vector<StateId> *cc) { + cc->clear(); + cc->resize(nstates_, kNoStateId); + StateId ncomp = 0; + for (StateId i = 0; i < nstates_; ++i) { + StateId rep = comps_->FindSet(i); + StateId &comp = (*cc)[rep]; + if (comp == kNoStateId) { + comp = ncomp; + ++ncomp; + } + (*cc)[i] = comp; + } + return ncomp; + } + + private: + UnionFind<StateId> *comps_; // Components + vector<StateId> *cc_; // State's cc number + StateId nstates_; // State count +}; + + +// Finds and returns strongly-connected components, accessible and +// coaccessible states and related properties. Uses Tarjan's single +// DFS SCC algorithm (see Aho, et al, "Design and Analysis of Computer +// Algorithms", 189pp). Use with DfsVisit(); +template <class A> +class SccVisitor { + public: + typedef A Arc; + typedef typename A::Weight Weight; + typedef typename A::StateId StateId; + + // scc[i]: strongly-connected component number for state i. + // SCC numbers will be in topological order for acyclic input. + // access[i]: accessibility of state i. + // coaccess[i]: coaccessibility of state i. + // Any of above can be NULL. + // props: related property bits (cyclicity, initial cyclicity, + // accessibility, coaccessibility) set/cleared (o.w. unchanged). + SccVisitor(vector<StateId> *scc, vector<bool> *access, + vector<bool> *coaccess, uint64 *props) + : scc_(scc), access_(access), coaccess_(coaccess), props_(props) {} + SccVisitor(uint64 *props) + : scc_(0), access_(0), coaccess_(0), props_(props) {} + + void InitVisit(const Fst<A> &fst); + + bool InitState(StateId s, StateId root); + + bool TreeArc(StateId s, const A &arc) { return true; } + + bool BackArc(StateId s, const A &arc) { + StateId t = arc.nextstate; + if ((*dfnumber_)[t] < (*lowlink_)[s]) + (*lowlink_)[s] = (*dfnumber_)[t]; + if ((*coaccess_)[t]) + (*coaccess_)[s] = true; + *props_ |= kCyclic; + *props_ &= ~kAcyclic; + if (arc.nextstate == start_) { + *props_ |= kInitialCyclic; + *props_ &= ~kInitialAcyclic; + } + return true; + } + + bool ForwardOrCrossArc(StateId s, const A &arc) { + StateId t = arc.nextstate; + if ((*dfnumber_)[t] < (*dfnumber_)[s] /* cross edge */ && + (*onstack_)[t] && (*dfnumber_)[t] < (*lowlink_)[s]) + (*lowlink_)[s] = (*dfnumber_)[t]; + if ((*coaccess_)[t]) + (*coaccess_)[s] = true; + return true; + } + + void FinishState(StateId s, StateId p, const A *); + + void FinishVisit() { + // Numbers SCC's in topological order when acyclic. + if (scc_) + for (StateId i = 0; i < scc_->size(); ++i) + (*scc_)[i] = nscc_ - 1 - (*scc_)[i]; + if (coaccess_internal_) + delete coaccess_; + delete dfnumber_; + delete lowlink_; + delete onstack_; + delete scc_stack_; + } + + private: + vector<StateId> *scc_; // State's scc number + vector<bool> *access_; // State's accessibility + vector<bool> *coaccess_; // State's coaccessibility + uint64 *props_; + const Fst<A> *fst_; + StateId start_; + StateId nstates_; // State count + StateId nscc_; // SCC count + bool coaccess_internal_; + vector<StateId> *dfnumber_; // state discovery times + vector<StateId> *lowlink_; // lowlink[s] == dfnumber[s] => SCC root + vector<bool> *onstack_; // is a state on the SCC stack + vector<StateId> *scc_stack_; // SCC stack (w/ random access) +}; + +template <class A> inline +void SccVisitor<A>::InitVisit(const Fst<A> &fst) { + if (scc_) + scc_->clear(); + if (access_) + access_->clear(); + if (coaccess_) { + coaccess_->clear(); + coaccess_internal_ = false; + } else { + coaccess_ = new vector<bool>; + coaccess_internal_ = true; + } + *props_ |= kAcyclic | kInitialAcyclic | kAccessible | kCoAccessible; + *props_ &= ~(kCyclic | kInitialCyclic | kNotAccessible | kNotCoAccessible); + fst_ = &fst; + start_ = fst.Start(); + nstates_ = 0; + nscc_ = 0; + dfnumber_ = new vector<StateId>; + lowlink_ = new vector<StateId>; + onstack_ = new vector<bool>; + scc_stack_ = new vector<StateId>; +} + +template <class A> inline +bool SccVisitor<A>::InitState(StateId s, StateId root) { + scc_stack_->push_back(s); + while (dfnumber_->size() <= s) { + if (scc_) + scc_->push_back(-1); + if (access_) + access_->push_back(false); + coaccess_->push_back(false); + dfnumber_->push_back(-1); + lowlink_->push_back(-1); + onstack_->push_back(false); + } + (*dfnumber_)[s] = nstates_; + (*lowlink_)[s] = nstates_; + (*onstack_)[s] = true; + if (root == start_) { + if (access_) + (*access_)[s] = true; + } else { + if (access_) + (*access_)[s] = false; + *props_ |= kNotAccessible; + *props_ &= ~kAccessible; + } + ++nstates_; + return true; +} + +template <class A> inline +void SccVisitor<A>::FinishState(StateId s, StateId p, const A *) { + if (fst_->Final(s) != Weight::Zero()) + (*coaccess_)[s] = true; + if ((*dfnumber_)[s] == (*lowlink_)[s]) { // root of new SCC + bool scc_coaccess = false; + size_t i = scc_stack_->size(); + StateId t; + do { + t = (*scc_stack_)[--i]; + if ((*coaccess_)[t]) + scc_coaccess = true; + } while (s != t); + do { + t = scc_stack_->back(); + if (scc_) + (*scc_)[t] = nscc_; + if (scc_coaccess) + (*coaccess_)[t] = true; + (*onstack_)[t] = false; + scc_stack_->pop_back(); + } while (s != t); + if (!scc_coaccess) { + *props_ |= kNotCoAccessible; + *props_ &= ~kCoAccessible; + } + ++nscc_; + } + if (p != kNoStateId) { + if ((*coaccess_)[s]) + (*coaccess_)[p] = true; + if ((*lowlink_)[s] < (*lowlink_)[p]) + (*lowlink_)[p] = (*lowlink_)[s]; + } +} + + +// Trims an FST, removing states and arcs that are not on successful +// paths. This version modifies its input. +// +// Complexity: +// - Time: O(V + E) +// - Space: O(V + E) +// where V = # of states and E = # of arcs. +template<class Arc> +void Connect(MutableFst<Arc> *fst) { + typedef typename Arc::StateId StateId; + + vector<bool> access; + vector<bool> coaccess; + uint64 props = 0; + SccVisitor<Arc> scc_visitor(0, &access, &coaccess, &props); + DfsVisit(*fst, &scc_visitor); + vector<StateId> dstates; + for (StateId s = 0; s < access.size(); ++s) + if (!access[s] || !coaccess[s]) + dstates.push_back(s); + fst->DeleteStates(dstates); + fst->SetProperties(kAccessible | kCoAccessible, kAccessible | kCoAccessible); +} + +} // namespace fst + +#endif // FST_LIB_CONNECT_H__ diff --git a/kaldi_io/src/tools/openfst/include/fst/const-fst.h b/kaldi_io/src/tools/openfst/include/fst/const-fst.h new file mode 100644 index 0000000..e6e85af --- /dev/null +++ b/kaldi_io/src/tools/openfst/include/fst/const-fst.h @@ -0,0 +1,497 @@ +// const-fst.h + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Copyright 2005-2010 Google, Inc. +// Author: [email protected] (Michael Riley) +// +// \file +// Simple concrete immutable FST whose states and arcs are each stored +// in single arrays. + +#ifndef FST_LIB_CONST_FST_H__ +#define FST_LIB_CONST_FST_H__ + +#include <string> +#include <vector> +using std::vector; + +#include <fst/expanded-fst.h> +#include <fst/fst-decl.h> // For optional argument declarations +#include <fst/mapped-file.h> +#include <fst/test-properties.h> +#include <fst/util.h> + + +namespace fst { + +template <class A, class U> class ConstFst; +template <class F, class G> void Cast(const F &, G *); + +// States and arcs each implemented by single arrays, templated on the +// Arc definition. The unsigned type U is used to represent indices into +// the arc array. +template <class A, class U> +class ConstFstImpl : public FstImpl<A> { + public: + using FstImpl<A>::SetInputSymbols; + using FstImpl<A>::SetOutputSymbols; + using FstImpl<A>::SetType; + using FstImpl<A>::SetProperties; + using FstImpl<A>::Properties; + + typedef A Arc; + typedef typename A::Weight Weight; + typedef typename A::StateId StateId; + typedef U Unsigned; + + ConstFstImpl() + : states_region_(0), arcs_region_(0), states_(0), arcs_(0), nstates_(0), + narcs_(0), start_(kNoStateId) { + string type = "const"; + if (sizeof(U) != sizeof(uint32)) { + string size; + Int64ToStr(8 * sizeof(U), &size); + type += size; + } + SetType(type); + SetProperties(kNullProperties | kStaticProperties); + } + + explicit ConstFstImpl(const Fst<A> &fst); + + ~ConstFstImpl() { + delete arcs_region_; + delete states_region_; + } + + StateId Start() const { return start_; } + + Weight Final(StateId s) const { return states_[s].final; } + + StateId NumStates() const { return nstates_; } + + size_t NumArcs(StateId s) const { return states_[s].narcs; } + + size_t NumInputEpsilons(StateId s) const { return states_[s].niepsilons; } + + size_t NumOutputEpsilons(StateId s) const { return states_[s].noepsilons; } + + static ConstFstImpl<A, U> *Read(istream &strm, const FstReadOptions &opts); + + A *Arcs(StateId s) { return arcs_ + states_[s].pos; } + + // Provide information needed for generic state iterator + void InitStateIterator(StateIteratorData<A> *data) const { + data->base = 0; + data->nstates = nstates_; + } + + // Provide information needed for the generic arc iterator + void InitArcIterator(StateId s, ArcIteratorData<A> *data) const { + data->base = 0; + data->arcs = arcs_ + states_[s].pos; + data->narcs = states_[s].narcs; + data->ref_count = 0; + } + + private: + friend class ConstFst<A, U>; // Allow finding narcs_, nstates_ during Write + + // States implemented by array *states_ below, arcs by (single) *arcs_. + struct State { + Weight final; // Final weight + Unsigned pos; // Start of state's arcs in *arcs_ + Unsigned narcs; // Number of arcs (per state) + Unsigned niepsilons; // # of input epsilons + Unsigned noepsilons; // # of output epsilons + State() : final(Weight::Zero()), niepsilons(0), noepsilons(0) {} + }; + + // Properties always true of this Fst class + static const uint64 kStaticProperties = kExpanded; + // Current unaligned file format version. The unaligned version was added and + // made the default since the aligned version does not work on pipes. + static const int kFileVersion = 2; + // Current aligned file format version + static const int kAlignedFileVersion = 1; + // Minimum file format version supported + static const int kMinFileVersion = 1; + + MappedFile *states_region_; // Mapped file for states + MappedFile *arcs_region_; // Mapped file for arcs + State *states_; // States represenation + A *arcs_; // Arcs representation + StateId nstates_; // Number of states + size_t narcs_; // Number of arcs (per FST) + StateId start_; // Initial state + + DISALLOW_COPY_AND_ASSIGN(ConstFstImpl); +}; + +template <class A, class U> +const uint64 ConstFstImpl<A, U>::kStaticProperties; +template <class A, class U> +const int ConstFstImpl<A, U>::kFileVersion; +template <class A, class U> +const int ConstFstImpl<A, U>::kAlignedFileVersion; +template <class A, class U> +const int ConstFstImpl<A, U>::kMinFileVersion; + + +template<class A, class U> +ConstFstImpl<A, U>::ConstFstImpl(const Fst<A> &fst) : nstates_(0), narcs_(0) { + string type = "const"; + if (sizeof(U) != sizeof(uint32)) { + string size; + Int64ToStr(sizeof(U) * 8, &size); + type += size; + } + SetType(type); + SetInputSymbols(fst.InputSymbols()); + SetOutputSymbols(fst.OutputSymbols()); + start_ = fst.Start(); + + // Count # of states and arcs. + for (StateIterator< Fst<A> > siter(fst); + !siter.Done(); + siter.Next()) { + ++nstates_; + StateId s = siter.Value(); + for (ArcIterator< Fst<A> > aiter(fst, s); + !aiter.Done(); + aiter.Next()) + ++narcs_; + } + states_region_ = MappedFile::Allocate(nstates_ * sizeof(*states_)); + arcs_region_ = MappedFile::Allocate(narcs_ * sizeof(*arcs_)); + states_ = reinterpret_cast<State*>(states_region_->mutable_data()); + arcs_ = reinterpret_cast<A*>(arcs_region_->mutable_data()); + size_t pos = 0; + for (StateId s = 0; s < nstates_; ++s) { + states_[s].final = fst.Final(s); + states_[s].pos = pos; + states_[s].narcs = 0; + states_[s].niepsilons = 0; + states_[s].noepsilons = 0; + for (ArcIterator< Fst<A> > aiter(fst, s); + !aiter.Done(); + aiter.Next()) { + const A &arc = aiter.Value(); + ++states_[s].narcs; + if (arc.ilabel == 0) + ++states_[s].niepsilons; + if (arc.olabel == 0) + ++states_[s].noepsilons; + arcs_[pos++] = arc; + } + } + SetProperties(fst.Properties(kCopyProperties, true) | kStaticProperties); +} + + +template<class A, class U> +ConstFstImpl<A, U> *ConstFstImpl<A, U>::Read(istream &strm, + const FstReadOptions &opts) { + ConstFstImpl<A, U> *impl = new ConstFstImpl<A, U>; + FstHeader hdr; + if (!impl->ReadHeader(strm, opts, kMinFileVersion, &hdr)) { + delete impl; + return 0; + } + impl->start_ = hdr.Start(); + impl->nstates_ = hdr.NumStates(); + impl->narcs_ = hdr.NumArcs(); + + // Ensures compatibility + if (hdr.Version() == kAlignedFileVersion) + hdr.SetFlags(hdr.GetFlags() | FstHeader::IS_ALIGNED); + + if ((hdr.GetFlags() & FstHeader::IS_ALIGNED) && !AlignInput(strm)) { + LOG(ERROR) << "ConstFst::Read: Alignment failed: " << opts.source; + delete impl; + return 0; + } + + size_t b = impl->nstates_ * sizeof(typename ConstFstImpl<A, U>::State); + impl->states_region_ = MappedFile::Map(&strm, opts, b); + if (!strm || impl->states_region_ == NULL) { + LOG(ERROR) << "ConstFst::Read: Read failed: " << opts.source; + delete impl; + return 0; + } + impl->states_ = reinterpret_cast<State*>( + impl->states_region_->mutable_data()); + if ((hdr.GetFlags() & FstHeader::IS_ALIGNED) && !AlignInput(strm)) { + LOG(ERROR) << "ConstFst::Read: Alignment failed: " << opts.source; + delete impl; + return 0; + } + + b = impl->narcs_ * sizeof(A); + impl->arcs_region_ = MappedFile::Map(&strm, opts, b); + if (!strm || impl->arcs_region_ == NULL) { + LOG(ERROR) << "ConstFst::Read: Read failed: " << opts.source; + delete impl; + return 0; + } + impl->arcs_ = reinterpret_cast<A*>(impl->arcs_region_->mutable_data()); + return impl; +} + +// Simple concrete immutable FST. This class attaches interface to +// implementation and handles reference counting, delegating most +// methods to ImplToExpandedFst. The unsigned type U is used to +// represent indices into the arc array (uint32 by default, declared +// in fst-decl.h). +template <class A, class U> +class ConstFst : public ImplToExpandedFst< ConstFstImpl<A, U> > { + public: + friend class StateIterator< ConstFst<A, U> >; + friend class ArcIterator< ConstFst<A, U> >; + template <class F, class G> void friend Cast(const F &, G *); + + typedef A Arc; + typedef typename A::StateId StateId; + typedef ConstFstImpl<A, U> Impl; + typedef U Unsigned; + + ConstFst() : ImplToExpandedFst<Impl>(new Impl()) {} + + explicit ConstFst(const Fst<A> &fst) + : ImplToExpandedFst<Impl>(new Impl(fst)) {} + + ConstFst(const ConstFst<A, U> &fst) : ImplToExpandedFst<Impl>(fst) {} + + // Get a copy of this ConstFst. See Fst<>::Copy() for further doc. + virtual ConstFst<A, U> *Copy(bool safe = false) const { + return new ConstFst<A, U>(*this); + } + + // Read a ConstFst from an input stream; return NULL on error + static ConstFst<A, U> *Read(istream &strm, const FstReadOptions &opts) { + Impl* impl = Impl::Read(strm, opts); + return impl ? new ConstFst<A, U>(impl) : 0; + } + + // Read a ConstFst from a file; return NULL on error + // Empty filename reads from standard input + static ConstFst<A, U> *Read(const string &filename) { + Impl* impl = ImplToExpandedFst<Impl>::Read(filename); + return impl ? new ConstFst<A, U>(impl) : 0; + } + + virtual bool Write(ostream &strm, const FstWriteOptions &opts) const { + return WriteFst(*this, strm, opts); + } + + virtual bool Write(const string &filename) const { + return Fst<A>::WriteFile(filename); + } + + template <class F> + static bool WriteFst(const F &fst, ostream &strm, + const FstWriteOptions &opts); + + virtual void InitStateIterator(StateIteratorData<Arc> *data) const { + GetImpl()->InitStateIterator(data); + } + + virtual void InitArcIterator(StateId s, ArcIteratorData<Arc> *data) const { + GetImpl()->InitArcIterator(s, data); + } + + private: + explicit ConstFst(Impl *impl) : ImplToExpandedFst<Impl>(impl) {} + + // Makes visible to friends. + Impl *GetImpl() const { return ImplToFst<Impl, ExpandedFst<A> >::GetImpl(); } + + void SetImpl(Impl *impl, bool own_impl = true) { + ImplToFst< Impl, ExpandedFst<A> >::SetImpl(impl, own_impl); + } + + // Use overloading to extract the type of the argument. + static Impl* GetImplIfConstFst(const ConstFst &const_fst) { + return const_fst.GetImpl(); + } + + // Note that this does not give privileged treatment to subtypes of ConstFst. + template<typename NonConstFst> + static Impl* GetImplIfConstFst(const NonConstFst& fst) { + return NULL; + } + + void operator=(const ConstFst<A, U> &fst); // disallow +}; + +// Writes Fst in Const format, potentially with a pass over the machine +// before writing to compute number of states and arcs. +// +template <class A, class U> +template <class F> +bool ConstFst<A, U>::WriteFst(const F &fst, ostream &strm, + const FstWriteOptions &opts) { + int file_version = opts.align ? ConstFstImpl<A, U>::kAlignedFileVersion : + ConstFstImpl<A, U>::kFileVersion; + size_t num_arcs = -1, num_states = -1; + size_t start_offset = 0; + bool update_header = true; + if (Impl* impl = GetImplIfConstFst(fst)) { + num_arcs = impl->narcs_; + num_states = impl->nstates_; + update_header = false; + } else if ((start_offset = strm.tellp()) == -1) { + // precompute values needed for header when we cannot seek to rewrite it. + num_arcs = 0; + num_states = 0; + for (StateIterator<F> siter(fst); !siter.Done(); siter.Next()) { + num_arcs += fst.NumArcs(siter.Value()); + ++num_states; + } + update_header = false; + } + FstHeader hdr; + hdr.SetStart(fst.Start()); + hdr.SetNumStates(num_states); + hdr.SetNumArcs(num_arcs); + string type = "const"; + if (sizeof(U) != sizeof(uint32)) { + string size; + Int64ToStr(8 * sizeof(U), &size); + type += size; + } + uint64 properties = fst.Properties(kCopyProperties, true) | + ConstFstImpl<A, U>::kStaticProperties; + FstImpl<A>::WriteFstHeader(fst, strm, opts, file_version, type, properties, + &hdr); + if (opts.align && !AlignOutput(strm)) { + LOG(ERROR) << "Could not align file during write after header"; + return false; + } + size_t pos = 0, states = 0; + typename ConstFstImpl<A, U>::State state; + for (StateIterator<F> siter(fst); !siter.Done(); siter.Next()) { + state.final = fst.Final(siter.Value()); + state.pos = pos; + state.narcs = fst.NumArcs(siter.Value()); + state.niepsilons = fst.NumInputEpsilons(siter.Value()); + state.noepsilons = fst.NumOutputEpsilons(siter.Value()); + strm.write(reinterpret_cast<const char *>(&state), sizeof(state)); + pos += state.narcs; + ++states; + } + hdr.SetNumStates(states); + hdr.SetNumArcs(pos); + if (opts.align && !AlignOutput(strm)) { + LOG(ERROR) << "Could not align file during write after writing states"; + } + for (StateIterator<F> siter(fst); !siter.Done(); siter.Next()) { + StateId s = siter.Value(); + for (ArcIterator<F> aiter(fst, s); !aiter.Done(); aiter.Next()) { + const A &arc = aiter.Value(); + strm.write(reinterpret_cast<const char *>(&arc), sizeof(arc)); + } + } + strm.flush(); + if (!strm) { + LOG(ERROR) << "ConstFst Write write failed: " << opts.source; + return false; + } + if (update_header) { + return FstImpl<A>::UpdateFstHeader(fst, strm, opts, file_version, type, + properties, &hdr, start_offset); + } else { + if (hdr.NumStates() != num_states) { + LOG(ERROR) << "Inconsistent number of states observed during write"; + return false; + } + if (hdr.NumArcs() != num_arcs) { + LOG(ERROR) << "Inconsistent number of arcs observed during write"; + return false; + } + } + return true; +} + +// Specialization for ConstFst; see generic version in fst.h +// for sample usage (but use the ConstFst type!). This version +// should inline. +template <class A, class U> +class StateIterator< ConstFst<A, U> > { + public: + typedef typename A::StateId StateId; + + explicit StateIterator(const ConstFst<A, U> &fst) + : nstates_(fst.GetImpl()->NumStates()), s_(0) {} + + bool Done() const { return s_ >= nstates_; } + + StateId Value() const { return s_; } + + void Next() { ++s_; } + + void Reset() { s_ = 0; } + + private: + StateId nstates_; + StateId s_; + + DISALLOW_COPY_AND_ASSIGN(StateIterator); +}; + + +// Specialization for ConstFst; see generic version in fst.h +// for sample usage (but use the ConstFst type!). This version +// should inline. +template <class A, class U> +class ArcIterator< ConstFst<A, U> > { + public: + typedef typename A::StateId StateId; + + ArcIterator(const ConstFst<A, U> &fst, StateId s) + : arcs_(fst.GetImpl()->Arcs(s)), + narcs_(fst.GetImpl()->NumArcs(s)), i_(0) {} + + bool Done() const { return i_ >= narcs_; } + + const A& Value() const { return arcs_[i_]; } + + void Next() { ++i_; } + + size_t Position() const { return i_; } + + void Reset() { i_ = 0; } + + void Seek(size_t a) { i_ = a; } + + uint32 Flags() const { + return kArcValueFlags; + } + + void SetFlags(uint32 f, uint32 m) {} + + private: + const A *arcs_; + size_t narcs_; + size_t i_; + + DISALLOW_COPY_AND_ASSIGN(ArcIterator); +}; + +// A useful alias when using StdArc. +typedef ConstFst<StdArc> StdConstFst; + +} // namespace fst + +#endif // FST_LIB_CONST_FST_H__ diff --git a/kaldi_io/src/tools/openfst/include/fst/determinize.h b/kaldi_io/src/tools/openfst/include/fst/determinize.h new file mode 100644 index 0000000..9ff8723 --- /dev/null +++ b/kaldi_io/src/tools/openfst/include/fst/determinize.h @@ -0,0 +1,1015 @@ +// determinize.h + + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Copyright 2005-2010 Google, Inc. +// Author: [email protected] (Michael Riley) +// +// \file +// Functions and classes to determinize an FST. + +#ifndef FST_LIB_DETERMINIZE_H__ +#define FST_LIB_DETERMINIZE_H__ + +#include <algorithm> +#include <climits> +#include <tr1/unordered_map> +using std::tr1::unordered_map; +using std::tr1::unordered_multimap; +#include <map> +#include <fst/slist.h> +#include <string> +#include <vector> +using std::vector; + +#include <fst/arc-map.h> +#include <fst/cache.h> +#include <fst/bi-table.h> +#include <fst/factor-weight.h> +#include <fst/prune.h> +#include <fst/test-properties.h> + + +namespace fst { + +// +// COMMON DIVISORS - these are used in determinization to compute +// the transition weights. In the simplest case, it is just the same +// as the semiring Plus(). However, other choices permit more efficient +// determinization when the output contains strings. +// + +// The default common divisor uses the semiring Plus. +template <class W> +class DefaultCommonDivisor { + public: + typedef W Weight; + + W operator()(const W &w1, const W &w2) const { return Plus(w1, w2); } +}; + + +// The label common divisor for a (left) string semiring selects a +// single letter common prefix or the empty string. This is used in +// the determinization of output strings so that at most a single +// letter will appear in the output of a transtion. +template <typename L, StringType S> +class LabelCommonDivisor { + public: + typedef StringWeight<L, S> Weight; + + Weight operator()(const Weight &w1, const Weight &w2) const { + StringWeightIterator<L, S> iter1(w1); + StringWeightIterator<L, S> iter2(w2); + + if (!(StringWeight<L, S>::Properties() & kLeftSemiring)) { + FSTERROR() << "LabelCommonDivisor: Weight needs to be left semiring"; + return Weight::NoWeight(); + } else if (w1.Size() == 0 || w2.Size() == 0) { + return Weight::One(); + } else if (w1 == Weight::Zero()) { + return Weight(iter2.Value()); + } else if (w2 == Weight::Zero()) { + return Weight(iter1.Value()); + } else if (iter1.Value() == iter2.Value()) { + return Weight(iter1.Value()); + } else { + return Weight::One(); + } + } +}; + + +// The gallic common divisor uses the label common divisor on the +// string component and the template argument D common divisor on the +// weight component, which defaults to the default common divisor. +template <class L, class W, StringType S, class D = DefaultCommonDivisor<W> > +class GallicCommonDivisor { + public: + typedef GallicWeight<L, W, S> Weight; + + Weight operator()(const Weight &w1, const Weight &w2) const { + return Weight(label_common_divisor_(w1.Value1(), w2.Value1()), + weight_common_divisor_(w1.Value2(), w2.Value2())); + } + + private: + LabelCommonDivisor<L, S> label_common_divisor_; + D weight_common_divisor_; +}; + + +// Represents an element in a subset +template <class A> +struct DeterminizeElement { + typedef typename A::StateId StateId; + typedef typename A::Weight Weight; + + DeterminizeElement() {} + + DeterminizeElement(StateId s, Weight w) : state_id(s), weight(w) {} + + bool operator==(const DeterminizeElement<A> & element) const { + return state_id == element.state_id && weight == element.weight; + } + + bool operator<(const DeterminizeElement<A> & element) const { + return state_id < element.state_id || + (state_id == element.state_id && weight == element.weight); + } + + StateId state_id; // Input state Id + Weight weight; // Residual weight +}; + + +// +// DETERMINIZE FILTERS - these can be used in determinization to compute +// transformations on the subsets prior to their being added as destination +// states. The filter operates on a map between a label and the +// corresponding destination subsets. The possibly modified map is +// then used to construct the destination states for arcs exiting state 's'. +// It must define the ordered map type LabelMap and have a default +// and copy constructor. + +// A determinize filter that does not modify its input. +template <class Arc> +struct IdentityDeterminizeFilter { + typedef typename Arc::StateId StateId; + typedef typename Arc::Label Label; + typedef slist< DeterminizeElement<Arc> > Subset; + typedef map<Label, Subset*> LabelMap; + + static uint64 Properties(uint64 props) { return props; } + + void operator()(StateId s, LabelMap *label_map) {} +}; + + +// +// DETERMINIZATION STATE TABLES +// +// The determiziation state table has the form: +// +// template <class Arc> +// class DeterminizeStateTable { +// public: +// typedef typename Arc::StateId StateId; +// typedef DeterminizeElement<Arc> Element; +// typedef slist<Element> Subset; +// +// // Required constuctor +// DeterminizeStateTable(); +// +// // Required copy constructor that does not copy state +// DeterminizeStateTable(const DeterminizeStateTable<A,P> &table); +// +// // Lookup state ID by subset (not depending of the element order). +// // If it doesn't exist, then add it. FindState takes +// // ownership of the subset argument (so that it doesn't have to +// // copy it if it creates a new state). +// StateId FindState(Subset *subset); +// +// // Lookup subset by ID. +// const Subset *FindSubset(StateId id) const; +// }; +// + +// The default determinization state table based on the +// compact hash bi-table. +template <class Arc> +class DefaultDeterminizeStateTable { + public: + typedef typename Arc::StateId StateId; + typedef typename Arc::Label Label; + typedef typename Arc::Weight Weight; + typedef DeterminizeElement<Arc> Element; + typedef slist<Element> Subset; + + explicit DefaultDeterminizeStateTable(size_t table_size = 0) + : table_size_(table_size), + subsets_(table_size_, new SubsetKey(), new SubsetEqual(&elements_)) { } + + DefaultDeterminizeStateTable(const DefaultDeterminizeStateTable<Arc> &table) + : table_size_(table.table_size_), + subsets_(table_size_, new SubsetKey(), new SubsetEqual(&elements_)) { } + + ~DefaultDeterminizeStateTable() { + for (StateId s = 0; s < subsets_.Size(); ++s) + delete subsets_.FindEntry(s); + } + + // Finds the state corresponding to a subset. Only creates a new + // state if the subset is not found. FindState takes ownership of + // the subset argument (so that it doesn't have to copy it if it + // creates a new state). + StateId FindState(Subset *subset) { + StateId ns = subsets_.Size(); + StateId s = subsets_.FindId(subset); + if (s != ns) delete subset; // subset found + return s; + } + + const Subset* FindSubset(StateId s) { return subsets_.FindEntry(s); } + + private: + // Comparison object for hashing Subset(s). Subsets are not sorted in this + // implementation, so ordering must not be assumed in the equivalence + // test. + class SubsetEqual { + public: + SubsetEqual() { // needed for compilation but should never be called + FSTERROR() << "SubsetEqual: default constructor not implemented"; + } + + // Constructor takes vector needed to check equality. See immediately + // below for constraints on it. + explicit SubsetEqual(vector<Element *> *elements) + : elements_(elements) {} + + // At each call to operator(), the elements_ vector should contain + // only NULLs. When this operator returns, elements_ will still + // have this property. + bool operator()(Subset* subset1, Subset* subset2) const { + if (!subset1 && !subset2) + return true; + if ((subset1 && !subset2) || (!subset1 && subset2)) + return false; + + if (subset1->size() != subset2->size()) + return false; + + // Loads first subset elements in element vector. + for (typename Subset::iterator iter1 = subset1->begin(); + iter1 != subset1->end(); + ++iter1) { + Element &element1 = *iter1; + while (elements_->size() <= element1.state_id) + elements_->push_back(0); + (*elements_)[element1.state_id] = &element1; + } + + // Checks second subset matches first via element vector. + for (typename Subset::iterator iter2 = subset2->begin(); + iter2 != subset2->end(); + ++iter2) { + Element &element2 = *iter2; + while (elements_->size() <= element2.state_id) + elements_->push_back(0); + Element *element1 = (*elements_)[element2.state_id]; + if (!element1 || element1->weight != element2.weight) { + // Mismatch found. Resets element vector before returning false. + for (typename Subset::iterator iter1 = subset1->begin(); + iter1 != subset1->end(); + ++iter1) + (*elements_)[iter1->state_id] = 0; + return false; + } else { + (*elements_)[element2.state_id] = 0; // Clears entry + } + } + return true; + } + private: + vector<Element *> *elements_; + }; + + // Hash function for Subset to Fst states. Subset elements are not + // sorted in this implementation, so the hash must be invariant + // under subset reordering. + class SubsetKey { + public: + size_t operator()(const Subset* subset) const { + size_t hash = 0; + if (subset) { + for (typename Subset::const_iterator iter = subset->begin(); + iter != subset->end(); + ++iter) { + const Element &element = *iter; + int lshift = element.state_id % (CHAR_BIT * sizeof(size_t) - 1) + 1; + int rshift = CHAR_BIT * sizeof(size_t) - lshift; + size_t n = element.state_id; + hash ^= n << lshift ^ n >> rshift ^ element.weight.Hash(); + } + } + return hash; + } + }; + + size_t table_size_; + + typedef CompactHashBiTable<StateId, Subset *, + SubsetKey, SubsetEqual, HS_STL> SubsetTable; + + SubsetTable subsets_; + vector<Element *> elements_; + + void operator=(const DefaultDeterminizeStateTable<Arc> &); // disallow +}; + +// Options for finite-state transducer determinization templated on +// the arc type, common divisor, the determinization filter and the +// state table. DeterminizeFst takes ownership of the determinization +// filter and state table if provided. +template <class Arc, + class D = DefaultCommonDivisor<typename Arc::Weight>, + class F = IdentityDeterminizeFilter<Arc>, + class T = DefaultDeterminizeStateTable<Arc> > +struct DeterminizeFstOptions : CacheOptions { + typedef typename Arc::Label Label; + float delta; // Quantization delta for subset weights + Label subsequential_label; // Label used for residual final output + // when producing subsequential transducers. + F *filter; // Determinization filter + T *state_table; // Determinization state table + + explicit DeterminizeFstOptions(const CacheOptions &opts, + float del = kDelta, Label lab = 0, + F *filt = 0, + T *table = 0) + : CacheOptions(opts), delta(del), subsequential_label(lab), + filter(filt), state_table(table) {} + + explicit DeterminizeFstOptions(float del = kDelta, Label lab = 0, + F *filt = 0, T *table = 0) + : delta(del), subsequential_label(lab), filter(filt), + state_table(table) {} +}; + +// Implementation of delayed DeterminizeFst. This base class is +// common to the variants that implement acceptor and transducer +// determinization. +template <class A> +class DeterminizeFstImplBase : public CacheImpl<A> { + public: + using FstImpl<A>::SetType; + using FstImpl<A>::SetProperties; + using FstImpl<A>::Properties; + using FstImpl<A>::SetInputSymbols; + using FstImpl<A>::SetOutputSymbols; + + using CacheBaseImpl< CacheState<A> >::HasStart; + using CacheBaseImpl< CacheState<A> >::HasFinal; + using CacheBaseImpl< CacheState<A> >::HasArcs; + using CacheBaseImpl< CacheState<A> >::SetFinal; + using CacheBaseImpl< CacheState<A> >::SetStart; + + typedef typename A::Label Label; + typedef typename A::Weight Weight; + typedef typename A::StateId StateId; + typedef CacheState<A> State; + + template <class D, class F, class T> + DeterminizeFstImplBase(const Fst<A> &fst, + const DeterminizeFstOptions<A, D, F, T> &opts) + : CacheImpl<A>(opts), fst_(fst.Copy()) { + SetType("determinize"); + uint64 iprops = fst.Properties(kFstProperties, false); + uint64 dprops = DeterminizeProperties(iprops, + opts.subsequential_label != 0); + SetProperties(F::Properties(dprops), kCopyProperties); + SetInputSymbols(fst.InputSymbols()); + SetOutputSymbols(fst.OutputSymbols()); + } + + DeterminizeFstImplBase(const DeterminizeFstImplBase<A> &impl) + : CacheImpl<A>(impl), + fst_(impl.fst_->Copy(true)) { + SetType("determinize"); + SetProperties(impl.Properties(), kCopyProperties); + SetInputSymbols(impl.InputSymbols()); + SetOutputSymbols(impl.OutputSymbols()); + } + + virtual ~DeterminizeFstImplBase() { delete fst_; } + + virtual DeterminizeFstImplBase<A> *Copy() = 0; + + StateId Start() { + if (!HasStart()) { + StateId start = ComputeStart(); + if (start != kNoStateId) { + SetStart(start); + } + } + return CacheImpl<A>::Start(); + } + + Weight Final(StateId s) { + if (!HasFinal(s)) { + Weight final = ComputeFinal(s); + SetFinal(s, final); + } + return CacheImpl<A>::Final(s); + } + + virtual void Expand(StateId s) = 0; + + size_t NumArcs(StateId s) { + if (!HasArcs(s)) + Expand(s); + return CacheImpl<A>::NumArcs(s); + } + + size_t NumInputEpsilons(StateId s) { + if (!HasArcs(s)) + Expand(s); + return CacheImpl<A>::NumInputEpsilons(s); + } + + size_t NumOutputEpsilons(StateId s) { + if (!HasArcs(s)) + Expand(s); + return CacheImpl<A>::NumOutputEpsilons(s); + } + + void InitArcIterator(StateId s, ArcIteratorData<A> *data) { + if (!HasArcs(s)) + Expand(s); + CacheImpl<A>::InitArcIterator(s, data); + } + + virtual StateId ComputeStart() = 0; + + virtual Weight ComputeFinal(StateId s) = 0; + + const Fst<A> &GetFst() const { return *fst_; } + + private: + const Fst<A> *fst_; // Input Fst + + void operator=(const DeterminizeFstImplBase<A> &); // disallow +}; + + +// Implementation of delayed determinization for weighted acceptors. +// It is templated on the arc type A and the common divisor D. +template <class A, class D, class F, class T> +class DeterminizeFsaImpl : public DeterminizeFstImplBase<A> { + public: + using FstImpl<A>::SetProperties; + using DeterminizeFstImplBase<A>::GetFst; + using DeterminizeFstImplBase<A>::SetArcs; + + typedef typename A::Label Label; + typedef typename A::Weight Weight; + typedef typename A::StateId StateId; + typedef DeterminizeElement<A> Element; + typedef slist<Element> Subset; + typedef typename F::LabelMap LabelMap; + + DeterminizeFsaImpl(const Fst<A> &fst, + const vector<Weight> *in_dist, vector<Weight> *out_dist, + const DeterminizeFstOptions<A, D, F, T> &opts) + : DeterminizeFstImplBase<A>(fst, opts), + delta_(opts.delta), + in_dist_(in_dist), + out_dist_(out_dist), + filter_(opts.filter ? opts.filter : new F()), + state_table_(opts.state_table ? opts.state_table : new T()) { + if (!fst.Properties(kAcceptor, true)) { + FSTERROR() << "DeterminizeFst: argument not an acceptor"; + SetProperties(kError, kError); + } + if (!(Weight::Properties() & kLeftSemiring)) { + FSTERROR() << "DeterminizeFst: Weight needs to be left distributive: " + << Weight::Type(); + SetProperties(kError, kError); + } + if (out_dist_) + out_dist_->clear(); + } + + DeterminizeFsaImpl(const DeterminizeFsaImpl<A, D, F, T> &impl) + : DeterminizeFstImplBase<A>(impl), + delta_(impl.delta_), + in_dist_(0), + out_dist_(0), + filter_(new F(*impl.filter_)), + state_table_(new T(*impl.state_table_)) { + if (impl.out_dist_) { + FSTERROR() << "DeterminizeFsaImpl: cannot copy with out_dist vector"; + SetProperties(kError, kError); + } + } + + virtual ~DeterminizeFsaImpl() { + delete filter_; + delete state_table_; + } + + virtual DeterminizeFsaImpl<A, D, F, T> *Copy() { + return new DeterminizeFsaImpl<A, D, F, T>(*this); + } + + uint64 Properties() const { return Properties(kFstProperties); } + + // Set error if found; return FST impl properties. + uint64 Properties(uint64 mask) const { + if ((mask & kError) && (GetFst().Properties(kError, false))) + SetProperties(kError, kError); + return FstImpl<A>::Properties(mask); + } + + virtual StateId ComputeStart() { + StateId s = GetFst().Start(); + if (s == kNoStateId) + return kNoStateId; + Element element(s, Weight::One()); + Subset *subset = new Subset; + subset->push_front(element); + return FindState(subset); + } + + virtual Weight ComputeFinal(StateId s) { + const Subset *subset = state_table_->FindSubset(s); + Weight final = Weight::Zero(); + for (typename Subset::const_iterator siter = subset->begin(); + siter != subset->end(); + ++siter) { + const Element &element = *siter; + final = Plus(final, Times(element.weight, + GetFst().Final(element.state_id))); + if (!final.Member()) + SetProperties(kError, kError); + } + return final; + } + + StateId FindState(Subset *subset) { + StateId s = state_table_->FindState(subset); + if (in_dist_ && out_dist_->size() <= s) + out_dist_->push_back(ComputeDistance(subset)); + return s; + } + + // Compute distance from a state to the final states in the DFA + // given the distances in the NFA. + Weight ComputeDistance(const Subset *subset) { + Weight outd = Weight::Zero(); + for (typename Subset::const_iterator siter = subset->begin(); + siter != subset->end(); ++siter) { + const Element &element = *siter; + Weight ind = element.state_id < in_dist_->size() ? + (*in_dist_)[element.state_id] : Weight::Zero(); + outd = Plus(outd, Times(element.weight, ind)); + } + return outd; + } + + // Computes the outgoing transitions from a state, creating new destination + // states as needed. + virtual void Expand(StateId s) { + + LabelMap label_map; + LabelSubsets(s, &label_map); + + for (typename LabelMap::iterator liter = label_map.begin(); + liter != label_map.end(); + ++liter) + AddArc(s, liter->first, liter->second); + SetArcs(s); + } + + private: + // Constructs destination subsets per label. At return, subset + // element weights include the input automaton label weights and the + // subsets may contain duplicate states. + void LabelSubsets(StateId s, LabelMap *label_map) { + const Subset *src_subset = state_table_->FindSubset(s); + + for (typename Subset::const_iterator siter = src_subset->begin(); + siter != src_subset->end(); + ++siter) { + const Element &src_element = *siter; + for (ArcIterator< Fst<A> > aiter(GetFst(), src_element.state_id); + !aiter.Done(); + aiter.Next()) { + const A &arc = aiter.Value(); + Element dest_element(arc.nextstate, + Times(src_element.weight, arc.weight)); + + // The LabelMap may be a e.g. multimap with more complex + // determinization filters, so we insert efficiently w/o using []. + typename LabelMap::iterator liter = label_map->lower_bound(arc.ilabel); + Subset* dest_subset; + if (liter == label_map->end() || liter->first != arc.ilabel) { + dest_subset = new Subset; + label_map->insert(liter, make_pair(arc.ilabel, dest_subset)); + } else { + dest_subset = liter->second; + } + + dest_subset->push_front(dest_element); + } + } + // Applies the determinization filter + (*filter_)(s, label_map); + } + + // Adds an arc from state S to the destination state associated + // with subset DEST_SUBSET (as created by LabelSubsets). + void AddArc(StateId s, Label label, Subset *dest_subset) { + A arc; + arc.ilabel = label; + arc.olabel = label; + arc.weight = Weight::Zero(); + + typename Subset::iterator oiter; + for (typename Subset::iterator diter = dest_subset->begin(); + diter != dest_subset->end();) { + Element &dest_element = *diter; + // Computes label weight. + arc.weight = common_divisor_(arc.weight, dest_element.weight); + + while (elements_.size() <= dest_element.state_id) + elements_.push_back(0); + Element *matching_element = elements_[dest_element.state_id]; + if (matching_element) { + // Found duplicate state: sums state weight and deletes dup. + matching_element->weight = Plus(matching_element->weight, + dest_element.weight); + if (!matching_element->weight.Member()) + SetProperties(kError, kError); + ++diter; + dest_subset->erase_after(oiter); + } else { + // Saves element so we can check for duplicate for this state. + elements_[dest_element.state_id] = &dest_element; + oiter = diter; + ++diter; + } + } + + // Divides out label weight from destination subset elements. + // Quantizes to ensure comparisons are effective. + // Clears element vector. + for (typename Subset::iterator diter = dest_subset->begin(); + diter != dest_subset->end(); + ++diter) { + Element &dest_element = *diter; + dest_element.weight = Divide(dest_element.weight, arc.weight, + DIVIDE_LEFT); + dest_element.weight = dest_element.weight.Quantize(delta_); + elements_[dest_element.state_id] = 0; + } + + arc.nextstate = FindState(dest_subset); + CacheImpl<A>::PushArc(s, arc); + } + + float delta_; // Quantization delta for subset weights + const vector<Weight> *in_dist_; // Distance to final NFA states + vector<Weight> *out_dist_; // Distance to final DFA states + + D common_divisor_; + F *filter_; + T *state_table_; + + vector<Element *> elements_; + + void operator=(const DeterminizeFsaImpl<A, D, F, T> &); // disallow +}; + + +// Implementation of delayed determinization for transducers. +// Transducer determinization is implemented by mapping the input to +// the Gallic semiring as an acceptor whose weights contain the output +// strings and using acceptor determinization above to determinize +// that acceptor. +template <class A, StringType S, class D, class F, class T> +class DeterminizeFstImpl : public DeterminizeFstImplBase<A> { + public: + using FstImpl<A>::SetProperties; + using DeterminizeFstImplBase<A>::GetFst; + using CacheBaseImpl< CacheState<A> >::GetCacheGc; + using CacheBaseImpl< CacheState<A> >::GetCacheLimit; + + typedef typename A::Label Label; + typedef typename A::Weight Weight; + typedef typename A::StateId StateId; + + typedef ToGallicMapper<A, S> ToMapper; + typedef FromGallicMapper<A, S> FromMapper; + + typedef typename ToMapper::ToArc ToArc; + typedef ArcMapFst<A, ToArc, ToMapper> ToFst; + typedef ArcMapFst<ToArc, A, FromMapper> FromFst; + + typedef GallicCommonDivisor<Label, Weight, S, D> CommonDivisor; + typedef GallicFactor<Label, Weight, S> FactorIterator; + + DeterminizeFstImpl(const Fst<A> &fst, + const DeterminizeFstOptions<A, D, F, T> &opts) + : DeterminizeFstImplBase<A>(fst, opts), + delta_(opts.delta), + subsequential_label_(opts.subsequential_label) { + Init(GetFst()); + } + + DeterminizeFstImpl(const DeterminizeFstImpl<A, S, D, F, T> &impl) + : DeterminizeFstImplBase<A>(impl), + delta_(impl.delta_), + subsequential_label_(impl.subsequential_label_) { + Init(GetFst()); + } + + ~DeterminizeFstImpl() { delete from_fst_; } + + virtual DeterminizeFstImpl<A, S, D, F, T> *Copy() { + return new DeterminizeFstImpl<A, S, D, F, T>(*this); + } + + uint64 Properties() const { return Properties(kFstProperties); } + + // Set error if found; return FST impl properties. + uint64 Properties(uint64 mask) const { + if ((mask & kError) && (GetFst().Properties(kError, false) || + from_fst_->Properties(kError, false))) + SetProperties(kError, kError); + return FstImpl<A>::Properties(mask); + } + + virtual StateId ComputeStart() { return from_fst_->Start(); } + + virtual Weight ComputeFinal(StateId s) { return from_fst_->Final(s); } + + virtual void Expand(StateId s) { + for (ArcIterator<FromFst> aiter(*from_fst_, s); + !aiter.Done(); + aiter.Next()) + CacheImpl<A>::PushArc(s, aiter.Value()); + CacheImpl<A>::SetArcs(s); + } + + private: + // Initialization of transducer determinization implementation, which + // is defined after DeterminizeFst since it calls it. + void Init(const Fst<A> &fst); + + float delta_; + Label subsequential_label_; + FromFst *from_fst_; + + void operator=(const DeterminizeFstImpl<A, S, D, F, T> &); // disallow +}; + + +// Determinizes a weighted transducer. This version is a delayed +// Fst. The result will be an equivalent FST that has the property +// that no state has two transitions with the same input label. +// For this algorithm, epsilon transitions are treated as regular +// symbols (cf. RmEpsilon). +// +// The transducer must be functional. The weights must be (weakly) +// left divisible (valid for TropicalWeight and LogWeight for instance) +// and be zero-sum-free if for all a,b: (Plus(a, b) = 0 => a = b = 0. +// +// Complexity: +// - Determinizable: exponential (polynomial in the size of the output) +// - Non-determinizable) does not terminate +// +// The determinizable automata include all unweighted and all acyclic input. +// +// References: +// - Mehryar Mohri, "Finite-State Transducers in Language and Speech +// Processing". Computational Linguistics, 23:2, 1997. +// +// This class attaches interface to implementation and handles +// reference counting, delegating most methods to ImplToFst. +template <class A> +class DeterminizeFst : public ImplToFst< DeterminizeFstImplBase<A> > { + public: + friend class ArcIterator< DeterminizeFst<A> >; + friend class StateIterator< DeterminizeFst<A> >; + template <class B, StringType S, class D, class F, class T> + friend class DeterminizeFstImpl; + + typedef A Arc; + typedef typename A::Weight Weight; + typedef typename A::StateId StateId; + typedef typename A::Label Label; + typedef CacheState<A> State; + typedef DeterminizeFstImplBase<A> Impl; + + using ImplToFst<Impl>::SetImpl; + + explicit DeterminizeFst(const Fst<A> &fst) { + typedef DefaultCommonDivisor<Weight> D; + typedef IdentityDeterminizeFilter<A> F; + typedef DefaultDeterminizeStateTable<A> T; + DeterminizeFstOptions<A, D, F, T> opts; + if (fst.Properties(kAcceptor, true)) { + // Calls implementation for acceptors. + SetImpl(new DeterminizeFsaImpl<A, D, F, T>(fst, 0, 0, opts)); + } else { + // Calls implementation for transducers. + SetImpl(new + DeterminizeFstImpl<A, STRING_LEFT_RESTRICT, D, F, T>(fst, opts)); + } + } + + template <class D, class F, class T> + DeterminizeFst(const Fst<A> &fst, + const DeterminizeFstOptions<A, D, F, T> &opts) { + if (fst.Properties(kAcceptor, true)) { + // Calls implementation for acceptors. + SetImpl(new DeterminizeFsaImpl<A, D, F, T>(fst, 0, 0, opts)); + } else { + // Calls implementation for transducers. + SetImpl(new + DeterminizeFstImpl<A, STRING_LEFT_RESTRICT, D, F, T>(fst, opts)); + } + } + + // This acceptor-only version additionally computes the distance to + // final states in the output if provided with those distances for the + // input. Useful for e.g. unique N-shortest paths. + template <class D, class F, class T> + DeterminizeFst(const Fst<A> &fst, + const vector<Weight> *in_dist, vector<Weight> *out_dist, + const DeterminizeFstOptions<A, D, F, T> &opts) { + if (!fst.Properties(kAcceptor, true)) { + FSTERROR() << "DeterminizeFst:" + << " distance to final states computed for acceptors only"; + GetImpl()->SetProperties(kError, kError); + } + SetImpl(new DeterminizeFsaImpl<A, D, F, T>(fst, in_dist, out_dist, opts)); + } + + // See Fst<>::Copy() for doc. + DeterminizeFst(const DeterminizeFst<A> &fst, bool safe = false) { + if (safe) + SetImpl(fst.GetImpl()->Copy()); + else + SetImpl(fst.GetImpl(), false); + } + + // Get a copy of this DeterminizeFst. See Fst<>::Copy() for further doc. + virtual DeterminizeFst<A> *Copy(bool safe = false) const { + return new DeterminizeFst<A>(*this, safe); + } + + virtual inline void InitStateIterator(StateIteratorData<A> *data) const; + + virtual void InitArcIterator(StateId s, ArcIteratorData<A> *data) const { + GetImpl()->InitArcIterator(s, data); + } + + private: + // Makes visible to friends. + Impl *GetImpl() const { return ImplToFst<Impl>::GetImpl(); } + + void operator=(const DeterminizeFst<A> &fst); // Disallow +}; + + +// Initialization of transducer determinization implementation. which +// is defined after DeterminizeFst since it calls it. +template <class A, StringType S, class D, class F, class T> +void DeterminizeFstImpl<A, S, D, F, T>::Init(const Fst<A> &fst) { + // Mapper to an acceptor. + ToFst to_fst(fst, ToMapper()); + + // Determinizes acceptor. + // This recursive call terminates since it passes the common divisor + // to a private constructor. + CacheOptions copts(GetCacheGc(), GetCacheLimit()); + DeterminizeFstOptions<ToArc, CommonDivisor> dopts(copts, delta_); + // Uses acceptor-only constructor to avoid template recursion + DeterminizeFst<ToArc> det_fsa(to_fst, 0, 0, dopts); + + // Mapper back to transducer. + FactorWeightOptions<ToArc> fopts(CacheOptions(true, 0), delta_, + kFactorFinalWeights, + subsequential_label_, + subsequential_label_); + FactorWeightFst<ToArc, FactorIterator> factored_fst(det_fsa, fopts); + from_fst_ = new FromFst(factored_fst, FromMapper(subsequential_label_)); +} + + +// Specialization for DeterminizeFst. +template <class A> +class StateIterator< DeterminizeFst<A> > + : public CacheStateIterator< DeterminizeFst<A> > { + public: + explicit StateIterator(const DeterminizeFst<A> &fst) + : CacheStateIterator< DeterminizeFst<A> >(fst, fst.GetImpl()) {} +}; + + +// Specialization for DeterminizeFst. +template <class A> +class ArcIterator< DeterminizeFst<A> > + : public CacheArcIterator< DeterminizeFst<A> > { + public: + typedef typename A::StateId StateId; + + ArcIterator(const DeterminizeFst<A> &fst, StateId s) + : CacheArcIterator< DeterminizeFst<A> >(fst.GetImpl(), s) { + if (!fst.GetImpl()->HasArcs(s)) + fst.GetImpl()->Expand(s); + } + + private: + DISALLOW_COPY_AND_ASSIGN(ArcIterator); +}; + + +template <class A> inline +void DeterminizeFst<A>::InitStateIterator(StateIteratorData<A> *data) const +{ + data->base = new StateIterator< DeterminizeFst<A> >(*this); +} + + +// Useful aliases when using StdArc. +typedef DeterminizeFst<StdArc> StdDeterminizeFst; + + +template <class Arc> +struct DeterminizeOptions { + typedef typename Arc::StateId StateId; + typedef typename Arc::Weight Weight; + typedef typename Arc::Label Label; + + float delta; // Quantization delta for subset weights. + Weight weight_threshold; // Pruning weight threshold. + StateId state_threshold; // Pruning state threshold. + Label subsequential_label; // Label used for residual final output + // when producing subsequential transducers. + + explicit DeterminizeOptions(float d = kDelta, Weight w = Weight::Zero(), + StateId n = kNoStateId, Label l = 0) + : delta(d), weight_threshold(w), state_threshold(n), + subsequential_label(l) {} +}; + + +// Determinizes a weighted transducer. This version writes the +// determinized Fst to an output MutableFst. The result will be an +// equivalent FST that has the property that no state has two +// transitions with the same input label. For this algorithm, epsilon +// transitions are treated as regular symbols (cf. RmEpsilon). +// +// The transducer must be functional. The weights must be (weakly) +// left divisible (valid for TropicalWeight and LogWeight). +// +// Complexity: +// - Determinizable: exponential (polynomial in the size of the output) +// - Non-determinizable: does not terminate +// +// The determinizable automata include all unweighted and all acyclic input. +// +// References: +// - Mehryar Mohri, "Finite-State Transducers in Language and Speech +// Processing". Computational Linguistics, 23:2, 1997. +template <class Arc> +void Determinize(const Fst<Arc> &ifst, MutableFst<Arc> *ofst, + const DeterminizeOptions<Arc> &opts + = DeterminizeOptions<Arc>()) { + typedef typename Arc::StateId StateId; + typedef typename Arc::Weight Weight; + + DeterminizeFstOptions<Arc> nopts; + nopts.delta = opts.delta; + nopts.subsequential_label = opts.subsequential_label; + + nopts.gc_limit = 0; // Cache only the last state for fastest copy. + + if (opts.weight_threshold != Weight::Zero() || + opts.state_threshold != kNoStateId) { + if (ifst.Properties(kAcceptor, false)) { + vector<Weight> idistance, odistance; + ShortestDistance(ifst, &idistance, true); + DeterminizeFst<Arc> dfst(ifst, &idistance, &odistance, nopts); + PruneOptions< Arc, AnyArcFilter<Arc> > popts(opts.weight_threshold, + opts.state_threshold, + AnyArcFilter<Arc>(), + &odistance); + Prune(dfst, ofst, popts); + } else { + *ofst = DeterminizeFst<Arc>(ifst, nopts); + Prune(ofst, opts.weight_threshold, opts.state_threshold); + } + } else { + *ofst = DeterminizeFst<Arc>(ifst, nopts); + } +} + + +} // namespace fst + +#endif // FST_LIB_DETERMINIZE_H__ diff --git a/kaldi_io/src/tools/openfst/include/fst/dfs-visit.h b/kaldi_io/src/tools/openfst/include/fst/dfs-visit.h new file mode 100644 index 0000000..4d93a39 --- /dev/null +++ b/kaldi_io/src/tools/openfst/include/fst/dfs-visit.h @@ -0,0 +1,205 @@ +// dfs-visit.h + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Copyright 2005-2010 Google, Inc. +// Author: [email protected] (Michael Riley) +// +// \file +// Depth-first search visitation. See visit.h for more general +// search queue disciplines. + +#ifndef FST_LIB_DFS_VISIT_H__ +#define FST_LIB_DFS_VISIT_H__ + +#include <stack> +#include <vector> +using std::vector; + +#include <fst/arcfilter.h> +#include <fst/fst.h> + + +namespace fst { + +// Visitor Interface - class determines actions taken during a Dfs. +// If any of the boolean member functions return false, the DFS is +// aborted by first calling FinishState() on all currently grey states +// and then calling FinishVisit(). +// +// Note this is similar to the more general visitor interface in visit.h +// except that FinishState returns additional information appropriate only for +// a DFS and some methods names here are better suited to a DFS. +// +// template <class Arc> +// class Visitor { +// public: +// typedef typename Arc::StateId StateId; +// +// Visitor(T *return_data); +// // Invoked before DFS visit +// void InitVisit(const Fst<Arc> &fst); +// // Invoked when state discovered (2nd arg is DFS tree root) +// bool InitState(StateId s, StateId root); +// // Invoked when tree arc examined (to white/undiscovered state) +// bool TreeArc(StateId s, const Arc &a); +// // Invoked when back arc examined (to grey/unfinished state) +// bool BackArc(StateId s, const Arc &a); +// // Invoked when forward or cross arc examined (to black/finished state) +// bool ForwardOrCrossArc(StateId s, const Arc &a); +// // Invoked when state finished (PARENT is kNoStateID and ARC == NULL +// // when S is tree root) +// void FinishState(StateId s, StateId parent, const Arc *parent_arc); +// // Invoked after DFS visit +// void FinishVisit(); +// }; + +// An Fst state's DFS status +const int kDfsWhite = 0; // Undiscovered +const int kDfsGrey = 1; // Discovered & unfinished +const int kDfsBlack = 2; // Finished + +// An Fst state's DFS stack state +template <class Arc> +struct DfsState { + typedef typename Arc::StateId StateId; + + DfsState(const Fst<Arc> &fst, StateId s): state_id(s), arc_iter(fst, s) {} + + StateId state_id; // Fst state ... + ArcIterator< Fst<Arc> > arc_iter; // and its corresponding arcs +}; + + +// Performs depth-first visitation. Visitor class argument determines +// actions and contains any return data. ArcFilter determines arcs +// that are considered. +// +// Note this is similar to Visit() in visit.h called with a LIFO +// queue except this version has a Visitor class specialized and +// augmented for a DFS. +template <class Arc, class V, class ArcFilter> +void DfsVisit(const Fst<Arc> &fst, V *visitor, ArcFilter filter) { + typedef typename Arc::StateId StateId; + + visitor->InitVisit(fst); + + StateId start = fst.Start(); + if (start == kNoStateId) { + visitor->FinishVisit(); + return; + } + + vector<char> state_color; // Fst state DFS status + stack<DfsState<Arc> *> state_stack; // DFS execution stack + + StateId nstates = start + 1; // # of known states in general case + bool expanded = false; + if (fst.Properties(kExpanded, false)) { // tests if expanded case, then + nstates = CountStates(fst); // uses ExpandedFst::NumStates(). + expanded = true; + } + + state_color.resize(nstates, kDfsWhite); + StateIterator< Fst<Arc> > siter(fst); + + // Continue DFS while true + bool dfs = true; + + // Iterate over trees in DFS forest. + for (StateId root = start; dfs && root < nstates;) { + state_color[root] = kDfsGrey; + state_stack.push(new DfsState<Arc>(fst, root)); + dfs = visitor->InitState(root, root); + while (!state_stack.empty()) { + DfsState<Arc> *dfs_state = state_stack.top(); + StateId s = dfs_state->state_id; + if (s >= state_color.size()) { + nstates = s + 1; + state_color.resize(nstates, kDfsWhite); + } + ArcIterator< Fst<Arc> > &aiter = dfs_state->arc_iter; + if (!dfs || aiter.Done()) { + state_color[s] = kDfsBlack; + delete dfs_state; + state_stack.pop(); + if (!state_stack.empty()) { + DfsState<Arc> *parent_state = state_stack.top(); + StateId p = parent_state->state_id; + ArcIterator< Fst<Arc> > &piter = parent_state->arc_iter; + visitor->FinishState(s, p, &piter.Value()); + piter.Next(); + } else { + visitor->FinishState(s, kNoStateId, 0); + } + continue; + } + const Arc &arc = aiter.Value(); + if (arc.nextstate >= state_color.size()) { + nstates = arc.nextstate + 1; + state_color.resize(nstates, kDfsWhite); + } + if (!filter(arc)) { + aiter.Next(); + continue; + } + int next_color = state_color[arc.nextstate]; + switch (next_color) { + default: + case kDfsWhite: + dfs = visitor->TreeArc(s, arc); + if (!dfs) break; + state_color[arc.nextstate] = kDfsGrey; + state_stack.push(new DfsState<Arc>(fst, arc.nextstate)); + dfs = visitor->InitState(arc.nextstate, root); + break; + case kDfsGrey: + dfs = visitor->BackArc(s, arc); + aiter.Next(); + break; + case kDfsBlack: + dfs = visitor->ForwardOrCrossArc(s, arc); + aiter.Next(); + break; + } + } + + // Find next tree root + for (root = root == start ? 0 : root + 1; + root < nstates && state_color[root] != kDfsWhite; + ++root) { + } + + // Check for a state beyond the largest known state + if (!expanded && root == nstates) { + for (; !siter.Done(); siter.Next()) { + if (siter.Value() == nstates) { + ++nstates; + state_color.push_back(kDfsWhite); + break; + } + } + } + } + visitor->FinishVisit(); +} + + +template <class Arc, class V> +void DfsVisit(const Fst<Arc> &fst, V *visitor) { + DfsVisit(fst, visitor, AnyArcFilter<Arc>()); +} + +} // namespace fst + +#endif // FST_LIB_DFS_VISIT_H__ diff --git a/kaldi_io/src/tools/openfst/include/fst/difference.h b/kaldi_io/src/tools/openfst/include/fst/difference.h new file mode 100644 index 0000000..8a3306f --- /dev/null +++ b/kaldi_io/src/tools/openfst/include/fst/difference.h @@ -0,0 +1,189 @@ +// difference.h + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Copyright 2005-2010 Google, Inc. +// Author: [email protected] (Michael Riley) +// +// \file +// Class to compute the difference between two FSAs + +#ifndef FST_LIB_DIFFERENCE_H__ +#define FST_LIB_DIFFERENCE_H__ + +#include <vector> +using std::vector; +#include <algorithm> + +#include <fst/cache.h> +#include <fst/compose.h> +#include <fst/complement.h> + + +namespace fst { + +template <class A, + class M = Matcher<Fst<A> >, + class F = SequenceComposeFilter<M>, + class T = GenericComposeStateTable<A, typename F::FilterState> > +struct DifferenceFstOptions : public ComposeFstOptions<A, M, F, T> { + explicit DifferenceFstOptions(const CacheOptions &opts, + M *mat1 = 0, M *mat2 = 0, + F *filt = 0, T *sttable= 0) + : ComposeFstOptions<A, M, F, T>(mat1, mat2, filt, sttable) { } + + DifferenceFstOptions() {} +}; + +// Computes the difference between two FSAs. This version is a delayed +// Fst. Only strings that are in the first automaton but not in second +// are retained in the result. +// +// The first argument must be an acceptor; the second argument must be +// an unweighted, epsilon-free, deterministic acceptor. One of the +// arguments must be label-sorted. +// +// Complexity: same as ComposeFst. +// +// Caveats: same as ComposeFst. +template <class A> +class DifferenceFst : public ComposeFst<A> { + public: + using ImplToFst< ComposeFstImplBase<A> >::SetImpl; + using ImplToFst< ComposeFstImplBase<A> >::GetImpl; + + using ComposeFst<A>::CreateBase1; + + typedef A Arc; + typedef typename A::Weight Weight; + typedef typename A::StateId StateId; + + // A - B = A ^ B'. + DifferenceFst(const Fst<A> &fst1, const Fst<A> &fst2, + const CacheOptions &opts = CacheOptions()) { + typedef RhoMatcher< Matcher<Fst<A> > > R; + + ComplementFst<A> cfst(fst2); + ComposeFstOptions<A, R> copts(CacheOptions(), + new R(fst1, MATCH_NONE), + new R(cfst, MATCH_INPUT, + ComplementFst<A>::kRhoLabel)); + SetImpl(CreateBase1(fst1, cfst, copts)); + + if (!fst1.Properties(kAcceptor, true)) { + FSTERROR() << "DifferenceFst: 1st argument not an acceptor"; + GetImpl()->SetProperties(kError, kError); + } + } + + template <class M, class F, class T> + DifferenceFst(const Fst<A> &fst1, const Fst<A> &fst2, + const DifferenceFstOptions<A, M, F, T> &opts) { + typedef RhoMatcher<M> R; + + ComplementFst<A> cfst(fst2); + ComposeFstOptions<A, R> copts(opts); + copts.matcher1 = new R(fst1, MATCH_NONE, kNoLabel, MATCHER_REWRITE_ALWAYS, + opts.matcher1); + copts.matcher2 = new R(cfst, MATCH_INPUT, ComplementFst<A>::kRhoLabel, + MATCHER_REWRITE_ALWAYS, opts.matcher2); + + SetImpl(CreateBase1(fst1, cfst, copts)); + + if (!fst1.Properties(kAcceptor, true)) { + FSTERROR() << "DifferenceFst: 1st argument not an acceptor"; + GetImpl()->SetProperties(kError, kError); + } + } + + // See Fst<>::Copy() for doc. + DifferenceFst(const DifferenceFst<A> &fst, bool safe = false) + : ComposeFst<A>(fst, safe) {} + + // Get a copy of this DifferenceFst. See Fst<>::Copy() for further doc. + virtual DifferenceFst<A> *Copy(bool safe = false) const { + return new DifferenceFst<A>(*this, safe); + } +}; + + +// Specialization for DifferenceFst. +template <class A> +class StateIterator< DifferenceFst<A> > + : public StateIterator< ComposeFst<A> > { + public: + explicit StateIterator(const DifferenceFst<A> &fst) + : StateIterator< ComposeFst<A> >(fst) {} +}; + + +// Specialization for DifferenceFst. +template <class A> +class ArcIterator< DifferenceFst<A> > + : public ArcIterator< ComposeFst<A> > { + public: + typedef typename A::StateId StateId; + + ArcIterator(const DifferenceFst<A> &fst, StateId s) + : ArcIterator< ComposeFst<A> >(fst, s) {} +}; + +// Useful alias when using StdArc. +typedef DifferenceFst<StdArc> StdDifferenceFst; + + +typedef ComposeOptions DifferenceOptions; + + +// Computes the difference between two FSAs. This version is writes +// the difference to an output MutableFst. Only strings that are in +// the first automaton but not in second are retained in the result. +// +// The first argument must be an acceptor; the second argument must be +// an unweighted, epsilon-free, deterministic acceptor. One of the +// arguments must be label-sorted. +// +// Complexity: same as Compose. +// +// Caveats: same as Compose. +template<class Arc> +void Difference(const Fst<Arc> &ifst1, const Fst<Arc> &ifst2, + MutableFst<Arc> *ofst, + const DifferenceOptions &opts = DifferenceOptions()) { + typedef Matcher< Fst<Arc> > M; + + if (opts.filter_type == AUTO_FILTER) { + CacheOptions nopts; + nopts.gc_limit = 0; // Cache only the last state for fastest copy. + *ofst = DifferenceFst<Arc>(ifst1, ifst2, nopts); + } else if (opts.filter_type == SEQUENCE_FILTER) { + DifferenceFstOptions<Arc> dopts; + dopts.gc_limit = 0; // Cache only the last state for fastest copy. + *ofst = DifferenceFst<Arc>(ifst1, ifst2, dopts); + } else if (opts.filter_type == ALT_SEQUENCE_FILTER) { + DifferenceFstOptions<Arc, M, AltSequenceComposeFilter<M> > dopts; + dopts.gc_limit = 0; // Cache only the last state for fastest copy. + *ofst = DifferenceFst<Arc>(ifst1, ifst2, dopts); + } else if (opts.filter_type == MATCH_FILTER) { + DifferenceFstOptions<Arc, M, MatchComposeFilter<M> > dopts; + dopts.gc_limit = 0; // Cache only the last state for fastest copy. + *ofst = DifferenceFst<Arc>(ifst1, ifst2, dopts); + } + + if (opts.connect) + Connect(ofst); +} + +} // namespace fst + +#endif // FST_LIB_DIFFERENCE_H__ diff --git a/kaldi_io/src/tools/openfst/include/fst/edit-fst.h b/kaldi_io/src/tools/openfst/include/fst/edit-fst.h new file mode 100644 index 0000000..bd33b9d --- /dev/null +++ b/kaldi_io/src/tools/openfst/include/fst/edit-fst.h @@ -0,0 +1,779 @@ + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Copyright 2005-2010 Google, Inc. +// Author: [email protected] (Dan Bikel) +// +// An \ref Fst implementation that allows non-destructive edit operations on an +// existing fst. + +#ifndef FST_LIB_EDIT_FST_H_ +#define FST_LIB_EDIT_FST_H_ + +#include <vector> +using std::vector; + +#include <fst/cache.h> + +#include <tr1/unordered_map> +using std::tr1::unordered_map; +using std::tr1::unordered_multimap; + +namespace fst { + +// The EditFst class enables non-destructive edit operations on a wrapped +// ExpandedFst. The implementation uses copy-on-write semantics at the node +// level: if a user has an underlying fst on which he or she wants to perform a +// relatively small number of edits (read: mutations), then this implementation +// will copy the edited node to an internal MutableFst and perform any edits in +// situ on that copied node. This class supports all the methods of MutableFst +// except for DeleteStates(const vector<StateId> &); thus, new nodes may also be +// added, and one may add transitions from existing nodes of the wrapped fst to +// new nodes. +// +// N.B.: The documentation for Fst::Copy(true) says that its behavior is +// undefined if invoked on an fst that has already been accessed. This class +// requires that the Fst implementation it wraps provides consistent, reliable +// behavior when its Copy(true) method is invoked, where consistent means +// the graph structure, graph properties and state numbering and do not change. +// VectorFst and CompactFst, for example, are both well-behaved in this regard. + +// The EditFstData class is a container for all mutable data for EditFstImpl; +// also, this class provides most of the actual implementation of what EditFst +// does (that is, most of EditFstImpl's methods delegate to methods in this, the +// EditFstData class). Instances of this class are reference-counted and can be +// shared between otherwise independent EditFstImpl instances. This scheme +// allows EditFstImpl to implement the thread-safe, copy-on-write semantics +// required by Fst::Copy(true). +// +// template parameters: +// A the type of arc to use +// WrappedFstT the type of fst wrapped by the EditFst instance that +// this EditFstData instance is backing +// MutableFstT the type of mutable fst to use internally for edited states; +// crucially, MutableFstT::Copy(false) *must* yield an fst that is +// thread-safe for reading (VectorFst, for example, has this property) +template <typename A, + typename WrappedFstT = ExpandedFst<A>, + typename MutableFstT = VectorFst<A> > +class EditFstData { + public: + typedef A Arc; + typedef typename A::Weight Weight; + typedef typename A::StateId StateId; + typedef typename unordered_map<StateId, StateId>::const_iterator + IdMapIterator; + typedef typename unordered_map<StateId, Weight>::const_iterator + FinalWeightIterator; + + + EditFstData() : num_new_states_(0) { + SetEmptyAndDeleteKeysForInternalMaps(); + } + + EditFstData(const EditFstData &other) : + edits_(other.edits_), + external_to_internal_ids_(other.external_to_internal_ids_), + edited_final_weights_(other.edited_final_weights_), + num_new_states_(other.num_new_states_) { + } + + ~EditFstData() { + } + + static EditFstData<A, WrappedFstT, MutableFstT> *Read(istream &strm, + const FstReadOptions &opts); + + bool Write(ostream &strm, const FstWriteOptions &opts) const { + // Serialize all private data members of this class. + FstWriteOptions edits_opts(opts); + edits_opts.write_header = true; // Force writing contained header. + edits_.Write(strm, edits_opts); + WriteType(strm, external_to_internal_ids_); + WriteType(strm, edited_final_weights_); + WriteType(strm, num_new_states_); + if (!strm) { + LOG(ERROR) << "EditFstData::Write: write failed: " << opts.source; + return false; + } + return true; + } + + int RefCount() const { return ref_count_.count(); } + int IncrRefCount() { return ref_count_.Incr(); } + int DecrRefCount() { return ref_count_.Decr(); } + + StateId NumNewStates() const { + return num_new_states_; + } + + // accessor methods for the fst holding edited states + StateId EditedStart() const { + return edits_.Start(); + } + + Weight Final(StateId s, const WrappedFstT *wrapped) const { + FinalWeightIterator final_weight_it = GetFinalWeightIterator(s); + if (final_weight_it == NotInFinalWeightMap()) { + IdMapIterator it = GetEditedIdMapIterator(s); + return it == NotInEditedMap() ? + wrapped->Final(s) : edits_.Final(it->second); + } + else { + return final_weight_it->second; + } + } + + size_t NumArcs(StateId s, const WrappedFstT *wrapped) const { + IdMapIterator it = GetEditedIdMapIterator(s); + return it == NotInEditedMap() ? + wrapped->NumArcs(s) : edits_.NumArcs(it->second); + } + + size_t NumInputEpsilons(StateId s, const WrappedFstT *wrapped) const { + IdMapIterator it = GetEditedIdMapIterator(s); + return it == NotInEditedMap() ? + wrapped->NumInputEpsilons(s) : + edits_.NumInputEpsilons(it->second); + } + + size_t NumOutputEpsilons(StateId s, const WrappedFstT *wrapped) const { + IdMapIterator it = GetEditedIdMapIterator(s); + return it == NotInEditedMap() ? + wrapped->NumOutputEpsilons(s) : + edits_.NumOutputEpsilons(it->second); + } + + void SetEditedProperties(uint64 props, uint64 mask) { + edits_.SetProperties(props, mask); + } + + // non-const MutableFst operations + + // Sets the start state for this fst. + void SetStart(StateId s) { + edits_.SetStart(s); + } + + // Sets the final state for this fst. + Weight SetFinal(StateId s, Weight w, const WrappedFstT *wrapped) { + Weight old_weight = Final(s, wrapped); + IdMapIterator it = GetEditedIdMapIterator(s); + // if we haven't already edited state s, don't add it to edited_ (which can + // be expensive if s has many transitions); just use the + // edited_final_weights_ map + if (it == NotInEditedMap()) { + edited_final_weights_[s] = w; + } + else { + edits_.SetFinal(GetEditableInternalId(s, wrapped), w); + } + return old_weight; + } + + // Adds a new state to this fst, initially with no arcs. + StateId AddState(StateId curr_num_states) { + StateId internal_state_id = edits_.AddState(); + StateId external_state_id = curr_num_states; + external_to_internal_ids_[external_state_id] = internal_state_id; + num_new_states_++; + return external_state_id; + } + + // Adds the specified arc to the specified state of this fst. + const A *AddArc(StateId s, const Arc &arc, const WrappedFstT *wrapped) { + StateId internal_id = GetEditableInternalId(s, wrapped); + + size_t num_arcs = edits_.NumArcs(internal_id); + ArcIterator<MutableFstT> arc_it(edits_, internal_id); + const A *prev_arc = NULL; + if (num_arcs > 0) { + // grab the final arc associated with this state in edits_ + arc_it.Seek(num_arcs - 1); + prev_arc = &(arc_it.Value()); + } + edits_.AddArc(internal_id, arc); + return prev_arc; + } + + void DeleteStates() { + edits_.DeleteStates(); + num_new_states_ = 0; + external_to_internal_ids_.clear(); + edited_final_weights_.clear(); + } + + // Removes all but the first n outgoing arcs of the specified state. + void DeleteArcs(StateId s, size_t n, const WrappedFstT *wrapped) { + edits_.DeleteArcs(GetEditableInternalId(s, wrapped), n); + } + + // Removes all outgoing arcs from the specified state. + void DeleteArcs(StateId s, const WrappedFstT *wrapped) { + edits_.DeleteArcs(GetEditableInternalId(s, wrapped)); + } + + // end methods for non-const MutableFst operations + + // Provides information for the generic arc iterator. + void InitArcIterator(StateId s, ArcIteratorData<Arc> *data, + const WrappedFstT *wrapped) const { + IdMapIterator id_map_it = GetEditedIdMapIterator(s); + if (id_map_it == NotInEditedMap()) { + VLOG(3) << "EditFstData::InitArcIterator: iterating on state " + << s << " of original fst"; + wrapped->InitArcIterator(s, data); + } else { + VLOG(2) << "EditFstData::InitArcIterator: iterating on edited state " + << s << " (internal state id: " << id_map_it->second << ")"; + edits_.InitArcIterator(id_map_it->second, data); + } + } + + // Provides information for the generic mutable arc iterator. + void InitMutableArcIterator(StateId s, MutableArcIteratorData<A> *data, + const WrappedFstT *wrapped) { + data->base = + new MutableArcIterator<MutableFstT>(&edits_, + GetEditableInternalId(s, wrapped)); + } + + // Prints out the map from external to internal state id's (for debugging + // purposes). + void PrintMap() { + for (IdMapIterator map_it = external_to_internal_ids_.begin(); + map_it != NotInEditedMap(); ++map_it) { + LOG(INFO) << "(external,internal)=(" + << map_it->first << "," << map_it->second << ")"; + } + } + + + private: + void SetEmptyAndDeleteKeysForInternalMaps() { + } + + // Returns the iterator of the map from external to internal state id's + // of edits_ for the specified external state id. + IdMapIterator GetEditedIdMapIterator(StateId s) const { + return external_to_internal_ids_.find(s); + } + IdMapIterator NotInEditedMap() const { + return external_to_internal_ids_.end(); + } + + FinalWeightIterator GetFinalWeightIterator(StateId s) const { + return edited_final_weights_.find(s); + } + FinalWeightIterator NotInFinalWeightMap() const { + return edited_final_weights_.end(); + } + + // Returns the internal state id of the specified external id if the state has + // already been made editable, or else copies the state from wrapped_ + // to edits_ and returns the state id of the newly editable state in edits_. + // + // \return makes the specified state editable if it isn't already and returns + // its state id in edits_ + StateId GetEditableInternalId(StateId s, const WrappedFstT *wrapped) { + IdMapIterator id_map_it = GetEditedIdMapIterator(s); + if (id_map_it == NotInEditedMap()) { + StateId new_internal_id = edits_.AddState(); + VLOG(2) << "EditFstData::GetEditableInternalId: editing state " << s + << " of original fst; new internal state id:" << new_internal_id; + external_to_internal_ids_[s] = new_internal_id; + for (ArcIterator< Fst<A> > arc_iterator(*wrapped, s); + !arc_iterator.Done(); + arc_iterator.Next()) { + edits_.AddArc(new_internal_id, arc_iterator.Value()); + } + // copy the final weight + FinalWeightIterator final_weight_it = GetFinalWeightIterator(s); + if (final_weight_it == NotInFinalWeightMap()) { + edits_.SetFinal(new_internal_id, wrapped->Final(s)); + } else { + edits_.SetFinal(new_internal_id, final_weight_it->second); + edited_final_weights_.erase(s); + } + return new_internal_id; + } else { + return id_map_it->second; + } + } + + // A mutable fst (by default, a VectorFst) to contain new states, and/or + // copies of states from a wrapped ExpandedFst that have been modified in + // some way. + MutableFstT edits_; + // A mapping from external state id's to the internal id's of states that + // appear in edits_. + unordered_map<StateId, StateId> external_to_internal_ids_; + // A mapping from external state id's to final state weights assigned to + // those states. The states in this map are *only* those whose final weight + // has been modified; if any other part of the state has been modified, + // the entire state is copied to edits_, and all modifications reside there. + unordered_map<StateId, Weight> edited_final_weights_; + // The number of new states added to this mutable fst impl, which is <= the + // number of states in edits_ (since edits_ contains both edited *and* new + // states). + StateId num_new_states_; + RefCounter ref_count_; +}; + +// EditFstData method implementations: just the Read method. +template <typename A, typename WrappedFstT, typename MutableFstT> +EditFstData<A, WrappedFstT, MutableFstT> * +EditFstData<A, WrappedFstT, MutableFstT>::Read(istream &strm, + const FstReadOptions &opts) { + EditFstData<A, WrappedFstT, MutableFstT> *data = + new EditFstData<A, WrappedFstT, MutableFstT>(); + // next read in MutabelFstT machine that stores edits + FstReadOptions edits_opts(opts); + edits_opts.header = 0; // Contained header was written out, so read it in. + + // Because our internal representation of edited states is a solid object + // of type MutableFstT (defaults to VectorFst<A>) and not a pointer, + // and because the static Read method allocates a new object on the heap, + // we need to call Read, check if there was a failure, use + // MutableFstT::operator= to assign the object (not the pointer) to the + // edits_ data member (which will increase the ref count by 1 on the impl) + // and, finally, delete the heap-allocated object. + MutableFstT *edits = MutableFstT::Read(strm, edits_opts); + if (!edits) { + return 0; + } + data->edits_ = *edits; + delete edits; + // finally, read in rest of private data members + ReadType(strm, &data->external_to_internal_ids_); + ReadType(strm, &data->edited_final_weights_); + ReadType(strm, &data->num_new_states_); + if (!strm) { + LOG(ERROR) << "EditFst::Read: read failed: " << opts.source; + return 0; + } + return data; +} + +// This class enables non-destructive edit operations on a wrapped ExpandedFst. +// The implementation uses copy-on-write semantics at the node level: if a user +// has an underlying fst on which he or she wants to perform a relatively small +// number of edits (read: mutations), then this implementation will copy the +// edited node to an internal MutableFst and perform any edits in situ on that +// copied node. This class supports all the methods of MutableFst except for +// DeleteStates(const vector<StateId> &); thus, new nodes may also be added, and +// one may add transitions from existing nodes of the wrapped fst to new nodes. +// +// template parameters: +// A the type of arc to use +// WrappedFstT the type of fst wrapped by the EditFst instance that +// this EditFstImpl instance is backing +// MutableFstT the type of mutable fst to use internally for edited states; +// crucially, MutableFstT::Copy(false) *must* yield an fst that is +// thread-safe for reading (VectorFst, for example, has this property) +template <typename A, + typename WrappedFstT = ExpandedFst<A>, + typename MutableFstT = VectorFst<A> > +class EditFstImpl : public FstImpl<A> { + public: + using FstImpl<A>::SetProperties; + using FstImpl<A>::SetInputSymbols; + using FstImpl<A>::SetOutputSymbols; + using FstImpl<A>::WriteHeader; + + typedef A Arc; + typedef typename Arc::Weight Weight; + typedef typename Arc::StateId StateId; + + // Constructs an editable fst implementation with no states. Effectively, + // this initially-empty fst will in every way mimic the behavior of + // a VectorFst--more precisely, a VectorFstImpl instance--but with slightly + // slower performance (by a constant factor), due to the fact that + // this class maintains a mapping between external state id's and + // their internal equivalents. + EditFstImpl() { + FstImpl<A>::SetType("edit"); + wrapped_ = new MutableFstT(); + InheritPropertiesFromWrapped(); + data_ = new EditFstData<A, WrappedFstT, MutableFstT>(); + } + + // Wraps the specified ExpandedFst. This constructor requires that the + // specified Fst is an ExpandedFst instance. This requirement is only enforced + // at runtime. (See below for the reason.) + // + // This library uses the pointer-to-implementation or "PIMPL" design pattern. + // In particular, to make it convenient to bind an implementation class to its + // interface, there are a pair of template "binder" classes, one for immutable + // and one for mutable fst's (ImplToFst and ImplToMutableFst, respectively). + // As it happens, the API for the ImplToMutableFst<I,F> class requires that + // the implementation class--the template parameter "I"--have a constructor + // taking a const Fst<A> reference. Accordingly, the constructor here must + // perform a static_cast to the WrappedFstT type required by EditFst and + // therefore EditFstImpl. + explicit EditFstImpl(const Fst<A> &wrapped) + : wrapped_(static_cast<WrappedFstT *>(wrapped.Copy())) { + FstImpl<A>::SetType("edit"); + + data_ = new EditFstData<A, WrappedFstT, MutableFstT>(); + // have edits_ inherit all properties from wrapped_ + data_->SetEditedProperties(wrapped_->Properties(kFstProperties, false), + kFstProperties); + InheritPropertiesFromWrapped(); + } + + // A copy constructor for this implementation class, used to implement + // the Copy() method of the Fst interface. + EditFstImpl(const EditFstImpl &impl) + : FstImpl<A>(), + wrapped_(static_cast<WrappedFstT *>(impl.wrapped_->Copy(true))), + data_(impl.data_) { + data_->IncrRefCount(); + SetProperties(impl.Properties()); + } + + ~EditFstImpl() { + delete wrapped_; + if (!data_->DecrRefCount()) { + delete data_; + } + } + + // const Fst/ExpandedFst operations, declared in the Fst and ExpandedFst + // interfaces + StateId Start() const { + StateId edited_start = data_->EditedStart(); + return edited_start == kNoStateId ? wrapped_->Start() : edited_start; + } + + Weight Final(StateId s) const { + return data_->Final(s, wrapped_); + } + + size_t NumArcs(StateId s) const { + return data_->NumArcs(s, wrapped_); + } + + size_t NumInputEpsilons(StateId s) const { + return data_->NumInputEpsilons(s, wrapped_); + } + + size_t NumOutputEpsilons(StateId s) const { + return data_->NumOutputEpsilons(s, wrapped_); + } + + StateId NumStates() const { + return wrapped_->NumStates() + data_->NumNewStates(); + } + + static EditFstImpl<A, WrappedFstT, MutableFstT> * + Read(istream &strm, + const FstReadOptions &opts); + + bool Write(ostream &strm, const FstWriteOptions &opts) const { + FstHeader hdr; + hdr.SetStart(Start()); + hdr.SetNumStates(NumStates()); + FstWriteOptions header_opts(opts); + header_opts.write_isymbols = false; // Let contained FST hold any symbols. + header_opts.write_osymbols = false; + WriteHeader(strm, header_opts, kFileVersion, &hdr); + + // First, serialize wrapped fst to stream. + FstWriteOptions wrapped_opts(opts); + wrapped_opts.write_header = true; // Force writing contained header. + wrapped_->Write(strm, wrapped_opts); + + data_->Write(strm, opts); + + strm.flush(); + if (!strm) { + LOG(ERROR) << "EditFst::Write: write failed: " << opts.source; + return false; + } + return true; + } + // end const Fst operations + + // non-const MutableFst operations + + // Sets the start state for this fst. + void SetStart(StateId s) { + MutateCheck(); + data_->SetStart(s); + SetProperties(SetStartProperties(FstImpl<A>::Properties())); + } + + // Sets the final state for this fst. + void SetFinal(StateId s, Weight w) { + MutateCheck(); + Weight old_weight = data_->SetFinal(s, w, wrapped_); + SetProperties(SetFinalProperties(FstImpl<A>::Properties(), old_weight, w)); + } + + // Adds a new state to this fst, initially with no arcs. + StateId AddState() { + MutateCheck(); + SetProperties(AddStateProperties(FstImpl<A>::Properties())); + return data_->AddState(NumStates()); + } + + // Adds the specified arc to the specified state of this fst. + void AddArc(StateId s, const Arc &arc) { + MutateCheck(); + const A *prev_arc = data_->AddArc(s, arc, wrapped_); + SetProperties(AddArcProperties(FstImpl<A>::Properties(), s, arc, prev_arc)); + } + + void DeleteStates(const vector<StateId>& dstates) { + FSTERROR() << ": EditFstImpl::DeleteStates(const std::vector<StateId>&): " + << " not implemented"; + SetProperties(kError, kError); + } + + // Deletes all states in this fst. + void DeleteStates(); + + // Removes all but the first n outgoing arcs of the specified state. + void DeleteArcs(StateId s, size_t n) { + MutateCheck(); + data_->DeleteArcs(s, n, wrapped_); + SetProperties(DeleteArcsProperties(FstImpl<A>::Properties())); + } + + // Removes all outgoing arcs from the specified state. + void DeleteArcs(StateId s) { + MutateCheck(); + data_->DeleteArcs(s, wrapped_); + SetProperties(DeleteArcsProperties(FstImpl<A>::Properties())); + } + + void ReserveStates(StateId s) { + } + + void ReserveArcs(StateId s, size_t n) { + } + + // end non-const MutableFst operations + + // Provides information for the generic state iterator. + void InitStateIterator(StateIteratorData<Arc> *data) const { + data->base = 0; + data->nstates = NumStates(); + } + + // Provides information for the generic arc iterator. + void InitArcIterator(StateId s, ArcIteratorData<Arc> *data) const { + data_->InitArcIterator(s, data, wrapped_); + } + + // Provides information for the generic mutable arc iterator. + void InitMutableArcIterator(StateId s, MutableArcIteratorData<A> *data) { + MutateCheck(); + data_->InitMutableArcIterator(s, data, wrapped_); + } + + private: + typedef typename unordered_map<StateId, StateId>::const_iterator + IdMapIterator; + typedef typename unordered_map<StateId, Weight>::const_iterator + FinalWeightIterator; + // Properties always true of this Fst class + static const uint64 kStaticProperties = kExpanded | kMutable; + // Current file format version + static const int kFileVersion = 2; + // Minimum file format version supported + static const int kMinFileVersion = 2; + + // Causes this fst to inherit all the properties from its wrapped fst, except + // for the two properties that always apply to EditFst instances: kExpanded + // and kMutable. + void InheritPropertiesFromWrapped() { + SetProperties(wrapped_->Properties(kCopyProperties, false) | + kStaticProperties); + SetInputSymbols(wrapped_->InputSymbols()); + SetOutputSymbols(wrapped_->OutputSymbols()); + } + + // This method ensures that any operations that alter the mutable data + // portion of this EditFstImpl cause the data_ member to be copied when its + // reference count is greater than 1. Note that this method is distinct from + // MutableFst::Mutate, which gets invoked whenever one of the basic mutation + // methods defined in MutableFst is invoked, such as SetInputSymbols. + // The MutateCheck here in EditFstImpl is invoked whenever one of the + // mutating methods specifically related to the types of edits provided + // by EditFst is performed, such as changing an arc of an existing state + // of the wrapped fst via a MutableArcIterator, or adding a new state via + // AddState(). + void MutateCheck() { + if (data_->RefCount() > 1) { + EditFstData<A, WrappedFstT, MutableFstT> *data_copy = + new EditFstData<A, WrappedFstT, MutableFstT>(*data_); + if (data_ && !data_->DecrRefCount()) { + delete data_; + } + data_ = data_copy; + } + } + + // The fst that this fst wraps. The purpose of this class is to enable + // non-destructive edits on this wrapped fst. + const WrappedFstT *wrapped_; + // The mutable data for this EditFst instance, with delegates for all the + // methods that can mutate data. + EditFstData<A, WrappedFstT, MutableFstT> *data_; +}; + +template <typename A, typename WrappedFstT, typename MutableFstT> +const uint64 EditFstImpl<A, WrappedFstT, MutableFstT>::kStaticProperties; + +// EditFstImpl IMPLEMENTATION STARTS HERE + +template<typename A, typename WrappedFstT, typename MutableFstT> +inline void EditFstImpl<A, WrappedFstT, MutableFstT>::DeleteStates() { + data_->DeleteStates(); + delete wrapped_; + // we are deleting all states, so just forget about pointer to wrapped_ + // and do what default constructor does: set wrapped_ to a new VectorFst + wrapped_ = new MutableFstT(); + uint64 newProps = DeleteAllStatesProperties(FstImpl<A>::Properties(), + kStaticProperties); + FstImpl<A>::SetProperties(newProps); +} + +template <typename A, typename WrappedFstT, typename MutableFstT> +EditFstImpl<A, WrappedFstT, MutableFstT> * +EditFstImpl<A, WrappedFstT, MutableFstT>::Read(istream &strm, + const FstReadOptions &opts) { + EditFstImpl<A, WrappedFstT, MutableFstT> *impl = new EditFstImpl(); + FstHeader hdr; + if (!impl->ReadHeader(strm, opts, kMinFileVersion, &hdr)) { + return 0; + } + impl->SetStart(hdr.Start()); + + // first, read in wrapped fst + FstReadOptions wrapped_opts(opts); + wrapped_opts.header = 0; // Contained header was written out, so read it in. + Fst<A> *wrapped_fst = Fst<A>::Read(strm, wrapped_opts); + if (!wrapped_fst) { + return 0; + } + impl->wrapped_ = static_cast<WrappedFstT *>(wrapped_fst); + + impl->data_ = EditFstData<A, WrappedFstT, MutableFstT>::Read(strm, opts); + + if (!impl->data_) { + delete wrapped_fst; + return 0; + } + + return impl; +} + +// END EditFstImpl IMPLEMENTATION + +// Concrete, editable FST. This class attaches interface to implementation. +template <typename A, + typename WrappedFstT = ExpandedFst<A>, + typename MutableFstT = VectorFst<A> > +class EditFst : + public ImplToMutableFst< EditFstImpl<A, WrappedFstT, MutableFstT> > { + public: + friend class MutableArcIterator< EditFst<A, WrappedFstT, MutableFstT> >; + + typedef A Arc; + typedef typename A::StateId StateId; + typedef EditFstImpl<A, WrappedFstT, MutableFstT> Impl; + + EditFst() : ImplToMutableFst<Impl>(new Impl()) {} + + explicit EditFst(const Fst<A> &fst) : + ImplToMutableFst<Impl>(new Impl(fst)) {} + + explicit EditFst(const WrappedFstT &fst) : + ImplToMutableFst<Impl>(new Impl(fst)) {} + + // See Fst<>::Copy() for doc. + EditFst(const EditFst<A, WrappedFstT, MutableFstT> &fst, bool safe = false) : + ImplToMutableFst<Impl>(fst, safe) {} + + virtual ~EditFst() {} + + // Get a copy of this EditFst. See Fst<>::Copy() for further doc. + virtual EditFst<A, WrappedFstT, MutableFstT> *Copy(bool safe = false) const { + return new EditFst<A, WrappedFstT, MutableFstT>(*this, safe); + } + + EditFst<A, WrappedFstT, MutableFstT> & + operator=(const EditFst<A, WrappedFstT, MutableFstT> &fst) { + SetImpl(fst.GetImpl(), false); + return *this; + } + + virtual EditFst<A, WrappedFstT, MutableFstT> &operator=(const Fst<A> &fst) { + if (this != &fst) { + SetImpl(new Impl(fst)); + } + return *this; + } + + // Read an EditFst from an input stream; return NULL on error. + static EditFst<A, WrappedFstT, MutableFstT> * + Read(istream &strm, + const FstReadOptions &opts) { + Impl* impl = Impl::Read(strm, opts); + return impl ? new EditFst<A>(impl) : 0; + } + + // Read an EditFst from a file; return NULL on error. + // Empty filename reads from standard input. + static EditFst<A, WrappedFstT, MutableFstT> *Read(const string &filename) { + Impl* impl = ImplToExpandedFst<Impl, MutableFst<A> >::Read(filename); + return impl ? new EditFst<A, WrappedFstT, MutableFstT>(impl) : 0; + } + + virtual bool Write(ostream &strm, const FstWriteOptions &opts) const { + return GetImpl()->Write(strm, opts); + } + + virtual bool Write(const string &filename) const { + return Fst<A>::WriteFile(filename); + } + + virtual void InitStateIterator(StateIteratorData<Arc> *data) const { + GetImpl()->InitStateIterator(data); + } + + virtual void InitArcIterator(StateId s, ArcIteratorData<Arc> *data) const { + GetImpl()->InitArcIterator(s, data); + } + + virtual + void InitMutableArcIterator(StateId s, MutableArcIteratorData<A> *data) { + GetImpl()->InitMutableArcIterator(s, data); + } + private: + explicit EditFst(Impl *impl) : ImplToMutableFst<Impl>(impl) {} + + // Makes visible to friends. + Impl *GetImpl() const { return ImplToFst< Impl, MutableFst<A> >::GetImpl(); } + + void SetImpl(Impl *impl, bool own_impl = true) { + ImplToFst< Impl, MutableFst<A> >::SetImpl(impl, own_impl); + } +}; + +} // namespace fst + +#endif // FST_LIB_EDIT_FST_H_ diff --git a/kaldi_io/src/tools/openfst/include/fst/encode.h b/kaldi_io/src/tools/openfst/include/fst/encode.h new file mode 100644 index 0000000..08b84cb --- /dev/null +++ b/kaldi_io/src/tools/openfst/include/fst/encode.h @@ -0,0 +1,599 @@ +// encode.h + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Copyright 2005-2010 Google, Inc. +// Author: [email protected] (Johan Schalkwyk) +// +// \file +// Class to encode and decoder an fst. + +#ifndef FST_LIB_ENCODE_H__ +#define FST_LIB_ENCODE_H__ + +#include <climits> +#include <tr1/unordered_map> +using std::tr1::unordered_map; +using std::tr1::unordered_multimap; +#include <string> +#include <vector> +using std::vector; + +#include <fst/arc-map.h> +#include <fst/rmfinalepsilon.h> + + +namespace fst { + +static const uint32 kEncodeLabels = 0x0001; +static const uint32 kEncodeWeights = 0x0002; +static const uint32 kEncodeFlags = 0x0003; // All non-internal flags + +static const uint32 kEncodeHasISymbols = 0x0004; // For internal use +static const uint32 kEncodeHasOSymbols = 0x0008; // For internal use + +enum EncodeType { ENCODE = 1, DECODE = 2 }; + +// Identifies stream data as an encode table (and its endianity) +static const int32 kEncodeMagicNumber = 2129983209; + + +// The following class encapsulates implementation details for the +// encoding and decoding of label/weight tuples used for encoding +// and decoding of Fsts. The EncodeTable is bidirectional. I.E it +// stores both the Tuple of encode labels and weights to a unique +// label, and the reverse. +template <class A> class EncodeTable { + public: + typedef typename A::Label Label; + typedef typename A::Weight Weight; + + // Encoded data consists of arc input/output labels and arc weight + struct Tuple { + Tuple() {} + Tuple(Label ilabel_, Label olabel_, Weight weight_) + : ilabel(ilabel_), olabel(olabel_), weight(weight_) {} + Tuple(const Tuple& tuple) + : ilabel(tuple.ilabel), olabel(tuple.olabel), weight(tuple.weight) {} + + Label ilabel; + Label olabel; + Weight weight; + }; + + // Comparison object for hashing EncodeTable Tuple(s). + class TupleEqual { + public: + bool operator()(const Tuple* x, const Tuple* y) const { + return (x->ilabel == y->ilabel && + x->olabel == y->olabel && + x->weight == y->weight); + } + }; + + // Hash function for EncodeTabe Tuples. Based on the encode flags + // we either hash the labels, weights or combination of them. + class TupleKey { + public: + TupleKey() + : encode_flags_(kEncodeLabels | kEncodeWeights) {} + + TupleKey(const TupleKey& key) + : encode_flags_(key.encode_flags_) {} + + explicit TupleKey(uint32 encode_flags) + : encode_flags_(encode_flags) {} + + size_t operator()(const Tuple* x) const { + size_t hash = x->ilabel; + const int lshift = 5; + const int rshift = CHAR_BIT * sizeof(size_t) - 5; + if (encode_flags_ & kEncodeLabels) + hash = hash << lshift ^ hash >> rshift ^ x->olabel; + if (encode_flags_ & kEncodeWeights) + hash = hash << lshift ^ hash >> rshift ^ x->weight.Hash(); + return hash; + } + + private: + int32 encode_flags_; + }; + + typedef unordered_map<const Tuple*, + Label, + TupleKey, + TupleEqual> EncodeHash; + + explicit EncodeTable(uint32 encode_flags) + : flags_(encode_flags), + encode_hash_(1024, TupleKey(encode_flags)), + isymbols_(0), osymbols_(0) {} + + ~EncodeTable() { + for (size_t i = 0; i < encode_tuples_.size(); ++i) { + delete encode_tuples_[i]; + } + delete isymbols_; + delete osymbols_; + } + + // Given an arc encode either input/ouptut labels or input/costs or both + Label Encode(const A &arc) { + const Tuple tuple(arc.ilabel, + flags_ & kEncodeLabels ? arc.olabel : 0, + flags_ & kEncodeWeights ? arc.weight : Weight::One()); + typename EncodeHash::const_iterator it = encode_hash_.find(&tuple); + if (it == encode_hash_.end()) { + encode_tuples_.push_back(new Tuple(tuple)); + encode_hash_[encode_tuples_.back()] = encode_tuples_.size(); + return encode_tuples_.size(); + } else { + return it->second; + } + } + + // Given an arc, look up its encoded label. Returns kNoLabel if not found. + Label GetLabel(const A &arc) const { + const Tuple tuple(arc.ilabel, + flags_ & kEncodeLabels ? arc.olabel : 0, + flags_ & kEncodeWeights ? arc.weight : Weight::One()); + typename EncodeHash::const_iterator it = encode_hash_.find(&tuple); + if (it == encode_hash_.end()) { + return kNoLabel; + } else { + return it->second; + } + } + + // Given an encode arc Label decode back to input/output labels and costs + const Tuple* Decode(Label key) const { + if (key < 1 || key > encode_tuples_.size()) { + LOG(ERROR) << "EncodeTable::Decode: unknown decode key: " << key; + return 0; + } + return encode_tuples_[key - 1]; + } + + size_t Size() const { return encode_tuples_.size(); } + + bool Write(ostream &strm, const string &source) const; + + static EncodeTable<A> *Read(istream &strm, const string &source); + + const uint32 flags() const { return flags_ & kEncodeFlags; } + + int RefCount() const { return ref_count_.count(); } + int IncrRefCount() { return ref_count_.Incr(); } + int DecrRefCount() { return ref_count_.Decr(); } + + + SymbolTable *InputSymbols() const { return isymbols_; } + + SymbolTable *OutputSymbols() const { return osymbols_; } + + void SetInputSymbols(const SymbolTable* syms) { + if (isymbols_) delete isymbols_; + if (syms) { + isymbols_ = syms->Copy(); + flags_ |= kEncodeHasISymbols; + } else { + isymbols_ = 0; + flags_ &= ~kEncodeHasISymbols; + } + } + + void SetOutputSymbols(const SymbolTable* syms) { + if (osymbols_) delete osymbols_; + if (syms) { + osymbols_ = syms->Copy(); + flags_ |= kEncodeHasOSymbols; + } else { + osymbols_ = 0; + flags_ &= ~kEncodeHasOSymbols; + } + } + + private: + uint32 flags_; + vector<Tuple*> encode_tuples_; + EncodeHash encode_hash_; + RefCounter ref_count_; + SymbolTable *isymbols_; // Pre-encoded ilabel symbol table + SymbolTable *osymbols_; // Pre-encoded olabel symbol table + + DISALLOW_COPY_AND_ASSIGN(EncodeTable); +}; + +template <class A> inline +bool EncodeTable<A>::Write(ostream &strm, const string &source) const { + WriteType(strm, kEncodeMagicNumber); + WriteType(strm, flags_); + int64 size = encode_tuples_.size(); + WriteType(strm, size); + for (size_t i = 0; i < size; ++i) { + const Tuple* tuple = encode_tuples_[i]; + WriteType(strm, tuple->ilabel); + WriteType(strm, tuple->olabel); + tuple->weight.Write(strm); + } + + if (flags_ & kEncodeHasISymbols) + isymbols_->Write(strm); + + if (flags_ & kEncodeHasOSymbols) + osymbols_->Write(strm); + + strm.flush(); + if (!strm) { + LOG(ERROR) << "EncodeTable::Write: write failed: " << source; + return false; + } + return true; +} + +template <class A> inline +EncodeTable<A> *EncodeTable<A>::Read(istream &strm, const string &source) { + int32 magic_number = 0; + ReadType(strm, &magic_number); + if (magic_number != kEncodeMagicNumber) { + LOG(ERROR) << "EncodeTable::Read: Bad encode table header: " << source; + return 0; + } + uint32 flags; + ReadType(strm, &flags); + EncodeTable<A> *table = new EncodeTable<A>(flags); + + int64 size; + ReadType(strm, &size); + if (!strm) { + LOG(ERROR) << "EncodeTable::Read: read failed: " << source; + return 0; + } + + for (size_t i = 0; i < size; ++i) { + Tuple* tuple = new Tuple(); + ReadType(strm, &tuple->ilabel); + ReadType(strm, &tuple->olabel); + tuple->weight.Read(strm); + if (!strm) { + LOG(ERROR) << "EncodeTable::Read: read failed: " << source; + return 0; + } + table->encode_tuples_.push_back(tuple); + table->encode_hash_[table->encode_tuples_.back()] = + table->encode_tuples_.size(); + } + + if (flags & kEncodeHasISymbols) + table->isymbols_ = SymbolTable::Read(strm, source); + + if (flags & kEncodeHasOSymbols) + table->osymbols_ = SymbolTable::Read(strm, source); + + return table; +} + + +// A mapper to encode/decode weighted transducers. Encoding of an +// Fst is useful for performing classical determinization or minimization +// on a weighted transducer by treating it as an unweighted acceptor over +// encoded labels. +// +// The Encode mapper stores the encoding in a local hash table (EncodeTable) +// This table is shared (and reference counted) between the encoder and +// decoder. A decoder has read only access to the EncodeTable. +// +// The EncodeMapper allows on the fly encoding of the machine. As the +// EncodeTable is generated the same table may by used to decode the machine +// on the fly. For example in the following sequence of operations +// +// Encode -> Determinize -> Decode +// +// we will use the encoding table generated during the encode step in the +// decode, even though the encoding is not complete. +// +template <class A> class EncodeMapper { + typedef typename A::Weight Weight; + typedef typename A::Label Label; + public: + EncodeMapper(uint32 flags, EncodeType type) + : flags_(flags), + type_(type), + table_(new EncodeTable<A>(flags)), + error_(false) {} + + EncodeMapper(const EncodeMapper& mapper) + : flags_(mapper.flags_), + type_(mapper.type_), + table_(mapper.table_), + error_(false) { + table_->IncrRefCount(); + } + + // Copy constructor but setting the type, typically to DECODE + EncodeMapper(const EncodeMapper& mapper, EncodeType type) + : flags_(mapper.flags_), + type_(type), + table_(mapper.table_), + error_(mapper.error_) { + table_->IncrRefCount(); + } + + ~EncodeMapper() { + if (!table_->DecrRefCount()) delete table_; + } + + A operator()(const A &arc); + + MapFinalAction FinalAction() const { + return (type_ == ENCODE && (flags_ & kEncodeWeights)) ? + MAP_REQUIRE_SUPERFINAL : MAP_NO_SUPERFINAL; + } + + MapSymbolsAction InputSymbolsAction() const { return MAP_CLEAR_SYMBOLS; } + + MapSymbolsAction OutputSymbolsAction() const { return MAP_CLEAR_SYMBOLS;} + + uint64 Properties(uint64 inprops) { + uint64 outprops = inprops; + if (error_) outprops |= kError; + + uint64 mask = kFstProperties; + if (flags_ & kEncodeLabels) + mask &= kILabelInvariantProperties & kOLabelInvariantProperties; + if (flags_ & kEncodeWeights) + mask &= kILabelInvariantProperties & kWeightInvariantProperties & + (type_ == ENCODE ? kAddSuperFinalProperties : + kRmSuperFinalProperties); + + return outprops & mask; + } + + const uint32 flags() const { return flags_; } + const EncodeType type() const { return type_; } + const EncodeTable<A> &table() const { return *table_; } + + bool Write(ostream &strm, const string& source) { + return table_->Write(strm, source); + } + + bool Write(const string& filename) { + ofstream strm(filename.c_str(), ofstream::out | ofstream::binary); + if (!strm) { + LOG(ERROR) << "EncodeMap: Can't open file: " << filename; + return false; + } + return Write(strm, filename); + } + + static EncodeMapper<A> *Read(istream &strm, + const string& source, + EncodeType type = ENCODE) { + EncodeTable<A> *table = EncodeTable<A>::Read(strm, source); + return table ? new EncodeMapper(table->flags(), type, table) : 0; + } + + static EncodeMapper<A> *Read(const string& filename, + EncodeType type = ENCODE) { + ifstream strm(filename.c_str(), ifstream::in | ifstream::binary); + if (!strm) { + LOG(ERROR) << "EncodeMap: Can't open file: " << filename; + return NULL; + } + return Read(strm, filename, type); + } + + SymbolTable *InputSymbols() const { return table_->InputSymbols(); } + + SymbolTable *OutputSymbols() const { return table_->OutputSymbols(); } + + void SetInputSymbols(const SymbolTable* syms) { + table_->SetInputSymbols(syms); + } + + void SetOutputSymbols(const SymbolTable* syms) { + table_->SetOutputSymbols(syms); + } + + private: + uint32 flags_; + EncodeType type_; + EncodeTable<A>* table_; + bool error_; + + explicit EncodeMapper(uint32 flags, EncodeType type, EncodeTable<A> *table) + : flags_(flags), type_(type), table_(table) {} + void operator=(const EncodeMapper &); // Disallow. +}; + +template <class A> inline +A EncodeMapper<A>::operator()(const A &arc) { + if (type_ == ENCODE) { // labels and/or weights to single label + if ((arc.nextstate == kNoStateId && !(flags_ & kEncodeWeights)) || + (arc.nextstate == kNoStateId && (flags_ & kEncodeWeights) && + arc.weight == Weight::Zero())) { + return arc; + } else { + Label label = table_->Encode(arc); + return A(label, + flags_ & kEncodeLabels ? label : arc.olabel, + flags_ & kEncodeWeights ? Weight::One() : arc.weight, + arc.nextstate); + } + } else { // type_ == DECODE + if (arc.nextstate == kNoStateId) { + return arc; + } else { + if (arc.ilabel == 0) return arc; + if (flags_ & kEncodeLabels && arc.ilabel != arc.olabel) { + FSTERROR() << "EncodeMapper: Label-encoded arc has different " + "input and output labels"; + error_ = true; + } + if (flags_ & kEncodeWeights && arc.weight != Weight::One()) { + FSTERROR() << + "EncodeMapper: Weight-encoded arc has non-trivial weight"; + error_ = true; + } + const typename EncodeTable<A>::Tuple* tuple = table_->Decode(arc.ilabel); + if (!tuple) { + FSTERROR() << "EncodeMapper: decode failed"; + error_ = true; + return A(kNoLabel, kNoLabel, Weight::NoWeight(), arc.nextstate); + } else { + return A(tuple->ilabel, + flags_ & kEncodeLabels ? tuple->olabel : arc.olabel, + flags_ & kEncodeWeights ? tuple->weight : arc.weight, + arc.nextstate); + } + } + } +} + + +// Complexity: O(nstates + narcs) +template<class A> inline +void Encode(MutableFst<A> *fst, EncodeMapper<A>* mapper) { + mapper->SetInputSymbols(fst->InputSymbols()); + mapper->SetOutputSymbols(fst->OutputSymbols()); + ArcMap(fst, mapper); +} + +template<class A> inline +void Decode(MutableFst<A>* fst, const EncodeMapper<A>& mapper) { + ArcMap(fst, EncodeMapper<A>(mapper, DECODE)); + RmFinalEpsilon(fst); + fst->SetInputSymbols(mapper.InputSymbols()); + fst->SetOutputSymbols(mapper.OutputSymbols()); +} + + +// On the fly label and/or weight encoding of input Fst +// +// Complexity: +// - Constructor: O(1) +// - Traversal: O(nstates_visited + narcs_visited), assuming constant +// time to visit an input state or arc. +template <class A> +class EncodeFst : public ArcMapFst<A, A, EncodeMapper<A> > { + public: + typedef A Arc; + typedef EncodeMapper<A> C; + typedef ArcMapFstImpl< A, A, EncodeMapper<A> > Impl; + using ImplToFst<Impl>::GetImpl; + + EncodeFst(const Fst<A> &fst, EncodeMapper<A>* encoder) + : ArcMapFst<A, A, C>(fst, encoder, ArcMapFstOptions()) { + encoder->SetInputSymbols(fst.InputSymbols()); + encoder->SetOutputSymbols(fst.OutputSymbols()); + } + + EncodeFst(const Fst<A> &fst, const EncodeMapper<A>& encoder) + : ArcMapFst<A, A, C>(fst, encoder, ArcMapFstOptions()) {} + + // See Fst<>::Copy() for doc. + EncodeFst(const EncodeFst<A> &fst, bool copy = false) + : ArcMapFst<A, A, C>(fst, copy) {} + + // Get a copy of this EncodeFst. See Fst<>::Copy() for further doc. + virtual EncodeFst<A> *Copy(bool safe = false) const { + if (safe) { + FSTERROR() << "EncodeFst::Copy(true): not allowed."; + GetImpl()->SetProperties(kError, kError); + } + return new EncodeFst(*this); + } +}; + + +// On the fly label and/or weight encoding of input Fst +// +// Complexity: +// - Constructor: O(1) +// - Traversal: O(nstates_visited + narcs_visited), assuming constant +// time to visit an input state or arc. +template <class A> +class DecodeFst : public ArcMapFst<A, A, EncodeMapper<A> > { + public: + typedef A Arc; + typedef EncodeMapper<A> C; + typedef ArcMapFstImpl< A, A, EncodeMapper<A> > Impl; + using ImplToFst<Impl>::GetImpl; + + DecodeFst(const Fst<A> &fst, const EncodeMapper<A>& encoder) + : ArcMapFst<A, A, C>(fst, + EncodeMapper<A>(encoder, DECODE), + ArcMapFstOptions()) { + GetImpl()->SetInputSymbols(encoder.InputSymbols()); + GetImpl()->SetOutputSymbols(encoder.OutputSymbols()); + } + + // See Fst<>::Copy() for doc. + DecodeFst(const DecodeFst<A> &fst, bool safe = false) + : ArcMapFst<A, A, C>(fst, safe) {} + + // Get a copy of this DecodeFst. See Fst<>::Copy() for further doc. + virtual DecodeFst<A> *Copy(bool safe = false) const { + return new DecodeFst(*this, safe); + } +}; + + +// Specialization for EncodeFst. +template <class A> +class StateIterator< EncodeFst<A> > + : public StateIterator< ArcMapFst<A, A, EncodeMapper<A> > > { + public: + explicit StateIterator(const EncodeFst<A> &fst) + : StateIterator< ArcMapFst<A, A, EncodeMapper<A> > >(fst) {} +}; + + +// Specialization for EncodeFst. +template <class A> +class ArcIterator< EncodeFst<A> > + : public ArcIterator< ArcMapFst<A, A, EncodeMapper<A> > > { + public: + ArcIterator(const EncodeFst<A> &fst, typename A::StateId s) + : ArcIterator< ArcMapFst<A, A, EncodeMapper<A> > >(fst, s) {} +}; + + +// Specialization for DecodeFst. +template <class A> +class StateIterator< DecodeFst<A> > + : public StateIterator< ArcMapFst<A, A, EncodeMapper<A> > > { + public: + explicit StateIterator(const DecodeFst<A> &fst) + : StateIterator< ArcMapFst<A, A, EncodeMapper<A> > >(fst) {} +}; + + +// Specialization for DecodeFst. +template <class A> +class ArcIterator< DecodeFst<A> > + : public ArcIterator< ArcMapFst<A, A, EncodeMapper<A> > > { + public: + ArcIterator(const DecodeFst<A> &fst, typename A::StateId s) + : ArcIterator< ArcMapFst<A, A, EncodeMapper<A> > >(fst, s) {} +}; + + +// Useful aliases when using StdArc. +typedef EncodeFst<StdArc> StdEncodeFst; + +typedef DecodeFst<StdArc> StdDecodeFst; + +} // namespace fst + +#endif // FST_LIB_ENCODE_H__ diff --git a/kaldi_io/src/tools/openfst/include/fst/epsnormalize.h b/kaldi_io/src/tools/openfst/include/fst/epsnormalize.h new file mode 100644 index 0000000..9d178b1 --- /dev/null +++ b/kaldi_io/src/tools/openfst/include/fst/epsnormalize.h @@ -0,0 +1,73 @@ +// epsnormalize.h + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Copyright 2005-2010 Google, Inc. +// Author: [email protected] (Cyril Allauzen) +// +// \file +// Function that implements epsilon normalization. + +#ifndef FST_LIB_EPSNORMALIZE_H__ +#define FST_LIB_EPSNORMALIZE_H__ + +#include <tr1/unordered_map> +using std::tr1::unordered_map; +using std::tr1::unordered_multimap; + + +#include <fst/factor-weight.h> +#include <fst/invert.h> +#include <fst/arc-map.h> +#include <fst/rmepsilon.h> + + +namespace fst { + +enum EpsNormalizeType {EPS_NORM_INPUT, EPS_NORM_OUTPUT}; + +// Returns an equivalent FST that is epsilon-normalized. An acceptor is +// epsilon-normalized if it is epsilon-removed. A transducer is input +// epsilon-normalized if additionally if on each path any epsilon input +// label follows all non-epsilon input labels. Output epsilon-normalized +// is defined similarly. +// +// The input FST needs to be functional. +// +// References: +// - Mehryar Mohri. "Generic epsilon-removal and input epsilon-normalization +// algorithms for weighted transducers", International Journal of Computer +// Science, 13(1): 129-143, 2002. +template <class Arc> +void EpsNormalize(const Fst<Arc> &ifst, MutableFst<Arc> *ofst, + EpsNormalizeType type = EPS_NORM_INPUT) { + VectorFst< GallicArc<Arc, STRING_RIGHT_RESTRICT> > gfst; + if (type == EPS_NORM_INPUT) + ArcMap(ifst, &gfst, ToGallicMapper<Arc, STRING_RIGHT_RESTRICT>()); + else // type == EPS_NORM_OUTPUT + ArcMap(InvertFst<Arc>(ifst), &gfst, + ToGallicMapper<Arc, STRING_RIGHT_RESTRICT>()); + RmEpsilon(&gfst); + FactorWeightFst< GallicArc<Arc, STRING_RIGHT_RESTRICT>, + GallicFactor<typename Arc::Label, + typename Arc::Weight, STRING_RIGHT_RESTRICT> > + fwfst(gfst); + ArcMap(fwfst, ofst, FromGallicMapper<Arc, STRING_RIGHT_RESTRICT>()); + ofst->SetOutputSymbols(ifst.OutputSymbols()); + if(type == EPS_NORM_OUTPUT) + Invert(ofst); +} + +} // namespace fst + +#endif // FST_LIB_EPSNORMALIZE_H__ diff --git a/kaldi_io/src/tools/openfst/include/fst/equal.h b/kaldi_io/src/tools/openfst/include/fst/equal.h new file mode 100644 index 0000000..33be198 --- /dev/null +++ b/kaldi_io/src/tools/openfst/include/fst/equal.h @@ -0,0 +1,124 @@ +// test.h + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Copyright 2005-2010 Google, Inc. +// Author: [email protected] (Michael Riley) +// +// \file +// Function to test equality of two Fsts. + +#ifndef FST_LIB_EQUAL_H__ +#define FST_LIB_EQUAL_H__ + +#include <fst/fst.h> + + +namespace fst { + +// Tests if two Fsts have the same states and arcs in the same order. +template<class Arc> +bool Equal(const Fst<Arc> &fst1, const Fst<Arc> &fst2, float delta = kDelta) { + typedef typename Arc::StateId StateId; + typedef typename Arc::Weight Weight; + + if (fst1.Start() != fst2.Start()) { + VLOG(1) << "Equal: mismatched start states"; + return false; + } + + StateIterator< Fst<Arc> > siter1(fst1); + StateIterator< Fst<Arc> > siter2(fst2); + + while (!siter1.Done() || !siter2.Done()) { + if (siter1.Done() || siter2.Done()) { + VLOG(1) << "Equal: mismatched # of states"; + return false; + } + StateId s1 = siter1.Value(); + StateId s2 = siter2.Value(); + if (s1 != s2) { + VLOG(1) << "Equal: mismatched states:" + << ", state1 = " << s1 + << ", state2 = " << s2; + return false; + } + Weight final1 = fst1.Final(s1); + Weight final2 = fst2.Final(s2); + if (!ApproxEqual(final1, final2, delta)) { + VLOG(1) << "Equal: mismatched final weights:" + << " state = " << s1 + << ", final1 = " << final1 + << ", final2 = " << final2; + return false; + } + ArcIterator< Fst<Arc> > aiter1(fst1, s1); + ArcIterator< Fst<Arc> > aiter2(fst2, s2); + for (size_t a = 0; !aiter1.Done() || !aiter2.Done(); ++a) { + if (aiter1.Done() || aiter2.Done()) { + VLOG(1) << "Equal: mismatched # of arcs" + << " state = " << s1; + return false; + } + Arc arc1 = aiter1.Value(); + Arc arc2 = aiter2.Value(); + if (arc1.ilabel != arc2.ilabel) { + VLOG(1) << "Equal: mismatched arc input labels:" + << " state = " << s1 + << ", arc = " << a + << ", ilabel1 = " << arc1.ilabel + << ", ilabel2 = " << arc2.ilabel; + return false; + } else if (arc1.olabel != arc2.olabel) { + VLOG(1) << "Equal: mismatched arc output labels:" + << " state = " << s1 + << ", arc = " << a + << ", olabel1 = " << arc1.olabel + << ", olabel2 = " << arc2.olabel; + return false; + } else if (!ApproxEqual(arc1.weight, arc2.weight, delta)) { + VLOG(1) << "Equal: mismatched arc weights:" + << " state = " << s1 + << ", arc = " << a + << ", weight1 = " << arc1.weight + << ", weight2 = " << arc2.weight; + return false; + } else if (arc1.nextstate != arc2.nextstate) { + VLOG(1) << "Equal: mismatched input label:" + << " state = " << s1 + << ", arc = " << a + << ", nextstate1 = " << arc1.nextstate + << ", nextstate2 = " << arc2.nextstate; + return false; + } + aiter1.Next(); + aiter2.Next(); + + } + // Sanity checks: should never fail + if (fst1.NumArcs(s1) != fst2.NumArcs(s2) || + fst1.NumInputEpsilons(s1) != fst2.NumInputEpsilons(s2) || + fst1.NumOutputEpsilons(s1) != fst2.NumOutputEpsilons(s2)) { + FSTERROR() << "Equal: inconsistent arc/epsilon counts"; + } + + siter1.Next(); + siter2.Next(); + } + return true; +} + +} // namespace fst + + +#endif // FST_LIB_EQUAL_H__ diff --git a/kaldi_io/src/tools/openfst/include/fst/equivalent.h b/kaldi_io/src/tools/openfst/include/fst/equivalent.h new file mode 100644 index 0000000..e28fea1 --- /dev/null +++ b/kaldi_io/src/tools/openfst/include/fst/equivalent.h @@ -0,0 +1,275 @@ +// equivalent.h + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Copyright 2005-2010 Google, Inc. +// Author: [email protected] (Wojciech Skut) +// +// \file Functions and classes to determine the equivalence of two +// FSTs. + +#ifndef FST_LIB_EQUIVALENT_H__ +#define FST_LIB_EQUIVALENT_H__ + +#include <algorithm> +#include <deque> +using std::deque; +#include <tr1/unordered_map> +using std::tr1::unordered_map; +using std::tr1::unordered_multimap; +#include <utility> +using std::pair; using std::make_pair; +#include <vector> +using std::vector; + +#include <fst/encode.h> +#include <fst/push.h> +#include <fst/union-find.h> +#include <fst/vector-fst.h> + + +namespace fst { + +// Traits-like struct holding utility functions/typedefs/constants for +// the equivalence algorithm. +// +// Encoding device: in order to make the statesets of the two acceptors +// disjoint, we map Arc::StateId on the type MappedId. The states of +// the first acceptor are mapped on odd numbers (s -> 2s + 1), and +// those of the second one on even numbers (s -> 2s + 2). The number 0 +// is reserved for an implicit (non-final) 'dead state' (required for +// the correct treatment of non-coaccessible states; kNoStateId is +// mapped to kDeadState for both acceptors). The union-find algorithm +// operates on the mapped IDs. +template <class Arc> +struct EquivalenceUtil { + typedef typename Arc::StateId StateId; + typedef typename Arc::Weight Weight; + typedef StateId MappedId; // ID for an equivalence class. + + // MappedId for an implicit dead state. + static const MappedId kDeadState = 0; + + // MappedId for lookup failure. + static const MappedId kInvalidId = -1; + + // Maps state ID to the representative of the corresponding + // equivalence class. The parameter 'which_fst' takes the values 1 + // and 2, identifying the input FST. + static MappedId MapState(StateId s, int32 which_fst) { + return + (kNoStateId == s) + ? + kDeadState + : + (static_cast<MappedId>(s) << 1) + which_fst; + } + // Maps set ID to State ID. + static StateId UnMapState(MappedId id) { + return static_cast<StateId>((--id) >> 1); + } + // Convenience function: checks if state with MappedId 's' is final + // in acceptor 'fa'. + static bool IsFinal(const Fst<Arc> &fa, MappedId s) { + return + (kDeadState == s) ? + false : (fa.Final(UnMapState(s)) != Weight::Zero()); + } + // Convenience function: returns the representative of 'id' in 'sets', + // creating a new set if needed. + static MappedId FindSet(UnionFind<MappedId> *sets, MappedId id) { + MappedId repr = sets->FindSet(id); + if (repr != kInvalidId) { + return repr; + } else { + sets->MakeSet(id); + return id; + } + } +}; + +template <class Arc> const +typename EquivalenceUtil<Arc>::MappedId EquivalenceUtil<Arc>::kDeadState; + +template <class Arc> const +typename EquivalenceUtil<Arc>::MappedId EquivalenceUtil<Arc>::kInvalidId; + + +// Equivalence checking algorithm: determines if the two FSTs +// <code>fst1</code> and <code>fst2</code> are equivalent. The input +// FSTs must be deterministic input-side epsilon-free acceptors, +// unweighted or with weights over a left semiring. Two acceptors are +// considered equivalent if they accept exactly the same set of +// strings (with the same weights). +// +// The algorithm (cf. Aho, Hopcroft and Ullman, "The Design and +// Analysis of Computer Programs") successively constructs sets of +// states that can be reached by the same prefixes, starting with a +// set containing the start states of both acceptors. A disjoint tree +// forest (the union-find algorithm) is used to represent the sets of +// states. The algorithm returns 'false' if one of the constructed +// sets contains both final and non-final states. Returns optional error +// value (when FLAGS_error_fatal = false). +// +// Complexity: quasi-linear, i.e. O(n G(n)), where +// n = |S1| + |S2| is the number of states in both acceptors +// G(n) is a very slowly growing function that can be approximated +// by 4 by all practical purposes. +// +template <class Arc> +bool Equivalent(const Fst<Arc> &fst1, + const Fst<Arc> &fst2, + double delta = kDelta, bool *error = 0) { + typedef typename Arc::Weight Weight; + if (error) *error = false; + + // Check that the symbol table are compatible + if (!CompatSymbols(fst1.InputSymbols(), fst2.InputSymbols()) || + !CompatSymbols(fst1.OutputSymbols(), fst2.OutputSymbols())) { + FSTERROR() << "Equivalent: input/output symbol tables of 1st argument " + << "do not match input/output symbol tables of 2nd argument"; + if (error) *error = true; + return false; + } + // Check properties first: + uint64 props = kNoEpsilons | kIDeterministic | kAcceptor; + if (fst1.Properties(props, true) != props) { + FSTERROR() << "Equivalent: first argument not an" + << " epsilon-free deterministic acceptor"; + if (error) *error = true; + return false; + } + if (fst2.Properties(props, true) != props) { + FSTERROR() << "Equivalent: second argument not an" + << " epsilon-free deterministic acceptor"; + if (error) *error = true; + return false; + } + + if ((fst1.Properties(kUnweighted , true) != kUnweighted) + || (fst2.Properties(kUnweighted , true) != kUnweighted)) { + VectorFst<Arc> efst1(fst1); + VectorFst<Arc> efst2(fst2); + Push(&efst1, REWEIGHT_TO_INITIAL, delta); + Push(&efst2, REWEIGHT_TO_INITIAL, delta); + ArcMap(&efst1, QuantizeMapper<Arc>(delta)); + ArcMap(&efst2, QuantizeMapper<Arc>(delta)); + EncodeMapper<Arc> mapper(kEncodeWeights|kEncodeLabels, ENCODE); + ArcMap(&efst1, &mapper); + ArcMap(&efst2, &mapper); + return Equivalent(efst1, efst2); + } + + // Convenience typedefs: + typedef typename Arc::StateId StateId; + typedef EquivalenceUtil<Arc> Util; + typedef typename Util::MappedId MappedId; + enum { FST1 = 1, FST2 = 2 }; // Required by Util::MapState(...) + + MappedId s1 = Util::MapState(fst1.Start(), FST1); + MappedId s2 = Util::MapState(fst2.Start(), FST2); + + // The union-find structure. + UnionFind<MappedId> eq_classes(1000, Util::kInvalidId); + + // Initialize the union-find structure. + eq_classes.MakeSet(s1); + eq_classes.MakeSet(s2); + + // Data structure for the (partial) acceptor transition function of + // fst1 and fst2: input labels mapped to pairs of MappedId's + // representing destination states of the corresponding arcs in fst1 + // and fst2, respectively. + typedef + unordered_map<typename Arc::Label, pair<MappedId, MappedId> > + Label2StatePairMap; + + Label2StatePairMap arc_pairs; + + // Pairs of MappedId's to be processed, organized in a queue. + deque<pair<MappedId, MappedId> > q; + + bool ret = true; + // Early return if the start states differ w.r.t. being final. + if (Util::IsFinal(fst1, s1) != Util::IsFinal(fst2, s2)) { + ret = false; + } + + // Main loop: explores the two acceptors in a breadth-first manner, + // updating the equivalence relation on the statesets. Loop + // invariant: each block of states contains either final states only + // or non-final states only. + for (q.push_back(make_pair(s1, s2)); ret && !q.empty(); q.pop_front()) { + s1 = q.front().first; + s2 = q.front().second; + + // Representatives of the equivalence classes of s1/s2. + MappedId rep1 = Util::FindSet(&eq_classes, s1); + MappedId rep2 = Util::FindSet(&eq_classes, s2); + + if (rep1 != rep2) { + eq_classes.Union(rep1, rep2); + arc_pairs.clear(); + + // Copy outgoing arcs starting at s1 into the hashtable. + if (Util::kDeadState != s1) { + ArcIterator<Fst<Arc> > arc_iter(fst1, Util::UnMapState(s1)); + for (; !arc_iter.Done(); arc_iter.Next()) { + const Arc &arc = arc_iter.Value(); + if (arc.weight != Weight::Zero()) { // Zero-weight arcs + // are treated as + // non-exisitent. + arc_pairs[arc.ilabel].first = Util::MapState(arc.nextstate, FST1); + } + } + } + // Copy outgoing arcs starting at s2 into the hashtable. + if (Util::kDeadState != s2) { + ArcIterator<Fst<Arc> > arc_iter(fst2, Util::UnMapState(s2)); + for (; !arc_iter.Done(); arc_iter.Next()) { + const Arc &arc = arc_iter.Value(); + if (arc.weight != Weight::Zero()) { // Zero-weight arcs + // are treated as + // non-existent. + arc_pairs[arc.ilabel].second = Util::MapState(arc.nextstate, FST2); + } + } + } + // Iterate through the hashtable and process pairs of target + // states. + for (typename Label2StatePairMap::const_iterator + arc_iter = arc_pairs.begin(); + arc_iter != arc_pairs.end(); + ++arc_iter) { + const pair<MappedId, MappedId> &p = arc_iter->second; + if (Util::IsFinal(fst1, p.first) != Util::IsFinal(fst2, p.second)) { + // Detected inconsistency: return false. + ret = false; + break; + } + q.push_back(p); + } + } + } + + if (fst1.Properties(kError, false) || fst2.Properties(kError, false)) { + if (error) *error = true; + return false; + } + + return ret; +} + +} // namespace fst + +#endif // FST_LIB_EQUIVALENT_H__ diff --git a/kaldi_io/src/tools/openfst/include/fst/expanded-fst.h b/kaldi_io/src/tools/openfst/include/fst/expanded-fst.h new file mode 100644 index 0000000..676ceb3 --- /dev/null +++ b/kaldi_io/src/tools/openfst/include/fst/expanded-fst.h @@ -0,0 +1,189 @@ +// expanded-fst.h + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Copyright 2005-2010 Google, Inc. +// Author: [email protected] (Michael Riley) +// +// \file +// Generic FST augmented with state count - interface class definition. +// + +#ifndef FST_LIB_EXPANDED_FST_H__ +#define FST_LIB_EXPANDED_FST_H__ + +#include <sys/types.h> +#include <string> + +#include <fst/fst.h> + + +namespace fst { + +// A generic FST plus state count. +template <class A> +class ExpandedFst : public Fst<A> { + public: + typedef A Arc; + typedef typename A::StateId StateId; + + virtual StateId NumStates() const = 0; // State count + + // Get a copy of this ExpandedFst. See Fst<>::Copy() for further doc. + virtual ExpandedFst<A> *Copy(bool safe = false) const = 0; + + // Read an ExpandedFst from an input stream; return NULL on error. + static ExpandedFst<A> *Read(istream &strm, const FstReadOptions &opts) { + FstReadOptions ropts(opts); + FstHeader hdr; + if (ropts.header) + hdr = *opts.header; + else { + if (!hdr.Read(strm, opts.source)) + return 0; + ropts.header = &hdr; + } + if (!(hdr.Properties() & kExpanded)) { + LOG(ERROR) << "ExpandedFst::Read: Not an ExpandedFst: " << ropts.source; + return 0; + } + FstRegister<A> *registr = FstRegister<A>::GetRegister(); + const typename FstRegister<A>::Reader reader = + registr->GetReader(hdr.FstType()); + if (!reader) { + LOG(ERROR) << "ExpandedFst::Read: Unknown FST type \"" << hdr.FstType() + << "\" (arc type = \"" << A::Type() + << "\"): " << ropts.source; + return 0; + } + Fst<A> *fst = reader(strm, ropts); + if (!fst) return 0; + return static_cast<ExpandedFst<A> *>(fst); + } + + // Read an ExpandedFst from a file; return NULL on error. + // Empty filename reads from standard input. + static ExpandedFst<A> *Read(const string &filename) { + if (!filename.empty()) { + ifstream strm(filename.c_str(), ifstream::in | ifstream::binary); + if (!strm) { + LOG(ERROR) << "ExpandedFst::Read: Can't open file: " << filename; + return 0; + } + return Read(strm, FstReadOptions(filename)); + } else { + return Read(cin, FstReadOptions("standard input")); + } + } +}; + + +namespace internal { + +// ExpandedFst<A> case - abstract methods. +template <class A> inline +typename A::Weight Final(const ExpandedFst<A> &fst, typename A::StateId s) { + return fst.Final(s); +} + +template <class A> inline +ssize_t NumArcs(const ExpandedFst<A> &fst, typename A::StateId s) { + return fst.NumArcs(s); +} + +template <class A> inline +ssize_t NumInputEpsilons(const ExpandedFst<A> &fst, typename A::StateId s) { + return fst.NumInputEpsilons(s); +} + +template <class A> inline +ssize_t NumOutputEpsilons(const ExpandedFst<A> &fst, typename A::StateId s) { + return fst.NumOutputEpsilons(s); +} + +} // namespace internal + + +// A useful alias when using StdArc. +typedef ExpandedFst<StdArc> StdExpandedFst; + + +// This is a helper class template useful for attaching an ExpandedFst +// interface to its implementation, handling reference counting. It +// delegates to ImplToFst the handling of the Fst interface methods. +template < class I, class F = ExpandedFst<typename I::Arc> > +class ImplToExpandedFst : public ImplToFst<I, F> { + public: + typedef typename I::Arc Arc; + typedef typename Arc::Weight Weight; + typedef typename Arc::StateId StateId; + + using ImplToFst<I, F>::GetImpl; + + virtual StateId NumStates() const { return GetImpl()->NumStates(); } + + protected: + ImplToExpandedFst() : ImplToFst<I, F>() {} + + ImplToExpandedFst(I *impl) : ImplToFst<I, F>(impl) {} + + ImplToExpandedFst(const ImplToExpandedFst<I, F> &fst) + : ImplToFst<I, F>(fst) {} + + ImplToExpandedFst(const ImplToExpandedFst<I, F> &fst, bool safe) + : ImplToFst<I, F>(fst, safe) {} + + // Read FST implementation from a file; return NULL on error. + // Empty filename reads from standard input. + static I *Read(const string &filename) { + if (!filename.empty()) { + ifstream strm(filename.c_str(), ifstream::in | ifstream::binary); + if (!strm) { + LOG(ERROR) << "ExpandedFst::Read: Can't open file: " << filename; + return 0; + } + return I::Read(strm, FstReadOptions(filename)); + } else { + return I::Read(cin, FstReadOptions("standard input")); + } + } + + private: + // Disallow + ImplToExpandedFst<I, F> &operator=(const ImplToExpandedFst<I, F> &fst); + + ImplToExpandedFst<I, F> &operator=(const Fst<Arc> &fst) { + FSTERROR() << "ImplToExpandedFst: Assignment operator disallowed"; + GetImpl()->SetProperties(kError, kError); + return *this; + } +}; + +// Function to return the number of states in an FST, counting them +// if necessary. +template <class Arc> +typename Arc::StateId CountStates(const Fst<Arc> &fst) { + if (fst.Properties(kExpanded, false)) { + const ExpandedFst<Arc> *efst = static_cast<const ExpandedFst<Arc> *>(&fst); + return efst->NumStates(); + } else { + typename Arc::StateId nstates = 0; + for (StateIterator< Fst<Arc> > siter(fst); !siter.Done(); siter.Next()) + ++nstates; + return nstates; + } +} + +} // namespace fst + +#endif // FST_LIB_EXPANDED_FST_H__ diff --git a/kaldi_io/src/tools/openfst/include/fst/expectation-weight.h b/kaldi_io/src/tools/openfst/include/fst/expectation-weight.h new file mode 100644 index 0000000..5226cad --- /dev/null +++ b/kaldi_io/src/tools/openfst/include/fst/expectation-weight.h @@ -0,0 +1,142 @@ + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Copyright 2005-2010 Google, Inc. +// Author: [email protected] (Kasturi Rangan Raghavan) +// Inspiration: [email protected] (Masha Maria Shugrina) +// \file +// Expectation semiring as described by Jason Eisner: +// See: doi=10.1.1.22.9398 +// Multiplex semiring operations and identities: +// One: <One, Zero> +// Zero: <Zero, Zero> +// Plus: <a1, b1> + <a2, b2> = < (a1 + a2) , (b1 + b2) > +// Times: <a1, b1> * <a2, b2> = < (a1 * a2) , [(a1 * b2) + (a2 * b1)] > +// Division: Undefined (currently) +// +// Usually used to store the pair <probability, random_variable> so that +// ShortestDistance[Fst<ArcTpl<ExpectationWeight<P, V> > >] +// == < PosteriorProbability, Expected_Value[V] > + +#ifndef FST_LIB_EXPECTATION_WEIGHT_H_ +#define FST_LIB_EXPECTATION_WEIGHT_H_ + +#include<string> + +#include <fst/pair-weight.h> + + +namespace fst { + +// X1 is usually a probability weight like LogWeight +// X2 is usually a random variable or vector +// see SignedLogWeight or SparsePowerWeight +// +// If X1 is distinct from X2, it is required that there is an external +// product between X1 and X2 and if both semriring are commutative, or +// left or right semirings, then result must have those properties. +template <class X1, class X2> +class ExpectationWeight : public PairWeight<X1, X2> { + public: + using PairWeight<X1, X2>::Value1; + using PairWeight<X1, X2>::Value2; + + using PairWeight<X1, X2>::Reverse; + using PairWeight<X1, X2>::Quantize; + using PairWeight<X1, X2>::Member; + + typedef X1 W1; + typedef X2 W2; + + typedef ExpectationWeight<typename X1::ReverseWeight, + typename X2::ReverseWeight> ReverseWeight; + + ExpectationWeight() : PairWeight<X1, X2>(Zero()) { } + + ExpectationWeight(const ExpectationWeight<X1, X2>& w) + : PairWeight<X1, X2> (w) { } + + ExpectationWeight(const PairWeight<X1, X2>& w) + : PairWeight<X1, X2> (w) { } + + ExpectationWeight(const X1& x1, const X2& x2) + : PairWeight<X1, X2>(x1, x2) { } + + static const ExpectationWeight<X1, X2> &Zero() { + static const ExpectationWeight<X1, X2> zero(X1::Zero(), X2::Zero()); + return zero; + } + + static const ExpectationWeight<X1, X2> &One() { + static const ExpectationWeight<X1, X2> one(X1::One(), X2::Zero()); + return one; + } + + static const ExpectationWeight<X1, X2> &NoWeight() { + static const ExpectationWeight<X1, X2> no_weight(X1::NoWeight(), + X2::NoWeight()); + return no_weight; + } + + static const string &Type() { + static const string type = "expectation_" + X1::Type() + "_" + X2::Type(); + return type; + } + + PairWeight<X1, X2> Quantize(float delta = kDelta) const { + return PairWeight<X1, X2>::Quantize(); + } + + ReverseWeight Reverse() const { + return PairWeight<X1, X2>::Reverse(); + } + + bool Member() const { + return PairWeight<X1, X2>::Member(); + } + + static uint64 Properties() { + uint64 props1 = W1::Properties(); + uint64 props2 = W2::Properties(); + return props1 & props2 & (kLeftSemiring | kRightSemiring | + kCommutative | kIdempotent); + } +}; + +template <class X1, class X2> +inline ExpectationWeight<X1, X2> Plus(const ExpectationWeight<X1, X2> &w, + const ExpectationWeight<X1, X2> &v) { + return ExpectationWeight<X1, X2>(Plus(w.Value1(), v.Value1()), + Plus(w.Value2(), v.Value2())); +} + + +template <class X1, class X2> +inline ExpectationWeight<X1, X2> Times(const ExpectationWeight<X1, X2> &w, + const ExpectationWeight<X1, X2> &v) { + return ExpectationWeight<X1, X2>(Times(w.Value1(), v.Value1()), + Plus(Times(w.Value1(), v.Value2()), + Times(w.Value2(), v.Value1()))); +} + +template <class X1, class X2> +inline ExpectationWeight<X1, X2> Divide(const ExpectationWeight<X1, X2> &w, + const ExpectationWeight<X1, X2> &v, + DivideType typ = DIVIDE_ANY) { + FSTERROR() << "ExpectationWeight::Divide: not implemented"; + return ExpectationWeight<X1, X2>::NoWeight(); +} + +} // namespace fst + +#endif // FST_LIB_EXPECTATION_WEIGHT_H_ diff --git a/kaldi_io/src/tools/openfst/include/fst/extensions/far/compile-strings.h b/kaldi_io/src/tools/openfst/include/fst/extensions/far/compile-strings.h new file mode 100644 index 0000000..ca247db --- /dev/null +++ b/kaldi_io/src/tools/openfst/include/fst/extensions/far/compile-strings.h @@ -0,0 +1,304 @@ + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Copyright 2005-2010 Google, Inc. +// Authors: [email protected] (Cyril Allauzen) +// [email protected] (Terry Tai) +// [email protected] (Jake Ratkiewicz) + + +#ifndef FST_EXTENSIONS_FAR_COMPILE_STRINGS_H_ +#define FST_EXTENSIONS_FAR_COMPILE_STRINGS_H_ + +#include <libgen.h> +#include <string> +#include <vector> +using std::vector; + +#include <fst/extensions/far/far.h> +#include <fst/string.h> + +namespace fst { + +// Construct a reader that provides FSTs from a file (stream) either on a +// line-by-line basis or on a per-stream basis. Note that the freshly +// constructed reader is already set to the first input. +// +// Sample Usage: +// for (StringReader<Arc> reader(...); !reader.Done(); reader.Next()) { +// Fst *fst = reader.GetVectorFst(); +// } +template <class A> +class StringReader { + public: + typedef A Arc; + typedef typename A::Label Label; + typedef typename A::Weight Weight; + typedef typename StringCompiler<A>::TokenType TokenType; + + enum EntryType { LINE = 1, FILE = 2 }; + + StringReader(istream &istrm, + const string &source, + EntryType entry_type, + TokenType token_type, + bool allow_negative_labels, + const SymbolTable *syms = 0, + Label unknown_label = kNoStateId) + : nline_(0), strm_(istrm), source_(source), entry_type_(entry_type), + token_type_(token_type), symbols_(syms), done_(false), + compiler_(token_type, syms, unknown_label, allow_negative_labels) { + Next(); // Initialize the reader to the first input. + } + + bool Done() { + return done_; + } + + void Next() { + VLOG(1) << "Processing source " << source_ << " at line " << nline_; + if (!strm_) { // We're done if we have no more input. + done_ = true; + return; + } + if (entry_type_ == LINE) { + getline(strm_, content_); + ++nline_; + } else { + content_.clear(); + string line; + while (getline(strm_, line)) { + ++nline_; + content_.append(line); + content_.append("\n"); + } + } + if (!strm_ && content_.empty()) // We're also done if we read off all the + done_ = true; // whitespace at the end of a file. + } + + VectorFst<A> *GetVectorFst(bool keep_symbols = false) { + VectorFst<A> *fst = new VectorFst<A>; + if (keep_symbols) { + fst->SetInputSymbols(symbols_); + fst->SetOutputSymbols(symbols_); + } + if (compiler_(content_, fst)) { + return fst; + } else { + delete fst; + return NULL; + } + } + + CompactFst<A, StringCompactor<A> > *GetCompactFst(bool keep_symbols = false) { + CompactFst<A, StringCompactor<A> > *fst; + if (keep_symbols) { + VectorFst<A> tmp; + tmp.SetInputSymbols(symbols_); + tmp.SetOutputSymbols(symbols_); + fst = new CompactFst<A, StringCompactor<A> >(tmp); + } else { + fst = new CompactFst<A, StringCompactor<A> >; + } + if (compiler_(content_, fst)) { + return fst; + } else { + delete fst; + return NULL; + } + } + + private: + size_t nline_; + istream &strm_; + string source_; + EntryType entry_type_; + TokenType token_type_; + const SymbolTable *symbols_; + bool done_; + StringCompiler<A> compiler_; + string content_; // The actual content of the input stream's next FST. + + DISALLOW_COPY_AND_ASSIGN(StringReader); +}; + +// Compute the minimal length required to encode each line number as a decimal +// number. +int KeySize(const char *filename); + +template <class Arc> +void FarCompileStrings(const vector<string> &in_fnames, + const string &out_fname, + const string &fst_type, + const FarType &far_type, + int32 generate_keys, + FarEntryType fet, + FarTokenType tt, + const string &symbols_fname, + const string &unknown_symbol, + bool keep_symbols, + bool initial_symbols, + bool allow_negative_labels, + bool file_list_input, + const string &key_prefix, + const string &key_suffix) { + typename StringReader<Arc>::EntryType entry_type; + if (fet == FET_LINE) { + entry_type = StringReader<Arc>::LINE; + } else if (fet == FET_FILE) { + entry_type = StringReader<Arc>::FILE; + } else { + FSTERROR() << "FarCompileStrings: unknown entry type"; + return; + } + + typename StringCompiler<Arc>::TokenType token_type; + if (tt == FTT_SYMBOL) { + token_type = StringCompiler<Arc>::SYMBOL; + } else if (tt == FTT_BYTE) { + token_type = StringCompiler<Arc>::BYTE; + } else if (tt == FTT_UTF8) { + token_type = StringCompiler<Arc>::UTF8; + } else { + FSTERROR() << "FarCompileStrings: unknown token type"; + return; + } + + bool compact; + if (fst_type.empty() || (fst_type == "vector")) { + compact = false; + } else if (fst_type == "compact") { + compact = true; + } else { + FSTERROR() << "FarCompileStrings: unknown fst type: " + << fst_type; + return; + } + + const SymbolTable *syms = 0; + typename Arc::Label unknown_label = kNoLabel; + if (!symbols_fname.empty()) { + SymbolTableTextOptions opts; + opts.allow_negative = allow_negative_labels; + syms = SymbolTable::ReadText(symbols_fname, opts); + if (!syms) { + FSTERROR() << "FarCompileStrings: error reading symbol table: " + << symbols_fname; + return; + } + if (!unknown_symbol.empty()) { + unknown_label = syms->Find(unknown_symbol); + if (unknown_label == kNoLabel) { + FSTERROR() << "FarCompileStrings: unknown label \"" << unknown_label + << "\" missing from symbol table: " << symbols_fname; + return; + } + } + } + + FarWriter<Arc> *far_writer = + FarWriter<Arc>::Create(out_fname, far_type); + if (!far_writer) return; + + vector<string> inputs; + if (file_list_input) { + for (int i = 1; i < in_fnames.size(); ++i) { + istream *istrm = in_fnames.empty() ? &cin : + new ifstream(in_fnames[i].c_str()); + string str; + while (getline(*istrm, str)) + inputs.push_back(str); + if (!in_fnames.empty()) + delete istrm; + } + } else { + inputs = in_fnames; + } + + for (int i = 0, n = 0; i < inputs.size(); ++i) { + if (generate_keys == 0 && inputs[i].empty()) { + FSTERROR() << "FarCompileStrings: read from a file instead of stdin or" + << " set the --generate_keys flags."; + delete far_writer; + delete syms; + return; + } + int key_size = generate_keys ? generate_keys : + (entry_type == StringReader<Arc>::FILE ? 1 : + KeySize(inputs[i].c_str())); + istream *istrm = inputs[i].empty() ? &cin : + new ifstream(inputs[i].c_str()); + + bool keep_syms = keep_symbols; + for (StringReader<Arc> reader( + *istrm, inputs[i].empty() ? "stdin" : inputs[i], + entry_type, token_type, allow_negative_labels, + syms, unknown_label); + !reader.Done(); + reader.Next()) { + ++n; + const Fst<Arc> *fst; + if (compact) + fst = reader.GetCompactFst(keep_syms); + else + fst = reader.GetVectorFst(keep_syms); + if (initial_symbols) + keep_syms = false; + if (!fst) { + FSTERROR() << "FarCompileStrings: compiling string number " << n + << " in file " << inputs[i] << " failed with token_type = " + << (tt == FTT_BYTE ? "byte" : + (tt == FTT_UTF8 ? "utf8" : + (tt == FTT_SYMBOL ? "symbol" : "unknown"))) + << " and entry_type = " + << (fet == FET_LINE ? "line" : + (fet == FET_FILE ? "file" : "unknown")); + delete far_writer; + delete syms; + if (!inputs[i].empty()) delete istrm; + return; + } + ostringstream keybuf; + keybuf.width(key_size); + keybuf.fill('0'); + keybuf << n; + string key; + if (generate_keys > 0) { + key = keybuf.str(); + } else { + char* filename = new char[inputs[i].size() + 1]; + strcpy(filename, inputs[i].c_str()); + key = basename(filename); + if (entry_type != StringReader<Arc>::FILE) { + key += "-"; + key += keybuf.str(); + } + delete[] filename; + } + far_writer->Add(key_prefix + key + key_suffix, *fst); + delete fst; + } + if (generate_keys == 0) + n = 0; + if (!inputs[i].empty()) + delete istrm; + } + + delete far_writer; +} + +} // namespace fst + + +#endif // FST_EXTENSIONS_FAR_COMPILE_STRINGS_H_ diff --git a/kaldi_io/src/tools/openfst/include/fst/extensions/far/create.h b/kaldi_io/src/tools/openfst/include/fst/extensions/far/create.h new file mode 100644 index 0000000..edb31e7 --- /dev/null +++ b/kaldi_io/src/tools/openfst/include/fst/extensions/far/create.h @@ -0,0 +1,87 @@ +// create-main.h + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Copyright 2005-2010 Google, Inc. +// Author: [email protected] (Michael Riley) +// Modified: [email protected] (Jake Ratkiewicz) to use new dispatch +// +// \file +// Creates a finite-state archive from component FSTs. Includes +// helper function for farcreate.cc that templates the main on the arc +// type to support multiple and extensible arc types. +// + +#ifndef FST_EXTENSIONS_FAR_CREATE_H__ +#define FST_EXTENSIONS_FAR_CREATE_H__ + +#include <libgen.h> +#include <string> +#include <vector> +using std::vector; + +#include <fst/extensions/far/far.h> + +namespace fst { + +template <class Arc> +void FarCreate(const vector<string> &in_fnames, + const string &out_fname, + const int32 generate_keys, + const bool file_list_input, + const FarType &far_type, + const string &key_prefix, + const string &key_suffix) { + FarWriter<Arc> *far_writer = + FarWriter<Arc>::Create(out_fname, far_type); + if (!far_writer) return; + + vector<string> inputs; + if (file_list_input) { + for (int i = 1; i < in_fnames.size(); ++i) { + ifstream istrm(in_fnames[i].c_str()); + string str; + while (getline(istrm, str)) + inputs.push_back(str); + } + } else { + inputs = in_fnames; + } + + for (int i = 0; i < inputs.size(); ++i) { + Fst<Arc> *ifst = Fst<Arc>::Read(inputs[i]); + if (!ifst) return; + string key; + if (generate_keys > 0) { + ostringstream keybuf; + keybuf.width(generate_keys); + keybuf.fill('0'); + keybuf << i + 1; + key = keybuf.str(); + } else { + char* filename = new char[inputs[i].size() + 1]; + strcpy(filename, inputs[i].c_str()); + key = basename(filename); + delete[] filename; + } + + far_writer->Add(key_prefix + key + key_suffix, *ifst); + delete ifst; + } + + delete far_writer; +} + +} // namespace fst + +#endif // FST_EXTENSIONS_FAR_CREATE_H__ diff --git a/kaldi_io/src/tools/openfst/include/fst/extensions/far/equal.h b/kaldi_io/src/tools/openfst/include/fst/extensions/far/equal.h new file mode 100644 index 0000000..be82e2d --- /dev/null +++ b/kaldi_io/src/tools/openfst/include/fst/extensions/far/equal.h @@ -0,0 +1,99 @@ + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Copyright 2005-2010 Google, Inc. +// Author: [email protected] (Cyril Allauzen) + +#ifndef FST_EXTENSIONS_FAR_EQUAL_H_ +#define FST_EXTENSIONS_FAR_EQUAL_H_ + +#include <string> + +#include <fst/extensions/far/far.h> +#include <fst/equal.h> + +namespace fst { + +template <class Arc> +bool FarEqual(const string &filename1, + const string &filename2, + float delta = kDelta, + const string &begin_key = string(), + const string &end_key = string()) { + + FarReader<Arc> *reader1 = FarReader<Arc>::Open(filename1); + FarReader<Arc> *reader2 = FarReader<Arc>::Open(filename2); + if (!reader1 || !reader2) { + delete reader1; + delete reader2; + VLOG(1) << "FarEqual: cannot open input Far file(s)"; + return false; + } + + if (!begin_key.empty()) { + bool find_begin1 = reader1->Find(begin_key); + bool find_begin2 = reader2->Find(begin_key); + if (!find_begin1 || !find_begin2) { + bool ret = !find_begin1 && !find_begin2; + if (!ret) { + VLOG(1) << "FarEqual: key \"" << begin_key << "\" missing from " + << (find_begin1 ? "second" : "first") << " archive."; + } + delete reader1; + delete reader2; + return ret; + } + } + + for(; !reader1->Done() && !reader2->Done(); + reader1->Next(), reader2->Next()) { + const string key1 = reader1->GetKey(); + const string key2 = reader2->GetKey(); + if (!end_key.empty() && end_key < key1 && end_key < key2) { + delete reader1; + delete reader2; + return true; + } + if (key1 != key2) { + VLOG(1) << "FarEqual: mismatched keys \"" + << key1 << "\" <> \"" << key2 << "\"."; + delete reader1; + delete reader2; + return false; + } + if (!Equal(reader1->GetFst(), reader2->GetFst(), delta)) { + VLOG(1) << "FarEqual: Fsts for key \"" << key1 << "\" are not equal."; + delete reader1; + delete reader2; + return false; + } + } + + if (!reader1->Done() || !reader2->Done()) { + VLOG(1) << "FarEqual: key \"" + << (reader1->Done() ? reader2->GetKey() : reader1->GetKey()) + << "\" missing form " << (reader2->Done() ? "first" : "second") + << " archive."; + delete reader1; + delete reader2; + return false; + } + + delete reader1; + delete reader2; + return true; +} + +} // namespace fst + +#endif // FST_EXTENSIONS_FAR_EQUAL_H_ diff --git a/kaldi_io/src/tools/openfst/include/fst/extensions/far/extract.h b/kaldi_io/src/tools/openfst/include/fst/extensions/far/extract.h new file mode 100644 index 0000000..95866de --- /dev/null +++ b/kaldi_io/src/tools/openfst/include/fst/extensions/far/extract.h @@ -0,0 +1,140 @@ +// extract-main.h + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Copyright 2005-2010 Google, Inc. +// Author: [email protected] (Michael Riley) +// Modified: [email protected] (Jake Ratkiewicz) to use the new arc-dispatch + +// \file +// Extracts component FSTs from an finite-state archive. +// + +#ifndef FST_EXTENSIONS_FAR_EXTRACT_H__ +#define FST_EXTENSIONS_FAR_EXTRACT_H__ + +#include <string> +#include <vector> +using std::vector; + +#include <fst/extensions/far/far.h> + +namespace fst { + +template<class Arc> +inline void FarWriteFst(const Fst<Arc>* fst, string key, + string* okey, int* nrep, + const int32 &generate_filenames, int i, + const string &filename_prefix, + const string &filename_suffix) { + if (key == *okey) + ++*nrep; + else + *nrep = 0; + + *okey = key; + + string ofilename; + if (generate_filenames) { + ostringstream tmp; + tmp.width(generate_filenames); + tmp.fill('0'); + tmp << i; + ofilename = tmp.str(); + } else { + if (*nrep > 0) { + ostringstream tmp; + tmp << '.' << nrep; + key.append(tmp.str().data(), tmp.str().size()); + } + ofilename = key; + } + fst->Write(filename_prefix + ofilename + filename_suffix); +} + +template<class Arc> +void FarExtract(const vector<string> &ifilenames, + const int32 &generate_filenames, + const string &keys, + const string &key_separator, + const string &range_delimiter, + const string &filename_prefix, + const string &filename_suffix) { + FarReader<Arc> *far_reader = FarReader<Arc>::Open(ifilenames); + if (!far_reader) return; + + string okey; + int nrep = 0; + + vector<char *> key_vector; + // User has specified a set of fsts to extract, where some of the "fsts" could + // be ranges. + if (!keys.empty()) { + char *keys_cstr = new char[keys.size()+1]; + strcpy(keys_cstr, keys.c_str()); + SplitToVector(keys_cstr, key_separator.c_str(), &key_vector, true); + int i = 0; + for (int k = 0; k < key_vector.size(); ++k, ++i) { + string key = string(key_vector[k]); + char *key_cstr = new char[key.size()+1]; + strcpy(key_cstr, key.c_str()); + vector<char *> range_vector; + SplitToVector(key_cstr, range_delimiter.c_str(), &range_vector, false); + if (range_vector.size() == 1) { // Not a range + if (!far_reader->Find(key)) { + LOG(ERROR) << "FarExtract: Cannot find key: " << key; + return; + } + const Fst<Arc> &fst = far_reader->GetFst(); + FarWriteFst(&fst, key, &okey, &nrep, generate_filenames, i, + filename_prefix, filename_suffix); + } else if (range_vector.size() == 2) { // A legal range + string begin_key = string(range_vector[0]); + string end_key = string(range_vector[1]); + if (begin_key.empty() || end_key.empty()) { + LOG(ERROR) << "FarExtract: Illegal range specification: " << key; + return; + } + if (!far_reader->Find(begin_key)) { + LOG(ERROR) << "FarExtract: Cannot find key: " << begin_key; + return; + } + for ( ; !far_reader->Done(); far_reader->Next(), ++i) { + string ikey = far_reader->GetKey(); + if (end_key < ikey) break; + const Fst<Arc> &fst = far_reader->GetFst(); + FarWriteFst(&fst, ikey, &okey, &nrep, generate_filenames, i, + filename_prefix, filename_suffix); + } + } else { + LOG(ERROR) << "FarExtract: Illegal range specification: " << key; + return; + } + delete key_cstr; + } + delete keys_cstr; + return; + } + // Nothing specified: extract everything. + for (int i = 1; !far_reader->Done(); far_reader->Next(), ++i) { + string key = far_reader->GetKey(); + const Fst<Arc> &fst = far_reader->GetFst(); + FarWriteFst(&fst, key, &okey, &nrep, generate_filenames, i, + filename_prefix, filename_suffix); + } + return; +} + +} // namespace fst + +#endif // FST_EXTENSIONS_FAR_EXTRACT_H__ diff --git a/kaldi_io/src/tools/openfst/include/fst/extensions/far/far.h b/kaldi_io/src/tools/openfst/include/fst/extensions/far/far.h new file mode 100644 index 0000000..acce76e --- /dev/null +++ b/kaldi_io/src/tools/openfst/include/fst/extensions/far/far.h @@ -0,0 +1,532 @@ +// far.h + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Copyright 2005-2010 Google, Inc. +// Author: [email protected] (Michael Riley) +// +// \file +// Finite-State Transducer (FST) archive classes. +// + +#ifndef FST_EXTENSIONS_FAR_FAR_H__ +#define FST_EXTENSIONS_FAR_FAR_H__ + +#include <fst/extensions/far/stlist.h> +#include <fst/extensions/far/sttable.h> +#include <fst/fst.h> +#include <fst/vector-fst.h> + +namespace fst { + +enum FarEntryType { FET_LINE, FET_FILE }; +enum FarTokenType { FTT_SYMBOL, FTT_BYTE, FTT_UTF8 }; + +inline bool IsFst(const string &filename) { + ifstream strm(filename.c_str()); + if (!strm) + return false; + return IsFstHeader(strm, filename); +} + +// FST archive header class +class FarHeader { + public: + const string &FarType() const { return fartype_; } + const string &ArcType() const { return arctype_; } + + bool Read(const string &filename) { + FstHeader fsthdr; + if (filename.empty()) { + // Header reading unsupported on stdin. Assumes STList and StdArc. + fartype_ = "stlist"; + arctype_ = "standard"; + return true; + } else if (IsSTTable(filename)) { // Check if STTable + ReadSTTableHeader(filename, &fsthdr); + fartype_ = "sttable"; + arctype_ = fsthdr.ArcType().empty() ? "unknown" : fsthdr.ArcType(); + return true; + } else if (IsSTList(filename)) { // Check if STList + ReadSTListHeader(filename, &fsthdr); + fartype_ = "sttable"; + arctype_ = fsthdr.ArcType().empty() ? "unknown" : fsthdr.ArcType(); + return true; + } else if (IsFst(filename)) { // Check if Fst + ifstream istrm(filename.c_str()); + fsthdr.Read(istrm, filename); + fartype_ = "fst"; + arctype_ = fsthdr.ArcType().empty() ? "unknown" : fsthdr.ArcType(); + return true; + } + return false; + } + + private: + string fartype_; + string arctype_; +}; + +enum FarType { + FAR_DEFAULT = 0, + FAR_STTABLE = 1, + FAR_STLIST = 2, + FAR_FST = 3, +}; + +// This class creates an archive of FSTs. +template <class A> +class FarWriter { + public: + typedef A Arc; + + // Creates a new (empty) FST archive; returns NULL on error. + static FarWriter *Create(const string &filename, FarType type = FAR_DEFAULT); + + // Adds an FST to the end of an archive. Keys must be non-empty and + // in lexicographic order. FSTs must have a suitable write method. + virtual void Add(const string &key, const Fst<A> &fst) = 0; + + virtual FarType Type() const = 0; + + virtual bool Error() const = 0; + + virtual ~FarWriter() {} + + protected: + FarWriter() {} + + private: + DISALLOW_COPY_AND_ASSIGN(FarWriter); +}; + + +// This class iterates through an existing archive of FSTs. +template <class A> +class FarReader { + public: + typedef A Arc; + + // Opens an existing FST archive in a single file; returns NULL on error. + // Sets current position to the beginning of the achive. + static FarReader *Open(const string &filename); + + // Opens an existing FST archive in multiple files; returns NULL on error. + // Sets current position to the beginning of the achive. + static FarReader *Open(const vector<string> &filenames); + + // Resets current posision to beginning of archive. + virtual void Reset() = 0; + + // Sets current position to first entry >= key. Returns true if a match. + virtual bool Find(const string &key) = 0; + + // Current position at end of archive? + virtual bool Done() const = 0; + + // Move current position to next FST. + virtual void Next() = 0; + + // Returns key at the current position. This reference is invalidated if + // the current position in the archive is changed. + virtual const string &GetKey() const = 0; + + // Returns FST at the current position. This reference is invalidated if + // the current position in the archive is changed. + virtual const Fst<A> &GetFst() const = 0; + + virtual FarType Type() const = 0; + + virtual bool Error() const = 0; + + virtual ~FarReader() {} + + protected: + FarReader() {} + + private: + DISALLOW_COPY_AND_ASSIGN(FarReader); +}; + + +template <class A> +class FstWriter { + public: + void operator()(ostream &strm, const Fst<A> &fst) const { + fst.Write(strm, FstWriteOptions()); + } +}; + + +template <class A> +class STTableFarWriter : public FarWriter<A> { + public: + typedef A Arc; + + static STTableFarWriter *Create(const string &filename) { + STTableWriter<Fst<A>, FstWriter<A> > *writer = + STTableWriter<Fst<A>, FstWriter<A> >::Create(filename); + return new STTableFarWriter(writer); + } + + void Add(const string &key, const Fst<A> &fst) { writer_->Add(key, fst); } + + FarType Type() const { return FAR_STTABLE; } + + bool Error() const { return writer_->Error(); } + + ~STTableFarWriter() { delete writer_; } + + private: + explicit STTableFarWriter(STTableWriter<Fst<A>, FstWriter<A> > *writer) + : writer_(writer) {} + + private: + STTableWriter<Fst<A>, FstWriter<A> > *writer_; + + DISALLOW_COPY_AND_ASSIGN(STTableFarWriter); +}; + + +template <class A> +class STListFarWriter : public FarWriter<A> { + public: + typedef A Arc; + + static STListFarWriter *Create(const string &filename) { + STListWriter<Fst<A>, FstWriter<A> > *writer = + STListWriter<Fst<A>, FstWriter<A> >::Create(filename); + return new STListFarWriter(writer); + } + + void Add(const string &key, const Fst<A> &fst) { writer_->Add(key, fst); } + + FarType Type() const { return FAR_STLIST; } + + bool Error() const { return writer_->Error(); } + + ~STListFarWriter() { delete writer_; } + + private: + explicit STListFarWriter(STListWriter<Fst<A>, FstWriter<A> > *writer) + : writer_(writer) {} + + private: + STListWriter<Fst<A>, FstWriter<A> > *writer_; + + DISALLOW_COPY_AND_ASSIGN(STListFarWriter); +}; + + +template <class A> +class FstFarWriter : public FarWriter<A> { + public: + typedef A Arc; + + explicit FstFarWriter(const string &filename) + : filename_(filename), error_(false), written_(false) {} + + static FstFarWriter *Create(const string &filename) { + return new FstFarWriter(filename); + } + + void Add(const string &key, const Fst<A> &fst) { + if (written_) { + LOG(WARNING) << "FstFarWriter::Add: only one Fst supported," + << " subsequent entries discarded."; + } else { + error_ = !fst.Write(filename_); + written_ = true; + } + } + + FarType Type() const { return FAR_FST; } + + bool Error() const { return error_; } + + ~FstFarWriter() {} + + private: + string filename_; + bool error_; + bool written_; + + DISALLOW_COPY_AND_ASSIGN(FstFarWriter); +}; + + +template <class A> +FarWriter<A> *FarWriter<A>::Create(const string &filename, FarType type) { + switch(type) { + case FAR_DEFAULT: + if (filename.empty()) + return STListFarWriter<A>::Create(filename); + case FAR_STTABLE: + return STTableFarWriter<A>::Create(filename); + break; + case FAR_STLIST: + return STListFarWriter<A>::Create(filename); + break; + case FAR_FST: + return FstFarWriter<A>::Create(filename); + break; + default: + LOG(ERROR) << "FarWriter::Create: unknown far type"; + return 0; + } +} + + +template <class A> +class FstReader { + public: + Fst<A> *operator()(istream &strm) const { + return Fst<A>::Read(strm, FstReadOptions()); + } +}; + + +template <class A> +class STTableFarReader : public FarReader<A> { + public: + typedef A Arc; + + static STTableFarReader *Open(const string &filename) { + STTableReader<Fst<A>, FstReader<A> > *reader = + STTableReader<Fst<A>, FstReader<A> >::Open(filename); + // TODO: error check + return new STTableFarReader(reader); + } + + static STTableFarReader *Open(const vector<string> &filenames) { + STTableReader<Fst<A>, FstReader<A> > *reader = + STTableReader<Fst<A>, FstReader<A> >::Open(filenames); + // TODO: error check + return new STTableFarReader(reader); + } + + void Reset() { reader_->Reset(); } + + bool Find(const string &key) { return reader_->Find(key); } + + bool Done() const { return reader_->Done(); } + + void Next() { return reader_->Next(); } + + const string &GetKey() const { return reader_->GetKey(); } + + const Fst<A> &GetFst() const { return reader_->GetEntry(); } + + FarType Type() const { return FAR_STTABLE; } + + bool Error() const { return reader_->Error(); } + + ~STTableFarReader() { delete reader_; } + + private: + explicit STTableFarReader(STTableReader<Fst<A>, FstReader<A> > *reader) + : reader_(reader) {} + + private: + STTableReader<Fst<A>, FstReader<A> > *reader_; + + DISALLOW_COPY_AND_ASSIGN(STTableFarReader); +}; + + +template <class A> +class STListFarReader : public FarReader<A> { + public: + typedef A Arc; + + static STListFarReader *Open(const string &filename) { + STListReader<Fst<A>, FstReader<A> > *reader = + STListReader<Fst<A>, FstReader<A> >::Open(filename); + // TODO: error check + return new STListFarReader(reader); + } + + static STListFarReader *Open(const vector<string> &filenames) { + STListReader<Fst<A>, FstReader<A> > *reader = + STListReader<Fst<A>, FstReader<A> >::Open(filenames); + // TODO: error check + return new STListFarReader(reader); + } + + void Reset() { reader_->Reset(); } + + bool Find(const string &key) { return reader_->Find(key); } + + bool Done() const { return reader_->Done(); } + + void Next() { return reader_->Next(); } + + const string &GetKey() const { return reader_->GetKey(); } + + const Fst<A> &GetFst() const { return reader_->GetEntry(); } + + FarType Type() const { return FAR_STLIST; } + + bool Error() const { return reader_->Error(); } + + ~STListFarReader() { delete reader_; } + + private: + explicit STListFarReader(STListReader<Fst<A>, FstReader<A> > *reader) + : reader_(reader) {} + + private: + STListReader<Fst<A>, FstReader<A> > *reader_; + + DISALLOW_COPY_AND_ASSIGN(STListFarReader); +}; + +template <class A> +class FstFarReader : public FarReader<A> { + public: + typedef A Arc; + + static FstFarReader *Open(const string &filename) { + vector<string> filenames; + filenames.push_back(filename); + return new FstFarReader<A>(filenames); + } + + static FstFarReader *Open(const vector<string> &filenames) { + return new FstFarReader<A>(filenames); + } + + FstFarReader(const vector<string> &filenames) + : keys_(filenames), has_stdin_(false), pos_(0), fst_(0), error_(false) { + sort(keys_.begin(), keys_.end()); + streams_.resize(keys_.size(), 0); + for (size_t i = 0; i < keys_.size(); ++i) { + if (keys_[i].empty()) { + if (!has_stdin_) { + streams_[i] = &cin; + //sources_[i] = "stdin"; + has_stdin_ = true; + } else { + FSTERROR() << "FstFarReader::FstFarReader: stdin should only " + << "appear once in the input file list."; + error_ = true; + return; + } + } else { + streams_[i] = new ifstream( + keys_[i].c_str(), ifstream::in | ifstream::binary); + } + } + if (pos_ >= keys_.size()) return; + ReadFst(); + } + + void Reset() { + if (has_stdin_) { + FSTERROR() << "FstFarReader::Reset: operation not supported on stdin"; + error_ = true; + return; + } + pos_ = 0; + ReadFst(); + } + + bool Find(const string &key) { + if (has_stdin_) { + FSTERROR() << "FstFarReader::Find: operation not supported on stdin"; + error_ = true; + return false; + } + pos_ = 0;//TODO + ReadFst(); + return true; + } + + bool Done() const { return error_ || pos_ >= keys_.size(); } + + void Next() { + ++pos_; + ReadFst(); + } + + const string &GetKey() const { + return keys_[pos_]; + } + + const Fst<A> &GetFst() const { + return *fst_; + } + + FarType Type() const { return FAR_FST; } + + bool Error() const { return error_; } + + ~FstFarReader() { + if (fst_) delete fst_; + for (size_t i = 0; i < keys_.size(); ++i) + delete streams_[i]; + } + + private: + void ReadFst() { + if (fst_) delete fst_; + if (pos_ >= keys_.size()) return; + streams_[pos_]->seekg(0); + fst_ = Fst<A>::Read(*streams_[pos_], FstReadOptions()); + if (!fst_) { + FSTERROR() << "FstFarReader: error reading Fst from: " << keys_[pos_]; + error_ = true; + } + } + + private: + vector<string> keys_; + vector<istream*> streams_; + bool has_stdin_; + size_t pos_; + mutable Fst<A> *fst_; + mutable bool error_; + + DISALLOW_COPY_AND_ASSIGN(FstFarReader); +}; + +template <class A> +FarReader<A> *FarReader<A>::Open(const string &filename) { + if (filename.empty()) + return STListFarReader<A>::Open(filename); + else if (IsSTTable(filename)) + return STTableFarReader<A>::Open(filename); + else if (IsSTList(filename)) + return STListFarReader<A>::Open(filename); + else if (IsFst(filename)) + return FstFarReader<A>::Open(filename); + return 0; +} + + +template <class A> +FarReader<A> *FarReader<A>::Open(const vector<string> &filenames) { + if (!filenames.empty() && filenames[0].empty()) + return STListFarReader<A>::Open(filenames); + else if (!filenames.empty() && IsSTTable(filenames[0])) + return STTableFarReader<A>::Open(filenames); + else if (!filenames.empty() && IsSTList(filenames[0])) + return STListFarReader<A>::Open(filenames); + else if (!filenames.empty() && IsFst(filenames[0])) + return FstFarReader<A>::Open(filenames); + return 0; +} + +} // namespace fst + +#endif // FST_EXTENSIONS_FAR_FAR_H__ diff --git a/kaldi_io/src/tools/openfst/include/fst/extensions/far/farlib.h b/kaldi_io/src/tools/openfst/include/fst/extensions/far/farlib.h new file mode 100644 index 0000000..91ba224 --- /dev/null +++ b/kaldi_io/src/tools/openfst/include/fst/extensions/far/farlib.h @@ -0,0 +1,31 @@ + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Copyright 2005-2010 Google, Inc. +// Author: [email protected] (Jake Ratkiewicz) + +// A finite-state archive (FAR) is used to store an indexable collection of +// FSTs in a single file. Utilities are provided to create FARs from FSTs, +// to iterate over FARs, and to extract specific FSTs from FARs. + +#ifndef FST_EXTENSIONS_FAR_FARLIB_H_ +#define FST_EXTENSIONS_FAR_FARLIB_H_ + +#include <fst/extensions/far/far.h> +#include <fst/extensions/far/compile-strings.h> +#include <fst/extensions/far/create.h> +#include <fst/extensions/far/extract.h> +#include <fst/extensions/far/info.h> +#include <fst/extensions/far/print-strings.h> + +#endif // FST_EXTENSIONS_FAR_FARLIB_H_ diff --git a/kaldi_io/src/tools/openfst/include/fst/extensions/far/farscript.h b/kaldi_io/src/tools/openfst/include/fst/extensions/far/farscript.h new file mode 100644 index 0000000..cfd9167 --- /dev/null +++ b/kaldi_io/src/tools/openfst/include/fst/extensions/far/farscript.h @@ -0,0 +1,273 @@ + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Copyright 2005-2010 Google, Inc. +// Author: [email protected] (Jake Ratkiewicz) + +// Convenience file for including all of the FAR operations, +// or registering them for new arc types. + +#ifndef FST_EXTENSIONS_FAR_FARSCRIPT_H_ +#define FST_EXTENSIONS_FAR_FARSCRIPT_H_ + +#include <vector> +using std::vector; +#include <string> + +#include <fst/script/arg-packs.h> +#include <fst/extensions/far/compile-strings.h> +#include <fst/extensions/far/create.h> +#include <fst/extensions/far/equal.h> +#include <fst/extensions/far/extract.h> +#include <fst/extensions/far/info.h> +#include <fst/extensions/far/print-strings.h> +#include <fst/extensions/far/far.h> + +#include <fst/types.h> + +namespace fst { +namespace script { + +// Note: it is safe to pass these strings as references because +// this struct is only used to pass them deeper in the call graph. +// Be sure you understand why this is so before using this struct +// for anything else! +struct FarCompileStringsArgs { + const vector<string> &in_fnames; + const string &out_fname; + const string &fst_type; + const FarType &far_type; + const int32 generate_keys; + const FarEntryType fet; + const FarTokenType tt; + const string &symbols_fname; + const string &unknown_symbol; + const bool keep_symbols; + const bool initial_symbols; + const bool allow_negative_labels; + const bool file_list_input; + const string &key_prefix; + const string &key_suffix; + + FarCompileStringsArgs(const vector<string> &in_fnames, + const string &out_fname, + const string &fst_type, + const FarType &far_type, + int32 generate_keys, + FarEntryType fet, + FarTokenType tt, + const string &symbols_fname, + const string &unknown_symbol, + bool keep_symbols, + bool initial_symbols, + bool allow_negative_labels, + bool file_list_input, + const string &key_prefix, + const string &key_suffix) : + in_fnames(in_fnames), out_fname(out_fname), fst_type(fst_type), + far_type(far_type), generate_keys(generate_keys), fet(fet), + tt(tt), symbols_fname(symbols_fname), unknown_symbol(unknown_symbol), + keep_symbols(keep_symbols), initial_symbols(initial_symbols), + allow_negative_labels(allow_negative_labels), + file_list_input(file_list_input), key_prefix(key_prefix), + key_suffix(key_suffix) { } +}; + +template <class Arc> +void FarCompileStrings(FarCompileStringsArgs *args) { + fst::FarCompileStrings<Arc>( + args->in_fnames, args->out_fname, args->fst_type, args->far_type, + args->generate_keys, args->fet, args->tt, args->symbols_fname, + args->unknown_symbol, args->keep_symbols, args->initial_symbols, + args->allow_negative_labels, args->file_list_input, + args->key_prefix, args->key_suffix); +} + +void FarCompileStrings( + const vector<string> &in_fnames, + const string &out_fname, + const string &arc_type, + const string &fst_type, + const FarType &far_type, + int32 generate_keys, + FarEntryType fet, + FarTokenType tt, + const string &symbols_fname, + const string &unknown_symbol, + bool keep_symbols, + bool initial_symbols, + bool allow_negative_labels, + bool file_list_input, + const string &key_prefix, + const string &key_suffix); + + +// Note: it is safe to pass these strings as references because +// this struct is only used to pass them deeper in the call graph. +// Be sure you understand why this is so before using this struct +// for anything else! +struct FarCreateArgs { + const vector<string> &in_fnames; + const string &out_fname; + const int32 generate_keys; + const bool file_list_input; + const FarType &far_type; + const string &key_prefix; + const string &key_suffix; + + FarCreateArgs( + const vector<string> &in_fnames, const string &out_fname, + const int32 generate_keys, const bool file_list_input, + const FarType &far_type, const string &key_prefix, + const string &key_suffix) + : in_fnames(in_fnames), out_fname(out_fname), + generate_keys(generate_keys), file_list_input(file_list_input), + far_type(far_type), key_prefix(key_prefix), key_suffix(key_suffix) { } +}; + +template<class Arc> +void FarCreate(FarCreateArgs *args) { + fst::FarCreate<Arc>(args->in_fnames, args->out_fname, args->generate_keys, + args->file_list_input, args->far_type, + args->key_prefix, args->key_suffix); +} + +void FarCreate(const vector<string> &in_fnames, + const string &out_fname, + const string &arc_type, + const int32 generate_keys, + const bool file_list_input, + const FarType &far_type, + const string &key_prefix, + const string &key_suffix); + + +typedef args::Package<const string &, const string &, float, + const string &, const string &> FarEqualInnerArgs; +typedef args::WithReturnValue<bool, FarEqualInnerArgs> FarEqualArgs; + +template <class Arc> +void FarEqual(FarEqualArgs *args) { + args->retval = fst::FarEqual<Arc>( + args->args.arg1, args->args.arg2, args->args.arg3, + args->args.arg4, args->args.arg5); +} + +bool FarEqual(const string &filename1, + const string &filename2, + const string &arc_type, + float delta = kDelta, + const string &begin_key = string(), + const string &end_key = string()); + + +typedef args::Package<const vector<string> &, int32, + const string&, const string&, const string&, + const string&, const string&> FarExtractArgs; + +template<class Arc> +void FarExtract(FarExtractArgs *args) { + fst::FarExtract<Arc>( + args->arg1, args->arg2, args->arg3, args->arg4, args->arg5, args->arg6, + args->arg7); +} + +void FarExtract(const vector<string> &ifilenames, + const string &arc_type, + int32 generate_filenames, + const string &keys, + const string &key_separator, + const string &range_delimiter, + const string &filename_prefix, + const string &filename_suffix); + +typedef args::Package<const vector<string> &, const string &, + const string &, const bool> FarInfoArgs; + +template <class Arc> +void FarInfo(FarInfoArgs *args) { + fst::FarInfo<Arc>(args->arg1, args->arg2, args->arg3, args->arg4); +} + +void FarInfo(const vector<string> &filenames, + const string &arc_type, + const string &begin_key, + const string &end_key, + const bool list_fsts); + +struct FarPrintStringsArgs { + const vector<string> &ifilenames; + const FarEntryType entry_type; + const FarTokenType token_type; + const string &begin_key; + const string &end_key; + const bool print_key; + const bool print_weight; + const string &symbols_fname; + const bool initial_symbols; + const int32 generate_filenames; + const string &filename_prefix; + const string &filename_suffix; + + FarPrintStringsArgs( + const vector<string> &ifilenames, const FarEntryType entry_type, + const FarTokenType token_type, const string &begin_key, + const string &end_key, const bool print_key, const bool print_weight, + const string &symbols_fname, const bool initial_symbols, + const int32 generate_filenames, + const string &filename_prefix, const string &filename_suffix) : + ifilenames(ifilenames), entry_type(entry_type), token_type(token_type), + begin_key(begin_key), end_key(end_key), + print_key(print_key), print_weight(print_weight), + symbols_fname(symbols_fname), initial_symbols(initial_symbols), + generate_filenames(generate_filenames), filename_prefix(filename_prefix), + filename_suffix(filename_suffix) { } +}; + +template <class Arc> +void FarPrintStrings(FarPrintStringsArgs *args) { + fst::FarPrintStrings<Arc>( + args->ifilenames, args->entry_type, args->token_type, + args->begin_key, args->end_key, args->print_key, args->print_weight, + args->symbols_fname, args->initial_symbols, args->generate_filenames, + args->filename_prefix, args->filename_suffix); +} + + +void FarPrintStrings(const vector<string> &ifilenames, + const string &arc_type, + const FarEntryType entry_type, + const FarTokenType token_type, + const string &begin_key, + const string &end_key, + const bool print_key, + const bool print_weight, + const string &symbols_fname, + const bool initial_symbols, + const int32 generate_filenames, + const string &filename_prefix, + const string &filename_suffix); + +} // namespace script +} // namespace fst + + +#define REGISTER_FST_FAR_OPERATIONS(ArcType) \ + REGISTER_FST_OPERATION(FarCompileStrings, ArcType, FarCompileStringsArgs); \ + REGISTER_FST_OPERATION(FarCreate, ArcType, FarCreateArgs); \ + REGISTER_FST_OPERATION(FarEqual, ArcType, FarEqualArgs); \ + REGISTER_FST_OPERATION(FarExtract, ArcType, FarExtractArgs); \ + REGISTER_FST_OPERATION(FarInfo, ArcType, FarInfoArgs); \ + REGISTER_FST_OPERATION(FarPrintStrings, ArcType, FarPrintStringsArgs) + +#endif // FST_EXTENSIONS_FAR_FARSCRIPT_H_ diff --git a/kaldi_io/src/tools/openfst/include/fst/extensions/far/info.h b/kaldi_io/src/tools/openfst/include/fst/extensions/far/info.h new file mode 100644 index 0000000..100fe68 --- /dev/null +++ b/kaldi_io/src/tools/openfst/include/fst/extensions/far/info.h @@ -0,0 +1,128 @@ + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Copyright 2005-2010 Google, Inc. +// Author: [email protected] (Cyril Allauzen) +// Modified: [email protected] (Jake Ratkiewicz) + +#ifndef FST_EXTENSIONS_FAR_INFO_H_ +#define FST_EXTENSIONS_FAR_INFO_H_ + +#include <iomanip> +#include <set> +#include <string> +#include <vector> +using std::vector; + +#include <fst/extensions/far/far.h> +#include <fst/extensions/far/main.h> // For FarTypeToString + +namespace fst { + +template <class Arc> +void CountStatesAndArcs(const Fst<Arc> &fst, size_t *nstate, size_t *narc) { + StateIterator<Fst<Arc> > siter(fst); + for (; !siter.Done(); siter.Next(), ++(*nstate)) { + ArcIterator<Fst<Arc> > aiter(fst, siter.Value()); + for (; !aiter.Done(); aiter.Next(), ++(*narc)) {} + } +} + +struct KeyInfo { + string key; + string type; + size_t nstate; + size_t narc; + + KeyInfo(string k, string t, int64 ns = 0, int64 na = 0) + : key(k), type(t), nstate(ns), narc(na) {} +}; + +template <class Arc> +void FarInfo(const vector<string> &filenames, const string &begin_key, + const string &end_key, const bool list_fsts) { + FarReader<Arc> *far_reader = FarReader<Arc>::Open(filenames); + if (!far_reader) return; + + if (!begin_key.empty()) + far_reader->Find(begin_key); + + vector<KeyInfo> *infos = list_fsts ? new vector<KeyInfo>() : 0; + size_t nfst = 0, nstate = 0, narc = 0; + set<string> fst_types; + for (; !far_reader->Done(); far_reader->Next()) { + string key = far_reader->GetKey(); + if (!end_key.empty() && end_key < key) + break; + ++nfst; + const Fst<Arc> &fst = far_reader->GetFst(); + fst_types.insert(fst.Type()); + if (infos) { + KeyInfo info(key, fst.Type()); + CountStatesAndArcs(fst, &info.nstate, &info.narc); + nstate += info.nstate; + nstate += info.narc; + infos->push_back(info); + } else { + CountStatesAndArcs(fst, &nstate, &narc); + } + } + + if (!infos) { + cout << std::left << setw(50) << "far type" + << FarTypeToString(far_reader->Type()) << endl; + cout << std::left << setw(50) << "arc type" << Arc::Type() << endl; + cout << std::left << setw(50) << "fst type"; + for (set<string>::const_iterator iter = fst_types.begin(); + iter != fst_types.end(); + ++iter) { + if (iter != fst_types.begin()) + cout << ","; + cout << *iter; + } + cout << endl; + cout << std::left << setw(50) << "# of FSTs" << nfst << endl; + cout << std::left << setw(50) << "total # of states" << nstate << endl; + cout << std::left << setw(50) << "total # of arcs" << narc << endl; + + } else { + int wkey = 10, wtype = 10, wnstate = 16, wnarc = 16; + for (size_t i = 0; i < infos->size(); ++i) { + const KeyInfo &info = (*infos)[i]; + if (info.key.size() + 2 > wkey) + wkey = info.key.size() + 2; + if (info.type.size() + 2 > wtype) + wtype = info.type.size() + 2; + if (ceil(log10(info.nstate)) + 2 > wnstate) + wnstate = ceil(log10(info.nstate)) + 2; + if (ceil(log10(info.narc)) + 2 > wnarc) + wnarc = ceil(log10(info.narc)) + 2; + } + + cout << std::left << setw(wkey) << "key" << setw(wtype) << "type" + << std::right << setw(wnstate) << "# of states" + << setw(wnarc) << "# of arcs" << endl; + + for (size_t i = 0; i < infos->size(); ++i) { + const KeyInfo &info = (*infos)[i]; + cout << std::left << setw(wkey) << info.key << setw(wtype) << info.type + << std::right << setw(wnstate) << info.nstate + << setw(wnarc) << info.narc << endl; + } + } +} + +} // namespace fst + + +#endif // FST_EXTENSIONS_FAR_INFO_H_ diff --git a/kaldi_io/src/tools/openfst/include/fst/extensions/far/main.h b/kaldi_io/src/tools/openfst/include/fst/extensions/far/main.h new file mode 100644 index 0000000..00ccfef --- /dev/null +++ b/kaldi_io/src/tools/openfst/include/fst/extensions/far/main.h @@ -0,0 +1,43 @@ +// main.h + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Copyright 2005-2010 Google, Inc. +// Author: [email protected] (Michael Riley) +// +// \file +// Classes and functions for registering and invoking Far main +// functions that support multiple and extensible arc types. + +#ifndef FST_EXTENSIONS_FAR_MAIN_H__ +#define FST_EXTENSIONS_FAR_MAIN_H__ + +#include <fst/extensions/far/far.h> + +namespace fst { + +FarEntryType StringToFarEntryType(const string &s); +FarTokenType StringToFarTokenType(const string &s); + +// Return the 'FarType' value corresponding to a far type name. +FarType FarTypeFromString(const string &str); + +// Return the textual name corresponding to a 'FarType;. +string FarTypeToString(FarType type); + +string LoadArcTypeFromFar(const string& far_fname); +string LoadArcTypeFromFst(const string& far_fname); + +} // namespace fst + +#endif // FST_EXTENSIONS_FAR_MAIN_H__ diff --git a/kaldi_io/src/tools/openfst/include/fst/extensions/far/print-strings.h b/kaldi_io/src/tools/openfst/include/fst/extensions/far/print-strings.h new file mode 100644 index 0000000..dcc7351 --- /dev/null +++ b/kaldi_io/src/tools/openfst/include/fst/extensions/far/print-strings.h @@ -0,0 +1,138 @@ +// printstrings-main.h + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Copyright 2005-2010 Google, Inc. +// Author: [email protected] (Cyril Allauzen) +// Modified by: [email protected] (Jake Ratkiewicz) +// +// \file +// Output as strings the string FSTs in a finite-state archive. + +#ifndef FST_EXTENSIONS_FAR_PRINT_STRINGS_H__ +#define FST_EXTENSIONS_FAR_PRINT_STRINGS_H__ + +#include <string> +#include <vector> +using std::vector; + +#include <fst/extensions/far/far.h> +#include <fst/shortest-distance.h> +#include <fst/string.h> + +DECLARE_string(far_field_separator); + +namespace fst { + +template <class Arc> +void FarPrintStrings( + const vector<string> &ifilenames, const FarEntryType entry_type, + const FarTokenType far_token_type, const string &begin_key, + const string &end_key, const bool print_key, const bool print_weight, + const string &symbols_fname, const bool initial_symbols, + const int32 generate_filenames, + const string &filename_prefix, const string &filename_suffix) { + + typename StringPrinter<Arc>::TokenType token_type; + if (far_token_type == FTT_SYMBOL) { + token_type = StringPrinter<Arc>::SYMBOL; + } else if (far_token_type == FTT_BYTE) { + token_type = StringPrinter<Arc>::BYTE; + } else if (far_token_type == FTT_UTF8) { + token_type = StringPrinter<Arc>::UTF8; + } else { + FSTERROR() << "FarPrintStrings: unknown token type"; + return; + } + + const SymbolTable *syms = 0; + if (!symbols_fname.empty()) { + // allow negative flag? + SymbolTableTextOptions opts; + opts.allow_negative = true; + syms = SymbolTable::ReadText(symbols_fname, opts); + if (!syms) { + FSTERROR() << "FarPrintStrings: error reading symbol table: " + << symbols_fname; + return; + } + } + + FarReader<Arc> *far_reader = FarReader<Arc>::Open(ifilenames); + if (!far_reader) return; + + if (!begin_key.empty()) + far_reader->Find(begin_key); + + string okey; + int nrep = 0; + for (int i = 1; !far_reader->Done(); far_reader->Next(), ++i) { + string key = far_reader->GetKey(); + if (!end_key.empty() && end_key < key) + break; + if (okey == key) + ++nrep; + else + nrep = 0; + okey = key; + + const Fst<Arc> &fst = far_reader->GetFst(); + if (i == 1 && initial_symbols && syms == 0 && fst.InputSymbols() != 0) + syms = fst.InputSymbols()->Copy(); + string str; + VLOG(2) << "Handling key: " << key; + StringPrinter<Arc> string_printer( + token_type, syms ? syms : fst.InputSymbols()); + string_printer(fst, &str); + + if (entry_type == FET_LINE) { + if (print_key) + cout << key << FLAGS_far_field_separator[0]; + cout << str; + if (print_weight) + cout << FLAGS_far_field_separator[0] << ShortestDistance(fst); + cout << endl; + } else if (entry_type == FET_FILE) { + stringstream sstrm; + if (generate_filenames) { + sstrm.fill('0'); + sstrm << std::right << setw(generate_filenames) << i; + } else { + sstrm << key; + if (nrep > 0) + sstrm << "." << nrep; + } + + string filename; + filename = filename_prefix + sstrm.str() + filename_suffix; + + ofstream ostrm(filename.c_str()); + if (!ostrm) { + FSTERROR() << "FarPrintStrings: Can't open file:" << filename; + delete syms; + delete far_reader; + return; + } + ostrm << str; + if (token_type == StringPrinter<Arc>::SYMBOL) + ostrm << "\n"; + } + } + delete syms; +} + + + +} // namespace fst + +#endif // FST_EXTENSIONS_FAR_PRINT_STRINGS_H__ diff --git a/kaldi_io/src/tools/openfst/include/fst/extensions/far/stlist.h b/kaldi_io/src/tools/openfst/include/fst/extensions/far/stlist.h new file mode 100644 index 0000000..ff3d98b --- /dev/null +++ b/kaldi_io/src/tools/openfst/include/fst/extensions/far/stlist.h @@ -0,0 +1,305 @@ + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Copyright 2005-2010 Google, Inc. +// Author: [email protected] (Cyril Allauzen) +// +// \file +// A generic (string,type) list file format. +// +// This is a stripped-down version of STTable that does +// not support the Find() operation but that does support +// reading/writting from standard in/out. + +#ifndef FST_EXTENSIONS_FAR_STLIST_H_ +#define FST_EXTENSIONS_FAR_STLIST_H_ + +#include <iostream> +#include <fstream> +#include <sstream> +#include <fst/util.h> + +#include <algorithm> +#include <functional> +#include <queue> +#include <string> +#include <utility> +using std::pair; using std::make_pair; +#include <vector> +using std::vector; + +namespace fst { + +static const int32 kSTListMagicNumber = 5656924; +static const int32 kSTListFileVersion = 1; + +// String-type list writing class for object of type 'T' using functor 'W' +// to write an object of type 'T' from a stream. 'W' must conform to the +// following interface: +// +// struct Writer { +// void operator()(ostream &, const T &) const; +// }; +// +template <class T, class W> +class STListWriter { + public: + typedef T EntryType; + typedef W EntryWriter; + + explicit STListWriter(const string filename) + : stream_( + filename.empty() ? &cout : + new ofstream(filename.c_str(), ofstream::out | ofstream::binary)), + error_(false) { + WriteType(*stream_, kSTListMagicNumber); + WriteType(*stream_, kSTListFileVersion); + if (!stream_) { + FSTERROR() << "STListWriter::STListWriter: error writing to file: " + << filename; + error_ = true; + } + } + + static STListWriter<T, W> *Create(const string &filename) { + return new STListWriter<T, W>(filename); + } + + void Add(const string &key, const T &t) { + if (key == "") { + FSTERROR() << "STListWriter::Add: key empty: " << key; + error_ = true; + } else if (key < last_key_) { + FSTERROR() << "STListWriter::Add: key disorder: " << key; + error_ = true; + } + if (error_) return; + last_key_ = key; + WriteType(*stream_, key); + entry_writer_(*stream_, t); + } + + bool Error() const { return error_; } + + ~STListWriter() { + WriteType(*stream_, string()); + if (stream_ != &cout) + delete stream_; + } + + private: + EntryWriter entry_writer_; // Write functor for 'EntryType' + ostream *stream_; // Output stream + string last_key_; // Last key + bool error_; + + DISALLOW_COPY_AND_ASSIGN(STListWriter); +}; + + +// String-type list reading class for object of type 'T' using functor 'R' +// to read an object of type 'T' form a stream. 'R' must conform to the +// following interface: +// +// struct Reader { +// T *operator()(istream &) const; +// }; +// +template <class T, class R> +class STListReader { + public: + typedef T EntryType; + typedef R EntryReader; + + explicit STListReader(const vector<string> &filenames) + : sources_(filenames), entry_(0), error_(false) { + streams_.resize(filenames.size(), 0); + bool has_stdin = false; + for (size_t i = 0; i < filenames.size(); ++i) { + if (filenames[i].empty()) { + if (!has_stdin) { + streams_[i] = &cin; + sources_[i] = "stdin"; + has_stdin = true; + } else { + FSTERROR() << "STListReader::STListReader: stdin should only " + << "appear once in the input file list."; + error_ = true; + return; + } + } else { + streams_[i] = new ifstream( + filenames[i].c_str(), ifstream::in | ifstream::binary); + } + int32 magic_number = 0, file_version = 0; + ReadType(*streams_[i], &magic_number); + ReadType(*streams_[i], &file_version); + if (magic_number != kSTListMagicNumber) { + FSTERROR() << "STListReader::STListReader: wrong file type: " + << filenames[i]; + error_ = true; + return; + } + if (file_version != kSTListFileVersion) { + FSTERROR() << "STListReader::STListReader: wrong file version: " + << filenames[i]; + error_ = true; + return; + } + string key; + ReadType(*streams_[i], &key); + if (!key.empty()) + heap_.push(make_pair(key, i)); + if (!*streams_[i]) { + FSTERROR() << "STListReader: error reading file: " << sources_[i]; + error_ = true; + return; + } + } + if (heap_.empty()) return; + size_t current = heap_.top().second; + entry_ = entry_reader_(*streams_[current]); + if (!entry_ || !*streams_[current]) { + FSTERROR() << "STListReader: error reading entry for key: " + << heap_.top().first << ", file: " << sources_[current]; + error_ = true; + } + } + + ~STListReader() { + for (size_t i = 0; i < streams_.size(); ++i) { + if (streams_[i] != &cin) + delete streams_[i]; + } + if (entry_) + delete entry_; + } + + static STListReader<T, R> *Open(const string &filename) { + vector<string> filenames; + filenames.push_back(filename); + return new STListReader<T, R>(filenames); + } + + static STListReader<T, R> *Open(const vector<string> &filenames) { + return new STListReader<T, R>(filenames); + } + + void Reset() { + FSTERROR() + << "STListReader::Reset: stlist does not support reset operation"; + error_ = true; + } + + bool Find(const string &key) { + FSTERROR() + << "STListReader::Find: stlist does not support find operation"; + error_ = true; + return false; + } + + bool Done() const { + return error_ || heap_.empty(); + } + + void Next() { + if (error_) return; + size_t current = heap_.top().second; + string key; + heap_.pop(); + ReadType(*(streams_[current]), &key); + if (!*streams_[current]) { + FSTERROR() << "STListReader: error reading file: " + << sources_[current]; + error_ = true; + return; + } + if (!key.empty()) + heap_.push(make_pair(key, current)); + + if(!heap_.empty()) { + current = heap_.top().second; + if (entry_) + delete entry_; + entry_ = entry_reader_(*streams_[current]); + if (!entry_ || !*streams_[current]) { + FSTERROR() << "STListReader: error reading entry for key: " + << heap_.top().first << ", file: " << sources_[current]; + error_ = true; + } + } + } + + const string &GetKey() const { + return heap_.top().first; + } + + const EntryType &GetEntry() const { + return *entry_; + } + + bool Error() const { return error_; } + + private: + EntryReader entry_reader_; // Read functor for 'EntryType' + vector<istream*> streams_; // Input streams + vector<string> sources_; // and corresponding file names + priority_queue< + pair<string, size_t>, vector<pair<string, size_t> >, + greater<pair<string, size_t> > > heap_; // (Key, stream id) heap + mutable EntryType *entry_; // Pointer to the currently read entry + bool error_; + + DISALLOW_COPY_AND_ASSIGN(STListReader); +}; + + +// String-type list header reading function template on the entry header +// type 'H' having a member function: +// Read(istream &strm, const string &filename); +// Checks that 'filename' is an STList and call the H::Read() on the last +// entry in the STList. +// Does not support reading from stdin. +template <class H> +bool ReadSTListHeader(const string &filename, H *header) { + if (filename.empty()) { + LOG(ERROR) << "ReadSTListHeader: reading header not supported on stdin"; + return false; + } + ifstream strm(filename.c_str(), ifstream::in | ifstream::binary); + int32 magic_number = 0, file_version = 0; + ReadType(strm, &magic_number); + ReadType(strm, &file_version); + if (magic_number != kSTListMagicNumber) { + LOG(ERROR) << "ReadSTListHeader: wrong file type: " << filename; + return false; + } + if (file_version != kSTListFileVersion) { + LOG(ERROR) << "ReadSTListHeader: wrong file version: " << filename; + return false; + } + string key; + ReadType(strm, &key); + header->Read(strm, filename + ":" + key); + if (!strm) { + LOG(ERROR) << "ReadSTListHeader: error reading file: " << filename; + return false; + } + return true; +} + +bool IsSTList(const string &filename); + +} // namespace fst + +#endif // FST_EXTENSIONS_FAR_STLIST_H_ diff --git a/kaldi_io/src/tools/openfst/include/fst/extensions/far/sttable.h b/kaldi_io/src/tools/openfst/include/fst/extensions/far/sttable.h new file mode 100644 index 0000000..3ce0a4b --- /dev/null +++ b/kaldi_io/src/tools/openfst/include/fst/extensions/far/sttable.h @@ -0,0 +1,371 @@ +// sttable.h + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Copyright 2005-2010 Google, Inc. +// Author: [email protected] (Cyril Allauzen) +// +// \file +// A generic string-to-type table file format +// +// This is not meant as a generalization of SSTable. This is more of +// a simple replacement for SSTable in order to provide an open-source +// implementation of the FAR format for the external version of the +// FST Library. + +#ifndef FST_EXTENSIONS_FAR_STTABLE_H_ +#define FST_EXTENSIONS_FAR_STTABLE_H_ + +#include <algorithm> +#include <iostream> +#include <fstream> +#include <sstream> +#include <fst/util.h> + +namespace fst { + +static const int32 kSTTableMagicNumber = 2125656924; +static const int32 kSTTableFileVersion = 1; + +// String-to-type table writing class for object of type 'T' using functor 'W' +// to write an object of type 'T' from a stream. 'W' must conform to the +// following interface: +// +// struct Writer { +// void operator()(ostream &, const T &) const; +// }; +// +template <class T, class W> +class STTableWriter { + public: + typedef T EntryType; + typedef W EntryWriter; + + explicit STTableWriter(const string &filename) + : stream_(filename.c_str(), ofstream::out | ofstream::binary), + error_(false) { + WriteType(stream_, kSTTableMagicNumber); + WriteType(stream_, kSTTableFileVersion); + if (!stream_) { + FSTERROR() << "STTableWriter::STTableWriter: error writing to file: " + << filename; + error_=true; + } + } + + static STTableWriter<T, W> *Create(const string &filename) { + if (filename.empty()) { + LOG(ERROR) << "STTableWriter: writing to standard out unsupported."; + return 0; + } + return new STTableWriter<T, W>(filename); + } + + void Add(const string &key, const T &t) { + if (key == "") { + FSTERROR() << "STTableWriter::Add: key empty: " << key; + error_ = true; + } else if (key < last_key_) { + FSTERROR() << "STTableWriter::Add: key disorder: " << key; + error_ = true; + } + if (error_) return; + last_key_ = key; + positions_.push_back(stream_.tellp()); + WriteType(stream_, key); + entry_writer_(stream_, t); + } + + bool Error() const { return error_; } + + ~STTableWriter() { + WriteType(stream_, positions_); + WriteType(stream_, static_cast<int64>(positions_.size())); + } + + private: + EntryWriter entry_writer_; // Write functor for 'EntryType' + ofstream stream_; // Output stream + vector<int64> positions_; // Position in file of each key-entry pair + string last_key_; // Last key + bool error_; + + DISALLOW_COPY_AND_ASSIGN(STTableWriter); +}; + + +// String-to-type table reading class for object of type 'T' using functor 'R' +// to read an object of type 'T' form a stream. 'R' must conform to the +// following interface: +// +// struct Reader { +// T *operator()(istream &) const; +// }; +// +template <class T, class R> +class STTableReader { + public: + typedef T EntryType; + typedef R EntryReader; + + explicit STTableReader(const vector<string> &filenames) + : sources_(filenames), entry_(0), error_(false) { + compare_ = new Compare(&keys_); + keys_.resize(filenames.size()); + streams_.resize(filenames.size(), 0); + positions_.resize(filenames.size()); + for (size_t i = 0; i < filenames.size(); ++i) { + streams_[i] = new ifstream( + filenames[i].c_str(), ifstream::in | ifstream::binary); + int32 magic_number = 0, file_version = 0; + ReadType(*streams_[i], &magic_number); + ReadType(*streams_[i], &file_version); + if (magic_number != kSTTableMagicNumber) { + FSTERROR() << "STTableReader::STTableReader: wrong file type: " + << filenames[i]; + error_ = true; + return; + } + if (file_version != kSTTableFileVersion) { + FSTERROR() << "STTableReader::STTableReader: wrong file version: " + << filenames[i]; + error_ = true; + return; + } + int64 num_entries; + streams_[i]->seekg(-static_cast<int>(sizeof(int64)), ios_base::end); + ReadType(*streams_[i], &num_entries); + streams_[i]->seekg(-static_cast<int>(sizeof(int64)) * + (num_entries + 1), ios_base::end); + positions_[i].resize(num_entries); + for (size_t j = 0; (j < num_entries) && (*streams_[i]); ++j) + ReadType(*streams_[i], &(positions_[i][j])); + streams_[i]->seekg(positions_[i][0]); + if (!*streams_[i]) { + FSTERROR() << "STTableReader::STTableReader: error reading file: " + << filenames[i]; + error_ = true; + return; + } + + } + MakeHeap(); + } + + ~STTableReader() { + for (size_t i = 0; i < streams_.size(); ++i) + delete streams_[i]; + delete compare_; + if (entry_) + delete entry_; + } + + static STTableReader<T, R> *Open(const string &filename) { + if (filename.empty()) { + LOG(ERROR) << "STTableReader: reading from standard in not supported"; + return 0; + } + vector<string> filenames; + filenames.push_back(filename); + return new STTableReader<T, R>(filenames); + } + + static STTableReader<T, R> *Open(const vector<string> &filenames) { + return new STTableReader<T, R>(filenames); + } + + void Reset() { + if (error_) return; + for (size_t i = 0; i < streams_.size(); ++i) + streams_[i]->seekg(positions_[i].front()); + MakeHeap(); + } + + bool Find(const string &key) { + if (error_) return false; + for (size_t i = 0; i < streams_.size(); ++i) + LowerBound(i, key); + MakeHeap(); + return keys_[current_] == key; + } + + bool Done() const { return error_ || heap_.empty(); } + + void Next() { + if (error_) return; + if (streams_[current_]->tellg() <= positions_[current_].back()) { + ReadType(*(streams_[current_]), &(keys_[current_])); + if (!*streams_[current_]) { + FSTERROR() << "STTableReader: error reading file: " + << sources_[current_]; + error_ = true; + return; + } + push_heap(heap_.begin(), heap_.end(), *compare_); + } else { + heap_.pop_back(); + } + if (!heap_.empty()) + PopHeap(); + } + + const string &GetKey() const { + return keys_[current_]; + } + + const EntryType &GetEntry() const { + return *entry_; + } + + bool Error() const { return error_; } + + private: + // Comparison functor used to compare stream IDs in the heap + struct Compare { + Compare(const vector<string> *keys) : keys_(keys) {} + + bool operator()(size_t i, size_t j) const { + return (*keys_)[i] > (*keys_)[j]; + }; + + private: + const vector<string> *keys_; + }; + + // Position the stream with ID 'id' at the position corresponding + // to the lower bound for key 'find_key' + void LowerBound(size_t id, const string &find_key) { + ifstream *strm = streams_[id]; + const vector<int64> &positions = positions_[id]; + size_t low = 0, high = positions.size() - 1; + + while (low < high) { + size_t mid = (low + high)/2; + strm->seekg(positions[mid]); + string key; + ReadType(*strm, &key); + if (key > find_key) { + high = mid; + } else if (key < find_key) { + low = mid + 1; + } else { + for (size_t i = mid; i > low; --i) { + strm->seekg(positions[i - 1]); + ReadType(*strm, &key); + if (key != find_key) { + strm->seekg(positions[i]); + return; + } + } + strm->seekg(positions[low]); + return; + } + } + strm->seekg(positions[low]); + } + + // Add all streams to the heap + void MakeHeap() { + heap_.clear(); + for (size_t i = 0; i < streams_.size(); ++i) { + ReadType(*streams_[i], &(keys_[i])); + if (!*streams_[i]) { + FSTERROR() << "STTableReader: error reading file: " << sources_[i]; + error_ = true; + return; + } + heap_.push_back(i); + } + make_heap(heap_.begin(), heap_.end(), *compare_); + PopHeap(); + } + + // Position the stream with the lowest key at the top + // of the heap, set 'current_' to the ID of that stream + // and read the current entry from that stream + void PopHeap() { + pop_heap(heap_.begin(), heap_.end(), *compare_); + current_ = heap_.back(); + if (entry_) + delete entry_; + entry_ = entry_reader_(*streams_[current_]); + if (!entry_) + error_ = true; + if (!*streams_[current_]) { + FSTERROR() << "STTableReader: error reading entry for key: " + << keys_[current_] << ", file: " << sources_[current_]; + error_ = true; + } + } + + + EntryReader entry_reader_; // Read functor for 'EntryType' + vector<ifstream*> streams_; // Input streams + vector<string> sources_; // and corresponding file names + vector<vector<int64> > positions_; // Index of positions for each stream + vector<string> keys_; // Lowest unread key for each stream + vector<int64> heap_; // Heap containing ID of streams with unread keys + int64 current_; // Id of current stream to be read + Compare *compare_; // Functor comparing stream IDs for the heap + mutable EntryType *entry_; // Pointer to the currently read entry + bool error_; + + DISALLOW_COPY_AND_ASSIGN(STTableReader); +}; + + +// String-to-type table header reading function template on the entry header +// type 'H' having a member function: +// Read(istream &strm, const string &filename); +// Checks that 'filename' is an STTable and call the H::Read() on the last +// entry in the STTable. +template <class H> +bool ReadSTTableHeader(const string &filename, H *header) { + ifstream strm(filename.c_str(), ifstream::in | ifstream::binary); + int32 magic_number = 0, file_version = 0; + ReadType(strm, &magic_number); + ReadType(strm, &file_version); + if (magic_number != kSTTableMagicNumber) { + LOG(ERROR) << "ReadSTTableHeader: wrong file type: " << filename; + return false; + } + if (file_version != kSTTableFileVersion) { + LOG(ERROR) << "ReadSTTableHeader: wrong file version: " << filename; + return false; + } + int64 i = -1; + strm.seekg(-static_cast<int>(sizeof(int64)), ios_base::end); + ReadType(strm, &i); // Read number of entries + if (!strm) { + LOG(ERROR) << "ReadSTTableHeader: error reading file: " << filename; + return false; + } + if (i == 0) return true; // No entry header to read + strm.seekg(-2 * static_cast<int>(sizeof(int64)), ios_base::end); + ReadType(strm, &i); // Read position for last entry in file + strm.seekg(i); + string key; + ReadType(strm, &key); + header->Read(strm, filename + ":" + key); + if (!strm) { + LOG(ERROR) << "ReadSTTableHeader: error reading file: " << filename; + return false; + } + return true; +} + +bool IsSTTable(const string &filename); + +} // namespace fst + +#endif // FST_EXTENSIONS_FAR_STTABLE_H_ diff --git a/kaldi_io/src/tools/openfst/include/fst/extensions/ngram/bitmap-index.h b/kaldi_io/src/tools/openfst/include/fst/extensions/ngram/bitmap-index.h new file mode 100644 index 0000000..f5a5ba7 --- /dev/null +++ b/kaldi_io/src/tools/openfst/include/fst/extensions/ngram/bitmap-index.h @@ -0,0 +1,183 @@ + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Copyright 2005-2010 Google, Inc. +// Author: [email protected] (Jeffrey Sorensen) + +#ifndef FST_EXTENSIONS_NGRAM_BITMAP_INDEX_H_ +#define FST_EXTENSIONS_NGRAM_BITMAP_INDEX_H_ + +#include <vector> +using std::vector; + +#include <fst/compat.h> + +// This class is a bitstring storage class with an index that allows +// seeking to the Nth set or clear bit in time O(Log(N)) where N is +// the length of the bit vector. In addition, it allows counting set or +// clear bits over ranges in constant time. +// +// This is accomplished by maintaining an "secondary" index of limited +// size in bits that maintains a running count of the number of bits set +// in each block of bitmap data. A block is defined as the number of +// uint64 values that can fit in the secondary index before an overflow +// occurs. +// +// To handle overflows, a "primary" index containing a running count of +// bits set in each block is created using the type uint64. + +namespace fst { + +class BitmapIndex { + public: + static size_t StorageSize(size_t size) { + return ((size + kStorageBlockMask) >> kStorageLogBitSize); + } + + BitmapIndex() : bits_(NULL), size_(0) { } + + bool Get(size_t index) const { + return (bits_[index >> kStorageLogBitSize] & + (kOne << (index & kStorageBlockMask))) != 0; + } + + static void Set(uint64* bits, size_t index) { + bits[index >> kStorageLogBitSize] |= (kOne << (index & kStorageBlockMask)); + } + + static void Clear(uint64* bits, size_t index) { + bits[index >> kStorageLogBitSize] &= ~(kOne << (index & kStorageBlockMask)); + } + + size_t Bits() const { + return size_; + } + + size_t ArraySize() const { + return StorageSize(size_); + } + + // Returns the number of one bits in the bitmap + size_t GetOnesCount() const { + return primary_index_[primary_index_size() - 1]; + } + + // Returns the number of one bits in positions 0 to limit - 1. + // REQUIRES: limit <= Bits() + size_t Rank1(size_t end) const; + + // Returns the number of one bits in the range start to end - 1. + // REQUIRES: limit <= Bits() + size_t GetOnesCountInRange(size_t start, size_t end) const { + return Rank1(end) - Rank1(start); + } + + // Returns the number of zero bits in positions 0 to limit - 1. + // REQUIRES: limit <= Bits() + size_t Rank0(size_t end) const { + return end - Rank1(end); + } + + // Returns the number of zero bits in the range start to end - 1. + // REQUIRES: limit <= Bits() + size_t GetZeroesCountInRange(size_t start, size_t end) const { + return end - start - GetOnesCountInRange(start, end); + } + + // Return true if any bit between begin inclusive and end exclusive + // is set. 0 <= begin <= end <= Bits() is required. + // + bool TestRange(size_t start, size_t end) const { + return Rank1(end) > Rank1(start); + } + + // Returns the offset to the nth set bit (zero based) + // or Bits() if index >= number of ones + size_t Select1(size_t bit_index) const; + + // Returns the offset to the nth clear bit (zero based) + // or Bits() if index > number of + size_t Select0(size_t bit_index) const; + + // Rebuilds from index for the associated Bitmap, should be called + // whenever changes have been made to the Bitmap or else behavior + // of the indexed bitmap methods will be undefined. + void BuildIndex(const uint64 *bits, size_t size); + + // the secondary index accumulates counts until it can possibly overflow + // this constant computes the number of uint64 units that can fit into + // units the size of uint16. + static const uint64 kOne = 1; + static const uint32 kStorageBitSize = 64; + static const uint32 kStorageLogBitSize = 6; + static const uint32 kSecondaryBlockSize = ((1 << 16) - 1) + >> kStorageLogBitSize; + + private: + static const uint32 kStorageBlockMask = kStorageBitSize - 1; + + // returns, from the index, the count of ones up to array_index + size_t get_index_ones_count(size_t array_index) const; + + // because the indexes, both primary and secondary, contain a running + // count of the population of one bits contained in [0,i), there is + // no reason to have an element in the zeroth position as this value would + // necessarily be zero. (The bits are indexed in a zero based way.) Thus + // we don't store the 0th element in either index. Both of the following + // functions, if greater than 0, must be decremented by one before retreiving + // the value from the corresponding array. + // returns the 1 + the block that contains the bitindex in question + // the inverted version works the same but looks for zeros using an inverted + // view of the index + size_t find_primary_block(size_t bit_index) const; + + size_t find_inverted_primary_block(size_t bit_index) const; + + // similarly, the secondary index (which resets its count to zero at + // the end of every kSecondaryBlockSize entries) does not store the element + // at 0. Note that the rem_bit_index parameter is the number of bits + // within the secondary block, after the bits accounted for by the primary + // block have been removed (i.e. the remaining bits) And, because we + // reset to zero with each new block, there is no need to store those + // actual zeros. + // returns 1 + the secondary block that contains the bitindex in question + size_t find_secondary_block(size_t block, size_t rem_bit_index) const; + + size_t find_inverted_secondary_block(size_t block, size_t rem_bit_index) + const; + + // We create a primary index based upon the number of secondary index + // blocks. The primary index uses fields wide enough to accomodate any + // index of the bitarray so cannot overflow + // The primary index is the actual running + // count of one bits set for all blocks (and, thus, all uint64s). + size_t primary_index_size() const { + return (ArraySize() + kSecondaryBlockSize - 1) / kSecondaryBlockSize; + } + + const uint64* bits_; + size_t size_; + + // The primary index contains the running popcount of all blocks + // which means the nth value contains the popcounts of + // [0,n*kSecondaryBlockSize], however, the 0th element is omitted. + vector<uint32> primary_index_; + // The secondary index contains the running popcount of the associated + // bitmap. It is the same length (in units of uint16) as the + // bitmap's map is in units of uint64s. + vector<uint16> secondary_index_; +}; + +} // end namespace fst + +#endif // FST_EXTENSIONS_NGRAM_BITMAP_INDEX_H_ diff --git a/kaldi_io/src/tools/openfst/include/fst/extensions/ngram/ngram-fst.h b/kaldi_io/src/tools/openfst/include/fst/extensions/ngram/ngram-fst.h new file mode 100644 index 0000000..d113fb3 --- /dev/null +++ b/kaldi_io/src/tools/openfst/include/fst/extensions/ngram/ngram-fst.h @@ -0,0 +1,934 @@ + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Copyright 2005-2010 Google, Inc. +// Author: [email protected] (Jeffrey Sorensen) +// +#ifndef FST_EXTENSIONS_NGRAM_NGRAM_FST_H_ +#define FST_EXTENSIONS_NGRAM_NGRAM_FST_H_ + +#include <stddef.h> +#include <string.h> +#include <algorithm> +#include <string> +#include <vector> +using std::vector; + +#include <fst/compat.h> +#include <fst/fstlib.h> +#include <fst/mapped-file.h> +#include <fst/extensions/ngram/bitmap-index.h> + +// NgramFst implements a n-gram language model based upon the LOUDS data +// structure. Please refer to "Unary Data Strucutres for Language Models" +// http://research.google.com/pubs/archive/37218.pdf + +namespace fst { +template <class A> class NGramFst; +template <class A> class NGramFstMatcher; + +// Instance data containing mutable state for bookkeeping repeated access to +// the same state. +template <class A> +struct NGramFstInst { + typedef typename A::Label Label; + typedef typename A::StateId StateId; + typedef typename A::Weight Weight; + StateId state_; + size_t num_futures_; + size_t offset_; + size_t node_; + StateId node_state_; + vector<Label> context_; + StateId context_state_; + NGramFstInst() + : state_(kNoStateId), node_state_(kNoStateId), + context_state_(kNoStateId) { } +}; + +// Implementation class for LOUDS based NgramFst interface +template <class A> +class NGramFstImpl : public FstImpl<A> { + using FstImpl<A>::SetInputSymbols; + using FstImpl<A>::SetOutputSymbols; + using FstImpl<A>::SetType; + using FstImpl<A>::WriteHeader; + + friend class ArcIterator<NGramFst<A> >; + friend class NGramFstMatcher<A>; + + public: + using FstImpl<A>::InputSymbols; + using FstImpl<A>::SetProperties; + using FstImpl<A>::Properties; + + typedef A Arc; + typedef typename A::Label Label; + typedef typename A::StateId StateId; + typedef typename A::Weight Weight; + + NGramFstImpl() : data_region_(0), data_(0), owned_(false) { + SetType("ngram"); + SetInputSymbols(NULL); + SetOutputSymbols(NULL); + SetProperties(kStaticProperties); + } + + NGramFstImpl(const Fst<A> &fst, vector<StateId>* order_out); + + ~NGramFstImpl() { + if (owned_) { + delete [] data_; + } + delete data_region_; + } + + static NGramFstImpl<A>* Read(istream &strm, // NOLINT + const FstReadOptions &opts) { + NGramFstImpl<A>* impl = new NGramFstImpl(); + FstHeader hdr; + if (!impl->ReadHeader(strm, opts, kMinFileVersion, &hdr)) return 0; + uint64 num_states, num_futures, num_final; + const size_t offset = sizeof(num_states) + sizeof(num_futures) + + sizeof(num_final); + // Peek at num_states and num_futures to see how much more needs to be read. + strm.read(reinterpret_cast<char *>(&num_states), sizeof(num_states)); + strm.read(reinterpret_cast<char *>(&num_futures), sizeof(num_futures)); + strm.read(reinterpret_cast<char *>(&num_final), sizeof(num_final)); + size_t size = Storage(num_states, num_futures, num_final); + MappedFile *data_region = MappedFile::Allocate(size); + char *data = reinterpret_cast<char *>(data_region->mutable_data()); + // Copy num_states, num_futures and num_final back into data. + memcpy(data, reinterpret_cast<char *>(&num_states), sizeof(num_states)); + memcpy(data + sizeof(num_states), reinterpret_cast<char *>(&num_futures), + sizeof(num_futures)); + memcpy(data + sizeof(num_states) + sizeof(num_futures), + reinterpret_cast<char *>(&num_final), sizeof(num_final)); + strm.read(data + offset, size - offset); + if (!strm) { + delete impl; + return NULL; + } + impl->Init(data, false, data_region); + return impl; + } + + bool Write(ostream &strm, // NOLINT + const FstWriteOptions &opts) const { + FstHeader hdr; + hdr.SetStart(Start()); + hdr.SetNumStates(num_states_); + WriteHeader(strm, opts, kFileVersion, &hdr); + strm.write(data_, Storage(num_states_, num_futures_, num_final_)); + return strm; + } + + StateId Start() const { + return 1; + } + + Weight Final(StateId state) const { + if (final_index_.Get(state)) { + return final_probs_[final_index_.Rank1(state)]; + } else { + return Weight::Zero(); + } + } + + size_t NumArcs(StateId state, NGramFstInst<A> *inst = NULL) const { + if (inst == NULL) { + const size_t next_zero = future_index_.Select0(state + 1); + const size_t this_zero = future_index_.Select0(state); + return next_zero - this_zero - 1; + } + SetInstFuture(state, inst); + return inst->num_futures_ + ((state == 0) ? 0 : 1); + } + + size_t NumInputEpsilons(StateId state) const { + // State 0 has no parent, thus no backoff. + if (state == 0) return 0; + return 1; + } + + size_t NumOutputEpsilons(StateId state) const { + return NumInputEpsilons(state); + } + + StateId NumStates() const { + return num_states_; + } + + void InitStateIterator(StateIteratorData<A>* data) const { + data->base = 0; + data->nstates = num_states_; + } + + static size_t Storage(uint64 num_states, uint64 num_futures, + uint64 num_final) { + uint64 b64; + Weight weight; + Label label; + size_t offset = sizeof(num_states) + sizeof(num_futures) + + sizeof(num_final); + offset += sizeof(b64) * ( + BitmapIndex::StorageSize(num_states * 2 + 1) + + BitmapIndex::StorageSize(num_futures + num_states + 1) + + BitmapIndex::StorageSize(num_states)); + offset += (num_states + 1) * sizeof(label) + num_futures * sizeof(label); + // Pad for alignemnt, see + // http://en.wikipedia.org/wiki/Data_structure_alignment#Computing_padding + offset = (offset + sizeof(weight) - 1) & ~(sizeof(weight) - 1); + offset += (num_states + 1) * sizeof(weight) + num_final * sizeof(weight) + + (num_futures + 1) * sizeof(weight); + return offset; + } + + void SetInstFuture(StateId state, NGramFstInst<A> *inst) const { + if (inst->state_ != state) { + inst->state_ = state; + const size_t next_zero = future_index_.Select0(state + 1); + const size_t this_zero = future_index_.Select0(state); + inst->num_futures_ = next_zero - this_zero - 1; + inst->offset_ = future_index_.Rank1(future_index_.Select0(state) + 1); + } + } + + void SetInstNode(NGramFstInst<A> *inst) const { + if (inst->node_state_ != inst->state_) { + inst->node_state_ = inst->state_; + inst->node_ = context_index_.Select1(inst->state_); + } + } + + void SetInstContext(NGramFstInst<A> *inst) const { + SetInstNode(inst); + if (inst->context_state_ != inst->state_) { + inst->context_state_ = inst->state_; + inst->context_.clear(); + size_t node = inst->node_; + while (node != 0) { + inst->context_.push_back(context_words_[context_index_.Rank1(node)]); + node = context_index_.Select1(context_index_.Rank0(node) - 1); + } + } + } + + // Access to the underlying representation + const char* GetData(size_t* data_size) const { + *data_size = Storage(num_states_, num_futures_, num_final_); + return data_; + } + + void Init(const char* data, bool owned, MappedFile *file = 0); + + const vector<Label> &GetContext(StateId s, NGramFstInst<A> *inst) const { + SetInstFuture(s, inst); + SetInstContext(inst); + return inst->context_; + } + + private: + StateId Transition(const vector<Label> &context, Label future) const; + + // Properties always true for this Fst class. + static const uint64 kStaticProperties = kAcceptor | kIDeterministic | + kODeterministic | kEpsilons | kIEpsilons | kOEpsilons | kILabelSorted | + kOLabelSorted | kWeighted | kCyclic | kInitialAcyclic | kNotTopSorted | + kAccessible | kCoAccessible | kNotString | kExpanded; + // Current file format version. + static const int kFileVersion = 4; + // Minimum file format version supported. + static const int kMinFileVersion = 4; + + MappedFile *data_region_; + const char* data_; + bool owned_; // True if we own data_ + uint64 num_states_, num_futures_, num_final_; + size_t root_num_children_; + const Label *root_children_; + size_t root_first_child_; + // borrowed references + const uint64 *context_, *future_, *final_; + const Label *context_words_, *future_words_; + const Weight *backoff_, *final_probs_, *future_probs_; + BitmapIndex context_index_; + BitmapIndex future_index_; + BitmapIndex final_index_; + + void operator=(const NGramFstImpl<A> &); // Disallow +}; + +template<typename A> +NGramFstImpl<A>::NGramFstImpl(const Fst<A> &fst, vector<StateId>* order_out) + : data_region_(0), data_(0), owned_(false) { + typedef A Arc; + typedef typename Arc::Label Label; + typedef typename Arc::Weight Weight; + typedef typename Arc::StateId StateId; + SetType("ngram"); + SetInputSymbols(fst.InputSymbols()); + SetOutputSymbols(fst.OutputSymbols()); + SetProperties(kStaticProperties); + + // Check basic requirements for an OpenGRM language model Fst. + int64 props = kAcceptor | kIDeterministic | kIEpsilons | kILabelSorted; + if (fst.Properties(props, true) != props) { + FSTERROR() << "NGramFst only accepts OpenGRM langauge models as input"; + SetProperties(kError, kError); + return; + } + + int64 num_states = CountStates(fst); + Label* context = new Label[num_states]; + + // Find the unigram state by starting from the start state, following + // epsilons. + StateId unigram = fst.Start(); + while (1) { + if (unigram == kNoStateId) { + FSTERROR() << "Could not identify unigram state."; + SetProperties(kError, kError); + return; + } + ArcIterator<Fst<A> > aiter(fst, unigram); + if (aiter.Done()) { + LOG(WARNING) << "Unigram state " << unigram << " has no arcs."; + break; + } + if (aiter.Value().ilabel != 0) break; + unigram = aiter.Value().nextstate; + } + + // Each state's context is determined by the subtree it is under from the + // unigram state. + queue<pair<StateId, Label> > label_queue; + vector<bool> visited(num_states); + // Force an epsilon link to the start state. + label_queue.push(make_pair(fst.Start(), 0)); + for (ArcIterator<Fst<A> > aiter(fst, unigram); + !aiter.Done(); aiter.Next()) { + label_queue.push(make_pair(aiter.Value().nextstate, aiter.Value().ilabel)); + } + // investigate states in breadth first fashion to assign context words. + while (!label_queue.empty()) { + pair<StateId, Label> &now = label_queue.front(); + if (!visited[now.first]) { + context[now.first] = now.second; + visited[now.first] = true; + for (ArcIterator<Fst<A> > aiter(fst, now.first); + !aiter.Done(); aiter.Next()) { + const Arc &arc = aiter.Value(); + if (arc.ilabel != 0) { + label_queue.push(make_pair(arc.nextstate, now.second)); + } + } + } + label_queue.pop(); + } + visited.clear(); + + // The arc from the start state should be assigned an epsilon to put it + // in front of the all other labels (which makes Start state 1 after + // unigram which is state 0). + context[fst.Start()] = 0; + + // Build the tree of contexts fst by reversing the epsilon arcs from fst. + VectorFst<Arc> context_fst; + uint64 num_final = 0; + for (int i = 0; i < num_states; ++i) { + if (fst.Final(i) != Weight::Zero()) { + ++num_final; + } + context_fst.SetFinal(context_fst.AddState(), fst.Final(i)); + } + context_fst.SetStart(unigram); + context_fst.SetInputSymbols(fst.InputSymbols()); + context_fst.SetOutputSymbols(fst.OutputSymbols()); + int64 num_context_arcs = 0; + int64 num_futures = 0; + for (StateIterator<Fst<A> > siter(fst); !siter.Done(); siter.Next()) { + const StateId &state = siter.Value(); + num_futures += fst.NumArcs(state) - fst.NumInputEpsilons(state); + ArcIterator<Fst<A> > aiter(fst, state); + if (!aiter.Done()) { + const Arc &arc = aiter.Value(); + // this arc goes from state to arc.nextstate, so create an arc from + // arc.nextstate to state to reverse it. + if (arc.ilabel == 0) { + context_fst.AddArc(arc.nextstate, Arc(context[state], context[state], + arc.weight, state)); + num_context_arcs++; + } + } + } + if (num_context_arcs != context_fst.NumStates() - 1) { + FSTERROR() << "Number of contexts arcs != number of states - 1"; + SetProperties(kError, kError); + return; + } + if (context_fst.NumStates() != num_states) { + FSTERROR() << "Number of contexts != number of states"; + SetProperties(kError, kError); + return; + } + int64 context_props = context_fst.Properties(kIDeterministic | + kILabelSorted, true); + if (!(context_props & kIDeterministic)) { + FSTERROR() << "Input fst is not structured properly"; + SetProperties(kError, kError); + return; + } + if (!(context_props & kILabelSorted)) { + ArcSort(&context_fst, ILabelCompare<Arc>()); + } + + delete [] context; + + uint64 b64; + Weight weight; + Label label = kNoLabel; + const size_t storage = Storage(num_states, num_futures, num_final); + MappedFile *data_region = MappedFile::Allocate(storage); + char *data = reinterpret_cast<char *>(data_region->mutable_data()); + memset(data, 0, storage); + size_t offset = 0; + memcpy(data + offset, reinterpret_cast<char *>(&num_states), + sizeof(num_states)); + offset += sizeof(num_states); + memcpy(data + offset, reinterpret_cast<char *>(&num_futures), + sizeof(num_futures)); + offset += sizeof(num_futures); + memcpy(data + offset, reinterpret_cast<char *>(&num_final), + sizeof(num_final)); + offset += sizeof(num_final); + uint64* context_bits = reinterpret_cast<uint64*>(data + offset); + offset += BitmapIndex::StorageSize(num_states * 2 + 1) * sizeof(b64); + uint64* future_bits = reinterpret_cast<uint64*>(data + offset); + offset += + BitmapIndex::StorageSize(num_futures + num_states + 1) * sizeof(b64); + uint64* final_bits = reinterpret_cast<uint64*>(data + offset); + offset += BitmapIndex::StorageSize(num_states) * sizeof(b64); + Label* context_words = reinterpret_cast<Label*>(data + offset); + offset += (num_states + 1) * sizeof(label); + Label* future_words = reinterpret_cast<Label*>(data + offset); + offset += num_futures * sizeof(label); + offset = (offset + sizeof(weight) - 1) & ~(sizeof(weight) - 1); + Weight* backoff = reinterpret_cast<Weight*>(data + offset); + offset += (num_states + 1) * sizeof(weight); + Weight* final_probs = reinterpret_cast<Weight*>(data + offset); + offset += num_final * sizeof(weight); + Weight* future_probs = reinterpret_cast<Weight*>(data + offset); + int64 context_arc = 0, future_arc = 0, context_bit = 0, future_bit = 0, + final_bit = 0; + + // pseudo-root bits + BitmapIndex::Set(context_bits, context_bit++); + ++context_bit; + context_words[context_arc] = label; + backoff[context_arc] = Weight::Zero(); + context_arc++; + + ++future_bit; + if (order_out) { + order_out->clear(); + order_out->resize(num_states); + } + + queue<StateId> context_q; + context_q.push(context_fst.Start()); + StateId state_number = 0; + while (!context_q.empty()) { + const StateId &state = context_q.front(); + if (order_out) { + (*order_out)[state] = state_number; + } + + const Weight &final = context_fst.Final(state); + if (final != Weight::Zero()) { + BitmapIndex::Set(final_bits, state_number); + final_probs[final_bit] = final; + ++final_bit; + } + + for (ArcIterator<VectorFst<A> > aiter(context_fst, state); + !aiter.Done(); aiter.Next()) { + const Arc &arc = aiter.Value(); + context_words[context_arc] = arc.ilabel; + backoff[context_arc] = arc.weight; + ++context_arc; + BitmapIndex::Set(context_bits, context_bit++); + context_q.push(arc.nextstate); + } + ++context_bit; + + for (ArcIterator<Fst<A> > aiter(fst, state); !aiter.Done(); aiter.Next()) { + const Arc &arc = aiter.Value(); + if (arc.ilabel != 0) { + future_words[future_arc] = arc.ilabel; + future_probs[future_arc] = arc.weight; + ++future_arc; + BitmapIndex::Set(future_bits, future_bit++); + } + } + ++future_bit; + ++state_number; + context_q.pop(); + } + + if ((state_number != num_states) || + (context_bit != num_states * 2 + 1) || + (context_arc != num_states) || + (future_arc != num_futures) || + (future_bit != num_futures + num_states + 1) || + (final_bit != num_final)) { + FSTERROR() << "Structure problems detected during construction"; + SetProperties(kError, kError); + return; + } + + Init(data, false, data_region); +} + +template<typename A> +inline void NGramFstImpl<A>::Init(const char* data, bool owned, + MappedFile *data_region) { + if (owned_) { + delete [] data_; + } + delete data_region_; + data_region_ = data_region; + owned_ = owned; + data_ = data; + size_t offset = 0; + num_states_ = *(reinterpret_cast<const uint64*>(data_ + offset)); + offset += sizeof(num_states_); + num_futures_ = *(reinterpret_cast<const uint64*>(data_ + offset)); + offset += sizeof(num_futures_); + num_final_ = *(reinterpret_cast<const uint64*>(data_ + offset)); + offset += sizeof(num_final_); + uint64 bits; + size_t context_bits = num_states_ * 2 + 1; + size_t future_bits = num_futures_ + num_states_ + 1; + context_ = reinterpret_cast<const uint64*>(data_ + offset); + offset += BitmapIndex::StorageSize(context_bits) * sizeof(bits); + future_ = reinterpret_cast<const uint64*>(data_ + offset); + offset += BitmapIndex::StorageSize(future_bits) * sizeof(bits); + final_ = reinterpret_cast<const uint64*>(data_ + offset); + offset += BitmapIndex::StorageSize(num_states_) * sizeof(bits); + context_words_ = reinterpret_cast<const Label*>(data_ + offset); + offset += (num_states_ + 1) * sizeof(*context_words_); + future_words_ = reinterpret_cast<const Label*>(data_ + offset); + offset += num_futures_ * sizeof(*future_words_); + offset = (offset + sizeof(*backoff_) - 1) & ~(sizeof(*backoff_) - 1); + backoff_ = reinterpret_cast<const Weight*>(data_ + offset); + offset += (num_states_ + 1) * sizeof(*backoff_); + final_probs_ = reinterpret_cast<const Weight*>(data_ + offset); + offset += num_final_ * sizeof(*final_probs_); + future_probs_ = reinterpret_cast<const Weight*>(data_ + offset); + + context_index_.BuildIndex(context_, context_bits); + future_index_.BuildIndex(future_, future_bits); + final_index_.BuildIndex(final_, num_states_); + + const size_t node_rank = context_index_.Rank1(0); + root_first_child_ = context_index_.Select0(node_rank) + 1; + if (context_index_.Get(root_first_child_) == false) { + FSTERROR() << "Missing unigrams"; + SetProperties(kError, kError); + return; + } + const size_t last_child = context_index_.Select0(node_rank + 1) - 1; + root_num_children_ = last_child - root_first_child_ + 1; + root_children_ = context_words_ + context_index_.Rank1(root_first_child_); +} + +template<typename A> +inline typename A::StateId NGramFstImpl<A>::Transition( + const vector<Label> &context, Label future) const { + size_t num_children = root_num_children_; + const Label *children = root_children_; + const Label *loc = lower_bound(children, children + num_children, future); + if (loc == children + num_children || *loc != future) { + return context_index_.Rank1(0); + } + size_t node = root_first_child_ + loc - children; + size_t node_rank = context_index_.Rank1(node); + size_t first_child = context_index_.Select0(node_rank) + 1; + if (context_index_.Get(first_child) == false) { + return context_index_.Rank1(node); + } + size_t last_child = context_index_.Select0(node_rank + 1) - 1; + num_children = last_child - first_child + 1; + for (int word = context.size() - 1; word >= 0; --word) { + children = context_words_ + context_index_.Rank1(first_child); + loc = lower_bound(children, children + last_child - first_child + 1, + context[word]); + if (loc == children + last_child - first_child + 1 || + *loc != context[word]) { + break; + } + node = first_child + loc - children; + node_rank = context_index_.Rank1(node); + first_child = context_index_.Select0(node_rank) + 1; + if (context_index_.Get(first_child) == false) break; + last_child = context_index_.Select0(node_rank + 1) - 1; + } + return context_index_.Rank1(node); +} + +/*****************************************************************************/ +template<class A> +class NGramFst : public ImplToExpandedFst<NGramFstImpl<A> > { + friend class ArcIterator<NGramFst<A> >; + friend class NGramFstMatcher<A>; + + public: + typedef A Arc; + typedef typename A::StateId StateId; + typedef typename A::Label Label; + typedef typename A::Weight Weight; + typedef NGramFstImpl<A> Impl; + + explicit NGramFst(const Fst<A> &dst) + : ImplToExpandedFst<Impl>(new Impl(dst, NULL)) {} + + NGramFst(const Fst<A> &fst, vector<StateId>* order_out) + : ImplToExpandedFst<Impl>(new Impl(fst, order_out)) {} + + // Because the NGramFstImpl is a const stateless data structure, there + // is never a need to do anything beside copy the reference. + NGramFst(const NGramFst<A> &fst, bool safe = false) + : ImplToExpandedFst<Impl>(fst, false) {} + + NGramFst() : ImplToExpandedFst<Impl>(new Impl()) {} + + // Non-standard constructor to initialize NGramFst directly from data. + NGramFst(const char* data, bool owned) : ImplToExpandedFst<Impl>(new Impl()) { + GetImpl()->Init(data, owned, NULL); + } + + // Get method that gets the data associated with Init(). + const char* GetData(size_t* data_size) const { + return GetImpl()->GetData(data_size); + } + + const vector<Label> GetContext(StateId s) const { + return GetImpl()->GetContext(s, &inst_); + } + + virtual size_t NumArcs(StateId s) const { + return GetImpl()->NumArcs(s, &inst_); + } + + virtual NGramFst<A>* Copy(bool safe = false) const { + return new NGramFst(*this, safe); + } + + static NGramFst<A>* Read(istream &strm, const FstReadOptions &opts) { + Impl* impl = Impl::Read(strm, opts); + return impl ? new NGramFst<A>(impl) : 0; + } + + static NGramFst<A>* Read(const string &filename) { + if (!filename.empty()) { + ifstream strm(filename.c_str(), ifstream::in | ifstream::binary); + if (!strm) { + LOG(ERROR) << "NGramFst::Read: Can't open file: " << filename; + return 0; + } + return Read(strm, FstReadOptions(filename)); + } else { + return Read(cin, FstReadOptions("standard input")); + } + } + + virtual bool Write(ostream &strm, const FstWriteOptions &opts) const { + return GetImpl()->Write(strm, opts); + } + + virtual bool Write(const string &filename) const { + return Fst<A>::WriteFile(filename); + } + + virtual inline void InitStateIterator(StateIteratorData<A>* data) const { + GetImpl()->InitStateIterator(data); + } + + virtual inline void InitArcIterator( + StateId s, ArcIteratorData<A>* data) const; + + virtual MatcherBase<A>* InitMatcher(MatchType match_type) const { + return new NGramFstMatcher<A>(*this, match_type); + } + + private: + explicit NGramFst(Impl* impl) : ImplToExpandedFst<Impl>(impl) {} + + Impl* GetImpl() const { + return + ImplToExpandedFst<Impl, ExpandedFst<A> >::GetImpl(); + } + + void SetImpl(Impl* impl, bool own_impl = true) { + ImplToExpandedFst<Impl, Fst<A> >::SetImpl(impl, own_impl); + } + + mutable NGramFstInst<A> inst_; +}; + +template <class A> inline void +NGramFst<A>::InitArcIterator(StateId s, ArcIteratorData<A>* data) const { + GetImpl()->SetInstFuture(s, &inst_); + GetImpl()->SetInstNode(&inst_); + data->base = new ArcIterator<NGramFst<A> >(*this, s); +} + +/*****************************************************************************/ +template <class A> +class NGramFstMatcher : public MatcherBase<A> { + public: + typedef A Arc; + typedef typename A::Label Label; + typedef typename A::StateId StateId; + typedef typename A::Weight Weight; + + NGramFstMatcher(const NGramFst<A> &fst, MatchType match_type) + : fst_(fst), inst_(fst.inst_), match_type_(match_type), + current_loop_(false), + loop_(kNoLabel, 0, A::Weight::One(), kNoStateId) { + if (match_type_ == MATCH_OUTPUT) { + swap(loop_.ilabel, loop_.olabel); + } + } + + NGramFstMatcher(const NGramFstMatcher<A> &matcher, bool safe = false) + : fst_(matcher.fst_), inst_(matcher.inst_), + match_type_(matcher.match_type_), current_loop_(false), + loop_(kNoLabel, 0, A::Weight::One(), kNoStateId) { + if (match_type_ == MATCH_OUTPUT) { + swap(loop_.ilabel, loop_.olabel); + } + } + + virtual NGramFstMatcher<A>* Copy(bool safe = false) const { + return new NGramFstMatcher<A>(*this, safe); + } + + virtual MatchType Type(bool test) const { + return match_type_; + } + + virtual const Fst<A> &GetFst() const { + return fst_; + } + + virtual uint64 Properties(uint64 props) const { + return props; + } + + private: + virtual void SetState_(StateId s) { + fst_.GetImpl()->SetInstFuture(s, &inst_); + current_loop_ = false; + } + + virtual bool Find_(Label label) { + const Label nolabel = kNoLabel; + done_ = true; + if (label == 0 || label == nolabel) { + if (label == 0) { + current_loop_ = true; + loop_.nextstate = inst_.state_; + } + // The unigram state has no epsilon arc. + if (inst_.state_ != 0) { + arc_.ilabel = arc_.olabel = 0; + fst_.GetImpl()->SetInstNode(&inst_); + arc_.nextstate = fst_.GetImpl()->context_index_.Rank1( + fst_.GetImpl()->context_index_.Select1( + fst_.GetImpl()->context_index_.Rank0(inst_.node_) - 1)); + arc_.weight = fst_.GetImpl()->backoff_[inst_.state_]; + done_ = false; + } + } else { + const Label *start = fst_.GetImpl()->future_words_ + inst_.offset_; + const Label *end = start + inst_.num_futures_; + const Label* search = lower_bound(start, end, label); + if (search != end && *search == label) { + size_t state = search - start; + arc_.ilabel = arc_.olabel = label; + arc_.weight = fst_.GetImpl()->future_probs_[inst_.offset_ + state]; + fst_.GetImpl()->SetInstContext(&inst_); + arc_.nextstate = fst_.GetImpl()->Transition(inst_.context_, label); + done_ = false; + } + } + return !Done_(); + } + + virtual bool Done_() const { + return !current_loop_ && done_; + } + + virtual const Arc& Value_() const { + return (current_loop_) ? loop_ : arc_; + } + + virtual void Next_() { + if (current_loop_) { + current_loop_ = false; + } else { + done_ = true; + } + } + + const NGramFst<A>& fst_; + NGramFstInst<A> inst_; + MatchType match_type_; // Supplied by caller + bool done_; + Arc arc_; + bool current_loop_; // Current arc is the implicit loop + Arc loop_; +}; + +/*****************************************************************************/ +template<class A> +class ArcIterator<NGramFst<A> > : public ArcIteratorBase<A> { + public: + typedef A Arc; + typedef typename A::Label Label; + typedef typename A::StateId StateId; + typedef typename A::Weight Weight; + + ArcIterator(const NGramFst<A> &fst, StateId state) + : lazy_(~0), impl_(fst.GetImpl()), i_(0), flags_(kArcValueFlags) { + inst_ = fst.inst_; + impl_->SetInstFuture(state, &inst_); + impl_->SetInstNode(&inst_); + } + + bool Done() const { + return i_ >= ((inst_.node_ == 0) ? inst_.num_futures_ : + inst_.num_futures_ + 1); + } + + const Arc &Value() const { + bool eps = (inst_.node_ != 0 && i_ == 0); + StateId state = (inst_.node_ == 0) ? i_ : i_ - 1; + if (flags_ & lazy_ & (kArcILabelValue | kArcOLabelValue)) { + arc_.ilabel = + arc_.olabel = eps ? 0 : impl_->future_words_[inst_.offset_ + state]; + lazy_ &= ~(kArcILabelValue | kArcOLabelValue); + } + if (flags_ & lazy_ & kArcNextStateValue) { + if (eps) { + arc_.nextstate = impl_->context_index_.Rank1( + impl_->context_index_.Select1( + impl_->context_index_.Rank0(inst_.node_) - 1)); + } else { + if (lazy_ & kArcNextStateValue) { + impl_->SetInstContext(&inst_); // first time only. + } + arc_.nextstate = + impl_->Transition(inst_.context_, + impl_->future_words_[inst_.offset_ + state]); + } + lazy_ &= ~kArcNextStateValue; + } + if (flags_ & lazy_ & kArcWeightValue) { + arc_.weight = eps ? impl_->backoff_[inst_.state_] : + impl_->future_probs_[inst_.offset_ + state]; + lazy_ &= ~kArcWeightValue; + } + return arc_; + } + + void Next() { + ++i_; + lazy_ = ~0; + } + + size_t Position() const { return i_; } + + void Reset() { + i_ = 0; + lazy_ = ~0; + } + + void Seek(size_t a) { + if (i_ != a) { + i_ = a; + lazy_ = ~0; + } + } + + uint32 Flags() const { + return flags_; + } + + void SetFlags(uint32 f, uint32 m) { + flags_ &= ~m; + flags_ |= (f & kArcValueFlags); + } + + private: + virtual bool Done_() const { return Done(); } + virtual const Arc& Value_() const { return Value(); } + virtual void Next_() { Next(); } + virtual size_t Position_() const { return Position(); } + virtual void Reset_() { Reset(); } + virtual void Seek_(size_t a) { Seek(a); } + uint32 Flags_() const { return Flags(); } + void SetFlags_(uint32 f, uint32 m) { SetFlags(f, m); } + + mutable Arc arc_; + mutable uint32 lazy_; + const NGramFstImpl<A> *impl_; + mutable NGramFstInst<A> inst_; + + size_t i_; + uint32 flags_; + + DISALLOW_COPY_AND_ASSIGN(ArcIterator); +}; + +/*****************************************************************************/ +// Specialization for NGramFst; see generic version in fst.h +// for sample usage (but use the ProdLmFst type!). This version +// should inline. +template <class A> +class StateIterator<NGramFst<A> > : public StateIteratorBase<A> { + public: + typedef typename A::StateId StateId; + + explicit StateIterator(const NGramFst<A> &fst) + : s_(0), num_states_(fst.NumStates()) { } + + bool Done() const { return s_ >= num_states_; } + StateId Value() const { return s_; } + void Next() { ++s_; } + void Reset() { s_ = 0; } + + private: + virtual bool Done_() const { return Done(); } + virtual StateId Value_() const { return Value(); } + virtual void Next_() { Next(); } + virtual void Reset_() { Reset(); } + + StateId s_, num_states_; + + DISALLOW_COPY_AND_ASSIGN(StateIterator); +}; +} // namespace fst +#endif // FST_EXTENSIONS_NGRAM_NGRAM_FST_H_ diff --git a/kaldi_io/src/tools/openfst/include/fst/extensions/ngram/nthbit.h b/kaldi_io/src/tools/openfst/include/fst/extensions/ngram/nthbit.h new file mode 100644 index 0000000..d4a9a5a --- /dev/null +++ b/kaldi_io/src/tools/openfst/include/fst/extensions/ngram/nthbit.h @@ -0,0 +1,46 @@ + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Copyright 2005-2010 Google, Inc. +// Author: [email protected] (Jeffrey Sorensen) +// [email protected] (Doug Rohde) + +#ifndef FST_EXTENSIONS_NGRAM_NTHBIT_H_ +#define FST_EXTENSIONS_NGRAM_NTHBIT_H_ + +#include <fst/types.h> + +extern uint32 nth_bit_bit_offset[]; + +inline uint32 nth_bit(uint64 v, uint32 r) { + uint32 shift = 0; + uint32 c = __builtin_popcount(v & 0xffffffff); + uint32 mask = -(r > c); + r -= c & mask; + shift += (32 & mask); + + c = __builtin_popcount((v >> shift) & 0xffff); + mask = -(r > c); + r -= c & mask; + shift += (16 & mask); + + c = __builtin_popcount((v >> shift) & 0xff); + mask = -(r > c); + r -= c & mask; + shift += (8 & mask); + + return shift + ((nth_bit_bit_offset[(v >> shift) & 0xff] >> + ((r - 1) << 2)) & 0xf); +} + +#endif // FST_EXTENSIONS_NGRAM_NTHBIT_H_ diff --git a/kaldi_io/src/tools/openfst/include/fst/factor-weight.h b/kaldi_io/src/tools/openfst/include/fst/factor-weight.h new file mode 100644 index 0000000..685155c --- /dev/null +++ b/kaldi_io/src/tools/openfst/include/fst/factor-weight.h @@ -0,0 +1,475 @@ +// factor-weight.h + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Copyright 2005-2010 Google, Inc. +// Author: [email protected] (Cyril Allauzen) +// +// \file +// Classes to factor weights in an FST. + +#ifndef FST_LIB_FACTOR_WEIGHT_H__ +#define FST_LIB_FACTOR_WEIGHT_H__ + +#include <algorithm> +#include <tr1/unordered_map> +using std::tr1::unordered_map; +using std::tr1::unordered_multimap; +#include <string> +#include <utility> +using std::pair; using std::make_pair; +#include <vector> +using std::vector; + +#include <fst/cache.h> +#include <fst/test-properties.h> + + +namespace fst { + +const uint32 kFactorFinalWeights = 0x00000001; +const uint32 kFactorArcWeights = 0x00000002; + +template <class Arc> +struct FactorWeightOptions : CacheOptions { + typedef typename Arc::Label Label; + float delta; + uint32 mode; // factor arc weights and/or final weights + Label final_ilabel; // input label of arc created when factoring final w's + Label final_olabel; // output label of arc created when factoring final w's + + FactorWeightOptions(const CacheOptions &opts, float d, + uint32 m = kFactorArcWeights | kFactorFinalWeights, + Label il = 0, Label ol = 0) + : CacheOptions(opts), delta(d), mode(m), final_ilabel(il), + final_olabel(ol) {} + + explicit FactorWeightOptions( + float d, uint32 m = kFactorArcWeights | kFactorFinalWeights, + Label il = 0, Label ol = 0) + : delta(d), mode(m), final_ilabel(il), final_olabel(ol) {} + + FactorWeightOptions(uint32 m = kFactorArcWeights | kFactorFinalWeights, + Label il = 0, Label ol = 0) + : delta(kDelta), mode(m), final_ilabel(il), final_olabel(ol) {} +}; + + +// A factor iterator takes as argument a weight w and returns a +// sequence of pairs of weights (xi,yi) such that the sum of the +// products xi times yi is equal to w. If w is fully factored, +// the iterator should return nothing. +// +// template <class W> +// class FactorIterator { +// public: +// FactorIterator(W w); +// bool Done() const; +// void Next(); +// pair<W, W> Value() const; +// void Reset(); +// } + + +// Factor trivially. +template <class W> +class IdentityFactor { + public: + IdentityFactor(const W &w) {} + bool Done() const { return true; } + void Next() {} + pair<W, W> Value() const { return make_pair(W::One(), W::One()); } // unused + void Reset() {} +}; + + +// Factor a StringWeight w as 'ab' where 'a' is a label. +template <typename L, StringType S = STRING_LEFT> +class StringFactor { + public: + StringFactor(const StringWeight<L, S> &w) + : weight_(w), done_(w.Size() <= 1) {} + + bool Done() const { return done_; } + + void Next() { done_ = true; } + + pair< StringWeight<L, S>, StringWeight<L, S> > Value() const { + StringWeightIterator<L, S> iter(weight_); + StringWeight<L, S> w1(iter.Value()); + StringWeight<L, S> w2; + for (iter.Next(); !iter.Done(); iter.Next()) + w2.PushBack(iter.Value()); + return make_pair(w1, w2); + } + + void Reset() { done_ = weight_.Size() <= 1; } + + private: + StringWeight<L, S> weight_; + bool done_; +}; + + +// Factor a GallicWeight using StringFactor. +template <class L, class W, StringType S = STRING_LEFT> +class GallicFactor { + public: + GallicFactor(const GallicWeight<L, W, S> &w) + : weight_(w), done_(w.Value1().Size() <= 1) {} + + bool Done() const { return done_; } + + void Next() { done_ = true; } + + pair< GallicWeight<L, W, S>, GallicWeight<L, W, S> > Value() const { + StringFactor<L, S> iter(weight_.Value1()); + GallicWeight<L, W, S> w1(iter.Value().first, weight_.Value2()); + GallicWeight<L, W, S> w2(iter.Value().second, W::One()); + return make_pair(w1, w2); + } + + void Reset() { done_ = weight_.Value1().Size() <= 1; } + + private: + GallicWeight<L, W, S> weight_; + bool done_; +}; + + +// Implementation class for FactorWeight +template <class A, class F> +class FactorWeightFstImpl + : public CacheImpl<A> { + public: + using FstImpl<A>::SetType; + using FstImpl<A>::SetProperties; + using FstImpl<A>::SetInputSymbols; + using FstImpl<A>::SetOutputSymbols; + + using CacheBaseImpl< CacheState<A> >::PushArc; + using CacheBaseImpl< CacheState<A> >::HasStart; + using CacheBaseImpl< CacheState<A> >::HasFinal; + using CacheBaseImpl< CacheState<A> >::HasArcs; + using CacheBaseImpl< CacheState<A> >::SetArcs; + using CacheBaseImpl< CacheState<A> >::SetFinal; + using CacheBaseImpl< CacheState<A> >::SetStart; + + typedef A Arc; + typedef typename A::Label Label; + typedef typename A::Weight Weight; + typedef typename A::StateId StateId; + typedef F FactorIterator; + + struct Element { + Element() {} + + Element(StateId s, Weight w) : state(s), weight(w) {} + + StateId state; // Input state Id + Weight weight; // Residual weight + }; + + FactorWeightFstImpl(const Fst<A> &fst, const FactorWeightOptions<A> &opts) + : CacheImpl<A>(opts), + fst_(fst.Copy()), + delta_(opts.delta), + mode_(opts.mode), + final_ilabel_(opts.final_ilabel), + final_olabel_(opts.final_olabel) { + SetType("factor_weight"); + uint64 props = fst.Properties(kFstProperties, false); + SetProperties(FactorWeightProperties(props), kCopyProperties); + + SetInputSymbols(fst.InputSymbols()); + SetOutputSymbols(fst.OutputSymbols()); + + if (mode_ == 0) + LOG(WARNING) << "FactorWeightFst: factor mode is set to 0: " + << "factoring neither arc weights nor final weights."; + } + + FactorWeightFstImpl(const FactorWeightFstImpl<A, F> &impl) + : CacheImpl<A>(impl), + fst_(impl.fst_->Copy(true)), + delta_(impl.delta_), + mode_(impl.mode_), + final_ilabel_(impl.final_ilabel_), + final_olabel_(impl.final_olabel_) { + SetType("factor_weight"); + SetProperties(impl.Properties(), kCopyProperties); + SetInputSymbols(impl.InputSymbols()); + SetOutputSymbols(impl.OutputSymbols()); + } + + ~FactorWeightFstImpl() { + delete fst_; + } + + StateId Start() { + if (!HasStart()) { + StateId s = fst_->Start(); + if (s == kNoStateId) + return kNoStateId; + StateId start = FindState(Element(fst_->Start(), Weight::One())); + SetStart(start); + } + return CacheImpl<A>::Start(); + } + + Weight Final(StateId s) { + if (!HasFinal(s)) { + const Element &e = elements_[s]; + // TODO: fix so cast is unnecessary + Weight w = e.state == kNoStateId + ? e.weight + : (Weight) Times(e.weight, fst_->Final(e.state)); + FactorIterator f(w); + if (!(mode_ & kFactorFinalWeights) || f.Done()) + SetFinal(s, w); + else + SetFinal(s, Weight::Zero()); + } + return CacheImpl<A>::Final(s); + } + + size_t NumArcs(StateId s) { + if (!HasArcs(s)) + Expand(s); + return CacheImpl<A>::NumArcs(s); + } + + size_t NumInputEpsilons(StateId s) { + if (!HasArcs(s)) + Expand(s); + return CacheImpl<A>::NumInputEpsilons(s); + } + + size_t NumOutputEpsilons(StateId s) { + if (!HasArcs(s)) + Expand(s); + return CacheImpl<A>::NumOutputEpsilons(s); + } + + uint64 Properties() const { return Properties(kFstProperties); } + + // Set error if found; return FST impl properties. + uint64 Properties(uint64 mask) const { + if ((mask & kError) && fst_->Properties(kError, false)) + SetProperties(kError, kError); + return FstImpl<Arc>::Properties(mask); + } + + void InitArcIterator(StateId s, ArcIteratorData<A> *data) { + if (!HasArcs(s)) + Expand(s); + CacheImpl<A>::InitArcIterator(s, data); + } + + + // Find state corresponding to an element. Create new state + // if element not found. + StateId FindState(const Element &e) { + if (!(mode_ & kFactorArcWeights) && e.weight == Weight::One()) { + while (unfactored_.size() <= e.state) + unfactored_.push_back(kNoStateId); + if (unfactored_[e.state] == kNoStateId) { + unfactored_[e.state] = elements_.size(); + elements_.push_back(e); + } + return unfactored_[e.state]; + } else { + typename ElementMap::iterator eit = element_map_.find(e); + if (eit != element_map_.end()) { + return (*eit).second; + } else { + StateId s = elements_.size(); + elements_.push_back(e); + element_map_.insert(pair<const Element, StateId>(e, s)); + return s; + } + } + } + + // Computes the outgoing transitions from a state, creating new destination + // states as needed. + void Expand(StateId s) { + Element e = elements_[s]; + if (e.state != kNoStateId) { + for (ArcIterator< Fst<A> > ait(*fst_, e.state); + !ait.Done(); + ait.Next()) { + const A &arc = ait.Value(); + Weight w = Times(e.weight, arc.weight); + FactorIterator fit(w); + if (!(mode_ & kFactorArcWeights) || fit.Done()) { + StateId d = FindState(Element(arc.nextstate, Weight::One())); + PushArc(s, Arc(arc.ilabel, arc.olabel, w, d)); + } else { + for (; !fit.Done(); fit.Next()) { + const pair<Weight, Weight> &p = fit.Value(); + StateId d = FindState(Element(arc.nextstate, + p.second.Quantize(delta_))); + PushArc(s, Arc(arc.ilabel, arc.olabel, p.first, d)); + } + } + } + } + + if ((mode_ & kFactorFinalWeights) && + ((e.state == kNoStateId) || + (fst_->Final(e.state) != Weight::Zero()))) { + Weight w = e.state == kNoStateId + ? e.weight + : Times(e.weight, fst_->Final(e.state)); + for (FactorIterator fit(w); + !fit.Done(); + fit.Next()) { + const pair<Weight, Weight> &p = fit.Value(); + StateId d = FindState(Element(kNoStateId, + p.second.Quantize(delta_))); + PushArc(s, Arc(final_ilabel_, final_olabel_, p.first, d)); + } + } + SetArcs(s); + } + + private: + static const size_t kPrime = 7853; + + // Equality function for Elements, assume weights have been quantized. + class ElementEqual { + public: + bool operator()(const Element &x, const Element &y) const { + return x.state == y.state && x.weight == y.weight; + } + }; + + // Hash function for Elements to Fst states. + class ElementKey { + public: + size_t operator()(const Element &x) const { + return static_cast<size_t>(x.state * kPrime + x.weight.Hash()); + } + private: + }; + + typedef unordered_map<Element, StateId, ElementKey, ElementEqual> ElementMap; + + const Fst<A> *fst_; + float delta_; + uint32 mode_; // factoring arc and/or final weights + Label final_ilabel_; // ilabel of arc created when factoring final w's + Label final_olabel_; // olabel of arc created when factoring final w's + vector<Element> elements_; // mapping Fst state to Elements + ElementMap element_map_; // mapping Elements to Fst state + // mapping between old/new 'StateId' for states that do not need to + // be factored when 'mode_' is '0' or 'kFactorFinalWeights' + vector<StateId> unfactored_; + + void operator=(const FactorWeightFstImpl<A, F> &); // disallow +}; + +template <class A, class F> const size_t FactorWeightFstImpl<A, F>::kPrime; + + +// FactorWeightFst takes as template parameter a FactorIterator as +// defined above. The result of weight factoring is a transducer +// equivalent to the input whose path weights have been factored +// according to the FactorIterator. States and transitions will be +// added as necessary. The algorithm is a generalization to arbitrary +// weights of the second step of the input epsilon-normalization +// algorithm due to Mohri, "Generic epsilon-removal and input +// epsilon-normalization algorithms for weighted transducers", +// International Journal of Computer Science 13(1): 129-143 (2002). +// +// This class attaches interface to implementation and handles +// reference counting, delegating most methods to ImplToFst. +template <class A, class F> +class FactorWeightFst : public ImplToFst< FactorWeightFstImpl<A, F> > { + public: + friend class ArcIterator< FactorWeightFst<A, F> >; + friend class StateIterator< FactorWeightFst<A, F> >; + + typedef A Arc; + typedef typename A::Weight Weight; + typedef typename A::StateId StateId; + typedef CacheState<A> State; + typedef FactorWeightFstImpl<A, F> Impl; + + FactorWeightFst(const Fst<A> &fst) + : ImplToFst<Impl>(new Impl(fst, FactorWeightOptions<A>())) {} + + FactorWeightFst(const Fst<A> &fst, const FactorWeightOptions<A> &opts) + : ImplToFst<Impl>(new Impl(fst, opts)) {} + + // See Fst<>::Copy() for doc. + FactorWeightFst(const FactorWeightFst<A, F> &fst, bool copy) + : ImplToFst<Impl>(fst, copy) {} + + // Get a copy of this FactorWeightFst. See Fst<>::Copy() for further doc. + virtual FactorWeightFst<A, F> *Copy(bool copy = false) const { + return new FactorWeightFst<A, F>(*this, copy); + } + + virtual inline void InitStateIterator(StateIteratorData<A> *data) const; + + virtual void InitArcIterator(StateId s, ArcIteratorData<A> *data) const { + GetImpl()->InitArcIterator(s, data); + } + + private: + // Makes visible to friends. + Impl *GetImpl() const { return ImplToFst<Impl>::GetImpl(); } + + void operator=(const FactorWeightFst<A, F> &fst); // Disallow +}; + + +// Specialization for FactorWeightFst. +template<class A, class F> +class StateIterator< FactorWeightFst<A, F> > + : public CacheStateIterator< FactorWeightFst<A, F> > { + public: + explicit StateIterator(const FactorWeightFst<A, F> &fst) + : CacheStateIterator< FactorWeightFst<A, F> >(fst, fst.GetImpl()) {} +}; + + +// Specialization for FactorWeightFst. +template <class A, class F> +class ArcIterator< FactorWeightFst<A, F> > + : public CacheArcIterator< FactorWeightFst<A, F> > { + public: + typedef typename A::StateId StateId; + + ArcIterator(const FactorWeightFst<A, F> &fst, StateId s) + : CacheArcIterator< FactorWeightFst<A, F> >(fst.GetImpl(), s) { + if (!fst.GetImpl()->HasArcs(s)) + fst.GetImpl()->Expand(s); + } + + private: + DISALLOW_COPY_AND_ASSIGN(ArcIterator); +}; + +template <class A, class F> inline +void FactorWeightFst<A, F>::InitStateIterator(StateIteratorData<A> *data) const +{ + data->base = new StateIterator< FactorWeightFst<A, F> >(*this); +} + + +} // namespace fst + +#endif // FST_LIB_FACTOR_WEIGHT_H__ diff --git a/kaldi_io/src/tools/openfst/include/fst/flags.h b/kaldi_io/src/tools/openfst/include/fst/flags.h new file mode 100644 index 0000000..b3bb66c --- /dev/null +++ b/kaldi_io/src/tools/openfst/include/fst/flags.h @@ -0,0 +1,242 @@ +// flags.h +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Author: [email protected] (Michael Riley) +// +// \file +// Google-style flag handling declarations and inline definitions. + +#ifndef FST_LIB_FLAGS_H__ +#define FST_LIB_FLAGS_H__ + +#include <iostream> +#include <map> +#include <set> +#include <sstream> +#include <string> + +#include <fst/types.h> +#include <fst/lock.h> + +using std::string; + +// +// FLAGS USAGE: +// +// Definition example: +// +// DEFINE_int32(length, 0, "length"); +// +// This defines variable FLAGS_length, initialized to 0. +// +// Declaration example: +// +// DECLARE_int32(length); +// +// SET_FLAGS() can be used to set flags from the command line +// using, for example, '--length=2'. +// +// ShowUsage() can be used to print out command and flag usage. +// + +#define DECLARE_bool(name) extern bool FLAGS_ ## name +#define DECLARE_string(name) extern string FLAGS_ ## name +#define DECLARE_int32(name) extern int32 FLAGS_ ## name +#define DECLARE_int64(name) extern int64 FLAGS_ ## name +#define DECLARE_double(name) extern double FLAGS_ ## name + +template <typename T> +struct FlagDescription { + FlagDescription(T *addr, const char *doc, const char *type, + const char *file, const T val) + : address(addr), + doc_string(doc), + type_name(type), + file_name(file), + default_value(val) {} + + T *address; + const char *doc_string; + const char *type_name; + const char *file_name; + const T default_value; +}; + +template <typename T> +class FlagRegister { + public: + static FlagRegister<T> *GetRegister() { + fst::FstOnceInit(®ister_init_, &FlagRegister<T>::Init); + return register_; + } + + const FlagDescription<T> &GetFlagDescription(const string &name) const { + fst::MutexLock l(register_lock_); + typename std::map< string, FlagDescription<T> >::const_iterator it = + flag_table_.find(name); + return it != flag_table_.end() ? it->second : 0; + } + void SetDescription(const string &name, + const FlagDescription<T> &desc) { + fst::MutexLock l(register_lock_); + flag_table_.insert(make_pair(name, desc)); + } + + bool SetFlag(const string &val, bool *address) const { + if (val == "true" || val == "1" || val.empty()) { + *address = true; + return true; + } else if (val == "false" || val == "0") { + *address = false; + return true; + } + else { + return false; + } + } + bool SetFlag(const string &val, string *address) const { + *address = val; + return true; + } + bool SetFlag(const string &val, int32 *address) const { + char *p = 0; + *address = strtol(val.c_str(), &p, 0); + return !val.empty() && *p == '\0'; + } + bool SetFlag(const string &val, int64 *address) const { + char *p = 0; + *address = strtoll(val.c_str(), &p, 0); + return !val.empty() && *p == '\0'; + } + bool SetFlag(const string &val, double *address) const { + char *p = 0; + *address = strtod(val.c_str(), &p); + return !val.empty() && *p == '\0'; + } + + bool SetFlag(const string &arg, const string &val) const { + for (typename std::map< string, FlagDescription<T> >::const_iterator it = + flag_table_.begin(); + it != flag_table_.end(); + ++it) { + const string &name = it->first; + const FlagDescription<T> &desc = it->second; + if (arg == name) + return SetFlag(val, desc.address); + } + return false; + } + + void GetUsage(std::set< std::pair<string, string> > *usage_set) const { + for (typename std::map< string, + FlagDescription<T> >::const_iterator it = + flag_table_.begin(); + it != flag_table_.end(); + ++it) { + const string &name = it->first; + const FlagDescription<T> &desc = it->second; + string usage = " --" + name; + usage += ": type = "; + usage += desc.type_name; + usage += ", default = "; + usage += GetDefault(desc.default_value) + "\n "; + usage += desc.doc_string; + usage_set->insert(make_pair(desc.file_name, usage)); + } + } + + private: + static void Init() { + register_lock_ = new fst::Mutex; + register_ = new FlagRegister<T>; + } + + std::map< string, FlagDescription<T> > flag_table_; + + string GetDefault(bool default_value) const { + return default_value ? "true" : "false"; + } + + string GetDefault(const string &default_value) const { + return "\"" + default_value + "\""; + } + + template<typename V> string GetDefault(const V& default_value) const { + std::ostringstream strm; + strm << default_value; + return strm.str(); + } + + static fst::FstOnceType register_init_; // ensures only called once + static fst::Mutex* register_lock_; // multithreading lock + static FlagRegister<T> *register_; +}; + +template <class T> +fst::FstOnceType FlagRegister<T>::register_init_ = fst::FST_ONCE_INIT; + +template <class T> +fst::Mutex *FlagRegister<T>::register_lock_ = 0; + +template <class T> +FlagRegister<T> *FlagRegister<T>::register_ = 0; + + +template <typename T> +class FlagRegisterer { + public: + FlagRegisterer(const string &name, const FlagDescription<T> &desc) { + FlagRegister<T> *registr = FlagRegister<T>::GetRegister(); + registr->SetDescription(name, desc); + } + + private: + DISALLOW_COPY_AND_ASSIGN(FlagRegisterer); +}; + + +#define DEFINE_VAR(type, name, value, doc) \ + type FLAGS_ ## name = value; \ + static FlagRegisterer<type> \ + name ## _flags_registerer(#name, FlagDescription<type>(&FLAGS_ ## name, \ + doc, \ + #type, \ + __FILE__, \ + value)) + +#define DEFINE_bool(name, value, doc) DEFINE_VAR(bool, name, value, doc) +#define DEFINE_string(name, value, doc) \ + DEFINE_VAR(string, name, value, doc) +#define DEFINE_int32(name, value, doc) DEFINE_VAR(int32, name, value, doc) +#define DEFINE_int64(name, value, doc) DEFINE_VAR(int64, name, value, doc) +#define DEFINE_double(name, value, doc) DEFINE_VAR(double, name, value, doc) + + +// Temporary directory +DECLARE_string(tmpdir); + +void SetFlags(const char *usage, int *argc, char ***argv, bool remove_flags, + const char *src = ""); + +#define SET_FLAGS(usage, argc, argv, rmflags) \ +SetFlags(usage, argc, argv, rmflags, __FILE__) + +// Deprecated - for backward compatibility +inline void InitFst(const char *usage, int *argc, char ***argv, bool rmflags) { + return SetFlags(usage, argc, argv, rmflags); +} + +void ShowUsage(bool long_usage = true); + +#endif // FST_LIB_FLAGS_H__ diff --git a/kaldi_io/src/tools/openfst/include/fst/float-weight.h b/kaldi_io/src/tools/openfst/include/fst/float-weight.h new file mode 100644 index 0000000..eb22638 --- /dev/null +++ b/kaldi_io/src/tools/openfst/include/fst/float-weight.h @@ -0,0 +1,601 @@ +// float-weight.h + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Copyright 2005-2010 Google, Inc. +// Author: [email protected] (Michael Riley) +// +// \file +// Float weight set and associated semiring operation definitions. +// + +#ifndef FST_LIB_FLOAT_WEIGHT_H__ +#define FST_LIB_FLOAT_WEIGHT_H__ + +#include <limits> +#include <climits> +#include <sstream> +#include <string> + +#include <fst/util.h> +#include <fst/weight.h> + + +namespace fst { + +// numeric limits class +template <class T> +class FloatLimits { + public: + static const T PosInfinity() { + static const T pos_infinity = numeric_limits<T>::infinity(); + return pos_infinity; + } + + static const T NegInfinity() { + static const T neg_infinity = -PosInfinity(); + return neg_infinity; + } + + static const T NumberBad() { + static const T number_bad = numeric_limits<T>::quiet_NaN(); + return number_bad; + } + +}; + +// weight class to be templated on floating-points types +template <class T = float> +class FloatWeightTpl { + public: + FloatWeightTpl() {} + + FloatWeightTpl(T f) : value_(f) {} + + FloatWeightTpl(const FloatWeightTpl<T> &w) : value_(w.value_) {} + + FloatWeightTpl<T> &operator=(const FloatWeightTpl<T> &w) { + value_ = w.value_; + return *this; + } + + istream &Read(istream &strm) { + return ReadType(strm, &value_); + } + + ostream &Write(ostream &strm) const { + return WriteType(strm, value_); + } + + size_t Hash() const { + union { + T f; + size_t s; + } u; + u.s = 0; + u.f = value_; + return u.s; + } + + const T &Value() const { return value_; } + + protected: + void SetValue(const T &f) { value_ = f; } + + inline static string GetPrecisionString() { + int64 size = sizeof(T); + if (size == sizeof(float)) return ""; + size *= CHAR_BIT; + + string result; + Int64ToStr(size, &result); + return result; + } + + private: + T value_; +}; + +// Single-precision float weight +typedef FloatWeightTpl<float> FloatWeight; + +template <class T> +inline bool operator==(const FloatWeightTpl<T> &w1, + const FloatWeightTpl<T> &w2) { + // Volatile qualifier thwarts over-aggressive compiler optimizations + // that lead to problems esp. with NaturalLess(). + volatile T v1 = w1.Value(); + volatile T v2 = w2.Value(); + return v1 == v2; +} + +inline bool operator==(const FloatWeightTpl<double> &w1, + const FloatWeightTpl<double> &w2) { + return operator==<double>(w1, w2); +} + +inline bool operator==(const FloatWeightTpl<float> &w1, + const FloatWeightTpl<float> &w2) { + return operator==<float>(w1, w2); +} + +template <class T> +inline bool operator!=(const FloatWeightTpl<T> &w1, + const FloatWeightTpl<T> &w2) { + return !(w1 == w2); +} + +inline bool operator!=(const FloatWeightTpl<double> &w1, + const FloatWeightTpl<double> &w2) { + return operator!=<double>(w1, w2); +} + +inline bool operator!=(const FloatWeightTpl<float> &w1, + const FloatWeightTpl<float> &w2) { + return operator!=<float>(w1, w2); +} + +template <class T> +inline bool ApproxEqual(const FloatWeightTpl<T> &w1, + const FloatWeightTpl<T> &w2, + float delta = kDelta) { + return w1.Value() <= w2.Value() + delta && w2.Value() <= w1.Value() + delta; +} + +template <class T> +inline ostream &operator<<(ostream &strm, const FloatWeightTpl<T> &w) { + if (w.Value() == FloatLimits<T>::PosInfinity()) + return strm << "Infinity"; + else if (w.Value() == FloatLimits<T>::NegInfinity()) + return strm << "-Infinity"; + else if (w.Value() != w.Value()) // Fails for NaN + return strm << "BadNumber"; + else + return strm << w.Value(); +} + +template <class T> +inline istream &operator>>(istream &strm, FloatWeightTpl<T> &w) { + string s; + strm >> s; + if (s == "Infinity") { + w = FloatWeightTpl<T>(FloatLimits<T>::PosInfinity()); + } else if (s == "-Infinity") { + w = FloatWeightTpl<T>(FloatLimits<T>::NegInfinity()); + } else { + char *p; + T f = strtod(s.c_str(), &p); + if (p < s.c_str() + s.size()) + strm.clear(std::ios::badbit); + else + w = FloatWeightTpl<T>(f); + } + return strm; +} + + +// Tropical semiring: (min, +, inf, 0) +template <class T> +class TropicalWeightTpl : public FloatWeightTpl<T> { + public: + using FloatWeightTpl<T>::Value; + + typedef TropicalWeightTpl<T> ReverseWeight; + + TropicalWeightTpl() : FloatWeightTpl<T>() {} + + TropicalWeightTpl(T f) : FloatWeightTpl<T>(f) {} + + TropicalWeightTpl(const TropicalWeightTpl<T> &w) : FloatWeightTpl<T>(w) {} + + static const TropicalWeightTpl<T> Zero() { + return TropicalWeightTpl<T>(FloatLimits<T>::PosInfinity()); } + + static const TropicalWeightTpl<T> One() { + return TropicalWeightTpl<T>(0.0F); } + + static const TropicalWeightTpl<T> NoWeight() { + return TropicalWeightTpl<T>(FloatLimits<T>::NumberBad()); } + + static const string &Type() { + static const string type = "tropical" + + FloatWeightTpl<T>::GetPrecisionString(); + return type; + } + + bool Member() const { + // First part fails for IEEE NaN + return Value() == Value() && Value() != FloatLimits<T>::NegInfinity(); + } + + TropicalWeightTpl<T> Quantize(float delta = kDelta) const { + if (Value() == FloatLimits<T>::NegInfinity() || + Value() == FloatLimits<T>::PosInfinity() || + Value() != Value()) + return *this; + else + return TropicalWeightTpl<T>(floor(Value()/delta + 0.5F) * delta); + } + + TropicalWeightTpl<T> Reverse() const { return *this; } + + static uint64 Properties() { + return kLeftSemiring | kRightSemiring | kCommutative | + kPath | kIdempotent; + } +}; + +// Single precision tropical weight +typedef TropicalWeightTpl<float> TropicalWeight; + +template <class T> +inline TropicalWeightTpl<T> Plus(const TropicalWeightTpl<T> &w1, + const TropicalWeightTpl<T> &w2) { + if (!w1.Member() || !w2.Member()) + return TropicalWeightTpl<T>::NoWeight(); + return w1.Value() < w2.Value() ? w1 : w2; +} + +inline TropicalWeightTpl<float> Plus(const TropicalWeightTpl<float> &w1, + const TropicalWeightTpl<float> &w2) { + return Plus<float>(w1, w2); +} + +inline TropicalWeightTpl<double> Plus(const TropicalWeightTpl<double> &w1, + const TropicalWeightTpl<double> &w2) { + return Plus<double>(w1, w2); +} + +template <class T> +inline TropicalWeightTpl<T> Times(const TropicalWeightTpl<T> &w1, + const TropicalWeightTpl<T> &w2) { + if (!w1.Member() || !w2.Member()) + return TropicalWeightTpl<T>::NoWeight(); + T f1 = w1.Value(), f2 = w2.Value(); + if (f1 == FloatLimits<T>::PosInfinity()) + return w1; + else if (f2 == FloatLimits<T>::PosInfinity()) + return w2; + else + return TropicalWeightTpl<T>(f1 + f2); +} + +inline TropicalWeightTpl<float> Times(const TropicalWeightTpl<float> &w1, + const TropicalWeightTpl<float> &w2) { + return Times<float>(w1, w2); +} + +inline TropicalWeightTpl<double> Times(const TropicalWeightTpl<double> &w1, + const TropicalWeightTpl<double> &w2) { + return Times<double>(w1, w2); +} + +template <class T> +inline TropicalWeightTpl<T> Divide(const TropicalWeightTpl<T> &w1, + const TropicalWeightTpl<T> &w2, + DivideType typ = DIVIDE_ANY) { + if (!w1.Member() || !w2.Member()) + return TropicalWeightTpl<T>::NoWeight(); + T f1 = w1.Value(), f2 = w2.Value(); + if (f2 == FloatLimits<T>::PosInfinity()) + return FloatLimits<T>::NumberBad(); + else if (f1 == FloatLimits<T>::PosInfinity()) + return FloatLimits<T>::PosInfinity(); + else + return TropicalWeightTpl<T>(f1 - f2); +} + +inline TropicalWeightTpl<float> Divide(const TropicalWeightTpl<float> &w1, + const TropicalWeightTpl<float> &w2, + DivideType typ = DIVIDE_ANY) { + return Divide<float>(w1, w2, typ); +} + +inline TropicalWeightTpl<double> Divide(const TropicalWeightTpl<double> &w1, + const TropicalWeightTpl<double> &w2, + DivideType typ = DIVIDE_ANY) { + return Divide<double>(w1, w2, typ); +} + + +// Log semiring: (log(e^-x + e^y), +, inf, 0) +template <class T> +class LogWeightTpl : public FloatWeightTpl<T> { + public: + using FloatWeightTpl<T>::Value; + + typedef LogWeightTpl ReverseWeight; + + LogWeightTpl() : FloatWeightTpl<T>() {} + + LogWeightTpl(T f) : FloatWeightTpl<T>(f) {} + + LogWeightTpl(const LogWeightTpl<T> &w) : FloatWeightTpl<T>(w) {} + + static const LogWeightTpl<T> Zero() { + return LogWeightTpl<T>(FloatLimits<T>::PosInfinity()); + } + + static const LogWeightTpl<T> One() { + return LogWeightTpl<T>(0.0F); + } + + static const LogWeightTpl<T> NoWeight() { + return LogWeightTpl<T>(FloatLimits<T>::NumberBad()); } + + static const string &Type() { + static const string type = "log" + FloatWeightTpl<T>::GetPrecisionString(); + return type; + } + + bool Member() const { + // First part fails for IEEE NaN + return Value() == Value() && Value() != FloatLimits<T>::NegInfinity(); + } + + LogWeightTpl<T> Quantize(float delta = kDelta) const { + if (Value() == FloatLimits<T>::NegInfinity() || + Value() == FloatLimits<T>::PosInfinity() || + Value() != Value()) + return *this; + else + return LogWeightTpl<T>(floor(Value()/delta + 0.5F) * delta); + } + + LogWeightTpl<T> Reverse() const { return *this; } + + static uint64 Properties() { + return kLeftSemiring | kRightSemiring | kCommutative; + } +}; + +// Single-precision log weight +typedef LogWeightTpl<float> LogWeight; +// Double-precision log weight +typedef LogWeightTpl<double> Log64Weight; + +template <class T> +inline T LogExp(T x) { return log(1.0F + exp(-x)); } + +template <class T> +inline LogWeightTpl<T> Plus(const LogWeightTpl<T> &w1, + const LogWeightTpl<T> &w2) { + T f1 = w1.Value(), f2 = w2.Value(); + if (f1 == FloatLimits<T>::PosInfinity()) + return w2; + else if (f2 == FloatLimits<T>::PosInfinity()) + return w1; + else if (f1 > f2) + return LogWeightTpl<T>(f2 - LogExp(f1 - f2)); + else + return LogWeightTpl<T>(f1 - LogExp(f2 - f1)); +} + +inline LogWeightTpl<float> Plus(const LogWeightTpl<float> &w1, + const LogWeightTpl<float> &w2) { + return Plus<float>(w1, w2); +} + +inline LogWeightTpl<double> Plus(const LogWeightTpl<double> &w1, + const LogWeightTpl<double> &w2) { + return Plus<double>(w1, w2); +} + +template <class T> +inline LogWeightTpl<T> Times(const LogWeightTpl<T> &w1, + const LogWeightTpl<T> &w2) { + if (!w1.Member() || !w2.Member()) + return LogWeightTpl<T>::NoWeight(); + T f1 = w1.Value(), f2 = w2.Value(); + if (f1 == FloatLimits<T>::PosInfinity()) + return w1; + else if (f2 == FloatLimits<T>::PosInfinity()) + return w2; + else + return LogWeightTpl<T>(f1 + f2); +} + +inline LogWeightTpl<float> Times(const LogWeightTpl<float> &w1, + const LogWeightTpl<float> &w2) { + return Times<float>(w1, w2); +} + +inline LogWeightTpl<double> Times(const LogWeightTpl<double> &w1, + const LogWeightTpl<double> &w2) { + return Times<double>(w1, w2); +} + +template <class T> +inline LogWeightTpl<T> Divide(const LogWeightTpl<T> &w1, + const LogWeightTpl<T> &w2, + DivideType typ = DIVIDE_ANY) { + if (!w1.Member() || !w2.Member()) + return LogWeightTpl<T>::NoWeight(); + T f1 = w1.Value(), f2 = w2.Value(); + if (f2 == FloatLimits<T>::PosInfinity()) + return FloatLimits<T>::NumberBad(); + else if (f1 == FloatLimits<T>::PosInfinity()) + return FloatLimits<T>::PosInfinity(); + else + return LogWeightTpl<T>(f1 - f2); +} + +inline LogWeightTpl<float> Divide(const LogWeightTpl<float> &w1, + const LogWeightTpl<float> &w2, + DivideType typ = DIVIDE_ANY) { + return Divide<float>(w1, w2, typ); +} + +inline LogWeightTpl<double> Divide(const LogWeightTpl<double> &w1, + const LogWeightTpl<double> &w2, + DivideType typ = DIVIDE_ANY) { + return Divide<double>(w1, w2, typ); +} + +// MinMax semiring: (min, max, inf, -inf) +template <class T> +class MinMaxWeightTpl : public FloatWeightTpl<T> { + public: + using FloatWeightTpl<T>::Value; + + typedef MinMaxWeightTpl<T> ReverseWeight; + + MinMaxWeightTpl() : FloatWeightTpl<T>() {} + + MinMaxWeightTpl(T f) : FloatWeightTpl<T>(f) {} + + MinMaxWeightTpl(const MinMaxWeightTpl<T> &w) : FloatWeightTpl<T>(w) {} + + static const MinMaxWeightTpl<T> Zero() { + return MinMaxWeightTpl<T>(FloatLimits<T>::PosInfinity()); + } + + static const MinMaxWeightTpl<T> One() { + return MinMaxWeightTpl<T>(FloatLimits<T>::NegInfinity()); + } + + static const MinMaxWeightTpl<T> NoWeight() { + return MinMaxWeightTpl<T>(FloatLimits<T>::NumberBad()); } + + static const string &Type() { + static const string type = "minmax" + + FloatWeightTpl<T>::GetPrecisionString(); + return type; + } + + bool Member() const { + // Fails for IEEE NaN + return Value() == Value(); + } + + MinMaxWeightTpl<T> Quantize(float delta = kDelta) const { + // If one of infinities, or a NaN + if (Value() == FloatLimits<T>::NegInfinity() || + Value() == FloatLimits<T>::PosInfinity() || + Value() != Value()) + return *this; + else + return MinMaxWeightTpl<T>(floor(Value()/delta + 0.5F) * delta); + } + + MinMaxWeightTpl<T> Reverse() const { return *this; } + + static uint64 Properties() { + return kLeftSemiring | kRightSemiring | kCommutative | kIdempotent | kPath; + } +}; + +// Single-precision min-max weight +typedef MinMaxWeightTpl<float> MinMaxWeight; + +// Min +template <class T> +inline MinMaxWeightTpl<T> Plus( + const MinMaxWeightTpl<T> &w1, const MinMaxWeightTpl<T> &w2) { + if (!w1.Member() || !w2.Member()) + return MinMaxWeightTpl<T>::NoWeight(); + return w1.Value() < w2.Value() ? w1 : w2; +} + +inline MinMaxWeightTpl<float> Plus( + const MinMaxWeightTpl<float> &w1, const MinMaxWeightTpl<float> &w2) { + return Plus<float>(w1, w2); +} + +inline MinMaxWeightTpl<double> Plus( + const MinMaxWeightTpl<double> &w1, const MinMaxWeightTpl<double> &w2) { + return Plus<double>(w1, w2); +} + +// Max +template <class T> +inline MinMaxWeightTpl<T> Times( + const MinMaxWeightTpl<T> &w1, const MinMaxWeightTpl<T> &w2) { + if (!w1.Member() || !w2.Member()) + return MinMaxWeightTpl<T>::NoWeight(); + return w1.Value() >= w2.Value() ? w1 : w2; +} + +inline MinMaxWeightTpl<float> Times( + const MinMaxWeightTpl<float> &w1, const MinMaxWeightTpl<float> &w2) { + return Times<float>(w1, w2); +} + +inline MinMaxWeightTpl<double> Times( + const MinMaxWeightTpl<double> &w1, const MinMaxWeightTpl<double> &w2) { + return Times<double>(w1, w2); +} + +// Defined only for special cases +template <class T> +inline MinMaxWeightTpl<T> Divide(const MinMaxWeightTpl<T> &w1, + const MinMaxWeightTpl<T> &w2, + DivideType typ = DIVIDE_ANY) { + if (!w1.Member() || !w2.Member()) + return MinMaxWeightTpl<T>::NoWeight(); + // min(w1, x) = w2, w1 >= w2 => min(w1, x) = w2, x = w2 + return w1.Value() >= w2.Value() ? w1 : FloatLimits<T>::NumberBad(); +} + +inline MinMaxWeightTpl<float> Divide(const MinMaxWeightTpl<float> &w1, + const MinMaxWeightTpl<float> &w2, + DivideType typ = DIVIDE_ANY) { + return Divide<float>(w1, w2, typ); +} + +inline MinMaxWeightTpl<double> Divide(const MinMaxWeightTpl<double> &w1, + const MinMaxWeightTpl<double> &w2, + DivideType typ = DIVIDE_ANY) { + return Divide<double>(w1, w2, typ); +} + +// +// WEIGHT CONVERTER SPECIALIZATIONS. +// + +// Convert to tropical +template <> +struct WeightConvert<LogWeight, TropicalWeight> { + TropicalWeight operator()(LogWeight w) const { return w.Value(); } +}; + +template <> +struct WeightConvert<Log64Weight, TropicalWeight> { + TropicalWeight operator()(Log64Weight w) const { return w.Value(); } +}; + +// Convert to log +template <> +struct WeightConvert<TropicalWeight, LogWeight> { + LogWeight operator()(TropicalWeight w) const { return w.Value(); } +}; + +template <> +struct WeightConvert<Log64Weight, LogWeight> { + LogWeight operator()(Log64Weight w) const { return w.Value(); } +}; + +// Convert to log64 +template <> +struct WeightConvert<TropicalWeight, Log64Weight> { + Log64Weight operator()(TropicalWeight w) const { return w.Value(); } +}; + +template <> +struct WeightConvert<LogWeight, Log64Weight> { + Log64Weight operator()(LogWeight w) const { return w.Value(); } +}; + +} // namespace fst + +#endif // FST_LIB_FLOAT_WEIGHT_H__ diff --git a/kaldi_io/src/tools/openfst/include/fst/fst-decl.h b/kaldi_io/src/tools/openfst/include/fst/fst-decl.h new file mode 100644 index 0000000..f27ded8 --- /dev/null +++ b/kaldi_io/src/tools/openfst/include/fst/fst-decl.h @@ -0,0 +1,124 @@ +// fst-decl.h + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Copyright 2005-2010 Google, Inc. +// Author: [email protected] (Michael Riley) +// +// \file +// This file contains declarations of classes in the Fst template library. +// + +#ifndef FST_LIB_FST_DECL_H__ +#define FST_LIB_FST_DECL_H__ + +#include <fst/types.h> + +namespace fst { + +class SymbolTable; +class SymbolTableIterator; + +template <class W> class FloatWeightTpl; +template <class W> class TropicalWeightTpl; +template <class W> class LogWeightTpl; +template <class W> class MinMaxWeightTpl; + +typedef FloatWeightTpl<float> FloatWeight; +typedef TropicalWeightTpl<float> TropicalWeight; +typedef LogWeightTpl<float> LogWeight; +typedef MinMaxWeightTpl<float> MinMaxWeight; + +template <class W> class ArcTpl; +typedef ArcTpl<TropicalWeight> StdArc; +typedef ArcTpl<LogWeight> LogArc; + +template <class A, class C, class U = uint32> class CompactFst; +template <class A, class U = uint32> class ConstFst; +template <class A, class W, class M> class EditFst; +template <class A> class ExpandedFst; +template <class A> class Fst; +template <class A> class MutableFst; +template <class A> class VectorFst; + +template <class A, class C> class ArcSortFst; +template <class A> class ClosureFst; +template <class A> class ComposeFst; +template <class A> class ConcatFst; +template <class A> class DeterminizeFst; +template <class A> class DifferenceFst; +template <class A> class IntersectFst; +template <class A> class InvertFst; +template <class A, class B, class C> class ArcMapFst; +template <class A> class ProjectFst; +template <class A, class B, class S> class RandGenFst; +template <class A> class RelabelFst; +template <class A, class T> class ReplaceFst; +template <class A> class RmEpsilonFst; +template <class A> class UnionFst; + +template <class T, class Compare, bool max> class Heap; + +template <class A> class AcceptorCompactor; +template <class A> class StringCompactor; +template <class A> class UnweightedAcceptorCompactor; +template <class A> class UnweightedCompactor; +template <class A> class WeightedStringCompactor; + +template <class A, class P> class DefaultReplaceStateTable; + +typedef CompactFst<StdArc, AcceptorCompactor<StdArc> > +StdCompactAcceptorFst; +typedef CompactFst< StdArc, StringCompactor<StdArc> > +StdCompactStringFst; +typedef CompactFst<StdArc, UnweightedAcceptorCompactor<StdArc> > +StdCompactUnweightedAcceptorFst; +typedef CompactFst<StdArc, UnweightedCompactor<StdArc> > +StdCompactUnweightedFst; +typedef CompactFst< StdArc, WeightedStringCompactor<StdArc> > +StdCompactWeightedStringFst; +typedef ConstFst<StdArc> StdConstFst; +typedef ExpandedFst<StdArc> StdExpandedFst; +typedef Fst<StdArc> StdFst; +typedef MutableFst<StdArc> StdMutableFst; +typedef VectorFst<StdArc> StdVectorFst; + + +template <class C> class StdArcSortFst; +typedef ClosureFst<StdArc> StdClosureFst; +typedef ComposeFst<StdArc> StdComposeFst; +typedef ConcatFst<StdArc> StdConcatFst; +typedef DeterminizeFst<StdArc> StdDeterminizeFst; +typedef DifferenceFst<StdArc> StdDifferenceFst; +typedef IntersectFst<StdArc> StdIntersectFst; +typedef InvertFst<StdArc> StdInvertFst; +typedef ProjectFst<StdArc> StdProjectFst; +typedef RelabelFst<StdArc> StdRelabelFst; +typedef ReplaceFst<StdArc, DefaultReplaceStateTable<StdArc, ssize_t> > +StdReplaceFst; +typedef RmEpsilonFst<StdArc> StdRmEpsilonFst; +typedef UnionFst<StdArc> StdUnionFst; + +template <typename T> class IntegerFilterState; +typedef IntegerFilterState<signed char> CharFilterState; +typedef IntegerFilterState<short> ShortFilterState; +typedef IntegerFilterState<int> IntFilterState; + +template <class F> class Matcher; +template <class M1, class M2 = M1> class SequenceComposeFilter; +template <class M1, class M2 = M1> class AltSequenceComposeFilter; +template <class M1, class M2 = M1> class MatchComposeFilter; + +} // namespace fst + +#endif // FST_LIB_FST_DECL_H__ diff --git a/kaldi_io/src/tools/openfst/include/fst/fst.h b/kaldi_io/src/tools/openfst/include/fst/fst.h new file mode 100644 index 0000000..150fc4e --- /dev/null +++ b/kaldi_io/src/tools/openfst/include/fst/fst.h @@ -0,0 +1,949 @@ +// fst.h + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Copyright 2005-2010 Google, Inc. +// Author: [email protected] (Michael Riley) +// +// \file +// Finite-State Transducer (FST) - abstract base class definition, +// state and arc iterator interface, and suggested base implementation. +// + +#ifndef FST_LIB_FST_H__ +#define FST_LIB_FST_H__ + +#include <stddef.h> +#include <sys/types.h> +#include <cmath> +#include <string> + +#include <fst/compat.h> +#include <fst/types.h> + +#include <fst/arc.h> +#include <fst/properties.h> +#include <fst/register.h> +#include <iostream> +#include <fstream> +#include <sstream> +#include <fst/symbol-table.h> +#include <fst/util.h> + + +DECLARE_bool(fst_align); + +namespace fst { + +bool IsFstHeader(istream &, const string &); + +class FstHeader; +template <class A> class StateIteratorData; +template <class A> class ArcIteratorData; +template <class A> class MatcherBase; + +struct FstReadOptions { + // FileReadMode(s) are advisory, there are many conditions than prevent a + // file from being mapped, READ mode will be selected in these cases with + // a warning indicating why it was chosen. + enum FileReadMode { READ, MAP }; + + string source; // Where you're reading from + const FstHeader *header; // Pointer to Fst header. If non-zero, use + // this info (don't read a stream header) + const SymbolTable* isymbols; // Pointer to input symbols. If non-zero, use + // this info (read and skip stream isymbols) + const SymbolTable* osymbols; // Pointer to output symbols. If non-zero, use + // this info (read and skip stream osymbols) + FileReadMode mode; // Read or map files (advisory, if possible) + + explicit FstReadOptions(const string& src = "<unspecified>", + const FstHeader *hdr = 0, + const SymbolTable* isym = 0, + const SymbolTable* osym = 0); + + explicit FstReadOptions(const string& src, + const SymbolTable* isym, + const SymbolTable* osym = 0); + + // Helper function to convert strings FileReadModes into their enum value. + static FileReadMode ReadMode(const string &mode); +}; + +struct FstWriteOptions { + string source; // Where you're writing to + bool write_header; // Write the header? + bool write_isymbols; // Write input symbols? + bool write_osymbols; // Write output symbols? + bool align; // Write data aligned where appropriate; + // this may fail on pipes + + explicit FstWriteOptions(const string& src = "<unspecifed>", + bool hdr = true, bool isym = true, + bool osym = true, bool alig = FLAGS_fst_align) + : source(src), write_header(hdr), + write_isymbols(isym), write_osymbols(osym), align(alig) {} +}; + +// +// Fst HEADER CLASS +// +// This is the recommended Fst file header representation. +// +class FstHeader { + public: + enum { + HAS_ISYMBOLS = 0x1, // Has input symbol table + HAS_OSYMBOLS = 0x2, // Has output symbol table + IS_ALIGNED = 0x4, // Memory-aligned (where appropriate) + } Flags; + + FstHeader() : version_(0), flags_(0), properties_(0), start_(-1), + numstates_(0), numarcs_(0) {} + const string &FstType() const { return fsttype_; } + const string &ArcType() const { return arctype_; } + int32 Version() const { return version_; } + int32 GetFlags() const { return flags_; } + uint64 Properties() const { return properties_; } + int64 Start() const { return start_; } + int64 NumStates() const { return numstates_; } + int64 NumArcs() const { return numarcs_; } + + void SetFstType(const string& type) { fsttype_ = type; } + void SetArcType(const string& type) { arctype_ = type; } + void SetVersion(int32 version) { version_ = version; } + void SetFlags(int32 flags) { flags_ = flags; } + void SetProperties(uint64 properties) { properties_ = properties; } + void SetStart(int64 start) { start_ = start; } + void SetNumStates(int64 numstates) { numstates_ = numstates; } + void SetNumArcs(int64 numarcs) { numarcs_ = numarcs; } + + bool Read(istream &strm, const string &source, bool rewind = false); + bool Write(ostream &strm, const string &source) const; + + private: + + string fsttype_; // E.g. "vector" + string arctype_; // E.g. "standard" + int32 version_; // Type version # + int32 flags_; // File format bits + uint64 properties_; // FST property bits + int64 start_; // Start state + int64 numstates_; // # of states + int64 numarcs_; // # of arcs +}; + + +// Specifies matcher action. +enum MatchType { MATCH_INPUT, // Match input label. + MATCH_OUTPUT, // Match output label. + MATCH_BOTH, // Match input or output label. + MATCH_NONE, // Match nothing. + MATCH_UNKNOWN }; // Match type unknown. + +// +// Fst INTERFACE CLASS DEFINITION +// + +// A generic FST, templated on the arc definition, with +// common-demoninator methods (use StateIterator and ArcIterator to +// iterate over its states and arcs). +template <class A> +class Fst { + public: + typedef A Arc; + typedef typename A::Weight Weight; + typedef typename A::StateId StateId; + + virtual ~Fst() {} + + virtual StateId Start() const = 0; // Initial state + + virtual Weight Final(StateId) const = 0; // State's final weight + + virtual size_t NumArcs(StateId) const = 0; // State's arc count + + virtual size_t NumInputEpsilons(StateId) + const = 0; // State's input epsilon count + + virtual size_t NumOutputEpsilons(StateId) + const = 0; // State's output epsilon count + + // If test=false, return stored properties bits for mask (some poss. unknown) + // If test=true, return property bits for mask (computing o.w. unknown) + virtual uint64 Properties(uint64 mask, bool test) + const = 0; // Property bits + + virtual const string& Type() const = 0; // Fst type name + + // Get a copy of this Fst. The copying behaves as follows: + // + // (1) The copying is constant time if safe = false or if safe = true + // and is on an otherwise unaccessed Fst. + // + // (2) If safe = true, the copy is thread-safe in that the original + // and copy can be safely accessed (but not necessarily mutated) by + // separate threads. For some Fst types, 'Copy(true)' should only be + // called on an Fst that has not otherwise been accessed. Its behavior + // is undefined otherwise. + // + // (3) If a MutableFst is copied and then mutated, then the original is + // unmodified and vice versa (often by a copy-on-write on the initial + // mutation, which may not be constant time). + virtual Fst<A> *Copy(bool safe = false) const = 0; + + // Read an Fst from an input stream; returns NULL on error + static Fst<A> *Read(istream &strm, const FstReadOptions &opts) { + FstReadOptions ropts(opts); + FstHeader hdr; + if (ropts.header) + hdr = *opts.header; + else { + if (!hdr.Read(strm, opts.source)) + return 0; + ropts.header = &hdr; + } + FstRegister<A> *registr = FstRegister<A>::GetRegister(); + const typename FstRegister<A>::Reader reader = + registr->GetReader(hdr.FstType()); + if (!reader) { + LOG(ERROR) << "Fst::Read: Unknown FST type \"" << hdr.FstType() + << "\" (arc type = \"" << A::Type() + << "\"): " << ropts.source; + return 0; + } + return reader(strm, ropts); + }; + + // Read an Fst from a file; return NULL on error + // Empty filename reads from standard input + static Fst<A> *Read(const string &filename) { + if (!filename.empty()) { + ifstream strm(filename.c_str(), ifstream::in | ifstream::binary); + if (!strm) { + LOG(ERROR) << "Fst::Read: Can't open file: " << filename; + return 0; + } + return Read(strm, FstReadOptions(filename)); + } else { + return Read(cin, FstReadOptions("standard input")); + } + } + + // Write an Fst to an output stream; return false on error + virtual bool Write(ostream &strm, const FstWriteOptions &opts) const { + LOG(ERROR) << "Fst::Write: No write stream method for " << Type() + << " Fst type"; + return false; + } + + // Write an Fst to a file; return false on error + // Empty filename writes to standard output + virtual bool Write(const string &filename) const { + LOG(ERROR) << "Fst::Write: No write filename method for " << Type() + << " Fst type"; + return false; + } + + // Return input label symbol table; return NULL if not specified + virtual const SymbolTable* InputSymbols() const = 0; + + // Return output label symbol table; return NULL if not specified + virtual const SymbolTable* OutputSymbols() const = 0; + + // For generic state iterator construction; not normally called + // directly by users. + virtual void InitStateIterator(StateIteratorData<A> *) const = 0; + + // For generic arc iterator construction; not normally called + // directly by users. + virtual void InitArcIterator(StateId s, ArcIteratorData<A> *) const = 0; + + // For generic matcher construction; not normally called + // directly by users. + virtual MatcherBase<A> *InitMatcher(MatchType match_type) const; + + protected: + bool WriteFile(const string &filename) const { + if (!filename.empty()) { + ofstream strm(filename.c_str(), ofstream::out | ofstream::binary); + if (!strm) { + LOG(ERROR) << "Fst::Write: Can't open file: " << filename; + return false; + } + return Write(strm, FstWriteOptions(filename)); + } else { + return Write(cout, FstWriteOptions("standard output")); + } + } +}; + + +// +// STATE and ARC ITERATOR DEFINITIONS +// + +// State iterator interface templated on the Arc definition; used +// for StateIterator specializations returned by the InitStateIterator +// Fst method. +template <class A> +class StateIteratorBase { + public: + typedef A Arc; + typedef typename A::StateId StateId; + + virtual ~StateIteratorBase() {} + + bool Done() const { return Done_(); } // End of iterator? + StateId Value() const { return Value_(); } // Current state (when !Done) + void Next() { Next_(); } // Advance to next state (when !Done) + void Reset() { Reset_(); } // Return to initial condition + + private: + // This allows base class virtual access to non-virtual derived- + // class members of the same name. It makes the derived class more + // efficient to use but unsafe to further derive. + virtual bool Done_() const = 0; + virtual StateId Value_() const = 0; + virtual void Next_() = 0; + virtual void Reset_() = 0; +}; + + +// StateIterator initialization data + +template <class A> struct StateIteratorData { + StateIteratorBase<A> *base; // Specialized iterator if non-zero + typename A::StateId nstates; // O.w. total # of states +}; + + +// Generic state iterator, templated on the FST definition +// - a wrapper around pointer to specific one. +// Here is a typical use: \code +// for (StateIterator<StdFst> siter(fst); +// !siter.Done(); +// siter.Next()) { +// StateId s = siter.Value(); +// ... +// } \endcode +template <class F> +class StateIterator { + public: + typedef F FST; + typedef typename F::Arc Arc; + typedef typename Arc::StateId StateId; + + explicit StateIterator(const F &fst) : s_(0) { + fst.InitStateIterator(&data_); + } + + ~StateIterator() { if (data_.base) delete data_.base; } + + bool Done() const { + return data_.base ? data_.base->Done() : s_ >= data_.nstates; + } + + StateId Value() const { return data_.base ? data_.base->Value() : s_; } + + void Next() { + if (data_.base) + data_.base->Next(); + else + ++s_; + } + + void Reset() { + if (data_.base) + data_.base->Reset(); + else + s_ = 0; + } + + private: + StateIteratorData<Arc> data_; + StateId s_; + + DISALLOW_COPY_AND_ASSIGN(StateIterator); +}; + + +// Flags to control the behavior on an arc iterator: +static const uint32 kArcILabelValue = 0x0001; // Value() gives valid ilabel +static const uint32 kArcOLabelValue = 0x0002; // " " " olabel +static const uint32 kArcWeightValue = 0x0004; // " " " weight +static const uint32 kArcNextStateValue = 0x0008; // " " " nextstate +static const uint32 kArcNoCache = 0x0010; // No need to cache arcs + +static const uint32 kArcValueFlags = + kArcILabelValue | kArcOLabelValue | + kArcWeightValue | kArcNextStateValue; + +static const uint32 kArcFlags = kArcValueFlags | kArcNoCache; + + +// Arc iterator interface, templated on the Arc definition; used +// for Arc iterator specializations that are returned by the InitArcIterator +// Fst method. +template <class A> +class ArcIteratorBase { + public: + typedef A Arc; + typedef typename A::StateId StateId; + + virtual ~ArcIteratorBase() {} + + bool Done() const { return Done_(); } // End of iterator? + const A& Value() const { return Value_(); } // Current arc (when !Done) + void Next() { Next_(); } // Advance to next arc (when !Done) + size_t Position() const { return Position_(); } // Return current position + void Reset() { Reset_(); } // Return to initial condition + void Seek(size_t a) { Seek_(a); } // Random arc access by position + uint32 Flags() const { return Flags_(); } // Return current behavorial flags + void SetFlags(uint32 flags, uint32 mask) { // Set behavorial flags + SetFlags_(flags, mask); + } + + private: + // This allows base class virtual access to non-virtual derived- + // class members of the same name. It makes the derived class more + // efficient to use but unsafe to further derive. + virtual bool Done_() const = 0; + virtual const A& Value_() const = 0; + virtual void Next_() = 0; + virtual size_t Position_() const = 0; + virtual void Reset_() = 0; + virtual void Seek_(size_t a) = 0; + virtual uint32 Flags_() const = 0; + virtual void SetFlags_(uint32 flags, uint32 mask) = 0; +}; + + +// ArcIterator initialization data +template <class A> struct ArcIteratorData { + ArcIteratorBase<A> *base; // Specialized iterator if non-zero + const A *arcs; // O.w. arcs pointer + size_t narcs; // ... and arc count + int *ref_count; // ... and reference count if non-zero +}; + + +// Generic arc iterator, templated on the FST definition +// - a wrapper around pointer to specific one. +// Here is a typical use: \code +// for (ArcIterator<StdFst> aiter(fst, s)); +// !aiter.Done(); +// aiter.Next()) { +// StdArc &arc = aiter.Value(); +// ... +// } \endcode +template <class F> +class ArcIterator { + public: + typedef F FST; + typedef typename F::Arc Arc; + typedef typename Arc::StateId StateId; + + ArcIterator(const F &fst, StateId s) : i_(0) { + fst.InitArcIterator(s, &data_); + } + + explicit ArcIterator(const ArcIteratorData<Arc> &data) : data_(data), i_(0) { + if (data_.ref_count) + ++(*data_.ref_count); + } + + ~ArcIterator() { + if (data_.base) + delete data_.base; + else if (data_.ref_count) + --(*data_.ref_count); + } + + bool Done() const { + return data_.base ? data_.base->Done() : i_ >= data_.narcs; + } + + const Arc& Value() const { + return data_.base ? data_.base->Value() : data_.arcs[i_]; + } + + void Next() { + if (data_.base) + data_.base->Next(); + else + ++i_; + } + + void Reset() { + if (data_.base) + data_.base->Reset(); + else + i_ = 0; + } + + void Seek(size_t a) { + if (data_.base) + data_.base->Seek(a); + else + i_ = a; + } + + size_t Position() const { + return data_.base ? data_.base->Position() : i_; + } + + uint32 Flags() const { + if (data_.base) + return data_.base->Flags(); + else + return kArcValueFlags; + } + + void SetFlags(uint32 flags, uint32 mask) { + if (data_.base) + data_.base->SetFlags(flags, mask); + } + + private: + ArcIteratorData<Arc> data_; + size_t i_; + DISALLOW_COPY_AND_ASSIGN(ArcIterator); +}; + +// +// MATCHER DEFINITIONS +// + +template <class A> +MatcherBase<A> *Fst<A>::InitMatcher(MatchType match_type) const { + return 0; // Use the default matcher +} + + +// +// FST ACCESSORS - Useful functions in high-performance cases. +// + +namespace internal { + +// General case - requires non-abstract, 'final' methods. Use for inlining. +template <class F> inline +typename F::Arc::Weight Final(const F &fst, typename F::Arc::StateId s) { + return fst.F::Final(s); +} + +template <class F> inline +ssize_t NumArcs(const F &fst, typename F::Arc::StateId s) { + return fst.F::NumArcs(s); +} + +template <class F> inline +ssize_t NumInputEpsilons(const F &fst, typename F::Arc::StateId s) { + return fst.F::NumInputEpsilons(s); +} + +template <class F> inline +ssize_t NumOutputEpsilons(const F &fst, typename F::Arc::StateId s) { + return fst.F::NumOutputEpsilons(s); +} + + +// Fst<A> case - abstract methods. +template <class A> inline +typename A::Weight Final(const Fst<A> &fst, typename A::StateId s) { + return fst.Final(s); +} + +template <class A> inline +ssize_t NumArcs(const Fst<A> &fst, typename A::StateId s) { + return fst.NumArcs(s); +} + +template <class A> inline +ssize_t NumInputEpsilons(const Fst<A> &fst, typename A::StateId s) { + return fst.NumInputEpsilons(s); +} + +template <class A> inline +ssize_t NumOutputEpsilons(const Fst<A> &fst, typename A::StateId s) { + return fst.NumOutputEpsilons(s); +} + +} // namespace internal + +// A useful alias when using StdArc. +typedef Fst<StdArc> StdFst; + + +// +// CONSTANT DEFINITIONS +// + +const int kNoStateId = -1; // Not a valid state ID +const int kNoLabel = -1; // Not a valid label + +// +// Fst IMPLEMENTATION BASE +// +// This is the recommended Fst implementation base class. It will +// handle reference counts, property bits, type information and symbols. +// + +template <class A> class FstImpl { + public: + typedef typename A::Weight Weight; + typedef typename A::StateId StateId; + + FstImpl() + : properties_(0), type_("null"), isymbols_(0), osymbols_(0) {} + + FstImpl(const FstImpl<A> &impl) + : properties_(impl.properties_), type_(impl.type_), + isymbols_(impl.isymbols_ ? impl.isymbols_->Copy() : 0), + osymbols_(impl.osymbols_ ? impl.osymbols_->Copy() : 0) {} + + virtual ~FstImpl() { + delete isymbols_; + delete osymbols_; + } + + const string& Type() const { return type_; } + + void SetType(const string &type) { type_ = type; } + + virtual uint64 Properties() const { return properties_; } + + virtual uint64 Properties(uint64 mask) const { return properties_ & mask; } + + void SetProperties(uint64 props) { + properties_ &= kError; // kError can't be cleared + properties_ |= props; + } + + void SetProperties(uint64 props, uint64 mask) { + properties_ &= ~mask | kError; // kError can't be cleared + properties_ |= props & mask; + } + + // Allows (only) setting error bit on const FST impls + void SetProperties(uint64 props, uint64 mask) const { + if (mask != kError) + FSTERROR() << "FstImpl::SetProperties() const: can only set kError"; + properties_ |= kError; + } + + const SymbolTable* InputSymbols() const { return isymbols_; } + + const SymbolTable* OutputSymbols() const { return osymbols_; } + + SymbolTable* InputSymbols() { return isymbols_; } + + SymbolTable* OutputSymbols() { return osymbols_; } + + void SetInputSymbols(const SymbolTable* isyms) { + if (isymbols_) delete isymbols_; + isymbols_ = isyms ? isyms->Copy() : 0; + } + + void SetOutputSymbols(const SymbolTable* osyms) { + if (osymbols_) delete osymbols_; + osymbols_ = osyms ? osyms->Copy() : 0; + } + + int RefCount() const { + return ref_count_.count(); + } + + int IncrRefCount() { + return ref_count_.Incr(); + } + + int DecrRefCount() { + return ref_count_.Decr(); + } + + // Read-in header and symbols from input stream, initialize Fst, and + // return the header. If opts.header is non-null, skip read-in and + // use the option value. If opts.[io]symbols is non-null, read-in + // (if present), but use the option value. + bool ReadHeader(istream &strm, const FstReadOptions& opts, + int min_version, FstHeader *hdr); + + // Write-out header and symbols from output stream. + // If a opts.header is false, skip writing header. + // If opts.[io]symbols is false, skip writing those symbols. + // This method is needed for Impl's that implement Write methods. + void WriteHeader(ostream &strm, const FstWriteOptions& opts, + int version, FstHeader *hdr) const { + if (opts.write_header) { + hdr->SetFstType(type_); + hdr->SetArcType(A::Type()); + hdr->SetVersion(version); + hdr->SetProperties(properties_); + int32 file_flags = 0; + if (isymbols_ && opts.write_isymbols) + file_flags |= FstHeader::HAS_ISYMBOLS; + if (osymbols_ && opts.write_osymbols) + file_flags |= FstHeader::HAS_OSYMBOLS; + if (opts.align) + file_flags |= FstHeader::IS_ALIGNED; + hdr->SetFlags(file_flags); + hdr->Write(strm, opts.source); + } + if (isymbols_ && opts.write_isymbols) isymbols_->Write(strm); + if (osymbols_ && opts.write_osymbols) osymbols_->Write(strm); + } + + // Write-out header and symbols to output stream. + // If a opts.header is false, skip writing header. + // If opts.[io]symbols is false, skip writing those symbols. + // type is the Fst type being written. + // This method is used in the cross-type serialization methods Fst::WriteFst. + static void WriteFstHeader(const Fst<A> &fst, ostream &strm, + const FstWriteOptions& opts, int version, + const string &type, uint64 properties, + FstHeader *hdr) { + if (opts.write_header) { + hdr->SetFstType(type); + hdr->SetArcType(A::Type()); + hdr->SetVersion(version); + hdr->SetProperties(properties); + int32 file_flags = 0; + if (fst.InputSymbols() && opts.write_isymbols) + file_flags |= FstHeader::HAS_ISYMBOLS; + if (fst.OutputSymbols() && opts.write_osymbols) + file_flags |= FstHeader::HAS_OSYMBOLS; + if (opts.align) + file_flags |= FstHeader::IS_ALIGNED; + hdr->SetFlags(file_flags); + hdr->Write(strm, opts.source); + } + if (fst.InputSymbols() && opts.write_isymbols) { + fst.InputSymbols()->Write(strm); + } + if (fst.OutputSymbols() && opts.write_osymbols) { + fst.OutputSymbols()->Write(strm); + } + } + + // In serialization routines where the header cannot be written until after + // the machine has been serialized, this routine can be called to seek to + // the beginning of the file an rewrite the header with updated fields. + // It repositions the file pointer back at the end of the file. + // returns true on success, false on failure. + static bool UpdateFstHeader(const Fst<A> &fst, ostream &strm, + const FstWriteOptions& opts, int version, + const string &type, uint64 properties, + FstHeader *hdr, size_t header_offset) { + strm.seekp(header_offset); + if (!strm) { + LOG(ERROR) << "Fst::UpdateFstHeader: write failed: " << opts.source; + return false; + } + WriteFstHeader(fst, strm, opts, version, type, properties, hdr); + if (!strm) { + LOG(ERROR) << "Fst::UpdateFstHeader: write failed: " << opts.source; + return false; + } + strm.seekp(0, ios_base::end); + if (!strm) { + LOG(ERROR) << "Fst::UpdateFstHeader: write failed: " << opts.source; + return false; + } + return true; + } + + protected: + mutable uint64 properties_; // Property bits + + private: + string type_; // Unique name of Fst class + SymbolTable *isymbols_; // Ilabel symbol table + SymbolTable *osymbols_; // Olabel symbol table + RefCounter ref_count_; // Reference count + + void operator=(const FstImpl<A> &impl); // disallow +}; + +template <class A> inline +bool FstImpl<A>::ReadHeader(istream &strm, const FstReadOptions& opts, + int min_version, FstHeader *hdr) { + if (opts.header) + *hdr = *opts.header; + else if (!hdr->Read(strm, opts.source)) + return false; + + if (FLAGS_v >= 2) { + LOG(INFO) << "FstImpl::ReadHeader: source: " << opts.source + << ", fst_type: " << hdr->FstType() + << ", arc_type: " << A::Type() + << ", version: " << hdr->Version() + << ", flags: " << hdr->GetFlags(); + } + + if (hdr->FstType() != type_) { + LOG(ERROR) << "FstImpl::ReadHeader: Fst not of type \"" << type_ + << "\": " << opts.source; + return false; + } + if (hdr->ArcType() != A::Type()) { + LOG(ERROR) << "FstImpl::ReadHeader: Arc not of type \"" << A::Type() + << "\": " << opts.source; + return false; + } + if (hdr->Version() < min_version) { + LOG(ERROR) << "FstImpl::ReadHeader: Obsolete " << type_ + << " Fst version: " << opts.source; + return false; + } + properties_ = hdr->Properties(); + if (hdr->GetFlags() & FstHeader::HAS_ISYMBOLS) + isymbols_ = SymbolTable::Read(strm, opts.source); + if (hdr->GetFlags() & FstHeader::HAS_OSYMBOLS) + osymbols_ =SymbolTable::Read(strm, opts.source); + + if (opts.isymbols) { + delete isymbols_; + isymbols_ = opts.isymbols->Copy(); + } + if (opts.osymbols) { + delete osymbols_; + osymbols_ = opts.osymbols->Copy(); + } + return true; +} + + +template<class Arc> +uint64 TestProperties(const Fst<Arc> &fst, uint64 mask, uint64 *known); + + +// This is a helper class template useful for attaching an Fst interface to +// its implementation, handling reference counting. +template < class I, class F = Fst<typename I::Arc> > +class ImplToFst : public F { + public: + typedef typename I::Arc Arc; + typedef typename Arc::Weight Weight; + typedef typename Arc::StateId StateId; + + virtual ~ImplToFst() { if (!impl_->DecrRefCount()) delete impl_; } + + virtual StateId Start() const { return impl_->Start(); } + + virtual Weight Final(StateId s) const { return impl_->Final(s); } + + virtual size_t NumArcs(StateId s) const { return impl_->NumArcs(s); } + + virtual size_t NumInputEpsilons(StateId s) const { + return impl_->NumInputEpsilons(s); + } + + virtual size_t NumOutputEpsilons(StateId s) const { + return impl_->NumOutputEpsilons(s); + } + + virtual uint64 Properties(uint64 mask, bool test) const { + if (test) { + uint64 knownprops, testprops = TestProperties(*this, mask, &knownprops); + impl_->SetProperties(testprops, knownprops); + return testprops & mask; + } else { + return impl_->Properties(mask); + } + } + + virtual const string& Type() const { return impl_->Type(); } + + virtual const SymbolTable* InputSymbols() const { + return impl_->InputSymbols(); + } + + virtual const SymbolTable* OutputSymbols() const { + return impl_->OutputSymbols(); + } + + protected: + ImplToFst() : impl_(0) {} + + ImplToFst(I *impl) : impl_(impl) {} + + ImplToFst(const ImplToFst<I, F> &fst) { + impl_ = fst.impl_; + impl_->IncrRefCount(); + } + + // This constructor presumes there is a copy constructor for the + // implementation. + ImplToFst(const ImplToFst<I, F> &fst, bool safe) { + if (safe) { + impl_ = new I(*(fst.impl_)); + } else { + impl_ = fst.impl_; + impl_->IncrRefCount(); + } + } + + I *GetImpl() const { return impl_; } + + // Change Fst implementation pointer. If 'own_impl' is true, + // ownership of the input implementation is given to this + // object; otherwise, the input implementation's reference count + // should be incremented. + void SetImpl(I *impl, bool own_impl = true) { + if (!own_impl) + impl->IncrRefCount(); + if (impl_ && !impl_->DecrRefCount()) delete impl_; + impl_ = impl; + } + + private: + // Disallow + ImplToFst<I, F> &operator=(const ImplToFst<I, F> &fst); + + ImplToFst<I, F> &operator=(const Fst<Arc> &fst) { + FSTERROR() << "ImplToFst: Assignment operator disallowed"; + GetImpl()->SetProperties(kError, kError); + return *this; + } + + I *impl_; +}; + + +// Converts FSTs by casting their implementations, where this makes +// sense (which excludes implementations with weight-dependent virtual +// methods). Must be a friend of the Fst classes involved (currently +// the concrete Fsts: VectorFst, ConstFst, CompactFst). +template<class F, class G> void Cast(const F &ifst, G *ofst) { + ofst->SetImpl(reinterpret_cast<typename G::Impl *>(ifst.GetImpl()), false); +} + +// Fst Serialization +template <class A> +void FstToString(const Fst<A> &fst, string *result) { + ostringstream ostrm; + fst.Write(ostrm, FstWriteOptions("FstToString")); + *result = ostrm.str(); +} + +template <class A> +Fst<A> *StringToFst(const string &s) { + istringstream istrm(s); + return Fst<A>::Read(istrm, FstReadOptions("StringToFst")); +} + +} // namespace fst + +#endif // FST_LIB_FST_H__ diff --git a/kaldi_io/src/tools/openfst/include/fst/fstlib.h b/kaldi_io/src/tools/openfst/include/fst/fstlib.h new file mode 100644 index 0000000..de5976d --- /dev/null +++ b/kaldi_io/src/tools/openfst/include/fst/fstlib.h @@ -0,0 +1,153 @@ +// fstlib.h + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Copyright 2005-2010 Google, Inc. +// Author: [email protected] (Michael Riley) +// +// \page FstLib FST - Weighted Finite State Transducers +// This is a library for constructing, combining, optimizing, and +// searching "weighted finite-state transducers" (FSTs). Weighted +// finite-state transducers are automata where each transition has an +// input label, an output label, and a weight. The more familiar +// finite-state acceptor is represented as a transducer with each +// transition's input and output the same. Finite-state acceptors +// are used to represent sets of strings (specifically, "regular" or +// "rational sets"); finite-state transducers are used to represent +// binary relations between pairs of strings (specifically, "rational +// transductions"). The weights can be used to represent the cost of +// taking a particular transition. +// +// In this library, the transducers are templated on the Arc +// (transition) definition, which allows changing the label, weight, +// and state ID sets. Labels and state IDs are restricted to signed +// integral types but the weight can be an arbitrary type whose +// members satisfy certain algebraic ("semiring") properties. +// +// For more information, see the FST Library Wiki page: +// http://wiki.corp.google.com/twiki/bin/view/Main/FstLibrary + +// \file +// This convenience file includes all other FST inl.h files. +// + +#ifndef FST_LIB_FSTLIB_H__ +#define FST_LIB_FSTLIB_H__ + + +// Abstract FST classes +#include <fst/fst.h> +#include <fst/expanded-fst.h> +#include <fst/mutable-fst.h> + +// Concrete FST classes +#include <fst/compact-fst.h> +#include <fst/const-fst.h> +#include <fst/edit-fst.h> +#include <fst/vector-fst.h> + +// FST algorithms and delayed FST classes +#include <fst/arcsort.h> +#include <fst/arc-map.h> +#include <fst/closure.h> +#include <fst/compose.h> +#include <fst/concat.h> +#include <fst/connect.h> +#include <fst/determinize.h> +#include <fst/difference.h> +#include <fst/encode.h> +#include <fst/epsnormalize.h> +#include <fst/equal.h> +#include <fst/equivalent.h> +#include <fst/factor-weight.h> +#include <fst/intersect.h> +#include <fst/invert.h> +#include <fst/map.h> +#include <fst/minimize.h> +#include <fst/project.h> +#include <fst/prune.h> +#include <fst/push.h> +#include <fst/randequivalent.h> +#include <fst/randgen.h> +#include <fst/rational.h> +#include <fst/relabel.h> +#include <fst/replace.h> +#include <fst/replace-util.h> +#include <fst/reverse.h> +#include <fst/reweight.h> +#include <fst/rmepsilon.h> +#include <fst/rmfinalepsilon.h> +#include <fst/shortest-distance.h> +#include <fst/shortest-path.h> +#include <fst/statesort.h> +#include <fst/state-map.h> +#include <fst/synchronize.h> +#include <fst/topsort.h> +#include <fst/union.h> +#include <fst/verify.h> +#include <fst/visit.h> + +// Weights +#include <fst/weight.h> +#include <fst/expectation-weight.h> +#include <fst/float-weight.h> +#include <fst/lexicographic-weight.h> +#include <fst/pair-weight.h> +#include <fst/power-weight.h> +#include <fst/product-weight.h> +#include <fst/random-weight.h> +#include <fst/signed-log-weight.h> +#include <fst/sparse-power-weight.h> +#include <fst/sparse-tuple-weight.h> +#include <fst/string-weight.h> +#include <fst/tuple-weight.h> + +// Auxiliary classes for composition +#include <fst/compose-filter.h> +#include <fst/lookahead-filter.h> +#include <fst/lookahead-matcher.h> +#include <fst/matcher-fst.h> +#include <fst/matcher.h> +#include <fst/state-table.h> + +// Data structures +#include <fst/heap.h> +#include <fst/interval-set.h> +#include <fst/queue.h> +#include <fst/union-find.h> + +// Miscellaneous +#include <fst/accumulator.h> +#include <fst/add-on.h> +#include <fst/arc.h> +#include <fst/arcfilter.h> +#include <fst/cache.h> +#include <fst/complement.h> +#include <fst/dfs-visit.h> +#include <fst/generic-register.h> +#include <fst/label-reachable.h> +#include <fst/partition.h> +#include <fst/properties.h> +#include <fst/register.h> +#include <fst/state-reachable.h> +#include <iostream> +#include <fstream> +#include <sstream> +#include <fst/string.h> +#include <fst/symbol-table.h> +#include <fst/symbol-table-ops.h> +#include <fst/test-properties.h> +#include <fst/util.h> + + +#endif // FST_LIB_FSTLIB_H__ diff --git a/kaldi_io/src/tools/openfst/include/fst/generic-register.h b/kaldi_io/src/tools/openfst/include/fst/generic-register.h new file mode 100644 index 0000000..4f8b512 --- /dev/null +++ b/kaldi_io/src/tools/openfst/include/fst/generic-register.h @@ -0,0 +1,159 @@ + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Copyright 2005-2010 Google, Inc. +// Author: [email protected] (Jake Ratkiewicz) + +#ifndef FST_LIB_GENERIC_REGISTER_H_ +#define FST_LIB_GENERIC_REGISTER_H_ + +#include <map> +#include <string> + +#include <fst/compat.h> +#include <fst/types.h> + +// Generic class representing a globally-stored correspondence between +// objects of KeyType and EntryType. +// KeyType must: +// a) be such as can be stored as a key in a map<> +// b) be concatenable with a const char* with the + operator +// (or you must subclass and redefine LoadEntryFromSharedObject) +// EntryType must be default constructible. +// +// The third template parameter should be the type of a subclass of this class +// (think CRTP). This is to allow GetRegister() to instantiate and return +// an object of the appropriate type. + +namespace fst { + +template<class KeyType, class EntryType, class RegisterType> +class GenericRegister { + public: + typedef KeyType Key; + typedef EntryType Entry; + + static RegisterType *GetRegister() { + FstOnceInit(®ister_init_, + &RegisterType::Init); + + return register_; + } + + void SetEntry(const KeyType &key, + const EntryType &entry) { + MutexLock l(register_lock_); + + register_table_.insert(make_pair(key, entry)); + } + + EntryType GetEntry(const KeyType &key) const { + const EntryType *entry = LookupEntry(key); + if (entry) { + return *entry; + } else { + return LoadEntryFromSharedObject(key); + } + } + + virtual ~GenericRegister() { } + + protected: + // Override this if you want to be able to load missing definitions from + // shared object files. + virtual EntryType LoadEntryFromSharedObject(const KeyType &key) const { + string so_filename = ConvertKeyToSoFilename(key); + + void *handle = dlopen(so_filename.c_str(), RTLD_LAZY); + if (handle == 0) { + LOG(ERROR) << "GenericRegister::GetEntry : " << dlerror(); + return EntryType(); + } + + // We assume that the DSO constructs a static object in its global + // scope that does the registration. Thus we need only load it, not + // call any methods. + const EntryType *entry = this->LookupEntry(key); + if (entry == 0) { + LOG(ERROR) << "GenericRegister::GetEntry : " + << "lookup failed in shared object: " << so_filename; + return EntryType(); + } + return *entry; + } + + // Override this to define how to turn a key into an SO filename. + virtual string ConvertKeyToSoFilename(const KeyType& key) const = 0; + + virtual const EntryType *LookupEntry( + const KeyType &key) const { + MutexLock l(register_lock_); + + typename RegisterMapType::const_iterator it = register_table_.find(key); + + if (it != register_table_.end()) { + return &it->second; + } else { + return 0; + } + } + + private: + typedef map<KeyType, EntryType> RegisterMapType; + + static void Init() { + register_lock_ = new Mutex; + register_ = new RegisterType; + } + + static FstOnceType register_init_; + static Mutex *register_lock_; + static RegisterType *register_; + + RegisterMapType register_table_; +}; + +template<class KeyType, class EntryType, class RegisterType> +FstOnceType GenericRegister<KeyType, EntryType, + RegisterType>::register_init_ = FST_ONCE_INIT; + +template<class KeyType, class EntryType, class RegisterType> +Mutex *GenericRegister<KeyType, EntryType, RegisterType>::register_lock_ = 0; + +template<class KeyType, class EntryType, class RegisterType> +RegisterType *GenericRegister<KeyType, EntryType, RegisterType>::register_ = 0; + +// +// GENERIC REGISTRATION +// + +// Generic register-er class capable of creating new register entries in the +// given RegisterType template parameter. This type must define types Key +// and Entry, and have appropriate static GetRegister() and instance +// SetEntry() functions. An easy way to accomplish this is to have RegisterType +// be the type of a subclass of GenericRegister. +template<class RegisterType> +class GenericRegisterer { + public: + typedef typename RegisterType::Key Key; + typedef typename RegisterType::Entry Entry; + + GenericRegisterer(Key key, Entry entry) { + RegisterType *reg = RegisterType::GetRegister(); + reg->SetEntry(key, entry); + } +}; + +} // namespace fst + +#endif // FST_LIB_GENERIC_REGISTER_H_ diff --git a/kaldi_io/src/tools/openfst/include/fst/heap.h b/kaldi_io/src/tools/openfst/include/fst/heap.h new file mode 100644 index 0000000..a7affbd --- /dev/null +++ b/kaldi_io/src/tools/openfst/include/fst/heap.h @@ -0,0 +1,206 @@ + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Copyright 2005-2010 Google, Inc. +// All Rights Reserved. +// Author: Johan Schalkwyk ([email protected]) +// +// \file +// Implementation of a heap as in STL, but allows tracking positions +// in heap using a key. The key can be used to do an in-place update of +// values in the heap. + +#ifndef FST_LIB_HEAP_H__ +#define FST_LIB_HEAP_H__ + +#include <vector> +using std::vector; +#include <functional> + +#include <fst/compat.h> +namespace fst { + +// +// \class Heap +// \brief A templated heap implementation that support in-place update +// of values. +// +// The templated heap implementation is a little different from the +// STL priority_queue and the *_heap operations in STL. This heap +// supports indexing of values in the heap via an associated key. +// +// Each value is internally associated with a key which is returned +// to the calling functions on heap insert. This key can be used +// to later update the specific value in the heap. +// +// \param T the element type of the hash, can be POD, Data or Ptr to Data +// \param Compare Comparison class for determiningg min-heapness. +// \param whether heap top should be max or min element w.r.t. Compare +// + +static const int kNoKey = -1; +template <class T, class Compare, bool max> +class Heap { + public: + + // Initialize with a specific comparator + Heap(Compare comp) : comp_(comp), size_(0) { } + + // Create a heap with initial size of internal arrays of 0 + Heap() : size_(0) { } + + ~Heap() { } + + // Insert a value into the heap + int Insert(const T& val) { + if (size_ < A_.size()) { + A_[size_] = val; + pos_[key_[size_]] = size_; + } else { + A_.push_back(val); + pos_.push_back(size_); + key_.push_back(size_); + } + + ++size_; + return Insert(val, size_ - 1); + } + + // Update a value at position given by the key. The pos array is first + // indexed by the key. The position gives the position in the heap array. + // Once we have the position we can then use the standard heap operations + // to calculate the parent and child positions. + void Update(int key, const T& val) { + int i = pos_[key]; + if (Better(val, A_[Parent(i)])) { + Insert(val, i); + } else { + A_[i] = val; + Heapify(i); + } + } + + // Return the greatest (max=true) / least (max=false) value w.r.t. + // from the heap. + T Pop() { + T top = A_[0]; + + Swap(0, size_-1); + size_--; + Heapify(0); + return top; + } + + // Return the greatest (max=true) / least (max=false) value w.r.t. + // comp object from the heap. + T Top() const { + return A_[0]; + } + + // Check if the heap is empty + bool Empty() const { + return size_ == 0; + } + + void Clear() { + size_ = 0; + } + + + // + // The following protected routines are used in a supportive role + // for managing the heap and keeping the heap properties. + // + private: + // Compute left child of parent + int Left(int i) { + return 2*(i+1)-1; // 0 -> 1, 1 -> 3 + } + + // Compute right child of parent + int Right(int i) { + return 2*(i+1); // 0 -> 2, 1 -> 4 + } + + // Given a child compute parent + int Parent(int i) { + return (i-1)/2; // 1 -> 0, 2 -> 0, 3 -> 1, 4-> 1 + } + + // Swap a child, parent. Use to move element up/down tree. + // Note a little tricky here. When we swap we need to swap: + // the value + // the associated keys + // the position of the value in the heap + void Swap(int j, int k) { + int tkey = key_[j]; + pos_[key_[j] = key_[k]] = j; + pos_[key_[k] = tkey] = k; + + T val = A_[j]; + A_[j] = A_[k]; + A_[k] = val; + } + + // Returns the greater (max=true) / least (max=false) of two + // elements. + bool Better(const T& x, const T& y) { + return max ? comp_(y, x) : comp_(x, y); + } + + // Heapify subtree rooted at index i. + void Heapify(int i) { + int l = Left(i); + int r = Right(i); + int largest; + + if (l < size_ && Better(A_[l], A_[i]) ) + largest = l; + else + largest = i; + + if (r < size_ && Better(A_[r], A_[largest]) ) + largest = r; + + if (largest != i) { + Swap(i, largest); + Heapify(largest); + } + } + + + // Insert (update) element at subtree rooted at index i + int Insert(const T& val, int i) { + int p; + while (i > 0 && !Better(A_[p = Parent(i)], val)) { + Swap(i, p); + i = p; + } + + return key_[i]; + } + + private: + Compare comp_; + + vector<int> pos_; + vector<int> key_; + vector<T> A_; + int size_; + + // DISALLOW_COPY_AND_ASSIGN(Heap); +}; + +} // namespace fst + +#endif // FST_LIB_HEAP_H__ diff --git a/kaldi_io/src/tools/openfst/include/fst/icu.h b/kaldi_io/src/tools/openfst/include/fst/icu.h new file mode 100644 index 0000000..3947716 --- /dev/null +++ b/kaldi_io/src/tools/openfst/include/fst/icu.h @@ -0,0 +1,116 @@ +// icu.h + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Copyright 2005-2010 Google, Inc. +// Author: [email protected] (Jeffrey Sorensen) +// [email protected] (Fredrik Roubert) +// +// This library implements an unrestricted Thompson/Pike UTF-8 parser and +// serializer. UTF-8 is a restricted subset of this byte stream encoding. See +// http://en.wikipedia.org/wiki/UTF-8 for a good description of the encoding +// details. + +#ifndef FST_LIB_ICU_H_ +#define FST_LIB_ICU_H_ + +#include <iostream> +#include <fstream> +#include <sstream> + +namespace fst { + +template <class Label> +bool UTF8StringToLabels(const string &str, vector<Label> *labels) { + const char *data = str.data(); + size_t length = str.size(); + for (int i = 0; i < length; /* no update */) { + int c = data[i++] & 0xff; + if ((c & 0x80) == 0) { + labels->push_back(c); + } else { + if ((c & 0xc0) == 0x80) { + LOG(ERROR) << "UTF8StringToLabels: continuation byte as lead byte"; + return false; + } + int count = (c >= 0xc0) + (c >= 0xe0) + (c >= 0xf0) + (c >= 0xf8) + + (c >= 0xfc); + int code = c & ((1 << (6 - count)) - 1); + while (count != 0) { + if (i == length) { + LOG(ERROR) << "UTF8StringToLabels: truncated utf-8 byte sequence"; + return false; + } + char cb = data[i++]; + if ((cb & 0xc0) != 0x80) { + LOG(ERROR) << "UTF8StringToLabels: missing/invalid continuation byte"; + return false; + } + code = (code << 6) | (cb & 0x3f); + count--; + } + if (code < 0) { + // This should not be able to happen. + LOG(ERROR) << "UTF8StringToLabels: Invalid character found: " << c; + return false; + } + labels->push_back(code); + } + } + return true; +} + +template <class Label> +bool LabelsToUTF8String(const vector<Label> &labels, string *str) { + ostringstream ostr; + for (size_t i = 0; i < labels.size(); ++i) { + int32_t code = labels[i]; + if (code < 0) { + LOG(ERROR) << "LabelsToUTF8String: Invalid character found: " << code; + return false; + } else if (code < 0x80) { + ostr << static_cast<char>(code); + } else if (code < 0x800) { + ostr << static_cast<char>((code >> 6) | 0xc0); + ostr << static_cast<char>((code & 0x3f) | 0x80); + } else if (code < 0x10000) { + ostr << static_cast<char>((code >> 12) | 0xe0); + ostr << static_cast<char>(((code >> 6) & 0x3f) | 0x80); + ostr << static_cast<char>((code & 0x3f) | 0x80); + } else if (code < 0x200000) { + ostr << static_cast<char>((code >> 18) | 0xf0); + ostr << static_cast<char>(((code >> 12) & 0x3f) | 0x80); + ostr << static_cast<char>(((code >> 6) & 0x3f) | 0x80); + ostr << static_cast<char>((code & 0x3f) | 0x80); + } else if (code < 0x4000000) { + ostr << static_cast<char>((code >> 24) | 0xf8); + ostr << static_cast<char>(((code >> 18) & 0x3f) | 0x80); + ostr << static_cast<char>(((code >> 12) & 0x3f) | 0x80); + ostr << static_cast<char>(((code >> 6) & 0x3f) | 0x80); + ostr << static_cast<char>((code & 0x3f) | 0x80); + } else { + ostr << static_cast<char>((code >> 30) | 0xfc); + ostr << static_cast<char>(((code >> 24) & 0x3f) | 0x80); + ostr << static_cast<char>(((code >> 18) & 0x3f) | 0x80); + ostr << static_cast<char>(((code >> 12) & 0x3f) | 0x80); + ostr << static_cast<char>(((code >> 6) & 0x3f) | 0x80); + ostr << static_cast<char>((code & 0x3f) | 0x80); + } + } + *str = ostr.str(); + return true; +} + +} // namespace fst + +#endif // FST_LIB_ICU_H_ diff --git a/kaldi_io/src/tools/openfst/include/fst/intersect.h b/kaldi_io/src/tools/openfst/include/fst/intersect.h new file mode 100644 index 0000000..f46116f --- /dev/null +++ b/kaldi_io/src/tools/openfst/include/fst/intersect.h @@ -0,0 +1,172 @@ +// intersect.h + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Copyright 2005-2010 Google, Inc. +// Author: [email protected] (Michael Riley) +// +// \file +// Class to compute the intersection of two FSAs + +#ifndef FST_LIB_INTERSECT_H__ +#define FST_LIB_INTERSECT_H__ + +#include <algorithm> +#include <vector> +using std::vector; + +#include <fst/cache.h> +#include <fst/compose.h> + + +namespace fst { + +template <class A, + class M = Matcher<Fst<A> >, + class F = SequenceComposeFilter<M>, + class T = GenericComposeStateTable<A, typename F::FilterState> > +struct IntersectFstOptions : public ComposeFstOptions<A, M, F, T> { + explicit IntersectFstOptions(const CacheOptions &opts, + M *mat1 = 0, M *mat2 = 0, + F *filt = 0, T *sttable= 0) + : ComposeFstOptions<A, M, F, T>(opts, mat1, mat2, filt, sttable) { } + + IntersectFstOptions() {} +}; + +// Computes the intersection (Hadamard product) of two FSAs. This +// version is a delayed Fst. Only strings that are in both automata +// are retained in the result. +// +// The two arguments must be acceptors. One of the arguments must be +// label-sorted. +// +// Complexity: same as ComposeFst. +// +// Caveats: same as ComposeFst. +template <class A> +class IntersectFst : public ComposeFst<A> { + public: + using ComposeFst<A>::CreateBase; + using ComposeFst<A>::CreateBase1; + using ComposeFst<A>::Properties; + using ImplToFst< ComposeFstImplBase<A> >::GetImpl; + using ImplToFst< ComposeFstImplBase<A> >::SetImpl; + + typedef A Arc; + typedef typename A::Weight Weight; + typedef typename A::StateId StateId; + + IntersectFst(const Fst<A> &fst1, const Fst<A> &fst2, + const CacheOptions opts = CacheOptions()) { + bool acceptors = fst1.Properties(kAcceptor, true) && + fst2.Properties(kAcceptor, true); + SetImpl(CreateBase(fst1, fst2, opts)); + if (!acceptors) { + FSTERROR() << "IntersectFst: input FSTs are not acceptors"; + GetImpl()->SetProperties(kError); + } + } + + template <class M, class F, class T> + IntersectFst(const Fst<A> &fst1, const Fst<A> &fst2, + const IntersectFstOptions<A, M, F, T> &opts) { + bool acceptors = fst1.Properties(kAcceptor, true) && + fst2.Properties(kAcceptor, true); + SetImpl(CreateBase1(fst1, fst2, opts)); + if (!acceptors) { + FSTERROR() << "IntersectFst: input FSTs are not acceptors"; + GetImpl()->SetProperties(kError); + } + } + + // See Fst<>::Copy() for doc. + IntersectFst(const IntersectFst<A> &fst, bool safe = false) : + ComposeFst<A>(fst, safe) {} + + // Get a copy of this IntersectFst. See Fst<>::Copy() for further doc. + virtual IntersectFst<A> *Copy(bool safe = false) const { + return new IntersectFst<A>(*this, safe); + } +}; + + +// Specialization for IntersectFst. +template <class A> +class StateIterator< IntersectFst<A> > + : public StateIterator< ComposeFst<A> > { + public: + explicit StateIterator(const IntersectFst<A> &fst) + : StateIterator< ComposeFst<A> >(fst) {} +}; + + +// Specialization for IntersectFst. +template <class A> +class ArcIterator< IntersectFst<A> > + : public ArcIterator< ComposeFst<A> > { + public: + typedef typename A::StateId StateId; + + ArcIterator(const IntersectFst<A> &fst, StateId s) + : ArcIterator< ComposeFst<A> >(fst, s) {} +}; + +// Useful alias when using StdArc. +typedef IntersectFst<StdArc> StdIntersectFst; + + +typedef ComposeOptions IntersectOptions; + + +// Computes the intersection (Hadamard product) of two FSAs. This +// version writes the intersection to an output MurableFst. Only +// strings that are in both automata are retained in the result. +// +// The two arguments must be acceptors. One of the arguments must be +// label-sorted. +// +// Complexity: same as Compose. +// +// Caveats: same as Compose. +template<class Arc> +void Intersect(const Fst<Arc> &ifst1, const Fst<Arc> &ifst2, + MutableFst<Arc> *ofst, + const IntersectOptions &opts = IntersectOptions()) { + typedef Matcher< Fst<Arc> > M; + + if (opts.filter_type == AUTO_FILTER) { + CacheOptions nopts; + nopts.gc_limit = 0; // Cache only the last state for fastest copy. + *ofst = IntersectFst<Arc>(ifst1, ifst2, nopts); + } else if (opts.filter_type == SEQUENCE_FILTER) { + IntersectFstOptions<Arc> iopts; + iopts.gc_limit = 0; // Cache only the last state for fastest copy. + *ofst = IntersectFst<Arc>(ifst1, ifst2, iopts); + } else if (opts.filter_type == ALT_SEQUENCE_FILTER) { + IntersectFstOptions<Arc, M, AltSequenceComposeFilter<M> > iopts; + iopts.gc_limit = 0; // Cache only the last state for fastest copy. + *ofst = IntersectFst<Arc>(ifst1, ifst2, iopts); + } else if (opts.filter_type == MATCH_FILTER) { + IntersectFstOptions<Arc, M, MatchComposeFilter<M> > iopts; + iopts.gc_limit = 0; // Cache only the last state for fastest copy. + *ofst = IntersectFst<Arc>(ifst1, ifst2, iopts); + } + + if (opts.connect) + Connect(ofst); +} + +} // namespace fst + +#endif // FST_LIB_INTERSECT_H__ diff --git a/kaldi_io/src/tools/openfst/include/fst/interval-set.h b/kaldi_io/src/tools/openfst/include/fst/interval-set.h new file mode 100644 index 0000000..58cad44 --- /dev/null +++ b/kaldi_io/src/tools/openfst/include/fst/interval-set.h @@ -0,0 +1,381 @@ +// interval-set.h + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Copyright 2005-2010 Google, Inc. +// Author: [email protected] (Michael Riley) +// +// \file +// Class to represent and operate on sets of intervals. + +#ifndef FST_LIB_INTERVAL_SET_H__ +#define FST_LIB_INTERVAL_SET_H__ + +#include <iostream> +#include <vector> +using std::vector; + + +#include <fst/util.h> + + +namespace fst { + +// Stores and operates on a set of half-open integral intervals [a,b) +// of signed integers of type T. +template <typename T> +class IntervalSet { + public: + struct Interval { + T begin_; + T end_; + + Interval() : begin_(-1), end_(-1) {} + + Interval(T b, T e) : begin_(b), end_(e) {} + + bool operator<(const Interval &i) const { + return begin_ < i.begin_ || (begin_ == i.begin_ && end_ > i.end_); + } + + bool operator==(const Interval &i) const { + return begin_ == i.begin_ && end_ == i.end_; + } + + bool operator!=(const Interval &i) const { + return begin_ != i.begin_ || end_ != i.end_; + } + + istream &Read(istream &strm) { + T n; + ReadType(strm, &n); + begin_ = n; + ReadType(strm, &n); + end_ = n; + return strm; + } + + ostream &Write(ostream &strm) const { + T n = begin_; + WriteType(strm, n); + n = end_; + WriteType(strm, n); + return strm; + } + }; + + IntervalSet() : count_(-1) {} + + // Returns the interval set as a vector. + vector<Interval> *Intervals() { return &intervals_; } + + const vector<Interval> *Intervals() const { return &intervals_; } + + bool Empty() const { return intervals_.empty(); } + + T Size() const { return intervals_.size(); } + + // Number of points in the intervals (undefined if not normalized). + T Count() const { return count_; } + + void Clear() { + intervals_.clear(); + count_ = 0; + } + + // Adds an interval set to the set. The result may not be normalized. + void Union(const IntervalSet<T> &iset) { + const vector<Interval> *intervals = iset.Intervals(); + for (typename vector<Interval>::const_iterator it = intervals->begin(); + it != intervals->end(); ++it) + intervals_.push_back(*it); + } + + // Requires intervals be normalized. + bool Member(T value) const { + Interval interval(value, value); + typename vector<Interval>::const_iterator lb = + lower_bound(intervals_.begin(), intervals_.end(), interval); + if (lb == intervals_.begin()) + return false; + return (--lb)->end_ > value; + } + + // Requires intervals be normalized. + bool operator==(const IntervalSet<T>& iset) const { + return *(iset.Intervals()) == intervals_; + } + + // Requires intervals be normalized. + bool operator!=(const IntervalSet<T>& iset) const { + return *(iset.Intervals()) != intervals_; + } + + bool Singleton() const { + return intervals_.size() == 1 && + intervals_[0].begin_ + 1 == intervals_[0].end_; + } + + + // Sorts; collapses overlapping and adjacent interals; sets count. + void Normalize(); + + // Intersects an interval set with the set. Requires intervals be + // normalized. The result is normalized. + void Intersect(const IntervalSet<T> &iset, IntervalSet<T> *oset) const; + + // Complements the set w.r.t [0, maxval). Requires intervals be + // normalized. The result is normalized. + void Complement(T maxval, IntervalSet<T> *oset) const; + + // Subtract an interval set from the set. Requires intervals be + // normalized. The result is normalized. + void Difference(const IntervalSet<T> &iset, IntervalSet<T> *oset) const; + + // Determines if an interval set overlaps with the set. Requires + // intervals be normalized. + bool Overlaps(const IntervalSet<T> &iset) const; + + // Determines if an interval set overlaps with the set but neither + // is contained in the other. Requires intervals be normalized. + bool StrictlyOverlaps(const IntervalSet<T> &iset) const; + + // Determines if an interval set is contained within the set. Requires + // intervals be normalized. + bool Contains(const IntervalSet<T> &iset) const; + + istream &Read(istream &strm) { + ReadType(strm, &intervals_); + return ReadType(strm, &count_); + } + + ostream &Write(ostream &strm) const { + WriteType(strm, intervals_); + return WriteType(strm, count_); + } + + private: + vector<Interval> intervals_; + T count_; +}; + +// Sorts; collapses overlapping and adjacent interavls; sets count. +template <typename T> +void IntervalSet<T>::Normalize() { + sort(intervals_.begin(), intervals_.end()); + + count_ = 0; + T size = 0; + for (T i = 0; i < intervals_.size(); ++i) { + Interval &inti = intervals_[i]; + if (inti.begin_ == inti.end_) + continue; + for (T j = i + 1; j < intervals_.size(); ++j) { + Interval &intj = intervals_[j]; + if (intj.begin_ > inti.end_) + break; + if (intj.end_ > inti.end_) + inti.end_ = intj.end_; + ++i; + } + count_ += inti.end_ - inti.begin_; + intervals_[size++] = inti; + } + intervals_.resize(size); +} + +// Intersects an interval set with the set. Requires intervals be normalized. +// The result is normalized. +template <typename T> +void IntervalSet<T>::Intersect(const IntervalSet<T> &iset, + IntervalSet<T> *oset) const { + const vector<Interval> *iintervals = iset.Intervals(); + vector<Interval> *ointervals = oset->Intervals(); + typename vector<Interval>::const_iterator it1 = intervals_.begin(); + typename vector<Interval>::const_iterator it2 = iintervals->begin(); + + ointervals->clear(); + oset->count_ = 0; + + while (it1 != intervals_.end() && it2 != iintervals->end()) { + if (it1->end_ <= it2->begin_) { + ++it1; + } else if (it2->end_ <= it1->begin_) { + ++it2; + } else { + Interval interval; + interval.begin_ = max(it1->begin_, it2->begin_); + interval.end_ = min(it1->end_, it2->end_); + ointervals->push_back(interval); + oset->count_ += interval.end_ - interval.begin_; + if (it1->end_ < it2->end_) + ++it1; + else + ++it2; + } + } +} + +// Complements the set w.r.t [0, maxval). Requires intervals be normalized. +// The result is normalized. +template <typename T> +void IntervalSet<T>::Complement(T maxval, IntervalSet<T> *oset) const { + vector<Interval> *ointervals = oset->Intervals(); + ointervals->clear(); + oset->count_ = 0; + + Interval interval; + interval.begin_ = 0; + for (typename vector<Interval>::const_iterator it = intervals_.begin(); + it != intervals_.end(); + ++it) { + interval.end_ = min(it->begin_, maxval); + if (interval.begin_ < interval.end_) { + ointervals->push_back(interval); + oset->count_ += interval.end_ - interval.begin_; + } + interval.begin_ = it->end_; + } + interval.end_ = maxval; + if (interval.begin_ < interval.end_) { + ointervals->push_back(interval); + oset->count_ += interval.end_ - interval.begin_; + } +} + +// Subtract an interval set from the set. Requires intervals be normalized. +// The result is normalized. +template <typename T> +void IntervalSet<T>::Difference(const IntervalSet<T> &iset, + IntervalSet<T> *oset) const { + if (intervals_.empty()) { + oset->Intervals()->clear(); + oset->count_ = 0; + } else { + IntervalSet<T> cset; + iset.Complement(intervals_.back().end_, &cset); + Intersect(cset, oset); + } +} + +// Determines if an interval set overlaps with the set. Requires +// intervals be normalized. +template <typename T> +bool IntervalSet<T>::Overlaps(const IntervalSet<T> &iset) const { + const vector<Interval> *intervals = iset.Intervals(); + typename vector<Interval>::const_iterator it1 = intervals_.begin(); + typename vector<Interval>::const_iterator it2 = intervals->begin(); + + while (it1 != intervals_.end() && it2 != intervals->end()) { + if (it1->end_ <= it2->begin_) { + ++it1; + } else if (it2->end_ <= it1->begin_) { + ++it2; + } else { + return true; + } + } + return false; +} + +// Determines if an interval set overlaps with the set but neither +// is contained in the other. Requires intervals be normalized. +template <typename T> +bool IntervalSet<T>::StrictlyOverlaps(const IntervalSet<T> &iset) const { + const vector<Interval> *intervals = iset.Intervals(); + typename vector<Interval>::const_iterator it1 = intervals_.begin(); + typename vector<Interval>::const_iterator it2 = intervals->begin(); + bool only1 = false; // point in intervals_ but not intervals + bool only2 = false; // point in intervals but not intervals_ + bool overlap = false; // point in both intervals_ and intervals + + while (it1 != intervals_.end() && it2 != intervals->end()) { + if (it1->end_ <= it2->begin_) { // no overlap - it1 first + only1 = true; + ++it1; + } else if (it2->end_ <= it1->begin_) { // no overlap - it2 first + only2 = true; + ++it2; + } else if (it2->begin_ == it1->begin_ && it2->end_ == it1->end_) { // equals + overlap = true; + ++it1; + ++it2; + } else if (it2->begin_ <= it1->begin_ && it2->end_ >= it1->end_) { // 1 c 2 + only2 = true; + overlap = true; + ++it1; + } else if (it1->begin_ <= it2->begin_ && it1->end_ >= it2->end_) { // 2 c 1 + only1 = true; + overlap = true; + ++it2; + } else { // strict overlap + only1 = true; + only2 = true; + overlap = true; + } + if (only1 == true && only2 == true && overlap == true) + return true; + } + if (it1 != intervals_.end()) + only1 = true; + if (it2 != intervals->end()) + only2 = true; + + return only1 == true && only2 == true && overlap == true; +} + +// Determines if an interval set is contained within the set. Requires +// intervals be normalized. +template <typename T> +bool IntervalSet<T>::Contains(const IntervalSet<T> &iset) const { + if (iset.Count() > Count()) + return false; + + const vector<Interval> *intervals = iset.Intervals(); + typename vector<Interval>::const_iterator it1 = intervals_.begin(); + typename vector<Interval>::const_iterator it2 = intervals->begin(); + + while (it1 != intervals_.end() && it2 != intervals->end()) { + if (it1->end_ <= it2->begin_) { // no overlap - it1 first + ++it1; + } else if (it2->begin_ < it1->begin_ || it2->end_ > it1->end_) { // no C + return false; + } else if (it2->end_ == it1->end_) { + ++it1; + ++it2; + } else { + ++it2; + } + } + return it2 == intervals->end(); +} + +template <typename T> +ostream &operator<<(ostream &strm, const IntervalSet<T> &s) { + typedef typename IntervalSet<T>::Interval Interval; + const vector<Interval> *intervals = s.Intervals(); + strm << "{"; + for (typename vector<Interval>::const_iterator it = intervals->begin(); + it != intervals->end(); + ++it) { + if (it != intervals->begin()) + strm << ","; + strm << "[" << it->begin_ << "," << it->end_ << ")"; + } + strm << "}"; + return strm; +} + +} // namespace fst + +#endif // FST_LIB_INTERVAL_SET_H__ diff --git a/kaldi_io/src/tools/openfst/include/fst/invert.h b/kaldi_io/src/tools/openfst/include/fst/invert.h new file mode 100644 index 0000000..bc83a5d --- /dev/null +++ b/kaldi_io/src/tools/openfst/include/fst/invert.h @@ -0,0 +1,125 @@ +// invert.h + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Copyright 2005-2010 Google, Inc. +// Author: [email protected] (Michael Riley) +// +// \file +// Functions and classes to invert an Fst. + +#ifndef FST_LIB_INVERT_H__ +#define FST_LIB_INVERT_H__ + +#include <fst/arc-map.h> +#include <fst/mutable-fst.h> + + +namespace fst { + +// Mapper to implement inversion of an arc. +template <class A> struct InvertMapper { + InvertMapper() {} + + A operator()(const A &arc) { + return A(arc.olabel, arc.ilabel, arc.weight, arc.nextstate); + } + + MapFinalAction FinalAction() const { return MAP_NO_SUPERFINAL; } + + MapSymbolsAction InputSymbolsAction() const { return MAP_CLEAR_SYMBOLS; } + + MapSymbolsAction OutputSymbolsAction() const { return MAP_CLEAR_SYMBOLS;} + + uint64 Properties(uint64 props) { return InvertProperties(props); } +}; + + +// Inverts the transduction corresponding to an FST by exchanging the +// FST's input and output labels. This version modifies its input. +// +// Complexity: +// - Time: O(V + E) +// - Space: O(1) +// where V = # of states and E = # of arcs. +template<class Arc> inline +void Invert(MutableFst<Arc> *fst) { + SymbolTable *input = fst->InputSymbols() ? fst->InputSymbols()->Copy() : 0; + SymbolTable *output = fst->OutputSymbols() ? fst->OutputSymbols()->Copy() : 0; + ArcMap(fst, InvertMapper<Arc>()); + fst->SetInputSymbols(output); + fst->SetOutputSymbols(input); + delete input; + delete output; +} + + +// Inverts the transduction corresponding to an FST by exchanging the +// FST's input and output labels. This version is a delayed Fst. +// +// Complexity: +// - Time: O(v + e) +// - Space: O(1) +// where v = # of states visited, e = # of arcs visited. Constant +// time and to visit an input state or arc is assumed and exclusive +// of caching. +template <class A> +class InvertFst : public ArcMapFst<A, A, InvertMapper<A> > { + public: + typedef A Arc; + typedef InvertMapper<A> C; + typedef ArcMapFstImpl< A, A, InvertMapper<A> > Impl; + using ImplToFst<Impl>::GetImpl; + + explicit InvertFst(const Fst<A> &fst) : ArcMapFst<A, A, C>(fst, C()) { + GetImpl()->SetOutputSymbols(fst.InputSymbols()); + GetImpl()->SetInputSymbols(fst.OutputSymbols()); + } + + // See Fst<>::Copy() for doc. + InvertFst(const InvertFst<A> &fst, bool safe = false) + : ArcMapFst<A, A, C>(fst, safe) {} + + // Get a copy of this InvertFst. See Fst<>::Copy() for further doc. + virtual InvertFst<A> *Copy(bool safe = false) const { + return new InvertFst(*this, safe); + } +}; + + +// Specialization for InvertFst. +template <class A> +class StateIterator< InvertFst<A> > + : public StateIterator< ArcMapFst<A, A, InvertMapper<A> > > { + public: + explicit StateIterator(const InvertFst<A> &fst) + : StateIterator< ArcMapFst<A, A, InvertMapper<A> > >(fst) {} +}; + + +// Specialization for InvertFst. +template <class A> +class ArcIterator< InvertFst<A> > + : public ArcIterator< ArcMapFst<A, A, InvertMapper<A> > > { + public: + ArcIterator(const InvertFst<A> &fst, typename A::StateId s) + : ArcIterator< ArcMapFst<A, A, InvertMapper<A> > >(fst, s) {} +}; + + +// Useful alias when using StdArc. +typedef InvertFst<StdArc> StdInvertFst; + +} // namespace fst + +#endif // FST_LIB_INVERT_H__ diff --git a/kaldi_io/src/tools/openfst/include/fst/label-reachable.h b/kaldi_io/src/tools/openfst/include/fst/label-reachable.h new file mode 100644 index 0000000..af06eef --- /dev/null +++ b/kaldi_io/src/tools/openfst/include/fst/label-reachable.h @@ -0,0 +1,565 @@ +// label_reachable.h + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Copyright 2005-2010 Google, Inc. +// Author: [email protected] (Michael Riley) +// +// \file +// Class to determine if a non-epsilon label can be read as the +// first non-epsilon symbol along some path from a given state. + + +#ifndef FST_LIB_LABEL_REACHABLE_H__ +#define FST_LIB_LABEL_REACHABLE_H__ + +#include <tr1/unordered_map> +using std::tr1::unordered_map; +using std::tr1::unordered_multimap; +#include <vector> +using std::vector; + +#include <fst/accumulator.h> +#include <fst/arcsort.h> +#include <fst/interval-set.h> +#include <fst/state-reachable.h> +#include <fst/vector-fst.h> + + +namespace fst { + +// Stores shareable data for label reachable class copies. +template <typename L> +class LabelReachableData { + public: + typedef L Label; + typedef typename IntervalSet<L>::Interval Interval; + + explicit LabelReachableData(bool reach_input, bool keep_relabel_data = true) + : reach_input_(reach_input), + keep_relabel_data_(keep_relabel_data), + have_relabel_data_(true), + final_label_(kNoLabel) {} + + ~LabelReachableData() {} + + bool ReachInput() const { return reach_input_; } + + vector< IntervalSet<L> > *IntervalSets() { return &isets_; } + + unordered_map<L, L> *Label2Index() { + if (!have_relabel_data_) + FSTERROR() << "LabelReachableData: no relabeling data"; + return &label2index_; + } + + Label FinalLabel() { + if (final_label_ == kNoLabel) + final_label_ = label2index_[kNoLabel]; + return final_label_; + } + + static LabelReachableData<L> *Read(istream &istrm) { + LabelReachableData<L> *data = new LabelReachableData<L>(); + + ReadType(istrm, &data->reach_input_); + ReadType(istrm, &data->keep_relabel_data_); + data->have_relabel_data_ = data->keep_relabel_data_; + if (data->keep_relabel_data_) + ReadType(istrm, &data->label2index_); + ReadType(istrm, &data->final_label_); + ReadType(istrm, &data->isets_); + return data; + } + + bool Write(ostream &ostrm) { + WriteType(ostrm, reach_input_); + WriteType(ostrm, keep_relabel_data_); + if (keep_relabel_data_) + WriteType(ostrm, label2index_); + WriteType(ostrm, FinalLabel()); + WriteType(ostrm, isets_); + return true; + } + + int RefCount() const { return ref_count_.count(); } + int IncrRefCount() { return ref_count_.Incr(); } + int DecrRefCount() { return ref_count_.Decr(); } + + private: + LabelReachableData() {} + + bool reach_input_; // Input or output labels considered? + bool keep_relabel_data_; // Save label2index_ to file? + bool have_relabel_data_; // Using label2index_? + Label final_label_; // Final label + RefCounter ref_count_; // Reference count. + unordered_map<L, L> label2index_; // Finds index for a label. + vector<IntervalSet <L> > isets_; // Interval sets per state. + + DISALLOW_COPY_AND_ASSIGN(LabelReachableData); +}; + + +// Tests reachability of labels from a given state. If reach_input = +// true, then input labels are considered, o.w. output labels are +// considered. To test for reachability from a state s, first do +// SetState(s). Then a label l can be reached from state s of FST f +// iff Reach(r) is true where r = Relabel(l). The relabeling is +// required to ensure a compact representation of the reachable +// labels. + +// The whole FST can be relabeled instead with Relabel(&f, +// reach_input) so that the test Reach(r) applies directly to the +// labels of the transformed FST f. The relabeled FST will also be +// sorted appropriately for composition. +// +// Reachablity of a final state from state s (via an epsilon path) +// can be tested with ReachFinal(); +// +// Reachability can also be tested on the set of labels specified by +// an arc iterator, useful for FST composition. In particular, +// Reach(aiter, ...) is true if labels on the input (output) side of +// the transitions of the arc iterator, when iter_input is true +// (false), can be reached from the state s. The iterator labels must +// have already been relabeled. +// +// With the arc iterator test of reachability, the begin position, end +// position and accumulated arc weight of the matches can be +// returned. The optional template argument controls how reachable arc +// weights are accumulated. The default uses the semiring +// Plus(). Alternative ones can be used to distribute the weights in +// composition in various ways. +template <class A, class S = DefaultAccumulator<A> > +class LabelReachable { + public: + typedef A Arc; + typedef typename A::StateId StateId; + typedef typename A::Label Label; + typedef typename A::Weight Weight; + typedef typename IntervalSet<Label>::Interval Interval; + + LabelReachable(const Fst<A> &fst, bool reach_input, S *s = 0, + bool keep_relabel_data = true) + : fst_(new VectorFst<Arc>(fst)), + s_(kNoStateId), + data_(new LabelReachableData<Label>(reach_input, keep_relabel_data)), + accumulator_(s ? s : new S()), + ncalls_(0), + nintervals_(0), + error_(false) { + StateId ins = fst_->NumStates(); + TransformFst(); + FindIntervals(ins); + delete fst_; + } + + explicit LabelReachable(LabelReachableData<Label> *data, S *s = 0) + : fst_(0), + s_(kNoStateId), + data_(data), + accumulator_(s ? s : new S()), + ncalls_(0), + nintervals_(0), + error_(false) { + data_->IncrRefCount(); + } + + LabelReachable(const LabelReachable<A, S> &reachable) : + fst_(0), + s_(kNoStateId), + data_(reachable.data_), + accumulator_(new S(*reachable.accumulator_)), + ncalls_(0), + nintervals_(0), + error_(reachable.error_) { + data_->IncrRefCount(); + } + + ~LabelReachable() { + if (!data_->DecrRefCount()) + delete data_; + delete accumulator_; + if (ncalls_ > 0) { + VLOG(2) << "# of calls: " << ncalls_; + VLOG(2) << "# of intervals/call: " << (nintervals_ / ncalls_); + } + } + + // Relabels w.r.t labels that give compact label sets. + Label Relabel(Label label) { + if (label == 0 || error_) + return label; + unordered_map<Label, Label> &label2index = *data_->Label2Index(); + Label &relabel = label2index[label]; + if (!relabel) // Add new label + relabel = label2index.size() + 1; + return relabel; + } + + // Relabels Fst w.r.t to labels that give compact label sets. + void Relabel(MutableFst<Arc> *fst, bool relabel_input) { + for (StateIterator< MutableFst<Arc> > siter(*fst); + !siter.Done(); siter.Next()) { + StateId s = siter.Value(); + for (MutableArcIterator< MutableFst<Arc> > aiter(fst, s); + !aiter.Done(); + aiter.Next()) { + Arc arc = aiter.Value(); + if (relabel_input) + arc.ilabel = Relabel(arc.ilabel); + else + arc.olabel = Relabel(arc.olabel); + aiter.SetValue(arc); + } + } + if (relabel_input) { + ArcSort(fst, ILabelCompare<Arc>()); + fst->SetInputSymbols(0); + } else { + ArcSort(fst, OLabelCompare<Arc>()); + fst->SetOutputSymbols(0); + } + } + + // Returns relabeling pairs (cf. relabel.h::Relabel()). + // If 'avoid_collisions' is true, extra pairs are added to + // ensure no collisions when relabeling automata that have + // labels unseen here. + void RelabelPairs(vector<pair<Label, Label> > *pairs, + bool avoid_collisions = false) { + pairs->clear(); + unordered_map<Label, Label> &label2index = *data_->Label2Index(); + // Maps labels to their new values in [1, label2index().size()] + for (typename unordered_map<Label, Label>::const_iterator + it = label2index.begin(); it != label2index.end(); ++it) + if (it->second != data_->FinalLabel()) + pairs->push_back(pair<Label, Label>(it->first, it->second)); + if (avoid_collisions) { + // Ensures any label in [1, label2index().size()] is mapped either + // by the above step or to label2index() + 1 (to avoid collisions). + for (int i = 1; i <= label2index.size(); ++i) { + typename unordered_map<Label, Label>::const_iterator + it = label2index.find(i); + if (it == label2index.end() || it->second == data_->FinalLabel()) + pairs->push_back(pair<Label, Label>(i, label2index.size() + 1)); + } + } + } + + // Set current state. Optionally set state associated + // with arc iterator to be passed to Reach. + void SetState(StateId s, StateId aiter_s = kNoStateId) { + s_ = s; + if (aiter_s != kNoStateId) { + accumulator_->SetState(aiter_s); + if (accumulator_->Error()) error_ = true; + } + } + + // Can reach this label from current state? + // Original labels must be transformed by the Relabel methods above. + bool Reach(Label label) { + if (label == 0 || error_) + return false; + vector< IntervalSet<Label> > &isets = *data_->IntervalSets(); + return isets[s_].Member(label); + + } + + // Can reach final state (via epsilon transitions) from this state? + bool ReachFinal() { + if (error_) return false; + vector< IntervalSet<Label> > &isets = *data_->IntervalSets(); + return isets[s_].Member(data_->FinalLabel()); + } + + // Initialize with secondary FST to be used with Reach(Iterator,...). + // If copy is true, then 'fst' is a copy of the FST used in the + // previous call to this method (useful to avoid unnecessary updates). + template <class F> + void ReachInit(const F &fst, bool copy = false) { + accumulator_->Init(fst, copy); + if (accumulator_->Error()) error_ = true; + } + + // Can reach any arc iterator label between iterator positions + // aiter_begin and aiter_end? If aiter_input = true, then iterator + // input labels are considered, o.w. output labels are considered. + // Arc iterator labels must be transformed by the Relabel methods + // above. If compute_weight is true, user may call ReachWeight(). + template <class Iterator> + bool Reach(Iterator *aiter, ssize_t aiter_begin, + ssize_t aiter_end, bool aiter_input, bool compute_weight) { + if (error_) return false; + vector< IntervalSet<Label> > &isets = *data_->IntervalSets(); + const vector<Interval> *intervals = isets[s_].Intervals(); + ++ncalls_; + nintervals_ += intervals->size(); + + reach_begin_ = -1; + reach_end_ = -1; + reach_weight_ = Weight::Zero(); + + uint32 flags = aiter->Flags(); // save flags to restore them on exit + aiter->SetFlags(kArcNoCache, kArcNoCache); // make caching optional + aiter->Seek(aiter_begin); + + if (2 * (aiter_end - aiter_begin) < intervals->size()) { + // Check each arc against intervals. + // Set arc iterator flags to only compute the ilabel or olabel values, + // since they are the only values required for most of the arcs processed. + aiter->SetFlags(aiter_input ? kArcILabelValue : kArcOLabelValue, + kArcValueFlags); + Label reach_label = kNoLabel; + for (ssize_t aiter_pos = aiter_begin; + aiter_pos < aiter_end; aiter->Next(), ++aiter_pos) { + const A &arc = aiter->Value(); + Label label = aiter_input ? arc.ilabel : arc.olabel; + if (label == reach_label || Reach(label)) { + reach_label = label; + if (reach_begin_ < 0) + reach_begin_ = aiter_pos; + reach_end_ = aiter_pos + 1; + if (compute_weight) { + if (!(aiter->Flags() & kArcWeightValue)) { + // If the 'arc.weight' wasn't computed by the call + // to 'aiter->Value()' above, we need to call + // 'aiter->Value()' again after having set the arc iterator + // flags to compute the arc weight value. + aiter->SetFlags(kArcWeightValue, kArcValueFlags); + const A &arcb = aiter->Value(); + // Call the accumulator. + reach_weight_ = accumulator_->Sum(reach_weight_, arcb.weight); + // Only ilabel or olabel required to process the following + // arcs. + aiter->SetFlags(aiter_input ? kArcILabelValue : kArcOLabelValue, + kArcValueFlags); + } else { + // Call the accumulator. + reach_weight_ = accumulator_->Sum(reach_weight_, arc.weight); + } + } + } + } + } else { + // Check each interval against arcs + ssize_t begin_low, end_low = aiter_begin; + for (typename vector<Interval>::const_iterator + iiter = intervals->begin(); + iiter != intervals->end(); ++iiter) { + begin_low = LowerBound(aiter, end_low, aiter_end, + aiter_input, iiter->begin); + end_low = LowerBound(aiter, begin_low, aiter_end, + aiter_input, iiter->end); + if (end_low - begin_low > 0) { + if (reach_begin_ < 0) + reach_begin_ = begin_low; + reach_end_ = end_low; + if (compute_weight) { + aiter->SetFlags(kArcWeightValue, kArcValueFlags); + reach_weight_ = accumulator_->Sum(reach_weight_, aiter, + begin_low, end_low); + } + } + } + } + + aiter->SetFlags(flags, kArcFlags); // restore original flag values + return reach_begin_ >= 0; + } + + // Returns iterator position of first matching arc. + ssize_t ReachBegin() const { return reach_begin_; } + + // Returns iterator position one past last matching arc. + ssize_t ReachEnd() const { return reach_end_; } + + // Return the sum of the weights for matching arcs. + // Valid only if compute_weight was true in Reach() call. + Weight ReachWeight() const { return reach_weight_; } + + // Access to the relabeling map. Excludes epsilon (0) label but + // includes kNoLabel that is used internally for super-final + // transitons. + const unordered_map<Label, Label>& Label2Index() const { + return *data_->Label2Index(); + } + + LabelReachableData<Label> *GetData() const { return data_; } + + bool Error() const { return error_ || accumulator_->Error(); } + + private: + // Redirects labeled arcs (input or output labels determined by + // ReachInput()) to new label-specific final states. Each original + // final state is redirected via a transition labeled with kNoLabel + // to a new kNoLabel-specific final state. Creates super-initial + // state for all states with zero in-degree. + void TransformFst() { + StateId ins = fst_->NumStates(); + StateId ons = ins; + + vector<ssize_t> indeg(ins, 0); + + // Redirects labeled arcs to new final states. + for (StateId s = 0; s < ins; ++s) { + for (MutableArcIterator< VectorFst<Arc> > aiter(fst_, s); + !aiter.Done(); + aiter.Next()) { + Arc arc = aiter.Value(); + Label label = data_->ReachInput() ? arc.ilabel : arc.olabel; + if (label) { + if (label2state_.find(label) == label2state_.end()) { + label2state_[label] = ons; + indeg.push_back(0); + ++ons; + } + arc.nextstate = label2state_[label]; + aiter.SetValue(arc); + } + ++indeg[arc.nextstate]; // Finds in-degrees for next step. + } + + // Redirects final weights to new final state. + Weight final = fst_->Final(s); + if (final != Weight::Zero()) { + if (label2state_.find(kNoLabel) == label2state_.end()) { + label2state_[kNoLabel] = ons; + indeg.push_back(0); + ++ons; + } + Arc arc(kNoLabel, kNoLabel, final, label2state_[kNoLabel]); + fst_->AddArc(s, arc); + ++indeg[arc.nextstate]; // Finds in-degrees for next step. + + fst_->SetFinal(s, Weight::Zero()); + } + } + + // Add new final states to Fst. + while (fst_->NumStates() < ons) { + StateId s = fst_->AddState(); + fst_->SetFinal(s, Weight::One()); + } + + // Creates a super-initial state for all states with zero in-degree. + StateId start = fst_->AddState(); + fst_->SetStart(start); + for (StateId s = 0; s < start; ++s) { + if (indeg[s] == 0) { + Arc arc(0, 0, Weight::One(), s); + fst_->AddArc(start, arc); + } + } + } + + void FindIntervals(StateId ins) { + StateReachable<A, Label> state_reachable(*fst_); + if (state_reachable.Error()) { + error_ = true; + return; + } + + vector<Label> &state2index = state_reachable.State2Index(); + vector< IntervalSet<Label> > &isets = *data_->IntervalSets(); + isets = state_reachable.IntervalSets(); + isets.resize(ins); + + unordered_map<Label, Label> &label2index = *data_->Label2Index(); + for (typename unordered_map<Label, StateId>::const_iterator + it = label2state_.begin(); + it != label2state_.end(); + ++it) { + Label l = it->first; + StateId s = it->second; + Label i = state2index[s]; + label2index[l] = i; + } + label2state_.clear(); + + double nintervals = 0; + ssize_t non_intervals = 0; + for (ssize_t s = 0; s < ins; ++s) { + nintervals += isets[s].Size(); + if (isets[s].Size() > 1) { + ++non_intervals; + VLOG(3) << "state: " << s << " # of intervals: " << isets[s].Size(); + } + } + VLOG(2) << "# of states: " << ins; + VLOG(2) << "# of intervals: " << nintervals; + VLOG(2) << "# of intervals/state: " << nintervals/ins; + VLOG(2) << "# of non-interval states: " << non_intervals; + } + + template <class Iterator> + ssize_t LowerBound(Iterator *aiter, ssize_t aiter_begin, + ssize_t aiter_end, bool aiter_input, + Label match_label) const { + // Only need to compute the ilabel or olabel of arcs when + // performing the binary search. + aiter->SetFlags(aiter_input ? kArcILabelValue : kArcOLabelValue, + kArcValueFlags); + ssize_t low = aiter_begin; + ssize_t high = aiter_end; + while (low < high) { + ssize_t mid = (low + high) / 2; + aiter->Seek(mid); + Label label = aiter_input ? + aiter->Value().ilabel : aiter->Value().olabel; + if (label > match_label) { + high = mid; + } else if (label < match_label) { + low = mid + 1; + } else { + // Find first matching label (when non-deterministic) + for (ssize_t i = mid; i > low; --i) { + aiter->Seek(i - 1); + label = aiter_input ? aiter->Value().ilabel : aiter->Value().olabel; + if (label != match_label) { + aiter->Seek(i); + aiter->SetFlags(kArcValueFlags, kArcValueFlags); + return i; + } + } + aiter->SetFlags(kArcValueFlags, kArcValueFlags); + return low; + } + } + aiter->Seek(low); + aiter->SetFlags(kArcValueFlags, kArcValueFlags); + return low; + } + + VectorFst<Arc> *fst_; + StateId s_; // Current state + unordered_map<Label, StateId> label2state_; // Finds final state for a label + + ssize_t reach_begin_; // Iterator pos of first match + ssize_t reach_end_; // Iterator pos after last match + Weight reach_weight_; // Gives weight sum of arc iterator + // arcs with reachable labels. + LabelReachableData<Label> *data_; // Shareable data between copies + S *accumulator_; // Sums arc weights + + double ncalls_; + double nintervals_; + bool error_; + + void operator=(const LabelReachable<A, S> &); // Disallow +}; + +} // namespace fst + +#endif // FST_LIB_LABEL_REACHABLE_H__ diff --git a/kaldi_io/src/tools/openfst/include/fst/lexicographic-weight.h b/kaldi_io/src/tools/openfst/include/fst/lexicographic-weight.h new file mode 100644 index 0000000..4b55c50 --- /dev/null +++ b/kaldi_io/src/tools/openfst/include/fst/lexicographic-weight.h @@ -0,0 +1,151 @@ +// lexicographic-weight.h + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Copyright 2005-2010 Google, Inc. +// Author: [email protected] (Richard Sproat) +// +// \file +// Lexicographic weight set and associated semiring operation definitions. +// +// A lexicographic weight is a sequence of weights, each of which must have the +// path property and Times() must be (strongly) cancellative +// (for all a,b,c != Zero(): Times(c, a) = Times(c, b) => a = b, +// Times(a, c) = Times(b, c) => a = b). +// The + operation on two weights a and b is the lexicographically +// prior of a and b. + +#ifndef FST_LIB_LEXICOGRAPHIC_WEIGHT_H__ +#define FST_LIB_LEXICOGRAPHIC_WEIGHT_H__ + +#include <string> + +#include <fst/pair-weight.h> +#include <fst/weight.h> + + +namespace fst { + +template<class W1, class W2> +class LexicographicWeight : public PairWeight<W1, W2> { + public: + using PairWeight<W1, W2>::Value1; + using PairWeight<W1, W2>::Value2; + using PairWeight<W1, W2>::SetValue1; + using PairWeight<W1, W2>::SetValue2; + using PairWeight<W1, W2>::Zero; + using PairWeight<W1, W2>::One; + using PairWeight<W1, W2>::NoWeight; + using PairWeight<W1, W2>::Quantize; + using PairWeight<W1, W2>::Reverse; + + typedef LexicographicWeight<typename W1::ReverseWeight, + typename W2::ReverseWeight> + ReverseWeight; + + LexicographicWeight() {} + + LexicographicWeight(const PairWeight<W1, W2>& w) + : PairWeight<W1, W2>(w) {} + + LexicographicWeight(W1 w1, W2 w2) : PairWeight<W1, W2>(w1, w2) { + uint64 props = kPath; + if ((W1::Properties() & props) != props) { + FSTERROR() << "LexicographicWeight must " + << "have the path property: " << W1::Type(); + SetValue1(W1::NoWeight()); + } + if ((W2::Properties() & props) != props) { + FSTERROR() << "LexicographicWeight must " + << "have the path property: " << W2::Type(); + SetValue2(W2::NoWeight()); + } + } + + static const LexicographicWeight<W1, W2> &Zero() { + static const LexicographicWeight<W1, W2> zero(PairWeight<W1, W2>::Zero()); + return zero; + } + + static const LexicographicWeight<W1, W2> &One() { + static const LexicographicWeight<W1, W2> one(PairWeight<W1, W2>::One()); + return one; + } + + static const LexicographicWeight<W1, W2> &NoWeight() { + static const LexicographicWeight<W1, W2> no_weight( + PairWeight<W1, W2>::NoWeight()); + return no_weight; + } + + static const string &Type() { + static const string type = W1::Type() + "_LT_" + W2::Type(); + return type; + } + + bool Member() const { + if (!Value1().Member() || !Value2().Member()) return false; + // Lexicographic weights cannot mix zeroes and non-zeroes. + if (Value1() == W1::Zero() && Value2() == W2::Zero()) return true; + if (Value1() != W1::Zero() && Value2() != W2::Zero()) return true; + return false; + } + + LexicographicWeight<W1, W2> Quantize(float delta = kDelta) const { + return PairWeight<W1, W2>::Quantize(); + } + + ReverseWeight Reverse() const { + return PairWeight<W1, W2>::Reverse(); + } + + static uint64 Properties() { + uint64 props1 = W1::Properties(); + uint64 props2 = W2::Properties(); + return props1 & props2 & (kLeftSemiring | kRightSemiring | kPath | + kIdempotent | kCommutative); + } +}; + +template <class W1, class W2> +inline LexicographicWeight<W1, W2> Plus(const LexicographicWeight<W1, W2> &w, + const LexicographicWeight<W1, W2> &v) { + if (!w.Member() || !v.Member()) + return LexicographicWeight<W1, W2>::NoWeight(); + NaturalLess<W1> less1; + NaturalLess<W2> less2; + if (less1(w.Value1(), v.Value1())) return w; + if (less1(v.Value1(), w.Value1())) return v; + if (less2(w.Value2(), v.Value2())) return w; + if (less2(v.Value2(), w.Value2())) return v; + return w; +} + +template <class W1, class W2> +inline LexicographicWeight<W1, W2> Times(const LexicographicWeight<W1, W2> &w, + const LexicographicWeight<W1, W2> &v) { + return LexicographicWeight<W1, W2>(Times(w.Value1(), v.Value1()), + Times(w.Value2(), v.Value2())); +} + +template <class W1, class W2> +inline LexicographicWeight<W1, W2> Divide(const LexicographicWeight<W1, W2> &w, + const LexicographicWeight<W1, W2> &v, + DivideType typ = DIVIDE_ANY) { + return LexicographicWeight<W1, W2>(Divide(w.Value1(), v.Value1(), typ), + Divide(w.Value2(), v.Value2(), typ)); +} + +} // namespace fst + +#endif // FST_LIB_LEXICOGRAPHIC_WEIGHT_H__ diff --git a/kaldi_io/src/tools/openfst/include/fst/lock.h b/kaldi_io/src/tools/openfst/include/fst/lock.h new file mode 100644 index 0000000..58cb22a --- /dev/null +++ b/kaldi_io/src/tools/openfst/include/fst/lock.h @@ -0,0 +1,100 @@ +// lock.h +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Author: [email protected] (Michael Riley) +// +// \file +// Google-compatibility locking declarations and inline definitions +// +// Classes and functions here are no-ops (by design); proper locking requires +// actual implementation. + +#ifndef FST_LIB_LOCK_H__ +#define FST_LIB_LOCK_H__ + +#include <fst/compat.h> // for DISALLOW_COPY_AND_ASSIGN + +namespace fst { + +using namespace std; + +// +// Single initialization - single-thread implementation +// + +typedef int FstOnceType; + +static const int FST_ONCE_INIT = 1; + +inline int FstOnceInit(FstOnceType *once, void (*init)(void)) { + if (*once) + (*init)(); + *once = 0; + return 0; +} + +// +// Thread locking - single-thread (non-)implementation +// + +class Mutex { + public: + Mutex() {} + + private: + DISALLOW_COPY_AND_ASSIGN(Mutex); +}; + +class MutexLock { + public: + MutexLock(Mutex *) {} + + private: + DISALLOW_COPY_AND_ASSIGN(MutexLock); +}; + +class ReaderMutexLock { + public: + ReaderMutexLock(Mutex *) {} + + private: + DISALLOW_COPY_AND_ASSIGN(ReaderMutexLock); +}; + +// Reference counting - single-thread implementation +class RefCounter { + public: + RefCounter() : count_(1) {} + + int count() const { return count_; } + +// below lines are modifications of openfst for multi-thrads support, +// from tools/extras/openfst_gcc41up.patch, applied by tools/Makefile, +// applicable to gcc 4.1 or above + // int Incr() const { return ++count_; } + // int Decr() const { return --count_; } + + int Incr() const { return __sync_add_and_fetch(&count_, 1); } + int Decr() const { return __sync_sub_and_fetch(&count_, 1); } +// end modifications + + private: + mutable int count_; + + DISALLOW_COPY_AND_ASSIGN(RefCounter); +}; + +} // namespace fst + +#endif // FST_LIB_LOCK_H__ diff --git a/kaldi_io/src/tools/openfst/include/fst/log.h b/kaldi_io/src/tools/openfst/include/fst/log.h new file mode 100644 index 0000000..d1492cd --- /dev/null +++ b/kaldi_io/src/tools/openfst/include/fst/log.h @@ -0,0 +1,66 @@ +// log.h +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Author: [email protected] (Michael Riley) +// +// \file +// Google-style logging declarations and inline definitions. + +#ifndef FST_LIB_LOG_H__ +#define FST_LIB_LOG_H__ + +#include <cassert> +#include <iostream> +#include <string> + +#include <fst/types.h> +#include <fst/flags.h> + +using std::string; + +DECLARE_int32(v); + +class LogMessage { + public: + LogMessage(const string &type) : fatal_(type == "FATAL") { + std::cerr << type << ": "; + } + ~LogMessage() { + std::cerr << std::endl; + if(fatal_) + exit(1); + } + std::ostream &stream() { return std::cerr; } + + private: + bool fatal_; +}; + +#define LOG(type) LogMessage(#type).stream() +#define VLOG(level) if ((level) <= FLAGS_v) LOG(INFO) + +// Checks +inline void CHECK(bool x) { assert(x); } + +#define CHECK_EQ(x, y) CHECK((x) == (y)) +#define CHECK_LT(x, y) CHECK((x) < (y)) +#define CHECK_GT(x, y) CHECK((x) > (y)) +#define CHECK_LE(x, y) CHECK((x) <= (y)) +#define CHECK_GE(x, y) CHECK((x) >= (y)) +#define CHECK_NE(x, y) CHECK((x) != (y)) + +// Ports +#define ATTRIBUTE_DEPRECATED __attribute__((deprecated)) + +#endif // FST_LIB_LOG_H__ diff --git a/kaldi_io/src/tools/openfst/include/fst/lookahead-filter.h b/kaldi_io/src/tools/openfst/include/fst/lookahead-filter.h new file mode 100644 index 0000000..e11c1bb --- /dev/null +++ b/kaldi_io/src/tools/openfst/include/fst/lookahead-filter.h @@ -0,0 +1,698 @@ +// lookahead-filter.h + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Copyright 2005-2010 Google, Inc. +// Author: [email protected] (Michael Riley) +// +// \file +// Composition filters to support lookahead matchers, useful for improving +// composition efficiency with certain inputs. + +#ifndef FST_LIB_LOOKAHEAD_FILTER_H__ +#define FST_LIB_LOOKAHEAD_FILTER_H__ + +#include <vector> +using std::vector; + +#include <fst/fst.h> +#include <fst/lookahead-matcher.h> + + +namespace fst { + +// Identifies and verifies the capabilities of the matcher to be used for +// lookahead with the composition filters below. This version is passed +// the matchers. +template <class M1, class M2> +MatchType LookAheadMatchType(const M1 &m1, const M2 &m2) { + MatchType type1 = m1.Type(false); + MatchType type2 = m2.Type(false); + if (type1 == MATCH_OUTPUT && + m1.Flags() & kOutputLookAheadMatcher) + return MATCH_OUTPUT; + else if (type2 == MATCH_INPUT && + m2.Flags() & kInputLookAheadMatcher) + return MATCH_INPUT; + else if (m1.Flags() & kOutputLookAheadMatcher && + m1.Type(true) == MATCH_OUTPUT) + return MATCH_OUTPUT; + else if (m2.Flags() & kInputLookAheadMatcher && + m2.Type(true) == MATCH_INPUT) + return MATCH_INPUT; + else + return MATCH_NONE; +} + +// Identifies and verifies the capabilities of the matcher to be used for +// lookahead with the composition filters below. This version uses the +// Fst's default matchers. +template <class Arc> +MatchType LookAheadMatchType(const Fst<Arc> &fst1, const Fst<Arc> &fst2) { + LookAheadMatcher< Fst <Arc> > matcher1(fst1, MATCH_OUTPUT); + LookAheadMatcher< Fst <Arc> > matcher2(fst2, MATCH_INPUT); + return LookAheadMatchType(matcher1, matcher2); +} + +// +// LookAheadSelector - a helper class for selecting among possibly +// distinct FST and matcher types w/o using a common base class. This +// lets us avoid virtual function calls. +// + +// Stores and returns the appropriate FST and matcher for lookahead. +// It is templated on the matcher types. General case has no methods +// since not currently supported. +template <class M1, class M2, MatchType MT> +class LookAheadSelector { +}; + +// Stores and returns the appropriate FST and matcher for lookahead. +// Specialized for two matchers of same type with the (match) 'type' +// arg determining which is used for lookahead. +template <class M, MatchType MT> +class LookAheadSelector<M, M, MT> { + public: + typedef typename M::Arc Arc; + typedef typename M::FST F; + + LookAheadSelector(M *lmatcher1, M *lmatcher2, MatchType type) + : lmatcher1_(lmatcher1->Copy()), + lmatcher2_(lmatcher2->Copy()), + type_(type) {} + + LookAheadSelector(const LookAheadSelector<M, M, MT> &selector) + : lmatcher1_(selector.lmatcher1_->Copy()), + lmatcher2_(selector.lmatcher2_->Copy()), + type_(selector.type_) {} + + ~LookAheadSelector() { + delete lmatcher1_; + delete lmatcher2_; + } + + const F &GetFst() const { + return type_ == MATCH_OUTPUT ? lmatcher2_->GetFst() : + lmatcher1_->GetFst(); + } + + M *GetMatcher() const { + return type_ == MATCH_OUTPUT ? lmatcher1_ : lmatcher2_; + } + + private: + M *lmatcher1_; + M *lmatcher2_; + MatchType type_; + + void operator=(const LookAheadSelector<M, M, MT> &); // disallow +}; + +// Stores and returns the appropriate FST and matcher for lookahead. +// Specialized for lookahead on input labels. +template <class M1, class M2> +class LookAheadSelector<M1, M2, MATCH_INPUT> { + public: + typedef typename M1::FST F1; + + LookAheadSelector(M1 *lmatcher1, M2 *lmatcher2, MatchType) + : fst_(lmatcher1->GetFst().Copy()), + lmatcher_(lmatcher2->Copy()) {} + + LookAheadSelector(const LookAheadSelector<M1, M2, MATCH_INPUT> &selector) + : fst_(selector.fst_->Copy()), + lmatcher_(selector.lmatcher_->Copy()) {} + + ~LookAheadSelector() { + delete lmatcher_; + delete fst_; + } + + const F1 &GetFst() const { return *fst_; } + + M2 *GetMatcher() const { return lmatcher_; } + + private: + const F1 *fst_; + M2 *lmatcher_; + + void operator=(const LookAheadSelector<M1, M2, MATCH_INPUT> &); // disallow +}; + + +// Stores and returns the appropriate FST and matcher for lookahead. +// Specialized for lookahead on output labels. +template <class M1, class M2> +class LookAheadSelector<M1, M2, MATCH_OUTPUT> { + public: + typedef typename M2::FST F2; + + LookAheadSelector(M1 *lmatcher1, M2 *lmatcher2, MatchType) + : fst_(lmatcher2->GetFst().Copy()), + lmatcher_(lmatcher1->Copy()) {} + + LookAheadSelector(const LookAheadSelector<M1, M2, MATCH_OUTPUT> &selector) + : fst_(selector.fst_->Copy()), + lmatcher_(selector.lmatcher_->Copy()) {} + + ~LookAheadSelector() { + delete lmatcher_; + delete fst_; + } + + const F2 &GetFst() const { return *fst_; } + + M1 *GetMatcher() const { return lmatcher_; } + + private: + const F2 *fst_; + M1 *lmatcher_; + + void operator=(const LookAheadSelector<M1, M2, MATCH_OUTPUT> &); // disallow +}; + +// This filter uses a lookahead matcher in FilterArc(arc1, arc2) to +// examine the future of the composition state (arc1.nextstate, +// arc2.nextstate), blocking moving forward when its determined to be +// non-coaccessible. It is templated on an underlying filter, +// typically the epsilon filter. Which matcher is the lookahead +// matcher is determined by the template argument MT unless it is +// MATCH_BOTH. In that case, both matcher arguments must be lookahead +// matchers of the same type and one will be selected by +// LookAheadMatchType() based on their capability. +template <class F, + class M1 = LookAheadMatcher<typename F::FST1>, + class M2 = M1, + MatchType MT = MATCH_BOTH> +class LookAheadComposeFilter { + public: + typedef typename F::FST1 FST1; + typedef typename F::FST2 FST2; + typedef typename F::Arc Arc; + typedef typename F::Matcher1 Matcher1; + typedef typename F::Matcher2 Matcher2; + typedef typename F::FilterState FilterState; + typedef LookAheadComposeFilter<F, M1, M2, MT> Filter; + + typedef typename Arc::StateId StateId; + typedef typename Arc::Label Label; + typedef typename Arc::Weight Weight; + + LookAheadComposeFilter(const FST1 &fst1, const FST2 &fst2, + M1 *matcher1, M2 *matcher2) + : filter_(fst1, fst2, matcher1, matcher2), + lookahead_type_(MT == MATCH_BOTH ? + LookAheadMatchType(*filter_.GetMatcher1(), + *filter_.GetMatcher2()) : MT), + selector_(filter_.GetMatcher1(), filter_.GetMatcher2(), + lookahead_type_), + flags_(lookahead_type_ == MATCH_OUTPUT ? + filter_.GetMatcher1()->Flags() : + filter_.GetMatcher2()->Flags()) { + if (lookahead_type_ == MATCH_NONE) { + FSTERROR() << "LookAheadComposeFilter: 1st argument cannot " + << "match/look-ahead on output labels and 2nd argument " + << "cannot match/look-ahead on input labels."; + } + selector_.GetMatcher()->InitLookAheadFst(selector_.GetFst()); + } + + LookAheadComposeFilter(const LookAheadComposeFilter<F, M1, M2, MT> &filter, + bool safe = false) + : filter_(filter.filter_, safe), + lookahead_type_(filter.lookahead_type_), + selector_(filter_.GetMatcher1(), filter_.GetMatcher2(), + lookahead_type_), + flags_(filter.flags_) { + selector_.GetMatcher()->InitLookAheadFst(selector_.GetFst(), true); + } + + FilterState Start() const { + return filter_.Start(); + } + + void SetState(StateId s1, StateId s2, const FilterState &f) { + filter_.SetState(s1, s2, f); + } + + FilterState FilterArc(Arc *arc1, Arc *arc2) const { + lookahead_arc_ = false; + + const FilterState &f = filter_.FilterArc(arc1, arc2); + if (f == FilterState::NoState()) + return FilterState::NoState(); + + return LookAheadOutput() ? LookAheadFilterArc(arc1, arc2, f) : + LookAheadFilterArc(arc2, arc1, f); + } + + void FilterFinal(Weight *weight1, Weight *weight2) const { + filter_.FilterFinal(weight1, weight2); + } + + // Return resp matchers. Ownership stays with filter. + Matcher1 *GetMatcher1() { return filter_.GetMatcher1(); } + Matcher2 *GetMatcher2() { return filter_.GetMatcher2(); } + + const LookAheadSelector<Matcher1, Matcher2, MT> &Selector() const { + return selector_; + } + + uint64 Properties(uint64 inprops) const { + uint64 outprops = filter_.Properties(inprops); + if (lookahead_type_ == MATCH_NONE) + outprops |= kError; + return outprops; + } + + uint32 LookAheadFlags() const { return flags_; } + + bool LookAheadArc() const { return lookahead_arc_; } + + bool LookAheadOutput() const { + if (MT == MATCH_OUTPUT) + return true; + else if (MT == MATCH_INPUT) + return false; + else if (lookahead_type_ == MATCH_OUTPUT) + return true; + else + return false; + } + + private: + FilterState LookAheadFilterArc(Arc *arca, Arc *arcb, + const FilterState &f) const { + Label &labela = LookAheadOutput() ? arca->olabel : arca->ilabel; + + if (labela != 0 && !(flags_ & kLookAheadNonEpsilons)) + return f; + if (labela == 0 && !(flags_ & kLookAheadEpsilons)) + return f; + + lookahead_arc_ = true; + selector_.GetMatcher()->SetState(arca->nextstate); + + return selector_.GetMatcher()->LookAheadFst(selector_.GetFst(), + arcb->nextstate) ? f : + FilterState::NoState(); + } + + F filter_; // Underlying filter + MatchType lookahead_type_; // Lookahead match type + LookAheadSelector<Matcher1, Matcher2, MT> selector_; + uint32 flags_; // Lookahead flags + mutable bool lookahead_arc_; // Look-ahead performed at last FilterArc()? + + void operator=(const LookAheadComposeFilter<F, M1, M2> &); // disallow +}; + + +// This filter adds weight-pushing to a lookahead composition filter +// using the LookAheadWeight() method of matcher argument. It is +// templated on an underlying lookahead filter, typically the basic +// lookahead filter. Weight-pushing in composition brings weights +// forward as much as possible based on the lookahead information. +template <class F, + class M1 = LookAheadMatcher<typename F::FST1>, + class M2 = M1, + MatchType MT = MATCH_BOTH> +class PushWeightsComposeFilter { + public: + typedef typename F::FST1 FST1; + typedef typename F::FST2 FST2; + typedef typename F::Arc Arc; + typedef typename F::Matcher1 Matcher1; + typedef typename F::Matcher2 Matcher2; + typedef typename F::FilterState FilterState1; + typedef WeightFilterState<typename Arc::Weight> FilterState2; + typedef PairFilterState<FilterState1, FilterState2> FilterState; + + typedef typename Arc::StateId StateId; + typedef typename Arc::Label Label; + typedef typename Arc::Weight Weight; + + PushWeightsComposeFilter(const FST1 &fst1, const FST2 &fst2, + M1 *matcher1, M2 *matcher2) + : filter_(fst1, fst2, matcher1, matcher2), + f_(FilterState::NoState()) {} + + PushWeightsComposeFilter(const PushWeightsComposeFilter<F, M1, M2, MT> + &filter, + bool safe = false) + : filter_(filter.filter_, safe), + f_(FilterState::NoState()) {} + + FilterState Start() const { + return FilterState(filter_.Start(), FilterState2(Weight::One())); + } + + void SetState(StateId s1, StateId s2, const FilterState &f) { + f_ = f; + filter_.SetState(s1, s2, f.GetState1()); + } + + FilterState FilterArc(Arc *arc1, Arc *arc2) const { + const FilterState1 &f1 = filter_.FilterArc(arc1, arc2); + if (f1 == FilterState1::NoState()) + return FilterState::NoState(); + + if (!(LookAheadFlags() & kLookAheadWeight)) + return FilterState(f1, FilterState2(Weight::One())); + + const Weight &lweight = filter_.LookAheadArc() ? + Selector().GetMatcher()->LookAheadWeight() : Weight::One(); + const FilterState2 &f2 = f_.GetState2(); + const Weight &fweight = f2.GetWeight(); + + arc2->weight = Divide(Times(arc2->weight, lweight), fweight); + return FilterState(f1, FilterState2(lweight)); + } + + void FilterFinal(Weight *weight1, Weight *weight2) const { + filter_.FilterFinal(weight1, weight2); + if (!(LookAheadFlags() & kLookAheadWeight) || *weight1 == Weight::Zero()) + return; + + const FilterState2 &f2 = f_.GetState2(); + const Weight &fweight = f2.GetWeight(); + *weight1 = Divide(*weight1, fweight); + } + // Return resp matchers. Ownership states with filter. + Matcher1 *GetMatcher1() { return filter_.GetMatcher1(); } + Matcher2 *GetMatcher2() { return filter_.GetMatcher2(); } + + const LookAheadSelector<Matcher1, Matcher2, MT> &Selector() const { + return filter_.Selector(); + } + + uint32 LookAheadFlags() const { return filter_.LookAheadFlags(); } + bool LookAheadArc() const { return filter_.LookAheadArc(); } + bool LookAheadOutput() const { return filter_.LookAheadOutput(); } + + uint64 Properties(uint64 props) const { + return filter_.Properties(props) & kWeightInvariantProperties; + } + + private: + F filter_; // Underlying filter + FilterState f_; // Current filter state + + void operator=(const PushWeightsComposeFilter<F, M1, M2, MT> &); // disallow +}; + +// This filter adds label-pushing to a lookahead composition filter +// using the LookAheadPrefix() method of the matcher argument. It is +// templated on an underlying filter, typically the basic lookahead +// or weight-pushing lookahead filter. Label-pushing in composition +// matches labels as early as possible based on the lookahead +// information. +template <class F, + class M1 = LookAheadMatcher<typename F::FST1>, + class M2 = M1, + MatchType MT = MATCH_BOTH> +class PushLabelsComposeFilter { + public: + typedef typename F::FST1 FST1; + typedef typename F::FST2 FST2; + typedef typename F::Arc Arc; + typedef typename Arc::StateId StateId; + typedef typename Arc::Label Label; + typedef typename Arc::Weight Weight; + + typedef MultiEpsMatcher<typename F::Matcher1> Matcher1; + typedef MultiEpsMatcher<typename F::Matcher2> Matcher2; + typedef typename F::FilterState FilterState1; + typedef IntegerFilterState<typename Arc::Label> FilterState2; + typedef PairFilterState<FilterState1, FilterState2> FilterState; + + PushLabelsComposeFilter(const FST1 &fst1, const FST2 &fst2, + M1 *matcher1, M2 *matcher2) + : filter_(fst1, fst2, matcher1, matcher2), + f_(FilterState::NoState()), + fst1_(filter_.GetMatcher1()->GetFst()), + fst2_(filter_.GetMatcher2()->GetFst()), + matcher1_(fst1_, MATCH_OUTPUT, + filter_.LookAheadOutput() ? kMultiEpsList : kMultiEpsLoop, + filter_.GetMatcher1(), + false), + matcher2_(fst2_, MATCH_INPUT, + filter_.LookAheadOutput() ? kMultiEpsLoop : kMultiEpsList, + filter_.GetMatcher2(), + false) {} + + PushLabelsComposeFilter(const PushLabelsComposeFilter<F, M1, M2, MT> &filter, + bool safe = false) + : filter_(filter.filter_, safe), + f_(FilterState::NoState()), + fst1_(filter_.GetMatcher1()->GetFst()), + fst2_(filter_.GetMatcher2()->GetFst()), + matcher1_(fst1_, MATCH_OUTPUT, + filter_.LookAheadOutput() ? kMultiEpsList : kMultiEpsLoop, + filter_.GetMatcher1(), + false), + matcher2_(fst2_, MATCH_INPUT, + filter_.LookAheadOutput() ? kMultiEpsLoop : kMultiEpsList, + filter_.GetMatcher2(), + false) { + } + + FilterState Start() const { + return FilterState(filter_.Start(), FilterState2(kNoLabel)); + } + + void SetState(StateId s1, StateId s2, const FilterState &f) { + f_ = f; + filter_.SetState(s1, s2, f.GetState1()); + if (!(LookAheadFlags() & kLookAheadPrefix)) + return; + + narcsa_ = LookAheadOutput() ? internal::NumArcs(fst1_, s1) + : internal::NumArcs(fst2_, s2); + + const FilterState2 &f2 = f_.GetState2(); + const Label &flabel = f2.GetState(); + + GetMatcher1()->ClearMultiEpsLabels(); + GetMatcher2()->ClearMultiEpsLabels(); + if (flabel != kNoLabel) { // Have a lookahead label? + GetMatcher1()->AddMultiEpsLabel(flabel); // Yes, make it a multi-epsilon + GetMatcher2()->AddMultiEpsLabel(flabel); // label so that it matches the + } // implicit epsilon arc to be + } // modified below when pushing. + + FilterState FilterArc(Arc *arc1, Arc *arc2) const { + if (!(LookAheadFlags() & kLookAheadPrefix)) + return FilterState(filter_.FilterArc(arc1, arc2), + FilterState2(kNoLabel)); + + const FilterState2 &f2 = f_.GetState2(); + const Label &flabel = f2.GetState(); + if (flabel != kNoLabel) // Have a lookahead label? + return LookAheadOutput() ? PushedLabelFilterArc(arc1, arc2, flabel) : + PushedLabelFilterArc(arc2, arc1, flabel); + + const FilterState1 &f1 = filter_.FilterArc(arc1, arc2); + if (f1 == FilterState1::NoState()) + return FilterState::NoState(); + + if (!filter_.LookAheadArc()) + return FilterState(f1, FilterState2(kNoLabel)); + + return LookAheadOutput() ? PushLabelFilterArc(arc1, arc2, f1) : + PushLabelFilterArc(arc2, arc1, f1); + } + + void FilterFinal(Weight *weight1, Weight *weight2) const { + filter_.FilterFinal(weight1, weight2); + if (!(LookAheadFlags() & kLookAheadPrefix) || + *weight1 == Weight::Zero()) + return; + + const FilterState2 &f2 = f_.GetState2(); + const Label &flabel = f2.GetState(); + if (flabel != kNoLabel) + *weight1 = Weight::Zero(); + } + + // Return resp matchers. Ownership states with filter. + Matcher1 *GetMatcher1() { return &matcher1_; } + Matcher2 *GetMatcher2() { return &matcher2_; } + + uint64 Properties(uint64 iprops) const { + uint64 oprops = filter_.Properties(iprops); + if (LookAheadOutput()) + return oprops & kOLabelInvariantProperties; + else + return oprops & kILabelInvariantProperties; + } + + private: + const LookAheadSelector<typename F::Matcher1, typename F::Matcher2, MT> + &Selector() const { + return filter_.Selector(); + } + + // Consumes an already pushed label. + FilterState PushedLabelFilterArc(Arc *arca, Arc *arcb, + Label flabel) const { + Label &labela = LookAheadOutput() ? arca->olabel : arca->ilabel; + const Label &labelb = LookAheadOutput() ? arcb->ilabel : arcb->olabel; + + if (labelb != kNoLabel) { + return FilterState::NoState(); // Block non- (multi-) epsilon label + } else if (labela == flabel) { + labela = 0; // Convert match to multi-eps to eps + return Start(); + } else if (labela == 0) { + if (narcsa_ == 1) + return f_; // Take eps; keep state w/ label + Selector().GetMatcher()->SetState(arca->nextstate); + if (Selector().GetMatcher()->LookAheadLabel(flabel)) + return f_; // Take eps; keep state w/ label + else + return FilterState::NoState(); // Block non-coaccessible path + } else { + return FilterState::NoState(); // Block mismatch to multi-eps label + } + } + + // Pushes a label forward when possible. + FilterState PushLabelFilterArc(Arc *arca, Arc *arcb, + const FilterState1 &f1) const { + Label &labela = LookAheadOutput() ? arca->olabel : arca->ilabel; + const Label &labelb = LookAheadOutput() ? arcb->olabel : arcb->ilabel; + + if (labelb != 0) // No place to push. + return FilterState(f1, FilterState2(kNoLabel)); + if (labela != 0 && // Wrong lookahead prefix type? + LookAheadFlags() & kLookAheadNonEpsilonPrefix) + return FilterState(f1, FilterState2(kNoLabel)); + + Arc larc(kNoLabel, kNoLabel, Weight::Zero(), kNoStateId); + + if (Selector().GetMatcher()->LookAheadPrefix(&larc)) { // Have prefix arc? + labela = LookAheadOutput() ? larc.ilabel : larc.olabel; + arcb->ilabel = larc.ilabel; // Yes, go forward on that arc, + arcb->olabel = larc.olabel; // thus pushing the label. + arcb->weight = Times(arcb->weight, larc.weight); + arcb->nextstate = larc.nextstate; + return FilterState(f1, FilterState2(labela)); + } else { + return FilterState(f1, FilterState2(kNoLabel)); + } + } + + uint32 LookAheadFlags() const { return filter_.LookAheadFlags(); } + bool LookAheadArc() const { return filter_.LookAheadArc(); } + bool LookAheadOutput() const { return filter_.LookAheadOutput(); } + + F filter_; // Underlying filter + FilterState f_ ; // Current filter state + const FST1 &fst1_; + const FST2 &fst2_; + Matcher1 matcher1_; // Multi-epsilon matcher for fst1 + Matcher2 matcher2_; // Multi-epsilon matcher for fst2 + ssize_t narcsa_; // Number of arcs leaving look-ahead match FST + + void operator=(const PushLabelsComposeFilter<F, M1, M2, MT> &); // disallow +}; + +// +// CONVENIENCE CLASS useful for setting up composition with a default +// look-ahead matcher and filter. +// + +template <class A, MatchType type> // MATCH_NONE +class DefaultLookAhead { + public: + typedef Matcher< Fst<A> > M; + typedef SequenceComposeFilter<M> ComposeFilter; + typedef M FstMatcher; +}; + +// Specializes for MATCH_INPUT to allow lookahead. +template <class A> +class DefaultLookAhead<A, MATCH_INPUT> { + public: + typedef LookAheadMatcher< Fst<A> > M; + typedef SequenceComposeFilter<M> SF; + typedef LookAheadComposeFilter<SF, M> ComposeFilter; + typedef M FstMatcher; +}; + +// Specializes for MATCH_OUTPUT to allow lookahead. +template <class A> +class DefaultLookAhead<A, MATCH_OUTPUT> { + public: + typedef LookAheadMatcher< Fst<A> > M; + typedef AltSequenceComposeFilter<M> SF; + typedef LookAheadComposeFilter<SF, M> ComposeFilter; + typedef M FstMatcher; +}; + +// Specializes for StdArc to allow weight and label pushing. +template <> +class DefaultLookAhead<StdArc, MATCH_INPUT> { + public: + typedef StdArc A; + typedef LookAheadMatcher< Fst<A> > M; + typedef SequenceComposeFilter<M> SF; + typedef LookAheadComposeFilter<SF, M> LF; + typedef PushWeightsComposeFilter<LF, M> WF; + typedef PushLabelsComposeFilter<WF, M> ComposeFilter; + typedef M FstMatcher; +}; + +// Specializes for StdArc to allow weight and label pushing. +template <> +class DefaultLookAhead<StdArc, MATCH_OUTPUT> { + public: + typedef StdArc A; + typedef LookAheadMatcher< Fst<A> > M; + typedef AltSequenceComposeFilter<M> SF; + typedef LookAheadComposeFilter<SF, M> LF; + typedef PushWeightsComposeFilter<LF, M> WF; + typedef PushLabelsComposeFilter<WF, M> ComposeFilter; + typedef M FstMatcher; +}; + +// Specializes for LogArc to allow weight and label pushing. +template <> +class DefaultLookAhead<LogArc, MATCH_INPUT> { + public: + typedef LogArc A; + typedef LookAheadMatcher< Fst<A> > M; + typedef SequenceComposeFilter<M> SF; + typedef LookAheadComposeFilter<SF, M> LF; + typedef PushWeightsComposeFilter<LF, M> WF; + typedef PushLabelsComposeFilter<WF, M> ComposeFilter; + typedef M FstMatcher; +}; + +// Specializes for LogArc to allow weight and label pushing. +template <> +class DefaultLookAhead<LogArc, MATCH_OUTPUT> { + public: + typedef LogArc A; + typedef LookAheadMatcher< Fst<A> > M; + typedef AltSequenceComposeFilter<M> SF; + typedef LookAheadComposeFilter<SF, M> LF; + typedef PushWeightsComposeFilter<LF, M> WF; + typedef PushLabelsComposeFilter<WF, M> ComposeFilter; + typedef M FstMatcher; +}; + +} // namespace fst + +#endif // FST_LIB_LOOKAHEAD_FILTER_H__ diff --git a/kaldi_io/src/tools/openfst/include/fst/lookahead-matcher.h b/kaldi_io/src/tools/openfst/include/fst/lookahead-matcher.h new file mode 100644 index 0000000..f927d65 --- /dev/null +++ b/kaldi_io/src/tools/openfst/include/fst/lookahead-matcher.h @@ -0,0 +1,812 @@ +// lookahead-matcher.h + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Copyright 2005-2010 Google, Inc. +// Author: [email protected] (Michael Riley) +// +// \file +// Classes to add lookahead to FST matchers, useful e.g. for improving +// composition efficiency with certain inputs. + +#ifndef FST_LIB_LOOKAHEAD_MATCHER_H__ +#define FST_LIB_LOOKAHEAD_MATCHER_H__ + +#include <fst/add-on.h> +#include <fst/const-fst.h> +#include <fst/fst.h> +#include <fst/label-reachable.h> +#include <fst/matcher.h> + + +DECLARE_string(save_relabel_ipairs); +DECLARE_string(save_relabel_opairs); + +namespace fst { + +// LOOKAHEAD MATCHERS - these have the interface of Matchers (see +// matcher.h) and these additional methods: +// +// template <class F> +// class LookAheadMatcher { +// public: +// typedef F FST; +// typedef F::Arc Arc; +// typedef typename Arc::StateId StateId; +// typedef typename Arc::Label Label; +// typedef typename Arc::Weight Weight; +// +// // Required constructors. +// LookAheadMatcher(const F &fst, MatchType match_type); +// // If safe=true, the copy is thread-safe (except the lookahead Fst is +// // preserved). See Fst<>::Cop() for further doc. +// LookAheadMatcher(const LookAheadMatcher &matcher, bool safe = false); +// +// Below are methods for looking ahead for a match to a label and +// more generally, to a rational set. Each returns false if there is +// definitely not a match and returns true if there possibly is a +// match. + +// // LABEL LOOKAHEAD: Can 'label' be read from the current matcher state +// // after possibly following epsilon transitions? +// bool LookAheadLabel(Label label) const; +// +// // RATIONAL LOOKAHEAD: The next methods allow looking ahead for an +// // arbitrary rational set of strings, specified by an FST and a state +// // from which to begin the matching. If the lookahead FST is a +// // transducer, this looks on the side different from the matcher +// // 'match_type' (cf. composition). +// +// // Are there paths P from 's' in the lookahead FST that can be read from +// // the cur. matcher state? +// bool LookAheadFst(const Fst<Arc>& fst, StateId s); +// +// // Gives an estimate of the combined weight of the paths P in the +// // lookahead and matcher FSTs for the last call to LookAheadFst. +// // A trivial implementation returns Weight::One(). Non-trivial +// // implementations are useful for weight-pushing in composition. +// Weight LookAheadWeight() const; +// +// // Is there is a single non-epsilon arc found in the lookahead FST +// // that begins P (after possibly following any epsilons) in the last +// // call LookAheadFst? If so, return true and copy it to '*arc', o.w. +// // return false. A trivial implementation returns false. Non-trivial +// // implementations are useful for label-pushing in composition. +// bool LookAheadPrefix(Arc *arc); +// +// // Optionally pre-specifies the lookahead FST that will be passed +// // to LookAheadFst() for possible precomputation. If copy is true, +// // then 'fst' is a copy of the FST used in the previous call to +// // this method (useful to avoid unnecessary updates). +// void InitLookAheadFst(const Fst<Arc>& fst, bool copy = false); +// +// }; + +// +// LOOK-AHEAD FLAGS (see also kMatcherFlags in matcher.h): +// +// Matcher is a lookahead matcher when 'match_type' is MATCH_INPUT. +const uint32 kInputLookAheadMatcher = 0x00000010; + +// Matcher is a lookahead matcher when 'match_type' is MATCH_OUTPUT. +const uint32 kOutputLookAheadMatcher = 0x00000020; + +// A non-trivial implementation of LookAheadWeight() method defined and +// should be used? +const uint32 kLookAheadWeight = 0x00000040; + +// A non-trivial implementation of LookAheadPrefix() method defined and +// should be used? +const uint32 kLookAheadPrefix = 0x00000080; + +// Look-ahead of matcher FST non-epsilon arcs? +const uint32 kLookAheadNonEpsilons = 0x00000100; + +// Look-ahead of matcher FST epsilon arcs? +const uint32 kLookAheadEpsilons = 0x00000200; + +// Ignore epsilon paths for the lookahead prefix? Note this gives +// correct results in composition only with an appropriate composition +// filter since it depends on the filter blocking the ignored paths. +const uint32 kLookAheadNonEpsilonPrefix = 0x00000400; + +// For LabelLookAheadMatcher, save relabeling data to file +const uint32 kLookAheadKeepRelabelData = 0x00000800; + +// Flags used for lookahead matchers. +const uint32 kLookAheadFlags = 0x00000ff0; + +// LookAhead Matcher interface, templated on the Arc definition; used +// for lookahead matcher specializations that are returned by the +// InitMatcher() Fst method. +template <class A> +class LookAheadMatcherBase : public MatcherBase<A> { + public: + typedef A Arc; + typedef typename A::StateId StateId; + typedef typename A::Label Label; + typedef typename A::Weight Weight; + + LookAheadMatcherBase() + : weight_(Weight::One()), + prefix_arc_(kNoLabel, kNoLabel, Weight::One(), kNoStateId) {} + + virtual ~LookAheadMatcherBase() {} + + bool LookAheadLabel(Label label) const { return LookAheadLabel_(label); } + + bool LookAheadFst(const Fst<Arc> &fst, StateId s) { + return LookAheadFst_(fst, s); + } + + Weight LookAheadWeight() const { return weight_; } + + bool LookAheadPrefix(Arc *arc) const { + if (prefix_arc_.nextstate != kNoStateId) { + *arc = prefix_arc_; + return true; + } else { + return false; + } + } + + virtual void InitLookAheadFst(const Fst<Arc>& fst, bool copy = false) = 0; + + protected: + void SetLookAheadWeight(const Weight &w) { weight_ = w; } + + void SetLookAheadPrefix(const Arc &arc) { prefix_arc_ = arc; } + + void ClearLookAheadPrefix() { prefix_arc_.nextstate = kNoStateId; } + + private: + virtual bool LookAheadLabel_(Label label) const = 0; + virtual bool LookAheadFst_(const Fst<Arc> &fst, + StateId s) = 0; // This must set l.a. weight and + // prefix if non-trivial. + Weight weight_; // Look-ahead weight + Arc prefix_arc_; // Look-ahead prefix arc +}; + + +// Don't really lookahead, just declare future looks good regardless. +template <class M> +class TrivialLookAheadMatcher + : public LookAheadMatcherBase<typename M::FST::Arc> { + public: + typedef typename M::FST FST; + typedef typename M::Arc Arc; + typedef typename Arc::StateId StateId; + typedef typename Arc::Label Label; + typedef typename Arc::Weight Weight; + + TrivialLookAheadMatcher(const FST &fst, MatchType match_type) + : matcher_(fst, match_type) {} + + TrivialLookAheadMatcher(const TrivialLookAheadMatcher<M> &lmatcher, + bool safe = false) + : matcher_(lmatcher.matcher_, safe) {} + + // General matcher methods + TrivialLookAheadMatcher<M> *Copy(bool safe = false) const { + return new TrivialLookAheadMatcher<M>(*this, safe); + } + + MatchType Type(bool test) const { return matcher_.Type(test); } + void SetState(StateId s) { return matcher_.SetState(s); } + bool Find(Label label) { return matcher_.Find(label); } + bool Done() const { return matcher_.Done(); } + const Arc& Value() const { return matcher_.Value(); } + void Next() { matcher_.Next(); } + virtual const FST &GetFst() const { return matcher_.GetFst(); } + uint64 Properties(uint64 props) const { return matcher_.Properties(props); } + uint32 Flags() const { + return matcher_.Flags() | kInputLookAheadMatcher | kOutputLookAheadMatcher; + } + + // Look-ahead methods. + bool LookAheadLabel(Label label) const { return true; } + bool LookAheadFst(const Fst<Arc> &fst, StateId s) {return true; } + Weight LookAheadWeight() const { return Weight::One(); } + bool LookAheadPrefix(Arc *arc) const { return false; } + void InitLookAheadFst(const Fst<Arc>& fst, bool copy = false) {} + + private: + // This allows base class virtual access to non-virtual derived- + // class members of the same name. It makes the derived class more + // efficient to use but unsafe to further derive. + virtual void SetState_(StateId s) { SetState(s); } + virtual bool Find_(Label label) { return Find(label); } + virtual bool Done_() const { return Done(); } + virtual const Arc& Value_() const { return Value(); } + virtual void Next_() { Next(); } + + bool LookAheadLabel_(Label l) const { return LookAheadLabel(l); } + + bool LookAheadFst_(const Fst<Arc> &fst, StateId s) { + return LookAheadFst(fst, s); + } + + Weight LookAheadWeight_() const { return LookAheadWeight(); } + bool LookAheadPrefix_(Arc *arc) const { return LookAheadPrefix(arc); } + + M matcher_; +}; + +// Look-ahead of one transition. Template argument F accepts flags to +// control behavior. +template <class M, uint32 F = kLookAheadNonEpsilons | kLookAheadEpsilons | + kLookAheadWeight | kLookAheadPrefix> +class ArcLookAheadMatcher + : public LookAheadMatcherBase<typename M::FST::Arc> { + public: + typedef typename M::FST FST; + typedef typename M::Arc Arc; + typedef typename Arc::StateId StateId; + typedef typename Arc::Label Label; + typedef typename Arc::Weight Weight; + typedef NullAddOn MatcherData; + + using LookAheadMatcherBase<Arc>::LookAheadWeight; + using LookAheadMatcherBase<Arc>::SetLookAheadPrefix; + using LookAheadMatcherBase<Arc>::SetLookAheadWeight; + using LookAheadMatcherBase<Arc>::ClearLookAheadPrefix; + + ArcLookAheadMatcher(const FST &fst, MatchType match_type, + MatcherData *data = 0) + : matcher_(fst, match_type), + fst_(matcher_.GetFst()), + lfst_(0), + s_(kNoStateId) {} + + ArcLookAheadMatcher(const ArcLookAheadMatcher<M, F> &lmatcher, + bool safe = false) + : matcher_(lmatcher.matcher_, safe), + fst_(matcher_.GetFst()), + lfst_(lmatcher.lfst_), + s_(kNoStateId) {} + + // General matcher methods + ArcLookAheadMatcher<M, F> *Copy(bool safe = false) const { + return new ArcLookAheadMatcher<M, F>(*this, safe); + } + + MatchType Type(bool test) const { return matcher_.Type(test); } + + void SetState(StateId s) { + s_ = s; + matcher_.SetState(s); + } + + bool Find(Label label) { return matcher_.Find(label); } + bool Done() const { return matcher_.Done(); } + const Arc& Value() const { return matcher_.Value(); } + void Next() { matcher_.Next(); } + const FST &GetFst() const { return fst_; } + uint64 Properties(uint64 props) const { return matcher_.Properties(props); } + uint32 Flags() const { + return matcher_.Flags() | kInputLookAheadMatcher | + kOutputLookAheadMatcher | F; + } + + // Writable matcher methods + MatcherData *GetData() const { return 0; } + + // Look-ahead methods. + bool LookAheadLabel(Label label) const { return matcher_.Find(label); } + + // Checks if there is a matching (possibly super-final) transition + // at (s_, s). + bool LookAheadFst(const Fst<Arc> &fst, StateId s); + + void InitLookAheadFst(const Fst<Arc>& fst, bool copy = false) { + lfst_ = &fst; + } + + private: + // This allows base class virtual access to non-virtual derived- + // class members of the same name. It makes the derived class more + // efficient to use but unsafe to further derive. + virtual void SetState_(StateId s) { SetState(s); } + virtual bool Find_(Label label) { return Find(label); } + virtual bool Done_() const { return Done(); } + virtual const Arc& Value_() const { return Value(); } + virtual void Next_() { Next(); } + + bool LookAheadLabel_(Label l) const { return LookAheadLabel(l); } + bool LookAheadFst_(const Fst<Arc> &fst, StateId s) { + return LookAheadFst(fst, s); + } + + mutable M matcher_; + const FST &fst_; // Matcher FST + const Fst<Arc> *lfst_; // Look-ahead FST + StateId s_; // Matcher state +}; + +template <class M, uint32 F> +bool ArcLookAheadMatcher<M, F>::LookAheadFst(const Fst<Arc> &fst, StateId s) { + if (&fst != lfst_) + InitLookAheadFst(fst); + + bool ret = false; + ssize_t nprefix = 0; + if (F & kLookAheadWeight) + SetLookAheadWeight(Weight::Zero()); + if (F & kLookAheadPrefix) + ClearLookAheadPrefix(); + if (fst_.Final(s_) != Weight::Zero() && + lfst_->Final(s) != Weight::Zero()) { + if (!(F & (kLookAheadWeight | kLookAheadPrefix))) + return true; + ++nprefix; + if (F & kLookAheadWeight) + SetLookAheadWeight(Plus(LookAheadWeight(), + Times(fst_.Final(s_), lfst_->Final(s)))); + ret = true; + } + if (matcher_.Find(kNoLabel)) { + if (!(F & (kLookAheadWeight | kLookAheadPrefix))) + return true; + ++nprefix; + if (F & kLookAheadWeight) + for (; !matcher_.Done(); matcher_.Next()) + SetLookAheadWeight(Plus(LookAheadWeight(), matcher_.Value().weight)); + ret = true; + } + for (ArcIterator< Fst<Arc> > aiter(*lfst_, s); + !aiter.Done(); + aiter.Next()) { + const Arc &arc = aiter.Value(); + Label label = kNoLabel; + switch (matcher_.Type(false)) { + case MATCH_INPUT: + label = arc.olabel; + break; + case MATCH_OUTPUT: + label = arc.ilabel; + break; + default: + FSTERROR() << "ArcLookAheadMatcher::LookAheadFst: bad match type"; + return true; + } + if (label == 0) { + if (!(F & (kLookAheadWeight | kLookAheadPrefix))) + return true; + if (!(F & kLookAheadNonEpsilonPrefix)) + ++nprefix; + if (F & kLookAheadWeight) + SetLookAheadWeight(Plus(LookAheadWeight(), arc.weight)); + ret = true; + } else if (matcher_.Find(label)) { + if (!(F & (kLookAheadWeight | kLookAheadPrefix))) + return true; + for (; !matcher_.Done(); matcher_.Next()) { + ++nprefix; + if (F & kLookAheadWeight) + SetLookAheadWeight(Plus(LookAheadWeight(), + Times(arc.weight, + matcher_.Value().weight))); + if ((F & kLookAheadPrefix) && nprefix == 1) + SetLookAheadPrefix(arc); + } + ret = true; + } + } + if (F & kLookAheadPrefix) { + if (nprefix == 1) + SetLookAheadWeight(Weight::One()); // Avoids double counting. + else + ClearLookAheadPrefix(); + } + return ret; +} + + +// Template argument F accepts flags to control behavior. +// It must include precisely one of KInputLookAheadMatcher or +// KOutputLookAheadMatcher. +template <class M, uint32 F = kLookAheadEpsilons | kLookAheadWeight | + kLookAheadPrefix | kLookAheadNonEpsilonPrefix | + kLookAheadKeepRelabelData, + class S = DefaultAccumulator<typename M::Arc> > +class LabelLookAheadMatcher + : public LookAheadMatcherBase<typename M::FST::Arc> { + public: + typedef typename M::FST FST; + typedef typename M::Arc Arc; + typedef typename Arc::StateId StateId; + typedef typename Arc::Label Label; + typedef typename Arc::Weight Weight; + typedef LabelReachableData<Label> MatcherData; + + using LookAheadMatcherBase<Arc>::LookAheadWeight; + using LookAheadMatcherBase<Arc>::SetLookAheadPrefix; + using LookAheadMatcherBase<Arc>::SetLookAheadWeight; + using LookAheadMatcherBase<Arc>::ClearLookAheadPrefix; + + LabelLookAheadMatcher(const FST &fst, MatchType match_type, + MatcherData *data = 0, S *s = 0) + : matcher_(fst, match_type), + lfst_(0), + label_reachable_(0), + s_(kNoStateId), + error_(false) { + if (!(F & (kInputLookAheadMatcher | kOutputLookAheadMatcher))) { + FSTERROR() << "LabelLookaheadMatcher: bad matcher flags: " << F; + error_ = true; + } + bool reach_input = match_type == MATCH_INPUT; + if (data) { + if (reach_input == data->ReachInput()) + label_reachable_ = new LabelReachable<Arc, S>(data, s); + } else if ((reach_input && (F & kInputLookAheadMatcher)) || + (!reach_input && (F & kOutputLookAheadMatcher))) { + label_reachable_ = new LabelReachable<Arc, S>( + fst, reach_input, s, F & kLookAheadKeepRelabelData); + } + } + + LabelLookAheadMatcher(const LabelLookAheadMatcher<M, F, S> &lmatcher, + bool safe = false) + : matcher_(lmatcher.matcher_, safe), + lfst_(lmatcher.lfst_), + label_reachable_( + lmatcher.label_reachable_ ? + new LabelReachable<Arc, S>(*lmatcher.label_reachable_) : 0), + s_(kNoStateId), + error_(lmatcher.error_) {} + + ~LabelLookAheadMatcher() { + delete label_reachable_; + } + + // General matcher methods + LabelLookAheadMatcher<M, F, S> *Copy(bool safe = false) const { + return new LabelLookAheadMatcher<M, F, S>(*this, safe); + } + + MatchType Type(bool test) const { return matcher_.Type(test); } + + void SetState(StateId s) { + if (s_ == s) + return; + s_ = s; + match_set_state_ = false; + reach_set_state_ = false; + } + + bool Find(Label label) { + if (!match_set_state_) { + matcher_.SetState(s_); + match_set_state_ = true; + } + return matcher_.Find(label); + } + + bool Done() const { return matcher_.Done(); } + const Arc& Value() const { return matcher_.Value(); } + void Next() { matcher_.Next(); } + const FST &GetFst() const { return matcher_.GetFst(); } + + uint64 Properties(uint64 inprops) const { + uint64 outprops = matcher_.Properties(inprops); + if (error_ || (label_reachable_ && label_reachable_->Error())) + outprops |= kError; + return outprops; + } + + uint32 Flags() const { + if (label_reachable_ && label_reachable_->GetData()->ReachInput()) + return matcher_.Flags() | F | kInputLookAheadMatcher; + else if (label_reachable_ && !label_reachable_->GetData()->ReachInput()) + return matcher_.Flags() | F | kOutputLookAheadMatcher; + else + return matcher_.Flags(); + } + + // Writable matcher methods + MatcherData *GetData() const { + return label_reachable_ ? label_reachable_->GetData() : 0; + }; + + // Look-ahead methods. + bool LookAheadLabel(Label label) const { + if (label == 0) + return true; + + if (label_reachable_) { + if (!reach_set_state_) { + label_reachable_->SetState(s_); + reach_set_state_ = true; + } + return label_reachable_->Reach(label); + } else { + return true; + } + } + + // Checks if there is a matching (possibly super-final) transition + // at (s_, s). + template <class L> + bool LookAheadFst(const L &fst, StateId s); + + void InitLookAheadFst(const Fst<Arc>& fst, bool copy = false) { + lfst_ = &fst; + if (label_reachable_) + label_reachable_->ReachInit(fst, copy); + } + + template <class L> + void InitLookAheadFst(const L& fst, bool copy = false) { + lfst_ = static_cast<const Fst<Arc> *>(&fst); + if (label_reachable_) + label_reachable_->ReachInit(fst, copy); + } + + private: + // This allows base class virtual access to non-virtual derived- + // class members of the same name. It makes the derived class more + // efficient to use but unsafe to further derive. + virtual void SetState_(StateId s) { SetState(s); } + virtual bool Find_(Label label) { return Find(label); } + virtual bool Done_() const { return Done(); } + virtual const Arc& Value_() const { return Value(); } + virtual void Next_() { Next(); } + + bool LookAheadLabel_(Label l) const { return LookAheadLabel(l); } + bool LookAheadFst_(const Fst<Arc> &fst, StateId s) { + return LookAheadFst(fst, s); + } + + mutable M matcher_; + const Fst<Arc> *lfst_; // Look-ahead FST + LabelReachable<Arc, S> *label_reachable_; // Label reachability info + StateId s_; // Matcher state + bool match_set_state_; // matcher_.SetState called? + mutable bool reach_set_state_; // reachable_.SetState called? + bool error_; +}; + +template <class M, uint32 F, class S> +template <class L> inline +bool LabelLookAheadMatcher<M, F, S>::LookAheadFst(const L &fst, StateId s) { + if (static_cast<const Fst<Arc> *>(&fst) != lfst_) + InitLookAheadFst(fst); + + SetLookAheadWeight(Weight::One()); + ClearLookAheadPrefix(); + + if (!label_reachable_) + return true; + + label_reachable_->SetState(s_, s); + reach_set_state_ = true; + + bool compute_weight = F & kLookAheadWeight; + bool compute_prefix = F & kLookAheadPrefix; + + bool reach_input = Type(false) == MATCH_OUTPUT; + ArcIterator<L> aiter(fst, s); + bool reach_arc = label_reachable_->Reach(&aiter, 0, + internal::NumArcs(*lfst_, s), + reach_input, compute_weight); + Weight lfinal = internal::Final(*lfst_, s); + bool reach_final = lfinal != Weight::Zero() && label_reachable_->ReachFinal(); + if (reach_arc) { + ssize_t begin = label_reachable_->ReachBegin(); + ssize_t end = label_reachable_->ReachEnd(); + if (compute_prefix && end - begin == 1 && !reach_final) { + aiter.Seek(begin); + SetLookAheadPrefix(aiter.Value()); + compute_weight = false; + } else if (compute_weight) { + SetLookAheadWeight(label_reachable_->ReachWeight()); + } + } + if (reach_final && compute_weight) + SetLookAheadWeight(reach_arc ? + Plus(LookAheadWeight(), lfinal) : lfinal); + + return reach_arc || reach_final; +} + + +// Label-lookahead relabeling class. +template <class A> +class LabelLookAheadRelabeler { + public: + typedef typename A::Label Label; + typedef LabelReachableData<Label> MatcherData; + typedef AddOnPair<MatcherData, MatcherData> D; + + // Relabels matcher Fst - initialization function object. + template <typename I> + LabelLookAheadRelabeler(I **impl); + + // Relabels arbitrary Fst. Class L should be a label-lookahead Fst. + template <class L> + static void Relabel(MutableFst<A> *fst, const L &mfst, + bool relabel_input) { + typename L::Impl *impl = mfst.GetImpl(); + D *data = impl->GetAddOn(); + LabelReachable<A> reachable(data->First() ? + data->First() : data->Second()); + reachable.Relabel(fst, relabel_input); + } + + // Returns relabeling pairs (cf. relabel.h::Relabel()). + // Class L should be a label-lookahead Fst. + // If 'avoid_collisions' is true, extra pairs are added to + // ensure no collisions when relabeling automata that have + // labels unseen here. + template <class L> + static void RelabelPairs(const L &mfst, vector<pair<Label, Label> > *pairs, + bool avoid_collisions = false) { + typename L::Impl *impl = mfst.GetImpl(); + D *data = impl->GetAddOn(); + LabelReachable<A> reachable(data->First() ? + data->First() : data->Second()); + reachable.RelabelPairs(pairs, avoid_collisions); + } +}; + +template <class A> +template <typename I> inline +LabelLookAheadRelabeler<A>::LabelLookAheadRelabeler(I **impl) { + Fst<A> &fst = (*impl)->GetFst(); + D *data = (*impl)->GetAddOn(); + const string name = (*impl)->Type(); + bool is_mutable = fst.Properties(kMutable, false); + MutableFst<A> *mfst = 0; + if (is_mutable) { + mfst = static_cast<MutableFst<A> *>(&fst); + } else { + mfst = new VectorFst<A>(fst); + data->IncrRefCount(); + delete *impl; + } + if (data->First()) { // reach_input + LabelReachable<A> reachable(data->First()); + reachable.Relabel(mfst, true); + if (!FLAGS_save_relabel_ipairs.empty()) { + vector<pair<Label, Label> > pairs; + reachable.RelabelPairs(&pairs, true); + WriteLabelPairs(FLAGS_save_relabel_ipairs, pairs); + } + } else { + LabelReachable<A> reachable(data->Second()); + reachable.Relabel(mfst, false); + if (!FLAGS_save_relabel_opairs.empty()) { + vector<pair<Label, Label> > pairs; + reachable.RelabelPairs(&pairs, true); + WriteLabelPairs(FLAGS_save_relabel_opairs, pairs); + } + } + if (!is_mutable) { + *impl = new I(*mfst, name); + (*impl)->SetAddOn(data); + delete mfst; + data->DecrRefCount(); + } +} + + +// Generic lookahead matcher, templated on the FST definition +// - a wrapper around pointer to specific one. +template <class F> +class LookAheadMatcher { + public: + typedef F FST; + typedef typename F::Arc Arc; + typedef typename Arc::StateId StateId; + typedef typename Arc::Label Label; + typedef typename Arc::Weight Weight; + typedef LookAheadMatcherBase<Arc> LBase; + + LookAheadMatcher(const F &fst, MatchType match_type) { + base_ = fst.InitMatcher(match_type); + if (!base_) + base_ = new SortedMatcher<F>(fst, match_type); + lookahead_ = false; + } + + LookAheadMatcher(const LookAheadMatcher<F> &matcher, bool safe = false) { + base_ = matcher.base_->Copy(safe); + lookahead_ = matcher.lookahead_; + } + + ~LookAheadMatcher() { delete base_; } + + // General matcher methods + LookAheadMatcher<F> *Copy(bool safe = false) const { + return new LookAheadMatcher<F>(*this, safe); + } + + MatchType Type(bool test) const { return base_->Type(test); } + void SetState(StateId s) { base_->SetState(s); } + bool Find(Label label) { return base_->Find(label); } + bool Done() const { return base_->Done(); } + const Arc& Value() const { return base_->Value(); } + void Next() { base_->Next(); } + const F &GetFst() const { return static_cast<const F &>(base_->GetFst()); } + + uint64 Properties(uint64 props) const { return base_->Properties(props); } + + uint32 Flags() const { return base_->Flags(); } + + // Look-ahead methods + bool LookAheadLabel(Label label) const { + if (LookAheadCheck()) { + LBase *lbase = static_cast<LBase *>(base_); + return lbase->LookAheadLabel(label); + } else { + return true; + } + } + + bool LookAheadFst(const Fst<Arc> &fst, StateId s) { + if (LookAheadCheck()) { + LBase *lbase = static_cast<LBase *>(base_); + return lbase->LookAheadFst(fst, s); + } else { + return true; + } + } + + Weight LookAheadWeight() const { + if (LookAheadCheck()) { + LBase *lbase = static_cast<LBase *>(base_); + return lbase->LookAheadWeight(); + } else { + return Weight::One(); + } + } + + bool LookAheadPrefix(Arc *arc) const { + if (LookAheadCheck()) { + LBase *lbase = static_cast<LBase *>(base_); + return lbase->LookAheadPrefix(arc); + } else { + return false; + } + } + + void InitLookAheadFst(const Fst<Arc>& fst, bool copy = false) { + if (LookAheadCheck()) { + LBase *lbase = static_cast<LBase *>(base_); + lbase->InitLookAheadFst(fst, copy); + } + } + + private: + bool LookAheadCheck() const { + if (!lookahead_) { + lookahead_ = base_->Flags() & + (kInputLookAheadMatcher | kOutputLookAheadMatcher); + if (!lookahead_) { + FSTERROR() << "LookAheadMatcher: No look-ahead matcher defined"; + } + } + return lookahead_; + } + + MatcherBase<Arc> *base_; + mutable bool lookahead_; + + void operator=(const LookAheadMatcher<Arc> &); // disallow +}; + +} // namespace fst + +#endif // FST_LIB_LOOKAHEAD_MATCHER_H__ diff --git a/kaldi_io/src/tools/openfst/include/fst/map.h b/kaldi_io/src/tools/openfst/include/fst/map.h new file mode 100644 index 0000000..419cac4 --- /dev/null +++ b/kaldi_io/src/tools/openfst/include/fst/map.h @@ -0,0 +1,121 @@ +// map.h + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Copyright 2005-2010 Google, Inc. +// Author: [email protected] (Michael Riley) +// +// \file +// Compatability file for old-style Map() functions and MapFst class +// that have been renamed to ArcMap (cf. StateMap). + +#ifndef FST_LIB_MAP_H__ +#define FST_LIB_MAP_H__ + + +#include <fst/arc-map.h> + + +namespace fst { + +template<class A, class C> +void Map(MutableFst<A> *fst, C* mapper) { + ArcMap(fst, mapper); +} + +template<class A, class C> +void Map(MutableFst<A> *fst, C mapper) { + ArcMap(fst, mapper); +} + +template<class A, class B, class C> +void Map(const Fst<A> &ifst, MutableFst<B> *ofst, C* mapper) { + ArcMap(ifst, ofst, mapper); +} + +template<class A, class B, class C> +void Map(const Fst<A> &ifst, MutableFst<B> *ofst, C mapper) { + ArcMap(ifst, ofst, mapper); +} + +typedef ArcMapFstOptions MapFstOptions; + +template <class A, class B, class C> +class MapFst : public ArcMapFst<A, B, C> { + public: + typedef B Arc; + typedef typename B::Weight Weight; + typedef typename B::StateId StateId; + typedef CacheState<B> State; + + MapFst(const Fst<A> &fst, const C &mapper, const MapFstOptions& opts) + : ArcMapFst<A, B, C>(fst, mapper, opts) {} + + MapFst(const Fst<A> &fst, C* mapper, const MapFstOptions& opts) + : ArcMapFst<A, B, C>(fst, mapper, opts) {} + + MapFst(const Fst<A> &fst, const C &mapper) + : ArcMapFst<A, B, C>(fst, mapper) {} + + MapFst(const Fst<A> &fst, C* mapper) : ArcMapFst<A, B, C>(fst, mapper) {} + + // See Fst<>::Copy() for doc. + MapFst(const ArcMapFst<A, B, C> &fst, bool safe = false) + : ArcMapFst<A, B, C>(fst, safe) {} + + // Get a copy of this MapFst. See Fst<>::Copy() for further doc. +virtual MapFst<A, B, C> *Copy(bool safe = false) const { + return new MapFst(*this, safe); + } +}; + + +// Specialization for MapFst. +template <class A, class B, class C> +class StateIterator< MapFst<A, B, C> > + : public StateIterator< ArcMapFst<A, B, C> > { + public: + explicit StateIterator(const ArcMapFst<A, B, C> &fst) + : StateIterator< ArcMapFst<A, B, C> >(fst) {} +}; + + +// Specialization for MapFst. +template <class A, class B, class C> +class ArcIterator< MapFst<A, B, C> > + : public ArcIterator< ArcMapFst<A, B, C> > { + public: + ArcIterator(const ArcMapFst<A, B, C> &fst, typename A::StateId s) + : ArcIterator< ArcMapFst<A, B, C> >(fst, s) {} +}; + + +template <class A> +struct IdentityMapper { + typedef A FromArc; + typedef A ToArc; + + A operator()(const A &arc) const { return arc; } + + MapFinalAction FinalAction() const { return MAP_NO_SUPERFINAL; } + + MapSymbolsAction InputSymbolsAction() const { return MAP_COPY_SYMBOLS; } + + MapSymbolsAction OutputSymbolsAction() const { return MAP_COPY_SYMBOLS;} + + uint64 Properties(uint64 props) const { return props; } +}; + +} // namespace fst + +#endif // FST_LIB_MAP_H__ diff --git a/kaldi_io/src/tools/openfst/include/fst/mapped-file.h b/kaldi_io/src/tools/openfst/include/fst/mapped-file.h new file mode 100644 index 0000000..d61bc14 --- /dev/null +++ b/kaldi_io/src/tools/openfst/include/fst/mapped-file.h @@ -0,0 +1,83 @@ + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Copyright 2005-2010 Google, Inc. +// Author: [email protected] (Jeffrey Sorensen) + +#ifndef FST_LIB_MAPPED_FILE_H_ +#define FST_LIB_MAPPED_FILE_H_ + +#include <unistd.h> +#include <sys/mman.h> + +#include <fst/fst.h> +#include <iostream> +#include <fstream> +#include <sstream> + +DECLARE_int32(fst_arch_alignment); // defined in mapped-file.h + +namespace fst { + +// A memory region is a simple abstraction for allocated memory or data from +// mmap'ed files. If mmap equals NULL, then data represents an owned region of +// size bytes. Otherwise, mmap and size refer to the mapping and data is a +// casted pointer to a region contained within [mmap, mmap + size). +// If size is 0, then mmap refers and data refer to a block of memory managed +// externally by some other allocator. +struct MemoryRegion { + void *data; + void *mmap; + size_t size; +}; + +class MappedFile { + public: + virtual ~MappedFile(); + + void* mutable_data() const { + return reinterpret_cast<void*>(region_.data); + } + + const void* data() const { + return reinterpret_cast<void*>(region_.data); + } + + // Returns a MappedFile object that contains the contents of the input + // stream s starting from the current file position with size bytes. + // The file name must also be provided in the FstReadOptions as opts.source + // or else mapping will fail. If mapping is not possible, then a MappedFile + // object with a new[]'ed block of memory will be created. + static MappedFile* Map(istream* s, const FstReadOptions& opts, size_t size); + + // Creates a MappedFile object with a new[]'ed block of memory of size. + // RECOMMENDED FOR INTERNAL USE ONLY, may change in future releases. + static MappedFile* Allocate(size_t size); + + // Creates a MappedFile object pointing to a borrowed reference to data. + // This block of memory is not owned by the MappedFile object and will not + // be freed. + // RECOMMENDED FOR INTERNAL USE ONLY, may change in future releases. + static MappedFile* Borrow(void *data); + + static const int kArchAlignment; + + private: + explicit MappedFile(const MemoryRegion ®ion); + + MemoryRegion region_; + DISALLOW_COPY_AND_ASSIGN(MappedFile); +}; +} // namespace fst + +#endif // FST_LIB_MAPPED_FILE_H_ diff --git a/kaldi_io/src/tools/openfst/include/fst/matcher-fst.h b/kaldi_io/src/tools/openfst/include/fst/matcher-fst.h new file mode 100644 index 0000000..73e64ad --- /dev/null +++ b/kaldi_io/src/tools/openfst/include/fst/matcher-fst.h @@ -0,0 +1,359 @@ +// matcher-fst.h + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Copyright 2005-2010 Google, Inc. +// Author: [email protected] (Michael Riley) +// +// \file +// Class to add a matcher to an FST. + +#ifndef FST_LIB_MATCHER_FST_FST_H__ +#define FST_LIB_MATCHER_FST_FST_H__ + +#include <fst/add-on.h> +#include <fst/const-fst.h> +#include <fst/lookahead-matcher.h> + + +namespace fst { + +// WRITABLE MATCHERS - these have the interface of Matchers (see +// matcher.h) and these additional methods: +// +// template <class F> +// class Matcher { +// public: +// typedef ... MatcherData; // Initialization data +// ... +// // Constructor with additional argument for external initialization +// // data; matcher increments its reference count on construction and +// // decrements the reference count, and if 0 deletes, on destruction. +// Matcher(const F &fst, MatchType type, MatcherData *data); +// +// // Returns pointer to initialization data that can be +// // passed to a Matcher constructor. +// MatcherData *GetData() const; +// }; + +// The matcher initialization data class must have the form: +// class MatcherData { +// public: +// // Required copy constructor. +// MatcherData(const MatcherData &); +// // +// // Required I/O methods. +// static MatcherData *Read(istream &istrm); +// bool Write(ostream &ostrm); +// +// // Required reference counting. +// int RefCount() const; +// int IncrRefCount(); +// int DecrRefCount(); +// }; + +// Default MatcherFst initializer - does nothing. +template <class M> +class NullMatcherFstInit { + public: + typedef AddOnPair<typename M::MatcherData, typename M::MatcherData> D; + typedef AddOnImpl<typename M::FST, D> Impl; + NullMatcherFstInit(Impl **) {} +}; + +// Class to add a matcher M to an Fst F. Creates a new Fst of type name N. +// Optional function object I can be used to initialize the Fst. +template <class F, class M, const char* N, + class I = NullMatcherFstInit<M> > +class MatcherFst + : public ImplToExpandedFst< + AddOnImpl<F, + AddOnPair<typename M::MatcherData, + typename M::MatcherData> > > { + public: + friend class StateIterator< MatcherFst<F, M, N, I> >; + friend class ArcIterator< MatcherFst<F, M, N, I> >; + + typedef F FST; + typedef M FstMatcher; + typedef typename F::Arc Arc; + typedef typename Arc::StateId StateId; + typedef AddOnPair<typename M::MatcherData, typename M::MatcherData> D; + typedef AddOnImpl<F, D> Impl; + + MatcherFst() : ImplToExpandedFst<Impl>(new Impl(F(), N)) {} + + explicit MatcherFst(const F &fst) + : ImplToExpandedFst<Impl>(CreateImpl(fst, N)) {} + + explicit MatcherFst(const Fst<Arc> &fst) + : ImplToExpandedFst<Impl>(CreateImpl(fst, N)) {} + + // See Fst<>::Copy() for doc. + MatcherFst(const MatcherFst<F, M, N, I> &fst, bool safe = false) + : ImplToExpandedFst<Impl>(fst, safe) {} + + // Get a copy of this MatcherFst. See Fst<>::Copy() for further doc. + virtual MatcherFst<F, M, N, I> *Copy(bool safe = false) const { + return new MatcherFst<F, M, N, I>(*this, safe); + } + + // Read a MatcherFst from an input stream; return NULL on error + static MatcherFst<F, M, N, I> *Read(istream &strm, + const FstReadOptions &opts) { + Impl *impl = Impl::Read(strm, opts); + return impl ? new MatcherFst<F, M, N, I>(impl) : 0; + } + + // Read a MatcherFst from a file; return NULL on error + // Empty filename reads from standard input + static MatcherFst<F, M, N, I> *Read(const string &filename) { + Impl *impl = ImplToExpandedFst<Impl>::Read(filename); + return impl ? new MatcherFst<F, M, N, I>(impl) : 0; + } + + virtual bool Write(ostream &strm, const FstWriteOptions &opts) const { + return GetImpl()->Write(strm, opts); + } + + virtual bool Write(const string &filename) const { + return Fst<Arc>::WriteFile(filename); + } + + virtual void InitStateIterator(StateIteratorData<Arc> *data) const { + return GetImpl()->InitStateIterator(data); + } + + virtual void InitArcIterator(StateId s, ArcIteratorData<Arc> *data) const { + return GetImpl()->InitArcIterator(s, data); + } + + virtual M *InitMatcher(MatchType match_type) const { + return new M(GetFst(), match_type, GetData(match_type)); + } + + // Allows access to MatcherFst components. + Impl *GetImpl() const { + return ImplToFst<Impl, ExpandedFst<Arc> >::GetImpl(); + } + + F& GetFst() const { return GetImpl()->GetFst(); } + + typename M::MatcherData *GetData(MatchType match_type) const { + D *data = GetImpl()->GetAddOn(); + return match_type == MATCH_INPUT ? data->First() : data->Second(); + } + + private: + static Impl *CreateImpl(const F &fst, const string &name) { + M imatcher(fst, MATCH_INPUT); + M omatcher(fst, MATCH_OUTPUT); + D *data = new D(imatcher.GetData(), omatcher.GetData()); + Impl *impl = new Impl(fst, name); + impl->SetAddOn(data); + I init(&impl); + data->DecrRefCount(); + return impl; + } + + static Impl *CreateImpl(const Fst<Arc> &fst, const string &name) { + F ffst(fst); + return CreateImpl(ffst, name); + } + + explicit MatcherFst(Impl *impl) : ImplToExpandedFst<Impl>(impl) {} + + // Makes visible to friends. + void SetImpl(Impl *impl, bool own_impl = true) { + ImplToFst< Impl, ExpandedFst<Arc> >::SetImpl(impl, own_impl); + } + + void operator=(const MatcherFst<F, M, N, I> &fst); // disallow +}; + + +// Specialization fo MatcherFst. +template <class F, class M, const char* N, class I> +class StateIterator< MatcherFst<F, M, N, I> > : public StateIterator<F> { + public: + explicit StateIterator(const MatcherFst<F, M, N, I> &fst) : + StateIterator<F>(fst.GetImpl()->GetFst()) {} +}; + + +// Specialization for MatcherFst. +template <class F, class M, const char* N, class I> +class ArcIterator< MatcherFst<F, M, N, I> > : public ArcIterator<F> { + public: + ArcIterator(const MatcherFst<F, M, N, I> &fst, typename F::Arc::StateId s) + : ArcIterator<F>(fst.GetImpl()->GetFst(), s) {} +}; + + +// Specialization for MatcherFst +template <class F, class M, const char* N, class I> +class Matcher< MatcherFst<F, M, N, I> > { + public: + typedef MatcherFst<F, M, N, I> FST; + typedef typename F::Arc Arc; + typedef typename Arc::StateId StateId; + typedef typename Arc::Label Label; + + Matcher(const FST &fst, MatchType match_type) { + matcher_ = fst.InitMatcher(match_type); + } + + Matcher(const Matcher<FST> &matcher) { + matcher_ = matcher.matcher_->Copy(); + } + + ~Matcher() { delete matcher_; } + + Matcher<FST> *Copy() const { + return new Matcher<FST>(*this); + } + + MatchType Type(bool test) const { return matcher_->Type(test); } + void SetState(StateId s) { matcher_->SetState(s); } + bool Find(Label label) { return matcher_->Find(label); } + bool Done() const { return matcher_->Done(); } + const Arc& Value() const { return matcher_->Value(); } + void Next() { matcher_->Next(); } + uint64 Properties(uint64 props) const { return matcher_->Properties(props); } + uint32 Flags() const { return matcher_->Flags(); } + + private: + M *matcher_; + + void operator=(const Matcher<Arc> &); // disallow +}; + + +// Specialization for MatcherFst +template <class F, class M, const char* N, class I> +class LookAheadMatcher< MatcherFst<F, M, N, I> > { + public: + typedef MatcherFst<F, M, N, I> FST; + typedef typename F::Arc Arc; + typedef typename Arc::StateId StateId; + typedef typename Arc::Label Label; + typedef typename Arc::Weight Weight; + + LookAheadMatcher(const FST &fst, MatchType match_type) { + matcher_ = fst.InitMatcher(match_type); + } + + LookAheadMatcher(const LookAheadMatcher<FST> &matcher, bool safe = false) { + matcher_ = matcher.matcher_->Copy(safe); + } + + ~LookAheadMatcher() { delete matcher_; } + + // General matcher methods + LookAheadMatcher<FST> *Copy(bool safe = false) const { + return new LookAheadMatcher<FST>(*this, safe); + } + + MatchType Type(bool test) const { return matcher_->Type(test); } + void SetState(StateId s) { matcher_->SetState(s); } + bool Find(Label label) { return matcher_->Find(label); } + bool Done() const { return matcher_->Done(); } + const Arc& Value() const { return matcher_->Value(); } + void Next() { matcher_->Next(); } + const FST &GetFst() const { return matcher_->GetFst(); } + uint64 Properties(uint64 props) const { return matcher_->Properties(props); } + uint32 Flags() const { return matcher_->Flags(); } + + // Look-ahead methods + bool LookAheadLabel(Label label) const { + return matcher_->LookAheadLabel(label); + } + + bool LookAheadFst(const Fst<Arc> &fst, StateId s) { + return matcher_->LookAheadFst(fst, s); + } + + Weight LookAheadWeight() const { return matcher_->LookAheadWeight(); } + + bool LookAheadPrefix(Arc *arc) const { + return matcher_->LookAheadPrefix(arc); + } + + void InitLookAheadFst(const Fst<Arc>& fst, bool copy = false) { + matcher_->InitLookAheadFst(fst, copy); + } + + private: + M *matcher_; + + void operator=(const LookAheadMatcher<FST> &); // disallow +}; + +// +// Useful aliases when using StdArc and LogArc. +// + +// Arc look-ahead matchers +extern const char arc_lookahead_fst_type[]; + +typedef MatcherFst<ConstFst<StdArc>, + ArcLookAheadMatcher<SortedMatcher<ConstFst<StdArc> > >, + arc_lookahead_fst_type> StdArcLookAheadFst; + +typedef MatcherFst<ConstFst<LogArc>, + ArcLookAheadMatcher<SortedMatcher<ConstFst<LogArc> > >, + arc_lookahead_fst_type> LogArcLookAheadFst; + + +// Label look-ahead matchers +extern const char ilabel_lookahead_fst_type[]; +extern const char olabel_lookahead_fst_type[]; + +static const uint32 ilabel_lookahead_flags = kInputLookAheadMatcher | + kLookAheadWeight | kLookAheadPrefix | + kLookAheadEpsilons | kLookAheadNonEpsilonPrefix; +static const uint32 olabel_lookahead_flags = kOutputLookAheadMatcher | + kLookAheadWeight | kLookAheadPrefix | + kLookAheadEpsilons | kLookAheadNonEpsilonPrefix; + +typedef MatcherFst<ConstFst<StdArc>, + LabelLookAheadMatcher<SortedMatcher<ConstFst<StdArc> >, + ilabel_lookahead_flags, + FastLogAccumulator<StdArc> >, + ilabel_lookahead_fst_type, + LabelLookAheadRelabeler<StdArc> > StdILabelLookAheadFst; + +typedef MatcherFst<ConstFst<LogArc>, + LabelLookAheadMatcher<SortedMatcher<ConstFst<LogArc> >, + ilabel_lookahead_flags, + FastLogAccumulator<LogArc> >, + ilabel_lookahead_fst_type, + LabelLookAheadRelabeler<LogArc> > LogILabelLookAheadFst; + +typedef MatcherFst<ConstFst<StdArc>, + LabelLookAheadMatcher<SortedMatcher<ConstFst<StdArc> >, + olabel_lookahead_flags, + FastLogAccumulator<StdArc> >, + olabel_lookahead_fst_type, + LabelLookAheadRelabeler<StdArc> > StdOLabelLookAheadFst; + +typedef MatcherFst<ConstFst<LogArc>, + LabelLookAheadMatcher<SortedMatcher<ConstFst<LogArc> >, + olabel_lookahead_flags, + FastLogAccumulator<LogArc> >, + olabel_lookahead_fst_type, + LabelLookAheadRelabeler<LogArc> > LogOLabelLookAheadFst; + +} // namespace fst + +#endif // FST_LIB_MATCHER_FST_FST_H__ diff --git a/kaldi_io/src/tools/openfst/include/fst/matcher.h b/kaldi_io/src/tools/openfst/include/fst/matcher.h new file mode 100644 index 0000000..89ed9be --- /dev/null +++ b/kaldi_io/src/tools/openfst/include/fst/matcher.h @@ -0,0 +1,1205 @@ +// matcher.h + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Copyright 2005-2010 Google, Inc. +// Author: [email protected] (Michael Riley) +// +// \file +// Classes to allow matching labels leaving FST states. + +#ifndef FST_LIB_MATCHER_H__ +#define FST_LIB_MATCHER_H__ + +#include <algorithm> +#include <set> + +#include <fst/mutable-fst.h> // for all internal FST accessors + + +namespace fst { + +// MATCHERS - these can find and iterate through requested labels at +// FST states. In the simplest form, these are just some associative +// map or search keyed on labels. More generally, they may +// implement matching special labels that represent sets of labels +// such as 'sigma' (all), 'rho' (rest), or 'phi' (fail). +// The Matcher interface is: +// +// template <class F> +// class Matcher { +// public: +// typedef F FST; +// typedef F::Arc Arc; +// typedef typename Arc::StateId StateId; +// typedef typename Arc::Label Label; +// typedef typename Arc::Weight Weight; +// +// // Required constructors. +// Matcher(const F &fst, MatchType type); +// // If safe=true, the copy is thread-safe. See Fst<>::Copy() +// // for further doc. +// Matcher(const Matcher &matcher, bool safe = false); +// +// // If safe=true, the copy is thread-safe. See Fst<>::Copy() +// // for further doc. +// Matcher<F> *Copy(bool safe = false) const; +// +// // Returns the match type that can be provided (depending on +// // compatibility of the input FST). It is either +// // the requested match type, MATCH_NONE, or MATCH_UNKNOWN. +// // If 'test' is false, a constant time test is performed, but +// // MATCH_UNKNOWN may be returned. If 'test' is true, +// // a definite answer is returned, but may involve more costly +// // computation (e.g., visiting the Fst). +// MatchType Type(bool test) const; +// // Specifies the current state. +// void SetState(StateId s); +// +// // This finds matches to a label at the current state. +// // Returns true if a match found. kNoLabel matches any +// // 'non-consuming' transitions, e.g., epsilon transitions, +// // which do not require a matching symbol. +// bool Find(Label label); +// // These iterate through any matches found: +// bool Done() const; // No more matches. +// const A& Value() const; // Current arc (when !Done) +// void Next(); // Advance to next arc (when !Done) +// // Initially and after SetState() the iterator methods +// // have undefined behavior until Find() is called. +// +// // Return matcher FST. +// const F& GetFst() const; +// // This specifies the known Fst properties as viewed from this +// // matcher. It takes as argument the input Fst's known properties. +// uint64 Properties(uint64 props) const; +// }; + +// +// MATCHER FLAGS (see also kLookAheadFlags in lookahead-matcher.h) +// +// Matcher prefers being used as the matching side in composition. +const uint32 kPreferMatch = 0x00000001; + +// Matcher needs to be used as the matching side in composition. +const uint32 kRequireMatch = 0x00000002; + +// Flags used for basic matchers (see also lookahead.h). +const uint32 kMatcherFlags = kPreferMatch | kRequireMatch; + +// Matcher interface, templated on the Arc definition; used +// for matcher specializations that are returned by the +// InitMatcher Fst method. +template <class A> +class MatcherBase { + public: + typedef A Arc; + typedef typename A::StateId StateId; + typedef typename A::Label Label; + typedef typename A::Weight Weight; + + virtual ~MatcherBase() {} + + virtual MatcherBase<A> *Copy(bool safe = false) const = 0; + virtual MatchType Type(bool test) const = 0; + void SetState(StateId s) { SetState_(s); } + bool Find(Label label) { return Find_(label); } + bool Done() const { return Done_(); } + const A& Value() const { return Value_(); } + void Next() { Next_(); } + virtual const Fst<A> &GetFst() const = 0; + virtual uint64 Properties(uint64 props) const = 0; + virtual uint32 Flags() const { return 0; } + private: + virtual void SetState_(StateId s) = 0; + virtual bool Find_(Label label) = 0; + virtual bool Done_() const = 0; + virtual const A& Value_() const = 0; + virtual void Next_() = 0; +}; + + +// A matcher that expects sorted labels on the side to be matched. +// If match_type == MATCH_INPUT, epsilons match the implicit self loop +// Arc(kNoLabel, 0, Weight::One(), current_state) as well as any +// actual epsilon transitions. If match_type == MATCH_OUTPUT, then +// Arc(0, kNoLabel, Weight::One(), current_state) is instead matched. +template <class F> +class SortedMatcher : public MatcherBase<typename F::Arc> { + public: + typedef F FST; + typedef typename F::Arc Arc; + typedef typename Arc::StateId StateId; + typedef typename Arc::Label Label; + typedef typename Arc::Weight Weight; + + // Labels >= binary_label will be searched for by binary search, + // o.w. linear search is used. + SortedMatcher(const F &fst, MatchType match_type, + Label binary_label = 1) + : fst_(fst.Copy()), + s_(kNoStateId), + aiter_(0), + match_type_(match_type), + binary_label_(binary_label), + match_label_(kNoLabel), + narcs_(0), + loop_(kNoLabel, 0, Weight::One(), kNoStateId), + error_(false) { + switch(match_type_) { + case MATCH_INPUT: + case MATCH_NONE: + break; + case MATCH_OUTPUT: + swap(loop_.ilabel, loop_.olabel); + break; + default: + FSTERROR() << "SortedMatcher: bad match type"; + match_type_ = MATCH_NONE; + error_ = true; + } + } + + SortedMatcher(const SortedMatcher<F> &matcher, bool safe = false) + : fst_(matcher.fst_->Copy(safe)), + s_(kNoStateId), + aiter_(0), + match_type_(matcher.match_type_), + binary_label_(matcher.binary_label_), + match_label_(kNoLabel), + narcs_(0), + loop_(matcher.loop_), + error_(matcher.error_) {} + + virtual ~SortedMatcher() { + if (aiter_) + delete aiter_; + delete fst_; + } + + virtual SortedMatcher<F> *Copy(bool safe = false) const { + return new SortedMatcher<F>(*this, safe); + } + + virtual MatchType Type(bool test) const { + if (match_type_ == MATCH_NONE) + return match_type_; + + uint64 true_prop = match_type_ == MATCH_INPUT ? + kILabelSorted : kOLabelSorted; + uint64 false_prop = match_type_ == MATCH_INPUT ? + kNotILabelSorted : kNotOLabelSorted; + uint64 props = fst_->Properties(true_prop | false_prop, test); + + if (props & true_prop) + return match_type_; + else if (props & false_prop) + return MATCH_NONE; + else + return MATCH_UNKNOWN; + } + + void SetState(StateId s) { + if (s_ == s) + return; + s_ = s; + if (match_type_ == MATCH_NONE) { + FSTERROR() << "SortedMatcher: bad match type"; + error_ = true; + } + if (aiter_) + delete aiter_; + aiter_ = new ArcIterator<F>(*fst_, s); + aiter_->SetFlags(kArcNoCache, kArcNoCache); + narcs_ = internal::NumArcs(*fst_, s); + loop_.nextstate = s; + } + + bool Find(Label match_label) { + exact_match_ = true; + if (error_) { + current_loop_ = false; + match_label_ = kNoLabel; + return false; + } + current_loop_ = match_label == 0; + match_label_ = match_label == kNoLabel ? 0 : match_label; + if (Search()) { + return true; + } else { + return current_loop_; + } + } + + // Positions matcher to the first position where inserting + // match_label would maintain the sort order. + void LowerBound(Label match_label) { + exact_match_ = false; + current_loop_ = false; + if (error_) { + match_label_ = kNoLabel; + return; + } + match_label_ = match_label; + Search(); + } + + // After Find(), returns false if no more exact matches. + // After LowerBound(), returns false if no more arcs. + bool Done() const { + if (current_loop_) + return false; + if (aiter_->Done()) + return true; + if (!exact_match_) + return false; + aiter_->SetFlags( + match_type_ == MATCH_INPUT ? kArcILabelValue : kArcOLabelValue, + kArcValueFlags); + Label label = match_type_ == MATCH_INPUT ? + aiter_->Value().ilabel : aiter_->Value().olabel; + return label != match_label_; + } + + const Arc& Value() const { + if (current_loop_) { + return loop_; + } + aiter_->SetFlags(kArcValueFlags, kArcValueFlags); + return aiter_->Value(); + } + + void Next() { + if (current_loop_) + current_loop_ = false; + else + aiter_->Next(); + } + + virtual const F &GetFst() const { return *fst_; } + + virtual uint64 Properties(uint64 inprops) const { + uint64 outprops = inprops; + if (error_) outprops |= kError; + return outprops; + } + + size_t Position() const { return aiter_ ? aiter_->Position() : 0; } + + private: + virtual void SetState_(StateId s) { SetState(s); } + virtual bool Find_(Label label) { return Find(label); } + virtual bool Done_() const { return Done(); } + virtual const Arc& Value_() const { return Value(); } + virtual void Next_() { Next(); } + + bool Search(); + + const F *fst_; + StateId s_; // Current state + ArcIterator<F> *aiter_; // Iterator for current state + MatchType match_type_; // Type of match to perform + Label binary_label_; // Least label for binary search + Label match_label_; // Current label to be matched + size_t narcs_; // Current state arc count + Arc loop_; // For non-consuming symbols + bool current_loop_; // Current arc is the implicit loop + bool exact_match_; // Exact match or lower bound? + bool error_; // Error encountered + + void operator=(const SortedMatcher<F> &); // Disallow +}; + +// Returns true iff match to match_label_. Positions arc iterator at +// lower bound regardless. +template <class F> inline +bool SortedMatcher<F>::Search() { + aiter_->SetFlags( + match_type_ == MATCH_INPUT ? kArcILabelValue : kArcOLabelValue, + kArcValueFlags); + if (match_label_ >= binary_label_) { + // Binary search for match. + size_t low = 0; + size_t high = narcs_; + while (low < high) { + size_t mid = (low + high) / 2; + aiter_->Seek(mid); + Label label = match_type_ == MATCH_INPUT ? + aiter_->Value().ilabel : aiter_->Value().olabel; + if (label > match_label_) { + high = mid; + } else if (label < match_label_) { + low = mid + 1; + } else { + // find first matching label (when non-determinism) + for (size_t i = mid; i > low; --i) { + aiter_->Seek(i - 1); + label = match_type_ == MATCH_INPUT ? aiter_->Value().ilabel : + aiter_->Value().olabel; + if (label != match_label_) { + aiter_->Seek(i); + return true; + } + } + return true; + } + } + aiter_->Seek(low); + return false; + } else { + // Linear search for match. + for (aiter_->Reset(); !aiter_->Done(); aiter_->Next()) { + Label label = match_type_ == MATCH_INPUT ? + aiter_->Value().ilabel : aiter_->Value().olabel; + if (label == match_label_) { + return true; + } + if (label > match_label_) + break; + } + return false; + } +} + + +// Specifies whether during matching we rewrite both the input and output sides. +enum MatcherRewriteMode { + MATCHER_REWRITE_AUTO = 0, // Rewrites both sides iff acceptor. + MATCHER_REWRITE_ALWAYS, + MATCHER_REWRITE_NEVER +}; + + +// For any requested label that doesn't match at a state, this matcher +// considers all transitions that match the label 'rho_label' (rho = +// 'rest'). Each such rho transition found is returned with the +// rho_label rewritten as the requested label (both sides if an +// acceptor, or if 'rewrite_both' is true and both input and output +// labels of the found transition are 'rho_label'). If 'rho_label' is +// kNoLabel, this special matching is not done. RhoMatcher is +// templated itself on a matcher, which is used to perform the +// underlying matching. By default, the underlying matcher is +// constructed by RhoMatcher. The user can instead pass in this +// object; in that case, RhoMatcher takes its ownership. +template <class M> +class RhoMatcher : public MatcherBase<typename M::Arc> { + public: + typedef typename M::FST FST; + typedef typename M::Arc Arc; + typedef typename Arc::StateId StateId; + typedef typename Arc::Label Label; + typedef typename Arc::Weight Weight; + + RhoMatcher(const FST &fst, + MatchType match_type, + Label rho_label = kNoLabel, + MatcherRewriteMode rewrite_mode = MATCHER_REWRITE_AUTO, + M *matcher = 0) + : matcher_(matcher ? matcher : new M(fst, match_type)), + match_type_(match_type), + rho_label_(rho_label), + error_(false) { + if (match_type == MATCH_BOTH) { + FSTERROR() << "RhoMatcher: bad match type"; + match_type_ = MATCH_NONE; + error_ = true; + } + if (rho_label == 0) { + FSTERROR() << "RhoMatcher: 0 cannot be used as rho_label"; + rho_label_ = kNoLabel; + error_ = true; + } + + if (rewrite_mode == MATCHER_REWRITE_AUTO) + rewrite_both_ = fst.Properties(kAcceptor, true); + else if (rewrite_mode == MATCHER_REWRITE_ALWAYS) + rewrite_both_ = true; + else + rewrite_both_ = false; + } + + RhoMatcher(const RhoMatcher<M> &matcher, bool safe = false) + : matcher_(new M(*matcher.matcher_, safe)), + match_type_(matcher.match_type_), + rho_label_(matcher.rho_label_), + rewrite_both_(matcher.rewrite_both_), + error_(matcher.error_) {} + + virtual ~RhoMatcher() { + delete matcher_; + } + + virtual RhoMatcher<M> *Copy(bool safe = false) const { + return new RhoMatcher<M>(*this, safe); + } + + virtual MatchType Type(bool test) const { return matcher_->Type(test); } + + void SetState(StateId s) { + matcher_->SetState(s); + has_rho_ = rho_label_ != kNoLabel; + } + + bool Find(Label match_label) { + if (match_label == rho_label_ && rho_label_ != kNoLabel) { + FSTERROR() << "RhoMatcher::Find: bad label (rho)"; + error_ = true; + return false; + } + if (matcher_->Find(match_label)) { + rho_match_ = kNoLabel; + return true; + } else if (has_rho_ && match_label != 0 && match_label != kNoLabel && + (has_rho_ = matcher_->Find(rho_label_))) { + rho_match_ = match_label; + return true; + } else { + return false; + } + } + + bool Done() const { return matcher_->Done(); } + + const Arc& Value() const { + if (rho_match_ == kNoLabel) { + return matcher_->Value(); + } else { + rho_arc_ = matcher_->Value(); + if (rewrite_both_) { + if (rho_arc_.ilabel == rho_label_) + rho_arc_.ilabel = rho_match_; + if (rho_arc_.olabel == rho_label_) + rho_arc_.olabel = rho_match_; + } else if (match_type_ == MATCH_INPUT) { + rho_arc_.ilabel = rho_match_; + } else { + rho_arc_.olabel = rho_match_; + } + return rho_arc_; + } + } + + void Next() { matcher_->Next(); } + + virtual const FST &GetFst() const { return matcher_->GetFst(); } + + virtual uint64 Properties(uint64 props) const; + + virtual uint32 Flags() const { + if (rho_label_ == kNoLabel || match_type_ == MATCH_NONE) + return matcher_->Flags(); + return matcher_->Flags() | kRequireMatch; + } + + private: + virtual void SetState_(StateId s) { SetState(s); } + virtual bool Find_(Label label) { return Find(label); } + virtual bool Done_() const { return Done(); } + virtual const Arc& Value_() const { return Value(); } + virtual void Next_() { Next(); } + + M *matcher_; + MatchType match_type_; // Type of match requested + Label rho_label_; // Label that represents the rho transition + bool rewrite_both_; // Rewrite both sides when both are 'rho_label_' + bool has_rho_; // Are there possibly rhos at the current state? + Label rho_match_; // Current label that matches rho transition + mutable Arc rho_arc_; // Arc to return when rho match + bool error_; // Error encountered + + void operator=(const RhoMatcher<M> &); // Disallow +}; + +template <class M> inline +uint64 RhoMatcher<M>::Properties(uint64 inprops) const { + uint64 outprops = matcher_->Properties(inprops); + if (error_) outprops |= kError; + + if (match_type_ == MATCH_NONE) { + return outprops; + } else if (match_type_ == MATCH_INPUT) { + if (rewrite_both_) { + return outprops & ~(kODeterministic | kNonODeterministic | kString | + kILabelSorted | kNotILabelSorted | + kOLabelSorted | kNotOLabelSorted); + } else { + return outprops & ~(kODeterministic | kAcceptor | kString | + kILabelSorted | kNotILabelSorted); + } + } else if (match_type_ == MATCH_OUTPUT) { + if (rewrite_both_) { + return outprops & ~(kIDeterministic | kNonIDeterministic | kString | + kILabelSorted | kNotILabelSorted | + kOLabelSorted | kNotOLabelSorted); + } else { + return outprops & ~(kIDeterministic | kAcceptor | kString | + kOLabelSorted | kNotOLabelSorted); + } + } else { + // Shouldn't ever get here. + FSTERROR() << "RhoMatcher:: bad match type: " << match_type_; + return 0; + } +} + + +// For any requested label, this matcher considers all transitions +// that match the label 'sigma_label' (sigma = "any"), and this in +// additions to transitions with the requested label. Each such sigma +// transition found is returned with the sigma_label rewritten as the +// requested label (both sides if an acceptor, or if 'rewrite_both' is +// true and both input and output labels of the found transition are +// 'sigma_label'). If 'sigma_label' is kNoLabel, this special +// matching is not done. SigmaMatcher is templated itself on a +// matcher, which is used to perform the underlying matching. By +// default, the underlying matcher is constructed by SigmaMatcher. +// The user can instead pass in this object; in that case, +// SigmaMatcher takes its ownership. +template <class M> +class SigmaMatcher : public MatcherBase<typename M::Arc> { + public: + typedef typename M::FST FST; + typedef typename M::Arc Arc; + typedef typename Arc::StateId StateId; + typedef typename Arc::Label Label; + typedef typename Arc::Weight Weight; + + SigmaMatcher(const FST &fst, + MatchType match_type, + Label sigma_label = kNoLabel, + MatcherRewriteMode rewrite_mode = MATCHER_REWRITE_AUTO, + M *matcher = 0) + : matcher_(matcher ? matcher : new M(fst, match_type)), + match_type_(match_type), + sigma_label_(sigma_label), + error_(false) { + if (match_type == MATCH_BOTH) { + FSTERROR() << "SigmaMatcher: bad match type"; + match_type_ = MATCH_NONE; + error_ = true; + } + if (sigma_label == 0) { + FSTERROR() << "SigmaMatcher: 0 cannot be used as sigma_label"; + sigma_label_ = kNoLabel; + error_ = true; + } + + if (rewrite_mode == MATCHER_REWRITE_AUTO) + rewrite_both_ = fst.Properties(kAcceptor, true); + else if (rewrite_mode == MATCHER_REWRITE_ALWAYS) + rewrite_both_ = true; + else + rewrite_both_ = false; + } + + SigmaMatcher(const SigmaMatcher<M> &matcher, bool safe = false) + : matcher_(new M(*matcher.matcher_, safe)), + match_type_(matcher.match_type_), + sigma_label_(matcher.sigma_label_), + rewrite_both_(matcher.rewrite_both_), + error_(matcher.error_) {} + + virtual ~SigmaMatcher() { + delete matcher_; + } + + virtual SigmaMatcher<M> *Copy(bool safe = false) const { + return new SigmaMatcher<M>(*this, safe); + } + + virtual MatchType Type(bool test) const { return matcher_->Type(test); } + + void SetState(StateId s) { + matcher_->SetState(s); + has_sigma_ = + sigma_label_ != kNoLabel ? matcher_->Find(sigma_label_) : false; + } + + bool Find(Label match_label) { + match_label_ = match_label; + if (match_label == sigma_label_ && sigma_label_ != kNoLabel) { + FSTERROR() << "SigmaMatcher::Find: bad label (sigma)"; + error_ = true; + return false; + } + if (matcher_->Find(match_label)) { + sigma_match_ = kNoLabel; + return true; + } else if (has_sigma_ && match_label != 0 && match_label != kNoLabel && + matcher_->Find(sigma_label_)) { + sigma_match_ = match_label; + return true; + } else { + return false; + } + } + + bool Done() const { + return matcher_->Done(); + } + + const Arc& Value() const { + if (sigma_match_ == kNoLabel) { + return matcher_->Value(); + } else { + sigma_arc_ = matcher_->Value(); + if (rewrite_both_) { + if (sigma_arc_.ilabel == sigma_label_) + sigma_arc_.ilabel = sigma_match_; + if (sigma_arc_.olabel == sigma_label_) + sigma_arc_.olabel = sigma_match_; + } else if (match_type_ == MATCH_INPUT) { + sigma_arc_.ilabel = sigma_match_; + } else { + sigma_arc_.olabel = sigma_match_; + } + return sigma_arc_; + } + } + + void Next() { + matcher_->Next(); + if (matcher_->Done() && has_sigma_ && (sigma_match_ == kNoLabel) && + (match_label_ > 0)) { + matcher_->Find(sigma_label_); + sigma_match_ = match_label_; + } + } + + virtual const FST &GetFst() const { return matcher_->GetFst(); } + + virtual uint64 Properties(uint64 props) const; + + virtual uint32 Flags() const { + if (sigma_label_ == kNoLabel || match_type_ == MATCH_NONE) + return matcher_->Flags(); + // kRequireMatch temporarily disabled until issues + // in //speech/gaudi/annotation/util/denorm are resolved. + // return matcher_->Flags() | kRequireMatch; + return matcher_->Flags(); + } + +private: + virtual void SetState_(StateId s) { SetState(s); } + virtual bool Find_(Label label) { return Find(label); } + virtual bool Done_() const { return Done(); } + virtual const Arc& Value_() const { return Value(); } + virtual void Next_() { Next(); } + + M *matcher_; + MatchType match_type_; // Type of match requested + Label sigma_label_; // Label that represents the sigma transition + bool rewrite_both_; // Rewrite both sides when both are 'sigma_label_' + bool has_sigma_; // Are there sigmas at the current state? + Label sigma_match_; // Current label that matches sigma transition + mutable Arc sigma_arc_; // Arc to return when sigma match + Label match_label_; // Label being matched + bool error_; // Error encountered + + void operator=(const SigmaMatcher<M> &); // disallow +}; + +template <class M> inline +uint64 SigmaMatcher<M>::Properties(uint64 inprops) const { + uint64 outprops = matcher_->Properties(inprops); + if (error_) outprops |= kError; + + if (match_type_ == MATCH_NONE) { + return outprops; + } else if (rewrite_both_) { + return outprops & ~(kIDeterministic | kNonIDeterministic | + kODeterministic | kNonODeterministic | + kILabelSorted | kNotILabelSorted | + kOLabelSorted | kNotOLabelSorted | + kString); + } else if (match_type_ == MATCH_INPUT) { + return outprops & ~(kIDeterministic | kNonIDeterministic | + kODeterministic | kNonODeterministic | + kILabelSorted | kNotILabelSorted | + kString | kAcceptor); + } else if (match_type_ == MATCH_OUTPUT) { + return outprops & ~(kIDeterministic | kNonIDeterministic | + kODeterministic | kNonODeterministic | + kOLabelSorted | kNotOLabelSorted | + kString | kAcceptor); + } else { + // Shouldn't ever get here. + FSTERROR() << "SigmaMatcher:: bad match type: " << match_type_; + return 0; + } +} + + +// For any requested label that doesn't match at a state, this matcher +// considers the *unique* transition that matches the label 'phi_label' +// (phi = 'fail'), and recursively looks for a match at its +// destination. When 'phi_loop' is true, if no match is found but a +// phi self-loop is found, then the phi transition found is returned +// with the phi_label rewritten as the requested label (both sides if +// an acceptor, or if 'rewrite_both' is true and both input and output +// labels of the found transition are 'phi_label'). If 'phi_label' is +// kNoLabel, this special matching is not done. PhiMatcher is +// templated itself on a matcher, which is used to perform the +// underlying matching. By default, the underlying matcher is +// constructed by PhiMatcher. The user can instead pass in this +// object; in that case, PhiMatcher takes its ownership. +// Warning: phi non-determinism not supported (for simplicity). +template <class M> +class PhiMatcher : public MatcherBase<typename M::Arc> { + public: + typedef typename M::FST FST; + typedef typename M::Arc Arc; + typedef typename Arc::StateId StateId; + typedef typename Arc::Label Label; + typedef typename Arc::Weight Weight; + + PhiMatcher(const FST &fst, + MatchType match_type, + Label phi_label = kNoLabel, + bool phi_loop = true, + MatcherRewriteMode rewrite_mode = MATCHER_REWRITE_AUTO, + M *matcher = 0) + : matcher_(matcher ? matcher : new M(fst, match_type)), + match_type_(match_type), + phi_label_(phi_label), + state_(kNoStateId), + phi_loop_(phi_loop), + error_(false) { + if (match_type == MATCH_BOTH) { + FSTERROR() << "PhiMatcher: bad match type"; + match_type_ = MATCH_NONE; + error_ = true; + } + + if (rewrite_mode == MATCHER_REWRITE_AUTO) + rewrite_both_ = fst.Properties(kAcceptor, true); + else if (rewrite_mode == MATCHER_REWRITE_ALWAYS) + rewrite_both_ = true; + else + rewrite_both_ = false; + } + + PhiMatcher(const PhiMatcher<M> &matcher, bool safe = false) + : matcher_(new M(*matcher.matcher_, safe)), + match_type_(matcher.match_type_), + phi_label_(matcher.phi_label_), + rewrite_both_(matcher.rewrite_both_), + state_(kNoStateId), + phi_loop_(matcher.phi_loop_), + error_(matcher.error_) {} + + virtual ~PhiMatcher() { + delete matcher_; + } + + virtual PhiMatcher<M> *Copy(bool safe = false) const { + return new PhiMatcher<M>(*this, safe); + } + + virtual MatchType Type(bool test) const { return matcher_->Type(test); } + + void SetState(StateId s) { + matcher_->SetState(s); + state_ = s; + has_phi_ = phi_label_ != kNoLabel; + } + + bool Find(Label match_label); + + bool Done() const { return matcher_->Done(); } + + const Arc& Value() const { + if ((phi_match_ == kNoLabel) && (phi_weight_ == Weight::One())) { + return matcher_->Value(); + } else if (phi_match_ == 0) { // Virtual epsilon loop + phi_arc_ = Arc(kNoLabel, 0, Weight::One(), state_); + if (match_type_ == MATCH_OUTPUT) + swap(phi_arc_.ilabel, phi_arc_.olabel); + return phi_arc_; + } else { + phi_arc_ = matcher_->Value(); + phi_arc_.weight = Times(phi_weight_, phi_arc_.weight); + if (phi_match_ != kNoLabel) { // Phi loop match + if (rewrite_both_) { + if (phi_arc_.ilabel == phi_label_) + phi_arc_.ilabel = phi_match_; + if (phi_arc_.olabel == phi_label_) + phi_arc_.olabel = phi_match_; + } else if (match_type_ == MATCH_INPUT) { + phi_arc_.ilabel = phi_match_; + } else { + phi_arc_.olabel = phi_match_; + } + } + return phi_arc_; + } + } + + void Next() { matcher_->Next(); } + + virtual const FST &GetFst() const { return matcher_->GetFst(); } + + virtual uint64 Properties(uint64 props) const; + + virtual uint32 Flags() const { + if (phi_label_ == kNoLabel || match_type_ == MATCH_NONE) + return matcher_->Flags(); + return matcher_->Flags() | kRequireMatch; + } + +private: + virtual void SetState_(StateId s) { SetState(s); } + virtual bool Find_(Label label) { return Find(label); } + virtual bool Done_() const { return Done(); } + virtual const Arc& Value_() const { return Value(); } + virtual void Next_() { Next(); } + + M *matcher_; + MatchType match_type_; // Type of match requested + Label phi_label_; // Label that represents the phi transition + bool rewrite_both_; // Rewrite both sides when both are 'phi_label_' + bool has_phi_; // Are there possibly phis at the current state? + Label phi_match_; // Current label that matches phi loop + mutable Arc phi_arc_; // Arc to return + StateId state_; // State where looking for matches + Weight phi_weight_; // Product of the weights of phi transitions taken + bool phi_loop_; // When true, phi self-loop are allowed and treated + // as rho (required for Aho-Corasick) + bool error_; // Error encountered + + void operator=(const PhiMatcher<M> &); // disallow +}; + +template <class M> inline +bool PhiMatcher<M>::Find(Label match_label) { + if (match_label == phi_label_ && phi_label_ != kNoLabel && phi_label_ != 0) { + FSTERROR() << "PhiMatcher::Find: bad label (phi): " << phi_label_; + error_ = true; + return false; + } + matcher_->SetState(state_); + phi_match_ = kNoLabel; + phi_weight_ = Weight::One(); + if (phi_label_ == 0) { // When 'phi_label_ == 0', + if (match_label == kNoLabel) // there are no more true epsilon arcs, + return false; + if (match_label == 0) { // but virtual eps loop need to be returned + if (!matcher_->Find(kNoLabel)) { + return matcher_->Find(0); + } else { + phi_match_ = 0; + return true; + } + } + } + if (!has_phi_ || match_label == 0 || match_label == kNoLabel) + return matcher_->Find(match_label); + StateId state = state_; + while (!matcher_->Find(match_label)) { + // Look for phi transition (if phi_label_ == 0, we need to look + // for -1 to avoid getting the virtual self-loop) + if (!matcher_->Find(phi_label_ == 0 ? -1 : phi_label_)) + return false; + if (phi_loop_ && matcher_->Value().nextstate == state) { + phi_match_ = match_label; + return true; + } + phi_weight_ = Times(phi_weight_, matcher_->Value().weight); + state = matcher_->Value().nextstate; + matcher_->Next(); + if (!matcher_->Done()) { + FSTERROR() << "PhiMatcher: phi non-determinism not supported"; + error_ = true; + } + matcher_->SetState(state); + } + return true; +} + +template <class M> inline +uint64 PhiMatcher<M>::Properties(uint64 inprops) const { + uint64 outprops = matcher_->Properties(inprops); + if (error_) outprops |= kError; + + if (match_type_ == MATCH_NONE) { + return outprops; + } else if (match_type_ == MATCH_INPUT) { + if (phi_label_ == 0) { + outprops &= ~kEpsilons | ~kIEpsilons | ~kOEpsilons; + outprops |= kNoEpsilons | kNoIEpsilons; + } + if (rewrite_both_) { + return outprops & ~(kODeterministic | kNonODeterministic | kString | + kILabelSorted | kNotILabelSorted | + kOLabelSorted | kNotOLabelSorted); + } else { + return outprops & ~(kODeterministic | kAcceptor | kString | + kILabelSorted | kNotILabelSorted | + kOLabelSorted | kNotOLabelSorted); + } + } else if (match_type_ == MATCH_OUTPUT) { + if (phi_label_ == 0) { + outprops &= ~kEpsilons | ~kIEpsilons | ~kOEpsilons; + outprops |= kNoEpsilons | kNoOEpsilons; + } + if (rewrite_both_) { + return outprops & ~(kIDeterministic | kNonIDeterministic | kString | + kILabelSorted | kNotILabelSorted | + kOLabelSorted | kNotOLabelSorted); + } else { + return outprops & ~(kIDeterministic | kAcceptor | kString | + kILabelSorted | kNotILabelSorted | + kOLabelSorted | kNotOLabelSorted); + } + } else { + // Shouldn't ever get here. + FSTERROR() << "PhiMatcher:: bad match type: " << match_type_; + return 0; + } +} + + +// +// MULTI-EPS MATCHER FLAGS +// + +// Return multi-epsilon arcs for Find(kNoLabel). +const uint32 kMultiEpsList = 0x00000001; + +// Return a kNolabel loop for Find(multi_eps). +const uint32 kMultiEpsLoop = 0x00000002; + +// MultiEpsMatcher: allows treating multiple non-0 labels as +// non-consuming labels in addition to 0 that is always +// non-consuming. Precise behavior controlled by 'flags' argument. By +// default, the underlying matcher is constructed by +// MultiEpsMatcher. The user can instead pass in this object; in that +// case, MultiEpsMatcher takes its ownership iff 'own_matcher' is +// true. +template <class M> +class MultiEpsMatcher { + public: + typedef typename M::FST FST; + typedef typename M::Arc Arc; + typedef typename Arc::StateId StateId; + typedef typename Arc::Label Label; + typedef typename Arc::Weight Weight; + + MultiEpsMatcher(const FST &fst, MatchType match_type, + uint32 flags = (kMultiEpsLoop | kMultiEpsList), + M *matcher = 0, bool own_matcher = true) + : matcher_(matcher ? matcher : new M(fst, match_type)), + flags_(flags), + own_matcher_(matcher ? own_matcher : true) { + if (match_type == MATCH_INPUT) { + loop_.ilabel = kNoLabel; + loop_.olabel = 0; + } else { + loop_.ilabel = 0; + loop_.olabel = kNoLabel; + } + loop_.weight = Weight::One(); + loop_.nextstate = kNoStateId; + } + + MultiEpsMatcher(const MultiEpsMatcher<M> &matcher, bool safe = false) + : matcher_(new M(*matcher.matcher_, safe)), + flags_(matcher.flags_), + own_matcher_(true), + multi_eps_labels_(matcher.multi_eps_labels_), + loop_(matcher.loop_) { + loop_.nextstate = kNoStateId; + } + + ~MultiEpsMatcher() { + if (own_matcher_) + delete matcher_; + } + + MultiEpsMatcher<M> *Copy(bool safe = false) const { + return new MultiEpsMatcher<M>(*this, safe); + } + + MatchType Type(bool test) const { return matcher_->Type(test); } + + void SetState(StateId s) { + matcher_->SetState(s); + loop_.nextstate = s; + } + + bool Find(Label match_label); + + bool Done() const { + return done_; + } + + const Arc& Value() const { + return current_loop_ ? loop_ : matcher_->Value(); + } + + void Next() { + if (!current_loop_) { + matcher_->Next(); + done_ = matcher_->Done(); + if (done_ && multi_eps_iter_ != multi_eps_labels_.End()) { + ++multi_eps_iter_; + while ((multi_eps_iter_ != multi_eps_labels_.End()) && + !matcher_->Find(*multi_eps_iter_)) + ++multi_eps_iter_; + if (multi_eps_iter_ != multi_eps_labels_.End()) + done_ = false; + else + done_ = !matcher_->Find(kNoLabel); + + } + } else { + done_ = true; + } + } + + const FST &GetFst() const { return matcher_->GetFst(); } + + uint64 Properties(uint64 props) const { return matcher_->Properties(props); } + + uint32 Flags() const { return matcher_->Flags(); } + + void AddMultiEpsLabel(Label label) { + if (label == 0) { + FSTERROR() << "MultiEpsMatcher: Bad multi-eps label: 0"; + } else { + multi_eps_labels_.Insert(label); + } + } + + void RemoveMultiEpsLabel(Label label) { + if (label == 0) { + FSTERROR() << "MultiEpsMatcher: Bad multi-eps label: 0"; + } else { + multi_eps_labels_.Erase(label); + } + } + + void ClearMultiEpsLabels() { + multi_eps_labels_.Clear(); + } + +private: + M *matcher_; + uint32 flags_; + bool own_matcher_; // Does this class delete the matcher? + + // Multi-eps label set + CompactSet<Label, kNoLabel> multi_eps_labels_; + typename CompactSet<Label, kNoLabel>::const_iterator multi_eps_iter_; + + bool current_loop_; // Current arc is the implicit loop + mutable Arc loop_; // For non-consuming symbols + bool done_; // Matching done + + void operator=(const MultiEpsMatcher<M> &); // Disallow +}; + +template <class M> inline +bool MultiEpsMatcher<M>::Find(Label match_label) { + multi_eps_iter_ = multi_eps_labels_.End(); + current_loop_ = false; + bool ret; + if (match_label == 0) { + ret = matcher_->Find(0); + } else if (match_label == kNoLabel) { + if (flags_ & kMultiEpsList) { + // return all non-consuming arcs (incl. epsilon) + multi_eps_iter_ = multi_eps_labels_.Begin(); + while ((multi_eps_iter_ != multi_eps_labels_.End()) && + !matcher_->Find(*multi_eps_iter_)) + ++multi_eps_iter_; + if (multi_eps_iter_ != multi_eps_labels_.End()) + ret = true; + else + ret = matcher_->Find(kNoLabel); + } else { + // return all epsilon arcs + ret = matcher_->Find(kNoLabel); + } + } else if ((flags_ & kMultiEpsLoop) && + multi_eps_labels_.Find(match_label) != multi_eps_labels_.End()) { + // return 'implicit' loop + current_loop_ = true; + ret = true; + } else { + ret = matcher_->Find(match_label); + } + done_ = !ret; + return ret; +} + + +// Generic matcher, templated on the FST definition +// - a wrapper around pointer to specific one. +// Here is a typical use: \code +// Matcher<StdFst> matcher(fst, MATCH_INPUT); +// matcher.SetState(state); +// if (matcher.Find(label)) +// for (; !matcher.Done(); matcher.Next()) { +// StdArc &arc = matcher.Value(); +// ... +// } \endcode +template <class F> +class Matcher { + public: + typedef F FST; + typedef typename F::Arc Arc; + typedef typename Arc::StateId StateId; + typedef typename Arc::Label Label; + typedef typename Arc::Weight Weight; + + Matcher(const F &fst, MatchType match_type) { + base_ = fst.InitMatcher(match_type); + if (!base_) + base_ = new SortedMatcher<F>(fst, match_type); + } + + Matcher(const Matcher<F> &matcher, bool safe = false) { + base_ = matcher.base_->Copy(safe); + } + + // Takes ownership of the provided matcher + Matcher(MatcherBase<Arc>* base_matcher) { base_ = base_matcher; } + + ~Matcher() { delete base_; } + + Matcher<F> *Copy(bool safe = false) const { + return new Matcher<F>(*this, safe); + } + + MatchType Type(bool test) const { return base_->Type(test); } + void SetState(StateId s) { base_->SetState(s); } + bool Find(Label label) { return base_->Find(label); } + bool Done() const { return base_->Done(); } + const Arc& Value() const { return base_->Value(); } + void Next() { base_->Next(); } + const F &GetFst() const { return static_cast<const F &>(base_->GetFst()); } + uint64 Properties(uint64 props) const { return base_->Properties(props); } + uint32 Flags() const { return base_->Flags() & kMatcherFlags; } + + private: + MatcherBase<Arc> *base_; + + void operator=(const Matcher<Arc> &); // disallow +}; + +} // namespace fst + + + +#endif // FST_LIB_MATCHER_H__ diff --git a/kaldi_io/src/tools/openfst/include/fst/minimize.h b/kaldi_io/src/tools/openfst/include/fst/minimize.h new file mode 100644 index 0000000..6e9dd3d --- /dev/null +++ b/kaldi_io/src/tools/openfst/include/fst/minimize.h @@ -0,0 +1,591 @@ +// minimize.h +// minimize.h + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Copyright 2005-2010 Google, Inc. +// Author: [email protected] (Johan Schalkwyk) +// +// \file Functions and classes to minimize a finite state acceptor +// + +#ifndef FST_LIB_MINIMIZE_H__ +#define FST_LIB_MINIMIZE_H__ + +#include <cmath> + +#include <algorithm> +#include <map> +#include <queue> +#include <vector> +using std::vector; + +#include <fst/arcsort.h> +#include <fst/connect.h> +#include <fst/dfs-visit.h> +#include <fst/encode.h> +#include <fst/factor-weight.h> +#include <fst/fst.h> +#include <fst/mutable-fst.h> +#include <fst/partition.h> +#include <fst/push.h> +#include <fst/queue.h> +#include <fst/reverse.h> +#include <fst/state-map.h> + + +namespace fst { + +// comparator for creating partition based on sorting on +// - states +// - final weight +// - out degree, +// - (input label, output label, weight, destination_block) +template <class A> +class StateComparator { + public: + typedef typename A::StateId StateId; + typedef typename A::Weight Weight; + + static const uint32 kCompareFinal = 0x00000001; + static const uint32 kCompareOutDegree = 0x00000002; + static const uint32 kCompareArcs = 0x00000004; + static const uint32 kCompareAll = 0x00000007; + + StateComparator(const Fst<A>& fst, + const Partition<typename A::StateId>& partition, + uint32 flags = kCompareAll) + : fst_(fst), partition_(partition), flags_(flags) {} + + // compare state x with state y based on sort criteria + bool operator()(const StateId x, const StateId y) const { + // check for final state equivalence + if (flags_ & kCompareFinal) { + const size_t xfinal = fst_.Final(x).Hash(); + const size_t yfinal = fst_.Final(y).Hash(); + if (xfinal < yfinal) return true; + else if (xfinal > yfinal) return false; + } + + if (flags_ & kCompareOutDegree) { + // check for # arcs + if (fst_.NumArcs(x) < fst_.NumArcs(y)) return true; + if (fst_.NumArcs(x) > fst_.NumArcs(y)) return false; + + if (flags_ & kCompareArcs) { + // # arcs are equal, check for arc match + for (ArcIterator<Fst<A> > aiter1(fst_, x), aiter2(fst_, y); + !aiter1.Done() && !aiter2.Done(); aiter1.Next(), aiter2.Next()) { + const A& arc1 = aiter1.Value(); + const A& arc2 = aiter2.Value(); + if (arc1.ilabel < arc2.ilabel) return true; + if (arc1.ilabel > arc2.ilabel) return false; + + if (partition_.class_id(arc1.nextstate) < + partition_.class_id(arc2.nextstate)) return true; + if (partition_.class_id(arc1.nextstate) > + partition_.class_id(arc2.nextstate)) return false; + } + } + } + + return false; + } + + private: + const Fst<A>& fst_; + const Partition<typename A::StateId>& partition_; + const uint32 flags_; +}; + +template <class A> const uint32 StateComparator<A>::kCompareFinal; +template <class A> const uint32 StateComparator<A>::kCompareOutDegree; +template <class A> const uint32 StateComparator<A>::kCompareArcs; +template <class A> const uint32 StateComparator<A>::kCompareAll; + + +// Computes equivalence classes for cyclic Fsts. For cyclic minimization +// we use the classic HopCroft minimization algorithm, which is of +// +// O(E)log(N), +// +// where E is the number of edges in the machine and N is number of states. +// +// The following paper describes the original algorithm +// An N Log N algorithm for minimizing states in a finite automaton +// by John HopCroft, January 1971 +// +template <class A, class Queue> +class CyclicMinimizer { + public: + typedef typename A::Label Label; + typedef typename A::StateId StateId; + typedef typename A::StateId ClassId; + typedef typename A::Weight Weight; + typedef ReverseArc<A> RevA; + + CyclicMinimizer(const ExpandedFst<A>& fst): + // tell the Partition data-member to expect multiple repeated + // calls to SplitOn with the same element if we are non-deterministic. + P_(fst.Properties(kIDeterministic, true) == 0) { + if(fst.Properties(kIDeterministic, true) == 0) + CHECK(Weight::Properties() & kIdempotent); // this minimization + // algorithm for non-deterministic FSTs can only work with idempotent + // semirings. + Initialize(fst); + Compute(fst); + } + + ~CyclicMinimizer() { + delete aiter_queue_; + } + + const Partition<StateId>& partition() const { + return P_; + } + + // helper classes + private: + typedef ArcIterator<Fst<RevA> > ArcIter; + class ArcIterCompare { + public: + ArcIterCompare(const Partition<StateId>& partition) + : partition_(partition) {} + + ArcIterCompare(const ArcIterCompare& comp) + : partition_(comp.partition_) {} + + // compare two iterators based on there input labels, and proto state + // (partition class Ids) + bool operator()(const ArcIter* x, const ArcIter* y) const { + const RevA& xarc = x->Value(); + const RevA& yarc = y->Value(); + return (xarc.ilabel > yarc.ilabel); + } + + private: + const Partition<StateId>& partition_; + }; + + typedef priority_queue<ArcIter*, vector<ArcIter*>, ArcIterCompare> + ArcIterQueue; + + // helper methods + private: + // prepartitions the space into equivalence classes with + // same final weight + // same # arcs per state + // same outgoing arcs + void PrePartition(const Fst<A>& fst) { + VLOG(5) << "PrePartition"; + + typedef map<StateId, StateId, StateComparator<A> > EquivalenceMap; + StateComparator<A> comp(fst, P_, StateComparator<A>::kCompareFinal); + EquivalenceMap equiv_map(comp); + + StateIterator<Fst<A> > siter(fst); + StateId class_id = P_.AddClass(); + P_.Add(siter.Value(), class_id); + equiv_map[siter.Value()] = class_id; + L_.Enqueue(class_id); + for (siter.Next(); !siter.Done(); siter.Next()) { + StateId s = siter.Value(); + typename EquivalenceMap::const_iterator it = equiv_map.find(s); + if (it == equiv_map.end()) { + class_id = P_.AddClass(); + P_.Add(s, class_id); + equiv_map[s] = class_id; + L_.Enqueue(class_id); + } else { + P_.Add(s, it->second); + equiv_map[s] = it->second; + } + } + + VLOG(5) << "Initial Partition: " << P_.num_classes(); + } + + // - Create inverse transition Tr_ = rev(fst) + // - loop over states in fst and split on final, creating two blocks + // in the partition corresponding to final, non-final + void Initialize(const Fst<A>& fst) { + // construct Tr + Reverse(fst, &Tr_); + ILabelCompare<RevA> ilabel_comp; + ArcSort(&Tr_, ilabel_comp); + + // initial split (F, S - F) + P_.Initialize(Tr_.NumStates() - 1); + + // prep partition + PrePartition(fst); + + // allocate arc iterator queue + ArcIterCompare comp(P_); + aiter_queue_ = new ArcIterQueue(comp); + } + + // partition all classes with destination C + void Split(ClassId C) { + // Prep priority queue. Open arc iterator for each state in C, and + // insert into priority queue. + for (PartitionIterator<StateId> siter(P_, C); + !siter.Done(); siter.Next()) { + StateId s = siter.Value(); + if (Tr_.NumArcs(s + 1)) + aiter_queue_->push(new ArcIterator<Fst<RevA> >(Tr_, s + 1)); + } + + // Now pop arc iterator from queue, split entering equivalence class + // re-insert updated iterator into queue. + Label prev_label = -1; + while (!aiter_queue_->empty()) { + ArcIterator<Fst<RevA> >* aiter = aiter_queue_->top(); + aiter_queue_->pop(); + if (aiter->Done()) { + delete aiter; + continue; + } + + const RevA& arc = aiter->Value(); + StateId from_state = aiter->Value().nextstate - 1; + Label from_label = arc.ilabel; + if (prev_label != from_label) + P_.FinalizeSplit(&L_); + + StateId from_class = P_.class_id(from_state); + if (P_.class_size(from_class) > 1) + P_.SplitOn(from_state); + + prev_label = from_label; + aiter->Next(); + if (aiter->Done()) + delete aiter; + else + aiter_queue_->push(aiter); + } + P_.FinalizeSplit(&L_); + } + + // Main loop for hopcroft minimization. + void Compute(const Fst<A>& fst) { + // process active classes (FIFO, or FILO) + while (!L_.Empty()) { + ClassId C = L_.Head(); + L_.Dequeue(); + + // split on C, all labels in C + Split(C); + } + } + + // helper data + private: + // Partioning of states into equivalence classes + Partition<StateId> P_; + + // L = set of active classes to be processed in partition P + Queue L_; + + // reverse transition function + VectorFst<RevA> Tr_; + + // Priority queue of open arc iterators for all states in the 'splitter' + // equivalence class + ArcIterQueue* aiter_queue_; +}; + + +// Computes equivalence classes for acyclic Fsts. The implementation details +// for this algorithms is documented by the following paper. +// +// Minimization of acyclic deterministic automata in linear time +// Dominque Revuz +// +// Complexity O(|E|) +// +template <class A> +class AcyclicMinimizer { + public: + typedef typename A::Label Label; + typedef typename A::StateId StateId; + typedef typename A::StateId ClassId; + typedef typename A::Weight Weight; + + AcyclicMinimizer(const ExpandedFst<A>& fst): + // tell the Partition data-member to expect multiple repeated + // calls to SplitOn with the same element if we are non-deterministic. + partition_(fst.Properties(kIDeterministic, true) == 0) { + if(fst.Properties(kIDeterministic, true) == 0) + CHECK(Weight::Properties() & kIdempotent); // minimization for + // non-deterministic FSTs can only work with idempotent semirings. + Initialize(fst); + Refine(fst); + } + + const Partition<StateId>& partition() { + return partition_; + } + + // helper classes + private: + // DFS visitor to compute the height (distance) to final state. + class HeightVisitor { + public: + HeightVisitor() : max_height_(0), num_states_(0) { } + + // invoked before dfs visit + void InitVisit(const Fst<A>& fst) {} + + // invoked when state is discovered (2nd arg is DFS tree root) + bool InitState(StateId s, StateId root) { + // extend height array and initialize height (distance) to 0 + for (size_t i = height_.size(); i <= s; ++i) + height_.push_back(-1); + + if (s >= num_states_) num_states_ = s + 1; + return true; + } + + // invoked when tree arc examined (to undiscoverted state) + bool TreeArc(StateId s, const A& arc) { + return true; + } + + // invoked when back arc examined (to unfinished state) + bool BackArc(StateId s, const A& arc) { + return true; + } + + // invoked when forward or cross arc examined (to finished state) + bool ForwardOrCrossArc(StateId s, const A& arc) { + if (height_[arc.nextstate] + 1 > height_[s]) + height_[s] = height_[arc.nextstate] + 1; + return true; + } + + // invoked when state finished (parent is kNoStateId for tree root) + void FinishState(StateId s, StateId parent, const A* parent_arc) { + if (height_[s] == -1) height_[s] = 0; + StateId h = height_[s] + 1; + if (parent >= 0) { + if (h > height_[parent]) height_[parent] = h; + if (h > max_height_) max_height_ = h; + } + } + + // invoked after DFS visit + void FinishVisit() {} + + size_t max_height() const { return max_height_; } + + const vector<StateId>& height() const { return height_; } + + const size_t num_states() const { return num_states_; } + + private: + vector<StateId> height_; + size_t max_height_; + size_t num_states_; + }; + + // helper methods + private: + // cluster states according to height (distance to final state) + void Initialize(const Fst<A>& fst) { + // compute height (distance to final state) + HeightVisitor hvisitor; + DfsVisit(fst, &hvisitor); + + // create initial partition based on height + partition_.Initialize(hvisitor.num_states()); + partition_.AllocateClasses(hvisitor.max_height() + 1); + const vector<StateId>& hstates = hvisitor.height(); + for (size_t s = 0; s < hstates.size(); ++s) + partition_.Add(s, hstates[s]); + } + + // refine states based on arc sort (out degree, arc equivalence) + void Refine(const Fst<A>& fst) { + typedef map<StateId, StateId, StateComparator<A> > EquivalenceMap; + StateComparator<A> comp(fst, partition_); + + // start with tail (height = 0) + size_t height = partition_.num_classes(); + for (size_t h = 0; h < height; ++h) { + EquivalenceMap equiv_classes(comp); + + // sort states within equivalence class + PartitionIterator<StateId> siter(partition_, h); + equiv_classes[siter.Value()] = h; + for (siter.Next(); !siter.Done(); siter.Next()) { + const StateId s = siter.Value(); + typename EquivalenceMap::const_iterator it = equiv_classes.find(s); + if (it == equiv_classes.end()) + equiv_classes[s] = partition_.AddClass(); + else + equiv_classes[s] = it->second; + } + + // create refined partition + for (siter.Reset(); !siter.Done();) { + const StateId s = siter.Value(); + const StateId old_class = partition_.class_id(s); + const StateId new_class = equiv_classes[s]; + + // a move operation can invalidate the iterator, so + // we first update the iterator to the next element + // before we move the current element out of the list + siter.Next(); + if (old_class != new_class) + partition_.Move(s, new_class); + } + } + } + + private: + Partition<StateId> partition_; +}; + + +// Given a partition and a mutable fst, merge states of Fst inplace +// (i.e. destructively). Merging works by taking the first state in +// a class of the partition to be the representative state for the class. +// Each arc is then reconnected to this state. All states in the class +// are merged by adding there arcs to the representative state. +template <class A> +void MergeStates( + const Partition<typename A::StateId>& partition, MutableFst<A>* fst) { + typedef typename A::StateId StateId; + + vector<StateId> state_map(partition.num_classes()); + for (size_t i = 0; i < partition.num_classes(); ++i) { + PartitionIterator<StateId> siter(partition, i); + state_map[i] = siter.Value(); // first state in partition; + } + + // relabel destination states + for (size_t c = 0; c < partition.num_classes(); ++c) { + for (PartitionIterator<StateId> siter(partition, c); + !siter.Done(); siter.Next()) { + StateId s = siter.Value(); + for (MutableArcIterator<MutableFst<A> > aiter(fst, s); + !aiter.Done(); aiter.Next()) { + A arc = aiter.Value(); + arc.nextstate = state_map[partition.class_id(arc.nextstate)]; + + if (s == state_map[c]) // first state just set destination + aiter.SetValue(arc); + else + fst->AddArc(state_map[c], arc); + } + } + } + fst->SetStart(state_map[partition.class_id(fst->Start())]); + + Connect(fst); +} + +template <class A> +void AcceptorMinimize(MutableFst<A>* fst) { + typedef typename A::StateId StateId; + if (!(fst->Properties(kAcceptor | kUnweighted, true))) { + FSTERROR() << "FST is not an unweighted acceptor"; + fst->SetProperties(kError, kError); + return; + } + + // connect fst before minimization, handles disconnected states + Connect(fst); + if (fst->NumStates() == 0) return; + + if (fst->Properties(kAcyclic, true)) { + // Acyclic minimization (revuz) + VLOG(2) << "Acyclic Minimization"; + ArcSort(fst, ILabelCompare<A>()); + AcyclicMinimizer<A> minimizer(*fst); + MergeStates(minimizer.partition(), fst); + + } else { + // Cyclic minimizaton (hopcroft) + VLOG(2) << "Cyclic Minimization"; + CyclicMinimizer<A, LifoQueue<StateId> > minimizer(*fst); + MergeStates(minimizer.partition(), fst); + } + + // Merge in appropriate semiring + ArcUniqueMapper<A> mapper(*fst); + StateMap(fst, mapper); +} + + +// In place minimization of deterministic weighted automata and transducers. +// For transducers, then the 'sfst' argument is not null, the algorithm +// produces a compact factorization of the minimal transducer. +// +// In the acyclic case, we use an algorithm from Dominique Revuz that +// is linear in the number of arcs (edges) in the machine. +// Complexity = O(E) +// +// In the cyclic case, we use the classical hopcroft minimization. +// Complexity = O(|E|log(|N|) +// +template <class A> +void Minimize(MutableFst<A>* fst, + MutableFst<A>* sfst = 0, + float delta = kDelta) { + uint64 props = fst->Properties(kAcceptor | kWeighted | kUnweighted, true); + + if (!(props & kAcceptor)) { // weighted transducer + VectorFst< GallicArc<A, STRING_LEFT> > gfst; + ArcMap(*fst, &gfst, ToGallicMapper<A, STRING_LEFT>()); + fst->DeleteStates(); + gfst.SetProperties(kAcceptor, kAcceptor); + Push(&gfst, REWEIGHT_TO_INITIAL, delta); + ArcMap(&gfst, QuantizeMapper< GallicArc<A, STRING_LEFT> >(delta)); + EncodeMapper< GallicArc<A, STRING_LEFT> > + encoder(kEncodeLabels | kEncodeWeights, ENCODE); + Encode(&gfst, &encoder); + AcceptorMinimize(&gfst); + Decode(&gfst, encoder); + + if (sfst == 0) { + FactorWeightFst< GallicArc<A, STRING_LEFT>, + GallicFactor<typename A::Label, + typename A::Weight, STRING_LEFT> > fwfst(gfst); + SymbolTable *osyms = fst->OutputSymbols() ? + fst->OutputSymbols()->Copy() : 0; + ArcMap(fwfst, fst, FromGallicMapper<A, STRING_LEFT>()); + fst->SetOutputSymbols(osyms); + delete osyms; + } else { + sfst->SetOutputSymbols(fst->OutputSymbols()); + GallicToNewSymbolsMapper<A, STRING_LEFT> mapper(sfst); + ArcMap(gfst, fst, &mapper); + fst->SetOutputSymbols(sfst->InputSymbols()); + } + } else if (props & kWeighted) { // weighted acceptor + Push(fst, REWEIGHT_TO_INITIAL, delta); + ArcMap(fst, QuantizeMapper<A>(delta)); + EncodeMapper<A> encoder(kEncodeLabels | kEncodeWeights, ENCODE); + Encode(fst, &encoder); + AcceptorMinimize(fst); + Decode(fst, encoder); + } else { // unweighted acceptor + AcceptorMinimize(fst); + } +} + +} // namespace fst + +#endif // FST_LIB_MINIMIZE_H__ diff --git a/kaldi_io/src/tools/openfst/include/fst/mutable-fst.h b/kaldi_io/src/tools/openfst/include/fst/mutable-fst.h new file mode 100644 index 0000000..09eb237 --- /dev/null +++ b/kaldi_io/src/tools/openfst/include/fst/mutable-fst.h @@ -0,0 +1,378 @@ +// mutable-fst.h + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Copyright 2005-2010 Google, Inc. +// Author: [email protected] (Michael Riley) +// +// \file +// Expanded FST augmented with mutators - interface class definition +// and mutable arc iterator interface. +// + +#ifndef FST_LIB_MUTABLE_FST_H__ +#define FST_LIB_MUTABLE_FST_H__ + +#include <stddef.h> +#include <sys/types.h> +#include <string> +#include <vector> +using std::vector; + +#include <fst/expanded-fst.h> + + +namespace fst { + +template <class A> class MutableArcIteratorData; + +// An expanded FST plus mutators (use MutableArcIterator to modify arcs). +template <class A> +class MutableFst : public ExpandedFst<A> { + public: + typedef A Arc; + typedef typename A::Weight Weight; + typedef typename A::StateId StateId; + + virtual MutableFst<A> &operator=(const Fst<A> &fst) = 0; + + MutableFst<A> &operator=(const MutableFst<A> &fst) { + return operator=(static_cast<const Fst<A> &>(fst)); + } + + virtual void SetStart(StateId) = 0; // Set the initial state + virtual void SetFinal(StateId, Weight) = 0; // Set a state's final weight + virtual void SetProperties(uint64 props, + uint64 mask) = 0; // Set property bits wrt mask + + virtual StateId AddState() = 0; // Add a state, return its ID + virtual void AddArc(StateId, const A &arc) = 0; // Add an arc to state + + virtual void DeleteStates(const vector<StateId>&) = 0; // Delete some states + virtual void DeleteStates() = 0; // Delete all states + virtual void DeleteArcs(StateId, size_t n) = 0; // Delete some arcs at state + virtual void DeleteArcs(StateId) = 0; // Delete all arcs at state + + virtual void ReserveStates(StateId n) { } // Optional, best effort only. + virtual void ReserveArcs(StateId s, size_t n) { } // Optional, Best effort. + + // Return input label symbol table; return NULL if not specified + virtual const SymbolTable* InputSymbols() const = 0; + // Return output label symbol table; return NULL if not specified + virtual const SymbolTable* OutputSymbols() const = 0; + + // Return input label symbol table; return NULL if not specified + virtual SymbolTable* MutableInputSymbols() = 0; + // Return output label symbol table; return NULL if not specified + virtual SymbolTable* MutableOutputSymbols() = 0; + + // Set input label symbol table; NULL signifies not unspecified + virtual void SetInputSymbols(const SymbolTable* isyms) = 0; + // Set output label symbol table; NULL signifies not unspecified + virtual void SetOutputSymbols(const SymbolTable* osyms) = 0; + + // Get a copy of this MutableFst. See Fst<>::Copy() for further doc. + virtual MutableFst<A> *Copy(bool safe = false) const = 0; + + // Read an MutableFst from an input stream; return NULL on error. + static MutableFst<A> *Read(istream &strm, const FstReadOptions &opts) { + FstReadOptions ropts(opts); + FstHeader hdr; + if (ropts.header) + hdr = *opts.header; + else { + if (!hdr.Read(strm, opts.source)) + return 0; + ropts.header = &hdr; + } + if (!(hdr.Properties() & kMutable)) { + LOG(ERROR) << "MutableFst::Read: Not an MutableFst: " << ropts.source; + return 0; + } + FstRegister<A> *registr = FstRegister<A>::GetRegister(); + const typename FstRegister<A>::Reader reader = + registr->GetReader(hdr.FstType()); + if (!reader) { + LOG(ERROR) << "MutableFst::Read: Unknown FST type \"" << hdr.FstType() + << "\" (arc type = \"" << A::Type() + << "\"): " << ropts.source; + return 0; + } + Fst<A> *fst = reader(strm, ropts); + if (!fst) return 0; + return static_cast<MutableFst<A> *>(fst); + } + + // Read a MutableFst from a file; return NULL on error. + // Empty filename reads from standard input. If 'convert' is true, + // convert to a mutable FST of type 'convert_type' if file is + // a non-mutable FST. + static MutableFst<A> *Read(const string &filename, bool convert = false, + const string &convert_type = "vector") { + if (convert == false) { + if (!filename.empty()) { + ifstream strm(filename.c_str(), ifstream::in | ifstream::binary); + if (!strm) { + LOG(ERROR) << "MutableFst::Read: Can't open file: " << filename; + return 0; + } + return Read(strm, FstReadOptions(filename)); + } else { + return Read(cin, FstReadOptions("standard input")); + } + } else { // Converts to 'convert_type' if not mutable. + Fst<A> *ifst = Fst<A>::Read(filename); + if (!ifst) return 0; + if (ifst->Properties(kMutable, false)) { + return static_cast<MutableFst *>(ifst); + } else { + Fst<A> *ofst = Convert(*ifst, convert_type); + delete ifst; + if (!ofst) return 0; + if (!ofst->Properties(kMutable, false)) + LOG(ERROR) << "MutableFst: bad convert type: " << convert_type; + return static_cast<MutableFst *>(ofst); + } + } + } + + // For generic mutuble arc iterator construction; not normally called + // directly by users. + virtual void InitMutableArcIterator(StateId s, + MutableArcIteratorData<A> *) = 0; +}; + +// Mutable arc iterator interface, templated on the Arc definition; used +// for mutable Arc iterator specializations that are returned by +// the InitMutableArcIterator MutableFst method. +template <class A> +class MutableArcIteratorBase : public ArcIteratorBase<A> { + public: + typedef A Arc; + + void SetValue(const A &arc) { SetValue_(arc); } // Set current arc's content + + private: + virtual void SetValue_(const A &arc) = 0; +}; + +template <class A> +struct MutableArcIteratorData { + MutableArcIteratorBase<A> *base; // Specific iterator +}; + +// Generic mutable arc iterator, templated on the FST definition +// - a wrapper around pointer to specific one. +// Here is a typical use: \code +// for (MutableArcIterator<StdFst> aiter(&fst, s)); +// !aiter.Done(); +// aiter.Next()) { +// StdArc arc = aiter.Value(); +// arc.ilabel = 7; +// aiter.SetValue(arc); +// ... +// } \endcode +// This version requires function calls. +template <class F> +class MutableArcIterator { + public: + typedef F FST; + typedef typename F::Arc Arc; + typedef typename Arc::StateId StateId; + + MutableArcIterator(F *fst, StateId s) { + fst->InitMutableArcIterator(s, &data_); + } + ~MutableArcIterator() { delete data_.base; } + + bool Done() const { return data_.base->Done(); } + const Arc& Value() const { return data_.base->Value(); } + void Next() { data_.base->Next(); } + size_t Position() const { return data_.base->Position(); } + void Reset() { data_.base->Reset(); } + void Seek(size_t a) { data_.base->Seek(a); } + void SetValue(const Arc &a) { data_.base->SetValue(a); } + uint32 Flags() const { return data_.base->Flags(); } + void SetFlags(uint32 f, uint32 m) { + return data_.base->SetFlags(f, m); + } + + private: + MutableArcIteratorData<Arc> data_; + DISALLOW_COPY_AND_ASSIGN(MutableArcIterator); +}; + + +namespace internal { + +// MutableFst<A> case - abstract methods. +template <class A> inline +typename A::Weight Final(const MutableFst<A> &fst, typename A::StateId s) { + return fst.Final(s); +} + +template <class A> inline +ssize_t NumArcs(const MutableFst<A> &fst, typename A::StateId s) { + return fst.NumArcs(s); +} + +template <class A> inline +ssize_t NumInputEpsilons(const MutableFst<A> &fst, typename A::StateId s) { + return fst.NumInputEpsilons(s); +} + +template <class A> inline +ssize_t NumOutputEpsilons(const MutableFst<A> &fst, typename A::StateId s) { + return fst.NumOutputEpsilons(s); +} + +} // namespace internal + + +// A useful alias when using StdArc. +typedef MutableFst<StdArc> StdMutableFst; + + +// This is a helper class template useful for attaching a MutableFst +// interface to its implementation, handling reference counting and +// copy-on-write. +template <class I, class F = MutableFst<typename I::Arc> > +class ImplToMutableFst : public ImplToExpandedFst<I, F> { + public: + typedef typename I::Arc Arc; + typedef typename Arc::Weight Weight; + typedef typename Arc::StateId StateId; + + using ImplToFst<I, F>::GetImpl; + using ImplToFst<I, F>::SetImpl; + + virtual void SetStart(StateId s) { + MutateCheck(); + GetImpl()->SetStart(s); + } + + virtual void SetFinal(StateId s, Weight w) { + MutateCheck(); + GetImpl()->SetFinal(s, w); + } + + virtual void SetProperties(uint64 props, uint64 mask) { + // Can skip mutate check if extrinsic properties don't change, + // since it is then safe to update all (shallow) copies + uint64 exprops = kExtrinsicProperties & mask; + if (GetImpl()->Properties(exprops) != (props & exprops)) + MutateCheck(); + GetImpl()->SetProperties(props, mask); + } + + virtual StateId AddState() { + MutateCheck(); + return GetImpl()->AddState(); + } + + virtual void AddArc(StateId s, const Arc &arc) { + MutateCheck(); + GetImpl()->AddArc(s, arc); + } + + virtual void DeleteStates(const vector<StateId> &dstates) { + MutateCheck(); + GetImpl()->DeleteStates(dstates); + } + + virtual void DeleteStates() { + MutateCheck(); + GetImpl()->DeleteStates(); + } + + virtual void DeleteArcs(StateId s, size_t n) { + MutateCheck(); + GetImpl()->DeleteArcs(s, n); + } + + virtual void DeleteArcs(StateId s) { + MutateCheck(); + GetImpl()->DeleteArcs(s); + } + + virtual void ReserveStates(StateId s) { + MutateCheck(); + GetImpl()->ReserveStates(s); + } + + virtual void ReserveArcs(StateId s, size_t n) { + MutateCheck(); + GetImpl()->ReserveArcs(s, n); + } + + virtual const SymbolTable* InputSymbols() const { + return GetImpl()->InputSymbols(); + } + + virtual const SymbolTable* OutputSymbols() const { + return GetImpl()->OutputSymbols(); + } + + virtual SymbolTable* MutableInputSymbols() { + MutateCheck(); + return GetImpl()->InputSymbols(); + } + + virtual SymbolTable* MutableOutputSymbols() { + MutateCheck(); + return GetImpl()->OutputSymbols(); + } + + virtual void SetInputSymbols(const SymbolTable* isyms) { + MutateCheck(); + GetImpl()->SetInputSymbols(isyms); + } + + virtual void SetOutputSymbols(const SymbolTable* osyms) { + MutateCheck(); + GetImpl()->SetOutputSymbols(osyms); + } + + protected: + ImplToMutableFst() : ImplToExpandedFst<I, F>() {} + + ImplToMutableFst(I *impl) : ImplToExpandedFst<I, F>(impl) {} + + + ImplToMutableFst(const ImplToMutableFst<I, F> &fst) + : ImplToExpandedFst<I, F>(fst) {} + + ImplToMutableFst(const ImplToMutableFst<I, F> &fst, bool safe) + : ImplToExpandedFst<I, F>(fst, safe) {} + + void MutateCheck() { + // Copy on write + if (GetImpl()->RefCount() > 1) + SetImpl(new I(*this)); + } + + private: + // Disallow + ImplToMutableFst<I, F> &operator=(const ImplToMutableFst<I, F> &fst); + + ImplToMutableFst<I, F> &operator=(const Fst<Arc> &fst) { + FSTERROR() << "ImplToMutableFst: Assignment operator disallowed"; + GetImpl()->SetProperties(kError, kError); + return *this; + } +}; + + +} // namespace fst + +#endif // FST_LIB_MUTABLE_FST_H__ diff --git a/kaldi_io/src/tools/openfst/include/fst/pair-weight.h b/kaldi_io/src/tools/openfst/include/fst/pair-weight.h new file mode 100644 index 0000000..7d8aa11 --- /dev/null +++ b/kaldi_io/src/tools/openfst/include/fst/pair-weight.h @@ -0,0 +1,280 @@ +// pair-weight.h + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Copyright 2005-2010 Google, Inc. +// Author: [email protected] (Masha Maria Shugrina) +// +// \file +// Pair weight templated base class for weight classes that +// contain two weights (e.g. Product, Lexicographic) + +#ifndef FST_LIB_PAIR_WEIGHT_H_ +#define FST_LIB_PAIR_WEIGHT_H_ + +#include <climits> +#include <stack> +#include <string> + +#include <fst/weight.h> + + +DECLARE_string(fst_weight_parentheses); +DECLARE_string(fst_weight_separator); + +namespace fst { + +template<class W1, class W2> class PairWeight; +template <class W1, class W2> +istream &operator>>(istream &strm, PairWeight<W1, W2> &w); + +template<class W1, class W2> +class PairWeight { + public: + friend istream &operator>><W1, W2>(istream&, PairWeight<W1, W2>&); + + typedef PairWeight<typename W1::ReverseWeight, + typename W2::ReverseWeight> + ReverseWeight; + + PairWeight() {} + + PairWeight(const PairWeight& w) : value1_(w.value1_), value2_(w.value2_) {} + + PairWeight(W1 w1, W2 w2) : value1_(w1), value2_(w2) {} + + static const PairWeight<W1, W2> &Zero() { + static const PairWeight<W1, W2> zero(W1::Zero(), W2::Zero()); + return zero; + } + + static const PairWeight<W1, W2> &One() { + static const PairWeight<W1, W2> one(W1::One(), W2::One()); + return one; + } + + static const PairWeight<W1, W2> &NoWeight() { + static const PairWeight<W1, W2> no_weight(W1::NoWeight(), W2::NoWeight()); + return no_weight; + } + + istream &Read(istream &strm) { + value1_.Read(strm); + return value2_.Read(strm); + } + + ostream &Write(ostream &strm) const { + value1_.Write(strm); + return value2_.Write(strm); + } + + PairWeight<W1, W2> &operator=(const PairWeight<W1, W2> &w) { + value1_ = w.Value1(); + value2_ = w.Value2(); + return *this; + } + + bool Member() const { return value1_.Member() && value2_.Member(); } + + size_t Hash() const { + size_t h1 = value1_.Hash(); + size_t h2 = value2_.Hash(); + const int lshift = 5; + const int rshift = CHAR_BIT * sizeof(size_t) - 5; + return h1 << lshift ^ h1 >> rshift ^ h2; + } + + PairWeight<W1, W2> Quantize(float delta = kDelta) const { + return PairWeight<W1, W2>(value1_.Quantize(delta), + value2_.Quantize(delta)); + } + + ReverseWeight Reverse() const { + return ReverseWeight(value1_.Reverse(), value2_.Reverse()); + } + + const W1& Value1() const { return value1_; } + + const W2& Value2() const { return value2_; } + + protected: + void SetValue1(const W1 &w) { value1_ = w; } + void SetValue2(const W2 &w) { value2_ = w; } + + // Reads PairWeight when there are not parentheses around pair terms + inline static istream &ReadNoParen( + istream &strm, PairWeight<W1, W2>& w, char separator) { + int c; + do { + c = strm.get(); + } while (isspace(c)); + + string s1; + while (c != separator) { + if (c == EOF) { + strm.clear(std::ios::badbit); + return strm; + } + s1 += c; + c = strm.get(); + } + istringstream strm1(s1); + W1 w1 = W1::Zero(); + strm1 >> w1; + + // read second element + W2 w2 = W2::Zero(); + strm >> w2; + + w = PairWeight<W1, W2>(w1, w2); + return strm; + } + + // Reads PairWeight when there are parentheses around pair terms + inline static istream &ReadWithParen( + istream &strm, PairWeight<W1, W2>& w, + char separator, char open_paren, char close_paren) { + int c; + do { + c = strm.get(); + } while (isspace(c)); + if (c != open_paren) { + FSTERROR() << " is fst_weight_parentheses flag set correcty? "; + strm.clear(std::ios::failbit); + return strm; + } + c = strm.get(); + + // read first element + stack<int> parens; + string s1; + while (c != separator || !parens.empty()) { + if (c == EOF) { + strm.clear(std::ios::badbit); + return strm; + } + s1 += c; + // if parens encountered before separator, they must be matched + if (c == open_paren) { + parens.push(1); + } else if (c == close_paren) { + // Fail for mismatched parens + if (parens.empty()) { + strm.clear(std::ios::failbit); + return strm; + } + parens.pop(); + } + c = strm.get(); + } + istringstream strm1(s1); + W1 w1 = W1::Zero(); + strm1 >> w1; + + // read second element + string s2; + c = strm.get(); + while (c != EOF) { + s2 += c; + c = strm.get(); + } + if (s2.empty() || (s2[s2.size() - 1] != close_paren)) { + FSTERROR() << " is fst_weight_parentheses flag set correcty? "; + strm.clear(std::ios::failbit); + return strm; + } + + s2.erase(s2.size() - 1, 1); + istringstream strm2(s2); + W2 w2 = W2::Zero(); + strm2 >> w2; + + w = PairWeight<W1, W2>(w1, w2); + return strm; + } + + private: + W1 value1_; + W2 value2_; + +}; + +template <class W1, class W2> +inline bool operator==(const PairWeight<W1, W2> &w, + const PairWeight<W1, W2> &v) { + return w.Value1() == v.Value1() && w.Value2() == v.Value2(); +} + +template <class W1, class W2> +inline bool operator!=(const PairWeight<W1, W2> &w1, + const PairWeight<W1, W2> &w2) { + return w1.Value1() != w2.Value1() || w1.Value2() != w2.Value2(); +} + + +template <class W1, class W2> +inline bool ApproxEqual(const PairWeight<W1, W2> &w1, + const PairWeight<W1, W2> &w2, + float delta = kDelta) { + return ApproxEqual(w1.Value1(), w2.Value1(), delta) && + ApproxEqual(w1.Value2(), w2.Value2(), delta); +} + +template <class W1, class W2> +inline ostream &operator<<(ostream &strm, const PairWeight<W1, W2> &w) { + if(FLAGS_fst_weight_separator.size() != 1) { + FSTERROR() << "FLAGS_fst_weight_separator.size() is not equal to 1"; + strm.clear(std::ios::badbit); + return strm; + } + char separator = FLAGS_fst_weight_separator[0]; + if (FLAGS_fst_weight_parentheses.empty()) + return strm << w.Value1() << separator << w.Value2(); + + if (FLAGS_fst_weight_parentheses.size() != 2) { + FSTERROR() << "FLAGS_fst_weight_parentheses.size() is not equal to 2"; + strm.clear(std::ios::badbit); + return strm; + } + char open_paren = FLAGS_fst_weight_parentheses[0]; + char close_paren = FLAGS_fst_weight_parentheses[1]; + return strm << open_paren << w.Value1() << separator + << w.Value2() << close_paren ; +} + +template <class W1, class W2> +inline istream &operator>>(istream &strm, PairWeight<W1, W2> &w) { + if(FLAGS_fst_weight_separator.size() != 1) { + FSTERROR() << "FLAGS_fst_weight_separator.size() is not equal to 1"; + strm.clear(std::ios::badbit); + return strm; + } + char separator = FLAGS_fst_weight_separator[0]; + bool read_parens = !FLAGS_fst_weight_parentheses.empty(); + if (read_parens) { + if (FLAGS_fst_weight_parentheses.size() != 2) { + FSTERROR() << "FLAGS_fst_weight_parentheses.size() is not equal to 2"; + strm.clear(std::ios::badbit); + return strm; + } + return PairWeight<W1, W2>::ReadWithParen( + strm, w, separator, FLAGS_fst_weight_parentheses[0], + FLAGS_fst_weight_parentheses[1]); + } else { + return PairWeight<W1, W2>::ReadNoParen(strm, w, separator); + } +} + +} // namespace fst + +#endif // FST_LIB_PAIR_WEIGHT_H_ diff --git a/kaldi_io/src/tools/openfst/include/fst/partition.h b/kaldi_io/src/tools/openfst/include/fst/partition.h new file mode 100644 index 0000000..40b849a --- /dev/null +++ b/kaldi_io/src/tools/openfst/include/fst/partition.h @@ -0,0 +1,305 @@ +// partition.h + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Copyright 2005-2010 Google, Inc. +// Author: [email protected] (Johan Schalkwyk) +// +// \file Functions and classes to create a partition of states +// + +#ifndef FST_LIB_PARTITION_H__ +#define FST_LIB_PARTITION_H__ + +#include <vector> +using std::vector; +#include <algorithm> + + +#include <fst/queue.h> + + + +namespace fst { + +template <typename T> class PartitionIterator; + +// \class Partition +// \brief Defines a partitioning of states. Typically used to represent +// equivalence classes for Fst operations like minimization. +// +template <typename T> +class Partition { + friend class PartitionIterator<T>; + + struct Element { + Element() : value(0), next(0), prev(0) {} + Element(T v) : value(v), next(0), prev(0) {} + + T value; + Element* next; + Element* prev; + }; + + public: + Partition(bool allow_repeated_split): + allow_repeated_split_(allow_repeated_split) {} + + Partition(bool allow_repeated_split, T num_states): + allow_repeated_split_(allow_repeated_split) { + Initialize(num_states); + } + + ~Partition() { + for (size_t i = 0; i < elements_.size(); ++i) + delete elements_[i]; + } + + // Create an empty partition for num_states. At initialization time + // all elements are not assigned to a class (i.e class_index = -1). + // Initialize just creates num_states of elements. All element + // operations are then done by simply disconnecting the element from + // it current class and placing it at the head of the next class. + void Initialize(size_t num_states) { + for (size_t i = 0; i < elements_.size(); ++i) + delete elements_[i]; + elements_.clear(); + classes_.clear(); + class_index_.clear(); + + elements_.resize(num_states); + class_index_.resize(num_states, -1); + class_size_.reserve(num_states); + for (size_t i = 0; i < num_states; ++i) + elements_[i] = new Element(i); + num_states_ = num_states; + } + + // Add a class, resize classes_ and class_size_ resource by 1. + size_t AddClass() { + size_t num_classes = classes_.size(); + classes_.resize(num_classes + 1, 0); + class_size_.resize(num_classes + 1, 0); + class_split_.resize(num_classes + 1, 0); + split_size_.resize(num_classes + 1, 0); + return num_classes; + } + + void AllocateClasses(T num_classes) { + size_t n = classes_.size() + num_classes; + classes_.resize(n, 0); + class_size_.resize(n, 0); + class_split_.resize(n, 0); + split_size_.resize(n, 0); + } + + // Add element_id to class_id. The Add method is used to initialize + // partition. Once elements have been added to a class, you need to + // use the Move() method move an element from once class to another. + void Add(T element_id, T class_id) { + Element* element = elements_[element_id]; + + if (classes_[class_id]) + classes_[class_id]->prev = element; + element->next = classes_[class_id]; + element->prev = 0; + classes_[class_id] = element; + + class_index_[element_id] = class_id; + class_size_[class_id]++; + } + + // Move and element_id to class_id. Disconnects (removes) element + // from it current class and + void Move(T element_id, T class_id) { + T old_class_id = class_index_[element_id]; + + Element* element = elements_[element_id]; + if (element->next) element->next->prev = element->prev; + if (element->prev) element->prev->next = element->next; + else classes_[old_class_id] = element->next; + + Add(element_id, class_id); + class_size_[old_class_id]--; + } + + // split class on the element_id + void SplitOn(T element_id) { + T class_id = class_index_[element_id]; + if (class_size_[class_id] == 1) return; + + // first time class is split + if (split_size_[class_id] == 0) { + visited_classes_.push_back(class_id); + class_split_[class_id] = classes_[class_id]; + } + // increment size of split (set of element at head of chain) + split_size_[class_id]++; + + // update split point + if (class_split_[class_id] != 0 + && class_split_[class_id] == elements_[element_id]) + class_split_[class_id] = elements_[element_id]->next; + + // move to head of chain in same class + Move(element_id, class_id); + } + + // Finalize class_id, split if required, and update class_splits, + // class indices of the newly created class. Returns the new_class id + // or -1 if no new class was created. + T SplitRefine(T class_id) { + + Element* split_el = class_split_[class_id]; + // only split if necessary + //if (class_size_[class_id] == split_size_[class_id]) { + if(split_el == NULL) { // we split on everything... + split_size_[class_id] = 0; + return -1; + } else { + T new_class = AddClass(); + + if(allow_repeated_split_) { // split_size_ is possibly + // inaccurate, so work it out exactly. + size_t split_count; Element *e; + for(split_count=0,e=classes_[class_id]; + e != split_el; split_count++, e=e->next); + split_size_[class_id] = split_count; + } + size_t remainder = class_size_[class_id] - split_size_[class_id]; + if (remainder < split_size_[class_id]) { // add smaller + classes_[new_class] = split_el; + split_el->prev->next = 0; + split_el->prev = 0; + class_size_[class_id] = split_size_[class_id]; + class_size_[new_class] = remainder; + } else { + classes_[new_class] = classes_[class_id]; + class_size_[class_id] = remainder; + class_size_[new_class] = split_size_[class_id]; + split_el->prev->next = 0; + split_el->prev = 0; + classes_[class_id] = split_el; + } + + // update class index for element in new class + for (Element* el = classes_[new_class]; el; el = el->next) + class_index_[el->value] = new_class; + + class_split_[class_id] = 0; + split_size_[class_id] = 0; + + return new_class; + } + } + + // Once all states have been processed for a particular class C, we + // can finalize the split. FinalizeSplit() will update each block in the + // partition, create new once and update the queue of active classes + // that require further refinement. + template <class Queue> + void FinalizeSplit(Queue* L) { + for (size_t i = 0; i < visited_classes_.size(); ++i) { + T new_class = SplitRefine(visited_classes_[i]); + if (new_class != -1 && L) + L->Enqueue(new_class); + } + visited_classes_.clear(); + } + + + const T class_id(T element_id) const { + return class_index_[element_id]; + } + + const vector<T>& class_sizes() const { + return class_size_; + } + + const size_t class_size(T class_id) const { + return class_size_[class_id]; + } + + const T num_classes() const { + return classes_.size(); + } + + + private: + int num_states_; + + // container of all elements (owner of ptrs) + vector<Element*> elements_; + + // linked list of elements belonging to class + vector<Element*> classes_; + + // pointer to split point for each class + vector<Element*> class_split_; + + // class index of element + vector<T> class_index_; + + // class sizes + vector<T> class_size_; + + // size of split for each class + // in the nondeterministic case, split_size_ is actually an upper + // bound on the size of split for each class. + vector<T> split_size_; + + // set of visited classes to be used in split refine + vector<T> visited_classes_; + + // true if input fst was deterministic: we can make + // certain assumptions in this case that speed up the algorithm. + bool allow_repeated_split_; +}; + + +// iterate over members of a class in a partition +template <typename T> +class PartitionIterator { + typedef typename Partition<T>::Element Element; + public: + PartitionIterator(const Partition<T>& partition, T class_id) + : p_(partition), + element_(p_.classes_[class_id]), + class_id_(class_id) {} + + bool Done() { + return (element_ == 0); + } + + const T Value() { + return (element_->value); + } + + void Next() { + element_ = element_->next; + } + + void Reset() { + element_ = p_.classes_[class_id_]; + } + + private: + const Partition<T>& p_; + + const Element* element_; + + T class_id_; +}; +} // namespace fst + +#endif // FST_LIB_PARTITION_H__ diff --git a/kaldi_io/src/tools/openfst/include/fst/power-weight.h b/kaldi_io/src/tools/openfst/include/fst/power-weight.h new file mode 100644 index 0000000..256928d --- /dev/null +++ b/kaldi_io/src/tools/openfst/include/fst/power-weight.h @@ -0,0 +1,159 @@ +// power-weight.h + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Copyright 2005-2010 Google, Inc. +// Author: [email protected] (Cyril Allauzen) +// +// \file +// Cartesian power weight semiring operation definitions. + +#ifndef FST_LIB_POWER_WEIGHT_H__ +#define FST_LIB_POWER_WEIGHT_H__ + +#include <fst/tuple-weight.h> +#include <fst/weight.h> + + +namespace fst { + +// Cartesian power semiring: W ^ n +// Forms: +// - a left semimodule when W is a left semiring, +// - a right semimodule when W is a right semiring, +// - a bisemimodule when W is a semiring, +// the free semimodule of rank n over W +// The Times operation is overloaded to provide the +// left and right scalar products. +template <class W, unsigned int n> +class PowerWeight : public TupleWeight<W, n> { + public: + using TupleWeight<W, n>::Zero; + using TupleWeight<W, n>::One; + using TupleWeight<W, n>::NoWeight; + using TupleWeight<W, n>::Quantize; + using TupleWeight<W, n>::Reverse; + + typedef PowerWeight<typename W::ReverseWeight, n> ReverseWeight; + + PowerWeight() {} + + PowerWeight(const TupleWeight<W, n> &w) : TupleWeight<W, n>(w) {} + + template <class Iterator> + PowerWeight(Iterator begin, Iterator end) : TupleWeight<W, n>(begin, end) {} + + static const PowerWeight<W, n> &Zero() { + static const PowerWeight<W, n> zero(TupleWeight<W, n>::Zero()); + return zero; + } + + static const PowerWeight<W, n> &One() { + static const PowerWeight<W, n> one(TupleWeight<W, n>::One()); + return one; + } + + static const PowerWeight<W, n> &NoWeight() { + static const PowerWeight<W, n> no_weight(TupleWeight<W, n>::NoWeight()); + return no_weight; + } + + static const string &Type() { + static string type; + if (type.empty()) { + string power; + Int64ToStr(n, &power); + type = W::Type() + "_^" + power; + } + return type; + } + + static uint64 Properties() { + uint64 props = W::Properties(); + return props & (kLeftSemiring | kRightSemiring | + kCommutative | kIdempotent); + } + + PowerWeight<W, n> Quantize(float delta = kDelta) const { + return TupleWeight<W, n>::Quantize(delta); + } + + ReverseWeight Reverse() const { + return TupleWeight<W, n>::Reverse(); + } +}; + + +// Semiring plus operation +template <class W, unsigned int n> +inline PowerWeight<W, n> Plus(const PowerWeight<W, n> &w1, + const PowerWeight<W, n> &w2) { + PowerWeight<W, n> w; + for (size_t i = 0; i < n; ++i) + w.SetValue(i, Plus(w1.Value(i), w2.Value(i))); + return w; +} + +// Semiring times operation +template <class W, unsigned int n> +inline PowerWeight<W, n> Times(const PowerWeight<W, n> &w1, + const PowerWeight<W, n> &w2) { + PowerWeight<W, n> w; + for (size_t i = 0; i < n; ++i) + w.SetValue(i, Times(w1.Value(i), w2.Value(i))); + return w; +} + +// Semiring divide operation +template <class W, unsigned int n> +inline PowerWeight<W, n> Divide(const PowerWeight<W, n> &w1, + const PowerWeight<W, n> &w2, + DivideType type = DIVIDE_ANY) { + PowerWeight<W, n> w; + for (size_t i = 0; i < n; ++i) + w.SetValue(i, Divide(w1.Value(i), w2.Value(i), type)); + return w; +} + +// Semimodule left scalar product +template <class W, unsigned int n> +inline PowerWeight<W, n> Times(const W &s, const PowerWeight<W, n> &w) { + PowerWeight<W, n> sw; + for (size_t i = 0; i < n; ++i) + sw.SetValue(i, Times(s, w.Value(i))); + return w; +} + +// Semimodule right scalar product +template <class W, unsigned int n> +inline PowerWeight<W, n> Times(const PowerWeight<W, n> &w, const W &s) { + PowerWeight<W, n> ws; + for (size_t i = 0; i < n; ++i) + ws.SetValue(i, Times(w.Value(i), s)); + return w; +} + +// Semimodule dot product +template <class W, unsigned int n> +inline W DotProduct(const PowerWeight<W, n> &w1, + const PowerWeight<W, n> &w2) { + W w = W::Zero(); + for (size_t i = 0; i < n; ++i) + w = Plus(w, Times(w1.Value(i), w2.Value(i))); + return w; +} + + +} // namespace fst + +#endif // FST_LIB_POWER_WEIGHT_H__ diff --git a/kaldi_io/src/tools/openfst/include/fst/product-weight.h b/kaldi_io/src/tools/openfst/include/fst/product-weight.h new file mode 100644 index 0000000..16dede8 --- /dev/null +++ b/kaldi_io/src/tools/openfst/include/fst/product-weight.h @@ -0,0 +1,115 @@ +// product-weight.h + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Copyright 2005-2010 Google, Inc. +// Author: [email protected] (Michael Riley) +// +// \file +// Product weight set and associated semiring operation definitions. + +#ifndef FST_LIB_PRODUCT_WEIGHT_H__ +#define FST_LIB_PRODUCT_WEIGHT_H__ + +#include <stack> +#include <string> + +#include <fst/pair-weight.h> +#include <fst/weight.h> + + +namespace fst { + +// Product semiring: W1 * W2 +template<class W1, class W2> +class ProductWeight : public PairWeight<W1, W2> { + public: + using PairWeight<W1, W2>::Zero; + using PairWeight<W1, W2>::One; + using PairWeight<W1, W2>::NoWeight; + using PairWeight<W1, W2>::Quantize; + using PairWeight<W1, W2>::Reverse; + + typedef ProductWeight<typename W1::ReverseWeight, typename W2::ReverseWeight> + ReverseWeight; + + ProductWeight() {} + + ProductWeight(const PairWeight<W1, W2>& w) : PairWeight<W1, W2>(w) {} + + ProductWeight(W1 w1, W2 w2) : PairWeight<W1, W2>(w1, w2) {} + + static const ProductWeight<W1, W2> &Zero() { + static const ProductWeight<W1, W2> zero(PairWeight<W1, W2>::Zero()); + return zero; + } + + static const ProductWeight<W1, W2> &One() { + static const ProductWeight<W1, W2> one(PairWeight<W1, W2>::One()); + return one; + } + + static const ProductWeight<W1, W2> &NoWeight() { + static const ProductWeight<W1, W2> no_weight( + PairWeight<W1, W2>::NoWeight()); + return no_weight; + } + + static const string &Type() { + static const string type = W1::Type() + "_X_" + W2::Type(); + return type; + } + + static uint64 Properties() { + uint64 props1 = W1::Properties(); + uint64 props2 = W2::Properties(); + return props1 & props2 & (kLeftSemiring | kRightSemiring | + kCommutative | kIdempotent); + } + + ProductWeight<W1, W2> Quantize(float delta = kDelta) const { + return PairWeight<W1, W2>::Quantize(delta); + } + + ReverseWeight Reverse() const { + return PairWeight<W1, W2>::Reverse(); + } + + +}; + +template <class W1, class W2> +inline ProductWeight<W1, W2> Plus(const ProductWeight<W1, W2> &w, + const ProductWeight<W1, W2> &v) { + return ProductWeight<W1, W2>(Plus(w.Value1(), v.Value1()), + Plus(w.Value2(), v.Value2())); +} + +template <class W1, class W2> +inline ProductWeight<W1, W2> Times(const ProductWeight<W1, W2> &w, + const ProductWeight<W1, W2> &v) { + return ProductWeight<W1, W2>(Times(w.Value1(), v.Value1()), + Times(w.Value2(), v.Value2())); +} + +template <class W1, class W2> +inline ProductWeight<W1, W2> Divide(const ProductWeight<W1, W2> &w, + const ProductWeight<W1, W2> &v, + DivideType typ = DIVIDE_ANY) { + return ProductWeight<W1, W2>(Divide(w.Value1(), v.Value1(), typ), + Divide(w.Value2(), v.Value2(), typ)); +} + +} // namespace fst + +#endif // FST_LIB_PRODUCT_WEIGHT_H__ diff --git a/kaldi_io/src/tools/openfst/include/fst/project.h b/kaldi_io/src/tools/openfst/include/fst/project.h new file mode 100644 index 0000000..07946c3 --- /dev/null +++ b/kaldi_io/src/tools/openfst/include/fst/project.h @@ -0,0 +1,148 @@ +// project.h + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Copyright 2005-2010 Google, Inc. +// Author: [email protected] (Michael Riley) +// +// \file +// Functions and classes to project an Fst on to its domain or range. + +#ifndef FST_LIB_PROJECT_H__ +#define FST_LIB_PROJECT_H__ + +#include <fst/arc-map.h> +#include <fst/mutable-fst.h> + + +namespace fst { + +// This specifies whether to project on input or output. +enum ProjectType { PROJECT_INPUT = 1, PROJECT_OUTPUT = 2 }; + + +// Mapper to implement projection per arc. +template <class A> class ProjectMapper { + public: + explicit ProjectMapper(ProjectType project_type) + : project_type_(project_type) {} + + A operator()(const A &arc) { + typename A::Label label = project_type_ == PROJECT_INPUT + ? arc.ilabel : arc.olabel; + return A(label, label, arc.weight, arc.nextstate); + } + + MapFinalAction FinalAction() const { return MAP_NO_SUPERFINAL; } + + MapSymbolsAction InputSymbolsAction() const { + return project_type_ == PROJECT_INPUT ? MAP_COPY_SYMBOLS : + MAP_CLEAR_SYMBOLS; + } + + MapSymbolsAction OutputSymbolsAction() const { + return project_type_ == PROJECT_OUTPUT ? MAP_COPY_SYMBOLS : + MAP_CLEAR_SYMBOLS; + } + + uint64 Properties(uint64 props) { + return ProjectProperties(props, project_type_ == PROJECT_INPUT); + } + + + private: + ProjectType project_type_; +}; + + +// Projects an FST onto its domain or range by either copying each arcs' +// input label to the output label or vice versa. This version modifies +// its input. +// +// Complexity: +// - Time: O(V + E) +// - Space: O(1) +// where V = # of states and E = # of arcs. +template<class Arc> inline +void Project(MutableFst<Arc> *fst, ProjectType project_type) { + ArcMap(fst, ProjectMapper<Arc>(project_type)); + if (project_type == PROJECT_INPUT) + fst->SetOutputSymbols(fst->InputSymbols()); + if (project_type == PROJECT_OUTPUT) + fst->SetInputSymbols(fst->OutputSymbols()); +} + + +// Projects an FST onto its domain or range by either copying each arc's +// input label to the output label or vice versa. This version is a delayed +// Fst. +// +// Complexity: +// - Time: O(v + e) +// - Space: O(1) +// where v = # of states visited, e = # of arcs visited. Constant +// time and to visit an input state or arc is assumed and exclusive +// of caching. +template <class A> +class ProjectFst : public ArcMapFst<A, A, ProjectMapper<A> > { + public: + typedef A Arc; + typedef ProjectMapper<A> C; + typedef ArcMapFstImpl< A, A, ProjectMapper<A> > Impl; + using ImplToFst<Impl>::GetImpl; + + ProjectFst(const Fst<A> &fst, ProjectType project_type) + : ArcMapFst<A, A, C>(fst, C(project_type)) { + if (project_type == PROJECT_INPUT) + GetImpl()->SetOutputSymbols(fst.InputSymbols()); + if (project_type == PROJECT_OUTPUT) + GetImpl()->SetInputSymbols(fst.OutputSymbols()); + } + + // See Fst<>::Copy() for doc. + ProjectFst(const ProjectFst<A> &fst, bool safe = false) + : ArcMapFst<A, A, C>(fst, safe) {} + + // Get a copy of this ProjectFst. See Fst<>::Copy() for further doc. + virtual ProjectFst<A> *Copy(bool safe = false) const { + return new ProjectFst(*this, safe); + } +}; + + +// Specialization for ProjectFst. +template <class A> +class StateIterator< ProjectFst<A> > + : public StateIterator< ArcMapFst<A, A, ProjectMapper<A> > > { + public: + explicit StateIterator(const ProjectFst<A> &fst) + : StateIterator< ArcMapFst<A, A, ProjectMapper<A> > >(fst) {} +}; + + +// Specialization for ProjectFst. +template <class A> +class ArcIterator< ProjectFst<A> > + : public ArcIterator< ArcMapFst<A, A, ProjectMapper<A> > > { + public: + ArcIterator(const ProjectFst<A> &fst, typename A::StateId s) + : ArcIterator< ArcMapFst<A, A, ProjectMapper<A> > >(fst, s) {} +}; + + +// Useful alias when using StdArc. +typedef ProjectFst<StdArc> StdProjectFst; + +} // namespace fst + +#endif // FST_LIB_PROJECT_H__ diff --git a/kaldi_io/src/tools/openfst/include/fst/properties.h b/kaldi_io/src/tools/openfst/include/fst/properties.h new file mode 100644 index 0000000..8fab16f --- /dev/null +++ b/kaldi_io/src/tools/openfst/include/fst/properties.h @@ -0,0 +1,460 @@ +// properties.h + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Copyright 2005-2010 Google, Inc. +// Author: Michael Riley <[email protected]> +// \file +// FST property bits. + +#ifndef FST_LIB_PROPERTIES_H__ +#define FST_LIB_PROPERTIES_H__ + +#include <sys/types.h> +#include <vector> +using std::vector; + +#include <fst/compat.h> + +namespace fst { + +// The property bits here assert facts about an FST. If individual +// bits are added, then the composite properties below, the property +// functions and property names in properties.cc, and +// TestProperties() in test-properties.h should be updated. + +// +// BINARY PROPERTIES +// +// For each property below, there is a single bit. If it is set, +// the property is true. If it is not set, the property is false. +// + +// The Fst is an ExpandedFst +const uint64 kExpanded = 0x0000000000000001ULL; + +// The Fst is a MutableFst +const uint64 kMutable = 0x0000000000000002ULL; + +// An error was detected while constructing/using the FST +const uint64 kError = 0x0000000000000004ULL; + +// +// TRINARY PROPERTIES +// +// For each of these properties below there is a pair of property bits +// - one positive and one negative. If the positive bit is set, the +// property is true. If the negative bit is set, the property is +// false. If neither is set, the property has unknown value. Both +// should never be simultaneously set. The individual positive and +// negative bit pairs should be adjacent with the positive bit +// at an odd and lower position. + +// ilabel == olabel for each arc +const uint64 kAcceptor = 0x0000000000010000ULL; +// ilabel != olabel for some arc +const uint64 kNotAcceptor = 0x0000000000020000ULL; + +// ilabels unique leaving each state +const uint64 kIDeterministic = 0x0000000000040000ULL; +// ilabels not unique leaving some state +const uint64 kNonIDeterministic = 0x0000000000080000ULL; + +// olabels unique leaving each state +const uint64 kODeterministic = 0x0000000000100000ULL; +// olabels not unique leaving some state +const uint64 kNonODeterministic = 0x0000000000200000ULL; + +// FST has input/output epsilons +const uint64 kEpsilons = 0x0000000000400000ULL; +// FST has no input/output epsilons +const uint64 kNoEpsilons = 0x0000000000800000ULL; + +// FST has input epsilons +const uint64 kIEpsilons = 0x0000000001000000ULL; +// FST has no input epsilons +const uint64 kNoIEpsilons = 0x0000000002000000ULL; + +// FST has output epsilons +const uint64 kOEpsilons = 0x0000000004000000ULL; +// FST has no output epsilons +const uint64 kNoOEpsilons = 0x0000000008000000ULL; + +// ilabels sorted wrt < for each state +const uint64 kILabelSorted = 0x0000000010000000ULL; +// ilabels not sorted wrt < for some state +const uint64 kNotILabelSorted = 0x0000000020000000ULL; + +// olabels sorted wrt < for each state +const uint64 kOLabelSorted = 0x0000000040000000ULL; +// olabels not sorted wrt < for some state +const uint64 kNotOLabelSorted = 0x0000000080000000ULL; + +// Non-trivial arc or final weights +const uint64 kWeighted = 0x0000000100000000ULL; +// Only trivial arc and final weights +const uint64 kUnweighted = 0x0000000200000000ULL; + +// FST has cycles +const uint64 kCyclic = 0x0000000400000000ULL; +// FST has no cycles +const uint64 kAcyclic = 0x0000000800000000ULL; + +// FST has cycles containing the initial state +const uint64 kInitialCyclic = 0x0000001000000000ULL; +// FST has no cycles containing the initial state +const uint64 kInitialAcyclic = 0x0000002000000000ULL; + +// FST is topologically sorted +const uint64 kTopSorted = 0x0000004000000000ULL; +// FST is not topologically sorted +const uint64 kNotTopSorted = 0x0000008000000000ULL; + +// All states reachable from the initial state +const uint64 kAccessible = 0x0000010000000000ULL; +// Not all states reachable from the initial state +const uint64 kNotAccessible = 0x0000020000000000ULL; + +// All states can reach a final state +const uint64 kCoAccessible = 0x0000040000000000ULL; +// Not all states can reach a final state +const uint64 kNotCoAccessible = 0x0000080000000000ULL; + +// If NumStates() > 0, then state 0 is initial, state NumStates()-1 is +// final, there is a transition from each non-final state i to +// state i+1, and there are no other transitions. +const uint64 kString = 0x0000100000000000ULL; + +// Not a string FST +const uint64 kNotString = 0x0000200000000000ULL; + +// +// COMPOSITE PROPERTIES +// + +// Properties of an empty machine +const uint64 kNullProperties + = kAcceptor | kIDeterministic | kODeterministic | kNoEpsilons | + kNoIEpsilons | kNoOEpsilons | kILabelSorted | kOLabelSorted | + kUnweighted | kAcyclic | kInitialAcyclic | kTopSorted | + kAccessible | kCoAccessible | kString; + +// Properties that are preserved when an FST is copied +const uint64 kCopyProperties + = kError | kAcceptor | kNotAcceptor | kIDeterministic | kNonIDeterministic | + kODeterministic | kNonODeterministic | kEpsilons | kNoEpsilons | + kIEpsilons | kNoIEpsilons | kOEpsilons | kNoOEpsilons | + kILabelSorted | kNotILabelSorted | kOLabelSorted | + kNotOLabelSorted | kWeighted | kUnweighted | kCyclic | kAcyclic | + kInitialCyclic | kInitialAcyclic | kTopSorted | kNotTopSorted | + kAccessible | kNotAccessible | kCoAccessible | kNotCoAccessible | + kString | kNotString; + +// Properites that are intrinsic to the FST +const uint64 kIntrinsicProperties + = kExpanded | kMutable | kAcceptor | kNotAcceptor | kIDeterministic | + kNonIDeterministic | kODeterministic | kNonODeterministic | + kEpsilons | kNoEpsilons | kIEpsilons | kNoIEpsilons | kOEpsilons | + kNoOEpsilons | kILabelSorted | kNotILabelSorted | kOLabelSorted | + kNotOLabelSorted | kWeighted | kUnweighted | kCyclic | kAcyclic | + kInitialCyclic | kInitialAcyclic | kTopSorted | kNotTopSorted | + kAccessible | kNotAccessible | kCoAccessible | kNotCoAccessible | + kString | kNotString; + +// Properites that are (potentially) extrinsic to the FST +const uint64 kExtrinsicProperties = kError; + +// Properties that are preserved when an FST start state is set +const uint64 kSetStartProperties + = kExpanded | kMutable | kError | kAcceptor | kNotAcceptor | + kIDeterministic | kNonIDeterministic | kODeterministic | + kNonODeterministic | kEpsilons | kNoEpsilons | kIEpsilons | + kNoIEpsilons | kOEpsilons | kNoOEpsilons | kILabelSorted | + kNotILabelSorted | kOLabelSorted | kNotOLabelSorted | kWeighted | + kUnweighted | kCyclic | kAcyclic | kTopSorted | kNotTopSorted | + kCoAccessible | kNotCoAccessible; + +// Properties that are preserved when an FST final weight is set +const uint64 kSetFinalProperties + = kExpanded | kMutable | kError | kAcceptor | kNotAcceptor | + kIDeterministic | kNonIDeterministic | kODeterministic | + kNonODeterministic | kEpsilons | kNoEpsilons | kIEpsilons | + kNoIEpsilons | kOEpsilons | kNoOEpsilons | kILabelSorted | + kNotILabelSorted | kOLabelSorted | kNotOLabelSorted | kCyclic | + kAcyclic | kInitialCyclic | kInitialAcyclic | kTopSorted | + kNotTopSorted | kAccessible | kNotAccessible; + +// Properties that are preserved when an FST state is added +const uint64 kAddStateProperties + = kExpanded | kMutable | kError | kAcceptor | kNotAcceptor | + kIDeterministic | kNonIDeterministic | kODeterministic | + kNonODeterministic | kEpsilons | kNoEpsilons | kIEpsilons | + kNoIEpsilons | kOEpsilons | kNoOEpsilons | kILabelSorted | + kNotILabelSorted | kOLabelSorted | kNotOLabelSorted | kWeighted | + kUnweighted | kCyclic | kAcyclic | kInitialCyclic | + kInitialAcyclic | kTopSorted | kNotTopSorted | kNotAccessible | + kNotCoAccessible | kNotString; + +// Properties that are preserved when an FST arc is added +const uint64 kAddArcProperties = kExpanded | kMutable | kError | kNotAcceptor | + kNonIDeterministic | kNonODeterministic | kEpsilons | kIEpsilons | + kOEpsilons | kNotILabelSorted | kNotOLabelSorted | kWeighted | + kCyclic | kInitialCyclic | kNotTopSorted | kAccessible | kCoAccessible; + +// Properties that are preserved when an FST arc is set +const uint64 kSetArcProperties = kExpanded | kMutable | kError; + +// Properties that are preserved when FST states are deleted +const uint64 kDeleteStatesProperties + = kExpanded | kMutable | kError | kAcceptor | kIDeterministic | + kODeterministic | kNoEpsilons | kNoIEpsilons | kNoOEpsilons | + kILabelSorted | kOLabelSorted | kUnweighted | kAcyclic | + kInitialAcyclic | kTopSorted; + +// Properties that are preserved when FST arcs are deleted +const uint64 kDeleteArcsProperties + = kExpanded | kMutable | kError | kAcceptor | kIDeterministic | + kODeterministic | kNoEpsilons | kNoIEpsilons | kNoOEpsilons | + kILabelSorted | kOLabelSorted | kUnweighted | kAcyclic | + kInitialAcyclic | kTopSorted | kNotAccessible | kNotCoAccessible; + +// Properties that are preserved when an FST's states are reordered +const uint64 kStateSortProperties = kExpanded | kMutable | kError | kAcceptor | + kNotAcceptor | kIDeterministic | kNonIDeterministic | + kODeterministic | kNonODeterministic | kEpsilons | kNoEpsilons | + kIEpsilons | kNoIEpsilons | kOEpsilons | kNoOEpsilons | + kILabelSorted | kNotILabelSorted | kOLabelSorted | kNotOLabelSorted + | kWeighted | kUnweighted | kCyclic | kAcyclic | kInitialCyclic | + kInitialAcyclic | kAccessible | kNotAccessible | kCoAccessible | + kNotCoAccessible; + +// Properties that are preserved when an FST's arcs are reordered +const uint64 kArcSortProperties = + kExpanded | kMutable | kError | kAcceptor | kNotAcceptor | kIDeterministic | + kNonIDeterministic | kODeterministic | kNonODeterministic | + kEpsilons | kNoEpsilons | kIEpsilons | kNoIEpsilons | kOEpsilons | + kNoOEpsilons | kWeighted | kUnweighted | kCyclic | kAcyclic | + kInitialCyclic | kInitialAcyclic | kTopSorted | kNotTopSorted | + kAccessible | kNotAccessible | kCoAccessible | kNotCoAccessible | + kString | kNotString; + +// Properties that are preserved when an FST's input labels are changed. +const uint64 kILabelInvariantProperties = + kExpanded | kMutable | kError | kODeterministic | kNonODeterministic | + kOEpsilons | kNoOEpsilons | kOLabelSorted | kNotOLabelSorted | + kWeighted | kUnweighted | kCyclic | kAcyclic | kInitialCyclic | + kInitialAcyclic | kTopSorted | kNotTopSorted | kAccessible | + kNotAccessible | kCoAccessible | kNotCoAccessible | kString | kNotString; + +// Properties that are preserved when an FST's output labels are changed. +const uint64 kOLabelInvariantProperties = + kExpanded | kMutable | kError | kIDeterministic | kNonIDeterministic | + kIEpsilons | kNoIEpsilons | kILabelSorted | kNotILabelSorted | + kWeighted | kUnweighted | kCyclic | kAcyclic | kInitialCyclic | + kInitialAcyclic | kTopSorted | kNotTopSorted | kAccessible | + kNotAccessible | kCoAccessible | kNotCoAccessible | kString | kNotString; + +// Properties that are preserved when an FST's weights are changed. +// This assumes that the set of states that are non-final is not changed. +const uint64 kWeightInvariantProperties = + kExpanded | kMutable | kError | kAcceptor | kNotAcceptor | kIDeterministic | + kNonIDeterministic | kODeterministic | kNonODeterministic | + kEpsilons | kNoEpsilons | kIEpsilons | kNoIEpsilons | kOEpsilons | + kNoOEpsilons | kILabelSorted | kNotILabelSorted | kOLabelSorted | + kNotOLabelSorted | kCyclic | kAcyclic | kInitialCyclic | kInitialAcyclic | + kTopSorted | kNotTopSorted | kAccessible | kNotAccessible | kCoAccessible | + kNotCoAccessible | kString | kNotString; + +// Properties that are preserved when a superfinal state is added +// and an FSTs final weights are directed to it via new transitions. +const uint64 kAddSuperFinalProperties = kExpanded | kMutable | kError | + kAcceptor | kNotAcceptor | kNonIDeterministic | kNonODeterministic | + kEpsilons | kIEpsilons | kOEpsilons | kNotILabelSorted | kNotOLabelSorted | + kWeighted | kUnweighted | kCyclic | kAcyclic | kInitialCyclic | + kInitialAcyclic | kNotTopSorted | kNotAccessible | kCoAccessible | + kNotCoAccessible | kNotString; + +// Properties that are preserved when a superfinal state is removed +// and the epsilon transitions directed to it are made final weights. +const uint64 kRmSuperFinalProperties = kExpanded | kMutable | kError | + kAcceptor | kNotAcceptor | kIDeterministic | kODeterministic | + kNoEpsilons | kNoIEpsilons | kNoOEpsilons | kILabelSorted | kOLabelSorted | + kWeighted | kUnweighted | kCyclic | kAcyclic | kInitialCyclic | + kInitialAcyclic | kTopSorted | kAccessible | kCoAccessible | + kNotCoAccessible | kString; + +// All binary properties +const uint64 kBinaryProperties = 0x0000000000000007ULL; + +// All trinary properties +const uint64 kTrinaryProperties = 0x00003fffffff0000ULL; + +// +// COMPUTED PROPERTIES +// + +// 1st bit of trinary properties +const uint64 kPosTrinaryProperties = + kTrinaryProperties & 0x5555555555555555ULL; + +// 2nd bit of trinary properties +const uint64 kNegTrinaryProperties = + kTrinaryProperties & 0xaaaaaaaaaaaaaaaaULL; + +// All properties +const uint64 kFstProperties = kBinaryProperties | kTrinaryProperties; + +// +// PROPERTY FUNCTIONS and STRING NAMES (defined in properties.cc) +// + +// Below are functions for getting property bit vectors when executing +// mutating fst operations. +inline uint64 SetStartProperties(uint64 inprops); +template <typename Weight> +uint64 SetFinalProperties(uint64 inprops, Weight old_weight, + Weight new_weight); +inline uint64 AddStateProperties(uint64 inprops); +template <typename A> +uint64 AddArcProperties(uint64 inprops, typename A::StateId s, const A &arc, + const A *prev_arc); +inline uint64 DeleteStatesProperties(uint64 inprops); +inline uint64 DeleteAllStatesProperties(uint64 inprops, uint64 staticProps); +inline uint64 DeleteArcsProperties(uint64 inprops); + +uint64 ClosureProperties(uint64 inprops, bool star, bool delayed = false); +uint64 ComplementProperties(uint64 inprops); +uint64 ComposeProperties(uint64 inprops1, uint64 inprops2); +uint64 ConcatProperties(uint64 inprops1, uint64 inprops2, + bool delayed = false); +uint64 DeterminizeProperties(uint64 inprops, bool has_subsequential_label); +uint64 FactorWeightProperties(uint64 inprops); +uint64 InvertProperties(uint64 inprops); +uint64 ProjectProperties(uint64 inprops, bool project_input); +uint64 RandGenProperties(uint64 inprops, bool weighted); +uint64 RelabelProperties(uint64 inprops); +uint64 ReplaceProperties(const vector<uint64>& inprops, + ssize_t root, + bool epsilon_on_replace, + bool no_empty_fst); +uint64 ReverseProperties(uint64 inprops); +uint64 ReweightProperties(uint64 inprops); +uint64 RmEpsilonProperties(uint64 inprops, bool delayed = false); +uint64 ShortestPathProperties(uint64 props); +uint64 SynchronizeProperties(uint64 inprops); +uint64 UnionProperties(uint64 inprops1, uint64 inprops2, bool delayed = false); + +// Definitions of inlined functions. + +uint64 SetStartProperties(uint64 inprops) { + uint64 outprops = inprops & kSetStartProperties; + if (inprops & kAcyclic) { + outprops |= kInitialAcyclic; + } + return outprops; +} + +uint64 AddStateProperties(uint64 inprops) { + return inprops & kAddStateProperties; +} + +uint64 DeleteStatesProperties(uint64 inprops) { + return inprops & kDeleteStatesProperties; +} + +uint64 DeleteAllStatesProperties(uint64 inprops, uint64 staticprops) { + uint64 outprops = inprops & kError; + return outprops | kNullProperties | staticprops; +} + +uint64 DeleteArcsProperties(uint64 inprops) { + return inprops & kDeleteArcsProperties; +} + +// Definitions of template functions. + +// +template <typename Weight> +uint64 SetFinalProperties(uint64 inprops, Weight old_weight, + Weight new_weight) { + uint64 outprops = inprops; + if (old_weight != Weight::Zero() && old_weight != Weight::One()) { + outprops &= ~kWeighted; + } + if (new_weight != Weight::Zero() && new_weight != Weight::One()) { + outprops |= kWeighted; + outprops &= ~kUnweighted; + } + outprops &= kSetFinalProperties | kWeighted | kUnweighted; + return outprops; +} + +/// Gets the properties for the MutableFst::AddArc method. +/// +/// \param inprops the current properties of the fst +/// \param s the id of the state to which an arc is being added +/// \param arc the arc being added to the state with the specified id +/// \param prev_arc the previously-added (or "last") arc of state s, or NULL if +/// s currently has no arcs +template <typename A> +uint64 AddArcProperties(uint64 inprops, typename A::StateId s, + const A &arc, const A *prev_arc) { + uint64 outprops = inprops; + if (arc.ilabel != arc.olabel) { + outprops |= kNotAcceptor; + outprops &= ~kAcceptor; + } + if (arc.ilabel == 0) { + outprops |= kIEpsilons; + outprops &= ~kNoIEpsilons; + if (arc.olabel == 0) { + outprops |= kEpsilons; + outprops &= ~kNoEpsilons; + } + } + if (arc.olabel == 0) { + outprops |= kOEpsilons; + outprops &= ~kNoOEpsilons; + } + if (prev_arc != 0) { + if (prev_arc->ilabel > arc.ilabel) { + outprops |= kNotILabelSorted; + outprops &= ~kILabelSorted; + } + if (prev_arc->olabel > arc.olabel) { + outprops |= kNotOLabelSorted; + outprops &= ~kOLabelSorted; + } + } + if (arc.weight != A::Weight::Zero() && arc.weight != A::Weight::One()) { + outprops |= kWeighted; + outprops &= ~kUnweighted; + } + if (arc.nextstate <= s) { + outprops |= kNotTopSorted; + outprops &= ~kTopSorted; + } + outprops &= kAddArcProperties | kAcceptor | + kNoEpsilons | kNoIEpsilons | kNoOEpsilons | + kILabelSorted | kOLabelSorted | kUnweighted | kTopSorted; + if (outprops & kTopSorted) { + outprops |= kAcyclic | kInitialAcyclic; + } + return outprops; +} + +extern const char *PropertyNames[]; + +} // namespace fst + +#endif // FST_LIB_PROPERTIES_H__ diff --git a/kaldi_io/src/tools/openfst/include/fst/prune.h b/kaldi_io/src/tools/openfst/include/fst/prune.h new file mode 100644 index 0000000..5ea5b4d --- /dev/null +++ b/kaldi_io/src/tools/openfst/include/fst/prune.h @@ -0,0 +1,339 @@ +// prune.h + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Copyright 2005-2010 Google, Inc. +// Author: [email protected] (Cyril Allauzen) +// +// \file +// Functions implementing pruning. + +#ifndef FST_LIB_PRUNE_H__ +#define FST_LIB_PRUNE_H__ + +#include <vector> +using std::vector; + +#include <fst/arcfilter.h> +#include <fst/heap.h> +#include <fst/shortest-distance.h> + + +namespace fst { + +template <class A, class ArcFilter> +class PruneOptions { + public: + typedef typename A::Weight Weight; + typedef typename A::StateId StateId; + + // Pruning weight threshold. + Weight weight_threshold; + // Pruning state threshold. + StateId state_threshold; + // Arc filter. + ArcFilter filter; + // If non-zero, passes in pre-computed shortest distance to final states. + const vector<Weight> *distance; + // Determines the degree of convergence required when computing shortest + // distances. + float delta; + + explicit PruneOptions(const Weight& w, StateId s, ArcFilter f, + vector<Weight> *d = 0, float e = kDelta) + : weight_threshold(w), + state_threshold(s), + filter(f), + distance(d), + delta(e) {} + private: + PruneOptions(); // disallow +}; + + +template <class S, class W> +class PruneCompare { + public: + typedef S StateId; + typedef W Weight; + + PruneCompare(const vector<Weight> &idistance, + const vector<Weight> &fdistance) + : idistance_(idistance), fdistance_(fdistance) {} + + bool operator()(const StateId x, const StateId y) const { + Weight wx = Times(x < idistance_.size() ? idistance_[x] : Weight::Zero(), + x < fdistance_.size() ? fdistance_[x] : Weight::Zero()); + Weight wy = Times(y < idistance_.size() ? idistance_[y] : Weight::Zero(), + y < fdistance_.size() ? fdistance_[y] : Weight::Zero()); + return less_(wx, wy); + } + + private: + const vector<Weight> &idistance_; + const vector<Weight> &fdistance_; + NaturalLess<Weight> less_; +}; + + + +// Pruning algorithm: this version modifies its input and it takes an +// options class as an argment. Delete states and arcs in 'fst' that +// do not belong to a successful path whose weight is no more than +// the weight of the shortest path Times() 'opts.weight_threshold'. +// When 'opts.state_threshold != kNoStateId', the resulting transducer +// will restricted further to have at most 'opts.state_threshold' +// states. Weights need to be commutative and have the path +// property. The weight 'w' of any cycle needs to be bounded, i.e., +// 'Plus(w, W::One()) = One()'. +template <class Arc, class ArcFilter> +void Prune(MutableFst<Arc> *fst, + const PruneOptions<Arc, ArcFilter> &opts) { + typedef typename Arc::Weight Weight; + typedef typename Arc::StateId StateId; + + if ((Weight::Properties() & (kPath | kCommutative)) + != (kPath | kCommutative)) { + FSTERROR() << "Prune: Weight needs to have the path property and" + << " be commutative: " + << Weight::Type(); + fst->SetProperties(kError, kError); + return; + } + StateId ns = fst->NumStates(); + if (ns == 0) return; + vector<Weight> idistance(ns, Weight::Zero()); + vector<Weight> tmp; + if (!opts.distance) { + tmp.reserve(ns); + ShortestDistance(*fst, &tmp, true, opts.delta); + } + const vector<Weight> *fdistance = opts.distance ? opts.distance : &tmp; + + if ((opts.state_threshold == 0) || + (fdistance->size() <= fst->Start()) || + ((*fdistance)[fst->Start()] == Weight::Zero())) { + fst->DeleteStates(); + return; + } + PruneCompare<StateId, Weight> compare(idistance, *fdistance); + Heap< StateId, PruneCompare<StateId, Weight>, false> heap(compare); + vector<bool> visited(ns, false); + vector<size_t> enqueued(ns, kNoKey); + vector<StateId> dead; + dead.push_back(fst->AddState()); + NaturalLess<Weight> less; + Weight limit = Times((*fdistance)[fst->Start()], opts.weight_threshold); + + StateId num_visited = 0; + StateId s = fst->Start(); + if (!less(limit, (*fdistance)[s])) { + idistance[s] = Weight::One(); + enqueued[s] = heap.Insert(s); + ++num_visited; + } + + while (!heap.Empty()) { + s = heap.Top(); + heap.Pop(); + enqueued[s] = kNoKey; + visited[s] = true; + if (less(limit, Times(idistance[s], fst->Final(s)))) + fst->SetFinal(s, Weight::Zero()); + for (MutableArcIterator< MutableFst<Arc> > ait(fst, s); + !ait.Done(); + ait.Next()) { + Arc arc = ait.Value(); + if (!opts.filter(arc)) continue; + Weight weight = Times(Times(idistance[s], arc.weight), + arc.nextstate < fdistance->size() + ? (*fdistance)[arc.nextstate] + : Weight::Zero()); + if (less(limit, weight)) { + arc.nextstate = dead[0]; + ait.SetValue(arc); + continue; + } + if (less(Times(idistance[s], arc.weight), idistance[arc.nextstate])) + idistance[arc.nextstate] = Times(idistance[s], arc.weight); + if (visited[arc.nextstate]) continue; + if ((opts.state_threshold != kNoStateId) && + (num_visited >= opts.state_threshold)) + continue; + if (enqueued[arc.nextstate] == kNoKey) { + enqueued[arc.nextstate] = heap.Insert(arc.nextstate); + ++num_visited; + } else { + heap.Update(enqueued[arc.nextstate], arc.nextstate); + } + } + } + for (size_t i = 0; i < visited.size(); ++i) + if (!visited[i]) dead.push_back(i); + fst->DeleteStates(dead); +} + + +// Pruning algorithm: this version modifies its input and simply takes +// the pruning threshold as an argument. Delete states and arcs in +// 'fst' that do not belong to a successful path whose weight is no +// more than the weight of the shortest path Times() +// 'weight_threshold'. When 'state_threshold != kNoStateId', the +// resulting transducer will be restricted further to have at most +// 'opts.state_threshold' states. Weights need to be commutative and +// have the path property. The weight 'w' of any cycle needs to be +// bounded, i.e., 'Plus(w, W::One()) = One()'. +template <class Arc> +void Prune(MutableFst<Arc> *fst, + typename Arc::Weight weight_threshold, + typename Arc::StateId state_threshold = kNoStateId, + double delta = kDelta) { + PruneOptions<Arc, AnyArcFilter<Arc> > opts(weight_threshold, state_threshold, + AnyArcFilter<Arc>(), 0, delta); + Prune(fst, opts); +} + + +// Pruning algorithm: this version writes the pruned input Fst to an +// output MutableFst and it takes an options class as an argument. +// 'ofst' contains states and arcs that belong to a successful path in +// 'ifst' whose weight is no more than the weight of the shortest path +// Times() 'opts.weight_threshold'. When 'opts.state_threshold != +// kNoStateId', 'ofst' will be restricted further to have at most +// 'opts.state_threshold' states. Weights need to be commutative and +// have the path property. The weight 'w' of any cycle needs to be +// bounded, i.e., 'Plus(w, W::One()) = One()'. +template <class Arc, class ArcFilter> +void Prune(const Fst<Arc> &ifst, + MutableFst<Arc> *ofst, + const PruneOptions<Arc, ArcFilter> &opts) { + typedef typename Arc::Weight Weight; + typedef typename Arc::StateId StateId; + + if ((Weight::Properties() & (kPath | kCommutative)) + != (kPath | kCommutative)) { + FSTERROR() << "Prune: Weight needs to have the path property and" + << " be commutative: " + << Weight::Type(); + ofst->SetProperties(kError, kError); + return; + } + ofst->DeleteStates(); + ofst->SetInputSymbols(ifst.InputSymbols()); + ofst->SetOutputSymbols(ifst.OutputSymbols()); + if (ifst.Start() == kNoStateId) + return; + NaturalLess<Weight> less; + if (less(opts.weight_threshold, Weight::One()) || + (opts.state_threshold == 0)) + return; + vector<Weight> idistance; + vector<Weight> tmp; + if (!opts.distance) + ShortestDistance(ifst, &tmp, true, opts.delta); + const vector<Weight> *fdistance = opts.distance ? opts.distance : &tmp; + + if ((fdistance->size() <= ifst.Start()) || + ((*fdistance)[ifst.Start()] == Weight::Zero())) { + return; + } + PruneCompare<StateId, Weight> compare(idistance, *fdistance); + Heap< StateId, PruneCompare<StateId, Weight>, false> heap(compare); + vector<StateId> copy; + vector<size_t> enqueued; + vector<bool> visited; + + StateId s = ifst.Start(); + Weight limit = Times(s < fdistance->size() ? (*fdistance)[s] : Weight::Zero(), + opts.weight_threshold); + while (copy.size() <= s) + copy.push_back(kNoStateId); + copy[s] = ofst->AddState(); + ofst->SetStart(copy[s]); + while (idistance.size() <= s) + idistance.push_back(Weight::Zero()); + idistance[s] = Weight::One(); + while (enqueued.size() <= s) { + enqueued.push_back(kNoKey); + visited.push_back(false); + } + enqueued[s] = heap.Insert(s); + + while (!heap.Empty()) { + s = heap.Top(); + heap.Pop(); + enqueued[s] = kNoKey; + visited[s] = true; + if (!less(limit, Times(idistance[s], ifst.Final(s)))) + ofst->SetFinal(copy[s], ifst.Final(s)); + for (ArcIterator< Fst<Arc> > ait(ifst, s); + !ait.Done(); + ait.Next()) { + const Arc &arc = ait.Value(); + if (!opts.filter(arc)) continue; + Weight weight = Times(Times(idistance[s], arc.weight), + arc.nextstate < fdistance->size() + ? (*fdistance)[arc.nextstate] + : Weight::Zero()); + if (less(limit, weight)) continue; + if ((opts.state_threshold != kNoStateId) && + (ofst->NumStates() >= opts.state_threshold)) + continue; + while (idistance.size() <= arc.nextstate) + idistance.push_back(Weight::Zero()); + if (less(Times(idistance[s], arc.weight), + idistance[arc.nextstate])) + idistance[arc.nextstate] = Times(idistance[s], arc.weight); + while (copy.size() <= arc.nextstate) + copy.push_back(kNoStateId); + if (copy[arc.nextstate] == kNoStateId) + copy[arc.nextstate] = ofst->AddState(); + ofst->AddArc(copy[s], Arc(arc.ilabel, arc.olabel, arc.weight, + copy[arc.nextstate])); + while (enqueued.size() <= arc.nextstate) { + enqueued.push_back(kNoKey); + visited.push_back(false); + } + if (visited[arc.nextstate]) continue; + if (enqueued[arc.nextstate] == kNoKey) + enqueued[arc.nextstate] = heap.Insert(arc.nextstate); + else + heap.Update(enqueued[arc.nextstate], arc.nextstate); + } + } +} + + +// Pruning algorithm: this version writes the pruned input Fst to an +// output MutableFst and simply takes the pruning threshold as an +// argument. 'ofst' contains states and arcs that belong to a +// successful path in 'ifst' whose weight is no more than +// the weight of the shortest path Times() 'weight_threshold'. When +// 'state_threshold != kNoStateId', 'ofst' will be restricted further +// to have at most 'opts.state_threshold' states. Weights need to be +// commutative and have the path property. The weight 'w' of any cycle +// needs to be bounded, i.e., 'Plus(w, W::One()) = W::One()'. +template <class Arc> +void Prune(const Fst<Arc> &ifst, + MutableFst<Arc> *ofst, + typename Arc::Weight weight_threshold, + typename Arc::StateId state_threshold = kNoStateId, + float delta = kDelta) { + PruneOptions<Arc, AnyArcFilter<Arc> > opts(weight_threshold, state_threshold, + AnyArcFilter<Arc>(), 0, delta); + Prune(ifst, ofst, opts); +} + +} // namespace fst + +#endif // FST_LIB_PRUNE_H_ diff --git a/kaldi_io/src/tools/openfst/include/fst/push.h b/kaldi_io/src/tools/openfst/include/fst/push.h new file mode 100644 index 0000000..1f7a8fa --- /dev/null +++ b/kaldi_io/src/tools/openfst/include/fst/push.h @@ -0,0 +1,175 @@ +// push.h + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Copyright 2005-2010 Google, Inc. +// Author: [email protected] (Cyril Allauzen) +// +// \file +// Class to reweight/push an FST. + +#ifndef FST_LIB_PUSH_H__ +#define FST_LIB_PUSH_H__ + +#include <vector> +using std::vector; + +#include <fst/factor-weight.h> +#include <fst/fst.h> +#include <fst/arc-map.h> +#include <fst/reweight.h> +#include <fst/shortest-distance.h> + + +namespace fst { + +// Private helper functions for Push +namespace internal { + +// Compute the total weight (sum of the weights of all accepting paths) from +// the output of ShortestDistance. 'distance' is the shortest distance from the +// initial state when 'reverse == false' and to the final states when +// 'reverse == true'. +template <class Arc> +typename Arc::Weight ComputeTotalWeight( + const Fst<Arc> &fst, + const vector<typename Arc::Weight> &distance, + bool reverse) { + if (reverse) + return fst.Start() < distance.size() ? + distance[fst.Start()] : Arc::Weight::Zero(); + + typename Arc::Weight sum = Arc::Weight::Zero(); + for (typename Arc::StateId s = 0; s < distance.size(); ++s) + sum = Plus(sum, Times(distance[s], fst.Final(s))); + return sum; +} + +// Divide the weight of every accepting path by 'w'. The weight 'w' is +// divided at the final states if 'at_final == true' and at the +// initial state otherwise. +template <class Arc> +void RemoveWeight(MutableFst<Arc> *fst, typename Arc::Weight w, bool at_final) { + if ((w == Arc::Weight::One()) || (w == Arc::Weight::Zero())) + return; + + if (at_final) { + // Remove 'w' from the final states + for (StateIterator< MutableFst<Arc> > sit(*fst); + !sit.Done(); + sit.Next()) + fst->SetFinal(sit.Value(), + Divide(fst->Final(sit.Value()), w, DIVIDE_RIGHT)); + } else { // at_final == false + // Remove 'w' from the initial state + typename Arc::StateId start = fst->Start(); + for (MutableArcIterator<MutableFst<Arc> > ait(fst, start); + !ait.Done(); + ait.Next()) { + Arc arc = ait.Value(); + arc.weight = Divide(arc.weight, w, DIVIDE_LEFT); + ait.SetValue(arc); + } + fst->SetFinal(start, Divide(fst->Final(start), w, DIVIDE_LEFT)); + } +} +} // namespace internal + +// Pushes the weights in FST in the direction defined by TYPE. If +// pushing towards the initial state, the sum of the weight of the +// outgoing transitions and final weight at a non-initial state is +// equal to One() in the resulting machine. If pushing towards the +// final state, the same property holds on the reverse machine. +// +// Weight needs to be left distributive when pushing towards the +// initial state and right distributive when pushing towards the final +// states. +template <class Arc> +void Push(MutableFst<Arc> *fst, + ReweightType type, + float delta = kDelta, + bool remove_total_weight = false) { + vector<typename Arc::Weight> distance; + ShortestDistance(*fst, &distance, type == REWEIGHT_TO_INITIAL, delta); + typename Arc::Weight total_weight = Arc::Weight::One(); + if (remove_total_weight) + total_weight = internal::ComputeTotalWeight(*fst, distance, + type == REWEIGHT_TO_INITIAL); + Reweight(fst, distance, type); + if (remove_total_weight) + internal::RemoveWeight(fst, total_weight, type == REWEIGHT_TO_FINAL); +} + +const uint32 kPushWeights = 0x0001; +const uint32 kPushLabels = 0x0002; +const uint32 kPushRemoveTotalWeight = 0x0004; +const uint32 kPushRemoveCommonAffix = 0x0008; + +// OFST obtained from IFST by pushing weights and/or labels according +// to PTYPE in the direction defined by RTYPE. Weight needs to be +// left distributive when pushing weights towards the initial state +// and right distributive when pushing weights towards the final +// states. +template <class Arc, ReweightType rtype> +void Push(const Fst<Arc> &ifst, + MutableFst<Arc> *ofst, + uint32 ptype, + float delta = kDelta) { + + if ((ptype & (kPushWeights | kPushLabels)) == kPushWeights) { + *ofst = ifst; + Push(ofst, rtype, delta, ptype & kPushRemoveTotalWeight); + } else if (ptype & kPushLabels) { + const StringType stype = rtype == REWEIGHT_TO_INITIAL + ? STRING_LEFT + : STRING_RIGHT; + vector<typename GallicArc<Arc, stype>::Weight> gdistance; + VectorFst<GallicArc<Arc, stype> > gfst; + ArcMap(ifst, &gfst, ToGallicMapper<Arc, stype>()); + if (ptype & kPushWeights ) { + ShortestDistance(gfst, &gdistance, rtype == REWEIGHT_TO_INITIAL, delta); + } else { + ArcMapFst<Arc, Arc, RmWeightMapper<Arc> > + uwfst(ifst, RmWeightMapper<Arc>()); + ArcMapFst<Arc, GallicArc<Arc, stype>, ToGallicMapper<Arc, stype> > + guwfst(uwfst, ToGallicMapper<Arc, stype>()); + ShortestDistance(guwfst, &gdistance, rtype == REWEIGHT_TO_INITIAL, delta); + } + typename GallicArc<Arc, stype>::Weight total_weight = + GallicArc<Arc, stype>::Weight::One(); + if (ptype & (kPushRemoveTotalWeight | kPushRemoveCommonAffix)) { + total_weight = internal::ComputeTotalWeight( + gfst, gdistance, rtype == REWEIGHT_TO_INITIAL); + total_weight = typename GallicArc<Arc, stype>::Weight( + ptype & kPushRemoveCommonAffix ? total_weight.Value1() + : StringWeight<typename Arc::Label, stype>::One(), + ptype & kPushRemoveTotalWeight ? total_weight.Value2() + : Arc::Weight::One()); + } + Reweight(&gfst, gdistance, rtype); + if (ptype & (kPushRemoveTotalWeight | kPushRemoveCommonAffix)) + internal::RemoveWeight(&gfst, total_weight, rtype == REWEIGHT_TO_FINAL); + FactorWeightFst< GallicArc<Arc, stype>, GallicFactor<typename Arc::Label, + typename Arc::Weight, stype> > fwfst(gfst); + ArcMap(fwfst, ofst, FromGallicMapper<Arc, stype>()); + ofst->SetOutputSymbols(ifst.OutputSymbols()); + } else { + LOG(WARNING) << "Push: pushing type is set to 0: " + << "pushing neither labels nor weights."; + *ofst = ifst; + } +} + +} // namespace fst + +#endif /* FST_LIB_PUSH_H_ */ diff --git a/kaldi_io/src/tools/openfst/include/fst/queue.h b/kaldi_io/src/tools/openfst/include/fst/queue.h new file mode 100644 index 0000000..95a082d --- /dev/null +++ b/kaldi_io/src/tools/openfst/include/fst/queue.h @@ -0,0 +1,938 @@ +// queue.h + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Copyright 2005-2010 Google, Inc. +// Author: [email protected] (Cyril Allauzen) +// +// \file +// Functions and classes for various Fst state queues with +// a unified interface. + +#ifndef FST_LIB_QUEUE_H__ +#define FST_LIB_QUEUE_H__ + +#include <deque> +using std::deque; +#include <vector> +using std::vector; + +#include <fst/arcfilter.h> +#include <fst/connect.h> +#include <fst/heap.h> +#include <fst/topsort.h> + + +namespace fst { + +// template <class S> +// class Queue { +// public: +// typedef typename S StateId; +// +// // Ctr: may need args (e.g., Fst, comparator) for some queues +// Queue(...); +// // Returns the head of the queue +// StateId Head() const; +// // Inserts a state +// void Enqueue(StateId s); +// // Removes the head of the queue +// void Dequeue(); +// // Updates ordering of state s when weight changes, if necessary +// void Update(StateId s); +// // Does the queue contain no elements? +// bool Empty() const; +// // Remove all states from queue +// void Clear(); +// }; + +// State queue types. +enum QueueType { + TRIVIAL_QUEUE = 0, // Single state queue + FIFO_QUEUE = 1, // First-in, first-out queue + LIFO_QUEUE = 2, // Last-in, first-out queue + SHORTEST_FIRST_QUEUE = 3, // Shortest-first queue + TOP_ORDER_QUEUE = 4, // Topologically-ordered queue + STATE_ORDER_QUEUE = 5, // State-ID ordered queue + SCC_QUEUE = 6, // Component graph top-ordered meta-queue + AUTO_QUEUE = 7, // Auto-selected queue + OTHER_QUEUE = 8 + }; + + +// QueueBase, templated on the StateId, is the base class shared by the +// queues considered by AutoQueue. +template <class S> +class QueueBase { + public: + typedef S StateId; + + QueueBase(QueueType type) : queue_type_(type), error_(false) {} + virtual ~QueueBase() {} + StateId Head() const { return Head_(); } + void Enqueue(StateId s) { Enqueue_(s); } + void Dequeue() { Dequeue_(); } + void Update(StateId s) { Update_(s); } + bool Empty() const { return Empty_(); } + void Clear() { Clear_(); } + QueueType Type() { return queue_type_; } + bool Error() const { return error_; } + void SetError(bool error) { error_ = error; } + + private: + // This allows base-class virtual access to non-virtual derived- + // class members of the same name. It makes the derived class more + // efficient to use but unsafe to further derive. + virtual StateId Head_() const = 0; + virtual void Enqueue_(StateId s) = 0; + virtual void Dequeue_() = 0; + virtual void Update_(StateId s) = 0; + virtual bool Empty_() const = 0; + virtual void Clear_() = 0; + + QueueType queue_type_; + bool error_; +}; + + +// Trivial queue discipline, templated on the StateId. You may enqueue +// at most one state at a time. It is used for strongly connected components +// with only one state and no self loops. +template <class S> +class TrivialQueue : public QueueBase<S> { +public: + typedef S StateId; + + TrivialQueue() : QueueBase<S>(TRIVIAL_QUEUE), front_(kNoStateId) {} + StateId Head() const { return front_; } + void Enqueue(StateId s) { front_ = s; } + void Dequeue() { front_ = kNoStateId; } + void Update(StateId s) {} + bool Empty() const { return front_ == kNoStateId; } + void Clear() { front_ = kNoStateId; } + + +private: + // This allows base-class virtual access to non-virtual derived- + // class members of the same name. It makes the derived class more + // efficient to use but unsafe to further derive. + virtual StateId Head_() const { return Head(); } + virtual void Enqueue_(StateId s) { Enqueue(s); } + virtual void Dequeue_() { Dequeue(); } + virtual void Update_(StateId s) { Update(s); } + virtual bool Empty_() const { return Empty(); } + virtual void Clear_() { return Clear(); } + + StateId front_; +}; + + +// First-in, first-out queue discipline, templated on the StateId. +template <class S> +class FifoQueue : public QueueBase<S>, public deque<S> { + public: + using deque<S>::back; + using deque<S>::push_front; + using deque<S>::pop_back; + using deque<S>::empty; + using deque<S>::clear; + + typedef S StateId; + + FifoQueue() : QueueBase<S>(FIFO_QUEUE) {} + StateId Head() const { return back(); } + void Enqueue(StateId s) { push_front(s); } + void Dequeue() { pop_back(); } + void Update(StateId s) {} + bool Empty() const { return empty(); } + void Clear() { clear(); } + + private: + // This allows base-class virtual access to non-virtual derived- + // class members of the same name. It makes the derived class more + // efficient to use but unsafe to further derive. + virtual StateId Head_() const { return Head(); } + virtual void Enqueue_(StateId s) { Enqueue(s); } + virtual void Dequeue_() { Dequeue(); } + virtual void Update_(StateId s) { Update(s); } + virtual bool Empty_() const { return Empty(); } + virtual void Clear_() { return Clear(); } +}; + + +// Last-in, first-out queue discipline, templated on the StateId. +template <class S> +class LifoQueue : public QueueBase<S>, public deque<S> { + public: + using deque<S>::front; + using deque<S>::push_front; + using deque<S>::pop_front; + using deque<S>::empty; + using deque<S>::clear; + + typedef S StateId; + + LifoQueue() : QueueBase<S>(LIFO_QUEUE) {} + StateId Head() const { return front(); } + void Enqueue(StateId s) { push_front(s); } + void Dequeue() { pop_front(); } + void Update(StateId s) {} + bool Empty() const { return empty(); } + void Clear() { clear(); } + + private: + // This allows base-class virtual access to non-virtual derived- + // class members of the same name. It makes the derived class more + // efficient to use but unsafe to further derive. + virtual StateId Head_() const { return Head(); } + virtual void Enqueue_(StateId s) { Enqueue(s); } + virtual void Dequeue_() { Dequeue(); } + virtual void Update_(StateId s) { Update(s); } + virtual bool Empty_() const { return Empty(); } + virtual void Clear_() { return Clear(); } +}; + + +// Shortest-first queue discipline, templated on the StateId and +// comparison function object. Comparison function object COMP is +// used to compare two StateIds. If a (single) state's order changes, +// it can be reordered in the queue with a call to Update(). +// If 'update == false', call to Update() does not reorder the queue. +template <typename S, typename C, bool update = true> +class ShortestFirstQueue : public QueueBase<S> { + public: + typedef S StateId; + typedef C Compare; + + ShortestFirstQueue(C comp) + : QueueBase<S>(SHORTEST_FIRST_QUEUE), heap_(comp) {} + + StateId Head() const { return heap_.Top(); } + + void Enqueue(StateId s) { + if (update) { + for (StateId i = key_.size(); i <= s; ++i) + key_.push_back(kNoKey); + key_[s] = heap_.Insert(s); + } else { + heap_.Insert(s); + } + } + + void Dequeue() { + if (update) + key_[heap_.Pop()] = kNoKey; + else + heap_.Pop(); + } + + void Update(StateId s) { + if (!update) + return; + if (s >= key_.size() || key_[s] == kNoKey) { + Enqueue(s); + } else { + heap_.Update(key_[s], s); + } + } + + bool Empty() const { return heap_.Empty(); } + + void Clear() { + heap_.Clear(); + if (update) key_.clear(); + } + + private: + Heap<S, C, false> heap_; + vector<ssize_t> key_; + + // This allows base-class virtual access to non-virtual derived- + // class members of the same name. It makes the derived class more + // efficient to use but unsafe to further derive. + virtual StateId Head_() const { return Head(); } + virtual void Enqueue_(StateId s) { Enqueue(s); } + virtual void Dequeue_() { Dequeue(); } + virtual void Update_(StateId s) { Update(s); } + virtual bool Empty_() const { return Empty(); } + virtual void Clear_() { return Clear(); } +}; + + +// Given a vector that maps from states to weights and a Less +// comparison function object between weights, this class defines a +// comparison function object between states. +template <typename S, typename L> +class StateWeightCompare { + public: + typedef L Less; + typedef typename L::Weight Weight; + typedef S StateId; + + StateWeightCompare(const vector<Weight>& weights, const L &less) + : weights_(weights), less_(less) {} + + bool operator()(const S x, const S y) const { + return less_(weights_[x], weights_[y]); + } + + private: + const vector<Weight>& weights_; + L less_; +}; + + +// Shortest-first queue discipline, templated on the StateId and Weight, is +// specialized to use the weight's natural order for the comparison function. +template <typename S, typename W> +class NaturalShortestFirstQueue : + public ShortestFirstQueue<S, StateWeightCompare<S, NaturalLess<W> > > { + public: + typedef StateWeightCompare<S, NaturalLess<W> > C; + + NaturalShortestFirstQueue(const vector<W> &distance) : + ShortestFirstQueue<S, C>(C(distance, less_)) {} + + private: + NaturalLess<W> less_; +}; + +// Topological-order queue discipline, templated on the StateId. +// States are ordered in the queue topologically. The FST must be acyclic. +template <class S> +class TopOrderQueue : public QueueBase<S> { + public: + typedef S StateId; + + // This constructor computes the top. order. It accepts an arc filter + // to limit the transitions considered in that computation (e.g., only + // the epsilon graph). + template <class Arc, class ArcFilter> + TopOrderQueue(const Fst<Arc> &fst, ArcFilter filter) + : QueueBase<S>(TOP_ORDER_QUEUE), front_(0), back_(kNoStateId), + order_(0), state_(0) { + bool acyclic; + TopOrderVisitor<Arc> top_order_visitor(&order_, &acyclic); + DfsVisit(fst, &top_order_visitor, filter); + if (!acyclic) { + FSTERROR() << "TopOrderQueue: fst is not acyclic."; + QueueBase<S>::SetError(true); + } + state_.resize(order_.size(), kNoStateId); + } + + // This constructor is passed the top. order, useful when we know it + // beforehand. + TopOrderQueue(const vector<StateId> &order) + : QueueBase<S>(TOP_ORDER_QUEUE), front_(0), back_(kNoStateId), + order_(order), state_(order.size(), kNoStateId) {} + + StateId Head() const { return state_[front_]; } + + void Enqueue(StateId s) { + if (front_ > back_) front_ = back_ = order_[s]; + else if (order_[s] > back_) back_ = order_[s]; + else if (order_[s] < front_) front_ = order_[s]; + state_[order_[s]] = s; + } + + void Dequeue() { + state_[front_] = kNoStateId; + while ((front_ <= back_) && (state_[front_] == kNoStateId)) ++front_; + } + + void Update(StateId s) {} + + bool Empty() const { return front_ > back_; } + + void Clear() { + for (StateId i = front_; i <= back_; ++i) state_[i] = kNoStateId; + back_ = kNoStateId; + front_ = 0; + } + + private: + StateId front_; + StateId back_; + vector<StateId> order_; + vector<StateId> state_; + + // This allows base-class virtual access to non-virtual derived- + // class members of the same name. It makes the derived class more + // efficient to use but unsafe to further derive. + virtual StateId Head_() const { return Head(); } + virtual void Enqueue_(StateId s) { Enqueue(s); } + virtual void Dequeue_() { Dequeue(); } + virtual void Update_(StateId s) { Update(s); } + virtual bool Empty_() const { return Empty(); } + virtual void Clear_() { return Clear(); } +}; + + +// State order queue discipline, templated on the StateId. +// States are ordered in the queue by state Id. +template <class S> +class StateOrderQueue : public QueueBase<S> { +public: + typedef S StateId; + + StateOrderQueue() + : QueueBase<S>(STATE_ORDER_QUEUE), front_(0), back_(kNoStateId) {} + + StateId Head() const { return front_; } + + void Enqueue(StateId s) { + if (front_ > back_) front_ = back_ = s; + else if (s > back_) back_ = s; + else if (s < front_) front_ = s; + while (enqueued_.size() <= s) enqueued_.push_back(false); + enqueued_[s] = true; + } + + void Dequeue() { + enqueued_[front_] = false; + while ((front_ <= back_) && (enqueued_[front_] == false)) ++front_; + } + + void Update(StateId s) {} + + bool Empty() const { return front_ > back_; } + + void Clear() { + for (StateId i = front_; i <= back_; ++i) enqueued_[i] = false; + front_ = 0; + back_ = kNoStateId; + } + +private: + StateId front_; + StateId back_; + vector<bool> enqueued_; + + // This allows base-class virtual access to non-virtual derived- + // class members of the same name. It makes the derived class more + // efficient to use but unsafe to further derive. + virtual StateId Head_() const { return Head(); } + virtual void Enqueue_(StateId s) { Enqueue(s); } + virtual void Dequeue_() { Dequeue(); } + virtual void Update_(StateId s) { Update(s); } + virtual bool Empty_() const { return Empty(); } + virtual void Clear_() { return Clear(); } + +}; + + +// SCC topological-order meta-queue discipline, templated on the StateId S +// and a queue Q, which is used inside each SCC. It visits the SCC's +// of an FST in topological order. Its constructor is passed the queues to +// to use within an SCC. +template <class S, class Q> +class SccQueue : public QueueBase<S> { + public: + typedef S StateId; + typedef Q Queue; + + // Constructor takes a vector specifying the SCC number per state + // and a vector giving the queue to use per SCC number. + SccQueue(const vector<StateId> &scc, vector<Queue*> *queue) + : QueueBase<S>(SCC_QUEUE), queue_(queue), scc_(scc), front_(0), + back_(kNoStateId) {} + + StateId Head() const { + while ((front_ <= back_) && + (((*queue_)[front_] && (*queue_)[front_]->Empty()) + || (((*queue_)[front_] == 0) && + ((front_ >= trivial_queue_.size()) + || (trivial_queue_[front_] == kNoStateId))))) + ++front_; + if ((*queue_)[front_]) + return (*queue_)[front_]->Head(); + else + return trivial_queue_[front_]; + } + + void Enqueue(StateId s) { + if (front_ > back_) front_ = back_ = scc_[s]; + else if (scc_[s] > back_) back_ = scc_[s]; + else if (scc_[s] < front_) front_ = scc_[s]; + if ((*queue_)[scc_[s]]) { + (*queue_)[scc_[s]]->Enqueue(s); + } else { + while (trivial_queue_.size() <= scc_[s]) + trivial_queue_.push_back(kNoStateId); + trivial_queue_[scc_[s]] = s; + } + } + + void Dequeue() { + if ((*queue_)[front_]) + (*queue_)[front_]->Dequeue(); + else if (front_ < trivial_queue_.size()) + trivial_queue_[front_] = kNoStateId; + } + + void Update(StateId s) { + if ((*queue_)[scc_[s]]) + (*queue_)[scc_[s]]->Update(s); + } + + bool Empty() const { + if (front_ < back_) // Queue scc # back_ not empty unless back_==front_ + return false; + else if (front_ > back_) + return true; + else if ((*queue_)[front_]) + return (*queue_)[front_]->Empty(); + else + return (front_ >= trivial_queue_.size()) + || (trivial_queue_[front_] == kNoStateId); + } + + void Clear() { + for (StateId i = front_; i <= back_; ++i) + if ((*queue_)[i]) + (*queue_)[i]->Clear(); + else if (i < trivial_queue_.size()) + trivial_queue_[i] = kNoStateId; + front_ = 0; + back_ = kNoStateId; + } + +private: + vector<Queue*> *queue_; + const vector<StateId> &scc_; + mutable StateId front_; + StateId back_; + vector<StateId> trivial_queue_; + + // This allows base-class virtual access to non-virtual derived- + // class members of the same name. It makes the derived class more + // efficient to use but unsafe to further derive. + virtual StateId Head_() const { return Head(); } + virtual void Enqueue_(StateId s) { Enqueue(s); } + virtual void Dequeue_() { Dequeue(); } + virtual void Update_(StateId s) { Update(s); } + virtual bool Empty_() const { return Empty(); } + virtual void Clear_() { return Clear(); } + + DISALLOW_COPY_AND_ASSIGN(SccQueue); +}; + + +// Automatic queue discipline, templated on the StateId. It selects a +// queue discipline for a given FST based on its properties. +template <class S> +class AutoQueue : public QueueBase<S> { +public: + typedef S StateId; + + // This constructor takes a state distance vector that, if non-null and if + // the Weight type has the path property, will entertain the + // shortest-first queue using the natural order w.r.t to the distance. + template <class Arc, class ArcFilter> + AutoQueue(const Fst<Arc> &fst, const vector<typename Arc::Weight> *distance, + ArcFilter filter) : QueueBase<S>(AUTO_QUEUE) { + typedef typename Arc::Weight Weight; + typedef StateWeightCompare< StateId, NaturalLess<Weight> > Compare; + + // First check if the FST is known to have these properties. + uint64 props = fst.Properties(kAcyclic | kCyclic | + kTopSorted | kUnweighted, false); + if ((props & kTopSorted) || fst.Start() == kNoStateId) { + queue_ = new StateOrderQueue<StateId>(); + VLOG(2) << "AutoQueue: using state-order discipline"; + } else if (props & kAcyclic) { + queue_ = new TopOrderQueue<StateId>(fst, filter); + VLOG(2) << "AutoQueue: using top-order discipline"; + } else if ((props & kUnweighted) && (Weight::Properties() & kIdempotent)) { + queue_ = new LifoQueue<StateId>(); + VLOG(2) << "AutoQueue: using LIFO discipline"; + } else { + uint64 properties; + // Decompose into strongly-connected components. + SccVisitor<Arc> scc_visitor(&scc_, 0, 0, &properties); + DfsVisit(fst, &scc_visitor, filter); + StateId nscc = *max_element(scc_.begin(), scc_.end()) + 1; + vector<QueueType> queue_types(nscc); + NaturalLess<Weight> *less = 0; + Compare *comp = 0; + if (distance && (Weight::Properties() & kPath)) { + less = new NaturalLess<Weight>; + comp = new Compare(*distance, *less); + } + // Find the queue type to use per SCC. + bool unweighted; + bool all_trivial; + SccQueueType(fst, scc_, &queue_types, filter, less, &all_trivial, + &unweighted); + // If unweighted and semiring is idempotent, use lifo queue. + if (unweighted) { + queue_ = new LifoQueue<StateId>(); + VLOG(2) << "AutoQueue: using LIFO discipline"; + delete comp; + delete less; + return; + } + // If all the scc are trivial, FST is acyclic and the scc# gives + // the topological order. + if (all_trivial) { + queue_ = new TopOrderQueue<StateId>(scc_); + VLOG(2) << "AutoQueue: using top-order discipline"; + delete comp; + delete less; + return; + } + VLOG(2) << "AutoQueue: using SCC meta-discipline"; + queues_.resize(nscc); + for (StateId i = 0; i < nscc; ++i) { + switch(queue_types[i]) { + case TRIVIAL_QUEUE: + queues_[i] = 0; + VLOG(3) << "AutoQueue: SCC #" << i + << ": using trivial discipline"; + break; + case SHORTEST_FIRST_QUEUE: + queues_[i] = new ShortestFirstQueue<StateId, Compare, false>(*comp); + VLOG(3) << "AutoQueue: SCC #" << i << + ": using shortest-first discipline"; + break; + case LIFO_QUEUE: + queues_[i] = new LifoQueue<StateId>(); + VLOG(3) << "AutoQueue: SCC #" << i + << ": using LIFO disciplle"; + break; + case FIFO_QUEUE: + default: + queues_[i] = new FifoQueue<StateId>(); + VLOG(3) << "AutoQueue: SCC #" << i + << ": using FIFO disciplle"; + break; + } + } + queue_ = new SccQueue< StateId, QueueBase<StateId> >(scc_, &queues_); + delete comp; + delete less; + } + } + + ~AutoQueue() { + for (StateId i = 0; i < queues_.size(); ++i) + delete queues_[i]; + delete queue_; + } + + StateId Head() const { return queue_->Head(); } + + void Enqueue(StateId s) { queue_->Enqueue(s); } + + void Dequeue() { queue_->Dequeue(); } + + void Update(StateId s) { queue_->Update(s); } + + bool Empty() const { return queue_->Empty(); } + + void Clear() { queue_->Clear(); } + + + private: + QueueBase<StateId> *queue_; + vector< QueueBase<StateId>* > queues_; + vector<StateId> scc_; + + template <class Arc, class ArcFilter, class Less> + static void SccQueueType(const Fst<Arc> &fst, + const vector<StateId> &scc, + vector<QueueType> *queue_types, + ArcFilter filter, Less *less, + bool *all_trivial, bool *unweighted); + + // This allows base-class virtual access to non-virtual derived- + // class members of the same name. It makes the derived class more + // efficient to use but unsafe to further derive. + virtual StateId Head_() const { return Head(); } + + virtual void Enqueue_(StateId s) { Enqueue(s); } + + virtual void Dequeue_() { Dequeue(); } + + virtual void Update_(StateId s) { Update(s); } + + virtual bool Empty_() const { return Empty(); } + + virtual void Clear_() { return Clear(); } + + DISALLOW_COPY_AND_ASSIGN(AutoQueue); +}; + + +// Examines the states in an Fst's strongly connected components and +// determines which type of queue to use per SCC. Stores result in +// vector QUEUE_TYPES, which is assumed to have length equal to the +// number of SCCs. An arc filter is used to limit the transitions +// considered (e.g., only the epsilon graph). ALL_TRIVIAL is set +// to true if every queue is the trivial queue. UNWEIGHTED is set to +// true if the semiring is idempotent and all the arc weights are equal to +// Zero() or One(). +template <class StateId> +template <class A, class ArcFilter, class Less> +void AutoQueue<StateId>::SccQueueType(const Fst<A> &fst, + const vector<StateId> &scc, + vector<QueueType> *queue_type, + ArcFilter filter, Less *less, + bool *all_trivial, bool *unweighted) { + typedef A Arc; + typedef typename A::StateId StateId; + typedef typename A::Weight Weight; + + *all_trivial = true; + *unweighted = true; + + for (StateId i = 0; i < queue_type->size(); ++i) + (*queue_type)[i] = TRIVIAL_QUEUE; + + for (StateIterator< Fst<Arc> > sit(fst); !sit.Done(); sit.Next()) { + StateId state = sit.Value(); + for (ArcIterator< Fst<Arc> > ait(fst, state); + !ait.Done(); + ait.Next()) { + const Arc &arc = ait.Value(); + if (!filter(arc)) continue; + if (scc[state] == scc[arc.nextstate]) { + QueueType &type = (*queue_type)[scc[state]]; + if (!less || ((*less)(arc.weight, Weight::One()))) + type = FIFO_QUEUE; + else if ((type == TRIVIAL_QUEUE) || (type == LIFO_QUEUE)) { + if (!(Weight::Properties() & kIdempotent) || + (arc.weight != Weight::Zero() && arc.weight != Weight::One())) + type = SHORTEST_FIRST_QUEUE; + else + type = LIFO_QUEUE; + } + if (type != TRIVIAL_QUEUE) *all_trivial = false; + } + if (!(Weight::Properties() & kIdempotent) || + (arc.weight != Weight::Zero() && arc.weight != Weight::One())) + *unweighted = false; + } + } +} + + +// An A* estimate is a function object that maps from a state ID to a +// an estimate of the shortest distance to the final states. +// The trivial A* estimate is always One(). +template <typename S, typename W> +struct TrivialAStarEstimate { + W operator()(S s) const { return W::One(); } +}; + + +// Given a vector that maps from states to weights representing the +// shortest distance from the initial state, a Less comparison +// function object between weights, and an estimate E of the +// shortest distance to the final states, this class defines a +// comparison function object between states. +template <typename S, typename L, typename E> +class AStarWeightCompare { + public: + typedef L Less; + typedef typename L::Weight Weight; + typedef S StateId; + + AStarWeightCompare(const vector<Weight>& weights, const L &less, + const E &estimate) + : weights_(weights), less_(less), estimate_(estimate) {} + + bool operator()(const S x, const S y) const { + Weight wx = Times(weights_[x], estimate_(x)); + Weight wy = Times(weights_[y], estimate_(y)); + return less_(wx, wy); + } + + private: + const vector<Weight>& weights_; + L less_; + const E &estimate_; +}; + + +// A* queue discipline, templated on the StateId, Weight and an +// estimate E of the shortest distance to the final states, is specialized +// to use the weight's natural order for the comparison function. +template <typename S, typename W, typename E> +class NaturalAStarQueue : + public ShortestFirstQueue<S, AStarWeightCompare<S, NaturalLess<W>, E> > { + public: + typedef AStarWeightCompare<S, NaturalLess<W>, E> C; + + NaturalAStarQueue(const vector<W> &distance, const E &estimate) : + ShortestFirstQueue<S, C>(C(distance, less_, estimate)) {} + + private: + NaturalLess<W> less_; +}; + + +// A state equivalence class is a function object that +// maps from a state ID to an equivalence class (state) ID. +// The trivial equivalence class maps a state to itself. +template <typename S> +struct TrivialStateEquivClass { + S operator()(S s) const { return s; } +}; + + +// Distance-based pruning queue discipline: Enqueues a state 's' +// only when its shortest distance (so far), as specified by +// 'distance', is less than (as specified by 'comp') the shortest +// distance Times() the 'threshold' to any state in the same +// equivalence class, as specified by the function object +// 'class_func'. The underlying queue discipline is specified by +// 'queue'. The ownership of 'queue' is given to this class. +template <typename Q, typename L, typename C> +class PruneQueue : public QueueBase<typename Q::StateId> { + public: + typedef typename Q::StateId StateId; + typedef typename L::Weight Weight; + + PruneQueue(const vector<Weight> &distance, Q *queue, L comp, + const C &class_func, Weight threshold) + : QueueBase<StateId>(OTHER_QUEUE), + distance_(distance), + queue_(queue), + less_(comp), + class_func_(class_func), + threshold_(threshold) {} + + ~PruneQueue() { delete queue_; } + + StateId Head() const { return queue_->Head(); } + + void Enqueue(StateId s) { + StateId c = class_func_(s); + if (c >= class_distance_.size()) + class_distance_.resize(c + 1, Weight::Zero()); + if (less_(distance_[s], class_distance_[c])) + class_distance_[c] = distance_[s]; + + // Enqueue only if below threshold limit + Weight limit = Times(class_distance_[c], threshold_); + if (less_(distance_[s], limit)) + queue_->Enqueue(s); + } + + void Dequeue() { queue_->Dequeue(); } + + void Update(StateId s) { + StateId c = class_func_(s); + if (less_(distance_[s], class_distance_[c])) + class_distance_[c] = distance_[s]; + queue_->Update(s); + } + + bool Empty() const { return queue_->Empty(); } + void Clear() { queue_->Clear(); } + + private: + // This allows base-class virtual access to non-virtual derived- + // class members of the same name. It makes the derived class more + // efficient to use but unsafe to further derive. + virtual StateId Head_() const { return Head(); } + virtual void Enqueue_(StateId s) { Enqueue(s); } + virtual void Dequeue_() { Dequeue(); } + virtual void Update_(StateId s) { Update(s); } + virtual bool Empty_() const { return Empty(); } + virtual void Clear_() { return Clear(); } + + const vector<Weight> &distance_; // shortest distance to state + Q *queue_; + L less_; + const C &class_func_; // eqv. class function object + Weight threshold_; // pruning weight threshold + vector<Weight> class_distance_; // shortest distance to class + + DISALLOW_COPY_AND_ASSIGN(PruneQueue); +}; + + +// Pruning queue discipline (see above) using the weight's natural +// order for the comparison function. The ownership of 'queue' is +// given to this class. +template <typename Q, typename W, typename C> +class NaturalPruneQueue : + public PruneQueue<Q, NaturalLess<W>, C> { + public: + typedef typename Q::StateId StateId; + typedef W Weight; + + NaturalPruneQueue(const vector<W> &distance, Q *queue, + const C &class_func_, Weight threshold) : + PruneQueue<Q, NaturalLess<W>, C>(distance, queue, less_, + class_func_, threshold) {} + + private: + NaturalLess<W> less_; +}; + + +// Filter-based pruning queue discipline: Enqueues a state 's' only +// if allowed by the filter, specified by the function object 'state_filter'. +// The underlying queue discipline is specified by 'queue'. The ownership +// of 'queue' is given to this class. +template <typename Q, typename F> +class FilterQueue : public QueueBase<typename Q::StateId> { + public: + typedef typename Q::StateId StateId; + + FilterQueue(Q *queue, const F &state_filter) + : QueueBase<StateId>(OTHER_QUEUE), + queue_(queue), + state_filter_(state_filter) {} + + ~FilterQueue() { delete queue_; } + + StateId Head() const { return queue_->Head(); } + + // Enqueues only if allowed by state filter. + void Enqueue(StateId s) { + if (state_filter_(s)) { + queue_->Enqueue(s); + } + } + + void Dequeue() { queue_->Dequeue(); } + + void Update(StateId s) {} + bool Empty() const { return queue_->Empty(); } + void Clear() { queue_->Clear(); } + + private: + // This allows base-class virtual access to non-virtual derived- + // class members of the same name. It makes the derived class more + // efficient to use but unsafe to further derive. + virtual StateId Head_() const { return Head(); } + virtual void Enqueue_(StateId s) { Enqueue(s); } + virtual void Dequeue_() { Dequeue(); } + virtual void Update_(StateId s) { Update(s); } + virtual bool Empty_() const { return Empty(); } + virtual void Clear_() { return Clear(); } + + Q *queue_; + const F &state_filter_; // Filter to prune states + + DISALLOW_COPY_AND_ASSIGN(FilterQueue); +}; + +} // namespace fst + +#endif diff --git a/kaldi_io/src/tools/openfst/include/fst/randequivalent.h b/kaldi_io/src/tools/openfst/include/fst/randequivalent.h new file mode 100644 index 0000000..1aaccf7 --- /dev/null +++ b/kaldi_io/src/tools/openfst/include/fst/randequivalent.h @@ -0,0 +1,135 @@ +// randequivalent.h + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Copyright 2005-2010 Google, Inc. +// Author: [email protected] (Cyril Allauzen) +// +// \file +// Tests if two FSTS are equivalent by checking if random +// strings from one FST are transduced the same by both FSTs. + +#ifndef FST_RANDEQUIVALENT_H__ +#define FST_RANDEQUIVALENT_H__ + +#include <fst/arcsort.h> +#include <fst/compose.h> +#include <fst/project.h> +#include <fst/randgen.h> +#include <fst/shortest-distance.h> +#include <fst/vector-fst.h> + + +namespace fst { + +// Test if two FSTs are equivalent by randomly generating 'num_paths' +// paths (as specified by the RandGenOptions 'opts') in these FSTs. +// +// For each randomly generated path, the algorithm computes for each +// of the two FSTs the sum of the weights of all the successful paths +// sharing the same input and output labels as the considered randomly +// generated path and checks that these two values are within +// 'delta'. Returns optional error value (when FLAGS_error_fatal = false). +template<class Arc, class ArcSelector> +bool RandEquivalent(const Fst<Arc> &fst1, const Fst<Arc> &fst2, + ssize_t num_paths, float delta, + const RandGenOptions<ArcSelector> &opts, + bool *error = 0) { + typedef typename Arc::Weight Weight; + if (error) *error = false; + + // Check that the symbol table are compatible + if (!CompatSymbols(fst1.InputSymbols(), fst2.InputSymbols()) || + !CompatSymbols(fst1.OutputSymbols(), fst2.OutputSymbols())) { + FSTERROR() << "RandEquivalent: input/output symbol tables of 1st " + << "argument do not match input/output symbol tables of 2nd " + << "argument"; + if (error) *error = true; + return false; + } + + ILabelCompare<Arc> icomp; + OLabelCompare<Arc> ocomp; + VectorFst<Arc> sfst1(fst1); + VectorFst<Arc> sfst2(fst2); + Connect(&sfst1); + Connect(&sfst2); + ArcSort(&sfst1, icomp); + ArcSort(&sfst2, icomp); + + bool ret = true; + for (ssize_t n = 0; n < num_paths; ++n) { + VectorFst<Arc> path; + const Fst<Arc> &fst = rand() % 2 ? sfst1 : sfst2; + RandGen(fst, &path, opts); + + VectorFst<Arc> ipath(path); + VectorFst<Arc> opath(path); + Project(&ipath, PROJECT_INPUT); + Project(&opath, PROJECT_OUTPUT); + + VectorFst<Arc> cfst1, pfst1; + Compose(ipath, sfst1, &cfst1); + ArcSort(&cfst1, ocomp); + Compose(cfst1, opath, &pfst1); + // Give up if there are epsilon cycles in a non-idempotent semiring + if (!(Weight::Properties() & kIdempotent) && + pfst1.Properties(kCyclic, true)) + continue; + Weight sum1 = ShortestDistance(pfst1); + + VectorFst<Arc> cfst2, pfst2; + Compose(ipath, sfst2, &cfst2); + ArcSort(&cfst2, ocomp); + Compose(cfst2, opath, &pfst2); + // Give up if there are epsilon cycles in a non-idempotent semiring + if (!(Weight::Properties() & kIdempotent) && + pfst2.Properties(kCyclic, true)) + continue; + Weight sum2 = ShortestDistance(pfst2); + + if (!ApproxEqual(sum1, sum2, delta)) { + VLOG(1) << "Sum1 = " << sum1; + VLOG(1) << "Sum2 = " << sum2; + ret = false; + break; + } + } + + if (fst1.Properties(kError, false) || fst2.Properties(kError, false)) { + if (error) *error = true; + return false; + } + + return ret; +} + + +// Test if two FSTs are equivalent by randomly generating 'num_paths' paths +// of length no more than 'path_length' using the seed 'seed' in these FSTs. +// Returns optional error value (when FLAGS_error_fatal = false). +template <class Arc> +bool RandEquivalent(const Fst<Arc> &fst1, const Fst<Arc> &fst2, + ssize_t num_paths, float delta = kDelta, + int seed = time(0), int path_length = INT_MAX, + bool *error = 0) { + UniformArcSelector<Arc> uniform_selector(seed); + RandGenOptions< UniformArcSelector<Arc> > + opts(uniform_selector, path_length); + return RandEquivalent(fst1, fst2, num_paths, delta, opts, error); +} + + +} // namespace fst + +#endif // FST_LIB_RANDEQUIVALENT_H__ diff --git a/kaldi_io/src/tools/openfst/include/fst/randgen.h b/kaldi_io/src/tools/openfst/include/fst/randgen.h new file mode 100644 index 0000000..82ddffa --- /dev/null +++ b/kaldi_io/src/tools/openfst/include/fst/randgen.h @@ -0,0 +1,712 @@ +// randgen.h + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Copyright 2005-2010 Google, Inc. +// Author: [email protected] (Michael Riley) +// +// \file +// Classes and functions to generate random paths through an FST. + +#ifndef FST_LIB_RANDGEN_H__ +#define FST_LIB_RANDGEN_H__ + +#include <cmath> +#include <cstdlib> +#include <ctime> +#include <map> + +#include <fst/accumulator.h> +#include <fst/cache.h> +#include <fst/dfs-visit.h> +#include <fst/mutable-fst.h> + +namespace fst { + +// +// ARC SELECTORS - these function objects are used to select a random +// transition to take from an FST's state. They should return a number +// N s.t. 0 <= N <= NumArcs(). If N < NumArcs(), then the N-th +// transition is selected. If N == NumArcs(), then the final weight at +// that state is selected (i.e., the 'super-final' transition is selected). +// It can be assumed these will not be called unless either there +// are transitions leaving the state and/or the state is final. +// + +// Randomly selects a transition using the uniform distribution. +template <class A> +struct UniformArcSelector { + typedef typename A::StateId StateId; + typedef typename A::Weight Weight; + + UniformArcSelector(int seed = time(0)) { srand(seed); } + + size_t operator()(const Fst<A> &fst, StateId s) const { + double r = rand()/(RAND_MAX + 1.0); + size_t n = fst.NumArcs(s); + if (fst.Final(s) != Weight::Zero()) + ++n; + return static_cast<size_t>(r * n); + } +}; + + +// Randomly selects a transition w.r.t. the weights treated as negative +// log probabilities after normalizing for the total weight leaving +// the state. Weight::zero transitions are disregarded. +// Assumes Weight::Value() accesses the floating point +// representation of the weight. +template <class A> +class LogProbArcSelector { + public: + typedef typename A::StateId StateId; + typedef typename A::Weight Weight; + + LogProbArcSelector(int seed = time(0)) { srand(seed); } + + size_t operator()(const Fst<A> &fst, StateId s) const { + // Find total weight leaving state + double sum = 0.0; + for (ArcIterator< Fst<A> > aiter(fst, s); !aiter.Done(); + aiter.Next()) { + const A &arc = aiter.Value(); + sum += exp(-to_log_weight_(arc.weight).Value()); + } + sum += exp(-to_log_weight_(fst.Final(s)).Value()); + + double r = rand()/(RAND_MAX + 1.0); + double p = 0.0; + int n = 0; + for (ArcIterator< Fst<A> > aiter(fst, s); !aiter.Done(); + aiter.Next(), ++n) { + const A &arc = aiter.Value(); + p += exp(-to_log_weight_(arc.weight).Value()); + if (p > r * sum) return n; + } + return n; + } + + private: + WeightConvert<Weight, Log64Weight> to_log_weight_; +}; + +// Convenience definitions +typedef LogProbArcSelector<StdArc> StdArcSelector; +typedef LogProbArcSelector<LogArc> LogArcSelector; + + +// Same as LogProbArcSelector but use CacheLogAccumulator to cache +// the cummulative weight computations. +template <class A> +class FastLogProbArcSelector : public LogProbArcSelector<A> { + public: + typedef typename A::StateId StateId; + typedef typename A::Weight Weight; + using LogProbArcSelector<A>::operator(); + + FastLogProbArcSelector(int seed = time(0)) + : LogProbArcSelector<A>(seed), + seed_(seed) {} + + size_t operator()(const Fst<A> &fst, StateId s, + CacheLogAccumulator<A> *accumulator) const { + accumulator->SetState(s); + ArcIterator< Fst<A> > aiter(fst, s); + // Find total weight leaving state + double sum = to_log_weight_(accumulator->Sum(fst.Final(s), &aiter, 0, + fst.NumArcs(s))).Value(); + double r = -log(rand()/(RAND_MAX + 1.0)); + return accumulator->LowerBound(r + sum, &aiter); + } + + int Seed() const { return seed_; } + private: + int seed_; + WeightConvert<Weight, Log64Weight> to_log_weight_; +}; + +// Random path state info maintained by RandGenFst and passed to samplers. +template <typename A> +struct RandState { + typedef typename A::StateId StateId; + + StateId state_id; // current input FST state + size_t nsamples; // # of samples to be sampled at this state + size_t length; // length of path to this random state + size_t select; // previous sample arc selection + const RandState<A> *parent; // previous random state on this path + + RandState(StateId s, size_t n, size_t l, size_t k, const RandState<A> *p) + : state_id(s), nsamples(n), length(l), select(k), parent(p) {} + + RandState() + : state_id(kNoStateId), nsamples(0), length(0), select(0), parent(0) {} +}; + +// This class, given an arc selector, samples, with raplacement, +// multiple random transitions from an FST's state. This is a generic +// version with a straight-forward use of the arc selector. +// Specializations may be defined for arc selectors for greater +// efficiency or special behavior. +template <class A, class S> +class ArcSampler { + public: + typedef typename A::StateId StateId; + typedef typename A::Weight Weight; + + // The 'max_length' may be interpreted (including ignored) by a + // sampler as it chooses. This generic version interprets this literally. + ArcSampler(const Fst<A> &fst, const S &arc_selector, + int max_length = INT_MAX) + : fst_(fst), + arc_selector_(arc_selector), + max_length_(max_length) {} + + // Allow updating Fst argument; pass only if changed. + ArcSampler(const ArcSampler<A, S> &sampler, const Fst<A> *fst = 0) + : fst_(fst ? *fst : sampler.fst_), + arc_selector_(sampler.arc_selector_), + max_length_(sampler.max_length_) { + Reset(); + } + + // Samples 'rstate.nsamples' from state 'state_id'. The 'rstate.length' is + // the length of the path to 'rstate'. Returns true if samples were + // collected. No samples may be collected if either there are no (including + // 'super-final') transitions leaving that state or if the + // 'max_length' has been deemed reached. Use the iterator members to + // read the samples. The samples will be in their original order. + bool Sample(const RandState<A> &rstate) { + sample_map_.clear(); + if ((fst_.NumArcs(rstate.state_id) == 0 && + fst_.Final(rstate.state_id) == Weight::Zero()) || + rstate.length == max_length_) { + Reset(); + return false; + } + + for (size_t i = 0; i < rstate.nsamples; ++i) + ++sample_map_[arc_selector_(fst_, rstate.state_id)]; + Reset(); + return true; + } + + // More samples? + bool Done() const { return sample_iter_ == sample_map_.end(); } + + // Gets the next sample. + void Next() { ++sample_iter_; } + + // Returns a pair (N, K) where 0 <= N <= NumArcs(s) and 0 < K <= nsamples. + // If N < NumArcs(s), then the N-th transition is specified. + // If N == NumArcs(s), then the final weight at that state is + // specified (i.e., the 'super-final' transition is specified). + // For the specified transition, K repetitions have been sampled. + pair<size_t, size_t> Value() const { return *sample_iter_; } + + void Reset() { sample_iter_ = sample_map_.begin(); } + + bool Error() const { return false; } + + private: + const Fst<A> &fst_; + const S &arc_selector_; + int max_length_; + + // Stores (N, K) as described for Value(). + map<size_t, size_t> sample_map_; + map<size_t, size_t>::const_iterator sample_iter_; + + // disallow + ArcSampler<A, S> & operator=(const ArcSampler<A, S> &s); +}; + + +// Specialization for FastLogProbArcSelector. +template <class A> +class ArcSampler<A, FastLogProbArcSelector<A> > { + public: + typedef FastLogProbArcSelector<A> S; + typedef typename A::StateId StateId; + typedef typename A::Weight Weight; + typedef CacheLogAccumulator<A> C; + + ArcSampler(const Fst<A> &fst, const S &arc_selector, int max_length = INT_MAX) + : fst_(fst), + arc_selector_(arc_selector), + max_length_(max_length), + accumulator_(new C()) { + accumulator_->Init(fst); + } + + ArcSampler(const ArcSampler<A, S> &sampler, const Fst<A> *fst = 0) + : fst_(fst ? *fst : sampler.fst_), + arc_selector_(sampler.arc_selector_), + max_length_(sampler.max_length_) { + if (fst) { + accumulator_ = new C(); + accumulator_->Init(*fst); + } else { // shallow copy + accumulator_ = new C(*sampler.accumulator_); + } + } + + ~ArcSampler() { + delete accumulator_; + } + + bool Sample(const RandState<A> &rstate) { + sample_map_.clear(); + if ((fst_.NumArcs(rstate.state_id) == 0 && + fst_.Final(rstate.state_id) == Weight::Zero()) || + rstate.length == max_length_) { + Reset(); + return false; + } + + for (size_t i = 0; i < rstate.nsamples; ++i) + ++sample_map_[arc_selector_(fst_, rstate.state_id, accumulator_)]; + Reset(); + return true; + } + + bool Done() const { return sample_iter_ == sample_map_.end(); } + void Next() { ++sample_iter_; } + pair<size_t, size_t> Value() const { return *sample_iter_; } + void Reset() { sample_iter_ = sample_map_.begin(); } + + bool Error() const { return accumulator_->Error(); } + + private: + const Fst<A> &fst_; + const S &arc_selector_; + int max_length_; + + // Stores (N, K) as described for Value(). + map<size_t, size_t> sample_map_; + map<size_t, size_t>::const_iterator sample_iter_; + C *accumulator_; + + // disallow + ArcSampler<A, S> & operator=(const ArcSampler<A, S> &s); +}; + + +// Options for random path generation with RandGenFst. The template argument +// is an arc sampler, typically class 'ArcSampler' above. Ownership of +// the sampler is taken by RandGenFst. +template <class S> +struct RandGenFstOptions : public CacheOptions { + S *arc_sampler; // How to sample transitions at a state + size_t npath; // # of paths to generate + bool weighted; // Output tree weighted by path count; o.w. + // output unweighted DAG + bool remove_total_weight; // Remove total weight when output is weighted. + + RandGenFstOptions(const CacheOptions &copts, S *samp, + size_t n = 1, bool w = true, bool rw = false) + : CacheOptions(copts), + arc_sampler(samp), + npath(n), + weighted(w), + remove_total_weight(rw) {} +}; + + +// Implementation of RandGenFst. +template <class A, class B, class S> +class RandGenFstImpl : public CacheImpl<B> { + public: + using FstImpl<B>::SetType; + using FstImpl<B>::SetProperties; + using FstImpl<B>::SetInputSymbols; + using FstImpl<B>::SetOutputSymbols; + + using CacheBaseImpl< CacheState<B> >::AddArc; + using CacheBaseImpl< CacheState<B> >::HasArcs; + using CacheBaseImpl< CacheState<B> >::HasFinal; + using CacheBaseImpl< CacheState<B> >::HasStart; + using CacheBaseImpl< CacheState<B> >::SetArcs; + using CacheBaseImpl< CacheState<B> >::SetFinal; + using CacheBaseImpl< CacheState<B> >::SetStart; + + typedef B Arc; + typedef typename A::Label Label; + typedef typename A::Weight Weight; + typedef typename A::StateId StateId; + + RandGenFstImpl(const Fst<A> &fst, const RandGenFstOptions<S> &opts) + : CacheImpl<B>(opts), + fst_(fst.Copy()), + arc_sampler_(opts.arc_sampler), + npath_(opts.npath), + weighted_(opts.weighted), + remove_total_weight_(opts.remove_total_weight), + superfinal_(kNoLabel) { + SetType("randgen"); + + uint64 props = fst.Properties(kFstProperties, false); + SetProperties(RandGenProperties(props, weighted_), kCopyProperties); + + SetInputSymbols(fst.InputSymbols()); + SetOutputSymbols(fst.OutputSymbols()); + } + + RandGenFstImpl(const RandGenFstImpl &impl) + : CacheImpl<B>(impl), + fst_(impl.fst_->Copy(true)), + arc_sampler_(new S(*impl.arc_sampler_, fst_)), + npath_(impl.npath_), + weighted_(impl.weighted_), + superfinal_(kNoLabel) { + SetType("randgen"); + SetProperties(impl.Properties(), kCopyProperties); + SetInputSymbols(impl.InputSymbols()); + SetOutputSymbols(impl.OutputSymbols()); + } + + ~RandGenFstImpl() { + for (int i = 0; i < state_table_.size(); ++i) + delete state_table_[i]; + delete fst_; + delete arc_sampler_; + } + + StateId Start() { + if (!HasStart()) { + StateId s = fst_->Start(); + if (s == kNoStateId) + return kNoStateId; + StateId start = state_table_.size(); + SetStart(start); + RandState<A> *rstate = new RandState<A>(s, npath_, 0, 0, 0); + state_table_.push_back(rstate); + } + return CacheImpl<B>::Start(); + } + + Weight Final(StateId s) { + if (!HasFinal(s)) { + Expand(s); + } + return CacheImpl<B>::Final(s); + } + + size_t NumArcs(StateId s) { + if (!HasArcs(s)) { + Expand(s); + } + return CacheImpl<B>::NumArcs(s); + } + + size_t NumInputEpsilons(StateId s) { + if (!HasArcs(s)) + Expand(s); + return CacheImpl<B>::NumInputEpsilons(s); + } + + size_t NumOutputEpsilons(StateId s) { + if (!HasArcs(s)) + Expand(s); + return CacheImpl<B>::NumOutputEpsilons(s); + } + + uint64 Properties() const { return Properties(kFstProperties); } + + // Set error if found; return FST impl properties. + uint64 Properties(uint64 mask) const { + if ((mask & kError) && + (fst_->Properties(kError, false) || arc_sampler_->Error())) { + SetProperties(kError, kError); + } + return FstImpl<Arc>::Properties(mask); + } + + void InitArcIterator(StateId s, ArcIteratorData<B> *data) { + if (!HasArcs(s)) + Expand(s); + CacheImpl<B>::InitArcIterator(s, data); + } + + // Computes the outgoing transitions from a state, creating new destination + // states as needed. + void Expand(StateId s) { + if (s == superfinal_) { + SetFinal(s, Weight::One()); + SetArcs(s); + return; + } + + SetFinal(s, Weight::Zero()); + const RandState<A> &rstate = *state_table_[s]; + arc_sampler_->Sample(rstate); + ArcIterator< Fst<A> > aiter(*fst_, rstate.state_id); + size_t narcs = fst_->NumArcs(rstate.state_id); + for (;!arc_sampler_->Done(); arc_sampler_->Next()) { + const pair<size_t, size_t> &sample_pair = arc_sampler_->Value(); + size_t pos = sample_pair.first; + size_t count = sample_pair.second; + double prob = static_cast<double>(count)/rstate.nsamples; + if (pos < narcs) { // regular transition + aiter.Seek(sample_pair.first); + const A &aarc = aiter.Value(); + Weight weight = weighted_ ? to_weight_(-log(prob)) : Weight::One(); + B barc(aarc.ilabel, aarc.olabel, weight, state_table_.size()); + AddArc(s, barc); + RandState<A> *nrstate = + new RandState<A>(aarc.nextstate, count, rstate.length + 1, + pos, &rstate); + state_table_.push_back(nrstate); + } else { // super-final transition + if (weighted_) { + Weight weight = remove_total_weight_ ? + to_weight_(-log(prob)) : to_weight_(-log(prob * npath_)); + SetFinal(s, weight); + } else { + if (superfinal_ == kNoLabel) { + superfinal_ = state_table_.size(); + RandState<A> *nrstate = new RandState<A>(kNoStateId, 0, 0, 0, 0); + state_table_.push_back(nrstate); + } + for (size_t n = 0; n < count; ++n) { + B barc(0, 0, Weight::One(), superfinal_); + AddArc(s, barc); + } + } + } + } + SetArcs(s); + } + + private: + Fst<A> *fst_; + S *arc_sampler_; + size_t npath_; + vector<RandState<A> *> state_table_; + bool weighted_; + bool remove_total_weight_; + StateId superfinal_; + WeightConvert<Log64Weight, Weight> to_weight_; + + void operator=(const RandGenFstImpl<A, B, S> &); // disallow +}; + + +// Fst class to randomly generate paths through an FST; details controlled +// by RandGenOptionsFst. Output format is a tree weighted by the +// path count. +template <class A, class B, class S> +class RandGenFst : public ImplToFst< RandGenFstImpl<A, B, S> > { + public: + friend class ArcIterator< RandGenFst<A, B, S> >; + friend class StateIterator< RandGenFst<A, B, S> >; + typedef B Arc; + typedef S Sampler; + typedef typename A::Label Label; + typedef typename A::Weight Weight; + typedef typename A::StateId StateId; + typedef CacheState<B> State; + typedef RandGenFstImpl<A, B, S> Impl; + + RandGenFst(const Fst<A> &fst, const RandGenFstOptions<S> &opts) + : ImplToFst<Impl>(new Impl(fst, opts)) {} + + // See Fst<>::Copy() for doc. + RandGenFst(const RandGenFst<A, B, S> &fst, bool safe = false) + : ImplToFst<Impl>(fst, safe) {} + + // Get a copy of this RandGenFst. See Fst<>::Copy() for further doc. + virtual RandGenFst<A, B, S> *Copy(bool safe = false) const { + return new RandGenFst<A, B, S>(*this, safe); + } + + virtual inline void InitStateIterator(StateIteratorData<B> *data) const; + + virtual void InitArcIterator(StateId s, ArcIteratorData<B> *data) const { + GetImpl()->InitArcIterator(s, data); + } + + private: + // Makes visible to friends. + Impl *GetImpl() const { return ImplToFst<Impl>::GetImpl(); } + + void operator=(const RandGenFst<A, B, S> &fst); // Disallow +}; + + + +// Specialization for RandGenFst. +template <class A, class B, class S> +class StateIterator< RandGenFst<A, B, S> > + : public CacheStateIterator< RandGenFst<A, B, S> > { + public: + explicit StateIterator(const RandGenFst<A, B, S> &fst) + : CacheStateIterator< RandGenFst<A, B, S> >(fst, fst.GetImpl()) {} + + private: + DISALLOW_COPY_AND_ASSIGN(StateIterator); +}; + + +// Specialization for RandGenFst. +template <class A, class B, class S> +class ArcIterator< RandGenFst<A, B, S> > + : public CacheArcIterator< RandGenFst<A, B, S> > { + public: + typedef typename A::StateId StateId; + + ArcIterator(const RandGenFst<A, B, S> &fst, StateId s) + : CacheArcIterator< RandGenFst<A, B, S> >(fst.GetImpl(), s) { + if (!fst.GetImpl()->HasArcs(s)) + fst.GetImpl()->Expand(s); + } + + private: + DISALLOW_COPY_AND_ASSIGN(ArcIterator); +}; + + +template <class A, class B, class S> inline +void RandGenFst<A, B, S>::InitStateIterator(StateIteratorData<B> *data) const +{ + data->base = new StateIterator< RandGenFst<A, B, S> >(*this); +} + +// Options for random path generation. +template <class S> +struct RandGenOptions { + const S &arc_selector; // How an arc is selected at a state + int max_length; // Maximum path length + size_t npath; // # of paths to generate + bool weighted; // Output is tree weighted by path count; o.w. + // output unweighted union of paths. + bool remove_total_weight; // Remove total weight when output is weighted. + + RandGenOptions(const S &sel, int len = INT_MAX, size_t n = 1, + bool w = false, bool rw = false) + : arc_selector(sel), + max_length(len), + npath(n), + weighted(w), + remove_total_weight(rw) {} +}; + + +template <class IArc, class OArc> +class RandGenVisitor { + public: + typedef typename IArc::Weight Weight; + typedef typename IArc::StateId StateId; + + RandGenVisitor(MutableFst<OArc> *ofst) : ofst_(ofst) {} + + void InitVisit(const Fst<IArc> &ifst) { + ifst_ = &ifst; + + ofst_->DeleteStates(); + ofst_->SetInputSymbols(ifst.InputSymbols()); + ofst_->SetOutputSymbols(ifst.OutputSymbols()); + if (ifst.Properties(kError, false)) + ofst_->SetProperties(kError, kError); + path_.clear(); + } + + bool InitState(StateId s, StateId root) { return true; } + + bool TreeArc(StateId s, const IArc &arc) { + if (ifst_->Final(arc.nextstate) == Weight::Zero()) { + path_.push_back(arc); + } else { + OutputPath(); + } + return true; + } + + bool BackArc(StateId s, const IArc &arc) { + FSTERROR() << "RandGenVisitor: cyclic input"; + ofst_->SetProperties(kError, kError); + return false; + } + + bool ForwardOrCrossArc(StateId s, const IArc &arc) { + OutputPath(); + return true; + } + + void FinishState(StateId s, StateId p, const IArc *) { + if (p != kNoStateId && ifst_->Final(s) == Weight::Zero()) + path_.pop_back(); + } + + void FinishVisit() {} + + private: + void OutputPath() { + if (ofst_->Start() == kNoStateId) { + StateId start = ofst_->AddState(); + ofst_->SetStart(start); + } + + StateId src = ofst_->Start(); + for (size_t i = 0; i < path_.size(); ++i) { + StateId dest = ofst_->AddState(); + OArc arc(path_[i].ilabel, path_[i].olabel, Weight::One(), dest); + ofst_->AddArc(src, arc); + src = dest; + } + ofst_->SetFinal(src, Weight::One()); + } + + const Fst<IArc> *ifst_; + MutableFst<OArc> *ofst_; + vector<OArc> path_; + + DISALLOW_COPY_AND_ASSIGN(RandGenVisitor); +}; + + +// Randomly generate paths through an FST; details controlled by +// RandGenOptions. +template<class IArc, class OArc, class Selector> +void RandGen(const Fst<IArc> &ifst, MutableFst<OArc> *ofst, + const RandGenOptions<Selector> &opts) { + typedef ArcSampler<IArc, Selector> Sampler; + typedef RandGenFst<IArc, OArc, Sampler> RandFst; + typedef typename OArc::StateId StateId; + typedef typename OArc::Weight Weight; + + Sampler* arc_sampler = new Sampler(ifst, opts.arc_selector, opts.max_length); + RandGenFstOptions<Sampler> fopts(CacheOptions(true, 0), arc_sampler, + opts.npath, opts.weighted, + opts.remove_total_weight); + RandFst rfst(ifst, fopts); + if (opts.weighted) { + *ofst = rfst; + } else { + RandGenVisitor<IArc, OArc> rand_visitor(ofst); + DfsVisit(rfst, &rand_visitor); + } +} + +// Randomly generate a path through an FST with the uniform distribution +// over the transitions. +template<class IArc, class OArc> +void RandGen(const Fst<IArc> &ifst, MutableFst<OArc> *ofst) { + UniformArcSelector<IArc> uniform_selector; + RandGenOptions< UniformArcSelector<IArc> > opts(uniform_selector); + RandGen(ifst, ofst, opts); +} + +} // namespace fst + +#endif // FST_LIB_RANDGEN_H__ diff --git a/kaldi_io/src/tools/openfst/include/fst/random-weight.h b/kaldi_io/src/tools/openfst/include/fst/random-weight.h new file mode 100644 index 0000000..0ccd95d --- /dev/null +++ b/kaldi_io/src/tools/openfst/include/fst/random-weight.h @@ -0,0 +1,348 @@ +// random-weight.h + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Copyright 2005-2010 Google, Inc. +// Author: [email protected] (Michael Riley) +// +// \file +// Function objects to generate random weights in various semirings +// for testing purposes. + +#ifndef FST_LIB_RANDOM_WEIGHT_H__ +#define FST_LIB_RANDOM_WEIGHT_H__ + +#include <cstdlib> +#include <ctime> +#include <vector> +using std::vector; + + +#include <fst/float-weight.h> +#include <fst/product-weight.h> +#include <fst/string-weight.h> +#include <fst/lexicographic-weight.h> +#include <fst/power-weight.h> +#include <fst/signed-log-weight.h> +#include <fst/sparse-power-weight.h> + + +namespace fst { + +// The boolean 'allow_zero' below determines whether Zero() and zero +// divisors should be returned in the random weight generation. + +// This function object returns TropicalWeightTpl<T>'s that are random integers +// chosen from [0, kNumRandomWeights). +template <class T> +class TropicalWeightGenerator_ { + public: + typedef TropicalWeightTpl<T> Weight; + + TropicalWeightGenerator_(int seed = time(0), bool allow_zero = true) + : allow_zero_(allow_zero) { + srand(seed); + } + + Weight operator() () const { + int n = rand() % (kNumRandomWeights + allow_zero_); + if (allow_zero_ && n == kNumRandomWeights) + return Weight::Zero(); + + return Weight(static_cast<T>(n)); + } + + private: + // The number of alternative random weights. + static const int kNumRandomWeights = 5; + + bool allow_zero_; // permit Zero() and zero divisors +}; + +template <class T> const int TropicalWeightGenerator_<T>::kNumRandomWeights; + +typedef TropicalWeightGenerator_<float> TropicalWeightGenerator; + + +// This function object returns LogWeightTpl<T>'s that are random integers +// chosen from [0, kNumRandomWeights). +template <class T> +class LogWeightGenerator_ { + public: + typedef LogWeightTpl<T> Weight; + + LogWeightGenerator_(int seed = time(0), bool allow_zero = true) + : allow_zero_(allow_zero) { + srand(seed); + } + + Weight operator() () const { + int n = rand() % (kNumRandomWeights + allow_zero_); + if (allow_zero_ && n == kNumRandomWeights) + return Weight::Zero(); + + return Weight(static_cast<T>(n)); + } + + private: + // Number of alternative random weights. + static const int kNumRandomWeights = 5; + + bool allow_zero_; // permit Zero() and zero divisors +}; + +template <class T> const int LogWeightGenerator_<T>::kNumRandomWeights; + +typedef LogWeightGenerator_<float> LogWeightGenerator; + + +// This function object returns MinMaxWeightTpl<T>'s that are random integers +// chosen from (-kNumRandomWeights, kNumRandomWeights) in addition to +// One(), and Zero() if zero is allowed. +template <class T> +class MinMaxWeightGenerator_ { + public: + typedef MinMaxWeightTpl<T> Weight; + + MinMaxWeightGenerator_(int seed = time(0), bool allow_zero = true) + : allow_zero_(allow_zero) { + srand(seed); + } + + Weight operator() () const { + int n = (rand() % (2*kNumRandomWeights + allow_zero_)) - kNumRandomWeights; + if (allow_zero_ && n == kNumRandomWeights) + return Weight::Zero(); + else if (n == -kNumRandomWeights) + return Weight::One(); + + return Weight(static_cast<T>(n)); + } + + private: + // Parameters controlling the number of alternative random weights. + static const int kNumRandomWeights = 5; + + bool allow_zero_; // permit Zero() and zero divisors +}; + +template <class T> const int MinMaxWeightGenerator_<T>::kNumRandomWeights; + +typedef MinMaxWeightGenerator_<float> MinMaxWeightGenerator; + + +// This function object returns StringWeights that are random integer +// strings chosen from {1,...,kAlphabetSize}^{0,kMaxStringLength} U { Zero } +template <typename L, StringType S = STRING_LEFT> +class StringWeightGenerator { + public: + typedef StringWeight<L, S> Weight; + + StringWeightGenerator(int seed = time(0), bool allow_zero = true) + : allow_zero_(allow_zero) { + srand(seed); + } + + Weight operator() () const { + int n = rand() % (kMaxStringLength + allow_zero_); + if (allow_zero_ && n == kMaxStringLength) + return Weight::Zero(); + + vector<L> v; + for (int i = 0; i < n; ++i) + v.push_back(rand() % kAlphabetSize + 1); + return Weight(v.begin(), v.end()); + } + + private: + // Alphabet size for random weights. + static const int kAlphabetSize = 5; + // Number of alternative random weights. + static const int kMaxStringLength = 5; + + bool allow_zero_; // permit Zero() and zero +}; + +template <typename L, StringType S> +const int StringWeightGenerator<L, S>::kAlphabetSize; +template <typename L, StringType S> +const int StringWeightGenerator<L, S>::kMaxStringLength; + + +// This function object returns a weight generator over the product of the +// weights (by default) for the generators G1 and G2. +template <class G1, class G2, + class W = ProductWeight<typename G1::Weight, typename G2::Weight> > +class ProductWeightGenerator { + public: + typedef typename G1::Weight W1; + typedef typename G2::Weight W2; + typedef W Weight; + + ProductWeightGenerator(int seed = time(0), bool allow_zero = true) + : generator1_(seed, allow_zero), generator2_(seed, allow_zero) {} + + Weight operator() () const { + W1 w1 = generator1_(); + W2 w2 = generator2_(); + return Weight(w1, w2); + } + + private: + G1 generator1_; + G2 generator2_; +}; + + +// This function object returns a weight generator for a lexicographic weight +// composed out of weights for the generators G1 and G2. For lexicographic +// weights, we cannot generate zeroes for the two subweights separately: +// weights are members iff both members are zero or both members are non-zero. +template <class G1, class G2> +class LexicographicWeightGenerator { + public: + typedef typename G1::Weight W1; + typedef typename G2::Weight W2; + typedef LexicographicWeight<W1, W2> Weight; + + LexicographicWeightGenerator(int seed = time(0), bool allow_zero = true) + : generator1_(seed, false), generator2_(seed, false), + allow_zero_(allow_zero) {} + + Weight operator() () const { + if (allow_zero_) { + int n = rand() % (kNumRandomWeights + allow_zero_); + if (n == kNumRandomWeights) + return Weight(W1::Zero(), W2::Zero()); + } + W1 w1 = generator1_(); + W2 w2 = generator2_(); + return Weight(w1, w2); + } + + private: + G1 generator1_; + G2 generator2_; + static const int kNumRandomWeights = 5; + bool allow_zero_; +}; + +template <class G1, class G2> +const int LexicographicWeightGenerator<G1, G2>::kNumRandomWeights; + + +// Product generator of a string weight generator and an +// arbitrary weight generator. +template <class L, class G, StringType S = STRING_LEFT> +class GallicWeightGenerator + : public ProductWeightGenerator<StringWeightGenerator<L, S>, G> { + + public: + typedef ProductWeightGenerator<StringWeightGenerator<L, S>, G> PG; + typedef typename G::Weight W; + typedef GallicWeight<L, W, S> Weight; + + GallicWeightGenerator(int seed = time(0), bool allow_zero = true) + : PG(seed, allow_zero) {} + + GallicWeightGenerator(const PG &pg) : PG(pg) {} +}; + +// This function object returms a weight generator over the catersian power +// of rank n of the weights for the generator G. +template <class G, unsigned int n> +class PowerWeightGenerator { + public: + typedef typename G::Weight W; + typedef PowerWeight<W, n> Weight; + + PowerWeightGenerator(int seed = time(0), bool allow_zero = true) + : generator_(seed, allow_zero) {} + + Weight operator()() const { + Weight w; + for (size_t i = 0; i < n; ++i) { + W r = generator_(); + w.SetValue(i, r); + } + return w; + } + + private: + G generator_; +}; + +// This function object returns SignedLogWeightTpl<T>'s that are +// random integers chosen from [0, kNumRandomWeights). +// The sign is randomly chosen as well. +template <class T> +class SignedLogWeightGenerator_ { + public: + typedef SignedLogWeightTpl<T> Weight; + + SignedLogWeightGenerator_(int seed = time(0), bool allow_zero = true) + : allow_zero_(allow_zero) { + srand(seed); + } + + Weight operator() () const { + int m = rand() % 2; + int n = rand() % (kNumRandomWeights + allow_zero_); + + return SignedLogWeightTpl<T>( + (m == 0) ? + TropicalWeight(-1.0) : + TropicalWeight(1.0), + (allow_zero_ && n == kNumRandomWeights) ? + LogWeightTpl<T>::Zero() : + LogWeightTpl<T>(static_cast<T>(n))); + } + + private: + // Number of alternative random weights. + static const int kNumRandomWeights = 5; + bool allow_zero_; // permit Zero() and zero divisors +}; + +template <class T> const int SignedLogWeightGenerator_<T>::kNumRandomWeights; + +typedef SignedLogWeightGenerator_<float> SignedLogWeightGenerator; + +// This function object returms a weight generator over the catersian power +// of rank n of the weights for the generator G. +template <class G, class K, unsigned int n> +class SparsePowerWeightGenerator { + public: + typedef typename G::Weight W; + typedef SparsePowerWeight<W, K> Weight; + + SparsePowerWeightGenerator(int seed = time(0), bool allow_zero = true) + : generator_(seed, allow_zero) {} + + Weight operator()() const { + Weight w; + for (size_t i = 1; i <= n; ++i) { + W r = generator_(); + K p = i; + w.Push(p, r, true); + } + return w; + } + + private: + G generator_; +}; + +} // namespace fst + +#endif // FST_LIB_RANDOM_WEIGHT_H__ diff --git a/kaldi_io/src/tools/openfst/include/fst/rational.h b/kaldi_io/src/tools/openfst/include/fst/rational.h new file mode 100644 index 0000000..96aa00d --- /dev/null +++ b/kaldi_io/src/tools/openfst/include/fst/rational.h @@ -0,0 +1,330 @@ +// rational.h + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Copyright 2005-2010 Google, Inc. +// Author: [email protected] (Michael Riley) +// +// \file +// An Fst implementation and base interface for delayed unions, +// concatenations and closures. + +#ifndef FST_LIB_RATIONAL_H__ +#define FST_LIB_RATIONAL_H__ + +#include <algorithm> +#include <string> +#include <vector> +using std::vector; + +#include <fst/mutable-fst.h> +#include <fst/replace.h> +#include <fst/test-properties.h> + + +namespace fst { + +typedef CacheOptions RationalFstOptions; + +// This specifies whether to add the empty string. +enum ClosureType { CLOSURE_STAR = 0, // T* -> add the empty string + CLOSURE_PLUS = 1 }; // T+ -> don't add the empty string + +template <class A> class RationalFst; +template <class A> void Union(RationalFst<A> *fst1, const Fst<A> &fst2); +template <class A> void Concat(RationalFst<A> *fst1, const Fst<A> &fst2); +template <class A> void Concat(const Fst<A> &fst1, RationalFst<A> *fst2); +template <class A> void Closure(RationalFst<A> *fst, ClosureType closure_type); + + +// Implementation class for delayed unions, concatenations and closures. +template<class A> +class RationalFstImpl : public FstImpl<A> { + public: + using FstImpl<A>::SetType; + using FstImpl<A>::SetProperties; + using FstImpl<A>::WriteHeader; + using FstImpl<A>::SetInputSymbols; + using FstImpl<A>::SetOutputSymbols; + + typedef A Arc; + typedef typename A::Weight Weight; + typedef typename A::StateId StateId; + typedef typename A::Label Label; + + explicit RationalFstImpl(const RationalFstOptions &opts) + : nonterminals_(0), + replace_(0), + replace_options_(opts, 0) { + SetType("rational"); + fst_tuples_.push_back(pair<Label, const Fst<A>*>(0, 0)); + } + + RationalFstImpl(const RationalFstImpl<A> &impl) + : rfst_(impl.rfst_), + nonterminals_(impl.nonterminals_), + + replace_(impl.replace_ ? impl.replace_->Copy(true) : 0), + replace_options_(impl.replace_options_) { + SetType("rational"); + fst_tuples_.reserve(impl.fst_tuples_.size()); + for (size_t i = 0; i < impl.fst_tuples_.size(); ++i) + fst_tuples_.push_back(make_pair(impl.fst_tuples_[i].first, + impl.fst_tuples_[i].second + ? impl.fst_tuples_[i].second->Copy(true) + : 0)); + } + + virtual ~RationalFstImpl() { + for (size_t i = 0; i < fst_tuples_.size(); ++i) + if (fst_tuples_[i].second) + delete fst_tuples_[i].second; + if (replace_) + delete replace_; + } + + StateId Start() { return Replace()->Start(); } + + Weight Final(StateId s) { return Replace()->Final(s); } + + size_t NumArcs(StateId s) { return Replace()->NumArcs(s); } + + size_t NumInputEpsilons(StateId s) { + return Replace()->NumInputEpsilons(s); + } + + size_t NumOutputEpsilons(StateId s) { + return Replace()->NumOutputEpsilons(s); + } + + uint64 Properties() const { return Properties(kFstProperties); } + + // Set error if found; return FST impl properties. + uint64 Properties(uint64 mask) const { + if ((mask & kError) && Replace()->Properties(kError, false)) + SetProperties(kError, kError); + return FstImpl<Arc>::Properties(mask); + } + + // Implementation of UnionFst(fst1,fst2) + void InitUnion(const Fst<A> &fst1, const Fst<A> &fst2) { + if (replace_) + delete replace_; + uint64 props1 = fst1.Properties(kFstProperties, false); + uint64 props2 = fst2.Properties(kFstProperties, false); + SetInputSymbols(fst1.InputSymbols()); + SetOutputSymbols(fst1.OutputSymbols()); + rfst_.AddState(); + rfst_.AddState(); + rfst_.SetStart(0); + rfst_.SetFinal(1, Weight::One()); + rfst_.SetInputSymbols(fst1.InputSymbols()); + rfst_.SetOutputSymbols(fst1.OutputSymbols()); + nonterminals_ = 2; + rfst_.AddArc(0, A(0, -1, Weight::One(), 1)); + rfst_.AddArc(0, A(0, -2, Weight::One(), 1)); + fst_tuples_.push_back(make_pair(-1, fst1.Copy())); + fst_tuples_.push_back(make_pair(-2, fst2.Copy())); + SetProperties(UnionProperties(props1, props2, true), kCopyProperties); + } + + // Implementation of ConcatFst(fst1,fst2) + void InitConcat(const Fst<A> &fst1, const Fst<A> &fst2) { + if (replace_) + delete replace_; + uint64 props1 = fst1.Properties(kFstProperties, false); + uint64 props2 = fst2.Properties(kFstProperties, false); + SetInputSymbols(fst1.InputSymbols()); + SetOutputSymbols(fst1.OutputSymbols()); + rfst_.AddState(); + rfst_.AddState(); + rfst_.AddState(); + rfst_.SetStart(0); + rfst_.SetFinal(2, Weight::One()); + rfst_.SetInputSymbols(fst1.InputSymbols()); + rfst_.SetOutputSymbols(fst1.OutputSymbols()); + nonterminals_ = 2; + rfst_.AddArc(0, A(0, -1, Weight::One(), 1)); + rfst_.AddArc(1, A(0, -2, Weight::One(), 2)); + fst_tuples_.push_back(make_pair(-1, fst1.Copy())); + fst_tuples_.push_back(make_pair(-2, fst2.Copy())); + SetProperties(ConcatProperties(props1, props2, true), kCopyProperties); + } + + // Implementation of ClosureFst(fst, closure_type) + void InitClosure(const Fst<A> &fst, ClosureType closure_type) { + if (replace_) + delete replace_; + uint64 props = fst.Properties(kFstProperties, false); + SetInputSymbols(fst.InputSymbols()); + SetOutputSymbols(fst.OutputSymbols()); + if (closure_type == CLOSURE_STAR) { + rfst_.AddState(); + rfst_.SetStart(0); + rfst_.SetFinal(0, Weight::One()); + rfst_.AddArc(0, A(0, -1, Weight::One(), 0)); + } else { + rfst_.AddState(); + rfst_.AddState(); + rfst_.SetStart(0); + rfst_.SetFinal(1, Weight::One()); + rfst_.AddArc(0, A(0, -1, Weight::One(), 1)); + rfst_.AddArc(1, A(0, 0, Weight::One(), 0)); + } + rfst_.SetInputSymbols(fst.InputSymbols()); + rfst_.SetOutputSymbols(fst.OutputSymbols()); + fst_tuples_.push_back(make_pair(-1, fst.Copy())); + nonterminals_ = 1; + SetProperties(ClosureProperties(props, closure_type == CLOSURE_STAR, true), + kCopyProperties); + } + + // Implementation of Union(Fst &, RationalFst *) + void AddUnion(const Fst<A> &fst) { + if (replace_) + delete replace_; + uint64 props1 = FstImpl<A>::Properties(); + uint64 props2 = fst.Properties(kFstProperties, false); + VectorFst<A> afst; + afst.AddState(); + afst.AddState(); + afst.SetStart(0); + afst.SetFinal(1, Weight::One()); + ++nonterminals_; + afst.AddArc(0, A(0, -nonterminals_, Weight::One(), 1)); + Union(&rfst_, afst); + fst_tuples_.push_back(make_pair(-nonterminals_, fst.Copy())); + SetProperties(UnionProperties(props1, props2, true), kCopyProperties); + } + + // Implementation of Concat(Fst &, RationalFst *) + void AddConcat(const Fst<A> &fst, bool append) { + if (replace_) + delete replace_; + uint64 props1 = FstImpl<A>::Properties(); + uint64 props2 = fst.Properties(kFstProperties, false); + VectorFst<A> afst; + afst.AddState(); + afst.AddState(); + afst.SetStart(0); + afst.SetFinal(1, Weight::One()); + ++nonterminals_; + afst.AddArc(0, A(0, -nonterminals_, Weight::One(), 1)); + if (append) + Concat(&rfst_, afst); + else + Concat(afst, &rfst_); + fst_tuples_.push_back(make_pair(-nonterminals_, fst.Copy())); + SetProperties(ConcatProperties(props1, props2, true), kCopyProperties); + } + + // Implementation of Closure(RationalFst *, closure_type) + void AddClosure(ClosureType closure_type) { + if (replace_) + delete replace_; + uint64 props = FstImpl<A>::Properties(); + Closure(&rfst_, closure_type); + SetProperties(ClosureProperties(props, closure_type == CLOSURE_STAR, true), + kCopyProperties); + } + + // Returns the underlying ReplaceFst. + ReplaceFst<A> *Replace() const { + if (!replace_) { + fst_tuples_[0].second = rfst_.Copy(); + replace_ = new ReplaceFst<A>(fst_tuples_, replace_options_); + } + return replace_; + } + + private: + VectorFst<A> rfst_; // rational topology machine; uses neg. nonterminals + Label nonterminals_; // # of nonterminals used + // Contains the nonterminals and their corresponding FSTs. + mutable vector<pair<Label, const Fst<A>*> > fst_tuples_; + mutable ReplaceFst<A> *replace_; // Underlying ReplaceFst + ReplaceFstOptions<A> replace_options_; // Options for creating 'replace_' + + void operator=(const RationalFstImpl<A> &impl); // disallow +}; + +// Parent class for the delayed rational operations - delayed union, +// concatenation, and closure. +// +// This class attaches interface to implementation and handles +// reference counting, delegating most methods to ImplToFst. +template <class A> +class RationalFst : public ImplToFst< RationalFstImpl<A> > { + public: + friend class StateIterator< RationalFst<A> >; + friend class ArcIterator< RationalFst<A> >; + friend void Union<>(RationalFst<A> *fst1, const Fst<A> &fst2); + friend void Concat<>(RationalFst<A> *fst1, const Fst<A> &fst2); + friend void Concat<>(const Fst<A> &fst1, RationalFst<A> *fst2); + friend void Closure<>(RationalFst<A> *fst, ClosureType closure_type); + + typedef A Arc; + typedef typename A::StateId StateId; + typedef RationalFstImpl<A> Impl; + + virtual void InitStateIterator(StateIteratorData<A> *data) const { + GetImpl()->Replace()->InitStateIterator(data); + } + + virtual void InitArcIterator(StateId s, ArcIteratorData<A> *data) const { + GetImpl()->Replace()->InitArcIterator(s, data); + } + + protected: + RationalFst() + : ImplToFst<Impl>(new Impl(RationalFstOptions())) {} + + explicit RationalFst(const RationalFstOptions &opts) + : ImplToFst<Impl>(new Impl(opts)) {} + + // See Fst<>::Copy() for doc. + RationalFst(const RationalFst<A> &fst , bool safe = false) + : ImplToFst<Impl>(fst, safe) {} + + private: + // Makes visible to friends. + Impl *GetImpl() const { return ImplToFst<Impl>::GetImpl(); } + + void operator=(const RationalFst<A> &fst); // disallow +}; + + +// Specialization for RationalFst. +template <class A> +class StateIterator< RationalFst<A> > + : public StateIterator< ReplaceFst<A> > { + public: + explicit StateIterator(const RationalFst<A> &fst) + : StateIterator< ReplaceFst<A> >(*(fst.GetImpl()->Replace())) {} +}; + + +// Specialization for RationalFst. +template <class A> +class ArcIterator< RationalFst<A> > + : public CacheArcIterator< ReplaceFst<A> > { + public: + typedef typename A::StateId StateId; + + ArcIterator(const RationalFst<A> &fst, StateId s) + : ArcIterator< ReplaceFst<A> >(*(fst.GetImpl()->Replace()), s) {} +}; + +} // namespace fst + +#endif // FST_LIB_RATIONAL_H__ diff --git a/kaldi_io/src/tools/openfst/include/fst/register.h b/kaldi_io/src/tools/openfst/include/fst/register.h new file mode 100644 index 0000000..ea3f4d8 --- /dev/null +++ b/kaldi_io/src/tools/openfst/include/fst/register.h @@ -0,0 +1,133 @@ +// register.h + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Copyright 2005-2010 Google, Inc. +// Author: [email protected] (Michael Riley), [email protected] (Jake Ratkiewicz) +// +// \file +// Classes for registering derived Fsts for generic reading +// + +#ifndef FST_LIB_REGISTER_H__ +#define FST_LIB_REGISTER_H__ + +#include <string> + + +#include <fst/compat.h> +#include <iostream> +#include <fstream> +#include <sstream> +#include <fst/util.h> +#include <fst/generic-register.h> + + +#include <fst/types.h> + +namespace fst { + +template <class A> class Fst; +struct FstReadOptions; + +// This class represents a single entry in a FstRegister +template<class A> +struct FstRegisterEntry { + typedef Fst<A> *(*Reader)(istream &strm, const FstReadOptions &opts); + typedef Fst<A> *(*Converter)(const Fst<A> &fst); + + Reader reader; + Converter converter; + FstRegisterEntry() : reader(0), converter(0) {} + FstRegisterEntry(Reader r, Converter c) : reader(r), converter(c) { } +}; + +// This class maintains the correspondence between a string describing +// an FST type, and its reader and converter. +template<class A> +class FstRegister : public GenericRegister<string, FstRegisterEntry<A>, + FstRegister<A> > { + public: + typedef typename FstRegisterEntry<A>::Reader Reader; + typedef typename FstRegisterEntry<A>::Converter Converter; + + const Reader GetReader(const string &type) const { + return this->GetEntry(type).reader; + } + + const Converter GetConverter(const string &type) const { + return this->GetEntry(type).converter; + } + + protected: + virtual string ConvertKeyToSoFilename(const string& key) const { + string legal_type(key); + + ConvertToLegalCSymbol(&legal_type); + + return legal_type + "-fst.so"; + } +}; + + +// This class registers an Fst type for generic reading and creating. +// The Fst type must have a default constructor and a copy constructor +// from 'Fst<Arc>' for this to work. +template <class F> +class FstRegisterer + : public GenericRegisterer<FstRegister<typename F::Arc> > { + public: + typedef typename F::Arc Arc; + typedef typename FstRegister<Arc>::Entry Entry; + typedef typename FstRegister<Arc>::Reader Reader; + + FstRegisterer() : + GenericRegisterer<FstRegister<typename F::Arc> >( + F().Type(), BuildEntry()) { } + + private: + Entry BuildEntry() { + F *(*reader)(istream &strm, + const FstReadOptions &opts) = &F::Read; + + return Entry(reinterpret_cast<Reader>(reader), + &FstRegisterer<F>::Convert); + } + + static Fst<Arc> *Convert(const Fst<Arc> &fst) { return new F(fst); } +}; + + +// Convenience macro to generate static FstRegisterer instance. +#define REGISTER_FST(F, A) \ +static fst::FstRegisterer< F<A> > F ## _ ## A ## _registerer + + +// Converts an fst to type 'type'. +template <class A> +Fst<A> *Convert(const Fst<A> &fst, const string &ftype) { + FstRegister<A> *registr = FstRegister<A>::GetRegister(); + const typename FstRegister<A>::Converter + converter = registr->GetConverter(ftype); + if (!converter) { + string atype = A::Type(); + LOG(ERROR) << "Fst::Convert: Unknown FST type \"" << ftype + << "\" (arc type = \"" << atype << "\")"; + return 0; + } + return converter(fst); +} + +} // namespace fst + +#endif // FST_LIB_REGISTER_H__ diff --git a/kaldi_io/src/tools/openfst/include/fst/relabel.h b/kaldi_io/src/tools/openfst/include/fst/relabel.h new file mode 100644 index 0000000..dc675b6 --- /dev/null +++ b/kaldi_io/src/tools/openfst/include/fst/relabel.h @@ -0,0 +1,528 @@ +// relabel.h + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Copyright 2005-2010 Google, Inc. +// Author: [email protected] (Johan Schalkwyk) +// +// \file +// Functions and classes to relabel an Fst (either on input or output) +// +#ifndef FST_LIB_RELABEL_H__ +#define FST_LIB_RELABEL_H__ + +#include <tr1/unordered_map> +using std::tr1::unordered_map; +using std::tr1::unordered_multimap; +#include <string> +#include <utility> +using std::pair; using std::make_pair; +#include <vector> +using std::vector; + +#include <fst/cache.h> +#include <fst/test-properties.h> + + +#include <tr1/unordered_map> +using std::tr1::unordered_map; +using std::tr1::unordered_multimap; + +namespace fst { + +// +// Relabels either the input labels or output labels. The old to +// new labels are specified using a vector of pair<Label,Label>. +// Any label associations not specified are assumed to be identity +// mapping. +// +// \param fst input fst, must be mutable +// \param ipairs vector of input label pairs indicating old to new mapping +// \param opairs vector of output label pairs indicating old to new mapping +// +template <class A> +void Relabel( + MutableFst<A> *fst, + const vector<pair<typename A::Label, typename A::Label> >& ipairs, + const vector<pair<typename A::Label, typename A::Label> >& opairs) { + typedef typename A::StateId StateId; + typedef typename A::Label Label; + + uint64 props = fst->Properties(kFstProperties, false); + + // construct label to label hash. + unordered_map<Label, Label> input_map; + for (size_t i = 0; i < ipairs.size(); ++i) { + input_map[ipairs[i].first] = ipairs[i].second; + } + + unordered_map<Label, Label> output_map; + for (size_t i = 0; i < opairs.size(); ++i) { + output_map[opairs[i].first] = opairs[i].second; + } + + for (StateIterator<MutableFst<A> > siter(*fst); + !siter.Done(); siter.Next()) { + StateId s = siter.Value(); + for (MutableArcIterator<MutableFst<A> > aiter(fst, s); + !aiter.Done(); aiter.Next()) { + A arc = aiter.Value(); + + // relabel input + // only relabel if relabel pair defined + typename unordered_map<Label, Label>::iterator it = + input_map.find(arc.ilabel); + if (it != input_map.end()) { + if (it->second == kNoLabel) { + FSTERROR() << "Input symbol id " << arc.ilabel + << " missing from target vocabulary"; + fst->SetProperties(kError, kError); + return; + } + arc.ilabel = it->second; + } + + // relabel output + it = output_map.find(arc.olabel); + if (it != output_map.end()) { + if (it->second == kNoLabel) { + FSTERROR() << "Output symbol id " << arc.olabel + << " missing from target vocabulary"; + fst->SetProperties(kError, kError); + return; + } + arc.olabel = it->second; + } + + aiter.SetValue(arc); + } + } + + fst->SetProperties(RelabelProperties(props), kFstProperties); +} + +// +// Relabels either the input labels or output labels. The old to +// new labels mappings are specified using an input Symbol set. +// Any label associations not specified are assumed to be identity +// mapping. +// +// \param fst input fst, must be mutable +// \param new_isymbols symbol set indicating new mapping of input symbols +// \param new_osymbols symbol set indicating new mapping of output symbols +// +template<class A> +void Relabel(MutableFst<A> *fst, + const SymbolTable* new_isymbols, + const SymbolTable* new_osymbols) { + Relabel(fst, + fst->InputSymbols(), new_isymbols, true, + fst->OutputSymbols(), new_osymbols, true); +} + +template<class A> +void Relabel(MutableFst<A> *fst, + const SymbolTable* old_isymbols, + const SymbolTable* new_isymbols, + bool attach_new_isymbols, + const SymbolTable* old_osymbols, + const SymbolTable* new_osymbols, + bool attach_new_osymbols) { + typedef typename A::StateId StateId; + typedef typename A::Label Label; + + vector<pair<Label, Label> > ipairs; + if (old_isymbols && new_isymbols) { + for (SymbolTableIterator syms_iter(*old_isymbols); !syms_iter.Done(); + syms_iter.Next()) { + string isymbol = syms_iter.Symbol(); + int isymbol_val = syms_iter.Value(); + int new_isymbol_val = new_isymbols->Find(isymbol); + ipairs.push_back(make_pair(isymbol_val, new_isymbol_val)); + } + if (attach_new_isymbols) + fst->SetInputSymbols(new_isymbols); + } + + vector<pair<Label, Label> > opairs; + if (old_osymbols && new_osymbols) { + for (SymbolTableIterator syms_iter(*old_osymbols); !syms_iter.Done(); + syms_iter.Next()) { + string osymbol = syms_iter.Symbol(); + int osymbol_val = syms_iter.Value(); + int new_osymbol_val = new_osymbols->Find(osymbol); + opairs.push_back(make_pair(osymbol_val, new_osymbol_val)); + } + if (attach_new_osymbols) + fst->SetOutputSymbols(new_osymbols); + } + + // call relabel using vector of relabel pairs. + Relabel(fst, ipairs, opairs); +} + + +typedef CacheOptions RelabelFstOptions; + +template <class A> class RelabelFst; + +// +// \class RelabelFstImpl +// \brief Implementation for delayed relabeling +// +// Relabels an FST from one symbol set to another. Relabeling +// can either be on input or output space. RelabelFst implements +// a delayed version of the relabel. Arcs are relabeled on the fly +// and not cached. I.e each request is recomputed. +// +template<class A> +class RelabelFstImpl : public CacheImpl<A> { + friend class StateIterator< RelabelFst<A> >; + public: + using FstImpl<A>::SetType; + using FstImpl<A>::SetProperties; + using FstImpl<A>::WriteHeader; + using FstImpl<A>::SetInputSymbols; + using FstImpl<A>::SetOutputSymbols; + + using CacheImpl<A>::PushArc; + using CacheImpl<A>::HasArcs; + using CacheImpl<A>::HasFinal; + using CacheImpl<A>::HasStart; + using CacheImpl<A>::SetArcs; + using CacheImpl<A>::SetFinal; + using CacheImpl<A>::SetStart; + + typedef A Arc; + typedef typename A::Label Label; + typedef typename A::Weight Weight; + typedef typename A::StateId StateId; + typedef CacheState<A> State; + + RelabelFstImpl(const Fst<A>& fst, + const vector<pair<Label, Label> >& ipairs, + const vector<pair<Label, Label> >& opairs, + const RelabelFstOptions &opts) + : CacheImpl<A>(opts), fst_(fst.Copy()), + relabel_input_(false), relabel_output_(false) { + uint64 props = fst.Properties(kCopyProperties, false); + SetProperties(RelabelProperties(props)); + SetType("relabel"); + + // create input label map + if (ipairs.size() > 0) { + for (size_t i = 0; i < ipairs.size(); ++i) { + input_map_[ipairs[i].first] = ipairs[i].second; + } + relabel_input_ = true; + } + + // create output label map + if (opairs.size() > 0) { + for (size_t i = 0; i < opairs.size(); ++i) { + output_map_[opairs[i].first] = opairs[i].second; + } + relabel_output_ = true; + } + } + + RelabelFstImpl(const Fst<A>& fst, + const SymbolTable* old_isymbols, + const SymbolTable* new_isymbols, + const SymbolTable* old_osymbols, + const SymbolTable* new_osymbols, + const RelabelFstOptions &opts) + : CacheImpl<A>(opts), fst_(fst.Copy()), + relabel_input_(false), relabel_output_(false) { + SetType("relabel"); + + uint64 props = fst.Properties(kCopyProperties, false); + SetProperties(RelabelProperties(props)); + SetInputSymbols(old_isymbols); + SetOutputSymbols(old_osymbols); + + if (old_isymbols && new_isymbols && + old_isymbols->LabeledCheckSum() != new_isymbols->LabeledCheckSum()) { + for (SymbolTableIterator syms_iter(*old_isymbols); !syms_iter.Done(); + syms_iter.Next()) { + input_map_[syms_iter.Value()] = new_isymbols->Find(syms_iter.Symbol()); + } + SetInputSymbols(new_isymbols); + relabel_input_ = true; + } + + if (old_osymbols && new_osymbols && + old_osymbols->LabeledCheckSum() != new_osymbols->LabeledCheckSum()) { + for (SymbolTableIterator syms_iter(*old_osymbols); !syms_iter.Done(); + syms_iter.Next()) { + output_map_[syms_iter.Value()] = + new_osymbols->Find(syms_iter.Symbol()); + } + SetOutputSymbols(new_osymbols); + relabel_output_ = true; + } + } + + RelabelFstImpl(const RelabelFstImpl<A>& impl) + : CacheImpl<A>(impl), + fst_(impl.fst_->Copy(true)), + input_map_(impl.input_map_), + output_map_(impl.output_map_), + relabel_input_(impl.relabel_input_), + relabel_output_(impl.relabel_output_) { + SetType("relabel"); + SetProperties(impl.Properties(), kCopyProperties); + SetInputSymbols(impl.InputSymbols()); + SetOutputSymbols(impl.OutputSymbols()); + } + + ~RelabelFstImpl() { delete fst_; } + + StateId Start() { + if (!HasStart()) { + StateId s = fst_->Start(); + SetStart(s); + } + return CacheImpl<A>::Start(); + } + + Weight Final(StateId s) { + if (!HasFinal(s)) { + SetFinal(s, fst_->Final(s)); + } + return CacheImpl<A>::Final(s); + } + + size_t NumArcs(StateId s) { + if (!HasArcs(s)) { + Expand(s); + } + return CacheImpl<A>::NumArcs(s); + } + + size_t NumInputEpsilons(StateId s) { + if (!HasArcs(s)) { + Expand(s); + } + return CacheImpl<A>::NumInputEpsilons(s); + } + + size_t NumOutputEpsilons(StateId s) { + if (!HasArcs(s)) { + Expand(s); + } + return CacheImpl<A>::NumOutputEpsilons(s); + } + + uint64 Properties() const { return Properties(kFstProperties); } + + // Set error if found; return FST impl properties. + uint64 Properties(uint64 mask) const { + if ((mask & kError) && fst_->Properties(kError, false)) + SetProperties(kError, kError); + return FstImpl<Arc>::Properties(mask); + } + + void InitArcIterator(StateId s, ArcIteratorData<A>* data) { + if (!HasArcs(s)) { + Expand(s); + } + CacheImpl<A>::InitArcIterator(s, data); + } + + void Expand(StateId s) { + for (ArcIterator<Fst<A> > aiter(*fst_, s); !aiter.Done(); aiter.Next()) { + A arc = aiter.Value(); + + // relabel input + if (relabel_input_) { + typename unordered_map<Label, Label>::iterator it = + input_map_.find(arc.ilabel); + if (it != input_map_.end()) { arc.ilabel = it->second; } + } + + // relabel output + if (relabel_output_) { + typename unordered_map<Label, Label>::iterator it = + output_map_.find(arc.olabel); + if (it != output_map_.end()) { arc.olabel = it->second; } + } + + PushArc(s, arc); + } + SetArcs(s); + } + + + private: + const Fst<A> *fst_; + + unordered_map<Label, Label> input_map_; + unordered_map<Label, Label> output_map_; + bool relabel_input_; + bool relabel_output_; + + void operator=(const RelabelFstImpl<A> &); // disallow +}; + + +// +// \class RelabelFst +// \brief Delayed implementation of arc relabeling +// +// This class attaches interface to implementation and handles +// reference counting, delegating most methods to ImplToFst. +template <class A> +class RelabelFst : public ImplToFst< RelabelFstImpl<A> > { + public: + friend class ArcIterator< RelabelFst<A> >; + friend class StateIterator< RelabelFst<A> >; + + typedef A Arc; + typedef typename A::Label Label; + typedef typename A::Weight Weight; + typedef typename A::StateId StateId; + typedef CacheState<A> State; + typedef RelabelFstImpl<A> Impl; + + RelabelFst(const Fst<A>& fst, + const vector<pair<Label, Label> >& ipairs, + const vector<pair<Label, Label> >& opairs) + : ImplToFst<Impl>(new Impl(fst, ipairs, opairs, RelabelFstOptions())) {} + + RelabelFst(const Fst<A>& fst, + const vector<pair<Label, Label> >& ipairs, + const vector<pair<Label, Label> >& opairs, + const RelabelFstOptions &opts) + : ImplToFst<Impl>(new Impl(fst, ipairs, opairs, opts)) {} + + RelabelFst(const Fst<A>& fst, + const SymbolTable* new_isymbols, + const SymbolTable* new_osymbols) + : ImplToFst<Impl>(new Impl(fst, fst.InputSymbols(), new_isymbols, + fst.OutputSymbols(), new_osymbols, + RelabelFstOptions())) {} + + RelabelFst(const Fst<A>& fst, + const SymbolTable* new_isymbols, + const SymbolTable* new_osymbols, + const RelabelFstOptions &opts) + : ImplToFst<Impl>(new Impl(fst, fst.InputSymbols(), new_isymbols, + fst.OutputSymbols(), new_osymbols, opts)) {} + + RelabelFst(const Fst<A>& fst, + const SymbolTable* old_isymbols, + const SymbolTable* new_isymbols, + const SymbolTable* old_osymbols, + const SymbolTable* new_osymbols) + : ImplToFst<Impl>(new Impl(fst, old_isymbols, new_isymbols, old_osymbols, + new_osymbols, RelabelFstOptions())) {} + + RelabelFst(const Fst<A>& fst, + const SymbolTable* old_isymbols, + const SymbolTable* new_isymbols, + const SymbolTable* old_osymbols, + const SymbolTable* new_osymbols, + const RelabelFstOptions &opts) + : ImplToFst<Impl>(new Impl(fst, old_isymbols, new_isymbols, old_osymbols, + new_osymbols, opts)) {} + + // See Fst<>::Copy() for doc. + RelabelFst(const RelabelFst<A> &fst, bool safe = false) + : ImplToFst<Impl>(fst, safe) {} + + // Get a copy of this RelabelFst. See Fst<>::Copy() for further doc. + virtual RelabelFst<A> *Copy(bool safe = false) const { + return new RelabelFst<A>(*this, safe); + } + + virtual void InitStateIterator(StateIteratorData<A> *data) const; + + virtual void InitArcIterator(StateId s, ArcIteratorData<A> *data) const { + return GetImpl()->InitArcIterator(s, data); + } + + private: + // Makes visible to friends. + Impl *GetImpl() const { return ImplToFst<Impl>::GetImpl(); } + + void operator=(const RelabelFst<A> &fst); // disallow +}; + +// Specialization for RelabelFst. +template<class A> +class StateIterator< RelabelFst<A> > : public StateIteratorBase<A> { + public: + typedef typename A::StateId StateId; + + explicit StateIterator(const RelabelFst<A> &fst) + : impl_(fst.GetImpl()), siter_(*impl_->fst_), s_(0) {} + + bool Done() const { return siter_.Done(); } + + StateId Value() const { return s_; } + + void Next() { + if (!siter_.Done()) { + ++s_; + siter_.Next(); + } + } + + void Reset() { + s_ = 0; + siter_.Reset(); + } + + private: + bool Done_() const { return Done(); } + StateId Value_() const { return Value(); } + void Next_() { Next(); } + void Reset_() { Reset(); } + + const RelabelFstImpl<A> *impl_; + StateIterator< Fst<A> > siter_; + StateId s_; + + DISALLOW_COPY_AND_ASSIGN(StateIterator); +}; + + +// Specialization for RelabelFst. +template <class A> +class ArcIterator< RelabelFst<A> > + : public CacheArcIterator< RelabelFst<A> > { + public: + typedef typename A::StateId StateId; + + ArcIterator(const RelabelFst<A> &fst, StateId s) + : CacheArcIterator< RelabelFst<A> >(fst.GetImpl(), s) { + if (!fst.GetImpl()->HasArcs(s)) + fst.GetImpl()->Expand(s); + } + + private: + DISALLOW_COPY_AND_ASSIGN(ArcIterator); +}; + +template <class A> inline +void RelabelFst<A>::InitStateIterator(StateIteratorData<A> *data) const { + data->base = new StateIterator< RelabelFst<A> >(*this); +} + +// Useful alias when using StdArc. +typedef RelabelFst<StdArc> StdRelabelFst; + +} // namespace fst + +#endif // FST_LIB_RELABEL_H__ diff --git a/kaldi_io/src/tools/openfst/include/fst/replace-util.h b/kaldi_io/src/tools/openfst/include/fst/replace-util.h new file mode 100644 index 0000000..d58cb15 --- /dev/null +++ b/kaldi_io/src/tools/openfst/include/fst/replace-util.h @@ -0,0 +1,550 @@ +// replace-util.h + + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Copyright 2005-2010 Google, Inc. +// Author: [email protected] (Michael Riley) +// + +// \file +// Utility classes for the recursive replacement of Fsts (RTNs). + +#ifndef FST_LIB_REPLACE_UTIL_H__ +#define FST_LIB_REPLACE_UTIL_H__ + +#include <vector> +using std::vector; +#include <tr1/unordered_map> +using std::tr1::unordered_map; +using std::tr1::unordered_multimap; +#include <tr1/unordered_set> +using std::tr1::unordered_set; +using std::tr1::unordered_multiset; +#include <map> + +#include <fst/connect.h> +#include <fst/mutable-fst.h> +#include <fst/topsort.h> + + +namespace fst { + +template <class Arc> +void Replace(const vector<pair<typename Arc::Label, const Fst<Arc>* > >&, + MutableFst<Arc> *, typename Arc::Label, bool); + + +// Utility class for the recursive replacement of Fsts (RTNs). The +// user provides a set of Label, Fst pairs at construction. These are +// used by methods for testing cyclic dependencies and connectedness +// and doing RTN connection and specific Fst replacement by label or +// for various optimization properties. The modified results can be +// obtained with the GetFstPairs() or GetMutableFstPairs() methods. +template <class Arc> +class ReplaceUtil { + public: + typedef typename Arc::Label Label; + typedef typename Arc::Weight Weight; + typedef typename Arc::StateId StateId; + + typedef pair<Label, const Fst<Arc>*> FstPair; + typedef pair<Label, MutableFst<Arc>*> MutableFstPair; + typedef unordered_map<Label, Label> NonTerminalHash; + + // Constructs from mutable Fsts; Fst ownership given to ReplaceUtil. + ReplaceUtil(const vector<MutableFstPair> &fst_pairs, + Label root_label, bool epsilon_on_replace = false); + + // Constructs from Fsts; Fst ownership retained by caller. + ReplaceUtil(const vector<FstPair> &fst_pairs, + Label root_label, bool epsilon_on_replace = false); + + // Constructs from ReplaceFst internals; ownership retained by caller. + ReplaceUtil(const vector<const Fst<Arc> *> &fst_array, + const NonTerminalHash &nonterminal_hash, Label root_fst, + bool epsilon_on_replace = false); + + ~ReplaceUtil() { + for (Label i = 0; i < fst_array_.size(); ++i) + delete fst_array_[i]; + } + + // True if the non-terminal dependencies are cyclic. Cyclic + // dependencies will result in an unexpandable replace fst. + bool CyclicDependencies() const { + GetDependencies(false); + return depprops_ & kCyclic; + } + + // Returns true if no useless Fsts, states or transitions. + bool Connected() const { + GetDependencies(false); + uint64 props = kAccessible | kCoAccessible; + for (Label i = 0; i < fst_array_.size(); ++i) { + if (!fst_array_[i]) + continue; + if (fst_array_[i]->Properties(props, true) != props || !depaccess_[i]) + return false; + } + return true; + } + + // Removes useless Fsts, states and transitions. + void Connect(); + + // Replaces Fsts specified by labels. + // Does nothing if there are cyclic dependencies. + void ReplaceLabels(const vector<Label> &labels); + + // Replaces Fsts that have at most 'nstates' states, 'narcs' arcs and + // 'nnonterm' non-terminals (updating in reverse dependency order). + // Does nothing if there are cyclic dependencies. + void ReplaceBySize(size_t nstates, size_t narcs, size_t nnonterms); + + // Replaces singleton Fsts. + // Does nothing if there are cyclic dependencies. + void ReplaceTrivial() { ReplaceBySize(2, 1, 1); } + + // Replaces non-terminals that have at most 'ninstances' instances + // (updating in dependency order). + // Does nothing if there are cyclic dependencies. + void ReplaceByInstances(size_t ninstances); + + // Replaces non-terminals that have only one instance. + // Does nothing if there are cyclic dependencies. + void ReplaceUnique() { ReplaceByInstances(1); } + + // Returns Label, Fst pairs; Fst ownership retained by ReplaceUtil. + void GetFstPairs(vector<FstPair> *fst_pairs); + + // Returns Label, MutableFst pairs; Fst ownership given to caller. + void GetMutableFstPairs(vector<MutableFstPair> *mutable_fst_pairs); + + private: + // Per Fst statistics + struct ReplaceStats { + StateId nstates; // # of states + StateId nfinal; // # of final states + size_t narcs; // # of arcs + Label nnonterms; // # of non-terminals in Fst + size_t nref; // # of non-terminal instances referring to this Fst + + // # of times that ith Fst references this Fst + map<Label, size_t> inref; + // # of times that this Fst references the ith Fst + map<Label, size_t> outref; + + ReplaceStats() + : nstates(0), + nfinal(0), + narcs(0), + nnonterms(0), + nref(0) {} + }; + + // Check Mutable Fsts exist o.w. create them. + void CheckMutableFsts(); + + // Computes the dependency graph of the replace Fsts. + // If 'stats' is true, dependency statistics computed as well. + void GetDependencies(bool stats) const; + + void ClearDependencies() const { + depfst_.DeleteStates(); + stats_.clear(); + depprops_ = 0; + have_stats_ = false; + } + + // Get topological order of dependencies. Returns false with cyclic input. + bool GetTopOrder(const Fst<Arc> &fst, vector<Label> *toporder) const; + + // Update statistics assuming that jth Fst will be replaced. + void UpdateStats(Label j); + + Label root_label_; // root non-terminal + Label root_fst_; // root Fst ID + bool epsilon_on_replace_; // see Replace() + vector<const Fst<Arc> *> fst_array_; // Fst per ID + vector<MutableFst<Arc> *> mutable_fst_array_; // MutableFst per ID + vector<Label> nonterminal_array_; // Fst ID to non-terminal + NonTerminalHash nonterminal_hash_; // non-terminal to Fst ID + mutable VectorFst<Arc> depfst_; // Fst ID dependencies + mutable vector<bool> depaccess_; // Fst ID accessibility + mutable uint64 depprops_; // dependency Fst props + mutable bool have_stats_; // have dependency statistics + mutable vector<ReplaceStats> stats_; // Per Fst statistics + DISALLOW_COPY_AND_ASSIGN(ReplaceUtil); +}; + +template <class Arc> +ReplaceUtil<Arc>::ReplaceUtil( + const vector<MutableFstPair> &fst_pairs, + Label root_label, bool epsilon_on_replace) + : root_label_(root_label), + epsilon_on_replace_(epsilon_on_replace), + depprops_(0), + have_stats_(false) { + fst_array_.push_back(0); + mutable_fst_array_.push_back(0); + nonterminal_array_.push_back(kNoLabel); + for (Label i = 0; i < fst_pairs.size(); ++i) { + Label label = fst_pairs[i].first; + MutableFst<Arc> *fst = fst_pairs[i].second; + nonterminal_hash_[label] = fst_array_.size(); + nonterminal_array_.push_back(label); + fst_array_.push_back(fst); + mutable_fst_array_.push_back(fst); + } + root_fst_ = nonterminal_hash_[root_label_]; + if (!root_fst_) + FSTERROR() << "ReplaceUtil: no root FST for label: " << root_label_; +} + +template <class Arc> +ReplaceUtil<Arc>::ReplaceUtil( + const vector<FstPair> &fst_pairs, + Label root_label, bool epsilon_on_replace) + : root_label_(root_label), + epsilon_on_replace_(epsilon_on_replace), + depprops_(0), + have_stats_(false) { + fst_array_.push_back(0); + nonterminal_array_.push_back(kNoLabel); + for (Label i = 0; i < fst_pairs.size(); ++i) { + Label label = fst_pairs[i].first; + const Fst<Arc> *fst = fst_pairs[i].second; + nonterminal_hash_[label] = fst_array_.size(); + nonterminal_array_.push_back(label); + fst_array_.push_back(fst->Copy()); + } + root_fst_ = nonterminal_hash_[root_label]; + if (!root_fst_) + FSTERROR() << "ReplaceUtil: no root FST for label: " << root_label_; +} + +template <class Arc> +ReplaceUtil<Arc>::ReplaceUtil( + const vector<const Fst<Arc> *> &fst_array, + const NonTerminalHash &nonterminal_hash, Label root_fst, + bool epsilon_on_replace) + : root_fst_(root_fst), + epsilon_on_replace_(epsilon_on_replace), + nonterminal_array_(fst_array.size()), + nonterminal_hash_(nonterminal_hash), + depprops_(0), + have_stats_(false) { + fst_array_.push_back(0); + for (Label i = 1; i < fst_array.size(); ++i) + fst_array_.push_back(fst_array[i]->Copy()); + for (typename NonTerminalHash::const_iterator it = + nonterminal_hash.begin(); it != nonterminal_hash.end(); ++it) + nonterminal_array_[it->second] = it->first; + root_label_ = nonterminal_array_[root_fst_]; +} + +template <class Arc> +void ReplaceUtil<Arc>::GetDependencies(bool stats) const { + if (depfst_.NumStates() > 0) { + if (stats && !have_stats_) + ClearDependencies(); + else + return; + } + + have_stats_ = stats; + if (have_stats_) + stats_.reserve(fst_array_.size()); + + for (Label i = 0; i < fst_array_.size(); ++i) { + depfst_.AddState(); + depfst_.SetFinal(i, Weight::One()); + if (have_stats_) + stats_.push_back(ReplaceStats()); + } + depfst_.SetStart(root_fst_); + + // An arc from each state (representing the fst) to the + // state representing the fst being replaced + for (Label i = 0; i < fst_array_.size(); ++i) { + const Fst<Arc> *ifst = fst_array_[i]; + if (!ifst) + continue; + for (StateIterator<Fst<Arc> > siter(*ifst); !siter.Done(); siter.Next()) { + StateId s = siter.Value(); + if (have_stats_) { + ++stats_[i].nstates; + if (ifst->Final(s) != Weight::Zero()) + ++stats_[i].nfinal; + } + for (ArcIterator<Fst<Arc> > aiter(*ifst, s); + !aiter.Done(); aiter.Next()) { + if (have_stats_) + ++stats_[i].narcs; + const Arc& arc = aiter.Value(); + + typename NonTerminalHash::const_iterator it = + nonterminal_hash_.find(arc.olabel); + if (it != nonterminal_hash_.end()) { + Label j = it->second; + depfst_.AddArc(i, Arc(arc.olabel, arc.olabel, Weight::One(), j)); + if (have_stats_) { + ++stats_[i].nnonterms; + ++stats_[j].nref; + ++stats_[j].inref[i]; + ++stats_[i].outref[j]; + } + } + } + } + } + + // Gets accessibility info + SccVisitor<Arc> scc_visitor(0, &depaccess_, 0, &depprops_); + DfsVisit(depfst_, &scc_visitor); +} + +template <class Arc> +void ReplaceUtil<Arc>::UpdateStats(Label j) { + if (!have_stats_) { + FSTERROR() << "ReplaceUtil::UpdateStats: stats not available"; + return; + } + + if (j == root_fst_) // can't replace root + return; + + typedef typename map<Label, size_t>::iterator Iter; + for (Iter in = stats_[j].inref.begin(); + in != stats_[j].inref.end(); + ++in) { + Label i = in->first; + size_t ni = in->second; + stats_[i].nstates += stats_[j].nstates * ni; + stats_[i].narcs += (stats_[j].narcs + 1) * ni; // narcs - 1 + 2 (eps) + stats_[i].nnonterms += (stats_[j].nnonterms - 1) * ni; + stats_[i].outref.erase(stats_[i].outref.find(j)); + for (Iter out = stats_[j].outref.begin(); + out != stats_[j].outref.end(); + ++out) { + Label k = out->first; + size_t nk = out->second; + stats_[i].outref[k] += ni * nk; + } + } + + for (Iter out = stats_[j].outref.begin(); + out != stats_[j].outref.end(); + ++out) { + Label k = out->first; + size_t nk = out->second; + stats_[k].nref -= nk; + stats_[k].inref.erase(stats_[k].inref.find(j)); + for (Iter in = stats_[j].inref.begin(); + in != stats_[j].inref.end(); + ++in) { + Label i = in->first; + size_t ni = in->second; + stats_[k].inref[i] += ni * nk; + stats_[k].nref += ni * nk; + } + } +} + +template <class Arc> +void ReplaceUtil<Arc>::CheckMutableFsts() { + if (mutable_fst_array_.size() == 0) { + for (Label i = 0; i < fst_array_.size(); ++i) { + if (!fst_array_[i]) { + mutable_fst_array_.push_back(0); + } else { + mutable_fst_array_.push_back(new VectorFst<Arc>(*fst_array_[i])); + delete fst_array_[i]; + fst_array_[i] = mutable_fst_array_[i]; + } + } + } +} + +template <class Arc> +void ReplaceUtil<Arc>::Connect() { + CheckMutableFsts(); + uint64 props = kAccessible | kCoAccessible; + for (Label i = 0; i < mutable_fst_array_.size(); ++i) { + if (!mutable_fst_array_[i]) + continue; + if (mutable_fst_array_[i]->Properties(props, false) != props) + fst::Connect(mutable_fst_array_[i]); + } + GetDependencies(false); + for (Label i = 0; i < mutable_fst_array_.size(); ++i) { + MutableFst<Arc> *fst = mutable_fst_array_[i]; + if (fst && !depaccess_[i]) { + delete fst; + fst_array_[i] = 0; + mutable_fst_array_[i] = 0; + } + } + ClearDependencies(); +} + +template <class Arc> +bool ReplaceUtil<Arc>::GetTopOrder(const Fst<Arc> &fst, + vector<Label> *toporder) const { + // Finds topological order of dependencies. + vector<StateId> order; + bool acyclic = false; + + TopOrderVisitor<Arc> top_order_visitor(&order, &acyclic); + DfsVisit(fst, &top_order_visitor); + if (!acyclic) { + LOG(WARNING) << "ReplaceUtil::GetTopOrder: Cyclical label dependencies"; + return false; + } + + toporder->resize(order.size()); + for (Label i = 0; i < order.size(); ++i) + (*toporder)[order[i]] = i; + + return true; +} + +template <class Arc> +void ReplaceUtil<Arc>::ReplaceLabels(const vector<Label> &labels) { + CheckMutableFsts(); + unordered_set<Label> label_set; + for (Label i = 0; i < labels.size(); ++i) + if (labels[i] != root_label_) // can't replace root + label_set.insert(labels[i]); + + // Finds Fst dependencies restricted to the labels requested. + GetDependencies(false); + VectorFst<Arc> pfst(depfst_); + for (StateId i = 0; i < pfst.NumStates(); ++i) { + vector<Arc> arcs; + for (ArcIterator< VectorFst<Arc> > aiter(pfst, i); + !aiter.Done(); aiter.Next()) { + const Arc &arc = aiter.Value(); + Label label = nonterminal_array_[arc.nextstate]; + if (label_set.count(label) > 0) + arcs.push_back(arc); + } + pfst.DeleteArcs(i); + for (size_t j = 0; j < arcs.size(); ++j) + pfst.AddArc(i, arcs[j]); + } + + vector<Label> toporder; + if (!GetTopOrder(pfst, &toporder)) { + ClearDependencies(); + return; + } + + // Visits Fsts in reverse topological order of dependencies and + // performs replacements. + for (Label o = toporder.size() - 1; o >= 0; --o) { + vector<FstPair> fst_pairs; + StateId s = toporder[o]; + for (ArcIterator< VectorFst<Arc> > aiter(pfst, s); + !aiter.Done(); aiter.Next()) { + const Arc &arc = aiter.Value(); + Label label = nonterminal_array_[arc.nextstate]; + const Fst<Arc> *fst = fst_array_[arc.nextstate]; + fst_pairs.push_back(make_pair(label, fst)); + } + if (fst_pairs.empty()) + continue; + Label label = nonterminal_array_[s]; + const Fst<Arc> *fst = fst_array_[s]; + fst_pairs.push_back(make_pair(label, fst)); + + Replace(fst_pairs, mutable_fst_array_[s], label, epsilon_on_replace_); + } + ClearDependencies(); +} + +template <class Arc> +void ReplaceUtil<Arc>::ReplaceBySize(size_t nstates, size_t narcs, + size_t nnonterms) { + vector<Label> labels; + GetDependencies(true); + + vector<Label> toporder; + if (!GetTopOrder(depfst_, &toporder)) { + ClearDependencies(); + return; + } + + for (Label o = toporder.size() - 1; o >= 0; --o) { + Label j = toporder[o]; + if (stats_[j].nstates <= nstates && + stats_[j].narcs <= narcs && + stats_[j].nnonterms <= nnonterms) { + labels.push_back(nonterminal_array_[j]); + UpdateStats(j); + } + } + ReplaceLabels(labels); +} + +template <class Arc> +void ReplaceUtil<Arc>::ReplaceByInstances(size_t ninstances) { + vector<Label> labels; + GetDependencies(true); + + vector<Label> toporder; + if (!GetTopOrder(depfst_, &toporder)) { + ClearDependencies(); + return; + } + for (Label o = 0; o < toporder.size(); ++o) { + Label j = toporder[o]; + if (stats_[j].nref <= ninstances) { + labels.push_back(nonterminal_array_[j]); + UpdateStats(j); + } + } + ReplaceLabels(labels); +} + +template <class Arc> +void ReplaceUtil<Arc>::GetFstPairs(vector<FstPair> *fst_pairs) { + CheckMutableFsts(); + fst_pairs->clear(); + for (Label i = 0; i < fst_array_.size(); ++i) { + Label label = nonterminal_array_[i]; + const Fst<Arc> *fst = fst_array_[i]; + if (!fst) + continue; + fst_pairs->push_back(make_pair(label, fst)); + } +} + +template <class Arc> +void ReplaceUtil<Arc>::GetMutableFstPairs( + vector<MutableFstPair> *mutable_fst_pairs) { + CheckMutableFsts(); + mutable_fst_pairs->clear(); + for (Label i = 0; i < mutable_fst_array_.size(); ++i) { + Label label = nonterminal_array_[i]; + MutableFst<Arc> *fst = mutable_fst_array_[i]; + if (!fst) + continue; + mutable_fst_pairs->push_back(make_pair(label, fst->Copy())); + } +} + +} // namespace fst + +#endif // FST_LIB_REPLACE_UTIL_H__ diff --git a/kaldi_io/src/tools/openfst/include/fst/replace.h b/kaldi_io/src/tools/openfst/include/fst/replace.h new file mode 100644 index 0000000..ef5f6cc --- /dev/null +++ b/kaldi_io/src/tools/openfst/include/fst/replace.h @@ -0,0 +1,1453 @@ +// replace.h + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Copyright 2005-2010 Google, Inc. +// Author: [email protected] (Johan Schalkwyk) +// +// \file +// Functions and classes for the recursive replacement of Fsts. +// + +#ifndef FST_LIB_REPLACE_H__ +#define FST_LIB_REPLACE_H__ + +#include <tr1/unordered_map> +using std::tr1::unordered_map; +using std::tr1::unordered_multimap; +#include <set> +#include <string> +#include <utility> +using std::pair; using std::make_pair; +#include <vector> +using std::vector; + +#include <fst/cache.h> +#include <fst/expanded-fst.h> +#include <fst/fst.h> +#include <fst/matcher.h> +#include <fst/replace-util.h> +#include <fst/state-table.h> +#include <fst/test-properties.h> + +namespace fst { + +// +// REPLACE STATE TUPLES AND TABLES +// +// The replace state table has the form +// +// template <class A, class P> +// class ReplaceStateTable { +// public: +// typedef A Arc; +// typedef P PrefixId; +// typedef typename A::StateId StateId; +// typedef ReplaceStateTuple<StateId, PrefixId> StateTuple; +// typedef typename A::Label Label; +// +// // Required constuctor +// ReplaceStateTable(const vector<pair<Label, const Fst<A>*> > &fst_tuples, +// Label root); +// +// // Required copy constructor that does not copy state +// ReplaceStateTable(const ReplaceStateTable<A,P> &table); +// +// // Lookup state ID by tuple. If it doesn't exist, then add it. +// StateId FindState(const StateTuple &tuple); +// +// // Lookup state tuple by ID. +// const StateTuple &Tuple(StateId id) const; +// }; + + +// \struct ReplaceStateTuple +// \brief Tuple of information that uniquely defines a state in replace +template <class S, class P> +struct ReplaceStateTuple { + typedef S StateId; + typedef P PrefixId; + + ReplaceStateTuple() + : prefix_id(-1), fst_id(kNoStateId), fst_state(kNoStateId) {} + + ReplaceStateTuple(PrefixId p, StateId f, StateId s) + : prefix_id(p), fst_id(f), fst_state(s) {} + + PrefixId prefix_id; // index in prefix table + StateId fst_id; // current fst being walked + StateId fst_state; // current state in fst being walked, not to be + // confused with the state_id of the combined fst +}; + + +// Equality of replace state tuples. +template <class S, class P> +inline bool operator==(const ReplaceStateTuple<S, P>& x, + const ReplaceStateTuple<S, P>& y) { + return x.prefix_id == y.prefix_id && + x.fst_id == y.fst_id && + x.fst_state == y.fst_state; +} + + +// \class ReplaceRootSelector +// Functor returning true for tuples corresponding to states in the root FST +template <class S, class P> +class ReplaceRootSelector { + public: + bool operator()(const ReplaceStateTuple<S, P> &tuple) const { + return tuple.prefix_id == 0; + } +}; + + +// \class ReplaceFingerprint +// Fingerprint for general replace state tuples. +template <class S, class P> +class ReplaceFingerprint { + public: + ReplaceFingerprint(const vector<uint64> *size_array) + : cumulative_size_array_(size_array) {} + + uint64 operator()(const ReplaceStateTuple<S, P> &tuple) const { + return tuple.prefix_id * (cumulative_size_array_->back()) + + cumulative_size_array_->at(tuple.fst_id - 1) + + tuple.fst_state; + } + + private: + const vector<uint64> *cumulative_size_array_; +}; + + +// \class ReplaceFstStateFingerprint +// Useful when the fst_state uniquely define the tuple. +template <class S, class P> +class ReplaceFstStateFingerprint { + public: + uint64 operator()(const ReplaceStateTuple<S, P>& tuple) const { + return tuple.fst_state; + } +}; + + +// \class ReplaceHash +// A generic hash function for replace state tuples. +template <typename S, typename P> +class ReplaceHash { + public: + size_t operator()(const ReplaceStateTuple<S, P>& t) const { + return t.prefix_id + t.fst_id * kPrime0 + t.fst_state * kPrime1; + } + private: + static const size_t kPrime0; + static const size_t kPrime1; +}; + +template <typename S, typename P> +const size_t ReplaceHash<S, P>::kPrime0 = 7853; + +template <typename S, typename P> +const size_t ReplaceHash<S, P>::kPrime1 = 7867; + +template <class A, class T> class ReplaceFstMatcher; + + +// \class VectorHashReplaceStateTable +// A two-level state table for replace. +// Warning: calls CountStates to compute the number of states of each +// component Fst. +template <class A, class P = ssize_t> +class VectorHashReplaceStateTable { + public: + typedef A Arc; + typedef typename A::StateId StateId; + typedef typename A::Label Label; + typedef P PrefixId; + typedef ReplaceStateTuple<StateId, P> StateTuple; + typedef VectorHashStateTable<ReplaceStateTuple<StateId, P>, + ReplaceRootSelector<StateId, P>, + ReplaceFstStateFingerprint<StateId, P>, + ReplaceFingerprint<StateId, P> > StateTable; + + VectorHashReplaceStateTable( + const vector<pair<Label, const Fst<A>*> > &fst_tuples, + Label root) : root_size_(0) { + cumulative_size_array_.push_back(0); + for (size_t i = 0; i < fst_tuples.size(); ++i) { + if (fst_tuples[i].first == root) { + root_size_ = CountStates(*(fst_tuples[i].second)); + cumulative_size_array_.push_back(cumulative_size_array_.back()); + } else { + cumulative_size_array_.push_back(cumulative_size_array_.back() + + CountStates(*(fst_tuples[i].second))); + } + } + state_table_ = new StateTable( + new ReplaceRootSelector<StateId, P>, + new ReplaceFstStateFingerprint<StateId, P>, + new ReplaceFingerprint<StateId, P>(&cumulative_size_array_), + root_size_, + root_size_ + cumulative_size_array_.back()); + } + + VectorHashReplaceStateTable(const VectorHashReplaceStateTable<A, P> &table) + : root_size_(table.root_size_), + cumulative_size_array_(table.cumulative_size_array_) { + state_table_ = new StateTable( + new ReplaceRootSelector<StateId, P>, + new ReplaceFstStateFingerprint<StateId, P>, + new ReplaceFingerprint<StateId, P>(&cumulative_size_array_), + root_size_, + root_size_ + cumulative_size_array_.back()); + } + + ~VectorHashReplaceStateTable() { + delete state_table_; + } + + StateId FindState(const StateTuple &tuple) { + return state_table_->FindState(tuple); + } + + const StateTuple &Tuple(StateId id) const { + return state_table_->Tuple(id); + } + + private: + StateId root_size_; + vector<uint64> cumulative_size_array_; + StateTable *state_table_; +}; + + +// \class DefaultReplaceStateTable +// Default replace state table +template <class A, class P = ssize_t> +class DefaultReplaceStateTable : public CompactHashStateTable< + ReplaceStateTuple<typename A::StateId, P>, + ReplaceHash<typename A::StateId, P> > { + public: + typedef A Arc; + typedef typename A::StateId StateId; + typedef typename A::Label Label; + typedef P PrefixId; + typedef ReplaceStateTuple<StateId, P> StateTuple; + typedef CompactHashStateTable<StateTuple, + ReplaceHash<StateId, PrefixId> > StateTable; + + using StateTable::FindState; + using StateTable::Tuple; + + DefaultReplaceStateTable( + const vector<pair<Label, const Fst<A>*> > &fst_tuples, + Label root) {} + + DefaultReplaceStateTable(const DefaultReplaceStateTable<A, P> &table) + : StateTable() {} +}; + +// +// REPLACE FST CLASS +// + +// By default ReplaceFst will copy the input label of the 'replace arc'. +// For acceptors we do not want this behaviour. Instead we need to +// create an epsilon arc when recursing into the appropriate Fst. +// The 'epsilon_on_replace' option can be used to toggle this behaviour. +template <class A, class T = DefaultReplaceStateTable<A> > +struct ReplaceFstOptions : CacheOptions { + int64 root; // root rule for expansion + bool epsilon_on_replace; + bool take_ownership; // take ownership of input Fst(s) + T* state_table; + + ReplaceFstOptions(const CacheOptions &opts, int64 r) + : CacheOptions(opts), + root(r), + epsilon_on_replace(false), + take_ownership(false), + state_table(0) {} + explicit ReplaceFstOptions(int64 r) + : root(r), + epsilon_on_replace(false), + take_ownership(false), + state_table(0) {} + ReplaceFstOptions(int64 r, bool epsilon_replace_arc) + : root(r), + epsilon_on_replace(epsilon_replace_arc), + take_ownership(false), + state_table(0) {} + ReplaceFstOptions() + : root(kNoLabel), + epsilon_on_replace(false), + take_ownership(false), + state_table(0) {} +}; + + +// \class ReplaceFstImpl +// \brief Implementation class for replace class Fst +// +// The replace implementation class supports a dynamic +// expansion of a recursive transition network represented as Fst +// with dynamic replacable arcs. +// +template <class A, class T> +class ReplaceFstImpl : public CacheImpl<A> { + friend class ReplaceFstMatcher<A, T>; + + public: + using FstImpl<A>::SetType; + using FstImpl<A>::SetProperties; + using FstImpl<A>::WriteHeader; + using FstImpl<A>::SetInputSymbols; + using FstImpl<A>::SetOutputSymbols; + using FstImpl<A>::InputSymbols; + using FstImpl<A>::OutputSymbols; + + using CacheImpl<A>::PushArc; + using CacheImpl<A>::HasArcs; + using CacheImpl<A>::HasFinal; + using CacheImpl<A>::HasStart; + using CacheImpl<A>::SetArcs; + using CacheImpl<A>::SetFinal; + using CacheImpl<A>::SetStart; + + typedef typename A::Label Label; + typedef typename A::Weight Weight; + typedef typename A::StateId StateId; + typedef CacheState<A> State; + typedef A Arc; + typedef unordered_map<Label, Label> NonTerminalHash; + + typedef T StateTable; + typedef typename T::PrefixId PrefixId; + typedef ReplaceStateTuple<StateId, PrefixId> StateTuple; + + // constructor for replace class implementation. + // \param fst_tuples array of label/fst tuples, one for each non-terminal + ReplaceFstImpl(const vector< pair<Label, const Fst<A>* > >& fst_tuples, + const ReplaceFstOptions<A, T> &opts) + : CacheImpl<A>(opts), + epsilon_on_replace_(opts.epsilon_on_replace), + state_table_(opts.state_table ? opts.state_table : + new StateTable(fst_tuples, opts.root)) { + + SetType("replace"); + + if (fst_tuples.size() > 0) { + SetInputSymbols(fst_tuples[0].second->InputSymbols()); + SetOutputSymbols(fst_tuples[0].second->OutputSymbols()); + } + + bool all_negative = true; // all nonterminals are negative? + bool dense_range = true; // all nonterminals are positive + // and form a dense range containing 1? + for (size_t i = 0; i < fst_tuples.size(); ++i) { + Label nonterminal = fst_tuples[i].first; + if (nonterminal >= 0) + all_negative = false; + if (nonterminal > fst_tuples.size() || nonterminal <= 0) + dense_range = false; + } + + vector<uint64> inprops; + bool all_ilabel_sorted = true; + bool all_olabel_sorted = true; + bool all_non_empty = true; + fst_array_.push_back(0); + for (size_t i = 0; i < fst_tuples.size(); ++i) { + Label label = fst_tuples[i].first; + const Fst<A> *fst = fst_tuples[i].second; + nonterminal_hash_[label] = fst_array_.size(); + nonterminal_set_.insert(label); + fst_array_.push_back(opts.take_ownership ? fst : fst->Copy()); + if (fst->Start() == kNoStateId) + all_non_empty = false; + if(!fst->Properties(kILabelSorted, false)) + all_ilabel_sorted = false; + if(!fst->Properties(kOLabelSorted, false)) + all_olabel_sorted = false; + inprops.push_back(fst->Properties(kCopyProperties, false)); + if (i) { + if (!CompatSymbols(InputSymbols(), fst->InputSymbols())) { + FSTERROR() << "ReplaceFstImpl: input symbols of Fst " << i + << " does not match input symbols of base Fst (0'th fst)"; + SetProperties(kError, kError); + } + if (!CompatSymbols(OutputSymbols(), fst->OutputSymbols())) { + FSTERROR() << "ReplaceFstImpl: output symbols of Fst " << i + << " does not match output symbols of base Fst " + << "(0'th fst)"; + SetProperties(kError, kError); + } + } + } + Label nonterminal = nonterminal_hash_[opts.root]; + if ((nonterminal == 0) && (fst_array_.size() > 1)) { + FSTERROR() << "ReplaceFstImpl: no Fst corresponding to root label '" + << opts.root << "' in the input tuple vector"; + SetProperties(kError, kError); + } + root_ = (nonterminal > 0) ? nonterminal : 1; + + SetProperties(ReplaceProperties(inprops, root_ - 1, epsilon_on_replace_, + all_non_empty)); + // We assume that all terminals are positive. The resulting + // ReplaceFst is known to be kILabelSorted when all sub-FSTs are + // kILabelSorted and one of the 3 following conditions is satisfied: + // 1. 'epsilon_on_replace' is false, or + // 2. all non-terminals are negative, or + // 3. all non-terninals are positive and form a dense range containing 1. + if (all_ilabel_sorted && + (!epsilon_on_replace_ || all_negative || dense_range)) + SetProperties(kILabelSorted, kILabelSorted); + // Similarly, the resulting ReplaceFst is known to be + // kOLabelSorted when all sub-FSTs are kOLabelSorted and one of + // the 2 following conditions is satisfied: + // 1. all non-terminals are negative, or + // 2. all non-terninals are positive and form a dense range containing 1. + if (all_olabel_sorted && (all_negative || dense_range)) + SetProperties(kOLabelSorted, kOLabelSorted); + + // Enable optional caching as long as sorted and all non empty. + if (Properties(kILabelSorted | kOLabelSorted) && all_non_empty) + always_cache_ = false; + else + always_cache_ = true; + VLOG(2) << "ReplaceFstImpl::ReplaceFstImpl: always_cache = " + << (always_cache_ ? "true" : "false"); + } + + ReplaceFstImpl(const ReplaceFstImpl& impl) + : CacheImpl<A>(impl), + epsilon_on_replace_(impl.epsilon_on_replace_), + always_cache_(impl.always_cache_), + state_table_(new StateTable(*(impl.state_table_))), + nonterminal_set_(impl.nonterminal_set_), + nonterminal_hash_(impl.nonterminal_hash_), + root_(impl.root_) { + SetType("replace"); + SetProperties(impl.Properties(), kCopyProperties); + SetInputSymbols(impl.InputSymbols()); + SetOutputSymbols(impl.OutputSymbols()); + fst_array_.reserve(impl.fst_array_.size()); + fst_array_.push_back(0); + for (size_t i = 1; i < impl.fst_array_.size(); ++i) { + fst_array_.push_back(impl.fst_array_[i]->Copy(true)); + } + } + + ~ReplaceFstImpl() { + VLOG(2) << "~ReplaceFstImpl: gc = " + << (CacheImpl<A>::GetCacheGc() ? "true" : "false") + << ", gc_size = " << CacheImpl<A>::GetCacheSize() + << ", gc_limit = " << CacheImpl<A>::GetCacheLimit(); + + delete state_table_; + for (size_t i = 1; i < fst_array_.size(); ++i) { + delete fst_array_[i]; + } + } + + // Computes the dependency graph of the replace class and returns + // true if the dependencies are cyclic. Cyclic dependencies will result + // in an un-expandable replace fst. + bool CyclicDependencies() const { + ReplaceUtil<A> replace_util(fst_array_, nonterminal_hash_, root_); + return replace_util.CyclicDependencies(); + } + + // Return or compute start state of replace fst + StateId Start() { + if (!HasStart()) { + if (fst_array_.size() == 1) { // no fsts defined for replace + SetStart(kNoStateId); + return kNoStateId; + } else { + const Fst<A>* fst = fst_array_[root_]; + StateId fst_start = fst->Start(); + if (fst_start == kNoStateId) // root Fst is empty + return kNoStateId; + + PrefixId prefix = GetPrefixId(StackPrefix()); + StateId start = state_table_->FindState( + StateTuple(prefix, root_, fst_start)); + SetStart(start); + return start; + } + } else { + return CacheImpl<A>::Start(); + } + } + + // return final weight of state (kInfWeight means state is not final) + Weight Final(StateId s) { + if (!HasFinal(s)) { + const StateTuple& tuple = state_table_->Tuple(s); + const StackPrefix& stack = stackprefix_array_[tuple.prefix_id]; + const Fst<A>* fst = fst_array_[tuple.fst_id]; + StateId fst_state = tuple.fst_state; + + if (fst->Final(fst_state) != Weight::Zero() && stack.Depth() == 0) + SetFinal(s, fst->Final(fst_state)); + else + SetFinal(s, Weight::Zero()); + } + return CacheImpl<A>::Final(s); + } + + size_t NumArcs(StateId s) { + if (HasArcs(s)) { // If state cached, use the cached value. + return CacheImpl<A>::NumArcs(s); + } else if (always_cache_) { // If always caching, expand and cache state. + Expand(s); + return CacheImpl<A>::NumArcs(s); + } else { // Otherwise compute the number of arcs without expanding. + StateTuple tuple = state_table_->Tuple(s); + if (tuple.fst_state == kNoStateId) + return 0; + + const Fst<A>* fst = fst_array_[tuple.fst_id]; + size_t num_arcs = fst->NumArcs(tuple.fst_state); + if (ComputeFinalArc(tuple, 0)) + num_arcs++; + + return num_arcs; + } + } + + // Returns whether a given label is a non terminal + bool IsNonTerminal(Label l) const { + // TODO(allauzen): be smarter and take advantage of + // all_dense or all_negative. + // Use also in ComputeArc, this would require changes to replace + // so that recursing into an empty fst lead to a non co-accessible + // state instead of deleting the arc as done currently. + // Current use correct, since i/olabel sorted iff all_non_empty. + typename NonTerminalHash::const_iterator it = + nonterminal_hash_.find(l); + return it != nonterminal_hash_.end(); + } + + size_t NumInputEpsilons(StateId s) { + if (HasArcs(s)) { + // If state cached, use the cached value. + return CacheImpl<A>::NumInputEpsilons(s); + } else if (always_cache_ || !Properties(kILabelSorted)) { + // If always caching or if the number of input epsilons is too expensive + // to compute without caching (i.e. not ilabel sorted), + // then expand and cache state. + Expand(s); + return CacheImpl<A>::NumInputEpsilons(s); + } else { + // Otherwise, compute the number of input epsilons without caching. + StateTuple tuple = state_table_->Tuple(s); + if (tuple.fst_state == kNoStateId) + return 0; + const Fst<A>* fst = fst_array_[tuple.fst_id]; + size_t num = 0; + if (!epsilon_on_replace_) { + // If epsilon_on_replace is false, all input epsilon arcs + // are also input epsilons arcs in the underlying machine. + fst->NumInputEpsilons(tuple.fst_state); + } else { + // Otherwise, one need to consider that all non-terminal arcs + // in the underlying machine also become input epsilon arc. + ArcIterator<Fst<A> > aiter(*fst, tuple.fst_state); + for (; !aiter.Done() && + ((aiter.Value().ilabel == 0) || + IsNonTerminal(aiter.Value().olabel)); + aiter.Next()) + ++num; + } + if (ComputeFinalArc(tuple, 0)) + num++; + return num; + } + } + + size_t NumOutputEpsilons(StateId s) { + if (HasArcs(s)) { + // If state cached, use the cached value. + return CacheImpl<A>::NumOutputEpsilons(s); + } else if(always_cache_ || !Properties(kOLabelSorted)) { + // If always caching or if the number of output epsilons is too expensive + // to compute without caching (i.e. not olabel sorted), + // then expand and cache state. + Expand(s); + return CacheImpl<A>::NumOutputEpsilons(s); + } else { + // Otherwise, compute the number of output epsilons without caching. + StateTuple tuple = state_table_->Tuple(s); + if (tuple.fst_state == kNoStateId) + return 0; + const Fst<A>* fst = fst_array_[tuple.fst_id]; + size_t num = 0; + ArcIterator<Fst<A> > aiter(*fst, tuple.fst_state); + for (; !aiter.Done() && + ((aiter.Value().olabel == 0) || + IsNonTerminal(aiter.Value().olabel)); + aiter.Next()) + ++num; + if (ComputeFinalArc(tuple, 0)) + num++; + return num; + } + } + + uint64 Properties() const { return Properties(kFstProperties); } + + // Set error if found; return FST impl properties. + uint64 Properties(uint64 mask) const { + if (mask & kError) { + for (size_t i = 1; i < fst_array_.size(); ++i) { + if (fst_array_[i]->Properties(kError, false)) + SetProperties(kError, kError); + } + } + return FstImpl<Arc>::Properties(mask); + } + + // return the base arc iterator, if arcs have not been computed yet, + // extend/recurse for new arcs. + void InitArcIterator(StateId s, ArcIteratorData<A> *data) { + if (!HasArcs(s)) + Expand(s); + CacheImpl<A>::InitArcIterator(s, data); + // TODO(allauzen): Set behaviour of generic iterator + // Warning: ArcIterator<ReplaceFst<A> >::InitCache() + // relies on current behaviour. + } + + + // Extend current state (walk arcs one level deep) + void Expand(StateId s) { + StateTuple tuple = state_table_->Tuple(s); + + // If local fst is empty + if (tuple.fst_state == kNoStateId) { + SetArcs(s); + return; + } + + ArcIterator< Fst<A> > aiter( + *(fst_array_[tuple.fst_id]), tuple.fst_state); + Arc arc; + + // Create a final arc when needed + if (ComputeFinalArc(tuple, &arc)) + PushArc(s, arc); + + // Expand all arcs leaving the state + for (;!aiter.Done(); aiter.Next()) { + if (ComputeArc(tuple, aiter.Value(), &arc)) + PushArc(s, arc); + } + + SetArcs(s); + } + + void Expand(StateId s, const StateTuple &tuple, + const ArcIteratorData<A> &data) { + // If local fst is empty + if (tuple.fst_state == kNoStateId) { + SetArcs(s); + return; + } + + ArcIterator< Fst<A> > aiter(data); + Arc arc; + + // Create a final arc when needed + if (ComputeFinalArc(tuple, &arc)) + AddArc(s, arc); + + // Expand all arcs leaving the state + for (; !aiter.Done(); aiter.Next()) { + if (ComputeArc(tuple, aiter.Value(), &arc)) + AddArc(s, arc); + } + + SetArcs(s); + } + + // If arcp == 0, only returns if a final arc is required, does not + // actually compute it. + bool ComputeFinalArc(const StateTuple &tuple, A* arcp, + uint32 flags = kArcValueFlags) { + const Fst<A>* fst = fst_array_[tuple.fst_id]; + StateId fst_state = tuple.fst_state; + if (fst_state == kNoStateId) + return false; + + // if state is final, pop up stack + const StackPrefix& stack = stackprefix_array_[tuple.prefix_id]; + if (fst->Final(fst_state) != Weight::Zero() && stack.Depth()) { + if (arcp) { + arcp->ilabel = 0; + arcp->olabel = 0; + if (flags & kArcNextStateValue) { + PrefixId prefix_id = PopPrefix(stack); + const PrefixTuple& top = stack.Top(); + arcp->nextstate = state_table_->FindState( + StateTuple(prefix_id, top.fst_id, top.nextstate)); + } + if (flags & kArcWeightValue) + arcp->weight = fst->Final(fst_state); + } + return true; + } else { + return false; + } + } + + // Compute the arc in the replace fst corresponding to a given + // in the underlying machine. Returns false if the underlying arc + // corresponds to no arc in the replace. + bool ComputeArc(const StateTuple &tuple, const A &arc, A* arcp, + uint32 flags = kArcValueFlags) { + if (!epsilon_on_replace_ && + (flags == (flags & (kArcILabelValue | kArcWeightValue)))) { + *arcp = arc; + return true; + } + + if (arc.olabel == 0) { // expand local fst + StateId nextstate = flags & kArcNextStateValue + ? state_table_->FindState( + StateTuple(tuple.prefix_id, tuple.fst_id, arc.nextstate)) + : kNoStateId; + *arcp = A(arc.ilabel, arc.olabel, arc.weight, nextstate); + } else { + // check for non terminal + typename NonTerminalHash::const_iterator it = + nonterminal_hash_.find(arc.olabel); + if (it != nonterminal_hash_.end()) { // recurse into non terminal + Label nonterminal = it->second; + const Fst<A>* nt_fst = fst_array_[nonterminal]; + PrefixId nt_prefix = PushPrefix(stackprefix_array_[tuple.prefix_id], + tuple.fst_id, arc.nextstate); + + // if start state is valid replace, else arc is implicitly + // deleted + StateId nt_start = nt_fst->Start(); + if (nt_start != kNoStateId) { + StateId nt_nextstate = flags & kArcNextStateValue + ? state_table_->FindState( + StateTuple(nt_prefix, nonterminal, nt_start)) + : kNoStateId; + Label ilabel = (epsilon_on_replace_) ? 0 : arc.ilabel; + *arcp = A(ilabel, 0, arc.weight, nt_nextstate); + } else { + return false; + } + } else { + StateId nextstate = flags & kArcNextStateValue + ? state_table_->FindState( + StateTuple(tuple.prefix_id, tuple.fst_id, arc.nextstate)) + : kNoStateId; + *arcp = A(arc.ilabel, arc.olabel, arc.weight, nextstate); + } + } + return true; + } + + // Returns the arc iterator flags supported by this Fst. + uint32 ArcIteratorFlags() const { + uint32 flags = kArcValueFlags; + if (!always_cache_) + flags |= kArcNoCache; + return flags; + } + + T* GetStateTable() const { + return state_table_; + } + + const Fst<A>* GetFst(Label fst_id) const { + return fst_array_[fst_id]; + } + + bool EpsilonOnReplace() const { return epsilon_on_replace_; } + + // private helper classes + private: + static const size_t kPrime0; + + // \class PrefixTuple + // \brief Tuple of fst_id and destination state (entry in stack prefix) + struct PrefixTuple { + PrefixTuple(Label f, StateId s) : fst_id(f), nextstate(s) {} + + Label fst_id; + StateId nextstate; + }; + + // \class StackPrefix + // \brief Container for stack prefix. + class StackPrefix { + public: + StackPrefix() {} + + // copy constructor + StackPrefix(const StackPrefix& x) : + prefix_(x.prefix_) { + } + + void Push(StateId fst_id, StateId nextstate) { + prefix_.push_back(PrefixTuple(fst_id, nextstate)); + } + + void Pop() { + prefix_.pop_back(); + } + + const PrefixTuple& Top() const { + return prefix_[prefix_.size()-1]; + } + + size_t Depth() const { + return prefix_.size(); + } + + public: + vector<PrefixTuple> prefix_; + }; + + + // \class StackPrefixEqual + // \brief Compare two stack prefix classes for equality + class StackPrefixEqual { + public: + bool operator()(const StackPrefix& x, const StackPrefix& y) const { + if (x.prefix_.size() != y.prefix_.size()) return false; + for (size_t i = 0; i < x.prefix_.size(); ++i) { + if (x.prefix_[i].fst_id != y.prefix_[i].fst_id || + x.prefix_[i].nextstate != y.prefix_[i].nextstate) return false; + } + return true; + } + }; + + // + // \class StackPrefixKey + // \brief Hash function for stack prefix to prefix id + class StackPrefixKey { + public: + size_t operator()(const StackPrefix& x) const { + size_t sum = 0; + for (size_t i = 0; i < x.prefix_.size(); ++i) { + sum += x.prefix_[i].fst_id + x.prefix_[i].nextstate*kPrime0; + } + return sum; + } + }; + + typedef unordered_map<StackPrefix, PrefixId, StackPrefixKey, StackPrefixEqual> + StackPrefixHash; + + // private methods + private: + // hash stack prefix (return unique index into stackprefix array) + PrefixId GetPrefixId(const StackPrefix& prefix) { + typename StackPrefixHash::iterator it = prefix_hash_.find(prefix); + if (it == prefix_hash_.end()) { + PrefixId prefix_id = stackprefix_array_.size(); + stackprefix_array_.push_back(prefix); + prefix_hash_[prefix] = prefix_id; + return prefix_id; + } else { + return it->second; + } + } + + // prefix id after a stack pop + PrefixId PopPrefix(StackPrefix prefix) { + prefix.Pop(); + return GetPrefixId(prefix); + } + + // prefix id after a stack push + PrefixId PushPrefix(StackPrefix prefix, Label fst_id, StateId nextstate) { + prefix.Push(fst_id, nextstate); + return GetPrefixId(prefix); + } + + + // private data + private: + // runtime options + bool epsilon_on_replace_; + bool always_cache_; // Optionally caching arc iterator disabled when true + + // state table + StateTable *state_table_; + + // cross index of unique stack prefix + // could potentially have one copy of prefix array + StackPrefixHash prefix_hash_; + vector<StackPrefix> stackprefix_array_; + + set<Label> nonterminal_set_; + NonTerminalHash nonterminal_hash_; + vector<const Fst<A>*> fst_array_; + Label root_; + + void operator=(const ReplaceFstImpl<A, T> &); // disallow +}; + + +template <class A, class T> +const size_t ReplaceFstImpl<A, T>::kPrime0 = 7853; + +// +// \class ReplaceFst +// \brief Recursivively replaces arcs in the root Fst with other Fsts. +// This version is a delayed Fst. +// +// ReplaceFst supports dynamic replacement of arcs in one Fst with +// another Fst. This replacement is recursive. ReplaceFst can be used +// to support a variety of delayed constructions such as recursive +// transition networks, union, or closure. It is constructed with an +// array of Fst(s). One Fst represents the root (or topology) +// machine. The root Fst refers to other Fsts by recursively replacing +// arcs labeled as non-terminals with the matching non-terminal +// Fst. Currently the ReplaceFst uses the output symbols of the arcs +// to determine whether the arc is a non-terminal arc or not. A +// non-terminal can be any label that is not a non-zero terminal label +// in the output alphabet. +// +// Note that the constructor uses a vector of pair<>. These correspond +// to the tuple of non-terminal Label and corresponding Fst. For example +// to implement the closure operation we need 2 Fsts. The first root +// Fst is a single Arc on the start State that self loops, it references +// the particular machine for which we are performing the closure operation. +// +// The ReplaceFst class supports an optionally caching arc iterator: +// ArcIterator< ReplaceFst<A> > +// The ReplaceFst need to be built such that it is known to be ilabel +// or olabel sorted (see usage below). +// +// Observe that Matcher<Fst<A> > will use the optionally caching arc +// iterator when available (Fst is ilabel sorted and matching on the +// input, or Fst is olabel sorted and matching on the output). +// In order to obtain the most efficient behaviour, it is recommended +// to set 'epsilon_on_replace' to false (this means constructing acceptors +// as transducers with epsilons on the input side of nonterminal arcs) +// and matching on the input side. +// +// This class attaches interface to implementation and handles +// reference counting, delegating most methods to ImplToFst. +template <class A, class T = DefaultReplaceStateTable<A> > +class ReplaceFst : public ImplToFst< ReplaceFstImpl<A, T> > { + public: + friend class ArcIterator< ReplaceFst<A, T> >; + friend class StateIterator< ReplaceFst<A, T> >; + friend class ReplaceFstMatcher<A, T>; + + typedef A Arc; + typedef typename A::Label Label; + typedef typename A::Weight Weight; + typedef typename A::StateId StateId; + typedef CacheState<A> State; + typedef ReplaceFstImpl<A, T> Impl; + + using ImplToFst<Impl>::Properties; + + ReplaceFst(const vector<pair<Label, const Fst<A>* > >& fst_array, + Label root) + : ImplToFst<Impl>(new Impl(fst_array, ReplaceFstOptions<A, T>(root))) {} + + ReplaceFst(const vector<pair<Label, const Fst<A>* > >& fst_array, + const ReplaceFstOptions<A, T> &opts) + : ImplToFst<Impl>(new Impl(fst_array, opts)) {} + + // See Fst<>::Copy() for doc. + ReplaceFst(const ReplaceFst<A, T>& fst, bool safe = false) + : ImplToFst<Impl>(fst, safe) {} + + // Get a copy of this ReplaceFst. See Fst<>::Copy() for further doc. + virtual ReplaceFst<A, T> *Copy(bool safe = false) const { + return new ReplaceFst<A, T>(*this, safe); + } + + virtual inline void InitStateIterator(StateIteratorData<A> *data) const; + + virtual void InitArcIterator(StateId s, ArcIteratorData<A> *data) const { + GetImpl()->InitArcIterator(s, data); + } + + virtual MatcherBase<A> *InitMatcher(MatchType match_type) const { + if ((GetImpl()->ArcIteratorFlags() & kArcNoCache) && + ((match_type == MATCH_INPUT && Properties(kILabelSorted, false)) || + (match_type == MATCH_OUTPUT && Properties(kOLabelSorted, false)))) { + return new ReplaceFstMatcher<A, T>(*this, match_type); + } + else { + VLOG(2) << "Not using replace matcher"; + return 0; + } + } + + bool CyclicDependencies() const { + return GetImpl()->CyclicDependencies(); + } + + private: + // Makes visible to friends. + Impl *GetImpl() const { return ImplToFst<Impl>::GetImpl(); } + + void operator=(const ReplaceFst<A> &fst); // disallow +}; + + +// Specialization for ReplaceFst. +template<class A, class T> +class StateIterator< ReplaceFst<A, T> > + : public CacheStateIterator< ReplaceFst<A, T> > { + public: + explicit StateIterator(const ReplaceFst<A, T> &fst) + : CacheStateIterator< ReplaceFst<A, T> >(fst, fst.GetImpl()) {} + + private: + DISALLOW_COPY_AND_ASSIGN(StateIterator); +}; + + +// Specialization for ReplaceFst. +// Implements optional caching. It can be used as follows: +// +// ReplaceFst<A> replace; +// ArcIterator< ReplaceFst<A> > aiter(replace, s); +// // Note: ArcIterator< Fst<A> > is always a caching arc iterator. +// aiter.SetFlags(kArcNoCache, kArcNoCache); +// // Use the arc iterator, no arc will be cached, no state will be expanded. +// // The varied 'kArcValueFlags' can be used to decide which part +// // of arc values needs to be computed. +// aiter.SetFlags(kArcILabelValue, kArcValueFlags); +// // Only want the ilabel for this arc +// aiter.Value(); // Does not compute the destination state. +// aiter.Next(); +// aiter.SetFlags(kArcNextStateValue, kArcNextStateValue); +// // Want both ilabel and nextstate for that arc +// aiter.Value(); // Does compute the destination state and inserts it +// // in the replace state table. +// // No Arc has been cached at that point. +// +template <class A, class T> +class ArcIterator< ReplaceFst<A, T> > { + public: + typedef A Arc; + typedef typename A::StateId StateId; + + ArcIterator(const ReplaceFst<A, T> &fst, StateId s) + : fst_(fst), state_(s), pos_(0), offset_(0), flags_(0), arcs_(0), + data_flags_(0), final_flags_(0) { + cache_data_.ref_count = 0; + local_data_.ref_count = 0; + + // If FST does not support optional caching, force caching. + if(!(fst_.GetImpl()->ArcIteratorFlags() & kArcNoCache) && + !(fst_.GetImpl()->HasArcs(state_))) + fst_.GetImpl()->Expand(state_); + + // If state is already cached, use cached arcs array. + if (fst_.GetImpl()->HasArcs(state_)) { + (fst_.GetImpl())->template CacheImpl<A>::InitArcIterator(state_, + &cache_data_); + num_arcs_ = cache_data_.narcs; + arcs_ = cache_data_.arcs; // 'arcs_' is a ptr to the cached arcs. + data_flags_ = kArcValueFlags; // All the arc member values are valid. + } else { // Otherwise delay decision until Value() is called. + tuple_ = fst_.GetImpl()->GetStateTable()->Tuple(state_); + if (tuple_.fst_state == kNoStateId) { + num_arcs_ = 0; + } else { + // The decision to cache or not to cache has been defered + // until Value() or SetFlags() is called. However, the arc + // iterator is set up now to be ready for non-caching in order + // to keep the Value() method simple and efficient. + const Fst<A>* fst = fst_.GetImpl()->GetFst(tuple_.fst_id); + fst->InitArcIterator(tuple_.fst_state, &local_data_); + // 'arcs_' is a pointer to the arcs in the underlying machine. + arcs_ = local_data_.arcs; + // Compute the final arc (but not its destination state) + // if a final arc is required. + bool has_final_arc = fst_.GetImpl()->ComputeFinalArc( + tuple_, + &final_arc_, + kArcValueFlags & ~kArcNextStateValue); + // Set the arc value flags that hold for 'final_arc_'. + final_flags_ = kArcValueFlags & ~kArcNextStateValue; + // Compute the number of arcs. + num_arcs_ = local_data_.narcs; + if (has_final_arc) + ++num_arcs_; + // Set the offset between the underlying arc positions and + // the positions in the arc iterator. + offset_ = num_arcs_ - local_data_.narcs; + // Defers the decision to cache or not until Value() or + // SetFlags() is called. + data_flags_ = 0; + } + } + } + + ~ArcIterator() { + if (cache_data_.ref_count) + --(*cache_data_.ref_count); + if (local_data_.ref_count) + --(*local_data_.ref_count); + } + + void ExpandAndCache() const { + // TODO(allauzen): revisit this + // fst_.GetImpl()->Expand(state_, tuple_, local_data_); + // (fst_.GetImpl())->CacheImpl<A>*>::InitArcIterator(state_, + // &cache_data_); + // + fst_.InitArcIterator(state_, &cache_data_); // Expand and cache state. + arcs_ = cache_data_.arcs; // 'arcs_' is a pointer to the cached arcs. + data_flags_ = kArcValueFlags; // All the arc member values are valid. + offset_ = 0; // No offset + + } + + void Init() { + if (flags_ & kArcNoCache) { // If caching is disabled + // 'arcs_' is a pointer to the arcs in the underlying machine. + arcs_ = local_data_.arcs; + // Set the arcs value flags that hold for 'arcs_'. + data_flags_ = kArcWeightValue; + if (!fst_.GetImpl()->EpsilonOnReplace()) + data_flags_ |= kArcILabelValue; + // Set the offset between the underlying arc positions and + // the positions in the arc iterator. + offset_ = num_arcs_ - local_data_.narcs; + } else { // Otherwise, expand and cache + ExpandAndCache(); + } + } + + bool Done() const { return pos_ >= num_arcs_; } + + const A& Value() const { + // If 'data_flags_' was set to 0, non-caching was not requested + if (!data_flags_) { + // TODO(allauzen): revisit this. + if (flags_ & kArcNoCache) { + // Should never happen. + FSTERROR() << "ReplaceFst: inconsistent arc iterator flags"; + } + ExpandAndCache(); // Expand and cache. + } + + if (pos_ - offset_ >= 0) { // The requested arc is not the 'final' arc. + const A& arc = arcs_[pos_ - offset_]; + if ((data_flags_ & flags_) == (flags_ & kArcValueFlags)) { + // If the value flags for 'arc' match the recquired value flags + // then return 'arc'. + return arc; + } else { + // Otherwise, compute the corresponding arc on-the-fly. + fst_.GetImpl()->ComputeArc(tuple_, arc, &arc_, flags_ & kArcValueFlags); + return arc_; + } + } else { // The requested arc is the 'final' arc. + if ((final_flags_ & flags_) != (flags_ & kArcValueFlags)) { + // If the arc value flags that hold for the final arc + // do not match the requested value flags, then + // 'final_arc_' needs to be updated. + fst_.GetImpl()->ComputeFinalArc(tuple_, &final_arc_, + flags_ & kArcValueFlags); + final_flags_ = flags_ & kArcValueFlags; + } + return final_arc_; + } + } + + void Next() { ++pos_; } + + size_t Position() const { return pos_; } + + void Reset() { pos_ = 0; } + + void Seek(size_t pos) { pos_ = pos; } + + uint32 Flags() const { return flags_; } + + void SetFlags(uint32 f, uint32 mask) { + // Update the flags taking into account what flags are supported + // by the Fst. + flags_ &= ~mask; + flags_ |= (f & fst_.GetImpl()->ArcIteratorFlags()); + // If non-caching is not requested (and caching has not already + // been performed), then flush 'data_flags_' to request caching + // during the next call to Value(). + if (!(flags_ & kArcNoCache) && data_flags_ != kArcValueFlags) { + if (!fst_.GetImpl()->HasArcs(state_)) + data_flags_ = 0; + } + // If 'data_flags_' has been flushed but non-caching is requested + // before calling Value(), then set up the iterator for non-caching. + if ((f & kArcNoCache) && (!data_flags_)) + Init(); + } + + private: + const ReplaceFst<A, T> &fst_; // Reference to the FST + StateId state_; // State in the FST + mutable typename T::StateTuple tuple_; // Tuple corresponding to state_ + + ssize_t pos_; // Current position + mutable ssize_t offset_; // Offset between position in iterator and in arcs_ + ssize_t num_arcs_; // Number of arcs at state_ + uint32 flags_; // Behavorial flags for the arc iterator + mutable Arc arc_; // Memory to temporarily store computed arcs + + mutable ArcIteratorData<Arc> cache_data_; // Arc iterator data in cache + mutable ArcIteratorData<Arc> local_data_; // Arc iterator data in local fst + + mutable const A* arcs_; // Array of arcs + mutable uint32 data_flags_; // Arc value flags valid for data in arcs_ + mutable Arc final_arc_; // Final arc (when required) + mutable uint32 final_flags_; // Arc value flags valid for final_arc_ + + DISALLOW_COPY_AND_ASSIGN(ArcIterator); +}; + + +template <class A, class T> +class ReplaceFstMatcher : public MatcherBase<A> { + public: + typedef A Arc; + typedef typename A::StateId StateId; + typedef typename A::Label Label; + typedef MultiEpsMatcher<Matcher<Fst<A> > > LocalMatcher; + + ReplaceFstMatcher(const ReplaceFst<A, T> &fst, fst::MatchType match_type) + : fst_(fst), + impl_(fst_.GetImpl()), + s_(fst::kNoStateId), + match_type_(match_type), + current_loop_(false), + final_arc_(false), + loop_(fst::kNoLabel, 0, A::Weight::One(), fst::kNoStateId) { + if (match_type_ == fst::MATCH_OUTPUT) + swap(loop_.ilabel, loop_.olabel); + InitMatchers(); + } + + ReplaceFstMatcher(const ReplaceFstMatcher<A, T> &matcher, bool safe = false) + : fst_(matcher.fst_), + impl_(fst_.GetImpl()), + s_(fst::kNoStateId), + match_type_(matcher.match_type_), + current_loop_(false), + loop_(fst::kNoLabel, 0, A::Weight::One(), fst::kNoStateId) { + if (match_type_ == fst::MATCH_OUTPUT) + swap(loop_.ilabel, loop_.olabel); + InitMatchers(); + } + + // Create a local matcher for each component Fst of replace. + // LocalMatcher is a multi epsilon wrapper matcher. MultiEpsilonMatcher + // is used to match each non-terminal arc, since these non-terminal + // turn into epsilons on recursion. + void InitMatchers() { + const vector<const Fst<A>*>& fst_array = impl_->fst_array_; + matcher_.resize(fst_array.size(), 0); + for (size_t i = 0; i < fst_array.size(); ++i) { + if (fst_array[i]) { + matcher_[i] = + new LocalMatcher(*fst_array[i], match_type_, kMultiEpsList); + + typename set<Label>::iterator it = impl_->nonterminal_set_.begin(); + for (; it != impl_->nonterminal_set_.end(); ++it) { + matcher_[i]->AddMultiEpsLabel(*it); + } + } + } + } + + virtual ReplaceFstMatcher<A, T> *Copy(bool safe = false) const { + return new ReplaceFstMatcher<A, T>(*this, safe); + } + + virtual ~ReplaceFstMatcher() { + for (size_t i = 0; i < matcher_.size(); ++i) + delete matcher_[i]; + } + + virtual MatchType Type(bool test) const { + if (match_type_ == MATCH_NONE) + return match_type_; + + uint64 true_prop = match_type_ == MATCH_INPUT ? + kILabelSorted : kOLabelSorted; + uint64 false_prop = match_type_ == MATCH_INPUT ? + kNotILabelSorted : kNotOLabelSorted; + uint64 props = fst_.Properties(true_prop | false_prop, test); + + if (props & true_prop) + return match_type_; + else if (props & false_prop) + return MATCH_NONE; + else + return MATCH_UNKNOWN; + } + + virtual const Fst<A> &GetFst() const { + return fst_; + } + + virtual uint64 Properties(uint64 props) const { + return props; + } + + private: + // Set the sate from which our matching happens. + virtual void SetState_(StateId s) { + if (s_ == s) return; + + s_ = s; + tuple_ = impl_->GetStateTable()->Tuple(s_); + if (tuple_.fst_state == kNoStateId) { + done_ = true; + return; + } + // Get current matcher. Used for non epsilon matching + current_matcher_ = matcher_[tuple_.fst_id]; + current_matcher_->SetState(tuple_.fst_state); + loop_.nextstate = s_; + + final_arc_ = false; + } + + // Search for label, from previous set state. If label == 0, first + // hallucinate and epsilon loop, else use the underlying matcher to + // search for the label or epsilons. + // - Note since the ReplaceFST recursion on non-terminal arcs causes + // epsilon transitions to be created we use the MultiEpsilonMatcher + // to search for possible matches of non terminals. + // - If the component Fst reaches a final state we also need to add + // the exiting final arc. + virtual bool Find_(Label label) { + bool found = false; + label_ = label; + if (label_ == 0 || label_ == kNoLabel) { + // Compute loop directly, saving Replace::ComputeArc + if (label_ == 0) { + current_loop_ = true; + found = true; + } + // Search for matching multi epsilons + final_arc_ = impl_->ComputeFinalArc(tuple_, 0); + found = current_matcher_->Find(kNoLabel) || final_arc_ || found; + } else { + // Search on sub machine directly using sub machine matcher. + found = current_matcher_->Find(label_); + } + return found; + } + + virtual bool Done_() const { + return !current_loop_ && !final_arc_ && current_matcher_->Done(); + } + + virtual const Arc& Value_() const { + if (current_loop_) { + return loop_; + } + if (final_arc_) { + impl_->ComputeFinalArc(tuple_, &arc_); + return arc_; + } + const Arc& component_arc = current_matcher_->Value(); + impl_->ComputeArc(tuple_, component_arc, &arc_); + return arc_; + } + + virtual void Next_() { + if (current_loop_) { + current_loop_ = false; + return; + } + if (final_arc_) { + final_arc_ = false; + return; + } + current_matcher_->Next(); + } + + const ReplaceFst<A, T>& fst_; + ReplaceFstImpl<A, T> *impl_; + LocalMatcher* current_matcher_; + vector<LocalMatcher*> matcher_; + + StateId s_; // Current state + Label label_; // Current label + + MatchType match_type_; // Supplied by caller + mutable bool done_; + mutable bool current_loop_; // Current arc is the implicit loop + mutable bool final_arc_; // Current arc for exiting recursion + mutable typename T::StateTuple tuple_; // Tuple corresponding to state_ + mutable Arc arc_; + Arc loop_; +}; + +template <class A, class T> inline +void ReplaceFst<A, T>::InitStateIterator(StateIteratorData<A> *data) const { + data->base = new StateIterator< ReplaceFst<A, T> >(*this); +} + +typedef ReplaceFst<StdArc> StdReplaceFst; + + +// // Recursivively replaces arcs in the root Fst with other Fsts. +// This version writes the result of replacement to an output MutableFst. +// +// Replace supports replacement of arcs in one Fst with another +// Fst. This replacement is recursive. Replace takes an array of +// Fst(s). One Fst represents the root (or topology) machine. The root +// Fst refers to other Fsts by recursively replacing arcs labeled as +// non-terminals with the matching non-terminal Fst. Currently Replace +// uses the output symbols of the arcs to determine whether the arc is +// a non-terminal arc or not. A non-terminal can be any label that is +// not a non-zero terminal label in the output alphabet. Note that +// input argument is a vector of pair<>. These correspond to the tuple +// of non-terminal Label and corresponding Fst. +template<class Arc> +void Replace(const vector<pair<typename Arc::Label, + const Fst<Arc>* > >& ifst_array, + MutableFst<Arc> *ofst, typename Arc::Label root, + bool epsilon_on_replace) { + ReplaceFstOptions<Arc> opts(root, epsilon_on_replace); + opts.gc_limit = 0; // Cache only the last state for fastest copy. + *ofst = ReplaceFst<Arc>(ifst_array, opts); +} + +template<class Arc> +void Replace(const vector<pair<typename Arc::Label, + const Fst<Arc>* > >& ifst_array, + MutableFst<Arc> *ofst, typename Arc::Label root) { + Replace(ifst_array, ofst, root, false); +} + +} // namespace fst + +#endif // FST_LIB_REPLACE_H__ diff --git a/kaldi_io/src/tools/openfst/include/fst/reverse.h b/kaldi_io/src/tools/openfst/include/fst/reverse.h new file mode 100644 index 0000000..4d4c75c --- /dev/null +++ b/kaldi_io/src/tools/openfst/include/fst/reverse.h @@ -0,0 +1,91 @@ +// reverse.h + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Copyright 2005-2010 Google, Inc. +// Author: [email protected] (Michael Riley) +// +// \file +// Functions and classes to sort arcs in an FST. + +#ifndef FST_LIB_REVERSE_H__ +#define FST_LIB_REVERSE_H__ + +#include <algorithm> +#include <vector> +using std::vector; + +#include <fst/cache.h> + + +namespace fst { + +// Reverses an FST. The reversed result is written to an output +// MutableFst. If A transduces string x to y with weight a, then the +// reverse of A transduces the reverse of x to the reverse of y with +// weight a.Reverse(). +// +// Typically, a = a.Reverse() and Arc = RevArc (e.g. for +// TropicalWeight or LogWeight). In general, e.g. when the weights +// only form a left or right semiring, the output arc type must match +// the input arc type except having the reversed Weight type. +template<class Arc, class RevArc> +void Reverse(const Fst<Arc> &ifst, MutableFst<RevArc> *ofst) { + typedef typename Arc::StateId StateId; + typedef typename Arc::Weight Weight; + typedef typename RevArc::Weight RevWeight; + + ofst->DeleteStates(); + ofst->SetInputSymbols(ifst.InputSymbols()); + ofst->SetOutputSymbols(ifst.OutputSymbols()); + if (ifst.Properties(kExpanded, false)) + ofst->ReserveStates(CountStates(ifst) + 1); + StateId istart = ifst.Start(); + StateId ostart = ofst->AddState(); + ofst->SetStart(ostart); + + for (StateIterator< Fst<Arc> > siter(ifst); + !siter.Done(); + siter.Next()) { + StateId is = siter.Value(); + StateId os = is + 1; + while (ofst->NumStates() <= os) + ofst->AddState(); + if (is == istart) + ofst->SetFinal(os, RevWeight::One()); + + Weight final = ifst.Final(is); + if (final != Weight::Zero()) { + RevArc oarc(0, 0, final.Reverse(), os); + ofst->AddArc(0, oarc); + } + + for (ArcIterator< Fst<Arc> > aiter(ifst, is); + !aiter.Done(); + aiter.Next()) { + const Arc &iarc = aiter.Value(); + RevArc oarc(iarc.ilabel, iarc.olabel, iarc.weight.Reverse(), os); + StateId nos = iarc.nextstate + 1; + while (ofst->NumStates() <= nos) + ofst->AddState(); + ofst->AddArc(nos, oarc); + } + } + uint64 iprops = ifst.Properties(kCopyProperties, false); + uint64 oprops = ofst->Properties(kFstProperties, false); + ofst->SetProperties(ReverseProperties(iprops) | oprops, kFstProperties); +} + +} // namespace fst + +#endif // FST_LIB_REVERSE_H__ diff --git a/kaldi_io/src/tools/openfst/include/fst/reweight.h b/kaldi_io/src/tools/openfst/include/fst/reweight.h new file mode 100644 index 0000000..c051c2a --- /dev/null +++ b/kaldi_io/src/tools/openfst/include/fst/reweight.h @@ -0,0 +1,146 @@ +// reweight.h + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Copyright 2005-2010 Google, Inc. +// Author: [email protected] (Cyril Allauzen) +// +// \file +// Function to reweight an FST. + +#ifndef FST_LIB_REWEIGHT_H__ +#define FST_LIB_REWEIGHT_H__ + +#include <vector> +using std::vector; + +#include <fst/mutable-fst.h> + + +namespace fst { + +enum ReweightType { REWEIGHT_TO_INITIAL, REWEIGHT_TO_FINAL }; + +// Reweight FST according to the potentials defined by the POTENTIAL +// vector in the direction defined by TYPE. Weight needs to be left +// distributive when reweighting towards the initial state and right +// distributive when reweighting towards the final states. +// +// An arc of weight w, with an origin state of potential p and +// destination state of potential q, is reweighted by p\wq when +// reweighting towards the initial state and by pw/q when reweighting +// towards the final states. +template <class Arc> +void Reweight(MutableFst<Arc> *fst, + const vector<typename Arc::Weight> &potential, + ReweightType type) { + typedef typename Arc::Weight Weight; + + if (fst->NumStates() == 0) + return; + + if (type == REWEIGHT_TO_FINAL && !(Weight::Properties() & kRightSemiring)) { + FSTERROR() << "Reweight: Reweighting to the final states requires " + << "Weight to be right distributive: " + << Weight::Type(); + fst->SetProperties(kError, kError); + return; + } + + if (type == REWEIGHT_TO_INITIAL && !(Weight::Properties() & kLeftSemiring)) { + FSTERROR() << "Reweight: Reweighting to the initial state requires " + << "Weight to be left distributive: " + << Weight::Type(); + fst->SetProperties(kError, kError); + return; + } + + StateIterator< MutableFst<Arc> > sit(*fst); + for (; !sit.Done(); sit.Next()) { + typename Arc::StateId state = sit.Value(); + if (state == potential.size()) + break; + typename Arc::Weight weight = potential[state]; + if (weight != Weight::Zero()) { + for (MutableArcIterator< MutableFst<Arc> > ait(fst, state); + !ait.Done(); + ait.Next()) { + Arc arc = ait.Value(); + if (arc.nextstate >= potential.size()) + continue; + typename Arc::Weight nextweight = potential[arc.nextstate]; + if (nextweight == Weight::Zero()) + continue; + if (type == REWEIGHT_TO_INITIAL) + arc.weight = Divide(Times(arc.weight, nextweight), weight, + DIVIDE_LEFT); + if (type == REWEIGHT_TO_FINAL) + arc.weight = Divide(Times(weight, arc.weight), nextweight, + DIVIDE_RIGHT); + ait.SetValue(arc); + } + if (type == REWEIGHT_TO_INITIAL) + fst->SetFinal(state, Divide(fst->Final(state), weight, DIVIDE_LEFT)); + } + if (type == REWEIGHT_TO_FINAL) + fst->SetFinal(state, Times(weight, fst->Final(state))); + } + + // This handles elements past the end of the potentials array. + for (; !sit.Done(); sit.Next()) { + typename Arc::StateId state = sit.Value(); + if (type == REWEIGHT_TO_FINAL) + fst->SetFinal(state, Times(Weight::Zero(), fst->Final(state))); + } + + typename Arc::Weight startweight = fst->Start() < potential.size() ? + potential[fst->Start()] : Weight::Zero(); + if ((startweight != Weight::One()) && (startweight != Weight::Zero())) { + if (fst->Properties(kInitialAcyclic, true) & kInitialAcyclic) { + typename Arc::StateId state = fst->Start(); + for (MutableArcIterator< MutableFst<Arc> > ait(fst, state); + !ait.Done(); + ait.Next()) { + Arc arc = ait.Value(); + if (type == REWEIGHT_TO_INITIAL) + arc.weight = Times(startweight, arc.weight); + else + arc.weight = Times( + Divide(Weight::One(), startweight, DIVIDE_RIGHT), + arc.weight); + ait.SetValue(arc); + } + if (type == REWEIGHT_TO_INITIAL) + fst->SetFinal(state, Times(startweight, fst->Final(state))); + else + fst->SetFinal(state, Times(Divide(Weight::One(), startweight, + DIVIDE_RIGHT), + fst->Final(state))); + } else { + typename Arc::StateId state = fst->AddState(); + Weight w = type == REWEIGHT_TO_INITIAL ? startweight : + Divide(Weight::One(), startweight, DIVIDE_RIGHT); + Arc arc(0, 0, w, fst->Start()); + fst->AddArc(state, arc); + fst->SetStart(state); + } + } + + fst->SetProperties(ReweightProperties( + fst->Properties(kFstProperties, false)), + kFstProperties); +} + +} // namespace fst + +#endif // FST_LIB_REWEIGHT_H_ diff --git a/kaldi_io/src/tools/openfst/include/fst/rmepsilon.h b/kaldi_io/src/tools/openfst/include/fst/rmepsilon.h new file mode 100644 index 0000000..89b8178 --- /dev/null +++ b/kaldi_io/src/tools/openfst/include/fst/rmepsilon.h @@ -0,0 +1,600 @@ +// rmepsilon.h + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Copyright 2005-2010 Google, Inc. +// Author: [email protected] (Cyril Allauzen) +// +// \file +// Functions and classes that implemement epsilon-removal. + +#ifndef FST_LIB_RMEPSILON_H__ +#define FST_LIB_RMEPSILON_H__ + +#include <tr1/unordered_map> +using std::tr1::unordered_map; +using std::tr1::unordered_multimap; +#include <fst/slist.h> +#include <stack> +#include <string> +#include <utility> +using std::pair; using std::make_pair; +#include <vector> +using std::vector; + +#include <fst/arcfilter.h> +#include <fst/cache.h> +#include <fst/connect.h> +#include <fst/factor-weight.h> +#include <fst/invert.h> +#include <fst/prune.h> +#include <fst/queue.h> +#include <fst/shortest-distance.h> +#include <fst/topsort.h> + + +namespace fst { + +template <class Arc, class Queue> +class RmEpsilonOptions + : public ShortestDistanceOptions<Arc, Queue, EpsilonArcFilter<Arc> > { + public: + typedef typename Arc::StateId StateId; + typedef typename Arc::Weight Weight; + + bool connect; // Connect output + Weight weight_threshold; // Pruning weight threshold. + StateId state_threshold; // Pruning state threshold. + + explicit RmEpsilonOptions(Queue *q, float d = kDelta, bool c = true, + Weight w = Weight::Zero(), + StateId n = kNoStateId) + : ShortestDistanceOptions< Arc, Queue, EpsilonArcFilter<Arc> >( + q, EpsilonArcFilter<Arc>(), kNoStateId, d), + connect(c), weight_threshold(w), state_threshold(n) {} + private: + RmEpsilonOptions(); // disallow +}; + +// Computation state of the epsilon-removal algorithm. +template <class Arc, class Queue> +class RmEpsilonState { + public: + typedef typename Arc::Label Label; + typedef typename Arc::StateId StateId; + typedef typename Arc::Weight Weight; + + RmEpsilonState(const Fst<Arc> &fst, + vector<Weight> *distance, + const RmEpsilonOptions<Arc, Queue> &opts) + : fst_(fst), distance_(distance), sd_state_(fst_, distance, opts, true), + expand_id_(0) {} + + // Compute arcs and final weight for state 's' + void Expand(StateId s); + + // Returns arcs of expanded state. + vector<Arc> &Arcs() { return arcs_; } + + // Returns final weight of expanded state. + const Weight &Final() const { return final_; } + + // Return true if an error has occured. + bool Error() const { return sd_state_.Error(); } + + private: + static const size_t kPrime0 = 7853; + static const size_t kPrime1 = 7867; + + struct Element { + Label ilabel; + Label olabel; + StateId nextstate; + + Element() {} + + Element(Label i, Label o, StateId s) + : ilabel(i), olabel(o), nextstate(s) {} + }; + + class ElementKey { + public: + size_t operator()(const Element& e) const { + return static_cast<size_t>(e.nextstate + + e.ilabel * kPrime0 + + e.olabel * kPrime1); + } + + private: + }; + + class ElementEqual { + public: + bool operator()(const Element &e1, const Element &e2) const { + return (e1.ilabel == e2.ilabel) && (e1.olabel == e2.olabel) + && (e1.nextstate == e2.nextstate); + } + }; + + typedef unordered_map<Element, pair<StateId, size_t>, + ElementKey, ElementEqual> ElementMap; + + const Fst<Arc> &fst_; + // Distance from state being expanded in epsilon-closure. + vector<Weight> *distance_; + // Shortest distance algorithm computation state. + ShortestDistanceState<Arc, Queue, EpsilonArcFilter<Arc> > sd_state_; + // Maps an element 'e' to a pair 'p' corresponding to a position + // in the arcs vector of the state being expanded. 'e' corresponds + // to the position 'p.second' in the 'arcs_' vector if 'p.first' is + // equal to the state being expanded. + ElementMap element_map_; + EpsilonArcFilter<Arc> eps_filter_; + stack<StateId> eps_queue_; // Queue used to visit the epsilon-closure + vector<bool> visited_; // '[i] = true' if state 'i' has been visited + slist<StateId> visited_states_; // List of visited states + vector<Arc> arcs_; // Arcs of state being expanded + Weight final_; // Final weight of state being expanded + StateId expand_id_; // Unique ID for each call to Expand + + DISALLOW_COPY_AND_ASSIGN(RmEpsilonState); +}; + +template <class Arc, class Queue> +const size_t RmEpsilonState<Arc, Queue>::kPrime0; +template <class Arc, class Queue> +const size_t RmEpsilonState<Arc, Queue>::kPrime1; + + +template <class Arc, class Queue> +void RmEpsilonState<Arc,Queue>::Expand(typename Arc::StateId source) { + final_ = Weight::Zero(); + arcs_.clear(); + sd_state_.ShortestDistance(source); + if (sd_state_.Error()) + return; + eps_queue_.push(source); + + while (!eps_queue_.empty()) { + StateId state = eps_queue_.top(); + eps_queue_.pop(); + + while (visited_.size() <= state) visited_.push_back(false); + if (visited_[state]) continue; + visited_[state] = true; + visited_states_.push_front(state); + + for (ArcIterator< Fst<Arc> > ait(fst_, state); + !ait.Done(); + ait.Next()) { + Arc arc = ait.Value(); + arc.weight = Times((*distance_)[state], arc.weight); + + if (eps_filter_(arc)) { + while (visited_.size() <= arc.nextstate) + visited_.push_back(false); + if (!visited_[arc.nextstate]) + eps_queue_.push(arc.nextstate); + } else { + Element element(arc.ilabel, arc.olabel, arc.nextstate); + typename ElementMap::iterator it = element_map_.find(element); + if (it == element_map_.end()) { + element_map_.insert( + pair<Element, pair<StateId, size_t> > + (element, pair<StateId, size_t>(expand_id_, arcs_.size()))); + arcs_.push_back(arc); + } else { + if (((*it).second).first == expand_id_) { + Weight &w = arcs_[((*it).second).second].weight; + w = Plus(w, arc.weight); + } else { + ((*it).second).first = expand_id_; + ((*it).second).second = arcs_.size(); + arcs_.push_back(arc); + } + } + } + } + final_ = Plus(final_, Times((*distance_)[state], fst_.Final(state))); + } + + while (!visited_states_.empty()) { + visited_[visited_states_.front()] = false; + visited_states_.pop_front(); + } + ++expand_id_; +} + +// Removes epsilon-transitions (when both the input and output label +// are an epsilon) from a transducer. The result will be an equivalent +// FST that has no such epsilon transitions. This version modifies +// its input. It allows fine control via the options argument; see +// below for a simpler interface. +// +// The vector 'distance' will be used to hold the shortest distances +// during the epsilon-closure computation. The state queue discipline +// and convergence delta are taken in the options argument. +template <class Arc, class Queue> +void RmEpsilon(MutableFst<Arc> *fst, + vector<typename Arc::Weight> *distance, + const RmEpsilonOptions<Arc, Queue> &opts) { + typedef typename Arc::StateId StateId; + typedef typename Arc::Weight Weight; + typedef typename Arc::Label Label; + + if (fst->Start() == kNoStateId) { + return; + } + + // 'noneps_in[s]' will be set to true iff 's' admits a non-epsilon + // incoming transition or is the start state. + vector<bool> noneps_in(fst->NumStates(), false); + noneps_in[fst->Start()] = true; + for (StateId i = 0; i < fst->NumStates(); ++i) { + for (ArcIterator<Fst<Arc> > aiter(*fst, i); + !aiter.Done(); + aiter.Next()) { + if (aiter.Value().ilabel != 0 || aiter.Value().olabel != 0) + noneps_in[aiter.Value().nextstate] = true; + } + } + + // States sorted in topological order when (acyclic) or generic + // topological order (cyclic). + vector<StateId> states; + states.reserve(fst->NumStates()); + + if (fst->Properties(kTopSorted, false) & kTopSorted) { + for (StateId i = 0; i < fst->NumStates(); i++) + states.push_back(i); + } else if (fst->Properties(kAcyclic, false) & kAcyclic) { + vector<StateId> order; + bool acyclic; + TopOrderVisitor<Arc> top_order_visitor(&order, &acyclic); + DfsVisit(*fst, &top_order_visitor, EpsilonArcFilter<Arc>()); + // Sanity check: should be acyclic if property bit is set. + if(!acyclic) { + FSTERROR() << "RmEpsilon: inconsistent acyclic property bit"; + fst->SetProperties(kError, kError); + return; + } + states.resize(order.size()); + for (StateId i = 0; i < order.size(); i++) + states[order[i]] = i; + } else { + uint64 props; + vector<StateId> scc; + SccVisitor<Arc> scc_visitor(&scc, 0, 0, &props); + DfsVisit(*fst, &scc_visitor, EpsilonArcFilter<Arc>()); + vector<StateId> first(scc.size(), kNoStateId); + vector<StateId> next(scc.size(), kNoStateId); + for (StateId i = 0; i < scc.size(); i++) { + if (first[scc[i]] != kNoStateId) + next[i] = first[scc[i]]; + first[scc[i]] = i; + } + for (StateId i = 0; i < first.size(); i++) + for (StateId j = first[i]; j != kNoStateId; j = next[j]) + states.push_back(j); + } + + RmEpsilonState<Arc, Queue> + rmeps_state(*fst, distance, opts); + + while (!states.empty()) { + StateId state = states.back(); + states.pop_back(); + if (!noneps_in[state]) + continue; + rmeps_state.Expand(state); + fst->SetFinal(state, rmeps_state.Final()); + fst->DeleteArcs(state); + vector<Arc> &arcs = rmeps_state.Arcs(); + fst->ReserveArcs(state, arcs.size()); + while (!arcs.empty()) { + fst->AddArc(state, arcs.back()); + arcs.pop_back(); + } + } + + for (StateId s = 0; s < fst->NumStates(); ++s) { + if (!noneps_in[s]) + fst->DeleteArcs(s); + } + + if(rmeps_state.Error()) + fst->SetProperties(kError, kError); + fst->SetProperties( + RmEpsilonProperties(fst->Properties(kFstProperties, false)), + kFstProperties); + + if (opts.weight_threshold != Weight::Zero() || + opts.state_threshold != kNoStateId) + Prune(fst, opts.weight_threshold, opts.state_threshold); + if (opts.connect && (opts.weight_threshold == Weight::Zero() || + opts.state_threshold != kNoStateId)) + Connect(fst); +} + +// Removes epsilon-transitions (when both the input and output label +// are an epsilon) from a transducer. The result will be an equivalent +// FST that has no such epsilon transitions. This version modifies its +// input. It has a simplified interface; see above for a version that +// allows finer control. +// +// Complexity: +// - Time: +// - Unweighted: O(V2 + V E) +// - Acyclic: O(V2 + V E) +// - Tropical semiring: O(V2 log V + V E) +// - General: exponential +// - Space: O(V E) +// where V = # of states visited, E = # of arcs. +// +// References: +// - Mehryar Mohri. Generic Epsilon-Removal and Input +// Epsilon-Normalization Algorithms for Weighted Transducers, +// "International Journal of Computer Science", 13(1):129-143 (2002). +template <class Arc> +void RmEpsilon(MutableFst<Arc> *fst, + bool connect = true, + typename Arc::Weight weight_threshold = Arc::Weight::Zero(), + typename Arc::StateId state_threshold = kNoStateId, + float delta = kDelta) { + typedef typename Arc::StateId StateId; + typedef typename Arc::Weight Weight; + typedef typename Arc::Label Label; + + vector<Weight> distance; + AutoQueue<StateId> state_queue(*fst, &distance, EpsilonArcFilter<Arc>()); + RmEpsilonOptions<Arc, AutoQueue<StateId> > + opts(&state_queue, delta, connect, weight_threshold, state_threshold); + + RmEpsilon(fst, &distance, opts); +} + + +struct RmEpsilonFstOptions : CacheOptions { + float delta; + + RmEpsilonFstOptions(const CacheOptions &opts, float delta = kDelta) + : CacheOptions(opts), delta(delta) {} + + explicit RmEpsilonFstOptions(float delta = kDelta) : delta(delta) {} +}; + + +// Implementation of delayed RmEpsilonFst. +template <class A> +class RmEpsilonFstImpl : public CacheImpl<A> { + public: + using FstImpl<A>::SetType; + using FstImpl<A>::SetProperties; + using FstImpl<A>::SetInputSymbols; + using FstImpl<A>::SetOutputSymbols; + + using CacheBaseImpl< CacheState<A> >::PushArc; + using CacheBaseImpl< CacheState<A> >::HasArcs; + using CacheBaseImpl< CacheState<A> >::HasFinal; + using CacheBaseImpl< CacheState<A> >::HasStart; + using CacheBaseImpl< CacheState<A> >::SetArcs; + using CacheBaseImpl< CacheState<A> >::SetFinal; + using CacheBaseImpl< CacheState<A> >::SetStart; + + typedef typename A::Label Label; + typedef typename A::Weight Weight; + typedef typename A::StateId StateId; + typedef CacheState<A> State; + + RmEpsilonFstImpl(const Fst<A>& fst, const RmEpsilonFstOptions &opts) + : CacheImpl<A>(opts), + fst_(fst.Copy()), + delta_(opts.delta), + rmeps_state_( + *fst_, + &distance_, + RmEpsilonOptions<A, FifoQueue<StateId> >(&queue_, delta_, false)) { + SetType("rmepsilon"); + uint64 props = fst.Properties(kFstProperties, false); + SetProperties(RmEpsilonProperties(props, true), kCopyProperties); + SetInputSymbols(fst.InputSymbols()); + SetOutputSymbols(fst.OutputSymbols()); + } + + RmEpsilonFstImpl(const RmEpsilonFstImpl &impl) + : CacheImpl<A>(impl), + fst_(impl.fst_->Copy(true)), + delta_(impl.delta_), + rmeps_state_( + *fst_, + &distance_, + RmEpsilonOptions<A, FifoQueue<StateId> >(&queue_, delta_, false)) { + SetType("rmepsilon"); + SetProperties(impl.Properties(), kCopyProperties); + SetInputSymbols(impl.InputSymbols()); + SetOutputSymbols(impl.OutputSymbols()); + } + + ~RmEpsilonFstImpl() { + delete fst_; + } + + StateId Start() { + if (!HasStart()) { + SetStart(fst_->Start()); + } + return CacheImpl<A>::Start(); + } + + Weight Final(StateId s) { + if (!HasFinal(s)) { + Expand(s); + } + return CacheImpl<A>::Final(s); + } + + size_t NumArcs(StateId s) { + if (!HasArcs(s)) + Expand(s); + return CacheImpl<A>::NumArcs(s); + } + + size_t NumInputEpsilons(StateId s) { + if (!HasArcs(s)) + Expand(s); + return CacheImpl<A>::NumInputEpsilons(s); + } + + size_t NumOutputEpsilons(StateId s) { + if (!HasArcs(s)) + Expand(s); + return CacheImpl<A>::NumOutputEpsilons(s); + } + + uint64 Properties() const { return Properties(kFstProperties); } + + // Set error if found; return FST impl properties. + uint64 Properties(uint64 mask) const { + if ((mask & kError) && + (fst_->Properties(kError, false) || rmeps_state_.Error())) + SetProperties(kError, kError); + return FstImpl<A>::Properties(mask); + } + + void InitArcIterator(StateId s, ArcIteratorData<A> *data) { + if (!HasArcs(s)) + Expand(s); + CacheImpl<A>::InitArcIterator(s, data); + } + + void Expand(StateId s) { + rmeps_state_.Expand(s); + SetFinal(s, rmeps_state_.Final()); + vector<A> &arcs = rmeps_state_.Arcs(); + while (!arcs.empty()) { + PushArc(s, arcs.back()); + arcs.pop_back(); + } + SetArcs(s); + } + + private: + const Fst<A> *fst_; + float delta_; + vector<Weight> distance_; + FifoQueue<StateId> queue_; + RmEpsilonState<A, FifoQueue<StateId> > rmeps_state_; + + void operator=(const RmEpsilonFstImpl<A> &); // disallow +}; + + +// Removes epsilon-transitions (when both the input and output label +// are an epsilon) from a transducer. The result will be an equivalent +// FST that has no such epsilon transitions. This version is a +// delayed Fst. +// +// Complexity: +// - Time: +// - Unweighted: O(v^2 + v e) +// - General: exponential +// - Space: O(v e) +// where v = # of states visited, e = # of arcs visited. Constant time +// to visit an input state or arc is assumed and exclusive of caching. +// +// References: +// - Mehryar Mohri. Generic Epsilon-Removal and Input +// Epsilon-Normalization Algorithms for Weighted Transducers, +// "International Journal of Computer Science", 13(1):129-143 (2002). +// +// This class attaches interface to implementation and handles +// reference counting, delegating most methods to ImplToFst. +template <class A> +class RmEpsilonFst : public ImplToFst< RmEpsilonFstImpl<A> > { + public: + friend class ArcIterator< RmEpsilonFst<A> >; + friend class StateIterator< RmEpsilonFst<A> >; + + typedef A Arc; + typedef typename A::StateId StateId; + typedef CacheState<A> State; + typedef RmEpsilonFstImpl<A> Impl; + + RmEpsilonFst(const Fst<A> &fst) + : ImplToFst<Impl>(new Impl(fst, RmEpsilonFstOptions())) {} + + RmEpsilonFst(const Fst<A> &fst, const RmEpsilonFstOptions &opts) + : ImplToFst<Impl>(new Impl(fst, opts)) {} + + // See Fst<>::Copy() for doc. + RmEpsilonFst(const RmEpsilonFst<A> &fst, bool safe = false) + : ImplToFst<Impl>(fst, safe) {} + + // Get a copy of this RmEpsilonFst. See Fst<>::Copy() for further doc. + virtual RmEpsilonFst<A> *Copy(bool safe = false) const { + return new RmEpsilonFst<A>(*this, safe); + } + + virtual inline void InitStateIterator(StateIteratorData<A> *data) const; + + virtual void InitArcIterator(StateId s, ArcIteratorData<Arc> *data) const { + GetImpl()->InitArcIterator(s, data); + } + + private: + // Makes visible to friends. + Impl *GetImpl() const { return ImplToFst<Impl>::GetImpl(); } + + void operator=(const RmEpsilonFst<A> &fst); // disallow +}; + +// Specialization for RmEpsilonFst. +template<class A> +class StateIterator< RmEpsilonFst<A> > + : public CacheStateIterator< RmEpsilonFst<A> > { + public: + explicit StateIterator(const RmEpsilonFst<A> &fst) + : CacheStateIterator< RmEpsilonFst<A> >(fst, fst.GetImpl()) {} +}; + + +// Specialization for RmEpsilonFst. +template <class A> +class ArcIterator< RmEpsilonFst<A> > + : public CacheArcIterator< RmEpsilonFst<A> > { + public: + typedef typename A::StateId StateId; + + ArcIterator(const RmEpsilonFst<A> &fst, StateId s) + : CacheArcIterator< RmEpsilonFst<A> >(fst.GetImpl(), s) { + if (!fst.GetImpl()->HasArcs(s)) + fst.GetImpl()->Expand(s); + } + + private: + DISALLOW_COPY_AND_ASSIGN(ArcIterator); +}; + + +template <class A> inline +void RmEpsilonFst<A>::InitStateIterator(StateIteratorData<A> *data) const { + data->base = new StateIterator< RmEpsilonFst<A> >(*this); +} + + +// Useful alias when using StdArc. +typedef RmEpsilonFst<StdArc> StdRmEpsilonFst; + +} // namespace fst + +#endif // FST_LIB_RMEPSILON_H__ diff --git a/kaldi_io/src/tools/openfst/include/fst/rmfinalepsilon.h b/kaldi_io/src/tools/openfst/include/fst/rmfinalepsilon.h new file mode 100644 index 0000000..eb0f937 --- /dev/null +++ b/kaldi_io/src/tools/openfst/include/fst/rmfinalepsilon.h @@ -0,0 +1,107 @@ +// rmfinalepsilon.h + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Copyright 2005-2010 Google, Inc. +// Author: [email protected] (Johan Schalkwyk) +// +// \file +// Function to remove of final states that have epsilon only input arcs. + +#ifndef FST_LIB_RMFINALEPSILON_H__ +#define FST_LIB_RMFINALEPSILON_H__ + +#include <tr1/unordered_set> +using std::tr1::unordered_set; +using std::tr1::unordered_multiset; +#include <vector> +using std::vector; + +#include <fst/connect.h> +#include <fst/mutable-fst.h> + + +namespace fst { + +template<class A> +void RmFinalEpsilon(MutableFst<A>* fst) { + typedef typename A::StateId StateId; + typedef typename A::Weight Weight; + + // Determine the coaccesibility of states. + vector<bool> access; + vector<bool> coaccess; + uint64 props = 0; + SccVisitor<A> scc_visitor(0, &access, &coaccess, &props); + DfsVisit(*fst, &scc_visitor); + + // Find potential list of removable final states. These are final states + // that have no outgoing transitions or final states that have a + // non-coaccessible future. Complexity O(S) + unordered_set<StateId> finals; + for (StateIterator<Fst<A> > siter(*fst); !siter.Done(); siter.Next()) { + StateId s = siter.Value(); + if (fst->Final(s) != Weight::Zero()) { + bool future_coaccess = false; + for (ArcIterator<Fst<A> > aiter(*fst, s); !aiter.Done(); aiter.Next()) { + const A& arc = aiter.Value(); + if (coaccess[arc.nextstate]) { + future_coaccess = true; + break; + } + } + if (!future_coaccess) { + finals.insert(s); + } + } + } + + // Move the final weight. Complexity O(E) + vector<A> arcs; + for (StateIterator<Fst<A> > siter(*fst); !siter.Done(); siter.Next()) { + StateId s = siter.Value(); + Weight w(fst->Final(s)); + + arcs.clear(); + for (ArcIterator<Fst<A> > aiter(*fst, s); !aiter.Done(); aiter.Next()) { + const A& arc = aiter.Value(); + // is next state in the list of finals + if (finals.find(arc.nextstate) != finals.end()) { + // sum up all epsilon arcs + if (arc.ilabel == 0 && arc.olabel == 0) { + w = Plus(Times(fst->Final(arc.nextstate), arc.weight), w); + } else { + arcs.push_back(arc); + } + } else { + arcs.push_back(arc); + } + } + + // If some arcs (epsilon arcs) were deleted, delete all + // arcs and add back only the non epsilon arcs + if (arcs.size() < fst->NumArcs(s)) { + fst->DeleteArcs(s); + fst->SetFinal(s, w); + for (size_t i = 0; i < arcs.size(); ++i) { + fst->AddArc(s, arcs[i]); + } + } + } + + Connect(fst); +} + +} // namespace fst + +#endif // FST_LIB_RMFINALEPSILON_H__ diff --git a/kaldi_io/src/tools/openfst/include/fst/script/arcsort.h b/kaldi_io/src/tools/openfst/include/fst/script/arcsort.h new file mode 100644 index 0000000..4277332 --- /dev/null +++ b/kaldi_io/src/tools/openfst/include/fst/script/arcsort.h @@ -0,0 +1,49 @@ + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Copyright 2005-2010 Google, Inc. +// Author: [email protected] (Jake Ratkiewicz) + +#ifndef FST_SCRIPT_ARCSORT_H_ +#define FST_SCRIPT_ARCSORT_H_ + +#include <fst/arcsort.h> +#include <fst/script/arg-packs.h> +#include <fst/script/fst-class.h> + +namespace fst { +namespace script { + +enum ArcSortType { ILABEL_COMPARE, OLABEL_COMPARE }; + +typedef args::Package<MutableFstClass*, const ArcSortType> ArcSortArgs; + +template<class Arc> +void ArcSort(ArcSortArgs *args) { + MutableFst<Arc> *fst = args->arg1->GetMutableFst<Arc>(); + + if (args->arg2 == ILABEL_COMPARE) { + ILabelCompare<Arc> icomp; + ArcSort(fst, icomp); + } else { // OLABEL_COMPARE + OLabelCompare<Arc> ocomp; + ArcSort(fst, ocomp); + } +} + +void ArcSort(MutableFstClass *ofst, ArcSortType sort_type); + +} // namespace script +} // namespace fst + +#endif // FST_SCRIPT_ARCSORT_H_ diff --git a/kaldi_io/src/tools/openfst/include/fst/script/arg-packs.h b/kaldi_io/src/tools/openfst/include/fst/script/arg-packs.h new file mode 100644 index 0000000..8ebf8d8 --- /dev/null +++ b/kaldi_io/src/tools/openfst/include/fst/script/arg-packs.h @@ -0,0 +1,240 @@ + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Copyright 2005-2010 Google, Inc. +// Author: [email protected] (Jake Ratkiewicz) + +// Convenience templates for defining arg packs for the FstClass operations. + +// See operation-templates.h for a discussion about why these are needed; the +// short story is that all FstClass operations must be implemented by a version +// that takes one argument, most likely a struct bundling all the +// logical arguments together. These template structs provide convenient ways +// to specify these bundles (e.g. by means of appropriate typedefs). + +// The ArgPack template is sufficient for bundling together all the args for +// a particular function. The function is assumed to be void-returning. If +// you want a space for a return value, use the WithReturnValue template +// as follows: + +// WithReturnValue<bool, ArgPack<...> > + +#ifndef FST_SCRIPT_ARG_PACKS_H_ +#define FST_SCRIPT_ARG_PACKS_H_ + +namespace fst { +namespace script { +namespace args { + +// Sentinel value that means "no arg here." +class none_type { }; + +// Base arg pack template class. Specializations follow that allow +// fewer numbers of arguments (down to 2). If the maximum number of arguments +// increases, you will need to change three things: +// 1) Add more template parameters to this template +// 2) Add more specializations to allow fewer numbers of parameters than +// the new max. +// 3) Add extra none_types to all existing specializations to fill +// the new slots. + + +// 9 args (max) +template<class T1, + class T2 = none_type, + class T3 = none_type, + class T4 = none_type, + class T5 = none_type, + class T6 = none_type, + class T7 = none_type, + class T8 = none_type, + class T9 = none_type> +struct Package { + T1 arg1; + T2 arg2; + T3 arg3; + T4 arg4; + T5 arg5; + T6 arg6; + T7 arg7; + T8 arg8; + T9 arg9; + + Package(T1 arg1, T2 arg2, T3 arg3, T4 arg4, T5 arg5, T6 arg6, + T7 arg7, T8 arg8, T9 arg9) : + arg1(arg1), arg2(arg2), arg3(arg3), arg4(arg4), arg5(arg5), + arg6(arg6), arg7(arg7), arg8(arg8), arg9(arg9) { } +}; + +// 8 args +template<class T1, + class T2, + class T3, + class T4, + class T5, + class T6, + class T7, + class T8> +struct Package<T1, T2, T3, T4, T5, T6, T7, T8, none_type> { + T1 arg1; + T2 arg2; + T3 arg3; + T4 arg4; + T5 arg5; + T6 arg6; + T7 arg7; + T8 arg8; + + Package(T1 arg1, T2 arg2, T3 arg3, T4 arg4, T5 arg5, T6 arg6, + T7 arg7, T8 arg8) : + arg1(arg1), arg2(arg2), arg3(arg3), arg4(arg4), arg5(arg5), + arg6(arg6), arg7(arg7), arg8(arg8) { } +}; + +// 7 args +template<class T1, + class T2, + class T3, + class T4, + class T5, + class T6, + class T7> +struct Package<T1, T2, T3, T4, T5, T6, T7, + none_type, none_type> { + T1 arg1; + T2 arg2; + T3 arg3; + T4 arg4; + T5 arg5; + T6 arg6; + T7 arg7; + + Package(T1 arg1, T2 arg2, T3 arg3, T4 arg4, T5 arg5, T6 arg6, + T7 arg7) : + arg1(arg1), arg2(arg2), arg3(arg3), arg4(arg4), arg5(arg5), + arg6(arg6), arg7(arg7) { } +}; + +// 6 args +template<class T1, + class T2, + class T3, + class T4, + class T5, + class T6> +struct Package<T1, T2, T3, T4, T5, T6, none_type, + none_type, none_type> { + T1 arg1; + T2 arg2; + T3 arg3; + T4 arg4; + T5 arg5; + T6 arg6; + + Package(T1 arg1, T2 arg2, T3 arg3, T4 arg4, T5 arg5, T6 arg6) : + arg1(arg1), arg2(arg2), arg3(arg3), arg4(arg4), arg5(arg5), + arg6(arg6) { } +}; + +// 5 args +template<class T1, + class T2, + class T3, + class T4, + class T5> +struct Package<T1, T2, T3, T4, T5, none_type, none_type, + none_type, none_type> { + T1 arg1; + T2 arg2; + T3 arg3; + T4 arg4; + T5 arg5; + + Package(T1 arg1, T2 arg2, T3 arg3, T4 arg4, T5 arg5) : + arg1(arg1), arg2(arg2), arg3(arg3), arg4(arg4), arg5(arg5) { } +}; + +// 4 args +template<class T1, + class T2, + class T3, + class T4> +struct Package<T1, T2, T3, T4, none_type, none_type, + none_type, none_type, none_type> { + T1 arg1; + T2 arg2; + T3 arg3; + T4 arg4; + + Package(T1 arg1, T2 arg2, T3 arg3, T4 arg4) : + arg1(arg1), arg2(arg2), arg3(arg3), arg4(arg4) { } +}; + +// 3 args +template<class T1, + class T2, + class T3> +struct Package<T1, T2, T3, none_type, none_type, + none_type, none_type, none_type, + none_type> { + T1 arg1; + T2 arg2; + T3 arg3; + + Package(T1 arg1, T2 arg2, T3 arg3) : + arg1(arg1), arg2(arg2), arg3(arg3) { } +}; + +// 2 args (minimum) +template<class T1, + class T2> +struct Package<T1, T2, none_type, none_type, + none_type, none_type, none_type, + none_type, none_type> { + T1 arg1; + T2 arg2; + + Package(T1 arg1, T2 arg2) : + arg1(arg1), arg2(arg2) { } +}; + +// Tack this on to an existing arg pack to add a return value. +// The syntax for accessing the args is then slightly more stilted, +// as you must do an extra member access (since the args are stored +// as a member of this class). +// The alternative is to declare another slew of templates for functions +// that return a value, analogous to the above. + +template<class Retval, class ArgPackage> +struct WithReturnValue { + Retval retval; + const ArgPackage &args; + + explicit WithReturnValue(const ArgPackage &args) : args(args) { } +}; + +// We don't want to store a reference to a reference, if ArgPackage is +// already some reference type. +template<class Retval, class ArgPackage> +struct WithReturnValue<Retval, ArgPackage&> { + Retval retval; + const ArgPackage &args; + + explicit WithReturnValue(const ArgPackage &args) : args(args) { } +}; + +} // namespace args +} // namespace script +} // namespace fst + +#endif // FST_SCRIPT_ARG_PACKS_H_ diff --git a/kaldi_io/src/tools/openfst/include/fst/script/closure.h b/kaldi_io/src/tools/openfst/include/fst/script/closure.h new file mode 100644 index 0000000..93b5ec3 --- /dev/null +++ b/kaldi_io/src/tools/openfst/include/fst/script/closure.h @@ -0,0 +1,41 @@ + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Copyright 2005-2010 Google, Inc. +// Author: [email protected] (Jake Ratkiewicz) + +#ifndef FST_SCRIPT_CLOSURE_H_ +#define FST_SCRIPT_CLOSURE_H_ + +#include <fst/script/arg-packs.h> +#include <fst/script/fst-class.h> +#include <fst/closure.h> + +namespace fst { +namespace script { + +typedef args::Package<MutableFstClass*, const ClosureType> ClosureArgs; + +template<class Arc> +void Closure(ClosureArgs *args) { + MutableFst<Arc> *fst = args->arg1->GetMutableFst<Arc>(); + + Closure(fst, args->arg2); +} + +void Closure(MutableFstClass *ofst, ClosureType closure_type); + +} // namespace script +} // namespace fst + +#endif // FST_SCRIPT_CLOSURE_H_ diff --git a/kaldi_io/src/tools/openfst/include/fst/script/compile-impl.h b/kaldi_io/src/tools/openfst/include/fst/script/compile-impl.h new file mode 100644 index 0000000..68f37c3 --- /dev/null +++ b/kaldi_io/src/tools/openfst/include/fst/script/compile-impl.h @@ -0,0 +1,216 @@ +// compile.h + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Copyright 2005-2010 Google, Inc. +// Author: [email protected] (Michael Riley) +// +// \file +// Class to to compile a binary Fst from textual input. + +#ifndef FST_SCRIPT_COMPILE_IMPL_H_ +#define FST_SCRIPT_COMPILE_IMPL_H_ + +#include <tr1/unordered_map> +using std::tr1::unordered_map; +using std::tr1::unordered_multimap; +#include <sstream> +#include <string> +#include <vector> +using std::vector; + +#include <iostream> +#include <fstream> +#include <sstream> +#include <fst/fst.h> +#include <fst/util.h> +#include <fst/vector-fst.h> + +DECLARE_string(fst_field_separator); + +namespace fst { + +// Compile a binary Fst from textual input, helper class for fstcompile.cc +// WARNING: Stand-alone use of this class not recommended, most code should +// read/write using the binary format which is much more efficient. +template <class A> class FstCompiler { + public: + typedef A Arc; + typedef typename A::StateId StateId; + typedef typename A::Label Label; + typedef typename A::Weight Weight; + + // WARNING: use of 'allow_negative_labels = true' not recommended; may + // cause conflicts + FstCompiler(istream &istrm, const string &source, + const SymbolTable *isyms, const SymbolTable *osyms, + const SymbolTable *ssyms, bool accep, bool ikeep, + bool okeep, bool nkeep, bool allow_negative_labels = false) + : nline_(0), source_(source), + isyms_(isyms), osyms_(osyms), ssyms_(ssyms), + nstates_(0), keep_state_numbering_(nkeep), + allow_negative_labels_(allow_negative_labels) { + char line[kLineLen]; + while (istrm.getline(line, kLineLen)) { + ++nline_; + vector<char *> col; + string separator = FLAGS_fst_field_separator + "\n"; + SplitToVector(line, separator.c_str(), &col, true); + if (col.size() == 0 || col[0][0] == '\0') // empty line + continue; + if (col.size() > 5 || + (col.size() > 4 && accep) || + (col.size() == 3 && !accep)) { + FSTERROR() << "FstCompiler: Bad number of columns, source = " + << source_ + << ", line = " << nline_; + fst_.SetProperties(kError, kError); + return; + } + StateId s = StrToStateId(col[0]); + while (s >= fst_.NumStates()) + fst_.AddState(); + if (nline_ == 1) + fst_.SetStart(s); + + Arc arc; + StateId d = s; + switch (col.size()) { + case 1: + fst_.SetFinal(s, Weight::One()); + break; + case 2: + fst_.SetFinal(s, StrToWeight(col[1], true)); + break; + case 3: + arc.nextstate = d = StrToStateId(col[1]); + arc.ilabel = StrToILabel(col[2]); + arc.olabel = arc.ilabel; + arc.weight = Weight::One(); + fst_.AddArc(s, arc); + break; + case 4: + arc.nextstate = d = StrToStateId(col[1]); + arc.ilabel = StrToILabel(col[2]); + if (accep) { + arc.olabel = arc.ilabel; + arc.weight = StrToWeight(col[3], false); + } else { + arc.olabel = StrToOLabel(col[3]); + arc.weight = Weight::One(); + } + fst_.AddArc(s, arc); + break; + case 5: + arc.nextstate = d = StrToStateId(col[1]); + arc.ilabel = StrToILabel(col[2]); + arc.olabel = StrToOLabel(col[3]); + arc.weight = StrToWeight(col[4], false); + fst_.AddArc(s, arc); + } + while (d >= fst_.NumStates()) + fst_.AddState(); + } + if (ikeep) + fst_.SetInputSymbols(isyms); + if (okeep) + fst_.SetOutputSymbols(osyms); + } + + const VectorFst<A> &Fst() const { + return fst_; + } + + private: + // Maximum line length in text file. + static const int kLineLen = 8096; + + int64 StrToId(const char *s, const SymbolTable *syms, + const char *name, bool allow_negative = false) const { + int64 n = 0; + + if (syms) { + n = syms->Find(s); + if (n == -1 || (!allow_negative && n < 0)) { + FSTERROR() << "FstCompiler: Symbol \"" << s + << "\" is not mapped to any integer " << name + << ", symbol table = " << syms->Name() + << ", source = " << source_ << ", line = " << nline_; + fst_.SetProperties(kError, kError); + } + } else { + char *p; + n = strtoll(s, &p, 10); + if (p < s + strlen(s) || (!allow_negative && n < 0)) { + FSTERROR() << "FstCompiler: Bad " << name << " integer = \"" << s + << "\", source = " << source_ << ", line = " << nline_; + fst_.SetProperties(kError, kError); + } + } + return n; + } + + StateId StrToStateId(const char *s) { + StateId n = StrToId(s, ssyms_, "state ID"); + + if (keep_state_numbering_) + return n; + + // remap state IDs to make dense set + typename unordered_map<StateId, StateId>::const_iterator it = states_.find(n); + if (it == states_.end()) { + states_[n] = nstates_; + return nstates_++; + } else { + return it->second; + } + } + + StateId StrToILabel(const char *s) const { + return StrToId(s, isyms_, "arc ilabel", allow_negative_labels_); + } + + StateId StrToOLabel(const char *s) const { + return StrToId(s, osyms_, "arc olabel", allow_negative_labels_); + } + + Weight StrToWeight(const char *s, bool allow_zero) const { + Weight w; + istringstream strm(s); + strm >> w; + if (!strm || (!allow_zero && w == Weight::Zero())) { + FSTERROR() << "FstCompiler: Bad weight = \"" << s + << "\", source = " << source_ << ", line = " << nline_; + fst_.SetProperties(kError, kError); + w = Weight::NoWeight(); + } + return w; + } + + mutable VectorFst<A> fst_; + size_t nline_; + string source_; // text FST source name + const SymbolTable *isyms_; // ilabel symbol table + const SymbolTable *osyms_; // olabel symbol table + const SymbolTable *ssyms_; // slabel symbol table + unordered_map<StateId, StateId> states_; // state ID map + StateId nstates_; // number of seen states + bool keep_state_numbering_; + bool allow_negative_labels_; // not recommended; may cause conflicts + + DISALLOW_COPY_AND_ASSIGN(FstCompiler); +}; + +} // namespace fst + +#endif // FST_SCRIPT_COMPILE_IMPL_H_ diff --git a/kaldi_io/src/tools/openfst/include/fst/script/compile.h b/kaldi_io/src/tools/openfst/include/fst/script/compile.h new file mode 100644 index 0000000..bb6ea56 --- /dev/null +++ b/kaldi_io/src/tools/openfst/include/fst/script/compile.h @@ -0,0 +1,92 @@ + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Copyright 2005-2010 Google, Inc. +// Author: [email protected] (Jake Ratkiewicz) + +#ifndef FST_SCRIPT_COMPILE_H_ +#define FST_SCRIPT_COMPILE_H_ + +#include <fst/script/arg-packs.h> +#include <fst/script/fst-class.h> +#include <fst/script/compile-impl.h> + +namespace fst { +namespace script { + +// Note: it is safe to pass these strings as references because +// this struct is only used to pass them deeper in the call graph. +// Be sure you understand why this is so before using this struct +// for anything else! +struct FstCompileArgs { + fst::istream &istrm; + const string &source; + const string &dest; + const string &fst_type; + const fst::SymbolTable *isyms; + const fst::SymbolTable *osyms; + const fst::SymbolTable *ssyms; + const bool accep; + const bool ikeep; + const bool okeep; + const bool nkeep; + const bool allow_negative_labels; + + FstCompileArgs(istream &istrm, const string &source, const string &dest, + const string &fst_type, const fst::SymbolTable *isyms, + const fst::SymbolTable *osyms, + const fst::SymbolTable *ssyms, + bool accep, bool ikeep, bool okeep, bool nkeep, + bool allow_negative_labels = false) : + istrm(istrm), source(source), dest(dest), fst_type(fst_type), + isyms(isyms), osyms(osyms), ssyms(ssyms), accep(accep), ikeep(ikeep), + okeep(okeep), nkeep(nkeep), + allow_negative_labels(allow_negative_labels) { } +}; + +template<class Arc> +void CompileFst(FstCompileArgs *args) { + using fst::FstCompiler; + using fst::Convert; + using fst::Fst; + + FstCompiler<Arc> fstcompiler(args->istrm, args->source, args->isyms, + args->osyms, args->ssyms, + args->accep, args->ikeep, + args->okeep, args->nkeep, + args->allow_negative_labels); + + const Fst<Arc> *fst = &fstcompiler.Fst(); + if (args->fst_type != "vector") { + fst = Convert<Arc>(*fst, args->fst_type); + if (!fst) { + FSTERROR() << "Failed to convert FST to desired type: " + << args->fst_type; + return; + } + } + + fst->Write(args->dest); +} + +void CompileFst(istream &istrm, const string &source, const string &dest, + const string &fst_type, const string &arc_type, + const SymbolTable *isyms, + const SymbolTable *osyms, const SymbolTable *ssyms, + bool accep, bool ikeep, bool okeep, bool nkeep, + bool allow_negative_labels); + +} // namespace script +} // namespace fst + +#endif // FST_SCRIPT_COMPILE_H_ diff --git a/kaldi_io/src/tools/openfst/include/fst/script/compose.h b/kaldi_io/src/tools/openfst/include/fst/script/compose.h new file mode 100644 index 0000000..96375f7 --- /dev/null +++ b/kaldi_io/src/tools/openfst/include/fst/script/compose.h @@ -0,0 +1,63 @@ + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Copyright 2005-2010 Google, Inc. +// Author: [email protected] (Jake Ratkiewicz) + +#ifndef FST_SCRIPT_COMPOSE_H_ +#define FST_SCRIPT_COMPOSE_H_ + +#include <fst/script/arg-packs.h> +#include <fst/script/fst-class.h> +#include <fst/compose.h> + +namespace fst { +namespace script { + +typedef args::Package<const FstClass&, const FstClass&, + MutableFstClass*, ComposeFilter> ComposeArgs1; + +template<class Arc> +void Compose(ComposeArgs1 *args) { + const Fst<Arc> &ifst1 = *(args->arg1.GetFst<Arc>()); + const Fst<Arc> &ifst2 = *(args->arg2.GetFst<Arc>()); + MutableFst<Arc> *ofst = args->arg3->GetMutableFst<Arc>(); + + Compose(ifst1, ifst2, ofst, args->arg4); +} + +typedef fst::ComposeOptions ComposeOptions; + +typedef args::Package<const FstClass&, const FstClass&, + MutableFstClass*, const ComposeOptions &> ComposeArgs2; + +template<class Arc> +void Compose(ComposeArgs2 *args) { + const Fst<Arc> &ifst1 = *(args->arg1.GetFst<Arc>()); + const Fst<Arc> &ifst2 = *(args->arg2.GetFst<Arc>()); + MutableFst<Arc> *ofst = args->arg3->GetMutableFst<Arc>(); + + Compose(ifst1, ifst2, ofst, args->arg4); +} + +void Compose(const FstClass &ifst1, const FstClass &ifst2, + MutableFstClass *ofst, + const ComposeOptions &opts = fst::script::ComposeOptions()); + +void Compose(const FstClass &ifst1, const FstClass &ifst2, + MutableFstClass *ofst, ComposeFilter compose_filter); + +} // namespace script +} // namespace fst + +#endif // FST_SCRIPT_COMPOSE_H_ diff --git a/kaldi_io/src/tools/openfst/include/fst/script/concat.h b/kaldi_io/src/tools/openfst/include/fst/script/concat.h new file mode 100644 index 0000000..46c4407 --- /dev/null +++ b/kaldi_io/src/tools/openfst/include/fst/script/concat.h @@ -0,0 +1,54 @@ + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Copyright 2005-2010 Google, Inc. +// Author: [email protected] (Jake Ratkiewicz) + +#ifndef FST_SCRIPT_CONCAT_H_ +#define FST_SCRIPT_CONCAT_H_ + +#include <fst/script/arg-packs.h> +#include <fst/script/fst-class.h> +#include <fst/concat.h> + +namespace fst { +namespace script { + +typedef args::Package<MutableFstClass*, const FstClass&> ConcatArgs1; +typedef args::Package<const FstClass&, MutableFstClass*> ConcatArgs2; + +template<class Arc> +void Concat(ConcatArgs1 *args) { + MutableFst<Arc> *ofst = args->arg1->GetMutableFst<Arc>(); + const Fst<Arc> &ifst = *(args->arg2.GetFst<Arc>()); + + Concat(ofst, ifst); +} + +template<class Arc> +void Concat(ConcatArgs2 *args) { + const Fst<Arc> &ifst = *(args->arg1.GetFst<Arc>()); + MutableFst<Arc> *ofst = args->arg2->GetMutableFst<Arc>(); + + Concat(ifst, ofst); +} + +void Concat(MutableFstClass *ofst, const FstClass &ifst); +void Concat(const FstClass &ifst, MutableFstClass *ofst); + +} // namespace script +} // namespace fst + + + +#endif // FST_SCRIPT_CONCAT_H_ diff --git a/kaldi_io/src/tools/openfst/include/fst/script/connect.h b/kaldi_io/src/tools/openfst/include/fst/script/connect.h new file mode 100644 index 0000000..19c4390 --- /dev/null +++ b/kaldi_io/src/tools/openfst/include/fst/script/connect.h @@ -0,0 +1,45 @@ + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Copyright 2005-2010 Google, Inc. +// Author: [email protected] (Jake Ratkiewicz) + +#ifndef FST_SCRIPT_CONNECT_H_ +#define FST_SCRIPT_CONNECT_H_ + +#include <fst/script/arg-packs.h> +#include <fst/script/fst-class.h> +#include <fst/dfs-visit.h> +#include <fst/connect.h> + +namespace fst { +namespace script { + +// This function confuses SWIG, because both versions have the same args +#ifndef SWIG +template<class Arc> +void Connect(MutableFstClass *fst) { + MutableFst<Arc> *typed_fst = fst->GetMutableFst<Arc>(); + + Connect(typed_fst); +} +#endif + +void Connect(MutableFstClass *fst); + +} // namespace script +} // namespace fst + + + +#endif // FST_SCRIPT_CONNECT_H_ diff --git a/kaldi_io/src/tools/openfst/include/fst/script/convert.h b/kaldi_io/src/tools/openfst/include/fst/script/convert.h new file mode 100644 index 0000000..4a3ce6b --- /dev/null +++ b/kaldi_io/src/tools/openfst/include/fst/script/convert.h @@ -0,0 +1,49 @@ + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Copyright 2005-2010 Google, Inc. +// Author: [email protected] (Jake Ratkiewicz) + +#ifndef FST_SCRIPT_CONVERT_H_ +#define FST_SCRIPT_CONVERT_H_ + +#include <string> + +#include <fst/script/arg-packs.h> +#include <fst/script/fst-class.h> + +namespace fst { +namespace script { + +typedef args::Package<const FstClass&, const string&> ConvertInnerArgs; +typedef args::WithReturnValue<FstClass*, ConvertInnerArgs> ConvertArgs; + +template<class Arc> +void Convert(ConvertArgs *args) { + const Fst<Arc> &fst = *(args->args.arg1.GetFst<Arc>()); + const string &new_type = args->args.arg2; + + Fst<Arc> *result = Convert(fst, new_type); + args->retval = new FstClass(*result); + delete result; +} + +#ifdef SWIG +%newobject Convert; +#endif +FstClass *Convert(const FstClass& f, const string &new_type); + +} // namespace script +} // namespace fst + +#endif // FST_SCRIPT_CONVERT_H_ diff --git a/kaldi_io/src/tools/openfst/include/fst/script/decode.h b/kaldi_io/src/tools/openfst/include/fst/script/decode.h new file mode 100644 index 0000000..1064ad5 --- /dev/null +++ b/kaldi_io/src/tools/openfst/include/fst/script/decode.h @@ -0,0 +1,46 @@ + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Copyright 2005-2010 Google, Inc. +// Author: [email protected] (Jake Ratkiewicz) + +#ifndef FST_SCRIPT_DECODE_H_ +#define FST_SCRIPT_DECODE_H_ + +#include <string> + +#include <fst/script/arg-packs.h> +#include <fst/script/fst-class.h> +#include <fst/encode.h> + +namespace fst { +namespace script { + +typedef args::Package<MutableFstClass*, const string&> DecodeArgs; + +template<class Arc> +void Decode(DecodeArgs *args) { + MutableFst<Arc> *ofst = args->arg1->GetMutableFst<Arc>(); + + EncodeMapper<Arc> *decoder = EncodeMapper<Arc>::Read(args->arg2, DECODE); + Decode(ofst, *decoder); + + delete decoder; +} + +void Decode(MutableFstClass *fst, const string &coder_fname); + +} // namespace script +} // namespace fst + +#endif // FST_SCRIPT_DECODE_H_ diff --git a/kaldi_io/src/tools/openfst/include/fst/script/determinize.h b/kaldi_io/src/tools/openfst/include/fst/script/determinize.h new file mode 100644 index 0000000..38fd7ad --- /dev/null +++ b/kaldi_io/src/tools/openfst/include/fst/script/determinize.h @@ -0,0 +1,68 @@ + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Copyright 2005-2010 Google, Inc. +// Author: [email protected] (Jake Ratkiewicz) + +#ifndef FST_SCRIPT_DETERMINIZE_H_ +#define FST_SCRIPT_DETERMINIZE_H_ + +#include <fst/determinize.h> +#include <fst/script/arg-packs.h> +#include <fst/script/fst-class.h> +#include <fst/script/weight-class.h> + +namespace fst { +namespace script { + +struct DeterminizeOptions { + float delta; + WeightClass weight_threshold; + int64 state_threshold; + int64 subsequential_label; + + explicit DeterminizeOptions(float d = fst::kDelta, + WeightClass w = + fst::script::WeightClass::Zero(), + int64 n = fst::kNoStateId, int64 l = 0) + : delta(d), weight_threshold(w), state_threshold(n), + subsequential_label(l) {} +}; + +typedef args::Package<const FstClass&, MutableFstClass*, + const DeterminizeOptions &> DeterminizeArgs; + +template<class Arc> +void Determinize(DeterminizeArgs *args) { + const Fst<Arc> &ifst = *(args->arg1.GetFst<Arc>()); + MutableFst<Arc> *ofst = args->arg2->GetMutableFst<Arc>(); + const DeterminizeOptions &opts = args->arg3; + + fst::DeterminizeOptions<Arc> detargs; + detargs.delta = opts.delta; + detargs.weight_threshold = + *(opts.weight_threshold.GetWeight<typename Arc::Weight>()); + detargs.state_threshold = opts.state_threshold; + detargs.subsequential_label = opts.subsequential_label; + + Determinize(ifst, ofst, detargs); +} + +void Determinize(const FstClass &ifst, MutableFstClass *ofst, + const DeterminizeOptions &opts = + fst::script::DeterminizeOptions()); + +} // namespace script +} // namespace fst + +#endif // FST_SCRIPT_DETERMINIZE_H_ diff --git a/kaldi_io/src/tools/openfst/include/fst/script/difference.h b/kaldi_io/src/tools/openfst/include/fst/script/difference.h new file mode 100644 index 0000000..76490d4 --- /dev/null +++ b/kaldi_io/src/tools/openfst/include/fst/script/difference.h @@ -0,0 +1,67 @@ + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Copyright 2005-2010 Google, Inc. +// Author: [email protected] (Jake Ratkiewicz) + +#ifndef FST_SCRIPT_DIFFERENCE_H_ +#define FST_SCRIPT_DIFFERENCE_H_ + +#include <fst/script/arg-packs.h> +#include <fst/script/fst-class.h> +#include <fst/script/compose.h> // for ComposeFilter +#include <fst/difference.h> + +namespace fst { +namespace script { + +typedef args::Package<const FstClass&, const FstClass&, + MutableFstClass*, ComposeFilter> DifferenceArgs1; + +template<class Arc> +void Difference(DifferenceArgs1 *args) { + const Fst<Arc> &ifst1 = *(args->arg1.GetFst<Arc>()); + const Fst<Arc> &ifst2 = *(args->arg2.GetFst<Arc>()); + MutableFst<Arc> *ofst = args->arg3->GetMutableFst<Arc>(); + + Difference(ifst1, ifst2, ofst, args->arg4); +} + +typedef args::Package<const FstClass&, const FstClass&, + MutableFstClass*, const ComposeOptions &> DifferenceArgs2; + +template<class Arc> +void Difference(DifferenceArgs2 *args) { + const Fst<Arc> &ifst1 = *(args->arg1.GetFst<Arc>()); + const Fst<Arc> &ifst2 = *(args->arg2.GetFst<Arc>()); + MutableFst<Arc> *ofst = args->arg3->GetMutableFst<Arc>(); + + Difference(ifst1, ifst2, ofst, args->arg4); +} + + +void Difference(const FstClass &ifst1, const FstClass &ifst2, + MutableFstClass *ofst, + ComposeFilter compose_filter); + +void Difference(const FstClass &ifst1, const FstClass &ifst2, + MutableFstClass *ofst, + const ComposeOptions &opts = fst::script::ComposeOptions()); + + +} // namespace script +} // namespace fst + + + +#endif // FST_SCRIPT_DIFFERENCE_H_ diff --git a/kaldi_io/src/tools/openfst/include/fst/script/disambiguate.h b/kaldi_io/src/tools/openfst/include/fst/script/disambiguate.h new file mode 100644 index 0000000..e42a9c2 --- /dev/null +++ b/kaldi_io/src/tools/openfst/include/fst/script/disambiguate.h @@ -0,0 +1,68 @@ + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Copyright 2005-2010 Google, Inc. +// Author: [email protected] (Jake Ratkiewicz) + +#ifndef FST_SCRIPT_DISAMBIGUATE_H_ +#define FST_SCRIPT_DISAMBIGUATE_H_ + +#include <fst/disambiguate.h> +#include <fst/script/arg-packs.h> +#include <fst/script/fst-class.h> +#include <fst/script/weight-class.h> + +namespace fst { +namespace script { + +struct DisambiguateOptions { + float delta; + WeightClass weight_threshold; + int64 state_threshold; + int64 subsequential_label; + + explicit DisambiguateOptions(float d = fst::kDelta, + WeightClass w = + fst::script::WeightClass::Zero(), + int64 n = fst::kNoStateId, int64 l = 0) + : delta(d), weight_threshold(w), state_threshold(n), + subsequential_label(l) {} +}; + +typedef args::Package<const FstClass&, MutableFstClass*, + const DisambiguateOptions &> DisambiguateArgs; + +template<class Arc> +void Disambiguate(DisambiguateArgs *args) { + const Fst<Arc> &ifst = *(args->arg1.GetFst<Arc>()); + MutableFst<Arc> *ofst = args->arg2->GetMutableFst<Arc>(); + const DisambiguateOptions &opts = args->arg3; + + fst::DisambiguateOptions<Arc> detargs; + detargs.delta = opts.delta; + detargs.weight_threshold = + *(opts.weight_threshold.GetWeight<typename Arc::Weight>()); + detargs.state_threshold = opts.state_threshold; + detargs.subsequential_label = opts.subsequential_label; + + Disambiguate(ifst, ofst, detargs); +} + +void Disambiguate(const FstClass &ifst, MutableFstClass *ofst, + const DisambiguateOptions &opts = + fst::script::DisambiguateOptions()); + +} // namespace script +} // namespace fst + +#endif // FST_SCRIPT_DISAMBIGUATE_H_ diff --git a/kaldi_io/src/tools/openfst/include/fst/script/draw-impl.h b/kaldi_io/src/tools/openfst/include/fst/script/draw-impl.h new file mode 100644 index 0000000..893e258 --- /dev/null +++ b/kaldi_io/src/tools/openfst/include/fst/script/draw-impl.h @@ -0,0 +1,234 @@ +// draw.h + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Copyright 2005-2010 Google, Inc. +// Author: [email protected] (Cyril Allauzen) +// +// \file +// Class to draw a binary FST by producing a text file in dot format, +// helper class to fstdraw.cc + +#ifndef FST_SCRIPT_DRAW_IMPL_H_ +#define FST_SCRIPT_DRAW_IMPL_H_ + +#include <sstream> +#include <string> + +#include <fst/script/fst-class.h> +#include <fst/fst.h> +#include <fst/util.h> + +namespace fst { + +// Print a binary Fst in the dot textual format, helper class for fstdraw.cc +// WARNING: Stand-alone use not recommend. +template <class A> class FstDrawer { + public: + typedef A Arc; + typedef typename A::StateId StateId; + typedef typename A::Label Label; + typedef typename A::Weight Weight; + + FstDrawer(const Fst<A> &fst, + const SymbolTable *isyms, + const SymbolTable *osyms, + const SymbolTable *ssyms, + bool accep, + string title, + float width, + float height, + bool portrait, + bool vertical, + float ranksep, + float nodesep, + int fontsize, + int precision, + bool show_weight_one) + : fst_(fst), isyms_(isyms), osyms_(osyms), ssyms_(ssyms), + accep_(accep && fst.Properties(kAcceptor, true)), ostrm_(0), + title_(title), width_(width), height_(height), portrait_(portrait), + vertical_(vertical), ranksep_(ranksep), nodesep_(nodesep), + fontsize_(fontsize), precision_(precision), + show_weight_one_(show_weight_one) {} + + // Draw Fst to an output buffer (or stdout if buf = 0) + void Draw(ostream *strm, const string &dest) { + ostrm_ = strm; + dest_ = dest; + StateId start = fst_.Start(); + if (start == kNoStateId) + return; + + PrintString("digraph FST {\n"); + if (vertical_) + PrintString("rankdir = BT;\n"); + else + PrintString("rankdir = LR;\n"); + PrintString("size = \""); + Print(width_); + PrintString(","); + Print(height_); + PrintString("\";\n"); + if (!dest_.empty()) + PrintString("label = \"" + title_ + "\";\n"); + PrintString("center = 1;\n"); + if (portrait_) + PrintString("orientation = Portrait;\n"); + else + PrintString("orientation = Landscape;\n"); + PrintString("ranksep = \""); + Print(ranksep_); + PrintString("\";\n"); + PrintString("nodesep = \""); + Print(nodesep_); + PrintString("\";\n"); + // initial state first + DrawState(start); + for (StateIterator< Fst<A> > siter(fst_); + !siter.Done(); + siter.Next()) { + StateId s = siter.Value(); + if (s != start) + DrawState(s); + } + PrintString("}\n"); + } + + private: + // Maximum line length in text file. + static const int kLineLen = 8096; + + void PrintString(const string &s) const { + *ostrm_ << s; + } + + // Escapes backslash and double quote if these occur in the string. Dot will + // not deal gracefully with these if they are not escaped. + inline void EscapeChars(const string &s, string* ns) const { + const char* c = s.c_str(); + while (*c) { + if (*c == '\\' || *c == '"') ns->push_back('\\'); + ns->push_back(*c); + ++c; + } + } + + void PrintId(int64 id, const SymbolTable *syms, + const char *name) const { + if (syms) { + string symbol = syms->Find(id); + if (symbol == "") { + FSTERROR() << "FstDrawer: Integer " << id + << " is not mapped to any textual symbol" + << ", symbol table = " << syms->Name() + << ", destination = " << dest_; + symbol = "?"; + } + string nsymbol; + EscapeChars(symbol, &nsymbol); + PrintString(nsymbol); + } else { + string idstr; + Int64ToStr(id, &idstr); + PrintString(idstr); + } + } + + void PrintStateId(StateId s) const { + PrintId(s, ssyms_, "state ID"); + } + + void PrintILabel(Label l) const { + PrintId(l, isyms_, "arc input label"); + } + + void PrintOLabel(Label l) const { + PrintId(l, osyms_, "arc output label"); + } + + template <class T> + void Print(T t) const { + *ostrm_ << t; + } + + void DrawState(StateId s) const { + Print(s); + PrintString(" [label = \""); + PrintStateId(s); + Weight final = fst_.Final(s); + if (final != Weight::Zero()) { + if (show_weight_one_ || (final != Weight::One())) { + PrintString("/"); + Print(final); + } + PrintString("\", shape = doublecircle,"); + } else { + PrintString("\", shape = circle,"); + } + if (s == fst_.Start()) + PrintString(" style = bold,"); + else + PrintString(" style = solid,"); + PrintString(" fontsize = "); + Print(fontsize_); + PrintString("]\n"); + for (ArcIterator< Fst<A> > aiter(fst_, s); + !aiter.Done(); + aiter.Next()) { + Arc arc = aiter.Value(); + PrintString("\t"); + Print(s); + PrintString(" -> "); + Print(arc.nextstate); + PrintString(" [label = \""); + PrintILabel(arc.ilabel); + if (!accep_) { + PrintString(":"); + PrintOLabel(arc.olabel); + } + if (show_weight_one_ || (arc.weight != Weight::One())) { + PrintString("/"); + Print(arc.weight); + } + PrintString("\", fontsize = "); + Print(fontsize_); + PrintString("];\n"); + } + } + + const Fst<A> &fst_; + const SymbolTable *isyms_; // ilabel symbol table + const SymbolTable *osyms_; // olabel symbol table + const SymbolTable *ssyms_; // slabel symbol table + bool accep_; // print as acceptor when possible + ostream *ostrm_; // drawn FST destination + string dest_; // drawn FST destination name + + string title_; + float width_; + float height_; + bool portrait_; + bool vertical_; + float ranksep_; + float nodesep_; + int fontsize_; + int precision_; + bool show_weight_one_; + + DISALLOW_COPY_AND_ASSIGN(FstDrawer); +}; + +} // namespace fst + +#endif // FST_SCRIPT_DRAW_IMPL_H_ diff --git a/kaldi_io/src/tools/openfst/include/fst/script/draw.h b/kaldi_io/src/tools/openfst/include/fst/script/draw.h new file mode 100644 index 0000000..2b66373 --- /dev/null +++ b/kaldi_io/src/tools/openfst/include/fst/script/draw.h @@ -0,0 +1,114 @@ + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Copyright 2005-2010 Google, Inc. +// Author: [email protected] (Jake Ratkiewicz) + +#ifndef FST_SCRIPT_DRAW_H_ +#define FST_SCRIPT_DRAW_H_ + +#include <fst/script/arg-packs.h> +#include <fst/script/fst-class.h> +#include <fst/script/draw-impl.h> +#include <iostream> +#include <fstream> +#include <sstream> + +namespace fst { +namespace script { + +// Note: it is safe to pass these strings as references because +// this struct is only used to pass them deeper in the call graph. +// Be sure you understand why this is so before using this struct +// for anything else! +struct FstDrawerArgs { + const FstClass &fst; + const SymbolTable *isyms; + const SymbolTable *osyms; + const SymbolTable *ssyms; + const bool accep; + const string& title; + const float width; + const float height; + const bool portrait; + const bool vertical; + const float ranksep; + const float nodesep; + const int fontsize; + const int precision; + const bool show_weight_one; + ostream *ostrm; + const string &dest; + + FstDrawerArgs(const FstClass &fst, + const SymbolTable *isyms, + const SymbolTable *osyms, + const SymbolTable *ssyms, + bool accep, + const string &title, + float width, + float height, + bool portrait, + bool vertical, + float ranksep, + float nodesep, + int fontsize, + int precision, + bool show_weight_one, + ostream *ostrm, + const string &dest) : + fst(fst), isyms(isyms), osyms(osyms), ssyms(ssyms), accep(accep), + title(title), width(width), height(height), portrait(portrait), + vertical(vertical), ranksep(ranksep), nodesep(nodesep), + fontsize(fontsize), precision(precision), + show_weight_one(show_weight_one), ostrm(ostrm), dest(dest) { } +}; + + +template<class Arc> +void DrawFst(FstDrawerArgs *args) { + const Fst<Arc> &fst = *(args->fst.GetFst<Arc>()); + + FstDrawer<Arc> fstdrawer(fst, args->isyms, args->osyms, args->ssyms, + args->accep, args->title, args->width, + args->height, args->portrait, + args->vertical, args->ranksep, + args->nodesep, args->fontsize, + args->precision, args->show_weight_one); + fstdrawer.Draw(args->ostrm, args->dest); +} + +void DrawFst(const FstClass &fst, + const SymbolTable *isyms, + const SymbolTable *osyms, + const SymbolTable *ssyms, + bool accep, + const string &title, + float width, + float height, + bool portrait, + bool vertical, + float ranksep, + float nodesep, + int fontsize, + int precision, + bool show_weight_one, + ostream *ostrm, + const string &dest); + +} // namespace script +} // namespace fst + + + +#endif // FST_SCRIPT_DRAW_H_ diff --git a/kaldi_io/src/tools/openfst/include/fst/script/encode.h b/kaldi_io/src/tools/openfst/include/fst/script/encode.h new file mode 100644 index 0000000..dc1a290 --- /dev/null +++ b/kaldi_io/src/tools/openfst/include/fst/script/encode.h @@ -0,0 +1,58 @@ + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Copyright 2005-2010 Google, Inc. +// Author: [email protected] (Jake Ratkiewicz) + +#ifndef FST_SCRIPT_ENCODE_H_ +#define FST_SCRIPT_ENCODE_H_ + +#include <string> + +#include <fst/script/arg-packs.h> +#include <fst/script/fst-class.h> +#include <fst/encode.h> + +namespace fst { +namespace script { + +typedef args::Package<MutableFstClass*, uint32, bool, + const string &> EncodeArgs; + +template<class Arc> +void Encode(EncodeArgs *args) { + MutableFst<Arc> *ofst = args->arg1->GetMutableFst<Arc>(); + bool reuse_encoder = args->arg3; + const string &coder_fname = args->arg4; + uint32 flags = args->arg2; + + EncodeMapper<Arc> *encoder = reuse_encoder + ? EncodeMapper<Arc>::Read(coder_fname, ENCODE) + : new EncodeMapper<Arc>(flags, ENCODE); + + Encode(ofst, encoder); + if (!args->arg3) + encoder->Write(coder_fname); + + delete encoder; +} + +void Encode(MutableFstClass *fst, uint32 flags, bool reuse_encoder, + const string &coder_fname); + +} // namespace script +} // namespace fst + + + +#endif // FST_SCRIPT_ENCODE_H_ diff --git a/kaldi_io/src/tools/openfst/include/fst/script/epsnormalize.h b/kaldi_io/src/tools/openfst/include/fst/script/epsnormalize.h new file mode 100644 index 0000000..50b12da --- /dev/null +++ b/kaldi_io/src/tools/openfst/include/fst/script/epsnormalize.h @@ -0,0 +1,44 @@ + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Copyright 2005-2010 Google, Inc. +// Author: [email protected] (Jake Ratkiewicz) + +#ifndef FST_SCRIPT_EPSNORMALIZE_H_ +#define FST_SCRIPT_EPSNORMALIZE_H_ + +#include <fst/script/arg-packs.h> +#include <fst/script/fst-class.h> +#include <fst/epsnormalize.h> + +namespace fst { +namespace script { + +typedef args::Package<const FstClass&, MutableFstClass*, + EpsNormalizeType> EpsNormalizeArgs; + +template<class Arc> +void EpsNormalize(EpsNormalizeArgs *args) { + const Fst<Arc> &ifst = *(args->arg1.GetFst<Arc>()); + MutableFst<Arc> *ofst = args->arg2->GetMutableFst<Arc>(); + + EpsNormalize(ifst, ofst, args->arg3); +} + +void EpsNormalize(const FstClass &ifst, MutableFstClass *ofst, + EpsNormalizeType norm_type = EPS_NORM_INPUT); + +} // namespace script +} // namespace fst + +#endif // FST_SCRIPT_EPSNORMALIZE_H_ diff --git a/kaldi_io/src/tools/openfst/include/fst/script/equal.h b/kaldi_io/src/tools/openfst/include/fst/script/equal.h new file mode 100644 index 0000000..9fb2d3c --- /dev/null +++ b/kaldi_io/src/tools/openfst/include/fst/script/equal.h @@ -0,0 +1,45 @@ + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Copyright 2005-2010 Google, Inc. +// Author: [email protected] (Jake Ratkiewicz) + +#ifndef FST_SCRIPT_EQUAL_H_ +#define FST_SCRIPT_EQUAL_H_ + +#include <fst/script/arg-packs.h> +#include <fst/script/fst-class.h> +#include <fst/equal.h> + +namespace fst { +namespace script { + +typedef args::Package<const FstClass&, const FstClass&, float> EqualInnerArgs; +typedef args::WithReturnValue<bool, EqualInnerArgs> EqualArgs; + +template<class Arc> +void Equal(EqualArgs *args) { + const Fst<Arc> &fst1 = *(args->args.arg1.GetFst<Arc>()); + const Fst<Arc> &fst2 = *(args->args.arg2.GetFst<Arc>()); + + args->retval = Equal(fst1, fst2, args->args.arg3); +} + +bool Equal(const FstClass &fst1, const FstClass &fst2, + float delta = kDelta); + +} // namespace script +} // namespace fst + + +#endif // FST_SCRIPT_EQUAL_H_ diff --git a/kaldi_io/src/tools/openfst/include/fst/script/equivalent.h b/kaldi_io/src/tools/openfst/include/fst/script/equivalent.h new file mode 100644 index 0000000..43460c6 --- /dev/null +++ b/kaldi_io/src/tools/openfst/include/fst/script/equivalent.h @@ -0,0 +1,47 @@ + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Copyright 2005-2010 Google, Inc. +// Author: [email protected] (Jake Ratkiewicz) + +#ifndef FST_SCRIPT_EQUIVALENT_H_ +#define FST_SCRIPT_EQUIVALENT_H_ + +#include <fst/script/arg-packs.h> +#include <fst/script/fst-class.h> +#include <fst/equivalent.h> + +namespace fst { +namespace script { + +typedef args::Package<const FstClass &, const FstClass &, + float> EquivalentInnerArgs; +typedef args::WithReturnValue<bool, EquivalentInnerArgs> EquivalentArgs; + +template<class Arc> +void Equivalent(EquivalentArgs *args) { + const Fst<Arc> &fst1 = *(args->args.arg1.GetFst<Arc>()); + const Fst<Arc> &fst2 = *(args->args.arg2.GetFst<Arc>()); + + args->retval = Equivalent(fst1, fst2, args->args.arg3); +} + +bool Equivalent(const FstClass &fst1, const FstClass &fst2, + float delta = kDelta); + +} // namespace script +} // namespace fst + + + +#endif // FST_SCRIPT_EQUIVALENT_H_ diff --git a/kaldi_io/src/tools/openfst/include/fst/script/fst-class.h b/kaldi_io/src/tools/openfst/include/fst/script/fst-class.h new file mode 100644 index 0000000..fe2cf53 --- /dev/null +++ b/kaldi_io/src/tools/openfst/include/fst/script/fst-class.h @@ -0,0 +1,382 @@ + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Copyright 2005-2010 Google, Inc. +// Author: [email protected] (Jake Ratkiewicz) + +#ifndef FST_SCRIPT_FST_CLASS_H_ +#define FST_SCRIPT_FST_CLASS_H_ + +#include <string> + +#include <fst/fst.h> +#include <fst/mutable-fst.h> +#include <fst/vector-fst.h> +#include <iostream> +#include <fstream> +#include <sstream> + +// Classes to support "boxing" all existing types of FST arcs in a single +// FstClass which hides the arc types. This allows clients to load +// and work with FSTs without knowing the arc type. + +// These classes are only recommended for use in high-level scripting +// applications. Most users should use the lower-level templated versions +// corresponding to these classes. + +namespace fst { +namespace script { + +// +// Abstract base class defining the set of functionalities implemented +// in all impls, and passed through by all bases Below FstClassBase +// the class hierarchy bifurcates; FstClassImplBase serves as the base +// class for all implementations (of which FstClassImpl is currently +// the only one) and FstClass serves as the base class for all +// interfaces. +// +class FstClassBase { + public: + virtual const string &ArcType() const = 0; + virtual const string &FstType() const = 0; + virtual const string &WeightType() const = 0; + virtual const SymbolTable *InputSymbols() const = 0; + virtual const SymbolTable *OutputSymbols() const = 0; + virtual bool Write(const string& fname) const = 0; + virtual bool Write(ostream &ostr, const FstWriteOptions &opts) const = 0; + virtual uint64 Properties(uint64 mask, bool test) const = 0; + virtual ~FstClassBase() { } +}; + +class FstClassImplBase : public FstClassBase { + public: + virtual FstClassImplBase *Copy() = 0; + virtual void SetInputSymbols(SymbolTable *is) = 0; + virtual void SetOutputSymbols(SymbolTable *is) = 0; + virtual ~FstClassImplBase() { } +}; + + +// +// CONTAINER CLASS +// Wraps an Fst<Arc>, hiding its arc type. Whether this Fst<Arc> +// pointer refers to a special kind of FST (e.g. a MutableFst) is +// known by the type of interface class that owns the pointer to this +// container. +// + +template<class Arc> +class FstClassImpl : public FstClassImplBase { + public: + explicit FstClassImpl(Fst<Arc> *impl, + bool should_own = false) : + impl_(should_own ? impl : impl->Copy()) { } + + explicit FstClassImpl(const Fst<Arc> &impl) : impl_(impl.Copy()) { } + + virtual const string &ArcType() const { + return Arc::Type(); + } + + virtual const string &FstType() const { + return impl_->Type(); + } + + virtual const string &WeightType() const { + return Arc::Weight::Type(); + } + + virtual const SymbolTable *InputSymbols() const { + return impl_->InputSymbols(); + } + + virtual const SymbolTable *OutputSymbols() const { + return impl_->OutputSymbols(); + } + + // Warning: calling this method casts the FST to a mutable FST. + virtual void SetInputSymbols(SymbolTable *is) { + static_cast<MutableFst<Arc> *>(impl_)->SetInputSymbols(is); + } + + // Warning: calling this method casts the FST to a mutable FST. + virtual void SetOutputSymbols(SymbolTable *os) { + static_cast<MutableFst<Arc> *>(impl_)->SetOutputSymbols(os); + } + + virtual bool Write(const string &fname) const { + return impl_->Write(fname); + } + + virtual bool Write(ostream &ostr, const FstWriteOptions &opts) const { + return impl_->Write(ostr, opts); + } + + virtual uint64 Properties(uint64 mask, bool test) const { + return impl_->Properties(mask, test); + } + + virtual ~FstClassImpl() { delete impl_; } + + Fst<Arc> *GetImpl() const { return impl_; } + + Fst<Arc> *GetImpl() { return impl_; } + + virtual FstClassImpl *Copy() { + return new FstClassImpl<Arc>(impl_); + } + + private: + Fst<Arc> *impl_; +}; + +// +// BASE CLASS DEFINITIONS +// + +class MutableFstClass; + +class FstClass : public FstClassBase { + public: + template<class Arc> + static FstClass *Read(istream &stream, + const FstReadOptions &opts) { + if (!opts.header) { + FSTERROR() << "FstClass::Read: options header not specified"; + return 0; + } + const FstHeader &hdr = *opts.header; + + if (hdr.Properties() & kMutable) { + return ReadTypedFst<MutableFstClass, MutableFst<Arc> >(stream, opts); + } else { + return ReadTypedFst<FstClass, Fst<Arc> >(stream, opts); + } + } + + FstClass() : impl_(NULL) { + } + + template<class Arc> + explicit FstClass(const Fst<Arc> &fst) : impl_(new FstClassImpl<Arc>(fst)) { + } + + FstClass(const FstClass &other) : impl_(other.impl_->Copy()) { } + + FstClass &operator=(const FstClass &other) { + delete impl_; + impl_ = other.impl_->Copy(); + return *this; + } + + static FstClass *Read(const string &fname); + + static FstClass *Read(istream &istr, const string &source); + + virtual const string &ArcType() const { + return impl_->ArcType(); + } + + virtual const string& FstType() const { + return impl_->FstType(); + } + + virtual const SymbolTable *InputSymbols() const { + return impl_->InputSymbols(); + } + + virtual const SymbolTable *OutputSymbols() const { + return impl_->OutputSymbols(); + } + + virtual const string& WeightType() const { + return impl_->WeightType(); + } + + virtual bool Write(const string &fname) const { + return impl_->Write(fname); + } + + virtual bool Write(ostream &ostr, const FstWriteOptions &opts) const { + return impl_->Write(ostr, opts); + } + + virtual uint64 Properties(uint64 mask, bool test) const { + return impl_->Properties(mask, test); + } + + template<class Arc> + const Fst<Arc> *GetFst() const { + if (Arc::Type() != ArcType()) { + return NULL; + } else { + FstClassImpl<Arc> *typed_impl = static_cast<FstClassImpl<Arc> *>(impl_); + return typed_impl->GetImpl(); + } + } + + virtual ~FstClass() { delete impl_; } + + // These methods are required by IO registration + template<class Arc> + static FstClassImplBase *Convert(const FstClass &other) { + LOG(ERROR) << "Doesn't make sense to convert any class to type FstClass."; + return 0; + } + + template<class Arc> + static FstClassImplBase *Create() { + LOG(ERROR) << "Doesn't make sense to create an FstClass with a " + << "particular arc type."; + return 0; + } + + + protected: + explicit FstClass(FstClassImplBase *impl) : impl_(impl) { } + + // Generic template method for reading an arc-templated FST of type + // UnderlyingT, and returning it wrapped as FstClassT, with appropriate + // error checking. Called from arc-templated Read() static methods. + template<class FstClassT, class UnderlyingT> + static FstClassT* ReadTypedFst(istream &stream, + const FstReadOptions &opts) { + UnderlyingT *u = UnderlyingT::Read(stream, opts); + if (!u) { + return 0; + } else { + FstClassT *r = new FstClassT(*u); + delete u; + return r; + } + } + + FstClassImplBase *GetImpl() const { return impl_; } + + FstClassImplBase *GetImpl() { return impl_; } + +// friend ostream &operator<<(ostream&, const FstClass&); + + private: + FstClassImplBase *impl_; +}; + +// +// Specific types of FstClass with special properties +// + +class MutableFstClass : public FstClass { + public: + template<class Arc> + explicit MutableFstClass(const MutableFst<Arc> &fst) : + FstClass(fst) { } + + template<class Arc> + MutableFst<Arc> *GetMutableFst() { + Fst<Arc> *fst = const_cast<Fst<Arc> *>(this->GetFst<Arc>()); + MutableFst<Arc> *mfst = static_cast<MutableFst<Arc> *>(fst); + + return mfst; + } + + template<class Arc> + static MutableFstClass *Read(istream &stream, + const FstReadOptions &opts) { + MutableFst<Arc> *mfst = MutableFst<Arc>::Read(stream, opts); + if (!mfst) { + return 0; + } else { + MutableFstClass *retval = new MutableFstClass(*mfst); + delete mfst; + return retval; + } + } + + virtual bool Write(const string &fname) const { + return GetImpl()->Write(fname); + } + + virtual bool Write(ostream &ostr, const FstWriteOptions &opts) const { + return GetImpl()->Write(ostr, opts); + } + + static MutableFstClass *Read(const string &fname, bool convert = false); + + virtual void SetInputSymbols(SymbolTable *is) { + GetImpl()->SetInputSymbols(is); + } + + virtual void SetOutputSymbols(SymbolTable *os) { + GetImpl()->SetOutputSymbols(os); + } + + // These methods are required by IO registration + template<class Arc> + static FstClassImplBase *Convert(const FstClass &other) { + LOG(ERROR) << "Doesn't make sense to convert any class to type " + << "MutableFstClass."; + return 0; + } + + template<class Arc> + static FstClassImplBase *Create() { + LOG(ERROR) << "Doesn't make sense to create a MutableFstClass with a " + << "particular arc type."; + return 0; + } + + protected: + explicit MutableFstClass(FstClassImplBase *impl) : FstClass(impl) { } +}; + + +class VectorFstClass : public MutableFstClass { + public: + explicit VectorFstClass(const FstClass &other); + explicit VectorFstClass(const string &arc_type); + + template<class Arc> + explicit VectorFstClass(const VectorFst<Arc> &fst) : + MutableFstClass(fst) { } + + template<class Arc> + static VectorFstClass *Read(istream &stream, + const FstReadOptions &opts) { + VectorFst<Arc> *vfst = VectorFst<Arc>::Read(stream, opts); + if (!vfst) { + return 0; + } else { + VectorFstClass *retval = new VectorFstClass(*vfst); + delete vfst; + return retval; + } + } + + static VectorFstClass *Read(const string &fname); + + // Converter / creator for known arc types + template<class Arc> + static FstClassImplBase *Convert(const FstClass &other) { + return new FstClassImpl<Arc>(new VectorFst<Arc>( + *other.GetFst<Arc>()), true); + } + + template<class Arc> + static FstClassImplBase *Create() { + return new FstClassImpl<Arc>(new VectorFst<Arc>(), true); + } +}; + +} // namespace script +} // namespace fst +#endif // FST_SCRIPT_FST_CLASS_H_ diff --git a/kaldi_io/src/tools/openfst/include/fst/script/fstscript-decl.h b/kaldi_io/src/tools/openfst/include/fst/script/fstscript-decl.h new file mode 100644 index 0000000..fee813e --- /dev/null +++ b/kaldi_io/src/tools/openfst/include/fst/script/fstscript-decl.h @@ -0,0 +1,35 @@ + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Copyright 2005-2010 Google, Inc. +// Author: [email protected] (Jake Ratkiewicz) + +// Forward declarations for the FST and FST-script classes. + +#ifndef FST_SCRIPT_FSTSCRIPT_DECL_H_ +#define FST_SCRIPT_FSTSCRIPT_DECL_H_ + +#include <fst/fst-decl.h> + +namespace fst { +namespace script { + +class FstClass; +class MutableFstClass; +class VectorFstClass; +class WeightClass; + +} // namespace script +} // namespace fst; + +#endif // FST_SCRIPT_FSTSCRIPT_DECL_H_ diff --git a/kaldi_io/src/tools/openfst/include/fst/script/fstscript.h b/kaldi_io/src/tools/openfst/include/fst/script/fstscript.h new file mode 100644 index 0000000..90e1e75 --- /dev/null +++ b/kaldi_io/src/tools/openfst/include/fst/script/fstscript.h @@ -0,0 +1,154 @@ + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Copyright 2005-2010 Google, Inc. +// Author: [email protected] (Jake Ratkiewicz) + +// Convenience file that includes all FstScript functionality + +#ifndef FST_SCRIPT_FSTSCRIPT_H_ +#define FST_SCRIPT_FSTSCRIPT_H_ + +// Major classes +#include <fst/script/fst-class.h> +#include <fst/script/weight-class.h> +#include <fst/script/text-io.h> + +// Templates like Operation< >, Apply< > +#include <fst/script/script-impl.h> + +// Operations +#include <fst/script/arcsort.h> +#include <fst/script/closure.h> +#include <fst/script/compile.h> +#include <fst/script/compose.h> +#include <fst/script/concat.h> +#include <fst/script/connect.h> +#include <fst/script/convert.h> +#include <fst/script/decode.h> +#include <fst/script/determinize.h> +#include <fst/script/difference.h> +#include <fst/script/draw.h> +#include <fst/script/encode.h> +#include <fst/script/epsnormalize.h> +#include <fst/script/equal.h> +#include <fst/script/equivalent.h> +#include <fst/script/info.h> +#include <fst/script/intersect.h> +#include <fst/script/invert.h> +#include <fst/script/map.h> +#include <fst/script/minimize.h> +#include <fst/script/print.h> +#include <fst/script/project.h> +#include <fst/script/prune.h> +#include <fst/script/push.h> +#include <fst/script/randequivalent.h> +#include <fst/script/randgen.h> +#include <fst/script/relabel.h> +#include <fst/script/replace.h> +#include <fst/script/reverse.h> +#include <fst/script/reweight.h> +#include <fst/script/rmepsilon.h> +#include <fst/script/shortest-distance.h> +#include <fst/script/shortest-path.h> +#include <fst/script/symbols.h> +#include <fst/script/synchronize.h> +#include <fst/script/topsort.h> +#include <fst/script/union.h> +#include <fst/script/verify.h> + +// +// REGISTER OPERATIONS +// + + +// This class is necessary because registering each of the operations +// separately overfills the stack, as there's so many of them. +namespace fst { +namespace script { +template<class Arc> +class AllFstOperationsRegisterer { + public: + AllFstOperationsRegisterer() { + RegisterBatch1(); + RegisterBatch2(); + } + + private: + void RegisterBatch1() { + REGISTER_FST_OPERATION(ArcSort, Arc, ArcSortArgs); + REGISTER_FST_OPERATION(Closure, Arc, ClosureArgs); + REGISTER_FST_OPERATION(CompileFst, Arc, FstCompileArgs); + REGISTER_FST_OPERATION(Compose, Arc, ComposeArgs1); + REGISTER_FST_OPERATION(Compose, Arc, ComposeArgs2); + REGISTER_FST_OPERATION(Concat, Arc, ConcatArgs1); + REGISTER_FST_OPERATION(Concat, Arc, ConcatArgs2); + REGISTER_FST_OPERATION(Connect, Arc, MutableFstClass); + REGISTER_FST_OPERATION(Convert, Arc, ConvertArgs); + REGISTER_FST_OPERATION(Decode, Arc, DecodeArgs); + REGISTER_FST_OPERATION(Determinize, Arc, DeterminizeArgs); + REGISTER_FST_OPERATION(Difference, Arc, DifferenceArgs1); + REGISTER_FST_OPERATION(Difference, Arc, DifferenceArgs2); + REGISTER_FST_OPERATION(DrawFst, Arc, FstDrawerArgs); + REGISTER_FST_OPERATION(Encode, Arc, EncodeArgs); + REGISTER_FST_OPERATION(EpsNormalize, Arc, EpsNormalizeArgs); + REGISTER_FST_OPERATION(Equal, Arc, EqualArgs); + REGISTER_FST_OPERATION(Equivalent, Arc, EquivalentArgs); + REGISTER_FST_OPERATION(PrintFstInfo, Arc, InfoArgs); + REGISTER_FST_OPERATION(Intersect, Arc, IntersectArgs1); + REGISTER_FST_OPERATION(Intersect, Arc, IntersectArgs2); + REGISTER_FST_OPERATION(Invert, Arc, MutableFstClass); + REGISTER_FST_OPERATION(Map, Arc, MapArgs); + REGISTER_FST_OPERATION(Minimize, Arc, MinimizeArgs); + } + + void RegisterBatch2() { + REGISTER_FST_OPERATION(PrintFst, Arc, FstPrinterArgs); + REGISTER_FST_OPERATION(Project, Arc, ProjectArgs); + REGISTER_FST_OPERATION(Prune, Arc, PruneArgs1); + REGISTER_FST_OPERATION(Prune, Arc, PruneArgs2); + REGISTER_FST_OPERATION(Prune, Arc, PruneArgs3); + REGISTER_FST_OPERATION(Prune, Arc, PruneArgs4); + REGISTER_FST_OPERATION(Push, Arc, PushArgs1); + REGISTER_FST_OPERATION(Push, Arc, PushArgs2); + REGISTER_FST_OPERATION(RandEquivalent, Arc, RandEquivalentArgs1); + REGISTER_FST_OPERATION(RandEquivalent, Arc, RandEquivalentArgs2); + REGISTER_FST_OPERATION(RandGen, Arc, RandGenArgs); + REGISTER_FST_OPERATION(Relabel, Arc, RelabelArgs1); + REGISTER_FST_OPERATION(Relabel, Arc, RelabelArgs2); + REGISTER_FST_OPERATION(Relabel, Arc, RelabelArgs3); + REGISTER_FST_OPERATION(Replace, Arc, ReplaceArgs); + REGISTER_FST_OPERATION(Reverse, Arc, ReverseArgs); + REGISTER_FST_OPERATION(Reweight, Arc, ReweightArgs); + REGISTER_FST_OPERATION(RmEpsilon, Arc, RmEpsilonArgs1); + REGISTER_FST_OPERATION(RmEpsilon, Arc, RmEpsilonArgs2); + REGISTER_FST_OPERATION(RmEpsilon, Arc, RmEpsilonArgs3); + REGISTER_FST_OPERATION(ShortestDistance, Arc, ShortestDistanceArgs1); + REGISTER_FST_OPERATION(ShortestDistance, Arc, ShortestDistanceArgs2); + REGISTER_FST_OPERATION(ShortestDistance, Arc, ShortestDistanceArgs3); + REGISTER_FST_OPERATION(ShortestPath, Arc, ShortestPathArgs1); + REGISTER_FST_OPERATION(ShortestPath, Arc, ShortestPathArgs2); + REGISTER_FST_OPERATION(Synchronize, Arc, SynchronizeArgs); + REGISTER_FST_OPERATION(TopSort, Arc, TopSortArgs); + REGISTER_FST_OPERATION(Union, Arc, UnionArgs); + REGISTER_FST_OPERATION(Verify, Arc, VerifyArgs); + } +}; +} // namespace script +} // namespace fst + + +#define REGISTER_FST_OPERATIONS(Arc) \ + AllFstOperationsRegisterer<Arc> register_all_fst_operations ## Arc; + +#endif // FST_SCRIPT_FSTSCRIPT_H_ diff --git a/kaldi_io/src/tools/openfst/include/fst/script/info-impl.h b/kaldi_io/src/tools/openfst/include/fst/script/info-impl.h new file mode 100644 index 0000000..408fbcd --- /dev/null +++ b/kaldi_io/src/tools/openfst/include/fst/script/info-impl.h @@ -0,0 +1,325 @@ +// info.h + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Copyright 2005-2010 Google, Inc. +// Author: [email protected] (Michael Riley) +// +// \file +// Class to compute various information about FSTs, helper class for fstinfo.cc + +#ifndef FST_SCRIPT_INFO_IMPL_H_ +#define FST_SCRIPT_INFO_IMPL_H_ + +#include <string> +#include <vector> +using std::vector; + +#include <fst/connect.h> +#include <fst/dfs-visit.h> +#include <fst/fst.h> +#include <fst/lookahead-matcher.h> +#include <fst/matcher.h> +#include <fst/queue.h> +#include <fst/test-properties.h> +#include <fst/verify.h> +#include <fst/visit.h> + +namespace fst { + +// Compute various information about FSTs, helper class for fstinfo.cc. +// WARNING: Stand-alone use of this class is not recommended, most code +// should call directly the relevant library functions: Fst<A>::NumStates, +// Fst<A>::NumArcs, TestProperties, ... +template <class A> class FstInfo { + public: + typedef A Arc; + typedef typename A::StateId StateId; + typedef typename A::Label Label; + typedef typename A::Weight Weight; + + // When info_type is "short" (or "auto" and not an ExpandedFst) + // then only minimal info is computed and can be requested. + FstInfo(const Fst<A> &fst, bool test_properties, + const string &arc_filter_type = "any", + string info_type = "auto", bool verify = true) + : fst_type_(fst.Type()), + input_symbols_(fst.InputSymbols() ? + fst.InputSymbols()->Name() : "none"), + output_symbols_(fst.OutputSymbols() ? + fst.OutputSymbols()->Name() : "none"), + nstates_(0), narcs_(0), start_(kNoStateId), nfinal_(0), + nepsilons_(0), niepsilons_(0), noepsilons_(0), + naccess_(0), ncoaccess_(0), nconnect_(0), ncc_(0), nscc_(0), + input_match_type_(MATCH_NONE), output_match_type_(MATCH_NONE), + input_lookahead_(false), output_lookahead_(false), + properties_(0), arc_filter_type_(arc_filter_type), long_info_(true) { + if (info_type == "long") { + long_info_ = true; + } else if (info_type == "short") { + long_info_ = false; + } else if (info_type == "auto") { + long_info_ = fst.Properties(kExpanded, false); + } else { + FSTERROR() << "Bad info type: " << info_type; + return; + } + + if (!long_info_) + return; + + // If the FST is not sane, we return. + if (verify && !Verify(fst)) { + FSTERROR() << "FstInfo: Verify: FST not well-formed."; + return; + } + + start_ = fst.Start(); + properties_ = fst.Properties(kFstProperties, test_properties); + + for (StateIterator< Fst<A> > siter(fst); + !siter.Done(); + siter.Next()) { + ++nstates_; + StateId s = siter.Value(); + if (fst.Final(s) != Weight::Zero()) + ++nfinal_; + for (ArcIterator< Fst<A> > aiter(fst, s); + !aiter.Done(); + aiter.Next()) { + const A &arc = aiter.Value(); + ++narcs_; + if (arc.ilabel == 0 && arc.olabel == 0) + ++nepsilons_; + if (arc.ilabel == 0) + ++niepsilons_; + if (arc.olabel == 0) + ++noepsilons_; + } + } + + { + vector<StateId> cc; + CcVisitor<Arc> cc_visitor(&cc); + FifoQueue<StateId> fifo_queue; + if (arc_filter_type == "any") { + Visit(fst, &cc_visitor, &fifo_queue); + } else if (arc_filter_type == "epsilon") { + Visit(fst, &cc_visitor, &fifo_queue, EpsilonArcFilter<Arc>()); + } else if (arc_filter_type == "iepsilon") { + Visit(fst, &cc_visitor, &fifo_queue, InputEpsilonArcFilter<Arc>()); + } else if (arc_filter_type == "oepsilon") { + Visit(fst, &cc_visitor, &fifo_queue, OutputEpsilonArcFilter<Arc>()); + } else { + FSTERROR() << "Bad arc filter type: " << arc_filter_type; + return; + } + + for (StateId s = 0; s < cc.size(); ++s) { + if (cc[s] >= ncc_) + ncc_ = cc[s] + 1; + } + } + + { + vector<StateId> scc; + vector<bool> access, coaccess; + uint64 props = 0; + SccVisitor<Arc> scc_visitor(&scc, &access, &coaccess, &props); + if (arc_filter_type == "any") { + DfsVisit(fst, &scc_visitor); + } else if (arc_filter_type == "epsilon") { + DfsVisit(fst, &scc_visitor, EpsilonArcFilter<Arc>()); + } else if (arc_filter_type == "iepsilon") { + DfsVisit(fst, &scc_visitor, InputEpsilonArcFilter<Arc>()); + } else if (arc_filter_type == "oepsilon") { + DfsVisit(fst, &scc_visitor, OutputEpsilonArcFilter<Arc>()); + } else { + FSTERROR() << "Bad arc filter type: " << arc_filter_type; + return; + } + + for (StateId s = 0; s < scc.size(); ++s) { + if (access[s]) + ++naccess_; + if (coaccess[s]) + ++ncoaccess_; + if (access[s] && coaccess[s]) + ++nconnect_; + if (scc[s] >= nscc_) + nscc_ = scc[s] + 1; + } + } + + LookAheadMatcher< Fst<A> > imatcher(fst, MATCH_INPUT); + input_match_type_ = imatcher.Type(test_properties); + input_lookahead_ = imatcher.Flags() & kInputLookAheadMatcher; + + LookAheadMatcher< Fst<A> > omatcher(fst, MATCH_OUTPUT); + output_match_type_ = omatcher.Type(test_properties); + output_lookahead_ = omatcher.Flags() & kOutputLookAheadMatcher; + } + + // Short info + const string& FstType() const { return fst_type_; } + const string& ArcType() const { return A::Type(); } + const string& InputSymbols() const { return input_symbols_; } + const string& OutputSymbols() const { return output_symbols_; } + const bool LongInfo() const { return long_info_; } + const string& ArcFilterType() const { return arc_filter_type_; } + + // Long info + MatchType InputMatchType() const { CheckLong(); return input_match_type_; } + MatchType OutputMatchType() const { CheckLong(); return output_match_type_; } + bool InputLookAhead() const { CheckLong(); return input_lookahead_; } + bool OutputLookAhead() const { CheckLong(); return output_lookahead_; } + int64 NumStates() const { CheckLong(); return nstates_; } + int64 NumArcs() const { CheckLong(); return narcs_; } + int64 Start() const { CheckLong(); return start_; } + int64 NumFinal() const { CheckLong(); return nfinal_; } + int64 NumEpsilons() const { CheckLong(); return nepsilons_; } + int64 NumInputEpsilons() const { CheckLong(); return niepsilons_; } + int64 NumOutputEpsilons() const { CheckLong(); return noepsilons_; } + int64 NumAccessible() const { CheckLong(); return naccess_; } + int64 NumCoAccessible() const { CheckLong(); return ncoaccess_; } + int64 NumConnected() const { CheckLong(); return nconnect_; } + int64 NumCc() const { CheckLong(); return ncc_; } + int64 NumScc() const { CheckLong(); return nscc_; } + uint64 Properties() const { CheckLong(); return properties_; } + + private: + void CheckLong() const { + if (!long_info_) + FSTERROR() << "FstInfo: method only available with long info version"; + } + + string fst_type_; + string input_symbols_; + string output_symbols_; + int64 nstates_; + int64 narcs_; + int64 start_; + int64 nfinal_; + int64 nepsilons_; + int64 niepsilons_; + int64 noepsilons_; + int64 naccess_; + int64 ncoaccess_; + int64 nconnect_; + int64 ncc_; + int64 nscc_; + MatchType input_match_type_; + MatchType output_match_type_; + bool input_lookahead_; + bool output_lookahead_; + uint64 properties_; + string arc_filter_type_; + bool long_info_; + DISALLOW_COPY_AND_ASSIGN(FstInfo); +}; + +template <class A> +void PrintFstInfo(const FstInfo<A> &fstinfo, bool pipe = false) { + ostream &os = pipe ? cerr : cout; + + ios_base::fmtflags old = os.setf(ios::left); + os.width(50); + os << "fst type" << fstinfo.FstType() << endl; + os.width(50); + os << "arc type" << fstinfo.ArcType() << endl; + os.width(50); + os << "input symbol table" << fstinfo.InputSymbols() << endl; + os.width(50); + os << "output symbol table" << fstinfo.OutputSymbols() << endl; + + if (!fstinfo.LongInfo()) { + os.setf(old); + return; + } + + os.width(50); + os << "# of states" << fstinfo.NumStates() << endl; + os.width(50); + os << "# of arcs" << fstinfo.NumArcs() << endl; + os.width(50); + os << "initial state" << fstinfo.Start() << endl; + os.width(50); + os << "# of final states" << fstinfo.NumFinal() << endl; + os.width(50); + os << "# of input/output epsilons" << fstinfo.NumEpsilons() << endl; + os.width(50); + os << "# of input epsilons" << fstinfo.NumInputEpsilons() << endl; + os.width(50); + os << "# of output epsilons" << fstinfo.NumOutputEpsilons() << endl; + os.width(50); + + string arc_type = ""; + if (fstinfo.ArcFilterType() == "epsilon") + arc_type = "epsilon "; + else if (fstinfo.ArcFilterType() == "iepsilon") + arc_type = "input-epsilon "; + else if (fstinfo.ArcFilterType() == "oepsilon") + arc_type = "output-epsilon "; + + string accessible_label = "# of " + arc_type + "accessible states"; + os.width(50); + os << accessible_label << fstinfo.NumAccessible() << endl; + string coaccessible_label = "# of " + arc_type + "coaccessible states"; + os.width(50); + os << coaccessible_label << fstinfo.NumCoAccessible() << endl; + string connected_label = "# of " + arc_type + "connected states"; + os.width(50); + os << connected_label << fstinfo.NumConnected() << endl; + string numcc_label = "# of " + arc_type + "connected components"; + os.width(50); + os << numcc_label << fstinfo.NumCc() << endl; + string numscc_label = "# of " + arc_type + "strongly conn components"; + os.width(50); + os << numscc_label << fstinfo.NumScc() << endl; + + os.width(50); + os << "input matcher" + << (fstinfo.InputMatchType() == MATCH_INPUT ? 'y' : + fstinfo.InputMatchType() == MATCH_NONE ? 'n' : '?') << endl; + os.width(50); + os << "output matcher" + << (fstinfo.OutputMatchType() == MATCH_OUTPUT ? 'y' : + fstinfo.OutputMatchType() == MATCH_NONE ? 'n' : '?') << endl; + os.width(50); + os << "input lookahead" + << (fstinfo.InputLookAhead() ? 'y' : 'n') << endl; + os.width(50); + os << "output lookahead" + << (fstinfo.OutputLookAhead() ? 'y' : 'n') << endl; + + uint64 prop = 1; + for (int i = 0; i < 64; ++i, prop <<= 1) { + if (prop & kBinaryProperties) { + char value = 'n'; + if (fstinfo.Properties() & prop) value = 'y'; + os.width(50); + os << PropertyNames[i] << value << endl; + } else if (prop & kPosTrinaryProperties) { + char value = '?'; + if (fstinfo.Properties() & prop) value = 'y'; + else if (fstinfo.Properties() & prop << 1) value = 'n'; + os.width(50); + os << PropertyNames[i] << value << endl; + } + } + os.setf(old); +} + +} // namespace fst + +#endif // FST_SCRIPT_INFO_IMPL_H_ diff --git a/kaldi_io/src/tools/openfst/include/fst/script/info.h b/kaldi_io/src/tools/openfst/include/fst/script/info.h new file mode 100644 index 0000000..f434bd5 --- /dev/null +++ b/kaldi_io/src/tools/openfst/include/fst/script/info.h @@ -0,0 +1,48 @@ + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Copyright 2005-2010 Google, Inc. +// Author: [email protected] (Jake Ratkiewicz) + +#ifndef FST_SCRIPT_INFO_H_ +#define FST_SCRIPT_INFO_H_ + +#include <fst/script/arg-packs.h> +#include <fst/script/fst-class.h> +#include <fst/script/info-impl.h> + +namespace fst { +namespace script { + +typedef args::Package<const FstClass&, bool, const string&, + const string&, bool, bool> InfoArgs; + +template<class Arc> +void PrintFstInfo(InfoArgs *args) { + const Fst<Arc> &fst = *(args->arg1.GetFst<Arc>()); + FstInfo<Arc> fstinfo(fst, args->arg2, args->arg3, + args->arg4, args->arg5); + PrintFstInfo(fstinfo, args->arg6); + + if (args->arg6) + fst.Write(""); +} + +void PrintFstInfo(const FstClass &f, bool test_properties, + const string &arc_filter, const string &info_type, + bool pipe, bool verify); + +} // namespace script +} // namespace fst + +#endif // FST_SCRIPT_INFO_H_ diff --git a/kaldi_io/src/tools/openfst/include/fst/script/intersect.h b/kaldi_io/src/tools/openfst/include/fst/script/intersect.h new file mode 100644 index 0000000..8011024 --- /dev/null +++ b/kaldi_io/src/tools/openfst/include/fst/script/intersect.h @@ -0,0 +1,65 @@ + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Copyright 2005-2010 Google, Inc. +// Author: [email protected] (Jake Ratkiewicz) + +#ifndef FST_SCRIPT_INTERSECT_H_ +#define FST_SCRIPT_INTERSECT_H_ + +#include <fst/script/arg-packs.h> +#include <fst/script/fst-class.h> +#include <fst/intersect.h> +#include <fst/script/compose.h> // for ComposeOptions, ComposeFilter + +namespace fst { +namespace script { + +typedef args::Package<const FstClass&, const FstClass&, + MutableFstClass*, ComposeFilter> IntersectArgs1; + +template<class Arc> +void Intersect(IntersectArgs1 *args) { + const Fst<Arc> &ifst1 = *(args->arg1.GetFst<Arc>()); + const Fst<Arc> &ifst2 = *(args->arg2.GetFst<Arc>()); + MutableFst<Arc> *ofst = args->arg3->GetMutableFst<Arc>(); + + Intersect(ifst1, ifst2, ofst, args->arg4); +} + +typedef args::Package<const FstClass&, const FstClass&, + MutableFstClass*, const ComposeOptions &> IntersectArgs2; + +template<class Arc> +void Intersect(IntersectArgs2 *args) { + const Fst<Arc> &ifst1 = *(args->arg1.GetFst<Arc>()); + const Fst<Arc> &ifst2 = *(args->arg2.GetFst<Arc>()); + MutableFst<Arc> *ofst = args->arg3->GetMutableFst<Arc>(); + + Intersect(ifst1, ifst2, ofst, args->arg4); +} + +void Intersect(const FstClass &ifst1, const FstClass &ifst2, + MutableFstClass *ofst, + ComposeFilter compose_filter); + +void Intersect(const FstClass &ifst, const FstClass &ifst2, + MutableFstClass *ofst, + const ComposeOptions &opts = fst::script::ComposeOptions()); + +} // namespace script +} // namespace fst + + + +#endif // FST_SCRIPT_INTERSECT_H_ diff --git a/kaldi_io/src/tools/openfst/include/fst/script/invert.h b/kaldi_io/src/tools/openfst/include/fst/script/invert.h new file mode 100644 index 0000000..1befd9f --- /dev/null +++ b/kaldi_io/src/tools/openfst/include/fst/script/invert.h @@ -0,0 +1,43 @@ + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Copyright 2005-2010 Google, Inc. +// Author: [email protected] (Jake Ratkiewicz) + +#ifndef FST_SCRIPT_INVERT_H_ +#define FST_SCRIPT_INVERT_H_ + +#include <fst/script/arg-packs.h> +#include <fst/script/fst-class.h> +#include <fst/invert.h> + +namespace fst { +namespace script { + +// The following confuses swig, because it has the same arguments +// as the non-templated version +#ifndef SWIG +template<class Arc> +void Invert(MutableFstClass *fst) { + MutableFst<Arc> *typed_fst = fst->GetMutableFst<Arc>(); + + Invert(typed_fst); +} +#endif + +void Invert(MutableFstClass *fst); + +} // namespace script +} // namespace fst + +#endif // FST_SCRIPT_INVERT_H_ diff --git a/kaldi_io/src/tools/openfst/include/fst/script/map.h b/kaldi_io/src/tools/openfst/include/fst/script/map.h new file mode 100644 index 0000000..3caaa9f --- /dev/null +++ b/kaldi_io/src/tools/openfst/include/fst/script/map.h @@ -0,0 +1,123 @@ + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Copyright 2005-2010 Google, Inc. +// Author: [email protected] (Jake Ratkiewicz) + +#ifndef FST_SCRIPT_MAP_H_ +#define FST_SCRIPT_MAP_H_ + +#include <fst/script/arg-packs.h> +#include <fst/script/fst-class.h> +#include <fst/script/weight-class.h> +#include <fst/arc-map.h> +#include <fst/state-map.h> + +namespace fst { +namespace script { + +template <class M> +Fst<typename M::ToArc> *ArcMap(const Fst<typename M::FromArc> &fst, + const M &mapper) { + typedef typename M::ToArc ToArc; + VectorFst<ToArc> *ofst = new VectorFst<ToArc>; + ArcMap(fst, ofst, mapper); + return ofst; +} + +template <class M> +Fst<typename M::ToArc> *StateMap(const Fst<typename M::FromArc> &fst, + const M &mapper) { + typedef typename M::ToArc ToArc; + VectorFst<ToArc> *ofst = new VectorFst<ToArc>; + StateMap(fst, ofst, mapper); + return ofst; +} + +enum MapType { ARC_SUM_MAPPER, IDENTITY_MAPPER, INVERT_MAPPER, PLUS_MAPPER, + QUANTIZE_MAPPER, RMWEIGHT_MAPPER, SUPERFINAL_MAPPER, + TIMES_MAPPER, TO_LOG_MAPPER, TO_LOG64_MAPPER, TO_STD_MAPPER }; + +typedef args::Package<const FstClass&, MapType, float, + const WeightClass &> MapInnerArgs; +typedef args::WithReturnValue<FstClass*, MapInnerArgs> MapArgs; + +template <class Arc> +void Map(MapArgs *args) { + const Fst<Arc> &ifst = *(args->args.arg1.GetFst<Arc>()); + MapType map_type = args->args.arg2; + float delta = args->args.arg3; + typename Arc::Weight w = *(args->args.arg4.GetWeight<typename Arc::Weight>()); + + Fst<Arc> *fst = NULL; + Fst<LogArc> *lfst = NULL; + Fst<Log64Arc> *l64fst = NULL; + Fst<StdArc> *sfst = NULL; + if (map_type == ARC_SUM_MAPPER) { + args->retval = new FstClass(*(fst = + script::StateMap(ifst, ArcSumMapper<Arc>(ifst)))); + } else if (map_type == IDENTITY_MAPPER) { + args->retval = new FstClass(*(fst = + script::ArcMap(ifst, IdentityArcMapper<Arc>()))); + } else if (map_type == INVERT_MAPPER) { + args->retval = new FstClass(*(fst = + script::ArcMap(ifst, InvertWeightMapper<Arc>()))); + } else if (map_type == PLUS_MAPPER) { + args->retval = new FstClass(*(fst = + script::ArcMap(ifst, PlusMapper<Arc>(w)))); + } else if (map_type == QUANTIZE_MAPPER) { + args->retval = new FstClass(*(fst = + script::ArcMap(ifst, QuantizeMapper<Arc>(delta)))); + } else if (map_type == RMWEIGHT_MAPPER) { + args->retval = new FstClass(*(fst = + script::ArcMap(ifst, RmWeightMapper<Arc>()))); + } else if (map_type == SUPERFINAL_MAPPER) { + args->retval = new FstClass(*(fst = + script::ArcMap(ifst, SuperFinalMapper<Arc>()))); + } else if (map_type == TIMES_MAPPER) { + args->retval = new FstClass(*(fst = + script::ArcMap(ifst, TimesMapper<Arc>(w)))); + } else if (map_type == TO_LOG_MAPPER) { + args->retval = new FstClass(*(lfst = + script::ArcMap(ifst, WeightConvertMapper<Arc, LogArc>()))); + } else if (map_type == TO_LOG64_MAPPER) { + args->retval = new FstClass(*(l64fst = + script::ArcMap(ifst, WeightConvertMapper<Arc, Log64Arc>()))); + } else if (map_type == TO_STD_MAPPER) { + args->retval = new FstClass(*(sfst = + script::ArcMap(ifst, WeightConvertMapper<Arc, StdArc>()))); + } else { + FSTERROR() << "Error: unknown/unsupported mapper type: " + << map_type; + VectorFst<Arc> *ofst = new VectorFst<Arc>; + ofst->SetProperties(kError, kError); + args->retval = new FstClass(*(fst =ofst)); + } + delete sfst; + delete l64fst; + delete lfst; + delete fst; +} + + +#ifdef SWIG +%newobject Map; +#endif +FstClass *Map(const FstClass& f, MapType map_type, + float delta = fst::kDelta, + const WeightClass &w = fst::script::WeightClass::Zero()); + +} // namespace script +} // namespace fst + +#endif // FST_SCRIPT_MAP_H_ diff --git a/kaldi_io/src/tools/openfst/include/fst/script/minimize.h b/kaldi_io/src/tools/openfst/include/fst/script/minimize.h new file mode 100644 index 0000000..f250d03 --- /dev/null +++ b/kaldi_io/src/tools/openfst/include/fst/script/minimize.h @@ -0,0 +1,45 @@ + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Copyright 2005-2010 Google, Inc. +// Author: [email protected] (Jake Ratkiewicz) + +#ifndef FST_SCRIPT_MINIMIZE_H_ +#define FST_SCRIPT_MINIMIZE_H_ + +#include <fst/script/arg-packs.h> +#include <fst/script/fst-class.h> +#include <fst/minimize.h> + +namespace fst { +namespace script { + +typedef args::Package<MutableFstClass*, MutableFstClass*, float> MinimizeArgs; + +template<class Arc> +void Minimize(MinimizeArgs *args) { + MutableFst<Arc> *ofst1 = args->arg1->GetMutableFst<Arc>(); + MutableFst<Arc> *ofst2 = args->arg2 ? args->arg2->GetMutableFst<Arc>() : 0; + + Minimize(ofst1, ofst2, args->arg3); +} + +void Minimize(MutableFstClass *ofst1, MutableFstClass *ofst2 = 0, + float delta = kDelta); + +} // namespace script +} // namespace fst + + + +#endif // FST_SCRIPT_MINIMIZE_H_ diff --git a/kaldi_io/src/tools/openfst/include/fst/script/print-impl.h b/kaldi_io/src/tools/openfst/include/fst/script/print-impl.h new file mode 100644 index 0000000..1433a29 --- /dev/null +++ b/kaldi_io/src/tools/openfst/include/fst/script/print-impl.h @@ -0,0 +1,149 @@ +// print.h + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Copyright 2005-2010 Google, Inc. +// Author: [email protected] (Michael Riley) +// +// \file +// Stand-alone class to print out binary FSTs in the AT&T format, +// helper class for fstprint.cc + +#ifndef FST_SCRIPT_PRINT_IMPL_H_ +#define FST_SCRIPT_PRINT_IMPL_H_ + +#include <sstream> +#include <string> + +#include <fst/fst.h> +#include <fst/util.h> + +DECLARE_string(fst_field_separator); + +namespace fst { + +// Print a binary Fst in textual format, helper class for fstprint.cc +// WARNING: Stand-alone use of this class not recommended, most code should +// read/write using the binary format which is much more efficient. +template <class A> class FstPrinter { + public: + typedef A Arc; + typedef typename A::StateId StateId; + typedef typename A::Label Label; + typedef typename A::Weight Weight; + + FstPrinter(const Fst<A> &fst, + const SymbolTable *isyms, + const SymbolTable *osyms, + const SymbolTable *ssyms, + bool accep, + bool show_weight_one) + : fst_(fst), isyms_(isyms), osyms_(osyms), ssyms_(ssyms), + accep_(accep && fst.Properties(kAcceptor, true)), ostrm_(0), + show_weight_one_(show_weight_one) {} + + // Print Fst to an output stream + void Print(ostream *ostrm, const string &dest) { + ostrm_ = ostrm; + dest_ = dest; + StateId start = fst_.Start(); + if (start == kNoStateId) + return; + // initial state first + PrintState(start); + for (StateIterator< Fst<A> > siter(fst_); + !siter.Done(); + siter.Next()) { + StateId s = siter.Value(); + if (s != start) + PrintState(s); + } + } + + private: + // Maximum line length in text file. + static const int kLineLen = 8096; + + void PrintId(int64 id, const SymbolTable *syms, + const char *name) const { + if (syms) { + string symbol = syms->Find(id); + if (symbol == "") { + FSTERROR() << "FstPrinter: Integer " << id + << " is not mapped to any textual symbol" + << ", symbol table = " << syms->Name() + << ", destination = " << dest_; + symbol = "?"; + } + *ostrm_ << symbol; + } else { + *ostrm_ << id; + } + } + + void PrintStateId(StateId s) const { + PrintId(s, ssyms_, "state ID"); + } + + void PrintILabel(Label l) const { + PrintId(l, isyms_, "arc input label"); + } + + void PrintOLabel(Label l) const { + PrintId(l, osyms_, "arc output label"); + } + + void PrintState(StateId s) const { + bool output = false; + for (ArcIterator< Fst<A> > aiter(fst_, s); + !aiter.Done(); + aiter.Next()) { + Arc arc = aiter.Value(); + PrintStateId(s); + *ostrm_ << FLAGS_fst_field_separator[0]; + PrintStateId(arc.nextstate); + *ostrm_ << FLAGS_fst_field_separator[0]; + PrintILabel(arc.ilabel); + if (!accep_) { + *ostrm_ << FLAGS_fst_field_separator[0]; + PrintOLabel(arc.olabel); + } + if (show_weight_one_ || arc.weight != Weight::One()) + *ostrm_ << FLAGS_fst_field_separator[0] << arc.weight; + *ostrm_ << "\n"; + output = true; + } + Weight final = fst_.Final(s); + if (final != Weight::Zero() || !output) { + PrintStateId(s); + if (show_weight_one_ || final != Weight::One()) { + *ostrm_ << FLAGS_fst_field_separator[0] << final; + } + *ostrm_ << "\n"; + } + } + + const Fst<A> &fst_; + const SymbolTable *isyms_; // ilabel symbol table + const SymbolTable *osyms_; // olabel symbol table + const SymbolTable *ssyms_; // slabel symbol table + bool accep_; // print as acceptor when possible + ostream *ostrm_; // text FST destination + string dest_; // text FST destination name + bool show_weight_one_; // print weights equal to Weight::One() + DISALLOW_COPY_AND_ASSIGN(FstPrinter); +}; + +} // namespace fst + +#endif // FST_SCRIPT_PRINT_IMPL_H_ diff --git a/kaldi_io/src/tools/openfst/include/fst/script/print.h b/kaldi_io/src/tools/openfst/include/fst/script/print.h new file mode 100644 index 0000000..f82b19b --- /dev/null +++ b/kaldi_io/src/tools/openfst/include/fst/script/print.h @@ -0,0 +1,86 @@ + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Copyright 2005-2010 Google, Inc. +// Author: [email protected] (Jake Ratkiewicz) + +#ifndef FST_SCRIPT_PRINT_H_ +#define FST_SCRIPT_PRINT_H_ + +#include <fst/script/arg-packs.h> +#include <fst/script/fst-class.h> +#include <fst/script/print-impl.h> + +namespace fst { +namespace script { + +// Note: it is safe to pass these strings as references because +// this struct is only used to pass them deeper in the call graph. +// Be sure you understand why this is so before using this struct +// for anything else! +struct FstPrinterArgs { + const FstClass &fst; + const SymbolTable *isyms; + const SymbolTable *osyms; + const SymbolTable *ssyms; + const bool accept; + const bool show_weight_one; + ostream *ostrm; + const string &dest; + + FstPrinterArgs(const FstClass &fst, + const SymbolTable *isyms, + const SymbolTable *osyms, + const SymbolTable *ssyms, + bool accept, + bool show_weight_one, + ostream *ostrm, + const string &dest) : + fst(fst), isyms(isyms), osyms(osyms), ssyms(ssyms), accept(accept), + show_weight_one(show_weight_one), ostrm(ostrm), dest(dest) { } +}; + +template<class Arc> +void PrintFst(FstPrinterArgs *args) { + const Fst<Arc> &fst = *(args->fst.GetFst<Arc>()); + + fst::FstPrinter<Arc> fstprinter(fst, args->isyms, args->osyms, + args->ssyms, args->accept, + args->show_weight_one); + fstprinter.Print(args->ostrm, args->dest); +} + +void PrintFst(const FstClass &fst, ostream &ostrm, const string &dest, + const SymbolTable *isyms, + const SymbolTable *osyms, + const SymbolTable *ssyms, + bool accept, bool show_weight_one); + + +// Below are two printing methods with useful defaults for a few of +// the fst printer arguments. +template <class Arc> +void PrintFst(const Fst<Arc> &fst, ostream &os, const string dest = "", + const SymbolTable *isyms = NULL, + const SymbolTable *osyms = NULL, + const SymbolTable *ssyms = NULL) { + fst::FstPrinter<Arc> fstprinter(fst, isyms, osyms, ssyms, true, true); + fstprinter.Print(&os, dest); +} + +} // namespace script +} // namespace fst + + + +#endif // FST_SCRIPT_PRINT_H_ diff --git a/kaldi_io/src/tools/openfst/include/fst/script/project.h b/kaldi_io/src/tools/openfst/include/fst/script/project.h new file mode 100644 index 0000000..12ee890 --- /dev/null +++ b/kaldi_io/src/tools/openfst/include/fst/script/project.h @@ -0,0 +1,43 @@ + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Copyright 2005-2010 Google, Inc. +// Author: [email protected] (Jake Ratkiewicz) + +#ifndef FST_SCRIPT_PROJECT_H_ +#define FST_SCRIPT_PROJECT_H_ + +#include <fst/script/arg-packs.h> +#include <fst/script/fst-class.h> +#include <fst/project.h> // for ProjectType + +namespace fst { +namespace script { + +typedef args::Package<MutableFstClass*, ProjectType> ProjectArgs; + +template<class Arc> +void Project(ProjectArgs *args) { + MutableFst<Arc> *ofst = args->arg1->GetMutableFst<Arc>(); + + Project(ofst, args->arg2); +} + +void Project(MutableFstClass *ofst, ProjectType project_type); + +} // namespace script +} // namespace fst + + + +#endif // FST_SCRIPT_PROJECT_H_ diff --git a/kaldi_io/src/tools/openfst/include/fst/script/prune.h b/kaldi_io/src/tools/openfst/include/fst/script/prune.h new file mode 100644 index 0000000..7118ff1 --- /dev/null +++ b/kaldi_io/src/tools/openfst/include/fst/script/prune.h @@ -0,0 +1,153 @@ + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Copyright 2005-2010 Google, Inc. +// Author: [email protected] (Jake Ratkiewicz) + +#ifndef FST_SCRIPT_PRUNE_H_ +#define FST_SCRIPT_PRUNE_H_ + +#include <vector> +using std::vector; + +#include <fst/script/arg-packs.h> +#include <fst/script/fst-class.h> +#include <fst/script/weight-class.h> +#include <fst/prune.h> +#include <fst/arcfilter.h> + +namespace fst { +namespace script { + +struct PruneOptions { + WeightClass weight_threshold; + int64 state_threshold; + const vector<WeightClass> *distance; + float delta; + + explicit PruneOptions(const WeightClass& w, int64 s, + vector<WeightClass> *d = 0, float e = kDelta) + : weight_threshold(w), + state_threshold(s), + distance(d), + delta(e) {} + private: + PruneOptions(); // disallow +}; + +// converts a script::PruneOptions into a fst::PruneOptions. +// Notes: +// If the original opts.distance is not NULL, a new distance will be +// created with new; it's the client's responsibility to delete this. + +template<class A> +fst::PruneOptions<A, AnyArcFilter<A> > ConvertPruneOptions( + const PruneOptions &opts) { + typedef typename A::Weight Weight; + typedef typename A::StateId StateId; + + Weight weight_threshold = *(opts.weight_threshold.GetWeight<Weight>()); + StateId state_threshold = opts.state_threshold; + vector<Weight> *distance = 0; + + if (opts.distance) { + distance = new vector<Weight>(opts.distance->size()); + for (unsigned i = 0; i < opts.distance->size(); ++i) { + (*distance)[i] = *((*opts.distance)[i].GetWeight<Weight>()); + } + } + + return fst::PruneOptions<A, AnyArcFilter<A> >( + weight_threshold, state_threshold, AnyArcFilter<A>(), distance, + opts.delta); +} + +// 1 +typedef args::Package<MutableFstClass *, const PruneOptions &> PruneArgs1; + +template<class Arc> +void Prune(PruneArgs1 *args) { + MutableFst<Arc> *ofst = args->arg1->GetMutableFst<Arc>(); + + typedef typename Arc::Weight Weight; + typedef typename Arc::StateId StateId; + + fst::PruneOptions<Arc, AnyArcFilter<Arc> > opts = + ConvertPruneOptions<Arc>(args->arg2); + Prune(ofst, opts); + delete opts.distance; +} + +// 2 +typedef args::Package<const FstClass &, MutableFstClass *, + const PruneOptions &> PruneArgs2; + +template<class Arc> +void Prune(PruneArgs2 *args) { + const Fst<Arc>& ifst = *(args->arg1.GetFst<Arc>()); + MutableFst<Arc> *ofst = args->arg2->GetMutableFst<Arc>(); + + fst::PruneOptions<Arc, AnyArcFilter<Arc> > opts = + ConvertPruneOptions<Arc>(args->arg3); + Prune(ifst, ofst, opts); + delete opts.distance; +} + +// 3 +typedef args::Package<const FstClass &, + MutableFstClass *, + const WeightClass &, int64, float> PruneArgs3; + +template<class Arc> +void Prune(PruneArgs3 *args) { + const Fst<Arc>& ifst = *(args->arg1.GetFst<Arc>()); + MutableFst<Arc> *ofst = args->arg2->GetMutableFst<Arc>(); + typename Arc::Weight w = *(args->arg3.GetWeight<typename Arc::Weight>()); + + Prune(ifst, ofst, w, args->arg4, args->arg5); +} + +// 4 +typedef args::Package<MutableFstClass *, const WeightClass&, + int64, float> PruneArgs4; +template<class Arc> +void Prune(PruneArgs4 *args) { + MutableFst<Arc> *fst = args->arg1->GetMutableFst<Arc>(); + typename Arc::Weight w = *(args->arg2.GetWeight<typename Arc::Weight>()); + Prune(fst, w, args->arg3, args->arg4); +} + + +// 1 +void Prune(MutableFstClass *fst, const PruneOptions &opts); + +// 2 +void Prune(const FstClass &ifst, MutableFstClass *fst, + const PruneOptions &opts); + +// 3 +void Prune(const FstClass &ifst, MutableFstClass *ofst, + const WeightClass &weight_threshold, + int64 state_threshold = kNoStateId, + float delta = kDelta); + +// 4 +void Prune(MutableFstClass *fst, const WeightClass& weight_threshold, + int64 state_threshold, float delta); + +} // namespace script +} // namespace fst + + + +#endif // FST_SCRIPT_PRUNE_H_ diff --git a/kaldi_io/src/tools/openfst/include/fst/script/push.h b/kaldi_io/src/tools/openfst/include/fst/script/push.h new file mode 100644 index 0000000..cebd655 --- /dev/null +++ b/kaldi_io/src/tools/openfst/include/fst/script/push.h @@ -0,0 +1,70 @@ + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Copyright 2005-2010 Google, Inc. +// Author: [email protected] (Jake Ratkiewicz) + +#ifndef FST_SCRIPT_PUSH_H_ +#define FST_SCRIPT_PUSH_H_ + +#include <fst/script/arg-packs.h> +#include <fst/script/fst-class.h> +#include <fst/push.h> + +namespace fst { +namespace script { + +// 1 +typedef args::Package<MutableFstClass*, ReweightType, float, bool> PushArgs1; + +template<class Arc> +void Push(PushArgs1 *args) { + MutableFst<Arc> *ofst = args->arg1->GetMutableFst<Arc>(); + + if (args->arg2 == REWEIGHT_TO_FINAL) { + fst::Push(ofst, REWEIGHT_TO_FINAL, args->arg3, args->arg4); + } else { + fst::Push(ofst, REWEIGHT_TO_INITIAL, args->arg3, args->arg4); + } +} + +// 2 +typedef args::Package<const FstClass &, MutableFstClass *, uint32, + ReweightType, float> PushArgs2; + +template<class Arc> +void Push(PushArgs2 *args) { + const Fst<Arc> &ifst = *(args->arg1.GetFst<Arc>()); + MutableFst<Arc> *ofst = args->arg2->GetMutableFst<Arc>(); + + if (args->arg4 == REWEIGHT_TO_FINAL) { + fst::Push<Arc, REWEIGHT_TO_FINAL>(ifst, ofst, args->arg3, args->arg5); + } else { + fst::Push<Arc, REWEIGHT_TO_INITIAL>(ifst, ofst, args->arg3, args->arg5); + } +} + +// 1 +void Push(MutableFstClass *ofst, ReweightType type, float delta = kDelta, + bool remove_total_weight = false); + +// 2 +void Push(const FstClass &ifst, MutableFstClass *ofst, uint32 flags, + ReweightType dir, float delta); + +} // namespace script +} // namespace fst + + + +#endif // FST_SCRIPT_PUSH_H_ diff --git a/kaldi_io/src/tools/openfst/include/fst/script/randequivalent.h b/kaldi_io/src/tools/openfst/include/fst/script/randequivalent.h new file mode 100644 index 0000000..b929683 --- /dev/null +++ b/kaldi_io/src/tools/openfst/include/fst/script/randequivalent.h @@ -0,0 +1,105 @@ + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Copyright 2005-2010 Google, Inc. +// Author: [email protected] (Jake Ratkiewicz) + +#ifndef FST_SCRIPT_RANDEQUIVALENT_H_ +#define FST_SCRIPT_RANDEQUIVALENT_H_ + +#include <fst/script/arg-packs.h> +#include <fst/script/fst-class.h> +#include <fst/script/randgen.h> // for RandArcSelection +#include <fst/randequivalent.h> + +namespace fst { +namespace script { + +// 1 +typedef args::Package<const FstClass&, const FstClass&, + int32, float, int, int> RandEquivalentInnerArgs1; +typedef args::WithReturnValue<bool, + RandEquivalentInnerArgs1> RandEquivalentArgs1; + +template<class Arc> +void RandEquivalent(RandEquivalentArgs1 *args) { + const Fst<Arc> &fst1 = *(args->args.arg1.GetFst<Arc>()); + const Fst<Arc> &fst2 = *(args->args.arg2.GetFst<Arc>()); + + args->retval = RandEquivalent(fst1, fst2, args->args.arg3, args->args.arg4, + args->args.arg5, args->args.arg6); +} + +// 2 +typedef args::Package<const FstClass &, const FstClass &, int32, + ssize_t, float, + const RandGenOptions<RandArcSelection> &> + RandEquivalentInnerArgs2; + +typedef args::WithReturnValue<bool, + RandEquivalentInnerArgs2> RandEquivalentArgs2; + +template<class Arc> +void RandEquivalent(RandEquivalentArgs2 *args) { + const Fst<Arc> &fst1 = *(args->args.arg1.GetFst<Arc>()); + const Fst<Arc> &fst2 = *(args->args.arg2.GetFst<Arc>()); + const RandGenOptions<RandArcSelection> &opts = args->args.arg6; + int32 seed = args->args.arg3; + + if (opts.arc_selector == UNIFORM_ARC_SELECTOR) { + UniformArcSelector<Arc> arc_selector(seed); + RandGenOptions< UniformArcSelector<Arc> > + ropts(arc_selector, opts.max_length, opts.npath); + + args->retval = RandEquivalent(fst1, fst2, args->args.arg4, + args->args.arg5, ropts); + } else if (opts.arc_selector == FAST_LOG_PROB_ARC_SELECTOR) { + FastLogProbArcSelector<Arc> arc_selector(seed); + RandGenOptions< FastLogProbArcSelector<Arc> > + ropts(arc_selector, opts.max_length, opts.npath); + + args->retval = RandEquivalent(fst1, fst2, args->args.arg4, + args->args.arg5, ropts); + } else { + LogProbArcSelector<Arc> arc_selector(seed); + RandGenOptions< LogProbArcSelector<Arc> > + ropts(arc_selector, opts.max_length, opts.npath); + args->retval = RandEquivalent(fst1, fst2, args->args.arg4, + args->args.arg5, ropts); + } +} + + +// 1 +bool RandEquivalent(const FstClass &fst1, + const FstClass &fst2, + int32 seed = time(0), + ssize_t num_paths = 1, + float delta = fst::kDelta, + int path_length = INT_MAX); + +// 2 +bool RandEquivalent(const FstClass &fst1, + const FstClass &fst2, + int32 seed, + ssize_t num_paths, + float delta, + const fst::RandGenOptions< + fst::script::RandArcSelection> &opts); + +} // namespace script +} // namespace fst + + + +#endif // FST_SCRIPT_RANDEQUIVALENT_H_ diff --git a/kaldi_io/src/tools/openfst/include/fst/script/randgen.h b/kaldi_io/src/tools/openfst/include/fst/script/randgen.h new file mode 100644 index 0000000..817f9c1 --- /dev/null +++ b/kaldi_io/src/tools/openfst/include/fst/script/randgen.h @@ -0,0 +1,76 @@ + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Copyright 2005-2010 Google, Inc. +// Author: [email protected] (Jake Ratkiewicz) + +#ifndef FST_SCRIPT_RANDGEN_H_ +#define FST_SCRIPT_RANDGEN_H_ + +#include <fst/script/arg-packs.h> +#include <fst/script/fst-class.h> +#include <fst/randgen.h> + +namespace fst { +namespace script { + +enum RandArcSelection { + UNIFORM_ARC_SELECTOR, + LOG_PROB_ARC_SELECTOR, + FAST_LOG_PROB_ARC_SELECTOR +}; + +typedef args::Package<const FstClass &, MutableFstClass*, int32, + const RandGenOptions<RandArcSelection> &> RandGenArgs; + +template<class Arc> +void RandGen(RandGenArgs *args) { + const Fst<Arc> &ifst = *(args->arg1.GetFst<Arc>()); + MutableFst<Arc> *ofst = args->arg2->GetMutableFst<Arc>(); + int32 seed = args->arg3; + const RandGenOptions<RandArcSelection> &opts = args->arg4; + + if (opts.arc_selector == UNIFORM_ARC_SELECTOR) { + UniformArcSelector<Arc> arc_selector(seed); + RandGenOptions< UniformArcSelector<Arc> > + ropts(arc_selector, opts.max_length, + opts.npath, opts.weighted); + RandGen(ifst, ofst, ropts); + } else if (opts.arc_selector == FAST_LOG_PROB_ARC_SELECTOR) { + FastLogProbArcSelector<Arc> arc_selector(seed); + RandGenOptions< FastLogProbArcSelector<Arc> > + ropts(arc_selector, opts.max_length, + opts.npath, opts.weighted); + RandGen(ifst, ofst, ropts); + } else { + LogProbArcSelector<Arc> arc_selector(seed); + RandGenOptions< LogProbArcSelector<Arc> > + ropts(arc_selector, opts.max_length, + opts.npath, opts.weighted); + RandGen(ifst, ofst, ropts); + } +} + + +// Client-facing prototype +void RandGen(const FstClass &ifst, MutableFstClass *ofst, int32 seed = time(0), + const RandGenOptions<RandArcSelection> &opts = + fst::RandGenOptions<fst::script::RandArcSelection>( + fst::script::UNIFORM_ARC_SELECTOR)); + +} // namespace script +} // namespace fst + + + +#endif // FST_SCRIPT_RANDGEN_H_ diff --git a/kaldi_io/src/tools/openfst/include/fst/script/register.h b/kaldi_io/src/tools/openfst/include/fst/script/register.h new file mode 100644 index 0000000..03e0e36 --- /dev/null +++ b/kaldi_io/src/tools/openfst/include/fst/script/register.h @@ -0,0 +1,120 @@ + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Copyright 2005-2010 Google, Inc. +// Author: [email protected] (Jake Ratkiewicz) + +#ifndef FST_SCRIPT_REGISTER_H_ +#define FST_SCRIPT_REGISTER_H_ + +#include <string> + +#include <fst/generic-register.h> +#include <fst/script/fst-class.h> +#include <fst/script/weight-class.h> + +// Holds methods and classes responsible for maintaining +// the register for FstClass arc types. + +namespace fst { +namespace script { + +// +// Registers for reading and converting various kinds of FST classes. +// + +// This class definition is to avoid a nested class definition inside +// the IORegistration struct. +template<class Reader, class Creator, class Converter> +struct FstClassRegEntry { + Reader reader; + Creator creator; + Converter converter; + + FstClassRegEntry(Reader r, Creator cr, Converter co) : + reader(r), creator(cr), converter(co) { } + FstClassRegEntry() : reader(0), creator(0), converter(0) { } +}; + +template<class Reader, class Creator, class Converter> +class FstClassIORegister + : public GenericRegister<string, + FstClassRegEntry<Reader, Creator, Converter>, + FstClassIORegister<Reader, Creator, + Converter> > { + public: + Reader GetReader(const string &arc_type) const { + return this->GetEntry(arc_type).reader; + } + + Creator GetCreator(const string &arc_type) const { + return this->GetEntry(arc_type).creator; + } + + Converter GetConverter(const string &arc_type) const { + return this->GetEntry(arc_type).converter; + } + + protected: + virtual string ConvertKeyToSoFilename( + const string& key) const { + string legal_type(key); + ConvertToLegalCSymbol(&legal_type); + + return legal_type + "-arc.so"; + } +}; + +// +// Struct containing everything needed to register a particular type +// of FST class (e.g. a plain FstClass, or a MutableFstClass, etc) +// +template<class FstClassType> +struct IORegistration { + typedef FstClassType *(*Reader)(istream &stream, + const FstReadOptions &opts); + + typedef FstClassImplBase *(*Creator)(); + typedef FstClassImplBase *(*Converter)(const FstClass &other); + + typedef FstClassRegEntry<Reader, Creator, Converter> Entry; + + // FST class Register + typedef FstClassIORegister<Reader, Creator, Converter> Register; + + // FST class Register-er + typedef GenericRegisterer<FstClassIORegister<Reader, Creator, Converter> > + Registerer; +}; + + +// +// REGISTRATION MACROS +// + +#define REGISTER_FST_CLASS(Class, Arc) \ + static IORegistration<Class>::Registerer Class ## _ ## Arc ## _registerer( \ + Arc::Type(), \ + IORegistration<Class>::Entry(Class::Read<Arc>, \ + Class::Create<Arc>, \ + Class::Convert<Arc>)) + +#define REGISTER_FST_CLASSES(Arc) \ + REGISTER_FST_CLASS(FstClass, Arc); \ + REGISTER_FST_CLASS(MutableFstClass, Arc); \ + REGISTER_FST_CLASS(VectorFstClass, Arc); + +} // namespace script +} // namespace fst + +#endif // FST_SCRIPT_REGISTER_H_ diff --git a/kaldi_io/src/tools/openfst/include/fst/script/relabel.h b/kaldi_io/src/tools/openfst/include/fst/script/relabel.h new file mode 100644 index 0000000..6bbb4c5 --- /dev/null +++ b/kaldi_io/src/tools/openfst/include/fst/script/relabel.h @@ -0,0 +1,102 @@ + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Copyright 2005-2010 Google, Inc. +// Author: [email protected] (Jake Ratkiewicz) + +#ifndef FST_SCRIPT_RELABEL_H_ +#define FST_SCRIPT_RELABEL_H_ + +#include <utility> +using std::pair; using std::make_pair; +#include <algorithm> +#include <vector> +using std::vector; + +#include <fst/script/arg-packs.h> +#include <fst/script/fst-class.h> +#include <fst/relabel.h> + +namespace fst { +namespace script { + +// 1 +typedef args::Package<MutableFstClass *, + const SymbolTable *, const SymbolTable *, bool, + const SymbolTable *, const SymbolTable *, + bool> RelabelArgs1; + +template<class Arc> +void Relabel(RelabelArgs1 *args) { + MutableFst<Arc> *ofst = args->arg1->GetMutableFst<Arc>(); + + Relabel(ofst, args->arg2, args->arg3, args->arg4, + args->arg5, args->arg6, args->arg7); +} + +// 2 +typedef args::Package<MutableFstClass*, + const vector<pair<int64, int64> > &, + const vector<pair<int64, int64> > > RelabelArgs2; + +template<class Arc> +void Relabel(RelabelArgs2 *args) { + MutableFst<Arc> *ofst = args->arg1->GetMutableFst<Arc>(); + + // In case int64 is not the same as Arc::Label, + // copy the reassignments + typedef typename Arc::Label Label; + + vector<pair<Label, Label> > converted_ipairs(args->arg2.size()); + copy(args->arg2.begin(), args->arg2.end(), converted_ipairs.begin()); + + vector<pair<Label, Label> > converted_opairs(args->arg3.size()); + copy(args->arg3.begin(), args->arg3.end(), converted_opairs.begin()); + + Relabel(ofst, converted_ipairs, converted_opairs); +} + +// 3 +typedef args::Package<MutableFstClass*, const SymbolTable*, + const SymbolTable*> RelabelArgs3; +template<class Arc> +void Relabel(args::Package<MutableFstClass*, const SymbolTable*, + const SymbolTable*> *args) { + MutableFst<Arc> *fst = args->arg1->GetMutableFst<Arc>(); + Relabel(fst, args->arg2, args->arg3); +} + + +// 1 +void Relabel(MutableFstClass *ofst, + const SymbolTable *old_isyms, const SymbolTable *relabel_isyms, + bool attach_new_isyms, + const SymbolTable *old_osyms, const SymbolTable *relabel_osyms, + bool attch_new_osyms); + +// 2 +void Relabel(MutableFstClass *ofst, + const vector<pair<int64, int64> > &ipairs, + const vector<pair<int64, int64> > &opairs); + + +// 3 +void Relabel(MutableFstClass *fst, + const SymbolTable *new_isymbols, + const SymbolTable *new_osymbols); + + +} // namespace script +} // namespace fst + +#endif // FST_SCRIPT_RELABEL_H_ diff --git a/kaldi_io/src/tools/openfst/include/fst/script/replace.h b/kaldi_io/src/tools/openfst/include/fst/script/replace.h new file mode 100644 index 0000000..5eaf5bf --- /dev/null +++ b/kaldi_io/src/tools/openfst/include/fst/script/replace.h @@ -0,0 +1,62 @@ + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Copyright 2005-2010 Google, Inc. +// Author: [email protected] (Jake Ratkiewicz) + +#ifndef FST_SCRIPT_REPLACE_H_ +#define FST_SCRIPT_REPLACE_H_ + +#include <utility> +using std::pair; using std::make_pair; +#include <vector> +using std::vector; + +#include <fst/script/arg-packs.h> +#include <fst/script/fst-class.h> +#include <fst/replace.h> + +namespace fst { +namespace script { + +typedef args::Package<const vector<pair<int64, const FstClass *> > &, + MutableFstClass *, const int64, bool> ReplaceArgs; + +template<class Arc> +void Replace(ReplaceArgs *args) { + // Now that we know the arc type, we construct a vector of + // pair<real label, real fst> that the real Replace will use + const vector<pair<int64, const FstClass *> >& untyped_tuples = + args->arg1; + + vector<pair<typename Arc::Label, const Fst<Arc> *> > fst_tuples( + untyped_tuples.size()); + + for (unsigned i = 0; i < untyped_tuples.size(); ++i) { + fst_tuples[i].first = untyped_tuples[i].first; // convert label + fst_tuples[i].second = untyped_tuples[i].second->GetFst<Arc>(); + } + + MutableFst<Arc> *ofst = args->arg2->GetMutableFst<Arc>(); + + Replace(fst_tuples, ofst, args->arg3, args->arg4); +} + +void Replace(const vector<pair<int64, const FstClass *> > &tuples, + MutableFstClass *ofst, const int64 &root, + bool epsilon_on_replace = false); + +} // namespace script +} // namespace fst + +#endif // FST_SCRIPT_REPLACE_H_ diff --git a/kaldi_io/src/tools/openfst/include/fst/script/reverse.h b/kaldi_io/src/tools/openfst/include/fst/script/reverse.h new file mode 100644 index 0000000..3930875 --- /dev/null +++ b/kaldi_io/src/tools/openfst/include/fst/script/reverse.h @@ -0,0 +1,42 @@ + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Copyright 2005-2010 Google, Inc. +// Author: [email protected] (Jake Ratkiewicz) + +#ifndef FST_SCRIPT_REVERSE_H_ +#define FST_SCRIPT_REVERSE_H_ + +#include <fst/script/arg-packs.h> +#include <fst/script/fst-class.h> +#include <fst/reverse.h> + +namespace fst { +namespace script { + +typedef args::Package<const FstClass &, MutableFstClass *> ReverseArgs; + +template<class Arc> +void Reverse(ReverseArgs *args) { + const Fst<Arc> &fst1 = *(args->arg1.GetFst<Arc>()); + MutableFst<Arc> *fst2 = args->arg2->GetMutableFst<Arc>(); + + Reverse(fst1, fst2); +} + +void Reverse(const FstClass &fst1, MutableFstClass *fst2); + +} // namespace script +} // namespace fst + +#endif // FST_SCRIPT_REVERSE_H_ diff --git a/kaldi_io/src/tools/openfst/include/fst/script/reweight.h b/kaldi_io/src/tools/openfst/include/fst/script/reweight.h new file mode 100644 index 0000000..7bce839 --- /dev/null +++ b/kaldi_io/src/tools/openfst/include/fst/script/reweight.h @@ -0,0 +1,53 @@ + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Copyright 2005-2010 Google, Inc. +// Author: [email protected] (Jake Ratkiewicz) + +#ifndef FST_SCRIPT_REWEIGHT_H_ +#define FST_SCRIPT_REWEIGHT_H_ + +#include <vector> +using std::vector; + +#include <fst/script/arg-packs.h> +#include <fst/script/fst-class.h> +#include <fst/script/weight-class.h> +#include <fst/reweight.h> + +namespace fst { +namespace script { + +typedef args::Package<MutableFstClass *, const vector<WeightClass> &, + ReweightType> ReweightArgs; + +template<class Arc> +void Reweight(ReweightArgs *args) { + MutableFst<Arc> *fst = args->arg1->GetMutableFst<Arc>(); + typedef typename Arc::Weight Weight; + vector<Weight> potentials(args->arg2.size()); + + for (unsigned i = 0; i < args->arg2.size(); ++i) { + potentials[i] = *(args->arg2[i].GetWeight<Weight>()); + } + + Reweight(fst, potentials, args->arg3); +} + +void Reweight(MutableFstClass *fst, const vector<WeightClass> &potential, + ReweightType reweight_type); + +} // namespace script +} // namespace fst + +#endif // FST_SCRIPT_REWEIGHT_H_ diff --git a/kaldi_io/src/tools/openfst/include/fst/script/rmepsilon.h b/kaldi_io/src/tools/openfst/include/fst/script/rmepsilon.h new file mode 100644 index 0000000..62fed03 --- /dev/null +++ b/kaldi_io/src/tools/openfst/include/fst/script/rmepsilon.h @@ -0,0 +1,211 @@ + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Copyright 2005-2010 Google, Inc. +// Author: [email protected] (Jake Ratkiewicz) + +#ifndef FST_SCRIPT_RMEPSILON_H_ +#define FST_SCRIPT_RMEPSILON_H_ + +#include <vector> +using std::vector; + +#include <fst/script/arg-packs.h> +#include <fst/script/fst-class.h> +#include <fst/script/weight-class.h> +#include <fst/script/shortest-distance.h> // for ShortestDistanceOptions +#include <fst/rmepsilon.h> +#include <fst/queue.h> + +// the following is necessary, or SWIG complains mightily about +// shortestdistanceoptions not being defined before being used as a base. +#ifdef SWIG +%include "nlp/fst/script/shortest-distance.h" +#endif + + +namespace fst { +namespace script { + +// +// OPTIONS +// + +struct RmEpsilonOptions : public fst::script::ShortestDistanceOptions { + bool connect; + WeightClass weight_threshold; + int64 state_threshold; + + RmEpsilonOptions(QueueType qt = AUTO_QUEUE, float d = kDelta, bool c = true, + WeightClass w = fst::script::WeightClass::Zero(), + int64 n = kNoStateId) + : ShortestDistanceOptions(qt, EPSILON_ARC_FILTER, + kNoStateId, d), + connect(c), weight_threshold(w), state_threshold(n) { } +}; + + +// +// TEMPLATES +// + +// this function takes care of transforming a script-land RmEpsilonOptions +// into a lib-land RmEpsilonOptions +template<class Arc> +void RmEpsilonHelper(MutableFst<Arc> *fst, + vector<typename Arc::Weight> *distance, + const RmEpsilonOptions &opts) { + typedef typename Arc::StateId StateId; + typedef typename Arc::Weight Weight; + + typename Arc::Weight weight_thresh = + *(opts.weight_threshold.GetWeight<Weight>()); + + switch (opts.queue_type) { + case AUTO_QUEUE: { + AutoQueue<StateId> queue(*fst, distance, EpsilonArcFilter<Arc>()); + fst::RmEpsilonOptions<Arc, AutoQueue<StateId> > ropts( + &queue, opts.delta, opts.connect, weight_thresh, + opts.state_threshold); + RmEpsilon(fst, distance, ropts); + break; + } + case FIFO_QUEUE: { + FifoQueue<StateId> queue; + fst::RmEpsilonOptions<Arc, FifoQueue<StateId> > ropts( + &queue, opts.delta, opts.connect, weight_thresh, + opts.state_threshold); + RmEpsilon(fst, distance, ropts); + break; + } + case LIFO_QUEUE: { + LifoQueue<StateId> queue; + fst::RmEpsilonOptions<Arc, LifoQueue<StateId> > ropts( + &queue, opts.delta, opts.connect, weight_thresh, + opts.state_threshold); + RmEpsilon(fst, distance, ropts); + break; + } + case SHORTEST_FIRST_QUEUE: { + NaturalShortestFirstQueue<StateId, Weight> queue(*distance); + fst::RmEpsilonOptions<Arc, NaturalShortestFirstQueue<StateId, + Weight> > ropts( + &queue, opts.delta, opts.connect, weight_thresh, + opts.state_threshold); + RmEpsilon(fst, distance, ropts); + break; + } + case STATE_ORDER_QUEUE: { + StateOrderQueue<StateId> queue; + fst::RmEpsilonOptions<Arc, StateOrderQueue<StateId> > ropts( + &queue, opts.delta, opts.connect, weight_thresh, + opts.state_threshold); + RmEpsilon(fst, distance, ropts); + break; + } + case TOP_ORDER_QUEUE: { + TopOrderQueue<StateId> queue(*fst, EpsilonArcFilter<Arc>()); + fst::RmEpsilonOptions<Arc, TopOrderQueue<StateId> > ropts( + &queue, opts.delta, opts.connect, weight_thresh, + opts.state_threshold); + RmEpsilon(fst, distance, ropts); + break; + } + default: + FSTERROR() << "Unknown or unsupported queue type: " << opts.queue_type; + fst->SetProperties(kError, kError); + } +} + +// 1 +typedef args::Package<const FstClass &, MutableFstClass *, + bool, const RmEpsilonOptions &> RmEpsilonArgs1; + +template<class Arc> +void RmEpsilon(RmEpsilonArgs1 *args) { + const Fst<Arc> &ifst = *(args->arg1.GetFst<Arc>()); + MutableFst<Arc> *ofst = args->arg2->GetMutableFst<Arc>(); + vector<typename Arc::Weight> distance; + bool reverse = args->arg3; + + if (reverse) { + VectorFst<Arc> rfst; + Reverse(ifst, &rfst); + RmEpsilonHelper(&rfst, &distance, args->arg4); + Reverse(rfst, ofst); + } else { + *ofst = ifst; + } + RmEpsilonHelper(ofst, &distance, args->arg4); +} + +// 2 +typedef args::Package<MutableFstClass *, bool, + const WeightClass, int64, + float> RmEpsilonArgs2; + +template<class Arc> +void RmEpsilon(RmEpsilonArgs2 *args) { + MutableFst<Arc> *fst = args->arg1->GetMutableFst<Arc>(); + typename Arc::Weight w = *(args->arg3.GetWeight<typename Arc::Weight>()); + + RmEpsilon(fst, args->arg2, w, args->arg4, args->arg5); +} + +// 3 +typedef args::Package<MutableFstClass *, vector<WeightClass> *, + const RmEpsilonOptions &> RmEpsilonArgs3; + +template<class Arc> +void RmEpsilon(RmEpsilonArgs3 *args) { + MutableFst<Arc> *fst = args->arg1->GetMutableFst<Arc>(); + const RmEpsilonOptions &opts = args->arg3; + + vector<typename Arc::Weight> weights; + + RmEpsilonHelper(fst, &weights, opts); + + // Copy the weights back + args->arg2->resize(weights.size()); + for (unsigned i = 0; i < weights.size(); ++i) { + (*args->arg2)[i] = WeightClass(weights[i]); + } +} + +// +// PROTOTYPES +// + +// 1 +void RmEpsilon(const FstClass &ifst, MutableFstClass *ofst, + bool reverse = false, + const RmEpsilonOptions& opts = + fst::script::RmEpsilonOptions()); + +// 2 +void RmEpsilon(MutableFstClass *arc, bool connect = true, + const WeightClass &weight_threshold = + fst::script::WeightClass::Zero(), + int64 state_threshold = fst::kNoStateId, + float delta = fst::kDelta); + +// 3 +void RmEpsilon(MutableFstClass *fst, vector<WeightClass> *distance, + const RmEpsilonOptions &opts); + + +} // namespace script +} // namespace fst + + +#endif // FST_SCRIPT_RMEPSILON_H_ diff --git a/kaldi_io/src/tools/openfst/include/fst/script/script-impl.h b/kaldi_io/src/tools/openfst/include/fst/script/script-impl.h new file mode 100644 index 0000000..452c7c5 --- /dev/null +++ b/kaldi_io/src/tools/openfst/include/fst/script/script-impl.h @@ -0,0 +1,206 @@ + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Copyright 2005-2010 Google, Inc. +// Author: [email protected] (Jake Ratkiewicz) + +// This file defines the registration mechanism for new operations. +// These operations are designed to enable scripts to work with FST classes +// at a high level. + +// If you have a new arc type and want these operations to work with FSTs +// with that arc type, see below for the registration steps +// you must take. + +// These methods are only recommended for use in high-level scripting +// applications. Most users should use the lower-level templated versions +// corresponding to these. + +// If you have a new arc type you'd like these operations to work with, +// use the REGISTER_FST_OPERATIONS macro defined in fstcsript.h + +// If you have a custom operation you'd like to define, you need four +// components. In the following, assume you want to create a new operation +// with the signature +// +// void Foo(const FstClass &ifst, MutableFstClass *ofst); +// +// You need: +// +// 1) A way to bundle the args that your new Foo operation will take, as +// a single struct. The template structs in arg-packs.h provide a handy +// way to do this. In Foo's case, that might look like this: +// +// typedef args::Package<const FstClass &, +// MutableFstClass *> FooArgs; +// +// Note: this package of args is going to be passed by non-const pointer. +// +// 2) A function template that is able to perform Foo, given the args and +// arc type. Yours might look like this: +// +// template<class Arc> +// void Foo(FooArgs *args) { +// // Pull out the actual, arc-templated FSTs +// const Fst<Arc> &ifst = args->arg1.GetFst<Arc>(); +// MutableFst<Arc> *ofst = args->arg2->GetMutableFst<Arc>(); +// +// // actually perform foo on ifst and ofst... +// } +// +// 3) a client-facing function for your operation. This would look like +// the following: +// +// void Foo(const FstClass &ifst, MutableFstClass *ofst) { +// // Check that the arc types of the FSTs match +// if (!ArcTypesMatch(ifst, *ofst, "Foo")) return; +// // package the args +// FooArgs args(ifst, ofst); +// // Finally, call the operation +// Apply<Operation<FooArgs> >("Foo", ifst->ArcType(), &args); +// } +// +// The Apply<> function template takes care of the link between 2 and 3, +// provided you also have: +// +// 4) A registration for your new operation, on the arc types you care about. +// This can be provided easily by the REGISTER_FST_OPERATION macro in +// operations.h: +// +// REGISTER_FST_OPERATION(Foo, StdArc, FooArgs); +// REGISTER_FST_OPERATION(Foo, MyArc, FooArgs); +// // .. etc +// +// +// That's it! Now when you call Foo(const FstClass &, MutableFstClass *), +// it dispatches (in #3) via the Apply<> function to the correct +// instantiation of the template function in #2. +// + + +#ifndef FST_SCRIPT_SCRIPT_IMPL_H_ +#define FST_SCRIPT_SCRIPT_IMPL_H_ + +// +// This file contains general-purpose templates which are used in the +// implementation of the operations. +// + +#include <utility> +using std::pair; using std::make_pair; +#include <string> + +#include <fst/script/fst-class.h> +#include <fst/generic-register.h> +#include <fst/script/arg-packs.h> + +#include <fst/types.h> + +namespace fst { +namespace script { + +// +// A generic register for operations with various kinds of signatures. +// Needed since every function signature requires a new registration class. +// The pair<string, string> is understood to be the operation name and arc +// type; subclasses (or typedefs) need only provide the operation signature. +// + +template<class OperationSignature> +class GenericOperationRegister + : public GenericRegister<pair<string, string>, + OperationSignature, + GenericOperationRegister<OperationSignature> > { + public: + void RegisterOperation(const string &operation_name, + const string &arc_type, + OperationSignature op) { + this->SetEntry(make_pair(operation_name, arc_type), op); + } + + OperationSignature GetOperation( + const string &operation_name, const string &arc_type) { + return this->GetEntry(make_pair(operation_name, arc_type)); + } + + protected: + virtual string ConvertKeyToSoFilename( + const pair<string, string>& key) const { + // Just use the old-style FST for now. + string legal_type(key.second); // the arc type + ConvertToLegalCSymbol(&legal_type); + + return legal_type + "-arc.so"; + } +}; + + +// Operation package - everything you need to register a new type of operation + +// The ArgPack should be the type that's passed into each wrapped function - +// for instance, it might be a struct containing all the args. +// It's always passed by pointer, so const members should be used to enforce +// constness where it's needed. Return values should be implemented as a +// member of ArgPack as well. + +template<class ArgPack> +struct Operation { + typedef ArgPack Args; + typedef void (*OpType)(ArgPack *args); + + // The register (hash) type + typedef GenericOperationRegister<OpType> Register; + + // The register-er type + typedef GenericRegisterer<Register> Registerer; +}; + + +// Macro for registering new types of operations. + +#define REGISTER_FST_OPERATION(Op, Arc, ArgPack) \ + static fst::script::Operation<ArgPack>::Registerer \ + arc_dispatched_operation_ ## ArgPack ## Op ## Arc ## _registerer( \ + make_pair(#Op, Arc::Type()), Op<Arc>) + + +// +// Template function to apply an operation by name +// + +template<class OpReg> +void Apply(const string &op_name, const string &arc_type, + typename OpReg::Args *args) { + typename OpReg::Register *reg = OpReg::Register::GetRegister(); + + typename OpReg::OpType op = reg->GetOperation(op_name, arc_type); + + if (op == 0) { + FSTERROR() << "No operation found for \"" << op_name << "\" on " + << "arc type " << arc_type; + return; + } + + op(args); +} + + +// Helper that logs to ERROR if the arc types of a and b don't match. +// The op_name is also printed. +bool ArcTypesMatch(const FstClass &a, const FstClass &b, + const string &op_name); + +} // namespace script +} // namespace fst + +#endif // FST_SCRIPT_SCRIPT_IMPL_H_ diff --git a/kaldi_io/src/tools/openfst/include/fst/script/shortest-distance.h b/kaldi_io/src/tools/openfst/include/fst/script/shortest-distance.h new file mode 100644 index 0000000..5fc2976 --- /dev/null +++ b/kaldi_io/src/tools/openfst/include/fst/script/shortest-distance.h @@ -0,0 +1,250 @@ + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Copyright 2005-2010 Google, Inc. +// Author: [email protected] (Jake Ratkiewicz) + +#ifndef FST_SCRIPT_SHORTEST_DISTANCE_H_ +#define FST_SCRIPT_SHORTEST_DISTANCE_H_ + +#include <vector> +using std::vector; + +#include <fst/script/arg-packs.h> +#include <fst/script/fst-class.h> +#include <fst/script/weight-class.h> +#include <fst/script/prune.h> // for ArcFilterType +#include <fst/queue.h> // for QueueType +#include <fst/shortest-distance.h> + +namespace fst { +namespace script { + +enum ArcFilterType { ANY_ARC_FILTER, EPSILON_ARC_FILTER, + INPUT_EPSILON_ARC_FILTER, OUTPUT_EPSILON_ARC_FILTER }; + +// See nlp/fst/lib/shortest-distance.h for the template options class +// that this one shadows +struct ShortestDistanceOptions { + const QueueType queue_type; + const ArcFilterType arc_filter_type; + const int64 source; + const float delta; + const bool first_path; + + ShortestDistanceOptions(QueueType qt, ArcFilterType aft, int64 s, + float d) + : queue_type(qt), arc_filter_type(aft), source(s), delta(d), + first_path(false) { } +}; + + + +// 1 +typedef args::Package<const FstClass &, vector<WeightClass> *, + const ShortestDistanceOptions &> ShortestDistanceArgs1; + +template<class Queue, class Arc, class ArcFilter> +struct QueueConstructor { + // template<class Arc, class ArcFilter> + static Queue *Construct(const Fst<Arc> &, + const vector<typename Arc::Weight> *) { + return new Queue(); + } +}; + +// Specializations to deal with AutoQueue, NaturalShortestFirstQueue, +// and TopOrderQueue's different constructors +template<class Arc, class ArcFilter> +struct QueueConstructor<AutoQueue<typename Arc::StateId>, Arc, ArcFilter> { + // template<class Arc, class ArcFilter> + static AutoQueue<typename Arc::StateId> *Construct( + const Fst<Arc> &fst, + const vector<typename Arc::Weight> *distance) { + return new AutoQueue<typename Arc::StateId>(fst, distance, ArcFilter()); + } +}; + +template<class Arc, class ArcFilter> +struct QueueConstructor<NaturalShortestFirstQueue<typename Arc::StateId, + typename Arc::Weight>, + Arc, ArcFilter> { + // template<class Arc, class ArcFilter> + static NaturalShortestFirstQueue<typename Arc::StateId, typename Arc::Weight> + *Construct(const Fst<Arc> &fst, + const vector<typename Arc::Weight> *distance) { + return new NaturalShortestFirstQueue<typename Arc::StateId, + typename Arc::Weight>(*distance); + } +}; + +template<class Arc, class ArcFilter> +struct QueueConstructor<TopOrderQueue<typename Arc::StateId>, Arc, ArcFilter> { + // template<class Arc, class ArcFilter> + static TopOrderQueue<typename Arc::StateId> *Construct( + const Fst<Arc> &fst, const vector<typename Arc::Weight> *weights) { + return new TopOrderQueue<typename Arc::StateId>(fst, ArcFilter()); + } +}; + + +template<class Arc, class Queue> +void ShortestDistanceHelper(ShortestDistanceArgs1 *args) { + const Fst<Arc> &fst = *(args->arg1.GetFst<Arc>()); + const ShortestDistanceOptions &opts = args->arg3; + + vector<typename Arc::Weight> weights; + + switch (opts.arc_filter_type) { + case ANY_ARC_FILTER: { + Queue *queue = + QueueConstructor<Queue, Arc, AnyArcFilter<Arc> >::Construct( + fst, &weights); + fst::ShortestDistanceOptions<Arc, Queue, AnyArcFilter<Arc> > sdopts( + queue, AnyArcFilter<Arc>(), opts.source, opts.delta); + ShortestDistance(fst, &weights, sdopts); + delete queue; + break; + } + case EPSILON_ARC_FILTER: { + Queue *queue = + QueueConstructor<Queue, Arc, AnyArcFilter<Arc> >::Construct( + fst, &weights); + fst::ShortestDistanceOptions<Arc, Queue, + EpsilonArcFilter<Arc> > sdopts( + queue, EpsilonArcFilter<Arc>(), opts.source, opts.delta); + ShortestDistance(fst, &weights, sdopts); + delete queue; + break; + } + case INPUT_EPSILON_ARC_FILTER: { + Queue *queue = + QueueConstructor<Queue, Arc, InputEpsilonArcFilter<Arc> >::Construct( + fst, &weights); + fst::ShortestDistanceOptions<Arc, Queue, + InputEpsilonArcFilter<Arc> > sdopts( + queue, InputEpsilonArcFilter<Arc>(), opts.source, opts.delta); + ShortestDistance(fst, &weights, sdopts); + delete queue; + break; + } + case OUTPUT_EPSILON_ARC_FILTER: { + Queue *queue = + QueueConstructor<Queue, Arc, + OutputEpsilonArcFilter<Arc> >::Construct( + fst, &weights); + fst::ShortestDistanceOptions<Arc, Queue, + OutputEpsilonArcFilter<Arc> > sdopts( + queue, OutputEpsilonArcFilter<Arc>(), opts.source, opts.delta); + ShortestDistance(fst, &weights, sdopts); + delete queue; + break; + } + } + + // Copy the weights back + args->arg2->resize(weights.size()); + for (unsigned i = 0; i < weights.size(); ++i) { + (*args->arg2)[i] = WeightClass(weights[i]); + } +} + +template<class Arc> +void ShortestDistance(ShortestDistanceArgs1 *args) { + const ShortestDistanceOptions &opts = args->arg3; + typedef typename Arc::StateId StateId; + typedef typename Arc::Weight Weight; + + // Must consider (opts.queue_type x opts.filter_type) options + switch (opts.queue_type) { + default: + FSTERROR() << "Unknown queue type." << opts.queue_type; + + case AUTO_QUEUE: + ShortestDistanceHelper<Arc, AutoQueue<StateId> >(args); + return; + + case FIFO_QUEUE: + ShortestDistanceHelper<Arc, FifoQueue<StateId> >(args); + return; + + case LIFO_QUEUE: + ShortestDistanceHelper<Arc, LifoQueue<StateId> >(args); + return; + + case SHORTEST_FIRST_QUEUE: + ShortestDistanceHelper<Arc, + NaturalShortestFirstQueue<StateId, Weight> >(args); + return; + + case STATE_ORDER_QUEUE: + ShortestDistanceHelper<Arc, StateOrderQueue<StateId> >(args); + return; + + case TOP_ORDER_QUEUE: + ShortestDistanceHelper<Arc, TopOrderQueue<StateId> >(args); + return; + } +} + +// 2 +typedef args::Package<const FstClass&, vector<WeightClass>*, + bool, double> ShortestDistanceArgs2; + +template<class Arc> +void ShortestDistance(ShortestDistanceArgs2 *args) { + const Fst<Arc> &fst = *(args->arg1.GetFst<Arc>()); + vector<typename Arc::Weight> distance; + + ShortestDistance(fst, &distance, args->arg3, args->arg4); + + // convert the typed weights back into weightclass + vector<WeightClass> *retval = args->arg2; + retval->resize(distance.size()); + + for (unsigned i = 0; i < distance.size(); ++i) { + (*retval)[i] = WeightClass(distance[i]); + } +} + +// 3 +typedef args::WithReturnValue<WeightClass, + const FstClass &> ShortestDistanceArgs3; + +template<class Arc> +void ShortestDistance(ShortestDistanceArgs3 *args) { + const Fst<Arc> &fst = *(args->args.GetFst<Arc>()); + + args->retval = WeightClass(ShortestDistance(fst)); +} + + +// 1 +void ShortestDistance(const FstClass &fst, vector<WeightClass> *distance, + const ShortestDistanceOptions &opts); + +// 2 +void ShortestDistance(const FstClass &ifst, vector<WeightClass> *distance, + bool reverse = false, double delta = fst::kDelta); + +#ifndef SWIG +// 3 +WeightClass ShortestDistance(const FstClass &ifst); +#endif + +} // namespace script +} // namespace fst + + + +#endif // FST_SCRIPT_SHORTEST_DISTANCE_H_ diff --git a/kaldi_io/src/tools/openfst/include/fst/script/shortest-path.h b/kaldi_io/src/tools/openfst/include/fst/script/shortest-path.h new file mode 100644 index 0000000..b3a3eb9 --- /dev/null +++ b/kaldi_io/src/tools/openfst/include/fst/script/shortest-path.h @@ -0,0 +1,190 @@ + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Copyright 2005-2010 Google, Inc. +// Author: [email protected] (Jake Ratkiewicz) + +#ifndef FST_SCRIPT_SHORTEST_PATH_H_ +#define FST_SCRIPT_SHORTEST_PATH_H_ + +#include <vector> +using std::vector; + +#include <fst/script/arg-packs.h> +#include <fst/script/fst-class.h> +#include <fst/script/weight-class.h> +#include <fst/shortest-path.h> +#include <fst/script/shortest-distance.h> // for ShortestDistanceOptions + +namespace fst { +namespace script { + +struct ShortestPathOptions + : public fst::script::ShortestDistanceOptions { + const size_t nshortest; + const bool unique; + const bool has_distance; + const bool first_path; + const WeightClass weight_threshold; + const int64 state_threshold; + + ShortestPathOptions(QueueType qt, size_t n = 1, + bool u = false, bool hasdist = false, + float d = fst::kDelta, bool fp = false, + WeightClass w = fst::script::WeightClass::Zero(), + int64 s = fst::kNoStateId) + : ShortestDistanceOptions(qt, ANY_ARC_FILTER, kNoStateId, d), + nshortest(n), unique(u), has_distance(hasdist), first_path(fp), + weight_threshold(w), state_threshold(s) { } +}; + +typedef args::Package<const FstClass &, MutableFstClass *, + vector<WeightClass> *, const ShortestPathOptions &> + ShortestPathArgs1; + + +template<class Arc> +void ShortestPath(ShortestPathArgs1 *args) { + const Fst<Arc> &ifst = *(args->arg1.GetFst<Arc>()); + MutableFst<Arc> *ofst = args->arg2->GetMutableFst<Arc>(); + const ShortestPathOptions &opts = args->arg4; + typedef typename Arc::StateId StateId; + typedef typename Arc::Weight Weight; + typedef AnyArcFilter<Arc> ArcFilter; + + vector<typename Arc::Weight> weights; + typename Arc::Weight weight_threshold = + *(opts.weight_threshold.GetWeight<Weight>()); + + switch (opts.queue_type) { + case AUTO_QUEUE: { + typedef AutoQueue<StateId> Queue; + Queue *queue = QueueConstructor<Queue, Arc, + ArcFilter>::Construct(ifst, &weights); + fst::ShortestPathOptions<Arc, Queue, ArcFilter> spopts( + queue, ArcFilter(), opts.nshortest, opts.unique, + opts.has_distance, opts.delta, opts.first_path, + weight_threshold, opts.state_threshold); + ShortestPath(ifst, ofst, &weights, spopts); + delete queue; + return; + } + case FIFO_QUEUE: { + typedef FifoQueue<StateId> Queue; + Queue *queue = QueueConstructor<Queue, Arc, + ArcFilter>::Construct(ifst, &weights); + fst::ShortestPathOptions<Arc, Queue, ArcFilter> spopts( + queue, ArcFilter(), opts.nshortest, opts.unique, + opts.has_distance, opts.delta, opts.first_path, + weight_threshold, opts.state_threshold); + ShortestPath(ifst, ofst, &weights, spopts); + delete queue; + return; + } + case LIFO_QUEUE: { + typedef LifoQueue<StateId> Queue; + Queue *queue = QueueConstructor<Queue, Arc, + ArcFilter >::Construct(ifst, &weights); + fst::ShortestPathOptions<Arc, Queue, ArcFilter> spopts( + queue, ArcFilter(), opts.nshortest, opts.unique, + opts.has_distance, opts.delta, opts.first_path, + weight_threshold, opts.state_threshold); + ShortestPath(ifst, ofst, &weights, spopts); + delete queue; + return; + } + case SHORTEST_FIRST_QUEUE: { + typedef NaturalShortestFirstQueue<StateId, Weight> Queue; + Queue *queue = QueueConstructor<Queue, Arc, + ArcFilter>::Construct(ifst, &weights); + fst::ShortestPathOptions<Arc, Queue, ArcFilter> spopts( + queue, ArcFilter(), opts.nshortest, opts.unique, + opts.has_distance, opts.delta, opts.first_path, + weight_threshold, opts.state_threshold); + ShortestPath(ifst, ofst, &weights, spopts); + delete queue; + return; + } + case STATE_ORDER_QUEUE: { + typedef StateOrderQueue<StateId> Queue; + Queue *queue = QueueConstructor<Queue, Arc, + ArcFilter>::Construct(ifst, &weights); + fst::ShortestPathOptions<Arc, Queue, ArcFilter> spopts( + queue, ArcFilter(), opts.nshortest, opts.unique, + opts.has_distance, opts.delta, opts.first_path, + weight_threshold, opts.state_threshold); + ShortestPath(ifst, ofst, &weights, spopts); + delete queue; + return; + } + case TOP_ORDER_QUEUE: { + typedef TopOrderQueue<StateId> Queue; + Queue *queue = QueueConstructor<Queue, Arc, + ArcFilter>::Construct(ifst, &weights); + fst::ShortestPathOptions<Arc, Queue, ArcFilter> spopts( + queue, ArcFilter(), opts.nshortest, opts.unique, + opts.has_distance, opts.delta, opts.first_path, + weight_threshold, opts.state_threshold); + ShortestPath(ifst, ofst, &weights, spopts); + delete queue; + return; + } + default: + FSTERROR() << "Unknown queue type: " << opts.queue_type; + ofst->SetProperties(kError, kError); + } + + // Copy the weights back + args->arg3->resize(weights.size()); + for (unsigned i = 0; i < weights.size(); ++i) { + (*args->arg3)[i] = WeightClass(weights[i]); + } +} + +// 2 +typedef args::Package<const FstClass &, MutableFstClass *, + size_t, bool, bool, WeightClass, + int64> ShortestPathArgs2; + +template<class Arc> +void ShortestPath(ShortestPathArgs2 *args) { + const Fst<Arc> &ifst = *(args->arg1.GetFst<Arc>()); + MutableFst<Arc> *ofst = args->arg2->GetMutableFst<Arc>(); + typename Arc::Weight weight_threshold = + *(args->arg6.GetWeight<typename Arc::Weight>()); + + ShortestPath(ifst, ofst, args->arg3, args->arg4, args->arg5, + weight_threshold, args->arg7); +} + + +// 1 +void ShortestPath(const FstClass &ifst, MutableFstClass *ofst, + vector<WeightClass> *distance, + const ShortestPathOptions &opts); + + +// 2 +void ShortestPath(const FstClass &ifst, MutableFstClass *ofst, + size_t n = 1, bool unique = false, + bool first_path = false, + WeightClass weight_threshold = + fst::script::WeightClass::Zero(), + int64 state_threshold = fst::kNoStateId); + +} // namespace script +} // namespace fst + + + +#endif // FST_SCRIPT_SHORTEST_PATH_H_ diff --git a/kaldi_io/src/tools/openfst/include/fst/script/symbols.h b/kaldi_io/src/tools/openfst/include/fst/script/symbols.h new file mode 100644 index 0000000..927600a --- /dev/null +++ b/kaldi_io/src/tools/openfst/include/fst/script/symbols.h @@ -0,0 +1,20 @@ + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Copyright 2005-2010 Google, Inc. +// Author: [email protected] (Jake Ratkiewicz) + +#ifndef FST_SCRIPT_SYMBOLS_H_ +#define FST_SCRIPT_SYMBOLS_H_ + +#endif // FST_SCRIPT_SYMBOLS_H_ diff --git a/kaldi_io/src/tools/openfst/include/fst/script/synchronize.h b/kaldi_io/src/tools/openfst/include/fst/script/synchronize.h new file mode 100644 index 0000000..3c0c905 --- /dev/null +++ b/kaldi_io/src/tools/openfst/include/fst/script/synchronize.h @@ -0,0 +1,42 @@ + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Copyright 2005-2010 Google, Inc. +// Author: [email protected] (Jake Ratkiewicz) + +#ifndef FST_SCRIPT_SYNCHRONIZE_H_ +#define FST_SCRIPT_SYNCHRONIZE_H_ + +#include <fst/script/arg-packs.h> +#include <fst/script/fst-class.h> +#include <fst/synchronize.h> + +namespace fst { +namespace script { + +typedef args::Package<const FstClass &, MutableFstClass *> SynchronizeArgs; + +template<class Arc> +void Synchronize(SynchronizeArgs *args) { + const Fst<Arc> &ifst = *(args->arg1.GetFst<Arc>()); + MutableFst<Arc> *ofst = args->arg2->GetMutableFst<Arc>(); + + Synchronize(ifst, ofst); +} + +void Synchronize(const FstClass &ifst, MutableFstClass *ofst); + +} // namespace script +} // namespace fst + +#endif // FST_SCRIPT_SYNCHRONIZE_H_ diff --git a/kaldi_io/src/tools/openfst/include/fst/script/text-io.h b/kaldi_io/src/tools/openfst/include/fst/script/text-io.h new file mode 100644 index 0000000..d97a007 --- /dev/null +++ b/kaldi_io/src/tools/openfst/include/fst/script/text-io.h @@ -0,0 +1,51 @@ +// text-io.h + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Copyright 2005-2010 Google, Inc. +// Author: [email protected] (Michael Riley) +// Modified: [email protected] (Jake Ratkiewicz) to work with generic WeightClass +// +// \file +// Utilities for reading and writing textual strings representing +// states, labels, and weights and files specifying label-label pairs +// and potentials (state-weight pairs). +// + +#ifndef FST_SCRIPT_TEXT_IO_H__ +#define FST_SCRIPT_TEXT_IO_H__ + +#include <string> +#include <vector> +using std::vector; + + +#include <iostream> +#include <fstream> +#include <sstream> +#include <fst/script/weight-class.h> + +namespace fst { +namespace script { + +bool ReadPotentials(const string &weight_type, + const string& filename, + vector<WeightClass>* potential); + +bool WritePotentials(const string& filename, + const vector<WeightClass>& potential); + +} // namespace script +} // namespace fst + +#endif // FST_SCRIPT_TEXT_IO_H__ diff --git a/kaldi_io/src/tools/openfst/include/fst/script/topsort.h b/kaldi_io/src/tools/openfst/include/fst/script/topsort.h new file mode 100644 index 0000000..4e27e48 --- /dev/null +++ b/kaldi_io/src/tools/openfst/include/fst/script/topsort.h @@ -0,0 +1,40 @@ + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Copyright 2005-2010 Google, Inc. +// Author: [email protected] (Jake Ratkiewicz) + +#ifndef FST_SCRIPT_TOPSORT_H_ +#define FST_SCRIPT_TOPSORT_H_ + +#include <fst/script/arg-packs.h> +#include <fst/script/fst-class.h> +#include <fst/topsort.h> + +namespace fst { +namespace script { + +typedef args::WithReturnValue<bool, MutableFstClass*> TopSortArgs; + +template<class Arc> +void TopSort(TopSortArgs *args) { + MutableFst<Arc> *fst = args->args->GetMutableFst<Arc>(); + args->retval = TopSort(fst); +} + +bool TopSort(MutableFstClass *fst); + +} // namespace script +} // namespace fst + +#endif // FST_SCRIPT_TOPSORT_H_ diff --git a/kaldi_io/src/tools/openfst/include/fst/script/union.h b/kaldi_io/src/tools/openfst/include/fst/script/union.h new file mode 100644 index 0000000..780e484 --- /dev/null +++ b/kaldi_io/src/tools/openfst/include/fst/script/union.h @@ -0,0 +1,42 @@ + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Copyright 2005-2010 Google, Inc. +// Author: [email protected] (Jake Ratkiewicz) + +#ifndef FST_SCRIPT_UNION_H_ +#define FST_SCRIPT_UNION_H_ + +#include <fst/script/arg-packs.h> +#include <fst/script/fst-class.h> +#include <fst/union.h> + +namespace fst { +namespace script { + +typedef args::Package<MutableFstClass *, const FstClass &> UnionArgs; + +template<class Arc> +void Union(UnionArgs *args) { + MutableFst<Arc> *fst1 = args->arg1->GetMutableFst<Arc>(); + const Fst<Arc> &fst2 = *(args->arg2.GetFst<Arc>()); + + Union(fst1, fst2); +} + +void Union(MutableFstClass *fst1, const FstClass &fst2); + +} // namespace script +} // namespace fst + +#endif // FST_SCRIPT_UNION_H_ diff --git a/kaldi_io/src/tools/openfst/include/fst/script/verify.h b/kaldi_io/src/tools/openfst/include/fst/script/verify.h new file mode 100644 index 0000000..6904003 --- /dev/null +++ b/kaldi_io/src/tools/openfst/include/fst/script/verify.h @@ -0,0 +1,40 @@ + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Copyright 2005-2010 Google, Inc. +// Author: [email protected] (Jeffrey Sorensen) + +#ifndef FST_SCRIPT_VERIFY_H_ +#define FST_SCRIPT_VERIFY_H_ + +#include <fst/script/arg-packs.h> +#include <fst/script/fst-class.h> +#include <fst/verify.h> + +namespace fst { +namespace script { + +typedef args::WithReturnValue<bool, const FstClass *> VerifyArgs; + +template<class Arc> +void Verify(VerifyArgs *args) { + const Fst<Arc> *fst = args->args->GetFst<Arc>(); + args->retval = Verify(*fst); +} + +bool Verify(const FstClass &fst1); + +} // namespace script +} // namespace fst + +#endif // FST_SCRIPT_VERIFY_H_ diff --git a/kaldi_io/src/tools/openfst/include/fst/script/weight-class.h b/kaldi_io/src/tools/openfst/include/fst/script/weight-class.h new file mode 100644 index 0000000..b9f7ddf --- /dev/null +++ b/kaldi_io/src/tools/openfst/include/fst/script/weight-class.h @@ -0,0 +1,223 @@ + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Copyright 2005-2010 Google, Inc. +// Author: [email protected] (Jake Ratkiewicz) + +// Represents a generic weight in an FST -- that is, represents a specific +// type of weight underneath while hiding that type from a client. + + +#ifndef FST_SCRIPT_WEIGHT_CLASS_H_ +#define FST_SCRIPT_WEIGHT_CLASS_H_ + +#include <string> + +#include <fst/generic-register.h> +#include <fst/util.h> + +namespace fst { +namespace script { + +class WeightImplBase { + public: + virtual WeightImplBase *Copy() const = 0; + virtual void Print(ostream *o) const = 0; + virtual const string &Type() const = 0; + virtual string to_string() const = 0; + virtual bool operator == (const WeightImplBase &other) const = 0; + virtual ~WeightImplBase() { } +}; + +template<class W> +struct WeightClassImpl : public WeightImplBase { + W weight; + + explicit WeightClassImpl(const W& weight) : weight(weight) { } + + virtual WeightClassImpl<W> *Copy() const { + return new WeightClassImpl<W>(weight); + } + + virtual const string &Type() const { return W::Type(); } + + virtual void Print(ostream *o) const { + *o << weight; + } + + virtual string to_string() const { + string str; + WeightToStr(weight, &str); + return str; + } + + virtual bool operator == (const WeightImplBase &other) const { + if (Type() != other.Type()) { + return false; + } else { + const WeightClassImpl<W> *typed_other = + static_cast<const WeightClassImpl<W> *>(&other); + + return typed_other->weight == weight; + } + } +}; + + +class WeightClass { + public: + WeightClass() : element_type_(ZERO), impl_(0) { } + + template<class W> + explicit WeightClass(const W& weight) + : element_type_(OTHER), impl_(new WeightClassImpl<W>(weight)) { } + + WeightClass(const string &weight_type, const string &weight_str); + + WeightClass(const WeightClass &other) : + element_type_(other.element_type_), + impl_(other.impl_ ? other.impl_->Copy() : 0) { } + + WeightClass &operator = (const WeightClass &other) { + if (impl_) delete impl_; + impl_ = other.impl_ ? other.impl_->Copy() : 0; + element_type_ = other.element_type_; + return *this; + } + + template<class W> + const W* GetWeight() const; + + string to_string() const { + switch (element_type_) { + case ZERO: + return "ZERO"; + case ONE: + return "ONE"; + default: + case OTHER: + return impl_->to_string(); + } + } + + bool operator == (const WeightClass &other) const { + return element_type_ == other.element_type_ && + ((impl_ && other.impl_ && (*impl_ == *other.impl_)) || + (impl_ == 0 && other.impl_ == 0)); + } + + static const WeightClass &Zero() { + static WeightClass w(ZERO); + + return w; + } + + static const WeightClass &One() { + static WeightClass w(ONE); + + return w; + } + + const string &Type() const { + if (impl_) return impl_->Type(); + static const string no_type = "none"; + return no_type; + } + + + ~WeightClass() { if (impl_) delete impl_; } + private: + enum ElementType { ZERO, ONE, OTHER }; + ElementType element_type_; + + WeightImplBase *impl_; + + explicit WeightClass(ElementType et) : element_type_(et), impl_(0) { } + + friend ostream &operator << (ostream &o, const WeightClass &c); +}; + +template<class W> +const W* WeightClass::GetWeight() const { + // We need to store zero and one as statics, because the weight type + // W might return them as temporaries. We're returning a pointer, + // and it won't do to get the address of a temporary. + static const W zero = W::Zero(); + static const W one = W::One(); + + if (element_type_ == ZERO) { + return &zero; + } else if (element_type_ == ONE) { + return &one; + } else { + if (W::Type() != impl_->Type()) { + return NULL; + } else { + WeightClassImpl<W> *typed_impl = + static_cast<WeightClassImpl<W> *>(impl_); + return &typed_impl->weight; + } + } +} + +// +// Registration for generic weight types. +// + +typedef WeightImplBase* (*StrToWeightImplBaseT)(const string &str, + const string &src, + size_t nline); + +template<class W> +WeightImplBase* StrToWeightImplBase(const string &str, + const string &src, size_t nline) { + return new WeightClassImpl<W>(StrToWeight<W>(str, src, nline)); +} + +// The following confuses swig, and doesn't need to be wrapped anyway. +#ifndef SWIG +ostream& operator << (ostream &o, const WeightClass &c); + +class WeightClassRegister : public GenericRegister<string, + StrToWeightImplBaseT, + WeightClassRegister> { + protected: + virtual string ConvertKeyToSoFilename(const string &key) const { + return key + ".so"; + } +}; + +typedef GenericRegisterer<WeightClassRegister> WeightClassRegisterer; +#endif + +// internal version, needs to be called by wrapper in order for +// macro args to expand +#define REGISTER_FST_WEIGHT__(Weight, line) \ + static WeightClassRegisterer weight_registerer ## _ ## line( \ + Weight::Type(), \ + StrToWeightImplBase<Weight>) + +// This layer is where __FILE__ and __LINE__ are expanded +#define REGISTER_FST_WEIGHT_EXPANDER(Weight, line) \ + REGISTER_FST_WEIGHT__(Weight, line) + +// +// Macro for registering new weight types. Clients call this. +// +#define REGISTER_FST_WEIGHT(Weight) \ + REGISTER_FST_WEIGHT_EXPANDER(Weight, __LINE__) + +} // namespace script +} // namespace fst + +#endif // FST_SCRIPT_WEIGHT_CLASS_H_ diff --git a/kaldi_io/src/tools/openfst/include/fst/shortest-distance.h b/kaldi_io/src/tools/openfst/include/fst/shortest-distance.h new file mode 100644 index 0000000..ec47a14 --- /dev/null +++ b/kaldi_io/src/tools/openfst/include/fst/shortest-distance.h @@ -0,0 +1,348 @@ +// shortest-distance.h + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Copyright 2005-2010 Google, Inc. +// Author: [email protected] (Cyril Allauzen) +// +// \file +// Functions and classes to find shortest distance in an FST. + +#ifndef FST_LIB_SHORTEST_DISTANCE_H__ +#define FST_LIB_SHORTEST_DISTANCE_H__ + +#include <deque> +using std::deque; +#include <vector> +using std::vector; + +#include <fst/arcfilter.h> +#include <fst/cache.h> +#include <fst/queue.h> +#include <fst/reverse.h> +#include <fst/test-properties.h> + + +namespace fst { + +template <class Arc, class Queue, class ArcFilter> +struct ShortestDistanceOptions { + typedef typename Arc::StateId StateId; + + Queue *state_queue; // Queue discipline used; owned by caller + ArcFilter arc_filter; // Arc filter (e.g., limit to only epsilon graph) + StateId source; // If kNoStateId, use the Fst's initial state + float delta; // Determines the degree of convergence required + bool first_path; // For a semiring with the path property (o.w. + // undefined), compute the shortest-distances along + // along the first path to a final state found + // by the algorithm. That path is the shortest-path + // only if the FST has a unique final state (or all + // the final states have the same final weight), the + // queue discipline is shortest-first and all the + // weights in the FST are between One() and Zero() + // according to NaturalLess. + + ShortestDistanceOptions(Queue *q, ArcFilter filt, StateId src = kNoStateId, + float d = kDelta) + : state_queue(q), arc_filter(filt), source(src), delta(d), + first_path(false) {} +}; + + +// Computation state of the shortest-distance algorithm. Reusable +// information is maintained across calls to member function +// ShortestDistance(source) when 'retain' is true for improved +// efficiency when calling multiple times from different source states +// (e.g., in epsilon removal). Contrary to usual conventions, 'fst' +// may not be freed before this class. Vector 'distance' should not be +// modified by the user between these calls. +// The Error() method returns true if an error was encountered. +template<class Arc, class Queue, class ArcFilter> +class ShortestDistanceState { + public: + typedef typename Arc::StateId StateId; + typedef typename Arc::Weight Weight; + + ShortestDistanceState( + const Fst<Arc> &fst, + vector<Weight> *distance, + const ShortestDistanceOptions<Arc, Queue, ArcFilter> &opts, + bool retain) + : fst_(fst), distance_(distance), state_queue_(opts.state_queue), + arc_filter_(opts.arc_filter), delta_(opts.delta), + first_path_(opts.first_path), retain_(retain), source_id_(0), + error_(false) { + distance_->clear(); + } + + ~ShortestDistanceState() {} + + void ShortestDistance(StateId source); + + bool Error() const { return error_; } + + private: + const Fst<Arc> &fst_; + vector<Weight> *distance_; + Queue *state_queue_; + ArcFilter arc_filter_; + float delta_; + bool first_path_; + bool retain_; // Retain and reuse information across calls + + vector<Weight> rdistance_; // Relaxation distance. + vector<bool> enqueued_; // Is state enqueued? + vector<StateId> sources_; // Source ID for ith state in 'distance_', + // 'rdistance_', and 'enqueued_' if retained. + StateId source_id_; // Unique ID characterizing each call to SD + + bool error_; +}; + +// Compute the shortest distance. If 'source' is kNoStateId, use +// the initial state of the Fst. +template <class Arc, class Queue, class ArcFilter> +void ShortestDistanceState<Arc, Queue, ArcFilter>::ShortestDistance( + StateId source) { + if (fst_.Start() == kNoStateId) { + if (fst_.Properties(kError, false)) error_ = true; + return; + } + + if (!(Weight::Properties() & kRightSemiring)) { + FSTERROR() << "ShortestDistance: Weight needs to be right distributive: " + << Weight::Type(); + error_ = true; + return; + } + + if (first_path_ && !(Weight::Properties() & kPath)) { + FSTERROR() << "ShortestDistance: first_path option disallowed when " + << "Weight does not have the path property: " + << Weight::Type(); + error_ = true; + return; + } + + state_queue_->Clear(); + + if (!retain_) { + distance_->clear(); + rdistance_.clear(); + enqueued_.clear(); + } + + if (source == kNoStateId) + source = fst_.Start(); + + while (distance_->size() <= source) { + distance_->push_back(Weight::Zero()); + rdistance_.push_back(Weight::Zero()); + enqueued_.push_back(false); + } + if (retain_) { + while (sources_.size() <= source) + sources_.push_back(kNoStateId); + sources_[source] = source_id_; + } + (*distance_)[source] = Weight::One(); + rdistance_[source] = Weight::One(); + enqueued_[source] = true; + + state_queue_->Enqueue(source); + + while (!state_queue_->Empty()) { + StateId s = state_queue_->Head(); + state_queue_->Dequeue(); + while (distance_->size() <= s) { + distance_->push_back(Weight::Zero()); + rdistance_.push_back(Weight::Zero()); + enqueued_.push_back(false); + } + if (first_path_ && (fst_.Final(s) != Weight::Zero())) + break; + enqueued_[s] = false; + Weight r = rdistance_[s]; + rdistance_[s] = Weight::Zero(); + for (ArcIterator< Fst<Arc> > aiter(fst_, s); + !aiter.Done(); + aiter.Next()) { + const Arc &arc = aiter.Value(); + if (!arc_filter_(arc)) + continue; + while (distance_->size() <= arc.nextstate) { + distance_->push_back(Weight::Zero()); + rdistance_.push_back(Weight::Zero()); + enqueued_.push_back(false); + } + if (retain_) { + while (sources_.size() <= arc.nextstate) + sources_.push_back(kNoStateId); + if (sources_[arc.nextstate] != source_id_) { + (*distance_)[arc.nextstate] = Weight::Zero(); + rdistance_[arc.nextstate] = Weight::Zero(); + enqueued_[arc.nextstate] = false; + sources_[arc.nextstate] = source_id_; + } + } + Weight &nd = (*distance_)[arc.nextstate]; + Weight &nr = rdistance_[arc.nextstate]; + Weight w = Times(r, arc.weight); + if (!ApproxEqual(nd, Plus(nd, w), delta_)) { + nd = Plus(nd, w); + nr = Plus(nr, w); + if (!nd.Member() || !nr.Member()) { + error_ = true; + return; + } + if (!enqueued_[arc.nextstate]) { + state_queue_->Enqueue(arc.nextstate); + enqueued_[arc.nextstate] = true; + } else { + state_queue_->Update(arc.nextstate); + } + } + } + } + ++source_id_; + if (fst_.Properties(kError, false)) error_ = true; +} + + +// Shortest-distance algorithm: this version allows fine control +// via the options argument. See below for a simpler interface. +// +// This computes the shortest distance from the 'opts.source' state to +// each visited state S and stores the value in the 'distance' vector. +// An unvisited state S has distance Zero(), which will be stored in +// the 'distance' vector if S is less than the maximum visited state. +// The state queue discipline, arc filter, and convergence delta are +// taken in the options argument. +// The 'distance' vector will contain a unique element for which +// Member() is false if an error was encountered. +// +// The weights must must be right distributive and k-closed (i.e., 1 + +// x + x^2 + ... + x^(k +1) = 1 + x + x^2 + ... + x^k). +// +// The algorithm is from Mohri, "Semiring Framweork and Algorithms for +// Shortest-Distance Problems", Journal of Automata, Languages and +// Combinatorics 7(3):321-350, 2002. The complexity of algorithm +// depends on the properties of the semiring and the queue discipline +// used. Refer to the paper for more details. +template<class Arc, class Queue, class ArcFilter> +void ShortestDistance( + const Fst<Arc> &fst, + vector<typename Arc::Weight> *distance, + const ShortestDistanceOptions<Arc, Queue, ArcFilter> &opts) { + + ShortestDistanceState<Arc, Queue, ArcFilter> + sd_state(fst, distance, opts, false); + sd_state.ShortestDistance(opts.source); + if (sd_state.Error()) { + distance->clear(); + distance->resize(1, Arc::Weight::NoWeight()); + } +} + +// Shortest-distance algorithm: simplified interface. See above for a +// version that allows finer control. +// +// If 'reverse' is false, this computes the shortest distance from the +// initial state to each state S and stores the value in the +// 'distance' vector. If 'reverse' is true, this computes the shortest +// distance from each state to the final states. An unvisited state S +// has distance Zero(), which will be stored in the 'distance' vector +// if S is less than the maximum visited state. The state queue +// discipline is automatically-selected. +// The 'distance' vector will contain a unique element for which +// Member() is false if an error was encountered. +// +// The weights must must be right (left) distributive if reverse is +// false (true) and k-closed (i.e., 1 + x + x^2 + ... + x^(k +1) = 1 + +// x + x^2 + ... + x^k). +// +// The algorithm is from Mohri, "Semiring Framweork and Algorithms for +// Shortest-Distance Problems", Journal of Automata, Languages and +// Combinatorics 7(3):321-350, 2002. The complexity of algorithm +// depends on the properties of the semiring and the queue discipline +// used. Refer to the paper for more details. +template<class Arc> +void ShortestDistance(const Fst<Arc> &fst, + vector<typename Arc::Weight> *distance, + bool reverse = false, + float delta = kDelta) { + typedef typename Arc::StateId StateId; + typedef typename Arc::Weight Weight; + + if (!reverse) { + AnyArcFilter<Arc> arc_filter; + AutoQueue<StateId> state_queue(fst, distance, arc_filter); + ShortestDistanceOptions< Arc, AutoQueue<StateId>, AnyArcFilter<Arc> > + opts(&state_queue, arc_filter); + opts.delta = delta; + ShortestDistance(fst, distance, opts); + } else { + typedef ReverseArc<Arc> ReverseArc; + typedef typename ReverseArc::Weight ReverseWeight; + AnyArcFilter<ReverseArc> rarc_filter; + VectorFst<ReverseArc> rfst; + Reverse(fst, &rfst); + vector<ReverseWeight> rdistance; + AutoQueue<StateId> state_queue(rfst, &rdistance, rarc_filter); + ShortestDistanceOptions< ReverseArc, AutoQueue<StateId>, + AnyArcFilter<ReverseArc> > + ropts(&state_queue, rarc_filter); + ropts.delta = delta; + ShortestDistance(rfst, &rdistance, ropts); + distance->clear(); + if (rdistance.size() == 1 && !rdistance[0].Member()) { + distance->resize(1, Arc::Weight::NoWeight()); + return; + } + while (distance->size() < rdistance.size() - 1) + distance->push_back(rdistance[distance->size() + 1].Reverse()); + } +} + + +// Return the sum of the weight of all successful paths in an FST, i.e., +// the shortest-distance from the initial state to the final states. +// Returns a weight such that Member() is false if an error was encountered. +template <class Arc> +typename Arc::Weight ShortestDistance(const Fst<Arc> &fst, float delta = kDelta) { + typedef typename Arc::Weight Weight; + typedef typename Arc::StateId StateId; + vector<Weight> distance; + if (Weight::Properties() & kRightSemiring) { + ShortestDistance(fst, &distance, false, delta); + if (distance.size() == 1 && !distance[0].Member()) + return Arc::Weight::NoWeight(); + Weight sum = Weight::Zero(); + for (StateId s = 0; s < distance.size(); ++s) + sum = Plus(sum, Times(distance[s], fst.Final(s))); + return sum; + } else { + ShortestDistance(fst, &distance, true, delta); + StateId s = fst.Start(); + if (distance.size() == 1 && !distance[0].Member()) + return Arc::Weight::NoWeight(); + return s != kNoStateId && s < distance.size() ? + distance[s] : Weight::Zero(); + } +} + + +} // namespace fst + +#endif // FST_LIB_SHORTEST_DISTANCE_H__ diff --git a/kaldi_io/src/tools/openfst/include/fst/shortest-path.h b/kaldi_io/src/tools/openfst/include/fst/shortest-path.h new file mode 100644 index 0000000..9cd13d9 --- /dev/null +++ b/kaldi_io/src/tools/openfst/include/fst/shortest-path.h @@ -0,0 +1,501 @@ +// shortest-path.h + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Copyright 2005-2010 Google, Inc. +// Author: [email protected] (Cyril Allauzen) +// +// \file +// Functions to find shortest paths in an FST. + +#ifndef FST_LIB_SHORTEST_PATH_H__ +#define FST_LIB_SHORTEST_PATH_H__ + +#include <functional> +#include <utility> +using std::pair; using std::make_pair; +#include <vector> +using std::vector; + +#include <fst/cache.h> +#include <fst/determinize.h> +#include <fst/queue.h> +#include <fst/shortest-distance.h> +#include <fst/test-properties.h> + + +namespace fst { + +template <class Arc, class Queue, class ArcFilter> +struct ShortestPathOptions + : public ShortestDistanceOptions<Arc, Queue, ArcFilter> { + typedef typename Arc::StateId StateId; + typedef typename Arc::Weight Weight; + size_t nshortest; // return n-shortest paths + bool unique; // only return paths with distinct input strings + bool has_distance; // distance vector already contains the + // shortest distance from the initial state + bool first_path; // Single shortest path stops after finding the first + // path to a final state. That path is the shortest path + // only when using the ShortestFirstQueue and + // only when all the weights in the FST are between + // One() and Zero() according to NaturalLess. + Weight weight_threshold; // pruning weight threshold. + StateId state_threshold; // pruning state threshold. + + ShortestPathOptions(Queue *q, ArcFilter filt, size_t n = 1, bool u = false, + bool hasdist = false, float d = kDelta, + bool fp = false, Weight w = Weight::Zero(), + StateId s = kNoStateId) + : ShortestDistanceOptions<Arc, Queue, ArcFilter>(q, filt, kNoStateId, d), + nshortest(n), unique(u), has_distance(hasdist), first_path(fp), + weight_threshold(w), state_threshold(s) {} +}; + + +// Shortest-path algorithm: normally not called directly; prefer +// 'ShortestPath' below with n=1. 'ofst' contains the shortest path in +// 'ifst'. 'distance' returns the shortest distances from the source +// state to each state in 'ifst'. 'opts' is used to specify options +// such as the queue discipline, the arc filter and delta. +// +// The shortest path is the lowest weight path w.r.t. the natural +// semiring order. +// +// The weights need to be right distributive and have the path (kPath) +// property. +template<class Arc, class Queue, class ArcFilter> +void SingleShortestPath(const Fst<Arc> &ifst, + MutableFst<Arc> *ofst, + vector<typename Arc::Weight> *distance, + ShortestPathOptions<Arc, Queue, ArcFilter> &opts) { + typedef typename Arc::StateId StateId; + typedef typename Arc::Weight Weight; + + ofst->DeleteStates(); + ofst->SetInputSymbols(ifst.InputSymbols()); + ofst->SetOutputSymbols(ifst.OutputSymbols()); + + if (ifst.Start() == kNoStateId) { + if (ifst.Properties(kError, false)) ofst->SetProperties(kError, kError); + return; + } + + vector<bool> enqueued; + vector<StateId> parent; + vector<Arc> arc_parent; + + Queue *state_queue = opts.state_queue; + StateId source = opts.source == kNoStateId ? ifst.Start() : opts.source; + Weight f_distance = Weight::Zero(); + StateId f_parent = kNoStateId; + + distance->clear(); + state_queue->Clear(); + if (opts.nshortest != 1) { + FSTERROR() << "SingleShortestPath: for nshortest > 1, use ShortestPath" + << " instead"; + ofst->SetProperties(kError, kError); + return; + } + if (opts.weight_threshold != Weight::Zero() || + opts.state_threshold != kNoStateId) { + FSTERROR() << + "SingleShortestPath: weight and state thresholds not applicable"; + ofst->SetProperties(kError, kError); + return; + } + if ((Weight::Properties() & (kPath | kRightSemiring)) + != (kPath | kRightSemiring)) { + FSTERROR() << "SingleShortestPath: Weight needs to have the path" + << " property and be right distributive: " << Weight::Type(); + ofst->SetProperties(kError, kError); + return; + } + while (distance->size() < source) { + distance->push_back(Weight::Zero()); + enqueued.push_back(false); + parent.push_back(kNoStateId); + arc_parent.push_back(Arc(kNoLabel, kNoLabel, Weight::Zero(), kNoStateId)); + } + distance->push_back(Weight::One()); + parent.push_back(kNoStateId); + arc_parent.push_back(Arc(kNoLabel, kNoLabel, Weight::Zero(), kNoStateId)); + state_queue->Enqueue(source); + enqueued.push_back(true); + + while (!state_queue->Empty()) { + StateId s = state_queue->Head(); + state_queue->Dequeue(); + enqueued[s] = false; + Weight sd = (*distance)[s]; + if (ifst.Final(s) != Weight::Zero()) { + Weight w = Times(sd, ifst.Final(s)); + if (f_distance != Plus(f_distance, w)) { + f_distance = Plus(f_distance, w); + f_parent = s; + } + if (!f_distance.Member()) { + ofst->SetProperties(kError, kError); + return; + } + if (opts.first_path) + break; + } + for (ArcIterator< Fst<Arc> > aiter(ifst, s); + !aiter.Done(); + aiter.Next()) { + const Arc &arc = aiter.Value(); + while (distance->size() <= arc.nextstate) { + distance->push_back(Weight::Zero()); + enqueued.push_back(false); + parent.push_back(kNoStateId); + arc_parent.push_back(Arc(kNoLabel, kNoLabel, Weight::Zero(), + kNoStateId)); + } + Weight &nd = (*distance)[arc.nextstate]; + Weight w = Times(sd, arc.weight); + if (nd != Plus(nd, w)) { + nd = Plus(nd, w); + if (!nd.Member()) { + ofst->SetProperties(kError, kError); + return; + } + parent[arc.nextstate] = s; + arc_parent[arc.nextstate] = arc; + if (!enqueued[arc.nextstate]) { + state_queue->Enqueue(arc.nextstate); + enqueued[arc.nextstate] = true; + } else { + state_queue->Update(arc.nextstate); + } + } + } + } + + StateId s_p = kNoStateId, d_p = kNoStateId; + for (StateId s = f_parent, d = kNoStateId; + s != kNoStateId; + d = s, s = parent[s]) { + d_p = s_p; + s_p = ofst->AddState(); + if (d == kNoStateId) { + ofst->SetFinal(s_p, ifst.Final(f_parent)); + } else { + arc_parent[d].nextstate = d_p; + ofst->AddArc(s_p, arc_parent[d]); + } + } + ofst->SetStart(s_p); + if (ifst.Properties(kError, false)) ofst->SetProperties(kError, kError); + ofst->SetProperties( + ShortestPathProperties(ofst->Properties(kFstProperties, false)), + kFstProperties); +} + + +template <class S, class W> +class ShortestPathCompare { + public: + typedef S StateId; + typedef W Weight; + typedef pair<StateId, Weight> Pair; + + ShortestPathCompare(const vector<Pair>& pairs, + const vector<Weight>& distance, + StateId sfinal, float d) + : pairs_(pairs), distance_(distance), superfinal_(sfinal), delta_(d) {} + + bool operator()(const StateId x, const StateId y) const { + const Pair &px = pairs_[x]; + const Pair &py = pairs_[y]; + Weight dx = px.first == superfinal_ ? Weight::One() : + px.first < distance_.size() ? distance_[px.first] : Weight::Zero(); + Weight dy = py.first == superfinal_ ? Weight::One() : + py.first < distance_.size() ? distance_[py.first] : Weight::Zero(); + Weight wx = Times(dx, px.second); + Weight wy = Times(dy, py.second); + // Penalize complete paths to ensure correct results with inexact weights. + // This forms a strict weak order so long as ApproxEqual(a, b) => + // ApproxEqual(a, c) for all c s.t. less_(a, c) && less_(c, b). + if (px.first == superfinal_ && py.first != superfinal_) { + return less_(wy, wx) || ApproxEqual(wx, wy, delta_); + } else if (py.first == superfinal_ && px.first != superfinal_) { + return less_(wy, wx) && !ApproxEqual(wx, wy, delta_); + } else { + return less_(wy, wx); + } + } + + private: + const vector<Pair> &pairs_; + const vector<Weight> &distance_; + StateId superfinal_; + float delta_; + NaturalLess<Weight> less_; +}; + + +// N-Shortest-path algorithm: implements the core n-shortest path +// algorithm. The output is built REVERSED. See below for versions with +// more options and not reversed. +// +// 'ofst' contains the REVERSE of 'n'-shortest paths in 'ifst'. +// 'distance' must contain the shortest distance from each state to a final +// state in 'ifst'. 'delta' is the convergence delta. +// +// The n-shortest paths are the n-lowest weight paths w.r.t. the +// natural semiring order. The single path that can be read from the +// ith of at most n transitions leaving the initial state of 'ofst' is +// the ith shortest path. Disregarding the initial state and initial +// transitions, the n-shortest paths, in fact, form a tree rooted at +// the single final state. +// +// The weights need to be left and right distributive (kSemiring) and +// have the path (kPath) property. +// +// The algorithm is from Mohri and Riley, "An Efficient Algorithm for +// the n-best-strings problem", ICSLP 2002. The algorithm relies on +// the shortest-distance algorithm. There are some issues with the +// pseudo-code as written in the paper (viz., line 11). +// +// IMPLEMENTATION NOTE: The input fst 'ifst' can be a delayed fst and +// and at any state in its expansion the values of distance vector need only +// be defined at that time for the states that are known to exist. +template<class Arc, class RevArc> +void NShortestPath(const Fst<RevArc> &ifst, + MutableFst<Arc> *ofst, + const vector<typename Arc::Weight> &distance, + size_t n, + float delta = kDelta, + typename Arc::Weight weight_threshold = Arc::Weight::Zero(), + typename Arc::StateId state_threshold = kNoStateId) { + typedef typename Arc::StateId StateId; + typedef typename Arc::Weight Weight; + typedef pair<StateId, Weight> Pair; + typedef typename RevArc::Weight RevWeight; + + if (n <= 0) return; + if ((Weight::Properties() & (kPath | kSemiring)) != (kPath | kSemiring)) { + FSTERROR() << "NShortestPath: Weight needs to have the " + << "path property and be distributive: " + << Weight::Type(); + ofst->SetProperties(kError, kError); + return; + } + ofst->DeleteStates(); + ofst->SetInputSymbols(ifst.InputSymbols()); + ofst->SetOutputSymbols(ifst.OutputSymbols()); + // Each state in 'ofst' corresponds to a path with weight w from the + // initial state of 'ifst' to a state s in 'ifst', that can be + // characterized by a pair (s,w). The vector 'pairs' maps each + // state in 'ofst' to the corresponding pair maps states in OFST to + // the corresponding pair (s,w). + vector<Pair> pairs; + // The supefinal state is denoted by -1, 'compare' knows that the + // distance from 'superfinal' to the final state is 'Weight::One()', + // hence 'distance[superfinal]' is not needed. + StateId superfinal = -1; + ShortestPathCompare<StateId, Weight> + compare(pairs, distance, superfinal, delta); + vector<StateId> heap; + // 'r[s + 1]', 's' state in 'fst', is the number of states in 'ofst' + // which corresponding pair contains 's' ,i.e. , it is number of + // paths computed so far to 's'. Valid for 's == -1' (superfinal). + vector<int> r; + NaturalLess<Weight> less; + if (ifst.Start() == kNoStateId || + distance.size() <= ifst.Start() || + distance[ifst.Start()] == Weight::Zero() || + less(weight_threshold, Weight::One()) || + state_threshold == 0) { + if (ifst.Properties(kError, false)) ofst->SetProperties(kError, kError); + return; + } + ofst->SetStart(ofst->AddState()); + StateId final = ofst->AddState(); + ofst->SetFinal(final, Weight::One()); + while (pairs.size() <= final) + pairs.push_back(Pair(kNoStateId, Weight::Zero())); + pairs[final] = Pair(ifst.Start(), Weight::One()); + heap.push_back(final); + Weight limit = Times(distance[ifst.Start()], weight_threshold); + + while (!heap.empty()) { + pop_heap(heap.begin(), heap.end(), compare); + StateId state = heap.back(); + Pair p = pairs[state]; + heap.pop_back(); + Weight d = p.first == superfinal ? Weight::One() : + p.first < distance.size() ? distance[p.first] : Weight::Zero(); + + if (less(limit, Times(d, p.second)) || + (state_threshold != kNoStateId && + ofst->NumStates() >= state_threshold)) + continue; + + while (r.size() <= p.first + 1) r.push_back(0); + ++r[p.first + 1]; + if (p.first == superfinal) + ofst->AddArc(ofst->Start(), Arc(0, 0, Weight::One(), state)); + if ((p.first == superfinal) && (r[p.first + 1] == n)) break; + if (r[p.first + 1] > n) continue; + if (p.first == superfinal) continue; + + for (ArcIterator< Fst<RevArc> > aiter(ifst, p.first); + !aiter.Done(); + aiter.Next()) { + const RevArc &rarc = aiter.Value(); + Arc arc(rarc.ilabel, rarc.olabel, rarc.weight.Reverse(), rarc.nextstate); + Weight w = Times(p.second, arc.weight); + StateId next = ofst->AddState(); + pairs.push_back(Pair(arc.nextstate, w)); + arc.nextstate = state; + ofst->AddArc(next, arc); + heap.push_back(next); + push_heap(heap.begin(), heap.end(), compare); + } + + Weight finalw = ifst.Final(p.first).Reverse(); + if (finalw != Weight::Zero()) { + Weight w = Times(p.second, finalw); + StateId next = ofst->AddState(); + pairs.push_back(Pair(superfinal, w)); + ofst->AddArc(next, Arc(0, 0, finalw, state)); + heap.push_back(next); + push_heap(heap.begin(), heap.end(), compare); + } + } + Connect(ofst); + if (ifst.Properties(kError, false)) ofst->SetProperties(kError, kError); + ofst->SetProperties( + ShortestPathProperties(ofst->Properties(kFstProperties, false)), + kFstProperties); +} + + +// N-Shortest-path algorithm: this version allow fine control +// via the options argument. See below for a simpler interface. +// +// 'ofst' contains the n-shortest paths in 'ifst'. 'distance' returns +// the shortest distances from the source state to each state in +// 'ifst'. 'opts' is used to specify options such as the number of +// paths to return, whether they need to have distinct input +// strings, the queue discipline, the arc filter and the convergence +// delta. +// +// The n-shortest paths are the n-lowest weight paths w.r.t. the +// natural semiring order. The single path that can be read from the +// ith of at most n transitions leaving the initial state of 'ofst' is +// the ith shortest path. Disregarding the initial state and initial +// transitions, The n-shortest paths, in fact, form a tree rooted at +// the single final state. + +// The weights need to be right distributive and have the path (kPath) +// property. They need to be left distributive as well for nshortest +// > 1. +// +// The algorithm is from Mohri and Riley, "An Efficient Algorithm for +// the n-best-strings problem", ICSLP 2002. The algorithm relies on +// the shortest-distance algorithm. There are some issues with the +// pseudo-code as written in the paper (viz., line 11). +template<class Arc, class Queue, class ArcFilter> +void ShortestPath(const Fst<Arc> &ifst, MutableFst<Arc> *ofst, + vector<typename Arc::Weight> *distance, + ShortestPathOptions<Arc, Queue, ArcFilter> &opts) { + typedef typename Arc::StateId StateId; + typedef typename Arc::Weight Weight; + typedef ReverseArc<Arc> ReverseArc; + + size_t n = opts.nshortest; + if (n == 1) { + SingleShortestPath(ifst, ofst, distance, opts); + return; + } + if (n <= 0) return; + if ((Weight::Properties() & (kPath | kSemiring)) != (kPath | kSemiring)) { + FSTERROR() << "ShortestPath: n-shortest: Weight needs to have the " + << "path property and be distributive: " + << Weight::Type(); + ofst->SetProperties(kError, kError); + return; + } + if (!opts.has_distance) { + ShortestDistance(ifst, distance, opts); + if (distance->size() == 1 && !(*distance)[0].Member()) { + ofst->SetProperties(kError, kError); + return; + } + } + // Algorithm works on the reverse of 'fst' : 'rfst', 'distance' is + // the distance to the final state in 'rfst', 'ofst' is built as the + // reverse of the tree of n-shortest path in 'rfst'. + VectorFst<ReverseArc> rfst; + Reverse(ifst, &rfst); + Weight d = Weight::Zero(); + for (ArcIterator< VectorFst<ReverseArc> > aiter(rfst, 0); + !aiter.Done(); aiter.Next()) { + const ReverseArc &arc = aiter.Value(); + StateId s = arc.nextstate - 1; + if (s < distance->size()) + d = Plus(d, Times(arc.weight.Reverse(), (*distance)[s])); + } + distance->insert(distance->begin(), d); + + if (!opts.unique) { + NShortestPath(rfst, ofst, *distance, n, opts.delta, + opts.weight_threshold, opts.state_threshold); + } else { + vector<Weight> ddistance; + DeterminizeFstOptions<ReverseArc> dopts(opts.delta); + DeterminizeFst<ReverseArc> dfst(rfst, distance, &ddistance, dopts); + NShortestPath(dfst, ofst, ddistance, n, opts.delta, + opts.weight_threshold, opts.state_threshold); + } + distance->erase(distance->begin()); +} + + +// Shortest-path algorithm: simplified interface. See above for a +// version that allows finer control. +// +// 'ofst' contains the 'n'-shortest paths in 'ifst'. The queue +// discipline is automatically selected. When 'unique' == true, only +// paths with distinct input labels are returned. +// +// The n-shortest paths are the n-lowest weight paths w.r.t. the +// natural semiring order. The single path that can be read from the +// ith of at most n transitions leaving the initial state of 'ofst' is +// the ith best path. +// +// The weights need to be right distributive and have the path +// (kPath) property. +template<class Arc> +void ShortestPath(const Fst<Arc> &ifst, MutableFst<Arc> *ofst, + size_t n = 1, bool unique = false, + bool first_path = false, + typename Arc::Weight weight_threshold = Arc::Weight::Zero(), + typename Arc::StateId state_threshold = kNoStateId) { + vector<typename Arc::Weight> distance; + AnyArcFilter<Arc> arc_filter; + AutoQueue<typename Arc::StateId> state_queue(ifst, &distance, arc_filter); + ShortestPathOptions< Arc, AutoQueue<typename Arc::StateId>, + AnyArcFilter<Arc> > opts(&state_queue, arc_filter, n, unique, false, + kDelta, first_path, weight_threshold, + state_threshold); + ShortestPath(ifst, ofst, &distance, opts); +} + +} // namespace fst + +#endif // FST_LIB_SHORTEST_PATH_H__ diff --git a/kaldi_io/src/tools/openfst/include/fst/signed-log-weight.h b/kaldi_io/src/tools/openfst/include/fst/signed-log-weight.h new file mode 100644 index 0000000..61adefb --- /dev/null +++ b/kaldi_io/src/tools/openfst/include/fst/signed-log-weight.h @@ -0,0 +1,367 @@ + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Copyright 2005-2010 Google, Inc. +// Author: [email protected] (Kasturi Rangan Raghavan) +// \file +// LogWeight along with sign information that represents the value X in the +// linear domain as <sign(X), -ln(|X|)> +// The sign is a TropicalWeight: +// positive, TropicalWeight.Value() > 0.0, recommended value 1.0 +// negative, TropicalWeight.Value() <= 0.0, recommended value -1.0 + +#ifndef FST_LIB_SIGNED_LOG_WEIGHT_H_ +#define FST_LIB_SIGNED_LOG_WEIGHT_H_ + +#include <fst/float-weight.h> +#include <fst/pair-weight.h> + + +namespace fst { +template <class T> +class SignedLogWeightTpl + : public PairWeight<TropicalWeight, LogWeightTpl<T> > { + public: + typedef TropicalWeight X1; + typedef LogWeightTpl<T> X2; + using PairWeight<X1, X2>::Value1; + using PairWeight<X1, X2>::Value2; + + using PairWeight<X1, X2>::Reverse; + using PairWeight<X1, X2>::Quantize; + using PairWeight<X1, X2>::Member; + + typedef SignedLogWeightTpl<T> ReverseWeight; + + SignedLogWeightTpl() : PairWeight<X1, X2>() {} + + SignedLogWeightTpl(const SignedLogWeightTpl<T>& w) + : PairWeight<X1, X2> (w) { } + + SignedLogWeightTpl(const PairWeight<X1, X2>& w) + : PairWeight<X1, X2> (w) { } + + SignedLogWeightTpl(const X1& x1, const X2& x2) + : PairWeight<X1, X2>(x1, x2) { } + + static const SignedLogWeightTpl<T> &Zero() { + static const SignedLogWeightTpl<T> zero(X1(1.0), X2::Zero()); + return zero; + } + + static const SignedLogWeightTpl<T> &One() { + static const SignedLogWeightTpl<T> one(X1(1.0), X2::One()); + return one; + } + + static const SignedLogWeightTpl<T> &NoWeight() { + static const SignedLogWeightTpl<T> no_weight(X1(1.0), X2::NoWeight()); + return no_weight; + } + + static const string &Type() { + static const string type = "signed_log_" + X1::Type() + "_" + X2::Type(); + return type; + } + + ProductWeight<X1, X2> Quantize(float delta = kDelta) const { + return PairWeight<X1, X2>::Quantize(); + } + + ReverseWeight Reverse() const { + return PairWeight<X1, X2>::Reverse(); + } + + bool Member() const { + return PairWeight<X1, X2>::Member(); + } + + static uint64 Properties() { + // not idempotent nor path + return kLeftSemiring | kRightSemiring | kCommutative; + } + + size_t Hash() const { + size_t h1; + if (Value2() == X2::Zero() || Value1().Value() > 0.0) + h1 = TropicalWeight(1.0).Hash(); + else + h1 = TropicalWeight(-1.0).Hash(); + size_t h2 = Value2().Hash(); + const int lshift = 5; + const int rshift = CHAR_BIT * sizeof(size_t) - 5; + return h1 << lshift ^ h1 >> rshift ^ h2; + } +}; + +template <class T> +inline SignedLogWeightTpl<T> Plus(const SignedLogWeightTpl<T> &w1, + const SignedLogWeightTpl<T> &w2) { + if (!w1.Member() || !w2.Member()) + return SignedLogWeightTpl<T>::NoWeight(); + bool s1 = w1.Value1().Value() > 0.0; + bool s2 = w2.Value1().Value() > 0.0; + T f1 = w1.Value2().Value(); + T f2 = w2.Value2().Value(); + if (f1 == FloatLimits<T>::PosInfinity()) + return w2; + else if (f2 == FloatLimits<T>::PosInfinity()) + return w1; + else if (f1 == f2) { + if (s1 == s2) + return SignedLogWeightTpl<T>(w1.Value1(), (f2 - log(2.0F))); + else + return SignedLogWeightTpl<T>::Zero(); + } else if (f1 > f2) { + if (s1 == s2) { + return SignedLogWeightTpl<T>( + w1.Value1(), (f2 - log(1.0F + exp(f2 - f1)))); + } else { + return SignedLogWeightTpl<T>( + w2.Value1(), (f2 - log(1.0F - exp(f2 - f1)))); + } + } else { + if (s2 == s1) { + return SignedLogWeightTpl<T>( + w2.Value1(), (f1 - log(1.0F + exp(f1 - f2)))); + } else { + return SignedLogWeightTpl<T>( + w1.Value1(), (f1 - log(1.0F - exp(f1 - f2)))); + } + } +} + +template <class T> +inline SignedLogWeightTpl<T> Minus(const SignedLogWeightTpl<T> &w1, + const SignedLogWeightTpl<T> &w2) { + SignedLogWeightTpl<T> minus_w2(-w2.Value1().Value(), w2.Value2()); + return Plus(w1, minus_w2); +} + +template <class T> +inline SignedLogWeightTpl<T> Times(const SignedLogWeightTpl<T> &w1, + const SignedLogWeightTpl<T> &w2) { + if (!w1.Member() || !w2.Member()) + return SignedLogWeightTpl<T>::NoWeight(); + bool s1 = w1.Value1().Value() > 0.0; + bool s2 = w2.Value1().Value() > 0.0; + T f1 = w1.Value2().Value(); + T f2 = w2.Value2().Value(); + if (s1 == s2) + return SignedLogWeightTpl<T>(TropicalWeight(1.0), (f1 + f2)); + else + return SignedLogWeightTpl<T>(TropicalWeight(-1.0), (f1 + f2)); +} + +template <class T> +inline SignedLogWeightTpl<T> Divide(const SignedLogWeightTpl<T> &w1, + const SignedLogWeightTpl<T> &w2, + DivideType typ = DIVIDE_ANY) { + if (!w1.Member() || !w2.Member()) + return SignedLogWeightTpl<T>::NoWeight(); + bool s1 = w1.Value1().Value() > 0.0; + bool s2 = w2.Value1().Value() > 0.0; + T f1 = w1.Value2().Value(); + T f2 = w2.Value2().Value(); + if (f2 == FloatLimits<T>::PosInfinity()) + return SignedLogWeightTpl<T>(TropicalWeight(1.0), + FloatLimits<T>::NumberBad()); + else if (f1 == FloatLimits<T>::PosInfinity()) + return SignedLogWeightTpl<T>(TropicalWeight(1.0), + FloatLimits<T>::PosInfinity()); + else if (s1 == s2) + return SignedLogWeightTpl<T>(TropicalWeight(1.0), (f1 - f2)); + else + return SignedLogWeightTpl<T>(TropicalWeight(-1.0), (f1 - f2)); +} + +template <class T> +inline bool ApproxEqual(const SignedLogWeightTpl<T> &w1, + const SignedLogWeightTpl<T> &w2, + float delta = kDelta) { + bool s1 = w1.Value1().Value() > 0.0; + bool s2 = w2.Value1().Value() > 0.0; + if (s1 == s2) { + return ApproxEqual(w1.Value2(), w2.Value2(), delta); + } else { + return w1.Value2() == LogWeightTpl<T>::Zero() + && w2.Value2() == LogWeightTpl<T>::Zero(); + } +} + +template <class T> +inline bool operator==(const SignedLogWeightTpl<T> &w1, + const SignedLogWeightTpl<T> &w2) { + bool s1 = w1.Value1().Value() > 0.0; + bool s2 = w2.Value1().Value() > 0.0; + if (s1 == s2) + return w1.Value2() == w2.Value2(); + else + return (w1.Value2() == LogWeightTpl<T>::Zero()) && + (w2.Value2() == LogWeightTpl<T>::Zero()); +} + + +// Single-precision signed-log weight +typedef SignedLogWeightTpl<float> SignedLogWeight; +// Double-precision signed-log weight +typedef SignedLogWeightTpl<double> SignedLog64Weight; + +// +// WEIGHT CONVERTER SPECIALIZATIONS. +// + +template <class W1, class W2> +bool SignedLogConvertCheck(W1 w) { + if (w.Value1().Value() < 0.0) { + FSTERROR() << "WeightConvert: can't convert weight from \"" + << W1::Type() << "\" to \"" << W2::Type(); + return false; + } + return true; +} + +// Convert to tropical +template <> +struct WeightConvert<SignedLogWeight, TropicalWeight> { + TropicalWeight operator()(SignedLogWeight w) const { + if (!SignedLogConvertCheck<SignedLogWeight, TropicalWeight>(w)) + return TropicalWeight::NoWeight(); + return w.Value2().Value(); + } +}; + +template <> +struct WeightConvert<SignedLog64Weight, TropicalWeight> { + TropicalWeight operator()(SignedLog64Weight w) const { + if (!SignedLogConvertCheck<SignedLog64Weight, TropicalWeight>(w)) + return TropicalWeight::NoWeight(); + return w.Value2().Value(); + } +}; + +// Convert to log +template <> +struct WeightConvert<SignedLogWeight, LogWeight> { + LogWeight operator()(SignedLogWeight w) const { + if (!SignedLogConvertCheck<SignedLogWeight, LogWeight>(w)) + return LogWeight::NoWeight(); + return w.Value2().Value(); + } +}; + +template <> +struct WeightConvert<SignedLog64Weight, LogWeight> { + LogWeight operator()(SignedLog64Weight w) const { + if (!SignedLogConvertCheck<SignedLog64Weight, LogWeight>(w)) + return LogWeight::NoWeight(); + return w.Value2().Value(); + } +}; + +// Convert to log64 +template <> +struct WeightConvert<SignedLogWeight, Log64Weight> { + Log64Weight operator()(SignedLogWeight w) const { + if (!SignedLogConvertCheck<SignedLogWeight, Log64Weight>(w)) + return Log64Weight::NoWeight(); + return w.Value2().Value(); + } +}; + +template <> +struct WeightConvert<SignedLog64Weight, Log64Weight> { + Log64Weight operator()(SignedLog64Weight w) const { + if (!SignedLogConvertCheck<SignedLog64Weight, Log64Weight>(w)) + return Log64Weight::NoWeight(); + return w.Value2().Value(); + } +}; + +// Convert to signed log +template <> +struct WeightConvert<TropicalWeight, SignedLogWeight> { + SignedLogWeight operator()(TropicalWeight w) const { + TropicalWeight x1 = 1.0; + LogWeight x2 = w.Value(); + return SignedLogWeight(x1, x2); + } +}; + +template <> +struct WeightConvert<LogWeight, SignedLogWeight> { + SignedLogWeight operator()(LogWeight w) const { + TropicalWeight x1 = 1.0; + LogWeight x2 = w.Value(); + return SignedLogWeight(x1, x2); + } +}; + +template <> +struct WeightConvert<Log64Weight, SignedLogWeight> { + SignedLogWeight operator()(Log64Weight w) const { + TropicalWeight x1 = 1.0; + LogWeight x2 = w.Value(); + return SignedLogWeight(x1, x2); + } +}; + +template <> +struct WeightConvert<SignedLog64Weight, SignedLogWeight> { + SignedLogWeight operator()(SignedLog64Weight w) const { + TropicalWeight x1 = w.Value1(); + LogWeight x2 = w.Value2().Value(); + return SignedLogWeight(x1, x2); + } +}; + +// Convert to signed log64 +template <> +struct WeightConvert<TropicalWeight, SignedLog64Weight> { + SignedLog64Weight operator()(TropicalWeight w) const { + TropicalWeight x1 = 1.0; + Log64Weight x2 = w.Value(); + return SignedLog64Weight(x1, x2); + } +}; + +template <> +struct WeightConvert<LogWeight, SignedLog64Weight> { + SignedLog64Weight operator()(LogWeight w) const { + TropicalWeight x1 = 1.0; + Log64Weight x2 = w.Value(); + return SignedLog64Weight(x1, x2); + } +}; + +template <> +struct WeightConvert<Log64Weight, SignedLog64Weight> { + SignedLog64Weight operator()(Log64Weight w) const { + TropicalWeight x1 = 1.0; + Log64Weight x2 = w.Value(); + return SignedLog64Weight(x1, x2); + } +}; + +template <> +struct WeightConvert<SignedLogWeight, SignedLog64Weight> { + SignedLog64Weight operator()(SignedLogWeight w) const { + TropicalWeight x1 = w.Value1(); + Log64Weight x2 = w.Value2().Value(); + return SignedLog64Weight(x1, x2); + } +}; + +} // namespace fst + +#endif // FST_LIB_SIGNED_LOG_WEIGHT_H_ diff --git a/kaldi_io/src/tools/openfst/include/fst/slist.h b/kaldi_io/src/tools/openfst/include/fst/slist.h new file mode 100644 index 0000000..b800522 --- /dev/null +++ b/kaldi_io/src/tools/openfst/include/fst/slist.h @@ -0,0 +1,61 @@ +// slist.h +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Author: [email protected] (Michael Riley) +// +// \file +// Includes slist definition or defines in terms of STL list as a fallback. + +#ifndef FST_LIB_SLIST_H__ +#define FST_LIB_SLIST_H__ + +#include <fst/config.h> + +#ifdef HAVE___GNU_CXX__SLIST_INT_ + +#include <ext/slist> + +namespace fst { + +using __gnu_cxx::slist; + +} + +#else + +#include <list> + +namespace fst { + +using std::list; + +template <typename T> class slist : public list<T> { + public: + typedef typename list<T>::iterator iterator; + typedef typename list<T>::const_iterator const_iterator; + + using list<T>::erase; + + iterator erase_after(iterator pos) { + iterator npos = pos; + erase(++npos); + return pos; + } +}; + +} // namespace fst + +#endif // HAVE___GNU_CXX__SLIST_INT_ + +#endif // FST_LIB_SLIST_H__ diff --git a/kaldi_io/src/tools/openfst/include/fst/sparse-power-weight.h b/kaldi_io/src/tools/openfst/include/fst/sparse-power-weight.h new file mode 100644 index 0000000..a1ff56a --- /dev/null +++ b/kaldi_io/src/tools/openfst/include/fst/sparse-power-weight.h @@ -0,0 +1,225 @@ +// sparse-power-weight.h + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Copyright 2005-2010 Google, Inc. +// Author: [email protected] (Kasturi Rangan Raghavan) +// Inspiration: [email protected] (Cyril Allauzen) +// +// \file +// Cartesian power weight semiring operation definitions. +// Uses SparseTupleWeight as underlying representation. + +#ifndef FST_LIB_SPARSE_POWER_WEIGHT_H__ +#define FST_LIB_SPARSE_POWER_WEIGHT_H__ + +#include<string> + +#include <fst/sparse-tuple-weight.h> +#include <fst/weight.h> + + +namespace fst { + +// Below SparseTupleWeight*Mapper are used in conjunction with +// SparseTupleWeightMap to compute the respective semiring operations +template<class W, class K> +struct SparseTupleWeightPlusMapper { + W Map(const K& k, const W& v1, const W& v2) const { + return Plus(v1, v2); + } +}; + +template<class W, class K> +struct SparseTupleWeightTimesMapper { + W Map(const K& k, const W& v1, const W& v2) const { + return Times(v1, v2); + } +}; + +template<class W, class K> +struct SparseTupleWeightDivideMapper { + SparseTupleWeightDivideMapper(DivideType divide_type) { + divide_type_ = divide_type; + } + W Map(const K& k, const W& v1, const W& v2) const { + return Divide(v1, v2, divide_type_); + } + DivideType divide_type_; +}; + +template<class W, class K> +struct SparseTupleWeightApproxMapper { + SparseTupleWeightApproxMapper(float delta) { delta_ = delta; } + W Map(const K& k, const W& v1, const W& v2) const { + return ApproxEqual(v1, v2, delta_) ? W::One() : W::Zero(); + } + float delta_; +}; + +// Sparse cartesian power semiring: W ^ n +// Forms: +// - a left semimodule when W is a left semiring, +// - a right semimodule when W is a right semiring, +// - a bisemimodule when W is a semiring, +// the free semimodule of rank n over W +// The Times operation is overloaded to provide the +// left and right scalar products. +// K is the key value type. kNoKey(-1) is reserved for internal use +template <class W, class K = int> +class SparsePowerWeight : public SparseTupleWeight<W, K> { + public: + using SparseTupleWeight<W, K>::Zero; + using SparseTupleWeight<W, K>::One; + using SparseTupleWeight<W, K>::NoWeight; + using SparseTupleWeight<W, K>::Quantize; + using SparseTupleWeight<W, K>::Reverse; + + typedef SparsePowerWeight<typename W::ReverseWeight, K> ReverseWeight; + + SparsePowerWeight() {} + + SparsePowerWeight(const SparseTupleWeight<W, K> &w) : + SparseTupleWeight<W, K>(w) { } + + template <class Iterator> + SparsePowerWeight(Iterator begin, Iterator end) : + SparseTupleWeight<W, K>(begin, end) { } + + SparsePowerWeight(const K &key, const W &w) : + SparseTupleWeight<W, K>(key, w) { } + + static const SparsePowerWeight<W, K> &Zero() { + static const SparsePowerWeight<W, K> zero(SparseTupleWeight<W, K>::Zero()); + return zero; + } + + static const SparsePowerWeight<W, K> &One() { + static const SparsePowerWeight<W, K> one(SparseTupleWeight<W, K>::One()); + return one; + } + + static const SparsePowerWeight<W, K> &NoWeight() { + static const SparsePowerWeight<W, K> no_weight( + SparseTupleWeight<W, K>::NoWeight()); + return no_weight; + } + + // Overide this: Overwrite the Type method to reflect the key type + // if using non-default key type. + static const string &Type() { + static string type; + if(type.empty()) { + type = W::Type() + "_^n"; + if(sizeof(K) != sizeof(uint32)) { + string size; + Int64ToStr(8 * sizeof(K), &size); + type += "_" + size; + } + } + return type; + } + + static uint64 Properties() { + uint64 props = W::Properties(); + return props & (kLeftSemiring | kRightSemiring | + kCommutative | kIdempotent); + } + + SparsePowerWeight<W, K> Quantize(float delta = kDelta) const { + return SparseTupleWeight<W, K>::Quantize(delta); + } + + ReverseWeight Reverse() const { + return SparseTupleWeight<W, K>::Reverse(); + } +}; + +// Semimodule plus operation +template <class W, class K> +inline SparsePowerWeight<W, K> Plus(const SparsePowerWeight<W, K> &w1, + const SparsePowerWeight<W, K> &w2) { + SparsePowerWeight<W, K> ret; + SparseTupleWeightPlusMapper<W, K> operator_mapper; + SparseTupleWeightMap(&ret, w1, w2, operator_mapper); + return ret; +} + +// Semimodule times operation +template <class W, class K> +inline SparsePowerWeight<W, K> Times(const SparsePowerWeight<W, K> &w1, + const SparsePowerWeight<W, K> &w2) { + SparsePowerWeight<W, K> ret; + SparseTupleWeightTimesMapper<W, K> operator_mapper; + SparseTupleWeightMap(&ret, w1, w2, operator_mapper); + return ret; +} + +// Semimodule divide operation +template <class W, class K> +inline SparsePowerWeight<W, K> Divide(const SparsePowerWeight<W, K> &w1, + const SparsePowerWeight<W, K> &w2, + DivideType type = DIVIDE_ANY) { + SparsePowerWeight<W, K> ret; + SparseTupleWeightDivideMapper<W, K> operator_mapper(type); + SparseTupleWeightMap(&ret, w1, w2, operator_mapper); + return ret; +} + +// Semimodule dot product +template <class W, class K> +inline const W& DotProduct(const SparsePowerWeight<W, K> &w1, + const SparsePowerWeight<W, K> &w2) { + const SparsePowerWeight<W, K>& product = Times(w1, w2); + W ret(W::Zero()); + for (SparseTupleWeightIterator<W, K> it(product); !it.Done(); it.Next()) { + ret = Plus(ret, it.Value().second); + } + return ret; +} + +template <class W, class K> +inline bool ApproxEqual(const SparsePowerWeight<W, K> &w1, + const SparsePowerWeight<W, K> &w2, + float delta = kDelta) { + SparseTupleWeight<W, K> ret; + SparseTupleWeightApproxMapper<W, K> operator_mapper(kDelta); + SparseTupleWeightMap(&ret, w1, w2, operator_mapper); + return ret == SparsePowerWeight<W, K>::One(); +} + +template <class W, class K> +inline SparsePowerWeight<W, K> Times(const W &k, + const SparsePowerWeight<W, K> &w2) { + SparsePowerWeight<W, K> w1(k); + return Times(w1, w2); +} + +template <class W, class K> +inline SparsePowerWeight<W, K> Times(const SparsePowerWeight<W, K> &w1, + const W &k) { + SparsePowerWeight<W, K> w2(k); + return Times(w1, w2); +} + +template <class W, class K> +inline SparsePowerWeight<W, K> Divide(const SparsePowerWeight<W, K> &w1, + const W &k, + DivideType divide_type = DIVIDE_ANY) { + SparsePowerWeight<W, K> w2(k); + return Divide(w1, w2, divide_type); +} + +} // namespace fst + +#endif // FST_LIB_SPARSE_POWER_WEIGHT_H__ diff --git a/kaldi_io/src/tools/openfst/include/fst/sparse-tuple-weight.h b/kaldi_io/src/tools/openfst/include/fst/sparse-tuple-weight.h new file mode 100644 index 0000000..c12ef4f --- /dev/null +++ b/kaldi_io/src/tools/openfst/include/fst/sparse-tuple-weight.h @@ -0,0 +1,640 @@ +// sparse-tuple-weight.h + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Copyright 2005-2010 Google, Inc. +// Author: [email protected] (Kasturi Rangan Raghavan) +// Inspiration: [email protected] (Cyril Allauzen) +// \file +// Sparse version of tuple-weight, based on tuple-weight.h +// Internally stores sparse key, value pairs in linked list +// Default value elemnt is the assumed value of unset keys +// Internal singleton implementation that stores first key, +// value pair as a initialized member variable to avoide +// unnecessary allocation on heap. +// Use SparseTupleWeightIterator to iterate through the key,value pairs +// Note: this does NOT iterate through the default value. +// +// Sparse tuple weight set operation definitions. + +#ifndef FST_LIB_SPARSE_TUPLE_WEIGHT_H__ +#define FST_LIB_SPARSE_TUPLE_WEIGHT_H__ + +#include<string> +#include<list> +#include<stack> +#include<tr1/unordered_map> +using std::tr1::unordered_map; +using std::tr1::unordered_multimap; + +#include <fst/weight.h> + + +DECLARE_string(fst_weight_parentheses); +DECLARE_string(fst_weight_separator); + +namespace fst { + +template <class W, class K> class SparseTupleWeight; + +template<class W, class K> +class SparseTupleWeightIterator; + +template <class W, class K> +istream &operator>>(istream &strm, SparseTupleWeight<W, K> &w); + +// Arbitrary dimension tuple weight, stored as a sorted linked-list +// W is any weight class, +// K is the key value type. kNoKey(-1) is reserved for internal use +template <class W, class K = int> +class SparseTupleWeight { + public: + typedef pair<K, W> Pair; + typedef SparseTupleWeight<typename W::ReverseWeight, K> ReverseWeight; + + const static K kNoKey = -1; + SparseTupleWeight() { + Init(); + } + + template <class Iterator> + SparseTupleWeight(Iterator begin, Iterator end) { + Init(); + // Assumes input iterator is sorted + for (Iterator it = begin; it != end; ++it) + Push(*it); + } + + + SparseTupleWeight(const K& key, const W &w) { + Init(); + Push(key, w); + } + + SparseTupleWeight(const W &w) { + Init(w); + } + + SparseTupleWeight(const SparseTupleWeight<W, K> &w) { + Init(w.DefaultValue()); + SetDefaultValue(w.DefaultValue()); + for (SparseTupleWeightIterator<W, K> it(w); !it.Done(); it.Next()) { + Push(it.Value()); + } + } + + static const SparseTupleWeight<W, K> &Zero() { + static SparseTupleWeight<W, K> zero; + return zero; + } + + static const SparseTupleWeight<W, K> &One() { + static SparseTupleWeight<W, K> one(W::One()); + return one; + } + + static const SparseTupleWeight<W, K> &NoWeight() { + static SparseTupleWeight<W, K> no_weight(W::NoWeight()); + return no_weight; + } + + istream &Read(istream &strm) { + ReadType(strm, &default_); + ReadType(strm, &first_); + return ReadType(strm, &rest_); + } + + ostream &Write(ostream &strm) const { + WriteType(strm, default_); + WriteType(strm, first_); + return WriteType(strm, rest_); + } + + SparseTupleWeight<W, K> &operator=(const SparseTupleWeight<W, K> &w) { + if (this == &w) return *this; // check for w = w + Init(w.DefaultValue()); + for (SparseTupleWeightIterator<W, K> it(w); !it.Done(); it.Next()) { + Push(it.Value()); + } + return *this; + } + + bool Member() const { + if (!DefaultValue().Member()) return false; + for (SparseTupleWeightIterator<W, K> it(*this); !it.Done(); it.Next()) { + if (!it.Value().second.Member()) return false; + } + return true; + } + + // Assumes H() function exists for the hash of the key value + size_t Hash() const { + uint64 h = 0; + std::tr1::hash<K> H; + for (SparseTupleWeightIterator<W, K> it(*this); !it.Done(); it.Next()) { + h = 5 * h + H(it.Value().first); + h = 13 * h + it.Value().second.Hash(); + } + return size_t(h); + } + + SparseTupleWeight<W, K> Quantize(float delta = kDelta) const { + SparseTupleWeight<W, K> w; + for (SparseTupleWeightIterator<W, K> it(*this); !it.Done(); it.Next()) { + w.Push(it.Value().first, it.Value().second.Quantize(delta)); + } + return w; + } + + ReverseWeight Reverse() const { + SparseTupleWeight<W, K> w; + for (SparseTupleWeightIterator<W, K> it(*this); !it.Done(); it.Next()) { + w.Push(it.Value().first, it.Value().second.Reverse()); + } + return w; + } + + // Common initializer among constructors. + void Init() { + Init(W::Zero()); + } + + void Init(const W& default_value) { + first_.first = kNoKey; + /* initialized to the reserved key value */ + default_ = default_value; + rest_.clear(); + } + + size_t Size() const { + if (first_.first == kNoKey) + return 0; + else + return rest_.size() + 1; + } + + inline void Push(const K &k, const W &w, bool default_value_check = true) { + Push(make_pair(k, w), default_value_check); + } + + inline void Push(const Pair &p, bool default_value_check = true) { + if (default_value_check && p.second == default_) return; + if (first_.first == kNoKey) { + first_ = p; + } else { + rest_.push_back(p); + } + } + + void SetDefaultValue(const W& val) { default_ = val; } + + const W& DefaultValue() const { return default_; } + + protected: + static istream& ReadNoParen( + istream&, SparseTupleWeight<W, K>&, char separator); + + static istream& ReadWithParen( + istream&, SparseTupleWeight<W, K>&, + char separator, char open_paren, char close_paren); + + private: + // Assumed default value of uninitialized keys, by default W::Zero() + W default_; + + // Key values pairs are first stored in first_, then fill rest_ + // this way we can avoid dynamic allocation in the common case + // where the weight is a single key,val pair. + Pair first_; + list<Pair> rest_; + + friend istream &operator>><W, K>(istream&, SparseTupleWeight<W, K>&); + friend class SparseTupleWeightIterator<W, K>; +}; + +template<class W, class K> +class SparseTupleWeightIterator { + public: + typedef typename SparseTupleWeight<W, K>::Pair Pair; + typedef typename list<Pair>::const_iterator const_iterator; + typedef typename list<Pair>::iterator iterator; + + explicit SparseTupleWeightIterator(const SparseTupleWeight<W, K>& w) + : first_(w.first_), rest_(w.rest_), init_(true), + iter_(rest_.begin()) {} + + bool Done() const { + if (init_) + return first_.first == SparseTupleWeight<W, K>::kNoKey; + else + return iter_ == rest_.end(); + } + + const Pair& Value() const { return init_ ? first_ : *iter_; } + + void Next() { + if (init_) + init_ = false; + else + ++iter_; + } + + void Reset() { + init_ = true; + iter_ = rest_.begin(); + } + + private: + const Pair &first_; + const list<Pair> & rest_; + bool init_; // in the initialized state? + typename list<Pair>::const_iterator iter_; + + DISALLOW_COPY_AND_ASSIGN(SparseTupleWeightIterator); +}; + +template<class W, class K, class M> +inline void SparseTupleWeightMap( + SparseTupleWeight<W, K>* ret, + const SparseTupleWeight<W, K>& w1, + const SparseTupleWeight<W, K>& w2, + const M& operator_mapper) { + SparseTupleWeightIterator<W, K> w1_it(w1); + SparseTupleWeightIterator<W, K> w2_it(w2); + const W& v1_def = w1.DefaultValue(); + const W& v2_def = w2.DefaultValue(); + ret->SetDefaultValue(operator_mapper.Map(0, v1_def, v2_def)); + while (!w1_it.Done() || !w2_it.Done()) { + const K& k1 = (w1_it.Done()) ? w2_it.Value().first : w1_it.Value().first; + const K& k2 = (w2_it.Done()) ? w1_it.Value().first : w2_it.Value().first; + const W& v1 = (w1_it.Done()) ? v1_def : w1_it.Value().second; + const W& v2 = (w2_it.Done()) ? v2_def : w2_it.Value().second; + if (k1 == k2) { + ret->Push(k1, operator_mapper.Map(k1, v1, v2)); + if (!w1_it.Done()) w1_it.Next(); + if (!w2_it.Done()) w2_it.Next(); + } else if (k1 < k2) { + ret->Push(k1, operator_mapper.Map(k1, v1, v2_def)); + w1_it.Next(); + } else { + ret->Push(k2, operator_mapper.Map(k2, v1_def, v2)); + w2_it.Next(); + } + } +} + +template <class W, class K> +inline bool operator==(const SparseTupleWeight<W, K> &w1, + const SparseTupleWeight<W, K> &w2) { + const W& v1_def = w1.DefaultValue(); + const W& v2_def = w2.DefaultValue(); + if (v1_def != v2_def) return false; + + SparseTupleWeightIterator<W, K> w1_it(w1); + SparseTupleWeightIterator<W, K> w2_it(w2); + while (!w1_it.Done() || !w2_it.Done()) { + const K& k1 = (w1_it.Done()) ? w2_it.Value().first : w1_it.Value().first; + const K& k2 = (w2_it.Done()) ? w1_it.Value().first : w2_it.Value().first; + const W& v1 = (w1_it.Done()) ? v1_def : w1_it.Value().second; + const W& v2 = (w2_it.Done()) ? v2_def : w2_it.Value().second; + if (k1 == k2) { + if (v1 != v2) return false; + if (!w1_it.Done()) w1_it.Next(); + if (!w2_it.Done()) w2_it.Next(); + } else if (k1 < k2) { + if (v1 != v2_def) return false; + w1_it.Next(); + } else { + if (v1_def != v2) return false; + w2_it.Next(); + } + } + return true; +} + +template <class W, class K> +inline bool operator!=(const SparseTupleWeight<W, K> &w1, + const SparseTupleWeight<W, K> &w2) { + return !(w1 == w2); +} + +template <class W, class K> +inline ostream &operator<<(ostream &strm, const SparseTupleWeight<W, K> &w) { + if(FLAGS_fst_weight_separator.size() != 1) { + FSTERROR() << "FLAGS_fst_weight_separator.size() is not equal to 1"; + strm.clear(std::ios::badbit); + return strm; + } + char separator = FLAGS_fst_weight_separator[0]; + bool write_parens = false; + if (!FLAGS_fst_weight_parentheses.empty()) { + if (FLAGS_fst_weight_parentheses.size() != 2) { + FSTERROR() << "FLAGS_fst_weight_parentheses.size() is not equal to 2"; + strm.clear(std::ios::badbit); + return strm; + } + write_parens = true; + } + + if (write_parens) + strm << FLAGS_fst_weight_parentheses[0]; + + strm << w.DefaultValue(); + strm << separator; + + size_t n = w.Size(); + strm << n; + strm << separator; + + for (SparseTupleWeightIterator<W, K> it(w); !it.Done(); it.Next()) { + strm << it.Value().first; + strm << separator; + strm << it.Value().second; + strm << separator; + } + + if (write_parens) + strm << FLAGS_fst_weight_parentheses[1]; + + return strm; +} + +template <class W, class K> +inline istream &operator>>(istream &strm, SparseTupleWeight<W, K> &w) { + if(FLAGS_fst_weight_separator.size() != 1) { + FSTERROR() << "FLAGS_fst_weight_separator.size() is not equal to 1"; + strm.clear(std::ios::badbit); + return strm; + } + char separator = FLAGS_fst_weight_separator[0]; + + if (!FLAGS_fst_weight_parentheses.empty()) { + if (FLAGS_fst_weight_parentheses.size() != 2) { + FSTERROR() << "FLAGS_fst_weight_parentheses.size() is not equal to 2"; + strm.clear(std::ios::badbit); + return strm; + } + return SparseTupleWeight<W, K>::ReadWithParen( + strm, w, separator, FLAGS_fst_weight_parentheses[0], + FLAGS_fst_weight_parentheses[1]); + } else { + return SparseTupleWeight<W, K>::ReadNoParen(strm, w, separator); + } +} + +// Reads SparseTupleWeight when there are no parentheses around tuple terms +template <class W, class K> +inline istream& SparseTupleWeight<W, K>::ReadNoParen( + istream &strm, + SparseTupleWeight<W, K> &w, + char separator) { + int c; + size_t n; + + do { + c = strm.get(); + } while (isspace(c)); + + + { // Read default weight + W default_value; + string s; + while (c != separator) { + if (c == EOF) { + strm.clear(std::ios::badbit); + return strm; + } + s += c; + c = strm.get(); + } + istringstream sstrm(s); + sstrm >> default_value; + w.SetDefaultValue(default_value); + } + + c = strm.get(); + + { // Read n + string s; + while (c != separator) { + if (c == EOF) { + strm.clear(std::ios::badbit); + return strm; + } + s += c; + c = strm.get(); + } + istringstream sstrm(s); + sstrm >> n; + } + + // Read n elements + for (size_t i = 0; i < n; ++i) { + // discard separator + c = strm.get(); + K p; + W r; + + { // read key + string s; + while (c != separator) { + if (c == EOF) { + strm.clear(std::ios::badbit); + return strm; + } + s += c; + c = strm.get(); + } + istringstream sstrm(s); + sstrm >> p; + } + + c = strm.get(); + + { // read weight + string s; + while (c != separator) { + if (c == EOF) { + strm.clear(std::ios::badbit); + return strm; + } + s += c; + c = strm.get(); + } + istringstream sstrm(s); + sstrm >> r; + } + + w.Push(p, r); + } + + c = strm.get(); + if (c != separator) { + strm.clear(std::ios::badbit); + } + + return strm; +} + +// Reads SparseTupleWeight when there are parentheses around tuple terms +template <class W, class K> +inline istream& SparseTupleWeight<W, K>::ReadWithParen( + istream &strm, + SparseTupleWeight<W, K> &w, + char separator, + char open_paren, + char close_paren) { + int c; + size_t n; + + do { + c = strm.get(); + } while (isspace(c)); + + if (c != open_paren) { + FSTERROR() << "is fst_weight_parentheses flag set correcty? "; + strm.clear(std::ios::badbit); + return strm; + } + + c = strm.get(); + + { // Read weight + W default_value; + stack<int> parens; + string s; + while (c != separator || !parens.empty()) { + if (c == EOF) { + strm.clear(std::ios::badbit); + return strm; + } + s += c; + // If parens encountered before separator, they must be matched + if (c == open_paren) { + parens.push(1); + } else if (c == close_paren) { + // Fail for mismatched parens + if (parens.empty()) { + strm.clear(std::ios::failbit); + return strm; + } + parens.pop(); + } + c = strm.get(); + } + istringstream sstrm(s); + sstrm >> default_value; + w.SetDefaultValue(default_value); + } + + c = strm.get(); + + { // Read n + string s; + while (c != separator) { + if (c == EOF) { + strm.clear(std::ios::badbit); + return strm; + } + s += c; + c = strm.get(); + } + istringstream sstrm(s); + sstrm >> n; + } + + // Read n elements + for (size_t i = 0; i < n; ++i) { + // discard separator + c = strm.get(); + K p; + W r; + + { // Read key + stack<int> parens; + string s; + while (c != separator || !parens.empty()) { + if (c == EOF) { + strm.clear(std::ios::badbit); + return strm; + } + s += c; + // If parens encountered before separator, they must be matched + if (c == open_paren) { + parens.push(1); + } else if (c == close_paren) { + // Fail for mismatched parens + if (parens.empty()) { + strm.clear(std::ios::failbit); + return strm; + } + parens.pop(); + } + c = strm.get(); + } + istringstream sstrm(s); + sstrm >> p; + } + + c = strm.get(); + + { // Read weight + stack<int> parens; + string s; + while (c != separator || !parens.empty()) { + if (c == EOF) { + strm.clear(std::ios::badbit); + return strm; + } + s += c; + // If parens encountered before separator, they must be matched + if (c == open_paren) { + parens.push(1); + } else if (c == close_paren) { + // Fail for mismatched parens + if (parens.empty()) { + strm.clear(std::ios::failbit); + return strm; + } + parens.pop(); + } + c = strm.get(); + } + istringstream sstrm(s); + sstrm >> r; + } + + w.Push(p, r); + } + + if (c != separator) { + FSTERROR() << " separator expected, not found! "; + strm.clear(std::ios::badbit); + return strm; + } + + c = strm.get(); + if (c != close_paren) { + FSTERROR() << " is fst_weight_parentheses flag set correcty? "; + strm.clear(std::ios::badbit); + return strm; + } + + return strm; +} + + + +} // namespace fst + +#endif // FST_LIB_SPARSE_TUPLE_WEIGHT_H__ diff --git a/kaldi_io/src/tools/openfst/include/fst/state-map.h b/kaldi_io/src/tools/openfst/include/fst/state-map.h new file mode 100644 index 0000000..9d6db74 --- /dev/null +++ b/kaldi_io/src/tools/openfst/include/fst/state-map.h @@ -0,0 +1,605 @@ +// map.h + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Copyright 2005-2010 Google, Inc. +// Author: [email protected] (Michael Riley) +// +// \file +// Class to map over/transform states e.g., sort transitions +// Consider using when operation does not change the number of states. + +#ifndef FST_LIB_STATE_MAP_H__ +#define FST_LIB_STATE_MAP_H__ + +#include <algorithm> +#include <tr1/unordered_map> +using std::tr1::unordered_map; +using std::tr1::unordered_multimap; +#include <string> +#include <utility> +using std::pair; using std::make_pair; + +#include <fst/cache.h> +#include <fst/arc-map.h> +#include <fst/mutable-fst.h> + + +namespace fst { + +// StateMapper Interface - class determinies how states are mapped. +// Useful for implementing operations that do not change the number of states. +// +// class StateMapper { +// public: +// typedef A FromArc; +// typedef B ToArc; +// +// // Typical constructor +// StateMapper(const Fst<A> &fst); +// // Required copy constructor that allows updating Fst argument; +// // pass only if relevant and changed. +// StateMapper(const StateMapper &mapper, const Fst<A> *fst = 0); +// +// // Specifies initial state of result +// B::StateId Start() const; +// // Specifies state's final weight in result +// B::Weight Final(B::StateId s) const; +// +// // These methods iterate through a state's arcs in result +// // Specifies state to iterate over +// void SetState(B::StateId s); +// // End of arcs? +// bool Done() const; +// // Current arc + +// const B &Value() const; +// // Advance to next arc (when !Done) +// void Next(); +// +// // Specifies input symbol table action the mapper requires (see above). +// MapSymbolsAction InputSymbolsAction() const; +// // Specifies output symbol table action the mapper requires (see above). +// MapSymbolsAction OutputSymbolsAction() const; +// // This specifies the known properties of an Fst mapped by this +// // mapper. It takes as argument the input Fst's known properties. +// uint64 Properties(uint64 props) const; +// }; +// +// We include a various state map versions below. One dimension of +// variation is whether the mapping mutates its input, writes to a +// new result Fst, or is an on-the-fly Fst. Another dimension is how +// we pass the mapper. We allow passing the mapper by pointer +// for cases that we need to change the state of the user's mapper. +// We also include map versions that pass the mapper +// by value or const reference when this suffices. + +// Maps an arc type A using a mapper function object C, passed +// by pointer. This version modifies its Fst input. +template<class A, class C> +void StateMap(MutableFst<A> *fst, C* mapper) { + typedef typename A::StateId StateId; + typedef typename A::Weight Weight; + + if (mapper->InputSymbolsAction() == MAP_CLEAR_SYMBOLS) + fst->SetInputSymbols(0); + + if (mapper->OutputSymbolsAction() == MAP_CLEAR_SYMBOLS) + fst->SetOutputSymbols(0); + + if (fst->Start() == kNoStateId) + return; + + uint64 props = fst->Properties(kFstProperties, false); + + fst->SetStart(mapper->Start()); + + for (StateId s = 0; s < fst->NumStates(); ++s) { + mapper->SetState(s); + fst->DeleteArcs(s); + for (; !mapper->Done(); mapper->Next()) + fst->AddArc(s, mapper->Value()); + fst->SetFinal(s, mapper->Final(s)); + } + + fst->SetProperties(mapper->Properties(props), kFstProperties); +} + +// Maps an arc type A using a mapper function object C, passed +// by value. This version modifies its Fst input. +template<class A, class C> +void StateMap(MutableFst<A> *fst, C mapper) { + StateMap(fst, &mapper); +} + + +// Maps an arc type A to an arc type B using mapper function +// object C, passed by pointer. This version writes the mapped +// input Fst to an output MutableFst. +template<class A, class B, class C> +void StateMap(const Fst<A> &ifst, MutableFst<B> *ofst, C* mapper) { + typedef typename A::StateId StateId; + typedef typename A::Weight Weight; + + ofst->DeleteStates(); + + if (mapper->InputSymbolsAction() == MAP_COPY_SYMBOLS) + ofst->SetInputSymbols(ifst.InputSymbols()); + else if (mapper->InputSymbolsAction() == MAP_CLEAR_SYMBOLS) + ofst->SetInputSymbols(0); + + if (mapper->OutputSymbolsAction() == MAP_COPY_SYMBOLS) + ofst->SetOutputSymbols(ifst.OutputSymbols()); + else if (mapper->OutputSymbolsAction() == MAP_CLEAR_SYMBOLS) + ofst->SetOutputSymbols(0); + + uint64 iprops = ifst.Properties(kCopyProperties, false); + + if (ifst.Start() == kNoStateId) { + if (iprops & kError) ofst->SetProperties(kError, kError); + return; + } + + // Add all states. + if (ifst.Properties(kExpanded, false)) + ofst->ReserveStates(CountStates(ifst)); + for (StateIterator< Fst<A> > siter(ifst); !siter.Done(); siter.Next()) + ofst->AddState(); + + ofst->SetStart(mapper->Start()); + + for (StateIterator< Fst<A> > siter(ifst); !siter.Done(); siter.Next()) { + StateId s = siter.Value(); + mapper->SetState(s); + for (; !mapper->Done(); mapper->Next()) + ofst->AddArc(s, mapper->Value()); + ofst->SetFinal(s, mapper->Final(s)); + } + + uint64 oprops = ofst->Properties(kFstProperties, false); + ofst->SetProperties(mapper->Properties(iprops) | oprops, kFstProperties); +} + +// Maps an arc type A to an arc type B using mapper function +// object C, passed by value. This version writes the mapped input +// Fst to an output MutableFst. +template<class A, class B, class C> +void StateMap(const Fst<A> &ifst, MutableFst<B> *ofst, C mapper) { + StateMap(ifst, ofst, &mapper); +} + +typedef CacheOptions StateMapFstOptions; + +template <class A, class B, class C> class StateMapFst; + +// Implementation of delayed StateMapFst. +template <class A, class B, class C> +class StateMapFstImpl : public CacheImpl<B> { + public: + using FstImpl<B>::SetType; + using FstImpl<B>::SetProperties; + using FstImpl<B>::SetInputSymbols; + using FstImpl<B>::SetOutputSymbols; + + using VectorFstBaseImpl<typename CacheImpl<B>::State>::NumStates; + + using CacheImpl<B>::PushArc; + using CacheImpl<B>::HasArcs; + using CacheImpl<B>::HasFinal; + using CacheImpl<B>::HasStart; + using CacheImpl<B>::SetArcs; + using CacheImpl<B>::SetFinal; + using CacheImpl<B>::SetStart; + + friend class StateIterator< StateMapFst<A, B, C> >; + + typedef B Arc; + typedef typename B::Weight Weight; + typedef typename B::StateId StateId; + + StateMapFstImpl(const Fst<A> &fst, const C &mapper, + const StateMapFstOptions& opts) + : CacheImpl<B>(opts), + fst_(fst.Copy()), + mapper_(new C(mapper, fst_)), + own_mapper_(true) { + Init(); + } + + StateMapFstImpl(const Fst<A> &fst, C *mapper, + const StateMapFstOptions& opts) + : CacheImpl<B>(opts), + fst_(fst.Copy()), + mapper_(mapper), + own_mapper_(false) { + Init(); + } + + StateMapFstImpl(const StateMapFstImpl<A, B, C> &impl) + : CacheImpl<B>(impl), + fst_(impl.fst_->Copy(true)), + mapper_(new C(*impl.mapper_, fst_)), + own_mapper_(true) { + Init(); + } + + ~StateMapFstImpl() { + delete fst_; + if (own_mapper_) delete mapper_; + } + + StateId Start() { + if (!HasStart()) + SetStart(mapper_->Start()); + return CacheImpl<B>::Start(); + } + + Weight Final(StateId s) { + if (!HasFinal(s)) + SetFinal(s, mapper_->Final(s)); + return CacheImpl<B>::Final(s); + } + + size_t NumArcs(StateId s) { + if (!HasArcs(s)) + Expand(s); + return CacheImpl<B>::NumArcs(s); + } + + size_t NumInputEpsilons(StateId s) { + if (!HasArcs(s)) + Expand(s); + return CacheImpl<B>::NumInputEpsilons(s); + } + + size_t NumOutputEpsilons(StateId s) { + if (!HasArcs(s)) + Expand(s); + return CacheImpl<B>::NumOutputEpsilons(s); + } + + void InitStateIterator(StateIteratorData<A> *data) const { + fst_->InitStateIterator(data); + } + + void InitArcIterator(StateId s, ArcIteratorData<B> *data) { + if (!HasArcs(s)) + Expand(s); + CacheImpl<B>::InitArcIterator(s, data); + } + + uint64 Properties() const { return Properties(kFstProperties); } + + // Set error if found; return FST impl properties. + uint64 Properties(uint64 mask) const { + if ((mask & kError) && (fst_->Properties(kError, false) || + (mapper_->Properties(0) & kError))) + SetProperties(kError, kError); + return FstImpl<Arc>::Properties(mask); + } + + void Expand(StateId s) { + // Add exiting arcs. + for (mapper_->SetState(s); !mapper_->Done(); mapper_->Next()) + PushArc(s, mapper_->Value()); + SetArcs(s); + } + + const Fst<A> &GetFst() const { + return *fst_; + } + + private: + void Init() { + SetType("statemap"); + + if (mapper_->InputSymbolsAction() == MAP_COPY_SYMBOLS) + SetInputSymbols(fst_->InputSymbols()); + else if (mapper_->InputSymbolsAction() == MAP_CLEAR_SYMBOLS) + SetInputSymbols(0); + + if (mapper_->OutputSymbolsAction() == MAP_COPY_SYMBOLS) + SetOutputSymbols(fst_->OutputSymbols()); + else if (mapper_->OutputSymbolsAction() == MAP_CLEAR_SYMBOLS) + SetOutputSymbols(0); + + uint64 props = fst_->Properties(kCopyProperties, false); + SetProperties(mapper_->Properties(props)); + } + + const Fst<A> *fst_; + C* mapper_; + bool own_mapper_; + + void operator=(const StateMapFstImpl<A, B, C> &); // disallow +}; + + +// Maps an arc type A to an arc type B using Mapper function object +// C. This version is a delayed Fst. +template <class A, class B, class C> +class StateMapFst : public ImplToFst< StateMapFstImpl<A, B, C> > { + public: + friend class ArcIterator< StateMapFst<A, B, C> >; + + typedef B Arc; + typedef typename B::Weight Weight; + typedef typename B::StateId StateId; + typedef CacheState<B> State; + typedef StateMapFstImpl<A, B, C> Impl; + + StateMapFst(const Fst<A> &fst, const C &mapper, + const StateMapFstOptions& opts) + : ImplToFst<Impl>(new Impl(fst, mapper, opts)) {} + + StateMapFst(const Fst<A> &fst, C* mapper, const StateMapFstOptions& opts) + : ImplToFst<Impl>(new Impl(fst, mapper, opts)) {} + + StateMapFst(const Fst<A> &fst, const C &mapper) + : ImplToFst<Impl>(new Impl(fst, mapper, StateMapFstOptions())) {} + + StateMapFst(const Fst<A> &fst, C* mapper) + : ImplToFst<Impl>(new Impl(fst, mapper, StateMapFstOptions())) {} + + // See Fst<>::Copy() for doc. + StateMapFst(const StateMapFst<A, B, C> &fst, bool safe = false) + : ImplToFst<Impl>(fst, safe) {} + + // Get a copy of this StateMapFst. See Fst<>::Copy() for further doc. + virtual StateMapFst<A, B, C> *Copy(bool safe = false) const { + return new StateMapFst<A, B, C>(*this, safe); + } + + virtual void InitStateIterator(StateIteratorData<A> *data) const { + GetImpl()->InitStateIterator(data); + } + + virtual void InitArcIterator(StateId s, ArcIteratorData<B> *data) const { + GetImpl()->InitArcIterator(s, data); + } + + protected: + Impl *GetImpl() const { return ImplToFst<Impl>::GetImpl(); } + + private: + void operator=(const StateMapFst<A, B, C> &fst); // disallow +}; + + +// Specialization for StateMapFst. +template <class A, class B, class C> +class ArcIterator< StateMapFst<A, B, C> > + : public CacheArcIterator< StateMapFst<A, B, C> > { + public: + typedef typename A::StateId StateId; + + ArcIterator(const StateMapFst<A, B, C> &fst, StateId s) + : CacheArcIterator< StateMapFst<A, B, C> >(fst.GetImpl(), s) { + if (!fst.GetImpl()->HasArcs(s)) + fst.GetImpl()->Expand(s); + } + + private: + DISALLOW_COPY_AND_ASSIGN(ArcIterator); +}; + +// +// Utility Mappers +// + +// Mapper that returns its input. +template <class A> +class IdentityStateMapper { + public: + typedef A FromArc; + typedef A ToArc; + + typedef typename A::StateId StateId; + typedef typename A::Weight Weight; + + explicit IdentityStateMapper(const Fst<A> &fst) : fst_(fst), aiter_(0) {} + + // Allows updating Fst argument; pass only if changed. + IdentityStateMapper(const IdentityStateMapper<A> &mapper, + const Fst<A> *fst = 0) + : fst_(fst ? *fst : mapper.fst_), aiter_(0) {} + + ~IdentityStateMapper() { delete aiter_; } + + StateId Start() const { return fst_.Start(); } + + Weight Final(StateId s) const { return fst_.Final(s); } + + void SetState(StateId s) { + if (aiter_) delete aiter_; + aiter_ = new ArcIterator< Fst<A> >(fst_, s); + } + + bool Done() const { return aiter_->Done(); } + const A &Value() const { return aiter_->Value(); } + void Next() { aiter_->Next(); } + + MapSymbolsAction InputSymbolsAction() const { return MAP_COPY_SYMBOLS; } + MapSymbolsAction OutputSymbolsAction() const { return MAP_COPY_SYMBOLS;} + + uint64 Properties(uint64 props) const { return props; } + + private: + const Fst<A> &fst_; + ArcIterator< Fst<A> > *aiter_; +}; + +template <class A> +class ArcSumMapper { + public: + typedef A FromArc; + typedef A ToArc; + + typedef typename A::StateId StateId; + typedef typename A::Weight Weight; + + explicit ArcSumMapper(const Fst<A> &fst) : fst_(fst), i_(0) {} + + // Allows updating Fst argument; pass only if changed. + ArcSumMapper(const ArcSumMapper<A> &mapper, + const Fst<A> *fst = 0) + : fst_(fst ? *fst : mapper.fst_), i_(0) {} + + StateId Start() const { return fst_.Start(); } + Weight Final(StateId s) const { return fst_.Final(s); } + + void SetState(StateId s) { + i_ = 0; + arcs_.clear(); + arcs_.reserve(fst_.NumArcs(s)); + for (ArcIterator<Fst<A> > aiter(fst_, s); !aiter.Done(); aiter.Next()) + arcs_.push_back(aiter.Value()); + + // First sorts the exiting arcs by input label, output label + // and destination state and then sums weights of arcs with + // the same input label, output label, and destination state. + sort(arcs_.begin(), arcs_.end(), comp_); + size_t narcs = 0; + for (size_t i = 0; i < arcs_.size(); ++i) { + if (narcs > 0 && equal_(arcs_[i], arcs_[narcs - 1])) { + arcs_[narcs - 1].weight = Plus(arcs_[narcs - 1].weight, + arcs_[i].weight); + } else { + arcs_[narcs++] = arcs_[i]; + } + } + arcs_.resize(narcs); + } + + bool Done() const { return i_ >= arcs_.size(); } + const A &Value() const { return arcs_[i_]; } + void Next() { ++i_; } + + MapSymbolsAction InputSymbolsAction() const { return MAP_COPY_SYMBOLS; } + MapSymbolsAction OutputSymbolsAction() const { return MAP_COPY_SYMBOLS; } + + uint64 Properties(uint64 props) const { + return props & kArcSortProperties & + kDeleteArcsProperties & kWeightInvariantProperties; + } + + private: + struct Compare { + bool operator()(const A& x, const A& y) { + if (x.ilabel < y.ilabel) return true; + if (x.ilabel > y.ilabel) return false; + if (x.olabel < y.olabel) return true; + if (x.olabel > y.olabel) return false; + if (x.nextstate < y.nextstate) return true; + if (x.nextstate > y.nextstate) return false; + return false; + } + }; + + struct Equal { + bool operator()(const A& x, const A& y) { + return (x.ilabel == y.ilabel && + x.olabel == y.olabel && + x.nextstate == y.nextstate); + } + }; + + const Fst<A> &fst_; + Compare comp_; + Equal equal_; + vector<A> arcs_; + ssize_t i_; // current arc position + + void operator=(const ArcSumMapper<A> &); // disallow +}; + +template <class A> +class ArcUniqueMapper { + public: + typedef A FromArc; + typedef A ToArc; + + typedef typename A::StateId StateId; + typedef typename A::Weight Weight; + + explicit ArcUniqueMapper(const Fst<A> &fst) : fst_(fst), i_(0) {} + + // Allows updating Fst argument; pass only if changed. + ArcUniqueMapper(const ArcUniqueMapper<A> &mapper, + const Fst<A> *fst = 0) + : fst_(fst ? *fst : mapper.fst_), i_(0) {} + + StateId Start() const { return fst_.Start(); } + Weight Final(StateId s) const { return fst_.Final(s); } + + void SetState(StateId s) { + i_ = 0; + arcs_.clear(); + arcs_.reserve(fst_.NumArcs(s)); + for (ArcIterator<Fst<A> > aiter(fst_, s); !aiter.Done(); aiter.Next()) + arcs_.push_back(aiter.Value()); + + // First sorts the exiting arcs by input label, output label + // and destination state and then uniques identical arcs + sort(arcs_.begin(), arcs_.end(), comp_); + typename vector<A>::iterator unique_end = + unique(arcs_.begin(), arcs_.end(), equal_); + arcs_.resize(unique_end - arcs_.begin()); + } + + bool Done() const { return i_ >= arcs_.size(); } + const A &Value() const { return arcs_[i_]; } + void Next() { ++i_; } + + MapSymbolsAction InputSymbolsAction() const { return MAP_COPY_SYMBOLS; } + MapSymbolsAction OutputSymbolsAction() const { return MAP_COPY_SYMBOLS; } + + uint64 Properties(uint64 props) const { + return props & kArcSortProperties & kDeleteArcsProperties; + } + + private: + struct Compare { + bool operator()(const A& x, const A& y) { + if (x.ilabel < y.ilabel) return true; + if (x.ilabel > y.ilabel) return false; + if (x.olabel < y.olabel) return true; + if (x.olabel > y.olabel) return false; + if (x.nextstate < y.nextstate) return true; + if (x.nextstate > y.nextstate) return false; + return false; + } + }; + + struct Equal { + bool operator()(const A& x, const A& y) { + return (x.ilabel == y.ilabel && + x.olabel == y.olabel && + x.nextstate == y.nextstate && + x.weight == y.weight); + } + }; + + const Fst<A> &fst_; + Compare comp_; + Equal equal_; + vector<A> arcs_; + ssize_t i_; // current arc position + + void operator=(const ArcUniqueMapper<A> &); // disallow +}; + + +} // namespace fst + +#endif // FST_LIB_STATE_MAP_H__ diff --git a/kaldi_io/src/tools/openfst/include/fst/state-reachable.h b/kaldi_io/src/tools/openfst/include/fst/state-reachable.h new file mode 100644 index 0000000..6d0c971 --- /dev/null +++ b/kaldi_io/src/tools/openfst/include/fst/state-reachable.h @@ -0,0 +1,198 @@ +// state-reachable.h + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Copyright 2005-2010 Google, Inc. +// Author: [email protected] (Michael Riley) +// +// \file +// Class to determine whether a given (final) state can be reached from some +// other given state. + +#ifndef FST_LIB_STATE_REACHABLE_H__ +#define FST_LIB_STATE_REACHABLE_H__ + +#include <vector> +using std::vector; + +#include <fst/dfs-visit.h> +#include <fst/fst.h> +#include <fst/interval-set.h> + + +namespace fst { + +// Computes the (final) states reachable from a given state in an FST. +// After this visitor has been called, a final state f can be reached +// from a state s iff (*isets)[s].Member(state2index[f]) is true, where +// (*isets[s]) is a set of half-open inteval of final state indices +// and state2index[f] maps from a final state to its index. +// +// If state2index is empty, it is filled-in with suitable indices. +// If it is non-empty, those indices are used; in this case, the +// final states must have out-degree 0. +template <class A, typename I = typename A::StateId> +class IntervalReachVisitor { + public: + typedef typename A::StateId StateId; + typedef typename A::Label Label; + typedef typename A::Weight Weight; + typedef typename IntervalSet<I>::Interval Interval; + + IntervalReachVisitor(const Fst<A> &fst, + vector< IntervalSet<I> > *isets, + vector<I> *state2index) + : fst_(fst), + isets_(isets), + state2index_(state2index), + index_(state2index->empty() ? 1 : -1), + error_(false) { + isets_->clear(); + } + + void InitVisit(const Fst<A> &fst) { error_ = false; } + + bool InitState(StateId s, StateId r) { + while (isets_->size() <= s) + isets_->push_back(IntervalSet<Label>()); + while (state2index_->size() <= s) + state2index_->push_back(-1); + + if (fst_.Final(s) != Weight::Zero()) { + // Create tree interval + vector<Interval> *intervals = (*isets_)[s].Intervals(); + if (index_ < 0) { // Use state2index_ map to set index + if (fst_.NumArcs(s) > 0) { + FSTERROR() << "IntervalReachVisitor: state2index map must be empty " + << "for this FST"; + error_ = true; + return false; + } + I index = (*state2index_)[s]; + if (index < 0) { + FSTERROR() << "IntervalReachVisitor: state2index map incomplete"; + error_ = true; + return false; + } + intervals->push_back(Interval(index, index + 1)); + } else { // Use pre-order index + intervals->push_back(Interval(index_, index_ + 1)); + (*state2index_)[s] = index_++; + } + } + return true; + } + + bool TreeArc(StateId s, const A &arc) { + return true; + } + + bool BackArc(StateId s, const A &arc) { + FSTERROR() << "IntervalReachVisitor: cyclic input"; + error_ = true; + return false; + } + + bool ForwardOrCrossArc(StateId s, const A &arc) { + // Non-tree interval + (*isets_)[s].Union((*isets_)[arc.nextstate]); + return true; + } + + void FinishState(StateId s, StateId p, const A *arc) { + if (index_ >= 0 && fst_.Final(s) != Weight::Zero()) { + vector<Interval> *intervals = (*isets_)[s].Intervals(); + (*intervals)[0].end = index_; // Update tree interval end + } + (*isets_)[s].Normalize(); + if (p != kNoStateId) + (*isets_)[p].Union((*isets_)[s]); // Propagate intervals to parent + } + + void FinishVisit() {} + + bool Error() const { return error_; } + + private: + const Fst<A> &fst_; + vector< IntervalSet<I> > *isets_; + vector<I> *state2index_; + I index_; + bool error_; +}; + + +// Tests reachability of final states from a given state. To test for +// reachability from a state s, first do SetState(s). Then a final +// state f can be reached from state s of FST iff Reach(f) is true. +template <class A, typename I = typename A::StateId> +class StateReachable { + public: + typedef A Arc; + typedef I Index; + typedef typename A::StateId StateId; + typedef typename A::Label Label; + typedef typename A::Weight Weight; + typedef typename IntervalSet<I>::Interval Interval; + + StateReachable(const Fst<A> &fst) + : error_(false) { + IntervalReachVisitor<Arc> reach_visitor(fst, &isets_, &state2index_); + DfsVisit(fst, &reach_visitor); + if (reach_visitor.Error()) error_ = true; + } + + StateReachable(const StateReachable<A> &reachable) { + FSTERROR() << "Copy constructor for state reachable class " + << "not yet implemented."; + error_ = true; + } + + // Set current state. + void SetState(StateId s) { s_ = s; } + + // Can reach this label from current state? + bool Reach(StateId s) { + if (s >= state2index_.size()) + return false; + + I i = state2index_[s]; + if (i < 0) { + FSTERROR() << "StateReachable: state non-final: " << s; + error_ = true; + return false; + } + return isets_[s_].Member(i); + } + + // Access to the state-to-index mapping. Unassigned states have index -1. + vector<I> &State2Index() { return state2index_; } + + // Access to the interval sets. These specify the reachability + // to the final states as intervals of the final state indices. + const vector< IntervalSet<I> > &IntervalSets() { return isets_; } + + bool Error() const { return error_; } + + private: + StateId s_; // Current state + vector< IntervalSet<I> > isets_; // Interval sets per state + vector<I> state2index_; // Finds index for a final state + bool error_; + + void operator=(const StateReachable<A> &); // Disallow +}; + +} // namespace fst + +#endif // FST_LIB_STATE_REACHABLE_H__ diff --git a/kaldi_io/src/tools/openfst/include/fst/state-table.h b/kaldi_io/src/tools/openfst/include/fst/state-table.h new file mode 100644 index 0000000..d8107a1 --- /dev/null +++ b/kaldi_io/src/tools/openfst/include/fst/state-table.h @@ -0,0 +1,481 @@ +// state-table.h + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Copyright 2005-2010 Google, Inc. +// Author: [email protected] (Michael Riley) +// +// \file +// Classes for representing the mapping between state tuples and state Ids. + +#ifndef FST_LIB_STATE_TABLE_H__ +#define FST_LIB_STATE_TABLE_H__ + +#include <deque> +using std::deque; +#include <vector> +using std::vector; + +#include <fst/bi-table.h> +#include <fst/expanded-fst.h> + + +namespace fst { + +// STATE TABLES - these determine the bijective mapping between state +// tuples (e.g. in composition triples of two FST states and a +// composition filter state) and their corresponding state IDs. +// They are classes, templated on state tuples, of the form: +// +// template <class T> +// class StateTable { +// public: +// typedef typename T StateTuple; +// +// // Required constructors. +// StateTable(); +// +// // Lookup state ID by tuple. If it doesn't exist, then add it. +// StateId FindState(const StateTuple &); +// // Lookup state tuple by state ID. +// const StateTuple<StateId> &Tuple(StateId) const; +// // # of stored tuples. +// StateId Size() const; +// }; +// +// A state tuple has the form: +// +// template <class S> +// struct StateTuple { +// typedef typename S StateId; +// +// // Required constructors. +// StateTuple(); +// StateTuple(const StateTuple &); +// }; + + +// An implementation using a hash map for the tuple to state ID mapping. +// The state tuple T must have == defined. H is the hash function. +template <class T, class H> +class HashStateTable : public HashBiTable<typename T::StateId, T, H> { + public: + typedef T StateTuple; + typedef typename StateTuple::StateId StateId; + using HashBiTable<StateId, T, H>::FindId; + using HashBiTable<StateId, T, H>::FindEntry; + using HashBiTable<StateId, T, H>::Size; + + HashStateTable() : HashBiTable<StateId, T, H>() {} + + // Reserves space for table_size elements. + explicit HashStateTable(size_t table_size) + : HashBiTable<StateId, T, H>(table_size) {} + + StateId FindState(const StateTuple &tuple) { return FindId(tuple); } + const StateTuple &Tuple(StateId s) const { return FindEntry(s); } +}; + + +// An implementation using a hash map for the tuple to state ID mapping. +// The state tuple T must have == defined. H is the hash function. +template <class T, class H> +class CompactHashStateTable + : public CompactHashBiTable<typename T::StateId, T, H> { + public: + typedef T StateTuple; + typedef typename StateTuple::StateId StateId; + using CompactHashBiTable<StateId, T, H>::FindId; + using CompactHashBiTable<StateId, T, H>::FindEntry; + using CompactHashBiTable<StateId, T, H>::Size; + + CompactHashStateTable() : CompactHashBiTable<StateId, T, H>() {} + + // Reserves space for 'table_size' elements. + explicit CompactHashStateTable(size_t table_size) + : CompactHashBiTable<StateId, T, H>(table_size) {} + + StateId FindState(const StateTuple &tuple) { return FindId(tuple); } + const StateTuple &Tuple(StateId s) const { return FindEntry(s); } +}; + +// An implementation using a vector for the tuple to state mapping. +// It is passed a function object FP that should fingerprint tuples +// uniquely to an integer that can used as a vector index. Normally, +// VectorStateTable constructs the FP object. The user can instead +// pass in this object; in that case, VectorStateTable takes its +// ownership. +template <class T, class FP> +class VectorStateTable + : public VectorBiTable<typename T::StateId, T, FP> { + public: + typedef T StateTuple; + typedef typename StateTuple::StateId StateId; + using VectorBiTable<StateId, T, FP>::FindId; + using VectorBiTable<StateId, T, FP>::FindEntry; + using VectorBiTable<StateId, T, FP>::Size; + using VectorBiTable<StateId, T, FP>::Fingerprint; + + // Reserves space for 'table_size' elements. + explicit VectorStateTable(FP *fp = 0, size_t table_size = 0) + : VectorBiTable<StateId, T, FP>(fp, table_size) {} + + StateId FindState(const StateTuple &tuple) { return FindId(tuple); } + const StateTuple &Tuple(StateId s) const { return FindEntry(s); } +}; + + +// An implementation using a vector and a compact hash table. The +// selecting functor S returns true for tuples to be hashed in the +// vector. The fingerprinting functor FP returns a unique fingerprint +// for each tuple to be hashed in the vector (these need to be +// suitable for indexing in a vector). The hash functor H is used when +// hashing tuple into the compact hash table. +template <class T, class S, class FP, class H> +class VectorHashStateTable + : public VectorHashBiTable<typename T::StateId, T, S, FP, H> { + public: + typedef T StateTuple; + typedef typename StateTuple::StateId StateId; + using VectorHashBiTable<StateId, T, S, FP, H>::FindId; + using VectorHashBiTable<StateId, T, S, FP, H>::FindEntry; + using VectorHashBiTable<StateId, T, S, FP, H>::Size; + using VectorHashBiTable<StateId, T, S, FP, H>::Selector; + using VectorHashBiTable<StateId, T, S, FP, H>::Fingerprint; + using VectorHashBiTable<StateId, T, S, FP, H>::Hash; + + VectorHashStateTable(S *s, FP *fp, H *h, + size_t vector_size = 0, + size_t tuple_size = 0) + : VectorHashBiTable<StateId, T, S, FP, H>( + s, fp, h, vector_size, tuple_size) {} + + StateId FindState(const StateTuple &tuple) { return FindId(tuple); } + const StateTuple &Tuple(StateId s) const { return FindEntry(s); } +}; + + +// An implementation using a hash map for the tuple to state ID +// mapping. This version permits erasing of states. The state tuple T +// must have == defined and its default constructor must produce a +// tuple that will never be seen. F is the hash function. +template <class T, class F> +class ErasableStateTable : public ErasableBiTable<typename T::StateId, T, F> { + public: + typedef T StateTuple; + typedef typename StateTuple::StateId StateId; + using ErasableBiTable<StateId, T, F>::FindId; + using ErasableBiTable<StateId, T, F>::FindEntry; + using ErasableBiTable<StateId, T, F>::Size; + using ErasableBiTable<StateId, T, F>::Erase; + + ErasableStateTable() : ErasableBiTable<StateId, T, F>() {} + StateId FindState(const StateTuple &tuple) { return FindId(tuple); } + const StateTuple &Tuple(StateId s) const { return FindEntry(s); } +}; + +// +// COMPOSITION STATE TUPLES AND TABLES +// +// The composition state table has the form: +// +// template <class A, class F> +// class ComposeStateTable { +// public: +// typedef A Arc; +// typedef F FilterState; +// typedef typename A::StateId StateId; +// typedef ComposeStateTuple<StateId> StateTuple; +// +// // Required constructors. Copy constructor does not copy state. +// ComposeStateTable(const Fst<Arc> &fst1, const Fst<Arc> &fst2); +// ComposeStateTable(const ComposeStateTable<A, F> &table); +// // Lookup state ID by tuple. If it doesn't exist, then add it. +// StateId FindState(const StateTuple &); +// // Lookup state tuple by state ID. +// const StateTuple<StateId> &Tuple(StateId) const; +// // # of stored tuples. +// StateId Size() const; +// // Return true if error encountered +// bool Error() const; +// }; + +// Represents the composition state. +template <typename S, typename F> +struct ComposeStateTuple { + typedef S StateId; + typedef F FilterState; + + ComposeStateTuple() + : state_id1(kNoStateId), state_id2(kNoStateId), + filter_state(FilterState::NoState()) {} + + ComposeStateTuple(StateId s1, StateId s2, const FilterState &f) + : state_id1(s1), state_id2(s2), filter_state(f) {} + + StateId state_id1; // State Id on fst1 + StateId state_id2; // State Id on fst2 + FilterState filter_state; // State of composition filter +}; + +// Equality of composition state tuples. +template <typename S, typename F> +inline bool operator==(const ComposeStateTuple<S, F>& x, + const ComposeStateTuple<S, F>& y) { + if (&x == &y) + return true; + return x.state_id1 == y.state_id1 && + x.state_id2 == y.state_id2 && + x.filter_state == y.filter_state; +} + + +// Hashing of composition state tuples. +template <typename S, typename F> +class ComposeHash { + public: + size_t operator()(const ComposeStateTuple<S, F>& t) const { + return t.state_id1 + t.state_id2 * kPrime0 + + t.filter_state.Hash() * kPrime1; + } + private: + static const size_t kPrime0; + static const size_t kPrime1; +}; + +template <typename S, typename F> +const size_t ComposeHash<S, F>::kPrime0 = 7853; + +template <typename S, typename F> +const size_t ComposeHash<S, F>::kPrime1 = 7867; + + +// A HashStateTable over composition tuples. +template <typename A, + typename F, + typename H = + CompactHashStateTable<ComposeStateTuple<typename A::StateId, F>, + ComposeHash<typename A::StateId, F> > > +class GenericComposeStateTable : public H { + public: + typedef A Arc; + typedef typename A::StateId StateId; + typedef F FilterState; + typedef ComposeStateTuple<StateId, F> StateTuple; + + GenericComposeStateTable(const Fst<A> &fst1, const Fst<A> &fst2) {} + + // Reserves space for 'table_size' elements. + GenericComposeStateTable(const Fst<A> &fst1, const Fst<A> &fst2, + size_t table_size) : H(table_size) {} + + bool Error() const { return false; } + + private: + void operator=(const GenericComposeStateTable<A, F> &table); // disallow +}; + + +// Fingerprint for general composition tuples. +template <typename S, typename F> +class ComposeFingerprint { + public: + typedef S StateId; + typedef F FilterState; + typedef ComposeStateTuple<S, F> StateTuple; + + // Required but suboptimal constructor. + ComposeFingerprint() : mult1_(8192), mult2_(8192) { + LOG(WARNING) << "TupleFingerprint: # of FST states should be provided."; + } + + // Constructor is provided the sizes of the input FSTs + ComposeFingerprint(StateId nstates1, StateId nstates2) + : mult1_(nstates1), mult2_(nstates1 * nstates2) { } + + size_t operator()(const StateTuple &tuple) { + return tuple.state_id1 + tuple.state_id2 * mult1_ + + tuple.filter_state.Hash() * mult2_; + } + + private: + ssize_t mult1_; + ssize_t mult2_; +}; + + +// Useful when the first composition state determines the tuple. +template <typename S, typename F> +class ComposeState1Fingerprint { + public: + typedef S StateId; + typedef F FilterState; + typedef ComposeStateTuple<S, F> StateTuple; + + size_t operator()(const StateTuple &tuple) { return tuple.state_id1; } +}; + + +// Useful when the second composition state determines the tuple. +template <typename S, typename F> +class ComposeState2Fingerprint { + public: + typedef S StateId; + typedef F FilterState; + typedef ComposeStateTuple<S, F> StateTuple; + + size_t operator()(const StateTuple &tuple) { return tuple.state_id2; } +}; + + +// A VectorStateTable over composition tuples. This can be used when +// the product of number of states in FST1 and FST2 (and the +// composition filter state hash) is manageable. If the FSTs are not +// expanded Fsts, they will first have their states counted. +template <typename A, typename F> +class ProductComposeStateTable : public +VectorStateTable<ComposeStateTuple<typename A::StateId, F>, + ComposeFingerprint<typename A::StateId, F> > { + public: + typedef A Arc; + typedef typename A::StateId StateId; + typedef F FilterState; + typedef ComposeStateTuple<StateId, F> StateTuple; + typedef VectorStateTable<StateTuple, + ComposeFingerprint<StateId, F> > StateTable; + + // Reserves space for 'table_size' elements. + ProductComposeStateTable(const Fst<A> &fst1, const Fst<A> &fst2, + size_t table_size = 0) + : StateTable(new ComposeFingerprint<StateId, F>(CountStates(fst1), + CountStates(fst2)), + table_size) {} + + ProductComposeStateTable(const ProductComposeStateTable<A, F> &table) + : StateTable(new ComposeFingerprint<StateId, F>(table.Fingerprint())) {} + + bool Error() const { return false; } + + private: + void operator=(const ProductComposeStateTable<A, F> &table); // disallow +}; + +// A VectorStateTable over composition tuples. This can be used when +// FST1 is a string (satisfies kStringProperties) and FST2 is +// epsilon-free and deterministic. It should be used with a +// composition filter that creates at most one filter state per tuple +// under these conditions (e.g. SequenceComposeFilter or +// MatchComposeFilter). +template <typename A, typename F> +class StringDetComposeStateTable : public +VectorStateTable<ComposeStateTuple<typename A::StateId, F>, + ComposeState1Fingerprint<typename A::StateId, F> > { + public: + typedef A Arc; + typedef typename A::StateId StateId; + typedef F FilterState; + typedef ComposeStateTuple<StateId, F> StateTuple; + typedef VectorStateTable<StateTuple, + ComposeState1Fingerprint<StateId, F> > StateTable; + + StringDetComposeStateTable(const Fst<A> &fst1, const Fst<A> &fst2) + : error_(false) { + uint64 props1 = kString; + uint64 props2 = kIDeterministic | kNoIEpsilons; + if (fst1.Properties(props1, true) != props1 || + fst2.Properties(props2, true) != props2) { + FSTERROR() << "StringDetComposeStateTable: fst1 not a string or" + << " fst2 not input deterministic and epsilon-free"; + error_ = true; + } + } + + StringDetComposeStateTable(const StringDetComposeStateTable<A, F> &table) + : StateTable(table), error_(table.error_) {} + + bool Error() const { return error_; } + + private: + bool error_; + + void operator=(const StringDetComposeStateTable<A, F> &table); // disallow +}; + + +// A VectorStateTable over composition tuples. This can be used when +// FST2 is a string (satisfies kStringProperties) and FST1 is +// epsilon-free and deterministic. It should be used with a +// composition filter that creates at most one filter state per tuple +// under these conditions (e.g. SequenceComposeFilter or +// MatchComposeFilter). +template <typename A, typename F> +class DetStringComposeStateTable : public +VectorStateTable<ComposeStateTuple<typename A::StateId, F>, + ComposeState2Fingerprint<typename A::StateId, F> > { + public: + typedef A Arc; + typedef typename A::StateId StateId; + typedef F FilterState; + typedef ComposeStateTuple<StateId, F> StateTuple; + typedef VectorStateTable<StateTuple, + ComposeState2Fingerprint<StateId, F> > StateTable; + + DetStringComposeStateTable(const Fst<A> &fst1, const Fst<A> &fst2) + :error_(false) { + uint64 props1 = kODeterministic | kNoOEpsilons; + uint64 props2 = kString; + if (fst1.Properties(props1, true) != props1 || + fst2.Properties(props2, true) != props2) { + FSTERROR() << "StringDetComposeStateTable: fst2 not a string or" + << " fst1 not output deterministic and epsilon-free"; + error_ = true; + } + } + + DetStringComposeStateTable(const DetStringComposeStateTable<A, F> &table) + : StateTable(table), error_(table.error_) {} + + bool Error() const { return error_; } + + private: + bool error_; + + void operator=(const DetStringComposeStateTable<A, F> &table); // disallow +}; + + +// An ErasableStateTable over composition tuples. The Erase(StateId) method +// can be called if the user either is sure that composition will never return +// to that tuple or doesn't care that if it does, it is assigned a new +// state ID. +template <typename A, typename F> +class ErasableComposeStateTable : public +ErasableStateTable<ComposeStateTuple<typename A::StateId, F>, + ComposeHash<typename A::StateId, F> > { + public: + typedef A Arc; + typedef typename A::StateId StateId; + typedef F FilterState; + typedef ComposeStateTuple<StateId, F> StateTuple; + + ErasableComposeStateTable(const Fst<A> &fst1, const Fst<A> &fst2) {} + + bool Error() const { return false; } + + private: + void operator=(const ErasableComposeStateTable<A, F> &table); // disallow +}; + +} // namespace fst + +#endif // FST_LIB_STATE_TABLE_H__ diff --git a/kaldi_io/src/tools/openfst/include/fst/statesort.h b/kaldi_io/src/tools/openfst/include/fst/statesort.h new file mode 100644 index 0000000..6f827f4 --- /dev/null +++ b/kaldi_io/src/tools/openfst/include/fst/statesort.h @@ -0,0 +1,97 @@ +// statesort.h + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Copyright 2005-2010 Google, Inc. +// Author: [email protected] (Michael Riley) +// +// \file +// Function to sort states of an Fst. + +#ifndef FST_LIB_STATESORT_H__ +#define FST_LIB_STATESORT_H__ + +#include <vector> +using std::vector; +#include <algorithm> + +#include <fst/mutable-fst.h> + + +namespace fst { + +// Sorts the input states of an FST, modifying it. ORDER[i] gives the +// the state Id after sorting that corresponds to state Id i before +// sorting. ORDER must be a permutation of FST's states ID sequence: +// (0, 1, 2, ..., fst->NumStates() - 1). +template <class Arc> +void StateSort(MutableFst<Arc> *fst, + const vector<typename Arc::StateId> &order) { + typedef typename Arc::StateId StateId; + typedef typename Arc::Weight Weight; + + if (order.size() != fst->NumStates()) { + FSTERROR() << "StateSort: bad order vector size: " << order.size(); + fst->SetProperties(kError, kError); + return; + } + + if (fst->Start() == kNoStateId) + return; + + uint64 props = fst->Properties(kStateSortProperties, false); + + vector<bool> done(order.size(), false); + vector<Arc> arcsa, arcsb; + vector<Arc> *arcs1 = &arcsa, *arcs2 = &arcsb; + + fst->SetStart(order[fst->Start()]); + + for (StateIterator< MutableFst<Arc> > siter(*fst); + !siter.Done(); + siter.Next()) { + StateId s1 = siter.Value(), s2; + if (done[s1]) + continue; + Weight final1 = fst->Final(s1), final2 = Weight::Zero(); + arcs1->clear(); + for (ArcIterator< MutableFst<Arc> > aiter(*fst, s1); + !aiter.Done(); + aiter.Next()) + arcs1->push_back(aiter.Value()); + for (; !done[s1]; s1 = s2, final1 = final2, swap(arcs1, arcs2)) { + s2 = order[s1]; + if (!done[s2]) { + final2 = fst->Final(s2); + arcs2->clear(); + for (ArcIterator< MutableFst<Arc> > aiter(*fst, s2); + !aiter.Done(); + aiter.Next()) + arcs2->push_back(aiter.Value()); + } + fst->SetFinal(s2, final1); + fst->DeleteArcs(s2); + for (size_t i = 0; i < arcs1->size(); ++i) { + Arc arc = (*arcs1)[i]; + arc.nextstate = order[arc.nextstate]; + fst->AddArc(s2, arc); + } + done[s1] = true; + } + } + fst->SetProperties(props, kFstProperties); +} + +} // namespace fst + +#endif // FST_LIB_STATESORT_H__ diff --git a/kaldi_io/src/tools/openfst/include/fst/string-weight.h b/kaldi_io/src/tools/openfst/include/fst/string-weight.h new file mode 100644 index 0000000..1beeb33 --- /dev/null +++ b/kaldi_io/src/tools/openfst/include/fst/string-weight.h @@ -0,0 +1,560 @@ +// string-weight.h + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Copyright 2005-2010 Google, Inc. +// Author: [email protected] (Michael Riley) +// +// \file +// String weight set and associated semiring operation definitions. + +#ifndef FST_LIB_STRING_WEIGHT_H__ +#define FST_LIB_STRING_WEIGHT_H__ + +#include <list> +#include <string> + +#include <fst/product-weight.h> +#include <fst/weight.h> + +namespace fst { + +const int kStringInfinity = -1; // Label for the infinite string +const int kStringBad = -2; // Label for a non-string +const char kStringSeparator = '_'; // Label separator in strings + +// Determines whether to use left or right string semiring. Includes +// restricted versions that signal an error if proper prefixes +// (suffixes) would otherwise be returned by Plus, useful with various +// algorithms that require functional transducer input with the +// string semirings. +enum StringType { STRING_LEFT = 0, STRING_RIGHT = 1 , + STRING_LEFT_RESTRICT = 2, STRING_RIGHT_RESTRICT }; + +#define REVERSE_STRING_TYPE(S) \ + ((S) == STRING_LEFT ? STRING_RIGHT : \ + ((S) == STRING_RIGHT ? STRING_LEFT : \ + ((S) == STRING_LEFT_RESTRICT ? STRING_RIGHT_RESTRICT : \ + STRING_LEFT_RESTRICT))) + +template <typename L, StringType S = STRING_LEFT> +class StringWeight; + +template <typename L, StringType S = STRING_LEFT> +class StringWeightIterator; + +template <typename L, StringType S = STRING_LEFT> +class StringWeightReverseIterator; + +template <typename L, StringType S> +bool operator==(const StringWeight<L, S> &, const StringWeight<L, S> &); + + +// String semiring: (longest_common_prefix/suffix, ., Infinity, Epsilon) +template <typename L, StringType S> +class StringWeight { + public: + typedef L Label; + typedef StringWeight<L, REVERSE_STRING_TYPE(S)> ReverseWeight; + + friend class StringWeightIterator<L, S>; + friend class StringWeightReverseIterator<L, S>; + friend bool operator==<>(const StringWeight<L, S> &, + const StringWeight<L, S> &); + + StringWeight() { Init(); } + + template <typename Iter> + StringWeight(const Iter &begin, const Iter &end) { + Init(); + for (Iter iter = begin; iter != end; ++iter) + PushBack(*iter); + } + + explicit StringWeight(L l) { Init(); PushBack(l); } + + static const StringWeight<L, S> &Zero() { + static const StringWeight<L, S> zero(kStringInfinity); + return zero; + } + + static const StringWeight<L, S> &One() { + static const StringWeight<L, S> one; + return one; + } + + static const StringWeight<L, S> &NoWeight() { + static const StringWeight<L, S> no_weight(kStringBad); + return no_weight; + } + + static const string &Type() { + static const string type = + S == STRING_LEFT ? "string" : + (S == STRING_RIGHT ? "right_string" : + (S == STRING_LEFT_RESTRICT ? "restricted_string" : + "right_restricted_string")); + return type; + } + + bool Member() const; + + istream &Read(istream &strm); + + ostream &Write(ostream &strm) const; + + size_t Hash() const; + + StringWeight<L, S> Quantize(float delta = kDelta) const { + return *this; + } + + ReverseWeight Reverse() const; + + static uint64 Properties() { + return (S == STRING_LEFT || S == STRING_LEFT_RESTRICT ? + kLeftSemiring : kRightSemiring) | kIdempotent; + } + + // NB: This needs to be uncommented only if default fails for this impl. + // StringWeight<L, S> &operator=(const StringWeight<L, S> &w); + + // These operations combined with the StringWeightIterator and + // StringWeightReverseIterator provide the access and mutation of + // the string internal elements. + + // Common initializer among constructors. + void Init() { first_ = 0; } + + // Clear existing StringWeight. + void Clear() { first_ = 0; rest_.clear(); } + + size_t Size() const { return first_ ? rest_.size() + 1 : 0; } + + void PushFront(L l) { + if (first_) + rest_.push_front(first_); + first_ = l; + } + + void PushBack(L l) { + if (!first_) + first_ = l; + else + rest_.push_back(l); + } + + private: + L first_; // first label in string (0 if empty) + list<L> rest_; // remaining labels in string +}; + + +// Traverses string in forward direction. +template <typename L, StringType S> +class StringWeightIterator { + public: + explicit StringWeightIterator(const StringWeight<L, S>& w) + : first_(w.first_), rest_(w.rest_), init_(true), + iter_(rest_.begin()) {} + + bool Done() const { + if (init_) return first_ == 0; + else return iter_ == rest_.end(); + } + + const L& Value() const { return init_ ? first_ : *iter_; } + + void Next() { + if (init_) init_ = false; + else ++iter_; + } + + void Reset() { + init_ = true; + iter_ = rest_.begin(); + } + + private: + const L &first_; + const list<L> &rest_; + bool init_; // in the initialized state? + typename list<L>::const_iterator iter_; + + DISALLOW_COPY_AND_ASSIGN(StringWeightIterator); +}; + + +// Traverses string in backward direction. +template <typename L, StringType S> +class StringWeightReverseIterator { + public: + explicit StringWeightReverseIterator(const StringWeight<L, S>& w) + : first_(w.first_), rest_(w.rest_), fin_(first_ == 0), + iter_(rest_.rbegin()) {} + + bool Done() const { return fin_; } + + const L& Value() const { return iter_ == rest_.rend() ? first_ : *iter_; } + + void Next() { + if (iter_ == rest_.rend()) fin_ = true; + else ++iter_; + } + + void Reset() { + fin_ = false; + iter_ = rest_.rbegin(); + } + + private: + const L &first_; + const list<L> &rest_; + bool fin_; // in the final state? + typename list<L>::const_reverse_iterator iter_; + + DISALLOW_COPY_AND_ASSIGN(StringWeightReverseIterator); +}; + + +// StringWeight member functions follow that require +// StringWeightIterator or StringWeightReverseIterator. + +template <typename L, StringType S> +inline istream &StringWeight<L, S>::Read(istream &strm) { + Clear(); + int32 size; + ReadType(strm, &size); + for (int i = 0; i < size; ++i) { + L label; + ReadType(strm, &label); + PushBack(label); + } + return strm; +} + +template <typename L, StringType S> +inline ostream &StringWeight<L, S>::Write(ostream &strm) const { + int32 size = Size(); + WriteType(strm, size); + for (StringWeightIterator<L, S> iter(*this); !iter.Done(); iter.Next()) { + L label = iter.Value(); + WriteType(strm, label); + } + return strm; +} + +template <typename L, StringType S> +inline bool StringWeight<L, S>::Member() const { + if (Size() != 1) + return true; + StringWeightIterator<L, S> iter(*this); + return iter.Value() != kStringBad; +} + +template <typename L, StringType S> +inline typename StringWeight<L, S>::ReverseWeight +StringWeight<L, S>::Reverse() const { + ReverseWeight rw; + for (StringWeightIterator<L, S> iter(*this); !iter.Done(); iter.Next()) + rw.PushFront(iter.Value()); + return rw; +} + +template <typename L, StringType S> +inline size_t StringWeight<L, S>::Hash() const { + size_t h = 0; + for (StringWeightIterator<L, S> iter(*this); !iter.Done(); iter.Next()) + h ^= h<<1 ^ iter.Value(); + return h; +} + +// NB: This needs to be uncommented only if default fails for this the impl. +// +// template <typename L, StringType S> +// inline StringWeight<L, S> +// &StringWeight<L, S>::operator=(const StringWeight<L, S> &w) { +// if (this != &w) { +// Clear(); +// for (StringWeightIterator<L, S> iter(w); !iter.Done(); iter.Next()) +// PushBack(iter.Value()); +// } +// return *this; +// } + +template <typename L, StringType S> +inline bool operator==(const StringWeight<L, S> &w1, + const StringWeight<L, S> &w2) { + if (w1.Size() != w2.Size()) + return false; + + StringWeightIterator<L, S> iter1(w1); + StringWeightIterator<L, S> iter2(w2); + + for (; !iter1.Done() ; iter1.Next(), iter2.Next()) + if (iter1.Value() != iter2.Value()) + return false; + + return true; +} + +template <typename L, StringType S> +inline bool operator!=(const StringWeight<L, S> &w1, + const StringWeight<L, S> &w2) { + return !(w1 == w2); +} + +template <typename L, StringType S> +inline bool ApproxEqual(const StringWeight<L, S> &w1, + const StringWeight<L, S> &w2, + float delta = kDelta) { + return w1 == w2; +} + +template <typename L, StringType S> +inline ostream &operator<<(ostream &strm, const StringWeight<L, S> &w) { + StringWeightIterator<L, S> iter(w); + if (iter.Done()) + return strm << "Epsilon"; + else if (iter.Value() == kStringInfinity) + return strm << "Infinity"; + else if (iter.Value() == kStringBad) + return strm << "BadString"; + else + for (size_t i = 0; !iter.Done(); ++i, iter.Next()) { + if (i > 0) + strm << kStringSeparator; + strm << iter.Value(); + } + return strm; +} + +template <typename L, StringType S> +inline istream &operator>>(istream &strm, StringWeight<L, S> &w) { + string s; + strm >> s; + if (s == "Infinity") { + w = StringWeight<L, S>::Zero(); + } else if (s == "Epsilon") { + w = StringWeight<L, S>::One(); + } else { + w.Clear(); + char *p = 0; + for (const char *cs = s.c_str(); !p || *p != '\0'; cs = p + 1) { + int l = strtoll(cs, &p, 10); + if (p == cs || (*p != 0 && *p != kStringSeparator)) { + strm.clear(std::ios::badbit); + break; + } + w.PushBack(l); + } + } + return strm; +} + + +// Default is for the restricted left and right semirings. String +// equality is required (for non-Zero() input. This restriction +// is used in e.g. Determinize to ensure functional input. +template <typename L, StringType S> inline StringWeight<L, S> +Plus(const StringWeight<L, S> &w1, + const StringWeight<L, S> &w2) { + if (!w1.Member() || !w2.Member()) + return StringWeight<L, S>::NoWeight(); + if (w1 == StringWeight<L, S>::Zero()) + return w2; + if (w2 == StringWeight<L, S>::Zero()) + return w1; + + if (w1 != w2) { + FSTERROR() << "StringWeight::Plus: unequal arguments " + << "(non-functional FST?)" + << " w1 = " << w1 + << " w2 = " << w2; + return StringWeight<L, S>::NoWeight(); + } + + return w1; +} + + +// Longest common prefix for left string semiring. +template <typename L> inline StringWeight<L, STRING_LEFT> +Plus(const StringWeight<L, STRING_LEFT> &w1, + const StringWeight<L, STRING_LEFT> &w2) { + if (!w1.Member() || !w2.Member()) + return StringWeight<L, STRING_LEFT>::NoWeight(); + if (w1 == StringWeight<L, STRING_LEFT>::Zero()) + return w2; + if (w2 == StringWeight<L, STRING_LEFT>::Zero()) + return w1; + + StringWeight<L, STRING_LEFT> sum; + StringWeightIterator<L, STRING_LEFT> iter1(w1); + StringWeightIterator<L, STRING_LEFT> iter2(w2); + for (; !iter1.Done() && !iter2.Done() && iter1.Value() == iter2.Value(); + iter1.Next(), iter2.Next()) + sum.PushBack(iter1.Value()); + return sum; +} + + +// Longest common suffix for right string semiring. +template <typename L> inline StringWeight<L, STRING_RIGHT> +Plus(const StringWeight<L, STRING_RIGHT> &w1, + const StringWeight<L, STRING_RIGHT> &w2) { + if (!w1.Member() || !w2.Member()) + return StringWeight<L, STRING_RIGHT>::NoWeight(); + if (w1 == StringWeight<L, STRING_RIGHT>::Zero()) + return w2; + if (w2 == StringWeight<L, STRING_RIGHT>::Zero()) + return w1; + + StringWeight<L, STRING_RIGHT> sum; + StringWeightReverseIterator<L, STRING_RIGHT> iter1(w1); + StringWeightReverseIterator<L, STRING_RIGHT> iter2(w2); + for (; !iter1.Done() && !iter2.Done() && iter1.Value() == iter2.Value(); + iter1.Next(), iter2.Next()) + sum.PushFront(iter1.Value()); + return sum; +} + + +template <typename L, StringType S> +inline StringWeight<L, S> Times(const StringWeight<L, S> &w1, + const StringWeight<L, S> &w2) { + if (!w1.Member() || !w2.Member()) + return StringWeight<L, S>::NoWeight(); + if (w1 == StringWeight<L, S>::Zero() || w2 == StringWeight<L, S>::Zero()) + return StringWeight<L, S>::Zero(); + + StringWeight<L, S> prod(w1); + for (StringWeightIterator<L, S> iter(w2); !iter.Done(); iter.Next()) + prod.PushBack(iter.Value()); + + return prod; +} + + +// Default is for left division in the left string and the +// left restricted string semirings. +template <typename L, StringType S> inline StringWeight<L, S> +Divide(const StringWeight<L, S> &w1, + const StringWeight<L, S> &w2, + DivideType typ) { + + if (typ != DIVIDE_LEFT) { + FSTERROR() << "StringWeight::Divide: only left division is defined " + << "for the " << StringWeight<L, S>::Type() << " semiring"; + return StringWeight<L, S>::NoWeight(); + } + + if (!w1.Member() || !w2.Member()) + return StringWeight<L, S>::NoWeight(); + + if (w2 == StringWeight<L, S>::Zero()) + return StringWeight<L, S>(kStringBad); + else if (w1 == StringWeight<L, S>::Zero()) + return StringWeight<L, S>::Zero(); + + StringWeight<L, S> div; + StringWeightIterator<L, S> iter(w1); + for (int i = 0; !iter.Done(); iter.Next(), ++i) { + if (i >= w2.Size()) + div.PushBack(iter.Value()); + } + return div; +} + + +// Right division in the right string semiring. +template <typename L> inline StringWeight<L, STRING_RIGHT> +Divide(const StringWeight<L, STRING_RIGHT> &w1, + const StringWeight<L, STRING_RIGHT> &w2, + DivideType typ) { + + if (typ != DIVIDE_RIGHT) { + FSTERROR() << "StringWeight::Divide: only right division is defined " + << "for the right string semiring"; + return StringWeight<L, STRING_RIGHT>::NoWeight(); + } + + if (!w1.Member() || !w2.Member()) + return StringWeight<L, STRING_RIGHT>::NoWeight(); + + if (w2 == StringWeight<L, STRING_RIGHT>::Zero()) + return StringWeight<L, STRING_RIGHT>(kStringBad); + else if (w1 == StringWeight<L, STRING_RIGHT>::Zero()) + return StringWeight<L, STRING_RIGHT>::Zero(); + + StringWeight<L, STRING_RIGHT> div; + StringWeightReverseIterator<L, STRING_RIGHT> iter(w1); + for (int i = 0; !iter.Done(); iter.Next(), ++i) { + if (i >= w2.Size()) + div.PushFront(iter.Value()); + } + return div; +} + + +// Right division in the right restricted string semiring. +template <typename L> inline StringWeight<L, STRING_RIGHT_RESTRICT> +Divide(const StringWeight<L, STRING_RIGHT_RESTRICT> &w1, + const StringWeight<L, STRING_RIGHT_RESTRICT> &w2, + DivideType typ) { + + if (typ != DIVIDE_RIGHT) { + FSTERROR() << "StringWeight::Divide: only right division is defined " + << "for the right restricted string semiring"; + return StringWeight<L, STRING_RIGHT_RESTRICT>::NoWeight(); + } + + if (!w1.Member() || !w2.Member()) + return StringWeight<L, STRING_RIGHT_RESTRICT>::NoWeight(); + + if (w2 == StringWeight<L, STRING_RIGHT_RESTRICT>::Zero()) + return StringWeight<L, STRING_RIGHT_RESTRICT>(kStringBad); + else if (w1 == StringWeight<L, STRING_RIGHT_RESTRICT>::Zero()) + return StringWeight<L, STRING_RIGHT_RESTRICT>::Zero(); + + StringWeight<L, STRING_RIGHT_RESTRICT> div; + StringWeightReverseIterator<L, STRING_RIGHT_RESTRICT> iter(w1); + for (int i = 0; !iter.Done(); iter.Next(), ++i) { + if (i >= w2.Size()) + div.PushFront(iter.Value()); + } + return div; +} + + +// Product of string weight and an arbitray weight. +template <class L, class W, StringType S = STRING_LEFT> +struct GallicWeight : public ProductWeight<StringWeight<L, S>, W> { + typedef GallicWeight<L, typename W::ReverseWeight, REVERSE_STRING_TYPE(S)> + ReverseWeight; + + GallicWeight() {} + + GallicWeight(StringWeight<L, S> w1, W w2) + : ProductWeight<StringWeight<L, S>, W>(w1, w2) {} + + explicit GallicWeight(const string &s, int *nread = 0) + : ProductWeight<StringWeight<L, S>, W>(s, nread) {} + + GallicWeight(const ProductWeight<StringWeight<L, S>, W> &w) + : ProductWeight<StringWeight<L, S>, W>(w) {} +}; + +} // namespace fst + +#endif // FST_LIB_STRING_WEIGHT_H__ diff --git a/kaldi_io/src/tools/openfst/include/fst/string.h b/kaldi_io/src/tools/openfst/include/fst/string.h new file mode 100644 index 0000000..9eaf7a3 --- /dev/null +++ b/kaldi_io/src/tools/openfst/include/fst/string.h @@ -0,0 +1,271 @@ + +// string.h + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Copyright 2005-2010 Google, Inc. +// Author: [email protected] (Cyril Allauzen) +// +// \file +// Utilities to convert strings into FSTs. +// + +#ifndef FST_LIB_STRING_H_ +#define FST_LIB_STRING_H_ + +#include <fst/compact-fst.h> +#include <fst/icu.h> +#include <fst/mutable-fst.h> + +DECLARE_string(fst_field_separator); + +namespace fst { + +// Functor compiling a string in an FST +template <class A> +class StringCompiler { + public: + typedef A Arc; + typedef typename A::Label Label; + typedef typename A::Weight Weight; + + enum TokenType { SYMBOL = 1, BYTE = 2, UTF8 = 3 }; + + StringCompiler(TokenType type, const SymbolTable *syms = 0, + Label unknown_label = kNoLabel, + bool allow_negative = false) + : token_type_(type), syms_(syms), unknown_label_(unknown_label), + allow_negative_(allow_negative) {} + + // Compile string 's' into FST 'fst'. + template <class F> + bool operator()(const string &s, F *fst) const { + vector<Label> labels; + if (!ConvertStringToLabels(s, &labels)) + return false; + Compile(labels, fst); + return true; + } + + template <class F> + bool operator()(const string &s, F *fst, Weight w) const { + vector<Label> labels; + if (!ConvertStringToLabels(s, &labels)) + return false; + Compile(labels, fst, w); + return true; + } + + private: + bool ConvertStringToLabels(const string &str, vector<Label> *labels) const { + labels->clear(); + if (token_type_ == BYTE) { + for (size_t i = 0; i < str.size(); ++i) + labels->push_back(static_cast<unsigned char>(str[i])); + } else if (token_type_ == UTF8) { + return UTF8StringToLabels(str, labels); + } else { + char *c_str = new char[str.size() + 1]; + str.copy(c_str, str.size()); + c_str[str.size()] = 0; + vector<char *> vec; + string separator = "\n" + FLAGS_fst_field_separator; + SplitToVector(c_str, separator.c_str(), &vec, true); + for (size_t i = 0; i < vec.size(); ++i) { + Label label; + if (!ConvertSymbolToLabel(vec[i], &label)) + return false; + labels->push_back(label); + } + delete[] c_str; + } + return true; + } + + void Compile(const vector<Label> &labels, MutableFst<A> *fst, + const Weight &weight = Weight::One()) const { + fst->DeleteStates(); + while (fst->NumStates() <= labels.size()) + fst->AddState(); + for (size_t i = 0; i < labels.size(); ++i) + fst->AddArc(i, Arc(labels[i], labels[i], Weight::One(), i + 1)); + fst->SetStart(0); + fst->SetFinal(labels.size(), weight); + } + + template <class Unsigned> + void Compile(const vector<Label> &labels, + CompactFst<A, StringCompactor<A>, Unsigned> *fst) const { + fst->SetCompactElements(labels.begin(), labels.end()); + } + + template <class Unsigned> + void Compile(const vector<Label> &labels, + CompactFst<A, WeightedStringCompactor<A>, Unsigned> *fst, + const Weight &weight = Weight::One()) const { + vector<pair<Label, Weight> > compacts; + compacts.reserve(labels.size()); + for (size_t i = 0; i < labels.size(); ++i) + compacts.push_back(make_pair(labels[i], Weight::One())); + compacts.back().second = weight; + fst->SetCompactElements(compacts.begin(), compacts.end()); + } + + bool ConvertSymbolToLabel(const char *s, Label* output) const { + int64 n; + if (syms_) { + n = syms_->Find(s); + if ((n == -1) && (unknown_label_ != kNoLabel)) + n = unknown_label_; + if (n == -1 || (!allow_negative_ && n < 0)) { + VLOG(1) << "StringCompiler::ConvertSymbolToLabel: Symbol \"" << s + << "\" is not mapped to any integer label, symbol table = " + << syms_->Name(); + return false; + } + } else { + char *p; + n = strtoll(s, &p, 10); + if (p < s + strlen(s) || (!allow_negative_ && n < 0)) { + VLOG(1) << "StringCompiler::ConvertSymbolToLabel: Bad label integer " + << "= \"" << s << "\""; + return false; + } + } + *output = n; + return true; + } + + TokenType token_type_; // Token type: symbol, byte or utf8 encoded + const SymbolTable *syms_; // Symbol table used when token type is symbol + Label unknown_label_; // Label for token missing from symbol table + bool allow_negative_; // Negative labels allowed? + + DISALLOW_COPY_AND_ASSIGN(StringCompiler); +}; + +// Functor to print a string FST as a string. +template <class A> +class StringPrinter { + public: + typedef A Arc; + typedef typename A::Label Label; + typedef typename A::StateId StateId; + typedef typename A::Weight Weight; + + enum TokenType { SYMBOL = 1, BYTE = 2, UTF8 = 3 }; + + StringPrinter(TokenType token_type, + const SymbolTable *syms = 0) + : token_type_(token_type), syms_(syms) {} + + // Convert the FST 'fst' into the string 'output' + bool operator()(const Fst<A> &fst, string *output) { + bool is_a_string = FstToLabels(fst); + if (!is_a_string) { + VLOG(1) << "StringPrinter::operator(): Fst is not a string."; + return false; + } + + output->clear(); + + if (token_type_ == SYMBOL) { + stringstream sstrm; + for (size_t i = 0; i < labels_.size(); ++i) { + if (i) + sstrm << *(FLAGS_fst_field_separator.rbegin()); + if (!PrintLabel(labels_[i], sstrm)) + return false; + } + *output = sstrm.str(); + } else if (token_type_ == BYTE) { + output->reserve(labels_.size()); + for (size_t i = 0; i < labels_.size(); ++i) { + output->push_back(labels_[i]); + } + } else if (token_type_ == UTF8) { + return LabelsToUTF8String(labels_, output); + } else { + VLOG(1) << "StringPrinter::operator(): Unknown token type: " + << token_type_; + return false; + } + return true; + } + + private: + bool FstToLabels(const Fst<A> &fst) { + labels_.clear(); + + StateId s = fst.Start(); + if (s == kNoStateId) { + VLOG(2) << "StringPrinter::FstToLabels: Invalid starting state for " + << "string fst."; + return false; + } + + while (fst.Final(s) == Weight::Zero()) { + ArcIterator<Fst<A> > aiter(fst, s); + if (aiter.Done()) { + VLOG(2) << "StringPrinter::FstToLabels: String fst traversal does " + << "not reach final state."; + return false; + } + + const A& arc = aiter.Value(); + labels_.push_back(arc.olabel); + + s = arc.nextstate; + if (s == kNoStateId) { + VLOG(2) << "StringPrinter::FstToLabels: Transition to invalid " + << "state."; + return false; + } + + aiter.Next(); + if (!aiter.Done()) { + VLOG(2) << "StringPrinter::FstToLabels: State with multiple " + << "outgoing arcs found."; + return false; + } + } + + return true; + } + + bool PrintLabel(Label lab, ostream& ostrm) { + if (syms_) { + string symbol = syms_->Find(lab); + if (symbol == "") { + VLOG(2) << "StringPrinter::PrintLabel: Integer " << lab << " is not " + << "mapped to any textual symbol, symbol table = " + << syms_->Name(); + return false; + } + ostrm << symbol; + } else { + ostrm << lab; + } + return true; + } + + TokenType token_type_; // Token type: symbol, byte or utf8 encoded + const SymbolTable *syms_; // Symbol table used when token type is symbol + vector<Label> labels_; // Input FST labels. + + DISALLOW_COPY_AND_ASSIGN(StringPrinter); +}; + +} // namespace fst + +#endif // FST_LIB_STRING_H_ diff --git a/kaldi_io/src/tools/openfst/include/fst/symbol-table-ops.h b/kaldi_io/src/tools/openfst/include/fst/symbol-table-ops.h new file mode 100644 index 0000000..1f327da --- /dev/null +++ b/kaldi_io/src/tools/openfst/include/fst/symbol-table-ops.h @@ -0,0 +1,91 @@ + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Copyright 2005-2010 Google, Inc. +// Author: [email protected] (Jeffrey Sorensen) + +#ifndef FST_LIB_SYMBOL_TABLE_OPS_H_ +#define FST_LIB_SYMBOL_TABLE_OPS_H_ + +#include <vector> +using std::vector; +#include <string> +#include <tr1/unordered_set> +using std::tr1::unordered_set; +using std::tr1::unordered_multiset; + + +#include <fst/fst.h> +#include <fst/symbol-table.h> + + +namespace fst { + +// Returns a minimal symbol table containing only symbols referenced by the +// passed fst. Symbols preserve their original numbering, so fst does not +// require relabeling. +template<class Arc> +SymbolTable *PruneSymbolTable(const Fst<Arc> &fst, const SymbolTable &syms, + bool input) { + unordered_set<typename Arc::Label> seen; + seen.insert(0); // Always keep epslion + StateIterator<Fst<Arc> > siter(fst); + for (; !siter.Done(); siter.Next()) { + ArcIterator<Fst<Arc> > aiter(fst, siter.Value()); + for (; !aiter.Done(); aiter.Next()) { + typename Arc::Label sym = (input) ? aiter.Value().ilabel : + aiter.Value().olabel; + seen.insert(sym); + } + } + SymbolTable *pruned = new SymbolTable(syms.Name() + "_pruned"); + for (SymbolTableIterator stiter(syms); !stiter.Done(); stiter.Next()) { + typename Arc::Label label = stiter.Value(); + if (seen.find(label) != seen.end()) { + pruned->AddSymbol(stiter.Symbol(), stiter.Value()); + } + } + return pruned; +} + +// Relabels a symbol table to make it a contiguous mapping. +SymbolTable *CompactSymbolTable(const SymbolTable &syms); + +// Merges two SymbolTables, all symbols from left will be merged into right +// with the same ids. Symbols in right that have conflicting ids with those +// in left will be assigned to value assigned from the left SymbolTable. +// The returned symbol table will never modify symbol assignments from the left +// side, but may do so on the right. If right_relabel_output is non-NULL, it +// will be assigned true if the symbols from the right table needed to be +// reassigned. +// A potential use case is to Compose two Fst's that have different symbol +// tables. You can reconcile them in the following way: +// Fst<Arc> a, b; +// bool relabel; +// SymbolTable *bnew = MergeSymbolTable(a.OutputSymbols(), +// b.InputSymbols(), &relabel); +// if (relabel) { +// Relabel(b, bnew, NULL); +// } +// b.SetInputSymbols(bnew); +// delete bnew; +SymbolTable *MergeSymbolTable(const SymbolTable &left, const SymbolTable &right, + bool *right_relabel_output = 0); + +// Read the symbol table from any Fst::Read()able file, without loading the +// corresponding Fst. Returns NULL if the Fst does not contain a symbol table +// or the symbol table cannot be read. +SymbolTable *FstReadSymbols(const string &filename, bool input); + +} // namespace fst +#endif // FST_LIB_SYMBOL_TABLE_OPS_H_ diff --git a/kaldi_io/src/tools/openfst/include/fst/symbol-table.h b/kaldi_io/src/tools/openfst/include/fst/symbol-table.h new file mode 100644 index 0000000..6eb6c2d --- /dev/null +++ b/kaldi_io/src/tools/openfst/include/fst/symbol-table.h @@ -0,0 +1,537 @@ + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Copyright 2005-2010 Google, Inc. +// All Rights Reserved. +// +// Author : Johan Schalkwyk +// +// \file +// Classes to provide symbol-to-integer and integer-to-symbol mappings. + +#ifndef FST_LIB_SYMBOL_TABLE_H__ +#define FST_LIB_SYMBOL_TABLE_H__ + +#include <cstring> +#include <string> +#include <utility> +using std::pair; using std::make_pair; +#include <vector> +using std::vector; + + +#include <fst/compat.h> +#include <iostream> +#include <fstream> +#include <sstream> + + +#include <map> + +DECLARE_bool(fst_compat_symbols); + +namespace fst { + +// WARNING: Reading via symbol table read options should +// not be used. This is a temporary work around for +// reading symbol ranges of previously stored symbol sets. +struct SymbolTableReadOptions { + SymbolTableReadOptions() { } + + SymbolTableReadOptions(vector<pair<int64, int64> > string_hash_ranges_, + const string& source_) + : string_hash_ranges(string_hash_ranges_), + source(source_) { } + + vector<pair<int64, int64> > string_hash_ranges; + string source; +}; + +struct SymbolTableTextOptions { + SymbolTableTextOptions(); + + bool allow_negative; + string fst_field_separator; +}; + +class SymbolTableImpl { + public: + SymbolTableImpl(const string &name) + : name_(name), + available_key_(0), + dense_key_limit_(0), + check_sum_finalized_(false) {} + + explicit SymbolTableImpl(const SymbolTableImpl& impl) + : name_(impl.name_), + available_key_(0), + dense_key_limit_(0), + check_sum_finalized_(false) { + for (size_t i = 0; i < impl.symbols_.size(); ++i) { + AddSymbol(impl.symbols_[i], impl.Find(impl.symbols_[i])); + } + } + + ~SymbolTableImpl() { + for (size_t i = 0; i < symbols_.size(); ++i) + delete[] symbols_[i]; + } + + // TODO(johans): Add flag to specify whether the symbol + // should be indexed as string or int or both. + int64 AddSymbol(const string& symbol, int64 key); + + int64 AddSymbol(const string& symbol) { + int64 key = Find(symbol); + return (key == -1) ? AddSymbol(symbol, available_key_++) : key; + } + + static SymbolTableImpl* ReadText( + istream &strm, const string &name, + const SymbolTableTextOptions &opts = SymbolTableTextOptions()); + + static SymbolTableImpl* Read(istream &strm, + const SymbolTableReadOptions& opts); + + bool Write(ostream &strm) const; + + // + // Return the string associated with the key. If the key is out of + // range (<0, >max), return an empty string. + string Find(int64 key) const { + if (key >=0 && key < dense_key_limit_) + return string(symbols_[key]); + + map<int64, const char*>::const_iterator it = + key_map_.find(key); + if (it == key_map_.end()) { + return ""; + } + return string(it->second); + } + + // + // Return the key associated with the symbol. If the symbol + // does not exists, return SymbolTable::kNoSymbol. + int64 Find(const string& symbol) const { + return Find(symbol.c_str()); + } + + // + // Return the key associated with the symbol. If the symbol + // does not exists, return SymbolTable::kNoSymbol. + int64 Find(const char* symbol) const { + map<const char *, int64, StrCmp>::const_iterator it = + symbol_map_.find(symbol); + if (it == symbol_map_.end()) { + return -1; + } + return it->second; + } + + int64 GetNthKey(ssize_t pos) const { + if ((pos < 0) || (pos >= symbols_.size())) return -1; + else return Find(symbols_[pos]); + } + + const string& Name() const { return name_; } + + int IncrRefCount() const { + return ref_count_.Incr(); + } + int DecrRefCount() const { + return ref_count_.Decr(); + } + int RefCount() const { + return ref_count_.count(); + } + + string CheckSum() const { + MaybeRecomputeCheckSum(); + return check_sum_string_; + } + + string LabeledCheckSum() const { + MaybeRecomputeCheckSum(); + return labeled_check_sum_string_; + } + + int64 AvailableKey() const { + return available_key_; + } + + size_t NumSymbols() const { + return symbols_.size(); + } + + private: + // Recomputes the checksums (both of them) if we've had changes since the last + // computation (i.e., if check_sum_finalized_ is false). + // Takes ~2.5 microseconds (dbg) or ~230 nanoseconds (opt) on a 2.67GHz Xeon + // if the checksum is up-to-date (requiring no recomputation). + void MaybeRecomputeCheckSum() const; + + struct StrCmp { + bool operator()(const char *s1, const char *s2) const { + return strcmp(s1, s2) < 0; + } + }; + + string name_; + int64 available_key_; + int64 dense_key_limit_; + vector<const char *> symbols_; + map<int64, const char*> key_map_; + map<const char *, int64, StrCmp> symbol_map_; + + mutable RefCounter ref_count_; + mutable bool check_sum_finalized_; + mutable string check_sum_string_; + mutable string labeled_check_sum_string_; + mutable Mutex check_sum_mutex_; +}; + +// +// \class SymbolTable +// \brief Symbol (string) to int and reverse mapping +// +// The SymbolTable implements the mappings of labels to strings and reverse. +// SymbolTables are used to describe the alphabet of the input and output +// labels for arcs in a Finite State Transducer. +// +// SymbolTables are reference counted and can therefore be shared across +// multiple machines. For example a language model grammar G, with a +// SymbolTable for the words in the language model can share this symbol +// table with the lexical representation L o G. +// +class SymbolTable { + public: + static const int64 kNoSymbol = -1; + + // Construct symbol table with an unspecified name. + SymbolTable() : impl_(new SymbolTableImpl("<unspecified>")) {} + + // Construct symbol table with a unique name. + SymbolTable(const string& name) : impl_(new SymbolTableImpl(name)) {} + + // Create a reference counted copy. + SymbolTable(const SymbolTable& table) : impl_(table.impl_) { + impl_->IncrRefCount(); + } + + // Derefence implentation object. When reference count hits 0, delete + // implementation. + virtual ~SymbolTable() { + if (!impl_->DecrRefCount()) delete impl_; + } + + // Copys the implemenation from one symbol table to another. + void operator=(const SymbolTable &st) { + if (impl_ != st.impl_) { + st.impl_->IncrRefCount(); + if (!impl_->DecrRefCount()) delete impl_; + impl_ = st.impl_; + } + } + + // Read an ascii representation of the symbol table from an istream. Pass a + // name to give the resulting SymbolTable. + static SymbolTable* ReadText( + istream &strm, const string& name, + const SymbolTableTextOptions &opts = SymbolTableTextOptions()) { + SymbolTableImpl* impl = SymbolTableImpl::ReadText(strm, name, opts); + if (!impl) + return 0; + else + return new SymbolTable(impl); + } + + // read an ascii representation of the symbol table + static SymbolTable* ReadText(const string& filename, + const SymbolTableTextOptions &opts = SymbolTableTextOptions()) { + ifstream strm(filename.c_str(), ifstream::in); + if (!strm) { + LOG(ERROR) << "SymbolTable::ReadText: Can't open file " << filename; + return 0; + } + return ReadText(strm, filename, opts); + } + + + // WARNING: Reading via symbol table read options should + // not be used. This is a temporary work around. + static SymbolTable* Read(istream &strm, + const SymbolTableReadOptions& opts) { + SymbolTableImpl* impl = SymbolTableImpl::Read(strm, opts); + if (!impl) + return 0; + else + return new SymbolTable(impl); + } + + // read a binary dump of the symbol table from a stream + static SymbolTable* Read(istream &strm, const string& source) { + SymbolTableReadOptions opts; + opts.source = source; + return Read(strm, opts); + } + + // read a binary dump of the symbol table + static SymbolTable* Read(const string& filename) { + ifstream strm(filename.c_str(), ifstream::in | ifstream::binary); + if (!strm) { + LOG(ERROR) << "SymbolTable::Read: Can't open file " << filename; + return 0; + } + return Read(strm, filename); + } + + //-------------------------------------------------------- + // Derivable Interface (final) + //-------------------------------------------------------- + // create a reference counted copy + virtual SymbolTable* Copy() const { + return new SymbolTable(*this); + } + + // Add a symbol with given key to table. A symbol table also + // keeps track of the last available key (highest key value in + // the symbol table). + virtual int64 AddSymbol(const string& symbol, int64 key) { + MutateCheck(); + return impl_->AddSymbol(symbol, key); + } + + // Add a symbol to the table. The associated value key is automatically + // assigned by the symbol table. + virtual int64 AddSymbol(const string& symbol) { + MutateCheck(); + return impl_->AddSymbol(symbol); + } + + // Add another symbol table to this table. All key values will be offset + // by the current available key (highest key value in the symbol table). + // Note string symbols with the same key value with still have the same + // key value after the symbol table has been merged, but a different + // value. Adding symbol tables do not result in changes in the base table. + virtual void AddTable(const SymbolTable& table); + + // return the name of the symbol table + virtual const string& Name() const { + return impl_->Name(); + } + + // Return the label-agnostic MD5 check-sum for this table. All new symbols + // added to the table will result in an updated checksum. + // DEPRECATED. + virtual string CheckSum() const { + return impl_->CheckSum(); + } + + // Same as CheckSum(), but this returns an label-dependent version. + virtual string LabeledCheckSum() const { + return impl_->LabeledCheckSum(); + } + + virtual bool Write(ostream &strm) const { + return impl_->Write(strm); + } + + bool Write(const string& filename) const { + ofstream strm(filename.c_str(), ofstream::out | ofstream::binary); + if (!strm) { + LOG(ERROR) << "SymbolTable::Write: Can't open file " << filename; + return false; + } + return Write(strm); + } + + // Dump an ascii text representation of the symbol table via a stream + virtual bool WriteText( + ostream &strm, + const SymbolTableTextOptions &opts = SymbolTableTextOptions()) const; + + // Dump an ascii text representation of the symbol table + bool WriteText(const string& filename) const { + ofstream strm(filename.c_str()); + if (!strm) { + LOG(ERROR) << "SymbolTable::WriteText: Can't open file " << filename; + return false; + } + return WriteText(strm); + } + + // Return the string associated with the key. If the key is out of + // range (<0, >max), log error and return an empty string. + virtual string Find(int64 key) const { + return impl_->Find(key); + } + + // Return the key associated with the symbol. If the symbol + // does not exists, log error and return SymbolTable::kNoSymbol + virtual int64 Find(const string& symbol) const { + return impl_->Find(symbol); + } + + // Return the key associated with the symbol. If the symbol + // does not exists, log error and return SymbolTable::kNoSymbol + virtual int64 Find(const char* symbol) const { + return impl_->Find(symbol); + } + + // Return the current available key (i.e highest key number+1) in + // the symbol table + virtual int64 AvailableKey(void) const { + return impl_->AvailableKey(); + } + + // Return the current number of symbols in table (not necessarily + // equal to AvailableKey()) + virtual size_t NumSymbols(void) const { + return impl_->NumSymbols(); + } + + virtual int64 GetNthKey(ssize_t pos) const { + return impl_->GetNthKey(pos); + } + + private: + explicit SymbolTable(SymbolTableImpl* impl) : impl_(impl) {} + + void MutateCheck() { + // Copy on write + if (impl_->RefCount() > 1) { + impl_->DecrRefCount(); + impl_ = new SymbolTableImpl(*impl_); + } + } + + const SymbolTableImpl* Impl() const { + return impl_; + } + + private: + SymbolTableImpl* impl_; +}; + + +// +// \class SymbolTableIterator +// \brief Iterator class for symbols in a symbol table +class SymbolTableIterator { + public: + SymbolTableIterator(const SymbolTable& table) + : table_(table), + pos_(0), + nsymbols_(table.NumSymbols()), + key_(table.GetNthKey(0)) { } + + ~SymbolTableIterator() { } + + // is iterator done + bool Done(void) { + return (pos_ == nsymbols_); + } + + // return the Value() of the current symbol (int64 key) + int64 Value(void) { + return key_; + } + + // return the string of the current symbol + string Symbol(void) { + return table_.Find(key_); + } + + // advance iterator forward + void Next(void) { + ++pos_; + if (pos_ < nsymbols_) key_ = table_.GetNthKey(pos_); + } + + // reset iterator + void Reset(void) { + pos_ = 0; + key_ = table_.GetNthKey(0); + } + + private: + const SymbolTable& table_; + ssize_t pos_; + size_t nsymbols_; + int64 key_; +}; + + +// Tests compatibilty between two sets of symbol tables +inline bool CompatSymbols(const SymbolTable *syms1, const SymbolTable *syms2, + bool warning = true) { + if (!FLAGS_fst_compat_symbols) { + return true; + } else if (!syms1 && !syms2) { + return true; + } else if (syms1 && !syms2) { + if (warning) + LOG(WARNING) << + "CompatSymbols: first symbol table present but second missing"; + return false; + } else if (!syms1 && syms2) { + if (warning) + LOG(WARNING) << + "CompatSymbols: second symbol table present but first missing"; + return false; + } else if (syms1->LabeledCheckSum() != syms2->LabeledCheckSum()) { + if (warning) + LOG(WARNING) << "CompatSymbols: Symbol table check sums do not match"; + return false; + } else { + return true; + } +} + + +// Relabels a symbol table as specified by the input vector of pairs +// (old label, new label). The new symbol table only retains symbols +// for which a relabeling is *explicitely* specified. +// TODO(allauzen): consider adding options to allow for some form +// of implicit identity relabeling. +template <class Label> +SymbolTable *RelabelSymbolTable(const SymbolTable *table, + const vector<pair<Label, Label> > &pairs) { + SymbolTable *new_table = new SymbolTable( + table->Name().empty() ? string() : + (string("relabeled_") + table->Name())); + + for (size_t i = 0; i < pairs.size(); ++i) + new_table->AddSymbol(table->Find(pairs[i].first), pairs[i].second); + + return new_table; +} + +// Symbol Table Serialization +inline void SymbolTableToString(const SymbolTable *table, string *result) { + ostringstream ostrm; + table->Write(ostrm); + *result = ostrm.str(); +} + +inline SymbolTable *StringToSymbolTable(const string &s) { + istringstream istrm(s); + return SymbolTable::Read(istrm, SymbolTableReadOptions()); +} + + + +} // namespace fst + +#endif // FST_LIB_SYMBOL_TABLE_H__ diff --git a/kaldi_io/src/tools/openfst/include/fst/synchronize.h b/kaldi_io/src/tools/openfst/include/fst/synchronize.h new file mode 100644 index 0000000..9582926 --- /dev/null +++ b/kaldi_io/src/tools/openfst/include/fst/synchronize.h @@ -0,0 +1,457 @@ +// synchronize.h + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Copyright 2005-2010 Google, Inc. +// Author: [email protected] (Cyril Allauzen) +// +// \file +// Synchronize an FST with bounded delay. + +#ifndef FST_LIB_SYNCHRONIZE_H__ +#define FST_LIB_SYNCHRONIZE_H__ + +#include <algorithm> +#include <tr1/unordered_map> +using std::tr1::unordered_map; +using std::tr1::unordered_multimap; +#include <tr1/unordered_set> +using std::tr1::unordered_set; +using std::tr1::unordered_multiset; +#include <string> +#include <utility> +using std::pair; using std::make_pair; +#include <vector> +using std::vector; + +#include <fst/cache.h> +#include <fst/test-properties.h> + + +namespace fst { + +typedef CacheOptions SynchronizeFstOptions; + + +// Implementation class for SynchronizeFst +template <class A> +class SynchronizeFstImpl + : public CacheImpl<A> { + public: + using FstImpl<A>::SetType; + using FstImpl<A>::SetProperties; + using FstImpl<A>::SetInputSymbols; + using FstImpl<A>::SetOutputSymbols; + + using CacheBaseImpl< CacheState<A> >::PushArc; + using CacheBaseImpl< CacheState<A> >::HasArcs; + using CacheBaseImpl< CacheState<A> >::HasFinal; + using CacheBaseImpl< CacheState<A> >::HasStart; + using CacheBaseImpl< CacheState<A> >::SetArcs; + using CacheBaseImpl< CacheState<A> >::SetFinal; + using CacheBaseImpl< CacheState<A> >::SetStart; + + typedef A Arc; + typedef typename A::Label Label; + typedef typename A::Weight Weight; + typedef typename A::StateId StateId; + + typedef basic_string<Label> String; + + struct Element { + Element() {} + + Element(StateId s, const String *i, const String *o) + : state(s), istring(i), ostring(o) {} + + StateId state; // Input state Id + const String *istring; // Residual input labels + const String *ostring; // Residual output labels + // Residual strings are represented by const pointers to + // basic_string<Label> and are stored in a hash_set. The pointed + // memory is owned by the hash_set string_set_. + }; + + SynchronizeFstImpl(const Fst<A> &fst, const SynchronizeFstOptions &opts) + : CacheImpl<A>(opts), fst_(fst.Copy()) { + SetType("synchronize"); + uint64 props = fst.Properties(kFstProperties, false); + SetProperties(SynchronizeProperties(props), kCopyProperties); + + SetInputSymbols(fst.InputSymbols()); + SetOutputSymbols(fst.OutputSymbols()); + } + + SynchronizeFstImpl(const SynchronizeFstImpl &impl) + : CacheImpl<A>(impl), + fst_(impl.fst_->Copy(true)) { + SetType("synchronize"); + SetProperties(impl.Properties(), kCopyProperties); + SetInputSymbols(impl.InputSymbols()); + SetOutputSymbols(impl.OutputSymbols()); + } + + ~SynchronizeFstImpl() { + delete fst_; + // Extract pointers from the hash set + vector<const String*> strings; + typename StringSet::iterator it = string_set_.begin(); + for (; it != string_set_.end(); ++it) + strings.push_back(*it); + // Free the extracted pointers + for (size_t i = 0; i < strings.size(); ++i) + delete strings[i]; + } + + StateId Start() { + if (!HasStart()) { + StateId s = fst_->Start(); + if (s == kNoStateId) + return kNoStateId; + const String *empty = FindString(new String()); + StateId start = FindState(Element(fst_->Start(), empty, empty)); + SetStart(start); + } + return CacheImpl<A>::Start(); + } + + Weight Final(StateId s) { + if (!HasFinal(s)) { + const Element &e = elements_[s]; + Weight w = e.state == kNoStateId ? Weight::One() : fst_->Final(e.state); + if ((w != Weight::Zero()) && (e.istring)->empty() && (e.ostring)->empty()) + SetFinal(s, w); + else + SetFinal(s, Weight::Zero()); + } + return CacheImpl<A>::Final(s); + } + + size_t NumArcs(StateId s) { + if (!HasArcs(s)) + Expand(s); + return CacheImpl<A>::NumArcs(s); + } + + size_t NumInputEpsilons(StateId s) { + if (!HasArcs(s)) + Expand(s); + return CacheImpl<A>::NumInputEpsilons(s); + } + + size_t NumOutputEpsilons(StateId s) { + if (!HasArcs(s)) + Expand(s); + return CacheImpl<A>::NumOutputEpsilons(s); + } + + uint64 Properties() const { return Properties(kFstProperties); } + + // Set error if found; return FST impl properties. + uint64 Properties(uint64 mask) const { + if ((mask & kError) && fst_->Properties(kError, false)) + SetProperties(kError, kError); + return FstImpl<Arc>::Properties(mask); + } + + void InitArcIterator(StateId s, ArcIteratorData<A> *data) { + if (!HasArcs(s)) + Expand(s); + CacheImpl<A>::InitArcIterator(s, data); + } + + // Returns the first character of the string obtained by + // concatenating s and l. + Label Car(const String *s, Label l = 0) const { + if (!s->empty()) + return (*s)[0]; + else + return l; + } + + // Computes the residual string obtained by removing the first + // character in the concatenation of s and l. + const String *Cdr(const String *s, Label l = 0) { + String *r = new String(); + for (int i = 1; i < s->size(); ++i) + r->push_back((*s)[i]); + if (l && !(s->empty())) r->push_back(l); + return FindString(r); + } + + // Computes the concatenation of s and l. + const String *Concat(const String *s, Label l = 0) { + String *r = new String(); + for (int i = 0; i < s->size(); ++i) + r->push_back((*s)[i]); + if (l) r->push_back(l); + return FindString(r); + } + + // Tests if the concatenation of s and l is empty + bool Empty(const String *s, Label l = 0) const { + if (s->empty()) + return l == 0; + else + return false; + } + + // Finds the string pointed by s in the hash set. Transfers the + // pointer ownership to the hash set. + const String *FindString(const String *s) { + typename StringSet::iterator it = string_set_.find(s); + if (it != string_set_.end()) { + delete s; + return (*it); + } else { + string_set_.insert(s); + return s; + } + } + + // Finds state corresponding to an element. Creates new state + // if element not found. + StateId FindState(const Element &e) { + typename ElementMap::iterator eit = element_map_.find(e); + if (eit != element_map_.end()) { + return (*eit).second; + } else { + StateId s = elements_.size(); + elements_.push_back(e); + element_map_.insert(pair<const Element, StateId>(e, s)); + return s; + } + } + + + // Computes the outgoing transitions from a state, creating new destination + // states as needed. + void Expand(StateId s) { + Element e = elements_[s]; + + if (e.state != kNoStateId) + for (ArcIterator< Fst<A> > ait(*fst_, e.state); + !ait.Done(); + ait.Next()) { + const A &arc = ait.Value(); + if (!Empty(e.istring, arc.ilabel) && !Empty(e.ostring, arc.olabel)) { + const String *istring = Cdr(e.istring, arc.ilabel); + const String *ostring = Cdr(e.ostring, arc.olabel); + StateId d = FindState(Element(arc.nextstate, istring, ostring)); + PushArc(s, Arc(Car(e.istring, arc.ilabel), + Car(e.ostring, arc.olabel), arc.weight, d)); + } else { + const String *istring = Concat(e.istring, arc.ilabel); + const String *ostring = Concat(e.ostring, arc.olabel); + StateId d = FindState(Element(arc.nextstate, istring, ostring)); + PushArc(s, Arc(0 , 0, arc.weight, d)); + } + } + + Weight w = e.state == kNoStateId ? Weight::One() : fst_->Final(e.state); + if ((w != Weight::Zero()) && + ((e.istring)->size() + (e.ostring)->size() > 0)) { + const String *istring = Cdr(e.istring); + const String *ostring = Cdr(e.ostring); + StateId d = FindState(Element(kNoStateId, istring, ostring)); + PushArc(s, Arc(Car(e.istring), Car(e.ostring), w, d)); + } + SetArcs(s); + } + + private: + // Equality function for Elements, assume strings have been hashed. + class ElementEqual { + public: + bool operator()(const Element &x, const Element &y) const { + return x.state == y.state && + x.istring == y.istring && + x.ostring == y.ostring; + } + }; + + // Hash function for Elements to Fst states. + class ElementKey { + public: + size_t operator()(const Element &x) const { + size_t key = x.state; + key = (key << 1) ^ (x.istring)->size(); + for (size_t i = 0; i < (x.istring)->size(); ++i) + key = (key << 1) ^ (*x.istring)[i]; + key = (key << 1) ^ (x.ostring)->size(); + for (size_t i = 0; i < (x.ostring)->size(); ++i) + key = (key << 1) ^ (*x.ostring)[i]; + return key; + } + }; + + // Equality function for strings + class StringEqual { + public: + bool operator()(const String * const &x, const String * const &y) const { + if (x->size() != y->size()) return false; + for (size_t i = 0; i < x->size(); ++i) + if ((*x)[i] != (*y)[i]) return false; + return true; + } + }; + + // Hash function for set of strings + class StringKey{ + public: + size_t operator()(const String * const & x) const { + size_t key = x->size(); + for (size_t i = 0; i < x->size(); ++i) + key = (key << 1) ^ (*x)[i]; + return key; + } + }; + + + typedef unordered_map<Element, StateId, ElementKey, ElementEqual> ElementMap; + typedef unordered_set<const String*, StringKey, StringEqual> StringSet; + + const Fst<A> *fst_; + vector<Element> elements_; // mapping Fst state to Elements + ElementMap element_map_; // mapping Elements to Fst state + StringSet string_set_; + + void operator=(const SynchronizeFstImpl<A> &); // disallow +}; + + +// Synchronizes a transducer. This version is a delayed Fst. The +// result will be an equivalent FST that has the property that during +// the traversal of a path, the delay is either zero or strictly +// increasing, where the delay is the difference between the number of +// non-epsilon output labels and input labels along the path. +// +// For the algorithm to terminate, the input transducer must have +// bounded delay, i.e., the delay of every cycle must be zero. +// +// Complexity: +// - A has bounded delay: exponential +// - A does not have bounded delay: does not terminate +// +// References: +// - Mehryar Mohri. Edit-Distance of Weighted Automata: General +// Definitions and Algorithms, International Journal of Computer +// Science, 14(6): 957-982 (2003). +// +// This class attaches interface to implementation and handles +// reference counting, delegating most methods to ImplToFst. +template <class A> +class SynchronizeFst : public ImplToFst< SynchronizeFstImpl<A> > { + public: + friend class ArcIterator< SynchronizeFst<A> >; + friend class StateIterator< SynchronizeFst<A> >; + + typedef A Arc; + typedef typename A::Weight Weight; + typedef typename A::StateId StateId; + typedef CacheState<A> State; + typedef SynchronizeFstImpl<A> Impl; + + SynchronizeFst(const Fst<A> &fst) + : ImplToFst<Impl>(new Impl(fst, SynchronizeFstOptions())) {} + + SynchronizeFst(const Fst<A> &fst, const SynchronizeFstOptions &opts) + : ImplToFst<Impl>(new Impl(fst, opts)) {} + + // See Fst<>::Copy() for doc. + SynchronizeFst(const SynchronizeFst<A> &fst, bool safe = false) + : ImplToFst<Impl>(fst, safe) {} + + // Get a copy of this SynchronizeFst. See Fst<>::Copy() for further doc. + virtual SynchronizeFst<A> *Copy(bool safe = false) const { + return new SynchronizeFst<A>(*this, safe); + } + + virtual inline void InitStateIterator(StateIteratorData<A> *data) const; + + virtual void InitArcIterator(StateId s, ArcIteratorData<A> *data) const { + GetImpl()->InitArcIterator(s, data); + } + + private: + // Makes visible to friends. + Impl *GetImpl() const { return ImplToFst<Impl>::GetImpl(); } + + void operator=(const SynchronizeFst<A> &fst); // Disallow +}; + + +// Specialization for SynchronizeFst. +template<class A> +class StateIterator< SynchronizeFst<A> > + : public CacheStateIterator< SynchronizeFst<A> > { + public: + explicit StateIterator(const SynchronizeFst<A> &fst) + : CacheStateIterator< SynchronizeFst<A> >(fst, fst.GetImpl()) {} +}; + + +// Specialization for SynchronizeFst. +template <class A> +class ArcIterator< SynchronizeFst<A> > + : public CacheArcIterator< SynchronizeFst<A> > { + public: + typedef typename A::StateId StateId; + + ArcIterator(const SynchronizeFst<A> &fst, StateId s) + : CacheArcIterator< SynchronizeFst<A> >(fst.GetImpl(), s) { + if (!fst.GetImpl()->HasArcs(s)) + fst.GetImpl()->Expand(s); + } + + private: + DISALLOW_COPY_AND_ASSIGN(ArcIterator); +}; + + +template <class A> inline +void SynchronizeFst<A>::InitStateIterator(StateIteratorData<A> *data) const +{ + data->base = new StateIterator< SynchronizeFst<A> >(*this); +} + + + +// Synchronizes a transducer. This version writes the synchronized +// result to a MutableFst. The result will be an equivalent FST that +// has the property that during the traversal of a path, the delay is +// either zero or strictly increasing, where the delay is the +// difference between the number of non-epsilon output labels and +// input labels along the path. +// +// For the algorithm to terminate, the input transducer must have +// bounded delay, i.e., the delay of every cycle must be zero. +// +// Complexity: +// - A has bounded delay: exponential +// - A does not have bounded delay: does not terminate +// +// References: +// - Mehryar Mohri. Edit-Distance of Weighted Automata: General +// Definitions and Algorithms, International Journal of Computer +// Science, 14(6): 957-982 (2003). +template<class Arc> +void Synchronize(const Fst<Arc> &ifst, MutableFst<Arc> *ofst) { + SynchronizeFstOptions opts; + opts.gc_limit = 0; // Cache only the last state for fastest copy. + *ofst = SynchronizeFst<Arc>(ifst, opts); +} + +} // namespace fst + +#endif // FST_LIB_SYNCHRONIZE_H__ diff --git a/kaldi_io/src/tools/openfst/include/fst/test-properties.h b/kaldi_io/src/tools/openfst/include/fst/test-properties.h new file mode 100644 index 0000000..80af593 --- /dev/null +++ b/kaldi_io/src/tools/openfst/include/fst/test-properties.h @@ -0,0 +1,250 @@ +// test-properties.h + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Copyright 2005-2010 Google, Inc. +// Author: [email protected] (Michael Riley) +// +// \file +// Functions to manipulate and test property bits + +#ifndef FST_LIB_TEST_PROPERTIES_H__ +#define FST_LIB_TEST_PROPERTIES_H__ + +#include <tr1/unordered_set> +using std::tr1::unordered_set; +using std::tr1::unordered_multiset; + +#include <fst/dfs-visit.h> +#include <fst/connect.h> + + +DECLARE_bool(fst_verify_properties); + +namespace fst { + +// For a binary property, the bit is always returned set. +// For a trinary (i.e. two-bit) property, both bits are +// returned set iff either corresponding input bit is set. +inline uint64 KnownProperties(uint64 props) { + return kBinaryProperties | (props & kTrinaryProperties) | + ((props & kPosTrinaryProperties) << 1) | + ((props & kNegTrinaryProperties) >> 1); +} + +// Tests compatibility between two sets of properties +inline bool CompatProperties(uint64 props1, uint64 props2) { + uint64 known_props1 = KnownProperties(props1); + uint64 known_props2 = KnownProperties(props2); + uint64 known_props = known_props1 & known_props2; + uint64 incompat_props = (props1 & known_props) ^ (props2 & known_props); + if (incompat_props) { + uint64 prop = 1; + for (int i = 0; i < 64; ++i, prop <<= 1) + if (prop & incompat_props) + LOG(ERROR) << "CompatProperties: mismatch: " << PropertyNames[i] + << ": props1 = " << (props1 & prop ? "true" : "false") + << ", props2 = " << (props2 & prop ? "true" : "false"); + return false; + } else { + return true; + } +} + +// Computes FST property values defined in properties.h. The value of +// each property indicated in the mask will be determined and returned +// (these will never be unknown here). In the course of determining +// the properties specifically requested in the mask, certain other +// properties may be determined (those with little additional expense) +// and their values will be returned as well. The complete set of +// known properties (whether true or false) determined by this +// operation will be assigned to the the value pointed to by KNOWN. +// If 'use_stored' is true, pre-computed FST properties may be used +// when possible. This routine is seldom called directly; instead it +// is used to implement fst.Properties(mask, true). +template<class Arc> +uint64 ComputeProperties(const Fst<Arc> &fst, uint64 mask, uint64 *known, + bool use_stored) { + typedef typename Arc::Label Label; + typedef typename Arc::Weight Weight; + typedef typename Arc::StateId StateId; + + uint64 fst_props = fst.Properties(kFstProperties, false); // Fst-stored + + // Check stored FST properties first if allowed. + if (use_stored) { + uint64 known_props = KnownProperties(fst_props); + // If FST contains required info, return it. + if ((known_props & mask) == mask) { + *known = known_props; + return fst_props; + } + } + + // Compute (trinary) properties explicitly. + + // Initialize with binary properties (already known). + uint64 comp_props = fst_props & kBinaryProperties; + + // Compute these trinary properties with a DFS. We compute only those + // that need a DFS here, since we otherwise would like to avoid a DFS + // since its stack could grow large. + uint64 dfs_props = kCyclic | kAcyclic | kInitialCyclic | kInitialAcyclic | + kAccessible | kNotAccessible | + kCoAccessible | kNotCoAccessible; + if (mask & dfs_props) { + SccVisitor<Arc> scc_visitor(&comp_props); + DfsVisit(fst, &scc_visitor); + } + + // Compute any remaining trinary properties via a state and arcs iterations + if (mask & ~(kBinaryProperties | dfs_props)) { + comp_props |= kAcceptor | kNoEpsilons | kNoIEpsilons | kNoOEpsilons | + kILabelSorted | kOLabelSorted | kUnweighted | kTopSorted | kString; + if (mask & (kIDeterministic | kNonIDeterministic)) + comp_props |= kIDeterministic; + if (mask & (kODeterministic | kNonODeterministic)) + comp_props |= kODeterministic; + + unordered_set<Label> *ilabels = 0; + unordered_set<Label> *olabels = 0; + + StateId nfinal = 0; + for (StateIterator< Fst<Arc> > siter(fst); + !siter.Done(); + siter.Next()) { + StateId s = siter.Value(); + + Arc prev_arc; + // Create these only if we need to + if (mask & (kIDeterministic | kNonIDeterministic)) + ilabels = new unordered_set<Label>; + if (mask & (kODeterministic | kNonODeterministic)) + olabels = new unordered_set<Label>; + + bool first_arc = true; + for (ArcIterator< Fst<Arc> > aiter(fst, s); + !aiter.Done(); + aiter.Next()) { + const Arc &arc =aiter.Value(); + + if (ilabels && ilabels->find(arc.ilabel) != ilabels->end()) { + comp_props |= kNonIDeterministic; + comp_props &= ~kIDeterministic; + } + if (olabels && olabels->find(arc.olabel) != olabels->end()) { + comp_props |= kNonODeterministic; + comp_props &= ~kODeterministic; + } + if (arc.ilabel != arc.olabel) { + comp_props |= kNotAcceptor; + comp_props &= ~kAcceptor; + } + if (arc.ilabel == 0 && arc.olabel == 0) { + comp_props |= kEpsilons; + comp_props &= ~kNoEpsilons; + } + if (arc.ilabel == 0) { + comp_props |= kIEpsilons; + comp_props &= ~kNoIEpsilons; + } + if (arc.olabel == 0) { + comp_props |= kOEpsilons; + comp_props &= ~kNoOEpsilons; + } + if (!first_arc) { + if (arc.ilabel < prev_arc.ilabel) { + comp_props |= kNotILabelSorted; + comp_props &= ~kILabelSorted; + } + if (arc.olabel < prev_arc.olabel) { + comp_props |= kNotOLabelSorted; + comp_props &= ~kOLabelSorted; + } + } + if (arc.weight != Weight::One() && arc.weight != Weight::Zero()) { + comp_props |= kWeighted; + comp_props &= ~kUnweighted; + } + if (arc.nextstate <= s) { + comp_props |= kNotTopSorted; + comp_props &= ~kTopSorted; + } + if (arc.nextstate != s + 1) { + comp_props |= kNotString; + comp_props &= ~kString; + } + prev_arc = arc; + first_arc = false; + if (ilabels) + ilabels->insert(arc.ilabel); + if (olabels) + olabels->insert(arc.olabel); + } + + if (nfinal > 0) { // final state not last + comp_props |= kNotString; + comp_props &= ~kString; + } + + Weight final = fst.Final(s); + + if (final != Weight::Zero()) { // final state + if (final != Weight::One()) { + comp_props |= kWeighted; + comp_props &= ~kUnweighted; + } + ++nfinal; + } else { // non-final state + if (fst.NumArcs(s) != 1) { + comp_props |= kNotString; + comp_props &= ~kString; + } + } + + delete ilabels; + delete olabels; + } + + if (fst.Start() != kNoStateId && fst.Start() != 0) { + comp_props |= kNotString; + comp_props &= ~kString; + } + } + + *known = KnownProperties(comp_props); + return comp_props; +} + +// This is a wrapper around ComputeProperties that will cause a fatal +// error if the stored properties and the computed properties are +// incompatible when 'FLAGS_fst_verify_properties' is true. This +// routine is seldom called directly; instead it is used to implement +// fst.Properties(mask, true). +template<class Arc> +uint64 TestProperties(const Fst<Arc> &fst, uint64 mask, uint64 *known) { + if (FLAGS_fst_verify_properties) { + uint64 stored_props = fst.Properties(kFstProperties, false); + uint64 computed_props = ComputeProperties(fst, mask, known, false); + if (!CompatProperties(stored_props, computed_props)) + LOG(FATAL) << "TestProperties: stored Fst properties incorrect" + << " (stored: props1, computed: props2)"; + return computed_props; + } else { + return ComputeProperties(fst, mask, known, true); + } +} + +} // namespace fst + +#endif // FST_LIB_TEST_PROPERTIES_H__ diff --git a/kaldi_io/src/tools/openfst/include/fst/topsort.h b/kaldi_io/src/tools/openfst/include/fst/topsort.h new file mode 100644 index 0000000..53735e5 --- /dev/null +++ b/kaldi_io/src/tools/openfst/include/fst/topsort.h @@ -0,0 +1,112 @@ +// topsort.h + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Copyright 2005-2010 Google, Inc. +// Author: [email protected] (Michael Riley) +// +// \file +// Topological sort of FSTs + +#ifndef FST_LIB_TOPSORT_H__ +#define FST_LIB_TOPSORT_H__ + +#include <algorithm> +#include <vector> +using std::vector; + + +#include <fst/dfs-visit.h> +#include <fst/fst.h> +#include <fst/statesort.h> + + +namespace fst { + +// DFS visitor class to return topological ordering. +template <class A> +class TopOrderVisitor { + public: + typedef A Arc; + typedef typename A::StateId StateId; + + // If acyclic, ORDER[i] gives the topological position of state Id i; + // otherwise unchanged. ACYCLIC will be true iff the FST has + // no cycles. + TopOrderVisitor(vector<StateId> *order, bool *acyclic) + : order_(order), acyclic_(acyclic) {} + + void InitVisit(const Fst<A> &fst) { + finish_ = new vector<StateId>; + *acyclic_ = true; + } + + bool InitState(StateId s, StateId r) { return true; } + + bool TreeArc(StateId s, const A &arc) { return true; } + + bool BackArc(StateId s, const A &arc) { return (*acyclic_ = false); } + + bool ForwardOrCrossArc(StateId s, const A &arc) { return true; } + + void FinishState(StateId s, StateId p, const A *) { finish_->push_back(s); } + + void FinishVisit() { + if (*acyclic_) { + order_->clear(); + for (StateId s = 0; s < finish_->size(); ++s) + order_->push_back(kNoStateId); + for (StateId s = 0; s < finish_->size(); ++s) + (*order_)[(*finish_)[finish_->size() - s - 1]] = s; + } + delete finish_; + } + + private: + vector<StateId> *order_; + bool *acyclic_; + vector<StateId> *finish_; // states in finishing-time order +}; + + +// Topologically sorts its input if acyclic, modifying it. Otherwise, +// the input is unchanged. When sorted, all transitions are from +// lower to higher state IDs. +// +// Complexity: +// - Time: O(V + E) +// - Space: O(V + E) +// where V = # of states and E = # of arcs. +template <class Arc> +bool TopSort(MutableFst<Arc> *fst) { + typedef typename Arc::StateId StateId; + + vector<StateId> order; + bool acyclic; + + TopOrderVisitor<Arc> top_order_visitor(&order, &acyclic); + DfsVisit(*fst, &top_order_visitor); + + if (acyclic) { + StateSort(fst, order); + fst->SetProperties(kAcyclic | kInitialAcyclic | kTopSorted, + kAcyclic | kInitialAcyclic | kTopSorted); + } else { + fst->SetProperties(kCyclic | kNotTopSorted, kCyclic | kNotTopSorted); + } + return acyclic; +} + +} // namespace fst + +#endif // FST_LIB_TOPSORT_H__ diff --git a/kaldi_io/src/tools/openfst/include/fst/tuple-weight.h b/kaldi_io/src/tools/openfst/include/fst/tuple-weight.h new file mode 100644 index 0000000..184026c --- /dev/null +++ b/kaldi_io/src/tools/openfst/include/fst/tuple-weight.h @@ -0,0 +1,332 @@ +// tuple-weight.h + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Copyright 2005-2010 Google, Inc. +// Author: allauzen@google (Cyril Allauzen) +// +// \file +// Tuple weight set operation definitions. + +#ifndef FST_LIB_TUPLE_WEIGHT_H__ +#define FST_LIB_TUPLE_WEIGHT_H__ + +#include <string> +#include <vector> +using std::vector; + +#include <fst/weight.h> + + +DECLARE_string(fst_weight_parentheses); +DECLARE_string(fst_weight_separator); + +namespace fst { + +template<class W, unsigned int n> class TupleWeight; +template <class W, unsigned int n> +istream &operator>>(istream &strm, TupleWeight<W, n> &w); + +// n-tuple weight, element of the n-th catersian power of W +template <class W, unsigned int n> +class TupleWeight { + public: + typedef TupleWeight<typename W::ReverseWeight, n> ReverseWeight; + + TupleWeight() {} + + TupleWeight(const TupleWeight &w) { + for (size_t i = 0; i < n; ++i) + values_[i] = w.values_[i]; + } + + template <class Iterator> + TupleWeight(Iterator begin, Iterator end) { + for (Iterator iter = begin; iter != end; ++iter) + values_[iter - begin] = *iter; + } + + TupleWeight(const W &w) { + for (size_t i = 0; i < n; ++i) + values_[i] = w; + } + + static const TupleWeight<W, n> &Zero() { + static const TupleWeight<W, n> zero(W::Zero()); + return zero; + } + + static const TupleWeight<W, n> &One() { + static const TupleWeight<W, n> one(W::One()); + return one; + } + + static const TupleWeight<W, n> &NoWeight() { + static const TupleWeight<W, n> no_weight(W::NoWeight()); + return no_weight; + } + + static unsigned int Length() { + return n; + } + + istream &Read(istream &strm) { + for (size_t i = 0; i < n; ++i) + values_[i].Read(strm); + return strm; + } + + ostream &Write(ostream &strm) const { + for (size_t i = 0; i < n; ++i) + values_[i].Write(strm); + return strm; + } + + TupleWeight<W, n> &operator=(const TupleWeight<W, n> &w) { + for (size_t i = 0; i < n; ++i) + values_[i] = w.values_[i]; + return *this; + } + + bool Member() const { + bool member = true; + for (size_t i = 0; i < n; ++i) + member = member && values_[i].Member(); + return member; + } + + size_t Hash() const { + uint64 hash = 0; + for (size_t i = 0; i < n; ++i) + hash = 5 * hash + values_[i].Hash(); + return size_t(hash); + } + + TupleWeight<W, n> Quantize(float delta = kDelta) const { + TupleWeight<W, n> w; + for (size_t i = 0; i < n; ++i) + w.values_[i] = values_[i].Quantize(delta); + return w; + } + + ReverseWeight Reverse() const { + TupleWeight<W, n> w; + for (size_t i = 0; i < n; ++i) + w.values_[i] = values_[i].Reverse(); + return w; + } + + const W& Value(size_t i) const { return values_[i]; } + + void SetValue(size_t i, const W &w) { values_[i] = w; } + + protected: + // Reads TupleWeight when there are no parentheses around tuple terms + inline static istream &ReadNoParen(istream &strm, + TupleWeight<W, n> &w, + char separator) { + int c; + do { + c = strm.get(); + } while (isspace(c)); + + for (size_t i = 0; i < n - 1; ++i) { + string s; + if (i) + c = strm.get(); + while (c != separator) { + if (c == EOF) { + strm.clear(std::ios::badbit); + return strm; + } + s += c; + c = strm.get(); + } + // read (i+1)-th element + istringstream sstrm(s); + W r = W::Zero(); + sstrm >> r; + w.SetValue(i, r); + } + + // read n-th element + W r = W::Zero(); + strm >> r; + w.SetValue(n - 1, r); + + return strm; + } + + // Reads TupleWeight when there are parentheses around tuple terms + inline static istream &ReadWithParen(istream &strm, + TupleWeight<W, n> &w, + char separator, + char open_paren, + char close_paren) { + int c; + do { + c = strm.get(); + } while (isspace(c)); + + if (c != open_paren) { + FSTERROR() << " is fst_weight_parentheses flag set correcty? "; + strm.clear(std::ios::badbit); + return strm; + } + + for (size_t i = 0; i < n - 1; ++i) { + // read (i+1)-th element + stack<int> parens; + string s; + c = strm.get(); + while (c != separator || !parens.empty()) { + if (c == EOF) { + strm.clear(std::ios::badbit); + return strm; + } + s += c; + // if parens encountered before separator, they must be matched + if (c == open_paren) { + parens.push(1); + } else if (c == close_paren) { + // Fail for mismatched parens + if (parens.empty()) { + strm.clear(std::ios::failbit); + return strm; + } + parens.pop(); + } + c = strm.get(); + } + istringstream sstrm(s); + W r = W::Zero(); + sstrm >> r; + w.SetValue(i, r); + } + + // read n-th element + string s; + c = strm.get(); + while (c != EOF) { + s += c; + c = strm.get(); + } + if (s.empty() || *s.rbegin() != close_paren) { + FSTERROR() << " is fst_weight_parentheses flag set correcty? "; + strm.clear(std::ios::failbit); + return strm; + } + s.erase(s.size() - 1, 1); + istringstream sstrm(s); + W r = W::Zero(); + sstrm >> r; + w.SetValue(n - 1, r); + + return strm; + } + + + private: + W values_[n]; + + friend istream &operator>><W, n>(istream&, TupleWeight<W, n>&); +}; + +template <class W, unsigned int n> +inline bool operator==(const TupleWeight<W, n> &w1, + const TupleWeight<W, n> &w2) { + bool equal = true; + for (size_t i = 0; i < n; ++i) + equal = equal && (w1.Value(i) == w2.Value(i)); + return equal; +} + +template <class W, unsigned int n> +inline bool operator!=(const TupleWeight<W, n> &w1, + const TupleWeight<W, n> &w2) { + bool not_equal = false; + for (size_t i = 0; (i < n) && !not_equal; ++i) + not_equal = not_equal || (w1.Value(i) != w2.Value(i)); + return not_equal; +} + +template <class W, unsigned int n> +inline bool ApproxEqual(const TupleWeight<W, n> &w1, + const TupleWeight<W, n> &w2, + float delta = kDelta) { + bool approx_equal = true; + for (size_t i = 0; i < n; ++i) + approx_equal = approx_equal && + ApproxEqual(w1.Value(i), w2.Value(i), delta); + return approx_equal; +} + +template <class W, unsigned int n> +inline ostream &operator<<(ostream &strm, const TupleWeight<W, n> &w) { + if(FLAGS_fst_weight_separator.size() != 1) { + FSTERROR() << "FLAGS_fst_weight_separator.size() is not equal to 1"; + strm.clear(std::ios::badbit); + return strm; + } + char separator = FLAGS_fst_weight_separator[0]; + bool write_parens = false; + if (!FLAGS_fst_weight_parentheses.empty()) { + if (FLAGS_fst_weight_parentheses.size() != 2) { + FSTERROR() << "FLAGS_fst_weight_parentheses.size() is not equal to 2"; + strm.clear(std::ios::badbit); + return strm; + } + write_parens = true; + } + + if (write_parens) + strm << FLAGS_fst_weight_parentheses[0]; + for (size_t i = 0; i < n; ++i) { + if(i) + strm << separator; + strm << w.Value(i); + } + if (write_parens) + strm << FLAGS_fst_weight_parentheses[1]; + + return strm; +} + +template <class W, unsigned int n> +inline istream &operator>>(istream &strm, TupleWeight<W, n> &w) { + if(FLAGS_fst_weight_separator.size() != 1) { + FSTERROR() << "FLAGS_fst_weight_separator.size() is not equal to 1"; + strm.clear(std::ios::badbit); + return strm; + } + char separator = FLAGS_fst_weight_separator[0]; + + if (!FLAGS_fst_weight_parentheses.empty()) { + if (FLAGS_fst_weight_parentheses.size() != 2) { + FSTERROR() << "FLAGS_fst_weight_parentheses.size() is not equal to 2"; + strm.clear(std::ios::badbit); + return strm; + } + return TupleWeight<W, n>::ReadWithParen( + strm, w, separator, FLAGS_fst_weight_parentheses[0], + FLAGS_fst_weight_parentheses[1]); + } else { + return TupleWeight<W, n>::ReadNoParen(strm, w, separator); + } +} + + + +} // namespace fst + +#endif // FST_LIB_TUPLE_WEIGHT_H__ diff --git a/kaldi_io/src/tools/openfst/include/fst/types.h b/kaldi_io/src/tools/openfst/include/fst/types.h new file mode 100644 index 0000000..8c4367a --- /dev/null +++ b/kaldi_io/src/tools/openfst/include/fst/types.h @@ -0,0 +1,38 @@ +// types.h +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Author: [email protected] (Michael Riley) +// +// \file +// Various type definitions (mostly for Google compatibility). + +#include <cstdlib> // for ssize_t +#include <stdint.h> // *int*_t + +#include <fst/compat.h> // for DISALLOW_COPY_AND_ASSIGN + +#ifndef FST_LIB_TYPES_H__ +#define FST_LIB_TYPES_H__ + +typedef int8_t int8; +typedef int16_t int16; +typedef int32_t int32; +typedef int64_t int64; + +typedef uint8_t uint8; +typedef uint16_t uint16; +typedef uint32_t uint32; +typedef uint64_t uint64; + +#endif // FST_LIB_TYPES_H__ diff --git a/kaldi_io/src/tools/openfst/include/fst/union-find.h b/kaldi_io/src/tools/openfst/include/fst/union-find.h new file mode 100644 index 0000000..c8633e0 --- /dev/null +++ b/kaldi_io/src/tools/openfst/include/fst/union-find.h @@ -0,0 +1,110 @@ + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Copyright 2005-2010 Google, Inc. +// Author: [email protected] (Wojciech Skut) +// +// \file Union-Find algorithm for dense sets of non-negative +// integers. Implemented using disjoint tree forests with rank +// heuristics and path compression. + +#ifndef __fst_union_find_inl_h__ +#define __fst_union_find_inl_h__ + +#include <stack> +#include <vector> +using std::vector; +#include <fst/types.h> + +namespace fst { + +// Union-Find algorithm for dense sets of non-negative integers +// (exact type: T). +template <class T> +class UnionFind { + public: + // Ctor: creates a disjoint set forest for the range [0;max). + // 'fail' is a value indicating that an element hasn't been + // initialized using MakeSet(...). The upper bound of the range + // can be reset (increased) using MakeSet(...). + UnionFind(T max, T fail) + : parent_(max, fail), rank_(max), fail_(fail) { } + + // Finds the representative of the set 'item' belongs to. + // Performs path compression if needed. + T FindSet(T item) { + if (item >= parent_.size() + || item == fail_ + || parent_[item] == fail_) return fail_; + + T *p = &parent_[item]; + for (; *p != item; item = *p, p = &parent_[item]) { + exec_stack_.push(p); + } + for (; ! exec_stack_.empty(); exec_stack_.pop()) { + *exec_stack_.top() = *p; + } + return *p; + } + + // Creates the (destructive) union of the sets x and y belong to. + void Union(T x, T y) { + Link(FindSet(x), FindSet(y)); + } + + // Initialization of an element: creates a singleton set containing + // 'item'. The range [0;max) is reset if item >= max. + T MakeSet(T item) { + if (item >= parent_.size()) { + // New value in parent_ should be initialized to fail_ + size_t nitem = item > 0 ? 2 * item : 2; + parent_.resize(nitem, fail_); + rank_.resize(nitem); + } + parent_[item] = item; + return item; + } + + // Initialization of all elements starting from 0 to max - 1 to distinct sets + void MakeAllSet(T max) { + parent_.resize(max); + for (T item = 0; item < max; ++item) { + parent_[item] = item; + } + } + + private: + vector<T> parent_; // Parent nodes. + vector<int> rank_; // Rank of an element = min. depth in tree. + T fail_; // Value indicating lookup failure. + stack<T*> exec_stack_; // Used for path compression. + + // Links trees rooted in 'x' and 'y'. + void Link(T x, T y) { + if (x == y) return; + + if (rank_[x] > rank_[y]) { + parent_[y] = x; + } else { + parent_[x] = y; + if (rank_[x] == rank_[y]) { + ++rank_[y]; + } + } + } + DISALLOW_COPY_AND_ASSIGN(UnionFind); +}; + +} // namespace fst + +#endif // __fst_union_find_inl_h__ diff --git a/kaldi_io/src/tools/openfst/include/fst/union.h b/kaldi_io/src/tools/openfst/include/fst/union.h new file mode 100644 index 0000000..a2f97fb --- /dev/null +++ b/kaldi_io/src/tools/openfst/include/fst/union.h @@ -0,0 +1,185 @@ +// union.h + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Copyright 2005-2010 Google, Inc. +// Author: [email protected] (Michael Riley) +// +// \file +// Functions and classes to compute the union of two FSTs. + +#ifndef FST_LIB_UNION_H__ +#define FST_LIB_UNION_H__ + +#include <vector> +using std::vector; +#include <algorithm> + +#include <fst/mutable-fst.h> +#include <fst/rational.h> + + +namespace fst { + +// Computes the union (sum) of two FSTs. This version writes the +// union to an output MurableFst. If A transduces string x to y with +// weight a and B transduces string w to v with weight b, then their +// union transduces x to y with weight a and w to v with weight b. +// +// Complexity: +// - Time: (V2 + E2) +// - Space: O(V2 + E2) +// where Vi = # of states and Ei = # of arcs of the ith FST. +template <class Arc> +void Union(MutableFst<Arc> *fst1, const Fst<Arc> &fst2) { + typedef typename Arc::StateId StateId; + typedef typename Arc::Label Label; + typedef typename Arc::Weight Weight; + + // TODO(riley): restore when voice actions issues fixed + // Check that the symbol table are compatible + if (!CompatSymbols(fst1->InputSymbols(), fst2.InputSymbols()) || + !CompatSymbols(fst1->OutputSymbols(), fst2.OutputSymbols())) { + LOG(ERROR) << "Union: input/output symbol tables of 1st argument " + << "do not match input/output symbol tables of 2nd argument"; + // fst1->SetProperties(kError, kError); + // return; + } + + StateId numstates1 = fst1->NumStates(); + bool initial_acyclic1 = fst1->Properties(kInitialAcyclic, true); + uint64 props1 = fst1->Properties(kFstProperties, false); + uint64 props2 = fst2.Properties(kFstProperties, false); + + StateId start2 = fst2.Start(); + if (start2 == kNoStateId) { + if (props2 & kError) fst1->SetProperties(kError, kError); + return; + } + + if (fst2.Properties(kExpanded, false)) { + fst1->ReserveStates( + numstates1 + CountStates(fst2) + (initial_acyclic1 ? 0 : 1)); + } + + for (StateIterator< Fst<Arc> > siter(fst2); + !siter.Done(); + siter.Next()) { + StateId s1 = fst1->AddState(); + StateId s2 = siter.Value(); + fst1->SetFinal(s1, fst2.Final(s2)); + fst1->ReserveArcs(s1, fst2.NumArcs(s2)); + for (ArcIterator< Fst<Arc> > aiter(fst2, s2); + !aiter.Done(); + aiter.Next()) { + Arc arc = aiter.Value(); + arc.nextstate += numstates1; + fst1->AddArc(s1, arc); + } + } + StateId start1 = fst1->Start(); + if (start1 == kNoStateId) { + fst1->SetStart(start2); + fst1->SetProperties(props2, kCopyProperties); + return; + } + + if (initial_acyclic1) { + fst1->AddArc(start1, Arc(0, 0, Weight::One(), start2 + numstates1)); + } else { + StateId nstart1 = fst1->AddState(); + fst1->SetStart(nstart1); + fst1->AddArc(nstart1, Arc(0, 0, Weight::One(), start1)); + fst1->AddArc(nstart1, Arc(0, 0, Weight::One(), start2 + numstates1)); + } + fst1->SetProperties(UnionProperties(props1, props2), kFstProperties); +} + + +// Computes the union of two FSTs; this version modifies its +// RationalFst argument. +template<class Arc> +void Union(RationalFst<Arc> *fst1, const Fst<Arc> &fst2) { + fst1->GetImpl()->AddUnion(fst2); +} + + +typedef RationalFstOptions UnionFstOptions; + + +// Computes the union (sum) of two FSTs. This version is a delayed +// Fst. If A transduces string x to y with weight a and B transduces +// string w to v with weight b, then their union transduces x to y +// with weight a and w to v with weight b. +// +// Complexity: +// - Time: O(v1 + e1 + v2 + e2) +// - Sapce: O(v1 + v2) +// where vi = # of states visited and ei = # of arcs visited of the +// ith FST. Constant time and space to visit an input state or arc +// is assumed and exclusive of caching. +template <class A> +class UnionFst : public RationalFst<A> { + public: + using ImplToFst< RationalFstImpl<A> >::GetImpl; + + typedef A Arc; + typedef typename A::Weight Weight; + typedef typename A::StateId StateId; + + UnionFst(const Fst<A> &fst1, const Fst<A> &fst2) { + GetImpl()->InitUnion(fst1, fst2); + } + + UnionFst(const Fst<A> &fst1, const Fst<A> &fst2, const UnionFstOptions &opts) + : RationalFst<A>(opts) { + GetImpl()->InitUnion(fst1, fst2); + } + + // See Fst<>::Copy() for doc. + UnionFst(const UnionFst<A> &fst, bool safe = false) + : RationalFst<A>(fst, safe) {} + + // Get a copy of this UnionFst. See Fst<>::Copy() for further doc. + virtual UnionFst<A> *Copy(bool safe = false) const { + return new UnionFst<A>(*this, safe); + } +}; + + +// Specialization for UnionFst. +template <class A> +class StateIterator< UnionFst<A> > : public StateIterator< RationalFst<A> > { + public: + explicit StateIterator(const UnionFst<A> &fst) + : StateIterator< RationalFst<A> >(fst) {} +}; + + +// Specialization for UnionFst. +template <class A> +class ArcIterator< UnionFst<A> > : public ArcIterator< RationalFst<A> > { + public: + typedef typename A::StateId StateId; + + ArcIterator(const UnionFst<A> &fst, StateId s) + : ArcIterator< RationalFst<A> >(fst, s) {} +}; + + +// Useful alias when using StdArc. +typedef UnionFst<StdArc> StdUnionFst; + +} // namespace fst + +#endif // FST_LIB_UNION_H__ diff --git a/kaldi_io/src/tools/openfst/include/fst/util.h b/kaldi_io/src/tools/openfst/include/fst/util.h new file mode 100644 index 0000000..57d7c4b --- /dev/null +++ b/kaldi_io/src/tools/openfst/include/fst/util.h @@ -0,0 +1,437 @@ +// util.h + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Copyright 2005-2010 Google, Inc. +// Author: [email protected] (Michael Riley) +// +// \file +// FST utility inline definitions. + +#ifndef FST_LIB_UTIL_H__ +#define FST_LIB_UTIL_H__ + +#include <tr1/unordered_map> +using std::tr1::unordered_map; +using std::tr1::unordered_multimap; +#include <tr1/unordered_set> +using std::tr1::unordered_set; +using std::tr1::unordered_multiset; +#include <list> +#include <map> +#include <set> +#include <sstream> +#include <string> +#include <vector> +using std::vector; + + +#include <fst/compat.h> +#include <fst/types.h> + +#include <iostream> +#include <fstream> +#include <sstream> + +// +// UTILITY FOR ERROR HANDLING +// + +DECLARE_bool(fst_error_fatal); + +#define FSTERROR() (FLAGS_fst_error_fatal ? LOG(FATAL) : LOG(ERROR)) + +namespace fst { + +// +// UTILITIES FOR TYPE I/O +// + +// Read some types from an input stream. + +// Generic case. +template <typename T> +inline istream &ReadType(istream &strm, T *t) { + return t->Read(strm); +} + +// Fixed size, contiguous memory read. +#define READ_POD_TYPE(T) \ +inline istream &ReadType(istream &strm, T *t) { \ + return strm.read(reinterpret_cast<char *>(t), sizeof(T)); \ +} + +READ_POD_TYPE(bool); +READ_POD_TYPE(char); +READ_POD_TYPE(signed char); +READ_POD_TYPE(unsigned char); +READ_POD_TYPE(short); +READ_POD_TYPE(unsigned short); +READ_POD_TYPE(int); +READ_POD_TYPE(unsigned int); +READ_POD_TYPE(long); +READ_POD_TYPE(unsigned long); +READ_POD_TYPE(long long); +READ_POD_TYPE(unsigned long long); +READ_POD_TYPE(float); +READ_POD_TYPE(double); + +// String case. +inline istream &ReadType(istream &strm, string *s) { + s->clear(); + int32 ns = 0; + strm.read(reinterpret_cast<char *>(&ns), sizeof(ns)); + for (int i = 0; i < ns; ++i) { + char c; + strm.read(&c, 1); + *s += c; + } + return strm; +} + +// Pair case. +template <typename S, typename T> +inline istream &ReadType(istream &strm, pair<S, T> *p) { + ReadType(strm, &p->first); + ReadType(strm, &p->second); + return strm; +} + +template <typename S, typename T> +inline istream &ReadType(istream &strm, pair<const S, T> *p) { + ReadType(strm, const_cast<S *>(&p->first)); + ReadType(strm, &p->second); + return strm; +} + +// General case - no-op. +template <typename C> +void StlReserve(C *c, int64 n) {} + +// Specialization for vectors. +template <typename S, typename T> +void StlReserve(vector<S, T> *c, int64 n) { + c->reserve(n); +} + +// STL sequence container. +#define READ_STL_SEQ_TYPE(C) \ +template <typename S, typename T> \ +inline istream &ReadType(istream &strm, C<S, T> *c) { \ + c->clear(); \ + int64 n = 0; \ + strm.read(reinterpret_cast<char *>(&n), sizeof(n)); \ + StlReserve(c, n); \ + for (ssize_t i = 0; i < n; ++i) { \ + typename C<S, T>::value_type value; \ + ReadType(strm, &value); \ + c->insert(c->end(), value); \ + } \ + return strm; \ +} + +READ_STL_SEQ_TYPE(vector); +READ_STL_SEQ_TYPE(list); + +// STL associative container. +#define READ_STL_ASSOC_TYPE(C) \ +template <typename S, typename T, typename U> \ +inline istream &ReadType(istream &strm, C<S, T, U> *c) { \ + c->clear(); \ + int64 n = 0; \ + strm.read(reinterpret_cast<char *>(&n), sizeof(n)); \ + for (ssize_t i = 0; i < n; ++i) { \ + typename C<S, T, U>::value_type value; \ + ReadType(strm, &value); \ + c->insert(value); \ + } \ + return strm; \ +} + +READ_STL_ASSOC_TYPE(set); +READ_STL_ASSOC_TYPE(unordered_set); +READ_STL_ASSOC_TYPE(map); +READ_STL_ASSOC_TYPE(unordered_map); + +// Write some types to an output stream. + +// Generic case. +template <typename T> +inline ostream &WriteType(ostream &strm, const T t) { + t.Write(strm); + return strm; +} + +// Fixed size, contiguous memory write. +#define WRITE_POD_TYPE(T) \ +inline ostream &WriteType(ostream &strm, const T t) { \ + return strm.write(reinterpret_cast<const char *>(&t), sizeof(T)); \ +} + +WRITE_POD_TYPE(bool); +WRITE_POD_TYPE(char); +WRITE_POD_TYPE(signed char); +WRITE_POD_TYPE(unsigned char); +WRITE_POD_TYPE(short); +WRITE_POD_TYPE(unsigned short); +WRITE_POD_TYPE(int); +WRITE_POD_TYPE(unsigned int); +WRITE_POD_TYPE(long); +WRITE_POD_TYPE(unsigned long); +WRITE_POD_TYPE(long long); +WRITE_POD_TYPE(unsigned long long); +WRITE_POD_TYPE(float); +WRITE_POD_TYPE(double); + +// String case. +inline ostream &WriteType(ostream &strm, const string &s) { + int32 ns = s.size(); + strm.write(reinterpret_cast<const char *>(&ns), sizeof(ns)); + return strm.write(s.data(), ns); +} + +// Pair case. +template <typename S, typename T> +inline ostream &WriteType(ostream &strm, const pair<S, T> &p) { + WriteType(strm, p.first); + WriteType(strm, p.second); + return strm; +} + +// STL sequence container. +#define WRITE_STL_SEQ_TYPE(C) \ +template <typename S, typename T> \ +inline ostream &WriteType(ostream &strm, const C<S, T> &c) { \ + int64 n = c.size(); \ + strm.write(reinterpret_cast<char *>(&n), sizeof(n)); \ + for (typename C<S, T>::const_iterator it = c.begin(); \ + it != c.end(); ++it) \ + WriteType(strm, *it); \ + return strm; \ +} + +WRITE_STL_SEQ_TYPE(vector); +WRITE_STL_SEQ_TYPE(list); + +// STL associative container. +#define WRITE_STL_ASSOC_TYPE(C) \ +template <typename S, typename T, typename U> \ +inline ostream &WriteType(ostream &strm, const C<S, T, U> &c) { \ + int64 n = c.size(); \ + strm.write(reinterpret_cast<char *>(&n), sizeof(n)); \ + for (typename C<S, T, U>::const_iterator it = c.begin(); \ + it != c.end(); ++it) \ + WriteType(strm, *it); \ + return strm; \ +} + +WRITE_STL_ASSOC_TYPE(set); +WRITE_STL_ASSOC_TYPE(unordered_set); +WRITE_STL_ASSOC_TYPE(map); +WRITE_STL_ASSOC_TYPE(unordered_map); + +// Utilities for converting between int64 or Weight and string. + +int64 StrToInt64(const string &s, const string &src, size_t nline, + bool allow_negative, bool *error = 0); + +template <typename Weight> +Weight StrToWeight(const string &s, const string &src, size_t nline) { + Weight w; + istringstream strm(s); + strm >> w; + if (!strm) { + FSTERROR() << "StrToWeight: Bad weight = \"" << s + << "\", source = " << src << ", line = " << nline; + return Weight::NoWeight(); + } + return w; +} + +void Int64ToStr(int64 n, string *s); + +template <typename Weight> +void WeightToStr(Weight w, string *s) { + ostringstream strm; + strm.precision(9); + strm << w; + s->append(strm.str().data(), strm.str().size()); +} + +// Utilities for reading/writing label pairs + +// Returns true on success +template <typename Label> +bool ReadLabelPairs(const string& filename, + vector<pair<Label, Label> >* pairs, + bool allow_negative = false) { + ifstream strm(filename.c_str()); + + if (!strm) { + LOG(ERROR) << "ReadLabelPairs: Can't open file: " << filename; + return false; + } + + const int kLineLen = 8096; + char line[kLineLen]; + size_t nline = 0; + + pairs->clear(); + while (strm.getline(line, kLineLen)) { + ++nline; + vector<char *> col; + SplitToVector(line, "\n\t ", &col, true); + if (col.size() == 0 || col[0][0] == '\0') // empty line + continue; + if (col.size() != 2) { + LOG(ERROR) << "ReadLabelPairs: Bad number of columns, " + << "file = " << filename << ", line = " << nline; + return false; + } + + bool err; + Label frmlabel = StrToInt64(col[0], filename, nline, allow_negative, &err); + if (err) return false; + Label tolabel = StrToInt64(col[1], filename, nline, allow_negative, &err); + if (err) return false; + pairs->push_back(make_pair(frmlabel, tolabel)); + } + return true; +} + +// Returns true on success +template <typename Label> +bool WriteLabelPairs(const string& filename, + const vector<pair<Label, Label> >& pairs) { + ostream *strm = &cout; + if (!filename.empty()) { + strm = new ofstream(filename.c_str()); + if (!*strm) { + LOG(ERROR) << "WriteLabelPairs: Can't open file: " << filename; + return false; + } + } + + for (ssize_t n = 0; n < pairs.size(); ++n) + *strm << pairs[n].first << "\t" << pairs[n].second << "\n"; + + if (!*strm) { + LOG(ERROR) << "WriteLabelPairs: Write failed: " + << (filename.empty() ? "standard output" : filename); + return false; + } + if (strm != &cout) + delete strm; + return true; +} + +// Utilities for converting a type name to a legal C symbol. + +void ConvertToLegalCSymbol(string *s); + + +// +// UTILITIES FOR STREAM I/O +// + +bool AlignInput(istream &strm); +bool AlignOutput(ostream &strm); + +// +// UTILITIES FOR PROTOCOL BUFFER I/O +// + + +// An associative container for which testing membership is +// faster than an STL set if members are restricted to an interval +// that excludes most non-members. A 'Key' must have ==, !=, and < defined. +// Element 'NoKey' should be a key that marks an uninitialized key and +// is otherwise unused. 'Find()' returns an STL const_iterator to the match +// found, otherwise it equals 'End()'. +template <class Key, Key NoKey> +class CompactSet { +public: + typedef typename set<Key>::const_iterator const_iterator; + + CompactSet() + : min_key_(NoKey), + max_key_(NoKey) { } + + CompactSet(const CompactSet<Key, NoKey> &compact_set) + : set_(compact_set.set_), + min_key_(compact_set.min_key_), + max_key_(compact_set.max_key_) { } + + void Insert(Key key) { + set_.insert(key); + if (min_key_ == NoKey || key < min_key_) + min_key_ = key; + if (max_key_ == NoKey || max_key_ < key) + max_key_ = key; + } + + void Erase(Key key) { + set_.erase(key); + if (set_.empty()) { + min_key_ = max_key_ = NoKey; + } else if (key == min_key_) { + ++min_key_; + } else if (key == max_key_) { + --max_key_; + } + } + + void Clear() { + set_.clear(); + min_key_ = max_key_ = NoKey; + } + + const_iterator Find(Key key) const { + if (min_key_ == NoKey || + key < min_key_ || max_key_ < key) + return set_.end(); + else + return set_.find(key); + } + + bool Member(Key key) const { + if (min_key_ == NoKey || key < min_key_ || max_key_ < key) { + return false; // out of range + } else if (min_key_ != NoKey && max_key_ + 1 == min_key_ + set_.size()) { + return true; // dense range + } else { + return set_.find(key) != set_.end(); + } + } + + const_iterator Begin() const { return set_.begin(); } + + const_iterator End() const { return set_.end(); } + + // All stored keys are greater than or equal to this value. + Key LowerBound() const { return min_key_; } + + // All stored keys are less than or equal to this value. + Key UpperBound() const { return max_key_; } + +private: + set<Key> set_; + Key min_key_; + Key max_key_; + + void operator=(const CompactSet<Key, NoKey> &); //disallow +}; + +} // namespace fst + +#endif // FST_LIB_UTIL_H__ diff --git a/kaldi_io/src/tools/openfst/include/fst/vector-fst.h b/kaldi_io/src/tools/openfst/include/fst/vector-fst.h new file mode 100644 index 0000000..8b80876 --- /dev/null +++ b/kaldi_io/src/tools/openfst/include/fst/vector-fst.h @@ -0,0 +1,731 @@ +// vector-fst.h + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Copyright 2005-2010 Google, Inc. +// Author: [email protected] (Michael Riley) +// +// \file +// Simple concrete, mutable FST whose states and arcs are stored in STL +// vectors. + +#ifndef FST_LIB_VECTOR_FST_H__ +#define FST_LIB_VECTOR_FST_H__ + +#include <string> +#include <vector> +using std::vector; + +#include <fst/mutable-fst.h> +#include <fst/test-properties.h> + + +namespace fst { + +template <class A> class VectorFst; +template <class F, class G> void Cast(const F &, G *); + + +// States and arcs implemented by STL vectors, templated on the +// State definition. This does not manage the Fst properties. +template <class State> +class VectorFstBaseImpl : public FstImpl<typename State::Arc> { + public: + typedef typename State::Arc Arc; + typedef typename Arc::Weight Weight; + typedef typename Arc::StateId StateId; + + VectorFstBaseImpl() : start_(kNoStateId) {} + + ~VectorFstBaseImpl() { + for (StateId s = 0; s < states_.size(); ++s) + delete states_[s]; + } + + StateId Start() const { return start_; } + + Weight Final(StateId s) const { return states_[s]->final; } + + StateId NumStates() const { return states_.size(); } + + size_t NumArcs(StateId s) const { return states_[s]->arcs.size(); } + + void SetStart(StateId s) { start_ = s; } + + void SetFinal(StateId s, Weight w) { states_[s]->final = w; } + + StateId AddState() { + states_.push_back(new State); + return states_.size() - 1; + } + + StateId AddState(State *state) { + states_.push_back(state); + return states_.size() - 1; + } + + void AddArc(StateId s, const Arc &arc) { + states_[s]->arcs.push_back(arc); + } + + void DeleteStates(const vector<StateId>& dstates) { + vector<StateId> newid(states_.size(), 0); + for (size_t i = 0; i < dstates.size(); ++i) + newid[dstates[i]] = kNoStateId; + StateId nstates = 0; + for (StateId s = 0; s < states_.size(); ++s) { + if (newid[s] != kNoStateId) { + newid[s] = nstates; + if (s != nstates) + states_[nstates] = states_[s]; + ++nstates; + } else { + delete states_[s]; + } + } + states_.resize(nstates); + for (StateId s = 0; s < states_.size(); ++s) { + vector<Arc> &arcs = states_[s]->arcs; + size_t narcs = 0; + for (size_t i = 0; i < arcs.size(); ++i) { + StateId t = newid[arcs[i].nextstate]; + if (t != kNoStateId) { + arcs[i].nextstate = t; + if (i != narcs) + arcs[narcs] = arcs[i]; + ++narcs; + } else { + if (arcs[i].ilabel == 0) + --states_[s]->niepsilons; + if (arcs[i].olabel == 0) + --states_[s]->noepsilons; + } + } + arcs.resize(narcs); + } + if (Start() != kNoStateId) + SetStart(newid[Start()]); + } + + void DeleteStates() { + for (StateId s = 0; s < states_.size(); ++s) + delete states_[s]; + states_.clear(); + SetStart(kNoStateId); + } + + void DeleteArcs(StateId s, size_t n) { + states_[s]->arcs.resize(states_[s]->arcs.size() - n); + } + + void DeleteArcs(StateId s) { states_[s]->arcs.clear(); } + + State *GetState(StateId s) { return states_[s]; } + + const State *GetState(StateId s) const { return states_[s]; } + + void SetState(StateId s, State *state) { states_[s] = state; } + + void ReserveStates(StateId n) { states_.reserve(n); } + + void ReserveArcs(StateId s, size_t n) { states_[s]->arcs.reserve(n); } + + // Provide information needed for generic state iterator + void InitStateIterator(StateIteratorData<Arc> *data) const { + data->base = 0; + data->nstates = states_.size(); + } + + // Provide information needed for generic arc iterator + void InitArcIterator(StateId s, ArcIteratorData<Arc> *data) const { + data->base = 0; + data->narcs = states_[s]->arcs.size(); + data->arcs = data->narcs > 0 ? &states_[s]->arcs[0] : 0; + data->ref_count = 0; + } + + private: + vector<State *> states_; // States represenation. + StateId start_; // initial state + + DISALLOW_COPY_AND_ASSIGN(VectorFstBaseImpl); +}; + +// Arcs implemented by an STL vector per state. +template <class A> +struct VectorState { + typedef A Arc; + typedef typename A::Weight Weight; + typedef typename A::StateId StateId; + + VectorState() : final(Weight::Zero()), niepsilons(0), noepsilons(0) {} + + Weight final; // Final weight + vector<A> arcs; // Arcs represenation + size_t niepsilons; // # of input epsilons + size_t noepsilons; // # of output epsilons +}; + +// This is a VectorFstBaseImpl container that holds VectorState's. It +// manages Fst properties and the # of input and output epsilons. +template <class A> +class VectorFstImpl : public VectorFstBaseImpl< VectorState<A> > { + public: + using FstImpl<A>::SetInputSymbols; + using FstImpl<A>::SetOutputSymbols; + using FstImpl<A>::SetType; + using FstImpl<A>::SetProperties; + using FstImpl<A>::Properties; + + using VectorFstBaseImpl<VectorState<A> >::Start; + using VectorFstBaseImpl<VectorState<A> >::NumStates; + using VectorFstBaseImpl<VectorState<A> >::GetState; + using VectorFstBaseImpl<VectorState<A> >::ReserveArcs; + + friend class MutableArcIterator< VectorFst<A> >; + + typedef VectorFstBaseImpl< VectorState<A> > BaseImpl; + typedef typename A::Weight Weight; + typedef typename A::StateId StateId; + + VectorFstImpl() { + SetType("vector"); + SetProperties(kNullProperties | kStaticProperties); + } + explicit VectorFstImpl(const Fst<A> &fst); + + static VectorFstImpl<A> *Read(istream &strm, const FstReadOptions &opts); + + size_t NumInputEpsilons(StateId s) const { return GetState(s)->niepsilons; } + + size_t NumOutputEpsilons(StateId s) const { return GetState(s)->noepsilons; } + + void SetStart(StateId s) { + BaseImpl::SetStart(s); + SetProperties(SetStartProperties(Properties())); + } + + void SetFinal(StateId s, Weight w) { + Weight ow = BaseImpl::Final(s); + BaseImpl::SetFinal(s, w); + SetProperties(SetFinalProperties(Properties(), ow, w)); + } + + StateId AddState() { + StateId s = BaseImpl::AddState(); + SetProperties(AddStateProperties(Properties())); + return s; + } + + void AddArc(StateId s, const A &arc) { + VectorState<A> *state = GetState(s); + if (arc.ilabel == 0) { + ++state->niepsilons; + } + if (arc.olabel == 0) { + ++state->noepsilons; + } + + const A *parc = state->arcs.empty() ? 0 : &(state->arcs.back()); + SetProperties(AddArcProperties(Properties(), s, arc, parc)); + + BaseImpl::AddArc(s, arc); + } + + void DeleteStates(const vector<StateId> &dstates) { + BaseImpl::DeleteStates(dstates); + SetProperties(DeleteStatesProperties(Properties())); + } + + void DeleteStates() { + BaseImpl::DeleteStates(); + SetProperties(DeleteAllStatesProperties(Properties(), + kStaticProperties)); + } + + void DeleteArcs(StateId s, size_t n) { + const vector<A> &arcs = GetState(s)->arcs; + for (size_t i = 0; i < n; ++i) { + size_t j = arcs.size() - i - 1; + if (arcs[j].ilabel == 0) + --GetState(s)->niepsilons; + if (arcs[j].olabel == 0) + --GetState(s)->noepsilons; + } + BaseImpl::DeleteArcs(s, n); + SetProperties(DeleteArcsProperties(Properties())); + } + + void DeleteArcs(StateId s) { + GetState(s)->niepsilons = 0; + GetState(s)->noepsilons = 0; + BaseImpl::DeleteArcs(s); + SetProperties(DeleteArcsProperties(Properties())); + } + + // Properties always true of this Fst class + static const uint64 kStaticProperties = kExpanded | kMutable; + + private: + // Current file format version + static const int kFileVersion = 2; + // Minimum file format version supported + static const int kMinFileVersion = 1; + + DISALLOW_COPY_AND_ASSIGN(VectorFstImpl); +}; + +template <class A> const uint64 VectorFstImpl<A>::kStaticProperties; +template <class A> const int VectorFstImpl<A>::kFileVersion; +template <class A> const int VectorFstImpl<A>::kMinFileVersion; + + +template <class A> +VectorFstImpl<A>::VectorFstImpl(const Fst<A> &fst) { + SetType("vector"); + SetInputSymbols(fst.InputSymbols()); + SetOutputSymbols(fst.OutputSymbols()); + BaseImpl::SetStart(fst.Start()); + if (fst.Properties(kExpanded, false)) + BaseImpl::ReserveStates(CountStates(fst)); + + for (StateIterator< Fst<A> > siter(fst); + !siter.Done(); + siter.Next()) { + StateId s = siter.Value(); + BaseImpl::AddState(); + BaseImpl::SetFinal(s, fst.Final(s)); + ReserveArcs(s, fst.NumArcs(s)); + for (ArcIterator< Fst<A> > aiter(fst, s); + !aiter.Done(); + aiter.Next()) { + const A &arc = aiter.Value(); + BaseImpl::AddArc(s, arc); + if (arc.ilabel == 0) + ++GetState(s)->niepsilons; + if (arc.olabel == 0) + ++GetState(s)->noepsilons; + } + } + SetProperties(fst.Properties(kCopyProperties, false) | kStaticProperties); +} + +template <class A> +VectorFstImpl<A> *VectorFstImpl<A>::Read(istream &strm, + const FstReadOptions &opts) { + VectorFstImpl<A> *impl = new VectorFstImpl; + FstHeader hdr; + if (!impl->ReadHeader(strm, opts, kMinFileVersion, &hdr)) { + delete impl; + return 0; + } + impl->BaseImpl::SetStart(hdr.Start()); + if (hdr.NumStates() != kNoStateId) { + impl->ReserveStates(hdr.NumStates()); + } + + StateId s = 0; + for (;hdr.NumStates() == kNoStateId || s < hdr.NumStates(); ++s) { + typename A::Weight final; + if (!final.Read(strm)) break; + impl->BaseImpl::AddState(); + VectorState<A> *state = impl->GetState(s); + state->final = final; + int64 narcs; + ReadType(strm, &narcs); + if (!strm) { + LOG(ERROR) << "VectorFst::Read: read failed: " << opts.source; + delete impl; + return 0; + } + impl->ReserveArcs(s, narcs); + for (size_t j = 0; j < narcs; ++j) { + A arc; + ReadType(strm, &arc.ilabel); + ReadType(strm, &arc.olabel); + arc.weight.Read(strm); + ReadType(strm, &arc.nextstate); + if (!strm) { + LOG(ERROR) << "VectorFst::Read: read failed: " << opts.source; + delete impl; + return 0; + } + impl->BaseImpl::AddArc(s, arc); + if (arc.ilabel == 0) + ++state->niepsilons; + if (arc.olabel == 0) + ++state->noepsilons; + } + } + if (hdr.NumStates() != kNoStateId && s != hdr.NumStates()) { + LOG(ERROR) << "VectorFst::Read: unexpected end of file: " << opts.source; + delete impl; + return 0; + } + return impl; +} + +// Converts a string into a weight. +template <class W> class WeightFromString { + public: + W operator()(const string &s); +}; + +// Generic case fails. +template <class W> inline +W WeightFromString<W>::operator()(const string &s) { + FSTERROR() << "VectorFst::Read: Obsolete file format"; + return W::NoWeight(); +} + +// TropicalWeight version. +template <> inline +TropicalWeight WeightFromString<TropicalWeight>::operator()(const string &s) { + float f; + memcpy(&f, s.data(), sizeof(f)); + return TropicalWeight(f); +} + +// LogWeight version. +template <> inline +LogWeight WeightFromString<LogWeight>::operator()(const string &s) { + float f; + memcpy(&f, s.data(), sizeof(f)); + return LogWeight(f); +} + +// Simple concrete, mutable FST. This class attaches interface to +// implementation and handles reference counting, delegating most +// methods to ImplToMutableFst. Supports additional operations: +// ReserveStates and ReserveArcs (cf. STL vectors). +template <class A> +class VectorFst : public ImplToMutableFst< VectorFstImpl<A> > { + public: + friend class StateIterator< VectorFst<A> >; + friend class ArcIterator< VectorFst<A> >; + friend class MutableArcIterator< VectorFst<A> >; + template <class F, class G> friend void Cast(const F &, G *); + + typedef A Arc; + typedef typename A::StateId StateId; + typedef VectorFstImpl<A> Impl; + + VectorFst() : ImplToMutableFst<Impl>(new Impl) {} + + explicit VectorFst(const Fst<A> &fst) + : ImplToMutableFst<Impl>(new Impl(fst)) {} + + VectorFst(const VectorFst<A> &fst) : ImplToMutableFst<Impl>(fst) {} + + // Get a copy of this VectorFst. See Fst<>::Copy() for further doc. + virtual VectorFst<A> *Copy(bool safe = false) const { + return new VectorFst<A>(*this); + } + + VectorFst<A> &operator=(const VectorFst<A> &fst) { + SetImpl(fst.GetImpl(), false); + return *this; + } + + virtual VectorFst<A> &operator=(const Fst<A> &fst) { + if (this != &fst) SetImpl(new Impl(fst)); + return *this; + } + + // Read a VectorFst from an input stream; return NULL on error + static VectorFst<A> *Read(istream &strm, const FstReadOptions &opts) { + Impl* impl = Impl::Read(strm, opts); + return impl ? new VectorFst<A>(impl) : 0; + } + + // Read a VectorFst from a file; return NULL on error + // Empty filename reads from standard input + static VectorFst<A> *Read(const string &filename) { + Impl* impl = ImplToExpandedFst<Impl, MutableFst<A> >::Read(filename); + return impl ? new VectorFst<A>(impl) : 0; + } + + virtual bool Write(ostream &strm, const FstWriteOptions &opts) const { + return WriteFst(*this, strm, opts); + } + + virtual bool Write(const string &filename) const { + return Fst<A>::WriteFile(filename); + } + + template <class F> + static bool WriteFst(const F &fst, ostream &strm, + const FstWriteOptions &opts); + + void ReserveStates(StateId n) { + MutateCheck(); + GetImpl()->ReserveStates(n); + } + + void ReserveArcs(StateId s, size_t n) { + MutateCheck(); + GetImpl()->ReserveArcs(s, n); + } + + virtual void InitStateIterator(StateIteratorData<Arc> *data) const { + GetImpl()->InitStateIterator(data); + } + + virtual void InitArcIterator(StateId s, ArcIteratorData<Arc> *data) const { + GetImpl()->InitArcIterator(s, data); + } + + virtual inline + void InitMutableArcIterator(StateId s, MutableArcIteratorData<A> *); + + private: + explicit VectorFst(Impl *impl) : ImplToMutableFst<Impl>(impl) {} + + // Makes visible to friends. + Impl *GetImpl() const { return ImplToFst< Impl, MutableFst<A> >::GetImpl(); } + + void SetImpl(Impl *impl, bool own_impl = true) { + ImplToFst< Impl, MutableFst<A> >::SetImpl(impl, own_impl); + } + + void MutateCheck() { return ImplToMutableFst<Impl>::MutateCheck(); } +}; + +// Specialization for VectorFst; see generic version in fst.h +// for sample usage (but use the VectorFst type!). This version +// should inline. +template <class A> +class StateIterator< VectorFst<A> > { + public: + typedef typename A::StateId StateId; + + explicit StateIterator(const VectorFst<A> &fst) + : nstates_(fst.GetImpl()->NumStates()), s_(0) {} + + bool Done() const { return s_ >= nstates_; } + + StateId Value() const { return s_; } + + void Next() { ++s_; } + + void Reset() { s_ = 0; } + + private: + StateId nstates_; + StateId s_; + + DISALLOW_COPY_AND_ASSIGN(StateIterator); +}; + +// Writes Fst to file, will call CountStates so may involve two passes if +// called from an Fst that is not derived from Expanded. +template <class A> +template <class F> +bool VectorFst<A>::WriteFst(const F &fst, ostream &strm, + const FstWriteOptions &opts) { + static const int kFileVersion = 2; + bool update_header = true; + FstHeader hdr; + hdr.SetStart(fst.Start()); + hdr.SetNumStates(kNoStateId); + size_t start_offset = 0; + if (fst.Properties(kExpanded, false) || (start_offset = strm.tellp()) != -1) { + hdr.SetNumStates(CountStates(fst)); + update_header = false; + } + uint64 properties = fst.Properties(kCopyProperties, false) | + VectorFstImpl<A>::kStaticProperties; + FstImpl<A>::WriteFstHeader(fst, strm, opts, kFileVersion, "vector", + properties, &hdr); + StateId num_states = 0; + for (StateIterator<F> siter(fst); !siter.Done(); siter.Next()) { + typename A::StateId s = siter.Value(); + fst.Final(s).Write(strm); + int64 narcs = fst.NumArcs(s); + WriteType(strm, narcs); + for (ArcIterator<F> aiter(fst, s); !aiter.Done(); aiter.Next()) { + const A &arc = aiter.Value(); + WriteType(strm, arc.ilabel); + WriteType(strm, arc.olabel); + arc.weight.Write(strm); + WriteType(strm, arc.nextstate); + } + num_states++; + } + strm.flush(); + if (!strm) { + LOG(ERROR) << "VectorFst::Write: write failed: " << opts.source; + return false; + } + if (update_header) { + hdr.SetNumStates(num_states); + return FstImpl<A>::UpdateFstHeader(fst, strm, opts, kFileVersion, "vector", + properties, &hdr, start_offset); + } else { + if (num_states != hdr.NumStates()) { + LOG(ERROR) << "Inconsistent number of states observed during write"; + return false; + } + } + return true; +} + +// Specialization for VectorFst; see generic version in fst.h +// for sample usage (but use the VectorFst type!). This version +// should inline. +template <class A> +class ArcIterator< VectorFst<A> > { + public: + typedef typename A::StateId StateId; + + ArcIterator(const VectorFst<A> &fst, StateId s) + : arcs_(fst.GetImpl()->GetState(s)->arcs), i_(0) {} + + bool Done() const { return i_ >= arcs_.size(); } + + const A& Value() const { return arcs_[i_]; } + + void Next() { ++i_; } + + void Reset() { i_ = 0; } + + void Seek(size_t a) { i_ = a; } + + size_t Position() const { return i_; } + + uint32 Flags() const { + return kArcValueFlags; + } + + void SetFlags(uint32 f, uint32 m) {} + + private: + const vector<A>& arcs_; + size_t i_; + + DISALLOW_COPY_AND_ASSIGN(ArcIterator); +}; + +// Specialization for VectorFst; see generic version in fst.h +// for sample usage (but use the VectorFst type!). This version +// should inline. +template <class A> +class MutableArcIterator< VectorFst<A> > + : public MutableArcIteratorBase<A> { + public: + typedef typename A::StateId StateId; + typedef typename A::Weight Weight; + + MutableArcIterator(VectorFst<A> *fst, StateId s) : i_(0) { + fst->MutateCheck(); + state_ = fst->GetImpl()->GetState(s); + properties_ = &fst->GetImpl()->properties_; + } + + bool Done() const { return i_ >= state_->arcs.size(); } + + const A& Value() const { return state_->arcs[i_]; } + + void Next() { ++i_; } + + size_t Position() const { return i_; } + + void Reset() { i_ = 0; } + + void Seek(size_t a) { i_ = a; } + + void SetValue(const A &arc) { + A& oarc = state_->arcs[i_]; + if (oarc.ilabel != oarc.olabel) + *properties_ &= ~kNotAcceptor; + if (oarc.ilabel == 0) { + --state_->niepsilons; + *properties_ &= ~kIEpsilons; + if (oarc.olabel == 0) + *properties_ &= ~kEpsilons; + } + if (oarc.olabel == 0) { + --state_->noepsilons; + *properties_ &= ~kOEpsilons; + } + if (oarc.weight != Weight::Zero() && oarc.weight != Weight::One()) + *properties_ &= ~kWeighted; + oarc = arc; + if (arc.ilabel != arc.olabel) { + *properties_ |= kNotAcceptor; + *properties_ &= ~kAcceptor; + } + if (arc.ilabel == 0) { + ++state_->niepsilons; + *properties_ |= kIEpsilons; + *properties_ &= ~kNoIEpsilons; + if (arc.olabel == 0) { + *properties_ |= kEpsilons; + *properties_ &= ~kNoEpsilons; + } + } + if (arc.olabel == 0) { + ++state_->noepsilons; + *properties_ |= kOEpsilons; + *properties_ &= ~kNoOEpsilons; + } + if (arc.weight != Weight::Zero() && arc.weight != Weight::One()) { + *properties_ |= kWeighted; + *properties_ &= ~kUnweighted; + } + *properties_ &= kSetArcProperties | kAcceptor | kNotAcceptor | + kEpsilons | kNoEpsilons | kIEpsilons | kNoIEpsilons | + kOEpsilons | kNoOEpsilons | kWeighted | kUnweighted; + } + + uint32 Flags() const { + return kArcValueFlags; + } + + void SetFlags(uint32 f, uint32 m) {} + + + private: + // This allows base-class virtual access to non-virtual derived- + // class members of the same name. It makes the derived class more + // efficient to use but unsafe to further derive. + virtual bool Done_() const { return Done(); } + virtual const A& Value_() const { return Value(); } + virtual void Next_() { Next(); } + virtual size_t Position_() const { return Position(); } + virtual void Reset_() { Reset(); } + virtual void Seek_(size_t a) { Seek(a); } + virtual void SetValue_(const A &a) { SetValue(a); } + uint32 Flags_() const { return Flags(); } + void SetFlags_(uint32 f, uint32 m) { SetFlags(f, m); } + + struct VectorState<A> *state_; + uint64 *properties_; + size_t i_; + + DISALLOW_COPY_AND_ASSIGN(MutableArcIterator); +}; + +// Provide information needed for the generic mutable arc iterator +template <class A> inline +void VectorFst<A>::InitMutableArcIterator( + StateId s, MutableArcIteratorData<A> *data) { + data->base = new MutableArcIterator< VectorFst<A> >(this, s); +} + +// A useful alias when using StdArc. +typedef VectorFst<StdArc> StdVectorFst; + +} // namespace fst + +#endif // FST_LIB_VECTOR_FST_H__ diff --git a/kaldi_io/src/tools/openfst/include/fst/verify.h b/kaldi_io/src/tools/openfst/include/fst/verify.h new file mode 100644 index 0000000..576cfca --- /dev/null +++ b/kaldi_io/src/tools/openfst/include/fst/verify.h @@ -0,0 +1,126 @@ +// verify.h + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Copyright 2005-2010 Google, Inc. +// Author: [email protected] (Michael Riley) +// +// \file +// Function to verify an Fst's contents + +#ifndef FST_LIB_VERIFY_H__ +#define FST_LIB_VERIFY_H__ + +#include <fst/fst.h> +#include <fst/test-properties.h> + + +namespace fst { + +// Verifies that an Fst's contents are sane. +template<class Arc> +bool Verify(const Fst<Arc> &fst, bool allow_negative_labels = false) { + typedef typename Arc::Label Label; + typedef typename Arc::Weight Weight; + typedef typename Arc::StateId StateId; + + StateId start = fst.Start(); + const SymbolTable *isyms = fst.InputSymbols(); + const SymbolTable *osyms = fst.OutputSymbols(); + + // Count states + StateId ns = 0; + for (StateIterator< Fst<Arc> > siter(fst); + !siter.Done(); + siter.Next()) + ++ns; + + if (start == kNoStateId && ns > 0) { + LOG(ERROR) << "Verify: Fst start state ID unset"; + return false; + } else if (start >= ns) { + LOG(ERROR) << "Verify: Fst start state ID exceeds number of states"; + return false; + } + + for (StateIterator< Fst<Arc> > siter(fst); + !siter.Done(); + siter.Next()) { + StateId s = siter.Value(); + size_t na = 0; + for (ArcIterator< Fst<Arc> > aiter(fst, s); + !aiter.Done(); + aiter.Next()) { + const Arc &arc =aiter.Value(); + if (!allow_negative_labels && arc.ilabel < 0) { + LOG(ERROR) << "Verify: Fst input label ID of arc at position " + << na << " of state " << s << " is negative"; + return false; + } else if (isyms && isyms->Find(arc.ilabel) == "") { + LOG(ERROR) << "Verify: Fst input label ID " << arc.ilabel + << " of arc at position " << na << " of state " << s + << " is missing from input symbol table \"" + << isyms->Name() << "\""; + return false; + } else if (!allow_negative_labels && arc.olabel < 0) { + LOG(ERROR) << "Verify: Fst output label ID of arc at position " + << na << " of state " << s << " is negative"; + return false; + } else if (osyms && osyms->Find(arc.olabel) == "") { + LOG(ERROR) << "Verify: Fst output label ID " << arc.olabel + << " of arc at position " << na << " of state " << s + << " is missing from output symbol table \"" + << osyms->Name() << "\""; + return false; + } else if (!arc.weight.Member() || arc.weight == Weight::Zero()) { + LOG(ERROR) << "Verify: Fst weight of arc at position " + << na << " of state " << s << " is invalid"; + return false; + } else if (arc.nextstate < 0) { + LOG(ERROR) << "Verify: Fst destination state ID of arc at position " + << na << " of state " << s << " is negative"; + return false; + } else if (arc.nextstate >= ns) { + LOG(ERROR) << "Verify: Fst destination state ID of arc at position " + << na << " of state " << s + << " exceeds number of states"; + return false; + } + ++na; + } + if (!fst.Final(s).Member()) { + LOG(ERROR) << "Verify: Fst final weight of state " << s << " is invalid"; + return false; + } + } + uint64 fst_props = fst.Properties(kFstProperties, false); + if (fst_props & kError) { + LOG(ERROR) << "Verify: Fst error property is set"; + return false; + } + + uint64 known_props; + uint64 test_props = ComputeProperties(fst, kFstProperties, &known_props, + false); + if (!CompatProperties(fst_props, test_props)) { + LOG(ERROR) << "Verify: stored Fst properties incorrect " + << "(props1 = stored props, props2 = tested)"; + return false; + } else { + return true; + } +} + +} // namespace fst + +#endif // FST_LIB_VERIFY_H__ diff --git a/kaldi_io/src/tools/openfst/include/fst/visit.h b/kaldi_io/src/tools/openfst/include/fst/visit.h new file mode 100644 index 0000000..5f5059a --- /dev/null +++ b/kaldi_io/src/tools/openfst/include/fst/visit.h @@ -0,0 +1,284 @@ +// visit.h + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Copyright 2005-2010 Google, Inc. +// Author: [email protected] (Michael Riley) +// +// \file +// Queue-dependent visitation of finite-state transducers. See also +// dfs-visit.h. + +#ifndef FST_LIB_VISIT_H__ +#define FST_LIB_VISIT_H__ + + +#include <fst/arcfilter.h> +#include <fst/mutable-fst.h> + + +namespace fst { + +// Visitor Interface - class determines actions taken during a visit. +// If any of the boolean member functions return false, the visit is +// aborted by first calling FinishState() on all unfinished (grey) +// states and then calling FinishVisit(). +// +// Note this is more general than the visitor interface in +// dfs-visit.h but lacks some DFS-specific behavior. +// +// template <class Arc> +// class Visitor { +// public: +// typedef typename Arc::StateId StateId; +// +// Visitor(T *return_data); +// // Invoked before visit +// void InitVisit(const Fst<Arc> &fst); +// // Invoked when state discovered (2nd arg is visitation root) +// bool InitState(StateId s, StateId root); +// // Invoked when arc to white/undiscovered state examined +// bool WhiteArc(StateId s, const Arc &a); +// // Invoked when arc to grey/unfinished state examined +// bool GreyArc(StateId s, const Arc &a); +// // Invoked when arc to black/finished state examined +// bool BlackArc(StateId s, const Arc &a); +// // Invoked when state finished. +// void FinishState(StateId s); +// // Invoked after visit +// void FinishVisit(); +// }; + +// Performs queue-dependent visitation. Visitor class argument +// determines actions and contains any return data. ArcFilter +// determines arcs that are considered. +// +// Note this is more general than DfsVisit() in dfs-visit.h but lacks +// some DFS-specific Visitor behavior. +template <class Arc, class V, class Q, class ArcFilter> +void Visit(const Fst<Arc> &fst, V *visitor, Q *queue, ArcFilter filter) { + + typedef typename Arc::StateId StateId; + typedef ArcIterator< Fst<Arc> > AIterator; + + visitor->InitVisit(fst); + + StateId start = fst.Start(); + if (start == kNoStateId) { + visitor->FinishVisit(); + return; + } + + // An Fst state's visit color + const unsigned kWhiteState = 0x01; // Undiscovered + const unsigned kGreyState = 0x02; // Discovered & unfinished + const unsigned kBlackState = 0x04; // Finished + + // We destroy an iterator as soon as possible and mark it so + const unsigned kArcIterDone = 0x08; // Arc iterator done and destroyed + + vector<unsigned char> state_status; + vector<AIterator *> arc_iterator; + + StateId nstates = start + 1; // # of known states in general case + bool expanded = false; + if (fst.Properties(kExpanded, false)) { // tests if expanded case, then + nstates = CountStates(fst); // uses ExpandedFst::NumStates(). + expanded = true; + } + + state_status.resize(nstates, kWhiteState); + arc_iterator.resize(nstates); + StateIterator< Fst<Arc> > siter(fst); + + // Continues visit while true + bool visit = true; + + // Iterates over trees in visit forest. + for (StateId root = start; visit && root < nstates;) { + visit = visitor->InitState(root, root); + state_status[root] = kGreyState; + queue->Enqueue(root); + while (!queue->Empty()) { + StateId s = queue->Head(); + if (s >= state_status.size()) { + nstates = s + 1; + state_status.resize(nstates, kWhiteState); + arc_iterator.resize(nstates); + } + // Creates arc iterator if needed. + if (arc_iterator[s] == 0 && !(state_status[s] & kArcIterDone) && visit) + arc_iterator[s] = new AIterator(fst, s); + // Deletes arc iterator if done. + AIterator *aiter = arc_iterator[s]; + if ((aiter && aiter->Done()) || !visit) { + delete aiter; + arc_iterator[s] = 0; + state_status[s] |= kArcIterDone; + } + // Dequeues state and marks black if done + if (state_status[s] & kArcIterDone) { + queue->Dequeue(); + visitor->FinishState(s); + state_status[s] = kBlackState; + continue; + } + + const Arc &arc = aiter->Value(); + if (arc.nextstate >= state_status.size()) { + nstates = arc.nextstate + 1; + state_status.resize(nstates, kWhiteState); + arc_iterator.resize(nstates); + } + // Visits respective arc types + if (filter(arc)) { + // Enqueues destination state and marks grey if white + if (state_status[arc.nextstate] == kWhiteState) { + visit = visitor->WhiteArc(s, arc); + if (!visit) continue; + visit = visitor->InitState(arc.nextstate, root); + state_status[arc.nextstate] = kGreyState; + queue->Enqueue(arc.nextstate); + } else if (state_status[arc.nextstate] == kBlackState) { + visit = visitor->BlackArc(s, arc); + } else { + visit = visitor->GreyArc(s, arc); + } + } + aiter->Next(); + // Destroys an iterator ASAP for efficiency. + if (aiter->Done()) { + delete aiter; + arc_iterator[s] = 0; + state_status[s] |= kArcIterDone; + } + } + // Finds next tree root + for (root = root == start ? 0 : root + 1; + root < nstates && state_status[root] != kWhiteState; + ++root) { + } + + // Check for a state beyond the largest known state + if (!expanded && root == nstates) { + for (; !siter.Done(); siter.Next()) { + if (siter.Value() == nstates) { + ++nstates; + state_status.push_back(kWhiteState); + arc_iterator.push_back(0); + break; + } + } + } + } + visitor->FinishVisit(); +} + + +template <class Arc, class V, class Q> +inline void Visit(const Fst<Arc> &fst, V *visitor, Q* queue) { + Visit(fst, visitor, queue, AnyArcFilter<Arc>()); +} + +// Copies input FST to mutable FST following queue order. +template <class A> +class CopyVisitor { + public: + typedef A Arc; + typedef typename A::StateId StateId; + + CopyVisitor(MutableFst<Arc> *ofst) : ifst_(0), ofst_(ofst) {} + + void InitVisit(const Fst<A> &ifst) { + ifst_ = &ifst; + ofst_->DeleteStates(); + ofst_->SetStart(ifst_->Start()); + } + + bool InitState(StateId s, StateId) { + while (ofst_->NumStates() <= s) + ofst_->AddState(); + return true; + } + + bool WhiteArc(StateId s, const Arc &arc) { + ofst_->AddArc(s, arc); + return true; + } + + bool GreyArc(StateId s, const Arc &arc) { + ofst_->AddArc(s, arc); + return true; + } + + bool BlackArc(StateId s, const Arc &arc) { + ofst_->AddArc(s, arc); + return true; + } + + void FinishState(StateId s) { + ofst_->SetFinal(s, ifst_->Final(s)); + } + + void FinishVisit() {} + + private: + const Fst<Arc> *ifst_; + MutableFst<Arc> *ofst_; +}; + + +// Visits input FST up to a state limit following queue order. If +// 'access_only' is true, aborts on visiting first state not +// accessible from the initial state. +template <class A> +class PartialVisitor { + public: + typedef A Arc; + typedef typename A::StateId StateId; + + explicit PartialVisitor(StateId maxvisit, bool access_only = false) + : maxvisit_(maxvisit), + access_only_(access_only), + start_(kNoStateId) {} + + void InitVisit(const Fst<A> &ifst) { + nvisit_ = 0; + start_ = ifst.Start(); + } + + bool InitState(StateId s, StateId root) { + if (access_only_ && root != start_) + return false; + ++nvisit_; + return nvisit_ <= maxvisit_; + } + + bool WhiteArc(StateId s, const Arc &arc) { return true; } + bool GreyArc(StateId s, const Arc &arc) { return true; } + bool BlackArc(StateId s, const Arc &arc) { return true; } + void FinishState(StateId s) {} + void FinishVisit() {} + + private: + StateId maxvisit_; + bool access_only_; + StateId nvisit_; + StateId start_; + +}; + + +} // namespace fst + +#endif // FST_LIB_VISIT_H__ diff --git a/kaldi_io/src/tools/openfst/include/fst/weight.h b/kaldi_io/src/tools/openfst/include/fst/weight.h new file mode 100644 index 0000000..7eb4bb1 --- /dev/null +++ b/kaldi_io/src/tools/openfst/include/fst/weight.h @@ -0,0 +1,179 @@ +// weight.h + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Copyright 2005-2010 Google, Inc. +// Author: [email protected] (Michael Riley) +// +// \file +// General weight set and associated semiring operation definitions. +// +// A semiring is specified by two binary operations Plus and Times and +// two designated elements Zero and One with the following properties: +// Plus: associative, commutative, and has Zero as its identity. +// Times: associative and has identity One, distributes w.r.t. Plus, and +// has Zero as an annihilator: +// Times(Zero(), a) == Times(a, Zero()) = Zero(). +// +// A left semiring distributes on the left; a right semiring is +// similarly defined. +// +// A Weight class must have binary functions =Plus= and =Times= and +// static member functions =Zero()= and =One()= and these must form +// (at least) a left or right semiring. +// +// In addition, the following should be defined for a Weight: +// Member: predicate on set membership. +// NoWeight: static member function that returns an element that is +// not a set member; used to signal an error. +// >>: reads textual representation of a weight. +// <<: prints textual representation of a weight. +// Read(istream &strm): reads binary representation of a weight. +// Write(ostream &strm): writes binary representation of a weight. +// Hash: maps weight to size_t. +// ApproxEqual: approximate equality (for inexact weights) +// Quantize: quantizes wrt delta (for inexact weights) +// Divide: for all a,b,c s.t. Times(a, b) == c +// --> b' = Divide(c, a, DIVIDE_LEFT) if a left semiring, b'.Member() +// and Times(a, b') == c +// --> a' = Divide(c, b, DIVIDE_RIGHT) if a right semiring, a'.Member() +// and Times(a', b) == c +// --> b' = Divide(c, a) = Divide(c, a, DIVIDE_ANY) = +// Divide(c, a, DIVIDE_LEFT) = Divide(c, a, DIVIDE_RIGHT) if a +// commutative semiring, b'.Member() and Times(a, b') = Times(b', a) = c +// ReverseWeight: the type of the corresponding reverse weight. +// Typically the same type as Weight for a (both left and right) semiring. +// For the left string semiring, it is the right string semiring. +// Reverse: a mapping from Weight to ReverseWeight s.t. +// --> Reverse(Reverse(a)) = a +// --> Reverse(Plus(a, b)) = Plus(Reverse(a), Reverse(b)) +// --> Reverse(Times(a, b)) = Times(Reverse(b), Reverse(a)) +// Typically the identity mapping in a (both left and right) semiring. +// In the left string semiring, it maps to the reverse string +// in the right string semiring. +// Properties: specifies additional properties that hold: +// LeftSemiring: indicates weights form a left semiring. +// RightSemiring: indicates weights form a right semiring. +// Commutative: for all a,b: Times(a,b) == Times(b,a) +// Idempotent: for all a: Plus(a, a) == a. +// Path: for all a, b: Plus(a, b) == a or Plus(a, b) == b. + + +#ifndef FST_LIB_WEIGHT_H__ +#define FST_LIB_WEIGHT_H__ + +#include <cmath> +#include <cctype> +#include <iostream> +#include <sstream> + +#include <fst/compat.h> + +#include <fst/util.h> + + +namespace fst { + +// +// CONSTANT DEFINITIONS +// + +// A representable float near .001 +const float kDelta = 1.0F/1024.0F; + +// For all a,b,c: Times(c, Plus(a,b)) = Plus(Times(c,a), Times(c, b)) +const uint64 kLeftSemiring = 0x0000000000000001ULL; + +// For all a,b,c: Times(Plus(a,b), c) = Plus(Times(a,c), Times(b, c)) +const uint64 kRightSemiring = 0x0000000000000002ULL; + +const uint64 kSemiring = kLeftSemiring | kRightSemiring; + +// For all a,b: Times(a,b) = Times(b,a) +const uint64 kCommutative = 0x0000000000000004ULL; + +// For all a: Plus(a, a) = a +const uint64 kIdempotent = 0x0000000000000008ULL; + +// For all a,b: Plus(a,b) = a or Plus(a,b) = b +const uint64 kPath = 0x0000000000000010ULL; + + +// Determines direction of division. +enum DivideType { DIVIDE_LEFT, // left division + DIVIDE_RIGHT, // right division + DIVIDE_ANY }; // division in a commutative semiring + +// NATURAL ORDER +// +// By definition: +// a <= b iff a + b = a +// The natural order is a negative partial order iff the semiring is +// idempotent. It is trivially monotonic for plus. It is left +// (resp. right) monotonic for times iff the semiring is left +// (resp. right) distributive. It is a total order iff the semiring +// has the path property. See Mohri, "Semiring Framework and +// Algorithms for Shortest-Distance Problems", Journal of Automata, +// Languages and Combinatorics 7(3):321-350, 2002. We define the +// strict version of this order below. + +template <class W> +class NaturalLess { + public: + typedef W Weight; + + NaturalLess() { + if (!(W::Properties() & kIdempotent)) { + FSTERROR() << "NaturalLess: Weight type is not idempotent: " + << W::Type(); + } + } + + bool operator()(const W &w1, const W &w2) const { + return (Plus(w1, w2) == w1) && w1 != w2; + } +}; + + +// Power is the iterated product for arbitrary semirings such that +// Power(w, 0) is One() for the semiring, and +// Power(w, n) = Times(Power(w, n-1), w) + +template <class W> +W Power(W w, size_t n) { + W result = W::One(); + for (size_t i = 0; i < n; ++i) { + result = Times(result, w); + } + return result; +} + +// General weight converter - raises error. +template <class W1, class W2> +struct WeightConvert { + W2 operator()(W1 w1) const { + FSTERROR() << "WeightConvert: can't convert weight from \"" + << W1::Type() << "\" to \"" << W2::Type(); + return W2::NoWeight(); + } +}; + +// Specialized weight converter to self. +template <class W> +struct WeightConvert<W, W> { + W operator()(W w) const { return w; } +}; + +} // namespace fst + +#endif // FST_LIB_WEIGHT_H__ diff --git a/kaldi_io/tools/kaldi_to_nerv.cpp b/kaldi_io/tools/kaldi_to_nerv.cpp new file mode 100644 index 0000000..1edb0f2 --- /dev/null +++ b/kaldi_io/tools/kaldi_to_nerv.cpp @@ -0,0 +1,109 @@ +#include <cstdio> +#include <fstream> +#include <string> +#include <cstring> +#include <cassert> + +char token[1024]; +char output[1024]; +double mat[4096][4096]; +int main(int argc, char **argv) { + std::ofstream fout; + fout.open(argv[1]); + int cnt = 0; + bool shift; + while (scanf("%s", token) != EOF) + { + int nrow, ncol; + int i, j; + if (strcmp(token, "<AffineTransform>") == 0) + { + double lrate, blrate, mnorm; + scanf("%d %d", &ncol, &nrow); + scanf("%s %lf %s %lf %s %lf", + token, &lrate, token, &blrate, token, &mnorm); + scanf("%s", token); + assert(*token == '['); + printf("%d %d\n", nrow, ncol); + for (j = 0; j < ncol; j++) + for (i = 0; i < nrow; i++) + scanf("%lf", mat[i] + j); + long base = fout.tellp(); + sprintf(output, "%16d", 0); + fout << output; + sprintf(output, "{type=\"nerv.LinearTransParam\",id=\"affine%d_ltp\"}\n", + cnt); + fout << output; + sprintf(output, "%d %d\n", nrow, ncol); + fout << output; + for (i = 0; i < nrow; i++) + { + for (j = 0; j < ncol; j++) + fout << mat[i][j] << " "; + fout << std::endl; + } + long length = fout.tellp() - base; + fout.seekp(base); + sprintf(output, "[%13lu]\n", length); + fout << output; + fout.seekp(0, std::ios_base::end); + scanf("%s", token); + assert(*token == ']'); + if (scanf("%s", token) == 1 && *token == '[') + { + base = fout.tellp(); + for (j = 0; j < ncol; j++) + scanf("%lf", mat[0] + j); + sprintf(output, "%16d", 0); + fout << output; + sprintf(output, "{type=\"nerv.BiasParam\",id=\"affine%d_bp\"}\n", + cnt); + fout << output; + sprintf(output, "1 %d\n", ncol); + fout << output; + for (j = 0; j < ncol; j++) + fout << mat[0][j] << " "; + fout << std::endl; + length = fout.tellp() - base; + fout.seekp(base); + sprintf(output, "[%13lu]\n", length); + fout << output; + fout.seekp(0, std::ios_base::end); + cnt++; + } + } + else if ((shift = (strcmp(token, "<AddShift>") == 0)) || + strcmp(token, "<Rescale>") == 0) + { + double lrate, blrate, mnorm; + scanf("%d %d", &ncol, &ncol); + scanf("%s %lf", + token, &lrate); + scanf("%s", token); + assert(*token == '['); + printf("%d\n", ncol); + for (j = 0; j < ncol; j++) + scanf("%lf", mat[0] + j); + long base = fout.tellp(); + sprintf(output, "%16d", 0); + fout << output; + sprintf(output, "{type=\"nerv.BiasParam\",id=\"%s%d\"}\n", + shift ? "bias" : "window", + cnt); + fout << output; + sprintf(output, "%d %d\n", 1, ncol); + fout << output; + for (j = 0; j < ncol; j++) + fout << mat[0][j] << " "; + fout << std::endl; + long length = fout.tellp() - base; + fout.seekp(base); + sprintf(output, "[%13lu]\n", length); + fout << output; + fout.seekp(0, std::ios_base::end); + scanf("%s", token); + assert(*token == ']'); + } + } + return 0; +} |