summaryrefslogtreecommitdiff
path: root/init.lua
diff options
context:
space:
mode:
authorDeterminant <[email protected]>2015-06-03 09:42:05 +0800
committerDeterminant <[email protected]>2015-06-03 09:42:05 +0800
commit38962683e518dcbebc0cfa6c0c9c9616b25d5bd1 (patch)
treef62b9c670960004f00d0cfd860b925f487edcf9f /init.lua
parent0c6ca6a17f06821cd5d612f489ca6cb68c2c4d5b (diff)
add TNetReader
Diffstat (limited to 'init.lua')
-rw-r--r--init.lua43
1 files changed, 43 insertions, 0 deletions
diff --git a/init.lua b/init.lua
new file mode 100644
index 0000000..1f20096
--- /dev/null
+++ b/init.lua
@@ -0,0 +1,43 @@
+require 'libspeech'
+local TNetReader = nerv.class("nerv.TNetReader", "nerv.DataReader")
+
+function TNetReader:__init(global_conf, reader_conf)
+ self.feat_id = reader_conf.id
+ self.frm_ext = reader_conf.frm_ext
+ self.feat_repo = nerv.TNetFeatureRepo(reader_conf.scp_file,
+ reader_conf.conf_file,
+ reader_conf.frm_ext)
+ self.lab_repo = {}
+ for id, mlf_spec in pairs(reader_conf.mlfs) do
+ self.lab_repo[id] = nerv.TNetLabelRepo(mlf_spec.file,
+ mlf_spec.format,
+ mlf_spec.format_arg,
+ mlf_spec.dir,
+ mlf_spec.ext)
+ end
+ self.global_transf = reader_conf.global_transf
+end
+
+function TNetReader:get_data()
+ local res = {}
+ local frm_ext = self.frm_ext
+ local step = frm_ext * 2 + 1
+ local feat_utter = self.feat_repo:cur_utter()
+ local expanded = nerv.CuMatrixFloat(feat_utter:nrow(), feat_utter:ncol() * step)
+ expanded:expand_frm(nerv.CuMatrixFloat.new_from_host(feat_utter), frm_ext)
+ local rearranged = expanded:create()
+ rearranged:rearrange_frm(expanded, step)
+ local input = {rearranged}
+ local output = {rearranged:create()}
+ self.global_transf:init()
+ self.global_transf:propagate(input, output)
+ expanded = nerv.CuMatrixFloat(output[1]:nrow() - frm_ext * 2, output[1]:ncol())
+ expanded:copy_fromd(output[1], frm_ext, feat_utter:nrow() - frm_ext)
+ res[self.feat_id] = expanded
+ for id, repo in pairs(self.lab_repo) do
+ local lab_utter = repo:get_utter(self.feat_repo, expanded:nrow())
+ res[id] = lab_utter
+ end
+ self.feat_repo:next()
+ return res
+end