diff --git a/src/lightning/app/utilities/network.py b/src/lightning/app/utilities/network.py index d7ba2a4f88102..04afdb0b4f92c 100644 --- a/src/lightning/app/utilities/network.py +++ b/src/lightning/app/utilities/network.py @@ -89,18 +89,32 @@ def _find_free_network_port_cloudspace(): _DEFAULT_REQUEST_TIMEOUT = 30 # seconds +def create_retry_strategy(): + return Retry( + # wait time between retries increases exponentially according to: backoff_factor * (2 ** (retry - 1)) + # but the the maximum wait time is 120 secs. By setting a large value (2880), we'll make sure clients + # are going to be alive for a very long time (~ 4 days) but retries every 120 seconds + total=_CONNECTION_RETRY_TOTAL, + backoff_factor=_CONNECTION_RETRY_BACKOFF_FACTOR, + status_forcelist={ + 408, # Request Timeout + 429, # Too Many Requests + *range(500, 600), # Any 5xx Server Error status + }, + allowed_methods={ + "POST", # Default methods are idempotent, add POST here + *Retry.DEFAULT_ALLOWED_METHODS, + }, + ) + + def _configure_session() -> Session: """Configures the session for GET and POST requests. It enables a generous retrial strategy that waits for the application server to connect. """ - retry_strategy = Retry( - # wait time between retries increases exponentially according to: backoff_factor * (2 ** (retry - 1)) - total=_CONNECTION_RETRY_TOTAL, - backoff_factor=_CONNECTION_RETRY_BACKOFF_FACTOR, - status_forcelist=[429, 500, 502, 503, 504], - ) + retry_strategy = create_retry_strategy() adapter = HTTPAdapter(max_retries=retry_strategy) http = requests.Session() http.mount("https://", adapter) @@ -157,21 +171,7 @@ def __init__( self, base_url: str, auth_token: Optional[str] = None, log_callback: Optional[Callable] = None ) -> None: self.base_url = base_url - retry_strategy = Retry( - # wait time between retries increases exponentially according to: backoff_factor * (2 ** (retry - 1)) - # but the the maximum wait time is 120 secs. By setting a large value (2880), we'll make sure clients - # are going to be alive for a very long time (~ 4 days) but retries every 120 seconds - total=_CONNECTION_RETRY_TOTAL, - backoff_factor=_CONNECTION_RETRY_BACKOFF_FACTOR, - status_forcelist=[ - 408, # Request Timeout - 429, # Too Many Requests - 500, # Internal Server Error - 502, # Bad Gateway - 503, # Service Unavailable - 504, # Gateway Timeout - ], - ) + retry_strategy = create_retry_strategy() adapter = CustomRetryAdapter(max_retries=retry_strategy, timeout=_DEFAULT_REQUEST_TIMEOUT) self.session = requests.Session() diff --git a/tests/tests_app/cli/test_cmd_install.py b/tests/tests_app/cli/test_cmd_install.py index 5fddaea097fb3..93df54cbc8389 100644 --- a/tests/tests_app/cli/test_cmd_install.py +++ b/tests/tests_app/cli/test_cmd_install.py @@ -9,6 +9,7 @@ from lightning.app.testing.helpers import _RunIf +@pytest.mark.xfail(strict=False, reason="lightning app cli was deprecated") @mock.patch("lightning.app.cli.cmd_install.subprocess", mock.MagicMock()) def test_valid_org_app_name(): """Valid organization name.""" @@ -69,6 +70,7 @@ def test_app_install(tmpdir, monkeypatch): assert test_app_pip_name in str(new_env_output), f"{test_app_pip_name} should be in the env" +@pytest.mark.xfail(strict=False, reason="lightning app cli was deprecated") @mock.patch("lightning.app.cli.cmd_install.subprocess", mock.MagicMock()) def test_valid_org_component_name(): runner = CliRunner() @@ -135,6 +137,7 @@ def test_component_install(real_component, test_component_pip_name): ), f"{test_component_pip_name} should not be in the env after cleanup" +@pytest.mark.xfail(strict=False, reason="lightning app cli was deprecated") def test_prompt_actions(): # TODO: each of these installs must check that a package is installed in the environment correctly app_to_use = "lightning/invideo" @@ -164,6 +167,7 @@ def test_prompt_actions(): # result = runner.invoke(lightning_cli.cmd_install.install_app, [app_to_use], input='') +@pytest.mark.xfail(strict=False, reason="lightning app cli was deprecated") @mock.patch("lightning.app.cli.cmd_install.subprocess", mock.MagicMock()) def test_version_arg_component(tmpdir, monkeypatch): monkeypatch.chdir(tmpdir) @@ -186,6 +190,7 @@ def test_version_arg_component(tmpdir, monkeypatch): assert result.exit_code == 0 +@pytest.mark.xfail(strict=False, reason="lightning app cli was deprecated") @mock.patch("lightning.app.cli.cmd_install.subprocess", mock.MagicMock()) @mock.patch("lightning.app.cli.cmd_install.os.chdir", mock.MagicMock()) def test_version_arg_app(tmpdir): @@ -237,6 +242,7 @@ def test_install_resolve_latest_version(mock_show_install_app_prompt, tmpdir): assert mock_show_install_app_prompt.call_args[0][0]["version"] == "0.0.4" +@pytest.mark.xfail(strict=False, reason="lightning app cli was deprecated") def test_proper_url_parsing(): name = "lightning/invideo" @@ -311,12 +317,14 @@ def test_install_app_shows_error(tmpdir): # os.chdir(cwd) +@pytest.mark.xfail(strict=False, reason="lightning app cli was deprecated") def test_app_and_component_gallery_app(monkeypatch): monkeypatch.setattr(cmd_install, "_install_app_from_source", mock.MagicMock()) path = cmd_install.gallery_apps_and_components("lightning/flashy", True, "latest") assert path == os.path.join(os.getcwd(), "app.py") +@pytest.mark.xfail(strict=False, reason="lightning app cli was deprecated") def test_app_and_component_gallery_component(monkeypatch): monkeypatch.setattr(cmd_install, "_install_app_from_source", mock.MagicMock()) path = cmd_install.gallery_apps_and_components("lightning/lit-jupyter", True, "latest") diff --git a/tests/tests_app/utilities/test_network.py b/tests/tests_app/utilities/test_network.py index e3ccaf662d57d..38c8961919db6 100644 --- a/tests/tests_app/utilities/test_network.py +++ b/tests/tests_app/utilities/test_network.py @@ -1,8 +1,9 @@ +from http.client import HTTPMessage from unittest import mock import pytest from lightning.app.core import constants -from lightning.app.utilities.network import find_free_network_port +from lightning.app.utilities.network import HTTPClient, find_free_network_port def test_find_free_network_port(): @@ -42,3 +43,41 @@ def test_find_free_network_port_cloudspace(_, patch_constants): # Shouldn't use the APP_SERVER_PORT assert constants.APP_SERVER_PORT not in ports + + +@mock.patch("urllib3.connectionpool.HTTPConnectionPool._get_conn") +def test_http_client_retry_post(getconn_mock): + getconn_mock.return_value.getresponse.side_effect = [ + mock.Mock(status=500, msg=HTTPMessage()), + mock.Mock(status=429, msg=HTTPMessage()), + mock.Mock(status=200, msg=HTTPMessage()), + ] + + client = HTTPClient(base_url="http://test.url") + r = client.post("/test") + r.raise_for_status() + + assert getconn_mock.return_value.request.mock_calls == [ + mock.call("POST", "/test", body=None, headers=mock.ANY), + mock.call("POST", "/test", body=None, headers=mock.ANY), + mock.call("POST", "/test", body=None, headers=mock.ANY), + ] + + +@mock.patch("urllib3.connectionpool.HTTPConnectionPool._get_conn") +def test_http_client_retry_get(getconn_mock): + getconn_mock.return_value.getresponse.side_effect = [ + mock.Mock(status=500, msg=HTTPMessage()), + mock.Mock(status=429, msg=HTTPMessage()), + mock.Mock(status=200, msg=HTTPMessage()), + ] + + client = HTTPClient(base_url="http://test.url") + r = client.get("/test") + r.raise_for_status() + + assert getconn_mock.return_value.request.mock_calls == [ + mock.call("GET", "/test", body=None, headers=mock.ANY), + mock.call("GET", "/test", body=None, headers=mock.ANY), + mock.call("GET", "/test", body=None, headers=mock.ANY), + ]