diff options
Diffstat (limited to 'include/salticidae/network.h')
-rw-r--r-- | include/salticidae/network.h | 552 |
1 files changed, 552 insertions, 0 deletions
diff --git a/include/salticidae/network.h b/include/salticidae/network.h new file mode 100644 index 0000000..3b82927 --- /dev/null +++ b/include/salticidae/network.h @@ -0,0 +1,552 @@ +/** + * Copyright (c) 2018 Cornell University. + * + * Author: Ted Yin <[email protected]> + * + * Permission is hereby granted, free of charge, to any person obtaining a copy of + * this software and associated documentation files (the "Software"), to deal in + * the Software without restriction, including without limitation the rights to + * use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies + * of the Software, and to permit persons to whom the Software is furnished to do + * so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ + +#ifndef _SALTICIDAE_NETWORK_H +#define _SALTICIDAE_NETWORK_H + +#include "salticidae/netaddr.h" +#include "salticidae/msg.h" +#include "salticidae/conn.h" + +namespace salticidae { + +/** Network of nodes who can send async messages. */ +template<typename MsgType> +class MsgNetwork: public ConnPool { + public: + class Conn: public ConnPool::Conn { + enum MsgState { + HEADER, + PAYLOAD + }; + MsgType msg; + MsgState msg_state; + MsgNetwork *mn; + + protected: + mutable size_t nsent; + mutable size_t nrecv; + + public: + friend MsgNetwork; + Conn(MsgNetwork *mn): msg_state(HEADER), mn(mn), nsent(0), nrecv(0) {} + size_t get_nsent() const { return nsent; } + size_t get_nrecv() const { return nrecv; } + void clear_nsent() const { nsent = 0; } + void clear_nrecv() const { nrecv = 0; } + + protected: + void on_read() override; + void on_setup() override {} + void on_teardown() override {} + }; + + using conn_t = RcObj<Conn>; + using msg_callback_t = std::function<void(const MsgType &msg, conn_t conn)>; + class msg_stat_by_opcode_t: + public std::unordered_map<typename MsgType::opcode_t, + std::pair<uint32_t, size_t>> { + public: + void add(const MsgType &msg) { + auto &p = this->operator[](msg.get_opcode()); + p.first++; + p.second += msg.get_length(); + } + }; + + private: + std::unordered_map<typename MsgType::opcode_t, + msg_callback_t> handler_map; + + protected: + mutable msg_stat_by_opcode_t sent_by_opcode; + mutable msg_stat_by_opcode_t recv_by_opcode; + uint16_t listen_port; + + ConnPool::conn_t create_conn() override { return (new Conn(this))->self(); } + + public: + MsgNetwork(struct event_base *eb): ConnPool(eb) {} + void reg_handler(typename MsgType::opcode_t opcode, msg_callback_t handler); + void send_msg(const MsgType &msg, conn_t conn); + void init(NetAddr listen_addr); + msg_stat_by_opcode_t &get_sent_by_opcode() const { + return sent_by_opcode; + } + msg_stat_by_opcode_t &get_recv_by_opcode() const { + return recv_by_opcode; + } +}; + +/** Simple network that handles client-server requests. */ +template<typename MsgType> +class ClientNetwork: public MsgNetwork<MsgType> { + using MsgNet = MsgNetwork<MsgType>; + std::unordered_map<NetAddr, typename MsgNet::conn_t> addr2conn; + + public: + class Conn: public MsgNet::Conn { + ClientNetwork *cn; + + public: + Conn(ClientNetwork *cn): + MsgNet::Conn(static_cast<MsgNet *>(cn)), + cn(cn) {} + + protected: + void on_setup() override; + void on_teardown() override; + }; + + protected: + ConnPool::conn_t create_conn() override { return (new Conn(this))->self(); } + + public: + ClientNetwork(struct event_base *eb): MsgNet(eb) {} + void send_msg(const MsgType &msg, const NetAddr &addr); +}; + +class PeerNetworkError: public SalticidaeError { + using SalticidaeError::SalticidaeError; +}; + +/** Peer-to-peer network where any two nodes could hold a bi-diretional message + * channel, established by either side. */ +template<typename MsgType> +class PeerNetwork: public MsgNetwork<MsgType> { + using MsgNet= MsgNetwork<MsgType>; + public: + enum IdentityMode { + IP_BASED, + IP_PORT_BASED + }; + + class Conn: public MsgNet::Conn { + NetAddr peer_id; + Event ev_timeout; + PeerNetwork *pn; + + public: + friend PeerNetwork; + const NetAddr &get_peer() { return peer_id; } + Conn(PeerNetwork *pn): + MsgNet::Conn(static_cast<MsgNet *>(pn)), + pn(pn) {} + + protected: + void close() override { + ev_timeout.clear(); + MsgNet::Conn::close(); + } + + void on_setup() override; + void on_teardown() override; + }; + + using conn_t = RcObj<Conn>; + + private: + struct Peer { + /** connection addr, may be different due to passive mode */ + NetAddr addr; + /** the underlying connection, may be invalid when connected = false */ + conn_t conn; + PeerNetwork *pn; + Event ev_ping_timer; + bool ping_timer_ok; + bool pong_msg_ok; + bool connected; + + Peer() = delete; + Peer(NetAddr addr, conn_t conn, PeerNetwork *pn, struct event_base *eb): + addr(addr), conn(conn), pn(pn), + ev_ping_timer( + Event(eb, -1, 0, std::bind(&Peer::ping_timer, this, _1, _2))), + connected(false) {} + ~Peer() {} + Peer &operator=(const Peer &) = delete; + Peer(const Peer &) = delete; + + void ping_timer(evutil_socket_t, short); + void reset_ping_timer(); + void send_ping(); + void clear_all_events() { + if (ev_ping_timer) + ev_ping_timer.del(); + } + void reset_conn(conn_t conn); + }; + + std::unordered_map <NetAddr, Peer *> id2peer; + std::vector<NetAddr> peer_list; + + IdentityMode id_mode; + double ping_period; + double conn_timeout; + + void msg_ping(const MsgType &msg, ConnPool::conn_t conn); + void msg_pong(const MsgType &msg, ConnPool::conn_t conn); + void reset_conn_timeout(conn_t conn); + bool check_new_conn(conn_t conn, uint16_t port); + void start_active_conn(const NetAddr &paddr); + + protected: + ConnPool::conn_t create_conn() override { return (new Conn(this))->self(); } + + public: + PeerNetwork(struct event_base *eb, + double ping_period = 30, + double conn_timeout = 180, + IdentityMode id_mode = IP_PORT_BASED): + MsgNet(eb), + id_mode(id_mode), + ping_period(ping_period), + conn_timeout(conn_timeout) {} + + void add_peer(const NetAddr &paddr); + const conn_t get_peer_conn(const NetAddr &paddr) const; + void send_msg(const MsgType &msg, const Peer *peer); + void send_msg(const MsgType &msg, const NetAddr &paddr); + void init(NetAddr listen_addr); + bool has_peer(const NetAddr &paddr) const; + const std::vector<NetAddr> &all_peers() const; + using ConnPool::create_conn; +}; + +template<typename MsgType> +void MsgNetwork<MsgType>::Conn::on_read() { + auto &recv_buffer = read(); + auto conn = static_pointer_cast<Conn>(self()); + while (get_fd() != -1) + { + if (msg_state == Conn::HEADER) + { + if (recv_buffer.size() < MsgType::header_size) break; + /* new header available */ + bytearray_t data = recv_buffer.pop(MsgType::header_size); + msg = MsgType(data.data()); + msg_state = Conn::PAYLOAD; + } + if (msg_state == Conn::PAYLOAD) + { + size_t len = msg.get_length(); + if (recv_buffer.size() < len) break; + /* new payload available */ + bytearray_t data = recv_buffer.pop(len); + msg.set_payload(std::move(data)); + msg_state = Conn::HEADER; + if (!msg.verify_checksum()) + { + SALTICIDAE_LOG_WARN("checksums do not match, dropping the message"); + return; + } + auto it = mn->handler_map.find(msg.get_opcode()); + if (it == mn->handler_map.end()) + SALTICIDAE_LOG_WARN("unknown command: %s", get_hex(msg.get_opcode())); + else /* call the handler */ + { + SALTICIDAE_LOG_DEBUG("got message %s from %s", + std::string(msg).c_str(), + std::string(*this).c_str()); + it->second(msg, conn); + nrecv++; + mn->recv_by_opcode.add(msg); + } + } + } +} + +template<typename MsgType> +void PeerNetwork<MsgType>::Peer::reset_conn(conn_t new_conn) { + if (conn != new_conn) + { + if (conn) + { + SALTICIDAE_LOG_DEBUG("moving send buffer"); + new_conn->move_send_buffer(conn); + SALTICIDAE_LOG_INFO("terminating old connection %s", std::string(*conn).c_str()); + conn->terminate(); + } + addr = new_conn->get_addr(); + conn = new_conn; + } + clear_all_events(); +} + +template<typename MsgType> +void PeerNetwork<MsgType>::Conn::on_setup() { + assert(!ev_timeout); + ev_timeout = Event(pn->eb, -1, 0, [this](evutil_socket_t, short) { + SALTICIDAE_LOG_INFO("peer ping-pong timeout"); + this->terminate(); + }); + /* the initial ping-pong to set up the connection */ + MsgType ping; + ping.gen_ping(pn->listen_port); + auto conn = static_pointer_cast<Conn>(this->self()); + pn->reset_conn_timeout(conn); + pn->MsgNet::send_msg(ping, conn); +} + +template<typename MsgType> +void PeerNetwork<MsgType>::Conn::on_teardown() { + auto it = pn->id2peer.find(peer_id); + if (it == pn->id2peer.end()) return; + Peer *p = it->second; + if (this != p->conn) return; + p->ev_ping_timer.del(); + p->connected = false; + p->conn = nullptr; + SALTICIDAE_LOG_INFO("connection lost %s for %s", + std::string(*this).c_str(), + std::string(peer_id).c_str()); + pn->start_active_conn(peer_id); +} + +template<typename MsgType> +bool PeerNetwork<MsgType>::check_new_conn(conn_t conn, uint16_t port) { + if (conn->peer_id.is_null()) + { /* passive connections can eventually have ids after getting the port + number in IP_BASED_PORT mode */ + assert(id_mode == IP_PORT_BASED); + conn->peer_id.ip = conn->get_addr().ip; + conn->peer_id.port = port; + } + Peer *p = id2peer.find(conn->peer_id)->second; + if (p->connected) + { + if (conn != p->conn) + { + conn->terminate(); + return true; + } + return false; + } + p->reset_conn(conn); + p->connected = true; + p->reset_ping_timer(); + p->send_ping(); + if (p->connected) + SALTICIDAE_LOG_INFO("PeerNetwork: established connection with id %s via %s", + std::string(conn->peer_id).c_str(), std::string(*conn).c_str()); + return false; +} + +template<typename MsgType> +void PeerNetwork<MsgType>::msg_ping(const MsgType &msg, ConnPool::conn_t conn_) { + auto conn = static_pointer_cast<Conn>(conn_); + uint16_t port; + msg.parse_ping(port); + SALTICIDAE_LOG_INFO("ping from %s, port %u", std::string(*conn).c_str(), ntohs(port)); + if (check_new_conn(conn, port)) return; + Peer *p = id2peer.find(conn->peer_id)->second; + MsgType pong; + pong.gen_pong(this->listen_port); + send_msg(pong, p); +} + +template<typename MsgType> +void PeerNetwork<MsgType>::msg_pong(const MsgType &msg, ConnPool::conn_t conn_) { + auto conn = static_pointer_cast<Conn>(conn_); + auto it = id2peer.find(conn->peer_id); + if (it == id2peer.end()) + { + SALTICIDAE_LOG_WARN("pong message discarded"); + return; + } + Peer *p = it->second; + uint16_t port; + msg.parse_pong(port); + if (check_new_conn(conn, port)) return; + p->pong_msg_ok = true; + if (p->ping_timer_ok) + { + p->reset_ping_timer(); + p->send_ping(); + } +} + +template<typename MsgType> +void MsgNetwork<MsgType>::init(NetAddr listen_addr) { + listen_port = listen_addr.port; + ConnPool::init(listen_addr); +} + +template<typename MsgType> +void PeerNetwork<MsgType>::init(NetAddr listen_addr) { + MsgNet::init(listen_addr); + this->reg_handler(MsgType::OPCODE_PING, + std::bind(&PeerNetwork::msg_ping, this, _1, _2)); + this->reg_handler(MsgType::OPCODE_PONG, + std::bind(&PeerNetwork::msg_pong, this, _1, _2)); +} + +template<typename MsgType> +void PeerNetwork<MsgType>::start_active_conn(const NetAddr &addr) { + Peer *p = id2peer.find(addr)->second; + if (p->connected) return; + auto conn = static_pointer_cast<Conn>(create_conn(addr)); + assert(p->conn == nullptr); + p->conn = conn; + conn->peer_id = addr; + if (id_mode == IP_BASED) + conn->peer_id.port = 0; +} + +template<typename MsgType> +void PeerNetwork<MsgType>::add_peer(const NetAddr &addr) { + auto it = id2peer.find(addr); + if (it != id2peer.end()) + throw PeerNetworkError("peer already exists"); + id2peer.insert(std::make_pair(addr, new Peer(addr, nullptr, this, this->eb))); + peer_list.push_back(addr); + start_active_conn(addr); +} + +template<typename MsgType> +const typename PeerNetwork<MsgType>::conn_t +PeerNetwork<MsgType>::get_peer_conn(const NetAddr &paddr) const { + auto it = id2peer.find(paddr); + if (it == id2peer.end()) + throw PeerNetworkError("peer does not exist"); + return it->second->conn; +} + +template<typename MsgType> +bool PeerNetwork<MsgType>::has_peer(const NetAddr &paddr) const { + return id2peer.count(paddr); +} + +template<typename MsgType> +void MsgNetwork<MsgType>::reg_handler(typename MsgType::opcode_t opcode, + msg_callback_t handler) { + handler_map[opcode] = handler; +} + +template<typename MsgType> +void MsgNetwork<MsgType>::send_msg(const MsgType &msg, conn_t conn) { + bytearray_t msg_data = msg.serialize(); + SALTICIDAE_LOG_DEBUG("wrote message %s to %s", + std::string(msg).c_str(), + std::string(*conn).c_str()); + conn->write(std::move(msg_data)); + conn->nsent++; + sent_by_opcode.add(msg); +} + +template<typename MsgType> +void PeerNetwork<MsgType>::send_msg(const MsgType &msg, const Peer *peer) { + bytearray_t msg_data = msg.serialize(); + SALTICIDAE_LOG_DEBUG("wrote message %s to %s", + std::string(msg).c_str(), + std::string(peer->addr).c_str()); + if (peer->connected) + { + SALTICIDAE_LOG_DEBUG("wrote to ConnPool"); + peer->conn->write(std::move(msg_data)); + } + else + { + SALTICIDAE_LOG_DEBUG("dropped"); + } + peer->conn->nsent++; + this->sent_by_opcode.add(msg); +} + +template<typename MsgType> +void PeerNetwork<MsgType>::send_msg(const MsgType &msg, const NetAddr &addr) { + auto it = id2peer.find(addr); + if (it == id2peer.end()) + { + SALTICIDAE_LOG_ERROR("sending to non-existing peer: %s", std::string(addr).c_str()); + throw PeerNetworkError("peer does not exist"); + } + send_msg(msg, it->second); +} + +template<typename MsgType> +void PeerNetwork<MsgType>::Peer::reset_ping_timer() { + assert(ev_ping_timer); + ev_ping_timer.del(); + ev_ping_timer.add_with_timeout(gen_rand_timeout(pn->ping_period)); +} + +template<typename MsgType> +void PeerNetwork<MsgType>::reset_conn_timeout(conn_t conn) { + assert(conn->ev_timeout); + conn->ev_timeout.del(); + conn->ev_timeout.add_with_timeout(conn_timeout); +} + +template<typename MsgType> +void PeerNetwork<MsgType>::Peer::send_ping() { + ping_timer_ok = false; + pong_msg_ok = false; + MsgType ping; + ping.gen_ping(pn->listen_port); + pn->reset_conn_timeout(conn); + pn->send_msg(ping, this); +} + +template<typename MsgType> +void PeerNetwork<MsgType>::Peer::ping_timer(evutil_socket_t, short) { + ping_timer_ok = true; + if (pong_msg_ok) + { + reset_ping_timer(); + send_ping(); + } +} + +template<typename MsgType> +const std::vector<NetAddr> &PeerNetwork<MsgType>::all_peers() const { + return peer_list; +} + +template<typename MsgType> +void ClientNetwork<MsgType>::Conn::on_setup() { + assert(this->get_mode() == Conn::PASSIVE); + const auto &addr = this->get_addr(); + cn->addr2conn.erase(addr); + cn->addr2conn.insert( + std::make_pair(addr, + static_pointer_cast<Conn>(this->self()))); +} + +template<typename MsgType> +void ClientNetwork<MsgType>::Conn::on_teardown() { + assert(this->get_mode() == Conn::PASSIVE); + cn->addr2conn.erase(this->get_addr()); +} + +template<typename MsgType> +void ClientNetwork<MsgType>::send_msg(const MsgType &msg, const NetAddr &addr) { + auto it = addr2conn.find(addr); + if (it == addr2conn.end()) return; + MsgNet::send_msg(msg, it->second); +} + +} + +#endif |