Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Ecdsa Malleability Bug #512

Merged
merged 5 commits into from
Jun 26, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 23 additions & 5 deletions cpp/src/barretenberg/crypto/ecdsa/ecdsa_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,11 @@ signature construct_signature(const std::string& message, const key_pair<Fr, G1>
Fr r_fr = Fr::serialize_from_buffer(&sig.r[0]);
Fr s_fr = (z + r_fr * account.private_key) / k;

Fr::serialize_to_buffer(s_fr, &sig.s[0]);
// Ensure that the value of s is "low", i.e. s := min{ s_fr, (|Fr| - s_fr) }
const bool is_s_low = (uint256_t(s_fr) < (uint256_t(Fr::modulus) / 2));
uint256_t s_uint256 = is_s_low ? uint256_t(s_fr) : (uint256_t(Fr::modulus) - uint256_t(s_fr));

Fr::serialize_to_buffer(Fr(s_uint256), &sig.s[0]);

// compute recovery_id: given R = (x, y)
// 0: y is even && x < |Fr|
Expand All @@ -39,9 +43,10 @@ signature construct_signature(const std::string& message, const key_pair<Fr, G1>
Fq r_fq = Fq(R.x);
bool is_r_finite = (uint256_t(r_fq) == uint256_t(r_fr));
bool y_parity = uint256_t(R.y).get_bit(0);
bool recovery_bit = y_parity ^ is_s_low;
constexpr uint8_t offset = 27;

int value = offset + y_parity + static_cast<uint8_t>(2) * !is_r_finite;
int value = offset + recovery_bit + static_cast<uint8_t>(2) * !is_r_finite;
ASSERT(value <= UINT8_MAX);
sig.v = static_cast<uint8_t>(value);
return sig;
Expand All @@ -54,6 +59,7 @@ typename G1::affine_element recover_public_key(const std::string& message, const
uint256_t r_uint;
uint256_t s_uint;
uint8_t v_uint;
uint256_t mod = uint256_t(Fr::modulus);

const auto* r_buf = &sig.r[0];
const auto* s_buf = &sig.s[0];
Expand All @@ -63,13 +69,18 @@ typename G1::affine_element recover_public_key(const std::string& message, const
read(v_buf, v_uint);

// We need to check that r and s are in Field according to specification
if ((r_uint >= Fr::modulus) || (s_uint >= Fr::modulus)) {
if ((r_uint >= mod) || (s_uint >= mod)) {
throw_or_abort("r or s value exceeds the modulus");
}
if ((r_uint == 0) || (s_uint == 0)) {
throw_or_abort("r or s value is zero");
}

// Check that the s value is less than |Fr| / 2
if (s_uint * 2 > mod) {
throw_or_abort("s value is not less than curve order by 2");
}

// Check that v must either be in {27, 28, 29, 30}
Fr r = Fr(r_uint);
Fr s = Fr(s_uint);
Expand All @@ -94,7 +105,7 @@ typename G1::affine_element recover_public_key(const std::string& message, const

// Negate the y-coordinate point of R based on the parity of v
bool y_parity_R = uint256_t(point_R.y).get_bit(0);
if ((v_uint & 1) == y_parity_R) {
if ((v_uint & 1) ^ y_parity_R) {
point_R.y = -point_R.y;
}

Expand All @@ -119,6 +130,7 @@ bool verify_signature(const std::string& message, const typename G1::affine_elem
using serialize::read;
uint256_t r_uint;
uint256_t s_uint;
uint256_t mod = uint256_t(Fr::modulus);
if (!public_key.on_curve()) {
return false;
}
Expand All @@ -127,12 +139,18 @@ bool verify_signature(const std::string& message, const typename G1::affine_elem
read(r_buf, r_uint);
read(s_buf, s_uint);
// We need to check that r and s are in Field according to specification
if ((r_uint >= Fr::modulus) || (s_uint >= Fr::modulus)) {
if ((r_uint >= mod) || (s_uint >= mod)) {
return false;
}
if ((r_uint == 0) || (s_uint == 0)) {
return false;
}

// Check that the s value is less than |Fr| / 2
if (s_uint * 2 > mod) {
throw_or_abort("s value is not less than curve order by 2");
}

Fr r = Fr(r_uint);
Fr s = Fr(s_uint);

Expand Down
8 changes: 8 additions & 0 deletions cpp/src/barretenberg/stdlib/encryption/ecdsa/ecdsa_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,10 @@ bool_t<Composer> verify_signature(const stdlib::byte_array<Composer>& message,
r.assert_is_not_equal(Fr::zero());
s.assert_is_not_equal(Fr::zero());

// s should be less than |Fr| / 2
// Read more about this at: https://www.derpturkey.com/inherent-malleability-of-ecdsa-signatures/amp/
s.assert_less_than((Fr::modulus + 1) / 2);

Fr u1 = z / s;
Fr u2 = r / s;

Expand Down Expand Up @@ -143,6 +147,10 @@ bool_t<Composer> verify_signature_prehashed_message_noassert(const stdlib::byte_
r.assert_is_not_equal(Fr::zero());
s.assert_is_not_equal(Fr::zero());

// s should be less than |Fr| / 2
// Read more about this at: https://www.derpturkey.com/inherent-malleability-of-ecdsa-signatures/amp/
s.assert_less_than((Fr::modulus + 1) / 2);

Fr u1 = z / s;
Fr u2 = r / s;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,7 @@ template <typename Composer, typename T> class bigfield {
bigfield conditional_select(const bigfield& other, const bool_t<Composer>& predicate) const;

void assert_is_in_field() const;
void assert_less_than(const uint256_t upper_limit) const;
void assert_equal(const bigfield& other) const;
void assert_is_not_equal(const bigfield& other) const;

Expand Down
51 changes: 51 additions & 0 deletions cpp/src/barretenberg/stdlib/primitives/bigfield/bigfield.test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -708,6 +708,53 @@ template <typename Composer> class stdlib_bigfield : public testing::Test {
EXPECT_EQ(result, true);
}

static void test_assert_less_than_success()
{
auto composer = Composer();
size_t num_repetitions = 10;
constexpr size_t num_bits = 200;
constexpr uint256_t bit_mask = (uint256_t(1) << num_bits) - 1;
for (size_t i = 0; i < num_repetitions; ++i) {

fq inputs[4]{ uint256_t(fq::random_element()) && bit_mask,
uint256_t(fq::random_element()) && bit_mask,
uint256_t(fq::random_element()) && bit_mask,
uint256_t(fq::random_element()) && bit_mask };

fq_ct a(witness_ct(&composer, fr(uint256_t(inputs[0]).slice(0, fq_ct::NUM_LIMB_BITS * 2))),
witness_ct(&composer,
fr(uint256_t(inputs[0]).slice(fq_ct::NUM_LIMB_BITS * 2, fq_ct::NUM_LIMB_BITS * 4))));
fq_ct b(witness_ct(&composer, fr(uint256_t(inputs[1]).slice(0, fq_ct::NUM_LIMB_BITS * 2))),
witness_ct(&composer,
fr(uint256_t(inputs[1]).slice(fq_ct::NUM_LIMB_BITS * 2, fq_ct::NUM_LIMB_BITS * 4))));

fq_ct c = a;
fq expected = inputs[0];
for (size_t i = 0; i < 16; ++i) {
c = b * b + c;
expected = inputs[1] * inputs[1] + expected;
}
// fq_ct c = a + a + a + a - b - b - b - b;
c.assert_less_than(bit_mask + 1);
uint256_t result = (c.get_value().lo);
EXPECT_EQ(result, uint256_t(expected));
EXPECT_EQ(c.get_value().get_msb() < num_bits, true);
}
bool result = composer.check_circuit();
EXPECT_EQ(result, true);
// Checking edge conditions
fq random_input = fq::random_element();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit, but not a blocker: unit tests should ideally not contain random elements as they then become unreproducable

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You are partially right. I'd add something like "random_element_for_test" into the field, so that it would automatically print it. But it is actually better to have random elements than static, because if there is a bug there is a chance that it triggers. And then you can stress the test to find it (that was the case with construct_addition_chains)

fq_ct a(witness_ct(&composer, fr(uint256_t(random_input).slice(0, fq_ct::NUM_LIMB_BITS * 2))),
witness_ct(&composer,
fr(uint256_t(random_input).slice(fq_ct::NUM_LIMB_BITS * 2, fq_ct::NUM_LIMB_BITS * 4))));

a.assert_less_than(random_input + 1);
EXPECT_EQ(composer.check_circuit(), true);

a.assert_less_than(random_input);
EXPECT_EQ(composer.check_circuit(), false);
}

static void test_byte_array_constructors()
{
auto composer = Composer();
Expand Down Expand Up @@ -866,6 +913,10 @@ TYPED_TEST(stdlib_bigfield, assert_is_in_field_succes)
{
TestFixture::test_assert_is_in_field_success();
}
TYPED_TEST(stdlib_bigfield, assert_less_than_success)
{
TestFixture::test_assert_less_than_success();
}
TYPED_TEST(stdlib_bigfield, byte_array_constructors)
{
TestFixture::test_byte_array_constructors();
Expand Down
65 changes: 65 additions & 0 deletions cpp/src/barretenberg/stdlib/primitives/bigfield/bigfield_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -1664,6 +1664,71 @@ template <typename C, typename T> void bigfield<C, T>::assert_is_in_field() cons
}
}

template <typename C, typename T> void bigfield<C, T>::assert_less_than(const uint256_t upper_limit) const
{
// TODO(kesha): Merge this with assert_is_in_field
// Warning: this assumes we have run circuit construction at least once in debug mode where large non reduced
// constants are allowed via ASSERT
if (is_constant()) {
return;
}
ASSERT(upper_limit != 0);
// The circuit checks that limit - this >= 0, so if we are doing a less_than comparison, we need to subtract 1 from
// the limit
uint256_t strict_upper_limit = upper_limit - uint256_t(1);
self_reduce(); // this method in particular enforces limb vals are <2^b - needed for logic described above
uint256_t value = get_value().lo;

const uint256_t upper_limit_value_0 = strict_upper_limit.slice(0, NUM_LIMB_BITS);
const uint256_t upper_limit_value_1 = strict_upper_limit.slice(NUM_LIMB_BITS, NUM_LIMB_BITS * 2);
const uint256_t upper_limit_value_2 = strict_upper_limit.slice(NUM_LIMB_BITS * 2, NUM_LIMB_BITS * 3);
const uint256_t upper_limit_value_3 = strict_upper_limit.slice(NUM_LIMB_BITS * 3, NUM_LIMB_BITS * 4);

bool borrow_0_value = value.slice(0, NUM_LIMB_BITS) > upper_limit_value_0;
bool borrow_1_value =
(value.slice(NUM_LIMB_BITS, NUM_LIMB_BITS * 2) + uint256_t(borrow_0_value)) > (upper_limit_value_1);
bool borrow_2_value =
(value.slice(NUM_LIMB_BITS * 2, NUM_LIMB_BITS * 3) + uint256_t(borrow_1_value)) > (upper_limit_value_2);

field_t<C> upper_limit_0(context, upper_limit_value_0);
field_t<C> upper_limit_1(context, upper_limit_value_1);
field_t<C> upper_limit_2(context, upper_limit_value_2);
field_t<C> upper_limit_3(context, upper_limit_value_3);
bool_t<C> borrow_0(witness_t<C>(context, borrow_0_value));
bool_t<C> borrow_1(witness_t<C>(context, borrow_1_value));
bool_t<C> borrow_2(witness_t<C>(context, borrow_2_value));
// The way we use borrows here ensures that we are checking that upper_limit - binary_basis > 0.
// We check that the result in each limb is > 0.
// If the modulus part in this limb is smaller, we simply borrow the value from the higher limb.
// The prover can rearrange the borrows the way they like. The important thing is that the borrows are
// constrained.
field_t<C> r0 = upper_limit_0 - binary_basis_limbs[0].element + static_cast<field_t<C>>(borrow_0) * shift_1;
field_t<C> r1 = upper_limit_1 - binary_basis_limbs[1].element + static_cast<field_t<C>>(borrow_1) * shift_1 -
static_cast<field_t<C>>(borrow_0);
field_t<C> r2 = upper_limit_2 - binary_basis_limbs[2].element + static_cast<field_t<C>>(borrow_2) * shift_1 -
static_cast<field_t<C>>(borrow_1);
field_t<C> r3 = upper_limit_3 - binary_basis_limbs[3].element - static_cast<field_t<C>>(borrow_2);
r0 = r0.normalize();
r1 = r1.normalize();
r2 = r2.normalize();
r3 = r3.normalize();
if constexpr (C::type == ComposerType::PLOOKUP) {
context->decompose_into_default_range(r0.witness_index, static_cast<size_t>(NUM_LIMB_BITS));
context->decompose_into_default_range(r1.witness_index, static_cast<size_t>(NUM_LIMB_BITS));
context->decompose_into_default_range(r2.witness_index, static_cast<size_t>(NUM_LIMB_BITS));
context->decompose_into_default_range(r3.witness_index, static_cast<size_t>(NUM_LIMB_BITS));
} else {
context->decompose_into_base4_accumulators(
r0.witness_index, static_cast<size_t>(NUM_LIMB_BITS), "bigfield: assert_less_than range constraint 1.");
context->decompose_into_base4_accumulators(
r1.witness_index, static_cast<size_t>(NUM_LIMB_BITS), "bigfield: assert_less_than range constraint 2.");
context->decompose_into_base4_accumulators(
r2.witness_index, static_cast<size_t>(NUM_LIMB_BITS), "bigfield: assert_less_than range constraint 3.");
context->decompose_into_base4_accumulators(
r3.witness_index, static_cast<size_t>(NUM_LIMB_BITS), "bigfield: assert_less_than range constraint 4.");
}
}

// check elements are equal mod p by proving their integer difference is a multiple of p.
// This relies on the minus operator for a-b increasing a by a multiple of p large enough so diff is non-negative
template <typename C, typename T> void bigfield<C, T>::assert_equal(const bigfield& other) const
Expand Down