summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorDeterminant <ted.sybil@gmail.com>2016-01-13 12:05:48 +0800
committerDeterminant <ted.sybil@gmail.com>2016-01-13 12:05:48 +0800
commit3f742c01540b8bad940d66331d562082a63d975b (patch)
tree25ed764f4ac9ee6694d8bff3745eafe5da5ae827
parentf9b78f6bc35cd5c0c117a5e523ef6aab96dee5c3 (diff)
support pure CPU reader
-rw-r--r--htk_io/init.lua20
-rw-r--r--kaldi_io/Makefile2
-rw-r--r--kaldi_io/init.lua20
-rw-r--r--speech_utils/init.lua9
4 files changed, 40 insertions, 11 deletions
diff --git a/htk_io/init.lua b/htk_io/init.lua
index b360b67..1cdabf1 100644
--- a/htk_io/init.lua
+++ b/htk_io/init.lua
@@ -6,6 +6,11 @@ function TNetReader:__init(global_conf, reader_conf)
self.feat_id = reader_conf.id
self.frm_ext = reader_conf.frm_ext
self.gconf = global_conf
+ if self.gconf.use_cpu then
+ self.mat_type = self.gconf.mmat_type
+ else
+ self.mat_type = self.gconf.cumat_type
+ end
self.debug = global_conf.debug
if self.debug == nil then
self.debug = false
@@ -31,12 +36,15 @@ function TNetReader:get_data()
end
local res = {}
-- read HTK feature
- local raw = self.gconf.cumat_type.new_from_host(self.feat_repo:cur_utter(self.debug))
+ local raw = self.feat_repo:cur_utter(self.debug)
+ if not self.gconf.use_cpu then
+ raw = self.gconf.cumat_type.new_from_host(raw)
+ end
local rearranged
if self.frm_ext and self.frm_ext > 0 then
local step = self.frm_ext * 2 + 1
-- expand the feature
- local expanded = self.gconf.cumat_type(raw:nrow(), raw:ncol() * step)
+ local expanded = self.mat_type(raw:nrow(), raw:ncol() * step)
expanded:expand_frm(raw, self.frm_ext)
-- rearrange the feature (``transpose'' operation in TNet)
if self.gconf.rearrange then
@@ -53,8 +61,12 @@ function TNetReader:get_data()
feat_utter = self.gconf.mmat_type(rearranged:nrow() - self.gconf.frm_trim * 2, rearranged:ncol())
rearranged:copy_toh(feat_utter, self.gconf.frm_trim, rearranged:nrow() - self.gconf.frm_trim)
else
- feat_utter = self.gconf.mmat_type(rearranged:nrow(), rearranged:ncol())
- rearranged:copy_toh(feat_utter)
+ if self.gconf.use_cpu then
+ feat_utter = rearranged
+ else
+ feat_utter = self.gconf.mmat_type(rearranged:nrow(), rearranged:ncol())
+ rearranged:copy_toh(feat_utter)
+ end
end
res[self.feat_id] = feat_utter
-- add corresponding labels
diff --git a/kaldi_io/Makefile b/kaldi_io/Makefile
index 7b0c0bd..1066fc5 100644
--- a/kaldi_io/Makefile
+++ b/kaldi_io/Makefile
@@ -1,5 +1,5 @@
# Change KDIR to `kaldi-trunk' path (Kaldi must be compiled with --share)
-KDIR := /slfs6/users/ymz09/kaldi/
+KDIR := /home/stuymf/kaldi-trunk/
SHELL := /bin/bash
BUILD_DIR := $(CURDIR)/build
diff --git a/kaldi_io/init.lua b/kaldi_io/init.lua
index e538ee5..2173230 100644
--- a/kaldi_io/init.lua
+++ b/kaldi_io/init.lua
@@ -7,6 +7,11 @@ function KaldiReader:__init(global_conf, reader_conf)
self.frm_ext = reader_conf.frm_ext
self.need_key = reader_conf.need_key -- for sequence training
self.gconf = global_conf
+ if self.gconf.use_cpu then
+ self.mat_type = self.gconf.mmat_type
+ else
+ self.mat_type = self.gconf.cumat_type
+ end
self.debug = global_conf.debug
if self.debug == nil then
self.debug = false
@@ -41,12 +46,15 @@ function KaldiReader:get_data()
end
local res = {}
-- read Kaldi feature
- local raw = self.gconf.cumat_type.new_from_host(self.feat_repo:cur_utter(self.debug))
+ local raw = self.feat_repo:cur_utter(self.debug)
+ if not self.gconf.use_cpu then
+ raw = self.gconf.cumat_type.new_from_host(raw)
+ end
local rearranged
if self.frm_ext and self.frm_ext > 0 then
local step = self.frm_ext * 2 + 1
-- expand the feature
- local expanded = self.gconf.cumat_type(raw:nrow(), raw:ncol() * step)
+ local expanded = self.mat_type(raw:nrow(), raw:ncol() * step)
expanded:expand_frm(raw, self.frm_ext)
-- rearrange the feature (``transpose'' operation in TNet)
if self.gconf.rearrange then
@@ -63,8 +71,12 @@ function KaldiReader:get_data()
feat_utter = self.gconf.mmat_type(rearranged:nrow() - self.gconf.frm_trim * 2, rearranged:ncol())
rearranged:copy_toh(feat_utter, self.gconf.frm_trim, rearranged:nrow() - self.gconf.frm_trim)
else
- feat_utter = self.gconf.mmat_type(rearranged:nrow(), rearranged:ncol())
- rearranged:copy_toh(feat_utter)
+ if self.gconf.use_cpu then
+ feat_utter = rearranged
+ else
+ feat_utter = self.gconf.mmat_type(rearranged:nrow(), rearranged:ncol())
+ rearranged:copy_toh(feat_utter)
+ end
end
res[self.feat_id] = feat_utter
if self.need_key then
diff --git a/speech_utils/init.lua b/speech_utils/init.lua
index f89f4fd..9e8adba 100644
--- a/speech_utils/init.lua
+++ b/speech_utils/init.lua
@@ -9,8 +9,13 @@ function nerv.speech_utils.global_transf(feat_utter, global_transf,
global_transf:init(input[1]:nrow())
global_transf:propagate(input, output)
-- trim frames
- expanded = gconf.cumat_type(output[1]:nrow() - frm_trim * 2, output[1]:ncol())
- expanded:copy_fromd(output[1], frm_trim, feat_utter:nrow() - frm_trim)
+ if gconf.use_cpu then
+ mat_type = gconf.mmat_type
+ else
+ mat_type = gconf.cumat_type
+ end
+ expanded = mat_type(output[1]:nrow() - frm_trim * 2, output[1]:ncol())
+ expanded:copy_from(output[1], frm_trim, feat_utter:nrow() - frm_trim)
collectgarbage("collect")
return expanded
end