From e3ea298fd1f7326199e6e35b3523aadb2b12a925 Mon Sep 17 00:00:00 2001 From: guipublic <47281315+guipublic@users.noreply.github.com> Date: Tue, 17 Sep 2024 17:49:53 +0200 Subject: [PATCH] chore: add more cases for assert_equal conversion (#8446) Transform arithmetic gate of the kind a==b into a copy constraint between a and b, as long as a or b is already constrained. In that case, we mark both a and b as constrained. --- .../dsl/acir_format/acir_format.cpp | 9 ++- .../dsl/acir_format/acir_format.hpp | 2 + .../acir_format/acir_to_constraint_buf.cpp | 78 ++++++++++++++++++- 3 files changed, 83 insertions(+), 6 deletions(-) diff --git a/barretenberg/cpp/src/barretenberg/dsl/acir_format/acir_format.cpp b/barretenberg/cpp/src/barretenberg/dsl/acir_format/acir_format.cpp index ef8ee0eebc7..42faf8bf36b 100644 --- a/barretenberg/cpp/src/barretenberg/dsl/acir_format/acir_format.cpp +++ b/barretenberg/cpp/src/barretenberg/dsl/acir_format/acir_format.cpp @@ -7,6 +7,7 @@ #include "barretenberg/stdlib_circuit_builders/ultra_circuit_builder.hpp" #include "proof_surgeon.hpp" #include +#include namespace acir_format { @@ -55,7 +56,11 @@ void build_constraints(Builder& builder, // Add range constraint for (size_t i = 0; i < constraint_system.range_constraints.size(); ++i) { const auto& constraint = constraint_system.range_constraints.at(i); - builder.create_range_constraint(constraint.witness, constraint.num_bits, ""); + uint32_t range = constraint.num_bits; + if (constraint_system.minimal_range.contains(constraint.witness)) { + range = constraint_system.minimal_range[constraint.witness]; + } + builder.create_range_constraint(constraint.witness, range, ""); gate_counter.track_diff(constraint_system.gates_per_opcode, constraint_system.original_opcode_indices.range_constraints.at(i)); } @@ -212,10 +217,10 @@ void build_constraints(Builder& builder, gate_counter.track_diff(constraint_system.gates_per_opcode, constraint_system.original_opcode_indices.bigint_to_le_bytes_constraints.at(i)); } + // assert equals for (size_t i = 0; i < constraint_system.assert_equalities.size(); ++i) { const auto& constraint = constraint_system.assert_equalities.at(i); - builder.assert_equal(constraint.a, constraint.b); gate_counter.track_diff(constraint_system.gates_per_opcode, constraint_system.original_opcode_indices.assert_equalities.at(i)); diff --git a/barretenberg/cpp/src/barretenberg/dsl/acir_format/acir_format.hpp b/barretenberg/cpp/src/barretenberg/dsl/acir_format/acir_format.hpp index ef80b59020e..7e72e80e318 100644 --- a/barretenberg/cpp/src/barretenberg/dsl/acir_format/acir_format.hpp +++ b/barretenberg/cpp/src/barretenberg/dsl/acir_format/acir_format.hpp @@ -25,6 +25,7 @@ #include "recursion_constraint.hpp" #include "schnorr_verify.hpp" #include "sha256_constraint.hpp" +#include #include #include @@ -125,6 +126,7 @@ struct AcirFormat { // Set of constrained witnesses std::set constrained_witness = {}; + std::map minimal_range = {}; // Indices of the original opcode that originated each constraint in AcirFormat. AcirFormatOriginalOpcodeIndices original_opcode_indices; diff --git a/barretenberg/cpp/src/barretenberg/dsl/acir_format/acir_to_constraint_buf.cpp b/barretenberg/cpp/src/barretenberg/dsl/acir_format/acir_to_constraint_buf.cpp index c42ba62db64..245509e6b5c 100644 --- a/barretenberg/cpp/src/barretenberg/dsl/acir_format/acir_to_constraint_buf.cpp +++ b/barretenberg/cpp/src/barretenberg/dsl/acir_format/acir_to_constraint_buf.cpp @@ -193,7 +193,8 @@ std::pair is_assert_equal(Program::Opcode::AssertZero const& return { 0, 0 }; } if (pt.q_l == -pt.q_r && pt.q_l != bb::fr::zero() && pt.q_c == bb::fr::zero()) { - if (af.constrained_witness.contains(pt.a) && af.constrained_witness.contains(pt.b)) { + // we require that one of the 2 witnesses to be constrained in an arithmetic gate + if (af.constrained_witness.contains(pt.a) || af.constrained_witness.contains(pt.b)) { return { pt.a, pt.b }; } } @@ -210,13 +211,39 @@ void handle_arithmetic(Program::Opcode::AssertZero const& arg, AcirFormat& af, s uint32_t w2 = std::get<1>(assert_equal); if (w1 != 0) { if (w1 != w2) { + if (!af.constrained_witness.contains(pt.a)) { + // we mark it as constrained because it is going to be asserted to be equal to a constrained one. + af.constrained_witness.insert(pt.a); + // swap the witnesses so that the first one is always properly constrained. + auto tmp = pt.a; + pt.a = pt.b; + pt.b = tmp; + } + if (!af.constrained_witness.contains(pt.b)) { + // we mark it as constrained because it is going to be asserted to be equal to a constrained one. + af.constrained_witness.insert(pt.b); + } + // minimal_range of a witness is the smallest range of the witness and the witness that are + // 'assert_equal' to it + if (af.minimal_range.contains(pt.b) && af.minimal_range.contains(pt.a)) { + if (af.minimal_range[pt.a] < af.minimal_range[pt.b]) { + af.minimal_range[pt.a] = af.minimal_range[pt.b]; + } else { + af.minimal_range[pt.b] = af.minimal_range[pt.a]; + } + } else if (af.minimal_range.contains(pt.b)) { + af.minimal_range[pt.a] = af.minimal_range[pt.b]; + } else if (af.minimal_range.contains(pt.a)) { + af.minimal_range[pt.b] = af.minimal_range[pt.a]; + } + af.assert_equalities.push_back(pt); af.original_opcode_indices.assert_equalities.push_back(opcode_index); } return; } // Even if the number of linear terms is less than 3, we might not be able to fit it into a width-3 arithmetic - // gate. This is the case if the linear terms are all disctinct witness from the multiplication term. In that + // gate. This is the case if the linear terms are all distinct witness from the multiplication term. In that // case, the serialize_arithmetic_gate() function will return a poly_triple with all 0's, and we use a width-4 // gate instead. We could probably always use a width-4 gate in fact. if (pt == poly_triple{ 0, 0, 0, 0, 0, 0, 0, 0 }) { @@ -288,6 +315,7 @@ void handle_blackbox_func_call(Program::Opcode::BlackBoxFuncCall const& arg, .num_bits = arg.lhs.num_bits, .is_xor_gate = false, }); + af.constrained_witness.insert(af.logic_constraints.back().result); af.original_opcode_indices.logic_constraints.push_back(opcode_index); } else if constexpr (std::is_same_v) { auto lhs_input = parse_input(arg.lhs); @@ -299,6 +327,7 @@ void handle_blackbox_func_call(Program::Opcode::BlackBoxFuncCall const& arg, .num_bits = arg.lhs.num_bits, .is_xor_gate = true, }); + af.constrained_witness.insert(af.logic_constraints.back().result); af.original_opcode_indices.logic_constraints.push_back(opcode_index); } else if constexpr (std::is_same_v) { auto witness_input = get_witness_from_function_input(arg.input); @@ -307,7 +336,13 @@ void handle_blackbox_func_call(Program::Opcode::BlackBoxFuncCall const& arg, .num_bits = arg.input.num_bits, }); af.original_opcode_indices.range_constraints.push_back(opcode_index); - + if (af.minimal_range.contains(witness_input)) { + if (af.minimal_range[witness_input] > arg.input.num_bits) { + af.minimal_range[witness_input] = arg.input.num_bits; + } + } else { + af.minimal_range[witness_input] = arg.input.num_bits; + } } else if constexpr (std::is_same_v) { af.aes128_constraints.push_back(AES128Constraint{ .inputs = map(arg.inputs, [](auto& e) { return parse_input(e); }), @@ -315,6 +350,9 @@ void handle_blackbox_func_call(Program::Opcode::BlackBoxFuncCall const& arg, .key = map(arg.key, [](auto& e) { return parse_input(e); }), .outputs = map(arg.outputs, [](auto& e) { return e.value; }), }); + for (auto& output : af.aes128_constraints.back().outputs) { + af.constrained_witness.insert(output); + } af.original_opcode_indices.aes128_constraints.push_back(opcode_index); } else if constexpr (std::is_same_v) { @@ -337,6 +375,9 @@ void handle_blackbox_func_call(Program::Opcode::BlackBoxFuncCall const& arg, .hash_values = map(arg.hash_values, [](auto& e) { return parse_input(e); }), .result = map(arg.outputs, [](auto& e) { return e.value; }), }); + for (auto& output : af.sha256_compression.back().result) { + af.constrained_witness.insert(output); + } af.original_opcode_indices.sha256_compression.push_back(opcode_index); } else if constexpr (std::is_same_v) { af.blake2s_constraints.push_back(Blake2sConstraint{ @@ -349,6 +390,9 @@ void handle_blackbox_func_call(Program::Opcode::BlackBoxFuncCall const& arg, }), .result = map(arg.outputs, [](auto& e) { return e.value; }), }); + for (auto& output : af.blake2s_constraints.back().result) { + af.constrained_witness.insert(output); + } af.original_opcode_indices.blake2s_constraints.push_back(opcode_index); } else if constexpr (std::is_same_v) { af.blake3_constraints.push_back(Blake3Constraint{ @@ -361,6 +405,9 @@ void handle_blackbox_func_call(Program::Opcode::BlackBoxFuncCall const& arg, }), .result = map(arg.outputs, [](auto& e) { return e.value; }), }); + for (auto& output : af.blake3_constraints.back().result) { + af.constrained_witness.insert(output); + } af.original_opcode_indices.blake3_constraints.push_back(opcode_index); } else if constexpr (std::is_same_v) { auto input_pkey_x = get_witness_from_function_input(arg.public_key_x); @@ -373,14 +420,16 @@ void handle_blackbox_func_call(Program::Opcode::BlackBoxFuncCall const& arg, .signature = map(arg.signature, [](auto& e) { return get_witness_from_function_input(e); }), }); af.original_opcode_indices.schnorr_constraints.push_back(opcode_index); + af.constrained_witness.insert(af.schnorr_constraints.back().result); } else if constexpr (std::is_same_v) { - af.pedersen_constraints.push_back(PedersenConstraint{ .scalars = map(arg.inputs, [](auto& e) { return get_witness_from_function_input(e); }), .hash_index = arg.domain_separator, .result_x = arg.outputs[0].value, .result_y = arg.outputs[1].value, }); + af.constrained_witness.insert(af.pedersen_constraints.back().result_x); + af.constrained_witness.insert(af.pedersen_constraints.back().result_y); af.original_opcode_indices.pedersen_constraints.push_back(opcode_index); } else if constexpr (std::is_same_v) { af.pedersen_hash_constraints.push_back(PedersenHashConstraint{ @@ -388,6 +437,7 @@ void handle_blackbox_func_call(Program::Opcode::BlackBoxFuncCall const& arg, .hash_index = arg.domain_separator, .result = arg.output.value, }); + af.constrained_witness.insert(af.pedersen_hash_constraints.back().result); af.original_opcode_indices.pedersen_hash_constraints.push_back(opcode_index); } else if constexpr (std::is_same_v) { af.ecdsa_k1_constraints.push_back(EcdsaSecp256k1Constraint{ @@ -398,6 +448,7 @@ void handle_blackbox_func_call(Program::Opcode::BlackBoxFuncCall const& arg, .pub_y_indices = map(arg.public_key_y, [](auto& e) { return get_witness_from_function_input(e); }), .result = arg.output.value, }); + af.constrained_witness.insert(af.ecdsa_k1_constraints.back().result); af.original_opcode_indices.ecdsa_k1_constraints.push_back(opcode_index); } else if constexpr (std::is_same_v) { af.ecdsa_r1_constraints.push_back(EcdsaSecp256r1Constraint{ @@ -408,6 +459,7 @@ void handle_blackbox_func_call(Program::Opcode::BlackBoxFuncCall const& arg, .result = arg.output.value, .signature = map(arg.signature, [](auto& e) { return get_witness_from_function_input(e); }), }); + af.constrained_witness.insert(af.ecdsa_r1_constraints.back().result); af.original_opcode_indices.ecdsa_r1_constraints.push_back(opcode_index); } else if constexpr (std::is_same_v) { af.multi_scalar_mul_constraints.push_back(MultiScalarMul{ @@ -417,6 +469,9 @@ void handle_blackbox_func_call(Program::Opcode::BlackBoxFuncCall const& arg, .out_point_y = arg.outputs[1].value, .out_point_is_infinite = arg.outputs[2].value, }); + af.constrained_witness.insert(af.multi_scalar_mul_constraints.back().out_point_x); + af.constrained_witness.insert(af.multi_scalar_mul_constraints.back().out_point_y); + af.constrained_witness.insert(af.multi_scalar_mul_constraints.back().out_point_is_infinite); af.original_opcode_indices.multi_scalar_mul_constraints.push_back(opcode_index); } else if constexpr (std::is_same_v) { auto input_1_x = parse_input(arg.input1[0]); @@ -437,6 +492,9 @@ void handle_blackbox_func_call(Program::Opcode::BlackBoxFuncCall const& arg, .result_y = arg.outputs[1].value, .result_infinite = arg.outputs[2].value, }); + af.constrained_witness.insert(af.ec_add_constraints.back().result_x); + af.constrained_witness.insert(af.ec_add_constraints.back().result_y); + af.constrained_witness.insert(af.ec_add_constraints.back().result_infinite); af.original_opcode_indices.ec_add_constraints.push_back(opcode_index); } else if constexpr (std::is_same_v) { auto input_var_message_size = get_witness_from_function_input(arg.var_message_size); @@ -452,12 +510,18 @@ void handle_blackbox_func_call(Program::Opcode::BlackBoxFuncCall const& arg, .result = map(arg.outputs, [](auto& e) { return e.value; }), .var_message_size = input_var_message_size, }); + for (auto& output : af.keccak_constraints.back().result) { + af.constrained_witness.insert(output); + } af.original_opcode_indices.keccak_constraints.push_back(opcode_index); } else if constexpr (std::is_same_v) { af.keccak_permutations.push_back(Keccakf1600{ .state = map(arg.inputs, [](auto& e) { return parse_input(e); }), .result = map(arg.outputs, [](auto& e) { return e.value; }), }); + for (auto& output : af.keccak_permutations.back().result) { + af.constrained_witness.insert(output); + } af.original_opcode_indices.keccak_permutations.push_back(opcode_index); } else if constexpr (std::is_same_v) { @@ -509,6 +573,9 @@ void handle_blackbox_func_call(Program::Opcode::BlackBoxFuncCall const& arg, .input = arg.input, .result = map(arg.outputs, [](auto& e) { return e.value; }), }); + for (auto& output : af.bigint_to_le_bytes_constraints.back().result) { + af.constrained_witness.insert(output); + } af.original_opcode_indices.bigint_to_le_bytes_constraints.push_back(opcode_index); } else if constexpr (std::is_same_v) { af.bigint_operations.push_back(BigIntOperation{ @@ -548,6 +615,9 @@ void handle_blackbox_func_call(Program::Opcode::BlackBoxFuncCall const& arg, .result = map(arg.outputs, [](auto& e) { return e.value; }), .len = arg.len, }); + for (auto& output : af.poseidon2_constraints.back().result) { + af.constrained_witness.insert(output); + } af.original_opcode_indices.poseidon2_constraints.push_back(opcode_index); } },