Skip to content

Commit

Permalink
feat!: add is_infinite to curve addition opcode (#6384)
Browse files Browse the repository at this point in the history
Resolves noir-lang/noir#4978

Since elliptic curve addition in barretenberg is already handling the
point at infinity, I simply expose it in the ACIR opcode.
  • Loading branch information
guipublic authored and AztecBot committed May 18, 2024
1 parent 4a6f4d4 commit c65a009
Show file tree
Hide file tree
Showing 7 changed files with 144 additions and 37 deletions.
12 changes: 8 additions & 4 deletions cpp/src/barretenberg/dsl/acir_format/acir_to_constraint_buf.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -342,15 +342,19 @@ void handle_blackbox_func_call(Program::Opcode::BlackBoxFuncCall const& arg, Aci
.scalars = map(arg.scalars, [](auto& e) { return e.witness.value; }),
.out_point_x = arg.outputs[0].value,
.out_point_y = arg.outputs[1].value,
.out_point_is_infinite = arg.outputs[2].value,
});
} else if constexpr (std::is_same_v<T, Program::BlackBoxFuncCall::EmbeddedCurveAdd>) {
af.ec_add_constraints.push_back(EcAdd{
.input1_x = arg.input1_x.witness.value,
.input1_y = arg.input1_y.witness.value,
.input2_x = arg.input2_x.witness.value,
.input2_y = arg.input2_y.witness.value,
.input1_x = arg.input1[0].witness.value,
.input1_y = arg.input1[1].witness.value,
.input1_infinite = arg.input1[2].witness.value,
.input2_x = arg.input2[0].witness.value,
.input2_y = arg.input2[1].witness.value,
.input2_infinite = arg.input2[2].witness.value,
.result_x = arg.outputs[0].value,
.result_y = arg.outputs[1].value,
.result_infinite = arg.outputs[2].value,
});
} else if constexpr (std::is_same_v<T, Program::BlackBoxFuncCall::Keccak256>) {
af.keccak_constraints.push_back(KeccakConstraint{
Expand Down
13 changes: 9 additions & 4 deletions cpp/src/barretenberg/dsl/acir_format/ec_operations.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,30 +13,35 @@ void create_ec_add_constraint(Builder& builder, const EcAdd& input, bool has_val
// Input to cycle_group points
using cycle_group_ct = bb::stdlib::cycle_group<Builder>;
using field_ct = bb::stdlib::field_t<Builder>;
using bool_ct = bb::stdlib::bool_t<Builder>;

auto x1 = field_ct::from_witness_index(&builder, input.input1_x);
auto y1 = field_ct::from_witness_index(&builder, input.input1_y);
auto x2 = field_ct::from_witness_index(&builder, input.input2_x);
auto y2 = field_ct::from_witness_index(&builder, input.input2_y);
auto infinite1 = bool_ct(field_ct::from_witness_index(&builder, input.input1_infinite));
auto infinite2 = bool_ct(field_ct::from_witness_index(&builder, input.input2_infinite));
if (!has_valid_witness_assignments) {
auto g1 = grumpkin::g1::affine_one;
// We need to have correct values representing points on the curve
builder.variables[input.input1_x] = g1.x;
builder.variables[input.input1_y] = g1.y;
builder.variables[input.input1_infinite] = fr(0);
builder.variables[input.input2_x] = g1.x;
builder.variables[input.input2_y] = g1.y;
builder.variables[input.input2_infinite] = fr(0);
}

cycle_group_ct input1_point(x1, y1, false);
cycle_group_ct input2_point(x2, y2, false);

cycle_group_ct input1_point(x1, y1, infinite1);
cycle_group_ct input2_point(x2, y2, infinite2);
// Addition
cycle_group_ct result = input1_point + input2_point;

auto x_normalized = result.x.normalize();
auto y_normalized = result.y.normalize();
auto infinite = result.is_point_at_infinity().normalize();
builder.assert_equal(x_normalized.witness_index, input.result_x);
builder.assert_equal(y_normalized.witness_index, input.result_y);
builder.assert_equal(infinite.witness_index, input.result_infinite);
}

template void create_ec_add_constraint<UltraCircuitBuilder>(UltraCircuitBuilder& builder,
Expand Down
6 changes: 5 additions & 1 deletion cpp/src/barretenberg/dsl/acir_format/ec_operations.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,17 @@ namespace acir_format {
struct EcAdd {
uint32_t input1_x;
uint32_t input1_y;
uint32_t input1_infinite;
uint32_t input2_x;
uint32_t input2_y;
uint32_t input2_infinite;
uint32_t result_x;
uint32_t result_y;
uint32_t result_infinite;

// for serialization, update with any new fields
MSGPACK_FIELDS(input1_x, input1_y, input2_x, input2_y, result_x, result_y);
MSGPACK_FIELDS(
input1_x, input1_y, input1_infinite, input2_x, input2_y, input2_infinite, result_x, result_y, result_infinite);
friend bool operator==(EcAdd const& lhs, EcAdd const& rhs) = default;
};

Expand Down
91 changes: 91 additions & 0 deletions cpp/src/barretenberg/dsl/acir_format/ec_operations.test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,13 +30,18 @@ size_t generate_ec_add_constraint(EcAdd& ec_add_constraint, WitnessVector& witne
witness_values.push_back(g1.y);
witness_values.push_back(result.x.get_value());
witness_values.push_back(result.y.get_value());
witness_values.push_back(fr(0));
witness_values.push_back(fr(0));
ec_add_constraint = EcAdd{
.input1_x = 1,
.input1_y = 2,
.input1_infinite = 7,
.input2_x = 3,
.input2_y = 4,
.input2_infinite = 7,
.result_x = 5,
.result_y = 6,
.result_infinite = 8,
};
return witness_values.size();
}
Expand Down Expand Up @@ -85,6 +90,92 @@ TEST_F(EcOperations, TestECOperations)
auto prover = composer.create_prover(builder);

auto proof = prover.construct_proof();

EXPECT_TRUE(CircuitChecker::check(builder));
auto verifier = composer.create_verifier(builder);
EXPECT_EQ(verifier.verify_proof(proof), true);
}

TEST_F(EcOperations, TestECMultiScalarMul)
{
MultiScalarMul msm_constrain;

WitnessVector witness_values;
witness_values.emplace_back(fr(0));

witness_values = {
// dummy
fr(0),
// g1: x,y,infinite
fr(1),
fr("0x0000000000000002cf135e7506a45d632d270d45f1181294833fc48d823f272c"),
fr(0),
// low, high scalars
fr(1),
fr(0),
// result
fr("0x06ce1b0827aafa85ddeb49cdaa36306d19a74caa311e13d46d8bc688cdbffffe"),
fr("0x1c122f81a3a14964909ede0ba2a6855fc93faf6fa1a788bf467be7e7a43f80ac"),
fr(0),
};
msm_constrain = MultiScalarMul{
.points = { 1, 2, 3, 1, 2, 3 },
.scalars = { 4, 5, 4, 5 },
.out_point_x = 6,
.out_point_y = 7,
.out_point_is_infinite = 0,
};
auto res_x = fr("0x06ce1b0827aafa85ddeb49cdaa36306d19a74caa311e13d46d8bc688cdbffffe");
auto assert_equal = poly_triple{
.a = 6,
.b = 0,
.c = 0,
.q_m = 0,
.q_l = fr::neg_one(),
.q_r = 0,
.q_o = 0,
.q_c = res_x,
};

size_t num_variables = witness_values.size();
AcirFormat constraint_system{
.varnum = static_cast<uint32_t>(num_variables + 1),
.recursive = false,
.num_acir_opcodes = 1,
.public_inputs = {},
.logic_constraints = {},
.range_constraints = {},
.aes128_constraints = {},
.sha256_constraints = {},
.sha256_compression = {},
.schnorr_constraints = {},
.ecdsa_k1_constraints = {},
.ecdsa_r1_constraints = {},
.blake2s_constraints = {},
.blake3_constraints = {},
.keccak_constraints = {},
.keccak_permutations = {},
.pedersen_constraints = {},
.pedersen_hash_constraints = {},
.poseidon2_constraints = {},
.multi_scalar_mul_constraints = { msm_constrain },
.ec_add_constraints = {},
.recursion_constraints = {},
.bigint_from_le_bytes_constraints = {},
.bigint_to_le_bytes_constraints = {},
.bigint_operations = {},
.poly_triple_constraints = { assert_equal },
.quad_constraints = {},
.block_constraints = {},
};

auto builder = create_circuit(constraint_system, /*size_hint*/ 0, witness_values);

auto composer = Composer();
auto prover = composer.create_prover(builder);

auto proof = prover.construct_proof();

EXPECT_TRUE(CircuitChecker::check(builder));
auto verifier = composer.create_verifier(builder);
EXPECT_EQ(verifier.verify_proof(proof), true);
Expand Down
12 changes: 7 additions & 5 deletions cpp/src/barretenberg/dsl/acir_format/multi_scalar_mul.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,19 +12,20 @@ template <typename Builder> void create_multi_scalar_mul_constraint(Builder& bui
using cycle_group_ct = bb::stdlib::cycle_group<Builder>;
using cycle_scalar_ct = typename bb::stdlib::cycle_group<Builder>::cycle_scalar;
using field_ct = bb::stdlib::field_t<Builder>;
using bool_ct = bb::stdlib::bool_t<Builder>;

std::vector<cycle_group_ct> points;
std::vector<cycle_scalar_ct> scalars;

for (size_t i = 0; i < input.points.size(); i += 2) {
for (size_t i = 0; i < input.points.size(); i += 3) {
// Instantiate the input point/variable base as `cycle_group_ct`
auto point_x = field_ct::from_witness_index(&builder, input.points[i]);
auto point_y = field_ct::from_witness_index(&builder, input.points[i + 1]);
cycle_group_ct input_point(point_x, point_y, false);

auto infinite = bool_ct(field_ct::from_witness_index(&builder, input.points[i + 2]));
cycle_group_ct input_point(point_x, point_y, infinite);
// Reconstruct the scalar from the low and high limbs
field_ct scalar_low_as_field = field_ct::from_witness_index(&builder, input.scalars[i]);
field_ct scalar_high_as_field = field_ct::from_witness_index(&builder, input.scalars[i + 1]);
field_ct scalar_low_as_field = field_ct::from_witness_index(&builder, input.scalars[2 * (i / 3)]);
field_ct scalar_high_as_field = field_ct::from_witness_index(&builder, input.scalars[2 * (i / 3) + 1]);
cycle_scalar_ct scalar(scalar_low_as_field, scalar_high_as_field);

// Add the point and scalar to the vectors
Expand All @@ -38,6 +39,7 @@ template <typename Builder> void create_multi_scalar_mul_constraint(Builder& bui
// Add the constraints
builder.assert_equal(output_point.x.get_witness_index(), input.out_point_x);
builder.assert_equal(output_point.y.get_witness_index(), input.out_point_y);
builder.assert_equal(output_point.is_point_at_infinity().witness_index, input.out_point_is_infinite);
}

template void create_multi_scalar_mul_constraint<UltraCircuitBuilder>(UltraCircuitBuilder& builder,
Expand Down
3 changes: 2 additions & 1 deletion cpp/src/barretenberg/dsl/acir_format/multi_scalar_mul.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,10 @@ struct MultiScalarMul {
std::vector<uint32_t> scalars;
uint32_t out_point_x;
uint32_t out_point_y;
uint32_t out_point_is_infinite;

// for serialization, update with any new fields
MSGPACK_FIELDS(points, scalars, out_point_x, out_point_y);
MSGPACK_FIELDS(points, scalars, out_point_x, out_point_y, out_point_is_infinite);
friend bool operator==(MultiScalarMul const& lhs, MultiScalarMul const& rhs) = default;
};

Expand Down
44 changes: 22 additions & 22 deletions cpp/src/barretenberg/dsl/acir_format/serde/acir.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -149,19 +149,17 @@ struct BlackBoxFuncCall {
struct MultiScalarMul {
std::vector<Program::FunctionInput> points;
std::vector<Program::FunctionInput> scalars;
std::array<Program::Witness, 2> outputs;
std::array<Program::Witness, 3> outputs;

friend bool operator==(const MultiScalarMul&, const MultiScalarMul&);
std::vector<uint8_t> bincodeSerialize() const;
static MultiScalarMul bincodeDeserialize(std::vector<uint8_t>);
};

struct EmbeddedCurveAdd {
Program::FunctionInput input1_x;
Program::FunctionInput input1_y;
Program::FunctionInput input2_x;
Program::FunctionInput input2_y;
std::array<Program::Witness, 2> outputs;
std::array<Program::FunctionInput, 3> input1;
std::array<Program::FunctionInput, 3> input2;
std::array<Program::Witness, 3> outputs;

friend bool operator==(const EmbeddedCurveAdd&, const EmbeddedCurveAdd&);
std::vector<uint8_t> bincodeSerialize() const;
Expand Down Expand Up @@ -807,8 +805,10 @@ struct BlackBoxOp {
struct EmbeddedCurveAdd {
Program::MemoryAddress input1_x;
Program::MemoryAddress input1_y;
Program::MemoryAddress input1_infinite;
Program::MemoryAddress input2_x;
Program::MemoryAddress input2_y;
Program::MemoryAddress input2_infinite;
Program::HeapArray result;

friend bool operator==(const EmbeddedCurveAdd&, const EmbeddedCurveAdd&);
Expand Down Expand Up @@ -3194,16 +3194,10 @@ namespace Program {

inline bool operator==(const BlackBoxFuncCall::EmbeddedCurveAdd& lhs, const BlackBoxFuncCall::EmbeddedCurveAdd& rhs)
{
if (!(lhs.input1_x == rhs.input1_x)) {
if (!(lhs.input1 == rhs.input1)) {
return false;
}
if (!(lhs.input1_y == rhs.input1_y)) {
return false;
}
if (!(lhs.input2_x == rhs.input2_x)) {
return false;
}
if (!(lhs.input2_y == rhs.input2_y)) {
if (!(lhs.input2 == rhs.input2)) {
return false;
}
if (!(lhs.outputs == rhs.outputs)) {
Expand Down Expand Up @@ -3237,10 +3231,8 @@ template <typename Serializer>
void serde::Serializable<Program::BlackBoxFuncCall::EmbeddedCurveAdd>::serialize(
const Program::BlackBoxFuncCall::EmbeddedCurveAdd& obj, Serializer& serializer)
{
serde::Serializable<decltype(obj.input1_x)>::serialize(obj.input1_x, serializer);
serde::Serializable<decltype(obj.input1_y)>::serialize(obj.input1_y, serializer);
serde::Serializable<decltype(obj.input2_x)>::serialize(obj.input2_x, serializer);
serde::Serializable<decltype(obj.input2_y)>::serialize(obj.input2_y, serializer);
serde::Serializable<decltype(obj.input1)>::serialize(obj.input1, serializer);
serde::Serializable<decltype(obj.input2)>::serialize(obj.input2, serializer);
serde::Serializable<decltype(obj.outputs)>::serialize(obj.outputs, serializer);
}

Expand All @@ -3250,10 +3242,8 @@ Program::BlackBoxFuncCall::EmbeddedCurveAdd serde::Deserializable<
Program::BlackBoxFuncCall::EmbeddedCurveAdd>::deserialize(Deserializer& deserializer)
{
Program::BlackBoxFuncCall::EmbeddedCurveAdd obj;
obj.input1_x = serde::Deserializable<decltype(obj.input1_x)>::deserialize(deserializer);
obj.input1_y = serde::Deserializable<decltype(obj.input1_y)>::deserialize(deserializer);
obj.input2_x = serde::Deserializable<decltype(obj.input2_x)>::deserialize(deserializer);
obj.input2_y = serde::Deserializable<decltype(obj.input2_y)>::deserialize(deserializer);
obj.input1 = serde::Deserializable<decltype(obj.input1)>::deserialize(deserializer);
obj.input2 = serde::Deserializable<decltype(obj.input2)>::deserialize(deserializer);
obj.outputs = serde::Deserializable<decltype(obj.outputs)>::deserialize(deserializer);
return obj;
}
Expand Down Expand Up @@ -4638,12 +4628,18 @@ inline bool operator==(const BlackBoxOp::EmbeddedCurveAdd& lhs, const BlackBoxOp
if (!(lhs.input1_y == rhs.input1_y)) {
return false;
}
if (!(lhs.input1_infinite == rhs.input1_infinite)) {
return false;
}
if (!(lhs.input2_x == rhs.input2_x)) {
return false;
}
if (!(lhs.input2_y == rhs.input2_y)) {
return false;
}
if (!(lhs.input2_infinite == rhs.input2_infinite)) {
return false;
}
if (!(lhs.result == rhs.result)) {
return false;
}
Expand Down Expand Up @@ -4676,8 +4672,10 @@ void serde::Serializable<Program::BlackBoxOp::EmbeddedCurveAdd>::serialize(
{
serde::Serializable<decltype(obj.input1_x)>::serialize(obj.input1_x, serializer);
serde::Serializable<decltype(obj.input1_y)>::serialize(obj.input1_y, serializer);
serde::Serializable<decltype(obj.input1_infinite)>::serialize(obj.input1_infinite, serializer);
serde::Serializable<decltype(obj.input2_x)>::serialize(obj.input2_x, serializer);
serde::Serializable<decltype(obj.input2_y)>::serialize(obj.input2_y, serializer);
serde::Serializable<decltype(obj.input2_infinite)>::serialize(obj.input2_infinite, serializer);
serde::Serializable<decltype(obj.result)>::serialize(obj.result, serializer);
}

Expand All @@ -4689,8 +4687,10 @@ Program::BlackBoxOp::EmbeddedCurveAdd serde::Deserializable<Program::BlackBoxOp:
Program::BlackBoxOp::EmbeddedCurveAdd obj;
obj.input1_x = serde::Deserializable<decltype(obj.input1_x)>::deserialize(deserializer);
obj.input1_y = serde::Deserializable<decltype(obj.input1_y)>::deserialize(deserializer);
obj.input1_infinite = serde::Deserializable<decltype(obj.input1_infinite)>::deserialize(deserializer);
obj.input2_x = serde::Deserializable<decltype(obj.input2_x)>::deserialize(deserializer);
obj.input2_y = serde::Deserializable<decltype(obj.input2_y)>::deserialize(deserializer);
obj.input2_infinite = serde::Deserializable<decltype(obj.input2_infinite)>::deserialize(deserializer);
obj.result = serde::Deserializable<decltype(obj.result)>::deserialize(deserializer);
return obj;
}
Expand Down

0 comments on commit c65a009

Please sign in to comment.