Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: share code between provers #4655

Merged
merged 21 commits into from
Mar 12, 2024
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -45,11 +45,12 @@ BB_PROFILE static void test_round_inner(State& state, GoblinUltraProver& prover,
}
};

time_if_index(PREAMBLE, [&] { prover.execute_preamble_round(); });
time_if_index(WIRE_COMMITMENTS, [&] { prover.execute_wire_commitments_round(); });
time_if_index(SORTED_LIST_ACCUMULATOR, [&] { prover.execute_sorted_list_accumulator_round(); });
time_if_index(LOG_DERIVATIVE_INVERSE, [&] { prover.execute_log_derivative_inverse_round(); });
time_if_index(GRAND_PRODUCT_COMPUTATION, [&] { prover.execute_grand_product_computation_round(); });
time_if_index(PREAMBLE, [&] { prover.pre_sumcheck_prover.execute_preamble_round(); });
time_if_index(WIRE_COMMITMENTS, [&] { prover.pre_sumcheck_prover.execute_wire_commitments_round(); });
time_if_index(SORTED_LIST_ACCUMULATOR, [&] { prover.pre_sumcheck_prover.execute_sorted_list_accumulator_round(); });
time_if_index(LOG_DERIVATIVE_INVERSE, [&] { prover.pre_sumcheck_prover.execute_log_derivative_inverse_round(); });
time_if_index(GRAND_PRODUCT_COMPUTATION,
[&] { prover.pre_sumcheck_prover.execute_grand_product_computation_round(); });
time_if_index(RELATION_CHECK, [&] { prover.execute_relation_check_rounds(); });
time_if_index(ZEROMORPH, [&] { prover.execute_zeromorph_rounds(); });
}
Expand All @@ -63,7 +64,10 @@ BB_PROFILE static void test_round(State& state, size_t index) noexcept
GoblinUltraProver prover = bb::mock_proofs::get_prover(
composer, &bb::mock_proofs::generate_basic_arithmetic_circuit<GoblinUltraCircuitBuilder>, log2_num_gates);
for (auto _ : state) {
state.PauseTiming();
test_round_inner(state, prover, index);
state.ResumeTiming();
// NOTE: google bench is very finnicky, must end in ResumeTiming() for correctness
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not sure why this was removed in the protogalaxy rounds bench. This is required for google bench to function.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm, this is not strictly related to this PR and we should be a bit more conservative withchanges to benchmarks so maybe address in a follow-up?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sure, I can just make a separate PR for this

}
}
#define ROUND_BENCHMARK(round) \
Expand Down
135 changes: 135 additions & 0 deletions barretenberg/cpp/src/barretenberg/protogalaxy/presumcheck_prover.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
#include "barretenberg/protogalaxy/presumcheck_prover.hpp"

