Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: 20-30% cost reduction in recursive ipa algorithm #9420

Merged
merged 18 commits into from
Oct 29, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
131 changes: 71 additions & 60 deletions barretenberg/cpp/src/barretenberg/commitment_schemes/ipa/ipa.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -369,20 +369,29 @@ template <typename Curve_> class IPA {
// Construct vector s
std::vector<Fr> s_vec(poly_length, Fr::one());

// TODO(https://github.com/AztecProtocol/barretenberg/issues/857): This code is not efficient as its
// O(nlogn). This can be optimized to be linear by computing a tree of products. Its very readable, so we're
// leaving it unoptimized for now.
parallel_for_heuristic(
poly_length,
[&](size_t i) {
for (size_t j = (log_poly_degree - 1); j != static_cast<size_t>(-1); j--) {
auto bit = (i >> j) & 1;
bool b = static_cast<bool>(bit);
if (b) {
s_vec[i] *= round_challenges_inv[log_poly_degree - 1 - j];
}
}
}, thread_heuristics::FF_MULTIPLICATION_COST * log_poly_degree);
std::vector<Fr> s_vec_temporaries(poly_length / 2);

Fr* previous_round_s = &s_vec_temporaries[0];
Fr* current_round_s = &s_vec[0];
// if number of rounds is even we need to swap these so that s_vec always contains the result
if ((log_poly_degree & 1) == 0)
{
std::swap(previous_round_s, current_round_s);
}
previous_round_s[0] = Fr(1);
for (size_t i = 0; i < log_poly_degree; ++i)
{
const size_t round_size = 1 << (i + 1);
const Fr round_challenge = round_challenges_inv[i];
parallel_for_heuristic(
round_size / 2,
[&](size_t j) {
current_round_s[j * 2] = previous_round_s[j];
current_round_s[j * 2 + 1] = previous_round_s[j] * round_challenge;
}, thread_heuristics::FF_MULTIPLICATION_COST * 2);
std::swap(current_round_s, previous_round_s);
}


std::span<const Commitment> srs_elements = vk->get_monomial_points();
if (poly_length * 2 > srs_elements.size()) {
Expand Down Expand Up @@ -454,28 +463,20 @@ template <typename Curve_> class IPA {
const Fr generator_challenge = transcript->template get_challenge<Fr>("IPA:generator_challenge");
auto builder = generator_challenge.get_context();

Commitment aux_generator = Commitment::one(builder) * generator_challenge;

const auto log_poly_degree = numeric::get_msb(static_cast<uint32_t>(poly_length));

// Step 3.
// Compute C' = C + f(\beta) ⋅ U
GroupElement C_prime = opening_claim.commitment + aux_generator * opening_claim.opening_pair.evaluation;

auto pippenger_size = 2 * log_poly_degree;
std::vector<Fr> round_challenges(log_poly_degree);
std::vector<Fr> round_challenges_inv(log_poly_degree);
std::vector<Commitment> msm_elements(pippenger_size);
std::vector<Fr> msm_scalars(pippenger_size);

// Step 4.
// Step 3.
// Receive all L_i and R_i and prepare for MSM
for (size_t i = 0; i < log_poly_degree; i++) {
std::string index = std::to_string(log_poly_degree - i - 1);
auto element_L = transcript->template receive_from_prover<Commitment>("IPA:L_" + index);
auto element_R = transcript->template receive_from_prover<Commitment>("IPA:R_" + index);
round_challenges[i] = transcript->template get_challenge<Fr>("IPA:round_challenge_" + index);

round_challenges_inv[i] = round_challenges[i].invert();

msm_elements[2 * i] = element_L;
Expand All @@ -484,63 +485,73 @@ template <typename Curve_> class IPA {
msm_scalars[2 * i + 1] = round_challenges[i];
}

// Step 5.
// Compute C₀ = C' + ∑_{j ∈ [k]} u_j^{-1}L_j + ∑_{j ∈ [k]} u_jR_j
GroupElement LR_sums = GroupElement::batch_mul(msm_elements, msm_scalars);

GroupElement C_zero = C_prime + LR_sums;

// Step 6.
// Step 4.
// Compute b_zero where b_zero can be computed using the polynomial:
// g(X) = ∏_{i ∈ [k]} (1 + u_{i-1}^{-1}.X^{2^{i-1}}).
// b_zero = g(evaluation) = ∏_{i ∈ [k]} (1 + u_{i-1}^{-1}. (evaluation)^{2^{i-1}})

Fr b_zero = Fr(1);
Fr challenge = opening_claim.opening_pair.challenge;
for (size_t i = 0; i < log_poly_degree; i++) {
b_zero *= Fr(1) + (round_challenges_inv[log_poly_degree - 1 - i] *
opening_claim.opening_pair.challenge.pow(1 << i));
b_zero *= Fr(1) + (round_challenges_inv[log_poly_degree - 1 - i] * challenge);
if (i != log_poly_degree - 1)
{
challenge = challenge * challenge;
}
}

// Step 7.

// Step 5.
// Construct vector s
// We implement a linear-time algorithm to optimally compute this vector
// Note: currently requires an extra vector of size `poly_length / 2` to cache temporaries
// this might able to be optimized if we care enough, but the size of this poly shouldn't be large relative to the builder polynomial sizes
std::vector<Fr> s_vec_temporaries(poly_length / 2);
std::vector<Fr> s_vec(poly_length);

// TODO(https://github.com/AztecProtocol/barretenberg/issues/857): This code is not efficient as its
// O(nlogn). This can be optimized to be linear by computing a tree of products.
for (size_t i = 0; i < poly_length; i++) {
Fr s_vec_scalar = Fr(1);
for (size_t j = (log_poly_degree - 1); j != static_cast<size_t>(-1); j--) {
auto bit = (i >> j) & 1;
bool b = static_cast<bool>(bit);
if (b) {
s_vec_scalar *= round_challenges_inv[log_poly_degree - 1 - j];
}
Fr* previous_round_s = &s_vec_temporaries[0];
Fr* current_round_s = &s_vec[0];
// if number of rounds is even we need to swap these so that s_vec always contains the result
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it'd be nice to add more comments in this section

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can add more myself later

if ((log_poly_degree & 1) == 0)
{
std::swap(previous_round_s, current_round_s);
}
previous_round_s[0] = Fr(1);
for (size_t i = 0; i < log_poly_degree; ++i)
{
const size_t round_size = 1 << (i + 1);
const Fr round_challenge = round_challenges_inv[i];
for (size_t j = 0; j < round_size / 2; ++j)
{
current_round_s[j * 2] = previous_round_s[j];
current_round_s[j * 2 + 1] = previous_round_s[j] * round_challenge;
}
s_vec[i] = s_vec_scalar;
std::swap(current_round_s, previous_round_s);
}

auto srs_elements = vk->get_monomial_points();

// TODO(https://github.com/AztecProtocol/barretenberg/issues/1023): Unify the two batch_muls
// Step 6.
// Receive a₀ from the prover
auto a_zero = transcript->template receive_from_prover<Fr>("IPA:a_0");
lucasxia01 marked this conversation as resolved.
Show resolved Hide resolved

// Step 8.
// Step 7.
// Compute G₀
// Unlike the native verification function, the verifier commitment key only containts the SRS so we can apply
// batch_mul directly on it.
auto srs_elements = vk->get_monomial_points();
lucasxia01 marked this conversation as resolved.
Show resolved Hide resolved
Commitment G_zero = Commitment::batch_mul(srs_elements, s_vec);

// Step 9.
// Receive a₀ from the prover
auto a_zero = transcript->template receive_from_prover<Fr>("IPA:a_0");

// Step 10.
// Compute C_right
GroupElement right_hand_side = G_zero * a_zero + aux_generator * a_zero * b_zero;

// Step 11.
// Check if C_right == C₀
C_zero.assert_equal(right_hand_side);
return (C_zero.get_value() == right_hand_side.get_value());
// Step 8.
// Compute R = C' + ∑_{j ∈ [k]} u_j^{-1}L_j + ∑_{j ∈ [k]} u_jR_j - G₀ * a₀ - (f(\beta) + a₀ * b₀) ⋅ U
// This is a combination of several IPA relations into a large batch mul
// which should be equal to -C
msm_elements.emplace_back(-G_zero);
msm_elements.emplace_back(-Commitment::one(builder));
msm_scalars.emplace_back(a_zero);
msm_scalars.emplace_back(generator_challenge * a_zero.madd(b_zero, {opening_claim.opening_pair.evaluation}));
GroupElement ipa_relation = GroupElement::batch_mul(msm_elements, msm_scalars);
ipa_relation.assert_equal(-opening_claim.commitment);

return (ipa_relation.get_value() == -opening_claim.commitment.get_value());
}

public:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ template <typename RecursiveFlavor> class ECCVMRecursiveTests : public ::testing
OuterBuilder outer_circuit;
RecursiveVerifier verifier{ &outer_circuit, verification_key };
verifier.verify_proof(proof);
info("Recursive Verifier: num gates = ", outer_circuit.num_gates);
info("Recursive Verifier: num gates = ", outer_circuit.get_estimated_num_finalized_gates());
lucasxia01 marked this conversation as resolved.
Show resolved Hide resolved

// Check for a failure flag in the recursive verifier circuit
EXPECT_EQ(outer_circuit.failed(), false) << outer_circuit.err();
Expand Down Expand Up @@ -135,10 +135,21 @@ template <typename RecursiveFlavor> class ECCVMRecursiveTests : public ::testing
OuterBuilder outer_circuit;
RecursiveVerifier verifier{ &outer_circuit, verification_key };
verifier.verify_proof(proof);
info("Recursive Verifier: num gates = ", outer_circuit.num_gates);
info("Recursive Verifier: num gates = ", outer_circuit.get_estimated_num_finalized_gates());

// Check for a failure flag in the recursive verifier circuit
EXPECT_EQ(outer_circuit.failed(), true) << outer_circuit.err();

{
auto proving_key = std::make_shared<OuterDeciderProvingKey>(outer_circuit);
OuterProver prover(proving_key);
auto verification_key = std::make_shared<typename OuterFlavor::VerificationKey>(proving_key->proving_key);
OuterVerifier verifier(verification_key);
auto proof = prover.construct_proof();
bool verified = verifier.verify_proof(proof);

EXPECT_FALSE(verified);
}
}
};
using FlavorTypes = testing::Types<ECCVMRecursiveFlavor_<UltraCircuitBuilder>>;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -113,17 +113,18 @@ TEST_F(GoblinRecursiveVerifierTests, ECCVMFailure)

// Tamper with the ECCVM proof
for (auto& val : proof.eccvm_proof) {
if (val > 0) { // tamper by finding the first non-zero value and incrementing it by 1
if (val > 0) { // tamper by finding the tenth non-zero value and incrementing it by 1
// tamper by finding the first non-zero value
// and incrementing it by 1
val += 1;
break;
}
}

Builder builder;
GoblinRecursiveVerifier verifier{ &builder, verifier_input };
verifier.verify(proof);

EXPECT_FALSE(CircuitChecker::check(builder));
EXPECT_DEBUG_DEATH(verifier.verify(proof), "(sumcheck_verified && batched_opening_verified)");
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -678,23 +678,110 @@ typename cycle_group<Builder>::cycle_scalar cycle_group<Builder>::cycle_scalar::
template <typename Builder> cycle_group<Builder>::cycle_scalar::cycle_scalar(BigScalarField& scalar)
{
auto* ctx = get_context() ? get_context() : scalar.get_context();
const uint256_t value((scalar.get_value() % uint512_t(ScalarField::modulus)).lo);
const uint256_t value_lo = value.slice(0, LO_BITS);
const uint256_t value_hi = value.slice(LO_BITS, HI_BITS);

if (scalar.is_constant()) {
const uint256_t value((scalar.get_value() % uint512_t(ScalarField::modulus)).lo);
const uint256_t value_lo = value.slice(0, LO_BITS);
const uint256_t value_hi = value.slice(LO_BITS, HI_BITS);

lo = value_lo;
hi = value_hi;
// N.B. to be able to call assert equal, these cannot be constants
} else {
lo = witness_t(ctx, value_lo);
hi = witness_t(ctx, value_hi);
field_t zero = field_t(0);
zero.convert_constant_to_fixed_witness(ctx);
BigScalarField lo_big(lo, zero);
BigScalarField hi_big(hi, zero);
BigScalarField res = lo_big + hi_big * BigScalarField((uint256_t(1) << LO_BITS));
scalar.assert_equal(res);
validate_scalar_is_in_field();
// To efficiently convert a bigfield into a cycle scalar,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for taking the time to add all the docs :)

// we are going to explicitly rely on the fact that `scalar.lo` and `scalar.hi`
// are implicitly range-constrained to be 128 bits when they are converted into 4-bit lookup window slices

// First check: can the scalar actually fit into LO_BITS + HI_BITS?
// If it can, we can tolerate the scalar being > ScalarField::modulus, because performing a scalar mul
// implicilty performs a modular reduction
// If not, call `self_reduce` to cut enougn modulus multiples until the above condition is met
if (scalar.get_maximum_value() >= (uint512_t(1) << (LO_BITS + HI_BITS))) {
scalar.self_reduce();
}

field_t limb0 = scalar.binary_basis_limbs[0].element;
field_t limb1 = scalar.binary_basis_limbs[1].element;
field_t limb2 = scalar.binary_basis_limbs[2].element;
field_t limb3 = scalar.binary_basis_limbs[3].element;

// The general plan is as follows:
// 1. ensure limb0 contains no more than BigScalarField::NUM_LIMB_BITS
// 2. define limb1_lo = limb1.slice(0, LO_BITS - BigScalarField::NUM_LIMB_BITS)
// 3. define limb1_hi = limb1.slice(LO_BITS - BigScalarField::NUM_LIMB_BITS, <whatever maximum bound of limb1
// is>)
// 4. construct *this.lo out of limb0 and limb1_lo
// 5. construct *this.hi out of limb1_hi, limb2 and limb3
// This is a lot of logic, but very cheap on constraints.
// For fresh bignums that have come out of a MUL operation,
// the only "expensive" part is a size (LO_BITS - BigScalarField::NUM_LIMB_BITS) range check

// to convert into a cycle_scalar, we need to convert 4*68 bit limbs into 2*128 bit limbs
// we also need to ensure that the number of bits in cycle_scalar is < LO_BITS + HI_BITS
// note: we do not need to validate that the scalar is within the field modulus
// because performing a scalar multiplication implicitly performs a modular reduction (ecc group is
// multiplicative modulo BigField::modulus)

uint256_t limb1_max = scalar.binary_basis_limbs[1].maximum_value;

// Ensure that limb0 only contains at most NUM_LIMB_BITS. If it exceeds this value, slice of the excess and add
// it into limb1
if (scalar.binary_basis_limbs[0].maximum_value > BigScalarField::DEFAULT_MAXIMUM_LIMB) {
const uint256_t limb = limb0.get_value();
const uint256_t lo_v = limb.slice(0, BigScalarField::NUM_LIMB_BITS);
const uint256_t hi_v = limb >> BigScalarField::NUM_LIMB_BITS;
field_t lo = field_t::from_witness(ctx, lo_v);
field_t hi = field_t::from_witness(ctx, hi_v);
Comment on lines +733 to +734
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

don't this two lines create an unconstrained witness, shouldn't we create a constant and then call convert_constant_to_fixed_witness?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Those aren’t constants. They vary depending on the value of limb. We constrain them to be correct via the lines 738-740
i.e. we’re taking the value of limb and chopping it up into two components lo, hi, where we know that lo has at most BigScalarField::NUM_LIMB_BITS and hi has whatever the overflow is

If we called convert_constant_to_fixed_witness, the values of lo, hi would have to be identical for every proof, which they aren't


uint256_t hi_max = (scalar.binary_basis_limbs[0].maximum_value >> BigScalarField::NUM_LIMB_BITS);
const size_t hi_bits = hi_max.get_msb() + 1;
lucasxia01 marked this conversation as resolved.
Show resolved Hide resolved
lo.create_range_constraint(BigScalarField::NUM_LIMB_BITS);
hi.create_range_constraint(hi_bits);
limb0.assert_equal(lo + hi * BigScalarField::shift_1);

limb1 += hi;
limb1_max += hi_max;
limb0 = lo;
}

// sanity check that limb[1] is the limb that contributs both to *this.lo and *this.hi
ASSERT((BigScalarField::NUM_LIMB_BITS * 2 > LO_BITS) && (BigScalarField::NUM_LIMB_BITS < LO_BITS));

// limb1 is the tricky one as it contributs to both *this.lo and *this.hi
// By this point, we know that limb1 fits in the range `1 << BigScalarField::NUM_LIMB_BITS to (1 <<
// BigScalarField::NUM_LIMB_BITS) + limb1_max.get_maximum_value() we need to slice this limb into 2. The first
// is LO_BITS - BigScalarField::NUM_LIMB_BITS (which reprsents its contribution to *this.lo) and the second
// represents the limbs contribution to *this.hi Step 1: compute the max bit sizes of both slices
const size_t lo_bits_in_limb_1 = LO_BITS - BigScalarField::NUM_LIMB_BITS;
const size_t hi_bits_in_limb_1 = (limb1_max.get_msb() + 1) - lo_bits_in_limb_1;

// Step 2: compute the witness values of both slices
const uint256_t limb_1 = limb1.get_value();
const uint256_t limb_1_hi_multiplicand = (uint256_t(1) << lo_bits_in_limb_1);
const uint256_t limb_1_hi_v = limb_1 >> lo_bits_in_limb_1;
const uint256_t limb_1_lo_v = limb_1 - (limb_1_hi_v << lo_bits_in_limb_1);

// Step 3: instantiate both slices as witnesses and validate their sum equals limb1
field_t limb_1_lo = field_t::from_witness(ctx, limb_1_lo_v);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same question here

field_t limb_1_hi = field_t::from_witness(ctx, limb_1_hi_v);
limb1.assert_equal(limb_1_hi * limb_1_hi_multiplicand + limb_1_lo);

// Step 4: apply range constraints to validate both slices represent the expected contributions to *this.lo and
// *this,hi
limb_1_lo.create_range_constraint(lo_bits_in_limb_1);
limb_1_hi.create_range_constraint(hi_bits_in_limb_1);

// construct *this.lo out of:
// a. `limb0` (the first NUM_LIMB_BITS bits of scalar)
// b. `limb_1_lo` (the first LO_BITS - NUM_LIMB_BITS) of limb1
lo = limb0 + (limb_1_lo * BigScalarField::shift_1);

const uint256_t limb_2_shift = uint256_t(1) << (BigScalarField::NUM_LIMB_BITS - lo_bits_in_limb_1);
const uint256_t limb_3_shift =
uint256_t(1) << ((BigScalarField::NUM_LIMB_BITS - lo_bits_in_limb_1) + BigScalarField::NUM_LIMB_BITS);

// construct *this.hi out of limb2, limb3 and the remaining term from limb1 not contributing to `lo`
hi = limb_1_hi.add_two(limb2 * limb_2_shift, limb3 * limb_3_shift);
}
};

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,11 +88,13 @@ class ECCOpQueue {
*/
void add_erroneous_equality_op_for_testing()
{
auto base_point = Point::random_element();
info("erroneous equality op point ", base_point);
raw_ops.emplace_back(ECCVMOperation{ .add = false,
.mul = false,
.eq = true,
.reset = true,
.base_point = Point::random_element(),
.base_point = base_point,
.z1 = 0,
.z2 = 0,
.mul_scalar_full = 0 });
Expand Down
Loading