Skip to content

Commit

Permalink
feat: Use structured polys to reduce prover memory (#8587)
Browse files Browse the repository at this point in the history
We use the new structured polynomial class to reduce the amount of
memory used by the Prover. For ClientIVCBench, this results in a
reduction of 36.5%, going from 2377.99MiB to 1511.34MiB.

This is due to a restricting polynomials down to smaller sizes. For
lagrange_first and last, we only allocate 1 element. For the gate
selectors, we only allocate the fixed block size for each one, cutting
the 8 gate selectors into almost 1 selector (caveat is that the
arithmetic selector spans the aux block for now). For the 5 ecc_op
polynomials, we restrict them to just the ecc_op block. For 9 of the 10
databus polynomials, we restrict them to MAX_DATABUS_SIZE. For the 4
table polynomials and the lookup read counts and read tag polynomials,
we restrict them to MAX_LOOKUP_TABLES_SIZE. We also restrict the inverse
polynomials, but this is complicated to explain.

Overall, this essentially allows us to cut down on 28 of the 54 total
polynomials, which leads to the drop of 867MiB.

There's more juice to be squeezed here, but this is a massive reduction
that should basically get us there.

Before:
<img width="1331" alt="Screenshot 2024-09-20 at 5 00 27 PM"
src="https://github.com/user-attachments/assets/7572a5d2-4fa9-4b4f-af1d-7885260d6756">
After:
<img width="1363" alt="Screenshot 2024-09-26 at 10 03 54 AM"
src="https://github.com/user-attachments/assets/aed64b1d-862c-4a21-9e32-160993d1f5c3">

For one instance, we cut down memory by 97MiB. 

And timing benchmark:
```
--------------------------------------------------------------------------------
Benchmark                      Time             CPU   Iterations UserCounters...
--------------------------------------------------------------------------------
ClientIVCBench/Full/6      33216 ms        30637 ms            1 Arithmetic::accumulate=3.89126M Arithmetic::accumulate(t)=7.32768G Auxiliary::accumulate=1.98134M Auxiliary::accumulate(t)=13.4156G COMMIT::databus=108 COMMIT::databus(t)=8.50634M COMMIT::databus_inverses=36 COMMIT::databus_inverses(t)=11.8267M COMMIT::ecc_op_wires=48 COMMIT::ecc_op_wires(t)=38.2178M COMMIT::lookup_counts_tags=12 COMMIT::lookup_counts_tags(t)=107.571M COMMIT::lookup_inverses=12 COMMIT::lookup_inverses(t)=257.772M COMMIT::wires=24 COMMIT::wires(t)=2.23405G COMMIT::z_perm=12 COMMIT::z_perm(t)=2.31578G DatabusRead::accumulate=447 DatabusRead::accumulate(t)=1.72333M Decider::construct_proof=1 Decider::construct_proof(t)=1.57152G DeciderProvingKey(Circuit&)=12 DeciderProvingKey(Circuit&)(t)=2.63528G DeltaRange::accumulate=1.87876M DeltaRange::accumulate(t)=4.27884G ECCVMProver(CircuitBuilder&)=1 ECCVMProver(CircuitBuilder&)(t)=228.84M ECCVMProver::construct_proof=1 ECCVMProver::construct_proof(t)=2.59672G Elliptic::accumulate=183.692k Elliptic::accumulate(t)=451.988M Goblin::merge=23 Goblin::merge(t)=116.924M Lookup::accumulate=1.66363M Lookup::accumulate(t)=3.74588G MegaFlavor::get_row=6.18564M MegaFlavor::get_row(t)=4.44329G OinkProver::execute_grand_product_computation_round=12 OinkProver::execute_grand_product_computation_round(t)=3.59852G OinkProver::execute_log_derivative_inverse_round=12 OinkProver::execute_log_derivative_inverse_round(t)=2.4985G OinkProver::execute_preamble_round=12 OinkProver::execute_preamble_round(t)=178.858k OinkProver::execute_sorted_list_accumulator_round=12 OinkProver::execute_sorted_list_accumulator_round(t)=683.402M OinkProver::execute_wire_commitments_round=12 OinkProver::execute_wire_commitments_round(t)=1.71268G OinkProver::generate_alphas_round=12 OinkProver::generate_alphas_round(t)=3.50247M Permutation::accumulate=10.6427M Permutation::accumulate(t)=40.1379G PoseidonExt::accumulate=30.452k PoseidonExt::accumulate(t)=76.6116M PoseidonInt::accumulate=210.454k PoseidonInt::accumulate(t)=365.722M ProtogalaxyProver::prove=11 ProtogalaxyProver::prove(t)=19.9675G ProtogalaxyProver_::combiner_quotient_round=11 ProtogalaxyProver_::combiner_quotient_round(t)=8.76403G ProtogalaxyProver_::compute_row_evaluations=11 ProtogalaxyProver_::compute_row_evaluations(t)=1.9728G ProtogalaxyProver_::perturbator_round=11 ProtogalaxyProver_::perturbator_round(t)=2.86884G ProtogalaxyProver_::run_oink_prover_on_each_incomplete_key=11 ProtogalaxyProver_::run_oink_prover_on_each_incomplete_key(t)=7.66211G ProtogalaxyProver_::update_target_sum_and_fold=11 ProtogalaxyProver_::update_target_sum_and_fold(t)=672.424M TranslatorCircuitBuilder::constructor=1 TranslatorCircuitBuilder::constructor(t)=32.9044M TranslatorProver=1 TranslatorProver(t)=43.1984M TranslatorProver::construct_proof=1 TranslatorProver::construct_proof(t)=832.913M batch_mul_with_endomorphism=16 batch_mul_with_endomorphism(t)=408.881M commit=543 commit(t)=6.5699G commit_sparse=36 commit_sparse(t)=11.813M compute_combiner=11 compute_combiner(t)=8.32169G compute_perturbator=11 compute_perturbator(t)=2.86857G compute_univariate=51 compute_univariate(t)=2.20204G construct_circuits=12 construct_circuits(t)=4.30706G pippenger=215 pippenger(t)=102.025M pippenger_unsafe_optimized_for_non_dyadic_polys=543 pippenger_unsafe_optimized_for_non_dyadic_polys(t)=6.56543G
Benchmarking lock deleted.
client_ivc_bench.json                                                                                                                                                                                                                  100% 6930   190.2KB/s   00:00    
function                                  ms     % sum
construct_circuits(t)                   4307    13.35%
DeciderProvingKey(Circuit&)(t)          2635     8.17%
ProtogalaxyProver::prove(t)            19967    61.90%
Decider::construct_proof(t)             1572     4.87%
ECCVMProver(CircuitBuilder&)(t)          229     0.71%
ECCVMProver::construct_proof(t)         2597     8.05%
TranslatorProver::construct_proof(t)     833     2.58%
Goblin::merge(t)                         117     0.36%

Total time accounted for: 32257ms/33216ms = 97.11%

Major contributors:
function                                  ms    % sum
commit(t)                               6570   20.37%
compute_combiner(t)                     8322   25.80%
compute_perturbator(t)                  2869    8.89%
compute_univariate(t)                   2202    6.83%

Breakdown of ProtogalaxyProver::prove:
ProtogalaxyProver_::run_oink_prover_on_each_incomplete_key(t)    7662    38.37%
ProtogalaxyProver_::perturbator_round(t)                         2869    14.37%
ProtogalaxyProver_::combiner_quotient_round(t)                   8764    43.89%
ProtogalaxyProver_::update_target_sum_and_fold(t)                 672     3.37%

Relation contributions (times to be interpreted relatively):
Total time accounted for (ms):    69802
operation                       ms     % sum
Arithmetic::accumulate(t)     7328    10.50%
Permutation::accumulate(t)   40138    57.50%
Lookup::accumulate(t)         3746     5.37%
DeltaRange::accumulate(t)     4279     6.13%
Elliptic::accumulate(t)        452     0.65%
Auxiliary::accumulate(t)     13416    19.22%
EccOp::accumulate(t)             0     0.00%
DatabusRead::accumulate(t)       2     0.00%
PoseidonExt::accumulate(t)      77     0.11%
PoseidonInt::accumulate(t)     366     0.52%

Commitment contributions:
Total time accounted for (ms):     4974
operation                          ms     % sum
COMMIT::wires(t)                 2234    44.92%
COMMIT::z_perm(t)                2316    46.56%
COMMIT::databus(t)                  9     0.17%
COMMIT::ecc_op_wires(t)            38     0.77%
COMMIT::lookup_inverses(t)        258     5.18%
COMMIT::databus_inverses(t)        12     0.24%
COMMIT::lookup_counts_tags(t)     108     2.16%
```

Compared to master, the notable differences are:
`DeciderProvingKey(Circuit&)` was at 8043ms and now is 2635ms. 
`ProtogalaxyProver::prove` was 20953ms and now is 19967ms. Unclear if
this is expected or not.
`commit` was 7033ms and is now 6570ms.
  • Loading branch information
lucasxia01 authored and Rumata888 committed Sep 27, 2024
1 parent 4f36d35 commit 22544ce
Show file tree
Hide file tree
Showing 21 changed files with 388 additions and 131 deletions.
Empty file modified barretenberg/cpp/scripts/analyze_client_ivc_bench.py
100644 → 100755
Empty file.
4 changes: 4 additions & 0 deletions barretenberg/cpp/src/barretenberg/constants.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,8 @@ static constexpr uint32_t CONST_PROOF_SIZE_LOG_N = 28;
// to ensure a constant PG proof size and a PG recursive verifier circuit that is independent of the size of the
// circuits being folded.
static constexpr uint32_t CONST_PG_LOG_N = 20;

static constexpr uint32_t MAX_LOOKUP_TABLES_SIZE = 70000;

static constexpr uint32_t MAX_DATABUS_SIZE = 10;
} // namespace bb
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,24 @@
#include "barretenberg/stdlib_circuit_builders/ultra_keccak_flavor.hpp"
namespace bb {

template <class Flavor> void ExecutionTrace_<Flavor>::populate_public_inputs_block(Builder& builder)
{
ZoneScopedN("populate_public_inputs_block");
// Update the public inputs block
for (const auto& idx : builder.public_inputs) {
for (size_t wire_idx = 0; wire_idx < NUM_WIRES; ++wire_idx) {
if (wire_idx < 2) { // first two wires get a copy of the public inputs
builder.blocks.pub_inputs.wires[wire_idx].emplace_back(idx);
} else { // the remaining wires get zeros
builder.blocks.pub_inputs.wires[wire_idx].emplace_back(builder.zero_idx);
}
}
for (auto& selector : builder.blocks.pub_inputs.selectors) {
selector.emplace_back(0);
}
}
}

template <class Flavor>
void ExecutionTrace_<Flavor>::populate(Builder& builder, typename Flavor::ProvingKey& proving_key, bool is_structured)
{
Expand Down Expand Up @@ -56,10 +74,13 @@ typename ExecutionTrace_<Flavor>::TraceData ExecutionTrace_<Flavor>::construct_t
Builder& builder, typename Flavor::ProvingKey& proving_key, bool is_structured)
{
ZoneScopedN("construct_trace_data");
TraceData trace_data{ builder, proving_key };

// Complete the public inputs execution trace block from builder.public_inputs
populate_public_inputs_block(builder);
if constexpr (IsPlonkFlavor<Flavor>) {
// Complete the public inputs execution trace block from builder.public_inputs
populate_public_inputs_block(builder);
}

TraceData trace_data{ builder, proving_key };

uint32_t offset = Flavor::has_zero_row ? 1 : 0; // Offset at which to place each block in the trace polynomials
// For each block in the trace, populate wire polys, copy cycles and selector polys
Expand Down Expand Up @@ -87,8 +108,7 @@ typename ExecutionTrace_<Flavor>::TraceData ExecutionTrace_<Flavor>::construct_t
// Insert the selector values for this block into the selector polynomials at the correct offset
// TODO(https://github.com/AztecProtocol/barretenberg/issues/398): implicit arithmetization/flavor consistency
for (size_t selector_idx = 0; selector_idx < NUM_SELECTORS; selector_idx++) {
auto selector_poly = trace_data.selectors[selector_idx];
auto selector = block.selectors[selector_idx];
auto& selector = block.selectors[selector_idx];
for (size_t row_idx = 0; row_idx < block_size; ++row_idx) {
size_t trace_row_idx = row_idx + offset;
trace_data.selectors[selector_idx].set_if_valid_index(trace_row_idx, selector[row_idx]);
Expand All @@ -111,35 +131,19 @@ typename ExecutionTrace_<Flavor>::TraceData ExecutionTrace_<Flavor>::construct_t
return trace_data;
}

template <class Flavor> void ExecutionTrace_<Flavor>::populate_public_inputs_block(Builder& builder)
{
ZoneScopedN("populate_public_inputs_block");
// Update the public inputs block
for (auto& idx : builder.public_inputs) {
for (size_t wire_idx = 0; wire_idx < NUM_WIRES; ++wire_idx) {
if (wire_idx < 2) { // first two wires get a copy of the public inputs
builder.blocks.pub_inputs.wires[wire_idx].emplace_back(idx);
} else { // the remaining wires get zeros
builder.blocks.pub_inputs.wires[wire_idx].emplace_back(builder.zero_idx);
}
}
for (auto& selector : builder.blocks.pub_inputs.selectors) {
selector.emplace_back(0);
}
}
}

template <class Flavor>
void ExecutionTrace_<Flavor>::add_ecc_op_wires_to_proving_key(Builder& builder,
typename Flavor::ProvingKey& proving_key)
requires IsGoblinFlavor<Flavor>
{
// Copy the ecc op data from the conventional wires into the op wires over the range of ecc op gates
auto& ecc_op_selector = proving_key.polynomials.lagrange_ecc_op;
const size_t op_wire_offset = Flavor::has_zero_row ? 1 : 0;

// Copy the ecc op data from the conventional wires into the op wires over the range of ecc op gates
const size_t num_ecc_ops = builder.blocks.ecc_op.size();
for (auto [ecc_op_wire, wire] :
zip_view(proving_key.polynomials.get_ecc_op_wires(), proving_key.polynomials.get_wires())) {
for (size_t i = 0; i < builder.blocks.ecc_op.size(); ++i) {
for (size_t i = 0; i < num_ecc_ops; ++i) {
size_t idx = i + op_wire_offset;
ecc_op_wire.at(idx) = wire[idx];
ecc_op_selector.at(idx) = 1; // construct selector as the indicator on the ecc op block
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@ template <class Flavor> class ExecutionTrace_ {
for (auto [selector, other_selector] : zip_view(selectors, proving_key.polynomials.get_selectors())) {
selector = other_selector.share();
}
proving_key.polynomials.set_shifted(); // Ensure shifted wires are set correctly
} else {
// Initialize and share the wire and selector polynomials
for (size_t idx = 0; idx < NUM_WIRES; ++idx) {
Expand Down Expand Up @@ -74,6 +73,14 @@ template <class Flavor> class ExecutionTrace_ {
*/
static void populate(Builder& builder, ProvingKey&, bool is_structured = false);

/**
* @brief Populate the public inputs block
* @details The first two wires are a copy of the public inputs and the other wires and all selectors are zero
*
* @param circuit
*/
static void populate_public_inputs_block(Builder& builder);

private:
/**
* @brief Add the memory records indicating which rows correspond to RAM/ROM reads/writes
Expand Down Expand Up @@ -104,14 +111,6 @@ template <class Flavor> class ExecutionTrace_ {
typename Flavor::ProvingKey& proving_key,
bool is_structured = false);

/**
* @brief Populate the public inputs block
* @details The first two wires are a copy of the public inputs and the other wires and all selectors are zero
*
* @param builder
*/
static void populate_public_inputs_block(Builder& builder);

/**
* @brief Construct and add the goblin ecc op wires to the proving key
* @details The ecc op wires vanish everywhere except on the ecc op block, where they contain a copy of the ecc op
Expand Down
3 changes: 3 additions & 0 deletions barretenberg/cpp/src/barretenberg/flavor/flavor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -374,6 +374,9 @@ concept IsGoblinFlavor = IsAnyOf<T, MegaFlavor,
MegaRecursiveFlavor_<UltraCircuitBuilder>,
MegaRecursiveFlavor_<MegaCircuitBuilder>, MegaRecursiveFlavor_<CircuitSimulatorBN254>>;

template <typename T>
concept HasDataBus = IsGoblinFlavor<T>;

template <typename T>
concept IsRecursiveFlavor = IsAnyOf<T, UltraRecursiveFlavor_<UltraCircuitBuilder>,
UltraRecursiveFlavor_<MegaCircuitBuilder>,
Expand Down
2 changes: 1 addition & 1 deletion barretenberg/cpp/src/barretenberg/flavor/flavor.test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ TEST(Flavor, Getters)
// set
size_t coset_idx = 0;
for (auto& id_poly : proving_key.polynomials.get_ids()) {
typename Flavor::Polynomial new_poly(proving_key.circuit_size);
id_poly = typename Flavor::Polynomial(proving_key.circuit_size);
for (size_t i = 0; i < proving_key.circuit_size; ++i) {
id_poly.at(i) = coset_idx * proving_key.circuit_size + i;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,7 @@ template <typename FF, size_t NUM_WIRES, size_t NUM_SELECTORS> class ExecutionTr
Selectors selectors;
bool has_ram_rom = false; // does the block contain RAM/ROM gates
bool is_pub_inputs = false; // is this the public inputs block

uint32_t fixed_size = 0; // Fixed size for use in structured trace
uint32_t trace_offset = 0; // where this block starts in the trace

bool operator==(const ExecutionTraceBlock& other) const = default;

Expand Down Expand Up @@ -104,6 +103,8 @@ template <typename FF, size_t NUM_WIRES, size_t NUM_SELECTORS> class ExecutionTr
}
}
#endif
private:
uint32_t fixed_size = 0; // Fixed size for use in structured trace
};

} // namespace bb
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,12 @@ template <typename FF_> class MegaArith {
aux, lookup, busread, poseidon2_external, poseidon2_internal };
}

auto get_gate_blocks()
{
return RefArray{ arithmetic, delta_range, elliptic, aux,
lookup, busread, poseidon2_external, poseidon2_internal };
}

bool operator==(const MegaTraceBlocks& other) const = default;
};

Expand Down Expand Up @@ -153,7 +159,11 @@ template <typename FF_> class MegaArith {

struct TraceBlocks : public MegaTraceBlocks<MegaTraceBlock> {

E2eStructuredBlockSizes fixed_block_sizes;
TraceBlocks()
{
this->aux.has_ram_rom = true;
this->pub_inputs.is_pub_inputs = true;
}

// Set fixed block sizes for use in structured trace
void set_fixed_block_sizes(TraceStructure setting)
Expand All @@ -178,10 +188,13 @@ template <typename FF_> class MegaArith {
}
}

TraceBlocks()
void compute_offsets(bool is_structured)
{
this->aux.has_ram_rom = true;
this->pub_inputs.is_pub_inputs = true;
uint32_t offset = 1; // start at 1 because the 0th row is unused for selectors for Honk
for (auto& block : this->get()) {
block.trace_offset = offset;
offset += block.get_fixed_size(is_structured);
}
}

void summarize() const
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,11 @@ template <typename FF_> class UltraArith {
aux, lookup, poseidon2_external, poseidon2_internal };
}

auto get_gate_blocks()
{
return RefArray{ arithmetic, delta_range, elliptic, aux, lookup, poseidon2_external, poseidon2_internal };
}

bool operator==(const UltraTraceBlocks& other) const = default;
};

Expand Down Expand Up @@ -90,6 +95,12 @@ template <typename FF_> class UltraArith {

struct TraceBlocks : public UltraTraceBlocks<UltraTraceBlock> {

TraceBlocks()
{
this->aux.has_ram_rom = true;
this->pub_inputs.is_pub_inputs = true;
}

// Set fixed block sizes for use in structured trace
void set_fixed_block_sizes(TraceStructure setting)
{
Expand All @@ -110,10 +121,13 @@ template <typename FF_> class UltraArith {
}
}

TraceBlocks()
void compute_offsets(bool is_structured)
{
this->aux.has_ram_rom = true;
this->pub_inputs.is_pub_inputs = true;
uint32_t offset = 1; // start at 1 because the 0th row is unused for selectors for Honk
for (auto& block : this->get()) {
block.trace_offset = offset;
offset += block.get_fixed_size(is_structured);
}
}

auto get()
Expand All @@ -137,7 +151,7 @@ template <typename FF_> class UltraArith {

size_t get_total_structured_size()
{
size_t total_size = 0;
size_t total_size = 1; // start at 1 because the 0th row is unused for selectors for Honk
for (auto block : this->get()) {
total_size += block.get_fixed_size();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@ namespace bb {
template <typename Flavor>
void construct_lookup_table_polynomials(const RefArray<typename Flavor::Polynomial, 4>& table_polynomials,
const typename Flavor::CircuitBuilder& circuit,
size_t dyadic_circuit_size,
size_t additional_offset = 0)
const size_t dyadic_circuit_size,
const size_t additional_offset = 0)
{
// Create lookup selector polynomials which interpolate each table column.
// Our selector polys always need to interpolate the full subgroup size, so here we offset so as to
Expand All @@ -22,8 +22,9 @@ void construct_lookup_table_polynomials(const RefArray<typename Flavor::Polynomi
// | table randomness
// ignored, as used for regular constraints and padding to the next power of 2.
// TODO(https://github.com/AztecProtocol/barretenberg/issues/1033): construct tables and counts at top of trace
ASSERT(dyadic_circuit_size > circuit.get_tables_size() + additional_offset);
size_t offset = dyadic_circuit_size - circuit.get_tables_size() - additional_offset;
const size_t tables_size = circuit.get_tables_size();
ASSERT(dyadic_circuit_size > tables_size + additional_offset);
size_t offset = dyadic_circuit_size - tables_size - additional_offset;

for (const auto& table : circuit.lookup_tables) {
const fr table_index(table.table_index);
Expand All @@ -49,12 +50,12 @@ template <typename Flavor>
void construct_lookup_read_counts(typename Flavor::Polynomial& read_counts,
typename Flavor::Polynomial& read_tags,
typename Flavor::CircuitBuilder& circuit,
size_t dyadic_circuit_size)
const size_t dyadic_circuit_size)
{
const size_t tables_size = circuit.get_tables_size();
// TODO(https://github.com/AztecProtocol/barretenberg/issues/1033): construct tables and counts at top of trace
size_t offset = dyadic_circuit_size - circuit.get_tables_size();
size_t table_offset = dyadic_circuit_size - tables_size;

size_t table_offset = offset; // offset of the present table in the table polynomials
// loop over all tables used in the circuit; each table contains data about the lookups made on it
for (auto& table : circuit.lookup_tables) {
table.initialize_index_map();
Expand Down
Loading

0 comments on commit 22544ce

Please sign in to comment.