diff options
Diffstat (limited to 'include/hotstuff')
-rw-r--r-- | include/hotstuff/consensus.h | 7 | ||||
-rw-r--r-- | include/hotstuff/crypto.h | 36 | ||||
-rw-r--r-- | include/hotstuff/entity.h | 20 | ||||
-rw-r--r-- | include/hotstuff/hotstuff.h | 4 | ||||
-rw-r--r-- | include/hotstuff/worker.h | 92 |
5 files changed, 151 insertions, 8 deletions
diff --git a/include/hotstuff/consensus.h b/include/hotstuff/consensus.h index 65ffff2..9e2558c 100644 --- a/include/hotstuff/consensus.h +++ b/include/hotstuff/consensus.h @@ -253,6 +253,13 @@ struct Vote: public Serializable { cert->get_blk_hash() == blk_hash; } + promise_t verify(VeriPool &vpool) const { + assert(hsc != nullptr); + return cert->verify(hsc->get_config().get_pubkey(voter), vpool).then([this](bool result) { + return result && cert->get_blk_hash() == blk_hash; + }); + } + operator std::string () const { DataStream s; s << "<vote " diff --git a/include/hotstuff/crypto.h b/include/hotstuff/crypto.h index 40c9140..b79c433 100644 --- a/include/hotstuff/crypto.h +++ b/include/hotstuff/crypto.h @@ -6,6 +6,7 @@ #include "secp256k1.h" #include "salticidae/crypto.h" #include "hotstuff/type.h" +#include "hotstuff/worker.h" namespace hotstuff { @@ -31,6 +32,7 @@ using privkey_bt = BoxObj<PrivKey>; class PartCert: public Serializable, public Cloneable { public: virtual ~PartCert() = default; + virtual promise_t verify(const PubKey &pubkey, VeriPool &vpool) = 0; virtual bool verify(const PubKey &pubkey) = 0; virtual const uint256_t &get_blk_hash() const = 0; virtual PartCert *clone() override = 0; @@ -43,6 +45,7 @@ class QuorumCert: public Serializable, public Cloneable { virtual ~QuorumCert() = default; virtual void add_part(ReplicaID replica, const PartCert &pc) = 0; virtual void compute() = 0; + virtual promise_t verify(const ReplicaConfig &config, VeriPool &vpool) = 0; virtual bool verify(const ReplicaConfig &config) = 0; virtual const uint256_t &get_blk_hash() const = 0; virtual QuorumCert *clone() override = 0; @@ -85,6 +88,9 @@ class PartCertDummy: public PartCert { } bool verify(const PubKey &) override { return true; } + promise_t verify(const PubKey &, VeriPool &) override { + return promise_t([](promise_t &pm){ pm.resolve(true); }); + } const uint256_t &get_blk_hash() const override { return blk_hash; } }; @@ -112,6 +118,9 @@ class QuorumCertDummy: public QuorumCert { void add_part(ReplicaID, const PartCert &) override {} void compute() override {} bool verify(const ReplicaConfig &) override { return true; } + promise_t verify(const ReplicaConfig &, VeriPool &) override { + return promise_t([](promise_t &pm) { pm.resolve(true); }); + } const uint256_t &get_blk_hash() const override { return blk_hash; } }; @@ -243,7 +252,7 @@ class SigSecp256k1: public Serializable { secp256k1_ecdsa_signature data; secp256k1_context_t ctx; - void check_msg_length(const bytearray_t &msg) { + static void check_msg_length(const bytearray_t &msg) { if (msg.size() != 32) throw std::invalid_argument("the message should be 32-bytes"); } @@ -291,7 +300,7 @@ class SigSecp256k1: public Serializable { } bool verify(const bytearray_t &msg, const PubKeySecp256k1 &pub_key, - const secp256k1_context_t &_ctx) { + const secp256k1_context_t &_ctx) const { check_msg_length(msg); return secp256k1_ecdsa_verify( _ctx->ctx, &data, @@ -304,6 +313,22 @@ class SigSecp256k1: public Serializable { } }; +class Secp256k1VeriTask: public VeriTask { + uint256_t msg; + PubKeySecp256k1 pubkey; + SigSecp256k1 sig; + public: + Secp256k1VeriTask(const uint256_t &msg, + const PubKeySecp256k1 &pubkey, + const SigSecp256k1 &sig): + msg(msg), pubkey(pubkey), sig(sig) {} + virtual ~Secp256k1VeriTask() = default; + + bool verify() override { + return sig.verify(msg, pubkey, secp256k1_default_verify_ctx); + } +}; + class PartCertSecp256k1: public SigSecp256k1, public PartCert { uint256_t blk_hash; @@ -320,6 +345,12 @@ class PartCertSecp256k1: public SigSecp256k1, public PartCert { secp256k1_default_verify_ctx); } + promise_t verify(const PubKey &pub_key, VeriPool &vpool) override { + return vpool.verify(new Secp256k1VeriTask(blk_hash, + static_cast<const PubKeySecp256k1 &>(pub_key), + static_cast<const SigSecp256k1 &>(*this))); + } + const uint256_t &get_blk_hash() const override { return blk_hash; } PartCertSecp256k1 *clone() override { @@ -357,6 +388,7 @@ class QuorumCertSecp256k1: public QuorumCert { void compute() override {} bool verify(const ReplicaConfig &config) override; + promise_t verify(const ReplicaConfig &config, VeriPool &vpool) override; const uint256_t &get_blk_hash() const override { return blk_hash; } diff --git a/include/hotstuff/entity.h b/include/hotstuff/entity.h index 6f73db8..6327dfe 100644 --- a/include/hotstuff/entity.h +++ b/include/hotstuff/entity.h @@ -179,6 +179,16 @@ class Block { return true; } + promise_t verify(const ReplicaConfig &config, VeriPool &vpool) const { + return (qc ? qc->verify(config, vpool) : + promise_t([](promise_t &pm) { pm.resolve(true); })).then([this](bool result) { + if (!result) return false; + for (auto cmd: cmds) + if (!cmd->verify()) return false; + return true; + }); + } + int8_t get_decision() const { return decision; } bool is_delivered() const { return delivered; } @@ -223,11 +233,11 @@ class EntityStorage { } block_t add_blk(Block &&_blk, const ReplicaConfig &config) { - if (!_blk.verify(config)) - { - HOTSTUFF_LOG_WARN("invalid %s", std::string(_blk).c_str()); - return nullptr; - } + //if (!_blk.verify(config)) + //{ + // HOTSTUFF_LOG_WARN("invalid %s", std::string(_blk).c_str()); + // return nullptr; + //} block_t blk = new Block(std::move(_blk)); return blk_cache.insert(std::make_pair(blk->get_hash(), blk)).first->second; } diff --git a/include/hotstuff/hotstuff.h b/include/hotstuff/hotstuff.h index f9aad3d..983a7b3 100644 --- a/include/hotstuff/hotstuff.h +++ b/include/hotstuff/hotstuff.h @@ -121,6 +121,7 @@ class HotStuffBase: public HotStuffCore { size_t blk_size; /** libevent handle */ EventContext eb; + VeriPool vpool; private: /** whether libevent handle is owned by itself */ @@ -183,7 +184,8 @@ class HotStuffBase: public HotStuffCore { privkey_bt &&priv_key, NetAddr listen_addr, pacemaker_bt pmaker, - EventContext eb); + EventContext eb, + size_t nworker = 4); ~HotStuffBase(); diff --git a/include/hotstuff/worker.h b/include/hotstuff/worker.h new file mode 100644 index 0000000..229b1bf --- /dev/null +++ b/include/hotstuff/worker.h @@ -0,0 +1,92 @@ +#ifndef _HOTSTUFF_WORKER_H +#define _HOTSTUFF_WORKER_H + +#include <thread> +#include <unordered_map> +#include <unistd.h> +#include "concurrentqueue/blockingconcurrentqueue.h" + +namespace hotstuff { + +class VeriTask { + friend class VeriPool; + bool result; + public: + virtual bool verify() = 0; + virtual ~VeriTask() = default; +}; + +using veritask_ut = BoxObj<VeriTask>; + +class VeriPool { + using queue_t = moodycamel::BlockingConcurrentQueue<VeriTask *>; + int fin_fd[2]; + Event fin_ev; + queue_t in_queue; + queue_t out_queue; + std::thread notifier; + std::vector<std::thread> workers; + std::unordered_map<VeriTask *, std::pair<veritask_ut, promise_t>> pms; + public: + VeriPool(EventContext ec, size_t nworker) { + pipe(fin_fd); + fin_ev = Event(ec, fin_fd[0], EV_READ, [&](int fd, short) { + VeriTask *task; + bool result; + read(fd, &task, sizeof(VeriTask *)); + read(fd, &result, sizeof(bool)); + auto it = pms.find(task); + it->second.second.resolve(result); + pms.erase(it); + fin_ev.add(); + }); + fin_ev.add(); + // finish notifier thread + notifier = std::thread([this]() { + while (true) + { + VeriTask *task; + out_queue.wait_dequeue(task); + write(fin_fd[1], &task, sizeof(VeriTask *)); + write(fin_fd[1], &(task->result), sizeof(bool)); + } + }); + for (size_t i = 0; i < nworker; i++) + { + workers.push_back(std::thread([this]() { + while (true) + { + VeriTask *task; + in_queue.wait_dequeue(task); + //fprintf(stderr, "%lu working on %u\n", std::this_thread::get_id(), (uintptr_t)task); + task->result = task->verify(); + out_queue.enqueue(task); + } + })); + } + } + + ~VeriPool() { + notifier.detach(); + for (auto &w: workers) w.detach(); + close(fin_fd[0]); + close(fin_fd[1]); + } + + promise_t verify(veritask_ut &&task) { + auto ptr = task.get(); + auto ret = pms.insert(std::make_pair(ptr, + std::make_pair(std::move(task), promise_t([](promise_t &){})))); + assert(ret.second); + in_queue.enqueue(ptr); + return ret.first->second.second; + } + + int get_fd() { + return fin_fd[0]; + } +}; + +} + +#endif |