diff options
-rw-r--r-- | .gitmodules | 3 | ||||
-rw-r--r-- | Makefile | 7 | ||||
m--------- | Penlight | 0 | ||||
-rw-r--r-- | nerv/examples/asr_trainer.lua | 17 | ||||
-rw-r--r-- | nerv/init.lua | 28 | ||||
-rw-r--r-- | nerv/nerv-scm-1.rockspec | 2 |
6 files changed, 51 insertions, 6 deletions
diff --git a/.gitmodules b/.gitmodules index 9f556c5..acce7f3 100644 --- a/.gitmodules +++ b/.gitmodules @@ -4,3 +4,6 @@ [submodule "luarocks"] path = luarocks url = https://github.com/keplerproject/luarocks.git +[submodule "Penlight"] + path = Penlight + url = https://github.com/stevedonovan/Penlight.git @@ -23,9 +23,10 @@ export KALDI_BASE export BLAS_LDFLAGS .PHONY: nerv speech/speech_utils speech/htk_io speech/kaldi_io speech/kaldi_decode \ - nerv-clean speech/speech_utils-clean speech/htk_io-clean speech/kaldi_io-clean speech/kaldi_decode-clean + nerv-clean speech/speech_utils-clean speech/htk_io-clean speech/kaldi_io-clean speech/kaldi_decode-clean \ + Penlight -all: luajit luarocks nerv +all: luajit luarocks Penlight nerv luajit: PREFIX=$(PREFIX) ./tools/build_luajit.sh luarocks: @@ -35,7 +36,7 @@ speech: speech/speech_utils speech/htk_io speech/kaldi_io speech/kaldi_decode speech-clean: speech/speech_utils-clean speech/htk_io-clean speech/kaldi_io-clean speech/kaldi_decode-clean clean: nerv-clean speech-clean -nerv speech/speech_utils speech/htk_io speech/kaldi_io speech/kaldi_decode: +nerv Penlight speech/speech_utils speech/htk_io speech/kaldi_io speech/kaldi_decode: cd $@; $(PREFIX)/bin/luarocks make nerv-clean speech/speech_utils-clean speech/htk_io-clean speech/kaldi_io-clean speech/kaldi_decode-clean: cd $(subst -clean,,$@); make clean LUA_BINDIR=$(PREFIX)/bin/ diff --git a/Penlight b/Penlight new file mode 160000 +Subproject 16d149338af9efc910528641c5240c5641aeb8d diff --git a/nerv/examples/asr_trainer.lua b/nerv/examples/asr_trainer.lua index 684ea30..5001e12 100644 --- a/nerv/examples/asr_trainer.lua +++ b/nerv/examples/asr_trainer.lua @@ -1,3 +1,5 @@ +require 'lfs' +require 'pl' local function build_trainer(ifname) local param_repo = nerv.ParamRepo() param_repo:import(ifname, nil, gconf) @@ -133,6 +135,8 @@ local trainer_defaults = { local options = make_options(trainer_defaults) table.insert(options, {"help", "h", "boolean", default = false, desc = "show this help information"}) +table.insert(options, {"dir", nil, "string", + default = nil, desc = "specify the working directory"}) arg, opts = nerv.parse_args(arg, options) @@ -156,8 +160,19 @@ check_and_add_defaults(trainer_defaults) local pf0 = gconf.initialized_param local trainer = build_trainer(pf0) local accu_best = trainer(nil, gconf.cv_scp, false) +local date_pattern = "%Y%m%d%H%M%S" +local logfile_name = "log" +local working_dir = opts["dir"].val or string.format("nerv_%s", os.date(date_pattern)) print_gconf() +if not lfs.mkdir(working_dir) then + nerv.error("[asr_trainer] working directory already exists") +end +-- copy the network config +dir.copyfile(arg[1], working_dir) +-- set logfile path +nerv.set_logfile(path.join(working_dir, logfile_name)) +path.chdir(working_dir) nerv.info("initial cross validation: %.3f", accu_best) for i = 1, gconf.max_iter do nerv.info("[NN] begin iteration %d with lrate = %.6f", i, gconf.lrate) @@ -168,7 +183,7 @@ for i = 1, gconf.max_iter do string.gsub( (string.gsub(pf0[1], "(.*/)(.*)", "%2")), "(.*)%..*", "%1"), - os.date("%Y%m%d%H%M%S"), + os.date(date_pattern), i, gconf.lrate, accu_tr), gconf.cv_scp, false) diff --git a/nerv/init.lua b/nerv/init.lua index a5b032c..06cb611 100644 --- a/nerv/init.lua +++ b/nerv/init.lua @@ -13,6 +13,10 @@ function nerv.error_method_not_implemented() nerv.error("method not implemented"); end +function nerv.set_logfile(filename) + nerv._logfile = io.open(filename, "w") +end + --- Format a string just like `sprintf` in C. -- @param fmt the format string -- @param ... args, the data to be formatted @@ -25,7 +29,13 @@ end -- @param fmt the format string -- @param ... args, the data to be formatted function nerv.printf(fmt, ...) - io.stderr:write(nerv.sprintf(fmt, ...)) + local line = nerv.sprintf(fmt, ...) + io.stderr:write(line) + -- duplicate the all output to the log file, if set + if nerv._logfile then + nerv._logfile:write(line) + nerv._logfile:flush() + end end --- Raise an global error with the formatted message. @@ -328,6 +338,22 @@ function nerv.print_usage(options) nerv.printf("\n") end +-- function nerv.copy_file(fname1, fname2) +-- local fin, fout, err +-- fin, err = io.open(fname1, "r") +-- if fin then +-- fout, err = io.open(fname2, "w") +-- end +-- if not (fin and fout) then +-- nerv.error("[copy] from %s to %s: %s", fname1, fname2, err) +-- end +-- while true do +-- local b = fin:read(1024) +-- if b == nil then break end +-- fout:write(b) +-- end +-- end + -- the following lines trigger the initialization of basic modules nerv.include('matrix/init.lua') diff --git a/nerv/nerv-scm-1.rockspec b/nerv/nerv-scm-1.rockspec index 9dbe771..d039e85 100644 --- a/nerv/nerv-scm-1.rockspec +++ b/nerv/nerv-scm-1.rockspec @@ -12,7 +12,7 @@ description = { } dependencies = { "lua >= 5.1", - "luafilesystem >= 1.6.3" + "penlight >= 1.3.2" } build = { type = "make", |