aboutsummaryrefslogtreecommitdiff
path: root/nerv
diff options
context:
space:
mode:
authorDeterminant <[email protected]>2016-03-03 15:49:10 +0800
committerDeterminant <[email protected]>2016-03-03 15:49:10 +0800
commite1c004f89ec3a783cf1948165baae5975489f775 (patch)
treef6b630853eac2719ef290d5bb8b7573e4f2c8567 /nerv
parentc964d74b1927ffd9fd431f457f7e385b2d2ba5ba (diff)
add Penlight to facilitate file copying and dir making
Diffstat (limited to 'nerv')
-rw-r--r--nerv/examples/asr_trainer.lua17
-rw-r--r--nerv/init.lua28
-rw-r--r--nerv/nerv-scm-1.rockspec2
3 files changed, 44 insertions, 3 deletions
diff --git a/nerv/examples/asr_trainer.lua b/nerv/examples/asr_trainer.lua
index 684ea30..5001e12 100644
--- a/nerv/examples/asr_trainer.lua
+++ b/nerv/examples/asr_trainer.lua
@@ -1,3 +1,5 @@
+require 'lfs'
+require 'pl'
local function build_trainer(ifname)
local param_repo = nerv.ParamRepo()
param_repo:import(ifname, nil, gconf)
@@ -133,6 +135,8 @@ local trainer_defaults = {
local options = make_options(trainer_defaults)
table.insert(options, {"help", "h", "boolean",
default = false, desc = "show this help information"})
+table.insert(options, {"dir", nil, "string",
+ default = nil, desc = "specify the working directory"})
arg, opts = nerv.parse_args(arg, options)
@@ -156,8 +160,19 @@ check_and_add_defaults(trainer_defaults)
local pf0 = gconf.initialized_param
local trainer = build_trainer(pf0)
local accu_best = trainer(nil, gconf.cv_scp, false)
+local date_pattern = "%Y%m%d%H%M%S"
+local logfile_name = "log"
+local working_dir = opts["dir"].val or string.format("nerv_%s", os.date(date_pattern))
print_gconf()
+if not lfs.mkdir(working_dir) then
+ nerv.error("[asr_trainer] working directory already exists")
+end
+-- copy the network config
+dir.copyfile(arg[1], working_dir)
+-- set logfile path
+nerv.set_logfile(path.join(working_dir, logfile_name))
+path.chdir(working_dir)
nerv.info("initial cross validation: %.3f", accu_best)
for i = 1, gconf.max_iter do
nerv.info("[NN] begin iteration %d with lrate = %.6f", i, gconf.lrate)
@@ -168,7 +183,7 @@ for i = 1, gconf.max_iter do
string.gsub(
(string.gsub(pf0[1], "(.*/)(.*)", "%2")),
"(.*)%..*", "%1"),
- os.date("%Y%m%d%H%M%S"),
+ os.date(date_pattern),
i, gconf.lrate,
accu_tr),
gconf.cv_scp, false)
diff --git a/nerv/init.lua b/nerv/init.lua
index a5b032c..06cb611 100644
--- a/nerv/init.lua
+++ b/nerv/init.lua
@@ -13,6 +13,10 @@ function nerv.error_method_not_implemented()
nerv.error("method not implemented");
end
+function nerv.set_logfile(filename)
+ nerv._logfile = io.open(filename, "w")
+end
+
--- Format a string just like `sprintf` in C.
-- @param fmt the format string
-- @param ... args, the data to be formatted
@@ -25,7 +29,13 @@ end
-- @param fmt the format string
-- @param ... args, the data to be formatted
function nerv.printf(fmt, ...)
- io.stderr:write(nerv.sprintf(fmt, ...))
+ local line = nerv.sprintf(fmt, ...)
+ io.stderr:write(line)
+ -- duplicate the all output to the log file, if set
+ if nerv._logfile then
+ nerv._logfile:write(line)
+ nerv._logfile:flush()
+ end
end
--- Raise an global error with the formatted message.
@@ -328,6 +338,22 @@ function nerv.print_usage(options)
nerv.printf("\n")
end
+-- function nerv.copy_file(fname1, fname2)
+-- local fin, fout, err
+-- fin, err = io.open(fname1, "r")
+-- if fin then
+-- fout, err = io.open(fname2, "w")
+-- end
+-- if not (fin and fout) then
+-- nerv.error("[copy] from %s to %s: %s", fname1, fname2, err)
+-- end
+-- while true do
+-- local b = fin:read(1024)
+-- if b == nil then break end
+-- fout:write(b)
+-- end
+-- end
+
-- the following lines trigger the initialization of basic modules
nerv.include('matrix/init.lua')
diff --git a/nerv/nerv-scm-1.rockspec b/nerv/nerv-scm-1.rockspec
index 9dbe771..d039e85 100644
--- a/nerv/nerv-scm-1.rockspec
+++ b/nerv/nerv-scm-1.rockspec
@@ -12,7 +12,7 @@ description = {
}
dependencies = {
"lua >= 5.1",
- "luafilesystem >= 1.6.3"
+ "penlight >= 1.3.2"
}
build = {
type = "make",