From 41783bd18ac37e537ca19bb0e1e87b53473b55d9 Mon Sep 17 00:00:00 2001 From: Suyash Bagad Date: Wed, 29 Mar 2023 17:57:48 +0100 Subject: [PATCH] `array_push` for Generic Type (#291) * test wip * fix tests. * Make `compress_native` static in recursive vk. * Add comments. --- .../stdlib/primitives/address/address.hpp | 2 + .../stdlib/primitives/field/array.hpp | 24 ++++++ .../stdlib/primitives/field/field.test.cpp | 84 +++++++++++++++++++ .../verification_key/verification_key.hpp | 4 +- 4 files changed, 112 insertions(+), 2 deletions(-) diff --git a/cpp/src/barretenberg/stdlib/primitives/address/address.hpp b/cpp/src/barretenberg/stdlib/primitives/address/address.hpp index 49587baa70..f50c890f0c 100644 --- a/cpp/src/barretenberg/stdlib/primitives/address/address.hpp +++ b/cpp/src/barretenberg/stdlib/primitives/address/address.hpp @@ -102,6 +102,8 @@ template class address_t { bool_t operator==(const address_t& other) const { return this->to_field() == other.to_field(); } + bool_t operator!=(const address_t& other) const { return this->to_field() != other.to_field(); } + field_t to_field() const { return address_; } fr get_value() const { return address_.get_value(); }; diff --git a/cpp/src/barretenberg/stdlib/primitives/field/array.hpp b/cpp/src/barretenberg/stdlib/primitives/field/array.hpp index 5ed9ff32f7..724b65bce6 100644 --- a/cpp/src/barretenberg/stdlib/primitives/field/array.hpp +++ b/cpp/src/barretenberg/stdlib/primitives/field/array.hpp @@ -61,6 +61,10 @@ void array_push(std::array, SIZE>& arr, field_t cons already_pushed.assert_equal(true, "array_push cannot push to a full array"); }; +/** + * Note: this assumes `0` always means 'not used', so be careful. If you actually want `0` to be counted, you'll need + * something else. + */ template inline size_t array_push(std::array>, SIZE>& arr, field_t const& value) { @@ -73,6 +77,10 @@ inline size_t array_push(std::array>, SIZE>& arr throw_or_abort("array_push cannot push to a full array"); }; +/** + * Note: this assumes `0` always means 'not used', so be careful. If you actually want `0` to be counted, you'll need + * something else. + */ template inline size_t array_push(std::array, SIZE>& arr, std::shared_ptr const& value) { @@ -85,6 +93,22 @@ inline size_t array_push(std::array, SIZE>& arr, std::shared_ throw_or_abort("array_push cannot push to a full array"); }; +/** + * Note: this assumes `0` always means 'not used', so be careful. If you actually want `0` to be counted, you'll need + * something else. + */ +template inline void array_push(std::array& arr, T const& value) +{ + bool_t already_pushed = false; + for (size_t i = 0; i < arr.size(); ++i) { + bool_t is_zero = arr[i].is_empty(); + arr[i].conditional_select(!already_pushed && is_zero, value); + + already_pushed |= is_zero; + } + already_pushed.assert_equal(true, "array_push cannot push to a full array"); +}; + /** * Note: this assumes `0` always means 'not used', so be careful. If you actually want `0` to be counted, you'll need * something else. diff --git a/cpp/src/barretenberg/stdlib/primitives/field/field.test.cpp b/cpp/src/barretenberg/stdlib/primitives/field/field.test.cpp index 3304b12dc2..d18d133872 100644 --- a/cpp/src/barretenberg/stdlib/primitives/field/field.test.cpp +++ b/cpp/src/barretenberg/stdlib/primitives/field/field.test.cpp @@ -3,6 +3,7 @@ #include "array.hpp" #include "barretenberg/plonk/proof_system/constants.hpp" #include +#include #include "barretenberg/honk/composer/standard_honk_composer.hpp" #include "barretenberg/plonk/composer/standard_composer.hpp" #include "barretenberg/plonk/composer/ultra_composer.hpp" @@ -1235,6 +1236,81 @@ template class stdlib_field : public testing::Test { EXPECT_EQ(composer.failed(), true); EXPECT_EQ(composer.err(), "push_array_to_array target array capacity exceeded"); } + + class MockClass { + public: + MockClass() + : m_a(field_ct(0)) + , m_b(field_ct(0)) + {} + MockClass(field_ct a, field_ct b) + : m_a(a) + , m_b(b) + {} + + bool_ct is_empty() const { return m_a == 0 && m_b == 0; } + + std::pair get_values() { return std::make_pair(m_a, m_b); } + + void conditional_select(bool_ct const& condition, MockClass const& other) + { + m_a = field_ct::conditional_assign(condition, other.m_a, m_a); + m_b = field_ct::conditional_assign(condition, other.m_b, m_b); + } + + private: + field_ct m_a; + field_ct m_b; + }; + + void test_array_push_generic() + { + Composer composer = Composer(); + + constexpr size_t SIZE = 5; + std::array arr{}; + + // Push values into the array + plonk::stdlib::array_push(arr, MockClass(witness_ct(&composer, 1), witness_ct(&composer, 10))); + plonk::stdlib::array_push(arr, MockClass(witness_ct(&composer, 2), witness_ct(&composer, 20))); + plonk::stdlib::array_push(arr, MockClass(witness_ct(&composer, 3), witness_ct(&composer, 30))); + + // Check the values in the array + EXPECT_EQ(arr[0].get_values().first.get_value(), 1); + EXPECT_EQ(arr[0].get_values().second.get_value(), 10); + EXPECT_EQ(arr[1].get_values().first.get_value(), 2); + EXPECT_EQ(arr[1].get_values().second.get_value(), 20); + EXPECT_EQ(arr[2].get_values().first.get_value(), 3); + EXPECT_EQ(arr[2].get_values().second.get_value(), 30); + + auto prover = composer.create_prover(); + auto verifier = composer.create_verifier(); + auto proof = prover.construct_proof(); + info("composer gates = ", composer.get_num_gates()); + bool proof_result = verifier.verify_proof(proof); + EXPECT_EQ(proof_result, true); + } + + void test_array_push_generic_full() + { + Composer composer = Composer(); + + constexpr size_t SIZE = 5; + std::array arr{}; + + // Push values into the array + plonk::stdlib::array_push(arr, MockClass(witness_ct(&composer, 1), witness_ct(&composer, 10))); + plonk::stdlib::array_push(arr, MockClass(witness_ct(&composer, 2), witness_ct(&composer, 20))); + plonk::stdlib::array_push(arr, MockClass(witness_ct(&composer, 3), witness_ct(&composer, 30))); + plonk::stdlib::array_push(arr, MockClass(witness_ct(&composer, 4), witness_ct(&composer, 40))); + plonk::stdlib::array_push(arr, MockClass(witness_ct(&composer, 5), witness_ct(&composer, 50))); + + // Try to push into a full array + plonk::stdlib::array_push(arr, MockClass(witness_ct(&composer, 6), witness_ct(&composer, 60))); + + EXPECT_EQ(composer.failed(), true); + EXPECT_EQ(composer.err(), "array_push cannot push to a full array"); + } }; typedef testing::Types @@ -1382,6 +1458,14 @@ TYPED_TEST(stdlib_field, test_array_push_optional) { TestFixture::test_array_push_optional(); } +TYPED_TEST(stdlib_field, test_array_push_generic) +{ + TestFixture::test_array_push_generic(); +} +TYPED_TEST(stdlib_field, test_array_push_generic_full) +{ + TestFixture::test_array_push_generic_full(); +} TYPED_TEST(stdlib_field, test_array_push_array_to_array) { TestFixture::test_push_array_to_array(); diff --git a/cpp/src/barretenberg/stdlib/recursion/verification_key/verification_key.hpp b/cpp/src/barretenberg/stdlib/recursion/verification_key/verification_key.hpp index e36fc3995a..d25d2930ef 100644 --- a/cpp/src/barretenberg/stdlib/recursion/verification_key/verification_key.hpp +++ b/cpp/src/barretenberg/stdlib/recursion/verification_key/verification_key.hpp @@ -187,7 +187,7 @@ template struct verification_key { } } - private: + public: field_t compress() { field_t compressed_domain = domain.compress(); @@ -215,7 +215,7 @@ template struct verification_key { return compressed_key; } - barretenberg::fr compress_native(const std::shared_ptr& key) + static barretenberg::fr compress_native(const std::shared_ptr& key) { barretenberg::fr compressed_domain = evaluation_domain::compress_native(key->domain);