namespace bb {

/**
* @brief Add circuit size, public input size, and public inputs to transcript
*
*/
template <IsUltraFlavor Flavor> void PreSumcheckProver<Flavor>::execute_preamble_round()
{
const auto instance_size = static_cast<uint32_t>(instance->instance_size);
const auto num_public_inputs = static_cast<uint32_t>(instance->public_inputs.size());
transcript->send_to_verifier(domain_separator + "instance_size", instance_size);
transcript->send_to_verifier(domain_separator + "public_input_size", num_public_inputs);
transcript->send_to_verifier(domain_separator + "pub_inputs_offset",
static_cast<uint32_t>(instance->pub_inputs_offset));

ASSERT(instance->proving_key->num_public_inputs == instance->public_inputs.size());

for (size_t i = 0; i < instance->public_inputs.size(); ++i) {
auto public_input_i = instance->public_inputs[i];
transcript->send_to_verifier(domain_separator + "public_input_" + std::to_string(i), public_input_i);
}
}

/**
* @brief Commit to the wire polynomials (part of the witness), with the exception of the fourth wire, which is
* only commited to after adding memory records. In the Goblin Flavor, we also commit to the ECC OP wires and the
* DataBus columns.
*/
template <IsUltraFlavor Flavor> void PreSumcheckProver<Flavor>::execute_wire_commitments_round()
{
auto& witness_commitments = instance->witness_commitments;

// Commit to the first three wire polynomials of the instance
// We only commit to the fourth wire polynomial after adding memory recordss
witness_commitments.w_l = commitment_key->commit(instance->proving_key->w_l);
witness_commitments.w_r = commitment_key->commit(instance->proving_key->w_r);
witness_commitments.w_o = commitment_key->commit(instance->proving_key->w_o);

auto wire_comms = witness_commitments.get_wires();
auto commitment_labels = instance->commitment_labels;
auto wire_labels = commitment_labels.get_wires();
for (size_t idx = 0; idx < 3; ++idx) {
transcript->send_to_verifier(domain_separator + wire_labels[idx], wire_comms[idx]);
}

if constexpr (IsGoblinFlavor<Flavor>) {
// Commit to Goblin ECC op wires
witness_commitments.ecc_op_wire_1 = commitment_key->commit(instance->proving_key->ecc_op_wire_1);
witness_commitments.ecc_op_wire_2 = commitment_key->commit(instance->proving_key->ecc_op_wire_2);
witness_commitments.ecc_op_wire_3 = commitment_key->commit(instance->proving_key->ecc_op_wire_3);
witness_commitments.ecc_op_wire_4 = commitment_key->commit(instance->proving_key->ecc_op_wire_4);

auto op_wire_comms = instance->witness_commitments.get_ecc_op_wires();
auto labels = commitment_labels.get_ecc_op_wires();
for (size_t idx = 0; idx < Flavor::NUM_WIRES; ++idx) {
transcript->send_to_verifier(domain_separator + labels[idx], op_wire_comms[idx]);
}
// Commit to DataBus columns
witness_commitments.calldata = commitment_key->commit(instance->proving_key->calldata);
witness_commitments.calldata_read_counts = commitment_key->commit(instance->proving_key->calldata_read_counts);
transcript->send_to_verifier(domain_separator + commitment_labels.calldata,
instance->witness_commitments.calldata);
transcript->send_to_verifier(domain_separator + commitment_labels.calldata_read_counts,
instance->witness_commitments.calldata_read_counts);
}
}

/**
* @brief Compute sorted witness-table accumulator and commit to the resulting polynomials.
*
*/
template <IsUltraFlavor Flavor> void PreSumcheckProver<Flavor>::execute_sorted_list_accumulator_round()
{
auto& witness_commitments = instance->witness_commitments;
const auto& commitment_labels = instance->commitment_labels;

auto eta = transcript->template get_challenge<typename Flavor::FF>(domain_separator + "eta");
instance->compute_sorted_accumulator_polynomials(eta);

// Commit to the sorted witness-table accumulator and the finalized (i.e. with memory records) fourth wire
// polynomial
witness_commitments.sorted_accum = commitment_key->commit(instance->prover_polynomials.sorted_accum);
witness_commitments.w_4 = commitment_key->commit(instance->prover_polynomials.w_4);

transcript->send_to_verifier(domain_separator + commitment_labels.sorted_accum, witness_commitments.sorted_accum);
transcript->send_to_verifier(domain_separator + commitment_labels.w_4, witness_commitments.w_4);
}

/**
* @brief Compute log derivative inverse polynomial and its commitment, if required
*
*/
template <IsUltraFlavor Flavor> void PreSumcheckProver<Flavor>::execute_log_derivative_inverse_round()
{
const auto& commitment_labels = instance->commitment_labels;

auto [beta, gamma] =
transcript->template get_challenges<typename Flavor::FF>(domain_separator + "beta", domain_separator + "gamma");
instance->relation_parameters.beta = beta;
instance->relation_parameters.gamma = gamma;
if constexpr (IsGoblinFlavor<Flavor>) {
// Compute and commit to the logderivative inverse used in DataBus
instance->compute_logderivative_inverse(beta, gamma);
instance->witness_commitments.lookup_inverses =
commitment_key->commit(instance->prover_polynomials.lookup_inverses);
transcript->send_to_verifier(domain_separator + commitment_labels.lookup_inverses,
instance->witness_commitments.lookup_inverses);
}
}

/**
* @brief Compute permutation and lookup grand product polynomials and their commitments
*
*/
template <IsUltraFlavor Flavor> void PreSumcheckProver<Flavor>::execute_grand_product_computation_round()
{
auto& witness_commitments = instance->witness_commitments;
const auto& commitment_labels = instance->commitment_labels;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

commitment_labels are static and come from flavor so they could become a field in PreSumcheckProver rather than extracting from instance every time

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's just making a const reference of them, so it should be fine?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

actually, next PR will do this and remove it from instance at the same time.


instance->compute_grand_product_polynomials(instance->relation_parameters.beta,
instance->relation_parameters.gamma);

witness_commitments.z_perm = commitment_key->commit(instance->prover_polynomials.z_perm);
witness_commitments.z_lookup = commitment_key->commit(instance->prover_polynomials.z_lookup);

transcript->send_to_verifier(domain_separator + commitment_labels.z_perm, instance->witness_commitments.z_perm);
transcript->send_to_verifier(domain_separator + commitment_labels.z_lookup, instance->witness_commitments.z_lookup);
}

template class PreSumcheckProver<UltraFlavor>;
template class PreSumcheckProver<GoblinUltraFlavor>;

} // namespace bb
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
#pragma once
#include <utility>

