Skip to content

Commit

Permalink
refactor: poseidon2 hash uses span instead of vector (AztecProtocol#4003
Browse files Browse the repository at this point in the history
)

Updating the hash() function in poseidon2 to take in a std::span instead
of a std::vector, and avoiding a needless copy.

Also updates poseidon2 hash_buffer test to check an inequivalent hash.
  • Loading branch information
lucasxia01 authored Feb 1, 2024
1 parent 0d682c7 commit 94b6800
Show file tree
Hide file tree
Showing 5 changed files with 16 additions and 14 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,7 @@ namespace bb::crypto {
template <typename Params>
typename Poseidon2<Params>::FF Poseidon2<Params>::hash(const std::vector<typename Poseidon2<Params>::FF>& input)
{
auto input_span = input;
return Sponge::hash_fixed_length(input_span);
return Sponge::hash_fixed_length(input);
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,12 +52,15 @@ TEST(Poseidon2, HashBufferConsistencyCheck)
// element
fr a(std::string("00000b615c4d3e2fa0b1c2d3e4f56789fedcba9876543210abcdef0123456789"));

auto input_vec = to_buffer(a); // takes field element and converts it to 32 bytes
// takes field element and converts it to 32 bytes
auto input_vec = to_buffer(a);
bb::fr result1 = crypto::Poseidon2<crypto::Poseidon2Bn254ScalarFieldParams>::hash_buffer(input_vec);
input_vec.erase(input_vec.begin()); // erase first byte since we want 31 bytes
fr result2 = crypto::Poseidon2<crypto::Poseidon2Bn254ScalarFieldParams>::hash_buffer(input_vec);

std::vector<fr> input{ a };
auto expected = crypto::Poseidon2<crypto::Poseidon2Bn254ScalarFieldParams>::hash(input);

fr result = crypto::Poseidon2<crypto::Poseidon2Bn254ScalarFieldParams>::hash_buffer(input_vec);

EXPECT_EQ(result, expected);
EXPECT_NE(result1, expected);
EXPECT_EQ(result2, expected);
}
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,8 @@ template <typename FF, size_t rate, size_t capacity, size_t t, typename Permutat
* @param input
* @return std::array<FF, out_len>
*/
template <size_t out_len, bool is_variable_length> static std::array<FF, out_len> hash_internal(std::span<FF> input)
template <size_t out_len, bool is_variable_length>
static std::array<FF, out_len> hash_internal(std::span<const FF> input)
{
size_t in_len = input.size();
const uint256_t iv = (static_cast<uint256_t>(in_len) << 64) + out_len - 1;
Expand All @@ -153,11 +154,11 @@ template <typename FF, size_t rate, size_t capacity, size_t t, typename Permutat
return output;
}

template <size_t out_len> static std::array<FF, out_len> hash_fixed_length(std::span<FF> input)
template <size_t out_len> static std::array<FF, out_len> hash_fixed_length(std::span<const FF> input)
{
return hash_internal<out_len, false>(input);
}
static FF hash_fixed_length(std::span<FF> input) { return hash_fixed_length<1>(input)[0]; }
static FF hash_fixed_length(std::span<const FF> input) { return hash_fixed_length<1>(input)[0]; }

template <size_t out_len> static std::array<FF, out_len> hash_variable_length(std::span<FF> input)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,7 @@ template <typename C> field_t<C> poseidon2<C>::hash(C& builder, const std::vecto
* This should just call the sponge variable length hash function
*
*/
auto input{ inputs };
return Sponge::hash_fixed_length(builder, input);
return Sponge::hash_fixed_length(builder, inputs);
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ template <size_t rate, size_t capacity, size_t t, typename Permutation, typename
* @return std::array<field_t, out_len>
*/
template <size_t out_len, bool is_variable_length>
static std::array<field_t, out_len> hash_internal(Builder& builder, std::span<field_t> input)
static std::array<field_t, out_len> hash_internal(Builder& builder, std::span<const field_t> input)
{
size_t in_len = input.size();
const uint256_t iv = (static_cast<uint256_t>(in_len) << 64) + out_len - 1;
Expand All @@ -160,11 +160,11 @@ template <size_t rate, size_t capacity, size_t t, typename Permutation, typename
}

template <size_t out_len>
static std::array<field_t, out_len> hash_fixed_length(Builder& builder, std::span<field_t> input)
static std::array<field_t, out_len> hash_fixed_length(Builder& builder, std::span<const field_t> input)
{
return hash_internal<out_len, false>(builder, input);
}
static field_t hash_fixed_length(Builder& builder, std::span<field_t> input)
static field_t hash_fixed_length(Builder& builder, std::span<const field_t> input)
{
return hash_fixed_length<1>(builder, input)[0];
}
Expand Down

0 comments on commit 94b6800

Please sign in to comment.