diff --git a/tests/test_connect.py b/tests/test_connect.py index 574693ee5f..49c3abe506 100644 --- a/tests/test_connect.py +++ b/tests/test_connect.py @@ -3,9 +3,15 @@ import socketserver import ssl import threading +from unittest.mock import patch import pytest -from redis.connection import Connection, SSLConnection, UnixDomainSocketConnection +from redis.connection import ( + Connection, + ResponseError, + SSLConnection, + UnixDomainSocketConnection, +) from . import resp from .ssl_utils import get_ssl_filename @@ -55,6 +61,88 @@ def test_tcp_ssl_connect(tcp_address): _assert_connect(conn, tcp_address, certfile=certfile, keyfile=keyfile) +@pytest.mark.parametrize( + ("use_server_ver", "use_protocol", "use_auth", "use_client_name"), + [ + (5, 2, False, True), + (5, 2, True, True), + (5, 3, True, True), + (6, 2, False, True), + (6, 2, True, True), + (6, 3, False, False), + (6, 3, True, False), + (6, 3, False, True), + (6, 3, True, True), + ], +) +# @pytest.mark.parametrize("use_protocol", [2, 3]) +# @pytest.mark.parametrize("use_auth", [False, True]) +def test_tcp_auth(tcp_address, use_protocol, use_auth, use_server_ver, use_client_name): + """ + Test that various initial handshake cases are handled correctly by the client + """ + got_auth = [] + got_protocol = None + got_name = None + + def on_auth(self, auth): + got_auth[:] = auth + + def on_protocol(self, proto): + nonlocal got_protocol + got_protocol = proto + + def on_setname(self, name): + nonlocal got_name + got_name = name + + def get_server_version(self): + return use_server_ver + + if use_auth: + auth_args = {"username": "myuser", "password": "mypassword"} + else: + auth_args = {} + got_protocol = None + host, port = tcp_address + conn = Connection( + host=host, + port=port, + client_name=_CLIENT_NAME if use_client_name else None, + socket_timeout=10, + protocol=use_protocol, + **auth_args, + ) + try: + with patch.multiple( + resp.RespServer, + on_auth=on_auth, + get_server_version=get_server_version, + on_protocol=on_protocol, + on_setname=on_setname, + ): + if use_server_ver < 6 and use_protocol > 2: + with pytest.raises(ResponseError): + _assert_connect(conn, tcp_address) + return + + _assert_connect(conn, tcp_address) + if use_protocol == 3: + assert got_protocol == use_protocol + if use_auth: + if use_server_ver < 6: + assert got_auth == ["mypassword"] + else: + assert got_auth == ["myuser", "mypassword"] + + if use_client_name: + assert got_name == _CLIENT_NAME + else: + assert got_name is None + finally: + conn.disconnect() + + def _assert_connect(conn, server_address, certfile=None, keyfile=None): if isinstance(server_address, str): if not _RedisUDSServer: