Skip to content

Commit

Permalink
introduce AbstractConnection so that UnixDomainSocketConnection can c…
Browse files Browse the repository at this point in the history
…all super().__init__ (#2588)

Co-authored-by: dvora-h <[email protected]>
  • Loading branch information
woutdenolf and dvora-h authored Mar 16, 2023
1 parent c871723 commit 7d474f9
Showing 1 changed file with 120 additions and 158 deletions.
278 changes: 120 additions & 158 deletions redis/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import sys
import threading
import weakref
from abc import abstractmethod
from io import SEEK_END
from itertools import chain
from queue import Empty, Full, LifoQueue
Expand Down Expand Up @@ -583,20 +584,13 @@ def pack(self, *args):
return output


class Connection:
"Manages TCP communication to and from a Redis server"
class AbstractConnection:
"Manages communication to and from a Redis server"

def __init__(
self,
host="localhost",
port=6379,
db=0,
password=None,
socket_timeout=None,
socket_connect_timeout=None,
socket_keepalive=False,
socket_keepalive_options=None,
socket_type=0,
retry_on_timeout=False,
retry_on_error=SENTINEL,
encoding="utf-8",
Expand Down Expand Up @@ -627,18 +621,11 @@ def __init__(
"2. 'credential_provider'"
)
self.pid = os.getpid()
self.host = host
self.port = int(port)
self.db = db
self.client_name = client_name
self.credential_provider = credential_provider
self.password = password
self.username = username
self.socket_timeout = socket_timeout
self.socket_connect_timeout = socket_connect_timeout or socket_timeout
self.socket_keepalive = socket_keepalive
self.socket_keepalive_options = socket_keepalive_options or {}
self.socket_type = socket_type
self.retry_on_timeout = retry_on_timeout
if retry_on_error is SENTINEL:
retry_on_error = []
Expand Down Expand Up @@ -671,11 +658,9 @@ def __repr__(self):
repr_args = ",".join([f"{k}={v}" for k, v in self.repr_pieces()])
return f"{self.__class__.__name__}<{repr_args}>"

@abstractmethod
def repr_pieces(self):
pieces = [("host", self.host), ("port", self.port), ("db", self.db)]
if self.client_name:
pieces.append(("client_name", self.client_name))
return pieces
pass

def __del__(self):
try:
Expand Down Expand Up @@ -738,75 +723,17 @@ def connect(self):
if callback:
callback(self)

@abstractmethod
def _connect(self):
"Create a TCP socket connection"
# we want to mimic what socket.create_connection does to support
# ipv4/ipv6, but we want to set options prior to calling
# socket.connect()
err = None
for res in socket.getaddrinfo(
self.host, self.port, self.socket_type, socket.SOCK_STREAM
):
family, socktype, proto, canonname, socket_address = res
sock = None
try:
sock = socket.socket(family, socktype, proto)
# TCP_NODELAY
sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)

# TCP_KEEPALIVE
if self.socket_keepalive:
sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1)
for k, v in self.socket_keepalive_options.items():
sock.setsockopt(socket.IPPROTO_TCP, k, v)

# set the socket_connect_timeout before we connect
sock.settimeout(self.socket_connect_timeout)

# connect
sock.connect(socket_address)

# set the socket_timeout now that we're connected
sock.settimeout(self.socket_timeout)
return sock

except OSError as _:
err = _
if sock is not None:
sock.close()

if err is not None:
raise err
raise OSError("socket.getaddrinfo returned an empty list")
pass

@abstractmethod
def _host_error(self):
try:
host_error = f"{self.host}:{self.port}"
except AttributeError:
host_error = "connection"

return host_error
pass

@abstractmethod
def _error_message(self, exception):
# args for socket.error can either be (errno, "message")
# or just "message"

host_error = self._host_error()

if len(exception.args) == 1:
try:
return f"Error connecting to {host_error}. \
{exception.args[0]}."
except AttributeError:
return f"Connection Error: {exception.args[0]}"
else:
try:
return (
f"Error {exception.args[0]} connecting to "
f"{host_error}. {exception.args[1]}."
)
except AttributeError:
return f"Connection Error: {exception.args[0]}"
pass

def on_connect(self):
"Initialize the connection, authenticate and select a database"
Expand Down Expand Up @@ -990,6 +917,101 @@ def pack_commands(self, commands):
return output


class Connection(AbstractConnection):
"Manages TCP communication to and from a Redis server"

def __init__(
self,
host="localhost",
port=6379,
socket_timeout=None,
socket_connect_timeout=None,
socket_keepalive=False,
socket_keepalive_options=None,
socket_type=0,
**kwargs,
):
self.host = host
self.port = int(port)
self.socket_timeout = socket_timeout
self.socket_connect_timeout = socket_connect_timeout or socket_timeout
self.socket_keepalive = socket_keepalive
self.socket_keepalive_options = socket_keepalive_options or {}
self.socket_type = socket_type
super().__init__(**kwargs)

def repr_pieces(self):
pieces = [("host", self.host), ("port", self.port), ("db", self.db)]
if self.client_name:
pieces.append(("client_name", self.client_name))
return pieces

def _connect(self):
"Create a TCP socket connection"
# we want to mimic what socket.create_connection does to support
# ipv4/ipv6, but we want to set options prior to calling
# socket.connect()
err = None
for res in socket.getaddrinfo(
self.host, self.port, self.socket_type, socket.SOCK_STREAM
):
family, socktype, proto, canonname, socket_address = res
sock = None
try:
sock = socket.socket(family, socktype, proto)
# TCP_NODELAY
sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)

