From eba6049a82455499c68ee875843b6f44d6164fa5 Mon Sep 17 00:00:00 2001 From: Determinant Date: Fri, 5 Jun 2015 16:56:33 +0800 Subject: add close method for ChunkFile, fix #18 --- examples/chunk_file_example.lua | 53 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 53 insertions(+) create mode 100644 examples/chunk_file_example.lua (limited to 'examples') diff --git a/examples/chunk_file_example.lua b/examples/chunk_file_example.lua new file mode 100644 index 0000000..5961c98 --- /dev/null +++ b/examples/chunk_file_example.lua @@ -0,0 +1,53 @@ +-- To define a readable and writable chunk, one must define a class with the +-- following methods: __init(id, global_conf), read(handle), write(handle), +-- get_info(), set_info(info) and an id attribute. This file demonstrates a +-- basic chunk implementation which manages the I/O of a matrix + +local MatrixChunk = nerv.class("nerv.MatrixChunk") + +function MatrixChunk:__init(id, global_conf) + self.id = id + self.info = {} + self.gconf = global_conf +end + +function MatrixChunk:read(handle) + -- pass the read handle to the matrix method + self.data = nerv.MMatrixFloat.load(handle) +end + +function MatrixChunk:write(handle) + -- pass the write handle to the matrix method + self.data:save(handle) +end + +function MatrixChunk:get_info() + return self.info +end + +function MatrixChunk:set_info(info) + self.info = info +end + +function MatrixChunk.create_from_matrix(id, mat) + local ins = nerv.MatrixChunk(id) + ins.data = mat + return ins +end + +mat = nerv.MMatrixFloat(3, 4) +for i = 0, 2 do + for j = 0, 3 do + mat[i][j] = i + j + end +end + +cd = nerv.MatrixChunk.create_from_matrix("matrix1", mat) + +cf = nerv.ChunkFile("test.nerv", "w") +cf:write_chunk(cd) +cf:close() + +cf2 = nerv.ChunkFile("test.nerv", "r") +cd2 = cf2:read_chunk("matrix1") +print(cd2.data) -- cgit v1.2.3 From 37af4bed9c3680fdb9db569605f15013e9b6b64d Mon Sep 17 00:00:00 2001 From: Determinant Date: Fri, 5 Jun 2015 17:53:05 +0800 Subject: add get_params to all layers --- examples/test_nn_lib.lua | 29 +++++++++++++++++++++-------- 1 file changed, 21 insertions(+), 8 deletions(-) (limited to 'examples') diff --git a/examples/test_nn_lib.lua b/examples/test_nn_lib.lua index 04fd7d6..6fdbd67 100644 --- a/examples/test_nn_lib.lua +++ b/examples/test_nn_lib.lua @@ -117,7 +117,7 @@ tnet_reader = nerv.TNetReader(gconf, buffer = nerv.SGDBuffer(gconf, { buffer_size = 81920, - -- randomize = true, + randomize = true, readers = { { reader = tnet_reader, data = {main_scp = 429, ref = 1}} @@ -128,9 +128,12 @@ sm = sublayer_repo:get_layer("softmax_ce0") main = layer_repo:get_layer("main") main:init(gconf.batch_size) gconf.cnt = 0 +-- data = buffer:get_data() +-- input = {data.main_scp, data.ref} +-- while true do for data in buffer.get_data, buffer do - if gconf.cnt == 1000 then break end - gconf.cnt = gconf.cnt + 1 +-- if gconf.cnt == 100 then break end +-- gconf.cnt = gconf.cnt + 1 input = {data.main_scp, data.ref} output = {} @@ -141,11 +144,21 @@ for data in buffer.get_data, buffer do main:back_propagate(err_output, err_input, input, output) main:update(err_input, input, output) - nerv.utils.printf("cross entropy: %.8f\n", sm.total_ce) - nerv.utils.printf("correct: %d\n", sm.total_correct) - nerv.utils.printf("frames: %d\n", sm.total_frames) - nerv.utils.printf("err/frm: %.8f\n", sm.total_ce / sm.total_frames) - nerv.utils.printf("accuracy: %.8f\n", sm.total_correct / sm.total_frames) +-- nerv.utils.printf("cross entropy: %.8f\n", sm.total_ce) +-- nerv.utils.printf("correct: %d\n", sm.total_correct) +-- nerv.utils.printf("frames: %d\n", sm.total_frames) +-- nerv.utils.printf("err/frm: %.8f\n", sm.total_ce / sm.total_frames) +-- nerv.utils.printf("accuracy: %.8f\n", sm.total_correct / sm.total_frames) collectgarbage("collect") end +nerv.utils.printf("cross entropy: %.8f\n", sm.total_ce) +nerv.utils.printf("correct: %d\n", sm.total_correct) +nerv.utils.printf("accuracy: %.3f%%\n", sm.total_correct / sm.total_frames * 100) +nerv.utils.printf("writing back...\n") +cf = nerv.ChunkFile("output.nerv", "w") +for i, p in ipairs(main:get_params()) do + print(p) + cf:write_chunk(p) +end +cf:close() nerv.Matrix.print_profile() -- cgit v1.2.3