Skip to content

Commit

Permalink
add test for goblin batch mul
Browse files Browse the repository at this point in the history
  • Loading branch information
ledwards2225 committed Aug 28, 2023
1 parent fcc9aa7 commit 3d95476
Show file tree
Hide file tree
Showing 3 changed files with 100 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -337,8 +337,6 @@ template <typename TestType> class stdlib_biggroup : public testing::Test {
EXPECT_CIRCUIT_CORRECTNESS(composer);
}

// WORKTODO: add a test for goblin_batch_mul

static void test_batch_mul()
{
const size_t num_points = 5;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,14 @@ namespace stdlib {

/**
* @brief Goblin style batch multiplication
* @note (Luke): The approach of having a distinct interface for goblin style group operations is limited/flawed. The
* natural alternative is to abstract the details away from the circuit writer and to simply allow the strategy to be
* determined by the type of circuit constructor (i.e. Goblin or not) from within biggroup. Currently, the goblin-style
* circuit builder functionality has been incorporated directly into the UltraCircuitBuilder, thus there is no
* means for distinction. If we decide it is preferable to support fully flexible goblin-style group operations via the
* existing biggroup API, we will need to make an independent GoblinUltraCircuitBuilder class (plausibly via inheritance
* from UltraCircuitBuilder) and implement Goblin-style strategies for each of the operations in biggroup.
*
* @details In goblin-style arithmetization, the operands (points/scalars) for each mul-accumulate operation are
* decomposed into smaller components and written to an operation queue via the builder. The components are also added
* as witness variables. This function adds constraints demonstrating the fidelity of the point/scalar decompositions
* given the indices of the components in the variables array. The actual mul-accumulate operations are performed
* natively (without constraints) under the hood, and the final result is obtained by queueing an equality operation via
* the builder. The components of the result are returned as indices into the variables array from which the resulting
* accumulator point is re-constructed.
*
* @tparam C CircuitBuilder
* @tparam Fq Base field
Expand Down Expand Up @@ -41,7 +42,7 @@ element<C, Fq, Fr, G> element<C, Fq, Fr, G>::goblin_batch_mul(const std::vector<
// Populate the goblin-style ecc op gates for the given mul inputs
auto op_tuple = builder->queue_ecc_mul_accum(point.get_value(), scalar.get_value());

// Constrain decomposition of point coordinates to reconstruct original values.
// Adds constraints demonstrating proper decomposition of point coordinates.
// Note: may need to do point.x.assert_is_in_field() prior to the assert_eq() according to Kesha.
auto x_lo = Fr::from_witness_index(builder, op_tuple.x_lo);
auto x_hi = Fr::from_witness_index(builder, op_tuple.x_hi);
Expand All @@ -54,7 +55,7 @@ element<C, Fq, Fr, G> element<C, Fq, Fr, G>::goblin_batch_mul(const std::vector<
point.x.assert_equal(point_x);
point.y.assert_equal(point_y);

// Constrain endomorphism scalars to reconstruct scalar
// Add constraints demonstrating proper decomposition of scalar into endomorphism scalars
auto z_1 = Fr::from_witness_index(builder, op_tuple.z_1);
auto z_2 = Fr::from_witness_index(builder, op_tuple.z_2);
auto beta = G::subgroup_field::cube_root_of_unity();
Expand All @@ -64,7 +65,7 @@ element<C, Fq, Fr, G> element<C, Fq, Fr, G>::goblin_batch_mul(const std::vector<
// Populate equality gates based on the internal accumulator point
auto op_tuple = builder->queue_ecc_eq();

// Reconstruct the result of the batch mul
// Reconstruct the result of the batch mul using indices into the variables array
auto x_lo = Fr::from_witness_index(builder, op_tuple.x_lo);
auto x_hi = Fr::from_witness_index(builder, op_tuple.x_hi);
auto y_lo = Fr::from_witness_index(builder, op_tuple.y_lo);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
#include "barretenberg/common/test.hpp"
#include <type_traits>

#include "../biggroup/biggroup.hpp"
#include "barretenberg/stdlib/primitives/circuit_builders/circuit_builders.hpp"

#include "barretenberg/stdlib/primitives/curves/bn254.hpp"

#include "barretenberg/numeric/random/engine.hpp"
#include <memory>

namespace test_stdlib_biggroup_goblin {
namespace {
auto& engine = numeric::random::get_debug_engine();
}

using namespace proof_system::plonk;

template <typename Curve> class stdlib_biggroup_goblin : public testing::Test {
using element_ct = typename Curve::Element;
using scalar_ct = typename Curve::ScalarField;

using fq = typename Curve::BaseFieldNative;
using fr = typename Curve::ScalarFieldNative;
using g1 = typename Curve::GroupNative;
using affine_element = typename g1::affine_element;
using element = typename g1::element;

using Builder = typename Curve::Builder;

static constexpr auto EXPECT_CIRCUIT_CORRECTNESS = [](Builder& builder, bool expected_result = true) {
info("builder gates = ", builder.get_num_gates());
EXPECT_EQ(builder.check_circuit(), expected_result);
};

public:
/**
* @brief Test goblin-style batch mul
* @details Check that 1) Goblin-style batch mul returns correct value, and 2) resulting circuit is correct
*
*/
static void test_goblin_style_batch_mul()
{
const bool goblin_flag = true; // used to indicate goblin-style in batch_mul
const size_t num_points = 5;
Builder builder;

std::vector<affine_element> points;
std::vector<fr> scalars;
for (size_t i = 0; i < num_points; ++i) {
points.push_back(affine_element(element::random_element()));
scalars.push_back(fr::random_element());
}

std::vector<element_ct> circuit_points;
std::vector<scalar_ct> circuit_scalars;
for (size_t i = 0; i < num_points; ++i) {
circuit_points.push_back(element_ct::from_witness(&builder, points[i]));
circuit_scalars.push_back(scalar_ct::from_witness(&builder, scalars[i]));
}

element_ct result_point = element_ct::template batch_mul<goblin_flag>(circuit_points, circuit_scalars);

element expected_point = g1::one;
expected_point.self_set_infinity();
for (size_t i = 0; i < num_points; ++i) {
expected_point += (element(points[i]) * scalars[i]);
}

expected_point = expected_point.normalize();
fq result_x(result_point.x.get_value().lo);
fq result_y(result_point.y.get_value().lo);

EXPECT_EQ(result_x, expected_point.x);
EXPECT_EQ(result_y, expected_point.y);

EXPECT_CIRCUIT_CORRECTNESS(builder);
}
};

using TestTypes = testing::Types<stdlib::bn254<proof_system::UltraCircuitBuilder>>;

TYPED_TEST_SUITE(stdlib_biggroup_goblin, TestTypes);

HEAVY_TYPED_TEST(stdlib_biggroup_goblin, batch_mul)
{
TestFixture::test_goblin_style_batch_mul();
}
} // namespace test_stdlib_biggroup_goblin

0 comments on commit 3d95476

Please sign in to comment.