Skip to content

Commit

Permalink
fix: Ecdsa Malleability Bug (AztecProtocol/barretenberg#512)
Browse files Browse the repository at this point in the history
Co-authored-by: Rumata888 <[email protected]>
  • Loading branch information
suyash67 and Rumata888 authored Jun 26, 2023
1 parent a39390a commit 0452563
Show file tree
Hide file tree
Showing 5 changed files with 148 additions and 5 deletions.
28 changes: 23 additions & 5 deletions barretenberg/cpp/src/barretenberg/crypto/ecdsa/ecdsa_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,11 @@ signature construct_signature(const std::string& message, const key_pair<Fr, G1>
Fr r_fr = Fr::serialize_from_buffer(&sig.r[0]);
Fr s_fr = (z + r_fr * account.private_key) / k;

Fr::serialize_to_buffer(s_fr, &sig.s[0]);
// Ensure that the value of s is "low", i.e. s := min{ s_fr, (|Fr| - s_fr) }
const bool is_s_low = (uint256_t(s_fr) < (uint256_t(Fr::modulus) / 2));
uint256_t s_uint256 = is_s_low ? uint256_t(s_fr) : (uint256_t(Fr::modulus) - uint256_t(s_fr));

Fr::serialize_to_buffer(Fr(s_uint256), &sig.s[0]);

// compute recovery_id: given R = (x, y)
// 0: y is even && x < |Fr|
Expand All @@ -39,9 +43,10 @@ signature construct_signature(const std::string& message, const key_pair<Fr, G1>
Fq r_fq = Fq(R.x);
bool is_r_finite = (uint256_t(r_fq) == uint256_t(r_fr));
bool y_parity = uint256_t(R.y).get_bit(0);
bool recovery_bit = y_parity ^ is_s_low;
constexpr uint8_t offset = 27;

int value = offset + y_parity + static_cast<uint8_t>(2) * !is_r_finite;
int value = offset + recovery_bit + static_cast<uint8_t>(2) * !is_r_finite;
ASSERT(value <= UINT8_MAX);
sig.v = static_cast<uint8_t>(value);
return sig;
Expand All @@ -54,6 +59,7 @@ typename G1::affine_element recover_public_key(const std::string& message, const
uint256_t r_uint;
uint256_t s_uint;
uint8_t v_uint;
uint256_t mod = uint256_t(Fr::modulus);

const auto* r_buf = &sig.r[0];
const auto* s_buf = &sig.s[0];
Expand All @@ -63,13 +69,18 @@ typename G1::affine_element recover_public_key(const std::string& message, const
read(v_buf, v_uint);

// We need to check that r and s are in Field according to specification
if ((r_uint >= Fr::modulus) || (s_uint >= Fr::modulus)) {
if ((r_uint >= mod) || (s_uint >= mod)) {
throw_or_abort("r or s value exceeds the modulus");
}
if ((r_uint == 0) || (s_uint == 0)) {
throw_or_abort("r or s value is zero");
}

// Check that the s value is less than |Fr| / 2
if (s_uint * 2 > mod) {
throw_or_abort("s value is not less than curve order by 2");
}

// Check that v must either be in {27, 28, 29, 30}
Fr r = Fr(r_uint);
Fr s = Fr(s_uint);
Expand All @@ -94,7 +105,7 @@ typename G1::affine_element recover_public_key(const std::string& message, const

// Negate the y-coordinate point of R based on the parity of v
bool y_parity_R = uint256_t(point_R.y).get_bit(0);
if ((v_uint & 1) == y_parity_R) {
if ((v_uint & 1) ^ y_parity_R) {
point_R.y = -point_R.y;
}

Expand All @@ -119,6 +130,7 @@ bool verify_signature(const std::string& message, const typename G1::affine_elem
using serialize::read;
uint256_t r_uint;
uint256_t s_uint;
uint256_t mod = uint256_t(Fr::modulus);
if (!public_key.on_curve()) {
return false;
}
Expand All @@ -127,12 +139,18 @@ bool verify_signature(const std::string& message, const typename G1::affine_elem
read(r_buf, r_uint);
read(s_buf, s_uint);
// We need to check that r and s are in Field according to specification
if ((r_uint >= Fr::modulus) || (s_uint >= Fr::modulus)) {
if ((r_uint >= mod) || (s_uint >= mod)) {
return false;
}
if ((r_uint == 0) || (s_uint == 0)) {
return false;
}

// Check that the s value is less than |Fr| / 2
if (s_uint * 2 > mod) {
throw_or_abort("s value is not less than curve order by 2");
}

Fr r = Fr(r_uint);
Fr s = Fr(s_uint);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,10 @@ bool_t<Composer> verify_signature(const stdlib::byte_array<Composer>& message,
r.assert_is_not_equal(Fr::zero());
s.assert_is_not_equal(Fr::zero());

// s should be less than |Fr| / 2
// Read more about this at: https://www.derpturkey.com/inherent-malleability-of-ecdsa-signatures/amp/
s.assert_less_than((Fr::modulus + 1) / 2);

Fr u1 = z / s;
Fr u2 = r / s;

Expand Down Expand Up @@ -143,6 +147,10 @@ bool_t<Composer> verify_signature_prehashed_message_noassert(const stdlib::byte_
r.assert_is_not_equal(Fr::zero());
s.assert_is_not_equal(Fr::zero());

// s should be less than |Fr| / 2
// Read more about this at: https://www.derpturkey.com/inherent-malleability-of-ecdsa-signatures/amp/
s.assert_less_than((Fr::modulus + 1) / 2);

Fr u1 = z / s;
Fr u2 = r / s;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,7 @@ template <typename Composer, typename T> class bigfield {
bigfield conditional_select(const bigfield& other, const bool_t<Composer>& predicate) const;

void assert_is_in_field() const;
void assert_less_than(const uint256_t upper_limit) const;
void assert_equal(const bigfield& other) const;
void assert_is_not_equal(const bigfield& other) const;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -708,6 +708,53 @@ template <typename Composer> class stdlib_bigfield : public testing::Test {
EXPECT_EQ(result, true);
}

static void test_assert_less_than_success()
{
auto composer = Composer();
size_t num_repetitions = 10;
constexpr size_t num_bits = 200;
constexpr uint256_t bit_mask = (uint256_t(1) << num_bits) - 1;
for (size_t i = 0; i < num_repetitions; ++i) {

fq inputs[4]{ uint256_t(fq::random_element()) && bit_mask,
uint256_t(fq::random_element()) && bit_mask,
uint256_t(fq::random_element()) && bit_mask,
uint256_t(fq::random_element()) && bit_mask };

fq_ct a(witness_ct(&composer, fr(uint256_t(inputs[0]).slice(0, fq_ct::NUM_LIMB_BITS * 2))),
witness_ct(&composer,
fr(uint256_t(inputs[0]).slice(fq_ct::NUM_LIMB_BITS * 2, fq_ct::NUM_LIMB_BITS * 4))));
fq_ct b(witness_ct(&composer, fr(uint256_t(inputs[1]).slice(0, fq_ct::NUM_LIMB_BITS * 2))),
witness_ct(&composer,
fr(uint256_t(inputs[1]).slice(fq_ct::NUM_LIMB_BITS * 2, fq_ct::NUM_LIMB_BITS * 4))));

fq_ct c = a;
fq expected = inputs[0];
for (size_t i = 0; i < 16; ++i) {
c = b * b + c;
expected = inputs[1] * inputs[1] + expected;
}
// fq_ct c = a + a + a + a - b - b - b - b;
c.assert_less_than(bit_mask + 1);
uint256_t result = (c.get_value().lo);
EXPECT_EQ(result, uint256_t(expected));
EXPECT_EQ(c.get_value().get_msb() < num_bits, true);
}
bool result = composer.check_circuit();
EXPECT_EQ(result, true);
// Checking edge conditions
fq random_input = fq::random_element();
fq_ct a(witness_ct(&composer, fr(uint256_t(random_input).slice(0, fq_ct::NUM_LIMB_BITS * 2))),
witness_ct(&composer,
fr(uint256_t(random_input).slice(fq_ct::NUM_LIMB_BITS * 2, fq_ct::NUM_LIMB_BITS * 4))));

a.assert_less_than(random_input + 1);
EXPECT_EQ(composer.check_circuit(), true);

a.assert_less_than(random_input);
EXPECT_EQ(composer.check_circuit(), false);
}

static void test_byte_array_constructors()
{
auto composer = Composer();
Expand Down Expand Up @@ -866,6 +913,10 @@ TYPED_TEST(stdlib_bigfield, assert_is_in_field_succes)
{
TestFixture::test_assert_is_in_field_success();
}
TYPED_TEST(stdlib_bigfield, assert_less_than_success)
{
TestFixture::test_assert_less_than_success();
}
TYPED_TEST(stdlib_bigfield, byte_array_constructors)
{
TestFixture::test_byte_array_constructors();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1664,6 +1664,71 @@ template <typename C, typename T> void bigfield<C, T>::assert_is_in_field() cons
}
}

template <typename C, typename T> void bigfield<C, T>::assert_less_than(const uint256_t upper_limit) const
{
// TODO(kesha): Merge this with assert_is_in_field
// Warning: this assumes we have run circuit construction at least once in debug mode where large non reduced
// constants are allowed via ASSERT
if (is_constant()) {
return;
}
ASSERT(upper_limit != 0);
// The circuit checks that limit - this >= 0, so if we are doing a less_than comparison, we need to subtract 1 from
// the limit
uint256_t strict_upper_limit = upper_limit - uint256_t(1);
self_reduce(); // this method in particular enforces limb vals are <2^b - needed for logic described above
uint256_t value = get_value().lo;

const uint256_t upper_limit_value_0 = strict_upper_limit.slice(0, NUM_LIMB_BITS);
const uint256_t upper_limit_value_1 = strict_upper_limit.slice(NUM_LIMB_BITS, NUM_LIMB_BITS * 2);
const uint256_t upper_limit_value_2 = strict_upper_limit.slice(NUM_LIMB_BITS * 2, NUM_LIMB_BITS * 3);
const uint256_t upper_limit_value_3 = strict_upper_limit.slice(NUM_LIMB_BITS * 3, NUM_LIMB_BITS * 4);

bool borrow_0_value = value.slice(0, NUM_LIMB_BITS) > upper_limit_value_0;
bool borrow_1_value =
(value.slice(NUM_LIMB_BITS, NUM_LIMB_BITS * 2) + uint256_t(borrow_0_value)) > (upper_limit_value_1);
bool borrow_2_value =
(value.slice(NUM_LIMB_BITS * 2, NUM_LIMB_BITS * 3) + uint256_t(borrow_1_value)) > (upper_limit_value_2);

field_t<C> upper_limit_0(context, upper_limit_value_0);
field_t<C> upper_limit_1(context, upper_limit_value_1);
field_t<C> upper_limit_2(context, upper_limit_value_2);
field_t<C> upper_limit_3(context, upper_limit_value_3);
bool_t<C> borrow_0(witness_t<C>(context, borrow_0_value));
bool_t<C> borrow_1(witness_t<C>(context, borrow_1_value));
bool_t<C> borrow_2(witness_t<C>(context, borrow_2_value));
// The way we use borrows here ensures that we are checking that upper_limit - binary_basis > 0.
// We check that the result in each limb is > 0.
// If the modulus part in this limb is smaller, we simply borrow the value from the higher limb.
// The prover can rearrange the borrows the way they like. The important thing is that the borrows are
// constrained.
field_t<C> r0 = upper_limit_0 - binary_basis_limbs[0].element + static_cast<field_t<C>>(borrow_0) * shift_1;
field_t<C> r1 = upper_limit_1 - binary_basis_limbs[1].element + static_cast<field_t<C>>(borrow_1) * shift_1 -
static_cast<field_t<C>>(borrow_0);
field_t<C> r2 = upper_limit_2 - binary_basis_limbs[2].element + static_cast<field_t<C>>(borrow_2) * shift_1 -
static_cast<field_t<C>>(borrow_1);
field_t<C> r3 = upper_limit_3 - binary_basis_limbs[3].element - static_cast<field_t<C>>(borrow_2);
r0 = r0.normalize();
r1 = r1.normalize();
r2 = r2.normalize();
r3 = r3.normalize();
if constexpr (C::type == ComposerType::PLOOKUP) {
context->decompose_into_default_range(r0.witness_index, static_cast<size_t>(NUM_LIMB_BITS));
context->decompose_into_default_range(r1.witness_index, static_cast<size_t>(NUM_LIMB_BITS));
context->decompose_into_default_range(r2.witness_index, static_cast<size_t>(NUM_LIMB_BITS));
context->decompose_into_default_range(r3.witness_index, static_cast<size_t>(NUM_LIMB_BITS));
} else {
context->decompose_into_base4_accumulators(
r0.witness_index, static_cast<size_t>(NUM_LIMB_BITS), "bigfield: assert_less_than range constraint 1.");
context->decompose_into_base4_accumulators(
r1.witness_index, static_cast<size_t>(NUM_LIMB_BITS), "bigfield: assert_less_than range constraint 2.");
context->decompose_into_base4_accumulators(
r2.witness_index, static_cast<size_t>(NUM_LIMB_BITS), "bigfield: assert_less_than range constraint 3.");
context->decompose_into_base4_accumulators(
r3.witness_index, static_cast<size_t>(NUM_LIMB_BITS), "bigfield: assert_less_than range constraint 4.");
}
}

// check elements are equal mod p by proving their integer difference is a multiple of p.
// This relies on the minus operator for a-b increasing a by a multiple of p large enough so diff is non-negative
template <typename C, typename T> void bigfield<C, T>::assert_equal(const bigfield& other) const
Expand Down

0 comments on commit 0452563

Please sign in to comment.