From d18122af2f57b8dd81db49385484f0e51d167a23 Mon Sep 17 00:00:00 2001 From: txh18 Date: Tue, 3 Nov 2015 18:36:43 +0800 Subject: still working on TNN --- nerv/examples/lmptb/lmptb/lmseqreader.lua | 21 +- nerv/examples/lmptb/m-tests/dagl_test.lua | 18 +- nerv/examples/lmptb/m-tests/lmseqreader_test.lua | 36 ++- nerv/examples/lmptb/rnn/layer_tdag.lua | 302 ----------------------- nerv/examples/lmptb/rnn/tnn.lua | 299 ++++++++++++++++++++++ 5 files changed, 352 insertions(+), 324 deletions(-) delete mode 100644 nerv/examples/lmptb/rnn/layer_tdag.lua create mode 100644 nerv/examples/lmptb/rnn/tnn.lua diff --git a/nerv/examples/lmptb/lmptb/lmseqreader.lua b/nerv/examples/lmptb/lmptb/lmseqreader.lua index 307c5a3..006b5cb 100644 --- a/nerv/examples/lmptb/lmptb/lmseqreader.lua +++ b/nerv/examples/lmptb/lmptb/lmseqreader.lua @@ -1,5 +1,5 @@ require 'lmptb.lmvocab' -require 'rnn.layer_tdag' +require 'rnn.tnn' local LMReader = nerv.class("nerv.LMSeqReader") @@ -66,7 +66,7 @@ function LMReader:refresh_stream(id) end --feeds: a table that will be filled by the reader ---Returns: inputs_m, labels_m +--Returns: bool function LMReader:get_batch(feeds) if (feeds == nil or type(feeds) ~= "table") then nerv.error("feeds is not a table") @@ -74,36 +74,49 @@ function LMReader:get_batch(feeds) feeds["inputs_s"] = {} feeds["labels_s"] = {} - inputs_s = feeds.inputs_s - labels_s = feeds.labels_s + local inputs_s = feeds.inputs_s + local labels_s = feeds.labels_s for i = 1, self.chunk_size, 1 do inputs_s[i] = {} labels_s[i] = {} end + local inputs_m = feeds.inputs_m --port 1 : word_id, port 2 : label + local flags = feeds.flags_now + local got_new = false for i = 1, self.batch_size, 1 do local st = self.streams[i] for j = 1, self.chunk_size, 1 do + flags[j][i] = 0 self:refresh_stream(i) if (st.store[st.head] ~= nil) then inputs_s[j][i] = st.store[st.head] + inputs_m[j][1][i - 1][0] = self.vocab:get_word_str(st.store[st.head]).id - 1 else inputs_s[j][i] = self.vocab.null_token + inputs_m[j][1][i - 1][0] = 0 end + inputs_m[j][2][i - 1]:fill(0) if (st.store[st.head + 1] ~= nil) then labels_s[j][i] = st.store[st.head + 1] + inputs_m[j][2][i - 1][self.vocab:get_word_str(st.store[st.head + 1]).id - 1] = 1 else labels_s[j][i] = self.vocab.null_token end if (inputs_s[j][i] ~= self.vocab.null_token) then + flags[j][i] = bit.bor(flags[j][i], nerv.TNN.FC.SEQ_NORM) got_new = true st.store[st.head] = nil st.head = st.head + 1 if (labels_s[j][i] == self.vocab.sen_end_token) then + flags[j][i] = bit.bor(flags[j][i], nerv.TNN.FC.SEQ_END) st.store[st.head] = nil --sentence end is passed st.head = st.head + 1 end + if (inputs_s[j][i] == self.vocab.send_end_token) then + flags[j][i] = bit.bor(flags[j][i], nerv.TNN.FC.SEQ_START) + end end end end diff --git a/nerv/examples/lmptb/m-tests/dagl_test.lua b/nerv/examples/lmptb/m-tests/dagl_test.lua index 5e90551..a50107d 100644 --- a/nerv/examples/lmptb/m-tests/dagl_test.lua +++ b/nerv/examples/lmptb/m-tests/dagl_test.lua @@ -98,14 +98,11 @@ end --global_conf: table --layerRepo: nerv.LayerRepo ---Returns: a nerv.TDAGLayer +--Returns: a nerv.TNN function prepare_dagLayer(global_conf, layerRepo) - printf("%s Initing daglayer ...\n", global_conf.sche_log_pre) + printf("%s Initing TNN ...\n", global_conf.sche_log_pre) --input: input_w, input_w, ... input_w_now, last_activation - local dim_in_t = {} - dim_in_t[1] = 1 --input to select_linear layer - dim_in_t[2] = global_conf.vocab:size() --input to softmax label local connections_t = { {"[1]", "selectL1[1]", 0}, {"selectL1[1]", "recurrentL1[1]", 0}, @@ -124,11 +121,11 @@ function prepare_dagLayer(global_conf, layerRepo) end ]]-- - local dagL = nerv.TDAGLayer("dagL", global_conf, {["dim_in"] = dim_in_t, ["dim_out"] = {1}, ["sub_layers"] = layerRepo, + local tnn = nerv.TNN("TNN", global_conf, {["dim_in"] = {1, global_conf.vocab:size()}, ["dim_out"] = {1}, ["sub_layers"] = layerRepo, ["connections"] = connections_t, }) - printf("%s Initing DAGLayer end.\n", global_conf.sche_log_pre) - return dagL + printf("%s Initing TNN end.\n", global_conf.sche_log_pre) + return tnn end train_fn = '/home/slhome/txh18/workspace/nerv/nerv/nerv/examples/lmptb/m-tests/some-text' @@ -160,7 +157,6 @@ global_conf["vocab"] = vocab global_conf.vocab:build_file(global_conf.train_fn, false) local paramRepo = prepare_parameters(global_conf, true) local layerRepo = prepare_layers(global_conf, paramRepo) -local dagL = prepare_dagLayer(global_conf, layerRepo) -dagL:init(global_conf.batch_size, global_conf.chunk_size) - +local tnn = prepare_dagLayer(global_conf, layerRepo) +tnn:init(global_conf.batch_size, global_conf.chunk_size) diff --git a/nerv/examples/lmptb/m-tests/lmseqreader_test.lua b/nerv/examples/lmptb/m-tests/lmseqreader_test.lua index 504698f..cbcdcbe 100644 --- a/nerv/examples/lmptb/m-tests/lmseqreader_test.lua +++ b/nerv/examples/lmptb/m-tests/lmseqreader_test.lua @@ -1,4 +1,5 @@ require 'lmptb.lmseqreader' +require 'lmptb.lmutil' local printf = nerv.printf @@ -8,15 +9,36 @@ local vocab = nerv.LMVocab() vocab:build_file(test_fn) local chunk_size = 5 local batch_size = 3 -local reader = nerv.LMSeqReader({}, batch_size, chunk_size, vocab) +local global_conf = { + lrate = 1, wcost = 1e-6, momentum = 0, + cumat_type = nerv.CuMatrixFloat, + mmat_type = nerv.CuMatrixFloat, + + hidden_size = 20, + chunk_size = chunk_size, + batch_size = batch_size, + max_iter = 18, + param_random = function() return (math.random() / 5 - 0.1) end, + independent = true, + + train_fn = train_fn, + test_fn = test_fn, + sche_log_pre = "[SCHEDULER]:", + log_w_num = 10, --give a message when log_w_num words have been processed + timer = nerv.Timer(), + + vocab = vocab +} + +local reader = nerv.LMSeqReader(global_conf, batch_size, chunk_size, vocab) reader:open_file(test_fn) -local input = {} -local label = {} -for i = 1, batch_size, 1 do - input[i] = {} - label[i] = {} -end local feeds = {} +feeds.flags_now = {} +feeds.inputs_m = {} +for j = 1, chunk_size do + feeds.inputs_m[j] = {global_conf.cumat_type(batch_size, 1), global_conf.cumat_type(batch_size, global_conf.vocab:size())} + feeds.flags_now[j] = {} +end while (1) do local r = reader:get_batch(feeds) if (r == false) then break end diff --git a/nerv/examples/lmptb/rnn/layer_tdag.lua b/nerv/examples/lmptb/rnn/layer_tdag.lua deleted file mode 100644 index 6e5d774..0000000 --- a/nerv/examples/lmptb/rnn/layer_tdag.lua +++ /dev/null @@ -1,302 +0,0 @@ -local DAGLayer = nerv.class("nerv.TDAGLayer", "nerv.Layer") - -local function parse_id(str) - --used to parse layerid[portid],time - local id, port, time, _ - _, _, id, port, time = string.find(str, "([a-zA-Z0-9_]+)%[([0-9]+)%][,]*([0-9]*)") - if id == nil or port == nil then - _, _, id, port, time = string.find(str, "(.+)%[([0-9]+)%][,]*([0-9]*)") - if not (id == "" or id == "") then - nerv.error("wrong format of connection id") - end - end - --print(str, id, port, time) - port = tonumber(port) - if (time == nil) then - time = 0 - else - time = tonumber(time) - end - --now time don't need to be parsed - return id, port -end - -local function discover(id, layers, layer_repo) - local ref = layers[id] - if id == "" or id == "" then - return nil - end - if ref == nil then - local layer = layer_repo:get_layer(id) - local dim_in, dim_out = layer:get_dim() - ref = { - layer = layer, - inputs_m = {}, --storage for computation, inputs_m[port][time] - outputs_m = {}, - err_inputs_m = {}, - err_outputs_m = {}, - conns_i = {}, --list of inputing connections - conns_o = {}, --list of outputing connections - dim_in = dim_in, --list of dimensions of ports - dim_out = dim_out, - } - layers[id] = ref - end - return ref -end - -function DAGLayer.makeInitialStore(dim, batch_size, chunk_size, global_conf) - --Return a table of matrix storage from time (1-chunk_size)..(2*chunk_size) - st = {} - for i = 1 - chunk_size, chunk_size * 2 do - st[i] = global_conf.cumat_type(batch_size, dim) - end - return st -end - -function DAGLayer:__init(id, global_conf, layer_conf) - local layers = {} - local inputs_p = {} --map:port of the TDAGLayer to layer ref and port - local outputs_p = {} - local dim_in = layer_conf.dim_in - local dim_out = layer_conf.dim_out - local parsed_conns = {} - local _ - - for _, ll in pairs(layer_conf.connections) do - local id_from, port_from = parse_id(ll[1]) - local id_to, port_to = parse_id(ll[2]) - local time_to = ll[3] - - print(id_from, id_to, time_to) - - local ref_from = discover(id_from, layers, layer_conf.sub_layers) - local ref_to = discover(id_to, layers, layer_conf.sub_layers) - - if (id_from == "") then - if (dim_in[port_from] ~= ref_to.dim_in[port_to] or time_to ~= 0) then - nerv.error("mismatch dimension or wrong time %s,%s,%d", ll[1], ll[2], ll[3]) - end - inputs_p[port_from] = {["ref"] = ref_to, ["port"] = port_to} - ref_to.inputs_m[port_to] = {} --just a place holder - elseif (id_to == "") then - if (dim_out[port_to] ~= ref_from.dim_out[port_from] or time_to ~= 0) then - nerv.error("mismatch dimension or wrong time %s,%s,%d", ll[1], ll[2], ll[3]) - end - outputs_p[port_to] = {["ref"] = ref_from, ["port"] = port_from} - ref_from.outputs_m[port_from] = {} --just a place holder - else - conn_now = { - ["src"] = {["ref"] = ref_from, ["port"] = port_from}, - ["dst"] = {["ref"] = ref_to, ["port"] = port_to}, - ["time"] = time_to - } - if (ref_to.dim_in[port_to] ~= ref_from.dim_out[port_from]) then - nerv.error("mismatch dimension or wrong time %s,%s,%d", ll[1], ll[2], ll[3]) - end - table.insert(parsed_conns, conn_now) - table.insert(ref_to.conns_i, conn_now) - table.insert(ref_from.conns_o, conn_now) - end - end - - self.layers = layers - self.inputs_p = inputs_p - self.outputs_p = outputs_p - self.id = id - self.dim_in = dim_in - self.dim_out = dim_out - self.parsed_conns = parsed_conns - self.gconf = global_conf -end - -function DAGLayer:init(batch_size, chunk_size) - for i, conn in ipairs(self.parsed_conns) do - local _, output_dim - local ref_from, port_from, ref_to, port_to - ref_from, port_from = conn.src.ref, conn.src.port - ref_to, port_to = conn.dst.ref, conn.dst.port - local dim = ref_from.dim_out[port_from] - if (dim == 0) then - nerv.error("layer %s has a zero dim port", ref_from.layer.id) - end - - local mid = DAGLayer.makeInitialStore(dim, batch_size, chunk_size, global_conf) - local err_mid = DAGLayer.makeInitialStore(dim, batch_size, chunk_size, global_conf) - - print(ref_from.layer.id, "->", ref_to.layer.id) - - ref_from.outputs_m[port_from] = mid - ref_to.inputs_m[port_to] = mid - - ref_from.err_inputs_m[port_from] = err_mid - ref_to.err_outputs_m[port_to] = err_mid - end - for id, ref in pairs(self.layers) do - for i = 1, #ref.dim_in do - if ref.inputs_m[i] == nil then - nerv.error("dangling input port %d of layer %s", i, id) - end - end - for i = 1, #ref.dim_out do - if ref.outputs_m[i] == nil then - nerv.error("dangling output port %d of layer %s", i, id) - end - end - -- initialize sub layers - ref.layer:init(batch_size) - end - for i = 1, #self.dim_in do - if self.inputs_p[i] == nil then - nerv.error(" port %d not attached", i) - end - end - for i = 1, #self.dim_out do - if self.outputs_p[i] == nil then - nerv.error(" port %d not attached", i) - end - end -end - ---[[ -function DAGLayer:batch_resize(batch_size) - self.gconf.batch_size = batch_size - - for i, conn in ipairs(self.parsed_conn) do - local _, output_dim - local ref_from, port_from, ref_to, port_to - ref_from, port_from = unpack(conn[1]) - ref_to, port_to = unpack(conn[2]) - _, output_dim = ref_from.layer:get_dim() - - if ref_from.outputs[port_from]:nrow() ~= batch_size and output_dim[port_from] > 0 then - local mid = self.gconf.cumat_type(batch_size, output_dim[port_from]) - local err_mid = mid:create() - - ref_from.outputs[port_from] = mid - ref_to.inputs[port_to] = mid - - ref_from.err_inputs[port_from] = err_mid - ref_to.err_outputs[port_to] = err_mid - end - end - for id, ref in pairs(self.layers) do - ref.layer:batch_resize(batch_size) - end - collectgarbage("collect") -end -]]-- - -function DAGLayer:set_inputs(inputs_m) - for i = 1, #self.dim_in do - if inputs_m[i] == nil then - nerv.error("inputs_m[%d] is not provided", i); - end - local ref = self.inputs_p[i].ref - local p = self.inputs_p[i].port - ref.inputs_m[p] = inputs_m[i] - end -end - -function DAGLayer:set_outputs(outputs_m) - for i = 1, #self.dim_out do - if outputs_m[i] == nil then - nerv.error("outputs_m[%d] is not provided", i); - end - local ref = self.outputs_p[i].ref - local p = self.outputs_p[i].port - ref.outputs_m[p] = outputs_m[i] - end -end - -function DAGLayer:set_err_inputs(bp_errs_m) - for i = 1, #self.dim_out do - if bp_errs_m[i] == nil then - nerv.error("bp_errs_m[%d] is not provided", i); - end - local ref = self.outputs_p[i].ref - local p = self.outputs_p[i].port - ref.err_inputs_m[p] = bp_errs_m[i] - end -end - -function DAGLayer:set_err_outputs(next_bp_err) - for i = 1, #self.dim_in do - if (next_bp_err[i] == nil) then - nerv.error("next_bp_err[%d] is not provided", i) - end - local ref = self.inputs_p[i].ref - local p = self.inputs_p[i].port - ref.err_outputs_m[p] = next_bp_err[i] - end -end - -function DAGLayer:update(bp_err, input, output) - self:set_err_inputs(bp_err) - self:set_inputs(input) - self:set_outputs(output) - -- print("update") - for id, ref in pairs(self.queue) do - -- print(ref.layer.id) - ref.layer:update(ref.err_inputs, ref.inputs, ref.outputs) - end -end - -function DAGLayer:propagate(input, output) - self:set_inputs(input) - self:set_outputs(output) - local ret = false - for i = 1, #self.queue do - local ref = self.queue[i] - -- print(ref.layer.id) - ret = ref.layer:propagate(ref.inputs, ref.outputs) - end - return ret -end - -function DAGLayer:back_propagate(bp_err, next_bp_err, input, output) - self:set_err_outputs(next_bp_err) - self:set_err_inputs(bp_err) - self:set_inputs(input) - self:set_outputs(output) - for i = #self.queue, 1, -1 do - local ref = self.queue[i] - -- print(ref.layer.id) - ref.layer:back_propagate(ref.err_inputs, ref.err_outputs, ref.inputs, ref.outputs) - end -end - -function DAGLayer:get_params() - local param_repos = {} - for id, ref in pairs(self.queue) do - table.insert(param_repos, ref.layer:get_params()) - end - return nerv.ParamRepo.merge(param_repos) -end - -DAGLayer.PORT_TYPES = { - INPUT = {}, - OUTPUT = {}, - ERR_INPUT = {}, - ERR_OUTPUT = {} -} - -function DAGLayer:get_intermediate(id, port_type) - if id == "" or id == "" then - nerv.error("an actual real layer id is expected") - end - local layer = self.layers[id] - if layer == nil then - nerv.error("layer id %s not found", id) - end - if port_type == DAGLayer.PORT_TYPES.INPUT then - return layer.inputs - elseif port_type == DAGLayer.PORT_TYPES.OUTPUT then - return layer.outputs - elseif port_type == DAGLayer.PORT_TYPES.ERR_INPUT then - return layer.err_inputs - elseif port_type == DAGLayer.PORT_TYPES.ERR_OUTPUT then - return layer.err_outputs - end - nerv.error("unrecognized port type") -end diff --git a/nerv/examples/lmptb/rnn/tnn.lua b/nerv/examples/lmptb/rnn/tnn.lua new file mode 100644 index 0000000..3f192b5 --- /dev/null +++ b/nerv/examples/lmptb/rnn/tnn.lua @@ -0,0 +1,299 @@ +local TNN = nerv.class("nerv.TNN", "nerv.Layer") +local DAGLayer = TNN + +local function parse_id(str) + --used to parse layerid[portid],time + local id, port, time, _ + _, _, id, port, time = string.find(str, "([a-zA-Z0-9_]+)%[([0-9]+)%][,]*([0-9]*)") + if id == nil or port == nil then + _, _, id, port, time = string.find(str, "(.+)%[([0-9]+)%][,]*([0-9]*)") + if not (id == "" or id == "") then + nerv.error("wrong format of connection id") + end + end + --print(str, id, port, time) + port = tonumber(port) + if (time == nil) then + time = 0 + else + time = tonumber(time) + end + --now time don't need to be parsed + return id, port +end + +local function discover(id, layers, layer_repo) + local ref = layers[id] + if id == "" or id == "" then + return nil + end + if ref == nil then + local layer = layer_repo:get_layer(id) + local dim_in, dim_out = layer:get_dim() + ref = { + layer = layer, + inputs_m = {}, --storage for computation, inputs_m[port][time] + outputs_m = {}, + err_inputs_m = {}, + err_outputs_m = {}, + conns_i = {}, --list of inputing connections + conns_o = {}, --list of outputing connections + dim_in = dim_in, --list of dimensions of ports + dim_out = dim_out, + } + layers[id] = ref + end + return ref +end + +nerv.TNN.FC = {} --flag const +nerv.TNN.FC.SEQ_START = 4 +nerv.TNN.FC.SEQ_END = 8 +nerv.TNN.FC.HAS_INPUT = 1 +nerv.TNN.FC.HAS_LABEL = 2 +nerv.TNN.FC.SEQ_NORM = bit.bor(nerv.TNN.FC.HAS_INPUT, nerv.TNN.FC.HAS_LABEL) --This instance have both input and label + +function DAGLayer.makeInitialStore(st, p, dim, batch_size, chunk_size, global_conf, st_c, p_c) + --Return a table of matrix storage from time (1-chunk_size)..(2*chunk_size) + if (type(st) ~= "table") then + nerv.error("st should be a table") + end + for i = 1 - chunk_size, chunk_size * 2 do + if (st[i] == nil) then + st[i] = {} + end + st[i][p] = global_conf.cumat_type(batch_size, dim) + st[i][p]:fill(0) + if (st_c ~= nil) then + if (st_c[i] == nil) then + st_c[i] = {} + end + st_c[i][p_c] = st[i][p] + end + end +end + +function DAGLayer:__init(id, global_conf, layer_conf) + local layers = {} + local inputs_p = {} --map:port of the TDAGLayer to layer ref and port + local outputs_p = {} + local dim_in = layer_conf.dim_in + local dim_out = layer_conf.dim_out + local parsed_conns = {} + local _ + + for _, ll in pairs(layer_conf.connections) do + local id_from, port_from = parse_id(ll[1]) + local id_to, port_to = parse_id(ll[2]) + local time_to = ll[3] + + print(id_from, id_to, time_to) + + local ref_from = discover(id_from, layers, layer_conf.sub_layers) + local ref_to = discover(id_to, layers, layer_conf.sub_layers) + + if (id_from == "") then + if (dim_in[port_from] ~= ref_to.dim_in[port_to] or time_to ~= 0) then + nerv.error("mismatch dimension or wrong time %s,%s,%d", ll[1], ll[2], ll[3]) + end + inputs_p[port_from] = {["ref"] = ref_to, ["port"] = port_to} + ref_to.inputs_m[port_to] = {} --just a place holder + elseif (id_to == "") then + if (dim_out[port_to] ~= ref_from.dim_out[port_from] or time_to ~= 0) then + nerv.error("mismatch dimension or wrong time %s,%s,%d", ll[1], ll[2], ll[3]) + end + outputs_p[port_to] = {["ref"] = ref_from, ["port"] = port_from} + ref_from.outputs_m[port_from] = {} --just a place holder + else + conn_now = { + ["src"] = {["ref"] = ref_from, ["port"] = port_from}, + ["dst"] = {["ref"] = ref_to, ["port"] = port_to}, + ["time"] = time_to + } + if (ref_to.dim_in[port_to] ~= ref_from.dim_out[port_from]) then + nerv.error("mismatch dimension or wrong time %s,%s,%d", ll[1], ll[2], ll[3]) + end + table.insert(parsed_conns, conn_now) + table.insert(ref_to.conns_i, conn_now) + table.insert(ref_from.conns_o, conn_now) + end + end + + self.layers = layers + self.inputs_p = inputs_p + self.outputs_p = outputs_p + self.id = id + self.dim_in = dim_in + self.dim_out = dim_out + self.parsed_conns = parsed_conns + self.gconf = global_conf +end + +function DAGLayer:init(batch_size, chunk_size) + for i, conn in ipairs(self.parsed_conns) do --init storage for connections inside the NN + local _, output_dim + local ref_from, port_from, ref_to, port_to + ref_from, port_from = conn.src.ref, conn.src.port + ref_to, port_to = conn.dst.ref, conn.dst.port + local dim = ref_from.dim_out[port_from] + if (dim == 0) then + nerv.error("layer %s has a zero dim port", ref_from.layer.id) + end + + print("TNN initing storage", ref_from.layer.id, "->", ref_to.layer.id) + self.makeInitialStore(ref_from.outputs_m, port_from, dim, batch_size, chunk_size, global_conf, ref_to.inputs_m, port_to) + self.makeInitialStore(ref_from.err_inputs_m, port_from, dim, batch_size, chunk_size, global_conf, ref_to.err_outputs_m, port_to) + + end + + self.outputs_m = {} + self.err_inputs_m = {} + for i = 1, #self.dim_out do --Init storage for output ports + local ref = self.outputs_p[i].ref + local p = self.outputs_p[i].port + self.makeInitialStore(ref.outputs_m, p, self.dim_out[i], batch_size, chunk_size, self.gconf, self.outputs_m, i) + self.makeInitialStore(ref.err_inputs_m, p, self.dim_out[i], batch_size, chunk_size, self.gconf, self.err_inputs_m, i) + end + + self.inputs_m = {} + self.err_outputs_m = {} + for i = 1, #self.dim_in do --Init storage for input ports + local ref = self.inputs_p[i].ref + local p = self.inputs_p[i].port + self.makeInitialStore(ref.inputs_m, p, self.dim_in[i], batch_size, chunk_size, self.gconf, self.inputs_m, i) + self.makeInitialStore(ref.err_outputs_m, p, self.dim_in[i], batch_size, chunk_size, self.gconf, self.err_outputs_m, i) + end + + for id, ref in pairs(self.layers) do --Calling init for child layers + for i = 1, #ref.dim_in do + if (ref.inputs_m[i] == nil or ref.err_outputs_m[i] == nil) then + nerv.error("dangling input port %d of layer %s", i, id) + end + end + for i = 1, #ref.dim_out do + if (ref.outputs_m[i] == nil or ref.err_inputs_m[i] == nil) then + nerv.error("dangling output port %d of layer %s", i, id) + end + end + -- initialize sub layers + ref.layer:init(batch_size) + end + + local flags_now = {} + for i = 1, chunk_size do + flags_now[i] = {} + end + + self.feeds_now = {} --feeds is for the reader to fill + self.feeds_now.inputs_m = self.inputs_m + self.feeds_now.flags_now = flags_now +end + +--[[ +function DAGLayer:batch_resize(batch_size) + self.gconf.batch_size = batch_size + + for i, conn in ipairs(self.parsed_conn) do + local _, output_dim + local ref_from, port_from, ref_to, port_to + ref_from, port_from = unpack(conn[1]) + ref_to, port_to = unpack(conn[2]) + _, output_dim = ref_from.layer:get_dim() + + if ref_from.outputs[port_from]:nrow() ~= batch_size and output_dim[port_from] > 0 then + local mid = self.gconf.cumat_type(batch_size, output_dim[port_from]) + local err_mid = mid:create() + + ref_from.outputs[port_from] = mid + ref_to.inputs[port_to] = mid + + ref_from.err_inputs[port_from] = err_mid + ref_to.err_outputs[port_to] = err_mid + end + end + for id, ref in pairs(self.layers) do + ref.layer:batch_resize(batch_size) + end + collectgarbage("collect") +end +]]-- + +--reader: some reader +--Returns: bool, whether has new feed +--Returns: feeds, a table that will be filled with the reader's feeds +function DAGLayer:getFeedFromReader(reader) + local feeds = self.feeds_now + local got_new = reader:get_batch(feeds) + return got_new, feeds +end + +function DAGLayer:update(bp_err, input, output) + self:set_err_inputs(bp_err) + self:set_inputs(input) + self:set_outputs(output) + -- print("update") + for id, ref in pairs(self.queue) do + -- print(ref.layer.id) + ref.layer:update(ref.err_inputs, ref.inputs, ref.outputs) + end +end + +function DAGLayer:propagate(input, output) + self:set_inputs(input) + self:set_outputs(output) + local ret = false + for i = 1, #self.queue do + local ref = self.queue[i] + -- print(ref.layer.id) + ret = ref.layer:propagate(ref.inputs, ref.outputs) + end + return ret +end + +function DAGLayer:back_propagate(bp_err, next_bp_err, input, output) + self:set_err_outputs(next_bp_err) + self:set_err_inputs(bp_err) + self:set_inputs(input) + self:set_outputs(output) + for i = #self.queue, 1, -1 do + local ref = self.queue[i] + -- print(ref.layer.id) + ref.layer:back_propagate(ref.err_inputs, ref.err_outputs, ref.inputs, ref.outputs) + end +end + +--Return: nerv.ParamRepo +function DAGLayer:get_params() + local param_repos = {} + for id, ref in pairs(self.queue) do + table.insert(param_repos, ref.layer:get_params()) + end + return nerv.ParamRepo.merge(param_repos) +end + +DAGLayer.PORT_TYPES = { + INPUT = {}, + OUTPUT = {}, + ERR_INPUT = {}, + ERR_OUTPUT = {} +} + +function DAGLayer:get_intermediate(id, port_type) + if id == "" or id == "" then + nerv.error("an actual real layer id is expected") + end + local layer = self.layers[id] + if layer == nil then + nerv.error("layer id %s not found", id) + end + if port_type == DAGLayer.PORT_TYPES.INPUT then + return layer.inputs + elseif port_type == DAGLayer.PORT_TYPES.OUTPUT then + return layer.outputs + elseif port_type == DAGLayer.PORT_TYPES.ERR_INPUT then + return layer.err_inputs + elseif port_type == DAGLayer.PORT_TYPES.ERR_OUTPUT then + return layer.err_outputs + end + nerv.error("unrecognized port type") +end -- cgit v1.2.3-70-g09d2