Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Split Sumcheck #245

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 9 additions & 14 deletions cpp/src/barretenberg/honk/composer/standard_honk_composer.test.cpp
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
#include "standard_honk_composer.hpp"
#include "barretenberg/honk/sumcheck/relations/relation.hpp"
#include "barretenberg/numeric/uint256/uint256.hpp"
#include "barretenberg/proof_system/flavor/flavor.hpp"
#include <cstdint>
#include "barretenberg/honk/proof_system/prover.hpp"
#include "barretenberg/honk/sumcheck/sumcheck_round.hpp"
#include "barretenberg/honk/sumcheck/relations/grand_product_computation_relation.hpp"
#include "barretenberg/honk/sumcheck/relations/grand_product_initialization_relation.hpp"
#include "barretenberg/honk/sumcheck/relations/arithmetic_relation.hpp"
#include "barretenberg/honk/sumcheck/relations/relation.hpp"
#include "barretenberg/honk/utils/public_inputs.hpp"

#include <gtest/gtest.h>
#include <cstdint>

using namespace honk;

Expand Down Expand Up @@ -335,16 +335,13 @@ TEST(StandardHonkComposer, SumcheckRelationCorrectness)
// Generate beta and gamma
fr beta = fr::random_element();
fr gamma = fr::random_element();
fr zeta = fr::random_element();

// Compute public input delta
const auto public_inputs = composer.circuit_constructor.get_public_inputs();
auto public_input_delta =
honk::compute_public_input_delta<fr>(public_inputs, beta, gamma, prover.key->circuit_size);

sumcheck::RelationParameters<fr> params{
.zeta = zeta,
.alpha = fr::one(),
.beta = beta,
.gamma = gamma,
.public_input_delta = public_input_delta,
Expand Down Expand Up @@ -380,11 +377,6 @@ TEST(StandardHonkComposer, SumcheckRelationCorrectness)
evaluations_array[POLYNOMIAL::LAGRANGE_FIRST] = prover.key->polynomial_store.get("L_first_lagrange");
evaluations_array[POLYNOMIAL::LAGRANGE_LAST] = prover.key->polynomial_store.get("L_last_lagrange");

// Construct the round for applying sumcheck relations and results for storing computed results
auto relations = std::tuple(honk::sumcheck::ArithmeticRelation<fr>(),
honk::sumcheck::GrandProductComputationRelation<fr>(),
honk::sumcheck::GrandProductInitializationRelation<fr>());

fr result = 0;
for (size_t i = 0; i < prover.key->circuit_size; i++) {
// Compute an array containing all the evaluations at a given row i
Expand All @@ -397,13 +389,16 @@ TEST(StandardHonkComposer, SumcheckRelationCorrectness)
// i-th row/vertex of the hypercube.
// We use ASSERT_EQ instead of EXPECT_EQ so that the tests stops at the first index at which the result is not
// 0, since result = 0 + C(transposed), which we expect will equal 0.
std::get<0>(relations).add_full_relation_value_contribution(result, evaluations_at_index_i, params);
result = honk::sumcheck::ArithmeticRelation<fr>::evaluate_full_relation_value_contribution(
evaluations_at_index_i, params);
ASSERT_EQ(result, 0);

std::get<1>(relations).add_full_relation_value_contribution(result, evaluations_at_index_i, params);
result = honk::sumcheck::GrandProductComputationRelation<fr>::evaluate_full_relation_value_contribution(
evaluations_at_index_i, params);
ASSERT_EQ(result, 0);

std::get<2>(relations).add_full_relation_value_contribution(result, evaluations_at_index_i, params);
result = honk::sumcheck::GrandProductInitializationRelation<fr>::evaluate_full_relation_value_contribution(
evaluations_at_index_i, params);
ASSERT_EQ(result, 0);
}
}
Expand Down
75 changes: 35 additions & 40 deletions cpp/src/barretenberg/honk/proof_system/prover.cpp
Original file line number Diff line number Diff line change
@@ -1,33 +1,26 @@
#include "prover.hpp"
#include <algorithm>
#include <cstddef>
#include "barretenberg/honk/sumcheck/sumcheck.hpp" // will need
#include <array>
#include "barretenberg/honk/sumcheck/polynomials/univariate.hpp" // will go away
#include "barretenberg/honk/utils/power_polynomial.hpp"
#include "barretenberg/honk/sumcheck/sumcheck.hpp"
#include "barretenberg/honk/pcs/commitment_key.hpp"
#include <memory>
#include <span>
#include <utility>
#include <vector>
#include "barretenberg/ecc/curves/bn254/fr.hpp"
#include "barretenberg/ecc/curves/bn254/g1.hpp"
#include "barretenberg/honk/sumcheck/relations/arithmetic_relation.hpp"
#include "barretenberg/honk/sumcheck/relations/grand_product_computation_relation.hpp"
#include "barretenberg/honk/sumcheck/relations/grand_product_initialization_relation.hpp"
#include "barretenberg/honk/utils/public_inputs.hpp"
#include "barretenberg/polynomials/polynomial.hpp"
#include "barretenberg/proof_system/flavor/flavor.hpp"
#include "barretenberg/transcript/transcript_wrappers.hpp"

