summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--.gitmodules3
-rw-r--r--Makefile7
m---------Penlight0
-rw-r--r--nerv/examples/asr_trainer.lua17
-rw-r--r--nerv/init.lua28
-rw-r--r--nerv/nerv-scm-1.rockspec2
6 files changed, 51 insertions, 6 deletions
diff --git a/.gitmodules b/.gitmodules
index 9f556c5..acce7f3 100644
--- a/.gitmodules
+++ b/.gitmodules
@@ -4,3 +4,6 @@
[submodule "luarocks"]
path = luarocks
url = https://github.com/keplerproject/luarocks.git
+[submodule "Penlight"]
+ path = Penlight
+ url = https://github.com/stevedonovan/Penlight.git
diff --git a/Makefile b/Makefile
index 1711b3c..28012da 100644
--- a/Makefile
+++ b/Makefile
@@ -23,9 +23,10 @@ export KALDI_BASE
export BLAS_LDFLAGS
.PHONY: nerv speech/speech_utils speech/htk_io speech/kaldi_io speech/kaldi_decode \
- nerv-clean speech/speech_utils-clean speech/htk_io-clean speech/kaldi_io-clean speech/kaldi_decode-clean
+ nerv-clean speech/speech_utils-clean speech/htk_io-clean speech/kaldi_io-clean speech/kaldi_decode-clean \
+ Penlight
-all: luajit luarocks nerv
+all: luajit luarocks Penlight nerv
luajit:
PREFIX=$(PREFIX) ./tools/build_luajit.sh
luarocks:
@@ -35,7 +36,7 @@ speech: speech/speech_utils speech/htk_io speech/kaldi_io speech/kaldi_decode
speech-clean: speech/speech_utils-clean speech/htk_io-clean speech/kaldi_io-clean speech/kaldi_decode-clean
clean: nerv-clean speech-clean
-nerv speech/speech_utils speech/htk_io speech/kaldi_io speech/kaldi_decode:
+nerv Penlight speech/speech_utils speech/htk_io speech/kaldi_io speech/kaldi_decode:
cd $@; $(PREFIX)/bin/luarocks make
nerv-clean speech/speech_utils-clean speech/htk_io-clean speech/kaldi_io-clean speech/kaldi_decode-clean:
cd $(subst -clean,,$@); make clean LUA_BINDIR=$(PREFIX)/bin/
diff --git a/Penlight b/Penlight
new file mode 160000
+Subproject 16d149338af9efc910528641c5240c5641aeb8d
diff --git a/nerv/examples/asr_trainer.lua b/nerv/examples/asr_trainer.lua
index 684ea30..5001e12 100644
--- a/nerv/examples/asr_trainer.lua
+++ b/nerv/examples/asr_trainer.lua
@@ -1,3 +1,5 @@
+require 'lfs'
+require 'pl'
local function build_trainer(ifname)
local param_repo = nerv.ParamRepo()
param_repo:import(ifname, nil, gconf)
@@ -133,6 +135,8 @@ local trainer_defaults = {
local options = make_options(trainer_defaults)
table.insert(options, {"help", "h", "boolean",
default = false, desc = "show this help information"})
+table.insert(options, {"dir", nil, "string",
+ default = nil, desc = "specify the working directory"})
arg, opts = nerv.parse_args(arg, options)
@@ -156,8 +160,19 @@ check_and_add_defaults(trainer_defaults)
local pf0 = gconf.initialized_param
local trainer = build_trainer(pf0)
local accu_best = trainer(nil, gconf.cv_scp, false)
+local date_pattern = "%Y%m%d%H%M%S"
+local logfile_name = "log"
+local working_dir = opts["dir"].val or string.format("nerv_%s", os.date(date_pattern))
print_gconf()
+if not lfs.mkdir(working_dir) then
+ nerv.error("[asr_trainer] working directory already exists")
+end
+-- copy the network config
+dir.copyfile(arg[1], working_dir)
+-- set logfile path
+nerv.set_logfile(path.join(working_dir, logfile_name))
+path.chdir(working_dir)
nerv.info("initial cross validation: %.3f", accu_best)
for i = 1, gconf.max_iter do
nerv.info("[NN] begin iteration %d with lrate = %.6f", i, gconf.lrate)
@@ -168,7 +183,7 @@ for i = 1, gconf.max_iter do
string.gsub(
(string.gsub(pf0[1], "(.*/)(.*)", "%2")),
"(.*)%..*", "%1"),
- os.date("%Y%m%d%H%M%S"),
+ os.date(date_pattern),
i, gconf.lrate,
accu_tr),
gconf.cv_scp, false)
diff --git a/nerv/init.lua b/nerv/init.lua
index a5b032c..06cb611 100644
--- a/nerv/init.lua
+++ b/nerv/init.lua
@@ -13,6 +13,10 @@ function nerv.error_method_not_implemented()
nerv.error("method not implemented");
end
+function nerv.set_logfile(filename)
+ nerv._logfile = io.open(filename, "w")
+end
+
--- Format a string just like `sprintf` in C.
-- @param fmt the format string
-- @param ... args, the data to be formatted
@@ -25,7 +29,13 @@ end
-- @param fmt the format string
-- @param ... args, the data to be formatted
function nerv.printf(fmt, ...)
- io.stderr:write(nerv.sprintf(fmt, ...))
+ local line = nerv.sprintf(fmt, ...)
+ io.stderr:write(line)
+ -- duplicate the all output to the log file, if set
+ if nerv._logfile then
+ nerv._logfile:write(line)
+ nerv._logfile:flush()
+ end
end
--- Raise an global error with the formatted message.
@@ -328,6 +338,22 @@ function nerv.print_usage(options)
nerv.printf("\n")
end
+-- function nerv.copy_file(fname1, fname2)
+-- local fin, fout, err
+-- fin, err = io.open(fname1, "r")
+-- if fin then
+-- fout, err = io.open(fname2, "w")
+-- end
+-- if not (fin and fout) then
+-- nerv.error("[copy] from %s to %s: %s", fname1, fname2, err)
+-- end
+-- while true do
+-- local b = fin:read(1024)
+-- if b == nil then break end
+-- fout:write(b)
+-- end
+-- end
+
-- the following lines trigger the initialization of basic modules
nerv.include('matrix/init.lua')
diff --git a/nerv/nerv-scm-1.rockspec b/nerv/nerv-scm-1.rockspec
index 9dbe771..d039e85 100644
--- a/nerv/nerv-scm-1.rockspec
+++ b/nerv/nerv-scm-1.rockspec
@@ -12,7 +12,7 @@ description = {
}
dependencies = {
"lua >= 5.1",
- "luafilesystem >= 1.6.3"
+ "penlight >= 1.3.2"
}
build = {
type = "make",