Skip to content

Commit

Permalink
harden also key decoding
Browse files Browse the repository at this point in the history
as assert statements will be removed in optimised build, do use a custom
exception that inherits from AssertionError so that the failures are
caught

this is reworking of d47a238 to implement the same checks but
without implementing support for SEC1/X9.62 formatted keys
  • Loading branch information
tomato42 committed Oct 7, 2019
1 parent 3427fa2 commit 99c907d
Show file tree
Hide file tree
Showing 3 changed files with 207 additions and 14 deletions.
5 changes: 4 additions & 1 deletion ecdsa/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
__all__ = ["curves", "der", "ecdsa", "ellipticcurve", "keys", "numbertheory",
"test_pyecdsa", "util", "six"]
from .keys import SigningKey, VerifyingKey, BadSignatureError, BadDigestError
from .keys import SigningKey, VerifyingKey, BadSignatureError, BadDigestError,\
MalformedPointError
from .curves import NIST192p, NIST224p, NIST256p, NIST384p, NIST521p, SECP256k1
from .der import UnexpectedDER

_hush_pyflakes = [SigningKey, VerifyingKey, BadSignatureError, BadDigestError,
MalformedPointError, UnexpectedDER,
NIST192p, NIST224p, NIST256p, NIST384p, NIST521p, SECP256k1]
del _hush_pyflakes

Expand Down
43 changes: 32 additions & 11 deletions ecdsa/keys.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from . import ecdsa
from . import der
from . import rfc6979
from . import ellipticcurve
from .curves import NIST192p, find_curve
from .util import string_to_number, number_to_string, randrange
from .util import sigencode_string, sigdecode_string
Expand All @@ -15,6 +16,11 @@ class BadSignatureError(Exception):
class BadDigestError(Exception):
pass


class MalformedPointError(AssertionError):
pass


class VerifyingKey:
def __init__(self, _error__please_use_generate=None):
if not _error__please_use_generate:
Expand All @@ -33,17 +39,21 @@ def from_public_point(klass, point, curve=NIST192p, hashfunc=sha1):
def from_string(klass, string, curve=NIST192p, hashfunc=sha1,
validate_point=True):
order = curve.order
assert len(string) == curve.verifying_key_length, \
(len(string), curve.verifying_key_length)
if len(string) != curve.verifying_key_length:
raise MalformedPointError(
"Malformed encoding of public point. Expected string {0} bytes"
" long, received {1} bytes long string".format(
curve.verifying_key_length, len(string)))
xs = string[:curve.baselen]
ys = string[curve.baselen:]
assert len(xs) == curve.baselen, (len(xs), curve.baselen)
assert len(ys) == curve.baselen, (len(ys), curve.baselen)
if len(xs) != curve.baselen:
raise MalformedPointError("Unexpected length of encoded x")
if len(ys) != curve.baselen:
raise MalformedPointError("Unexpected length of encoded y")
x = string_to_number(xs)
y = string_to_number(ys)
if validate_point:
assert ecdsa.point_is_valid(curve.generator, x, y)
from . import ellipticcurve
if validate_point and not ecdsa.point_is_valid(curve.generator, x, y):
raise MalformedPointError("Point does not lie on the curve")
point = ellipticcurve.Point(curve.curve, x, y, order)
return klass.from_public_point(point, curve, hashfunc)

Expand All @@ -65,13 +75,18 @@ def from_der(klass, string):
if empty != b(""):
raise der.UnexpectedDER("trailing junk after DER pubkey objects: %s" %
binascii.hexlify(empty))
assert oid_pk == oid_ecPublicKey, (oid_pk, oid_ecPublicKey)
if oid_pk != oid_ecPublicKey:
raise der.UnexpectedDER(
"Unexpected OID in encoding, received {0}, expected {1}"
.format(oid_pk, oid_ecPublicKey))
curve = find_curve(oid_curve)
point_str, empty = der.remove_bitstring(point_str_bitstring)
if empty != b(""):
raise der.UnexpectedDER("trailing junk after pubkey pointstring: %s" %
binascii.hexlify(empty))
assert point_str.startswith(b("\x00\x04"))
if not point_str.startswith(b("\x00\x04")):
raise der.UnexpectedDER(
"Unsupported or invalid encoding of pubcli key")
return klass.from_string(point_str[2:], curve)

def to_string(self):
Expand Down Expand Up @@ -137,7 +152,10 @@ def from_secret_exponent(klass, secexp, curve=NIST192p, hashfunc=sha1):
self.default_hashfunc = hashfunc
self.baselen = curve.baselen
n = curve.order
assert 1 <= secexp < n
if not 1 <= secexp < n:
raise MalformedPointError(
"Invalid value for secexp, expected integer between 1 and {0}"
.format(n))
pubkey_point = curve.generator*secexp
pubkey = ecdsa.Public_key(curve.generator, pubkey_point)
pubkey.order = n
Expand All @@ -149,7 +167,10 @@ def from_secret_exponent(klass, secexp, curve=NIST192p, hashfunc=sha1):

