diff options
author | cloudygoose <[email protected]> | 2015-06-05 21:40:45 +0800 |
---|---|---|
committer | cloudygoose <[email protected]> | 2015-06-05 21:40:45 +0800 |
commit | 5b4cc22736ade93f4d8348513c4a35f6a9f9be04 (patch) | |
tree | 255fbddedcdb25b88f4a70268cb6b1ffbaa5afed /examples/test_nn_lib.lua | |
parent | 90f2b7c257c286e6c52432ed43807f332d97cc7e (diff) | |
parent | 37af4bed9c3680fdb9db569605f15013e9b6b64d (diff) |
...
Merge remote-tracking branch 'upstream/master'
Diffstat (limited to 'examples/test_nn_lib.lua')
-rw-r--r-- | examples/test_nn_lib.lua | 29 |
1 files changed, 21 insertions, 8 deletions
diff --git a/examples/test_nn_lib.lua b/examples/test_nn_lib.lua index 04fd7d6..6fdbd67 100644 --- a/examples/test_nn_lib.lua +++ b/examples/test_nn_lib.lua @@ -117,7 +117,7 @@ tnet_reader = nerv.TNetReader(gconf, buffer = nerv.SGDBuffer(gconf, { buffer_size = 81920, - -- randomize = true, + randomize = true, readers = { { reader = tnet_reader, data = {main_scp = 429, ref = 1}} @@ -128,9 +128,12 @@ sm = sublayer_repo:get_layer("softmax_ce0") main = layer_repo:get_layer("main") main:init(gconf.batch_size) gconf.cnt = 0 +-- data = buffer:get_data() +-- input = {data.main_scp, data.ref} +-- while true do for data in buffer.get_data, buffer do - if gconf.cnt == 1000 then break end - gconf.cnt = gconf.cnt + 1 +-- if gconf.cnt == 100 then break end +-- gconf.cnt = gconf.cnt + 1 input = {data.main_scp, data.ref} output = {} @@ -141,11 +144,21 @@ for data in buffer.get_data, buffer do main:back_propagate(err_output, err_input, input, output) main:update(err_input, input, output) - nerv.utils.printf("cross entropy: %.8f\n", sm.total_ce) - nerv.utils.printf("correct: %d\n", sm.total_correct) - nerv.utils.printf("frames: %d\n", sm.total_frames) - nerv.utils.printf("err/frm: %.8f\n", sm.total_ce / sm.total_frames) - nerv.utils.printf("accuracy: %.8f\n", sm.total_correct / sm.total_frames) +-- nerv.utils.printf("cross entropy: %.8f\n", sm.total_ce) +-- nerv.utils.printf("correct: %d\n", sm.total_correct) +-- nerv.utils.printf("frames: %d\n", sm.total_frames) +-- nerv.utils.printf("err/frm: %.8f\n", sm.total_ce / sm.total_frames) +-- nerv.utils.printf("accuracy: %.8f\n", sm.total_correct / sm.total_frames) collectgarbage("collect") end +nerv.utils.printf("cross entropy: %.8f\n", sm.total_ce) +nerv.utils.printf("correct: %d\n", sm.total_correct) +nerv.utils.printf("accuracy: %.3f%%\n", sm.total_correct / sm.total_frames * 100) +nerv.utils.printf("writing back...\n") +cf = nerv.ChunkFile("output.nerv", "w") +for i, p in ipairs(main:get_params()) do + print(p) + cf:write_chunk(p) +end +cf:close() nerv.Matrix.print_profile() |