summaryrefslogtreecommitdiff
path: root/kaldi_io
diff options
context:
space:
mode:
Diffstat (limited to 'kaldi_io')
-rw-r--r--kaldi_io/Makefile2
-rw-r--r--kaldi_io/init.lua20
2 files changed, 17 insertions, 5 deletions
diff --git a/kaldi_io/Makefile b/kaldi_io/Makefile
index 7b0c0bd..1066fc5 100644
--- a/kaldi_io/Makefile
+++ b/kaldi_io/Makefile
@@ -1,5 +1,5 @@
# Change KDIR to `kaldi-trunk' path (Kaldi must be compiled with --share)
-KDIR := /slfs6/users/ymz09/kaldi/
+KDIR := /home/stuymf/kaldi-trunk/
SHELL := /bin/bash
BUILD_DIR := $(CURDIR)/build
diff --git a/kaldi_io/init.lua b/kaldi_io/init.lua
index e538ee5..2173230 100644
--- a/kaldi_io/init.lua
+++ b/kaldi_io/init.lua
@@ -7,6 +7,11 @@ function KaldiReader:__init(global_conf, reader_conf)
self.frm_ext = reader_conf.frm_ext
self.need_key = reader_conf.need_key -- for sequence training
self.gconf = global_conf
+ if self.gconf.use_cpu then
+ self.mat_type = self.gconf.mmat_type
+ else
+ self.mat_type = self.gconf.cumat_type
+ end
self.debug = global_conf.debug
if self.debug == nil then
self.debug = false
@@ -41,12 +46,15 @@ function KaldiReader:get_data()
end
local res = {}
-- read Kaldi feature
- local raw = self.gconf.cumat_type.new_from_host(self.feat_repo:cur_utter(self.debug))
+ local raw = self.feat_repo:cur_utter(self.debug)
+ if not self.gconf.use_cpu then
+ raw = self.gconf.cumat_type.new_from_host(raw)
+ end
local rearranged
if self.frm_ext and self.frm_ext > 0 then
local step = self.frm_ext * 2 + 1
-- expand the feature
- local expanded = self.gconf.cumat_type(raw:nrow(), raw:ncol() * step)
+ local expanded = self.mat_type(raw:nrow(), raw:ncol() * step)
expanded:expand_frm(raw, self.frm_ext)
-- rearrange the feature (``transpose'' operation in TNet)
if self.gconf.rearrange then
@@ -63,8 +71,12 @@ function KaldiReader:get_data()
feat_utter = self.gconf.mmat_type(rearranged:nrow() - self.gconf.frm_trim * 2, rearranged:ncol())
rearranged:copy_toh(feat_utter, self.gconf.frm_trim, rearranged:nrow() - self.gconf.frm_trim)
else
- feat_utter = self.gconf.mmat_type(rearranged:nrow(), rearranged:ncol())
- rearranged:copy_toh(feat_utter)
+ if self.gconf.use_cpu then
+ feat_utter = rearranged
+ else
+ feat_utter = self.gconf.mmat_type(rearranged:nrow(), rearranged:ncol())
+ rearranged:copy_toh(feat_utter)
+ end
end
res[self.feat_id] = feat_utter
if self.need_key then