@classmethod
def from_string(klass, string, curve=NIST192p, hashfunc=sha1):
assert len(string) == curve.baselen, (len(string), curve.baselen)
if len(string) != curve.baselen:
raise MalformedPointError(
"Invalid length of private key, received {0}, expected {1}"
.format(len(string), curve.baselen))
secexp = string_to_number(string)
return klass.from_secret_exponent(secexp, curve, hashfunc)

Expand Down
173 changes: 171 additions & 2 deletions ecdsa/test_pyecdsa.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
from __future__ import with_statement, division

import unittest
try:
import unittest2 as unittest
except ImportError:
import unittest
import os
import time
import shutil
Expand All @@ -10,15 +13,17 @@

from .six import b, print_, binary_type
from .keys import SigningKey, VerifyingKey
from .keys import BadSignatureError
from .keys import BadSignatureError, MalformedPointError, BadDigestError
from . import util
from .util import sigencode_der, sigencode_strings
from .util import sigdecode_der, sigdecode_strings
from .util import encoded_oid_ecPublicKey, MalformedSignature
from .curves import Curve, UnknownCurveError
from .curves import NIST192p, NIST224p, NIST256p, NIST384p, NIST521p, SECP256k1
from .ellipticcurve import Point
from . import der
from . import rfc6979
from . import ecdsa

class SubprocessError(Exception):
pass
Expand Down Expand Up @@ -258,6 +263,47 @@ def order(self): return 123456789
pub2 = VerifyingKey.from_pem(pem)
self.assertTruePubkeysEqual(pub1, pub2)

def test_vk_from_der_garbage_after_curve_oid(self):
type_oid_der = encoded_oid_ecPublicKey
curve_oid_der = der.encode_oid(*(1, 2, 840, 10045, 3, 1, 1)) + \
b('garbage')
enc_type_der = der.encode_sequence(type_oid_der, curve_oid_der)
point_der = der.encode_bitstring(b'\x00\xff')
to_decode = der.encode_sequence(enc_type_der, point_der)

with self.assertRaises(der.UnexpectedDER):
VerifyingKey.from_der(to_decode)

def test_vk_from_der_invalid_key_type(self):
type_oid_der = der.encode_oid(*(1, 2, 3))
curve_oid_der = der.encode_oid(*(1, 2, 840, 10045, 3, 1, 1))
enc_type_der = der.encode_sequence(type_oid_der, curve_oid_der)
point_der = der.encode_bitstring(b'\x00\xff')
to_decode = der.encode_sequence(enc_type_der, point_der)

with self.assertRaises(der.UnexpectedDER):
VerifyingKey.from_der(to_decode)

def test_vk_from_der_garbage_after_point_string(self):
type_oid_der = encoded_oid_ecPublicKey
curve_oid_der = der.encode_oid(*(1, 2, 840, 10045, 3, 1, 1))
enc_type_der = der.encode_sequence(type_oid_der, curve_oid_der)
point_der = der.encode_bitstring(b'\x00\xff') + b('garbage')
to_decode = der.encode_sequence(enc_type_der, point_der)

with self.assertRaises(der.UnexpectedDER):
VerifyingKey.from_der(to_decode)

def test_vk_from_der_invalid_bitstring(self):
type_oid_der = encoded_oid_ecPublicKey
curve_oid_der = der.encode_oid(*(1, 2, 840, 10045, 3, 1, 1))
enc_type_der = der.encode_sequence(type_oid_der, curve_oid_der)
point_der = der.encode_bitstring(b'\x08\xff')
to_decode = der.encode_sequence(enc_type_der, point_der)

with self.assertRaises(der.UnexpectedDER):
VerifyingKey.from_der(to_decode)

def test_signature_strings(self):
priv1 = SigningKey.generate()
pub1 = priv1.get_verifying_key()
Expand All @@ -281,6 +327,86 @@ def test_signature_strings(self):
self.assertEqual(type(sig_der), binary_type)
self.assertTrue(pub1.verify(sig_der, data, sigdecode=sigdecode_der))

def test_sig_decode_strings_with_invalid_count(self):
with self.assertRaises(MalformedSignature):
sigdecode_strings([b('one'), b('two'), b('three')], 0xff)

def test_sig_decode_strings_with_wrong_r_len(self):
with self.assertRaises(MalformedSignature):
sigdecode_strings([b('one'), b('two')], 0xff)

def test_sig_decode_strings_with_wrong_s_len(self):
with self.assertRaises(MalformedSignature):
sigdecode_strings([b('\xa0'), b('\xb0\xff')], 0xff)

def test_verify_with_too_long_input(self):
sk = SigningKey.generate()
vk = sk.verifying_key

with self.assertRaises(BadDigestError):
vk.verify_digest(None, b('\x00') * 128)

def test_sk_from_secret_exponent_with_wrong_sec_exponent(self):
with self.assertRaises(MalformedPointError):
SigningKey.from_secret_exponent(0)

def test_sk_from_string_with_wrong_len_string(self):
with self.assertRaises(MalformedPointError):
SigningKey.from_string(b('\x01'))

