Skip to content

Commit

Permalink
add connection handshake tests for non-async version
Browse files Browse the repository at this point in the history
  • Loading branch information
kristjanvalur committed Sep 16, 2023
1 parent 59ece83 commit e9990f2
Showing 1 changed file with 89 additions and 1 deletion.
90 changes: 89 additions & 1 deletion tests/test_connect.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,15 @@
import socketserver
import ssl
import threading
from unittest.mock import patch

import pytest
from redis.connection import Connection, SSLConnection, UnixDomainSocketConnection
from redis.connection import (
Connection,
ResponseError,
SSLConnection,
UnixDomainSocketConnection,
)

from . import resp
from .ssl_utils import get_ssl_filename
Expand Down Expand Up @@ -55,6 +61,88 @@ def test_tcp_ssl_connect(tcp_address):
_assert_connect(conn, tcp_address, certfile=certfile, keyfile=keyfile)


@pytest.mark.parametrize(
("use_server_ver", "use_protocol", "use_auth", "use_client_name"),
[
(5, 2, False, True),
(5, 2, True, True),
(5, 3, True, True),
(6, 2, False, True),
(6, 2, True, True),
(6, 3, False, False),
(6, 3, True, False),
(6, 3, False, True),
(6, 3, True, True),
],
)
# @pytest.mark.parametrize("use_protocol", [2, 3])
# @pytest.mark.parametrize("use_auth", [False, True])
def test_tcp_auth(tcp_address, use_protocol, use_auth, use_server_ver, use_client_name):
"""
Test that various initial handshake cases are handled correctly by the client
"""
got_auth = []
got_protocol = None
got_name = None

def on_auth(self, auth):
got_auth[:] = auth

def on_protocol(self, proto):
nonlocal got_protocol
got_protocol = proto

def on_setname(self, name):
nonlocal got_name
got_name = name

def get_server_version(self):
return use_server_ver

if use_auth:
auth_args = {"username": "myuser", "password": "mypassword"}
else:
auth_args = {}
got_protocol = None
host, port = tcp_address
conn = Connection(
host=host,
port=port,
client_name=_CLIENT_NAME if use_client_name else None,
socket_timeout=10,
protocol=use_protocol,
**auth_args,
)
try:
with patch.multiple(
resp.RespServer,
on_auth=on_auth,
get_server_version=get_server_version,
on_protocol=on_protocol,
on_setname=on_setname,
):
if use_server_ver < 6 and use_protocol > 2:
with pytest.raises(ResponseError):
_assert_connect(conn, tcp_address)
return

_assert_connect(conn, tcp_address)
if use_protocol == 3:
assert got_protocol == use_protocol
if use_auth:
if use_server_ver < 6:
assert got_auth == ["mypassword"]
else:
assert got_auth == ["myuser", "mypassword"]

if use_client_name:
assert got_name == _CLIENT_NAME
else:
assert got_name is None
finally:
conn.disconnect()


def _assert_connect(conn, server_address, certfile=None, keyfile=None):
if isinstance(server_address, str):
if not _RedisUDSServer:
Expand Down

0 comments on commit e9990f2

Please sign in to comment.