# TCP_KEEPALIVE
if self.socket_keepalive:
sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1)
for k, v in self.socket_keepalive_options.items():
sock.setsockopt(socket.IPPROTO_TCP, k, v)

# set the socket_connect_timeout before we connect
sock.settimeout(self.socket_connect_timeout)

# connect
sock.connect(socket_address)

# set the socket_timeout now that we're connected
sock.settimeout(self.socket_timeout)
return sock

except OSError as _:
err = _
if sock is not None:
sock.close()

if err is not None:
raise err
raise OSError("socket.getaddrinfo returned an empty list")

def _host_error(self):
return f"{self.host}:{self.port}"

def _error_message(self, exception):
# args for socket.error can either be (errno, "message")
# or just "message"

host_error = self._host_error()

if len(exception.args) == 1:
try:
return f"Error connecting to {host_error}. \
{exception.args[0]}."
except AttributeError:
return f"Connection Error: {exception.args[0]}"
else:
try:
return (
f"Error {exception.args[0]} connecting to "
f"{host_error}. {exception.args[1]}."
)
except AttributeError:
return f"Connection Error: {exception.args[0]}"


class SSLConnection(Connection):
"""Manages SSL connections to and from the Redis server(s).
This class extends the Connection class, adding SSL functionality, and making
Expand Down Expand Up @@ -1035,8 +1057,6 @@ def __init__(
if not ssl_available:
raise RedisError("Python wasn't built with SSL support")

super().__init__(**kwargs)

self.keyfile = ssl_keyfile
self.certfile = ssl_certfile
if ssl_cert_reqs is None:
Expand All @@ -1062,6 +1082,7 @@ def __init__(
self.ssl_validate_ocsp_stapled = ssl_validate_ocsp_stapled
self.ssl_ocsp_context = ssl_ocsp_context
self.ssl_ocsp_expected_cert = ssl_ocsp_expected_cert
super().__init__(**kwargs)

def _connect(self):
"Wrap the socket with SSL support"
Expand Down Expand Up @@ -1131,77 +1152,12 @@ def _connect(self):
return sslsock


class UnixDomainSocketConnection(Connection):
def __init__(
self,
path="",
db=0,
username=None,
password=None,
socket_timeout=None,
encoding="utf-8",
encoding_errors="strict",
decode_responses=False,
retry_on_timeout=False,
retry_on_error=SENTINEL,
parser_class=DefaultParser,
socket_read_size=65536,
health_check_interval=0,
client_name=None,
retry=None,
redis_connect_func=None,
credential_provider: Optional[CredentialProvider] = None,
command_packer=None,
):
"""
Initialize a new UnixDomainSocketConnection.
To specify a retry policy for specific errors, first set
`retry_on_error` to a list of the error/s to retry on, then set
`retry` to a valid `Retry` object.
To retry on TimeoutError, `retry_on_timeout` can also be set to `True`.
"""
if (username or password) and credential_provider is not None:
raise DataError(
"'username' and 'password' cannot be passed along with 'credential_"
"provider'. Please provide only one of the following arguments: \n"
"1. 'password' and (optional) 'username'\n"
"2. 'credential_provider'"
)
self.pid = os.getpid()
class UnixDomainSocketConnection(AbstractConnection):
"Manages UDS communication to and from a Redis server"

def __init__(self, path="", **kwargs):
self.path = path
self.db = db
self.client_name = client_name
self.credential_provider = credential_provider
self.password = password
self.username = username
self.socket_timeout = socket_timeout
self.retry_on_timeout = retry_on_timeout
if retry_on_error is SENTINEL:
retry_on_error = []
if retry_on_timeout:
# Add TimeoutError to the errors list to retry on
retry_on_error.append(TimeoutError)
self.retry_on_error = retry_on_error
if self.retry_on_error:
if retry is None:
self.retry = Retry(NoBackoff(), 1)
else:
# deep-copy the Retry object as it is mutable
self.retry = copy.deepcopy(retry)
# Update the retry's supported errors with the specified errors
self.retry.update_supported_errors(retry_on_error)
else:
self.retry = Retry(NoBackoff(), 0)
self.health_check_interval = health_check_interval
self.next_health_check = 0
self.redis_connect_func = redis_connect_func
self.encoder = Encoder(encoding, encoding_errors, decode_responses)
self._sock = None
self._socket_read_size = socket_read_size
self.set_parser(parser_class)
self._connect_callbacks = []
self._buffer_cutoff = 6000
self._command_packer = self._construct_command_packer(command_packer)
super().__init__(**kwargs)

def repr_pieces(self):
pieces = [("path", self.path), ("db", self.db)]
Expand All @@ -1216,15 +1172,21 @@ def _connect(self):
sock.connect(self.path)
return sock

def _host_error(self):
return self.path

def _error_message(self, exception):
# args for socket.error can either be (errno, "message")
# or just "message"
host_error = self._host_error()
if len(exception.args) == 1:
return f"Error connecting to unix socket: {self.path}. {exception.args[0]}."
return (
f"Error connecting to unix socket: {host_error}. {exception.args[0]}."
)
else:
return (
f"Error {exception.args[0]} connecting to unix socket: "
f"{self.path}. {exception.args[1]}."
f"{host_error}. {exception.args[1]}."
)


Expand Down

0 comments on commit 7d474f9

Please sign in to comment.