diff --git a/tests/test_asyncio/test_connect.py b/tests/test_asyncio/test_connect.py index 34400ae5ee..2388dcc747 100644 --- a/tests/test_asyncio/test_connect.py +++ b/tests/test_asyncio/test_connect.py @@ -2,10 +2,12 @@ import logging import socket import ssl +from unittest.mock import patch import pytest from redis.asyncio.connection import ( Connection, + ResponseError, SSLConnection, UnixDomainSocketConnection, ) @@ -61,6 +63,90 @@ async def test_tcp_ssl_connect(tcp_address): await conn.disconnect() +@pytest.mark.parametrize( + ("use_server_ver", "use_protocol", "use_auth", "use_client_name"), + [ + (5, 2, False, True), + (5, 2, True, True), + (5, 3, True, True), + (6, 2, False, True), + (6, 2, True, True), + (6, 3, False, False), + (6, 3, True, False), + (6, 3, False, True), + (6, 3, True, True), + ], +) +# @pytest.mark.parametrize("use_protocol", [2, 3]) +# @pytest.mark.parametrize("use_auth", [False, True]) +async def test_tcp_auth( + tcp_address, use_protocol, use_auth, use_server_ver, use_client_name +): + """ + Test that various initial handshake cases are handled correctly by the client + """ + got_auth = [] + got_protocol = None + got_name = None + + def on_auth(self, auth): + got_auth[:] = auth + + def on_protocol(self, proto): + nonlocal got_protocol + got_protocol = proto + + def on_setname(self, name): + nonlocal got_name + got_name = name + + def get_server_version(self): + return use_server_ver + + if use_auth: + auth_args = {"username": "myuser", "password": "mypassword"} + else: + auth_args = {} + got_protocol = None + host, port = tcp_address + conn = Connection( + host=host, + port=port, + client_name=_CLIENT_NAME if use_client_name else None, + socket_timeout=10, + protocol=use_protocol, + **auth_args, + ) + try: + with patch.multiple( + resp.RespServer, + on_auth=on_auth, + get_server_version=get_server_version, + on_protocol=on_protocol, + on_setname=on_setname, + ): + if use_server_ver < 6 and use_protocol > 2: + with pytest.raises(ResponseError): + await _assert_connect(conn, tcp_address) + return + + await _assert_connect(conn, tcp_address) + if use_protocol == 3: + assert got_protocol == use_protocol + if use_auth: + if use_server_ver < 6: + assert got_auth == ["mypassword"] + else: + assert got_auth == ["myuser", "mypassword"] + + if use_client_name: + assert got_name == _CLIENT_NAME + else: + assert got_name is None + finally: + await conn.disconnect() + + async def _assert_connect(conn, server_address, certfile=None, keyfile=None): stop_event = asyncio.Event() finished = asyncio.Event()