diff options
Diffstat (limited to 'nerv/examples/lmptb/rnn')
-rw-r--r-- | nerv/examples/lmptb/rnn/tnn.lua (renamed from nerv/examples/lmptb/rnn/layer_tdag.lua) | 133 |
1 files changed, 65 insertions, 68 deletions
diff --git a/nerv/examples/lmptb/rnn/layer_tdag.lua b/nerv/examples/lmptb/rnn/tnn.lua index 6e5d774..3f192b5 100644 --- a/nerv/examples/lmptb/rnn/layer_tdag.lua +++ b/nerv/examples/lmptb/rnn/tnn.lua @@ -1,4 +1,5 @@ -local DAGLayer = nerv.class("nerv.TDAGLayer", "nerv.Layer") +local TNN = nerv.class("nerv.TNN", "nerv.Layer") +local DAGLayer = TNN local function parse_id(str) --used to parse layerid[portid],time @@ -45,13 +46,31 @@ local function discover(id, layers, layer_repo) return ref end -function DAGLayer.makeInitialStore(dim, batch_size, chunk_size, global_conf) +nerv.TNN.FC = {} --flag const +nerv.TNN.FC.SEQ_START = 4 +nerv.TNN.FC.SEQ_END = 8 +nerv.TNN.FC.HAS_INPUT = 1 +nerv.TNN.FC.HAS_LABEL = 2 +nerv.TNN.FC.SEQ_NORM = bit.bor(nerv.TNN.FC.HAS_INPUT, nerv.TNN.FC.HAS_LABEL) --This instance have both input and label + +function DAGLayer.makeInitialStore(st, p, dim, batch_size, chunk_size, global_conf, st_c, p_c) --Return a table of matrix storage from time (1-chunk_size)..(2*chunk_size) - st = {} + if (type(st) ~= "table") then + nerv.error("st should be a table") + end for i = 1 - chunk_size, chunk_size * 2 do - st[i] = global_conf.cumat_type(batch_size, dim) + if (st[i] == nil) then + st[i] = {} + end + st[i][p] = global_conf.cumat_type(batch_size, dim) + st[i][p]:fill(0) + if (st_c ~= nil) then + if (st_c[i] == nil) then + st_c[i] = {} + end + st_c[i][p_c] = st[i][p] + end end - return st end function DAGLayer:__init(id, global_conf, layer_conf) @@ -111,7 +130,7 @@ function DAGLayer:__init(id, global_conf, layer_conf) end function DAGLayer:init(batch_size, chunk_size) - for i, conn in ipairs(self.parsed_conns) do + for i, conn in ipairs(self.parsed_conns) do --init storage for connections inside the NN local _, output_dim local ref_from, port_from, ref_to, port_to ref_from, port_from = conn.src.ref, conn.src.port @@ -121,41 +140,53 @@ function DAGLayer:init(batch_size, chunk_size) nerv.error("layer %s has a zero dim port", ref_from.layer.id) end - local mid = DAGLayer.makeInitialStore(dim, batch_size, chunk_size, global_conf) - local err_mid = DAGLayer.makeInitialStore(dim, batch_size, chunk_size, global_conf) + print("TNN initing storage", ref_from.layer.id, "->", ref_to.layer.id) + self.makeInitialStore(ref_from.outputs_m, port_from, dim, batch_size, chunk_size, global_conf, ref_to.inputs_m, port_to) + self.makeInitialStore(ref_from.err_inputs_m, port_from, dim, batch_size, chunk_size, global_conf, ref_to.err_outputs_m, port_to) - print(ref_from.layer.id, "->", ref_to.layer.id) + end - ref_from.outputs_m[port_from] = mid - ref_to.inputs_m[port_to] = mid + self.outputs_m = {} + self.err_inputs_m = {} + for i = 1, #self.dim_out do --Init storage for output ports + local ref = self.outputs_p[i].ref + local p = self.outputs_p[i].port + self.makeInitialStore(ref.outputs_m, p, self.dim_out[i], batch_size, chunk_size, self.gconf, self.outputs_m, i) + self.makeInitialStore(ref.err_inputs_m, p, self.dim_out[i], batch_size, chunk_size, self.gconf, self.err_inputs_m, i) + end - ref_from.err_inputs_m[port_from] = err_mid - ref_to.err_outputs_m[port_to] = err_mid - end - for id, ref in pairs(self.layers) do + self.inputs_m = {} + self.err_outputs_m = {} + for i = 1, #self.dim_in do --Init storage for input ports + local ref = self.inputs_p[i].ref + local p = self.inputs_p[i].port + self.makeInitialStore(ref.inputs_m, p, self.dim_in[i], batch_size, chunk_size, self.gconf, self.inputs_m, i) + self.makeInitialStore(ref.err_outputs_m, p, self.dim_in[i], batch_size, chunk_size, self.gconf, self.err_outputs_m, i) + end + + for id, ref in pairs(self.layers) do --Calling init for child layers for i = 1, #ref.dim_in do - if ref.inputs_m[i] == nil then + if (ref.inputs_m[i] == nil or ref.err_outputs_m[i] == nil) then nerv.error("dangling input port %d of layer %s", i, id) end end for i = 1, #ref.dim_out do - if ref.outputs_m[i] == nil then + if (ref.outputs_m[i] == nil or ref.err_inputs_m[i] == nil) then nerv.error("dangling output port %d of layer %s", i, id) end end -- initialize sub layers ref.layer:init(batch_size) end - for i = 1, #self.dim_in do - if self.inputs_p[i] == nil then - nerv.error("<input> port %d not attached", i) - end - end - for i = 1, #self.dim_out do - if self.outputs_p[i] == nil then - nerv.error("<output> port %d not attached", i) - end + + local flags_now = {} + for i = 1, chunk_size do + flags_now[i] = {} end + + self.feeds_now = {} --feeds is for the reader to fill + self.feeds_now.inputs_m = self.inputs_m + self.feeds_now.flags_now = flags_now end --[[ @@ -187,48 +218,13 @@ function DAGLayer:batch_resize(batch_size) end ]]-- -function DAGLayer:set_inputs(inputs_m) - for i = 1, #self.dim_in do - if inputs_m[i] == nil then - nerv.error("inputs_m[%d] is not provided", i); - end - local ref = self.inputs_p[i].ref - local p = self.inputs_p[i].port - ref.inputs_m[p] = inputs_m[i] - end -end - -function DAGLayer:set_outputs(outputs_m) - for i = 1, #self.dim_out do - if outputs_m[i] == nil then - nerv.error("outputs_m[%d] is not provided", i); - end - local ref = self.outputs_p[i].ref - local p = self.outputs_p[i].port - ref.outputs_m[p] = outputs_m[i] - end -end - -function DAGLayer:set_err_inputs(bp_errs_m) - for i = 1, #self.dim_out do - if bp_errs_m[i] == nil then - nerv.error("bp_errs_m[%d] is not provided", i); - end - local ref = self.outputs_p[i].ref - local p = self.outputs_p[i].port - ref.err_inputs_m[p] = bp_errs_m[i] - end -end - -function DAGLayer:set_err_outputs(next_bp_err) - for i = 1, #self.dim_in do - if (next_bp_err[i] == nil) then - nerv.error("next_bp_err[%d] is not provided", i) - end - local ref = self.inputs_p[i].ref - local p = self.inputs_p[i].port - ref.err_outputs_m[p] = next_bp_err[i] - end +--reader: some reader +--Returns: bool, whether has new feed +--Returns: feeds, a table that will be filled with the reader's feeds +function DAGLayer:getFeedFromReader(reader) + local feeds = self.feeds_now + local got_new = reader:get_batch(feeds) + return got_new, feeds end function DAGLayer:update(bp_err, input, output) @@ -266,6 +262,7 @@ function DAGLayer:back_propagate(bp_err, next_bp_err, input, output) end end +--Return: nerv.ParamRepo function DAGLayer:get_params() local param_repos = {} for id, ref in pairs(self.queue) do |