Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore: add more cases for assert_equal conversion #8446

Merged
merged 7 commits into from
Sep 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#include "barretenberg/stdlib_circuit_builders/ultra_circuit_builder.hpp"
#include "proof_surgeon.hpp"
#include <cstddef>
#include <cstdint>

namespace acir_format {

Expand Down Expand Up @@ -55,7 +56,11 @@ void build_constraints(Builder& builder,
// Add range constraint
for (size_t i = 0; i < constraint_system.range_constraints.size(); ++i) {
const auto& constraint = constraint_system.range_constraints.at(i);
builder.create_range_constraint(constraint.witness, constraint.num_bits, "");
uint32_t range = constraint.num_bits;
if (constraint_system.minimal_range.contains(constraint.witness)) {
range = constraint_system.minimal_range[constraint.witness];
}
builder.create_range_constraint(constraint.witness, range, "");
gate_counter.track_diff(constraint_system.gates_per_opcode,
constraint_system.original_opcode_indices.range_constraints.at(i));
}
Expand Down Expand Up @@ -212,10 +217,10 @@ void build_constraints(Builder& builder,
gate_counter.track_diff(constraint_system.gates_per_opcode,
constraint_system.original_opcode_indices.bigint_to_le_bytes_constraints.at(i));
}

// assert equals
for (size_t i = 0; i < constraint_system.assert_equalities.size(); ++i) {
const auto& constraint = constraint_system.assert_equalities.at(i);

builder.assert_equal(constraint.a, constraint.b);
gate_counter.track_diff(constraint_system.gates_per_opcode,
constraint_system.original_opcode_indices.assert_equalities.at(i));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
#include "recursion_constraint.hpp"
#include "schnorr_verify.hpp"
#include "sha256_constraint.hpp"
#include <cstdint>
#include <utility>
#include <vector>

Expand Down Expand Up @@ -125,6 +126,7 @@ struct AcirFormat {

// Set of constrained witnesses
std::set<uint32_t> constrained_witness = {};
std::map<uint32_t, uint32_t> minimal_range = {};

// Indices of the original opcode that originated each constraint in AcirFormat.
AcirFormatOriginalOpcodeIndices original_opcode_indices;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,8 @@ std::pair<uint32_t, uint32_t> is_assert_equal(Program::Opcode::AssertZero const&
return { 0, 0 };
}
if (pt.q_l == -pt.q_r && pt.q_l != bb::fr::zero() && pt.q_c == bb::fr::zero()) {
if (af.constrained_witness.contains(pt.a) && af.constrained_witness.contains(pt.b)) {
// we require that one of the 2 witnesses to be constrained in an arithmetic gate
if (af.constrained_witness.contains(pt.a) || af.constrained_witness.contains(pt.b)) {
return { pt.a, pt.b };
}
}
Expand All @@ -210,13 +211,39 @@ void handle_arithmetic(Program::Opcode::AssertZero const& arg, AcirFormat& af, s
uint32_t w2 = std::get<1>(assert_equal);
if (w1 != 0) {
if (w1 != w2) {
if (!af.constrained_witness.contains(pt.a)) {
// we mark it as constrained because it is going to be asserted to be equal to a constrained one.
af.constrained_witness.insert(pt.a);
// swap the witnesses so that the first one is always properly constrained.
auto tmp = pt.a;
pt.a = pt.b;
pt.b = tmp;
}
if (!af.constrained_witness.contains(pt.b)) {
// we mark it as constrained because it is going to be asserted to be equal to a constrained one.
af.constrained_witness.insert(pt.b);
}
// minimal_range of a witness is the smallest range of the witness and the witness that are
// 'assert_equal' to it
if (af.minimal_range.contains(pt.b) && af.minimal_range.contains(pt.a)) {
if (af.minimal_range[pt.a] < af.minimal_range[pt.b]) {
af.minimal_range[pt.a] = af.minimal_range[pt.b];
} else {
af.minimal_range[pt.b] = af.minimal_range[pt.a];
}
} else if (af.minimal_range.contains(pt.b)) {
af.minimal_range[pt.a] = af.minimal_range[pt.b];
} else if (af.minimal_range.contains(pt.a)) {
af.minimal_range[pt.b] = af.minimal_range[pt.a];
}

af.assert_equalities.push_back(pt);
af.original_opcode_indices.assert_equalities.push_back(opcode_index);
}
return;
}
// Even if the number of linear terms is less than 3, we might not be able to fit it into a width-3 arithmetic
// gate. This is the case if the linear terms are all disctinct witness from the multiplication term. In that
// gate. This is the case if the linear terms are all distinct witness from the multiplication term. In that
// case, the serialize_arithmetic_gate() function will return a poly_triple with all 0's, and we use a width-4
// gate instead. We could probably always use a width-4 gate in fact.
if (pt == poly_triple{ 0, 0, 0, 0, 0, 0, 0, 0 }) {
Expand Down Expand Up @@ -288,6 +315,7 @@ void handle_blackbox_func_call(Program::Opcode::BlackBoxFuncCall const& arg,
.num_bits = arg.lhs.num_bits,
.is_xor_gate = false,
});
af.constrained_witness.insert(af.logic_constraints.back().result);
af.original_opcode_indices.logic_constraints.push_back(opcode_index);
} else if constexpr (std::is_same_v<T, Program::BlackBoxFuncCall::XOR>) {
auto lhs_input = parse_input(arg.lhs);
Expand All @@ -299,6 +327,7 @@ void handle_blackbox_func_call(Program::Opcode::BlackBoxFuncCall const& arg,
.num_bits = arg.lhs.num_bits,
.is_xor_gate = true,
});
af.constrained_witness.insert(af.logic_constraints.back().result);
af.original_opcode_indices.logic_constraints.push_back(opcode_index);
} else if constexpr (std::is_same_v<T, Program::BlackBoxFuncCall::RANGE>) {
auto witness_input = get_witness_from_function_input(arg.input);
Expand All @@ -307,14 +336,23 @@ void handle_blackbox_func_call(Program::Opcode::BlackBoxFuncCall const& arg,
.num_bits = arg.input.num_bits,
});
af.original_opcode_indices.range_constraints.push_back(opcode_index);

if (af.minimal_range.contains(witness_input)) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if you'd like this entire if-else could be simplified to

if (af.minimal_range[witness_input] > arg.input.num_bits) {
    af.minimal_range[witness_input] = arg.input.num_bits;
}

since if the entry does not exist it will be default initialized to zero

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

0 is not a good default value for the minimal range, I'd prefer not having this value in the minimal_range map.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh no, 0 would never be added since the only cases here are a) the range has already been set or b) the range has not been set (and thus gets default initialized to 0) then immediately gets set to arg.input.num_bits. The behavior is identical to what you have here. Fine to leave it as is though

if (af.minimal_range[witness_input] > arg.input.num_bits) {
af.minimal_range[witness_input] = arg.input.num_bits;
}
} else {
af.minimal_range[witness_input] = arg.input.num_bits;
}
} else if constexpr (std::is_same_v<T, Program::BlackBoxFuncCall::AES128Encrypt>) {
af.aes128_constraints.push_back(AES128Constraint{
.inputs = map(arg.inputs, [](auto& e) { return parse_input(e); }),
.iv = map(arg.iv, [](auto& e) { return parse_input(e); }),
.key = map(arg.key, [](auto& e) { return parse_input(e); }),
.outputs = map(arg.outputs, [](auto& e) { return e.value; }),
});
for (auto& output : af.aes128_constraints.back().outputs) {
af.constrained_witness.insert(output);
}
af.original_opcode_indices.aes128_constraints.push_back(opcode_index);

} else if constexpr (std::is_same_v<T, Program::BlackBoxFuncCall::SHA256>) {
Expand All @@ -337,6 +375,9 @@ void handle_blackbox_func_call(Program::Opcode::BlackBoxFuncCall const& arg,
.hash_values = map(arg.hash_values, [](auto& e) { return parse_input(e); }),
.result = map(arg.outputs, [](auto& e) { return e.value; }),
});
for (auto& output : af.sha256_compression.back().result) {
af.constrained_witness.insert(output);
}
af.original_opcode_indices.sha256_compression.push_back(opcode_index);
} else if constexpr (std::is_same_v<T, Program::BlackBoxFuncCall::Blake2s>) {
af.blake2s_constraints.push_back(Blake2sConstraint{
Expand All @@ -349,6 +390,9 @@ void handle_blackbox_func_call(Program::Opcode::BlackBoxFuncCall const& arg,
}),
.result = map(arg.outputs, [](auto& e) { return e.value; }),
});
for (auto& output : af.blake2s_constraints.back().result) {
af.constrained_witness.insert(output);
}
af.original_opcode_indices.blake2s_constraints.push_back(opcode_index);
} else if constexpr (std::is_same_v<T, Program::BlackBoxFuncCall::Blake3>) {
af.blake3_constraints.push_back(Blake3Constraint{
Expand All @@ -361,6 +405,9 @@ void handle_blackbox_func_call(Program::Opcode::BlackBoxFuncCall const& arg,
}),
.result = map(arg.outputs, [](auto& e) { return e.value; }),
});
for (auto& output : af.blake3_constraints.back().result) {
af.constrained_witness.insert(output);
}
af.original_opcode_indices.blake3_constraints.push_back(opcode_index);
} else if constexpr (std::is_same_v<T, Program::BlackBoxFuncCall::SchnorrVerify>) {
auto input_pkey_x = get_witness_from_function_input(arg.public_key_x);
Expand All @@ -373,21 +420,24 @@ void handle_blackbox_func_call(Program::Opcode::BlackBoxFuncCall const& arg,
.signature = map(arg.signature, [](auto& e) { return get_witness_from_function_input(e); }),
});
af.original_opcode_indices.schnorr_constraints.push_back(opcode_index);
af.constrained_witness.insert(af.schnorr_constraints.back().result);
} else if constexpr (std::is_same_v<T, Program::BlackBoxFuncCall::PedersenCommitment>) {

af.pedersen_constraints.push_back(PedersenConstraint{
.scalars = map(arg.inputs, [](auto& e) { return get_witness_from_function_input(e); }),
.hash_index = arg.domain_separator,
.result_x = arg.outputs[0].value,
.result_y = arg.outputs[1].value,
});
af.constrained_witness.insert(af.pedersen_constraints.back().result_x);
af.constrained_witness.insert(af.pedersen_constraints.back().result_y);
af.original_opcode_indices.pedersen_constraints.push_back(opcode_index);
} else if constexpr (std::is_same_v<T, Program::BlackBoxFuncCall::PedersenHash>) {
af.pedersen_hash_constraints.push_back(PedersenHashConstraint{
.scalars = map(arg.inputs, [](auto& e) { return get_witness_from_function_input(e); }),
.hash_index = arg.domain_separator,
.result = arg.output.value,
});
af.constrained_witness.insert(af.pedersen_hash_constraints.back().result);
af.original_opcode_indices.pedersen_hash_constraints.push_back(opcode_index);
} else if constexpr (std::is_same_v<T, Program::BlackBoxFuncCall::EcdsaSecp256k1>) {
af.ecdsa_k1_constraints.push_back(EcdsaSecp256k1Constraint{
Expand All @@ -398,6 +448,7 @@ void handle_blackbox_func_call(Program::Opcode::BlackBoxFuncCall const& arg,
.pub_y_indices = map(arg.public_key_y, [](auto& e) { return get_witness_from_function_input(e); }),
.result = arg.output.value,
});
af.constrained_witness.insert(af.ecdsa_k1_constraints.back().result);
af.original_opcode_indices.ecdsa_k1_constraints.push_back(opcode_index);
} else if constexpr (std::is_same_v<T, Program::BlackBoxFuncCall::EcdsaSecp256r1>) {
af.ecdsa_r1_constraints.push_back(EcdsaSecp256r1Constraint{
Expand All @@ -408,6 +459,7 @@ void handle_blackbox_func_call(Program::Opcode::BlackBoxFuncCall const& arg,
.result = arg.output.value,
.signature = map(arg.signature, [](auto& e) { return get_witness_from_function_input(e); }),
});
af.constrained_witness.insert(af.ecdsa_r1_constraints.back().result);
af.original_opcode_indices.ecdsa_r1_constraints.push_back(opcode_index);
} else if constexpr (std::is_same_v<T, Program::BlackBoxFuncCall::MultiScalarMul>) {
af.multi_scalar_mul_constraints.push_back(MultiScalarMul{
Expand All @@ -417,6 +469,9 @@ void handle_blackbox_func_call(Program::Opcode::BlackBoxFuncCall const& arg,
.out_point_y = arg.outputs[1].value,
.out_point_is_infinite = arg.outputs[2].value,
});
af.constrained_witness.insert(af.multi_scalar_mul_constraints.back().out_point_x);
af.constrained_witness.insert(af.multi_scalar_mul_constraints.back().out_point_y);
af.constrained_witness.insert(af.multi_scalar_mul_constraints.back().out_point_is_infinite);
af.original_opcode_indices.multi_scalar_mul_constraints.push_back(opcode_index);
} else if constexpr (std::is_same_v<T, Program::BlackBoxFuncCall::EmbeddedCurveAdd>) {
auto input_1_x = parse_input(arg.input1[0]);
Expand All @@ -437,6 +492,9 @@ void handle_blackbox_func_call(Program::Opcode::BlackBoxFuncCall const& arg,
.result_y = arg.outputs[1].value,
.result_infinite = arg.outputs[2].value,
});
af.constrained_witness.insert(af.ec_add_constraints.back().result_x);
af.constrained_witness.insert(af.ec_add_constraints.back().result_y);
af.constrained_witness.insert(af.ec_add_constraints.back().result_infinite);
af.original_opcode_indices.ec_add_constraints.push_back(opcode_index);
} else if constexpr (std::is_same_v<T, Program::BlackBoxFuncCall::Keccak256>) {
auto input_var_message_size = get_witness_from_function_input(arg.var_message_size);
Expand All @@ -452,12 +510,18 @@ void handle_blackbox_func_call(Program::Opcode::BlackBoxFuncCall const& arg,
.result = map(arg.outputs, [](auto& e) { return e.value; }),
.var_message_size = input_var_message_size,
});
for (auto& output : af.keccak_constraints.back().result) {
af.constrained_witness.insert(output);
}
af.original_opcode_indices.keccak_constraints.push_back(opcode_index);
} else if constexpr (std::is_same_v<T, Program::BlackBoxFuncCall::Keccakf1600>) {
af.keccak_permutations.push_back(Keccakf1600{
.state = map(arg.inputs, [](auto& e) { return parse_input(e); }),
.result = map(arg.outputs, [](auto& e) { return e.value; }),
});
for (auto& output : af.keccak_permutations.back().result) {
af.constrained_witness.insert(output);
}
af.original_opcode_indices.keccak_permutations.push_back(opcode_index);
} else if constexpr (std::is_same_v<T, Program::BlackBoxFuncCall::RecursiveAggregation>) {

Expand Down Expand Up @@ -509,6 +573,9 @@ void handle_blackbox_func_call(Program::Opcode::BlackBoxFuncCall const& arg,
.input = arg.input,
.result = map(arg.outputs, [](auto& e) { return e.value; }),
});
for (auto& output : af.bigint_to_le_bytes_constraints.back().result) {
af.constrained_witness.insert(output);
}
af.original_opcode_indices.bigint_to_le_bytes_constraints.push_back(opcode_index);
} else if constexpr (std::is_same_v<T, Program::BlackBoxFuncCall::BigIntAdd>) {
af.bigint_operations.push_back(BigIntOperation{
Expand Down Expand Up @@ -548,6 +615,9 @@ void handle_blackbox_func_call(Program::Opcode::BlackBoxFuncCall const& arg,
.result = map(arg.outputs, [](auto& e) { return e.value; }),
.len = arg.len,
});
for (auto& output : af.poseidon2_constraints.back().result) {
af.constrained_witness.insert(output);
}
af.original_opcode_indices.poseidon2_constraints.push_back(opcode_index);
}
},
Expand Down
Loading