summaryrefslogtreecommitdiff
path: root/htk_io/init.lua
diff options
context:
space:
mode:
authorTed Yin <[email protected]>2015-08-31 16:16:00 +0800
committerTed Yin <[email protected]>2015-08-31 16:16:00 +0800
commit014bbfb7e64999a75f9e0dc52267a36741281624 (patch)
tree127888bcefaf8ce4991bb4c173d6538f1172f35f /htk_io/init.lua
parent9e1a0931be43ea80fe7d41154007839b637d4e08 (diff)
parent2196e0a591b9bc254aa95e180adf188fd70ded68 (diff)
Merge pull request #4 from uphantom/master
support fastnn multi-thread TNetReader
Diffstat (limited to 'htk_io/init.lua')
-rw-r--r--htk_io/init.lua30
1 files changed, 28 insertions, 2 deletions
diff --git a/htk_io/init.lua b/htk_io/init.lua
index b360b67..677d3e9 100644
--- a/htk_io/init.lua
+++ b/htk_io/init.lua
@@ -1,8 +1,9 @@
require 'libhtkio'
require 'speech_utils'
+require 'threads'
local TNetReader = nerv.class("nerv.TNetReader", "nerv.DataReader")
-function TNetReader:__init(global_conf, reader_conf)
+function TNetReader:__init(global_conf, reader_conf, feat_repo_shareid, data_mutex_shareid)
self.feat_id = reader_conf.id
self.frm_ext = reader_conf.frm_ext
self.gconf = global_conf
@@ -10,9 +11,22 @@ function TNetReader:__init(global_conf, reader_conf)
if self.debug == nil then
self.debug = false
end
- self.feat_repo = nerv.TNetFeatureRepo(reader_conf.scp_file,
+
+ if feat_repo_shareid ~= nil then
+ self.feat_repo = nerv.TNetFeatureRepo(feat_repo_shareid)
+ else
+ self.feat_repo = nerv.TNetFeatureRepo(reader_conf.scp_file,
reader_conf.conf_file,
reader_conf.frm_ext)
+ end
+
+ --print(self.feat_repo)
+ self.data_mutex = nil
+ if data_mutex_shareid ~= nil then
+ self.data_mutex = threads.Mutex(data_mutex_shareid)
+ end
+
+
self.lab_repo = {}
if reader_conf.mlfs then
for id, mlf_spec in pairs(reader_conf.mlfs) do
@@ -26,7 +40,14 @@ function TNetReader:__init(global_conf, reader_conf)
end
function TNetReader:get_data()
+ if self.data_mutex ~= nil then
+ self.data_mutex:lock()
+ end
+
if self.feat_repo:is_end() then
+ if self.data_mutex ~= nil then
+ self.data_mutex:unlock()
+ end
return nil
end
local res = {}
@@ -66,6 +87,11 @@ function TNetReader:get_data()
end
-- move the pointer to next
self.feat_repo:next()
+
+ if self.data_mutex ~= nil then
+ self.data_mutex:unlock()
+ end
+
collectgarbage("collect")
return res
end