aboutsummaryrefslogtreecommitdiff
path: root/include/salticidae/network.h
diff options
context:
space:
mode:
Diffstat (limited to 'include/salticidae/network.h')
-rw-r--r--include/salticidae/network.h552
1 files changed, 552 insertions, 0 deletions
diff --git a/include/salticidae/network.h b/include/salticidae/network.h
new file mode 100644
index 0000000..3b82927
--- /dev/null
+++ b/include/salticidae/network.h
@@ -0,0 +1,552 @@
+/**
+ * Copyright (c) 2018 Cornell University.
+ *
+ * Author: Ted Yin <[email protected]>
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy of
+ * this software and associated documentation files (the "Software"), to deal in
+ * the Software without restriction, including without limitation the rights to
+ * use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies
+ * of the Software, and to permit persons to whom the Software is furnished to do
+ * so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#ifndef _SALTICIDAE_NETWORK_H
+#define _SALTICIDAE_NETWORK_H
+
+#include "salticidae/netaddr.h"
+#include "salticidae/msg.h"
+#include "salticidae/conn.h"
+
+namespace salticidae {
+
+/** Network of nodes who can send async messages. */
+template<typename MsgType>
+class MsgNetwork: public ConnPool {
+ public:
+ class Conn: public ConnPool::Conn {
+ enum MsgState {
+ HEADER,
+ PAYLOAD
+ };
+ MsgType msg;
+ MsgState msg_state;
+ MsgNetwork *mn;
+
+ protected:
+ mutable size_t nsent;
+ mutable size_t nrecv;
+
+ public:
+ friend MsgNetwork;
+ Conn(MsgNetwork *mn): msg_state(HEADER), mn(mn), nsent(0), nrecv(0) {}
+ size_t get_nsent() const { return nsent; }
+ size_t get_nrecv() const { return nrecv; }
+ void clear_nsent() const { nsent = 0; }
+ void clear_nrecv() const { nrecv = 0; }
+
+ protected:
+ void on_read() override;
+ void on_setup() override {}
+ void on_teardown() override {}
+ };
+
+ using conn_t = RcObj<Conn>;
+ using msg_callback_t = std::function<void(const MsgType &msg, conn_t conn)>;
+ class msg_stat_by_opcode_t:
+ public std::unordered_map<typename MsgType::opcode_t,
+ std::pair<uint32_t, size_t>> {
+ public:
+ void add(const MsgType &msg) {
+ auto &p = this->operator[](msg.get_opcode());
+ p.first++;
+ p.second += msg.get_length();
+ }
+ };
+
+ private:
+ std::unordered_map<typename MsgType::opcode_t,
+ msg_callback_t> handler_map;
+
+ protected:
+ mutable msg_stat_by_opcode_t sent_by_opcode;
+ mutable msg_stat_by_opcode_t recv_by_opcode;
+ uint16_t listen_port;
+
+ ConnPool::conn_t create_conn() override { return (new Conn(this))->self(); }
+
+ public:
+ MsgNetwork(struct event_base *eb): ConnPool(eb) {}
+ void reg_handler(typename MsgType::opcode_t opcode, msg_callback_t handler);
+ void send_msg(const MsgType &msg, conn_t conn);
+ void init(NetAddr listen_addr);
+ msg_stat_by_opcode_t &get_sent_by_opcode() const {
+ return sent_by_opcode;
+ }
+ msg_stat_by_opcode_t &get_recv_by_opcode() const {
+ return recv_by_opcode;
+ }
+};
+
+/** Simple network that handles client-server requests. */
+template<typename MsgType>
+class ClientNetwork: public MsgNetwork<MsgType> {
+ using MsgNet = MsgNetwork<MsgType>;
+ std::unordered_map<NetAddr, typename MsgNet::conn_t> addr2conn;
+
+ public:
+ class Conn: public MsgNet::Conn {
+ ClientNetwork *cn;
+
+ public:
+ Conn(ClientNetwork *cn):
+ MsgNet::Conn(static_cast<MsgNet *>(cn)),
+ cn(cn) {}
+
+ protected:
+ void on_setup() override;
+ void on_teardown() override;
+ };
+
+ protected:
+ ConnPool::conn_t create_conn() override { return (new Conn(this))->self(); }
+
+ public:
+ ClientNetwork(struct event_base *eb): MsgNet(eb) {}
+ void send_msg(const MsgType &msg, const NetAddr &addr);
+};
+
+class PeerNetworkError: public SalticidaeError {
+ using SalticidaeError::SalticidaeError;
+};
+
+/** Peer-to-peer network where any two nodes could hold a bi-diretional message
+ * channel, established by either side. */
+template<typename MsgType>
+class PeerNetwork: public MsgNetwork<MsgType> {
+ using MsgNet= MsgNetwork<MsgType>;
+ public:
+ enum IdentityMode {
+ IP_BASED,
+ IP_PORT_BASED
+ };
+
+ class Conn: public MsgNet::Conn {
+ NetAddr peer_id;
+ Event ev_timeout;
+ PeerNetwork *pn;
+
+ public:
+ friend PeerNetwork;
+ const NetAddr &get_peer() { return peer_id; }
+ Conn(PeerNetwork *pn):
+ MsgNet::Conn(static_cast<MsgNet *>(pn)),
+ pn(pn) {}
+
+ protected:
+ void close() override {
+ ev_timeout.clear();
+ MsgNet::Conn::close();
+ }
+
+ void on_setup() override;
+ void on_teardown() override;
+ };
+
+ using conn_t = RcObj<Conn>;
+
+ private:
+ struct Peer {
+ /** connection addr, may be different due to passive mode */
+ NetAddr addr;
+ /** the underlying connection, may be invalid when connected = false */
+ conn_t conn;
+ PeerNetwork *pn;
+ Event ev_ping_timer;
+ bool ping_timer_ok;
+ bool pong_msg_ok;
+ bool connected;
+
+ Peer() = delete;
+ Peer(NetAddr addr, conn_t conn, PeerNetwork *pn, struct event_base *eb):
+ addr(addr), conn(conn), pn(pn),
+ ev_ping_timer(
+ Event(eb, -1, 0, std::bind(&Peer::ping_timer, this, _1, _2))),
+ connected(false) {}
+ ~Peer() {}
+ Peer &operator=(const Peer &) = delete;
+ Peer(const Peer &) = delete;
+
+ void ping_timer(evutil_socket_t, short);
+ void reset_ping_timer();
+ void send_ping();
+ void clear_all_events() {
+ if (ev_ping_timer)
+ ev_ping_timer.del();
+ }
+ void reset_conn(conn_t conn);
+ };
+
+ std::unordered_map <NetAddr, Peer *> id2peer;
+ std::vector<NetAddr> peer_list;
+
+ IdentityMode id_mode;
+ double ping_period;
+ double conn_timeout;
+
+ void msg_ping(const MsgType &msg, ConnPool::conn_t conn);
+ void msg_pong(const MsgType &msg, ConnPool::conn_t conn);
+ void reset_conn_timeout(conn_t conn);
+ bool check_new_conn(conn_t conn, uint16_t port);
+ void start_active_conn(const NetAddr &paddr);
+
+ protected:
+ ConnPool::conn_t create_conn() override { return (new Conn(this))->self(); }
+
+ public:
+ PeerNetwork(struct event_base *eb,
+ double ping_period = 30,
+ double conn_timeout = 180,
+ IdentityMode id_mode = IP_PORT_BASED):
+ MsgNet(eb),
+ id_mode(id_mode),
+ ping_period(ping_period),
+ conn_timeout(conn_timeout) {}
+
+ void add_peer(const NetAddr &paddr);
+ const conn_t get_peer_conn(const NetAddr &paddr) const;
+ void send_msg(const MsgType &msg, const Peer *peer);
+ void send_msg(const MsgType &msg, const NetAddr &paddr);
+ void init(NetAddr listen_addr);
+ bool has_peer(const NetAddr &paddr) const;
+ const std::vector<NetAddr> &all_peers() const;
+ using ConnPool::create_conn;
+};
+
+template<typename MsgType>
+void MsgNetwork<MsgType>::Conn::on_read() {
+ auto &recv_buffer = read();
+ auto conn = static_pointer_cast<Conn>(self());
+ while (get_fd() != -1)
+ {
+ if (msg_state == Conn::HEADER)
+ {
+ if (recv_buffer.size() < MsgType::header_size) break;
+ /* new header available */
+ bytearray_t data = recv_buffer.pop(MsgType::header_size);
+ msg = MsgType(data.data());
+ msg_state = Conn::PAYLOAD;
+ }
+ if (msg_state == Conn::PAYLOAD)
+ {
+ size_t len = msg.get_length();
+ if (recv_buffer.size() < len) break;
+ /* new payload available */
+ bytearray_t data = recv_buffer.pop(len);
+ msg.set_payload(std::move(data));
+ msg_state = Conn::HEADER;
+ if (!msg.verify_checksum())
+ {
+ SALTICIDAE_LOG_WARN("checksums do not match, dropping the message");
+ return;
+ }
+ auto it = mn->handler_map.find(msg.get_opcode());
+ if (it == mn->handler_map.end())
+ SALTICIDAE_LOG_WARN("unknown command: %s", get_hex(msg.get_opcode()));
+ else /* call the handler */
+ {
+ SALTICIDAE_LOG_DEBUG("got message %s from %s",
+ std::string(msg).c_str(),
+ std::string(*this).c_str());
+ it->second(msg, conn);
+ nrecv++;
+ mn->recv_by_opcode.add(msg);
+ }
+ }
+ }
+}
+
+template<typename MsgType>
+void PeerNetwork<MsgType>::Peer::reset_conn(conn_t new_conn) {
+ if (conn != new_conn)
+ {
+ if (conn)
+ {
+ SALTICIDAE_LOG_DEBUG("moving send buffer");
+ new_conn->move_send_buffer(conn);
+ SALTICIDAE_LOG_INFO("terminating old connection %s", std::string(*conn).c_str());
+ conn->terminate();
+ }
+ addr = new_conn->get_addr();
+ conn = new_conn;
+ }
+ clear_all_events();
+}
+
+template<typename MsgType>
+void PeerNetwork<MsgType>::Conn::on_setup() {
+ assert(!ev_timeout);
+ ev_timeout = Event(pn->eb, -1, 0, [this](evutil_socket_t, short) {
+ SALTICIDAE_LOG_INFO("peer ping-pong timeout");
+ this->terminate();
+ });
+ /* the initial ping-pong to set up the connection */
+ MsgType ping;
+ ping.gen_ping(pn->listen_port);
+ auto conn = static_pointer_cast<Conn>(this->self());
+ pn->reset_conn_timeout(conn);
+ pn->MsgNet::send_msg(ping, conn);
+}
+
+template<typename MsgType>
+void PeerNetwork<MsgType>::Conn::on_teardown() {
+ auto it = pn->id2peer.find(peer_id);
+ if (it == pn->id2peer.end()) return;
+ Peer *p = it->second;
+ if (this != p->conn) return;
+ p->ev_ping_timer.del();
+ p->connected = false;
+ p->conn = nullptr;
+ SALTICIDAE_LOG_INFO("connection lost %s for %s",
+ std::string(*this).c_str(),
+ std::string(peer_id).c_str());
+ pn->start_active_conn(peer_id);
+}
+
+template<typename MsgType>
+bool PeerNetwork<MsgType>::check_new_conn(conn_t conn, uint16_t port) {
+ if (conn->peer_id.is_null())
+ { /* passive connections can eventually have ids after getting the port
+ number in IP_BASED_PORT mode */
+ assert(id_mode == IP_PORT_BASED);
+ conn->peer_id.ip = conn->get_addr().ip;
+ conn->peer_id.port = port;
+ }
+ Peer *p = id2peer.find(conn->peer_id)->second;
+ if (p->connected)
+ {
+ if (conn != p->conn)
+ {
+ conn->terminate();
+ return true;
+ }
+ return false;
+ }
+ p->reset_conn(conn);
+ p->connected = true;
+ p->reset_ping_timer();
+ p->send_ping();
+ if (p->connected)
+ SALTICIDAE_LOG_INFO("PeerNetwork: established connection with id %s via %s",
+ std::string(conn->peer_id).c_str(), std::string(*conn).c_str());
+ return false;
+}
+
+template<typename MsgType>
+void PeerNetwork<MsgType>::msg_ping(const MsgType &msg, ConnPool::conn_t conn_) {
+ auto conn = static_pointer_cast<Conn>(conn_);
+ uint16_t port;
+ msg.parse_ping(port);
+ SALTICIDAE_LOG_INFO("ping from %s, port %u", std::string(*conn).c_str(), ntohs(port));
+ if (check_new_conn(conn, port)) return;
+ Peer *p = id2peer.find(conn->peer_id)->second;
+ MsgType pong;
+ pong.gen_pong(this->listen_port);
+ send_msg(pong, p);
+}
+
+template<typename MsgType>
+void PeerNetwork<MsgType>::msg_pong(const MsgType &msg, ConnPool::conn_t conn_) {
+ auto conn = static_pointer_cast<Conn>(conn_);
+ auto it = id2peer.find(conn->peer_id);
+ if (it == id2peer.end())
+ {
+ SALTICIDAE_LOG_WARN("pong message discarded");
+ return;
+ }
+ Peer *p = it->second;
+ uint16_t port;
+ msg.parse_pong(port);
+ if (check_new_conn(conn, port)) return;
+ p->pong_msg_ok = true;
+ if (p->ping_timer_ok)
+ {
+ p->reset_ping_timer();
+ p->send_ping();
+ }
+}
+
+template<typename MsgType>
+void MsgNetwork<MsgType>::init(NetAddr listen_addr) {
+ listen_port = listen_addr.port;
+ ConnPool::init(listen_addr);
+}
+
+template<typename MsgType>
+void PeerNetwork<MsgType>::init(NetAddr listen_addr) {
+ MsgNet::init(listen_addr);
+ this->reg_handler(MsgType::OPCODE_PING,
+ std::bind(&PeerNetwork::msg_ping, this, _1, _2));
+ this->reg_handler(MsgType::OPCODE_PONG,
+ std::bind(&PeerNetwork::msg_pong, this, _1, _2));
+}
+
+template<typename MsgType>
+void PeerNetwork<MsgType>::start_active_conn(const NetAddr &addr) {
+ Peer *p = id2peer.find(addr)->second;
+ if (p->connected) return;
+ auto conn = static_pointer_cast<Conn>(create_conn(addr));
+ assert(p->conn == nullptr);
+ p->conn = conn;
+ conn->peer_id = addr;
+ if (id_mode == IP_BASED)
+ conn->peer_id.port = 0;
+}
+
+template<typename MsgType>
+void PeerNetwork<MsgType>::add_peer(const NetAddr &addr) {
+ auto it = id2peer.find(addr);
+ if (it != id2peer.end())
+ throw PeerNetworkError("peer already exists");
+ id2peer.insert(std::make_pair(addr, new Peer(addr, nullptr, this, this->eb)));
+ peer_list.push_back(addr);
+ start_active_conn(addr);
+}
+
+template<typename MsgType>
+const typename PeerNetwork<MsgType>::conn_t
+PeerNetwork<MsgType>::get_peer_conn(const NetAddr &paddr) const {
+ auto it = id2peer.find(paddr);
+ if (it == id2peer.end())
+ throw PeerNetworkError("peer does not exist");
+ return it->second->conn;
+}
+
+template<typename MsgType>
+bool PeerNetwork<MsgType>::has_peer(const NetAddr &paddr) const {
+ return id2peer.count(paddr);
+}
+
+template<typename MsgType>
+void MsgNetwork<MsgType>::reg_handler(typename MsgType::opcode_t opcode,
+ msg_callback_t handler) {
+ handler_map[opcode] = handler;
+}
+
+template<typename MsgType>
+void MsgNetwork<MsgType>::send_msg(const MsgType &msg, conn_t conn) {
+ bytearray_t msg_data = msg.serialize();
+ SALTICIDAE_LOG_DEBUG("wrote message %s to %s",
+ std::string(msg).c_str(),
+ std::string(*conn).c_str());
+ conn->write(std::move(msg_data));
+ conn->nsent++;
+ sent_by_opcode.add(msg);
+}
+
+template<typename MsgType>
+void PeerNetwork<MsgType>::send_msg(const MsgType &msg, const Peer *peer) {
+ bytearray_t msg_data = msg.serialize();
+ SALTICIDAE_LOG_DEBUG("wrote message %s to %s",
+ std::string(msg).c_str(),
+ std::string(peer->addr).c_str());
+ if (peer->connected)
+ {
+ SALTICIDAE_LOG_DEBUG("wrote to ConnPool");
+ peer->conn->write(std::move(msg_data));
+ }
+ else
+ {
+ SALTICIDAE_LOG_DEBUG("dropped");
+ }
+ peer->conn->nsent++;
+ this->sent_by_opcode.add(msg);
+}
+
+template<typename MsgType>
+void PeerNetwork<MsgType>::send_msg(const MsgType &msg, const NetAddr &addr) {
+ auto it = id2peer.find(addr);
+ if (it == id2peer.end())
+ {
+ SALTICIDAE_LOG_ERROR("sending to non-existing peer: %s", std::string(addr).c_str());
+ throw PeerNetworkError("peer does not exist");
+ }
+ send_msg(msg, it->second);
+}
+
+template<typename MsgType>
+void PeerNetwork<MsgType>::Peer::reset_ping_timer() {
+ assert(ev_ping_timer);
+ ev_ping_timer.del();
+ ev_ping_timer.add_with_timeout(gen_rand_timeout(pn->ping_period));
+}
+
+template<typename MsgType>
+void PeerNetwork<MsgType>::reset_conn_timeout(conn_t conn) {
+ assert(conn->ev_timeout);
+ conn->ev_timeout.del();
+ conn->ev_timeout.add_with_timeout(conn_timeout);
+}
+
+template<typename MsgType>
+void PeerNetwork<MsgType>::Peer::send_ping() {
+ ping_timer_ok = false;
+ pong_msg_ok = false;
+ MsgType ping;
+ ping.gen_ping(pn->listen_port);
+ pn->reset_conn_timeout(conn);
+ pn->send_msg(ping, this);
+}
+
+template<typename MsgType>
+void PeerNetwork<MsgType>::Peer::ping_timer(evutil_socket_t, short) {
+ ping_timer_ok = true;
+ if (pong_msg_ok)
+ {
+ reset_ping_timer();
+ send_ping();
+ }
+}
+
+template<typename MsgType>
+const std::vector<NetAddr> &PeerNetwork<MsgType>::all_peers() const {
+ return peer_list;
+}
+
+template<typename MsgType>
+void ClientNetwork<MsgType>::Conn::on_setup() {
+ assert(this->get_mode() == Conn::PASSIVE);
+ const auto &addr = this->get_addr();
+ cn->addr2conn.erase(addr);
+ cn->addr2conn.insert(
+ std::make_pair(addr,
+ static_pointer_cast<Conn>(this->self())));
+}
+
+template<typename MsgType>
+void ClientNetwork<MsgType>::Conn::on_teardown() {
+ assert(this->get_mode() == Conn::PASSIVE);
+ cn->addr2conn.erase(this->get_addr());
+}
+
+template<typename MsgType>
+void ClientNetwork<MsgType>::send_msg(const MsgType &msg, const NetAddr &addr) {
+ auto it = addr2conn.find(addr);
+ if (it == addr2conn.end()) return;
+ MsgNet::send_msg(msg, it->second);
+}
+
+}
+
+#endif