From df737041e4a9f3f55978cc74db9a9cea27fa9fa0 Mon Sep 17 00:00:00 2001 From: Determinant Date: Fri, 5 Jun 2015 10:58:57 +0800 Subject: add profiling; add ce accurarcy; several other changes --- examples/test_nn_lib.lua | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) (limited to 'examples') diff --git a/examples/test_nn_lib.lua b/examples/test_nn_lib.lua index 9600917..04fd7d6 100644 --- a/examples/test_nn_lib.lua +++ b/examples/test_nn_lib.lua @@ -116,7 +116,8 @@ tnet_reader = nerv.TNetReader(gconf, buffer = nerv.SGDBuffer(gconf, { - buffer_size = 8192, + buffer_size = 81920, + -- randomize = true, readers = { { reader = tnet_reader, data = {main_scp = 429, ref = 1}} @@ -126,10 +127,11 @@ buffer = nerv.SGDBuffer(gconf, sm = sublayer_repo:get_layer("softmax_ce0") main = layer_repo:get_layer("main") main:init(gconf.batch_size) -cnt = 0 +gconf.cnt = 0 for data in buffer.get_data, buffer do - if cnt == 1000 then break end - cnt = cnt + 1 + if gconf.cnt == 1000 then break end + gconf.cnt = gconf.cnt + 1 + input = {data.main_scp, data.ref} output = {} err_input = {} @@ -140,7 +142,10 @@ for data in buffer.get_data, buffer do main:update(err_input, input, output) nerv.utils.printf("cross entropy: %.8f\n", sm.total_ce) - nerv.utils.printf("frames: %.8f\n", sm.total_frames) + nerv.utils.printf("correct: %d\n", sm.total_correct) + nerv.utils.printf("frames: %d\n", sm.total_frames) nerv.utils.printf("err/frm: %.8f\n", sm.total_ce / sm.total_frames) + nerv.utils.printf("accuracy: %.8f\n", sm.total_correct / sm.total_frames) collectgarbage("collect") end +nerv.Matrix.print_profile() -- cgit v1.2.3-70-g09d2