Skip to content

Commit

Permalink
refactor: share code between provers (#4655)
Browse files Browse the repository at this point in the history
A lot of code is repeated between the Decider prover, folding prover,
and ultra honk prover. This PR aims to reduce duplication by creating a
OinkProver, which supports the 5 round functions before sumcheck. 
The OinkProver is used in the folding and ultra honk provers.

It does not address the shared code between verifiers or the shared code
between prover and verifier. It also is an initial step at a round
abstraction, where each round is implemented as a separate class and the
data being used/modified in each round is clearly defined.

Resolves #795.
  • Loading branch information
lucasxia01 authored and AztecBot committed Mar 13, 2024
1 parent 3094154 commit d2620cf
Show file tree
Hide file tree
Showing 8 changed files with 236 additions and 240 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -45,11 +45,11 @@ BB_PROFILE static void test_round_inner(State& state, GoblinUltraProver& prover,
}
};

time_if_index(PREAMBLE, [&] { prover.execute_preamble_round(); });
time_if_index(WIRE_COMMITMENTS, [&] { prover.execute_wire_commitments_round(); });
time_if_index(SORTED_LIST_ACCUMULATOR, [&] { prover.execute_sorted_list_accumulator_round(); });
time_if_index(LOG_DERIVATIVE_INVERSE, [&] { prover.execute_log_derivative_inverse_round(); });
time_if_index(GRAND_PRODUCT_COMPUTATION, [&] { prover.execute_grand_product_computation_round(); });
time_if_index(PREAMBLE, [&] { prover.oink_prover.execute_preamble_round(); });
time_if_index(WIRE_COMMITMENTS, [&] { prover.oink_prover.execute_wire_commitments_round(); });
time_if_index(SORTED_LIST_ACCUMULATOR, [&] { prover.oink_prover.execute_sorted_list_accumulator_round(); });
time_if_index(LOG_DERIVATIVE_INVERSE, [&] { prover.oink_prover.execute_log_derivative_inverse_round(); });
time_if_index(GRAND_PRODUCT_COMPUTATION, [&] { prover.oink_prover.execute_grand_product_computation_round(); });
time_if_index(RELATION_CHECK, [&] { prover.execute_relation_check_rounds(); });
time_if_index(ZEROMORPH, [&] { prover.execute_zeromorph_rounds(); });
}
Expand All @@ -62,7 +62,10 @@ BB_PROFILE static void test_round(State& state, size_t index) noexcept
auto prover = bb::mock_proofs::get_prover<GoblinUltraProver>(
&bb::mock_proofs::generate_basic_arithmetic_circuit<GoblinUltraCircuitBuilder>, log2_num_gates);
for (auto _ : state) {
state.PauseTiming();
test_round_inner(state, prover, index);
state.ResumeTiming();
// NOTE: google bench is very finnicky, must end in ResumeTiming() for correctness
}
}
#define ROUND_BENCHMARK(round) \
Expand Down
90 changes: 13 additions & 77 deletions cpp/src/barretenberg/protogalaxy/protogalaxy_prover.cpp
Original file line number Diff line number Diff line change
@@ -1,93 +1,29 @@
#include "protogalaxy_prover.hpp"
#include "barretenberg/flavor/flavor.hpp"
#include "barretenberg/ultra_honk/oink_prover.hpp"
namespace bb {
template <class ProverInstances>
void ProtoGalaxyProver_<ProverInstances>::finalise_and_send_instance(std::shared_ptr<Instance> instance,
const std::string& domain_separator)
{
instance->initialize_prover_polynomials();
OinkProver<Flavor> oink_prover(instance, commitment_key, transcript, domain_separator + '_');

const auto instance_size = static_cast<uint32_t>(instance->proving_key->circuit_size);
const auto num_public_inputs = static_cast<uint32_t>(instance->proving_key->num_public_inputs);
transcript->send_to_verifier(domain_separator + "_instance_size", instance_size);
transcript->send_to_verifier(domain_separator + "_public_input_size", num_public_inputs);
// Add circuit size public input size and public inputs to transcript
oink_prover.execute_preamble_round();

for (size_t i = 0; i < instance->proving_key->public_inputs.size(); ++i) {
auto public_input_i = instance->proving_key->public_inputs[i];
transcript->send_to_verifier(domain_separator + "_public_input_" + std::to_string(i), public_input_i);
}
transcript->send_to_verifier(domain_separator + "_pub_inputs_offset",
static_cast<uint32_t>(instance->proving_key->pub_inputs_offset));

auto& witness_commitments = instance->witness_commitments;

// Commit to the first three wire polynomials of the instance
// We only commit to the fourth wire polynomial after adding memory recordss
witness_commitments.w_l = commitment_key->commit(instance->proving_key->w_l);
witness_commitments.w_r = commitment_key->commit(instance->proving_key->w_r);
witness_commitments.w_o = commitment_key->commit(instance->proving_key->w_o);

auto wire_comms = witness_commitments.get_wires();
auto commitment_labels = instance->commitment_labels;
auto wire_labels = commitment_labels.get_wires();
for (size_t idx = 0; idx < 3; ++idx) {
transcript->send_to_verifier(domain_separator + "_" + wire_labels[idx], wire_comms[idx]);
}

if constexpr (IsGoblinFlavor<Flavor>) {
// Commit to Goblin ECC op wires
witness_commitments.ecc_op_wire_1 = commitment_key->commit(instance->proving_key->ecc_op_wire_1);
witness_commitments.ecc_op_wire_2 = commitment_key->commit(instance->proving_key->ecc_op_wire_2);
witness_commitments.ecc_op_wire_3 = commitment_key->commit(instance->proving_key->ecc_op_wire_3);
witness_commitments.ecc_op_wire_4 = commitment_key->commit(instance->proving_key->ecc_op_wire_4);

auto op_wire_comms = instance->witness_commitments.get_ecc_op_wires();
auto labels = commitment_labels.get_ecc_op_wires();
for (size_t idx = 0; idx < Flavor::NUM_WIRES; ++idx) {
transcript->send_to_verifier(domain_separator + "_" + labels[idx], op_wire_comms[idx]);
}
// Commit to DataBus columns
witness_commitments.calldata = commitment_key->commit(instance->proving_key->calldata);
witness_commitments.calldata_read_counts = commitment_key->commit(instance->proving_key->calldata_read_counts);
transcript->send_to_verifier(domain_separator + "_" + commitment_labels.calldata,
instance->witness_commitments.calldata);
transcript->send_to_verifier(domain_separator + "_" + commitment_labels.calldata_read_counts,
instance->witness_commitments.calldata_read_counts);
}

auto eta = transcript->template get_challenge<FF>(domain_separator + "_eta");
instance->compute_sorted_accumulator_polynomials(eta);

// Commit to the sorted witness-table accumulator and the finalized (i.e. with memory records) fourth wire
// polynomial
witness_commitments.sorted_accum = commitment_key->commit(instance->prover_polynomials.sorted_accum);
witness_commitments.w_4 = commitment_key->commit(instance->prover_polynomials.w_4);
// Compute first three wire commitments
oink_prover.execute_wire_commitments_round();

transcript->send_to_verifier(domain_separator + "_" + commitment_labels.sorted_accum,
witness_commitments.sorted_accum);
transcript->send_to_verifier(domain_separator + "_" + commitment_labels.w_4, witness_commitments.w_4);

auto [beta, gamma] =
transcript->template get_challenges<FF>(domain_separator + "_beta", domain_separator + "_gamma");

if constexpr (IsGoblinFlavor<Flavor>) {
// Compute and commit to the logderivative inverse used in DataBus
instance->compute_logderivative_inverse(beta, gamma);
instance->witness_commitments.lookup_inverses =
commitment_key->commit(instance->prover_polynomials.lookup_inverses);
transcript->send_to_verifier(domain_separator + "_" + commitment_labels.lookup_inverses,
instance->witness_commitments.lookup_inverses);
}
// Compute sorted list accumulator and commitment
oink_prover.execute_sorted_list_accumulator_round();

instance->compute_grand_product_polynomials(beta, gamma);
// Fiat-Shamir: beta & gamma
oink_prover.execute_log_derivative_inverse_round();

witness_commitments.z_perm = commitment_key->commit(instance->prover_polynomials.z_perm);
witness_commitments.z_lookup = commitment_key->commit(instance->prover_polynomials.z_lookup);
// Compute grand product(s) and commitments.
oink_prover.execute_grand_product_computation_round();

transcript->send_to_verifier(domain_separator + "_" + commitment_labels.z_perm,
instance->witness_commitments.z_perm);
transcript->send_to_verifier(domain_separator + "_" + commitment_labels.z_lookup,
instance->witness_commitments.z_lookup);
// Generate relation separators alphas for sumcheck
for (size_t idx = 0; idx < NUM_SUBRELATIONS - 1; idx++) {
instance->alphas[idx] =
transcript->template get_challenge<FF>(domain_separator + "_alpha_" + std::to_string(idx));
Expand Down
7 changes: 3 additions & 4 deletions cpp/src/barretenberg/protogalaxy/protogalaxy_verifier.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,21 +8,20 @@ void ProtoGalaxyVerifier_<VerifierInstances>::receive_and_finalise_instance(cons
{
// Get circuit parameters and the public inputs
inst->verification_key->circuit_size =
transcript->template receive_from_prover<uint32_t>(domain_separator + "_instance_size");
transcript->template receive_from_prover<uint32_t>(domain_separator + "_circuit_size");
inst->verification_key->log_circuit_size =
static_cast<size_t>(numeric::get_msb(inst->verification_key->circuit_size));
inst->verification_key->num_public_inputs =
transcript->template receive_from_prover<uint32_t>(domain_separator + "_public_input_size");
inst->verification_key->pub_inputs_offset =
transcript->template receive_from_prover<uint32_t>(domain_separator + "_pub_inputs_offset");
inst->verification_key->public_inputs.clear();
for (size_t i = 0; i < inst->verification_key->num_public_inputs; ++i) {
auto public_input_i =
transcript->template receive_from_prover<FF>(domain_separator + "_public_input_" + std::to_string(i));
inst->verification_key->public_inputs.emplace_back(public_input_i);
}

inst->verification_key->pub_inputs_offset =
transcript->template receive_from_prover<uint32_t>(domain_separator + "_pub_inputs_offset");

// Get commitments to first three wire polynomials
auto labels = inst->commitment_labels;
auto& witness_commitments = inst->witness_commitments;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,25 +9,24 @@ void ProtoGalaxyRecursiveVerifier_<VerifierInstances>::receive_and_finalise_inst
const std::shared_ptr<Instance>& inst, const std::string& domain_separator)
{
// Get circuit parameters and the public inputs
const auto instance_size = transcript->template receive_from_prover<FF>(domain_separator + "_instance_size");
const auto instance_size = transcript->template receive_from_prover<FF>(domain_separator + "_circuit_size");
const auto public_input_size =
transcript->template receive_from_prover<FF>(domain_separator + "_public_input_size");
inst->verification_key->circuit_size = uint32_t(instance_size.get_value());
inst->verification_key->log_circuit_size =
static_cast<size_t>(numeric::get_msb(inst->verification_key->circuit_size));
inst->verification_key->num_public_inputs = uint32_t(public_input_size.get_value());
const auto pub_inputs_offset =
transcript->template receive_from_prover<FF>(domain_separator + "_pub_inputs_offset");
inst->verification_key->pub_inputs_offset = uint32_t(pub_inputs_offset.get_value());

inst->verification_key->public_inputs.clear();
for (size_t i = 0; i < inst->verification_key->num_public_inputs; ++i) {
auto public_input_i =
transcript->template receive_from_prover<FF>(domain_separator + "_public_input_" + std::to_string(i));
inst->verification_key->public_inputs.emplace_back(public_input_i);
}

const auto pub_inputs_offset =
transcript->template receive_from_prover<FF>(domain_separator + "_pub_inputs_offset");

inst->verification_key->pub_inputs_offset = uint32_t(pub_inputs_offset.get_value());

// Get commitments to first three wire polynomials
auto labels = inst->commitment_labels;
auto& witness_commitments = inst->witness_commitments;
Expand Down
133 changes: 133 additions & 0 deletions cpp/src/barretenberg/ultra_honk/oink_prover.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
#include "barretenberg/ultra_honk/oink_prover.hpp"

namespace bb {

/**
* @brief Add circuit size, public input size, and public inputs to transcript
*
*/
template <IsUltraFlavor Flavor> void OinkProver<Flavor>::execute_preamble_round()
{
const auto circuit_size = static_cast<uint32_t>(instance->proving_key->circuit_size);
const auto num_public_inputs = static_cast<uint32_t>(instance->proving_key->num_public_inputs);
transcript->send_to_verifier(domain_separator + "circuit_size", circuit_size);
transcript->send_to_verifier(domain_separator + "public_input_size", num_public_inputs);
transcript->send_to_verifier(domain_separator + "pub_inputs_offset",
static_cast<uint32_t>(instance->proving_key->pub_inputs_offset));

ASSERT(instance->proving_key->num_public_inputs == instance->proving_key->public_inputs.size());

for (size_t i = 0; i < instance->proving_key->num_public_inputs; ++i) {
auto public_input_i = instance->proving_key->public_inputs[i];
transcript->send_to_verifier(domain_separator + "public_input_" + std::to_string(i), public_input_i);
}
}

/**
* @brief Commit to the wire polynomials (part of the witness), with the exception of the fourth wire, which is
* only commited to after adding memory records. In the Goblin Flavor, we also commit to the ECC OP wires and the
* DataBus columns.
*/
template <IsUltraFlavor Flavor> void OinkProver<Flavor>::execute_wire_commitments_round()
{
auto& witness_commitments = instance->witness_commitments;

// Commit to the first three wire polynomials of the instance
// We only commit to the fourth wire polynomial after adding memory recordss
witness_commitments.w_l = commitment_key->commit(instance->proving_key->w_l);
witness_commitments.w_r = commitment_key->commit(instance->proving_key->w_r);
witness_commitments.w_o = commitment_key->commit(instance->proving_key->w_o);

auto wire_comms = witness_commitments.get_wires();
auto& commitment_labels = instance->commitment_labels;
auto wire_labels = commitment_labels.get_wires();
for (size_t idx = 0; idx < 3; ++idx) {
transcript->send_to_verifier(domain_separator + wire_labels[idx], wire_comms[idx]);
}

if constexpr (IsGoblinFlavor<Flavor>) {
// Commit to Goblin ECC op wires
witness_commitments.ecc_op_wire_1 = commitment_key->commit(instance->proving_key->ecc_op_wire_1);
witness_commitments.ecc_op_wire_2 = commitment_key->commit(instance->proving_key->ecc_op_wire_2);
witness_commitments.ecc_op_wire_3 = commitment_key->commit(instance->proving_key->ecc_op_wire_3);
witness_commitments.ecc_op_wire_4 = commitment_key->commit(instance->proving_key->ecc_op_wire_4);

auto op_wire_comms = witness_commitments.get_ecc_op_wires();
auto labels = commitment_labels.get_ecc_op_wires();
for (size_t idx = 0; idx < Flavor::NUM_WIRES; ++idx) {
transcript->send_to_verifier(domain_separator + labels[idx], op_wire_comms[idx]);
}
// Commit to DataBus columns
witness_commitments.calldata = commitment_key->commit(instance->proving_key->calldata);
witness_commitments.calldata_read_counts = commitment_key->commit(instance->proving_key->calldata_read_counts);
transcript->send_to_verifier(domain_separator + commitment_labels.calldata, witness_commitments.calldata);
transcript->send_to_verifier(domain_separator + commitment_labels.calldata_read_counts,
witness_commitments.calldata_read_counts);
}
}

/**
* @brief Compute sorted witness-table accumulator and commit to the resulting polynomials.
*
*/
template <IsUltraFlavor Flavor> void OinkProver<Flavor>::execute_sorted_list_accumulator_round()
{
auto& witness_commitments = instance->witness_commitments;
const auto& commitment_labels = instance->commitment_labels;

auto eta = transcript->template get_challenge<FF>(domain_separator + "eta");
instance->compute_sorted_accumulator_polynomials(eta);

// Commit to the sorted witness-table accumulator and the finalized (i.e. with memory records) fourth wire
// polynomial
witness_commitments.sorted_accum = commitment_key->commit(instance->prover_polynomials.sorted_accum);
witness_commitments.w_4 = commitment_key->commit(instance->prover_polynomials.w_4);

transcript->send_to_verifier(domain_separator + commitment_labels.sorted_accum, witness_commitments.sorted_accum);
transcript->send_to_verifier(domain_separator + commitment_labels.w_4, witness_commitments.w_4);
}

/**
* @brief Compute log derivative inverse polynomial and its commitment, if required
*
*/
template <IsUltraFlavor Flavor> void OinkProver<Flavor>::execute_log_derivative_inverse_round()
{
auto& witness_commitments = instance->witness_commitments;
const auto& commitment_labels = instance->commitment_labels;

auto [beta, gamma] = transcript->template get_challenges<FF>(domain_separator + "beta", domain_separator + "gamma");
instance->relation_parameters.beta = beta;
instance->relation_parameters.gamma = gamma;
if constexpr (IsGoblinFlavor<Flavor>) {
// Compute and commit to the logderivative inverse used in DataBus
instance->compute_logderivative_inverse(beta, gamma);
witness_commitments.lookup_inverses = commitment_key->commit(instance->prover_polynomials.lookup_inverses);
transcript->send_to_verifier(domain_separator + commitment_labels.lookup_inverses,
witness_commitments.lookup_inverses);
}
}

/**
* @brief Compute permutation and lookup grand product polynomials and their commitments
*
*/
template <IsUltraFlavor Flavor> void OinkProver<Flavor>::execute_grand_product_computation_round()
{
auto& witness_commitments = instance->witness_commitments;
const auto& commitment_labels = instance->commitment_labels;

instance->compute_grand_product_polynomials(instance->relation_parameters.beta,
instance->relation_parameters.gamma);

witness_commitments.z_perm = commitment_key->commit(instance->prover_polynomials.z_perm);
witness_commitments.z_lookup = commitment_key->commit(instance->prover_polynomials.z_lookup);

transcript->send_to_verifier(domain_separator + commitment_labels.z_perm, witness_commitments.z_perm);
transcript->send_to_verifier(domain_separator + commitment_labels.z_lookup, witness_commitments.z_lookup);
}

template class OinkProver<UltraFlavor>;
template class OinkProver<GoblinUltraFlavor>;

} // namespace bb
49 changes: 49 additions & 0 deletions cpp/src/barretenberg/ultra_honk/oink_prover.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
#pragma once
#include <utility>

#include "barretenberg/flavor/goblin_ultra.hpp"
#include "barretenberg/flavor/ultra.hpp"
#include "barretenberg/sumcheck/instance/prover_instance.hpp"
#include "barretenberg/transcript/transcript.hpp"

namespace bb {

/**
* @brief Class for all the oink rounds, which are shared between the folding prover and ultra prover.
* @details This class contains execute_preamble_round(), execute_wire_commitments_round(),
* execute_sorted_list_accumulator_round(), execute_log_derivative_inverse_round(), and
* execute_grand_product_computation_round().
*
* @tparam Flavor
*/
template <IsUltraFlavor Flavor> class OinkProver {
using CommitmentKey = typename Flavor::CommitmentKey;
using Instance = ProverInstance_<Flavor>;
using Transcript = typename Flavor::Transcript;
using FF = typename Flavor::FF;

public:
std::shared_ptr<Instance> instance;
std::shared_ptr<Transcript> transcript;
std::shared_ptr<CommitmentKey> commitment_key;
std::string domain_separator;

OinkProver(const std::shared_ptr<ProverInstance_<Flavor>>& inst,
const std::shared_ptr<typename Flavor::CommitmentKey>& commitment_key,
const std::shared_ptr<typename Flavor::Transcript>& transcript,
std::string domain_separator = "")
: instance(inst)
, transcript(transcript)
, commitment_key(commitment_key)
, domain_separator(std::move(domain_separator))
{
instance->initialize_prover_polynomials();
}

void execute_preamble_round();
void execute_wire_commitments_round();
void execute_sorted_list_accumulator_round();
void execute_log_derivative_inverse_round();
void execute_grand_product_computation_round();
};
} // namespace bb
Loading

0 comments on commit d2620cf

Please sign in to comment.