Skip to content

Commit

Permalink
Making sumcheck work with challenges (#102)
Browse files Browse the repository at this point in the history
Fixes sumcheck relations so the full relation is now correct
---------

Co-authored-by: Rumata888 <[email protected]>
Co-authored-by: ledwards2225 <[email protected]>
  • Loading branch information
3 people authored Jan 30, 2023
1 parent c697813 commit c9554f0
Show file tree
Hide file tree
Showing 12 changed files with 214 additions and 172 deletions.
2 changes: 1 addition & 1 deletion barretenberg/cpp/src/aztec/honk/proof_system/prover.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -295,7 +295,7 @@ template <typename settings> void Prover<settings>::execute_relation_check_round
using Sumcheck = sumcheck::Sumcheck<Multivariates,
Transcript,
sumcheck::ArithmeticRelation,
// sumcheck::GrandProductComputationRelation,
sumcheck::GrandProductComputationRelation,
sumcheck::GrandProductInitializationRelation>;

// Compute alpha challenge
Expand Down
2 changes: 1 addition & 1 deletion barretenberg/cpp/src/aztec/honk/proof_system/verifier.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ template <typename program_settings> bool Verifier<program_settings>::verify_pro
auto sumcheck = Sumcheck<Multivariates,
Transcript,
ArithmeticRelation,
// GrandProductComputationRelation,
GrandProductComputationRelation,
GrandProductInitializationRelation>(transcript);
bool sumcheck_result = sumcheck.execute_verifier();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,11 @@ template <typename FF> class ArithmeticRelation : public Relation<FF> {
*
* @param extended_edges Contain inputs for the relation
* @param evals Contains the resulting univariate polynomial
*
* The final parameter is left to conform to the general argument structure (input,output, challenges) even though
* we don't need challenges in this relation.
*/
void add_edge_contribution(auto& extended_edges, Univariate<FF, RELATION_LENGTH>& evals)
template <typename T> void add_edge_contribution(auto& extended_edges, Univariate<FF, RELATION_LENGTH>& evals, T)
{
add_edge_contribution_internal(extended_edges, evals);
};
Expand Down Expand Up @@ -70,7 +73,8 @@ template <typename FF> class ArithmeticRelation : public Relation<FF> {
evals += q_c;
};

void add_full_relation_value_contribution(auto& purported_evaluations, FF& full_honk_relation_value)
template <typename T>
void add_full_relation_value_contribution(auto& purported_evaluations, FF& full_honk_relation_value, T)
{
auto w_l = purported_evaluations[MULTIVARIATE::W_L];
auto w_r = purported_evaluations[MULTIVARIATE::W_R];
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,19 +12,20 @@ template <typename FF> class GrandProductComputationRelation : public Relation<F
using MULTIVARIATE = StandardHonk::MULTIVARIATE;

public:
// TODO(luke): make these real challenges once manifest is done
const FF beta_default = FF::one();
const FF gamma_default = FF::one();
const FF public_input_delta_default = FF::one();

GrandProductComputationRelation() = default;
explicit GrandProductComputationRelation(auto){}; // TODO(luke): should just be default?
/**
* @brief Add contribution of the permutation relation for a given edge (used by sumcheck round)
*/
void add_edge_contribution(auto& extended_edges, Univariate<FF, RELATION_LENGTH>& evals)
void add_edge_contribution(auto& extended_edges,
Univariate<FF, RELATION_LENGTH>& evals,
const RelationParameters<FF>& relation_parameters)
{
add_edge_contribution_internal(extended_edges, evals, beta_default, gamma_default, public_input_delta_default);
add_edge_contribution_internal(extended_edges,
evals,
relation_parameters.beta,
relation_parameters.gamma,
relation_parameters.public_input_delta);
};

/**
Expand Down Expand Up @@ -53,8 +54,11 @@ template <typename FF> class GrandProductComputationRelation : public Relation<F
* delta is the public input correction term
*
*/
inline void add_edge_contribution_internal(
auto& extended_edges, Univariate<FF, RELATION_LENGTH>& evals, FF beta, FF gamma, FF public_input_delta)
inline void add_edge_contribution_internal(auto& extended_edges,
Univariate<FF, RELATION_LENGTH>& evals,
const FF& beta,
const FF& gamma,
const FF& public_input_delta)
{
auto w_1 = UnivariateView<FF, RELATION_LENGTH>(extended_edges[MULTIVARIATE::W_L]);
auto w_2 = UnivariateView<FF, RELATION_LENGTH>(extended_edges[MULTIVARIATE::W_R]);
Expand All @@ -77,7 +81,9 @@ template <typename FF> class GrandProductComputationRelation : public Relation<F
(w_2 + sigma_2 * beta + gamma) * (w_3 + sigma_3 * beta + gamma));
};

void add_full_relation_value_contribution(auto& purported_evaluations, FF& full_honk_relation_value)
void add_full_relation_value_contribution(auto& purported_evaluations,
FF& full_honk_relation_value,
const RelationParameters<FF>& relation_parameters)
{
auto w_1 = purported_evaluations[MULTIVARIATE::W_L];
auto w_2 = purported_evaluations[MULTIVARIATE::W_R];
Expand All @@ -86,20 +92,22 @@ template <typename FF> class GrandProductComputationRelation : public Relation<F
auto sigma_2 = purported_evaluations[MULTIVARIATE::SIGMA_2];
auto sigma_3 = purported_evaluations[MULTIVARIATE::SIGMA_3];
auto id_1 = purported_evaluations[MULTIVARIATE::ID_1];
auto id_2 = purported_evaluations[MULTIVARIATE::ID_1];
auto id_3 = purported_evaluations[MULTIVARIATE::ID_1];
auto id_2 = purported_evaluations[MULTIVARIATE::ID_2];
auto id_3 = purported_evaluations[MULTIVARIATE::ID_3];
auto z_perm = purported_evaluations[MULTIVARIATE::Z_PERM];
auto z_perm_shift = purported_evaluations[MULTIVARIATE::Z_PERM_SHIFT];
auto lagrange_first = purported_evaluations[MULTIVARIATE::LAGRANGE_FIRST];
auto lagrange_last = purported_evaluations[MULTIVARIATE::LAGRANGE_LAST];

// Contribution (1)
full_honk_relation_value +=
(z_perm + lagrange_first) * (w_1 + beta_default * id_1 + gamma_default) *
(w_2 + beta_default * id_2 + gamma_default) * (w_3 + beta_default * id_3 + gamma_default) -
(z_perm_shift + lagrange_last * public_input_delta_default) *
(w_1 + beta_default * sigma_1 + gamma_default) * (w_2 + beta_default * sigma_2 + gamma_default) *
(w_3 + beta_default * sigma_3 + gamma_default);
full_honk_relation_value += (z_perm + lagrange_first) *
(w_1 + relation_parameters.beta * id_1 + relation_parameters.gamma) *
(w_2 + relation_parameters.beta * id_2 + relation_parameters.gamma) *
(w_3 + relation_parameters.beta * id_3 + relation_parameters.gamma) -
(z_perm_shift + lagrange_last * relation_parameters.public_input_delta) *
(w_1 + relation_parameters.beta * sigma_1 + relation_parameters.gamma) *
(w_2 + relation_parameters.beta * sigma_2 + relation_parameters.gamma) *
(w_3 + relation_parameters.beta * sigma_3 + relation_parameters.gamma);
};
};
} // namespace honk::sumcheck
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,13 @@ template <typename FF> class GrandProductInitializationRelation : public Relatio
* This file handles the relation Z_perm_shift(n_last) = 0 via the relation:
*
* C(X) = L_LAST(X) * Z_perm_shift(X)
*
*
* The final parameter is left to conform to the general argument structure (input,output, challenges) even though
* we don't need challenges in this relation.
*
*/
void add_edge_contribution(auto& extended_edges, Univariate<FF, RELATION_LENGTH>& evals)
template <typename T> void add_edge_contribution(auto& extended_edges, Univariate<FF, RELATION_LENGTH>& evals, T)
{
add_edge_contribution_internal(extended_edges, evals);
};
Expand Down Expand Up @@ -55,7 +60,8 @@ template <typename FF> class GrandProductInitializationRelation : public Relatio
add_edge_contribution_internal(extended_edges, evals);
}

