From 06ec1ea1cc7be6034144bd06f07c35eb9d1b4953 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sybren=20A=2E=20St=C3=BCvel?= Date: Sun, 15 Nov 2020 16:25:51 +0100 Subject: [PATCH] Fix #162: Blinding uses slow algorithm Store blinding factor + its inverse, so that they can be reused & updated on every blinding operation. This avoids expensive computations. The reuse of the previous blinding factor is done via squaring (mod n), as per section 9 of 'A Timing Attack against RSA with the Chinese Remainder Theorem' by Werner Schindler, https://tls.mbed.org/public/WSchindler-RSA_Timing_Attack.pdf --- CHANGELOG.md | 2 ++ rsa/key.py | 52 +++++++++++++++++++++++++++++------------------ tests/test_key.py | 17 ++++++++++++---- 3 files changed, 47 insertions(+), 24 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index f61b3c4..fe1ab28 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,8 @@ code - Add padding length check as described by PKCS#1 v1.5 (Fixes [#164](https://github.com/sybrenstuvel/python-rsa/issues/164)) +- Reuse of blinding factors to speed up blinding operations. + Fixes [#162](https://github.com/sybrenstuvel/python-rsa/issues/162). ## Version 4.4 & 4.6 - released 2020-06-12 diff --git a/rsa/key.py b/rsa/key.py index b1e2030..e0e7b11 100644 --- a/rsa/key.py +++ b/rsa/key.py @@ -49,12 +49,15 @@ class AbstractKey: """Abstract superclass for private and public keys.""" - __slots__ = ('n', 'e') + __slots__ = ('n', 'e', 'blindfac', 'blindfac_inverse') def __init__(self, n: int, e: int) -> None: self.n = n self.e = e + # These will be computed properly on the first call to blind(). + self.blindfac = self.blindfac_inverse = -1 + @classmethod def _load_pkcs1_pem(cls, keyfile: bytes) -> 'AbstractKey': """Loads a key in PKCS#1 PEM format, implement in a subclass. @@ -145,7 +148,7 @@ def save_pkcs1(self, format: str = 'PEM') -> bytes: method = self._assert_format_exists(format, methods) return method() - def blind(self, message: int, r: int) -> int: + def blind(self, message: int) -> int: """Performs blinding on the message using random number 'r'. :param message: the message, as integer, to blind. @@ -159,10 +162,10 @@ def blind(self, message: int, r: int) -> int: See https://en.wikipedia.org/wiki/Blinding_%28cryptography%29 """ + self._update_blinding_factor() + return (message * pow(self.blindfac, self.e, self.n)) % self.n - return (message * pow(r, self.e, self.n)) % self.n - - def unblind(self, blinded: int, r: int) -> int: + def unblind(self, blinded: int) -> int: """Performs blinding on the message using random number 'r'. :param blinded: the blinded message, as integer, to unblind. @@ -174,8 +177,27 @@ def unblind(self, blinded: int, r: int) -> int: See https://en.wikipedia.org/wiki/Blinding_%28cryptography%29 """ - return (rsa.common.inverse(r, self.n) * blinded) % self.n + return (self.blindfac_inverse * blinded) % self.n + def _initial_blinding_factor(self) -> int: + for _ in range(1000): + blind_r = rsa.randnum.randint(self.n - 1) + if rsa.prime.are_relatively_prime(self.n, blind_r): + return blind_r + raise RuntimeError('unable to find blinding factor') + + def _update_blinding_factor(self): + if self.blindfac < 0: + # Compute initial blinding factor, which is rather slow to do. + self.blindfac = self._initial_blinding_factor() + self.blindfac_inverse = rsa.common.inverse(self.blindfac, self.n) + else: + # Reuse previous blinding factor as per section 9 of 'A Timing + # Attack against RSA with the Chinese Remainder Theorem' by Werner + # Schindler. + # See https://tls.mbed.org/public/WSchindler-RSA_Timing_Attack.pdf + self.blindfac = pow(self.blindfac, 2, self.n) + self.blindfac_inverse = pow(self.blindfac_inverse, 2, self.n) class PublicKey(AbstractKey): """Represents a public RSA key. @@ -414,13 +436,6 @@ def __ne__(self, other: typing.Any) -> bool: def __hash__(self) -> int: return hash((self.n, self.e, self.d, self.p, self.q, self.exp1, self.exp2, self.coef)) - def _get_blinding_factor(self) -> int: - for _ in range(1000): - blind_r = rsa.randnum.randint(self.n - 1) - if rsa.prime.are_relatively_prime(self.n, blind_r): - return blind_r - raise RuntimeError('unable to find blinding factor') - def blinded_decrypt(self, encrypted: int) -> int: """Decrypts the message using blinding to prevent side-channel attacks. @@ -431,11 +446,9 @@ def blinded_decrypt(self, encrypted: int) -> int: :rtype: int """ - blind_r = self._get_blinding_factor() - blinded = self.blind(encrypted, blind_r) # blind before decrypting + blinded = self.blind(encrypted) # blind before decrypting decrypted = rsa.core.decrypt_int(blinded, self.d, self.n) - - return self.unblind(decrypted, blind_r) + return self.unblind(decrypted) def blinded_encrypt(self, message: int) -> int: """Encrypts the message using blinding to prevent side-channel attacks. @@ -447,10 +460,9 @@ def blinded_encrypt(self, message: int) -> int: :rtype: int """ - blind_r = self._get_blinding_factor() - blinded = self.blind(message, blind_r) # blind before encrypting + blinded = self.blind(message) # blind before encrypting encrypted = rsa.core.encrypt_int(blinded, self.d, self.n) - return self.unblind(encrypted, blind_r) + return self.unblind(encrypted) @classmethod def _load_pkcs1_der(cls, keyfile: bytes) -> 'PrivateKey': diff --git a/tests/test_key.py b/tests/test_key.py index 9db30ce..b00e26d 100644 --- a/tests/test_key.py +++ b/tests/test_key.py @@ -21,11 +21,20 @@ def test_blinding(self): message = 12345 encrypted = rsa.core.encrypt_int(message, pk.e, pk.n) - blinded = pk.blind(encrypted, 4134431) # blind before decrypting - decrypted = rsa.core.decrypt_int(blinded, pk.d, pk.n) - unblinded = pk.unblind(decrypted, 4134431) + blinded_1 = pk.blind(encrypted) # blind before decrypting + decrypted = rsa.core.decrypt_int(blinded_1, pk.d, pk.n) + unblinded_1 = pk.unblind(decrypted) - self.assertEqual(unblinded, message) + self.assertEqual(unblinded_1, message) + + # Re-blinding should use a different blinding factor. + blinded_2 = pk.blind(encrypted) # blind before decrypting + self.assertNotEqual(blinded_1, blinded_2) + + # The unblinding should still work, though. + decrypted = rsa.core.decrypt_int(blinded_2, pk.d, pk.n) + unblinded_2 = pk.unblind(decrypted) + self.assertEqual(unblinded_2, message) class KeyGenTest(unittest.TestCase):