Skip to content

Commit

Permalink
Track initialization metrics separately (#102)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #102

Many components in the MPC engine have a one-time initialization. Depending on how large of a sample size we use in our benchmarks, this can take up a significant amount of the overall runtime or network traffic. For this reason, I'm thinking it'd be useful to separate out the portion that's just for the initialization.

Reviewed By: adshastri

Differential Revision: D34906916

fbshipit-source-id: 9e02a44e8c99242801895ef11e95299129e443cb
  • Loading branch information
Elliott Lawrence authored and facebook-github-bot committed Mar 17, 2022
1 parent 096c62a commit 5692b16
Show file tree
Hide file tree
Showing 4 changed files with 127 additions and 41 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -55,24 +55,30 @@ class SinglePointCotBenchmark final : public util::NetworkedBenchmark {
agent0_ = std::move(agent0);
agent1_ = std::move(agent1);

SinglePointCotFactory factory;
sender_ = factory.create(agent0_);
receiver_ = factory.create(agent1_);

auto baseOtSize = std::log2(kExtendedSize / kWeight);
auto [baseOTSend, baseOTReceive, delta] = getBaseOT(baseOtSize);
baseOTSend_ = std::move(baseOTSend);
baseOTReceive_ = std::move(baseOTReceive);

sender_->senderInit(delta);
receiver_->receiverInit();
delta_ = delta;
}

protected:
void initSender() override {
SinglePointCotFactory factory;
sender_ = factory.create(agent0_);
sender_->senderInit(delta_);
}

void runSender() override {
sender_->senderExtend(std::move(baseOTSend_));
}

void initReceiver() override {
SinglePointCotFactory factory;
receiver_ = factory.create(agent1_);
receiver_->receiverInit();
}

void runReceiver() override {
receiver_->receiverExtend(std::move(baseOTReceive_));
}
Expand All @@ -90,6 +96,7 @@ class SinglePointCotBenchmark final : public util::NetworkedBenchmark {

std::vector<__m128i> baseOTSend_;
std::vector<__m128i> baseOTReceive_;
__m128i delta_;
};

class RegularErrorMultiPointCotBenchmark final
Expand All @@ -100,26 +107,31 @@ class RegularErrorMultiPointCotBenchmark final
agent0_ = std::move(agent0);
agent1_ = std::move(agent1);

RegularErrorMultiPointCotFactory factory(
factory_ = std::make_unique<RegularErrorMultiPointCotFactory>(
std::make_unique<SinglePointCotFactory>());

sender_ = factory.create(agent0_);
receiver_ = factory.create(agent1_);

auto baseOtSize = std::log2(kExtendedSize / kWeight) * kWeight;
auto [baseOTSend, baseOTReceive, delta] = getBaseOT(baseOtSize);
baseOTSend_ = std::move(baseOTSend);
baseOTReceive_ = std::move(baseOTReceive);

sender_->senderInit(delta, kExtendedSize, kWeight);
receiver_->receiverInit(kExtendedSize, kWeight);
delta_ = delta;
}

protected:
void initSender() override {
sender_ = factory_->create(agent0_);
sender_->senderInit(delta_, kExtendedSize, kWeight);
}

void runSender() override {
sender_->senderExtend(std::move(baseOTSend_));
}

void initReceiver() override {
receiver_ = factory_->create(agent1_);
receiver_->receiverInit(kExtendedSize, kWeight);
}

void runReceiver() override {
receiver_->receiverExtend(std::move(baseOTReceive_));
}
Expand All @@ -129,6 +141,8 @@ class RegularErrorMultiPointCotBenchmark final
}

private:
std::unique_ptr<RegularErrorMultiPointCotFactory> factory_;

std::unique_ptr<communication::IPartyCommunicationAgent> agent0_;
std::unique_ptr<communication::IPartyCommunicationAgent> agent1_;

Expand All @@ -137,38 +151,45 @@ class RegularErrorMultiPointCotBenchmark final

std::vector<__m128i> baseOTSend_;
std::vector<__m128i> baseOTReceive_;
__m128i delta_;
};