def test_sk_from_der_with_junk_after_sequence(self):
ver_der = der.encode_integer(1)
to_decode = der.encode_sequence(ver_der) + b('garbage')

with self.assertRaises(der.UnexpectedDER):
SigningKey.from_der(to_decode)

def test_sk_from_der_with_wrong_version(self):
ver_der = der.encode_integer(0)
to_decode = der.encode_sequence(ver_der)

with self.assertRaises(der.UnexpectedDER):
SigningKey.from_der(to_decode)

def test_sk_from_der_invalid_const_tag(self):
ver_der = der.encode_integer(1)
privkey_der = der.encode_octet_string(b('\x00\xff'))
curve_oid_der = der.encode_oid(*(1, 2, 3))
const_der = der.encode_constructed(1, curve_oid_der)
to_decode = der.encode_sequence(ver_der, privkey_der, const_der,
curve_oid_der)

with self.assertRaises(der.UnexpectedDER):
SigningKey.from_der(to_decode)

def test_sk_from_der_garbage_after_privkey_oid(self):
ver_der = der.encode_integer(1)
privkey_der = der.encode_octet_string(b('\x00\xff'))
curve_oid_der = der.encode_oid(*(1, 2, 3)) + b('garbage')
const_der = der.encode_constructed(0, curve_oid_der)
to_decode = der.encode_sequence(ver_der, privkey_der, const_der,
curve_oid_der)

with self.assertRaises(der.UnexpectedDER):
SigningKey.from_der(to_decode)

def test_sk_from_der_with_short_privkey(self):
ver_der = der.encode_integer(1)
privkey_der = der.encode_octet_string(b('\x00\xff'))
curve_oid_der = der.encode_oid(*(1, 2, 840, 10045, 3, 1, 1))
const_der = der.encode_constructed(0, curve_oid_der)
to_decode = der.encode_sequence(ver_der, privkey_der, const_der,
curve_oid_der)

sk = SigningKey.from_der(to_decode)
self.assertEqual(sk.privkey.secret_multiplier, 255)

def test_sign_with_too_long_hash(self):
sk = SigningKey.from_secret_exponent(12)

with self.assertRaises(BadDigestError):
sk.sign_digest(b('\xff') * 64)

def test_hashfunc(self):
sk = SigningKey.generate(curve=NIST256p, hashfunc=sha256)
data = b("security level is 128 bits")
Expand All @@ -299,6 +425,49 @@ def test_hashfunc(self):
curve=NIST256p)
self.assertTrue(vk3.verify(sig, data, hashfunc=sha256))

def test_decoding_with_malformed_uncompressed(self):
enc = b('\x0c\xe0\x1d\xe0d\x1c\x8eS\x8a\xc0\x9eK\xa8x !\xd5\xc2\xc3'
'\xfd\xc8\xa0c\xff\xfb\x02\xb9\xc4\x84)\x1a\x0f\x8b\x87\xa4'
'z\x8a#\xb5\x97\xecO\xb6\xa0HQ\x89*')

with self.assertRaises(MalformedPointError):
VerifyingKey.from_string(b('\x02') + enc)

def test_decoding_with_point_not_on_curve(self):
enc = b('\x0c\xe0\x1d\xe0d\x1c\x8eS\x8a\xc0\x9eK\xa8x !\xd5\xc2\xc3'
'\xfd\xc8\xa0c\xff\xfb\x02\xb9\xc4\x84)\x1a\x0f\x8b\x87\xa4'
'z\x8a#\xb5\x97\xecO\xb6\xa0HQ\x89*')

with self.assertRaises(MalformedPointError):
VerifyingKey.from_string(enc[:47] + b('\x00'))

def test_decoding_with_point_at_infinity(self):
# decoding it is unsupported, as it's not necessary to encode it
with self.assertRaises(MalformedPointError):
VerifyingKey.from_string(b('\x00'))

def test_from_string_with_invalid_curve_too_short_ver_key_len(self):
# both verifying_key_length and baselen are calculated internally
# by the Curve constructor, but since we depend on them verify
# that inconsistent values are detected
curve = Curve("test", ecdsa.curve_192, ecdsa.generator_192, (1, 2))
curve.verifying_key_length = 16
curve.baselen = 32

with self.assertRaises(MalformedPointError):
VerifyingKey.from_string(b('\x00')*16, curve)

def test_from_string_with_invalid_curve_too_long_ver_key_len(self):
# both verifying_key_length and baselen are calculated internally
# by the Curve constructor, but since we depend on them verify
# that inconsistent values are detected
curve = Curve("test", ecdsa.curve_192, ecdsa.generator_192, (1, 2))
curve.verifying_key_length = 16
curve.baselen = 16

with self.assertRaises(MalformedPointError):
VerifyingKey.from_string(b('\x00')*16, curve)


class OpenSSL(unittest.TestCase):
# test interoperability with OpenSSL tools. Note that openssl's ECDSA
Expand Down

0 comments on commit 99c907d

Please sign in to comment.