void add_full_relation_value_contribution(auto& purported_evaluations, FF& full_honk_relation_value)
template <typename T>
void add_full_relation_value_contribution(auto& purported_evaluations, FF& full_honk_relation_value, T)
{
auto z_perm_shift = purported_evaluations[MULTIVARIATE::Z_PERM_SHIFT];
auto lagrange_last = purported_evaluations[MULTIVARIATE::LAGRANGE_LAST];
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,10 @@ namespace honk::sumcheck {

template <typename Fr> class Relation {}; // TODO(Cody): Use or eventually remove.

template <typename FF> struct RelationParameters {
FF alpha;
FF beta;
FF gamma;
FF public_input_delta;
};
} // namespace honk::sumcheck
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ TYPED_TEST(SumcheckRelation, ArithmeticRelation)
Univariate expected_evals = (q_m * w_r * w_l) + (q_r * w_r) + (q_l * w_l) + (q_o * w_o) + (q_c);

auto evals = Univariate<FF, relation.RELATION_LENGTH>();
relation.add_edge_contribution(extended_edges, evals);
relation.add_edge_contribution(extended_edges, evals, 0);

EXPECT_EQ(evals, expected_evals);
};
Expand Down Expand Up @@ -120,9 +120,12 @@ TYPED_TEST(SumcheckRelation, GrandProductComputationRelation)
auto lagrange_first = UnivariateView(extended_edges[MULTIVARIATE::LAGRANGE_FIRST]);
auto lagrange_last = UnivariateView(extended_edges[MULTIVARIATE::LAGRANGE_LAST]);
// TODO(luke): use real transcript/challenges once manifest is done
FF beta = FF::one();
FF gamma = FF::one();
FF public_input_delta = FF::one();
FF beta = FF::random_element();
FF gamma = FF::random_element();
FF public_input_delta = FF::random_element();
const RelationParameters<FF> relation_parameters = RelationParameters<FF>{
.alpha = FF ::zero(), .beta = beta, .gamma = gamma, .public_input_delta = public_input_delta
};

