diff --git a/CHANGES b/CHANGES index 3d9d6292a1..e0959b0ef3 100644 --- a/CHANGES +++ b/CHANGES @@ -1,3 +1,4 @@ + * Allow to control the minimum SSL version * Add an optional lock_name attribute to LockError. * Fix return types for `get`, `set_path` and `strappend` in JSONCommands * Connection.register_connect_callback() is made public. diff --git a/docs/examples/ssl_connection_examples.ipynb b/docs/examples/ssl_connection_examples.ipynb index ab3b4415ae..a3d015619f 100644 --- a/docs/examples/ssl_connection_examples.ipynb +++ b/docs/examples/ssl_connection_examples.ipynb @@ -76,6 +76,42 @@ "ssl_connection.ping()" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Connecting to a Redis instance via SSL, while specifying a minimum TLS version" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "True" + ] + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "import redis\n", + "import ssl\n", + "\n", + "ssl_conn = redis.Redis(\n", + " host=\"localhost\",\n", + " port=6666,\n", + " ssl=True,\n", + " ssl_min_version=ssl.TLSVersion.TLSv1_3,\n", + ")\n", + "ssl_conn.ping()" + ] + }, { "cell_type": "markdown", "metadata": {}, diff --git a/redis/asyncio/client.py b/redis/asyncio/client.py index 88de893f5b..62bdc7dd5c 100644 --- a/redis/asyncio/client.py +++ b/redis/asyncio/client.py @@ -2,6 +2,7 @@ import copy import inspect import re +import ssl import warnings from typing import ( TYPE_CHECKING, @@ -226,6 +227,7 @@ def __init__( ssl_ca_certs: Optional[str] = None, ssl_ca_data: Optional[str] = None, ssl_check_hostname: bool = False, + ssl_min_version: Optional[ssl.TLSVersion] = None, max_connections: Optional[int] = None, single_connection_client: bool = False, health_check_interval: int = 0, @@ -332,6 +334,7 @@ def __init__( "ssl_ca_certs": ssl_ca_certs, "ssl_ca_data": ssl_ca_data, "ssl_check_hostname": ssl_check_hostname, + "ssl_min_version": ssl_min_version, } ) # This arg only used if no pool is passed in diff --git a/redis/asyncio/cluster.py b/redis/asyncio/cluster.py index 337c7bbdcc..4fb2fc4647 100644 --- a/redis/asyncio/cluster.py +++ b/redis/asyncio/cluster.py @@ -2,6 +2,7 @@ import collections import random import socket +import ssl import warnings from typing import ( Any, @@ -271,6 +272,7 @@ def __init__( ssl_certfile: Optional[str] = None, ssl_check_hostname: bool = False, ssl_keyfile: Optional[str] = None, + ssl_min_version: Optional[ssl.TLSVersion] = None, protocol: Optional[int] = 2, address_remap: Optional[Callable[[str, int], Tuple[str, int]]] = None, cache_enabled: bool = False, @@ -344,6 +346,7 @@ def __init__( "ssl_certfile": ssl_certfile, "ssl_check_hostname": ssl_check_hostname, "ssl_keyfile": ssl_keyfile, + "ssl_min_version": ssl_min_version, } ) diff --git a/redis/asyncio/connection.py b/redis/asyncio/connection.py index 07c4262233..81df3b3543 100644 --- a/redis/asyncio/connection.py +++ b/redis/asyncio/connection.py @@ -823,6 +823,7 @@ def __init__( ssl_ca_certs: Optional[str] = None, ssl_ca_data: Optional[str] = None, ssl_check_hostname: bool = False, + ssl_min_version: Optional[ssl.TLSVersion] = None, **kwargs, ): self.ssl_context: RedisSSLContext = RedisSSLContext( @@ -832,6 +833,7 @@ def __init__( ca_certs=ssl_ca_certs, ca_data=ssl_ca_data, check_hostname=ssl_check_hostname, + min_version=ssl_min_version, ) super().__init__(**kwargs) @@ -864,6 +866,10 @@ def ca_data(self): def check_hostname(self): return self.ssl_context.check_hostname + @property + def min_version(self): + return self.ssl_context.min_version + class RedisSSLContext: __slots__ = ( @@ -874,6 +880,7 @@ class RedisSSLContext: "ca_data", "context", "check_hostname", + "min_version", ) def __init__( @@ -884,6 +891,7 @@ def __init__( ca_certs: Optional[str] = None, ca_data: Optional[str] = None, check_hostname: bool = False, + min_version: Optional[ssl.TLSVersion] = None, ): self.keyfile = keyfile self.certfile = certfile @@ -903,6 +911,7 @@ def __init__( self.ca_certs = ca_certs self.ca_data = ca_data self.check_hostname = check_hostname + self.min_version = min_version self.context: Optional[ssl.SSLContext] = None def get(self) -> ssl.SSLContext: @@ -914,6 +923,8 @@ def get(self) -> ssl.SSLContext: context.load_cert_chain(certfile=self.certfile, keyfile=self.keyfile) if self.ca_certs or self.ca_data: context.load_verify_locations(cafile=self.ca_certs, cadata=self.ca_data) + if self.min_version is not None: + context.minimum_version = self.min_version self.context = context return self.context diff --git a/redis/client.py b/redis/client.py index 2d4c512699..1209a978d2 100755 --- a/redis/client.py +++ b/redis/client.py @@ -198,6 +198,7 @@ def __init__( ssl_validate_ocsp_stapled=False, ssl_ocsp_context=None, ssl_ocsp_expected_cert=None, + ssl_min_version=None, max_connections=None, single_connection_client=False, health_check_interval=0, @@ -311,6 +312,7 @@ def __init__( "ssl_validate_ocsp": ssl_validate_ocsp, "ssl_ocsp_context": ssl_ocsp_context, "ssl_ocsp_expected_cert": ssl_ocsp_expected_cert, + "ssl_min_version": ssl_min_version, } ) connection_pool = ConnectionPool(**kwargs) diff --git a/redis/connection.py b/redis/connection.py index 1f46267146..c9f7fc55d0 100644 --- a/redis/connection.py +++ b/redis/connection.py @@ -769,6 +769,7 @@ def __init__( ssl_validate_ocsp_stapled=False, ssl_ocsp_context=None, ssl_ocsp_expected_cert=None, + ssl_min_version=None, **kwargs, ): """Constructor @@ -787,6 +788,7 @@ def __init__( ssl_validate_ocsp_stapled: If set, perform a validation on a stapled ocsp response ssl_ocsp_context: A fully initialized OpenSSL.SSL.Context object to be used in verifying the ssl_ocsp_expected_cert ssl_ocsp_expected_cert: A PEM armoured string containing the expected certificate to be returned from the ocsp verification service. + ssl_min_version: The lowest supported SSL version. It affects the supported SSL versions of the SSLContext. None leaves the default provided by ssl module. Raises: RedisError @@ -819,6 +821,7 @@ def __init__( self.ssl_validate_ocsp_stapled = ssl_validate_ocsp_stapled self.ssl_ocsp_context = ssl_ocsp_context self.ssl_ocsp_expected_cert = ssl_ocsp_expected_cert + self.ssl_min_version = ssl_min_version super().__init__(**kwargs) def _connect(self): @@ -841,6 +844,8 @@ def _connect(self): context.load_verify_locations( cafile=self.ca_certs, capath=self.ca_path, cadata=self.ca_data ) + if self.ssl_min_version is not None: + context.minimum_version = self.ssl_min_version sslsock = context.wrap_socket(sock, server_hostname=self.host) if self.ssl_validate_ocsp is True and CRYPTOGRAPHY_AVAILABLE is False: raise RedisError("cryptography is not installed.") diff --git a/tests/test_asyncio/test_connect.py b/tests/test_asyncio/test_connect.py index 5e6b120fb3..5497501258 100644 --- a/tests/test_asyncio/test_connect.py +++ b/tests/test_asyncio/test_connect.py @@ -10,6 +10,7 @@ SSLConnection, UnixDomainSocketConnection, ) +from redis.exceptions import ConnectionError from ..ssl_utils import get_ssl_filename @@ -50,7 +51,17 @@ async def test_uds_connect(uds_address): @pytest.mark.ssl -async def test_tcp_ssl_connect(tcp_address): +@pytest.mark.parametrize( + "ssl_min_version", + [ + ssl.TLSVersion.TLSv1_2, + pytest.param( + ssl.TLSVersion.TLSv1_3, + marks=pytest.mark.skipif(not ssl.HAS_TLSv1_3, reason="requires TLSv1.3"), + ), + ], +) +async def test_tcp_ssl_connect(tcp_address, ssl_min_version): host, port = tcp_address certfile = get_ssl_filename("server-cert.pem") keyfile = get_ssl_filename("server-key.pem") @@ -60,12 +71,44 @@ async def test_tcp_ssl_connect(tcp_address): client_name=_CLIENT_NAME, ssl_ca_certs=certfile, socket_timeout=10, + ssl_min_version=ssl_min_version, ) await _assert_connect(conn, tcp_address, certfile=certfile, keyfile=keyfile) await conn.disconnect() -async def _assert_connect(conn, server_address, certfile=None, keyfile=None): +@pytest.mark.ssl +@pytest.mark.skipif(not ssl.HAS_TLSv1_3, reason="requires TLSv1.3") +async def test_tcp_ssl_version_mismatch(tcp_address): + host, port = tcp_address + certfile = get_ssl_filename("server-cert.pem") + keyfile = get_ssl_filename("server-key.pem") + conn = SSLConnection( + host=host, + port=port, + client_name=_CLIENT_NAME, + ssl_ca_certs=certfile, + socket_timeout=1, + ssl_min_version=ssl.TLSVersion.TLSv1_3, + ) + with pytest.raises(ConnectionError): + await _assert_connect( + conn, + tcp_address, + certfile=certfile, + keyfile=keyfile, + ssl_version=ssl.TLSVersion.TLSv1_2, + ) + await conn.disconnect() + + +async def _assert_connect( + conn, + server_address, + certfile=None, + keyfile=None, + ssl_version=None, +): stop_event = asyncio.Event() finished = asyncio.Event() @@ -82,7 +125,9 @@ async def _handler(reader, writer): elif certfile: host, port = server_address context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH) - context.minimum_version = ssl.TLSVersion.TLSv1_2 + if ssl_version is not None: + context.minimum_version = ssl_version + context.maximum_version = ssl_version context.load_cert_chain(certfile=certfile, keyfile=keyfile) server = await asyncio.start_server(_handler, host=host, port=port, ssl=context) else: @@ -94,6 +139,9 @@ async def _handler(reader, writer): try: await conn.connect() await conn.disconnect() + except ConnectionError: + finished.set() + raise finally: stop_event.set() aserver.close() diff --git a/tests/test_connect.py b/tests/test_connect.py index 696e69ceea..0fdbb7005f 100644 --- a/tests/test_connect.py +++ b/tests/test_connect.py @@ -7,6 +7,7 @@ import pytest from redis.connection import Connection, SSLConnection, UnixDomainSocketConnection +from redis.exceptions import ConnectionError from .ssl_utils import get_ssl_filename @@ -45,7 +46,17 @@ def test_uds_connect(uds_address): @pytest.mark.ssl -def test_tcp_ssl_connect(tcp_address): +@pytest.mark.parametrize( + "ssl_min_version", + [ + ssl.TLSVersion.TLSv1_2, + pytest.param( + ssl.TLSVersion.TLSv1_3, + marks=pytest.mark.skipif(not ssl.HAS_TLSv1_3, reason="requires TLSv1.3"), + ), + ], +) +def test_tcp_ssl_connect(tcp_address, ssl_min_version): host, port = tcp_address certfile = get_ssl_filename("server-cert.pem") keyfile = get_ssl_filename("server-key.pem") @@ -55,19 +66,42 @@ def test_tcp_ssl_connect(tcp_address): client_name=_CLIENT_NAME, ssl_ca_certs=certfile, socket_timeout=10, + ssl_min_version=ssl_min_version, ) _assert_connect(conn, tcp_address, certfile=certfile, keyfile=keyfile) -def _assert_connect(conn, server_address, certfile=None, keyfile=None): +@pytest.mark.ssl +@pytest.mark.skipif(not ssl.HAS_TLSv1_3, reason="requires TLSv1.3") +def test_tcp_ssl_version_mismatch(tcp_address): + host, port = tcp_address + certfile = get_ssl_filename("server-cert.pem") + keyfile = get_ssl_filename("server-key.pem") + conn = SSLConnection( + host=host, + port=port, + client_name=_CLIENT_NAME, + ssl_ca_certs=certfile, + socket_timeout=10, + ssl_min_version=ssl.TLSVersion.TLSv1_3, + ) + with pytest.raises(ConnectionError): + _assert_connect( + conn, + tcp_address, + certfile=certfile, + keyfile=keyfile, + ssl_version=ssl.PROTOCOL_TLSv1_2, + ) + + +def _assert_connect(conn, server_address, **tcp_kw): if isinstance(server_address, str): if not _RedisUDSServer: pytest.skip("Unix domain sockets are not supported on this platform") server = _RedisUDSServer(server_address, _RedisRequestHandler) else: - server = _RedisTCPServer( - server_address, _RedisRequestHandler, certfile=certfile, keyfile=keyfile - ) + server = _RedisTCPServer(server_address, _RedisRequestHandler, **tcp_kw) with server as aserver: t = threading.Thread(target=aserver.serve_forever) t.start() @@ -81,11 +115,19 @@ def _assert_connect(conn, server_address, certfile=None, keyfile=None): class _RedisTCPServer(socketserver.TCPServer): - def __init__(self, *args, certfile=None, keyfile=None, **kw) -> None: + def __init__( + self, + *args, + certfile=None, + keyfile=None, + ssl_version=ssl.PROTOCOL_TLS, + **kw, + ) -> None: self._ready_event = threading.Event() self._stop_requested = False self._certfile = certfile self._keyfile = keyfile + self._ssl_version = ssl_version super().__init__(*args, **kw) def service_actions(self): @@ -110,7 +152,7 @@ def get_request(self): server_side=True, certfile=self._certfile, keyfile=self._keyfile, - ssl_version=ssl.PROTOCOL_TLSv1_2, + ssl_version=self._ssl_version, ) return connstream, fromaddr