Skip to content

Commit

Permalink
refactor: Make MSM builder more explicit (#6110)
Browse files Browse the repository at this point in the history
After trying to understand the MSM builder part of the ECCVM builder, I
did a refactor for clarity. This is almost entirely naming (e.g we had
sometimes 4+ indices `i, j, k, m, idx` in deeply nested loops that I
gave more explicit names) and comments. I also made the function that
computes the trace rows return a table rather than to mutate one since
there was no real reason to take the latter pattern.
  • Loading branch information
codygunton authored May 8, 2024
1 parent cd05b91 commit 40306b6
Show file tree
Hide file tree
Showing 7 changed files with 392 additions and 388 deletions.
15 changes: 7 additions & 8 deletions barretenberg/cpp/src/barretenberg/eccvm/eccvm_builder_types.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,12 @@
#include "barretenberg/ecc/curves/grumpkin/grumpkin.hpp"

namespace bb::eccvm {

static constexpr size_t NUM_SCALAR_BITS = 128;
static constexpr size_t WNAF_SLICE_BITS = 4;
static constexpr size_t NUM_WNAF_SLICES = (NUM_SCALAR_BITS + WNAF_SLICE_BITS - 1) / WNAF_SLICE_BITS;
static constexpr uint64_t WNAF_MASK = static_cast<uint64_t>((1ULL << WNAF_SLICE_BITS) - 1ULL);
static constexpr size_t POINT_TABLE_SIZE = 1ULL << (WNAF_SLICE_BITS);
static constexpr size_t WNAF_SLICES_PER_ROW = 4;
static constexpr size_t NUM_SCALAR_BITS = 128; // The length of scalars handled by the ECCVVM
static constexpr size_t NUM_WNAF_DIGIT_BITS = 4; // Scalars are decompose into base 16 in wNAF form
static constexpr size_t NUM_WNAF_DIGITS_PER_SCALAR = NUM_SCALAR_BITS / NUM_WNAF_DIGIT_BITS; // 32
static constexpr uint64_t WNAF_MASK = static_cast<uint64_t>((1ULL << NUM_WNAF_DIGIT_BITS) - 1ULL);
static constexpr size_t POINT_TABLE_SIZE = 1ULL << (NUM_WNAF_DIGIT_BITS);
static constexpr size_t WNAF_DIGITS_PER_ROW = 4;
static constexpr size_t ADDITIONS_PER_ROW = 4;

template <typename CycleGroup> struct VMOperation {
Expand Down Expand Up @@ -39,7 +38,7 @@ template <typename CycleGroup> struct ScalarMul {
uint32_t pc;
uint256_t scalar;
typename CycleGroup::affine_element base_point;
std::array<int, NUM_WNAF_SLICES> wnaf_slices;
std::array<int, NUM_WNAF_DIGITS_PER_SCALAR> wnaf_digits;
bool wnaf_skew;
// size bumped by 1 to record base_point.dbl()
std::array<typename CycleGroup::affine_element, POINT_TABLE_SIZE + 1> precomputed_table;
Expand Down
51 changes: 25 additions & 26 deletions barretenberg/cpp/src/barretenberg/eccvm/eccvm_circuit_builder.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,11 @@ class ECCVMCircuitBuilder {
using AffineElement = typename CycleGroup::affine_element;

static constexpr size_t NUM_SCALAR_BITS = bb::eccvm::NUM_SCALAR_BITS;
static constexpr size_t WNAF_SLICE_BITS = bb::eccvm::WNAF_SLICE_BITS;
static constexpr size_t NUM_WNAF_SLICES = bb::eccvm::NUM_WNAF_SLICES;
static constexpr size_t NUM_WNAF_DIGIT_BITS = bb::eccvm::NUM_WNAF_DIGIT_BITS;
static constexpr size_t NUM_WNAF_DIGITS_PER_SCALAR = bb::eccvm::NUM_WNAF_DIGITS_PER_SCALAR;
static constexpr uint64_t WNAF_MASK = bb::eccvm::WNAF_MASK;
static constexpr size_t POINT_TABLE_SIZE = bb::eccvm::POINT_TABLE_SIZE;
static constexpr size_t WNAF_SLICES_PER_ROW = bb::eccvm::WNAF_SLICES_PER_ROW;
static constexpr size_t WNAF_DIGITS_PER_ROW = bb::eccvm::WNAF_DIGITS_PER_ROW;
static constexpr size_t ADDITIONS_PER_ROW = bb::eccvm::ADDITIONS_PER_ROW;

using MSM = bb::eccvm::MSM<CycleGroup>;
Expand All @@ -50,7 +50,8 @@ class ECCVMCircuitBuilder {
/**
* For input point [P], return { -15[P], -13[P], ..., -[P], [P], ..., 13[P], 15[P] }
*/
const auto compute_precomputed_table = [](const AffineElement& base_point) {
const auto compute_precomputed_table =
[](const AffineElement& base_point) -> std::array<AffineElement, POINT_TABLE_SIZE + 1> {
const auto d2 = Element(base_point).dbl();
std::array<Element, POINT_TABLE_SIZE + 1> table;
table[POINT_TABLE_SIZE] = d2; // need this for later
Expand All @@ -69,10 +70,10 @@ class ECCVMCircuitBuilder {
}
return result;
};
const auto compute_wnaf_slices = [](uint256_t scalar) {
std::array<int, NUM_WNAF_SLICES> output;
const auto compute_wnaf_digits = [](uint256_t scalar) -> std::array<int, NUM_WNAF_DIGITS_PER_SCALAR> {
std::array<int, NUM_WNAF_DIGITS_PER_SCALAR> output;
int previous_slice = 0;
for (size_t i = 0; i < NUM_WNAF_SLICES; ++i) {
for (size_t i = 0; i < NUM_WNAF_DIGITS_PER_SCALAR; ++i) {
// slice the scalar into 4-bit chunks, starting with the least significant bits
uint64_t raw_slice = static_cast<uint64_t>(scalar) & WNAF_MASK;

Expand All @@ -86,19 +87,19 @@ class ECCVMCircuitBuilder {
} else if (is_even) {
// for other slices, if it's even, we add 1 to the slice value
// and subtract 16 from the previous slice to preserve the total scalar sum
static constexpr int borrow_constant = static_cast<int>(1ULL << WNAF_SLICE_BITS);
static constexpr int borrow_constant = static_cast<int>(1ULL << NUM_WNAF_DIGIT_BITS);
previous_slice -= borrow_constant;
wnaf_slice += 1;
}

if (i > 0) {
const size_t idx = i - 1;
output[NUM_WNAF_SLICES - idx - 1] = previous_slice;
output[NUM_WNAF_DIGITS_PER_SCALAR - idx - 1] = previous_slice;
}
previous_slice = wnaf_slice;

// downshift raw_slice by 4 bits
scalar = scalar >> WNAF_SLICE_BITS;
scalar = scalar >> NUM_WNAF_DIGIT_BITS;
}

ASSERT(scalar == 0);
Expand All @@ -108,8 +109,6 @@ class ECCVMCircuitBuilder {
return output;
};

// a vector of MSMs = a vector of a vector of scalar muls
// each mul
size_t msm_count = 0;
size_t active_mul_count = 0;
std::vector<size_t> msm_opqueue_index;
Expand All @@ -118,6 +117,7 @@ class ECCVMCircuitBuilder {

const auto& raw_ops = op_queue->get_raw_ops();
size_t op_idx = 0;
// populate opqueue and mul indices
for (const auto& op : raw_ops) {
if (op.mul) {
if (op.z1 != 0 || op.z2 != 0) {
Expand All @@ -142,39 +142,38 @@ class ECCVMCircuitBuilder {
msm_sizes.push_back(active_mul_count);
msm_count++;
}
std::vector<MSM> msms_test(msm_count);
std::vector<MSM> result(msm_count);
for (size_t i = 0; i < msm_count; ++i) {
auto& msm = msms_test[i];
auto& msm = result[i];
msm.resize(msm_sizes[i]);
}

run_loop_in_parallel(msm_opqueue_index.size(), [&](size_t start, size_t end) {
for (size_t i = start; i < end; i++) {
const size_t opqueue_index = msm_opqueue_index[i];
const auto& op = raw_ops[opqueue_index];
const auto& op = raw_ops[msm_opqueue_index[i]];
auto [msm_index, mul_index] = msm_mul_index[i];
if (op.z1 != 0) {
ASSERT(msms_test.size() > msm_index);
ASSERT(msms_test[msm_index].size() > mul_index);
msms_test[msm_index][mul_index] = (ScalarMul{
ASSERT(result.size() > msm_index);
ASSERT(result[msm_index].size() > mul_index);
result[msm_index][mul_index] = (ScalarMul{
.pc = 0,
.scalar = op.z1,
.base_point = op.base_point,
.wnaf_slices = compute_wnaf_slices(op.z1),
.wnaf_digits = compute_wnaf_digits(op.z1),
.wnaf_skew = (op.z1 & 1) == 0,
.precomputed_table = compute_precomputed_table(op.base_point),
});
mul_index++;
}
if (op.z2 != 0) {
ASSERT(msms_test.size() > msm_index);
ASSERT(msms_test[msm_index].size() > mul_index);
ASSERT(result.size() > msm_index);
ASSERT(result[msm_index].size() > mul_index);
auto endo_point = AffineElement{ op.base_point.x * FF::cube_root_of_unity(), -op.base_point.y };
msms_test[msm_index][mul_index] = (ScalarMul{
result[msm_index][mul_index] = (ScalarMul{
.pc = 0,
.scalar = op.z2,
.base_point = endo_point,
.wnaf_slices = compute_wnaf_slices(op.z2),
.wnaf_digits = compute_wnaf_digits(op.z2),
.wnaf_skew = (op.z2 & 1) == 0,
.precomputed_table = compute_precomputed_table(endo_point),
});
Expand All @@ -191,15 +190,15 @@ class ECCVMCircuitBuilder {
// sumcheck relations that involve pc (if we did the other way around, starting at 1 and ending at num_muls,
// we create a discontinuity in pc values between the last transcript row and the following empty row)
uint32_t pc = num_muls;
for (auto& msm : msms_test) {
for (auto& msm : result) {
for (auto& mul : msm) {
mul.pc = pc;
pc--;
}
}

ASSERT(pc == 0);
return msms_test;
return result;
}

static std::vector<ScalarMul> get_flattened_scalar_muls(const std::vector<MSM>& msms)
Expand Down
Loading

0 comments on commit 40306b6

Please sign in to comment.