#include "barretenberg/flavor/goblin_ultra.hpp"
#include "barretenberg/flavor/ultra.hpp"
#include "barretenberg/sumcheck/instance/prover_instance.hpp"
#include "barretenberg/transcript/transcript.hpp"

namespace bb {

template <IsUltraFlavor Flavor> class PreSumcheckProver {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you please add documentation to this class.. also maybe we should discuss whether a prover inside a prover is the right approach, I was envisioning an architecture where we have some shared rounds that can be in a utility class, similar to the utility class for operating on relations that both sumcheck and protogalaxy uses. It would be nice to have this class be static but might require some more refactoring work.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

will add more comments. Prover inside prover is already our current standard. We have a SumcheckProver and ZeromorphProver both inside UltraProver.

I agree that the PreSumCheckProver can be split into further classes for each Round, and can eventually be made static or something like that.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That is a good point, I forgot about those. What do you think about creating a PreSumchVerifier as well (since we have a SumcheckVerifier and ZeromorphVerifier)? If you prefer not to do this in this PR, I think an issue should be added

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that can just go into a followup PR.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok, please add an issue

using CommitmentKey = typename Flavor::CommitmentKey;
using Instance = ProverInstance_<Flavor>;
using Transcript = typename Flavor::Transcript;

public:
PreSumcheckProver(const std::shared_ptr<ProverInstance_<Flavor>>& inst,
const std::shared_ptr<typename Flavor::CommitmentKey>& commitment_key,
const std::shared_ptr<typename Flavor::Transcript>& transcript,
std::string domain_separator = "")
: instance(inst)
, transcript(transcript)
, commitment_key(commitment_key)
, domain_separator(std::move(domain_separator))
{
instance->initialize_prover_polynomials();
}

void execute_preamble_round();

void execute_wire_commitments_round();

void execute_sorted_list_accumulator_round();

void execute_log_derivative_inverse_round();

void execute_grand_product_computation_round();

std::shared_ptr<Instance> instance;
std::shared_ptr<Transcript> transcript;
std::shared_ptr<CommitmentKey> commitment_key;
std::string domain_separator;
};

template <IsUltraFlavor Flavor>
void prover_setup_(const std::shared_ptr<ProverInstance_<Flavor>>& instance,
lucasxia01 marked this conversation as resolved.
Show resolved Hide resolved
const std::shared_ptr<typename Flavor::CommitmentKey>& commitment_key,
const std::shared_ptr<typename Flavor::Transcript>& transcript,
const std::string& domain_separator = "");
} // namespace bb
Original file line number Diff line number Diff line change
@@ -1,94 +1,18 @@
#include "protogalaxy_prover.hpp"
#include "barretenberg/flavor/flavor.hpp"
#include "barretenberg/protogalaxy/presumcheck_prover.hpp"
namespace bb {
template <class ProverInstances>
void ProtoGalaxyProver_<ProverInstances>::finalise_and_send_instance(std::shared_ptr<Instance> instance,
const std::string& domain_separator)
{
instance->initialize_prover_polynomials();

const auto instance_size = static_cast<uint32_t>(instance->instance_size);
const auto num_public_inputs = static_cast<uint32_t>(instance->public_inputs.size());
transcript->send_to_verifier(domain_separator + "_instance_size", instance_size);
transcript->send_to_verifier(domain_separator + "_public_input_size", num_public_inputs);

for (size_t i = 0; i < instance->public_inputs.size(); ++i) {
auto public_input_i = instance->public_inputs[i];
transcript->send_to_verifier(domain_separator + "_public_input_" + std::to_string(i), public_input_i);
}
transcript->send_to_verifier(domain_separator + "_pub_inputs_offset",
static_cast<uint32_t>(instance->pub_inputs_offset));

auto& witness_commitments = instance->witness_commitments;

// Commit to the first three wire polynomials of the instance
// We only commit to the fourth wire polynomial after adding memory recordss
witness_commitments.w_l = commitment_key->commit(instance->proving_key->w_l);
witness_commitments.w_r = commitment_key->commit(instance->proving_key->w_r);
witness_commitments.w_o = commitment_key->commit(instance->proving_key->w_o);

auto wire_comms = witness_commitments.get_wires();
auto commitment_labels = instance->commitment_labels;
auto wire_labels = commitment_labels.get_wires();
for (size_t idx = 0; idx < 3; ++idx) {
transcript->send_to_verifier(domain_separator + "_" + wire_labels[idx], wire_comms[idx]);
}

if constexpr (IsGoblinFlavor<Flavor>) {
// Commit to Goblin ECC op wires
witness_commitments.ecc_op_wire_1 = commitment_key->commit(instance->proving_key->ecc_op_wire_1);
witness_commitments.ecc_op_wire_2 = commitment_key->commit(instance->proving_key->ecc_op_wire_2);
witness_commitments.ecc_op_wire_3 = commitment_key->commit(instance->proving_key->ecc_op_wire_3);
witness_commitments.ecc_op_wire_4 = commitment_key->commit(instance->proving_key->ecc_op_wire_4);

auto op_wire_comms = instance->witness_commitments.get_ecc_op_wires();
auto labels = commitment_labels.get_ecc_op_wires();
for (size_t idx = 0; idx < Flavor::NUM_WIRES; ++idx) {
transcript->send_to_verifier(domain_separator + "_" + labels[idx], op_wire_comms[idx]);
}
// Commit to DataBus columns
witness_commitments.calldata = commitment_key->commit(instance->proving_key->calldata);
witness_commitments.calldata_read_counts = commitment_key->commit(instance->proving_key->calldata_read_counts);
transcript->send_to_verifier(domain_separator + "_" + commitment_labels.calldata,
instance->witness_commitments.calldata);
transcript->send_to_verifier(domain_separator + "_" + commitment_labels.calldata_read_counts,
instance->witness_commitments.calldata_read_counts);
}

auto eta = transcript->template get_challenge<FF>(domain_separator + "_eta");
instance->compute_sorted_accumulator_polynomials(eta);

// Commit to the sorted witness-table accumulator and the finalized (i.e. with memory records) fourth wire
// polynomial
witness_commitments.sorted_accum = commitment_key->commit(instance->prover_polynomials.sorted_accum);
witness_commitments.w_4 = commitment_key->commit(instance->prover_polynomials.w_4);

transcript->send_to_verifier(domain_separator + "_" + commitment_labels.sorted_accum,
witness_commitments.sorted_accum);
transcript->send_to_verifier(domain_separator + "_" + commitment_labels.w_4, witness_commitments.w_4);

auto [beta, gamma] =
transcript->template get_challenges<FF>(domain_separator + "_beta", domain_separator + "_gamma");

if constexpr (IsGoblinFlavor<Flavor>) {
// Compute and commit to the logderivative inverse used in DataBus
instance->compute_logderivative_inverse(beta, gamma);
instance->witness_commitments.lookup_inverses =
commitment_key->commit(instance->prover_polynomials.lookup_inverses);
transcript->send_to_verifier(domain_separator + "_" + commitment_labels.lookup_inverses,
instance->witness_commitments.lookup_inverses);
}

instance->compute_grand_product_polynomials(beta, gamma);

witness_commitments.z_perm = commitment_key->commit(instance->prover_polynomials.z_perm);
witness_commitments.z_lookup = commitment_key->commit(instance->prover_polynomials.z_lookup);

transcript->send_to_verifier(domain_separator + "_" + commitment_labels.z_perm,
instance->witness_commitments.z_perm);
transcript->send_to_verifier(domain_separator + "_" + commitment_labels.z_lookup,
instance->witness_commitments.z_lookup);
for (size_t idx = 0; idx < NUM_SUBRELATIONS - 1; idx++) {
PreSumcheckProver<Flavor> pre_sumcheck_prover(instance, commitment_key, transcript, domain_separator + '_');
pre_sumcheck_prover.execute_preamble_round();
Copy link
Contributor

@maramihali maramihali Feb 29, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would call these just pre_sumcheck to be more compliant with how the other inner provers are used (Sumcheck and Zeromorph)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it should be clarified if its the prover

pre_sumcheck_prover.execute_wire_commitments_round();
Copy link
Contributor

@maramihali maramihali Feb 29, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you please add comments summarising what each of these calls do for readability in a similar way they are found in the UltraProver

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

pre_sumcheck_prover.execute_sorted_list_accumulator_round();
pre_sumcheck_prover.execute_log_derivative_inverse_round();
pre_sumcheck_prover.execute_grand_product_computation_round();
for (size_t idx = 0; idx < Flavor::NUM_SUBRELATIONS - 1; idx++) {
instance->alphas[idx] =
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this alphas generation is part of sumcheck so its not included in the shared prover

transcript->template get_challenge<FF>(domain_separator + "_alpha_" + std::to_string(idx));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,16 +70,15 @@ void ProtoGalaxyVerifier_<VerifierInstances>::receive_and_finalise_instance(cons
inst->log_instance_size = static_cast<size_t>(numeric::get_msb(inst->instance_size));
inst->public_input_size =
transcript->template receive_from_prover<uint32_t>(domain_separator + "_public_input_size");
inst->pub_inputs_offset =
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

moved above the public inputs to align with ultra_prover's ordering of what it sends

transcript->template receive_from_prover<uint32_t>(domain_separator + "_pub_inputs_offset");

for (size_t i = 0; i < inst->public_input_size; ++i) {
auto public_input_i =
transcript->template receive_from_prover<FF>(domain_separator + "_public_input_" + std::to_string(i));
inst->public_inputs.emplace_back(public_input_i);
}

inst->pub_inputs_offset =
transcript->template receive_from_prover<uint32_t>(domain_separator + "_pub_inputs_offset");

// Get commitments to first three wire polynomials
auto labels = inst->commitment_labels;
auto& witness_commitments = inst->witness_commitments;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,18 +75,17 @@ void ProtoGalaxyRecursiveVerifier_<VerifierInstances>::receive_and_finalise_inst
inst->instance_size = uint32_t(instance_size.get_value());
inst->log_instance_size = static_cast<size_t>(numeric::get_msb(inst->instance_size));
inst->public_input_size = uint32_t(public_input_size.get_value());
const auto pub_inputs_offset =
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

reordered for consistency with oink prover

transcript->template receive_from_prover<FF>(domain_separator + "_pub_inputs_offset");

inst->pub_inputs_offset = uint32_t(pub_inputs_offset.get_value());

for (size_t i = 0; i < inst->public_input_size; ++i) {
auto public_input_i =
transcript->template receive_from_prover<FF>(domain_separator + "_public_input_" + std::to_string(i));
inst->public_inputs.emplace_back(public_input_i);
}

const auto pub_inputs_offset =
transcript->template receive_from_prover<FF>(domain_separator + "_pub_inputs_offset");

inst->pub_inputs_offset = uint32_t(pub_inputs_offset.get_value());

// Get commitments to first three wire polynomials
auto labels = inst->commitment_labels;
auto& witness_commitments = inst->witness_commitments;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ std::array<typename Flavor::GroupElement, 2> UltraRecursiveVerifier_<Flavor>::ve
VerifierCommitments commitments{ key };
CommitmentLabels commitment_labels;

const auto circuit_size = transcript->template receive_from_prover<FF>("circuit_size");
const auto circuit_size = transcript->template receive_from_prover<FF>("instance_size");
const auto public_input_size = transcript->template receive_from_prover<FF>("public_input_size");
const auto pub_inputs_offset = transcript->template receive_from_prover<FF>("pub_inputs_offset");

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ class GoblinUltraTranscriptTests : public ::testing::Test {
size_t frs_per_uint32 = bb::field_conversion::calc_num_bn254_frs<uint32_t>();

size_t round = 0;
manifest_expected.add_entry(round, "circuit_size", frs_per_uint32);
manifest_expected.add_entry(round, "instance_size", frs_per_uint32);
manifest_expected.add_entry(round, "public_input_size", frs_per_uint32);
manifest_expected.add_entry(round, "pub_inputs_offset", frs_per_uint32);
manifest_expected.add_entry(round, "public_input_0", frs_per_Fr);
Expand Down
Loading
Loading