class RcotExtenderBenchmark final : public util::NetworkedBenchmark {
public:
void setup() override {
auto [agent0, agent1] = util::getSocketAgents();
agent0_ = std::move(agent0);
agent1_ = std::move(agent1);

RcotExtenderFactory factory(
factory_ = std::make_unique<RcotExtenderFactory>(
std::make_unique<TenLocalLinearMatrixMultiplierFactory>(),
std::make_unique<RegularErrorMultiPointCotFactory>(
std::make_unique<SinglePointCotFactory>()));

sender_ = factory.create();
receiver_ = factory.create();

sender_->setCommunicationAgent(std::move(agent0));
receiver_->setCommunicationAgent(std::move(agent1));

auto baseOtSize = kBaseSize + std::log2(kExtendedSize / kWeight) * kWeight;
auto [baseOTSend, baseOTReceive, delta] = getBaseOT(baseOtSize);
baseOTSend_ = std::move(baseOTSend);
baseOTReceive_ = std::move(baseOTReceive);

sender_->senderInit(delta, kExtendedSize, kBaseSize, kWeight);
receiver_->receiverInit(kExtendedSize, kBaseSize, kWeight);
delta_ = delta;
}

protected:
void initSender() override {
sender_ = factory_->create();
sender_->setCommunicationAgent(std::move(agent0_));
sender_->senderInit(delta_, kExtendedSize, kBaseSize, kWeight);
}

void runSender() override {
sender_->senderExtendRcot(std::move(baseOTSend_));
}

void initReceiver() override {
receiver_ = factory_->create();
receiver_->setCommunicationAgent(std::move(agent1_));
receiver_->receiverInit(kExtendedSize, kBaseSize, kWeight);
}

void runReceiver() override {
receiver_->receiverExtendRcot(std::move(baseOTReceive_));
}
Expand All @@ -178,11 +199,17 @@ class RcotExtenderBenchmark final : public util::NetworkedBenchmark {
}

private:
std::unique_ptr<communication::IPartyCommunicationAgent> agent0_;
std::unique_ptr<communication::IPartyCommunicationAgent> agent1_;

std::unique_ptr<RcotExtenderFactory> factory_;

std::unique_ptr<IRcotExtender> sender_;
std::unique_ptr<IRcotExtender> receiver_;

std::vector<__m128i> baseOTSend_;
std::vector<__m128i> baseOTReceive_;
__m128i delta_;
};

} // namespace fbpcf::engine::tuple_generator::oblivious_transfer::ferret
Original file line number Diff line number Diff line change
Expand Up @@ -32,18 +32,27 @@ class NpBaseObliviousTransferBenchmark : public util::NetworkedBenchmark {
public:
void setup() override {
auto [agent0, agent1] = util::getSocketAgents();

NpBaseObliviousTransferFactory factory;
sender_ = factory.create(std::move(agent0));
receiver_ = factory.create(std::move(agent1));
agent0_ = std::move(agent0);
agent1_ = std::move(agent1);

choice_ = util::getRandomBoolVector(size_);
}

protected:
void initSender() override {
NpBaseObliviousTransferFactory factory;
sender_ = factory.create(std::move(agent0_));
}

void runSender() override {
sender_->send(size_);
}

void initReceiver() override {
NpBaseObliviousTransferFactory factory;
receiver_ = factory.create(std::move(agent1_));
}

void runReceiver() override {
receiver_->receive(choice_);
}
Expand All @@ -55,6 +64,9 @@ class NpBaseObliviousTransferBenchmark : public util::NetworkedBenchmark {
private:
size_t size_ = 1024;

std::unique_ptr<communication::IPartyCommunicationAgent> agent0_;
std::unique_ptr<communication::IPartyCommunicationAgent> agent1_;

std::unique_ptr<IBaseObliviousTransfer> sender_;
std::unique_ptr<IBaseObliviousTransfer> receiver_;

Expand All @@ -81,21 +93,27 @@ class RandomCorrelatedObliviousTransferBenchmark
util::setLsbTo1(delta_);
}

void runSender() override {
protected:
void initSender() override {
sender_ = factory_->create(delta_, std::move(agent0_));
}

void runSender() override {
sender_->rcot(size_);
}

void runReceiver() override {
void initReceiver() override {
receiver_ = factory_->create(std::move(agent1_));
}

void runReceiver() override {
receiver_->rcot(size_);
}

std::pair<uint64_t, uint64_t> getTrafficStatistics() override {
return sender_->getTrafficStatistics();
}

protected:
std::unique_ptr<IRandomCorrelatedObliviousTransferFactory> factory_;

private:
Expand Down Expand Up @@ -192,7 +210,6 @@ BENCHMARK_COUNTERS(
}

class BidirectionObliviousTransferBenchmark : public util::NetworkedBenchmark {
protected:
public:
void setup() override {
auto [agentFactory0, agentFactory1] = util::getSocketAgentFactories();
Expand All @@ -215,21 +232,27 @@ class BidirectionObliviousTransferBenchmark : public util::NetworkedBenchmark {
receiverChoice_ = util::getRandomBoolVector(size_);
}

void runSender() override {
protected:
void initSender() override {
sender_ = senderFactory_->create(1);
}

void runSender() override {
sender_->biDirectionOT(senderInput0_, senderInput1_, senderChoice_);
}

void runReceiver() override {
void initReceiver() override {
receiver_ = receiverFactory_->create(0);
}

void runReceiver() override {
receiver_->biDirectionOT(receiverInput0_, receiverInput1_, receiverChoice_);
}

std::pair<uint64_t, uint64_t> getTrafficStatistics() override {
return sender_->getTrafficStatistics();
}

protected:
virtual std::unique_ptr<IRandomCorrelatedObliviousTransferFactory>
getRcotFactory() = 0;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,13 +47,20 @@ class ProductShareGeneratorBenchmark : public util::NetworkedBenchmark {
receiverRight_ = util::getRandomBoolVector(size_);
}

void runSender() override {
protected:
void initSender() override {
sender_ = senderFactory_->create(1);
}

void runSender() override {
sender_->generateBooleanProductShares(senderLeft_, senderRight_);
}

void runReceiver() override {
void initReceiver() override {
receiver_ = receiverFactory_->create(0);
}

void runReceiver() override {
receiver_->generateBooleanProductShares(receiverLeft_, receiverRight_);
}

Expand Down Expand Up @@ -98,21 +105,27 @@ class BaseTupleGeneratorBenchmark : public util::NetworkedBenchmark {
receiverFactory_ = getTupleGeneratorFactory(1, *agentFactory1_);
}

void runSender() override {
protected:
void initSender() override {
sender_ = senderFactory_->create();
}

void runSender() override {
sender_->getBooleanTuple(size_);
}

void runReceiver() override {
void initReceiver() override {
receiver_ = receiverFactory_->create();
}

void runReceiver() override {
receiver_->getBooleanTuple(size_);
}

std::pair<uint64_t, uint64_t> getTrafficStatistics() override {
return sender_->getTrafficStatistics();
}

protected:
virtual std::unique_ptr<ITupleGeneratorFactory> getTupleGeneratorFactory(
int myId,
communication::IPartyCommunicationAgentFactory& agentFactory) = 0;
Expand Down
25 changes: 24 additions & 1 deletion fbpcf/engine/util/test/benchmarks/NetworkedBenchmark.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

#pragma once

#include <chrono>
#include <future>

#include <folly/Benchmark.h>
Expand All @@ -22,8 +23,26 @@ class NetworkedBenchmark {
virtual ~NetworkedBenchmark() = default;

void runBenchmark(folly::UserCounters& counters) {
uint64_t initTransmittedBytes;
BENCHMARK_SUSPEND {
setup();

auto start = std::chrono::high_resolution_clock::now();

auto initSenderTask = std::async([this]() { initSender(); });
auto initReceiverTask = std::async([this]() { initReceiver(); });

initSenderTask.get();
initReceiverTask.get();

auto end = std::chrono::high_resolution_clock::now();
counters["init_time_usec"] =
std::chrono::duration_cast<std::chrono::microseconds>(end - start)
.count();

auto [sent, received] = getTrafficStatistics();
initTransmittedBytes = sent + received;
counters["init_transmitted_bytes"] = initTransmittedBytes;
}

auto senderTask = std::async([this]() { runSender(); });
Expand All @@ -34,12 +53,16 @@ class NetworkedBenchmark {

BENCHMARK_SUSPEND {
auto [sent, received] = getTrafficStatistics();
counters["transmitted_bytes"] = sent + received;
counters["transmitted_bytes"] = sent + received - initTransmittedBytes;
}
}

protected:
virtual void setup() = 0;

virtual void initSender() = 0;
virtual void initReceiver() = 0;

virtual void runSender() = 0;
virtual void runReceiver() = 0;

Expand Down

0 comments on commit 5692b16

Please sign in to comment.