Skip to content

Commit

Permalink
pythongh-113280: Always close socket if SSLSocket creation failed (py…
Browse files Browse the repository at this point in the history
…thonGH-114659)

Co-authored-by: Thomas Grainger <[email protected]>
  • Loading branch information
serhiy-storchaka and graingert authored Feb 4, 2024
1 parent ecabff9 commit 0ea3662
Show file tree
Hide file tree
Showing 3 changed files with 78 additions and 64 deletions.
107 changes: 53 additions & 54 deletions Lib/ssl.py
Original file line number Diff line number Diff line change
Expand Up @@ -994,71 +994,67 @@ def _create(cls, sock, server_side=False, do_handshake_on_connect=True,
if context.check_hostname and not server_hostname:
raise ValueError("check_hostname requires server_hostname")

sock_timeout = sock.gettimeout()
kwargs = dict(
family=sock.family, type=sock.type, proto=sock.proto,
fileno=sock.fileno()
)
self = cls.__new__(cls, **kwargs)
super(SSLSocket, self).__init__(**kwargs)
sock_timeout = sock.gettimeout()
sock.detach()

self._context = context
self._session = session
self._closed = False
self._sslobj = None
self.server_side = server_side
self.server_hostname = context._encode_hostname(server_hostname)
self.do_handshake_on_connect = do_handshake_on_connect
self.suppress_ragged_eofs = suppress_ragged_eofs

# See if we are connected
# Now SSLSocket is responsible for closing the file descriptor.
try:
self.getpeername()
except OSError as e:
if e.errno != errno.ENOTCONN:
raise
connected = False
blocking = self.getblocking()
self.setblocking(False)
self._context = context
self._session = session
self._closed = False
self._sslobj = None
self.server_side = server_side
self.server_hostname = context._encode_hostname(server_hostname)
self.do_handshake_on_connect = do_handshake_on_connect
self.suppress_ragged_eofs = suppress_ragged_eofs

# See if we are connected
try:
# We are not connected so this is not supposed to block, but
# testing revealed otherwise on macOS and Windows so we do
# the non-blocking dance regardless. Our raise when any data
# is found means consuming the data is harmless.
notconn_pre_handshake_data = self.recv(1)
self.getpeername()
except OSError as e:
# EINVAL occurs for recv(1) on non-connected on unix sockets.
if e.errno not in (errno.ENOTCONN, errno.EINVAL):
if e.errno != errno.ENOTCONN:
raise
notconn_pre_handshake_data = b''
self.setblocking(blocking)
if notconn_pre_handshake_data:
# This prevents pending data sent to the socket before it was
# closed from escaping to the caller who could otherwise
# presume it came through a successful TLS connection.
reason = "Closed before TLS handshake with data in recv buffer."
notconn_pre_handshake_data_error = SSLError(e.errno, reason)
# Add the SSLError attributes that _ssl.c always adds.
notconn_pre_handshake_data_error.reason = reason
notconn_pre_handshake_data_error.library = None
try:
self.close()
except OSError:
pass
connected = False
blocking = self.getblocking()
self.setblocking(False)
try:
raise notconn_pre_handshake_data_error
finally:
# Explicitly break the reference cycle.
notconn_pre_handshake_data_error = None
else:
connected = True
# We are not connected so this is not supposed to block, but
# testing revealed otherwise on macOS and Windows so we do
# the non-blocking dance regardless. Our raise when any data
# is found means consuming the data is harmless.
notconn_pre_handshake_data = self.recv(1)
except OSError as e:
# EINVAL occurs for recv(1) on non-connected on unix sockets.
if e.errno not in (errno.ENOTCONN, errno.EINVAL):
raise
notconn_pre_handshake_data = b''
self.setblocking(blocking)
if notconn_pre_handshake_data:
# This prevents pending data sent to the socket before it was
# closed from escaping to the caller who could otherwise
# presume it came through a successful TLS connection.
reason = "Closed before TLS handshake with data in recv buffer."
notconn_pre_handshake_data_error = SSLError(e.errno, reason)
# Add the SSLError attributes that _ssl.c always adds.
notconn_pre_handshake_data_error.reason = reason
notconn_pre_handshake_data_error.library = None
try:
raise notconn_pre_handshake_data_error
finally:
# Explicitly break the reference cycle.
notconn_pre_handshake_data_error = None
else:
connected = True

self.settimeout(sock_timeout) # Must come after setblocking() calls.
self._connected = connected
if connected:
# create the SSL object
try:
self.settimeout(sock_timeout) # Must come after setblocking() calls.
self._connected = connected
if connected:
# create the SSL object
self._sslobj = self._context._wrap_socket(
self, server_side, self.server_hostname,
owner=self, session=self._session,
Expand All @@ -1069,9 +1065,12 @@ def _create(cls, sock, server_side=False, do_handshake_on_connect=True,
# non-blocking
raise ValueError("do_handshake_on_connect should not be specified for non-blocking sockets")
self.do_handshake()
except (OSError, ValueError):
except:
try:
self.close()
raise
except OSError:
pass
raise
return self

@property
Expand Down
33 changes: 23 additions & 10 deletions Lib/test/test_ssl.py
Original file line number Diff line number Diff line change
Expand Up @@ -2206,14 +2206,15 @@ def _test_get_server_certificate(test, host, port, cert=None):
sys.stdout.write("\nVerified certificate for %s:%s is\n%s\n" % (host, port ,pem))

def _test_get_server_certificate_fail(test, host, port):
try:
pem = ssl.get_server_certificate((host, port), ca_certs=CERTFILE)
except ssl.SSLError as x:
#should fail
if support.verbose:
sys.stdout.write("%s\n" % x)
else:
test.fail("Got server certificate %s for %s:%s!" % (pem, host, port))
with warnings_helper.check_no_resource_warning(test):
try:
pem = ssl.get_server_certificate((host, port), ca_certs=CERTFILE)
except ssl.SSLError as x:
#should fail
if support.verbose:
sys.stdout.write("%s\n" % x)
else:
test.fail("Got server certificate %s for %s:%s!" % (pem, host, port))


from test.ssl_servers import make_https_server
Expand Down Expand Up @@ -3026,6 +3027,16 @@ def test_check_hostname_idn(self):
server_hostname="python.example.org") as s:
with self.assertRaises(ssl.CertificateError):
s.connect((HOST, server.port))
with ThreadedEchoServer(context=server_context, chatty=True) as server:
with warnings_helper.check_no_resource_warning(self):
with self.assertRaises(UnicodeError):
context.wrap_socket(socket.socket(),
server_hostname='.pythontest.net')
with ThreadedEchoServer(context=server_context, chatty=True) as server:
with warnings_helper.check_no_resource_warning(self):
with self.assertRaises(UnicodeDecodeError):
context.wrap_socket(socket.socket(),
server_hostname=b'k\xf6nig.idn.pythontest.net')

def test_wrong_cert_tls12(self):
"""Connecting when the server rejects the client's certificate
Expand Down Expand Up @@ -4983,7 +4994,8 @@ def call_after_accept(conn_to_client):
self.assertIsNone(wrap_error.library, msg="attr must exist")
finally:
# gh-108342: Explicitly break the reference cycle
wrap_error = None
with warnings_helper.check_no_resource_warning(self):
wrap_error = None
server = None

def test_https_client_non_tls_response_ignored(self):
Expand Down Expand Up @@ -5032,7 +5044,8 @@ def call_after_accept(conn_to_client):
# socket; that fails if the connection is broken. It may seem pointless
# to test this. It serves as an illustration of something that we never
# want to happen... properly not happening.
with self.assertRaises(OSError):
with warnings_helper.check_no_resource_warning(self), \
self.assertRaises(OSError):
connection.request("HEAD", "/test", headers={"Host": "localhost"})
response = connection.getresponse()

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
Fix a leak of open socket in rare cases when error occurred in
:class:`ssl.SSLSocket` creation.

0 comments on commit 0ea3662

Please sign in to comment.