diff --git a/python/pyarrow/tests/test_flight.py b/python/pyarrow/tests/test_flight.py index 1ddd5ee1e7a22..08957e4652f54 100644 --- a/python/pyarrow/tests/test_flight.py +++ b/python/pyarrow/tests/test_flight.py @@ -337,6 +337,20 @@ def get_token(self): return self.token +def _server_thread(server_base, args, ctor_kwargs, uri, + auth_handler, tls_certificates): + """Helper function to run the Flight server as a subprocess.""" + server_instance = server_base(*args, **ctor_kwargs) + # Accept URI because flight Location is unpicklable (for + # multiprocessing) + location = flight.Location(uri) + server_instance.run( + location, + auth_handler=auth_handler, + tls_certificates=tls_certificates, + ) + + @contextlib.contextmanager def flight_server(server_base, *args, **kwargs): """Spawn a Flight server on a free port, shutting it down when done.""" @@ -361,14 +375,18 @@ def flight_server(server_base, *args, **kwargs): port = None ctor_kwargs = kwargs - server_instance = server_base(*args, **ctor_kwargs) - def _server_thread(): - server_instance.run( - location, - auth_handler=auth_handler, - tls_certificates=tls_certificates, - ) + server_process = multiprocessing.Process( + target=_server_thread, + args=( + server_base, + args, + ctor_kwargs, + location.uri, + auth_handler, + tls_certificates, + ), + ) thread = threading.Thread(target=_server_thread, daemon=True) thread.start() @@ -698,6 +716,7 @@ def test_flight_do_put_metadata(): assert idx == server_idx +@pytest.mark.slow def test_cancel_do_get(): """Test canceling a DoGet operation on the client side.""" with flight_server(ConstantFlightServer) as server_location: @@ -708,6 +727,7 @@ def test_cancel_do_get(): reader.read_next_batch() +@pytest.mark.slow def test_cancel_do_get_threaded(): """Test canceling a DoGet operation from another thread.""" with flight_server(SlowFlightServer) as server_location: