Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

gh-108342: Make ssl TestPreHandshakeClose more reliable #108370

Merged
merged 3 commits into from
Aug 23, 2023
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
103 changes: 70 additions & 33 deletions Lib/test/test_ssl.py
Original file line number Diff line number Diff line change
Expand Up @@ -4672,12 +4672,16 @@ class TestPreHandshakeClose(unittest.TestCase):

class SingleConnectionTestServerThread(threading.Thread):

def __init__(self, *, name, call_after_accept):
def __init__(self, *, name, call_after_accept, timeout=None):
self.call_after_accept = call_after_accept
self.received_data = b'' # set by .run()
self.wrap_error = None # set by .run()
self.listener = None # set by .start()
self.port = None # set by .start()
if timeout is None:
self.timeout = support.SHORT_TIMEOUT
else:
self.timeout = timeout
super().__init__(name=name)

def __enter__(self):
Expand All @@ -4700,13 +4704,19 @@ def start(self):
self.ssl_ctx.load_cert_chain(certfile=ONLYCERT, keyfile=ONLYKEY)
self.listener = socket.socket()
self.port = socket_helper.bind_port(self.listener)
self.listener.settimeout(2.0)
self.listener.settimeout(self.timeout)
self.listener.listen(1)
super().start()

def run(self):
conn, address = self.listener.accept()
self.listener.close()
try:
conn, address = self.listener.accept()
except TimeoutError:
# on timeout, just close the listener
return
finally:
self.listener.close()

with conn:
if self.call_after_accept(conn):
return
Expand Down Expand Up @@ -4734,8 +4744,13 @@ def non_linux_skip_if_other_okay_error(self, err):
# we're specifically trying to test. The way this test is written
# is known to work on Linux. We'll skip it anywhere else that it
# does not present as doing so.
self.skipTest(f"Could not recreate conditions on {sys.platform}:"
f" {err=}")
try:
self.skipTest(f"Could not recreate conditions on {sys.platform}:"
f" {err=}")
finally:
# gh-108342: Explicitly break the reference cycle
err = None

# If maintaining this conditional winds up being a problem.
# just turn this into an unconditional skip anything but Linux.
# The important thing is that our CI has the logic covered.
Expand All @@ -4746,7 +4761,7 @@ def test_preauth_data_to_tls_server(self):

def call_after_accept(unused):
server_accept_called.set()
if not ready_for_server_wrap_socket.wait(2.0):
if not ready_for_server_wrap_socket.wait(support.SHORT_TIMEOUT):
raise RuntimeError("wrap_socket event never set, test may fail.")
return False # Tell the server thread to continue.

Expand All @@ -4767,23 +4782,34 @@ def call_after_accept(unused):

ready_for_server_wrap_socket.set()
server.join()

wrap_error = server.wrap_error
self.assertEqual(b"", server.received_data)
self.assertIsInstance(wrap_error, OSError) # All platforms.
self.non_linux_skip_if_other_okay_error(wrap_error)
self.assertIsInstance(wrap_error, ssl.SSLError)
self.assertIn("before TLS handshake with data", wrap_error.args[1])
self.assertIn("before TLS handshake with data", wrap_error.reason)
self.assertNotEqual(0, wrap_error.args[0])
self.assertIsNone(wrap_error.library, msg="attr must exist")
server.wrap_error = None
try:
self.assertEqual(b"", server.received_data)
self.assertIsInstance(wrap_error, OSError) # All platforms.
self.non_linux_skip_if_other_okay_error(wrap_error)
self.assertIsInstance(wrap_error, ssl.SSLError)
self.assertIn("before TLS handshake with data", wrap_error.args[1])
self.assertIn("before TLS handshake with data", wrap_error.reason)
self.assertNotEqual(0, wrap_error.args[0])
self.assertIsNone(wrap_error.library, msg="attr must exist")
finally:
# gh-108342: Explicitly break the reference cycle
wrap_error = None
server = None

def test_preauth_data_to_tls_client(self):
server_can_continue_with_wrap_socket = threading.Event()
client_can_continue_with_wrap_socket = threading.Event()

def call_after_accept(conn_to_client):
if not server_can_continue_with_wrap_socket.wait(support.SHORT_TIMEOUT):
print("ERROR: test client took too long")

# This forces an immediate connection close via RST on .close().
set_socket_so_linger_on_with_zero_timeout(conn_to_client)
conn_to_client.send(
conn_to_client.sendall(
gpshead marked this conversation as resolved.
Show resolved Hide resolved
b"HTTP/1.0 307 Temporary Redirect\r\n"
b"Location: https://example.com/someone-elses-server\r\n"
b"\r\n")
Expand All @@ -4800,8 +4826,10 @@ def call_after_accept(conn_to_client):

with socket.socket() as client:
client.connect(server.listener.getsockname())
if not client_can_continue_with_wrap_socket.wait(2.0):
self.fail("test server took too long.")
server_can_continue_with_wrap_socket.set()

if not client_can_continue_with_wrap_socket.wait(support.SHORT_TIMEOUT):
self.fail("test server took too long")
ssl_ctx = ssl.create_default_context()
try:
tls_client = ssl_ctx.wrap_socket(
Expand All @@ -4815,24 +4843,28 @@ def call_after_accept(conn_to_client):
tls_client.close()

server.join()
self.assertEqual(b"", received_data)
self.assertIsInstance(wrap_error, OSError) # All platforms.
self.non_linux_skip_if_other_okay_error(wrap_error)
self.assertIsInstance(wrap_error, ssl.SSLError)
self.assertIn("before TLS handshake with data", wrap_error.args[1])
self.assertIn("before TLS handshake with data", wrap_error.reason)
self.assertNotEqual(0, wrap_error.args[0])
self.assertIsNone(wrap_error.library, msg="attr must exist")
try:
self.assertEqual(b"", received_data)
self.assertIsInstance(wrap_error, OSError) # All platforms.
self.non_linux_skip_if_other_okay_error(wrap_error)
self.assertIsInstance(wrap_error, ssl.SSLError)
self.assertIn("before TLS handshake with data", wrap_error.args[1])
self.assertIn("before TLS handshake with data", wrap_error.reason)
self.assertNotEqual(0, wrap_error.args[0])
self.assertIsNone(wrap_error.library, msg="attr must exist")
finally:
# gh-108342: Explicitly break the reference cycle
wrap_error = None
server = None

def test_https_client_non_tls_response_ignored(self):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could alternatively just delete test_https_client_non_tls_response_ignored entirely because it was not a regression/bug test. It's just a demonstration I added later that it was already a non-problem for https clients.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you are worried that this test is still unreliable even with my changes, may remove it in stable branches and only keep it in the main branch?


server_responding = threading.Event()

class SynchronizedHTTPSConnection(http.client.HTTPSConnection):
def connect(self):
http.client.HTTPConnection.connect(self)
super().connect()
Copy link
Member

@gpshead gpshead Aug 23, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was intentionally not calling super(). it called http.client.HTTPConnection's connect, skipping http.client.HTTPSConnection's connect() on purpose to skip up a level.

because HTTPSConnection connect calls _wrap_socket and that is what we needed to avoid because we do it ourselves below.

Leave the previous HTTPConnection connect call and just add an explanatory comment to make it clear.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, that wasn't obvious when I read the code! Ok, I will revert and add a comment to explain this subtle function call :-)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah i don't know how i managed to not leave a comment about that. definitely unusual. :)

# Wait for our fault injection server to have done its thing.
if not server_responding.wait(1.0) and support.verbose:
if not server_responding.wait(support.SHORT_TIMEOUT) and support.verbose:
sys.stdout.write("server_responding event never set.")
self.sock = self._context.wrap_socket(
self.sock, server_hostname=self.host)
Expand All @@ -4847,28 +4879,33 @@ def call_after_accept(conn_to_client):
server_responding.set()
return True # Tell the server to stop.

timeout = 2.0
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since this is set to 2.0, all the changes to generalize the timeout in SingleConnectionTestServerThread aren't really used?

Can we at least set this to, say, 5.0 to give it some breathing room compared to the previous value?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since this is set to 2.0, all the changes to generalize the timeout in SingleConnectionTestServerThread aren't really used?

I increased the timeout from 2 seconds to SHORT_TIMEOUT (at least 30 seconds) in the 2 other tests to make these tests more reliable (2 seconds may be too short on a busy system).

Can we at least set this to, say, 5.0 to give it some breathing room compared to the previous value?

On Windows, the test takes timeout * 2 seconds (4 seconds) to complete :-( I don't understand why the client doesn't fail with a timeout error as soon as the server closes its listener connection!? Right now, I prefer to fix the known issues of dangling threads before attempting to change this fragile timeout.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's also fair to just skip the test on windows if reliability is platform specific. the primary reproducer only happens on Linux anyways.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The behavior on Windows is very different than on Linux. On Windows, test_https_client_non_tls_response_ignored() takes 4 seconds (server times out, then client times out). On Linux, it completes in 100 ms.

I didn't spend too much time to try to understand why/how. I prefer to continue running the test on Windows, unless there is a good reason to no do so. With my change, the test is reliable on Windows.

server = self.SingleConnectionTestServerThread(
call_after_accept=call_after_accept,
name="non_tls_http_RST_responder")
name="non_tls_http_RST_responder",
timeout=timeout)
self.enterContext(server) # starts it & unittest.TestCase stops it.
# Redundant; call_after_accept sets SO_LINGER on the accepted conn.
set_socket_so_linger_on_with_zero_timeout(server.listener)

connection = SynchronizedHTTPSConnection(
f"localhost",
server.listener.getsockname()[0],
port=server.port,
context=ssl.create_default_context(),
timeout=2.0,
timeout=timeout,
)

# There are lots of reasons this raises as desired, long before this
# test was added. Sending the request requires a successful TLS wrapped
# socket; that fails if the connection is broken. It may seem pointless
# to test this. It serves as an illustration of something that we never
# want to happen... properly not happening.
with self.assertRaises(OSError) as err_ctx:
with self.assertRaises(OSError):
connection.request("HEAD", "/test", headers={"Host": "localhost"})
response = connection.getresponse()

server.join()


class TestEnumerations(unittest.TestCase):

Expand Down
Loading