Skip to content

Commit

Permalink
chore: add more cases for assert_equal conversion (#8446)
Browse files Browse the repository at this point in the history
Transform arithmetic gate of the kind a==b into a copy constraint
between a and b, as long as a or b is already constrained. In that case,
we mark both a and b as constrained.
  • Loading branch information
guipublic authored Sep 17, 2024
1 parent 38e3051 commit e3ea298
Show file tree
Hide file tree
Showing 3 changed files with 83 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#include "barretenberg/stdlib_circuit_builders/ultra_circuit_builder.hpp"
#include "proof_surgeon.hpp"
#include <cstddef>
#include <cstdint>

namespace acir_format {

Expand Down Expand Up @@ -55,7 +56,11 @@ void build_constraints(Builder& builder,
// Add range constraint
for (size_t i = 0; i < constraint_system.range_constraints.size(); ++i) {
const auto& constraint = constraint_system.range_constraints.at(i);
builder.create_range_constraint(constraint.witness, constraint.num_bits, "");
uint32_t range = constraint.num_bits;
if (constraint_system.minimal_range.contains(constraint.witness)) {
range = constraint_system.minimal_range[constraint.witness];
}
builder.create_range_constraint(constraint.witness, range, "");
gate_counter.track_diff(constraint_system.gates_per_opcode,
constraint_system.original_opcode_indices.range_constraints.at(i));
}
Expand Down Expand Up @@ -212,10 +217,10 @@ void build_constraints(Builder& builder,
gate_counter.track_diff(constraint_system.gates_per_opcode,
constraint_system.original_opcode_indices.bigint_to_le_bytes_constraints.at(i));
}

// assert equals
for (size_t i = 0; i < constraint_system.assert_equalities.size(); ++i) {
const auto& constraint = constraint_system.assert_equalities.at(i);

builder.assert_equal(constraint.a, constraint.b);
gate_counter.track_diff(constraint_system.gates_per_opcode,
constraint_system.original_opcode_indices.assert_equalities.at(i));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
#include "recursion_constraint.hpp"
#include "schnorr_verify.hpp"
#include "sha256_constraint.hpp"
#include <cstdint>
#include <utility>
#include <vector>

Expand Down Expand Up @@ -125,6 +126,7 @@ struct AcirFormat {

// Set of constrained witnesses
std::set<uint32_t> constrained_witness = {};
std::map<uint32_t, uint32_t> minimal_range = {};

// Indices of the original opcode that originated each constraint in AcirFormat.
AcirFormatOriginalOpcodeIndices original_opcode_indices;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,8 @@ std::pair<uint32_t, uint32_t> is_assert_equal(Program::Opcode::AssertZero const&
return { 0, 0 };
}
if (pt.q_l == -pt.q_r && pt.q_l != bb::fr::zero() && pt.q_c == bb::fr::zero()) {
if (af.constrained_witness.contains(pt.a) && af.constrained_witness.contains(pt.b)) {
// we require that one of the 2 witnesses to be constrained in an arithmetic gate
if (af.constrained_witness.contains(pt.a) || af.constrained_witness.contains(pt.b)) {
return { pt.a, pt.b };
}
}
Expand All @@ -210,13 +211,39 @@ void handle_arithmetic(Program::Opcode::AssertZero const& arg, AcirFormat& af, s
uint32_t w2 = std::get<1>(assert_equal);
if (w1 != 0) {
if (w1 != w2) {
if (!af.constrained_witness.contains(pt.a)) {
// we mark it as constrained because it is going to be asserted to be equal to a constrained one.
af.constrained_witness.insert(pt.a);
// swap the witnesses so that the first one is always properly constrained.
auto tmp = pt.a;
pt.a = pt.b;
pt.b = tmp;
}
if (!af.constrained_witness.contains(pt.b)) {
// we mark it as constrained because it is going to be asserted to be equal to a constrained one.
af.constrained_witness.insert(pt.b);
}
// minimal_range of a witness is the smallest range of the witness and the witness that are
// 'assert_equal' to it
if (af.minimal_range.contains(pt.b) && af.minimal_range.contains(pt.a)) {
if (af.minimal_range[pt.a] < af.minimal_range[pt.b]) {
af.minimal_range[pt.a] = af.minimal_range[pt.b];
} else {
af.minimal_range[pt.b] = af.minimal_range[pt.a];
}
} else if (af.minimal_range.contains(pt.b)) {
af.minimal_range[pt.a] = af.minimal_range[pt.b];
} else if (af.minimal_range.contains(pt.a)) {
af.minimal_range[pt.b] = af.minimal_range[pt.a];
}

af.assert_equalities.push_back(pt);
af.original_opcode_indices.assert_equalities.push_back(opcode_index);
}
return;
}
// Even if the number of linear terms is less than 3, we might not be able to fit it into a width-3 arithmetic
// gate. This is the case if the linear terms are all disctinct witness from the multiplication term. In that
// gate. This is the case if the linear terms are all distinct witness from the multiplication term. In that
// case, the serialize_arithmetic_gate() function will return a poly_triple with all 0's, and we use a width-4
// gate instead. We could probably always use a width-4 gate in fact.
if (pt == poly_triple{ 0, 0, 0, 0, 0, 0, 0, 0 }) {
Expand Down Expand Up @@ -288,6 +315,7 @@ void handle_blackbox_func_call(Program::Opcode::BlackBoxFuncCall const& arg,
.num_bits = arg.lhs.num_bits,
.is_xor_gate = false,
});
af.constrained_witness.insert(af.logic_constraints.back().result);
af.original_opcode_indices.logic_constraints.push_back(opcode_index);
} else if constexpr (std::is_same_v<T, Program::BlackBoxFuncCall::XOR>) {
auto lhs_input = parse_input(arg.lhs);
Expand All @@ -299,6 +327,7 @@ void handle_blackbox_func_call(Program::Opcode::BlackBoxFuncCall const& arg,
.num_bits = arg.lhs.num_bits,
.is_xor_gate = true,
});
af.constrained_witness.insert(af.logic_constraints.back().result);
af.original_opcode_indices.logic_constraints.push_back(opcode_index);
} else if constexpr (std::is_same_v<T, Program::BlackBoxFuncCall::RANGE>) {
auto witness_input = get_witness_from_function_input(arg.input);
Expand All @@ -307,14 +336,23 @@ void handle_blackbox_func_call(Program::Opcode::BlackBoxFuncCall const& arg,
.num_bits = arg.input.num_bits,
});
af.original_opcode_indices.range_constraints.push_back(opcode_index);

if (af.minimal_range.contains(witness_input)) {
if (af.minimal_range[witness_input] > arg.input.num_bits) {
af.minimal_range[witness_input] = arg.input.num_bits;
}
} else {
af.minimal_range[witness_input] = arg.input.num_bits;
}
} else if constexpr (std::is_same_v<T, Program::BlackBoxFuncCall::AES128Encrypt>) {
af.aes128_constraints.push_back(AES128Constraint{
.inputs = map(arg.inputs, [](auto& e) { return parse_input(e); }),
.iv = map(arg.iv, [](auto& e) { return parse_input(e); }),
.key = map(arg.key, [](auto& e) { return parse_input(e); }),
.outputs = map(arg.outputs, [](auto& e) { return e.value; }),
});
for (auto& output : af.aes128_constraints.back().outputs) {
af.constrained_witness.insert(output);
}
af.original_opcode_indices.aes128_constraints.push_back(opcode_index);

} else if constexpr (std::is_same_v<T, Program::BlackBoxFuncCall::SHA256>) {
Expand All @@ -337,6 +375,9 @@ void handle_blackbox_func_call(Program::Opcode::BlackBoxFuncCall const& arg,
.hash_values = map(arg.hash_values, [](auto& e) { return parse_input(e); }),
.result = map(arg.outputs, [](auto& e) { return e.value; }),
});
for (auto& output : af.sha256_compression.back().result) {
af.constrained_witness.insert(output);
}
af.original_opcode_indices.sha256_compression.push_back(opcode_index);
} else if constexpr (std::is_same_v<T, Program::BlackBoxFuncCall::Blake2s>) {
af.blake2s_constraints.push_back(Blake2sConstraint{
Expand All @@ -349,6 +390,9 @@ void handle_blackbox_func_call(Program::Opcode::BlackBoxFuncCall const& arg,
}),
.result = map(arg.outputs, [](auto& e) { return e.value; }),
});
for (auto& output : af.blake2s_constraints.back().result) {
af.constrained_witness.insert(output);
}
af.original_opcode_indices.blake2s_constraints.push_back(opcode_index);
} else if constexpr (std::is_same_v<T, Program::BlackBoxFuncCall::Blake3>) {
af.blake3_constraints.push_back(Blake3Constraint{
Expand All @@ -361,6 +405,9 @@ void handle_blackbox_func_call(Program::Opcode::BlackBoxFuncCall const& arg,
}),
.result = map(arg.outputs, [](auto& e) { return e.value; }),
});
for (auto& output : af.blake3_constraints.back().result) {
af.constrained_witness.insert(output);
}
af.original_opcode_indices.blake3_constraints.push_back(opcode_index);
} else if constexpr (std::is_same_v<T, Program::BlackBoxFuncCall::SchnorrVerify>) {
auto input_pkey_x = get_witness_from_function_input(arg.public_key_x);
Expand All @@ -373,21 +420,24 @@ void handle_blackbox_func_call(Program::Opcode::BlackBoxFuncCall const& arg,
.signature = map(arg.signature, [](auto& e) { return get_witness_from_function_input(e); }),
});
af.original_opcode_indices.schnorr_constraints.push_back(opcode_index);
af.constrained_witness.insert(af.schnorr_constraints.back().result);
} else if constexpr (std::is_same_v<T, Program::BlackBoxFuncCall::PedersenCommitment>) {

af.pedersen_constraints.push_back(PedersenConstraint{
.scalars = map(arg.inputs, [](auto& e) { return get_witness_from_function_input(e); }),
.hash_index = arg.domain_separator,
.result_x = arg.outputs[0].value,
.result_y = arg.outputs[1].value,
});
af.constrained_witness.insert(af.pedersen_constraints.back().result_x);
af.constrained_witness.insert(af.pedersen_constraints.back().result_y);
af.original_opcode_indices.pedersen_constraints.push_back(opcode_index);
} else if constexpr (std::is_same_v<T, Program::BlackBoxFuncCall::PedersenHash>) {
af.pedersen_hash_constraints.push_back(PedersenHashConstraint{
.scalars = map(arg.inputs, [](auto& e) { return get_witness_from_function_input(e); }),
.hash_index = arg.domain_separator,
.result = arg.output.value,
});
af.constrained_witness.insert(af.pedersen_hash_constraints.back().result);
af.original_opcode_indices.pedersen_hash_constraints.push_back(opcode_index);
} else if constexpr (std::is_same_v<T, Program::BlackBoxFuncCall::EcdsaSecp256k1>) {
af.ecdsa_k1_constraints.push_back(EcdsaSecp256k1Constraint{
Expand All @@ -398,6 +448,7 @@ void handle_blackbox_func_call(Program::Opcode::BlackBoxFuncCall const& arg,
.pub_y_indices = map(arg.public_key_y, [](auto& e) { return get_witness_from_function_input(e); }),
.result = arg.output.value,
});
af.constrained_witness.insert(af.ecdsa_k1_constraints.back().result);
af.original_opcode_indices.ecdsa_k1_constraints.push_back(opcode_index);
} else if constexpr (std::is_same_v<T, Program::BlackBoxFuncCall::EcdsaSecp256r1>) {
af.ecdsa_r1_constraints.push_back(EcdsaSecp256r1Constraint{
Expand All @@ -408,6 +459,7 @@ void handle_blackbox_func_call(Program::Opcode::BlackBoxFuncCall const& arg,
.result = arg.output.value,
.signature = map(arg.signature, [](auto& e) { return get_witness_from_function_input(e); }),
});
af.constrained_witness.insert(af.ecdsa_r1_constraints.back().result);
af.original_opcode_indices.ecdsa_r1_constraints.push_back(opcode_index);
} else if constexpr (std::is_same_v<T, Program::BlackBoxFuncCall::MultiScalarMul>) {
af.multi_scalar_mul_constraints.push_back(MultiScalarMul{
Expand All @@ -417,6 +469,9 @@ void handle_blackbox_func_call(Program::Opcode::BlackBoxFuncCall const& arg,
.out_point_y = arg.outputs[1].value,
.out_point_is_infinite = arg.outputs[2].value,
});
af.constrained_witness.insert(af.multi_scalar_mul_constraints.back().out_point_x);
af.constrained_witness.insert(af.multi_scalar_mul_constraints.back().out_point_y);
af.constrained_witness.insert(af.multi_scalar_mul_constraints.back().out_point_is_infinite);
af.original_opcode_indices.multi_scalar_mul_constraints.push_back(opcode_index);
} else if constexpr (std::is_same_v<T, Program::BlackBoxFuncCall::EmbeddedCurveAdd>) {
auto input_1_x = parse_input(arg.input1[0]);
Expand All @@ -437,6 +492,9 @@ void handle_blackbox_func_call(Program::Opcode::BlackBoxFuncCall const& arg,
.result_y = arg.outputs[1].value,
.result_infinite = arg.outputs[2].value,
});
af.constrained_witness.insert(af.ec_add_constraints.back().result_x);
af.constrained_witness.insert(af.ec_add_constraints.back().result_y);
af.constrained_witness.insert(af.ec_add_constraints.back().result_infinite);
af.original_opcode_indices.ec_add_constraints.push_back(opcode_index);
} else if constexpr (std::is_same_v<T, Program::BlackBoxFuncCall::Keccak256>) {
auto input_var_message_size = get_witness_from_function_input(arg.var_message_size);
Expand All @@ -452,12 +510,18 @@ void handle_blackbox_func_call(Program::Opcode::BlackBoxFuncCall const& arg,
.result = map(arg.outputs, [](auto& e) { return e.value; }),
.var_message_size = input_var_message_size,
});
for (auto& output : af.keccak_constraints.back().result) {
af.constrained_witness.insert(output);
}
af.original_opcode_indices.keccak_constraints.push_back(opcode_index);
} else if constexpr (std::is_same_v<T, Program::BlackBoxFuncCall::Keccakf1600>) {
af.keccak_permutations.push_back(Keccakf1600{
.state = map(arg.inputs, [](auto& e) { return parse_input(e); }),
.result = map(arg.outputs, [](auto& e) { return e.value; }),
});
for (auto& output : af.keccak_permutations.back().result) {
af.constrained_witness.insert(output);
}
af.original_opcode_indices.keccak_permutations.push_back(opcode_index);
} else if constexpr (std::is_same_v<T, Program::BlackBoxFuncCall::RecursiveAggregation>) {

Expand Down Expand Up @@ -509,6 +573,9 @@ void handle_blackbox_func_call(Program::Opcode::BlackBoxFuncCall const& arg,
.input = arg.input,
.result = map(arg.outputs, [](auto& e) { return e.value; }),
});
for (auto& output : af.bigint_to_le_bytes_constraints.back().result) {
af.constrained_witness.insert(output);
}
af.original_opcode_indices.bigint_to_le_bytes_constraints.push_back(opcode_index);
} else if constexpr (std::is_same_v<T, Program::BlackBoxFuncCall::BigIntAdd>) {
af.bigint_operations.push_back(BigIntOperation{
Expand Down Expand Up @@ -548,6 +615,9 @@ void handle_blackbox_func_call(Program::Opcode::BlackBoxFuncCall const& arg,
.result = map(arg.outputs, [](auto& e) { return e.value; }),
.len = arg.len,
});
for (auto& output : af.poseidon2_constraints.back().result) {
af.constrained_witness.insert(output);
}
af.original_opcode_indices.poseidon2_constraints.push_back(opcode_index);
}
},
Expand Down

0 comments on commit e3ea298

Please sign in to comment.