aboutsummaryrefslogtreecommitdiff
path: root/nerv/examples/lmptb/rnn/tnn.lua
diff options
context:
space:
mode:
Diffstat (limited to 'nerv/examples/lmptb/rnn/tnn.lua')
-rw-r--r--nerv/examples/lmptb/rnn/tnn.lua299
1 files changed, 299 insertions, 0 deletions
diff --git a/nerv/examples/lmptb/rnn/tnn.lua b/nerv/examples/lmptb/rnn/tnn.lua
new file mode 100644
index 0000000..3f192b5
--- /dev/null
+++ b/nerv/examples/lmptb/rnn/tnn.lua
@@ -0,0 +1,299 @@
+local TNN = nerv.class("nerv.TNN", "nerv.Layer")
+local DAGLayer = TNN
+
+local function parse_id(str)
+ --used to parse layerid[portid],time
+ local id, port, time, _
+ _, _, id, port, time = string.find(str, "([a-zA-Z0-9_]+)%[([0-9]+)%][,]*([0-9]*)")
+ if id == nil or port == nil then
+ _, _, id, port, time = string.find(str, "(.+)%[([0-9]+)%][,]*([0-9]*)")
+ if not (id == "<input>" or id == "<output>") then
+ nerv.error("wrong format of connection id")
+ end
+ end
+ --print(str, id, port, time)
+ port = tonumber(port)
+ if (time == nil) then
+ time = 0
+ else
+ time = tonumber(time)
+ end
+ --now time don't need to be parsed
+ return id, port
+end
+
+local function discover(id, layers, layer_repo)
+ local ref = layers[id]
+ if id == "<input>" or id == "<output>" then
+ return nil
+ end
+ if ref == nil then
+ local layer = layer_repo:get_layer(id)
+ local dim_in, dim_out = layer:get_dim()
+ ref = {
+ layer = layer,
+ inputs_m = {}, --storage for computation, inputs_m[port][time]
+ outputs_m = {},
+ err_inputs_m = {},
+ err_outputs_m = {},
+ conns_i = {}, --list of inputing connections
+ conns_o = {}, --list of outputing connections
+ dim_in = dim_in, --list of dimensions of ports
+ dim_out = dim_out,
+ }
+ layers[id] = ref
+ end
+ return ref
+end
+
+nerv.TNN.FC = {} --flag const
+nerv.TNN.FC.SEQ_START = 4
+nerv.TNN.FC.SEQ_END = 8
+nerv.TNN.FC.HAS_INPUT = 1
+nerv.TNN.FC.HAS_LABEL = 2
+nerv.TNN.FC.SEQ_NORM = bit.bor(nerv.TNN.FC.HAS_INPUT, nerv.TNN.FC.HAS_LABEL) --This instance have both input and label
+
+function DAGLayer.makeInitialStore(st, p, dim, batch_size, chunk_size, global_conf, st_c, p_c)
+ --Return a table of matrix storage from time (1-chunk_size)..(2*chunk_size)
+ if (type(st) ~= "table") then
+ nerv.error("st should be a table")
+ end
+ for i = 1 - chunk_size, chunk_size * 2 do
+ if (st[i] == nil) then
+ st[i] = {}
+ end
+ st[i][p] = global_conf.cumat_type(batch_size, dim)
+ st[i][p]:fill(0)
+ if (st_c ~= nil) then
+ if (st_c[i] == nil) then
+ st_c[i] = {}
+ end
+ st_c[i][p_c] = st[i][p]
+ end
+ end
+end
+
+function DAGLayer:__init(id, global_conf, layer_conf)
+ local layers = {}
+ local inputs_p = {} --map:port of the TDAGLayer to layer ref and port
+ local outputs_p = {}
+ local dim_in = layer_conf.dim_in
+ local dim_out = layer_conf.dim_out
+ local parsed_conns = {}
+ local _
+
+ for _, ll in pairs(layer_conf.connections) do
+ local id_from, port_from = parse_id(ll[1])
+ local id_to, port_to = parse_id(ll[2])
+ local time_to = ll[3]
+
+ print(id_from, id_to, time_to)
+
+ local ref_from = discover(id_from, layers, layer_conf.sub_layers)
+ local ref_to = discover(id_to, layers, layer_conf.sub_layers)
+
+ if (id_from == "<input>") then
+ if (dim_in[port_from] ~= ref_to.dim_in[port_to] or time_to ~= 0) then
+ nerv.error("mismatch dimension or wrong time %s,%s,%d", ll[1], ll[2], ll[3])
+ end
+ inputs_p[port_from] = {["ref"] = ref_to, ["port"] = port_to}
+ ref_to.inputs_m[port_to] = {} --just a place holder
+ elseif (id_to == "<output>") then
+ if (dim_out[port_to] ~= ref_from.dim_out[port_from] or time_to ~= 0) then
+ nerv.error("mismatch dimension or wrong time %s,%s,%d", ll[1], ll[2], ll[3])
+ end
+ outputs_p[port_to] = {["ref"] = ref_from, ["port"] = port_from}
+ ref_from.outputs_m[port_from] = {} --just a place holder
+ else
+ conn_now = {
+ ["src"] = {["ref"] = ref_from, ["port"] = port_from},
+ ["dst"] = {["ref"] = ref_to, ["port"] = port_to},
+ ["time"] = time_to
+ }
+ if (ref_to.dim_in[port_to] ~= ref_from.dim_out[port_from]) then
+ nerv.error("mismatch dimension or wrong time %s,%s,%d", ll[1], ll[2], ll[3])
+ end
+ table.insert(parsed_conns, conn_now)
+ table.insert(ref_to.conns_i, conn_now)
+ table.insert(ref_from.conns_o, conn_now)
+ end
+ end
+
+ self.layers = layers
+ self.inputs_p = inputs_p
+ self.outputs_p = outputs_p
+ self.id = id
+ self.dim_in = dim_in
+ self.dim_out = dim_out
+ self.parsed_conns = parsed_conns
+ self.gconf = global_conf
+end
+
+function DAGLayer:init(batch_size, chunk_size)
+ for i, conn in ipairs(self.parsed_conns) do --init storage for connections inside the NN
+ local _, output_dim
+ local ref_from, port_from, ref_to, port_to
+ ref_from, port_from = conn.src.ref, conn.src.port
+ ref_to, port_to = conn.dst.ref, conn.dst.port
+ local dim = ref_from.dim_out[port_from]
+ if (dim == 0) then
+ nerv.error("layer %s has a zero dim port", ref_from.layer.id)
+ end
+
+ print("TNN initing storage", ref_from.layer.id, "->", ref_to.layer.id)
+ self.makeInitialStore(ref_from.outputs_m, port_from, dim, batch_size, chunk_size, global_conf, ref_to.inputs_m, port_to)
+ self.makeInitialStore(ref_from.err_inputs_m, port_from, dim, batch_size, chunk_size, global_conf, ref_to.err_outputs_m, port_to)
+
+ end
+
+ self.outputs_m = {}
+ self.err_inputs_m = {}
+ for i = 1, #self.dim_out do --Init storage for output ports
+ local ref = self.outputs_p[i].ref
+ local p = self.outputs_p[i].port
+ self.makeInitialStore(ref.outputs_m, p, self.dim_out[i], batch_size, chunk_size, self.gconf, self.outputs_m, i)
+ self.makeInitialStore(ref.err_inputs_m, p, self.dim_out[i], batch_size, chunk_size, self.gconf, self.err_inputs_m, i)
+ end
+
+ self.inputs_m = {}
+ self.err_outputs_m = {}
+ for i = 1, #self.dim_in do --Init storage for input ports
+ local ref = self.inputs_p[i].ref
+ local p = self.inputs_p[i].port
+ self.makeInitialStore(ref.inputs_m, p, self.dim_in[i], batch_size, chunk_size, self.gconf, self.inputs_m, i)
+ self.makeInitialStore(ref.err_outputs_m, p, self.dim_in[i], batch_size, chunk_size, self.gconf, self.err_outputs_m, i)
+ end
+
+ for id, ref in pairs(self.layers) do --Calling init for child layers
+ for i = 1, #ref.dim_in do
+ if (ref.inputs_m[i] == nil or ref.err_outputs_m[i] == nil) then
+ nerv.error("dangling input port %d of layer %s", i, id)
+ end
+ end
+ for i = 1, #ref.dim_out do
+ if (ref.outputs_m[i] == nil or ref.err_inputs_m[i] == nil) then
+ nerv.error("dangling output port %d of layer %s", i, id)
+ end
+ end
+ -- initialize sub layers
+ ref.layer:init(batch_size)
+ end
+
+ local flags_now = {}
+ for i = 1, chunk_size do
+ flags_now[i] = {}
+ end
+
+ self.feeds_now = {} --feeds is for the reader to fill
+ self.feeds_now.inputs_m = self.inputs_m
+ self.feeds_now.flags_now = flags_now
+end
+
+--[[
+function DAGLayer:batch_resize(batch_size)
+ self.gconf.batch_size = batch_size
+
+ for i, conn in ipairs(self.parsed_conn) do
+ local _, output_dim
+ local ref_from, port_from, ref_to, port_to
+ ref_from, port_from = unpack(conn[1])
+ ref_to, port_to = unpack(conn[2])
+ _, output_dim = ref_from.layer:get_dim()
+
+ if ref_from.outputs[port_from]:nrow() ~= batch_size and output_dim[port_from] > 0 then
+ local mid = self.gconf.cumat_type(batch_size, output_dim[port_from])
+ local err_mid = mid:create()
+
+ ref_from.outputs[port_from] = mid
+ ref_to.inputs[port_to] = mid
+
+ ref_from.err_inputs[port_from] = err_mid
+ ref_to.err_outputs[port_to] = err_mid
+ end
+ end
+ for id, ref in pairs(self.layers) do
+ ref.layer:batch_resize(batch_size)
+ end
+ collectgarbage("collect")
+end
+]]--
+
+--reader: some reader
+--Returns: bool, whether has new feed
+--Returns: feeds, a table that will be filled with the reader's feeds
+function DAGLayer:getFeedFromReader(reader)
+ local feeds = self.feeds_now
+ local got_new = reader:get_batch(feeds)
+ return got_new, feeds
+end
+
+function DAGLayer:update(bp_err, input, output)
+ self:set_err_inputs(bp_err)
+ self:set_inputs(input)
+ self:set_outputs(output)
+ -- print("update")
+ for id, ref in pairs(self.queue) do
+ -- print(ref.layer.id)
+ ref.layer:update(ref.err_inputs, ref.inputs, ref.outputs)
+ end
+end
+
+function DAGLayer:propagate(input, output)
+ self:set_inputs(input)
+ self:set_outputs(output)
+ local ret = false
+ for i = 1, #self.queue do
+ local ref = self.queue[i]
+ -- print(ref.layer.id)
+ ret = ref.layer:propagate(ref.inputs, ref.outputs)
+ end
+ return ret
+end
+
+function DAGLayer:back_propagate(bp_err, next_bp_err, input, output)
+ self:set_err_outputs(next_bp_err)
+ self:set_err_inputs(bp_err)
+ self:set_inputs(input)
+ self:set_outputs(output)
+ for i = #self.queue, 1, -1 do
+ local ref = self.queue[i]
+ -- print(ref.layer.id)
+ ref.layer:back_propagate(ref.err_inputs, ref.err_outputs, ref.inputs, ref.outputs)
+ end
+end
+
+--Return: nerv.ParamRepo
+function DAGLayer:get_params()
+ local param_repos = {}
+ for id, ref in pairs(self.queue) do
+ table.insert(param_repos, ref.layer:get_params())
+ end
+ return nerv.ParamRepo.merge(param_repos)
+end
+
+DAGLayer.PORT_TYPES = {
+ INPUT = {},
+ OUTPUT = {},
+ ERR_INPUT = {},
+ ERR_OUTPUT = {}
+}
+
+function DAGLayer:get_intermediate(id, port_type)
+ if id == "<input>" or id == "<output>" then
+ nerv.error("an actual real layer id is expected")
+ end
+ local layer = self.layers[id]
+ if layer == nil then
+ nerv.error("layer id %s not found", id)
+ end
+ if port_type == DAGLayer.PORT_TYPES.INPUT then
+ return layer.inputs
+ elseif port_type == DAGLayer.PORT_TYPES.OUTPUT then
+ return layer.outputs
+ elseif port_type == DAGLayer.PORT_TYPES.ERR_INPUT then
+ return layer.err_inputs
+ elseif port_type == DAGLayer.PORT_TYPES.ERR_OUTPUT then
+ return layer.err_outputs
+ end
+ nerv.error("unrecognized port type")
+end