Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: added a UnivariateMonomial representation to reduce field ops in protogalaxy+sumcheck #10401

Merged
merged 34 commits into from
Dec 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
1926c6b
add monomial accumulator
zac-williamson Nov 8, 2024
3c96f4f
wip
zac-williamson Nov 8, 2024
b38ff10
relation updates
zac-williamson Nov 10, 2024
0922902
wip
zac-williamson Nov 10, 2024
9c47d26
wip
zac-williamson Nov 10, 2024
f0673e2
wip
zac-williamson Nov 10, 2024
019d8aa
wip
zac-williamson Nov 10, 2024
af2df21
wip
zac-williamson Nov 10, 2024
ae4246e
wip
zac-williamson Nov 10, 2024
a5d457e
wip
zac-williamson Nov 10, 2024
1d0fbb0
wip
zac-williamson Nov 12, 2024
e66e3d6
wip
zac-williamson Nov 12, 2024
ec321d1
wip
zac-williamson Nov 18, 2024
5c07a8a
wip
zac-williamson Nov 18, 2024
32f9858
Merge branch 'master' into zw/monomial-accumulaator
zac-williamson Dec 4, 2024
5d78d66
small cleanup
zac-williamson Dec 4, 2024
a3e47d9
test inlines
zac-williamson Dec 4, 2024
911801f
add short univariate case to sumcheck prover
zac-williamson Dec 4, 2024
2a0b185
move small computations within thread block
zac-williamson Dec 4, 2024
da6b1a6
tweaks
zac-williamson Dec 4, 2024
81315c7
tweaks
zac-williamson Dec 4, 2024
160e842
updated flavor.hpp.hbs to generate USE_SHORT_MONOMIALS
zac-williamson Dec 5, 2024
e8d3508
recursive_flavor derives USE_SHORT_MONOMIALS
zac-williamson Dec 5, 2024
4b9af43
compiler fixes
zac-williamson Dec 5, 2024
21b8d22
reduced tparams in ProtogalaxyProverInternal
zac-williamson Dec 10, 2024
3e1065b
Updated to reflect PR comments
zac-williamson Dec 10, 2024
12959ef
more PR feedback
zac-williamson Dec 10, 2024
4e4aa74
more PR feedback
zac-williamson Dec 10, 2024
f051728
added tests for UnivariateMonomial
zac-williamson Dec 11, 2024
a099809
renamed "UnivariateMonomial" to "UnivariateCoefficientBasis"
zac-williamson Dec 11, 2024
159c1aa
Merge branch 'master' into zw/monomial-accumulaator
zac-williamson Dec 12, 2024
2a627df
PR comments
zac-williamson Dec 12, 2024
2b05b43
Merge branch 'master' into zw/monomial-accumulaator
zac-williamson Dec 12, 2024
874e7ef
Merge branch 'master' into zw/monomial-accumulaator
zac-williamson Dec 18, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -54,9 +54,9 @@ template <typename Flavor, typename Relation> void execute_relation_for_univaria
template <typename Flavor, typename Relation> void execute_relation_for_pg_univariates(::benchmark::State& state)
{
using DeciderProvingKeys = DeciderProvingKeys_<Flavor>;
using Input = ProtogalaxyProverInternal<DeciderProvingKeys>::ExtendedUnivariatesNoOptimisticSkipping;
using Accumulator = typename Relation::template ProtogalaxyTupleOfUnivariatesOverSubrelationsNoOptimisticSkipping<
DeciderProvingKeys::NUM>;
using Input = ProtogalaxyProverInternal<DeciderProvingKeys>::ExtendedUnivariates;
using Accumulator =
typename Relation::template ProtogalaxyTupleOfUnivariatesOverSubrelations<DeciderProvingKeys::NUM>;

execute_relation<Flavor, Relation, Input, Accumulator>(state);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ namespace bb {
template <class Params_> struct alignas(32) field {
public:
using View = field;
using CoefficientAccumulator = field;
using Params = Params_;
using in_buf = const uint8_t*;
using vec_in_buf = const uint8_t*;
Expand Down
3 changes: 3 additions & 0 deletions barretenberg/cpp/src/barretenberg/eccvm/eccvm_flavor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,9 @@ class ECCVMFlavor {
using RelationSeparator = FF;
using MSM = bb::eccvm::MSM<CycleGroup>;

// indicates when evaluating sumcheck, edges must be extended to be MAX_TOTAL_RELATION_LENGTH
static constexpr bool USE_SHORT_MONOMIALS = false;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add to translator and eccvm recursive flavor?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed. I'm curious why the tests all passed without this. Do we not compile tests that use the sumcheck Prover for the translator and eccvm recursive flavours?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh, now that I think about it, this flag is only useful for the prover, and given we only ever instantiate the verifiers with recursive flavors the flag is not necessary in them, my bad


// Indicates that this flavor runs with ZK Sumcheck.
static constexpr bool HasZK = true;
static constexpr size_t NUM_WIRES = 85;
Expand Down
94 changes: 90 additions & 4 deletions barretenberg/cpp/src/barretenberg/polynomials/univariate.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#include "barretenberg/common/assert.hpp"
#include "barretenberg/common/serialize.hpp"
#include "barretenberg/polynomials/barycentric.hpp"
#include "barretenberg/polynomials/univariate_coefficient_basis.hpp"
#include <span>

namespace bb {
Expand Down Expand Up @@ -30,6 +31,8 @@ template <class Fr, size_t domain_end, size_t domain_start = 0, size_t skip_coun
static constexpr size_t LENGTH = domain_end - domain_start;
static constexpr size_t SKIP_COUNT = skip_count;
using View = UnivariateView<Fr, domain_end, domain_start, skip_count>;
static constexpr size_t MONOMIAL_LENGTH = LENGTH > 1 ? 2 : 1;
using CoefficientAccumulator = UnivariateCoefficientBasis<Fr, MONOMIAL_LENGTH, true>;

using value_type = Fr; // used to get the type of the elements consistently with std::array

Expand All @@ -47,6 +50,58 @@ template <class Fr, size_t domain_end, size_t domain_start = 0, size_t skip_coun
Univariate& operator=(const Univariate& other) = default;
Univariate& operator=(Univariate&& other) noexcept = default;

explicit operator UnivariateCoefficientBasis<Fr, 2, true>() const
requires(LENGTH > 1)
{
static_assert(domain_end >= 2);
static_assert(domain_start == 0);

UnivariateCoefficientBasis<Fr, 2, true> result;
result.coefficients[0] = evaluations[0];
result.coefficients[1] = evaluations[1] - evaluations[0];
result.coefficients[2] = evaluations[1];
return result;
}

template <bool has_a0_plus_a1> Univariate(UnivariateCoefficientBasis<Fr, 2, has_a0_plus_a1> monomial)
{
static_assert(domain_start == 0);
Fr to_add = monomial.coefficients[1];
evaluations[0] = monomial.coefficients[0];
auto prev = evaluations[0];
for (size_t i = 1; i < skip_count + 1; ++i) {
evaluations[i] = 0;
prev = prev + to_add;
}

for (size_t i = skip_count + 1; i < domain_end; ++i) {
prev = prev + to_add;
evaluations[i] = prev;
}
}

template <bool has_a0_plus_a1> Univariate(UnivariateCoefficientBasis<Fr, 3, has_a0_plus_a1> monomial)
{
static_assert(domain_start == 0);
Fr to_add = monomial.coefficients[1]; // a1 + a2
Fr derivative = monomial.coefficients[2] + monomial.coefficients[2]; // 2a2
evaluations[0] = monomial.coefficients[0];
auto prev = evaluations[0];
for (size_t i = 1; i < skip_count + 1; ++i) {
evaluations[i] = 0;
prev += to_add;
to_add += derivative;
}

for (size_t i = skip_count + 1; i < domain_end - 1; ++i) {
prev += to_add;
evaluations[i] = prev;
to_add += derivative;
}
prev += to_add;
evaluations[domain_end - 1] = prev;
}

/**
* @brief Convert from a version with skipped evaluations to one without skipping (with zeroes in previously skipped
* locations)
Expand Down Expand Up @@ -104,15 +159,12 @@ template <class Fr, size_t domain_end, size_t domain_start = 0, size_t skip_coun
// Check if the univariate is identically zero
bool is_zero() const
{
if (!evaluations[0].is_zero()) {
return false;
}
for (size_t i = skip_count + 1; i < LENGTH; ++i) {
if (!evaluations[i].is_zero()) {
return false;
}
}
return true;
return evaluations[0].is_zero();
}

// Write the Univariate evaluations to a buffer
Expand Down Expand Up @@ -350,6 +402,13 @@ template <class Fr, size_t domain_end, size_t domain_start = 0, size_t skip_coun
return os;
}

template <size_t EXTENDED_DOMAIN_END, size_t NUM_SKIPPED_INDICES = 0>
explicit operator Univariate<Fr, EXTENDED_DOMAIN_END, 0, NUM_SKIPPED_INDICES>()
requires(domain_start == 0 && domain_end == 2)
{
return extend_to<EXTENDED_DOMAIN_END, NUM_SKIPPED_INDICES>();
}

/**
* @brief Given a univariate f represented by {f(domain_start), ..., f(domain_end - 1)}, compute the
* evaluations {f(domain_end),..., f(extended_domain_end -1)} and return the Univariate represented by
Expand Down Expand Up @@ -576,15 +635,42 @@ template <class Fr, size_t domain_end, size_t domain_start = 0, size_t skip_coun
public:
static constexpr size_t LENGTH = domain_end - domain_start;
std::span<const Fr, LENGTH> evaluations;
static constexpr size_t MONOMIAL_LENGTH = LENGTH > 1 ? 2 : 1;
using CoefficientAccumulator = UnivariateCoefficientBasis<Fr, MONOMIAL_LENGTH, true>;

UnivariateView() = default;

bool operator==(const UnivariateView& other) const
{
bool r = true;
r = r && (evaluations[0] == other.evaluations[0]);
// a view might have nonzero terms in its skip_count if accessing an original monomial
for (size_t i = skip_count + 1; i < LENGTH; ++i) {
r = r && (evaluations[i] == other.evaluations[i]);
}
return r;
};

const Fr& value_at(size_t i) const { return evaluations[i]; };

template <size_t full_domain_end, size_t full_domain_start = 0>
explicit UnivariateView(const Univariate<Fr, full_domain_end, full_domain_start, skip_count>& univariate_in)
: evaluations(std::span<const Fr>(univariate_in.evaluations.data(), LENGTH)){};

explicit operator UnivariateCoefficientBasis<Fr, 2, true>() const
requires(LENGTH > 1)
{
static_assert(domain_end >= 2);
static_assert(domain_start == 0);

UnivariateCoefficientBasis<Fr, 2, true> result;

result.coefficients[0] = evaluations[0];
result.coefficients[1] = evaluations[1] - evaluations[0];
result.coefficients[2] = evaluations[1];
return result;
}

Univariate<Fr, domain_end, domain_start, skip_count> operator+(const UnivariateView& other) const
{
Univariate<Fr, domain_end, domain_start, skip_count> res(*this);
Expand Down
Loading
Loading