Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: ultra honk arith from ultra #3274

Merged
merged 7 commits into from
Nov 10, 2023
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
builds w UCB inheritance w new arith
  • Loading branch information
ledwards2225 committed Nov 8, 2023
commit 06d50de33a7f510dd861f56baff18718bf01c34c
183 changes: 108 additions & 75 deletions barretenberg/cpp/src/barretenberg/flavor/goblin_ultra.hpp
Original file line number Diff line number Diff line change
@@ -32,12 +32,13 @@ class GoblinUltra {
// The number of multivariate polynomials on which a sumcheck prover sumcheck operates (including shifts). We often
// need containers of this size to hold related data, so we choose a name more agnostic than `NUM_POLYNOMIALS`.
// Note: this number does not include the individual sorted list polynomials.
static constexpr size_t NUM_ALL_ENTITIES = 48; // 43 (UH) + 4 op wires + 1 op wire "selector"
// NUM = 43 (UH) + 4 op wires + 1 op wire "selector" + 3 (calldata + calldata_read_counts + q_busread)
static constexpr size_t NUM_ALL_ENTITIES = 51;
// The number of polynomials precomputed to describe a circuit and to aid a prover in constructing a satisfying
// assignment of witnesses. We again choose a neutral name.
static constexpr size_t NUM_PRECOMPUTED_ENTITIES = 26; // 25 (UH) + 1 op wire "selector"
static constexpr size_t NUM_PRECOMPUTED_ENTITIES = 27; // 25 (UH) + 1 op wire "selector" + q_busread
// The total number of witness entities not including shifts.
static constexpr size_t NUM_WITNESS_ENTITIES = 15; // 11 (UH) + 4 op wires
static constexpr size_t NUM_WITNESS_ENTITIES = 17; // 11 (UH) + 4 op wires + (calldata + calldata_read_counts)

using GrandProductRelations =
std::tuple<proof_system::UltraPermutationRelation<FF>, proof_system::LookupRelation<FF>>;
@@ -50,6 +51,7 @@ class GoblinUltra {
proof_system::EllipticRelation<FF>,
proof_system::AuxiliaryRelation<FF>,
proof_system::EccOpQueueRelation<FF>>;
// WORKTODO: add bus lookup relation!

static constexpr size_t MAX_PARTIAL_RELATION_LENGTH = compute_max_partial_relation_length<Relations>();
static constexpr size_t MAX_TOTAL_RELATION_LENGTH = compute_max_total_relation_length<Relations>();
@@ -89,27 +91,28 @@ class GoblinUltra {
DataType& q_elliptic = std::get<8>(this->_data);
DataType& q_aux = std::get<9>(this->_data);
DataType& q_lookup = std::get<10>(this->_data);
DataType& sigma_1 = std::get<11>(this->_data);
DataType& sigma_2 = std::get<12>(this->_data);
DataType& sigma_3 = std::get<13>(this->_data);
DataType& sigma_4 = std::get<14>(this->_data);
DataType& id_1 = std::get<15>(this->_data);
DataType& id_2 = std::get<16>(this->_data);
DataType& id_3 = std::get<17>(this->_data);
DataType& id_4 = std::get<18>(this->_data);
DataType& table_1 = std::get<19>(this->_data);
DataType& table_2 = std::get<20>(this->_data);
DataType& table_3 = std::get<21>(this->_data);
DataType& table_4 = std::get<22>(this->_data);
DataType& lagrange_first = std::get<23>(this->_data);
DataType& lagrange_last = std::get<24>(this->_data);
DataType& lagrange_ecc_op = std::get<25>(this->_data); // indicator poly for ecc op gates
DataType& q_busread = std::get<11>(this->_data);
DataType& sigma_1 = std::get<12>(this->_data);
DataType& sigma_2 = std::get<13>(this->_data);
DataType& sigma_3 = std::get<14>(this->_data);
DataType& sigma_4 = std::get<15>(this->_data);
DataType& id_1 = std::get<16>(this->_data);
DataType& id_2 = std::get<17>(this->_data);
DataType& id_3 = std::get<18>(this->_data);
DataType& id_4 = std::get<19>(this->_data);
DataType& table_1 = std::get<20>(this->_data);
DataType& table_2 = std::get<21>(this->_data);
DataType& table_3 = std::get<22>(this->_data);
DataType& table_4 = std::get<23>(this->_data);
DataType& lagrange_first = std::get<24>(this->_data);
DataType& lagrange_last = std::get<25>(this->_data);
DataType& lagrange_ecc_op = std::get<26>(this->_data); // indicator poly for ecc op gates

static constexpr CircuitType CIRCUIT_TYPE = CircuitBuilder::CIRCUIT_TYPE;

std::vector<HandleType> get_selectors() override
{
return { q_m, q_c, q_l, q_r, q_o, q_4, q_arith, q_sort, q_elliptic, q_aux, q_lookup };
return { q_m, q_c, q_l, q_r, q_o, q_4, q_arith, q_sort, q_elliptic, q_aux, q_lookup, q_busread };
};
std::vector<HandleType> get_sigma_polynomials() override { return { sigma_1, sigma_2, sigma_3, sigma_4 }; };
std::vector<HandleType> get_id_polynomials() override { return { id_1, id_2, id_3, id_4 }; };
@@ -139,6 +142,8 @@ class GoblinUltra {
DataType& ecc_op_wire_2 = std::get<12>(this->_data);
DataType& ecc_op_wire_3 = std::get<13>(this->_data);
DataType& ecc_op_wire_4 = std::get<14>(this->_data);
DataType& calldata = std::get<15>(this->_data);
DataType& calldata_read_counts = std::get<16>(this->_data);

std::vector<HandleType> get_wires() override { return { w_l, w_r, w_o, w_4 }; };
std::vector<HandleType> get_ecc_op_wires()
@@ -172,43 +177,46 @@ class GoblinUltra {
DataType& q_elliptic = std::get<8>(this->_data);
DataType& q_aux = std::get<9>(this->_data);
DataType& q_lookup = std::get<10>(this->_data);
DataType& sigma_1 = std::get<11>(this->_data);
DataType& sigma_2 = std::get<12>(this->_data);
DataType& sigma_3 = std::get<13>(this->_data);
DataType& sigma_4 = std::get<14>(this->_data);
DataType& id_1 = std::get<15>(this->_data);
DataType& id_2 = std::get<16>(this->_data);
DataType& id_3 = std::get<17>(this->_data);
DataType& id_4 = std::get<18>(this->_data);
DataType& table_1 = std::get<19>(this->_data);
DataType& table_2 = std::get<20>(this->_data);
DataType& table_3 = std::get<21>(this->_data);
DataType& table_4 = std::get<22>(this->_data);
DataType& lagrange_first = std::get<23>(this->_data);
DataType& lagrange_last = std::get<24>(this->_data);
DataType& lagrange_ecc_op = std::get<25>(this->_data);
DataType& w_l = std::get<26>(this->_data);
DataType& w_r = std::get<27>(this->_data);
DataType& w_o = std::get<28>(this->_data);
DataType& w_4 = std::get<29>(this->_data);
DataType& sorted_accum = std::get<30>(this->_data);
DataType& z_perm = std::get<31>(this->_data);
DataType& z_lookup = std::get<32>(this->_data);
DataType& ecc_op_wire_1 = std::get<33>(this->_data);
DataType& ecc_op_wire_2 = std::get<34>(this->_data);
DataType& ecc_op_wire_3 = std::get<35>(this->_data);
DataType& ecc_op_wire_4 = std::get<36>(this->_data);
DataType& table_1_shift = std::get<37>(this->_data);
DataType& table_2_shift = std::get<38>(this->_data);
DataType& table_3_shift = std::get<39>(this->_data);
DataType& table_4_shift = std::get<40>(this->_data);
DataType& w_l_shift = std::get<41>(this->_data);
DataType& w_r_shift = std::get<42>(this->_data);
DataType& w_o_shift = std::get<43>(this->_data);
DataType& w_4_shift = std::get<44>(this->_data);
DataType& sorted_accum_shift = std::get<45>(this->_data);
DataType& z_perm_shift = std::get<46>(this->_data);
DataType& z_lookup_shift = std::get<47>(this->_data);
DataType& q_busread = std::get<11>(this->_data);
DataType& sigma_1 = std::get<12>(this->_data);
DataType& sigma_2 = std::get<13>(this->_data);
DataType& sigma_3 = std::get<14>(this->_data);
DataType& sigma_4 = std::get<15>(this->_data);
DataType& id_1 = std::get<16>(this->_data);
DataType& id_2 = std::get<17>(this->_data);
DataType& id_3 = std::get<18>(this->_data);
DataType& id_4 = std::get<19>(this->_data);
DataType& table_1 = std::get<20>(this->_data);
DataType& table_2 = std::get<21>(this->_data);
DataType& table_3 = std::get<22>(this->_data);
DataType& table_4 = std::get<23>(this->_data);
DataType& lagrange_first = std::get<24>(this->_data);
DataType& lagrange_last = std::get<25>(this->_data);
DataType& lagrange_ecc_op = std::get<26>(this->_data);
DataType& w_l = std::get<27>(this->_data);
DataType& w_r = std::get<28>(this->_data);
DataType& w_o = std::get<29>(this->_data);
DataType& w_4 = std::get<30>(this->_data);
DataType& sorted_accum = std::get<31>(this->_data);
DataType& z_perm = std::get<32>(this->_data);
DataType& z_lookup = std::get<33>(this->_data);
DataType& ecc_op_wire_1 = std::get<34>(this->_data);
DataType& ecc_op_wire_2 = std::get<35>(this->_data);
DataType& ecc_op_wire_3 = std::get<36>(this->_data);
DataType& ecc_op_wire_4 = std::get<37>(this->_data);
DataType& calldata = std::get<38>(this->_data);
DataType& calldata_read_counts = std::get<39>(this->_data);
DataType& table_1_shift = std::get<40>(this->_data);
DataType& table_2_shift = std::get<41>(this->_data);
DataType& table_3_shift = std::get<42>(this->_data);
DataType& table_4_shift = std::get<43>(this->_data);
DataType& w_l_shift = std::get<44>(this->_data);
DataType& w_r_shift = std::get<45>(this->_data);
DataType& w_o_shift = std::get<46>(this->_data);
DataType& w_4_shift = std::get<47>(this->_data);
DataType& sorted_accum_shift = std::get<48>(this->_data);
DataType& z_perm_shift = std::get<49>(this->_data);
DataType& z_lookup_shift = std::get<50>(this->_data);

std::vector<HandleType> get_wires() override { return { w_l, w_r, w_o, w_4 }; };
std::vector<HandleType> get_ecc_op_wires()
@@ -218,25 +226,46 @@ class GoblinUltra {
// Gemini-specific getters.
std::vector<HandleType> get_unshifted() override
{
return { q_c, q_l,
q_r, q_o,
q_4, q_m,
q_arith, q_sort,
q_elliptic, q_aux,
q_lookup, sigma_1,
sigma_2, sigma_3,
sigma_4, id_1,
id_2, id_3,
id_4, table_1,
table_2, table_3,
table_4, lagrange_first,
lagrange_last, lagrange_ecc_op,
w_l, w_r,
w_o, w_4,
sorted_accum, z_perm,
z_lookup, ecc_op_wire_1,
ecc_op_wire_2, ecc_op_wire_3,
ecc_op_wire_4 };
return { q_c,
q_l,
q_r,
q_o,
q_4,
q_m,
q_arith,
q_sort,
q_elliptic,
q_aux,
q_lookup,
q_busread,
sigma_1,
sigma_2,
sigma_3,
sigma_4,
id_1,
id_2,
id_3,
id_4,
table_1,
table_2,
table_3,
table_4,
lagrange_first,
lagrange_last,
lagrange_ecc_op,
w_l,
w_r,
w_o,
w_4,
sorted_accum,
z_perm,
z_lookup,
ecc_op_wire_1,
ecc_op_wire_2,
ecc_op_wire_3,
ecc_op_wire_4,
calldata,
calldata_read_counts };
};
std::vector<HandleType> get_to_be_shifted() override
{
@@ -384,6 +413,8 @@ class GoblinUltra {
ecc_op_wire_2 = "ECC_OP_WIRE_2";
ecc_op_wire_3 = "ECC_OP_WIRE_3";
ecc_op_wire_4 = "ECC_OP_WIRE_4";
calldata = "CALLDATA";
calldata_read_counts = "CALLDATA_READ_COUNTS";

// The ones beginning with "__" are only used for debugging
q_c = "__Q_C";
@@ -397,6 +428,7 @@ class GoblinUltra {
q_elliptic = "__Q_ELLIPTIC";
q_aux = "__Q_AUX";
q_lookup = "__Q_LOOKUP";
q_busread = "__Q_BUSREAD";
sigma_1 = "__SIGMA_1";
sigma_2 = "__SIGMA_2";
sigma_3 = "__SIGMA_3";
@@ -432,6 +464,7 @@ class GoblinUltra {
q_elliptic = verification_key->q_elliptic;
q_aux = verification_key->q_aux;
q_lookup = verification_key->q_lookup;
q_busread = verification_key->q_busread;
sigma_1 = verification_key->sigma_1;
sigma_2 = verification_key->sigma_2;
sigma_3 = verification_key->sigma_3;
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
#include <cstddef>
#include <cstdint>
#include <gtest/gtest.h>

#include "barretenberg/common/log.hpp"
#include "barretenberg/honk/composer/ultra_composer.hpp"
#include "barretenberg/honk/proof_system/ultra_prover.hpp"
#include "barretenberg/proof_system/circuit_builder/goblin_ultra_circuit_builder.hpp"
#include "barretenberg/proof_system/circuit_builder/ultra_circuit_builder.hpp"

using namespace proof_system::honk;

namespace test_ultra_honk_composer {

namespace {
auto& engine = numeric::random::get_debug_engine();
}

class DataBusComposerTests : public ::testing::Test {
protected:
static void SetUpTestSuite() { barretenberg::srs::init_crs_factory("../srs_db/ignition"); }

using Curve = curve::BN254;
using FF = Curve::ScalarField;
using Point = Curve::AffineElement;
using CommitmentKey = pcs::CommitmentKey<Curve>;

/**
* @brief Generate a simple test circuit with some ECC op gates and conventional arithmetic gates
*
* @param builder
*/
void generate_test_circuit(auto& builder)
{
// Add some ecc op gates
for (size_t i = 0; i < 3; ++i) {
auto point = Point::one() * FF::random_element();
auto scalar = FF::random_element();
builder.queue_ecc_mul_accum(point, scalar);
}
builder.queue_ecc_eq();

// Add some conventional gates that utilize public inputs
for (size_t i = 0; i < 10; ++i) {
FF a = FF::random_element();
FF b = FF::random_element();
FF c = FF::random_element();
FF d = a + b + c;
uint32_t a_idx = builder.add_public_variable(a);
uint32_t b_idx = builder.add_variable(b);
uint32_t c_idx = builder.add_variable(c);
uint32_t d_idx = builder.add_variable(d);

builder.create_big_add_gate({ a_idx, b_idx, c_idx, d_idx, FF(1), FF(1), FF(1), FF(-1), FF(0) });
}
}

/**
* @brief Construct and a verify a Honk proof
*
*/
bool construct_and_verify_honk_proof(auto& composer, auto& builder)
{
auto instance = composer.create_instance(builder);
auto prover = composer.create_prover(instance);
auto verifier = composer.create_verifier(instance);
auto proof = prover.construct_proof();
bool verified = verifier.verify_proof(proof);

return verified;
}
};

/**
* @brief Test proof construction/verification for a circuit with ECC op gates, public inputs, and basic arithmetic
* gates
* @note We simulate op queue interactions with a previous circuit so the actual circuit under test utilizes an op queue
* with non-empty 'previous' data. This avoid complications with zero-commitments etc.
*
*/
TEST_F(DataBusComposerTests, SingleCircuit)
{
auto op_queue = std::make_shared<proof_system::ECCOpQueue>();

// Add mock data to op queue to simulate interaction with a previous circuit
op_queue->populate_with_mock_initital_data();

auto builder = proof_system::GoblinUltraCircuitBuilder(op_queue);

generate_test_circuit(builder);

auto composer = GoblinUltraComposer();

// Construct and verify Honk proof
auto honk_verified = construct_and_verify_honk_proof(composer, builder);
EXPECT_TRUE(honk_verified);
}

} // namespace test_ultra_honk_composer
Original file line number Diff line number Diff line change
@@ -647,7 +647,7 @@ TYPED_TEST(ultra_plonk_composer, non_native_field_multiplication)
const auto q_indices = get_limb_witness_indices(split_into_limbs(uint256_t(q)));
const auto r_indices = get_limb_witness_indices(split_into_limbs(uint256_t(r)));

proof_system::UltraCircuitBuilder::non_native_field_witnesses inputs{
proof_system::non_native_field_witnesses<fr> inputs{
a_indices, b_indices, q_indices, r_indices, modulus_limbs, fr(uint256_t(modulus)),
};
const auto [lo_1_idx, hi_1_idx] = builder.evaluate_non_native_field_multiplication(inputs);
Loading