summaryrefslogtreecommitdiff
path: root/kaldi_io/init.lua
diff options
context:
space:
mode:
authorDeterminant <[email protected]>2016-01-13 12:05:48 +0800
committerDeterminant <[email protected]>2016-01-13 12:05:48 +0800
commit3f742c01540b8bad940d66331d562082a63d975b (patch)
tree25ed764f4ac9ee6694d8bff3745eafe5da5ae827 /kaldi_io/init.lua
parentf9b78f6bc35cd5c0c117a5e523ef6aab96dee5c3 (diff)
support pure CPU reader
Diffstat (limited to 'kaldi_io/init.lua')
-rw-r--r--kaldi_io/init.lua20
1 files changed, 16 insertions, 4 deletions
diff --git a/kaldi_io/init.lua b/kaldi_io/init.lua
index e538ee5..2173230 100644
--- a/kaldi_io/init.lua
+++ b/kaldi_io/init.lua
@@ -7,6 +7,11 @@ function KaldiReader:__init(global_conf, reader_conf)
self.frm_ext = reader_conf.frm_ext
self.need_key = reader_conf.need_key -- for sequence training
self.gconf = global_conf
+ if self.gconf.use_cpu then
+ self.mat_type = self.gconf.mmat_type
+ else
+ self.mat_type = self.gconf.cumat_type
+ end
self.debug = global_conf.debug
if self.debug == nil then
self.debug = false
@@ -41,12 +46,15 @@ function KaldiReader:get_data()
end
local res = {}
-- read Kaldi feature
- local raw = self.gconf.cumat_type.new_from_host(self.feat_repo:cur_utter(self.debug))
+ local raw = self.feat_repo:cur_utter(self.debug)
+ if not self.gconf.use_cpu then
+ raw = self.gconf.cumat_type.new_from_host(raw)
+ end
local rearranged
if self.frm_ext and self.frm_ext > 0 then
local step = self.frm_ext * 2 + 1
-- expand the feature
- local expanded = self.gconf.cumat_type(raw:nrow(), raw:ncol() * step)
+ local expanded = self.mat_type(raw:nrow(), raw:ncol() * step)
expanded:expand_frm(raw, self.frm_ext)
-- rearrange the feature (``transpose'' operation in TNet)
if self.gconf.rearrange then
@@ -63,8 +71,12 @@ function KaldiReader:get_data()
feat_utter = self.gconf.mmat_type(rearranged:nrow() - self.gconf.frm_trim * 2, rearranged:ncol())
rearranged:copy_toh(feat_utter, self.gconf.frm_trim, rearranged:nrow() - self.gconf.frm_trim)
else
- feat_utter = self.gconf.mmat_type(rearranged:nrow(), rearranged:ncol())
- rearranged:copy_toh(feat_utter)
+ if self.gconf.use_cpu then
+ feat_utter = rearranged
+ else
+ feat_utter = self.gconf.mmat_type(rearranged:nrow(), rearranged:ncol())
+ rearranged:copy_toh(feat_utter)
+ end
end
res[self.feat_id] = feat_utter
if self.need_key then