diff --git a/fbpcf/engine/tuple_generator/oblivious_transfer/ferret/test/benchmarks/CotBenchmark.h b/fbpcf/engine/tuple_generator/oblivious_transfer/ferret/test/benchmarks/CotBenchmark.h index 1f0609c3..b325d1a5 100644 --- a/fbpcf/engine/tuple_generator/oblivious_transfer/ferret/test/benchmarks/CotBenchmark.h +++ b/fbpcf/engine/tuple_generator/oblivious_transfer/ferret/test/benchmarks/CotBenchmark.h @@ -55,24 +55,30 @@ class SinglePointCotBenchmark final : public util::NetworkedBenchmark { agent0_ = std::move(agent0); agent1_ = std::move(agent1); - SinglePointCotFactory factory; - sender_ = factory.create(agent0_); - receiver_ = factory.create(agent1_); - auto baseOtSize = std::log2(kExtendedSize / kWeight); auto [baseOTSend, baseOTReceive, delta] = getBaseOT(baseOtSize); baseOTSend_ = std::move(baseOTSend); baseOTReceive_ = std::move(baseOTReceive); - - sender_->senderInit(delta); - receiver_->receiverInit(); + delta_ = delta; } protected: + void initSender() override { + SinglePointCotFactory factory; + sender_ = factory.create(agent0_); + sender_->senderInit(delta_); + } + void runSender() override { sender_->senderExtend(std::move(baseOTSend_)); } + void initReceiver() override { + SinglePointCotFactory factory; + receiver_ = factory.create(agent1_); + receiver_->receiverInit(); + } + void runReceiver() override { receiver_->receiverExtend(std::move(baseOTReceive_)); } @@ -90,6 +96,7 @@ class SinglePointCotBenchmark final : public util::NetworkedBenchmark { std::vector<__m128i> baseOTSend_; std::vector<__m128i> baseOTReceive_; + __m128i delta_; }; class RegularErrorMultiPointCotBenchmark final @@ -100,26 +107,31 @@ class RegularErrorMultiPointCotBenchmark final agent0_ = std::move(agent0); agent1_ = std::move(agent1); - RegularErrorMultiPointCotFactory factory( + factory_ = std::make_unique( std::make_unique()); - sender_ = factory.create(agent0_); - receiver_ = factory.create(agent1_); - auto baseOtSize = std::log2(kExtendedSize / kWeight) * kWeight; auto [baseOTSend, baseOTReceive, delta] = getBaseOT(baseOtSize); baseOTSend_ = std::move(baseOTSend); baseOTReceive_ = std::move(baseOTReceive); - - sender_->senderInit(delta, kExtendedSize, kWeight); - receiver_->receiverInit(kExtendedSize, kWeight); + delta_ = delta; } protected: + void initSender() override { + sender_ = factory_->create(agent0_); + sender_->senderInit(delta_, kExtendedSize, kWeight); + } + void runSender() override { sender_->senderExtend(std::move(baseOTSend_)); } + void initReceiver() override { + receiver_ = factory_->create(agent1_); + receiver_->receiverInit(kExtendedSize, kWeight); + } + void runReceiver() override { receiver_->receiverExtend(std::move(baseOTReceive_)); } @@ -129,6 +141,8 @@ class RegularErrorMultiPointCotBenchmark final } private: + std::unique_ptr factory_; + std::unique_ptr agent0_; std::unique_ptr agent1_; @@ -137,38 +151,45 @@ class RegularErrorMultiPointCotBenchmark final std::vector<__m128i> baseOTSend_; std::vector<__m128i> baseOTReceive_; + __m128i delta_; }; class RcotExtenderBenchmark final : public util::NetworkedBenchmark { public: void setup() override { auto [agent0, agent1] = util::getSocketAgents(); + agent0_ = std::move(agent0); + agent1_ = std::move(agent1); - RcotExtenderFactory factory( + factory_ = std::make_unique( std::make_unique(), std::make_unique( std::make_unique())); - sender_ = factory.create(); - receiver_ = factory.create(); - - sender_->setCommunicationAgent(std::move(agent0)); - receiver_->setCommunicationAgent(std::move(agent1)); - auto baseOtSize = kBaseSize + std::log2(kExtendedSize / kWeight) * kWeight; auto [baseOTSend, baseOTReceive, delta] = getBaseOT(baseOtSize); baseOTSend_ = std::move(baseOTSend); baseOTReceive_ = std::move(baseOTReceive); - - sender_->senderInit(delta, kExtendedSize, kBaseSize, kWeight); - receiver_->receiverInit(kExtendedSize, kBaseSize, kWeight); + delta_ = delta; } protected: + void initSender() override { + sender_ = factory_->create(); + sender_->setCommunicationAgent(std::move(agent0_)); + sender_->senderInit(delta_, kExtendedSize, kBaseSize, kWeight); + } + void runSender() override { sender_->senderExtendRcot(std::move(baseOTSend_)); } + void initReceiver() override { + receiver_ = factory_->create(); + receiver_->setCommunicationAgent(std::move(agent1_)); + receiver_->receiverInit(kExtendedSize, kBaseSize, kWeight); + } + void runReceiver() override { receiver_->receiverExtendRcot(std::move(baseOTReceive_)); } @@ -178,11 +199,17 @@ class RcotExtenderBenchmark final : public util::NetworkedBenchmark { } private: + std::unique_ptr agent0_; + std::unique_ptr agent1_; + + std::unique_ptr factory_; + std::unique_ptr sender_; std::unique_ptr receiver_; std::vector<__m128i> baseOTSend_; std::vector<__m128i> baseOTReceive_; + __m128i delta_; }; } // namespace fbpcf::engine::tuple_generator::oblivious_transfer::ferret diff --git a/fbpcf/engine/tuple_generator/oblivious_transfer/test/benchmarks/OtBenchmark.cpp b/fbpcf/engine/tuple_generator/oblivious_transfer/test/benchmarks/OtBenchmark.cpp index 3f751e51..e05ea772 100644 --- a/fbpcf/engine/tuple_generator/oblivious_transfer/test/benchmarks/OtBenchmark.cpp +++ b/fbpcf/engine/tuple_generator/oblivious_transfer/test/benchmarks/OtBenchmark.cpp @@ -32,18 +32,27 @@ class NpBaseObliviousTransferBenchmark : public util::NetworkedBenchmark { public: void setup() override { auto [agent0, agent1] = util::getSocketAgents(); - - NpBaseObliviousTransferFactory factory; - sender_ = factory.create(std::move(agent0)); - receiver_ = factory.create(std::move(agent1)); + agent0_ = std::move(agent0); + agent1_ = std::move(agent1); choice_ = util::getRandomBoolVector(size_); } + protected: + void initSender() override { + NpBaseObliviousTransferFactory factory; + sender_ = factory.create(std::move(agent0_)); + } + void runSender() override { sender_->send(size_); } + void initReceiver() override { + NpBaseObliviousTransferFactory factory; + receiver_ = factory.create(std::move(agent1_)); + } + void runReceiver() override { receiver_->receive(choice_); } @@ -55,6 +64,9 @@ class NpBaseObliviousTransferBenchmark : public util::NetworkedBenchmark { private: size_t size_ = 1024; + std::unique_ptr agent0_; + std::unique_ptr agent1_; + std::unique_ptr sender_; std::unique_ptr receiver_; @@ -81,13 +93,20 @@ class RandomCorrelatedObliviousTransferBenchmark util::setLsbTo1(delta_); } - void runSender() override { + protected: + void initSender() override { sender_ = factory_->create(delta_, std::move(agent0_)); + } + + void runSender() override { sender_->rcot(size_); } - void runReceiver() override { + void initReceiver() override { receiver_ = factory_->create(std::move(agent1_)); + } + + void runReceiver() override { receiver_->rcot(size_); } @@ -95,7 +114,6 @@ class RandomCorrelatedObliviousTransferBenchmark return sender_->getTrafficStatistics(); } - protected: std::unique_ptr factory_; private: @@ -192,7 +210,6 @@ BENCHMARK_COUNTERS( } class BidirectionObliviousTransferBenchmark : public util::NetworkedBenchmark { - protected: public: void setup() override { auto [agentFactory0, agentFactory1] = util::getSocketAgentFactories(); @@ -215,13 +232,20 @@ class BidirectionObliviousTransferBenchmark : public util::NetworkedBenchmark { receiverChoice_ = util::getRandomBoolVector(size_); } - void runSender() override { + protected: + void initSender() override { sender_ = senderFactory_->create(1); + } + + void runSender() override { sender_->biDirectionOT(senderInput0_, senderInput1_, senderChoice_); } - void runReceiver() override { + void initReceiver() override { receiver_ = receiverFactory_->create(0); + } + + void runReceiver() override { receiver_->biDirectionOT(receiverInput0_, receiverInput1_, receiverChoice_); } @@ -229,7 +253,6 @@ class BidirectionObliviousTransferBenchmark : public util::NetworkedBenchmark { return sender_->getTrafficStatistics(); } - protected: virtual std::unique_ptr getRcotFactory() = 0; diff --git a/fbpcf/engine/tuple_generator/test/benchmarks/TupleGeneratorBenchmark.cpp b/fbpcf/engine/tuple_generator/test/benchmarks/TupleGeneratorBenchmark.cpp index c35a7fed..0108307c 100644 --- a/fbpcf/engine/tuple_generator/test/benchmarks/TupleGeneratorBenchmark.cpp +++ b/fbpcf/engine/tuple_generator/test/benchmarks/TupleGeneratorBenchmark.cpp @@ -47,13 +47,20 @@ class ProductShareGeneratorBenchmark : public util::NetworkedBenchmark { receiverRight_ = util::getRandomBoolVector(size_); } - void runSender() override { + protected: + void initSender() override { sender_ = senderFactory_->create(1); + } + + void runSender() override { sender_->generateBooleanProductShares(senderLeft_, senderRight_); } - void runReceiver() override { + void initReceiver() override { receiver_ = receiverFactory_->create(0); + } + + void runReceiver() override { receiver_->generateBooleanProductShares(receiverLeft_, receiverRight_); } @@ -98,13 +105,20 @@ class BaseTupleGeneratorBenchmark : public util::NetworkedBenchmark { receiverFactory_ = getTupleGeneratorFactory(1, *agentFactory1_); } - void runSender() override { + protected: + void initSender() override { sender_ = senderFactory_->create(); + } + + void runSender() override { sender_->getBooleanTuple(size_); } - void runReceiver() override { + void initReceiver() override { receiver_ = receiverFactory_->create(); + } + + void runReceiver() override { receiver_->getBooleanTuple(size_); } @@ -112,7 +126,6 @@ class BaseTupleGeneratorBenchmark : public util::NetworkedBenchmark { return sender_->getTrafficStatistics(); } - protected: virtual std::unique_ptr getTupleGeneratorFactory( int myId, communication::IPartyCommunicationAgentFactory& agentFactory) = 0; diff --git a/fbpcf/engine/util/test/benchmarks/NetworkedBenchmark.h b/fbpcf/engine/util/test/benchmarks/NetworkedBenchmark.h index 68cf867b..6a87b92a 100644 --- a/fbpcf/engine/util/test/benchmarks/NetworkedBenchmark.h +++ b/fbpcf/engine/util/test/benchmarks/NetworkedBenchmark.h @@ -7,6 +7,7 @@ #pragma once +#include #include #include @@ -22,8 +23,26 @@ class NetworkedBenchmark { virtual ~NetworkedBenchmark() = default; void runBenchmark(folly::UserCounters& counters) { + uint64_t initTransmittedBytes; BENCHMARK_SUSPEND { setup(); + + auto start = std::chrono::high_resolution_clock::now(); + + auto initSenderTask = std::async([this]() { initSender(); }); + auto initReceiverTask = std::async([this]() { initReceiver(); }); + + initSenderTask.get(); + initReceiverTask.get(); + + auto end = std::chrono::high_resolution_clock::now(); + counters["init_time_usec"] = + std::chrono::duration_cast(end - start) + .count(); + + auto [sent, received] = getTrafficStatistics(); + initTransmittedBytes = sent + received; + counters["init_transmitted_bytes"] = initTransmittedBytes; } auto senderTask = std::async([this]() { runSender(); }); @@ -34,12 +53,16 @@ class NetworkedBenchmark { BENCHMARK_SUSPEND { auto [sent, received] = getTrafficStatistics(); - counters["transmitted_bytes"] = sent + received; + counters["transmitted_bytes"] = sent + received - initTransmittedBytes; } } protected: virtual void setup() = 0; + + virtual void initSender() = 0; + virtual void initReceiver() = 0; + virtual void runSender() = 0; virtual void runReceiver() = 0;