Skip to content

Commit

Permalink
Merge branch 'master' into cg/guh-bench
Browse files Browse the repository at this point in the history
  • Loading branch information
codygunton authored Feb 19, 2024
2 parents cb90e33 + ba6048d commit 70e1a45
Show file tree
Hide file tree
Showing 11 changed files with 108 additions and 88 deletions.
41 changes: 19 additions & 22 deletions barretenberg/cpp/src/barretenberg/commitment_schemes/ipa/ipa.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
* @brief IPA (inner-product argument) commitment scheme class. Conforms to the specification
* https://hackmd.io/q-A8y6aITWyWJrvsGGMWNA?view.
*
*
*/
namespace bb {
template <typename Curve> class IPA {
Expand Down Expand Up @@ -141,32 +142,30 @@ template <typename Curve> class IPA {
const Fr round_challenge = transcript->get_challenge<Fr>("IPA:round_challenge_" + index);
const Fr round_challenge_inv = round_challenge.invert();

auto G_lo = GroupElement::batch_mul_with_endomorphism(
std::span{ G_vec_local.begin(), G_vec_local.begin() + static_cast<long>(round_size) },
round_challenge_inv);
auto G_hi = GroupElement::batch_mul_with_endomorphism(
std::span{ G_vec_local.begin() + static_cast<long>(round_size),
G_vec_local.begin() + static_cast<long>(round_size * 2) },
round_challenge);
round_challenge_inv);

// Update the vectors a_vec, b_vec and G_vec.
// a_vec_next = a_vec_lo * round_challenge + a_vec_hi * round_challenge_inv
// b_vec_next = b_vec_lo * round_challenge_inv + b_vec_hi * round_challenge
// G_vec_next = G_vec_lo * round_challenge_inv + G_vec_hi * round_challenge
// a_vec_next = a_vec_lo + a_vec_hi * round_challenge
// b_vec_next = b_vec_lo + b_vec_hi * round_challenge_inv
// G_vec_next = G_vec_lo + G_vec_hi * round_challenge_inv
run_loop_in_parallel_if_effective(
round_size,
[&a_vec, &b_vec, round_challenge, round_challenge_inv, round_size](size_t start, size_t end) {
for (size_t j = start; j < end; j++) {
a_vec[j] *= round_challenge;
a_vec[j] += round_challenge_inv * a_vec[round_size + j];
b_vec[j] *= round_challenge_inv;
b_vec[j] += round_challenge * b_vec[round_size + j];
a_vec[j] += round_challenge * a_vec[round_size + j];
b_vec[j] += round_challenge_inv * b_vec[round_size + j];
}
},
/*finite_field_additions_per_iteration=*/4,
/*finite_field_multiplications_per_iteration=*/8,
/*finite_field_inversions_per_iteration=*/1);
GroupElement::batch_affine_add(G_lo, G_hi, G_vec_local);
GroupElement::batch_affine_add(
std::span{ G_vec_local.begin(), G_vec_local.begin() + static_cast<long>(round_size) },
G_hi,
G_vec_local);
}

transcript->send_to_verifier("IPA:a_0", a_vec[0]);
Expand Down Expand Up @@ -196,7 +195,7 @@ template <typename Curve> class IPA {
// Compute C_prime
GroupElement C_prime = opening_claim.commitment + (aux_generator * opening_claim.opening_pair.evaluation);

// Compute C_zero = C_prime + ∑_{j ∈ [k]} u_j^2L_j + ∑_{j ∈ [k]} u_j^{-2}R_j
// Compute C_zero = C_prime + ∑_{j ∈ [k]} u_j^{-1}L_j + ∑_{j ∈ [k]} u_jR_j
auto pippenger_size = 2 * log_poly_degree;
std::vector<Fr> round_challenges(log_poly_degree);
std::vector<Fr> round_challenges_inv(log_poly_degree);
Expand All @@ -211,8 +210,8 @@ template <typename Curve> class IPA {

msm_elements[2 * i] = element_L;
msm_elements[2 * i + 1] = element_R;
msm_scalars[2 * i] = round_challenges[i].sqr();
msm_scalars[2 * i + 1] = round_challenges_inv[i].sqr();
msm_scalars[2 * i] = round_challenges_inv[i];
msm_scalars[2 * i + 1] = round_challenges[i];
}

GroupElement LR_sums = bb::scalar_multiplication::pippenger_without_endomorphism_basis_points<Curve>(
Expand All @@ -222,31 +221,29 @@ template <typename Curve> class IPA {
/**
* Compute b_zero where b_zero can be computed using the polynomial:
*
* g(X) = ∏_{i ∈ [k]} (u_{k-i}^{-1} + u_{k-i}.X^{2^{i-1}}).
* g(X) = ∏_{i ∈ [k]} (1 + u_{k-i}^{-1}.X^{2^{i-1}}).
*
* b_zero = g(evaluation) = ∏_{i ∈ [k]} (u_{k-i}^{-1} + u_{k-i}. (evaluation)^{2^{i-1}})
* b_zero = g(evaluation) = ∏_{i ∈ [k]} (1 + u_{k-i}^{-1}. (evaluation)^{2^{i-1}})
*/
Fr b_zero = Fr::one();
for (size_t i = 0; i < log_poly_degree; i++) {
auto exponent = static_cast<uint64_t>(Fr(2).pow(i));
b_zero *= round_challenges_inv[log_poly_degree - 1 - i] +
(round_challenges[log_poly_degree - 1 - i] * opening_claim.opening_pair.challenge.pow(exponent));
b_zero *= Fr::one() + (round_challenges_inv[log_poly_degree - 1 - i] *
opening_claim.opening_pair.challenge.pow(exponent));
}

// Compute G_zero
// First construct s_vec
std::vector<Fr> s_vec(poly_degree);
run_loop_in_parallel_if_effective(
poly_degree,
[&s_vec, &round_challenges, &round_challenges_inv, log_poly_degree](size_t start, size_t end) {
[&s_vec, &round_challenges_inv, log_poly_degree](size_t start, size_t end) {
for (size_t i = start; i < end; i++) {
Fr s_vec_scalar = Fr::one();
for (size_t j = (log_poly_degree - 1); j != size_t(-1); j--) {
auto bit = (i >> j) & 1;
bool b = static_cast<bool>(bit);
if (b) {
s_vec_scalar *= round_challenges[log_poly_degree - 1 - j];
} else {
s_vec_scalar *= round_challenges_inv[log_poly_degree - 1 - j];
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

namespace bb::crypto::merkle_tree {

typedef uint256_t index_t;
using index_t = uint256_t;

/**
* @brief Used in parallel insertions in the the IndexedTree. Workers signal to other following workes as they move up
Expand All @@ -18,11 +18,13 @@ class LevelSignal {
public:
LevelSignal(size_t initial_level)
: signal_(initial_level){};
~LevelSignal(){};
~LevelSignal() = default;
LevelSignal(const LevelSignal& other)
: signal_(other.signal_.load())
{}
LevelSignal(const LevelSignal&& other) = delete;
LevelSignal(const LevelSignal&& other) noexcept
: signal_(other.signal_.load())
{}

/**
* @brief Causes the thread to wait until the required level has been signalled
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -633,12 +633,12 @@ plookup::BasicTable& UltraCircuitBuilder_<Arithmetization>::get_table(const ploo

template <typename Arithmetization>
plookup::ReadData<uint32_t> UltraCircuitBuilder_<Arithmetization>::create_gates_from_plookup_accumulators(
const plookup::MultiTableIdOrPtr& id,
const plookup::MultiTableId& id,
const plookup::ReadData<FF>& read_values,
const uint32_t key_a_index,
std::optional<uint32_t> key_b_index)
{
const auto& multi_table = plookup::get_table(id);
const auto& multi_table = plookup::create_table(id);
const size_t num_lookups = read_values[plookup::ColumnIdx::C1].size();
plookup::ReadData<uint32_t> read_data;
for (size_t i = 0; i < num_lookups; ++i) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1009,7 +1009,7 @@ class UltraCircuitBuilder_ : public CircuitBuilderBase<typename Arithmetization:
plookup::MultiTable& create_table(const plookup::MultiTableId id);

plookup::ReadData<uint32_t> create_gates_from_plookup_accumulators(
const plookup::MultiTableIdOrPtr& id,
const plookup::MultiTableId& id,
const plookup::ReadData<FF>& read_values,
const uint32_t key_a_index,
std::optional<uint32_t> key_b_index = std::nullopt);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -108,21 +108,13 @@ const MultiTable& create_table(const MultiTableId id)
return MULTI_TABLES[id];
}

const MultiTable& get_table(const MultiTableIdOrPtr& id)
{
if (id.ptr == nullptr) {
return create_table(id.id);
}
return *id.ptr;
}

ReadData<bb::fr> get_lookup_accumulators(const MultiTableIdOrPtr& id,
ReadData<bb::fr> get_lookup_accumulators(const MultiTableId id,
const fr& key_a,
const fr& key_b,
const bool is_2_to_1_lookup)
{
// return multi-table, populating global array of all multi-tables if need be
const auto& multi_table = get_table(id);
const auto& multi_table = create_table(id);
const size_t num_lookups = multi_table.lookup_ids.size();

ReadData<bb::fr> lookup;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,8 @@
namespace bb::plookup {

const MultiTable& create_table(MultiTableId id);
const MultiTable& get_table(const MultiTableIdOrPtr& id);

ReadData<bb::fr> get_lookup_accumulators(const MultiTableIdOrPtr& id,
ReadData<bb::fr> get_lookup_accumulators(MultiTableId id,
const bb::fr& key_a,
const bb::fr& key_b = 0,
bool is_2_to_1_lookup = false);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -132,17 +132,17 @@ struct MultiTable {
std::vector<bb::fr> column_1_step_sizes;
std::vector<bb::fr> column_2_step_sizes;
std::vector<bb::fr> column_3_step_sizes;
using table_out = std::array<bb::fr, 2>;
using table_in = std::array<uint64_t, 2>;
typedef std::array<bb::fr, 2> table_out;
typedef std::array<uint64_t, 2> table_in;
std::vector<table_out (*)(table_in)> get_table_values;

private:
void init_step_sizes()
{
const size_t num_lookups = column_1_coefficients.size();
column_1_step_sizes.emplace_back(1);
column_2_step_sizes.emplace_back(1);
column_3_step_sizes.emplace_back(1);
column_1_step_sizes.emplace_back(bb::fr(1));
column_2_step_sizes.emplace_back(bb::fr(1));
column_3_step_sizes.emplace_back(bb::fr(1));

std::vector<bb::fr> coefficient_inverses(column_1_coefficients.begin(), column_1_coefficients.end());
std::copy(column_2_coefficients.begin(), column_2_coefficients.end(), std::back_inserter(coefficient_inverses));
Expand Down Expand Up @@ -184,30 +184,74 @@ struct MultiTable {
init_step_sizes();
}

MultiTable() = default;
MultiTable(){};
MultiTable(const MultiTable& other) = default;
MultiTable(MultiTable&& other) = default;

MultiTable& operator=(const MultiTable& other) = default;
MultiTable& operator=(MultiTable&& other) = default;
};

// Represents either a predefined table from our enum list of supported lookup tables, or a dynamic lookup table defined
// by ACIR
struct MultiTableIdOrPtr {
// Used if we are using a lookup table from our predefined list, otherwise set to NUM_MULTI_TABLES and unused.
MultiTableId id;
// Used if we are using a lookup table from a lookup table defined by e.g. ACIR, otherwise set to nullptr.
MultiTable* ptr;
MultiTableIdOrPtr(MultiTable* ptr)
: id(NUM_MULTI_TABLES)
, ptr(ptr)
{}
MultiTableIdOrPtr(MultiTableId id)
: id(id)
, ptr(nullptr)
{}
};
// struct PlookupLargeKeyTable {
// struct KeyEntry {
// uint256_t key;
// std::array<bb::fr, 2> value{ bb::fr(0), bb::fr(0) };
// bool operator<(const KeyEntry& other) const { return key < other.key; }

// std::array<bb::fr, 3> to_sorted_list_components(const bool use_two_keys) const
// {
// return {
// key[0],
// value[0],
// value[1],
// };
// }
// };

// BasicTableId id;
// size_t table_index;
// size_t size;
// bool use_twin_keys;

// bb::fr column_1_step_size = bb::fr(0);
// bb::fr column_2_step_size = bb::fr(0);
// bb::fr column_3_step_size = bb::fr(0);
// std::vector<bb::fr> column_1;
// std::vector<bb::fr> column_3;
// std::vector<bb::fr> column_2;
// std::vector<KeyEntry> lookup_gates;

// std::array<bb::fr, 2> (*get_values_from_key)(const std::array<uint64_t, 2>);
// };

// struct PlookupFatKeyTable {
// struct KeyEntry {
// bb::fr key;
// std::array<bb::fr, 2> values{ 0, 0 };
// bool operator<(const KeyEntry& other) const
// {
// return (key.from_montgomery_form() < other.key.from_montgomery_form());
// }

// std::array<bb::fr, 3> to_sorted_list_components() const { return { key, values[0], values[0] }; }
// }

// BasicTableId id;
// size_t table_index;
// size_t size;
// bool use_twin_keys;

// bb::fr column_1_step_size = bb::fr(0);
// bb::fr column_2_step_size = bb::fr(0);
// bb::fr column_3_step_size = bb::fr(0);
// std::vector<bb::fr> column_1;
// std::vector<bb::fr> column_3;
// std::vector<bb::fr> column_2;
// std::vector<KeyEntry> lookup_gates;

// std::array<bb::fr, 2> (*get_values_from_key)(const std::array<uint64_t, 2>);

// }

/**
* @brief The structure contains the most basic table serving one function (for, example an xor table)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ using plookup::MultiTableId;
using namespace bb;

template <typename Builder>
plookup::ReadData<field_t<Builder>> plookup_read<Builder>::get_lookup_accumulators(const plookup::MultiTableIdOrPtr& id,
plookup::ReadData<field_t<Builder>> plookup_read<Builder>::get_lookup_accumulators(const MultiTableId id,
const field_t<Builder>& key_a_in,
const field_t<Builder>& key_b_in,
const bool is_2_to_1_lookup)
Expand Down Expand Up @@ -64,16 +64,16 @@ plookup::ReadData<field_t<Builder>> plookup_read<Builder>::get_lookup_accumulato
}

template <typename Builder>
std::pair<field_t<Builder>, field_t<Builder>> plookup_read<Builder>::read_pair_from_table(
const plookup::MultiTableIdOrPtr& id, const field_t<Builder>& key)
std::pair<field_t<Builder>, field_t<Builder>> plookup_read<Builder>::read_pair_from_table(const MultiTableId id,
const field_t<Builder>& key)
{
const auto lookup = get_lookup_accumulators(id, key);

return { lookup[ColumnIdx::C2][0], lookup[ColumnIdx::C3][0] };
}

template <typename Builder>
field_t<Builder> plookup_read<Builder>::read_from_2_to_1_table(const plookup::MultiTableIdOrPtr& id,
field_t<Builder> plookup_read<Builder>::read_from_2_to_1_table(const MultiTableId id,
const field_t<Builder>& key_a,
const field_t<Builder>& key_b)
{
Expand All @@ -83,8 +83,7 @@ field_t<Builder> plookup_read<Builder>::read_from_2_to_1_table(const plookup::Mu
}

template <typename Builder>
field_t<Builder> plookup_read<Builder>::read_from_1_to_2_table(const plookup::MultiTableIdOrPtr& id,
const field_t<Builder>& key_a)
field_t<Builder> plookup_read<Builder>::read_from_1_to_2_table(const MultiTableId id, const field_t<Builder>& key_a)
{
const auto lookup = get_lookup_accumulators(id, key_a);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,14 @@ template <typename Builder> class plookup_read {
typedef field_t<Builder> field_pt;

public:
static std::pair<field_pt, field_pt> read_pair_from_table(const plookup::MultiTableIdOrPtr& id,
const field_pt& key);
static std::pair<field_pt, field_pt> read_pair_from_table(const plookup::MultiTableId id, const field_pt& key);

static field_pt read_from_2_to_1_table(const plookup::MultiTableIdOrPtr& id,
static field_pt read_from_2_to_1_table(const plookup::MultiTableId id,
const field_pt& key_a,
const field_pt& key_b);
static field_pt read_from_1_to_2_table(const plookup::MultiTableIdOrPtr& id, const field_pt& key_a);
static field_pt read_from_1_to_2_table(const plookup::MultiTableId id, const field_pt& key_a);

static plookup::ReadData<field_pt> get_lookup_accumulators(const plookup::MultiTableIdOrPtr& id,
static plookup::ReadData<field_pt> get_lookup_accumulators(const plookup::MultiTableId id,
const field_pt& key_a,
const field_pt& key_b = 0,
const bool is_2_to_1_lookup = false);
Expand Down
Loading

0 comments on commit 70e1a45

Please sign in to comment.