From 50b4a728b4c20503f6ab56c07feaa29d767cec10 Mon Sep 17 00:00:00 2001 From: Lucas Xia Date: Thu, 11 Jan 2024 18:18:48 -0500 Subject: [PATCH] feat: Poseidon2 stdlib impl (#3551) Poseidon2 permutation and sponge function stdlib implementation that follows native crypto/ implementation. Adds hash_buffer function to native and stdlib poseidon2 implementations. Updates CI tests with poseidon2 tests, stdlib_pedersen_hash tests. Adds poseidon2 end gate. Resolves https://github.com/AztecProtocol/barretenberg/issues/776 --- barretenberg/cpp/scripts/bb-tests.sh | 1 + barretenberg/cpp/scripts/stdlib-tests | 2 + .../crypto/pedersen_hash/pedersen.cpp | 12 +- .../crypto/poseidon2/poseidon2.bench.cpp | 3 +- .../crypto/poseidon2/poseidon2.cpp | 48 ++++ .../crypto/poseidon2/poseidon2.hpp | 14 +- .../crypto/poseidon2/poseidon2.test.cpp | 24 +- .../poseidon2/poseidon2_permutation.hpp | 11 +- .../arithmetization/gate_data.hpp | 14 +- .../goblin_ultra_circuit_builder.cpp | 36 +++ .../goblin_ultra_circuit_builder.hpp | 1 + .../barretenberg/stdlib/hash/CMakeLists.txt | 3 +- .../stdlib/hash/pedersen/pedersen.cpp | 19 +- .../stdlib/hash/pedersen/pedersen.hpp | 8 +- .../stdlib/hash/poseidon2/CMakeLists.txt | 1 + .../stdlib/hash/poseidon2/poseidon2.cpp | 46 ++++ .../stdlib/hash/poseidon2/poseidon2.hpp | 35 +++ .../stdlib/hash/poseidon2/poseidon2.test.cpp | 190 ++++++++++++++++ .../hash/poseidon2/poseidon2_permutation.cpp | 208 ++++++++++++++++++ .../hash/poseidon2/poseidon2_permutation.hpp | 67 ++++++ .../stdlib/hash/poseidon2/sponge/sponge.hpp | 182 +++++++++++++++ 21 files changed, 892 insertions(+), 33 deletions(-) create mode 100644 barretenberg/cpp/src/barretenberg/crypto/poseidon2/poseidon2.cpp create mode 100644 barretenberg/cpp/src/barretenberg/stdlib/hash/poseidon2/CMakeLists.txt create mode 100644 barretenberg/cpp/src/barretenberg/stdlib/hash/poseidon2/poseidon2.cpp create mode 100644 barretenberg/cpp/src/barretenberg/stdlib/hash/poseidon2/poseidon2.hpp create mode 100644 barretenberg/cpp/src/barretenberg/stdlib/hash/poseidon2/poseidon2.test.cpp create mode 100644 barretenberg/cpp/src/barretenberg/stdlib/hash/poseidon2/poseidon2_permutation.cpp create mode 100644 barretenberg/cpp/src/barretenberg/stdlib/hash/poseidon2/poseidon2_permutation.hpp create mode 100644 barretenberg/cpp/src/barretenberg/stdlib/hash/poseidon2/sponge/sponge.hpp diff --git a/barretenberg/cpp/scripts/bb-tests.sh b/barretenberg/cpp/scripts/bb-tests.sh index 934817d08a2..a4ba5b39417 100755 --- a/barretenberg/cpp/scripts/bb-tests.sh +++ b/barretenberg/cpp/scripts/bb-tests.sh @@ -18,6 +18,7 @@ TESTS=( crypto_ecdsa_tests crypto_pedersen_commitment_tests crypto_pedersen_hash_tests + crypto_poseidon2_tests crypto_schnorr_tests crypto_sha256_tests dsl_tests diff --git a/barretenberg/cpp/scripts/stdlib-tests b/barretenberg/cpp/scripts/stdlib-tests index 4f61b0c2f19..6cb94611294 100644 --- a/barretenberg/cpp/scripts/stdlib-tests +++ b/barretenberg/cpp/scripts/stdlib-tests @@ -5,5 +5,7 @@ stdlib_blake3s_tests stdlib_ecdsa_tests stdlib_merkle_tree_tests stdlib_pedersen_commitment_tests +stdlib_pedersen_hash_tests +stdlib_poseidon2_tests stdlib_schnorr_tests stdlib_sha256_tests \ No newline at end of file diff --git a/barretenberg/cpp/src/barretenberg/crypto/pedersen_hash/pedersen.cpp b/barretenberg/cpp/src/barretenberg/crypto/pedersen_hash/pedersen.cpp index f0f03378e80..9e702eccb2f 100644 --- a/barretenberg/cpp/src/barretenberg/crypto/pedersen_hash/pedersen.cpp +++ b/barretenberg/cpp/src/barretenberg/crypto/pedersen_hash/pedersen.cpp @@ -30,16 +30,14 @@ std::vector pedersen_hash_base::convert_buffer }; std::vector elements; - for (size_t i = 0; i < num_elements; ++i) { - size_t bytes_to_slice = 0; - if (i == num_elements - 1) { - bytes_to_slice = num_bytes - (i * bytes_per_element); - } else { - bytes_to_slice = bytes_per_element; - } + for (size_t i = 0; i < num_elements - 1; ++i) { + size_t bytes_to_slice = bytes_per_element; Fq element = slice(input, i * bytes_per_element, bytes_to_slice); elements.emplace_back(element); } + size_t bytes_to_slice = num_bytes - ((num_elements - 1) * bytes_per_element); + Fq element = slice(input, (num_elements - 1) * bytes_per_element, bytes_to_slice); + elements.emplace_back(element); return elements; } diff --git a/barretenberg/cpp/src/barretenberg/crypto/poseidon2/poseidon2.bench.cpp b/barretenberg/cpp/src/barretenberg/crypto/poseidon2/poseidon2.bench.cpp index 603238bf6e8..6b1b1457997 100644 --- a/barretenberg/cpp/src/barretenberg/crypto/poseidon2/poseidon2.bench.cpp +++ b/barretenberg/cpp/src/barretenberg/crypto/poseidon2/poseidon2.bench.cpp @@ -10,9 +10,8 @@ grumpkin::fq poseidon_function(const size_t count) for (size_t i = 0; i < count; ++i) { inputs[i] = grumpkin::fq::random_element(); } - std::span tmp(inputs); // hash count many field elements - inputs[0] = crypto::Poseidon2::hash(tmp); + inputs[0] = crypto::Poseidon2::hash(inputs); return inputs[0]; } diff --git a/barretenberg/cpp/src/barretenberg/crypto/poseidon2/poseidon2.cpp b/barretenberg/cpp/src/barretenberg/crypto/poseidon2/poseidon2.cpp new file mode 100644 index 00000000000..ad96dc31271 --- /dev/null +++ b/barretenberg/cpp/src/barretenberg/crypto/poseidon2/poseidon2.cpp @@ -0,0 +1,48 @@ +#include "poseidon2.hpp" + +namespace crypto { +/** + * @brief Hashes a vector of field elements + */ +template +typename Poseidon2::FF Poseidon2::hash(const std::vector::FF>& input) +{ + auto input_span = input; + return Sponge::hash_fixed_length(input_span); +} + +/** + * @brief Hashes vector of bytes by chunking it into 31 byte field elements and calling hash() + * @details Slice function cuts out the required number of bytes from the byte vector + */ +template +typename Poseidon2::FF Poseidon2::hash_buffer(const std::vector& input) +{ + const size_t num_bytes = input.size(); + const size_t bytes_per_element = 31; + size_t num_elements = static_cast(num_bytes % bytes_per_element != 0) + (num_bytes / bytes_per_element); + + const auto slice = [](const std::vector& data, const size_t start, const size_t slice_size) { + uint256_t result(0); + for (size_t i = 0; i < slice_size; ++i) { + result = (result << uint256_t(8)); + result += uint256_t(data[i + start]); + } + return FF(result); + }; + + std::vector converted; + for (size_t i = 0; i < num_elements - 1; ++i) { + size_t bytes_to_slice = bytes_per_element; + FF element = slice(input, i * bytes_per_element, bytes_to_slice); + converted.emplace_back(element); + } + size_t bytes_to_slice = num_bytes - ((num_elements - 1) * bytes_per_element); + FF element = slice(input, (num_elements - 1) * bytes_per_element, bytes_to_slice); + converted.emplace_back(element); + + return hash(converted); +} + +template class Poseidon2; +} // namespace crypto \ No newline at end of file diff --git a/barretenberg/cpp/src/barretenberg/crypto/poseidon2/poseidon2.hpp b/barretenberg/cpp/src/barretenberg/crypto/poseidon2/poseidon2.hpp index 15488e2d0b3..6969c680e17 100644 --- a/barretenberg/cpp/src/barretenberg/crypto/poseidon2/poseidon2.hpp +++ b/barretenberg/cpp/src/barretenberg/crypto/poseidon2/poseidon2.hpp @@ -10,7 +10,19 @@ template class Poseidon2 { public: using FF = typename Params::FF; + // We choose our rate to be t-1 and capacity to be 1. using Sponge = FieldSponge>; - static FF hash(std::span input) { return Sponge::hash_fixed_length(input); } + + /** + * @brief Hashes a vector of field elements + */ + static FF hash(const std::vector& input); + /** + * @brief Hashes vector of bytes by chunking it into 31 byte field elements and calling hash() + * @details Slice function cuts out the required number of bytes from the byte vector + */ + static FF hash_buffer(const std::vector& input); }; + +extern template class Poseidon2; } // namespace crypto \ No newline at end of file diff --git a/barretenberg/cpp/src/barretenberg/crypto/poseidon2/poseidon2.test.cpp b/barretenberg/cpp/src/barretenberg/crypto/poseidon2/poseidon2.test.cpp index 33757efd2b1..9649f757728 100644 --- a/barretenberg/cpp/src/barretenberg/crypto/poseidon2/poseidon2.test.cpp +++ b/barretenberg/cpp/src/barretenberg/crypto/poseidon2/poseidon2.test.cpp @@ -10,7 +10,7 @@ auto& engine = numeric::random::get_debug_engine(); } namespace poseidon2_tests { -TEST(Poseidon2, BasicTests) +TEST(Poseidon2, HashBasicTests) { barretenberg::fr a = barretenberg::fr::random_element(&engine); @@ -32,17 +32,33 @@ TEST(Poseidon2, BasicTests) // N.B. these hardcoded values were extracted from the algorithm being tested. These are NOT independent test vectors! // TODO(@zac-williamson #3132): find independent test vectors we can compare against! (very hard to find given // flexibility of Poseidon's parametrisation) -TEST(Poseidon2, ConsistencyCheck) +TEST(Poseidon2, HashConsistencyCheck) { barretenberg::fr a(std::string("9a807b615c4d3e2fa0b1c2d3e4f56789fedcba9876543210abcdef0123456789")); barretenberg::fr b(std::string("9a807b615c4d3e2fa0b1c2d3e4f56789fedcba9876543210abcdef0123456789")); barretenberg::fr c(std::string("0x9a807b615c4d3e2fa0b1c2d3e4f56789fedcba9876543210abcdef0123456789")); barretenberg::fr d(std::string("0x9a807b615c4d3e2fa0b1c2d3e4f56789fedcba9876543210abcdef0123456789")); - std::array input{ a, b, c, d }; + std::vector input{ a, b, c, d }; auto result = crypto::Poseidon2::hash(input); - barretenberg::fr expected(std::string("0x150c19ae11b3290c137c7a4d760d9482a6581d731535f560c3601d6a766b0937")); + barretenberg::fr expected(std::string("0x2f43a0f83b51a6f5fc839dea0ecec74947637802a579fa9841930a25a0bcec11")); + + EXPECT_EQ(result, expected); +} + +TEST(Poseidon2, HashBufferConsistencyCheck) +{ + // 31 byte inputs because hash_buffer slicing is only injective with 31 bytes, as it slices 31 bytes for each field + // element + barretenberg::fr a(std::string("00000b615c4d3e2fa0b1c2d3e4f56789fedcba9876543210abcdef0123456789")); + + auto input_vec = to_buffer(a); // takes field element and converts it to 32 bytes + input_vec.erase(input_vec.begin()); // erase first byte since we want 31 bytes + std::vector input{ a }; + auto expected = crypto::Poseidon2::hash(input); + + barretenberg::fr result = crypto::Poseidon2::hash_buffer(input_vec); EXPECT_EQ(result, expected); } diff --git a/barretenberg/cpp/src/barretenberg/crypto/poseidon2/poseidon2_permutation.hpp b/barretenberg/cpp/src/barretenberg/crypto/poseidon2/poseidon2_permutation.hpp index 9e3931cdac3..4f0794b893c 100644 --- a/barretenberg/cpp/src/barretenberg/crypto/poseidon2/poseidon2_permutation.hpp +++ b/barretenberg/cpp/src/barretenberg/crypto/poseidon2/poseidon2_permutation.hpp @@ -7,7 +7,6 @@ #include #include #include -#include namespace crypto { @@ -123,6 +122,13 @@ template class Poseidon2Permutation { } } + /** + * @brief Native form of Poseidon2 permutation from https://eprint.iacr.org/2023/323. + * @details The permutation consists of one initial linear layer, then a set of external rounds, a set of internal + * rounds, and a set of external rounds. + * @param input + * @return constexpr State + */ static constexpr State permutation(const State& input) { // deep copy @@ -131,6 +137,7 @@ template class Poseidon2Permutation { // Apply 1st linear layer matrix_multiplication_external(current_state); + // First set of external rounds constexpr size_t rounds_f_beginning = rounds_f / 2; for (size_t i = 0; i < rounds_f_beginning; ++i) { add_round_constants(current_state, round_constants[i]); @@ -138,6 +145,7 @@ template class Poseidon2Permutation { matrix_multiplication_external(current_state); } + // Internal rounds const size_t p_end = rounds_f_beginning + rounds_p; for (size_t i = rounds_f_beginning; i < p_end; ++i) { current_state[0] += round_constants[i][0]; @@ -145,6 +153,7 @@ template class Poseidon2Permutation { matrix_multiplication_internal(current_state); } + // Remaining external rounds for (size_t i = p_end; i < NUM_ROUNDS; ++i) { add_round_constants(current_state, round_constants[i]); apply_sbox(current_state); diff --git a/barretenberg/cpp/src/barretenberg/proof_system/arithmetization/gate_data.hpp b/barretenberg/cpp/src/barretenberg/proof_system/arithmetization/gate_data.hpp index 55afb319b0d..1efca4931e7 100644 --- a/barretenberg/cpp/src/barretenberg/proof_system/arithmetization/gate_data.hpp +++ b/barretenberg/cpp/src/barretenberg/proof_system/arithmetization/gate_data.hpp @@ -138,19 +138,29 @@ template struct databus_lookup_gate_ { uint32_t value; }; +/* External gate data for poseidon2 external round*/ template struct poseidon2_external_gate_ { uint32_t a; uint32_t b; uint32_t c; uint32_t d; - uint32_t round_idx; + size_t round_idx; }; +/* Internal gate data for poseidon2 internal round*/ template struct poseidon2_internal_gate_ { uint32_t a; uint32_t b; uint32_t c; uint32_t d; - uint32_t round_idx; + size_t round_idx; +}; + +/* Last gate for poseidon2, needed because poseidon2 gates compare against the shifted wires. */ +template struct poseidon2_end_gate_ { + uint32_t a; + uint32_t b; + uint32_t c; + uint32_t d; }; } // namespace proof_system diff --git a/barretenberg/cpp/src/barretenberg/proof_system/circuit_builder/goblin_ultra_circuit_builder.cpp b/barretenberg/cpp/src/barretenberg/proof_system/circuit_builder/goblin_ultra_circuit_builder.cpp index de95ff324cb..6c3911f6d05 100644 --- a/barretenberg/cpp/src/barretenberg/proof_system/circuit_builder/goblin_ultra_circuit_builder.cpp +++ b/barretenberg/cpp/src/barretenberg/proof_system/circuit_builder/goblin_ultra_circuit_builder.cpp @@ -260,6 +260,9 @@ void GoblinUltraCircuitBuilder_::create_calldata_lookup_gate(const databus_l ++this->num_gates; } +/** + * @brief Poseidon2 external round gate, activates the q_poseidon2_external selector and relation + */ template void GoblinUltraCircuitBuilder_::create_poseidon2_external_gate(const poseidon2_external_gate_& in) { @@ -284,6 +287,9 @@ void GoblinUltraCircuitBuilder_::create_poseidon2_external_gate(const poseid ++this->num_gates; } +/** + * @brief Poseidon2 internal round gate, activates the q_poseidon2_internal selector and relation + */ template void GoblinUltraCircuitBuilder_::create_poseidon2_internal_gate(const poseidon2_internal_gate_& in) { @@ -308,6 +314,36 @@ void GoblinUltraCircuitBuilder_::create_poseidon2_internal_gate(const poseid ++this->num_gates; } +/** + * @brief Poseidon2 end round gate, needed because poseidon2 rounds compare with shifted wires + * @details The Poseidon2 permutation is 64 rounds, but needs to be a block of 65 rows, since the result of applying a + * round of Poseidon2 is stored in the next row (the shifted row). As a result, we need this end row to compare with the + * result from the 64th round of Poseidon2. Note that it does not activate any selectors since it only serves as a + * comparison through the shifted wires. + */ +template void GoblinUltraCircuitBuilder_::create_poseidon2_end_gate(const poseidon2_end_gate_& in) +{ + this->w_l().emplace_back(in.a); + this->w_r().emplace_back(in.b); + this->w_o().emplace_back(in.c); + this->w_4().emplace_back(in.d); + this->q_m().emplace_back(0); + this->q_1().emplace_back(0); + this->q_2().emplace_back(0); + this->q_3().emplace_back(0); + this->q_c().emplace_back(0); + this->q_arith().emplace_back(0); + this->q_4().emplace_back(0); + this->q_sort().emplace_back(0); + this->q_lookup_type().emplace_back(0); + this->q_elliptic().emplace_back(0); + this->q_aux().emplace_back(0); + this->q_busread().emplace_back(0); + this->q_poseidon2_external().emplace_back(0); + this->q_poseidon2_internal().emplace_back(0); + ++this->num_gates; +} + template inline FF GoblinUltraCircuitBuilder_::compute_poseidon2_external_identity(FF q_poseidon2_external_value, FF q_1_value, diff --git a/barretenberg/cpp/src/barretenberg/proof_system/circuit_builder/goblin_ultra_circuit_builder.hpp b/barretenberg/cpp/src/barretenberg/proof_system/circuit_builder/goblin_ultra_circuit_builder.hpp index ac6d25c2603..fb130d887a7 100644 --- a/barretenberg/cpp/src/barretenberg/proof_system/circuit_builder/goblin_ultra_circuit_builder.hpp +++ b/barretenberg/cpp/src/barretenberg/proof_system/circuit_builder/goblin_ultra_circuit_builder.hpp @@ -175,6 +175,7 @@ template class GoblinUltraCircuitBuilder_ : public UltraCircuitBui void create_poseidon2_external_gate(const poseidon2_external_gate_& in); void create_poseidon2_internal_gate(const poseidon2_internal_gate_& in); + void create_poseidon2_end_gate(const poseidon2_end_gate_& in); FF compute_poseidon2_external_identity(FF q_poseidon2_external_value, FF q_1_value, diff --git a/barretenberg/cpp/src/barretenberg/stdlib/hash/CMakeLists.txt b/barretenberg/cpp/src/barretenberg/stdlib/hash/CMakeLists.txt index 2ce247a010a..0f20819bea6 100644 --- a/barretenberg/cpp/src/barretenberg/stdlib/hash/CMakeLists.txt +++ b/barretenberg/cpp/src/barretenberg/stdlib/hash/CMakeLists.txt @@ -3,4 +3,5 @@ add_subdirectory(blake3s) add_subdirectory(pedersen) add_subdirectory(sha256) add_subdirectory(keccak) -add_subdirectory(benchmarks) \ No newline at end of file +add_subdirectory(benchmarks) +add_subdirectory(poseidon2) \ No newline at end of file diff --git a/barretenberg/cpp/src/barretenberg/stdlib/hash/pedersen/pedersen.cpp b/barretenberg/cpp/src/barretenberg/stdlib/hash/pedersen/pedersen.cpp index 9f424a9cdf9..efd40983af9 100644 --- a/barretenberg/cpp/src/barretenberg/stdlib/hash/pedersen/pedersen.cpp +++ b/barretenberg/cpp/src/barretenberg/stdlib/hash/pedersen/pedersen.cpp @@ -6,7 +6,7 @@ using namespace barretenberg; using namespace proof_system; template -field_t pedersen_hash::hash(const std::vector& inputs, const GeneratorContext context) +field_t pedersen_hash::hash(const std::vector& inputs, const GeneratorContext context) { using cycle_scalar = typename cycle_group::cycle_scalar; using Curve = EmbeddedCurve; @@ -15,7 +15,7 @@ field_t pedersen_hash::hash(const std::vector& inputs, const Gene std::vector scalars; std::vector points; - scalars.emplace_back(cycle_scalar::create_from_bn254_scalar(field_t(inputs.size()))); + scalars.emplace_back(cycle_scalar::create_from_bn254_scalar(field_ct(inputs.size()))); points.emplace_back(crypto::pedersen_hash_base::length_generator); for (size_t i = 0; i < inputs.size(); ++i) { scalars.emplace_back(cycle_scalar::create_from_bn254_scalar(inputs[i])); @@ -28,7 +28,7 @@ field_t pedersen_hash::hash(const std::vector& inputs, const Gene } template -field_t pedersen_hash::hash_skip_field_validation(const std::vector& inputs, +field_t pedersen_hash::hash_skip_field_validation(const std::vector& inputs, const GeneratorContext context) { using cycle_scalar = typename cycle_group::cycle_scalar; @@ -38,7 +38,7 @@ field_t pedersen_hash::hash_skip_field_validation(const std::vector scalars; std::vector points; - scalars.emplace_back(cycle_scalar::create_from_bn254_scalar(field_t(inputs.size()))); + scalars.emplace_back(cycle_scalar::create_from_bn254_scalar(field_ct(inputs.size()))); points.emplace_back(crypto::pedersen_hash_base::length_generator); for (size_t i = 0; i < inputs.size(); ++i) { // `true` param = skip primality test when performing a scalar mul @@ -52,7 +52,7 @@ field_t pedersen_hash::hash_skip_field_validation(const std::vector pedersen_hash::hash_buffer(const stdlib::byte_array& input, Gen const size_t bytes_per_element = 31; size_t num_elements = static_cast(num_bytes % bytes_per_element != 0) + (num_bytes / bytes_per_element); - std::vector elements; + std::vector elements; for (size_t i = 0; i < num_elements; ++i) { size_t bytes_to_slice = 0; if (i == num_elements - 1) { @@ -72,13 +72,10 @@ field_t pedersen_hash::hash_buffer(const stdlib::byte_array& input, Gen } else { bytes_to_slice = bytes_per_element; } - auto element = static_cast(input.slice(i * bytes_per_element, bytes_to_slice)); + auto element = static_cast(input.slice(i * bytes_per_element, bytes_to_slice)); elements.emplace_back(element); } - for (auto& x : elements) { - std::cout << x << std::endl; - } - field_t hashed; + field_ct hashed; if (elements.size() < 2) { hashed = hash(elements, context); } else { diff --git a/barretenberg/cpp/src/barretenberg/stdlib/hash/pedersen/pedersen.hpp b/barretenberg/cpp/src/barretenberg/stdlib/hash/pedersen/pedersen.hpp index 26b2b484925..7babfd0d285 100644 --- a/barretenberg/cpp/src/barretenberg/stdlib/hash/pedersen/pedersen.hpp +++ b/barretenberg/cpp/src/barretenberg/stdlib/hash/pedersen/pedersen.hpp @@ -18,17 +18,17 @@ using namespace barretenberg; template class pedersen_hash { private: - using field_t = stdlib::field_t; + using field_ct = stdlib::field_t; using bool_t = stdlib::bool_t; using EmbeddedCurve = typename cycle_group::Curve; using GeneratorContext = crypto::GeneratorContext; using cycle_group = stdlib::cycle_group; public: - static field_t hash(const std::vector& in, GeneratorContext context = {}); + static field_ct hash(const std::vector& in, GeneratorContext context = {}); // TODO health warnings! - static field_t hash_skip_field_validation(const std::vector& in, GeneratorContext context = {}); - static field_t hash_buffer(const stdlib::byte_array& input, GeneratorContext context = {}); + static field_ct hash_skip_field_validation(const std::vector& in, GeneratorContext context = {}); + static field_ct hash_buffer(const stdlib::byte_array& input, GeneratorContext context = {}); }; EXTERN_STDLIB_TYPE(pedersen_hash); diff --git a/barretenberg/cpp/src/barretenberg/stdlib/hash/poseidon2/CMakeLists.txt b/barretenberg/cpp/src/barretenberg/stdlib/hash/poseidon2/CMakeLists.txt new file mode 100644 index 00000000000..6869b55d9bb --- /dev/null +++ b/barretenberg/cpp/src/barretenberg/stdlib/hash/poseidon2/CMakeLists.txt @@ -0,0 +1 @@ +barretenberg_module(stdlib_poseidon2 stdlib_primitives crypto_poseidon2) \ No newline at end of file diff --git a/barretenberg/cpp/src/barretenberg/stdlib/hash/poseidon2/poseidon2.cpp b/barretenberg/cpp/src/barretenberg/stdlib/hash/poseidon2/poseidon2.cpp new file mode 100644 index 00000000000..ff8e31dd56e --- /dev/null +++ b/barretenberg/cpp/src/barretenberg/stdlib/hash/poseidon2/poseidon2.cpp @@ -0,0 +1,46 @@ +#include "barretenberg/stdlib/hash/poseidon2/poseidon2.hpp" +#include "barretenberg/ecc/curves/grumpkin/grumpkin.hpp" +namespace proof_system::plonk::stdlib { + +using namespace barretenberg; +using namespace proof_system; + +/** + * @brief Hash a vector of field_ct. + */ +template field_t poseidon2::hash(C& builder, const std::vector& inputs) +{ + + /* Run the sponge by absorbing all the input and squeezing one output. + * This should just call the sponge variable length hash function + * + */ + auto input{ inputs }; + return Sponge::hash_fixed_length(builder, input); +} + +/** + * @brief Hash a byte_array. + */ +template field_t poseidon2::hash_buffer(C& builder, const stdlib::byte_array& input) +{ + const size_t num_bytes = input.size(); + const size_t bytes_per_element = 31; // 31 bytes in a fr element + size_t num_elements = static_cast(num_bytes % bytes_per_element != 0) + (num_bytes / bytes_per_element); + + std::vector elements; + for (size_t i = 0; i < num_elements; ++i) { + size_t bytes_to_slice = 0; + if (i == num_elements - 1) { + bytes_to_slice = num_bytes - (i * bytes_per_element); + } else { + bytes_to_slice = bytes_per_element; + } + auto element = static_cast(input.slice(i * bytes_per_element, bytes_to_slice)); + elements.emplace_back(element); + } + return hash(builder, elements); +} +template class poseidon2; + +} // namespace proof_system::plonk::stdlib diff --git a/barretenberg/cpp/src/barretenberg/stdlib/hash/poseidon2/poseidon2.hpp b/barretenberg/cpp/src/barretenberg/stdlib/hash/poseidon2/poseidon2.hpp new file mode 100644 index 00000000000..fcfb93f91d2 --- /dev/null +++ b/barretenberg/cpp/src/barretenberg/stdlib/hash/poseidon2/poseidon2.hpp @@ -0,0 +1,35 @@ +#pragma once +#include "barretenberg/crypto/poseidon2/poseidon2_params.hpp" +#include "barretenberg/stdlib/hash/poseidon2/sponge/sponge.hpp" +#include "barretenberg/stdlib/primitives/byte_array/byte_array.hpp" +#include "barretenberg/stdlib/primitives/field/field.hpp" + +#include "../../primitives/circuit_builders/circuit_builders.hpp" + +namespace proof_system::plonk::stdlib { + +using namespace barretenberg; +/** + * @brief stdlib class that evaluates in-circuit poseidon2 hashes, consistent with behavior in + * crypto::poseidon2 + * + * @tparam Builder + */ +template class poseidon2 { + + private: + using field_ct = stdlib::field_t; + using bool_ct = stdlib::bool_t; + using Params = crypto::Poseidon2Bn254ScalarFieldParams; + using Permutation = Poseidon2Permutation; + // We choose our rate to be t-1 and capacity to be 1. + using Sponge = FieldSponge; + + public: + static field_ct hash(Builder& builder, const std::vector& in); + static field_ct hash_buffer(Builder& builder, const stdlib::byte_array& input); +}; + +extern template class poseidon2; + +} // namespace proof_system::plonk::stdlib diff --git a/barretenberg/cpp/src/barretenberg/stdlib/hash/poseidon2/poseidon2.test.cpp b/barretenberg/cpp/src/barretenberg/stdlib/hash/poseidon2/poseidon2.test.cpp new file mode 100644 index 00000000000..797cec6669c --- /dev/null +++ b/barretenberg/cpp/src/barretenberg/stdlib/hash/poseidon2/poseidon2.test.cpp @@ -0,0 +1,190 @@ +#include "poseidon2.hpp" +#include "barretenberg/common/test.hpp" +#include "barretenberg/crypto/poseidon2/poseidon2.hpp" +#include "barretenberg/numeric/random/engine.hpp" +#include "barretenberg/stdlib/primitives/curves/bn254.hpp" + +namespace test_StdlibPoseidon2 { +using namespace barretenberg; +using namespace proof_system::plonk; +namespace { +auto& engine = numeric::random::get_debug_engine(); +} + +template class StdlibPoseidon2 : public testing::Test { + using _curve = stdlib::bn254; + + using byte_array_ct = typename _curve::byte_array_ct; + using fr_ct = typename _curve::ScalarField; + using witness_ct = typename _curve::witness_ct; + using public_witness_ct = typename _curve::public_witness_ct; + using poseidon2 = typename stdlib::poseidon2; + using native_poseidon2 = crypto::Poseidon2; + + public: + /** + * @brief Call poseidon2 on a vector of inputs + * + * @param num_inputs + */ + static void test_hash(size_t num_inputs) + { + using field_ct = stdlib::field_t; + using witness_ct = stdlib::witness_t; + auto builder = Builder(); + + std::vector inputs; + std::vector inputs_native; + + for (size_t i = 0; i < num_inputs; ++i) { + const auto element = fr::random_element(&engine); + inputs_native.emplace_back(element); + inputs.emplace_back(field_ct(witness_ct(&builder, element))); + } + + auto result = stdlib::poseidon2::hash(builder, inputs); + auto expected = crypto::Poseidon2::hash(inputs_native); + + EXPECT_EQ(result.get_value(), expected); + + bool proof_result = builder.check_circuit(); + EXPECT_EQ(proof_result, true); + } + + /** + * @brief Call poseidon2 on two inputs repeatedly. + * + * @param num_inputs + */ + static void test_hash_repeated_pairs(size_t num_inputs) + { + Builder builder; + + fr left_in = fr::random_element(); + fr right_in = fr::random_element(); + + fr_ct left = witness_ct(&builder, left_in); + fr_ct right = witness_ct(&builder, right_in); + + // num_inputs - 1 iterations since the first hash hashes two elements + for (size_t i = 0; i < num_inputs - 1; ++i) { + left = poseidon2::hash(builder, { left, right }); + } + + builder.set_public_input(left.witness_index); + + info("num gates = ", builder.get_num_gates()); + + bool result = builder.check_circuit(); + EXPECT_EQ(result, true); + } + /** + * @brief Call poseidon2 hash_buffer on a vector of bytes + * + * @param num_input_bytes + */ + static void test_hash_byte_array(size_t num_input_bytes) + { + Builder builder; + + std::vector input; + input.reserve(num_input_bytes); + for (size_t i = 0; i < num_input_bytes; ++i) { + input.push_back(engine.get_random_uint8()); + } + + fr expected = native_poseidon2::hash_buffer(input); + + byte_array_ct circuit_input(&builder, input); + auto result = poseidon2::hash_buffer(builder, circuit_input); + + EXPECT_EQ(result.get_value(), expected); + + info("num gates = ", builder.get_num_gates()); + + bool proof_result = builder.check_circuit(); + EXPECT_EQ(proof_result, true); + } + + static void test_hash_zeros(size_t num_inputs) + { + Builder builder; + + std::vector inputs; + inputs.reserve(num_inputs); + std::vector> witness_inputs; + + for (size_t i = 0; i < num_inputs; ++i) { + inputs.emplace_back(0); + witness_inputs.emplace_back(witness_ct(&builder, inputs[i])); + } + + fr expected = native_poseidon2::hash(inputs); + auto result = poseidon2::hash(builder, witness_inputs); + + EXPECT_EQ(result.get_value(), expected); + } + + static void test_hash_constants() + { + Builder builder; + + std::vector inputs; + std::vector> witness_inputs; + + for (size_t i = 0; i < 8; ++i) { + inputs.push_back(barretenberg::fr::random_element()); + if (i % 2 == 1) { + witness_inputs.push_back(witness_ct(&builder, inputs[i])); + } else { + witness_inputs.push_back(fr_ct(&builder, inputs[i])); + } + } + + barretenberg::fr expected = native_poseidon2::hash(inputs); + auto result = poseidon2::hash(builder, witness_inputs); + + EXPECT_EQ(result.get_value(), expected); + } +}; + +using CircuitTypes = testing::Types; + +TYPED_TEST_SUITE(StdlibPoseidon2, CircuitTypes); + +TYPED_TEST(StdlibPoseidon2, TestHashZeros) +{ + TestFixture::test_hash_zeros(8); +}; + +TYPED_TEST(StdlibPoseidon2, TestHashSmall) +{ + TestFixture::test_hash(10); +} + +TYPED_TEST(StdlibPoseidon2, TestHashLarge) +{ + TestFixture::test_hash(1000); +} + +TYPED_TEST(StdlibPoseidon2, TestHashRepeatedPairs) +{ + TestFixture::test_hash_repeated_pairs(256); +} + +TYPED_TEST(StdlibPoseidon2, TestHashByteArraySmall) +{ + TestFixture::test_hash_byte_array(351); +}; + +TYPED_TEST(StdlibPoseidon2, TestHashByteArrayLarge) +{ + TestFixture::test_hash_byte_array(31000); +}; + +TYPED_TEST(StdlibPoseidon2, TestHashConstants) +{ + TestFixture::test_hash_constants(); +}; + +} // namespace test_StdlibPoseidon2 diff --git a/barretenberg/cpp/src/barretenberg/stdlib/hash/poseidon2/poseidon2_permutation.cpp b/barretenberg/cpp/src/barretenberg/stdlib/hash/poseidon2/poseidon2_permutation.cpp new file mode 100644 index 00000000000..fc8a8cd300d --- /dev/null +++ b/barretenberg/cpp/src/barretenberg/stdlib/hash/poseidon2/poseidon2_permutation.cpp @@ -0,0 +1,208 @@ +#include "poseidon2_permutation.hpp" + +#include "barretenberg/proof_system/arithmetization/gate_data.hpp" +#include "barretenberg/proof_system/circuit_builder/goblin_ultra_circuit_builder.hpp" + +namespace proof_system::plonk::stdlib { + +/** + * @brief Circuit form of Poseidon2 permutation from https://eprint.iacr.org/2023/323. + * @details The permutation consists of one initial linear layer, then a set of external rounds, a set of internal + * rounds, and a set of external rounds. + * @param builder + * @param input + * @return State + */ +template +typename Poseidon2Permutation::State Poseidon2Permutation::permutation( + Builder* builder, const typename Poseidon2Permutation::State& input) +{ + // deep copy + State current_state(input); + NativeState current_native_state; + for (size_t i = 0; i < t; ++i) { + current_native_state[i] = current_state[i].get_value(); + } + + // Apply 1st linear layer + NativePermutation::matrix_multiplication_external(current_native_state); + initial_external_matrix_multiplication(builder, current_state); + + // First set of external rounds + constexpr size_t rounds_f_beginning = rounds_f / 2; + for (size_t i = 0; i < rounds_f_beginning; ++i) { + poseidon2_external_gate_ in{ current_state[0].witness_index, + current_state[1].witness_index, + current_state[2].witness_index, + current_state[3].witness_index, + i }; + builder->create_poseidon2_external_gate(in); + // calculate the new witnesses + NativePermutation::add_round_constants(current_native_state, round_constants[i]); + NativePermutation::apply_sbox(current_native_state); + NativePermutation::matrix_multiplication_external(current_native_state); + for (size_t j = 0; j < t; ++j) { + current_state[j] = witness_t(builder, current_native_state[j]); + } + } + + // Internal rounds + const size_t p_end = rounds_f_beginning + rounds_p; + for (size_t i = rounds_f_beginning; i < p_end; ++i) { + poseidon2_internal_gate_ in{ current_state[0].witness_index, + current_state[1].witness_index, + current_state[2].witness_index, + current_state[3].witness_index, + i }; + builder->create_poseidon2_internal_gate(in); + current_native_state[0] += round_constants[i][0]; + NativePermutation::apply_single_sbox(current_native_state[0]); + NativePermutation::matrix_multiplication_internal(current_native_state); + for (size_t j = 0; j < t; ++j) { + current_state[j] = witness_t(builder, current_native_state[j]); + } + } + + // Remaining external rounds + for (size_t i = p_end; i < NUM_ROUNDS; ++i) { + poseidon2_external_gate_ in{ current_state[0].witness_index, + current_state[1].witness_index, + current_state[2].witness_index, + current_state[3].witness_index, + i }; + builder->create_poseidon2_external_gate(in); + // calculate the new witnesses + NativePermutation::add_round_constants(current_native_state, round_constants[i]); + NativePermutation::apply_sbox(current_native_state); + NativePermutation::matrix_multiplication_external(current_native_state); + for (size_t j = 0; j < t; ++j) { + current_state[j] = witness_t(builder, current_native_state[j]); + } + } + // need to add an extra row here to ensure that things check out, more details found in poseidon2_end_gate_ + // definition + poseidon2_end_gate_ in{ + current_state[0].witness_index, + current_state[1].witness_index, + current_state[2].witness_index, + current_state[3].witness_index, + }; + builder->create_poseidon2_end_gate(in); + return current_state; +} + +/** + * @brief Separate function to do just the first linear layer (equivalent to external matrix mul). + * @details We use 6 arithmetic gates to implement: + * gate 1: Compute tmp1 = state[0] + state[1] + 2 * state[3] + * gate 2: Compute tmp2 = 2 * state[1] + state[2] + state[3] + * gate 3: Compute v2 = 4 * state[0] + 4 * state[1] + tmp2 + * gate 4: Compute v1 = v2 + tmp1 + * gate 5: Compute v4 = tmp1 + 4 * state[2] + 4 * state[3] + * gate 6: Compute v3 = v4 + tmp2 + * output state is [v1, v2, v3, v4] + * @param builder + * @param state + */ +template +void Poseidon2Permutation::initial_external_matrix_multiplication( + Builder* builder, typename Poseidon2Permutation::State& state) +{ + // create the 6 gates for the initial matrix multiplication + // gate 1: Compute tmp1 = state[0] + state[1] + 2 * state[3] + field_t tmp1 = + witness_t(builder, state[0].get_value() + state[1].get_value() + FF(2) * state[3].get_value()); + builder->create_big_add_gate({ + .a = state[0].witness_index, + .b = state[1].witness_index, + .c = state[3].witness_index, + .d = tmp1.witness_index, + .a_scaling = 1, + .b_scaling = 1, + .c_scaling = 2, + .d_scaling = -1, + .const_scaling = 0, + }); + + // gate 2: Compute tmp2 = 2 * state[1] + state[2] + state[3] + field_t tmp2 = + witness_t(builder, FF(2) * state[1].get_value() + state[2].get_value() + state[3].get_value()); + builder->create_big_add_gate({ + .a = state[1].witness_index, + .b = state[2].witness_index, + .c = state[3].witness_index, + .d = tmp2.witness_index, + .a_scaling = 2, + .b_scaling = 1, + .c_scaling = 1, + .d_scaling = -1, + .const_scaling = 0, + }); + + // gate 3: Compute v2 = 4 * state[0] + 4 * state[1] + tmp2 + field_t v2 = + witness_t(builder, FF(4) * state[0].get_value() + FF(4) * state[1].get_value() + tmp2.get_value()); + builder->create_big_add_gate({ + .a = state[0].witness_index, + .b = state[1].witness_index, + .c = tmp2.witness_index, + .d = v2.witness_index, + .a_scaling = 4, + .b_scaling = 4, + .c_scaling = 1, + .d_scaling = -1, + .const_scaling = 0, + }); + + // gate 4: Compute v1 = v2 + tmp1 + field_t v1 = witness_t(builder, v2.get_value() + tmp1.get_value()); + builder->create_big_add_gate({ + .a = v2.witness_index, + .b = tmp1.witness_index, + .c = v1.witness_index, + .d = builder->zero_idx, + .a_scaling = 1, + .b_scaling = 1, + .c_scaling = -1, + .d_scaling = 0, + .const_scaling = 0, + }); + + // gate 5: Compute v4 = tmp1 + 4 * state[2] + 4 * state[3] + field_t v4 = + witness_t(builder, tmp1.get_value() + FF(4) * state[2].get_value() + FF(4) * state[3].get_value()); + builder->create_big_add_gate({ + .a = tmp1.witness_index, + .b = state[2].witness_index, + .c = state[3].witness_index, + .d = v4.witness_index, + .a_scaling = 1, + .b_scaling = 4, + .c_scaling = 4, + .d_scaling = -1, + .const_scaling = 0, + }); + + // gate 6: Compute v3 = v4 + tmp2 + field_t v3 = witness_t(builder, v4.get_value() + tmp2.get_value()); + builder->create_big_add_gate({ + .a = v4.witness_index, + .b = tmp2.witness_index, + .c = v3.witness_index, + .d = builder->zero_idx, + .a_scaling = 1, + .b_scaling = 1, + .c_scaling = -1, + .d_scaling = 0, + .const_scaling = 0, + }); + + state[0] = v1; + state[1] = v2; + state[2] = v3; + state[3] = v4; +} + +template class Poseidon2Permutation; + +} // namespace proof_system::plonk::stdlib \ No newline at end of file diff --git a/barretenberg/cpp/src/barretenberg/stdlib/hash/poseidon2/poseidon2_permutation.hpp b/barretenberg/cpp/src/barretenberg/stdlib/hash/poseidon2/poseidon2_permutation.hpp new file mode 100644 index 00000000000..08391b7dafe --- /dev/null +++ b/barretenberg/cpp/src/barretenberg/stdlib/hash/poseidon2/poseidon2_permutation.hpp @@ -0,0 +1,67 @@ +#pragma once +#include +#include +#include + +#include "barretenberg/crypto/poseidon2/poseidon2_permutation.hpp" +#include "barretenberg/stdlib/primitives/field/field.hpp" + +namespace proof_system::plonk::stdlib { + +using namespace proof_system; +template class Poseidon2Permutation { + public: + using NativePermutation = crypto::Poseidon2Permutation; + // t = sponge permutation size (in field elements) + // t = rate + capacity + // capacity = 1 field element (256 bits) + // rate = number of field elements that can be compressed per permutation + static constexpr size_t t = Params::t; + // d = degree of s-box polynomials. For a given field, `d` is the smallest element of `p` such that gdc(d, p - 1) = + // 1 (excluding 1) For bn254/grumpkin, d = 5 + static constexpr size_t d = Params::d; + // sbox size = number of bits in p + static constexpr size_t sbox_size = Params::sbox_size; + // number of full sbox rounds + static constexpr size_t rounds_f = Params::rounds_f; + // number of partial sbox rounds + static constexpr size_t rounds_p = Params::rounds_p; + static constexpr size_t NUM_ROUNDS = Params::rounds_f + Params::rounds_p; + + using FF = typename Params::FF; + using State = std::array, t>; + using NativeState = std::array; + + using RoundConstants = std::array; + using RoundConstantsContainer = std::array; + static constexpr RoundConstantsContainer round_constants = Params::round_constants; + + /** + * @brief Circuit form of Poseidon2 permutation from https://eprint.iacr.org/2023/323. + * @details The permutation consists of one initial linear layer, then a set of external rounds, a set of internal + * rounds, and a set of external rounds. + * @param builder + * @param input + * @return State + */ + static State permutation(Builder* builder, const State& input); + + /** + * @brief Separate function to do just the first linear layer (equivalent to external matrix mul). + * @details We use 6 arithmetic gates to implement: + * gate 1: Compute tmp1 = state[0] + state[1] + 2 * state[3] + * gate 2: Compute tmp2 = 2 * state[1] + state[2] + state[3] + * gate 3: Compute v2 = 4 * state[0] + 4 * state[1] + tmp2 + * gate 4: Compute v1 = v2 + tmp1 + * gate 5: Compute v4 = tmp1 + 4 * state[2] + 4 * state[3] + * gate 6: Compute v3 = v4 + tmp2 + * output state is [v1, v2, v3, v4] + * @param builder + * @param state + */ + static void initial_external_matrix_multiplication(Builder* builder, State& state); +}; + +extern template class Poseidon2Permutation; + +} // namespace proof_system::plonk::stdlib \ No newline at end of file diff --git a/barretenberg/cpp/src/barretenberg/stdlib/hash/poseidon2/sponge/sponge.hpp b/barretenberg/cpp/src/barretenberg/stdlib/hash/poseidon2/sponge/sponge.hpp new file mode 100644 index 00000000000..359faf50f94 --- /dev/null +++ b/barretenberg/cpp/src/barretenberg/stdlib/hash/poseidon2/sponge/sponge.hpp @@ -0,0 +1,182 @@ +#pragma once + +#include +#include +#include +#include + +#include "barretenberg/numeric/uint256/uint256.hpp" +#include "barretenberg/stdlib/hash/poseidon2/poseidon2_permutation.hpp" +#include "barretenberg/stdlib/primitives/field/field.hpp" + +namespace proof_system::plonk::stdlib { + +/** + * @brief Implements the circuit form of a cryptographic sponge over prime fields. + * Implements the sponge specification from the Community Cryptographic Specification Project + * see https://github.com/C2SP/C2SP/blob/792c1254124f625d459bfe34417e8f6bdd02eb28/poseidon-sponge.md + * (Note: this spec was not accepted into the C2SP repo, we might want to reference something else!) + * + * Note: If we ever use this sponge class for more than 1 hash functions, we should move this out of `poseidon2` + * and into its own directory + * @tparam field_t + * @tparam rate + * @tparam capacity + * @tparam t + * @tparam Permutation + */ +template class FieldSponge { + public: + /** + * @brief Defines what phase of the sponge algorithm we are in. + * + * ABSORB: 'absorbing' field elements into the sponge + * SQUEEZE: compressing the sponge and extracting a field element + * + */ + enum Mode { + ABSORB, + SQUEEZE, + }; + using field_t = stdlib::field_t; + + // sponge state. t = rate + capacity. capacity = 1 field element (~256 bits) + std::array state; + + // cached elements that have been absorbed. + std::array cache; + size_t cache_size = 0; + Mode mode = Mode::ABSORB; + Builder* builder; + + FieldSponge(Builder& builder_, field_t domain_iv = 0) + : builder(&builder_) + { + for (size_t i = 0; i < rate; ++i) { + state[i] = witness_t(builder, 0); + } + state[rate] = witness_t(builder, domain_iv.get_value()); + } + + std::array perform_duplex() + { + // zero-pad the cache + for (size_t i = cache_size; i < rate; ++i) { + cache[i] = witness_t(builder, 0); + } + // add the cache into sponge state + for (size_t i = 0; i < rate; ++i) { + state[i] += cache[i]; + } + state = Permutation::permutation(builder, state); + // return `rate` number of field elements from the sponge state. + std::array output; + for (size_t i = 0; i < rate; ++i) { + output[i] = state[i]; + } + return output; + } + + void absorb(const field_t& input) + { + if (mode == Mode::ABSORB && cache_size == rate) { + // If we're absorbing, and the cache is full, apply the sponge permutation to compress the cache + perform_duplex(); + cache[0] = input; + cache_size = 1; + } else if (mode == Mode::ABSORB && cache_size < rate) { + // If we're absorbing, and the cache is not full, add the input into the cache + cache[cache_size] = input; + cache_size += 1; + } else if (mode == Mode::SQUEEZE) { + // If we're in squeeze mode, switch to absorb mode and add the input into the cache. + // N.B. I don't think this code path can be reached?! + cache[0] = input; + cache_size = 1; + mode = Mode::ABSORB; + } + } + + field_t squeeze() + { + if (mode == Mode::SQUEEZE && cache_size == 0) { + // If we're in squeze mode and the cache is empty, there is nothing left to squeeze out of the sponge! + // Switch to absorb mode. + mode = Mode::ABSORB; + cache_size = 0; + } + if (mode == Mode::ABSORB) { + // If we're in absorb mode, apply sponge permutation to compress the cache, populate cache with compressed + // state and switch to squeeze mode. Note: this code block will execute if the previous `if` condition was + // matched + auto new_output_elements = perform_duplex(); + mode = Mode::SQUEEZE; + for (size_t i = 0; i < rate; ++i) { + cache[i] = new_output_elements[i]; + } + cache_size = rate; + } + // By this point, we should have a non-empty cache. Pop one item off the top of the cache and return it. + field_t result = cache[0]; + for (size_t i = 1; i < cache_size; ++i) { + cache[i - 1] = cache[i]; + } + cache_size -= 1; + cache[cache_size] = witness_t(builder, 0); + return result; + } + + /** + * @brief Use the sponge to hash an input string + * + * @tparam out_len + * @tparam is_variable_length. Distinguishes between hashes where the preimage length is constant/not constant + * @param input + * @return std::array + */ + template + static std::array hash_internal(Builder& builder, std::span input) + { + size_t in_len = input.size(); + const uint256_t iv = (static_cast(in_len) << 64) + out_len - 1; + FieldSponge sponge(builder, iv); + + for (size_t i = 0; i < in_len; ++i) { + sponge.absorb(input[i]); + } + + // In the case where the hash preimage is variable-length, we append `1` to the end of the input, to distinguish + // from fixed-length hashes. (the combination of this additional field element + the hash IV ensures + // fixed-length and variable-length hashes do not collide) + if constexpr (is_variable_length) { + sponge.absorb(1); + } + + std::array output; + for (size_t i = 0; i < out_len; ++i) { + output[i] = sponge.squeeze(); + } + return output; + } + + template + static std::array hash_fixed_length(Builder& builder, std::span input) + { + return hash_internal(builder, input); + } + static field_t hash_fixed_length(Builder& builder, std::span input) + { + return hash_fixed_length<1>(builder, input)[0]; + } + + template + static std::array hash_variable_length(Builder& builder, std::span input) + { + return hash_internal(builder, input); + } + static field_t hash_variable_length(Builder& builder, std::span input) + { + return hash_variable_length<1>(builder, input)[0]; + } +}; +} // namespace proof_system::plonk::stdlib \ No newline at end of file