auto expected_evals = Univariate();
// expected_evals in the below step { { 27, 250, 1029, 2916, 6655 } }
Expand All @@ -133,7 +136,7 @@ TYPED_TEST(SumcheckRelation, GrandProductComputationRelation)
(w_2 + sigma_2 * beta + gamma) * (w_3 + sigma_3 * beta + gamma);

auto evals = Univariate();
relation.add_edge_contribution(extended_edges, evals);
relation.add_edge_contribution(extended_edges, evals, relation_parameters);

EXPECT_EQ(evals, expected_evals);
};
Expand All @@ -157,7 +160,7 @@ TYPED_TEST(SumcheckRelation, GrandProductInitializationRelation)

// Compute the edge contribution using add_edge_contribution
auto evals = Univariate();
relation.add_edge_contribution(extended_edges, evals);
relation.add_edge_contribution(extended_edges, evals, 0);

EXPECT_EQ(evals, expected_evals);
};
Expand Down
40 changes: 35 additions & 5 deletions barretenberg/cpp/src/aztec/honk/sumcheck/sumcheck.hpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#pragma once
#include "common/serialize.hpp"
#include "proof_system/types/polynomial_manifest.hpp"
#include <honk/utils/public_inputs.hpp>
#include "common/throw_or_abort.hpp"
#include "ecc/curves/bn254/fr.hpp"
#include "sumcheck_round.hpp"
Expand Down Expand Up @@ -35,6 +36,34 @@ template <class Multivariates, class Transcript, template <class> class... Relat
, transcript(transcript)
, round(std::tuple(Relations<FF>()...)){};

