Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Zw/recursion constraint reduction #377

Merged
merged 14 commits into from
Apr 24, 2023
192 changes: 118 additions & 74 deletions cpp/src/barretenberg/plonk/composer/ultra_composer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -531,6 +531,7 @@ std::shared_ptr<proving_key> UltraComposer::compute_proving_key()
* our circuit is finalised, and we must not to execute these functions again.
*/
if (!circuit_finalised) {
process_non_native_field_multiplications();
process_ROM_arrays(public_inputs.size());
process_RAM_arrays(public_inputs.size());
process_range_lists();
Expand Down Expand Up @@ -1889,8 +1890,6 @@ std::array<uint32_t, 2> UltraComposer::evaluate_non_native_field_multiplication(
constexpr barretenberg::fr LIMB_SHIFT = uint256_t(1) << DEFAULT_NON_NATIVE_FIELD_LIMB_BITS;
constexpr barretenberg::fr LIMB_SHIFT_2 = uint256_t(1) << (2 * DEFAULT_NON_NATIVE_FIELD_LIMB_BITS);
constexpr barretenberg::fr LIMB_SHIFT_3 = uint256_t(1) << (3 * DEFAULT_NON_NATIVE_FIELD_LIMB_BITS);
constexpr barretenberg::fr LIMB_RSHIFT =
barretenberg::fr(1) / barretenberg::fr(uint256_t(1) << DEFAULT_NON_NATIVE_FIELD_LIMB_BITS);
constexpr barretenberg::fr LIMB_RSHIFT_2 =
barretenberg::fr(1) / barretenberg::fr(uint256_t(1) << (2 * DEFAULT_NON_NATIVE_FIELD_LIMB_BITS));

Expand Down Expand Up @@ -1939,82 +1938,127 @@ std::array<uint32_t, 2> UltraComposer::evaluate_non_native_field_multiplication(
range_constrain_two_limbs(input.q[2], input.q[3]);
}

// product gate 1
// (lo_0 + q_0(p_0 + p_1*2^b) + q_1(p_0*2^b) - (r_1)2^b)2^-2b - lo_1 = 0
create_big_add_gate({ input.q[0],
input.q[1],
input.r[1],
lo_1_idx,
input.neg_modulus[0] + input.neg_modulus[1] * LIMB_SHIFT,
input.neg_modulus[0] * LIMB_SHIFT,
-LIMB_SHIFT,
-LIMB_SHIFT.sqr(),
0 },
true);
// Add witnesses into the multiplication cache
// (when finalising the circuit, we will remove duplicates; several dups produced by biggroup.hpp methods)
cached_non_native_field_multiplication cache{
codygunton marked this conversation as resolved.
Show resolved Hide resolved
.a = input.a,
.b = input.b,
.q = input.q,
.r = input.r,
.cross_terms = { lo_0_idx, lo_1_idx, hi_0_idx, hi_1_idx, hi_2_idx, hi_3_idx },
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TODO: clearer to make this a struct.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sanity check: these are witness indices, so the multiplications being removed really are duplicates. ✔️.

.neg_modulus = input.neg_modulus,
};
cached_non_native_field_multiplications.emplace_back(cache);

w_l.emplace_back(input.a[1]);
w_r.emplace_back(input.b[1]);
w_o.emplace_back(input.r[0]);
w_4.emplace_back(lo_0_idx);
apply_aux_selectors(AUX_SELECTORS::NON_NATIVE_FIELD_1);
++num_gates;
w_l.emplace_back(input.a[0]);
w_r.emplace_back(input.b[0]);
w_o.emplace_back(input.a[3]);
w_4.emplace_back(input.b[3]);
apply_aux_selectors(AUX_SELECTORS::NON_NATIVE_FIELD_2);
++num_gates;
w_l.emplace_back(input.a[2]);
w_r.emplace_back(input.b[2]);
w_o.emplace_back(input.r[3]);
w_4.emplace_back(hi_0_idx);
apply_aux_selectors(AUX_SELECTORS::NON_NATIVE_FIELD_3);
++num_gates;
w_l.emplace_back(input.a[1]);
w_r.emplace_back(input.b[1]);
w_o.emplace_back(input.r[2]);
w_4.emplace_back(hi_1_idx);
apply_aux_selectors(AUX_SELECTORS::NONE);
++num_gates;
return std::array<uint32_t, 2>{ lo_1_idx, hi_3_idx };
}

/**
* product gate 6
*
* hi_2 - hi_1 - lo_1 - q[2](p[1].2^b + p[0]) - q[3](p[0].2^b) = 0
*
**/
create_big_add_gate(
{
input.q[2],
input.q[3],
lo_1_idx,
hi_1_idx,
-input.neg_modulus[1] * LIMB_SHIFT - input.neg_modulus[0],
-input.neg_modulus[0] * LIMB_SHIFT,
-1,
-1,
0,
},
true);
/**
* @brief Called in `compute_proving_key` when finalizing circuit.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TODO: name and description of evaluate_non_native_field_multiplication are no longer correct.

* Iterates over the cached_non_native_field_multiplication objects,
* removes duplicates, and instantiates the remainder as constraints`
*/
void UltraComposer::process_non_native_field_multiplications()
{
std::sort(cached_non_native_field_multiplications.begin(), cached_non_native_field_multiplications.end());

/**
* product gate 7
*
* hi_3 - (hi_2 - q[0](p[3].2^b + p[2]) - q[1](p[2].2^b + p[1])).2^-2b
**/
create_big_add_gate({
hi_3_idx,
input.q[0],
input.q[1],
hi_2_idx,
-1,
input.neg_modulus[3] * LIMB_RSHIFT + input.neg_modulus[2] * LIMB_RSHIFT_2,
input.neg_modulus[2] * LIMB_RSHIFT + input.neg_modulus[1] * LIMB_RSHIFT_2,
LIMB_RSHIFT_2,
0,
});
auto last =
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Eliminate duplicates by shifting unique entries back to toward the beginning of the vector, then returns a pointer to the first element of the vector past the last unique element ✔️.

std::unique(cached_non_native_field_multiplications.begin(), cached_non_native_field_multiplications.end());

return std::array<uint32_t, 2>{ lo_1_idx, hi_3_idx };
auto it = cached_non_native_field_multiplications.begin();

constexpr barretenberg::fr LIMB_SHIFT = uint256_t(1) << DEFAULT_NON_NATIVE_FIELD_LIMB_BITS;
constexpr barretenberg::fr LIMB_RSHIFT =
barretenberg::fr(1) / barretenberg::fr(uint256_t(1) << DEFAULT_NON_NATIVE_FIELD_LIMB_BITS);
constexpr barretenberg::fr LIMB_RSHIFT_2 =
barretenberg::fr(1) / barretenberg::fr(uint256_t(1) << (2 * DEFAULT_NON_NATIVE_FIELD_LIMB_BITS));

// iterate over the cached items and create constraints
while (it != last) {
const auto input = *it;
const auto lo_0_idx = input.cross_terms[0];
const auto lo_1_idx = input.cross_terms[1];
const auto hi_0_idx = input.cross_terms[2];
const auto hi_1_idx = input.cross_terms[3];
const auto hi_2_idx = input.cross_terms[4];
const auto hi_3_idx = input.cross_terms[5];

// product gate 1
// (lo_0 + q_0(p_0 + p_1*2^b) + q_1(p_0*2^b) - (r_1)2^b)2^-2b - lo_1 = 0
create_big_add_gate({ input.q[0],
input.q[1],
input.r[1],
lo_1_idx,
input.neg_modulus[0] + input.neg_modulus[1] * LIMB_SHIFT,
input.neg_modulus[0] * LIMB_SHIFT,
-LIMB_SHIFT,
-LIMB_SHIFT.sqr(),
0 },
true);

w_l.emplace_back(input.a[1]);
w_r.emplace_back(input.b[1]);
w_o.emplace_back(input.r[0]);
w_4.emplace_back(lo_0_idx);
apply_aux_selectors(AUX_SELECTORS::NON_NATIVE_FIELD_1);
++num_gates;
w_l.emplace_back(input.a[0]);
w_r.emplace_back(input.b[0]);
w_o.emplace_back(input.a[3]);
w_4.emplace_back(input.b[3]);
apply_aux_selectors(AUX_SELECTORS::NON_NATIVE_FIELD_2);
++num_gates;
w_l.emplace_back(input.a[2]);
w_r.emplace_back(input.b[2]);
w_o.emplace_back(input.r[3]);
w_4.emplace_back(hi_0_idx);
apply_aux_selectors(AUX_SELECTORS::NON_NATIVE_FIELD_3);
++num_gates;
w_l.emplace_back(input.a[1]);
w_r.emplace_back(input.b[1]);
w_o.emplace_back(input.r[2]);
w_4.emplace_back(hi_1_idx);
apply_aux_selectors(AUX_SELECTORS::NONE);
++num_gates;

/**
* product gate 6
*
* hi_2 - hi_1 - lo_1 - q[2](p[1].2^b + p[0]) - q[3](p[0].2^b) = 0
*
**/
create_big_add_gate(
{
input.q[2],
input.q[3],
lo_1_idx,
hi_1_idx,
-input.neg_modulus[1] * LIMB_SHIFT - input.neg_modulus[0],
-input.neg_modulus[0] * LIMB_SHIFT,
-1,
-1,
0,
},
true);

/**
* product gate 7
*
* hi_3 - (hi_2 - q[0](p[3].2^b + p[2]) - q[1](p[2].2^b + p[1])).2^-2b
**/
create_big_add_gate({
hi_3_idx,
input.q[0],
input.q[1],
hi_2_idx,
-1,
input.neg_modulus[3] * LIMB_RSHIFT + input.neg_modulus[2] * LIMB_RSHIFT_2,
input.neg_modulus[2] * LIMB_RSHIFT + input.neg_modulus[1] * LIMB_RSHIFT_2,
LIMB_RSHIFT_2,
0,
});
++it;
}
}

/**
Expand Down
80 changes: 70 additions & 10 deletions cpp/src/barretenberg/plonk/composer/ultra_composer.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@ class UltraComposer : public ComposerBase {
static constexpr uint32_t UNINITIALIZED_MEMORY_RECORD = UINT32_MAX;
static constexpr size_t NUMBER_OF_GATES_PER_RAM_ACCESS = 2;
static constexpr size_t NUMBER_OF_ARITHMETIC_GATES_PER_RAM_ARRAY = 1;

// number of gates created per non-native field operation in process_non_native_field_multiplications
static constexpr size_t GATES_PER_NON_NATIVE_FIELD_MULTIPLICATION_ARITHMETIC = 7;
struct non_native_field_witnesses {
// first 4 array elements = limbs
// 5th element = prime basis limb
Expand All @@ -39,6 +40,57 @@ class UltraComposer : public ComposerBase {
barretenberg::fr modulus;
};

/**
* @brief Used to store instructions to create non_native_field_multiplication gates.
* We want to cache these (and remove duplicates) as the stdlib code can end up multiplying the same inputs
* repeatedly.
*/
struct cached_non_native_field_multiplication {
std::array<uint32_t, 5> a;
std::array<uint32_t, 5> b;
std::array<uint32_t, 5> q;
std::array<uint32_t, 5> r;
std::array<uint32_t, 6> cross_terms;
std::array<barretenberg::fr, 5> neg_modulus;

bool operator==(const cached_non_native_field_multiplication& other) const
{
bool valid = true;
for (size_t i = 0; i < 5; ++i) {
valid = valid && (a[i] == other.a[i]);
valid = valid && (b[i] == other.b[i]);
valid = valid && (q[i] == other.q[i]);
valid = valid && (r[i] == other.r[i]);
}
return valid;
}
bool operator<(const cached_non_native_field_multiplication& other) const
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Each return is a way in which we can see that this and other are different. cross_terms entry is determined by a static formula in a, b, q, r, hence there is no need to compare these. In case the formula were computed in a dynamic way, we'd have an issue here, but it's not worth worrying about that now.

{
if (a < other.a) {
return true;
}
if (a == other.a) {
if (b < other.b) {
return true;
}
if (b == other.b) {
if (q < other.q) {
return true;
}
if (q == other.q) {
if (r < other.r) {
return true;
}
}
}
}
return false;
}
};

std::vector<cached_non_native_field_multiplication> cached_non_native_field_multiplications;
void process_non_native_field_multiplications();

enum AUX_SELECTORS {
NONE,
LIMB_ACCUMULATE_1,
Expand Down Expand Up @@ -249,11 +301,10 @@ class UltraComposer : public ComposerBase {
* @param rangecount return argument, extra gates due to range checks
* @param romcount return argument, extra gates due to rom reads
* @param ramcount return argument, extra gates due to ram read/writes
* @param nnfcount return argument, extra gates due to queued non native field gates
*/
void get_num_gates_split_into_components(size_t& count,
size_t& rangecount,
size_t& romcount,
size_t& ramcount) const
void get_num_gates_split_into_components(
size_t& count, size_t& rangecount, size_t& romcount, size_t& ramcount, size_t& nnfcount) const
{
count = num_gates;
// each ROM gate adds +1 extra gate due to the rom reads being copied to a sorted list set
Expand Down Expand Up @@ -321,6 +372,13 @@ class UltraComposer : public ComposerBase {
rangecount += ram_range_sizes[i];
}
}
std::vector<cached_non_native_field_multiplication> nnf_copy(cached_non_native_field_multiplications);
// update nnfcount
std::sort(nnf_copy.begin(), nnf_copy.end());

auto last = std::unique(nnf_copy.begin(), nnf_copy.end());
const size_t num_nnf_ops = static_cast<size_t>(std::distance(nnf_copy.begin(), last));
nnfcount = num_nnf_ops * GATES_PER_NON_NATIVE_FIELD_MULTIPLICATION_ARITHMETIC;
}

/**
Expand All @@ -342,8 +400,9 @@ class UltraComposer : public ComposerBase {
size_t rangecount = 0;
size_t romcount = 0;
size_t ramcount = 0;
get_num_gates_split_into_components(count, rangecount, romcount, ramcount);
return count + romcount + ramcount + rangecount;
size_t nnfcount = 0;
get_num_gates_split_into_components(count, rangecount, romcount, ramcount, nnfcount);
return count + romcount + ramcount + rangecount + nnfcount;
}

virtual size_t get_total_circuit_size() const override
Expand All @@ -366,12 +425,13 @@ class UltraComposer : public ComposerBase {
size_t rangecount = 0;
size_t romcount = 0;
size_t ramcount = 0;

get_num_gates_split_into_components(count, rangecount, romcount, ramcount);
size_t nnfcount = 0;
get_num_gates_split_into_components(count, rangecount, romcount, ramcount, nnfcount);

size_t total = count + romcount + ramcount + rangecount;
std::cout << "gates = " << total << " (arith " << count << ", rom " << romcount << ", ram " << ramcount
<< ", range " << rangecount << "), pubinp = " << public_inputs.size() << std::endl;
<< ", range " << rangecount << ", non native field gates " << nnfcount
<< "), pubinp = " << public_inputs.size() << std::endl;
}

void assert_equal_constant(const uint32_t a_idx,
Expand Down
41 changes: 36 additions & 5 deletions cpp/src/barretenberg/stdlib/recursion/transcript/transcript.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -253,8 +253,23 @@ template <typename Composer> class Transcript {
}
byte_array<Composer> compressed_buffer(T0);

byte_array<Composer> base_hash = stdlib::blake3s(compressed_buffer);

// TODO(@zac-williamson) make this a Poseidon hash
byte_array<Composer> base_hash;
if constexpr (Composer::type == ComposerType::PLOOKUP) {
std::vector<field_pt> compression_buffer;
field_pt working_element(context);
size_t byte_counter = 0;
split(working_element, compression_buffer, field_pt(compressed_buffer), byte_counter, 32);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TODO: comment describing split is incorrect. The function splits element and inserts it into element_buffer, using working_element as an auxiliary field element to accomplish this.

if (byte_counter != 0) {
const uint256_t down_shift = uint256_t(1) << uint256_t((bytes_per_element - byte_counter) * 8);
working_element = working_element / barretenberg::fr(down_shift);
working_element = working_element.normalize();
compression_buffer.push_back(working_element);
}
base_hash = stdlib::pedersen_plookup_commitment<Composer>::compress(compression_buffer);
} else {
base_hash = stdlib::blake3s(compressed_buffer);
}
byte_array<Composer> first(field_pt(0), 16);
first.write(base_hash.slice(0, 16));
round_challenges.push_back(first);
Expand All @@ -267,9 +282,25 @@ template <typename Composer> class Transcript {

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This block of code only executes for num_challenges > 2, which (currently) only happens in the nu round when we need to generate short scalars. In this case, we generate 32-byte challenges and split them in half to get the relevant challenges.

for (size_t i = 2; i < num_challenges; i += 2) {
byte_array<Composer> rolling_buffer = base_hash;
rolling_buffer.write(byte_array<Composer>(field_pt(i / 2), 1));
byte_array<Composer> hash_output = stdlib::blake3s(rolling_buffer);

byte_array<Composer> hash_output;
if constexpr (Composer::type == ComposerType::PLOOKUP) {
// TODO(@zac-williamson) make this a Poseidon hash not a Pedersen hash
std::vector<field_pt> compression_buffer;
field_pt working_element(context);
size_t byte_counter = 0;
split(working_element, compression_buffer, field_pt(rolling_buffer), byte_counter, 32);
split(working_element, compression_buffer, field_pt(field_pt(i / 2)), byte_counter, 1);
if (byte_counter != 0) {
const uint256_t down_shift = uint256_t(1) << uint256_t((bytes_per_element - byte_counter) * 8);
working_element = working_element / barretenberg::fr(down_shift);
working_element = working_element.normalize();
compression_buffer.push_back(working_element);
}
hash_output = stdlib::pedersen_plookup_commitment<Composer>::compress(compression_buffer);
} else {
rolling_buffer.write(byte_array<Composer>(field_pt(i / 2), 1));
hash_output = stdlib::blake3s(rolling_buffer);
}
byte_array<Composer> hi(field_pt(0), 16);
hi.write(hash_output.slice(0, 16));
round_challenges.push_back(hi);
Expand Down
Loading