aboutsummaryrefslogtreecommitdiff
path: root/examples
diff options
context:
space:
mode:
Diffstat (limited to 'examples')
-rw-r--r--examples/test_nn_lib.lua15
1 files changed, 10 insertions, 5 deletions
diff --git a/examples/test_nn_lib.lua b/examples/test_nn_lib.lua
index 9600917..04fd7d6 100644
--- a/examples/test_nn_lib.lua
+++ b/examples/test_nn_lib.lua
@@ -116,7 +116,8 @@ tnet_reader = nerv.TNetReader(gconf,
buffer = nerv.SGDBuffer(gconf,
{
- buffer_size = 8192,
+ buffer_size = 81920,
+ -- randomize = true,
readers = {
{ reader = tnet_reader,
data = {main_scp = 429, ref = 1}}
@@ -126,10 +127,11 @@ buffer = nerv.SGDBuffer(gconf,
sm = sublayer_repo:get_layer("softmax_ce0")
main = layer_repo:get_layer("main")
main:init(gconf.batch_size)
-cnt = 0
+gconf.cnt = 0
for data in buffer.get_data, buffer do
- if cnt == 1000 then break end
- cnt = cnt + 1
+ if gconf.cnt == 1000 then break end
+ gconf.cnt = gconf.cnt + 1
+
input = {data.main_scp, data.ref}
output = {}
err_input = {}
@@ -140,7 +142,10 @@ for data in buffer.get_data, buffer do
main:update(err_input, input, output)
nerv.utils.printf("cross entropy: %.8f\n", sm.total_ce)
- nerv.utils.printf("frames: %.8f\n", sm.total_frames)
+ nerv.utils.printf("correct: %d\n", sm.total_correct)
+ nerv.utils.printf("frames: %d\n", sm.total_frames)
nerv.utils.printf("err/frm: %.8f\n", sm.total_ce / sm.total_frames)
+ nerv.utils.printf("accuracy: %.8f\n", sm.total_correct / sm.total_frames)
collectgarbage("collect")
end
+nerv.Matrix.print_profile()