#include <string>
#include "barretenberg/honk/pcs/claim.hpp"
#include <array>
#include <algorithm>
#include <cstddef>
#include <memory>
#include <span>
#include <utility>
#include <vector>

namespace honk {

using Fr = barretenberg::fr;
using Commitment = barretenberg::g1::affine_element;
using Polynomial = barretenberg::Polynomial<Fr>;
using POLYNOMIAL = bonk::StandardArithmetization::POLYNOMIAL;

/**
* Create Prover from proving key, witness and manifest.
*
Expand All @@ -37,7 +30,7 @@ using POLYNOMIAL = bonk::StandardArithmetization::POLYNOMIAL;
* @tparam settings Settings class.
* */
template <typename settings>
Prover<settings>::Prover(std::vector<barretenberg::polynomial>&& wire_polys,
Prover<settings>::Prover(std::vector<Polynomial>&& wire_polys,
std::shared_ptr<bonk::proving_key> input_key,
const transcript::Manifest& input_manifest)
: transcript(input_manifest, settings::hash_type, settings::num_challenge_bytes)
Expand Down Expand Up @@ -66,6 +59,13 @@ Prover<settings>::Prover(std::vector<barretenberg::polynomial>&& wire_polys,
prover_polynomials[POLYNOMIAL::W_L] = wire_polynomials[0];
prover_polynomials[POLYNOMIAL::W_R] = wire_polynomials[1];
prover_polynomials[POLYNOMIAL::W_O] = wire_polynomials[2];

// Add public inputs to transcript from the second wire polynomial
std::span<Fr> public_wires_source = prover_polynomials[POLYNOMIAL::W_R];

for (size_t i = 0; i < key->num_public_inputs; ++i) {
public_inputs.emplace_back(public_wires_source[i]);
}
}

/**
Expand Down Expand Up @@ -109,10 +109,11 @@ template <typename settings> void Prover<settings>::compute_wire_commitments()
* one batch inversion (at the expense of more multiplications)
*/
// TODO(#222)(luke): Parallelize
template <typename settings> Polynomial Prover<settings>::compute_grand_product_polynomial(Fr beta, Fr gamma)
template <typename settings>
barretenberg::Polynomial<barretenberg::fr> Prover<settings>::compute_grand_product_polynomial(Fr beta, Fr gamma)
{
using barretenberg::polynomial_arithmetic::copy_polynomial;
static const size_t program_width = settings::program_width;
constexpr size_t program_width = settings::program_width;

// Allocate scratch space for accumulators
std::array<Fr*, program_width> numerator_accumulator;
Expand Down Expand Up @@ -224,14 +225,8 @@ template <typename settings> void Prover<settings>::execute_wire_commitments_rou
// queue.flush_queue(); // NOTE: Don't remove; we may reinstate the queue
compute_wire_commitments();

// Add public inputs to transcript from the second wire polynomial
const Polynomial& public_wires_source = wire_polynomials[1];

std::vector<Fr> public_wires;
for (size_t i = 0; i < key->num_public_inputs; ++i) {
public_wires.push_back(public_wires_source[i]);
}
transcript.add_element("public_inputs", ::to_buffer(public_wires));
// Add public inputs to transcript
transcript.add_element("public_inputs", ::to_buffer(public_inputs));
}

/**
Expand All @@ -257,6 +252,15 @@ template <typename settings> void Prover<settings>::execute_grand_product_comput

auto beta = transcript.get_challenge_field_element("beta", 0);
auto gamma = transcript.get_challenge_field_element("beta", 1);

auto public_input_delta = compute_public_input_delta<Fr>(public_inputs, beta, gamma, key->circuit_size);

relation_parameters = sumcheck::RelationParameters<Fr>{
.beta = beta,
.gamma = gamma,
.public_input_delta = public_input_delta,
};

z_permutation = compute_grand_product_polynomial(beta, gamma);
// The actual polynomial is of length n+1, but commitment key is just n, so we need to limit it
auto commitment = commitment_key->commit(z_permutation);
Expand All @@ -276,16 +280,14 @@ template <typename settings> void Prover<settings>::execute_relation_check_round
// queue.flush_queue(); // NOTE: Don't remove; we may reinstate the queue

using Sumcheck = sumcheck::Sumcheck<Fr,
Transcript,
sumcheck::ArithmeticRelation,
sumcheck::GrandProductComputationRelation,
sumcheck::GrandProductInitializationRelation>;

transcript.apply_fiat_shamir("alpha");

auto sumcheck = Sumcheck(key->circuit_size, transcript);

sumcheck.execute_prover(prover_polynomials);
sumcheck_output =
Sumcheck::execute_prover(key->log_circuit_size, relation_parameters, prover_polynomials, transcript);
}

/**
Expand All @@ -298,20 +300,13 @@ template <typename settings> void Prover<settings>::execute_univariatization_rou
const size_t NUM_POLYNOMIALS = bonk::StandardArithmetization::NUM_POLYNOMIALS;
const size_t NUM_UNSHIFTED_POLYS = bonk::StandardArithmetization::NUM_UNSHIFTED_POLYNOMIALS;

// Construct MLE opening point u = (u_0, ..., u_{d-1})
std::vector<Fr> opening_point; // u
for (size_t round_idx = 0; round_idx < key->log_circuit_size; round_idx++) {
std::string label = "u_" + std::to_string(round_idx);
opening_point.emplace_back(transcript.get_challenge_field_element(label));
}

// Generate batching challenge ρ and powers 1,ρ,…,ρᵐ⁻¹
transcript.apply_fiat_shamir("rho");
Fr rho = Fr::serialize_from_buffer(transcript.get_challenge("rho").begin());
std::vector<Fr> rhos = Gemini::powers_of_rho(rho, NUM_POLYNOMIALS);

// Get vector of multivariate evaluations produced by Sumcheck
auto multivariate_evaluations = transcript.get_field_element_vector("multivariate_evaluations");
auto [opening_point, multivariate_evaluations] = sumcheck_output;

// Batch the unshifted polynomials and the to-be-shifted polynomials using ρ
Polynomial batched_poly_unshifted(key->circuit_size); // batched unshifted polynomials
Expand Down
42 changes: 23 additions & 19 deletions cpp/src/barretenberg/honk/proof_system/prover.hpp
Original file line number Diff line number Diff line change
@@ -1,34 +1,33 @@
#pragma once
#include "barretenberg/ecc/curves/bn254/fr.hpp"
#include "barretenberg/honk/pcs/shplonk/shplonk.hpp"
#include "barretenberg/polynomials/polynomial.hpp"
#include "barretenberg/proof_system/flavor/flavor.hpp"
#include <array>
#include "barretenberg/proof_system/proving_key/proving_key.hpp"
#include "barretenberg/honk/sumcheck/sumcheck.hpp"
#include "barretenberg/honk/sumcheck/relations/relation.hpp"
#include "barretenberg/honk/pcs/commitment_key.hpp"
#include "barretenberg/plonk/proof_system/types/proof.hpp"
#include "barretenberg/plonk/proof_system/types/program_settings.hpp"
#include "barretenberg/honk/pcs/gemini/gemini.hpp"
#include "barretenberg/honk/pcs/shplonk/shplonk_single.hpp"
#include "barretenberg/honk/pcs/shplonk/shplonk.hpp"
#include "barretenberg/honk/pcs/kzg/kzg.hpp"
#include "barretenberg/transcript/transcript_wrappers.hpp"
#include "barretenberg/plonk/proof_system/types/proof.hpp"
#include "barretenberg/proof_system/proving_key/proving_key.hpp"
#include "barretenberg/proof_system/flavor/flavor.hpp"
#include "barretenberg/plonk/proof_system/types/prover_settings.hpp"

#include <array>
#include <span>
#include <unordered_map>
#include <vector>
#include <algorithm>
#include <cstddef>
#include <memory>
#include <utility>
#include <string>
#include "barretenberg/honk/pcs/claim.hpp"

namespace honk {

using Fr = barretenberg::fr;

template <typename settings> class Prover {

using Fr = barretenberg::fr;
using Polynomial = barretenberg::Polynomial<Fr>;
using POLYNOMIAL = bonk::StandardArithmetization::POLYNOMIAL;

public:
Prover(std::vector<barretenberg::polynomial>&& wire_polys,
Prover(std::vector<Polynomial>&& wire_polys,
std::shared_ptr<bonk::proving_key> input_key = nullptr,
const transcript::Manifest& manifest = transcript::Manifest());

Expand All @@ -44,7 +43,7 @@ template <typename settings> class Prover {

void compute_wire_commitments();

barretenberg::polynomial compute_grand_product_polynomial(Fr beta, Fr gamma);
Polynomial compute_grand_product_polynomial(Fr beta, Fr gamma);

void construct_prover_polynomials();

Expand All @@ -53,8 +52,12 @@ template <typename settings> class Prover {

transcript::StandardTranscript transcript;

std::vector<barretenberg::polynomial> wire_polynomials;
barretenberg::polynomial z_permutation;
std::vector<Fr> public_inputs;

std::vector<Polynomial> wire_polynomials;
Polynomial z_permutation;

sumcheck::RelationParameters<Fr> relation_parameters;

std::shared_ptr<bonk::proving_key> key;

Expand All @@ -79,6 +82,7 @@ template <typename settings> class Prover {
// This makes 'settings' accesible from Prover
using settings_ = settings;

sumcheck::SumcheckOutput<Fr> sumcheck_output;
pcs::gemini::ProverOutput<pcs::kzg::Params> gemini_output;
pcs::shplonk::ProverOutput<pcs::kzg::Params> shplonk_output;

Expand Down
Loading