Skip to content

Commit

Permalink
refactor: towards plain struct flavor entities (#3216)
Browse files Browse the repository at this point in the history
Step towards removing the array backing in the flavor entity classes.
This systematically removes all usage of operator[] so that the array
backing can be deleted in a followup. This is achieved by having a
pointer_view method that is defined in const and non-const forms
(cleanest way I could find was through a simple macro) and returns an
array of each pointer. Pointers were used over references to not have to
deal with references decaying into values.

Happy to take feedback on the approach, it isn't cleaning up much just
yet, but will make way for a followup PR that removes the backing
arrays.

Also:
- Adds a zip_view class that aims to get us using zip_view before
C++2023 fully lands. I yoinked some open source and slimmed it down a
bit for what we're likely to use.

---------

Co-authored-by: ludamad <[email protected]>
  • Loading branch information
ludamad and ludamad0 authored Nov 10, 2023
1 parent 11e6ca7 commit 3ba89cf
Show file tree
Hide file tree
Showing 18 changed files with 737 additions and 105 deletions.
6 changes: 5 additions & 1 deletion barretenberg/cpp/CMakePresets.json
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,12 @@
"displayName": "Debugging build with Clang-16",
"description": "Build with globally installed Clang-16 in debug mode",
"inherits": "clang16",
"binaryDir": "build-debug",
"environment": {
"CMAKE_BUILD_TYPE": "Debug"
"CMAKE_BUILD_TYPE": "Debug",
"CFLAGS": "-gdwarf-4",
"CXXFLAGS": "-gdwarf-4",
"LDFLAGS": "-gdwarf-4"
},
"cacheVariables": {
"ENABLE_ASAN": "OFF",
Expand Down
190 changes: 190 additions & 0 deletions barretenberg/cpp/src/barretenberg/common/zip_view.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,190 @@
#pragma once
/* ********************************* FILE ************************************/
/** \file mzip.hpp
*
* \brief This header contains the zip iterator class.
*
* WARNING this is a zip view, not a zip copy!
*
* \remark
* - c++17
* - no dependencies
* - header only
* - tested by test_zip_iterator.cpp
* - not thread safe
* - view !
* - extends lifetime of rvalue inputs untill the end of the for loop
*
* \todo
* - add algorithm tests, probably does not work at all...
*
*
* \example
* std::vector<int> as{1,2},bs{1,2,3};
* for(auto [index, a,b]: zip(as,bs)){
* a++;
* }
* cout<<as<<endl; // shows (2, 3)
* works for any number
*
* zip returns tuples of references to the contents
*
*
*
*
*
*
*
*
*
* does not copy the containers
* returns tuple of references to the containers content
* iterates untill the first iterator hits end.
* extends ownership to the end of the for loop, or untill zip goes out of scope.
*
* possibly risky behaviour on clang, gcc for fun(const zip& z) when called as fun(zip(a,b))
*
*
* Depends on the following behaviour for for loops:
*
* // in for(x:zip)
* // equiv:
* { // c++ 11+
* auto && __range = range_expression ;
* for (auto __begin = begin_expr, __end = end_expr; __begin != __end; ++__begin) {
* range_declaration = *__begin;
* loop_statement
* }
* }
*
* { // in c++ 17
* auto && __range = range_expression ;
* auto __begin = begin_expr ;
* auto __end = end_expr ;
* for ( ; __begin != __end; ++__begin) {
* range_declaration = *__begin;
* loop_statement
* }
* }
*
*
* \author Mikael Persson
* \date 2019-09-01
******************************************************************************/

static_assert(__cplusplus >= 201703L,
" must be c++17 or greater"); // could be rewritten in c++11, but the features you must use will be buggy
// in an older compiler anyways.
#include <cassert>
#include <functional>
#include <iostream>
#include <sstream>
#include <tuple>
#include <type_traits>
#include <vector>

template <class T>
/**
* @brief The zip_iterator class
*
* Provides a zip iterator which is at end when any is at end
*/
class zip_iterator {
public:
// speeds up compilation a little bit...
using tuple_indexes = std::make_index_sequence<std::tuple_size_v<std::remove_reference_t<T>>>;

zip_iterator(T iter, T iter_end)
: iter(iter)
, iter_end(iter_end)
{}
// prefix, inc first, then return
zip_iterator& operator++()
{
for_each_in_tuple([](auto&& x) { return x++; }, iter);
// then if any hit end, update all to point to end.
auto end = apply2([](auto x, auto y) { return x == y; }, iter, iter_end);
if (if_any_in(end)) {
apply2([](auto& x, auto y) { return x = y; }, iter, iter_end);
}
index++;
return *this;
}
// sufficient because ++ keeps track and sets all to end when any is
bool operator!=(const zip_iterator& other) const { return other.iter != iter; }
auto operator*() const
{
return std::forward<decltype(get_refs(iter, tuple_indexes{}))>(get_refs(iter, tuple_indexes{}));
}

private:
T iter, iter_end;
std::size_t index = 0;

template <std::size_t... I> auto get_refs(T t, std::index_sequence<I...>) const
{
return std::make_tuple(std::ref(*std::get<I>(t))...);
}

template <class F, class A, std::size_t... I> auto apply2_impl(F&& f, A&& a, A&& b, std::index_sequence<I...>)
{
return std::make_tuple(f(std::get<I>(a), std::get<I>(b))...);
}
template <class F, class A> auto apply2(F&& f, A&& a, A&& b)
{
return apply2_impl(std::forward<F>(f), std::forward<A>(a), std::forward<A>(b), tuple_indexes{});
}
template <class A, std::size_t... I> bool if_any_impl(const A& t, std::index_sequence<I...>) const
{
return (... || std::get<I>(t)); // c++17
}

// in general context we must enforce that these are tuples
template <class A> bool if_any_in(A&& t) const { return if_any_impl(std::forward<A>(t), tuple_indexes{}); }

template <class F, class Tuple, std::size_t... I>
auto for_each_in_impl(F&& f, Tuple&& t, std::index_sequence<I...>) const
{
return std::make_tuple(f(std::get<I>(t))...);
}

template <class F, class A> void for_each_in_tuple(F&& f, A&& t) const
{
for_each_in_impl(std::forward<F>(f), std::forward<A>(t), tuple_indexes{});
}
};

template <class... S> class zip_view {
using arg_indexes = std::make_index_sequence<sizeof...(S)>;

public:
zip_view(S... args)
: args(std::forward<S>(args)...)
{}
auto begin() const { return get_begins(arg_indexes{}); }
auto end() const { return get_ends(arg_indexes{}); }
[[nodiscard]] std::size_t size() const { return size_impl(arg_indexes{}); }

private:
std::tuple<S...> args;
template <std::size_t... I> auto get_begins(std::index_sequence<I...>) const
{
return zip_iterator(std::make_tuple(std::get<I>(args).begin()...), std::make_tuple(std::get<I>(args).end()...));
}
template <std::size_t... I> auto get_ends(std::index_sequence<I...>) const
{
return zip_iterator(std::make_tuple(std::get<I>(args).end()...), std::make_tuple(std::get<I>(args).end()...));
}
template <std::size_t... I> auto size_impl(std::index_sequence<I...>) const
{
return std::max({ std::size_t(std::get<I>(args).size())... });
}

template <class A, std::size_t... I> bool if_any_impl(const A& t, std::index_sequence<I...>) const
{
return (... || std::get<I>(t)); // c++17
}
};

// deduction guide,
template <class... S> zip_view(S&&...) -> zip_view<S...>;
120 changes: 112 additions & 8 deletions barretenberg/cpp/src/barretenberg/flavor/ecc_vm.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -394,6 +394,113 @@ template <typename CycleGroup_T, typename Curve_T, typename PCS_T> class ECCVMBa
}
// clang-format on

// defines a method pointer_view that returns the following, with const and non-const variants
DEFINE_POINTER_VIEW(NUM_ALL_ENTITIES,
&lagrange_first,
&lagrange_second,
&lagrange_last,
&transcript_add,
&transcript_mul,
&transcript_eq,
&transcript_collision_check,
&transcript_msm_transition,
&transcript_pc,
&transcript_msm_count,
&transcript_x,
&transcript_y,
&transcript_z1,
&transcript_z2,
&transcript_z1zero,
&transcript_z2zero,
&transcript_op,
&transcript_accumulator_x,
&transcript_accumulator_y,
&transcript_msm_x,
&transcript_msm_y,
&precompute_pc,
&precompute_point_transition,
&precompute_round,
&precompute_scalar_sum,
&precompute_s1hi,
&precompute_s1lo,
&precompute_s2hi,
&precompute_s2lo,
&precompute_s3hi,
&precompute_s3lo,
&precompute_s4hi,
&precompute_s4lo,
&precompute_skew,
&precompute_dx,
&precompute_dy,
&precompute_tx,
&precompute_ty,
&msm_transition,
&msm_add,
&msm_double,
&msm_skew,
&msm_accumulator_x,
&msm_accumulator_y,
&msm_pc,
&msm_size_of_msm,
&msm_count,
&msm_round,
&msm_add1,
&msm_add2,
&msm_add3,
&msm_add4,
&msm_x1,
&msm_y1,
&msm_x2,
&msm_y2,
&msm_x3,
&msm_y3,
&msm_x4,
&msm_y4,
&msm_collision_x1,
&msm_collision_x2,
&msm_collision_x3,
&msm_collision_x4,
&msm_lambda1,
&msm_lambda2,
&msm_lambda3,
&msm_lambda4,
&msm_slice1,
&msm_slice2,
&msm_slice3,
&msm_slice4,
&transcript_accumulator_empty,
&transcript_reset_accumulator,
&precompute_select,
&lookup_read_counts_0,
&lookup_read_counts_1,
&z_perm,
&lookup_inverses,
&transcript_mul_shift,
&transcript_msm_count_shift,
&transcript_accumulator_x_shift,
&transcript_accumulator_y_shift,
&precompute_scalar_sum_shift,
&precompute_s1hi_shift,
&precompute_dx_shift,
&precompute_dy_shift,
&precompute_tx_shift,
&precompute_ty_shift,
&msm_transition_shift,
&msm_add_shift,
&msm_double_shift,
&msm_skew_shift,
&msm_accumulator_x_shift,
&msm_accumulator_y_shift,
&msm_count_shift,
&msm_round_shift,
&msm_add1_shift,
&msm_pc_shift,
&precompute_pc_shift,
&transcript_pc_shift,
&precompute_round_shift,
&transcript_accumulator_empty_shift,
&precompute_select_shift,
&z_perm_shift)
std::vector<HandleType> get_wires() override
{
return {
Expand Down Expand Up @@ -680,13 +787,12 @@ template <typename CycleGroup_T, typename Curve_T, typename PCS_T> class ECCVMBa
*/
class AllPolynomials : public AllEntities<Polynomial, PolynomialHandle> {
public:
[[nodiscard]] size_t get_polynomial_size() const { return this->lagrange_first.size(); }
AllValues get_row(const size_t row_idx) const
{
AllValues result;
size_t column_idx = 0; // // TODO(https://github.com/AztecProtocol/barretenberg/issues/391) zip
for (auto& column : this->_data) {
result[column_idx] = column[row_idx];
column_idx++;
for (auto [result_field, polynomial] : zip_view(result.pointer_view(), this->pointer_view())) {
*result_field = (*polynomial)[row_idx];
}
return result;
}
Expand Down Expand Up @@ -736,10 +842,8 @@ template <typename CycleGroup_T, typename Curve_T, typename PCS_T> class ECCVMBa
AllValues get_row(const size_t row_idx)
{
AllValues result;
size_t column_idx = 0; // TODO(https://github.com/AztecProtocol/barretenberg/issues/391) zip
for (auto& column : this->_data) {
result[column_idx] = column[row_idx];
column_idx++;
for (auto [result_field, polynomial] : zip_view(result.pointer_view(), this->pointer_view())) {
*result_field = (*polynomial)[row_idx];
}
return result;
}
Expand Down
21 changes: 19 additions & 2 deletions barretenberg/cpp/src/barretenberg/flavor/flavor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@
*/

#pragma once
#include "barretenberg/common/zip_view.hpp"
#include "barretenberg/polynomials/barycentric.hpp"
#include "barretenberg/polynomials/evaluation_domain.hpp"
#include "barretenberg/polynomials/univariate.hpp"
Expand All @@ -74,6 +75,23 @@

namespace proof_system::honk::flavor {

template <std::size_t ExpectedSize, typename T, std::size_t N> static auto _assert_array_size(std::array<T, N>&& array)
{
static_assert(N == ExpectedSize,
"Expected array size to match given size (first parameter) in DEFINE_POINTER_VIEW");
return array;
}

#define DEFINE_POINTER_VIEW(ExpectedSize, ...) \
[[nodiscard]] auto pointer_view() \
{ \
return _assert_array_size<ExpectedSize>(std::array{ __VA_ARGS__ }); \
} \
[[nodiscard]] auto pointer_view() const \
{ \
return _assert_array_size<ExpectedSize>(std::array{ __VA_ARGS__ }); \
}

/**
* @brief Base data class template, a wrapper for std::array, from which every flavor class ultimately derives.
*
Expand All @@ -87,8 +105,7 @@ template <typename DataType, typename HandleType, size_t NUM_ENTITIES> class Ent
ArrayType _data;

virtual ~Entities_() = default;

DataType& operator[](size_t idx) { return _data[idx]; };
// TODO(AD): remove these with the backing array
typename ArrayType::iterator begin() { return _data.begin(); };
typename ArrayType::iterator end() { return _data.end(); };

Expand Down
Loading

0 comments on commit 3ba89cf

Please sign in to comment.