From f17165703002d456f7fc4b66e2b5e7410963751c Mon Sep 17 00:00:00 2001 From: Ethan Harris Date: Thu, 24 Nov 2022 22:29:52 +0000 Subject: [PATCH] [App] Add utility to get install command for package extras (#15809) --- src/lightning_app/utilities/imports.py | 24 +++++++++++++++++++++++ tests/tests_app/utilities/test_imports.py | 12 +++++++++++- 2 files changed, 35 insertions(+), 1 deletion(-) diff --git a/src/lightning_app/utilities/imports.py b/src/lightning_app/utilities/imports.py index c44cae515fb00..60747ff3624c0 100644 --- a/src/lightning_app/utilities/imports.py +++ b/src/lightning_app/utilities/imports.py @@ -17,6 +17,30 @@ from typing import List, Union from lightning_utilities.core.imports import module_available +from packaging.requirements import Marker, Requirement + +try: + from importlib import metadata +except ImportError: + # Python < 3.8 + import importlib_metadata as metadata # type: ignore + + +def _get_extras(extras: str) -> str: + """Get the given extras as a space delimited string. + + Used by the platform to install cloud extras in the cloud. + """ + from lightning_app import __package_name__ + + requirements = {r: Requirement(r) for r in metadata.requires(__package_name__)} + marker = Marker(f'extra == "{extras}"') + requirements = [r for r, req in requirements.items() if str(req.marker) == str(marker)] + + if requirements: + requirements = [f"'{r.split(';')[0].strip()}'" for r in requirements] + return " ".join(requirements) + return "" def requires(module_paths: Union[str, List]): diff --git a/tests/tests_app/utilities/test_imports.py b/tests/tests_app/utilities/test_imports.py index 00a24d41a09f0..f9f36c007625a 100644 --- a/tests/tests_app/utilities/test_imports.py +++ b/tests/tests_app/utilities/test_imports.py @@ -3,7 +3,17 @@ import pytest -from lightning_app.utilities.imports import requires +from lightning_app import __package_name__ +from lightning_app.utilities.imports import _get_extras, requires + + +def test_get_extras(): + extras = "app-cloud" if __package_name__ == "lightning" else "cloud" + extras = _get_extras(extras) + assert "docker" in extras + assert "redis" in extras + + assert _get_extras("fake-extras") == "" @mock.patch.dict(os.environ, {"LIGHTING_TESTING": "0"})