diff options
author | Ted Yin <[email protected]> | 2015-08-31 16:16:00 +0800 |
---|---|---|
committer | Ted Yin <[email protected]> | 2015-08-31 16:16:00 +0800 |
commit | 014bbfb7e64999a75f9e0dc52267a36741281624 (patch) | |
tree | 127888bcefaf8ce4991bb4c173d6538f1172f35f /htk_io/init.lua | |
parent | 9e1a0931be43ea80fe7d41154007839b637d4e08 (diff) | |
parent | 2196e0a591b9bc254aa95e180adf188fd70ded68 (diff) |
Merge pull request #4 from uphantom/master
support fastnn multi-thread TNetReader
Diffstat (limited to 'htk_io/init.lua')
-rw-r--r-- | htk_io/init.lua | 30 |
1 files changed, 28 insertions, 2 deletions
diff --git a/htk_io/init.lua b/htk_io/init.lua index b360b67..677d3e9 100644 --- a/htk_io/init.lua +++ b/htk_io/init.lua @@ -1,8 +1,9 @@ require 'libhtkio' require 'speech_utils' +require 'threads' local TNetReader = nerv.class("nerv.TNetReader", "nerv.DataReader") -function TNetReader:__init(global_conf, reader_conf) +function TNetReader:__init(global_conf, reader_conf, feat_repo_shareid, data_mutex_shareid) self.feat_id = reader_conf.id self.frm_ext = reader_conf.frm_ext self.gconf = global_conf @@ -10,9 +11,22 @@ function TNetReader:__init(global_conf, reader_conf) if self.debug == nil then self.debug = false end - self.feat_repo = nerv.TNetFeatureRepo(reader_conf.scp_file, + + if feat_repo_shareid ~= nil then + self.feat_repo = nerv.TNetFeatureRepo(feat_repo_shareid) + else + self.feat_repo = nerv.TNetFeatureRepo(reader_conf.scp_file, reader_conf.conf_file, reader_conf.frm_ext) + end + + --print(self.feat_repo) + self.data_mutex = nil + if data_mutex_shareid ~= nil then + self.data_mutex = threads.Mutex(data_mutex_shareid) + end + + self.lab_repo = {} if reader_conf.mlfs then for id, mlf_spec in pairs(reader_conf.mlfs) do @@ -26,7 +40,14 @@ function TNetReader:__init(global_conf, reader_conf) end function TNetReader:get_data() + if self.data_mutex ~= nil then + self.data_mutex:lock() + end + if self.feat_repo:is_end() then + if self.data_mutex ~= nil then + self.data_mutex:unlock() + end return nil end local res = {} @@ -66,6 +87,11 @@ function TNetReader:get_data() end -- move the pointer to next self.feat_repo:next() + + if self.data_mutex ~= nil then + self.data_mutex:unlock() + end + collectgarbage("collect") return res end |