diff options
-rw-r--r-- | include/salticidae/conn.h | 1 | ||||
-rw-r--r-- | include/salticidae/msg.h | 9 | ||||
-rw-r--r-- | include/salticidae/network.h | 68 | ||||
-rw-r--r-- | test/test_msg.cpp | 40 |
4 files changed, 59 insertions, 59 deletions
diff --git a/include/salticidae/conn.h b/include/salticidae/conn.h index a254505..3742975 100644 --- a/include/salticidae/conn.h +++ b/include/salticidae/conn.h @@ -191,6 +191,7 @@ class ConnPool { int get_fd() const { return fd; } const NetAddr &get_addr() const { return addr; } ConnMode get_mode() const { return mode; } + ConnPool *get_pool() const { return cpool; } SegBuffer &read() { return recv_buffer; } /** Set the buffer size used for send/receive data. */ void set_seg_buff_size(size_t size) { seg_buff_size = size; } diff --git a/include/salticidae/msg.h b/include/salticidae/msg.h index 8a63a50..3a1eebf 100644 --- a/include/salticidae/msg.h +++ b/include/salticidae/msg.h @@ -56,6 +56,15 @@ class MsgBase { public: MsgBase(): magic(0x0), no_payload(true) {} + template<typename MsgType, + typename = typename std::enable_if< + !std::is_same<MsgType, MsgBase>::value && + !std::is_same<MsgType, uint8_t *>::value>::type> + MsgBase(const MsgType &msg): magic(0x0) { + set_opcode(MsgType::opcode); + set_payload(std::move(msg.serialized)); + } + MsgBase(const MsgBase &other): magic(other.magic), opcode(other.opcode), diff --git a/include/salticidae/network.h b/include/salticidae/network.h index 1e0f560..0d754d6 100644 --- a/include/salticidae/network.h +++ b/include/salticidae/network.h @@ -71,6 +71,7 @@ class MsgNetwork: public ConnPool { public callback_traits<ReturnType(Args...)> {}; class Conn: public ConnPool::Conn { + friend MsgNetwork; enum MsgState { HEADER, PAYLOAD @@ -78,7 +79,9 @@ class MsgNetwork: public ConnPool { Msg msg; MsgState msg_state; - MsgNetwork *mn; + MsgNetwork *get_net() { + return static_cast<MsgNetwork *>(get_pool()); + } protected: #ifdef SALTICIDAE_MSG_STAT @@ -87,9 +90,7 @@ class MsgNetwork: public ConnPool { #endif public: - friend MsgNetwork; - Conn(MsgNetwork *mn): - msg_state(HEADER), mn(mn) + Conn(): msg_state(HEADER) #ifdef SALTICIDAE_MSG_STAT , nsent(0), nrecv(0) #endif @@ -109,7 +110,6 @@ class MsgNetwork: public ConnPool { }; using conn_t = RcObj<Conn>; - using msg_callback_t = std::function<void(const Msg &msg, conn_t conn)>; #ifdef SALTICIDAE_MSG_STAT class msg_stat_by_opcode_t: public std::unordered_map<typename Msg::opcode_t, @@ -124,8 +124,9 @@ class MsgNetwork: public ConnPool { #endif private: - std::unordered_map<typename Msg::opcode_t, - msg_callback_t> handler_map; + std::unordered_map< + typename Msg::opcode_t, + std::function<void(const Msg &msg, conn_t conn)>> handler_map; protected: #ifdef SALTICIDAE_MSG_STAT @@ -133,10 +134,9 @@ class MsgNetwork: public ConnPool { mutable msg_stat_by_opcode_t recv_by_opcode; #endif - ConnPool::conn_t create_conn() override { return (new Conn(this))->self(); } + ConnPool::conn_t create_conn() override { return (new Conn())->self(); } public: - MsgNetwork(const EventContext &eb, int max_listen_backlog, double conn_server_timeout, @@ -181,12 +181,13 @@ class ClientNetwork: public MsgNetwork<OpcodeType> { public: class Conn: public MsgNet::Conn { - ClientNetwork *cn; + friend ClientNetwork; + ClientNetwork *get_net() { + return static_cast<ClientNetwork *>(ConnPool::Conn::get_pool()); + } public: - Conn(ClientNetwork *cn): - MsgNet::Conn(static_cast<MsgNet *>(cn)), - cn(cn) {} + Conn() = default; protected: void on_setup() override; @@ -196,7 +197,7 @@ class ClientNetwork: public MsgNetwork<OpcodeType> { using conn_t = RcObj<Conn>; protected: - ConnPool::conn_t create_conn() override { return (new Conn(this))->self(); } + ConnPool::conn_t create_conn() override { return (new Conn())->self(); } public: ClientNetwork(const EventContext &eb, @@ -231,16 +232,16 @@ class PeerNetwork: public MsgNetwork<OpcodeType> { }; class Conn: public MsgNet::Conn { + friend PeerNetwork; NetAddr peer_id; Event ev_timeout; - PeerNetwork *pn; + PeerNetwork *get_net() { + return static_cast<PeerNetwork *>(ConnPool::Conn::get_pool()); + } public: - friend PeerNetwork; + Conn() = default; const NetAddr &get_peer() { return peer_id; } - Conn(PeerNetwork *pn): - MsgNet::Conn(static_cast<MsgNet *>(pn)), - pn(pn) {} protected: void on_close() override { @@ -260,7 +261,6 @@ class PeerNetwork: public MsgNetwork<OpcodeType> { NetAddr addr; /** the underlying connection, may be invalid when connected = false */ conn_t conn; - PeerNetwork *pn; Event ev_ping_timer; Event ev_retry_timer; bool ping_timer_ok; @@ -268,8 +268,8 @@ class PeerNetwork: public MsgNetwork<OpcodeType> { bool connected; Peer() = delete; - Peer(NetAddr addr, conn_t conn, PeerNetwork *pn, const EventContext &eb): - addr(addr), conn(conn), pn(pn), + Peer(NetAddr addr, conn_t conn, const EventContext &eb): + addr(addr), conn(conn), ev_ping_timer( Event(eb, -1, 0, std::bind(&Peer::ping_timer, this, _1, _2))), connected(false) {} @@ -329,7 +329,7 @@ class PeerNetwork: public MsgNetwork<OpcodeType> { void start_active_conn(const NetAddr &paddr); protected: - ConnPool::conn_t create_conn() override { return (new Conn(this))->self(); } + ConnPool::conn_t create_conn() override { return (new Conn())->self(); } virtual double gen_conn_timeout() { return gen_rand_timeout(retry_conn_delay); } @@ -368,7 +368,7 @@ class PeerNetwork: public MsgNetwork<OpcodeType> { template<typename OpcodeType> void MsgNetwork<OpcodeType>::Conn::on_read() { auto &recv_buffer = read(); - auto conn = static_pointer_cast<Conn>(self()); + auto mn = get_net(); while (get_fd() != -1) { if (msg_state == Conn::HEADER) @@ -401,7 +401,7 @@ void MsgNetwork<OpcodeType>::Conn::on_read() { SALTICIDAE_LOG_DEBUG("got message %s from %s", std::string(msg).c_str(), std::string(*this).c_str()); - it->second(msg, conn); + it->second(msg, static_pointer_cast<Conn>(self())); #ifdef SALTICIDAE_MSG_STAT nrecv++; mn->recv_by_opcode.add(msg); @@ -430,6 +430,7 @@ void PeerNetwork<O, _, __>::Peer::reset_conn(conn_t new_conn) { template<typename O, O _, O __> void PeerNetwork<O, _, __>::Conn::on_setup() { + auto pn = get_net(); assert(!ev_timeout); ev_timeout = Event(pn->eb, -1, 0, [this](evutil_socket_t, short) { SALTICIDAE_LOG_INFO("peer ping-pong timeout"); @@ -443,6 +444,7 @@ void PeerNetwork<O, _, __>::Conn::on_setup() { template<typename O, O _, O __> void PeerNetwork<O, _, __>::Conn::on_teardown() { + auto pn = get_net(); auto it = pn->id2peer.find(peer_id); if (it == pn->id2peer.end()) return; auto p = it->second.get(); @@ -454,8 +456,7 @@ void PeerNetwork<O, _, __>::Conn::on_teardown() { std::string(*this).c_str(), std::string(peer_id).c_str()); p->ev_retry_timer = Event(pn->eb, -1, 0, - [pn = this->pn, - peer_id = this->peer_id](evutil_socket_t, short) { + [pn, peer_id = this->peer_id](evutil_socket_t, short) { pn->start_active_conn(peer_id); }); p->ev_retry_timer.add_with_timeout(pn->gen_conn_timeout()); @@ -545,7 +546,7 @@ void PeerNetwork<O, _, __>::add_peer(const NetAddr &addr) { auto it = id2peer.find(addr); if (it != id2peer.end()) throw PeerNetworkError("peer already exists"); - id2peer.insert(std::make_pair(addr, new Peer(addr, nullptr, this, this->eb))); + id2peer.insert(std::make_pair(addr, new Peer(addr, nullptr, this->eb))); peer_list.push_back(addr); start_active_conn(addr); } @@ -567,9 +568,7 @@ bool PeerNetwork<O, _, __>::has_peer(const NetAddr &paddr) const { template<typename OpcodeType> template<typename MsgType> void MsgNetwork<OpcodeType>::send_msg(const MsgType &_msg, conn_t conn) { - Msg msg; - msg.set_opcode(MsgType::opcode); - msg.set_payload(std::move(_msg.serialized)); + Msg msg(_msg); bytearray_t msg_data = msg.serialize(); SALTICIDAE_LOG_DEBUG("wrote message %s to %s", std::string(msg).c_str(), @@ -607,7 +606,8 @@ template<typename O, O _, O __> void PeerNetwork<O, _, __>::Peer::reset_ping_timer() { assert(ev_ping_timer); ev_ping_timer.del(); - ev_ping_timer.add_with_timeout(gen_rand_timeout(pn->ping_period)); + ev_ping_timer.add_with_timeout( + gen_rand_timeout(conn->get_net()->ping_period)); } template<typename O, O _, O __> @@ -620,6 +620,7 @@ void PeerNetwork<O, _, __>::reset_conn_timeout(conn_t conn) { template<typename O, O _, O __> void PeerNetwork<O, _, __>::Peer::send_ping() { + auto pn = conn->get_net(); ping_timer_ok = false; pong_msg_ok = false; pn->reset_conn_timeout(conn); @@ -645,6 +646,7 @@ template<typename OpcodeType> void ClientNetwork<OpcodeType>::Conn::on_setup() { assert(this->get_mode() == Conn::PASSIVE); const auto &addr = this->get_addr(); + auto cn = get_net(); cn->addr2conn.erase(addr); cn->addr2conn.insert( std::make_pair(addr, @@ -654,7 +656,7 @@ void ClientNetwork<OpcodeType>::Conn::on_setup() { template<typename OpcodeType> void ClientNetwork<OpcodeType>::Conn::on_teardown() { assert(this->get_mode() == Conn::PASSIVE); - cn->addr2conn.erase(this->get_addr()); + get_net()->addr2conn.erase(this->get_addr()); } template<typename OpcodeType> diff --git a/test/test_msg.cpp b/test/test_msg.cpp index 08be1ef..e4f35fa 100644 --- a/test/test_msg.cpp +++ b/test/test_msg.cpp @@ -29,27 +29,28 @@ using salticidae::uint256_t; using salticidae::DataStream; using salticidae::get_hash; using salticidae::get_hex; -/* -struct MsgTest: public salticidae::MsgBase<> { - using MsgBase::MsgBase; +using salticidae::htole; +using salticidae::letoh; - void gen_testhashes(int cnt) { - DataStream s; - set_opcode(0x01); - s << (uint32_t)cnt; +using opcode_t = uint8_t; + +struct MsgTest { + static const opcode_t opcode = 0x0; + DataStream serialized; + MsgTest(int cnt) { + serialized << htole((uint32_t)cnt); for (int i = 0; i < cnt; i++) { uint256_t hash = get_hash(i); printf("adding hash %s\n", get_hex(hash).c_str()); - s << hash; + serialized << hash; } - set_payload(std::move(s)); } - void parse_testhashes() { - DataStream s(get_payload()); + MsgTest(DataStream &&s) { uint32_t cnt; s >> cnt; + cnt = letoh(cnt); printf("got %d hashes\n", cnt); for (int i = 0; i < cnt; i++) { @@ -59,23 +60,10 @@ struct MsgTest: public salticidae::MsgBase<> { } } }; -*/ int main() { - /* - MsgTest msg; - msg.gen_ping(1234); - printf("%s\n", std::string(msg).c_str()); - msg.gen_testhashes(5); + salticidae::MsgBase<opcode_t> msg(MsgTest(10)); printf("%s\n", std::string(msg).c_str()); - msg.parse_testhashes(); - try - { - msg.parse_testhashes(); - } catch (std::runtime_error &e) { - printf("caught: %s\n", e.what()); - } - */ - salticidae::PeerNetwork<> pn(salticidae::EventContext()); + MsgTest parse(msg.get_payload()); return 0; } |