/**
* @brief Get all the challenges and computed parameters used in sumcheck in a convenient format
*
* @return RelationParameters<FF>
*/
RelationParameters<FF> retrieve_proof_parameters()
{
const FF alpha = FF::serialize_from_buffer(transcript.get_challenge("alpha").begin());
const FF beta = FF::serialize_from_buffer(transcript.get_challenge("beta").begin());
const FF gamma = FF::serialize_from_buffer(transcript.get_challenge("beta", 1).begin());
const auto public_input_size_vector = transcript.get_element("public_input_size");
const size_t public_input_size = (static_cast<size_t>(public_input_size_vector[0]) << 24) |
(static_cast<size_t>(public_input_size_vector[1]) << 16) |
(static_cast<size_t>(public_input_size_vector[2]) << 8) |

static_cast<size_t>(public_input_size_vector[3]);
const auto circut_size_vector = transcript.get_element("circuit_size");
const size_t n = (static_cast<size_t>(circut_size_vector[0]) << 24) |
(static_cast<size_t>(circut_size_vector[1]) << 16) |
(static_cast<size_t>(circut_size_vector[2]) << 8) | static_cast<size_t>(circut_size_vector[3]);
std::vector<FF> public_inputs = many_from_buffer<FF>(transcript.get_element("public_inputs"));
ASSERT(public_inputs.size() == public_input_size);
FF public_input_delta = honk::compute_public_input_delta<FF>(public_inputs, beta, gamma, n);
const RelationParameters<FF> relation_parameters = RelationParameters<FF>{
.alpha = alpha, .beta = beta, .gamma = gamma, .public_input_delta = public_input_delta
};
return relation_parameters;
}
/**
* @brief Compute univariate restriction place in transcript, generate challenge, fold,... repeat until final round,
* then compute multivariate evaluations and place in transcript.
Expand All @@ -45,8 +74,9 @@ template <class Multivariates, class Transcript, template <class> class... Relat
{
// First round
// This populates multivariates.folded_polynomials.
FF alpha = FF::serialize_from_buffer(transcript.get_challenge("alpha").begin());
auto round_univariate = round.compute_univariate(multivariates.full_polynomials, alpha);

const auto relation_parameters = retrieve_proof_parameters();
auto round_univariate = round.compute_univariate(multivariates.full_polynomials, relation_parameters);
transcript.add_element("univariate_" + std::to_string(multivariates.multivariate_d),
round_univariate.to_buffer());
std::string challenge_label = "u_" + std::to_string(multivariates.multivariate_d);
Expand All @@ -59,7 +89,7 @@ template <class Multivariates, class Transcript, template <class> class... Relat
// We operate on multivariates.folded_polynomials in place.
for (size_t round_idx = 1; round_idx < multivariates.multivariate_d; round_idx++) {
// Write the round univariate to the transcript
round_univariate = round.compute_univariate(multivariates.folded_polynomials, alpha);
round_univariate = round.compute_univariate(multivariates.folded_polynomials, relation_parameters);
transcript.add_element("univariate_" + std::to_string(multivariates.multivariate_d - round_idx),
round_univariate.to_buffer());
challenge_label = "u_" + std::to_string(multivariates.multivariate_d - round_idx);
Expand Down Expand Up @@ -102,6 +132,7 @@ template <class Multivariates, class Transcript, template <class> class... Relat
{
bool verified(true);

const auto relation_parameters = retrieve_proof_parameters();
// All but final round.
// target_total_sum is initialized to zero then mutated in place.

Expand All @@ -127,9 +158,8 @@ template <class Multivariates, class Transcript, template <class> class... Relat

// Final round
auto purported_evaluations = transcript.get_field_element_vector("multivariate_evaluations");
FF alpha = FF::serialize_from_buffer(transcript.get_challenge("alpha").begin());
FF full_honk_relation_purported_value =
round.compute_full_honk_relation_purported_value(purported_evaluations, alpha);
round.compute_full_honk_relation_purported_value(purported_evaluations, relation_parameters);
verified = verified && (full_honk_relation_purported_value == round.target_total_sum);
return verified;
};
Expand Down
Loading

0 comments on commit c9554f0

Please sign in to comment.