summaryrefslogtreecommitdiff
path: root/examples/test_nn_lib.lua
diff options
context:
space:
mode:
authorDeterminant <[email protected]>2015-06-05 17:53:05 +0800
committerDeterminant <[email protected]>2015-06-05 17:53:05 +0800
commit37af4bed9c3680fdb9db569605f15013e9b6b64d (patch)
tree5f870d23f241edbc670c2778c955f6bd9d5eb1d5 /examples/test_nn_lib.lua
parenteba6049a82455499c68ee875843b6f44d6164fa5 (diff)
add get_params to all layers
Diffstat (limited to 'examples/test_nn_lib.lua')
-rw-r--r--examples/test_nn_lib.lua29
1 files changed, 21 insertions, 8 deletions
diff --git a/examples/test_nn_lib.lua b/examples/test_nn_lib.lua
index 04fd7d6..6fdbd67 100644
--- a/examples/test_nn_lib.lua
+++ b/examples/test_nn_lib.lua
@@ -117,7 +117,7 @@ tnet_reader = nerv.TNetReader(gconf,
buffer = nerv.SGDBuffer(gconf,
{
buffer_size = 81920,
- -- randomize = true,
+ randomize = true,
readers = {
{ reader = tnet_reader,
data = {main_scp = 429, ref = 1}}
@@ -128,9 +128,12 @@ sm = sublayer_repo:get_layer("softmax_ce0")
main = layer_repo:get_layer("main")
main:init(gconf.batch_size)
gconf.cnt = 0
+-- data = buffer:get_data()
+-- input = {data.main_scp, data.ref}
+-- while true do
for data in buffer.get_data, buffer do
- if gconf.cnt == 1000 then break end
- gconf.cnt = gconf.cnt + 1
+-- if gconf.cnt == 100 then break end
+-- gconf.cnt = gconf.cnt + 1
input = {data.main_scp, data.ref}
output = {}
@@ -141,11 +144,21 @@ for data in buffer.get_data, buffer do
main:back_propagate(err_output, err_input, input, output)
main:update(err_input, input, output)
- nerv.utils.printf("cross entropy: %.8f\n", sm.total_ce)
- nerv.utils.printf("correct: %d\n", sm.total_correct)
- nerv.utils.printf("frames: %d\n", sm.total_frames)
- nerv.utils.printf("err/frm: %.8f\n", sm.total_ce / sm.total_frames)
- nerv.utils.printf("accuracy: %.8f\n", sm.total_correct / sm.total_frames)
+-- nerv.utils.printf("cross entropy: %.8f\n", sm.total_ce)
+-- nerv.utils.printf("correct: %d\n", sm.total_correct)
+-- nerv.utils.printf("frames: %d\n", sm.total_frames)
+-- nerv.utils.printf("err/frm: %.8f\n", sm.total_ce / sm.total_frames)
+-- nerv.utils.printf("accuracy: %.8f\n", sm.total_correct / sm.total_frames)
collectgarbage("collect")
end
+nerv.utils.printf("cross entropy: %.8f\n", sm.total_ce)
+nerv.utils.printf("correct: %d\n", sm.total_correct)
+nerv.utils.printf("accuracy: %.3f%%\n", sm.total_correct / sm.total_frames * 100)
+nerv.utils.printf("writing back...\n")
+cf = nerv.ChunkFile("output.nerv", "w")
+for i, p in ipairs(main:get_params()) do
+ print(p)
+ cf:write_chunk(p)
+end
+cf:close()
nerv.Matrix.print_profile()