diff --git a/.actions/setup_tools.py b/.actions/setup_tools.py index d467c0f3ba037..47eaddac3a832 100644 --- a/.actions/setup_tools.py +++ b/.actions/setup_tools.py @@ -191,7 +191,7 @@ def _load_aggregate_requirements(req_dir: str = "requirements", freeze_requireme load_requirements(d, file_name="base.txt", unfreeze=not freeze_requirements) for d in glob.glob(os.path.join(req_dir, "*")) # skip empty folder as git artefacts, and resolving Will's special issue - if os.path.isdir(d) and len(glob.glob(os.path.join(d, "*"))) > 0 + if os.path.isdir(d) and len(glob.glob(os.path.join(d, "*"))) > 0 and "__pycache__" not in d ] if not requires: return None diff --git a/.azure/app-cloud-e2e.yml b/.azure/app-cloud-e2e.yml index 7904fdc0980d6..0475d2c827c27 100644 --- a/.azure/app-cloud-e2e.yml +++ b/.azure/app-cloud-e2e.yml @@ -52,7 +52,7 @@ jobs: - job: App_cloud_e2e_testing pool: azure-cpus container: - image: mcr.microsoft.com/playwright/python:v1.27.1-focal + image: mcr.microsoft.com/playwright/python:v1.28.0-focal options: "--shm-size=4gb" strategy: matrix: diff --git a/.azure/hpu-tests.yml b/.azure/hpu-tests.yml index 899b247a37f04..3bc949a72d63e 100644 --- a/.azure/hpu-tests.yml +++ b/.azure/hpu-tests.yml @@ -40,7 +40,7 @@ jobs: cancelTimeoutInMinutes: "2" pool: intel-hpus container: - image: "vault.habana.ai/gaudi-docker/1.7.0/ubuntu20.04/habanalabs/pytorch-installer-1.12.0:latest" + image: "vault.habana.ai/gaudi-docker/1.7.1/ubuntu20.04/habanalabs/pytorch-installer-1.13.0:latest" options: "--runtime=habana -e HABANA_VISIBLE_DEVICES=all -e OMPI_MCA_btl_vader_single_copy_mechanism=none --cap-add=sys_nice --ipc=host --shm-size=4g -v /usr/bin/docker:/tmp/docker:ro" workspace: clean: all diff --git a/.github/actions/pkg-publish/action.yml b/.github/actions/pkg-publish/action.yml index b9362784ae6c5..beef2d34f7db9 100644 --- a/.github/actions/pkg-publish/action.yml +++ b/.github/actions/pkg-publish/action.yml @@ -30,7 +30,7 @@ runs: if: inputs.pypi-test-token != '' with: user: __token__ - password: ${{ secrets.test_pypi_token_lai }} + password: ${{ inputs.pypi-test-token }} repository_url: https://test.pypi.org/legacy/ packages_dir: pypi/ verbose: true diff --git a/.github/workflows/ci-app-examples.yml b/.github/workflows/ci-app-examples.yml index b1a79ea50d9bc..96c85388ea133 100644 --- a/.github/workflows/ci-app-examples.yml +++ b/.github/workflows/ci-app-examples.yml @@ -3,9 +3,9 @@ name: Test App - examples # see: https://help.github.com/en/actions/reference/events-that-trigger-workflows on: push: - branches: [master, "release/*"] + branches: [master, "release/*", "lite/debug"] pull_request: - branches: [master, "release/*"] + branches: [master, "release/*", "lite/debug"] types: [opened, reopened, ready_for_review, synchronize] # added `ready_for_review` since draft is skipped paths: - ".actions/**" diff --git a/.github/workflows/ci-app-tests.yml b/.github/workflows/ci-app-tests.yml index d19a408309bc4..780c8e0fa203a 100644 --- a/.github/workflows/ci-app-tests.yml +++ b/.github/workflows/ci-app-tests.yml @@ -3,9 +3,9 @@ name: Test App # see: https://help.github.com/en/actions/reference/events-that-trigger-workflows on: push: - branches: [master, "release/*"] + branches: [master, "release/*", "lite/debug"] pull_request: - branches: [master, "release/*"] + branches: [master, "release/*", "lite/debug"] types: [opened, reopened, ready_for_review, synchronize] # added `ready_for_review` since draft is skipped paths: - ".actions/**" @@ -94,7 +94,7 @@ jobs: - name: Adjust tests if: ${{ matrix.pkg-name == 'lightning' }} - run: python .actions/assistant.py copy_replace_imports --source_dir="./tests" --source_import="lightning_app" --target_import="lightning.app" + run: python .actions/assistant.py copy_replace_imports --source_dir="./tests" --source_import="lightning_app,lightning_lite,pytorch_lightning" --target_import="lightning.app,lightning.lite,lightning.pytorch" - name: Adjust examples if: ${{ matrix.pkg-name != 'lightning' }} diff --git a/.github/workflows/ci-pytorch-dockers.yml b/.github/workflows/ci-pytorch-dockers.yml index 9719cf3701835..065f5b71d4bd4 100644 --- a/.github/workflows/ci-pytorch-dockers.yml +++ b/.github/workflows/ci-pytorch-dockers.yml @@ -2,9 +2,9 @@ name: Docker on: push: - branches: [master, "release/*"] + branches: [master, "release/*", "lite/debug"] pull_request: - branches: [master, "release/*"] + branches: [master, "release/*", "lite/debug"] types: [opened, reopened, ready_for_review, synchronize] # added `ready_for_review` since draft is skipped paths: - ".actions/**" diff --git a/.github/workflows/legacy-checkpoints.yml b/.github/workflows/legacy-checkpoints.yml index 0531f2e72c957..b27ec472c8791 100644 --- a/.github/workflows/legacy-checkpoints.yml +++ b/.github/workflows/legacy-checkpoints.yml @@ -72,6 +72,7 @@ jobs: working-directory: ./ env: PACKAGE_NAME: pytorch + FREEZE_REQUIREMENTS: 1 run: | pip install . -f https://download.pytorch.org/whl/cpu/torch_stable.html pip list diff --git a/.github/workflows/release-pypi.yml b/.github/workflows/release-pypi.yml index 5c8bdd5b3d9b5..3a17745c93f43 100644 --- a/.github/workflows/release-pypi.yml +++ b/.github/workflows/release-pypi.yml @@ -11,9 +11,6 @@ defaults: run: shell: bash -env: - PUBLISH: ${{ startsWith(github.event.ref, 'refs/tags') || github.event_name == 'release' }} - jobs: init: runs-on: ubuntu-20.04 @@ -184,7 +181,7 @@ jobs: publish-packages: runs-on: ubuntu-20.04 - needs: waiting + needs: [build-packages, waiting] if: startsWith(github.event.ref, 'refs/tags') || github.event_name == 'release' steps: - uses: actions/checkout@v3 @@ -215,8 +212,8 @@ jobs: needs: [build-packages] uses: ./.github/workflows/legacy-checkpoints.yml with: - push_to_s3: ${{ env.PUBLISH }} - create_pr: ${{ env.PUBLISH }} + push_to_s3: ${{ startsWith(github.event.ref, 'refs/tags') || github.event_name == 'release' }} + create_pr: ${{ startsWith(github.event.ref, 'refs/tags') || github.event_name == 'release' }} secrets: AWS_ACCESS_KEY_ID: ${{ secrets.AWS_ACCESS_KEY_ID }} AWS_SECRET_KEY_ID: ${{ secrets.AWS_SECRET_KEY_ID }} diff --git a/dockers/base-xla/tpu_workflow_lite.jsonnet b/dockers/base-xla/tpu_workflow_fabric.jsonnet similarity index 100% rename from dockers/base-xla/tpu_workflow_lite.jsonnet rename to dockers/base-xla/tpu_workflow_fabric.jsonnet diff --git a/dockers/ci-runner-hpu/Dockerfile b/dockers/ci-runner-hpu/Dockerfile index db9f9113126af..4e9a191efd318 100644 --- a/dockers/ci-runner-hpu/Dockerfile +++ b/dockers/ci-runner-hpu/Dockerfile @@ -16,8 +16,8 @@ # gaudi-docker-agent:latest ARG DIST="latest" -ARG GAUDI_VERSION="1.7.0" -ARG PYTORCH_INSTALLER_VERSION="1.12.0" +ARG GAUDI_VERSION="1.7.1" +ARG PYTORCH_INSTALLER_VERSION="1.13.0" FROM vault.habana.ai/gaudi-docker/${GAUDI_VERSION}/ubuntu20.04/habanalabs/pytorch-installer-${PYTORCH_INSTALLER_VERSION}:${DIST} LABEL maintainer="https://vault.habana.ai/" diff --git a/docs/source-app/api_reference/runners.rst b/docs/source-app/api_reference/runners.rst index 3040d3adde36c..1036df1731eb8 100644 --- a/docs/source-app/api_reference/runners.rst +++ b/docs/source-app/api_reference/runners.rst @@ -18,5 +18,4 @@ ______________ :template: classtemplate.rst ~cloud.CloudRuntime - ~singleprocess.SingleProcessRuntime ~multiprocess.MultiProcessRuntime diff --git a/docs/source-app/api_reference/storage.rst b/docs/source-app/api_reference/storage.rst index 5bcdb0973dad6..4d125b80ae244 100644 --- a/docs/source-app/api_reference/storage.rst +++ b/docs/source-app/api_reference/storage.rst @@ -20,6 +20,7 @@ ______________ ~path.Path ~drive.Drive ~payload.Payload + ~mount.Mount ---- @@ -56,6 +57,14 @@ Learn more about Storage :height: 180 :tag: Intermediate +.. displayitem:: + :header: The Mount Object. + :description: Mount an AWS S3 Bucket When Running on the Cloud. + :col_css: col-md-4 + :button_link: ../workflows/mount_aws_s3_bucket.html + :height: 180 + :tag: Intermediate + .. raw:: html diff --git a/docs/source-app/api_references.rst b/docs/source-app/api_references.rst index 30e0ade3a25ad..2272f7bf13c41 100644 --- a/docs/source-app/api_references.rst +++ b/docs/source-app/api_references.rst @@ -32,11 +32,20 @@ ___________________ :nosignatures: :template: classtemplate_no_index.rst + ~database.client.DatabaseClient + ~database.server.Database ~python.popen.PopenPythonScript ~python.tracer.TracerPythonScript ~training.LightningTrainerScript ~serve.gradio.ServeGradio ~serve.serve.ModelInferenceAPI + ~serve.python_server.PythonServer + ~serve.streamlit.ServeStreamlit + ~multi_node.base.MultiNode + ~multi_node.lite.LiteMultiNode + ~multi_node.pytorch_spawn.PyTorchSpawnMultiNode + ~multi_node.trainer.LightningTrainerMultiNode + ~auto_scaler.AutoScaler ---- @@ -71,6 +80,7 @@ _______ ~path.Path ~drive.Drive ~payload.Payload + ~mount.Mount Learn more about :ref:`Storage `. @@ -87,5 +97,19 @@ _______ :template: classtemplate_no_index.rst ~cloud.CloudRuntime - ~singleprocess.SingleProcessRuntime ~multiprocess.MultiProcessRuntime + +---- + +lightning_app.utilities.packaging +_________________________________ + +.. currentmodule:: lightning_app.utilities.packaging + +.. autosummary:: + :toctree: generated/ + :nosignatures: + :template: classtemplate_no_index.rst + + ~cloud_compute.CloudCompute + ~build_config.BuildConfig diff --git a/docs/source-app/glossary/index.rst b/docs/source-app/glossary/index.rst index 9652106e2aad1..a46da20ddd118 100644 --- a/docs/source-app/glossary/index.rst +++ b/docs/source-app/glossary/index.rst @@ -112,6 +112,13 @@ Glossary :button_link: ../core_api/lightning_app/index.html :height: 100 +.. displayitem:: + :header: Mounts + :description: Mount Cloud Data + :col_css: col-md-6 + :button_link: mount.html + :height: 180 + .. displayitem:: :header: Sharing Components :description: Let's create an ecosystem altogether diff --git a/docs/source-app/glossary/mount.rst b/docs/source-app/glossary/mount.rst new file mode 100644 index 0000000000000..a62d72b5b798d --- /dev/null +++ b/docs/source-app/glossary/mount.rst @@ -0,0 +1 @@ +.. include:: ../workflows/mount_cloud_object_store.rst diff --git a/docs/source-app/testing.rst b/docs/source-app/testing.rst index 6d0fe71832a7e..da52727cbde0d 100644 --- a/docs/source-app/testing.rst +++ b/docs/source-app/testing.rst @@ -120,7 +120,6 @@ We provide ``application_testing`` as a helper funtion to get your application u os.path.join(_PROJECT_ROOT, "examples/app_v0/app.py"), "--blocking", "False", - "--multiprocess", "--open-ui", "False", ] @@ -129,9 +128,7 @@ First in the list for ``command_line`` is the location of your script. It is an Next there are a couple of options you can leverage: - * ``blocking`` - Blocking is an app status that says "Do not run until I click run in the UI". For our integration test, since we are not using the UI, we are setting this to "False". -* ``multiprocess/singleprocess`` - This is the runtime your app is expected to run under. * ``open-ui`` - We set this to false since this is the routine that opens a browser for your local execution. Once you have your commandline ready, you will then be able to kick off the test and gather results: diff --git a/docs/source-app/workflows/index.rst b/docs/source-app/workflows/index.rst index 314593c2aae23..801c6ccecb9a0 100644 --- a/docs/source-app/workflows/index.rst +++ b/docs/source-app/workflows/index.rst @@ -187,6 +187,14 @@ How to: :button_link: ssh/index.html :height: 180 +.. displayitem:: + :header: Mount Cloud Data + :description: Learn how Lightning Mounts are used to make the contents of an cloud object store bucket available on disk when running in the cloud. + :col_css: col-md-4 + :button_link: mount_cloud_object_store.html + :height: 180 + + .. raw:: html diff --git a/docs/source-app/workflows/mount_cloud_object_store.rst b/docs/source-app/workflows/mount_cloud_object_store.rst new file mode 100644 index 0000000000000..4dfbe7bb37132 --- /dev/null +++ b/docs/source-app/workflows/mount_cloud_object_store.rst @@ -0,0 +1,141 @@ +:orphan: + +############## +Add Cloud Data +############## + +**Audience:** Users who want to read files stored in a Cloud Object Bucket in an app. + +****************************** +Mounting Public AWS S3 Buckets +****************************** + +=================== +Add Mount to a Work +=================== + +To mount data from a cloud bucket to your app compute, initialize a :class:`~lightning_app.storage.mount.Mount` +object with the source path of the s3 bucket and the absolute directory path where it should be mounted and +pass the :class:`~lightning_app.storage.mount.Mount` to the :class:`~lightning_app.utilities.packaging.cloud_compute.CloudCompute` +of the :class:`~lightning_app.core.work.LightningWork` it should be mounted on. + +In this example, we will mount an S3 bucket: ``s3://ryft-public-sample-data/esRedditJson/`` to ``/content/esRedditJson/``. + +.. code-block:: python + + from lightning_app import CloudCompute + from lightning_app.storage import Mount + + self.my_work = MyWorkClass( + cloud_compute=CloudCompute( + mounts=Mount( + source="s3://ryft-public-sample-data/esRedditJson/", + mount_path="/content/esRedditJson/", + ), + ) + ) + +You can also pass multiple mounts to a single work by passing a ``List[Mount(...), ...]`` to the +``CloudCompute(mounts=...)`` argument. + +.. note:: + + * Mounts supported up to 1 Million files, 5GB per file. Need larger mounts? Contact support@lightning.ai + * When adding multiple mounts, each one should have a unique ``mount_path``. + * A maximum of 10 :class:`~lightning_app.storage.mount.Mount`\s can be added to a :class:`~lightning_app.core.work.LightningWork`. + +======================= +Read Files From a Mount +======================= + +Once a :class:`~lightning_app.storage.mount.Mount` object is passed to :class:`~lightning_app.utilities.packaging.cloud_compute.CloudCompute`, +you can access, list, or read any file from the mount under the specified ``mount_path``, just like you would if it +was on your local machine. + +Assuming your ``mount_path`` is ``"/content/esRedditJson/"`` you can do the following: + +---------- +Read Files +---------- + +.. code-block:: python + + with open("/content/esRedditJson/esRedditJson1", "r") as f: + some_data = f.read() + + # do something with "some_data"... + +---------- +List Files +---------- + +.. code-block:: python + + files = os.listdir("/content/esRedditJson/") + +-------------------- +See the Full Example +-------------------- + +.. code-block:: python + :emphasize-lines: 10,15 + + import os + + import lightning as L + from lightning_app import CloudCompute + from lightning_app.storage import Mount + + class ReadMount(L.LightningWork): + def run(self): + # Print a list of files stored in the mounted S3 Bucket. + files = os.listdir("/content/esRedditJson/") + for file in files: + print(file) + + # Read the contents of a particular file in the bucket "esRedditJson1" + with open("/content/esRedditJson/esRedditJson1", "r") as f: + some_data = f.read() + # do something with "some_data"... + + class Flow(L.LightningFlow): + def __init__(self): + super().__init__() + self.my_work = ReadMount( + cloud_compute=CloudCompute( + mounts=Mount( + source="s3://ryft-public-sample-data/esRedditJson/", + mount_path="/content/esRedditJson/", + ), + ) + ) + + def run(self): + self.my_work.run() + +.. note:: + + When running a Lighting App on your local machine, any :class:`~lightning_app.utilities.packaging.cloud_compute.CloudCompute` + configuration (including a :class:`~lightning_app.storage.mount.Mount`) is ignored at runtime. If you need access to + these files on your local disk, you should download a copy of them to your machine. + +.. note:: + + Mounted files from an S3 bucket are ``read-only``. Any modifications, additions, or deletions + to files in the mounted directory will not be reflected in the cloud object store. + +---- + +********************************************** +Mounting Private AWS S3 Buckets - Coming Soon! +********************************************** + +We'll Let you know when this feature is ready! + +---- + +************************************************ +Mounting Google Cloud GCS Buckets - Coming Soon! +************************************************ + +We'll Let you know when this feature is ready! diff --git a/docs/source-pytorch/accelerators/hpu_basic.rst b/docs/source-pytorch/accelerators/hpu_basic.rst index 2fa20dd9354f2..2ee36fee2361d 100644 --- a/docs/source-pytorch/accelerators/hpu_basic.rst +++ b/docs/source-pytorch/accelerators/hpu_basic.rst @@ -113,4 +113,3 @@ Known limitations ----------------- * `Habana dataloader `__ is not supported. -* :func:`torch.inference_mode` is not supported diff --git a/docs/source-pytorch/accelerators/hpu_intermediate.rst b/docs/source-pytorch/accelerators/hpu_intermediate.rst index a1395b0e2c183..3ef5a6f2bb485 100644 --- a/docs/source-pytorch/accelerators/hpu_intermediate.rst +++ b/docs/source-pytorch/accelerators/hpu_intermediate.rst @@ -96,4 +96,4 @@ The below snippet shows how DeviceStatsMonitor can be enabled. device_stats = DeviceStatsMonitor() trainer = Trainer(accelerator="hpu", callbacks=[device_stats]) -For more details, please refer to `Memory Stats APIs `__. +For more details, please refer to `Memory Stats APIs `__. diff --git a/docs/source-pytorch/advanced/model_parallel.rst b/docs/source-pytorch/advanced/model_parallel.rst index d2c86db5bacf7..db5605619ad43 100644 --- a/docs/source-pytorch/advanced/model_parallel.rst +++ b/docs/source-pytorch/advanced/model_parallel.rst @@ -424,10 +424,9 @@ You can customize the strategy configuration by adjusting the arguments of :clas from pytorch_lightning import Trainer from pytorch_lightning.strategies import DDPFullyShardedNativeStrategy - from torch.distributed.fsdp.fully_sharded_data_parallel import CPUOffload - native_fsdp = DDPFullyShardedNativeStrategy(cpu_offload=CPUOffload(offload_params=True)) + native_fsdp = DDPFullyShardedNativeStrategy(cpu_offload=True) trainer = pl.Trainer(strategy=native_fsdp, accelerator="gpu", devices=4) diff --git a/docs/source-pytorch/clouds/cluster_intermediate_1.rst b/docs/source-pytorch/clouds/cluster_intermediate_1.rst index c2d92e200bde0..89d60a85534b2 100644 --- a/docs/source-pytorch/clouds/cluster_intermediate_1.rst +++ b/docs/source-pytorch/clouds/cluster_intermediate_1.rst @@ -24,7 +24,7 @@ PyTorch Lightning follows the design of `PyTorch distributed communication packa - *MASTER_PORT* - required; has to be a free port on machine with NODE_RANK 0 - *MASTER_ADDR* - required (except for NODE_RANK 0); address of NODE_RANK 0 node -- *WORLD_SIZE* - required; how many nodes are in the cluster +- *WORLD_SIZE* - required; the total number of GPUs/processes that you will use - *NODE_RANK* - required; id of the node in the cluster .. _training_script_setup: diff --git a/examples/app_dag/requirements.txt b/examples/app_dag/requirements.txt index 101182e0cd9ab..f669f518e7389 100644 --- a/examples/app_dag/requirements.txt +++ b/examples/app_dag/requirements.txt @@ -1,2 +1,2 @@ -sklearn +scikit-learn pandas diff --git a/examples/app_installation_commands/app.py b/examples/app_installation_commands/app.py index 9eb1c2944ee2e..087d84b1335b2 100644 --- a/examples/app_installation_commands/app.py +++ b/examples/app_installation_commands/app.py @@ -13,6 +13,10 @@ def run(self): print("lmdb successfully installed") print("accessing a module in a Work or Flow body works!") + @property + def ready(self) -> bool: + return True + print(f"accessing an object in main code body works!: version={lmdb.version()}") diff --git a/examples/app_mount/app.py b/examples/app_mount/app.py index 11da2f02552d8..d0d2adf3e0759 100644 --- a/examples/app_mount/app.py +++ b/examples/app_mount/app.py @@ -32,4 +32,4 @@ def run(self): self.work_1.run() -app = L.LightningApp(Flow(), log_level="debug") +app = L.LightningApp(Flow()) diff --git a/examples/app_server_with_auto_scaler/app.py b/examples/app_server_with_auto_scaler/app.py new file mode 100644 index 0000000000000..b713bd6d1dcfc --- /dev/null +++ b/examples/app_server_with_auto_scaler/app.py @@ -0,0 +1,86 @@ +from typing import Any, List + +import torch +import torchvision +from pydantic import BaseModel + +import lightning as L + + +class RequestModel(BaseModel): + image: str # bytecode + + +class BatchRequestModel(BaseModel): + inputs: List[RequestModel] + + +class BatchResponse(BaseModel): + outputs: List[Any] + + +class PyTorchServer(L.app.components.PythonServer): + def __init__(self, *args, **kwargs): + super().__init__( + port=L.app.utilities.network.find_free_network_port(), + input_type=BatchRequestModel, + output_type=BatchResponse, + cloud_compute=L.CloudCompute("gpu"), + ) + + def setup(self): + self._device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + self._model = torchvision.models.resnet18(pretrained=True).to(self._device) + + def predict(self, requests: BatchRequestModel): + transforms = torchvision.transforms.Compose( + [ + torchvision.transforms.Resize(224), + torchvision.transforms.ToTensor(), + torchvision.transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), + ] + ) + images = [] + for request in requests.inputs: + image = L.app.components.serve.types.image.Image.deserialize(request.image) + image = transforms(image).unsqueeze(0) + images.append(image) + images = torch.cat(images) + images = images.to(self._device) + predictions = self._model(images) + results = predictions.argmax(1).cpu().numpy().tolist() + return BatchResponse(outputs=[{"prediction": pred} for pred in results]) + + +class MyAutoScaler(L.app.components.AutoScaler): + def scale(self, replicas: int, metrics: dict) -> int: + """The default scaling logic that users can override.""" + # scale out if the number of pending requests exceeds max batch size. + max_requests_per_work = self.max_batch_size + pending_requests_per_running_or_pending_work = metrics["pending_requests"] / ( + replicas + metrics["pending_works"] + ) + if pending_requests_per_running_or_pending_work >= max_requests_per_work: + return replicas + 1 + + # scale in if the number of pending requests is below 25% of max_requests_per_work + min_requests_per_work = max_requests_per_work * 0.25 + pending_requests_per_running_work = metrics["pending_requests"] / replicas + if pending_requests_per_running_work < min_requests_per_work: + return replicas - 1 + + return replicas + + +app = L.LightningApp( + MyAutoScaler( + PyTorchServer, + min_replicas=2, + max_replicas=4, + autoscale_interval=10, + endpoint="predict", + input_type=RequestModel, + output_type=Any, + timeout_batching=1, + ) +) diff --git a/examples/app_template_streamlit_ui/app.py b/examples/app_template_streamlit_ui/app.py index 6f344ac98eb8d..b6fc604222ce2 100644 --- a/examples/app_template_streamlit_ui/app.py +++ b/examples/app_template_streamlit_ui/app.py @@ -1,8 +1,8 @@ import logging -from lightning_app import LightningApp, LightningFlow -from lightning_app.frontend import StreamlitFrontend -from lightning_app.utilities.state import AppState +from lightning.app import LightningApp, LightningFlow +from lightning.app.frontend import StreamlitFrontend +from lightning.app.utilities.state import AppState logger = logging.getLogger(__name__) @@ -45,4 +45,4 @@ def configure_layout(self): return [{"name": "StreamLitUI", "content": self.streamlit_ui}] -app = LightningApp(HelloWorld(), log_level="debug") +app = LightningApp(HelloWorld()) diff --git a/pyproject.toml b/pyproject.toml index 191dad051fdc4..1f23e7a63e545 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -30,14 +30,12 @@ files = [ "src/lightning_fabric", "src/lightning_app", ] +# This section is for folders with "-" as they are not valid python modules exclude = [ + "src/lightning_app/cli/app-template", "src/lightning_app/cli/component-template", "src/lightning_app/cli/pl-app-template", "src/lightning_app/cli/react-ui-template", - "src/lightning_app/cli/app-template", - "src/lightning_app/components/database", - "src/lightning_app/components/multi_node", - "src/lightning_app/frontend/just_py/just_py", ] install_types = "True" non_interactive = "True" @@ -67,7 +65,9 @@ module = [ "lightning_app.api.request_types", "lightning_app.cli.commands.app_commands", "lightning_app.cli.commands.connection", - "lightning_app.cli.react-ui-template.example_app", + "lightning_app.cli.commands.lightning_cli", + "lightning_app.cli.commands.cmd_install", + "lightning_app.cli.cmd_install", "lightning_app.components.database.client", "lightning_app.components.database.server", "lightning_app.components.database.utilities", @@ -80,6 +80,7 @@ module = [ "lightning_app.components.serve.types.type", "lightning_app.components.serve.python_server", "lightning_app.components.training", + "lightning_app.components.auto_scaler", "lightning_app.core.api", "lightning_app.core.app", "lightning_app.core.flow", @@ -93,6 +94,7 @@ module = [ "lightning_app.frontend.streamlit_base", "lightning_app.frontend.utils", "lightning_app.frontend.web", + "lightning_app.perf.pdb", "lightning_app.runners.backends.__init__", "lightning_app.runners.backends.backend", "lightning_app.runners.backends.cloud", @@ -101,7 +103,6 @@ module = [ "lightning_app.runners.cloud", "lightning_app.runners.multiprocess", "lightning_app.runners.runtime", - "lightning_app.runners.singleprocess", "lightning_app.source_code.copytree", "lightning_app.source_code.hashing", "lightning_app.source_code.local", diff --git a/requirements/app/base.txt b/requirements/app/base.txt index 872fa2f4c84b1..37e91689bb54d 100644 --- a/requirements/app/base.txt +++ b/requirements/app/base.txt @@ -7,8 +7,9 @@ fsspec>=2022.5.0, <=2022.7.1 croniter>=1.3.0, <1.4.0 # strict; TODO: for now until we find something more robust. traitlets>=5.3.0, <=5.4.0 arrow>=1.2.0, <1.2.4 -lightning-utilities>=0.3.*, !=0.4.0, <0.5.0 +lightning-utilities>=0.3.0, !=0.4.0, <0.5.0 beautifulsoup4>=4.8.0, <4.11.2 inquirer>=2.10.0 psutil<5.9.4 click<=8.1.3 +aiohttp>=3.8.0, <=3.8.3 diff --git a/requirements/app/test.txt b/requirements/app/test.txt index 2d9a28947162d..5d000ce1ef625 100644 --- a/requirements/app/test.txt +++ b/requirements/app/test.txt @@ -4,7 +4,7 @@ pytest==7.2.0 pytest-timeout==2.1.0 pytest-cov==4.0.0 pytest-doctestplus>=0.9.0 -playwright==1.27.1 +playwright==1.28.0 httpx trio<0.22.0 pympler diff --git a/requirements/app/ui.txt b/requirements/app/ui.txt index fa051d284f0b8..6e73b96c317d3 100644 --- a/requirements/app/ui.txt +++ b/requirements/app/ui.txt @@ -1,2 +1,2 @@ -streamlit>=1.3.1, <=1.11.1 +streamlit>=1.0.0, <=1.15.2 panel>=0.12.7, <=0.13.1 diff --git a/requirements/fabric/base.txt b/requirements/fabric/base.txt index 45fd511d315e5..dc7fd6ea578da 100644 --- a/requirements/fabric/base.txt +++ b/requirements/fabric/base.txt @@ -2,7 +2,7 @@ # in case you want to preserve/enforce restrictions on the latest compatible version, add "strict" as an in-line comment numpy>=1.17.2, <1.23.1 -torch>=1.10.*, <=1.13.0 +torch>=1.10.0, <=1.13.0 fsspec[http]>2021.06.0, <2022.6.0 packaging>=17.0, <=21.3 typing-extensions>=4.0.0, <=4.4.0 diff --git a/requirements/fabric/examples.txt b/requirements/fabric/examples.txt index ebf32e2c2c48d..43bb03e07cc80 100644 --- a/requirements/fabric/examples.txt +++ b/requirements/fabric/examples.txt @@ -1,4 +1,4 @@ # NOTE: the upper bound for the package version is only set for CI stability, and it is dropped while installing this package # in case you want to preserve/enforce restrictions on the latest compatible version, add "strict" as an in-line comment -torchvision>=0.10.*, <=0.13.0 +torchvision>=0.10.0, <=0.13.0 diff --git a/requirements/pytorch/base.txt b/requirements/pytorch/base.txt index 483460f88ff11..31163ecb602b7 100644 --- a/requirements/pytorch/base.txt +++ b/requirements/pytorch/base.txt @@ -2,7 +2,7 @@ # in case you want to preserve/enforce restrictions on the latest compatible version, add "strict" as an in-line comment numpy>=1.17.2, <1.23.1 -torch>=1.10.*, <=1.13.0 +torch>=1.10.0, <=1.13.0 tqdm>=4.57.0, <4.65.0 PyYAML>=5.4, <=6.0 fsspec[http]>2021.06.0, <2022.8.0 diff --git a/requirements/pytorch/examples.txt b/requirements/pytorch/examples.txt index c749c83faedb9..7e02a2f4bea99 100644 --- a/requirements/pytorch/examples.txt +++ b/requirements/pytorch/examples.txt @@ -1,6 +1,6 @@ # NOTE: the upper bound for the package version is only set for CI stability, and it is dropped while installing this package # in case you want to preserve/enforce restrictions on the latest compatible version, add "strict" as an in-line comment -torchvision>=0.11.*, <=0.14.0 +torchvision>=0.11.1, <=0.14.0 gym[classic_control]>=0.17.0, <0.26.3 ipython[all] <8.6.1 diff --git a/src/lightning_app/CHANGELOG.md b/src/lightning_app/CHANGELOG.md index ec7812e6e7417..889bf82b8e42a 100644 --- a/src/lightning_app/CHANGELOG.md +++ b/src/lightning_app/CHANGELOG.md @@ -9,6 +9,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Added - Added the CLI command `lightning run model` to launch a `LightningLite` accelerated script ([#15506](https://github.com/Lightning-AI/lightning/pull/15506)) + - Added the CLI command `lightning delete app` to delete a lightning app on the cloud ([#15783](https://github.com/Lightning-AI/lightning/pull/15783)) - Show a message when `BuildConfig(requirements=[...])` is passed but a `requirements.txt` file is already present in the Work ([#15799](https://github.com/Lightning-AI/lightning/pull/15799)) @@ -17,8 +18,20 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added a CloudMultiProcessBackend which enables running a child App from within the Flow in the cloud ([#15800](https://github.com/Lightning-AI/lightning/pull/15800)) +- Utility for pickling work object safely even from a child process ([#15836](https://github.com/Lightning-AI/lightning/pull/15836)) + +- Added `AutoScaler` component ([#15769](https://github.com/Lightning-AI/lightning/pull/15769)) + - Added the property `ready` of the LightningFlow to inform when the `Open App` should be visible ([#15921](https://github.com/Lightning-AI/lightning/pull/15921)) +- Added private work attributed `_start_method` to customize how to start the works ([#15923](https://github.com/Lightning-AI/lightning/pull/15923)) + +- Added a `configure_layout` method to the `LightningWork` which can be used to control how the work is handled in the layout of a parent flow ([#15926](https://github.com/Lightning-AI/lightning/pull/15926)) + +- Added the ability to run a Lightning App or Component directly from the Gallery using `lightning run app organization/name` ([#15941](https://github.com/Lightning-AI/lightning/pull/15941)) + +- Added automatic conversion of list and dict of works and flows to structures ([#15961](https://github.com/Lightning-AI/lightning/pull/15961)) + ### Changed @@ -38,21 +51,32 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Removed -- +- Removed the `SingleProcessRuntime` ([#15933](https://github.com/Lightning-AI/lightning/pull/15933)) ### Fixed - Fixed SSH CLI command listing stopped components ([#15810](https://github.com/Lightning-AI/lightning/pull/15810)) +- Fixed MPS error for multinode component (defaults to cpu on mps devices now as distributed operations are not supported by pytorch on mps) ([#15748](https://github.com/Ligtning-AI/lightning/pull/15748)) - Fixed the work not stopped when successful when passed directly to the LightningApp ([#15801](https://github.com/Lightning-AI/lightning/pull/15801)) - - Fixed the `enable_spawn` method of the `WorkRunExecutor` ([#15812](https://github.com/Lightning-AI/lightning/pull/15812) + - Fixed Sigterm Handler causing thread lock which caused KeyboardInterrupt to hang ([#15881](https://github.com/Lightning-AI/lightning/pull/15881)) +- Fixed a bug where using `L.app.structures` would cause multiple apps to be opened and fail with an error in the cloud ([#15911](https://github.com/Lightning-AI/lightning/pull/15911)) + +- Fixed PythonServer generating noise on M1 ([#15949](https://github.com/Lightning-AI/lightning/pull/15949)) + +- Fixed `ImportError` on Multinode if package not present ([#15963](https://github.com/Lightning-AI/lightning/pull/15963)) + +- Fixed multiprocessing breakpoint ([#15950](https://github.com/Lightning-AI/lightning/pull/15950)) + +- Fixed detection of a Lightning App running in debug mode ([#15951](https://github.com/Lightning-AI/lightning/pull/15951)) + ## [1.8.3] - 2022-11-22 @@ -100,6 +124,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed bi-directional queues sending delta with Drive Component name changes ([#15642](https://github.com/Lightning-AI/lightning/pull/15642)) - Fixed CloudRuntime works collection with structures and accelerated multi node startup time ([#15650](https://github.com/Lightning-AI/lightning/pull/15650)) - Fixed catimage import ([#15712](https://github.com/Lightning-AI/lightning/pull/15712)) +- Fixed setting property to the LightningFlow ([#15750](https://github.com/Lightning-AI/lightning/pull/15750)) - Parse all lines in app file looking for shebangs to run commands ([#15714](https://github.com/Lightning-AI/lightning/pull/15714)) diff --git a/src/lightning_app/cli/cmd_install.py b/src/lightning_app/cli/cmd_install.py index db0467212f147..579a921179b4c 100644 --- a/src/lightning_app/cli/cmd_install.py +++ b/src/lightning_app/cli/cmd_install.py @@ -5,6 +5,7 @@ import sys from typing import Dict, Optional, Tuple +import click import requests from packaging.version import Version @@ -14,7 +15,117 @@ logger = Logger(__name__) -def gallery_component(name: str, yes_arg: bool, version_arg: str, cwd: str = None) -> None: +@click.group(name="install") +def install() -> None: + """Install Lightning AI selfresources.""" + pass + + +@install.command("app") +@click.argument("name", type=str) +@click.option( + "--yes", + "-y", + is_flag=True, + help="disables prompt to ask permission to create env and run install cmds", +) +@click.option( + "--version", + "-v", + type=str, + help="Specify the version to install. By default it uses 'latest'", + default="latest", + show_default=True, +) +@click.option( + "--overwrite", + "-f", + is_flag=True, + default=False, + help="When set, overwrite the app directory without asking if it already exists.", +) +def install_app(name: str, yes: bool, version: str, overwrite: bool = False) -> None: + _install_app_command(name, yes, version, overwrite=overwrite) + + +@install.command("component") +@click.argument("name", type=str) +@click.option( + "--yes", + "-y", + is_flag=True, + help="disables prompt to ask permission to create env and run install cmds", +) +@click.option( + "--version", + "-v", + type=str, + help="Specify the version to install. By default it uses 'latest'", + default="latest", + show_default=True, +) +def install_component(name: str, yes: bool, version: str) -> None: + _install_component_command(name, yes, version) + + +def _install_app_command(name: str, yes: bool, version: str, overwrite: bool = False) -> None: + if "github.com" in name: + if version != "latest": + logger.warn( + "When installing from GitHub, only the 'latest' version is supported. " + f"The provided version ({version}) will be ignored." + ) + return non_gallery_app(name, yes, overwrite=overwrite) + else: + return gallery_app(name, yes, version, overwrite=overwrite) + + +def _install_component_command(name: str, yes: bool, version: str, overwrite: bool = False) -> None: + if "github.com" in name: + if version != "latest": + logger.warn( + "When installing from GitHub, only the 'latest' version is supported. " + f"The provided version ({version}) will be ignored." + ) + return non_gallery_component(name, yes) + else: + return gallery_component(name, yes, version) + + +def gallery_apps_and_components( + name: str, yes_arg: bool, version_arg: str, cwd: str = None, overwrite: bool = False +) -> Optional[str]: + + try: + org, app_or_component = name.split("/") + except Exception: + return None + + entry, kind = _resolve_entry(app_or_component, version_arg) + + if kind == "app": + # give the user the chance to do a manual install + source_url, git_url, folder_name, git_sha = _show_install_app_prompt( + entry, app_or_component, org, yes_arg, resource_type="app" + ) + # run installation if requested + _install_app_from_source(source_url, git_url, folder_name, cwd=cwd, overwrite=overwrite, git_sha=git_sha) + + return os.path.join(os.getcwd(), folder_name, entry["appEntrypointFile"]) + + elif kind == "component": + # give the user the chance to do a manual install + git_url = _show_install_component_prompt(entry, app_or_component, org, yes_arg) + + # run installation if requested + _install_component_from_source(git_url) + + return os.path.join(os.getcwd(), entry["appEntrypointFile"]) + + return None + + +def gallery_component(name: str, yes_arg: bool, version_arg: str, cwd: str = None) -> str: # make sure org/component-name name is correct org, component = _validate_name(name, resource_type="component", example="lightning/LAI-slack-component") @@ -28,7 +139,9 @@ def gallery_component(name: str, yes_arg: bool, version_arg: str, cwd: str = Non git_url = _show_install_component_prompt(component_entry, component, org, yes_arg) # run installation if requested - _install_component(git_url) + _install_component_from_source(git_url) + + return os.path.join(os.getcwd(), component_entry["entrypointFile"]) def non_gallery_component(gh_url: str, yes_arg: bool, cwd: str = None) -> None: @@ -37,10 +150,10 @@ def non_gallery_component(gh_url: str, yes_arg: bool, cwd: str = None) -> None: git_url = _show_non_gallery_install_component_prompt(gh_url, yes_arg) # run installation if requested - _install_component(git_url) + _install_component_from_source(git_url) -def gallery_app(name: str, yes_arg: bool, version_arg: str, cwd: str = None, overwrite: bool = False) -> None: +def gallery_app(name: str, yes_arg: bool, version_arg: str, cwd: str = None, overwrite: bool = False) -> str: # make sure org/app-name syntax is correct org, app = _validate_name(name, resource_type="app", example="lightning/quick-start") @@ -57,7 +170,9 @@ def gallery_app(name: str, yes_arg: bool, version_arg: str, cwd: str = None, ove ) # run installation if requested - _install_app(source_url, git_url, folder_name, cwd=cwd, overwrite=overwrite, git_sha=git_sha) + _install_app_from_source(source_url, git_url, folder_name, cwd=cwd, overwrite=overwrite, git_sha=git_sha) + + return os.path.join(os.getcwd(), folder_name, app_entry["appEntrypointFile"]) def non_gallery_app(gh_url: str, yes_arg: bool, cwd: str = None, overwrite: bool = False) -> None: @@ -66,7 +181,7 @@ def non_gallery_app(gh_url: str, yes_arg: bool, cwd: str = None, overwrite: bool repo_url, folder_name = _show_non_gallery_install_app_prompt(gh_url, yes_arg) # run installation if requested - _install_app(repo_url, repo_url, folder_name, cwd=cwd, overwrite=overwrite) + _install_app_from_source(repo_url, repo_url, folder_name, cwd=cwd, overwrite=overwrite) def _show_install_component_prompt(entry: Dict[str, str], component: str, org: str, yes_arg: bool) -> str: @@ -299,7 +414,35 @@ def _validate_name(name: str, resource_type: str, example: str) -> Tuple[str, st return org, resource -def _resolve_resource(registry_url: str, name: str, version_arg: str, resource_type: str) -> Dict[str, str]: +def _resolve_entry(name, version_arg) -> Tuple[Optional[Dict], Optional[str]]: + entry = None + kind = None + + # resolve registry (orgs can have a private registry through their environment variables) + registry_url = _resolve_app_registry() + + # load the app resource + entry = _resolve_resource(registry_url, name=name, version_arg=version_arg, resource_type="app", raise_error=False) + + if not entry: + + registry_url = _resolve_component_registry() + + # load the component resource + entry = _resolve_resource( + registry_url, name=name, version_arg=version_arg, resource_type="component", raise_error=False + ) + kind = "component" if entry else None + + else: + kind = "app" + + return entry, kind + + +def _resolve_resource( + registry_url: str, name: str, version_arg: str, resource_type: str, raise_error: bool = True +) -> Dict[str, str]: gallery_entries = [] try: response = requests.get(registry_url) @@ -327,7 +470,10 @@ def _resolve_resource(registry_url: str, name: str, version_arg: str, resource_t all_versions.append(x["version"]) if len(entries) == 0: - raise SystemExit(f"{resource_type}: '{name}' is not available on ⚡ Lightning AI ⚡") + if raise_error: + raise SystemExit(f"{resource_type}: '{name}' is not available on ⚡ Lightning AI ⚡") + else: + return None entry = None if version_arg == "latest": @@ -337,11 +483,14 @@ def _resolve_resource(registry_url: str, name: str, version_arg: str, resource_t if e["version"] == version_arg: entry = e break - if entry is None: - raise Exception( - f"{resource_type}: 'Version {version_arg} for {name}' is not available on ⚡ Lightning AI ⚡. " - f"Here is the list of all availables versions:{os.linesep}{os.linesep.join(all_versions)}" - ) + if entry is None and raise_error: + if raise_error: + raise Exception( + f"{resource_type}: 'Version {version_arg} for {name}' is not available on ⚡ Lightning AI ⚡. " + f"Here is the list of all availables versions:{os.linesep}{os.linesep.join(all_versions)}" + ) + else: + return None return entry @@ -381,7 +530,7 @@ def _install_with_env(repo_url: str, folder_name: str, cwd: str = None) -> None: logger.info(m) -def _install_app( +def _install_app_from_source( source_url: str, git_url: str, folder_name: str, cwd: str = None, overwrite: bool = False, git_sha: str = None ) -> None: """Installing lighting app from the `git_url` @@ -458,7 +607,7 @@ def _install_app( logger.info(m) -def _install_component(git_url: str) -> None: +def _install_component_from_source(git_url: str) -> None: logger.info("⚡ RUN: pip install") out = subprocess.check_output(["pip", "install", git_url]) diff --git a/src/lightning_app/cli/lightning_cli.py b/src/lightning_app/cli/lightning_cli.py index 59a4fbddc889b..cef527eb5fa67 100644 --- a/src/lightning_app/cli/lightning_cli.py +++ b/src/lightning_app/cli/lightning_cli.py @@ -232,7 +232,14 @@ def _run_app( secret: tuple, run_app_comment_commands: bool, ) -> None: - file = _prepare_file(file) + + if not os.path.exists(file): + original_file = file + file = cmd_install.gallery_apps_and_components(file, True, "latest", overwrite=False) # type: ignore[assignment] # noqa E501 + if file is None: + click.echo(f"The provided entrypoint `{original_file}` doesn't exist.") + sys.exit(1) + run_app_comment_commands = True if not cloud and cluster_id is not None: raise click.ClickException("Using the flag --cluster-id in local execution is not supported.") @@ -288,7 +295,7 @@ def run() -> None: @run.command("app") -@click.argument("file", type=click.Path(exists=True)) +@click.argument("file", type=str) @click.option("--cloud", type=bool, default=False, is_flag=True) @click.option( "--cluster-id", @@ -361,6 +368,7 @@ def run_app( _main.add_command(get_list) _main.add_command(delete) _main.add_command(create) +_main.add_command(cmd_install.install) @_main.command("ssh") @@ -444,74 +452,6 @@ def ssh(app_name: str = None, component_name: str = None) -> None: os.execv(ssh_path, ["-tt", f"{component_id}@{ssh_endpoint}"]) -@_main.group() -def install() -> None: - """Install a Lightning App and/or component.""" - - -@install.command("app") -@click.argument("name", type=str) -@click.option( - "--yes", - "-y", - is_flag=True, - help="disables prompt to ask permission to create env and run install cmds", -) -@click.option( - "--version", - "-v", - type=str, - help="Specify the version to install. By default it uses 'latest'", - default="latest", - show_default=True, -) -@click.option( - "--overwrite", - "-f", - is_flag=True, - default=False, - help="When set, overwrite the app directory without asking if it already exists.", -) -def install_app(name: str, yes: bool, version: str, overwrite: bool = False) -> None: - if "github.com" in name: - if version != "latest": - logger.warn( - f"The provided version {version} isn't the officially supported one. " - f"The provided version will be ignored." - ) - cmd_install.non_gallery_app(name, yes, overwrite=overwrite) - else: - cmd_install.gallery_app(name, yes, version, overwrite=overwrite) - - -@install.command("component") -@click.argument("name", type=str) -@click.option( - "--yes", - "-y", - is_flag=True, - help="disables prompt to ask permission to create env and run install cmds", -) -@click.option( - "--version", - "-v", - type=str, - help="Specify the version to install. By default it uses 'latest'", - default="latest", - show_default=True, -) -def install_component(name: str, yes: bool, version: str) -> None: - if "github.com" in name: - if version != "latest": - logger.warn( - f"The provided version {version} isn't the officially supported one. " - f"The provided version will be ignored." - ) - cmd_install.non_gallery_component(name, yes) - else: - cmd_install.gallery_component(name, yes, version) - - @_main.group() def init() -> None: """Init a Lightning App and/or component.""" diff --git a/src/lightning_app/components/__init__.py b/src/lightning_app/components/__init__.py index ee52fb55670f2..ca47c36071dae 100644 --- a/src/lightning_app/components/__init__.py +++ b/src/lightning_app/components/__init__.py @@ -1,3 +1,4 @@ +from lightning_app.components.auto_scaler import AutoScaler from lightning_app.components.database.client import DatabaseClient from lightning_app.components.database.server import Database from lightning_app.components.multi_node import ( @@ -15,6 +16,7 @@ from lightning_app.components.training import LightningTrainerScript, PyTorchLightningScriptRunner __all__ = [ + "AutoScaler", "DatabaseClient", "Database", "PopenPythonScript", diff --git a/src/lightning_app/components/auto_scaler.py b/src/lightning_app/components/auto_scaler.py new file mode 100644 index 0000000000000..62e6180c49665 --- /dev/null +++ b/src/lightning_app/components/auto_scaler.py @@ -0,0 +1,566 @@ +import asyncio +import logging +import os +import secrets +import time +import uuid +from base64 import b64encode +from itertools import cycle +from typing import Any, Dict, List, Tuple, Type + +import aiohttp +import aiohttp.client_exceptions +import requests +import uvicorn +from fastapi import Depends, FastAPI, HTTPException, Request +from fastapi.middleware.cors import CORSMiddleware +from fastapi.responses import RedirectResponse +from fastapi.security import HTTPBasic, HTTPBasicCredentials +from pydantic import BaseModel +from starlette.status import HTTP_401_UNAUTHORIZED + +from lightning_app.core.flow import LightningFlow +from lightning_app.core.work import LightningWork +from lightning_app.utilities.app_helpers import Logger +from lightning_app.utilities.packaging.cloud_compute import CloudCompute + +logger = Logger(__name__) + + +def _raise_granular_exception(exception: Exception) -> None: + """Handle an exception from hitting the model servers.""" + if not isinstance(exception, Exception): + return + + if isinstance(exception, HTTPException): + raise exception + + if isinstance(exception, aiohttp.client_exceptions.ServerDisconnectedError): + raise HTTPException(500, "Worker Server Disconnected") from exception + + if isinstance(exception, aiohttp.client_exceptions.ClientError): + logging.exception(exception) + raise HTTPException(500, "Worker Server error") from exception + + if isinstance(exception, asyncio.TimeoutError): + raise HTTPException(408, "Request timed out") from exception + + if isinstance(exception, Exception): + if exception.args[0] == "Server disconnected": + raise HTTPException(500, "Worker Server disconnected") from exception + + logging.exception(exception) + raise HTTPException(500, exception.args[0]) from exception + + +class _SysInfo(BaseModel): + num_workers: int + servers: List[str] + num_requests: int + processing_time: int + global_request_count: int + + +class _BatchRequestModel(BaseModel): + inputs: List[Any] + + +def _create_fastapi(title: str) -> FastAPI: + fastapi_app = FastAPI(title=title) + + fastapi_app.add_middleware( + CORSMiddleware, + allow_origins=["*"], + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], + ) + + fastapi_app.global_request_count = 0 + fastapi_app.num_current_requests = 0 + fastapi_app.last_processing_time = 0 + + @fastapi_app.get("/", include_in_schema=False) + async def docs(): + return RedirectResponse("/docs") + + @fastapi_app.get("/num-requests") + async def num_requests() -> int: + return fastapi_app.num_current_requests + + return fastapi_app + + +class _LoadBalancer(LightningWork): + r"""The LoadBalancer is a LightningWork component that collects the requests and sends them to the prediciton API + asynchronously using RoundRobin scheduling. It also performs auto batching of the incoming requests. + + The LoadBalancer exposes system endpoints with a basic HTTP authentication, in order to activate the authentication + you need to provide a system password from environment variable:: + + lightning run app app.py --env AUTO_SCALER_AUTH_PASSWORD=PASSWORD + + After enabling you will require to send username and password from the request header for the private endpoints. + + Args: + input_type: Input type. + output_type: Output type. + endpoint: The REST API path. + max_batch_size: The number of requests processed at once. + timeout_batching: The number of seconds to wait before sending the requests to process in order to allow for + requests to be batched. In any case, requests are processed as soon as `max_batch_size` is reached. + timeout_keep_alive: The number of seconds until it closes Keep-Alive connections if no new data is received. + timeout_inference_request: The number of seconds to wait for inference. + \**kwargs: Arguments passed to :func:`LightningWork.init` like ``CloudCompute``, ``BuildConfig``, etc. + """ + + def __init__( + self, + input_type: BaseModel, + output_type: BaseModel, + endpoint: str, + max_batch_size: int = 8, + # all timeout args are in seconds + timeout_batching: int = 1, + timeout_keep_alive: int = 60, + timeout_inference_request: int = 60, + **kwargs: Any, + ) -> None: + super().__init__(cloud_compute=CloudCompute("default"), **kwargs) + self._input_type = input_type + self._output_type = output_type + self._timeout_keep_alive = timeout_keep_alive + self._timeout_inference_request = timeout_inference_request + self.servers = [] + self.max_batch_size = max_batch_size + self.timeout_batching = timeout_batching + self._iter = None + self._batch = [] + self._responses = {} # {request_id: response} + self._last_batch_sent = 0 + + if not endpoint.startswith("/"): + endpoint = "/" + endpoint + + self.endpoint = endpoint + + async def send_batch(self, batch: List[Tuple[str, _BatchRequestModel]]): + server = next(self._iter) # round-robin + request_data: List[_LoadBalancer._input_type] = [b[1] for b in batch] + batch_request_data = _BatchRequestModel(inputs=request_data) + + try: + async with aiohttp.ClientSession() as session: + headers = { + "accept": "application/json", + "Content-Type": "application/json", + } + async with session.post( + f"{server}{self.endpoint}", + json=batch_request_data.dict(), + timeout=self._timeout_inference_request, + headers=headers, + ) as response: + if response.status == 408: + raise HTTPException(408, "Request timed out") + response.raise_for_status() + response = await response.json() + outputs = response["outputs"] + if len(batch) != len(outputs): + raise RuntimeError(f"result has {len(outputs)} items but batch is {len(batch)}") + result = {request[0]: r for request, r in zip(batch, outputs)} + self._responses.update(result) + except Exception as ex: + result = {request[0]: ex for request in batch} + self._responses.update(result) + + async def consumer(self): + while True: + await asyncio.sleep(0.05) + + batch = self._batch[: self.max_batch_size] + while batch and ( + (len(batch) == self.max_batch_size) or ((time.time() - self._last_batch_sent) > self.timeout_batching) + ): + asyncio.create_task(self.send_batch(batch)) + + self._batch = self._batch[self.max_batch_size :] + batch = self._batch[: self.max_batch_size] + self._last_batch_sent = time.time() + + async def process_request(self, data: BaseModel): + if not self.servers: + raise HTTPException(500, "None of the workers are healthy!") + + request_id = uuid.uuid4().hex + request: Tuple = (request_id, data) + self._batch.append(request) + + while True: + await asyncio.sleep(0.05) + + if request_id in self._responses: + result = self._responses[request_id] + del self._responses[request_id] + _raise_granular_exception(result) + return result + + def run(self): + logger.info(f"servers: {self.servers}") + lock = asyncio.Lock() + + self._iter = cycle(self.servers) + self._last_batch_sent = time.time() + + fastapi_app = _create_fastapi("Load Balancer") + security = HTTPBasic() + fastapi_app.SEND_TASK = None + + @fastapi_app.middleware("http") + async def current_request_counter(request: Request, call_next): + if not request.scope["path"] == self.endpoint: + return await call_next(request) + fastapi_app.global_request_count += 1 + fastapi_app.num_current_requests += 1 + start_time = time.time() + response = await call_next(request) + processing_time = time.time() - start_time + fastapi_app.last_processing_time = processing_time + fastapi_app.num_current_requests -= 1 + return response + + @fastapi_app.on_event("startup") + async def startup_event(): + fastapi_app.SEND_TASK = asyncio.create_task(self.consumer()) + + @fastapi_app.on_event("shutdown") + def shutdown_event(): + fastapi_app.SEND_TASK.cancel() + + def authenticate_private_endpoint(credentials: HTTPBasicCredentials = Depends(security)): + AUTO_SCALER_AUTH_PASSWORD = os.environ.get("AUTO_SCALER_AUTH_PASSWORD", "") + if len(AUTO_SCALER_AUTH_PASSWORD) == 0: + logger.warn( + "You have not set a password for private endpoints! To set a password, add " + "`--env AUTO_SCALER_AUTH_PASSWORD=` to your lightning run command." + ) + current_password_bytes = credentials.password.encode("utf8") + is_correct_password = secrets.compare_digest( + current_password_bytes, AUTO_SCALER_AUTH_PASSWORD.encode("utf8") + ) + if not is_correct_password: + raise HTTPException( + status_code=401, + detail="Incorrect password", + headers={"WWW-Authenticate": "Basic"}, + ) + return True + + @fastapi_app.get("/system/info", response_model=_SysInfo) + async def sys_info(authenticated: bool = Depends(authenticate_private_endpoint)): + return _SysInfo( + num_workers=len(self.servers), + servers=self.servers, + num_requests=fastapi_app.num_current_requests, + processing_time=fastapi_app.last_processing_time, + global_request_count=fastapi_app.global_request_count, + ) + + @fastapi_app.put("/system/update-servers") + async def update_servers(servers: List[str], authenticated: bool = Depends(authenticate_private_endpoint)): + async with lock: + self.servers = servers + self._iter = cycle(self.servers) + + @fastapi_app.post(self.endpoint, response_model=self._output_type) + async def balance_api(inputs: self._input_type): + return await self.process_request(inputs) + + uvicorn.run( + fastapi_app, + host=self.host, + port=self.port, + loop="uvloop", + timeout_keep_alive=self._timeout_keep_alive, + access_log=False, + ) + + def update_servers(self, server_works: List[LightningWork]): + """Updates works that load balancer distributes requests to. + + AutoScaler uses this method to increase/decrease the number of works. + """ + old_servers = set(self.servers) + server_urls: List[str] = [server.url for server in server_works if server.url] + new_servers = set(server_urls) + + if new_servers == old_servers: + return + + if new_servers - old_servers: + logger.info(f"servers added: {new_servers - old_servers}") + + deleted_servers = old_servers - new_servers + if deleted_servers: + logger.info(f"servers deleted: {deleted_servers}") + + self.send_request_to_update_servers(server_urls) + + def send_request_to_update_servers(self, servers: List[str]): + AUTHORIZATION_TYPE = "Basic" + USERNAME = "lightning" + AUTO_SCALER_AUTH_PASSWORD = os.environ.get("AUTO_SCALER_AUTH_PASSWORD", "") + + try: + param = f"{USERNAME}:{AUTO_SCALER_AUTH_PASSWORD}".encode() + data = b64encode(param).decode("utf-8") + except (ValueError, UnicodeDecodeError) as e: + raise HTTPException( + status_code=HTTP_401_UNAUTHORIZED, + detail="Invalid authentication credentials", + headers={"WWW-Authenticate": "Basic"}, + ) from e + headers = { + "accept": "application/json", + "username": USERNAME, + "Authorization": AUTHORIZATION_TYPE + " " + data, + } + response = requests.put(f"{self.url}/system/update-servers", json=servers, headers=headers, timeout=10) + response.raise_for_status() + + +class AutoScaler(LightningFlow): + """The ``AutoScaler`` can be used to automatically change the number of replicas of the given server in + response to changes in the number of incoming requests. Incoming requests will be batched and balanced across + the replicas. + + Args: + min_replicas: The number of works to start when app initializes. + max_replicas: The max number of works to spawn to handle the incoming requests. + autoscale_interval: The number of seconds to wait before checking whether to upscale or downscale the works. + endpoint: Provide the REST API path. + max_batch_size: (auto-batching) The number of requests to process at once. + timeout_batching: (auto-batching) The number of seconds to wait before sending the requests to process. + input_type: Input type. + output_type: Output type. + + .. testcode:: + + import lightning as L + + # Example 1: Auto-scaling serve component out-of-the-box + app = L.LightningApp( + L.app.components.AutoScaler( + MyPythonServer, + min_replicas=1, + max_replicas=8, + autoscale_interval=10, + ) + ) + + # Example 2: Customizing the scaling logic + class MyAutoScaler(L.app.components.AutoScaler): + def scale(self, replicas: int, metrics: dict) -> int: + pending_requests_per_running_or_pending_work = metrics["pending_requests"] / ( + replicas + metrics["pending_works"] + ) + + # upscale + max_requests_per_work = self.max_batch_size + if pending_requests_per_running_or_pending_work >= max_requests_per_work: + return replicas + 1 + + # downscale + min_requests_per_work = max_requests_per_work * 0.25 + if pending_requests_per_running_or_pending_work < min_requests_per_work: + return replicas - 1 + + return replicas + + + app = L.LightningApp( + MyAutoScaler( + MyPythonServer, + min_replicas=1, + max_replicas=8, + autoscale_interval=10, + max_batch_size=8, # for auto batching + timeout_batching=1, # for auto batching + ) + ) + """ + + def __init__( + self, + work_cls: Type[LightningWork], + min_replicas: int = 1, + max_replicas: int = 4, + autoscale_interval: int = 10, + max_batch_size: int = 8, + timeout_batching: float = 1, + endpoint: str = "api/predict", + input_type: BaseModel = Dict, + output_type: BaseModel = Dict, + *work_args: Any, + **work_kwargs: Any, + ) -> None: + super().__init__() + self.num_replicas = 0 + self._work_registry = {} + + self._work_cls = work_cls + self._work_args = work_args + self._work_kwargs = work_kwargs + + self._input_type = input_type + self._output_type = output_type + self.autoscale_interval = autoscale_interval + self.max_batch_size = max_batch_size + + if max_replicas < min_replicas: + raise ValueError( + f"`max_replicas={max_replicas}` must be less than or equal to `min_replicas={min_replicas}`." + ) + self.max_replicas = max_replicas + self.min_replicas = min_replicas + self._last_autoscale = time.time() + self.fake_trigger = 0 + + self.load_balancer = _LoadBalancer( + input_type=self._input_type, + output_type=self._output_type, + endpoint=endpoint, + max_batch_size=max_batch_size, + timeout_batching=timeout_batching, + cache_calls=True, + parallel=True, + ) + for _ in range(min_replicas): + work = self.create_work() + self.add_work(work) + + @property + def workers(self) -> List[LightningWork]: + return [self.get_work(i) for i in range(self.num_replicas)] + + def create_work(self) -> LightningWork: + """Replicates a LightningWork instance with args and kwargs provided via ``__init__``.""" + # TODO: Remove `start_with_flow=False` for faster initialization on the cloud + return self._work_cls(*self._work_args, **self._work_kwargs, start_with_flow=False) + + def add_work(self, work) -> str: + """Adds a new LightningWork instance. + + Returns: + The name of the new work attribute. + """ + work_attribute = uuid.uuid4().hex + work_attribute = f"worker_{self.num_replicas}_{str(work_attribute)}" + setattr(self, work_attribute, work) + self._work_registry[self.num_replicas] = work_attribute + self.num_replicas += 1 + return work_attribute + + def remove_work(self, index: int) -> str: + """Removes the ``index`` th LightningWork instance.""" + work_attribute = self._work_registry[index] + del self._work_registry[index] + work = getattr(self, work_attribute) + work.stop() + self.num_replicas -= 1 + return work_attribute + + def get_work(self, index: int) -> LightningWork: + """Returns the ``LightningWork`` instance with the given index.""" + work_attribute = self._work_registry[index] + work = getattr(self, work_attribute) + return work + + def run(self): + if not self.load_balancer.is_running: + self.load_balancer.run() + + for work in self.workers: + work.run() + + if self.load_balancer.url: + self.fake_trigger += 1 # Note: change state to keep calling `run`. + self.autoscale() + + def scale(self, replicas: int, metrics: dict) -> int: + """The default scaling logic that users can override. + + Args: + replicas: The number of running works. + metrics: ``metrics['pending_requests']`` is the total number of requests that are currently pending. + ``metrics['pending_works']`` is the number of pending works. + + Returns: + The target number of running works. The value will be adjusted after this method runs + so that it satisfies ``min_replicas<=replicas<=max_replicas``. + """ + pending_requests_per_running_or_pending_work = metrics["pending_requests"] / ( + replicas + metrics["pending_works"] + ) + + # scale out if the number of pending requests exceeds max batch size. + max_requests_per_work = self.max_batch_size + if pending_requests_per_running_or_pending_work >= max_requests_per_work: + return replicas + 1 + + # scale in if the number of pending requests is below 25% of max_requests_per_work + min_requests_per_work = max_requests_per_work * 0.25 + if pending_requests_per_running_or_pending_work < min_requests_per_work: + return replicas - 1 + + return replicas + + @property + def num_pending_requests(self) -> int: + """Fetches the number of pending requests via load balancer.""" + return int(requests.get(f"{self.load_balancer.url}/num-requests").json()) + + @property + def num_pending_works(self) -> int: + """The number of pending works.""" + return sum(work.is_pending for work in self.workers) + + def autoscale(self) -> None: + """Adjust the number of works based on the target number returned by ``self.scale``.""" + if time.time() - self._last_autoscale < self.autoscale_interval: + return + + self.load_balancer.update_servers(self.workers) + + metrics = { + "pending_requests": self.num_pending_requests, + "pending_works": self.num_pending_works, + } + + # ensure min_replicas <= num_replicas <= max_replicas + num_target_workers = max( + self.min_replicas, + min(self.max_replicas, self.scale(self.num_replicas, metrics)), + ) + + # upscale + num_workers_to_add = num_target_workers - self.num_replicas + for _ in range(num_workers_to_add): + logger.info(f"Upscaling from {self.num_replicas} to {self.num_replicas + 1}") + work = self.create_work() + new_work_id = self.add_work(work) + logger.info(f"Work created: '{new_work_id}'") + + # downscale + num_workers_to_remove = self.num_replicas - num_target_workers + for _ in range(num_workers_to_remove): + logger.info(f"Downscaling from {self.num_replicas} to {self.num_replicas - 1}") + removed_work_id = self.remove_work(self.num_replicas - 1) + logger.info(f"Work removed: '{removed_work_id}'") + + self.load_balancer.update_servers(self.workers) + self._last_autoscale = time.time() + + def configure_layout(self): + tabs = [{"name": "Swagger", "content": self.load_balancer.url}] + return tabs diff --git a/src/lightning_app/components/database/server.py b/src/lightning_app/components/database/server.py index 01bd8f3b12033..6d187e4cda133 100644 --- a/src/lightning_app/components/database/server.py +++ b/src/lightning_app/components/database/server.py @@ -19,6 +19,8 @@ if _is_sqlmodel_available(): from sqlmodel import SQLModel +else: + SQLModel = object # Required to avoid Uvicorn Server overriding Lightning App signal handlers. diff --git a/src/lightning_app/components/multi_node/lite.py b/src/lightning_app/components/multi_node/lite.py index 14d6081872bff..d23eb0a72244f 100644 --- a/src/lightning_app/components/multi_node/lite.py +++ b/src/lightning_app/components/multi_node/lite.py @@ -1,4 +1,6 @@ +import importlib import os +import warnings from dataclasses import dataclass from typing import Any, Callable, Type @@ -30,8 +32,19 @@ def run( node_rank: int, nprocs: int, ): - from lightning.fabric import LightningLite - from lightning.fabric.strategies import DDPSpawnShardedStrategy, DDPSpawnStrategy + lites = [] + strategies = [] + mps_accelerators = [] + + for pkg_name in ("lightning.fabric", "lightning_" + "fabric"): + try: + pkg = importlib.import_module(pkg_name) + lites.append(pkg.Fabric) + strategies.append(pkg.strategies.DDPSpawnShardedStrategy) + strategies.append(pkg.strategies.DDPSpawnStrategy) + mps_accelerators.append(pkg.accelerators.MPSAccelerator) + except (ImportError, ModuleNotFoundError): + continue # Used to configure PyTorch progress group os.environ["MASTER_ADDR"] = main_address @@ -52,7 +65,15 @@ def run( def pre_fn(lite, *args, **kwargs): kwargs["devices"] = nprocs kwargs["num_nodes"] = num_nodes - kwargs["accelerator"] = "auto" + + if any(acc.is_available() for acc in mps_accelerators): + old_acc_value = kwargs.get("accelerator", "auto") + kwargs["accelerator"] = "cpu" + + if old_acc_value != kwargs["accelerator"]: + warnings.warn("Forcing `accelerator=cpu` as MPS does not support distributed training.") + else: + kwargs["accelerator"] = "auto" strategy = kwargs.get("strategy", None) if strategy: if isinstance(strategy, str): @@ -60,15 +81,20 @@ def pre_fn(lite, *args, **kwargs): strategy = "ddp" elif strategy == "ddp_sharded_spawn": strategy = "ddp_sharded" - elif isinstance(strategy, (DDPSpawnStrategy, DDPSpawnShardedStrategy)): - raise Exception("DDP Spawned strategies aren't supported yet.") + elif isinstance(strategy, tuple(strategies)): + raise ValueError("DDP Spawned strategies aren't supported yet.") + + kwargs["strategy"] = strategy + return {}, args, kwargs tracer = Tracer() - tracer.add_traced(LightningLite, "__init__", pre_fn=pre_fn) + for ll in lites: + tracer.add_traced(ll, "__init__", pre_fn=pre_fn) tracer._instrument() - work_run() + ret_val = work_run() tracer._restore() + return ret_val class LiteMultiNode(MultiNode): diff --git a/src/lightning_app/components/multi_node/pytorch_spawn.py b/src/lightning_app/components/multi_node/pytorch_spawn.py index 3119ffc51e0b5..013bdbcaec347 100644 --- a/src/lightning_app/components/multi_node/pytorch_spawn.py +++ b/src/lightning_app/components/multi_node/pytorch_spawn.py @@ -88,7 +88,7 @@ def run( elif world_size > 1: raise Exception("Torch distributed should be available.") - work_run(world_size, node_rank, global_rank, local_rank) + return work_run(world_size, node_rank, global_rank, local_rank) class PyTorchSpawnMultiNode(MultiNode): diff --git a/src/lightning_app/components/multi_node/trainer.py b/src/lightning_app/components/multi_node/trainer.py index b0e2f96d69660..76d744e24608c 100644 --- a/src/lightning_app/components/multi_node/trainer.py +++ b/src/lightning_app/components/multi_node/trainer.py @@ -1,4 +1,6 @@ +import importlib import os +import warnings from dataclasses import dataclass from typing import Any, Callable, Type @@ -30,9 +32,19 @@ def run( node_rank: int, nprocs: int, ): - from lightning.fabric.strategies import DDPSpawnShardedStrategy, DDPSpawnStrategy - from lightning.pytorch import Trainer as LTrainer - from pytorch_lightning import Trainer as PLTrainer + trainers = [] + strategies = [] + mps_accelerators = [] + + for pkg_name in ("lightning.pytorch", "pytorch_" + "lightning"): + try: + pkg = importlib.import_module(pkg_name) + trainers.append(pkg.Trainer) + strategies.append(pkg.strategies.DDPSpawnShardedStrategy) + strategies.append(pkg.strategies.DDPSpawnStrategy) + mps_accelerators.append(pkg.accelerators.MPSAccelerator) + except (ImportError, ModuleNotFoundError): + continue # Used to configure PyTorch progress group os.environ["MASTER_ADDR"] = main_address @@ -50,7 +62,15 @@ def run( def pre_fn(trainer, *args, **kwargs): kwargs["devices"] = nprocs kwargs["num_nodes"] = num_nodes - kwargs["accelerator"] = "auto" + if any(acc.is_available() for acc in mps_accelerators): + old_acc_value = kwargs.get("accelerator", "auto") + kwargs["accelerator"] = "cpu" + + if old_acc_value != kwargs["accelerator"]: + warnings.warn("Forcing `accelerator=cpu` as MPS does not support distributed training.") + else: + kwargs["accelerator"] = "auto" + strategy = kwargs.get("strategy", None) if strategy: if isinstance(strategy, str): @@ -58,16 +78,18 @@ def pre_fn(trainer, *args, **kwargs): strategy = "ddp" elif strategy == "ddp_sharded_spawn": strategy = "ddp_sharded" - elif isinstance(strategy, (DDPSpawnStrategy, DDPSpawnShardedStrategy)): - raise Exception("DDP Spawned strategies aren't supported yet.") + elif isinstance(strategy, tuple(strategies)): + raise ValueError("DDP Spawned strategies aren't supported yet.") + kwargs["strategy"] = strategy return {}, args, kwargs tracer = Tracer() - tracer.add_traced(PLTrainer, "__init__", pre_fn=pre_fn) - tracer.add_traced(LTrainer, "__init__", pre_fn=pre_fn) + for trainer in trainers: + tracer.add_traced(trainer, "__init__", pre_fn=pre_fn) tracer._instrument() - work_run() + ret_val = work_run() tracer._restore() + return ret_val class LightningTrainerMultiNode(MultiNode): diff --git a/src/lightning_app/components/python/tracer.py b/src/lightning_app/components/python/tracer.py index c476f083258fc..d10ca92252ed8 100644 --- a/src/lightning_app/components/python/tracer.py +++ b/src/lightning_app/components/python/tracer.py @@ -22,6 +22,9 @@ class Code(TypedDict): class TracerPythonScript(LightningWork): + + _start_method = "spawn" + def on_before_run(self): """Called before the python script is executed.""" diff --git a/src/lightning_app/components/serve/gradio.py b/src/lightning_app/components/serve/gradio.py index 6e9b1d8777f67..7c07129d39b25 100644 --- a/src/lightning_app/components/serve/gradio.py +++ b/src/lightning_app/components/serve/gradio.py @@ -1,10 +1,8 @@ import abc -import os from functools import partial from types import ModuleType from typing import Any, List, Optional -from lightning_app.components.serve.python_server import _PyTorchSpawnRunExecutor, WorkRunExecutor from lightning_app.core.work import LightningWork from lightning_app.utilities.imports import _is_gradio_available, requires @@ -36,15 +34,13 @@ class ServeGradio(LightningWork, abc.ABC): title: Optional[str] = None description: Optional[str] = None + _start_method = "spawn" + def __init__(self, *args, **kwargs): requires("gradio")(super().__init__(*args, **kwargs)) assert self.inputs assert self.outputs self._model = None - # Note: Enable to run inference on GPUs. - self._run_executor_cls = ( - WorkRunExecutor if os.getenv("LIGHTNING_CLOUD_APP_ID", None) else _PyTorchSpawnRunExecutor - ) @property def model(self): @@ -78,3 +74,6 @@ def run(self, *args, **kwargs): server_port=self.port, enable_queue=self.enable_queue, ) + + def configure_layout(self) -> str: + return self.url diff --git a/src/lightning_app/components/serve/python_server.py b/src/lightning_app/components/serve/python_server.py index 99d51ac1cf4fc..1868b0b357fd3 100644 --- a/src/lightning_app/components/serve/python_server.py +++ b/src/lightning_app/components/serve/python_server.py @@ -1,20 +1,18 @@ import abc import base64 import os +import platform from pathlib import Path from typing import Any, Dict, Optional import uvicorn from fastapi import FastAPI -from lightning_utilities.core.imports import module_available +from lightning_utilities.core.imports import compare_version, module_available from pydantic import BaseModel -from starlette.staticfiles import StaticFiles -from lightning_app.core.queues import MultiProcessQueue from lightning_app.core.work import LightningWork from lightning_app.utilities.app_helpers import Logger from lightning_app.utilities.imports import _is_torch_available, requires -from lightning_app.utilities.proxies import _proxy_setattr, unwrap, WorkRunExecutor, WorkStateObserver logger = Logger(__name__) @@ -28,44 +26,19 @@ __doctest_skip__ += ["PythonServer", "PythonServer.*"] -class _PyTorchSpawnRunExecutor(WorkRunExecutor): +def _get_device(): + import operator - """This Executor enables to move PyTorch tensors on GPU. + import torch - Without this executor, it would raise the following exception: - RuntimeError: Cannot re-initialize CUDA in forked subprocess. - To use CUDA with multiprocessing, you must use the 'spawn' start method - """ + _TORCH_GREATER_EQUAL_1_12 = compare_version("torch", operator.ge, "1.12.0") - enable_start_observer: bool = False + local_rank = int(os.getenv("LOCAL_RANK", "0")) - def __call__(self, *args: Any, **kwargs: Any): - import torch - - with self.enable_spawn(): - queue = self.delta_queue if isinstance(self.delta_queue, MultiProcessQueue) else self.delta_queue.to_dict() - torch.multiprocessing.spawn( - self.dispatch_run, - args=(self.__class__, self.work, queue, args, kwargs), - nprocs=1, - ) - - @staticmethod - def dispatch_run(local_rank, cls, work, delta_queue, args, kwargs): - if local_rank == 0: - if isinstance(delta_queue, dict): - delta_queue = cls.process_queue(delta_queue) - work._request_queue = cls.process_queue(work._request_queue) - work._response_queue = cls.process_queue(work._response_queue) - - state_observer = WorkStateObserver(work, delta_queue=delta_queue) - state_observer.start() - _proxy_setattr(work, delta_queue, state_observer) - - unwrap(work.run)(*args, **kwargs) - - if local_rank == 0: - state_observer.join(0) + if _TORCH_GREATER_EQUAL_1_12 and torch.backends.mps.is_available() and platform.processor() in ("arm", "arm64"): + return torch.device("mps", local_rank) + else: + return torch.device(f"cuda:{local_rank}" if torch.cuda.is_available() else "cpu") class _DefaultInputData(BaseModel): @@ -96,6 +69,9 @@ def _get_sample_data() -> Dict[Any, Any]: class PythonServer(LightningWork, abc.ABC): + + _start_method = "spawn" + @requires(["torch", "lightning_api_access"]) def __init__( # type: ignore self, @@ -114,26 +90,26 @@ def __init__( # type: ignore The default data type is good enough for the basic usecases and it expects the data to be a json object that has one key called `payload` - ``` - input_data = {"payload": "some data"} - ``` + .. code-block:: python + + input_data = {"payload": "some data"} and this can be accessed as `request.payload` in the `predict` method. - ``` - def predict(self, request): - data = request.payload - ``` + .. code-block:: python + + def predict(self, request): + data = request.payload output_type: Optional `output_type` to be provided. This needs to be a pydantic BaseModel class. The default data type is good enough for the basic usecases. It expects the return value of the `predict` method to be a dictionary with one key called `prediction`. - ``` - def predict(self, request): - # some code - return {"prediction": "some data"} - ``` + .. code-block:: python + + def predict(self, request): + # some code + return {"prediction": "some data"} and this can be accessed as `response.json()["prediction"]` in the client if you are using requests library @@ -161,11 +137,6 @@ def predict(self, request): self._input_type = input_type self._output_type = output_type - # Note: Enable to run inference on GPUs. - self._run_executor_cls = ( - WorkRunExecutor if os.getenv("LIGHTNING_CLOUD_APP_ID", None) else _PyTorchSpawnRunExecutor - ) - def setup(self, *args, **kwargs) -> None: """This method is called before the server starts. Override this if you need to download the model or initialize the weights, setting up pipelines etc. @@ -211,60 +182,44 @@ def _get_sample_dict_from_datatype(datatype: Any) -> dict: return out def _attach_predict_fn(self, fastapi_app: FastAPI) -> None: - from torch import inference_mode + from torch import inference_mode, no_grad input_type: type = self.configure_input_type() output_type: type = self.configure_output_type() + device = _get_device() + context = no_grad if device.type == "mps" else inference_mode + def predict_fn(request: input_type): # type: ignore - with inference_mode(): + with context(): return self.predict(request) fastapi_app.post("/predict", response_model=output_type)(predict_fn) - def _attach_frontend(self, fastapi_app: FastAPI) -> None: - from lightning_api_access import APIAccessFrontend - - class_name = self.__class__.__name__ - url = self._future_url if self._future_url else self.url - if not url: - # if the url is still empty, point it to localhost - url = f"http://127.0.0.1:{self.port}" - url = f"{url}/predict" - datatype_parse_error = False - try: - request = self._get_sample_dict_from_datatype(self.configure_input_type()) - except TypeError: - datatype_parse_error = True - - try: - response = self._get_sample_dict_from_datatype(self.configure_output_type()) - except TypeError: - datatype_parse_error = True - - if datatype_parse_error: - - @fastapi_app.get("/") - def index() -> str: - return ( - "Automatic generation of the UI is only supported for simple, " - "non-nested datatype with types string, integer, float and boolean" - ) - - return - - frontend = APIAccessFrontend( - apis=[ - { - "name": class_name, - "url": url, - "method": "POST", - "request": request, - "response": response, - } - ] - ) - fastapi_app.mount("/", StaticFiles(directory=frontend.serve_dir, html=True), name="static") + def configure_layout(self) -> None: + if module_available("lightning_api_access"): + from lightning_api_access import APIAccessFrontend + + class_name = self.__class__.__name__ + url = f"{self.url}/predict" + + try: + request = self._get_sample_dict_from_datatype(self.configure_input_type()) + response = self._get_sample_dict_from_datatype(self.configure_output_type()) + except TypeError: + return None + + return APIAccessFrontend( + apis=[ + { + "name": class_name, + "url": url, + "method": "POST", + "request": request, + "response": response, + } + ] + ) def run(self, *args: Any, **kwargs: Any) -> Any: """Run method takes care of configuring and setting up a FastAPI server behind the scenes. @@ -275,7 +230,6 @@ def run(self, *args: Any, **kwargs: Any) -> Any: fastapi_app = FastAPI() self._attach_predict_fn(fastapi_app) - self._attach_frontend(fastapi_app) logger.info(f"Your app has started. View it in your browser: http://{self.host}:{self.port}") uvicorn.run(app=fastapi_app, host=self.host, port=self.port, log_level="error") diff --git a/src/lightning_app/components/serve/serve.py b/src/lightning_app/components/serve/serve.py index 150ca522e591b..8b6f35364cc38 100644 --- a/src/lightning_app/components/serve/serve.py +++ b/src/lightning_app/components/serve/serve.py @@ -10,7 +10,6 @@ import uvicorn from fastapi import FastAPI from fastapi.responses import JSONResponse -from starlette.responses import RedirectResponse from lightning_app.components.serve.types import _DESERIALIZER, _SERIALIZER from lightning_app.core.work import LightningWork @@ -37,10 +36,6 @@ async def run(self, data) -> Any: return self.serialize(self.predict(self.deserialize(data))) -async def _redirect(): - return RedirectResponse("/docs") - - class ModelInferenceAPI(LightningWork, abc.ABC): def __init__( self, @@ -121,7 +116,6 @@ def run(self): def _populate_app(self, fastapi_service: FastAPI): self._model = self.build_model() - fastapi_service.get("/")(_redirect) fastapi_service.post("/predict", response_class=JSONResponse)( _InferenceCallable( deserialize=_DESERIALIZER[self.input] if self.input else self.deserialize, @@ -134,6 +128,9 @@ def _launch_server(self, fastapi_service: FastAPI): logger.info(f"Your app has started. View it in your browser: http://{self.host}:{self.port}") uvicorn.run(app=fastapi_service, host=self.host, port=self.port, log_level="error") + def configure_layout(self) -> str: + return f"{self.url}/docs" + def _maybe_create_instance() -> Optional[ModelInferenceAPI]: """This function tries to re-create the user `ModelInferenceAPI` if the environment associated with multi diff --git a/src/lightning_app/components/serve/streamlit.py b/src/lightning_app/components/serve/streamlit.py index ed543bd1de7b8..1a325d60fecee 100644 --- a/src/lightning_app/components/serve/streamlit.py +++ b/src/lightning_app/components/serve/streamlit.py @@ -63,6 +63,9 @@ def on_exit(self) -> None: if self._process is not None: self._process.kill() + def configure_layout(self) -> str: + return self.url + class _PatchedWork: """The ``_PatchedWork`` is used to emulate a work instance from a subprocess. This is acheived by patching the diff --git a/src/lightning_app/core/app.py b/src/lightning_app/core/app.py index 42cf0f241b47e..d9389ecd27e24 100644 --- a/src/lightning_app/core/app.py +++ b/src/lightning_app/core/app.py @@ -21,7 +21,7 @@ FRONTEND_DIR, STATE_ACCUMULATE_WAIT, ) -from lightning_app.core.queues import BaseQueue, SingleProcessQueue +from lightning_app.core.queues import BaseQueue from lightning_app.core.work import LightningWork from lightning_app.frontend import Frontend from lightning_app.storage import Drive, Path, Payload @@ -93,12 +93,10 @@ def __init__( >>> from lightning_app.runners import MultiProcessRuntime >>> class RootFlow(LightningFlow): ... def run(self): - ... print("Hello World!") ... self._exit() ... >>> app = LightningApp(RootFlow()) # application can be dispatched using the `runners`. >>> MultiProcessRuntime(app).dispatch() - Hello World! """ self.root_path = root_path # when running behind a proxy @@ -486,7 +484,15 @@ def _run(self) -> bool: """ self._original_state = deepcopy(self.state) done = False - self.ready = self.root.ready + + # TODO: Re-enable the `ready` property once issues are resolved + if not self.root.ready: + warnings.warn( + "One of your Flows returned `.ready` as `False`. " + "This feature is not yet enabled so this will be ignored.", + UserWarning, + ) + self.ready = True self._start_with_flow_works() @@ -549,8 +555,6 @@ def _collect_work_finish_status(self) -> dict: def _should_snapshot(self) -> bool: if len(self.works) == 0: return True - elif isinstance(self.delta_queue, SingleProcessQueue): - return True elif self._has_updated: work_finished_status = self._collect_work_finish_status() if work_finished_status: diff --git a/src/lightning_app/core/constants.py b/src/lightning_app/core/constants.py index 4038c85e7fc1e..da99db9018320 100644 --- a/src/lightning_app/core/constants.py +++ b/src/lightning_app/core/constants.py @@ -75,3 +75,7 @@ def get_lightning_cloud_url() -> str: def enable_multiple_works_in_default_container() -> bool: return bool(int(os.getenv("ENABLE_MULTIPLE_WORKS_IN_DEFAULT_CONTAINER", "0"))) + + +# Number of seconds to wait between filesystem checks when waiting for files in remote storage +REMOTE_STORAGE_WAIT = 0.5 diff --git a/src/lightning_app/core/flow.py b/src/lightning_app/core/flow.py index 72527bf7aee6f..a79794bac3d20 100644 --- a/src/lightning_app/core/flow.py +++ b/src/lightning_app/core/flow.py @@ -10,7 +10,7 @@ from lightning_app.frontend import Frontend from lightning_app.storage import Path from lightning_app.storage.drive import _maybe_create_drive, Drive -from lightning_app.utilities.app_helpers import _is_json_serializable, _LightningAppRef, _set_child_name +from lightning_app.utilities.app_helpers import _is_json_serializable, _LightningAppRef, _set_child_name, is_overridden from lightning_app.utilities.component import _sanitize_state from lightning_app.utilities.exceptions import ExitAppException from lightning_app.utilities.introspection import _is_init_context, _is_run_context @@ -142,6 +142,14 @@ def __setattr__(self, name: str, value: Any) -> None: if name in self._works and value != getattr(self, name): raise AttributeError(f"Cannot set attributes as the work can't be changed once defined: {name}") + if isinstance(value, (list, dict)) and value: + _type = (LightningFlow, LightningWork, List, Dict) + if isinstance(value, list) and all(isinstance(va, _type) for va in value): + value = List(*value) + + if isinstance(value, dict) and all(isinstance(va, _type) for va in value.values()): + value = Dict(**value) + if isinstance(value, LightningFlow): self._flows.add(name) _set_child_name(self, value, name) @@ -163,10 +171,10 @@ def __setattr__(self, name: str, value: Any) -> None: value._register_cloud_compute() elif isinstance(value, (Dict, List)): - value._backend = self._backend self._structures.add(name) _set_child_name(self, value, name) - if self._backend: + if getattr(self, "_backend", None) is not None: + value._backend = self._backend for flow in value.flows: LightningFlow._attach_backend(flow, self._backend) for work in value.works: @@ -232,7 +240,10 @@ def __getattr__(self, item): @property def ready(self) -> bool: - """Override to customize when your App should be ready.""" + """Not currently enabled. + + Override to customize when your App should be ready. + """ flows = self.flows return all(flow.ready for flow in flows.values()) if flows else True @@ -763,6 +774,13 @@ def __init__(self, work): super().__init__() self.work = work + @property + def ready(self) -> bool: + ready = getattr(self.work, "ready", None) + if ready: + return ready + return self.work.url != "" + def run(self): if self.work.has_succeeded: self.work.stop() @@ -770,4 +788,6 @@ def run(self): self.work.run() def configure_layout(self): - return [{"name": "Main", "content": self.work}] + if is_overridden("configure_layout", self.work): + return [{"name": "Main", "content": self.work}] + return [] diff --git a/src/lightning_app/core/queues.py b/src/lightning_app/core/queues.py index a7fee9a3b6e12..db150a57eb098 100644 --- a/src/lightning_app/core/queues.py +++ b/src/lightning_app/core/queues.py @@ -49,7 +49,6 @@ class QueuingSystem(Enum): - SINGLEPROCESS = "singleprocess" MULTIPROCESS = "multiprocess" REDIS = "redis" HTTP = "http" @@ -59,10 +58,8 @@ def get_queue(self, queue_name: str) -> "BaseQueue": return MultiProcessQueue(queue_name, default_timeout=STATE_UPDATE_TIMEOUT) elif self == QueuingSystem.REDIS: return RedisQueue(queue_name, default_timeout=REDIS_QUEUES_READ_DEFAULT_TIMEOUT) - elif self == QueuingSystem.HTTP: - return HTTPQueue(queue_name, default_timeout=STATE_UPDATE_TIMEOUT) else: - return SingleProcessQueue(queue_name, default_timeout=STATE_UPDATE_TIMEOUT) + return HTTPQueue(queue_name, default_timeout=STATE_UPDATE_TIMEOUT) def get_api_response_queue(self, queue_id: Optional[str] = None) -> "BaseQueue": queue_name = f"{queue_id}_{API_RESPONSE_QUEUE_CONSTANT}" if queue_id else API_RESPONSE_QUEUE_CONSTANT @@ -179,26 +176,12 @@ def is_running(self) -> bool: return True -class SingleProcessQueue(BaseQueue): - def __init__(self, name: str, default_timeout: float): - self.name = name - self.default_timeout = default_timeout - self.queue = queue.Queue() - - def put(self, item): - self.queue.put(item) - - def get(self, timeout: int = None): - if timeout == 0: - timeout = self.default_timeout - return self.queue.get(timeout=timeout, block=(timeout is None)) - - class MultiProcessQueue(BaseQueue): def __init__(self, name: str, default_timeout: float): self.name = name self.default_timeout = default_timeout - self.queue = multiprocessing.Queue() + context = multiprocessing.get_context("spawn") + self.queue = context.Queue() def put(self, item): self.queue.put(item) diff --git a/src/lightning_app/core/work.py b/src/lightning_app/core/work.py index ab0dc8426ac91..60d1ea62d8afb 100644 --- a/src/lightning_app/core/work.py +++ b/src/lightning_app/core/work.py @@ -1,8 +1,9 @@ +import sys import time import warnings from copy import deepcopy from functools import partial, wraps -from typing import Any, Callable, Dict, List, Optional, Type, Union +from typing import Any, Callable, Dict, List, Optional, Type, TYPE_CHECKING, Union from deepdiff import DeepHash, Delta @@ -32,6 +33,9 @@ ) from lightning_app.utilities.proxies import Action, LightningWorkSetAttrProxy, ProxyWorkRun, unwrap, WorkRunExecutor +if TYPE_CHECKING: + from lightning_app.frontend import Frontend + class LightningWork: @@ -46,6 +50,8 @@ class LightningWork: ) _run_executor_cls: Type[WorkRunExecutor] = WorkRunExecutor + # TODO: Move to spawn for all Operating System. + _start_method = "spawn" if sys.platform == "win32" else "fork" def __init__( self, @@ -626,3 +632,45 @@ def apply_flow_delta(self, delta: Delta): property_object.fset(self, value) else: self._default_setattr(name, value) + + def configure_layout(self) -> Union[None, str, "Frontend"]: + """Configure the UI of this LightningWork. + + You can either + + 1. Return a single :class:`~lightning_app.frontend.frontend.Frontend` object to serve a user interface + for this Work. + 2. Return a string containing a URL to act as the user interface for this Work. + 3. Return ``None`` to indicate that this Work doesn't currently have a user interface. + + **Example:** Serve a static directory (with at least a file index.html inside). + + .. code-block:: python + + from lightning_app.frontend import StaticWebFrontend + + + class Work(LightningWork): + def configure_layout(self): + return StaticWebFrontend("path/to/folder/to/serve") + + **Example:** Arrange the UI of my children in tabs (default UI by Lightning). + + .. code-block:: python + + class Work(LightningWork): + def configure_layout(self): + return [ + dict(name="First Tab", content=self.child0), + dict(name="Second Tab", content=self.child1), + dict(name="Lightning", content="https://lightning.ai"), + ] + + If you don't implement ``configure_layout``, Lightning will use ``self.url``. + + Note: + This hook gets called at the time of app creation and then again as part of the loop. If desired, a + returned URL can depend on the state. This is not the case if the work returns a + :class:`~lightning_app.frontend.frontend.Frontend`. These need to be provided at the time of app creation + in order for the runtime to start the server. + """ diff --git a/src/lightning_app/perf/pdb.py b/src/lightning_app/perf/pdb.py index 5bd56960715e3..f4b42f96e842d 100644 --- a/src/lightning_app/perf/pdb.py +++ b/src/lightning_app/perf/pdb.py @@ -1,19 +1,36 @@ +import multiprocessing +import os import pdb import sys -from typing import Any +_stdin = [None] +_stdin_lock = multiprocessing.Lock() +try: + _stdin_fd = sys.stdin.fileno() +except Exception: + _stdin_fd = None + +# Taken from https://github.com/facebookresearch/metaseq/blob/main/metaseq/pdb.py class MPPdb(pdb.Pdb): - """debugger for forked programs.""" + """A Pdb wrapper that works in a multiprocessing environment.""" + + def __init__(self) -> None: + pdb.Pdb.__init__(self, nosigint=True) - def interaction(self, *args: Any, **kwargs: Any) -> None: - _stdin = sys.stdin - try: - sys.stdin = open("/dev/stdin") - pdb.Pdb.interaction(self, *args, **kwargs) - finally: - sys.stdin = _stdin + def _cmdloop(self) -> None: + stdin_back = sys.stdin + with _stdin_lock: + try: + if _stdin_fd is not None: + if not _stdin[0]: + _stdin[0] = os.fdopen(_stdin_fd) + sys.stdin = _stdin[0] + self.cmdloop() + finally: + sys.stdin = stdin_back -def set_trace(*args: Any, **kwargs: Any) -> None: - MPPdb().set_trace(*args, **kwargs) +def set_trace() -> None: + pdb = MPPdb() + pdb.set_trace(sys._getframe().f_back) diff --git a/src/lightning_app/runners/__init__.py b/src/lightning_app/runners/__init__.py index e2300663c4930..7749cbbae561e 100644 --- a/src/lightning_app/runners/__init__.py +++ b/src/lightning_app/runners/__init__.py @@ -1,7 +1,6 @@ from lightning_app.runners.cloud import CloudRuntime from lightning_app.runners.multiprocess import MultiProcessRuntime from lightning_app.runners.runtime import dispatch, Runtime -from lightning_app.runners.singleprocess import SingleProcessRuntime from lightning_app.utilities.app_commands import run_app_commands from lightning_app.utilities.load_app import load_app_from_file @@ -11,6 +10,5 @@ "run_app_commands", "Runtime", "MultiProcessRuntime", - "SingleProcessRuntime", "CloudRuntime", ] diff --git a/src/lightning_app/runners/backends/mp_process.py b/src/lightning_app/runners/backends/mp_process.py index 36a067d0bfd80..dc0681390046e 100644 --- a/src/lightning_app/runners/backends/mp_process.py +++ b/src/lightning_app/runners/backends/mp_process.py @@ -31,7 +31,10 @@ def start(self): flow_to_work_delta_queue=self.app.flow_to_work_delta_queues[self.work.name], run_executor_cls=self.work._run_executor_cls, ) - self._process = multiprocessing.Process(target=self._work_runner) + + start_method = self.work._start_method + context = multiprocessing.get_context(start_method) + self._process = context.Process(target=self._work_runner) self._process.start() def kill(self): diff --git a/src/lightning_app/runners/runtime_type.py b/src/lightning_app/runners/runtime_type.py index aca045625f5e9..c5a9b60f89072 100644 --- a/src/lightning_app/runners/runtime_type.py +++ b/src/lightning_app/runners/runtime_type.py @@ -1,21 +1,18 @@ from enum import Enum from typing import Type, TYPE_CHECKING -from lightning_app.runners import CloudRuntime, MultiProcessRuntime, SingleProcessRuntime +from lightning_app.runners import CloudRuntime, MultiProcessRuntime if TYPE_CHECKING: from lightning_app.runners.runtime import Runtime class RuntimeType(Enum): - SINGLEPROCESS = "singleprocess" MULTIPROCESS = "multiprocess" CLOUD = "cloud" def get_runtime(self) -> Type["Runtime"]: - if self == RuntimeType.SINGLEPROCESS: - return SingleProcessRuntime - elif self == RuntimeType.MULTIPROCESS: + if self == RuntimeType.MULTIPROCESS: return MultiProcessRuntime elif self == RuntimeType.CLOUD: return CloudRuntime diff --git a/src/lightning_app/runners/singleprocess.py b/src/lightning_app/runners/singleprocess.py deleted file mode 100644 index 61a67ce9ba904..0000000000000 --- a/src/lightning_app/runners/singleprocess.py +++ /dev/null @@ -1,62 +0,0 @@ -import multiprocessing as mp -import os -from typing import Any - -import click - -from lightning_app.core.api import start_server -from lightning_app.core.queues import QueuingSystem -from lightning_app.runners.runtime import Runtime -from lightning_app.utilities.app_helpers import _is_headless -from lightning_app.utilities.load_app import extract_metadata_from_app - - -class SingleProcessRuntime(Runtime): - """Runtime to launch the LightningApp into a single process.""" - - def __post_init__(self): - pass - - def dispatch(self, *args, open_ui: bool = True, **kwargs: Any): - """Method to dispatch and run the LightningApp.""" - queue = QueuingSystem.SINGLEPROCESS - - self.app.delta_queue = queue.get_delta_queue() - self.app.state_update_queue = queue.get_caller_queue(work_name="single_worker") - self.app.error_queue = queue.get_error_queue() - - if self.start_server: - self.app.should_publish_changes_to_api = True - self.app.api_publish_state_queue = QueuingSystem.MULTIPROCESS.get_api_state_publish_queue() - self.app.api_delta_queue = QueuingSystem.MULTIPROCESS.get_api_delta_queue() - has_started_queue = QueuingSystem.MULTIPROCESS.get_has_server_started_queue() - kwargs = dict( - host=self.host, - port=self.port, - api_publish_state_queue=self.app.api_publish_state_queue, - api_delta_queue=self.app.api_delta_queue, - has_started_queue=has_started_queue, - spec=extract_metadata_from_app(self.app), - root_path=self.app.root_path, - ) - server_proc = mp.Process(target=start_server, kwargs=kwargs) - self.processes["server"] = server_proc - server_proc.start() - - # wait for server to be ready. - has_started_queue.get() - - if open_ui and not _is_headless(self.app): - click.launch(self._get_app_url()) - - try: - self.app._run() - except KeyboardInterrupt: - self.terminate() - raise - finally: - self.terminate() - - @staticmethod - def _get_app_url() -> str: - return os.getenv("APP_SERVER_HOST", "http://127.0.0.1:7501/view") diff --git a/src/lightning_app/storage/orchestrator.py b/src/lightning_app/storage/orchestrator.py index 52ac7be3dc55b..9edb6344852fa 100644 --- a/src/lightning_app/storage/orchestrator.py +++ b/src/lightning_app/storage/orchestrator.py @@ -105,6 +105,7 @@ def run_once(self, work_name: str) -> None: name=request.name, path=maybe_artifact_path, hash=request.hash, + size=self.fs.info(maybe_artifact_path)["size"], destination=request.destination, ) if isinstance(request, _ExistsRequest): @@ -139,6 +140,7 @@ def run_once(self, work_name: str) -> None: path=request.path, name=request.name, hash=request.hash, + size=0, destination=request.destination, ) if isinstance(request, _ExistsRequest): diff --git a/src/lightning_app/storage/path.py b/src/lightning_app/storage/path.py index a8aa9d41e8055..4b5da1d580946 100644 --- a/src/lightning_app/storage/path.py +++ b/src/lightning_app/storage/path.py @@ -10,6 +10,7 @@ from fsspec import AbstractFileSystem from fsspec.implementations.local import LocalFileSystem +from lightning_app.core.constants import REMOTE_STORAGE_WAIT from lightning_app.core.queues import BaseQueue from lightning_app.storage.requests import _ExistsRequest, _ExistsResponse, _GetRequest, _GetResponse from lightning_app.utilities.app_helpers import Logger @@ -199,9 +200,8 @@ def get(self, overwrite: bool = False) -> None: fs = _filesystem() # 3. Wait until the file appears in shared storage - while not fs.exists(response.path): - # TODO: Existence check on folder is not enough, files may not be completely transferred yet - sleep(0.5) + while not fs.exists(response.path) or fs.info(response.path)["size"] != response.size: + sleep(REMOTE_STORAGE_WAIT) if self.exists_local() and self.is_dir(): # Delete the directory, otherwise we can't overwrite it @@ -340,10 +340,11 @@ def _handle_get_request(work: "LightningWork", request: _GetRequest) -> _GetResp destination_path = _shared_storage_path() / request.hash response = _GetResponse( source=request.source, + name=request.name, path=str(destination_path), hash=request.hash, + size=source_path.stat().st_size, destination=request.destination, - name=request.name, ) try: diff --git a/src/lightning_app/storage/payload.py b/src/lightning_app/storage/payload.py index be9f8f20ff00e..29789d31fcf75 100644 --- a/src/lightning_app/storage/payload.py +++ b/src/lightning_app/storage/payload.py @@ -5,6 +5,7 @@ from time import sleep from typing import Any, Optional, TYPE_CHECKING, Union +from lightning_app.core.constants import REMOTE_STORAGE_WAIT from lightning_app.core.queues import BaseQueue from lightning_app.storage.path import _filesystem, _shared_storage_path, Path from lightning_app.storage.requests import _ExistsRequest, _ExistsResponse, _GetRequest, _GetResponse @@ -159,9 +160,8 @@ def get(self) -> Any: fs = _filesystem() # 3. Wait until the file appears in shared storage - while not fs.exists(response.path): - # TODO: Existence check on folder is not enough, files may not be completely transferred yet - sleep(0.5) + while not fs.exists(response.path) or fs.info(response.path)["size"] != response.size: + sleep(REMOTE_STORAGE_WAIT) # 4. Copy the file from the shared storage to the destination on the local filesystem local_path = self._path @@ -234,6 +234,7 @@ def _handle_get_request(work: "LightningWork", request: _GetRequest) -> _GetResp try: payload = getattr(work, request.name) payload.save(payload.value, source_path) + response.size = source_path.stat().st_size _copy_files(source_path, destination_path) _logger.debug(f"All files copied from {request.path} to {response.path}.") except Exception as e: diff --git a/src/lightning_app/storage/requests.py b/src/lightning_app/storage/requests.py index 43c97b8f133b3..117d2b91adb9b 100644 --- a/src/lightning_app/storage/requests.py +++ b/src/lightning_app/storage/requests.py @@ -17,6 +17,7 @@ class _GetResponse: name: str path: str hash: str + size: int = 0 destination: str = "" exception: Optional[Exception] = None timedelta: Optional[float] = None diff --git a/src/lightning_app/structures/dict.py b/src/lightning_app/structures/dict.py index aaf8a3c8298d0..7bf102e19f180 100644 --- a/src/lightning_app/structures/dict.py +++ b/src/lightning_app/structures/dict.py @@ -64,10 +64,10 @@ def __setitem__(self, k, v): if isinstance(k, str) and "." in k: raise Exception(f"The provided name {k} contains . which is forbidden.") + _set_child_name(self, v, k) if self._backend: if isinstance(v, LightningFlow): LightningFlow._attach_backend(v, self._backend) - _set_child_name(self, v, k) elif isinstance(v, LightningWork): self._backend._wrap_run_method(_LightningAppRef().get_current(), v) v._name = f"{self.name}.{k}" diff --git a/src/lightning_app/structures/list.py b/src/lightning_app/structures/list.py index 416f1e6d85a05..9f110c69b1388 100644 --- a/src/lightning_app/structures/list.py +++ b/src/lightning_app/structures/list.py @@ -53,20 +53,18 @@ def __init__(self, *items: T): self._backend: t.Optional[Backend] = None for item in items: self.append(item) - _set_child_name(self, item, str(self._last_index)) - self._last_index += 1 def append(self, v): from lightning_app import LightningFlow, LightningWork + _set_child_name(self, v, str(self._last_index)) if self._backend: if isinstance(v, LightningFlow): LightningFlow._attach_backend(v, self._backend) - _set_child_name(self, v, str(self._last_index)) elif isinstance(v, LightningWork): self._backend._wrap_run_method(_LightningAppRef().get_current(), v) - v._name = f"{self.name}.{self._last_index}" - self._last_index += 1 + v._name = f"{self.name}.{self._last_index}" + self._last_index += 1 super().append(v) @property diff --git a/src/lightning_app/utilities/app_helpers.py b/src/lightning_app/utilities/app_helpers.py index a000af3e71fe6..665c50889676c 100644 --- a/src/lightning_app/utilities/app_helpers.py +++ b/src/lightning_app/utilities/app_helpers.py @@ -130,13 +130,6 @@ def set_served_session_id(self, k, v): self.store[k].session_id = v -class DistributedMode(enum.Enum): - SINGLEPROCESS = enum.auto() - MULTIPROCESS = enum.auto() - CONTAINER = enum.auto() - GRID = enum.auto() - - class _LightningAppRef: _app_instance: Optional["LightningApp"] = None @@ -518,14 +511,10 @@ def is_static_method(klass_or_instance, attr) -> bool: return isinstance(inspect.getattr_static(klass_or_instance, attr), staticmethod) -def _debugger_is_active() -> bool: - """Return if the debugger is currently active.""" - return hasattr(sys, "gettrace") and sys.gettrace() is not None - - def _should_dispatch_app() -> bool: return ( - _debugger_is_active() + __debug__ + and "_pytest.doctest" not in sys.modules and not bool(int(os.getenv("LIGHTNING_DISPATCHED", "0"))) and "LIGHTNING_APP_STATE_URL" not in os.environ ) diff --git a/src/lightning_app/utilities/layout.py b/src/lightning_app/utilities/layout.py index 15079fcb6964b..11f26019cb406 100644 --- a/src/lightning_app/utilities/layout.py +++ b/src/lightning_app/utilities/layout.py @@ -4,7 +4,7 @@ import lightning_app from lightning_app.frontend.frontend import Frontend -from lightning_app.utilities.app_helpers import _MagicMockJsonSerializable +from lightning_app.utilities.app_helpers import _MagicMockJsonSerializable, is_overridden from lightning_app.utilities.cloud import is_running_in_cloud @@ -45,9 +45,9 @@ def _collect_layout(app: "lightning_app.LightningApp", flow: "lightning_app.Ligh app.frontends.setdefault(flow.name, "mock") return flow._layout elif isinstance(layout, dict): - layout = _collect_content_layout([layout], flow) + layout = _collect_content_layout([layout], app, flow) elif isinstance(layout, (list, tuple)) and all(isinstance(item, dict) for item in layout): - layout = _collect_content_layout(layout, flow) + layout = _collect_content_layout(layout, app, flow) else: lines = _add_comment_to_literal_code(flow.configure_layout, contains="return", comment=" <------- this guy") m = f""" @@ -76,7 +76,9 @@ def configure_layout(self): return layout -def _collect_content_layout(layout: List[Dict], flow: "lightning_app.LightningFlow") -> List[Dict]: +def _collect_content_layout( + layout: List[Dict], app: "lightning_app.LightningApp", flow: "lightning_app.LightningFlow" +) -> Union[List[Dict], Dict]: """Process the layout returned by the ``configure_layout()`` method if the returned format represents an aggregation of child layouts.""" for entry in layout: @@ -102,12 +104,43 @@ def _collect_content_layout(layout: List[Dict], flow: "lightning_app.LightningFl entry["content"] = entry["content"].name elif isinstance(entry["content"], lightning_app.LightningWork): - if entry["content"].url and not entry["content"].url.startswith("/"): - entry["content"] = entry["content"].url - entry["target"] = entry["content"] - else: + work = entry["content"] + work_layout = _collect_work_layout(work) + + if work_layout is None: entry["content"] = "" - entry["target"] = "" + elif isinstance(work_layout, str): + entry["content"] = work_layout + entry["target"] = work_layout + elif isinstance(work_layout, (Frontend, _MagicMockJsonSerializable)): + if len(layout) > 1: + lines = _add_comment_to_literal_code( + flow.configure_layout, contains="return", comment=" <------- this guy" + ) + m = f""" + The return value of configure_layout() in `{flow.__class__.__name__}` is an + unsupported format: + \n{lines} + + The tab containing a `{work.__class__.__name__}` must be the only tab in the + layout of this flow. + + (see the docs for `LightningWork.configure_layout`). + """ + raise TypeError(m) + + if isinstance(work_layout, Frontend): + # If the work returned a frontend, treat it as belonging to the flow. + # NOTE: This could evolve in the future to run the Frontend directly in the work machine. + frontend = work_layout + frontend.flow = flow + elif isinstance(work_layout, _MagicMockJsonSerializable): + # The import was mocked, we set a dummy `Frontend` so that `is_headless` knows there is a UI. + frontend = "mock" + + app.frontends.setdefault(flow.name, frontend) + return flow._layout + elif isinstance(entry["content"], _MagicMockJsonSerializable): # The import was mocked, we just record dummy content so that `is_headless` knows there is a UI entry["content"] = "mock" @@ -126,3 +159,43 @@ def configure_layout(self): """ raise ValueError(m) return layout + + +def _collect_work_layout(work: "lightning_app.LightningWork") -> Union[None, str, Frontend, _MagicMockJsonSerializable]: + """Check if ``configure_layout`` is overridden on the given work and return the work layout (either a string, a + ``Frontend`` object, or an instance of a mocked import). + + Args: + work: The work to collect the layout for. + + Raises: + TypeError: If the value returned by ``configure_layout`` is not of a supported format. + """ + if is_overridden("configure_layout", work): + work_layout = work.configure_layout() + else: + work_layout = work.url + + if work_layout is None: + return None + elif isinstance(work_layout, str): + url = work_layout + # The URL isn't fully defined yet. Looks something like ``self.work.url + /something``. + if url and not url.startswith("/"): + return url + return "" + elif isinstance(work_layout, (Frontend, _MagicMockJsonSerializable)): + return work_layout + else: + m = f""" + The value returned by `{work.__class__.__name__}.configure_layout()` is of an unsupported type. + + {repr(work_layout)} + + Return a `Frontend` or a URL string, for example: + + class {work.__class__.__name__}(LightningWork): + def configure_layout(self): + return MyFrontend() OR 'http://some/url' + """ + raise TypeError(m) diff --git a/src/lightning_app/utilities/safe_pickle.py b/src/lightning_app/utilities/safe_pickle.py new file mode 100644 index 0000000000000..8788ff22a3cb6 --- /dev/null +++ b/src/lightning_app/utilities/safe_pickle.py @@ -0,0 +1,95 @@ +import contextlib +import pickle +import sys +import types +import typing +from copy import deepcopy +from pathlib import Path + +from lightning_app.core.work import LightningWork +from lightning_app.utilities.app_helpers import _LightningAppRef + +NON_PICKLABLE_WORK_ATTRIBUTES = ["_request_queue", "_response_queue", "_backend", "_setattr_replacement"] + + +@contextlib.contextmanager +def _trimmed_work(work: LightningWork, to_trim: typing.List[str]) -> typing.Iterator[None]: + """Context manager to trim the work object to remove attributes that are not picklable.""" + holder = {} + for arg in to_trim: + holder[arg] = getattr(work, arg) + setattr(work, arg, None) + yield + for arg in to_trim: + setattr(work, arg, holder[arg]) + + +def get_picklable_work(work: LightningWork) -> LightningWork: + """Pickling a LightningWork instance fails if done from the work process + itself. This function is safe to call from the work process within both MultiprocessRuntime + and Cloud. + Note: This function modifies the module information of the work object. Specifically, it injects + the relative module path into the __module__ attribute of the work object. If the object is not + importable from the CWD, then the pickle load will fail. + + Example: + for a directory structure like below and the work class is defined in the app.py where + the app.py is the entrypoint for the app, it will inject `foo.bar.app` into the + __module__ attribute + + └── foo + ├── __init__.py + └── bar + └── app.py + """ + + # If the work object not taken from the app ref, there is a thread lock reference + # somewhere thats preventing it from being pickled. Investigate it later. We + # shouldn't be fetching the work object from the app ref. TODO @sherin + app_ref = _LightningAppRef.get_current() + if app_ref is None: + raise RuntimeError("Cannot pickle LightningWork outside of a LightningApp") + for w in app_ref.works: + if work.name == w.name: + # deep-copying the work object to avoid modifying the original work object + with _trimmed_work(w, to_trim=NON_PICKLABLE_WORK_ATTRIBUTES): + copied_work = deepcopy(w) + break + else: + raise ValueError(f"Work with name {work.name} not found in the app references") + + # if work is defined in the __main__ or __mp__main__ (the entrypoint file for `lightning run app` command), + # pickling/unpickling will fail, hence we need patch the module information + if "_main__" in copied_work.__class__.__module__: + work_class_module = sys.modules[copied_work.__class__.__module__] + work_class_file = work_class_module.__file__ + if not work_class_file: + raise ValueError( + f"Cannot pickle work class {copied_work.__class__.__name__} because we " + f"couldn't identify the module file" + ) + relative_path = Path(work_class_module.__file__).relative_to(Path.cwd()) # type: ignore + expected_module_name = relative_path.as_posix().replace(".py", "").replace("/", ".") + # TODO @sherin: also check if the module is importable from the CWD + fake_module = types.ModuleType(expected_module_name) + fake_module.__dict__.update(work_class_module.__dict__) + fake_module.__dict__["__name__"] = expected_module_name + sys.modules[expected_module_name] = fake_module + for k, v in fake_module.__dict__.items(): + if not k.startswith("__") and hasattr(v, "__module__"): + if "_main__" in v.__module__: + v.__module__ = expected_module_name + return copied_work + + +def dump(work: LightningWork, f: typing.BinaryIO) -> None: + picklable_work = get_picklable_work(work) + pickle.dump(picklable_work, f) + + +def load(f: typing.BinaryIO) -> typing.Any: + # inject current working directory to sys.path + sys.path.insert(1, str(Path.cwd())) + work = pickle.load(f) + sys.path.pop(1) + return work diff --git a/src/lightning_app/utilities/state.py b/src/lightning_app/utilities/state.py index a882953ab0450..775fa49ddd0ba 100644 --- a/src/lightning_app/utilities/state.py +++ b/src/lightning_app/utilities/state.py @@ -2,6 +2,7 @@ import json import os from copy import deepcopy +from time import sleep from typing import Any, Dict, List, Optional, Tuple, Union from deepdiff import DeepDiff @@ -149,16 +150,26 @@ def _request_state(self) -> None: return app_url = f"{self._url}/api/v1/state" headers = headers_for(self._plugin.get_context()) if self._plugin else {} - try: - response = self._session.get(app_url, headers=headers, timeout=1) - except ConnectionError as e: - raise AttributeError("Failed to connect and fetch the app state. Is the app running?") from e - self._authorized = response.status_code - if self._authorized != 200: - return - logger.debug(f"GET STATE {response} {response.json()}") - self._store_state(response.json()) + response_json = {} + + # Sometimes the state URL can return an empty JSON when things are being set-up, + # so we wait for it to be ready here. + while response_json == {}: + sleep(0.5) + try: + response = self._session.get(app_url, headers=headers, timeout=1) + except ConnectionError as e: + raise AttributeError("Failed to connect and fetch the app state. Is the app running?") from e + + self._authorized = response.status_code + if self._authorized != 200: + return + + response_json = response.json() + + logger.debug(f"GET STATE {response} {response_json}") + self._store_state(response_json) def __getattr__(self, name: str) -> Union[Any, "AppState"]: if name in self._APP_PRIVATE_KEYS: diff --git a/src/lightning_fabric/CHANGELOG.md b/src/lightning_fabric/CHANGELOG.md index 4d40fc47ca9be..2cc9050abd43d 100644 --- a/src/lightning_fabric/CHANGELOG.md +++ b/src/lightning_fabric/CHANGELOG.md @@ -42,7 +42,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Fixed -- +- Fixed `shuffle=False` having no effect when using DDP/DistributedSampler ([#15931](https://github.com/Lightning-AI/lightning/issues/15931)) + ## [1.8.3] - 2022-11-22 diff --git a/src/lightning_fabric/fabric.py b/src/lightning_fabric/fabric.py index c2cef2e0db0aa..5f46558059759 100644 --- a/src/lightning_fabric/fabric.py +++ b/src/lightning_fabric/fabric.py @@ -25,7 +25,7 @@ from lightning_utilities.core.rank_zero import rank_zero_warn from torch import Tensor from torch.optim import Optimizer -from torch.utils.data import BatchSampler, DataLoader, DistributedSampler +from torch.utils.data import BatchSampler, DataLoader, DistributedSampler, RandomSampler from lightning_fabric.plugins import Precision # avoid circular imports: # isort: split from lightning_fabric.accelerators.accelerator import Accelerator @@ -582,6 +582,7 @@ def _requires_distributed_sampler(self, dataloader: DataLoader) -> bool: @staticmethod def _get_distributed_sampler(dataloader: DataLoader, **kwargs: Any) -> DistributedSampler: + kwargs.setdefault("shuffle", isinstance(dataloader.sampler, RandomSampler)) kwargs.setdefault("seed", int(os.getenv("PL_GLOBAL_SEED", 0))) return DistributedSamplerWrapper(dataloader.sampler, **kwargs) diff --git a/src/lightning_fabric/plugins/collectives/collective.py b/src/lightning_fabric/plugins/collectives/collective.py index 4343d450a839f..0d6cceed8f3be 100644 --- a/src/lightning_fabric/plugins/collectives/collective.py +++ b/src/lightning_fabric/plugins/collectives/collective.py @@ -1,7 +1,7 @@ from abc import ABC, abstractmethod from typing import Any, List, Optional -import torch +from torch import Tensor from typing_extensions import Self from lightning_fabric.utilities.types import CollectibleGroup @@ -38,45 +38,43 @@ def group(self) -> CollectibleGroup: return self._group @abstractmethod - def broadcast(self, tensor: torch.Tensor, src: int) -> torch.Tensor: + def broadcast(self, tensor: Tensor, src: int) -> Tensor: ... @abstractmethod - def all_reduce(self, tensor: torch.Tensor, op: str) -> torch.Tensor: + def all_reduce(self, tensor: Tensor, op: str) -> Tensor: ... @abstractmethod - def reduce(self, tensor: torch.Tensor, dst: int, op: str) -> torch.Tensor: + def reduce(self, tensor: Tensor, dst: int, op: str) -> Tensor: ... @abstractmethod - def all_gather(self, tensor_list: List[torch.Tensor], tensor: torch.Tensor) -> List[torch.Tensor]: + def all_gather(self, tensor_list: List[Tensor], tensor: Tensor) -> List[Tensor]: ... @abstractmethod - def gather(self, tensor: torch.Tensor, gather_list: List[torch.Tensor], dst: int = 0) -> List[torch.Tensor]: + def gather(self, tensor: Tensor, gather_list: List[Tensor], dst: int = 0) -> List[Tensor]: ... @abstractmethod - def scatter(self, tensor: torch.Tensor, scatter_list: List[torch.Tensor], src: int = 0) -> torch.Tensor: + def scatter(self, tensor: Tensor, scatter_list: List[Tensor], src: int = 0) -> Tensor: ... @abstractmethod - def reduce_scatter(self, output: torch.Tensor, input_list: List[torch.Tensor], op: str) -> torch.Tensor: + def reduce_scatter(self, output: Tensor, input_list: List[Tensor], op: str) -> Tensor: ... @abstractmethod - def all_to_all( - self, output_tensor_list: List[torch.Tensor], input_tensor_list: List[torch.Tensor] - ) -> List[torch.Tensor]: + def all_to_all(self, output_tensor_list: List[Tensor], input_tensor_list: List[Tensor]) -> List[Tensor]: ... @abstractmethod - def send(self, tensor: torch.Tensor, dst: int, tag: Optional[int] = 0) -> None: + def send(self, tensor: Tensor, dst: int, tag: Optional[int] = 0) -> None: ... @abstractmethod - def recv(self, tensor: torch.Tensor, src: Optional[int] = None, tag: Optional[int] = 0) -> torch.Tensor: + def recv(self, tensor: Tensor, src: Optional[int] = None, tag: Optional[int] = 0) -> Tensor: ... @abstractmethod diff --git a/src/lightning_fabric/plugins/collectives/single_device.py b/src/lightning_fabric/plugins/collectives/single_device.py index 5acbac81fe099..afb5c397bbc5d 100644 --- a/src/lightning_fabric/plugins/collectives/single_device.py +++ b/src/lightning_fabric/plugins/collectives/single_device.py @@ -1,6 +1,6 @@ from typing import Any, List -import torch +from torch import Tensor from lightning_fabric.plugins.collectives.collective import Collective from lightning_fabric.utilities.types import CollectibleGroup @@ -15,42 +15,42 @@ def rank(self) -> int: def world_size(self) -> int: return 1 - def broadcast(self, tensor: torch.Tensor, *_: Any, **__: Any) -> torch.Tensor: + def broadcast(self, tensor: Tensor, *_: Any, **__: Any) -> Tensor: return tensor - def all_reduce(self, tensor: torch.Tensor, *_: Any, **__: Any) -> torch.Tensor: + def all_reduce(self, tensor: Tensor, *_: Any, **__: Any) -> Tensor: return tensor - def reduce(self, tensor: torch.Tensor, *_: Any, **__: Any) -> torch.Tensor: + def reduce(self, tensor: Tensor, *_: Any, **__: Any) -> Tensor: return tensor - def all_gather(self, tensor_list: List[torch.Tensor], tensor: torch.Tensor, **__: Any) -> List[torch.Tensor]: + def all_gather(self, tensor_list: List[Tensor], tensor: Tensor, **__: Any) -> List[Tensor]: return [tensor] - def gather(self, tensor: torch.Tensor, *_: Any, **__: Any) -> List[torch.Tensor]: + def gather(self, tensor: Tensor, *_: Any, **__: Any) -> List[Tensor]: return [tensor] def scatter( self, - tensor: torch.Tensor, - scatter_list: List[torch.Tensor], + tensor: Tensor, + scatter_list: List[Tensor], *_: Any, **__: Any, - ) -> torch.Tensor: + ) -> Tensor: return scatter_list[0] - def reduce_scatter(self, output: torch.Tensor, input_list: List[torch.Tensor], *_: Any, **__: Any) -> torch.Tensor: + def reduce_scatter(self, output: Tensor, input_list: List[Tensor], *_: Any, **__: Any) -> Tensor: return input_list[0] def all_to_all( - self, output_tensor_list: List[torch.Tensor], input_tensor_list: List[torch.Tensor], *_: Any, **__: Any - ) -> List[torch.Tensor]: + self, output_tensor_list: List[Tensor], input_tensor_list: List[Tensor], *_: Any, **__: Any + ) -> List[Tensor]: return input_tensor_list def send(self, *_: Any, **__: Any) -> None: pass - def recv(self, tensor: torch.Tensor, *_: Any, **__: Any) -> torch.Tensor: + def recv(self, tensor: Tensor, *_: Any, **__: Any) -> Tensor: return tensor def barrier(self, *_: Any, **__: Any) -> None: diff --git a/src/lightning_fabric/plugins/collectives/torch_collective.py b/src/lightning_fabric/plugins/collectives/torch_collective.py index 8ace0d9f82997..e841b6b9dd0bb 100644 --- a/src/lightning_fabric/plugins/collectives/torch_collective.py +++ b/src/lightning_fabric/plugins/collectives/torch_collective.py @@ -4,6 +4,7 @@ import torch import torch.distributed as dist +from torch import Tensor from typing_extensions import Self from lightning_fabric.plugins.collectives.collective import Collective @@ -33,49 +34,47 @@ def rank(self) -> int: def world_size(self) -> int: return dist.get_world_size(self.group) - def broadcast(self, tensor: torch.Tensor, src: int) -> torch.Tensor: + def broadcast(self, tensor: Tensor, src: int) -> Tensor: dist.broadcast(tensor, src, group=self.group) return tensor - def all_reduce(self, tensor: torch.Tensor, op: Union[str, ReduceOp, RedOpType] = "sum") -> torch.Tensor: + def all_reduce(self, tensor: Tensor, op: Union[str, ReduceOp, RedOpType] = "sum") -> Tensor: op = self._convert_to_native_op(op) dist.all_reduce(tensor, op=op, group=self.group) return tensor - def reduce(self, tensor: torch.Tensor, dst: int, op: Union[str, ReduceOp, RedOpType] = "sum") -> torch.Tensor: + def reduce(self, tensor: Tensor, dst: int, op: Union[str, ReduceOp, RedOpType] = "sum") -> Tensor: op = self._convert_to_native_op(op) dist.reduce(tensor, dst, op=op, group=self.group) return tensor - def all_gather(self, tensor_list: List[torch.Tensor], tensor: torch.Tensor) -> List[torch.Tensor]: + def all_gather(self, tensor_list: List[Tensor], tensor: Tensor) -> List[Tensor]: dist.all_gather(tensor_list, tensor, group=self.group) return tensor_list - def gather(self, tensor: torch.Tensor, gather_list: List[torch.Tensor], dst: int = 0) -> List[torch.Tensor]: + def gather(self, tensor: Tensor, gather_list: List[Tensor], dst: int = 0) -> List[Tensor]: dist.gather(tensor, gather_list, dst, group=self.group) return gather_list - def scatter(self, tensor: torch.Tensor, scatter_list: List[torch.Tensor], src: int = 0) -> torch.Tensor: + def scatter(self, tensor: Tensor, scatter_list: List[Tensor], src: int = 0) -> Tensor: dist.scatter(tensor, scatter_list, src, group=self.group) return tensor def reduce_scatter( - self, output: torch.Tensor, input_list: List[torch.Tensor], op: Union[str, ReduceOp, RedOpType] = "sum" - ) -> torch.Tensor: + self, output: Tensor, input_list: List[Tensor], op: Union[str, ReduceOp, RedOpType] = "sum" + ) -> Tensor: op = self._convert_to_native_op(op) dist.reduce_scatter(output, input_list, op=op, group=self.group) return output - def all_to_all( - self, output_tensor_list: List[torch.Tensor], input_tensor_list: List[torch.Tensor] - ) -> List[torch.Tensor]: + def all_to_all(self, output_tensor_list: List[Tensor], input_tensor_list: List[Tensor]) -> List[Tensor]: dist.all_to_all(output_tensor_list, input_tensor_list, group=self.group) return output_tensor_list - def send(self, tensor: torch.Tensor, dst: int, tag: Optional[int] = 0) -> None: + def send(self, tensor: Tensor, dst: int, tag: Optional[int] = 0) -> None: dist.send(tensor, dst, tag=tag, group=self.group) - def recv(self, tensor: torch.Tensor, src: Optional[int] = None, tag: Optional[int] = 0) -> torch.Tensor: + def recv(self, tensor: Tensor, src: Optional[int] = None, tag: Optional[int] = 0) -> Tensor: dist.recv(tensor, src, tag=tag, group=self.group) return tensor diff --git a/src/lightning_fabric/plugins/precision/utils.py b/src/lightning_fabric/plugins/precision/utils.py index dc41a5202d817..5ef6d5f858ea8 100644 --- a/src/lightning_fabric/plugins/precision/utils.py +++ b/src/lightning_fabric/plugins/precision/utils.py @@ -14,7 +14,8 @@ from typing import Union import torch +from torch import Tensor -def _convert_fp_tensor(tensor: torch.Tensor, dst_type: Union[str, torch.dtype]) -> torch.Tensor: +def _convert_fp_tensor(tensor: Tensor, dst_type: Union[str, torch.dtype]) -> Tensor: return tensor.to(dst_type) if torch.is_floating_point(tensor) else tensor diff --git a/src/lightning_fabric/strategies/fsdp.py b/src/lightning_fabric/strategies/fsdp.py index 4614e1fde0443..7fe400179eb8d 100644 --- a/src/lightning_fabric/strategies/fsdp.py +++ b/src/lightning_fabric/strategies/fsdp.py @@ -69,11 +69,10 @@ class FSDPStrategy(ParallelStrategy, _Sharded): `this tutorial `__ for more information. Arguments: - cpu_offload: CPU offloading config. Currently, only parameter and gradient CPU offload is supported. It - can be enabled via passing in ``cpu_offload=CPUOffload(offload_params=True)``. Note that this currently + cpu_offload: Enable offloading parameters and gradients to CPU to save GPU memory at the cost of speed. + You can also pass a config: ``cpu_offload=CPUOffload(offload_params=True)``. Note that this currently implicitly enables gradient offloading to CPU in order for parameters and gradients to be on same device - to work with the optimizer. This API is subject to change. Default is ``None`` in which case there - will be no offloading. + to work with the optimizer. This API is subject to change. Default: no offoading backward_prefetch: This is an experimental feature that is subject to change in the near future. It allows users to enable two different backward prefetching algorithms to help backward communication and computation overlapping. The pros and cons of each algorithm is explained in the class ``BackwardPrefetch``. @@ -96,7 +95,7 @@ def __init__( precision: Optional[Precision] = None, process_group_backend: Optional[str] = None, timeout: Optional[timedelta] = default_pg_timeout, - cpu_offload: Optional["CPUOffload"] = None, + cpu_offload: Union[bool, "CPUOffload", None] = None, backward_prefetch: Optional["BackwardPrefetch"] = None, mixed_precision: Optional["MixedPrecision"] = None, activation_checkpointing: Optional[Union[Type[Module], List[Type[Module]]]] = None, @@ -125,7 +124,7 @@ def __init__( [activation_checkpointing] if not isinstance(activation_checkpointing, list) else activation_checkpointing ) - self.cpu_offload = cpu_offload + self.cpu_offload = _init_cpu_offload(cpu_offload) self.backward_prefetch = backward_prefetch self.mixed_precision = mixed_precision @@ -276,7 +275,6 @@ def broadcast(self, obj: TBroadcast, src: int = 0) -> TBroadcast: def register_strategies(cls, strategy_registry: Dict) -> None: if not _TORCH_GREATER_EQUAL_1_12 or not torch.distributed.is_available(): return - from torch.distributed.fsdp.fully_sharded_data_parallel import CPUOffload strategy_registry.register( "fsdp", @@ -287,7 +285,7 @@ def register_strategies(cls, strategy_registry: Dict) -> None: "fsdp_full_shard_offload", cls, description="Native FSDP with Full Sharding and CPU Offloading", - cpu_offload=CPUOffload(offload_params=True), + cpu_offload=True, ) def _setup_distributed(self) -> None: @@ -341,6 +339,12 @@ def no_backward_sync(self, module: Module) -> Generator: yield +def _init_cpu_offload(cpu_offload: Optional[Union[bool, "CPUOffload"]]) -> "CPUOffload": + from torch.distributed.fsdp import CPUOffload + + return cpu_offload if isinstance(cpu_offload, CPUOffload) else CPUOffload(offload_params=bool(cpu_offload)) + + def _optimizer_has_flat_params(optimizer: Optimizer) -> bool: from torch.distributed.fsdp import FlatParameter diff --git a/src/lightning_fabric/strategies/xla.py b/src/lightning_fabric/strategies/xla.py index 46480faa1b6ed..1cc05b6438272 100644 --- a/src/lightning_fabric/strategies/xla.py +++ b/src/lightning_fabric/strategies/xla.py @@ -156,20 +156,22 @@ def broadcast(self, obj: TBroadcast, src: int = 0) -> TBroadcast: return obj def all_gather(self, tensor: Tensor, group: Optional[Any] = None, sync_grads: bool = False) -> Tensor: - """ - Function to gather a tensor from several distributed processes + """Function to gather a tensor from several distributed processes. + Args: tensor: tensor of shape (batch, ...) group: not available with TPUs - sync_grads: not available with TPUs + sync_grads: flag that allows users to synchronize gradients for the all_gather operation Return: A tensor of shape (world_size, batch, ...) """ if isinstance(tensor, Tensor) and tensor.dim() == 0: tensor = tensor.unsqueeze(0) + + import torch_xla.core.functions as xf import torch_xla.core.xla_model as xm - return xm.all_gather(tensor) + return xf.all_gather(tensor) if sync_grads else xm.all_gather(tensor) def save_checkpoint( self, checkpoint: Dict[str, Any], filepath: _PATH, storage_options: Optional[Any] = None diff --git a/src/pytorch_lightning/CHANGELOG.md b/src/pytorch_lightning/CHANGELOG.md index 3537cad3307e0..b56a8477dd48a 100644 --- a/src/pytorch_lightning/CHANGELOG.md +++ b/src/pytorch_lightning/CHANGELOG.md @@ -9,6 +9,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Added +- Added support for `torch.compile` ([#15922](https://github.com/Lightning-AI/lightning/pull/15922), [15957](https://github.com/Lightning-AI/lightning/pull/15957)) + + - Added support for DDP with `LRFinder` ([#15304](https://github.com/Lightning-AI/lightning/pull/15304)) @@ -33,6 +36,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added support for activation checkpointing for the `DDPFullyShardedNativeStrategy` strategy ([#15826](https://github.com/Lightning-AI/lightning/pull/15826)) + +- Added the option to set `DDPFullyShardedNativeStrategy(cpu_offload=True|False)` via bool instead of needing to pass a configufation object ([#15832](https://github.com/Lightning-AI/lightning/pull/15832)) + + ### Changed - Drop PyTorch 1.9 support ([#15347](https://github.com/Lightning-AI/lightning/pull/15347)) @@ -65,6 +72,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Removed deprecated `pytorch_lightning.utilities.memory.get_gpu_memory_map` in favor of `pytorch_lightning.accelerators.cuda.get_nvidia_gpu_stats` ([#15617](https://github.com/Lightning-AI/lightning/pull/15617)) + - Temporarily removed support for Hydra multi-run ([#15737](https://github.com/Lightning-AI/lightning/pull/15737)) @@ -81,6 +89,12 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Enhanced `reduce_boolean_decision` to accommodate `any`-analogous semantics expected by the `EarlyStopping` callback ([#15253](https://github.com/Lightning-AI/lightning/pull/15253)) +- Fixed issue with unsupported torch.inference_mode() on hpu backends ([#15918](https://github.com/Lightning-AI/lightning/pull/15918)) + +- Fixed `fit_loop.restarting` to be `False` for lr finder ([#15620](https://github.com/Lightning-AI/lightning/pull/15620)) + + +- Fixed `torch.jit.script`-ing a LightningModule causing an unintended error message about deprecated `use_amp` property ([#15947](https://github.com/Lightning-AI/lightning/pull/15947)) ## [1.8.3] - 2022-11-22 @@ -99,6 +113,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed the automatic fallback from `Trainer(strategy="ddp_spawn", ...)` to `Trainer(strategy="ddp", ...)` when on an LSF cluster ([#15103](https://github.com/PyTorchLightning/pytorch-lightning/issues/15103)) + ## [1.8.1] - 2022-11-10 ### Added diff --git a/src/pytorch_lightning/_graveyard/core.py b/src/pytorch_lightning/_graveyard/core.py index 49768e46569b8..e4722d60558a1 100644 --- a/src/pytorch_lightning/_graveyard/core.py +++ b/src/pytorch_lightning/_graveyard/core.py @@ -13,7 +13,7 @@ # limitations under the License. from typing import Any -from pytorch_lightning import LightningDataModule, LightningModule +from pytorch_lightning import LightningDataModule def _on_save_checkpoint(_: LightningDataModule, __: Any) -> None: @@ -32,28 +32,6 @@ def _on_load_checkpoint(_: LightningDataModule, __: Any) -> None: ) -def _use_amp(_: LightningModule) -> None: - # Remove in v2.0.0 and the skip in `__jit_unused_properties__` - if not LightningModule._jit_is_scripting: - # cannot use `AttributeError` as it messes up with `nn.Module.__getattr__` - raise RuntimeError( - "`LightningModule.use_amp` was deprecated in v1.6 and is no longer accessible as of v1.8." - " Please use `Trainer.amp_backend`.", - ) - - -def _use_amp_setter(_: LightningModule, __: bool) -> None: - # Remove in v2.0.0 - # cannot use `AttributeError` as it messes up with `nn.Module.__getattr__` - raise RuntimeError( - "`LightningModule.use_amp` was deprecated in v1.6 and is no longer accessible as of v1.8." - " Please use `Trainer.amp_backend`.", - ) - - -# Properties -LightningModule.use_amp = property(fget=_use_amp, fset=_use_amp_setter) - # Methods LightningDataModule.on_save_checkpoint = _on_save_checkpoint LightningDataModule.on_load_checkpoint = _on_load_checkpoint diff --git a/src/pytorch_lightning/callbacks/lr_finder.py b/src/pytorch_lightning/callbacks/lr_finder.py index 4d235751ca791..1c950e64086b9 100644 --- a/src/pytorch_lightning/callbacks/lr_finder.py +++ b/src/pytorch_lightning/callbacks/lr_finder.py @@ -85,7 +85,7 @@ def __init__( max_lr: float = 1, num_training_steps: int = 100, mode: str = "exponential", - early_stop_threshold: float = 4.0, + early_stop_threshold: Optional[float] = 4.0, update_attr: bool = False, ) -> None: mode = mode.lower() diff --git a/src/pytorch_lightning/callbacks/stochastic_weight_avg.py b/src/pytorch_lightning/callbacks/stochastic_weight_avg.py index b1981f2a87a82..5abd2e17fc695 100644 --- a/src/pytorch_lightning/callbacks/stochastic_weight_avg.py +++ b/src/pytorch_lightning/callbacks/stochastic_weight_avg.py @@ -115,7 +115,7 @@ def __init__( if device is not None and not isinstance(device, (torch.device, str)): raise MisconfigurationException(f"device is expected to be a torch.device or a str. Found {device}") - self.n_averaged: Optional[torch.Tensor] = None + self.n_averaged: Optional[Tensor] = None self._swa_epoch_start = swa_epoch_start self._swa_lrs = swa_lrs self._annealing_epochs = annealing_epochs diff --git a/src/pytorch_lightning/core/module.py b/src/pytorch_lightning/core/module.py index 36a03b0329bf0..b318df9ad5ae2 100644 --- a/src/pytorch_lightning/core/module.py +++ b/src/pytorch_lightning/core/module.py @@ -88,7 +88,6 @@ class LightningModule( "automatic_optimization", "truncated_bptt_steps", "trainer", - "use_amp", # from graveyard ] + _DeviceDtypeModuleMixin.__jit_unused_properties__ + HyperparametersMixin.__jit_unused_properties__ @@ -403,7 +402,7 @@ def log( " but it should not contain information about `dataloader_idx`" ) - value = apply_to_collection(value, (torch.Tensor, numbers.Number), self.__to_tensor, name) + value = apply_to_collection(value, (Tensor, numbers.Number), self.__to_tensor, name) if self.trainer._logger_connector.should_reset_tensors(self._current_fx_name): # if we started a new epoch (running its first batch) the hook name has changed @@ -545,10 +544,10 @@ def __check_not_nested(value: dict, name: str) -> None: def __check_allowed(v: Any, name: str, value: Any) -> None: raise ValueError(f"`self.log({name}, {value})` was called, but `{type(v).__name__}` values cannot be logged") - def __to_tensor(self, value: Union[torch.Tensor, numbers.Number], name: str) -> Tensor: + def __to_tensor(self, value: Union[Tensor, numbers.Number], name: str) -> Tensor: value = ( value.clone().detach().to(self.device) - if isinstance(value, torch.Tensor) + if isinstance(value, Tensor) else torch.tensor(value, device=self.device) ) if not torch.numel(value) == 1: @@ -1980,9 +1979,17 @@ def from_compiled(cls, model: "torch._dynamo.OptimizedModule") -> "pl.LightningM "compiler": "dynamo", "dynamo_ctx": model.dynamo_ctx, "original_forward": orig_module.forward, + "original_training_step": orig_module.training_step, + "original_validation_step": orig_module.validation_step, + "original_test_step": orig_module.test_step, + "original_predict_step": orig_module.predict_step, } orig_module.forward = model.dynamo_ctx(orig_module.forward) # type: ignore[assignment] + orig_module.training_step = model.dynamo_ctx(orig_module.training_step) # type: ignore[assignment] + orig_module.validation_step = model.dynamo_ctx(orig_module.validation_step) # type: ignore[assignment] + orig_module.test_step = model.dynamo_ctx(orig_module.test_step) # type: ignore[assignment] + orig_module.predict_step = model.dynamo_ctx(orig_module.predict_step) # type: ignore[assignment] return orig_module @classmethod @@ -2011,6 +2018,10 @@ def to_uncompiled(cls, model: Union["pl.LightningModule", "torch._dynamo.Optimiz raise ValueError("`model` must either be an instance of torch._dynamo.OptimizedModule or LightningModule") model.forward = model._compiler_ctx["original_forward"] # type: ignore[assignment] + model.training_step = model._compiler_ctx["original_training_step"] # type: ignore[assignment] + model.validation_step = model._compiler_ctx["original_validation_step"] # type: ignore[assignment] + model.test_step = model._compiler_ctx["original_test_step"] # type: ignore[assignment] + model.predict_step = model._compiler_ctx["original_predict_step"] # type: ignore[assignment] model._compiler_ctx = None return model diff --git a/src/pytorch_lightning/loggers/mlflow.py b/src/pytorch_lightning/loggers/mlflow.py index 35f9b8396dd0f..a9a00130d88b0 100644 --- a/src/pytorch_lightning/loggers/mlflow.py +++ b/src/pytorch_lightning/loggers/mlflow.py @@ -24,9 +24,9 @@ from time import time from typing import Any, Dict, Mapping, Optional, Union -import torch import yaml from lightning_utilities.core.imports import module_available +from torch import Tensor from typing_extensions import Literal from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint @@ -332,7 +332,7 @@ def _scan_and_log_checkpoints(self, checkpoint_callback: ModelCheckpoint) -> Non for t, p, s, tag in checkpoints: metadata = { # Ensure .item() is called to store Tensor contents - "score": s.item() if isinstance(s, torch.Tensor) else s, + "score": s.item() if isinstance(s, Tensor) else s, "original_filename": Path(p).name, "Checkpoint": { k: getattr(checkpoint_callback, k) diff --git a/src/pytorch_lightning/serve/servable_module.py b/src/pytorch_lightning/serve/servable_module.py index ef95187c63245..1ceb42777eb1d 100644 --- a/src/pytorch_lightning/serve/servable_module.py +++ b/src/pytorch_lightning/serve/servable_module.py @@ -1,6 +1,7 @@ from typing import Any, Callable, Dict, Tuple import torch +from torch import Tensor class ServableModule(torch.nn.Module): @@ -70,7 +71,7 @@ def configure_serialization(self) -> Tuple[Dict[str, Callable], Dict[str, Callab """ ... - def serve_step(self, *args: torch.Tensor, **kwargs: torch.Tensor) -> Dict[str, torch.Tensor]: + def serve_step(self, *args: Tensor, **kwargs: Tensor) -> Dict[str, Tensor]: r""" Returns the predictions of your model as a dictionary. diff --git a/src/pytorch_lightning/strategies/fully_sharded_native.py b/src/pytorch_lightning/strategies/fully_sharded_native.py index 356a7556995d8..ca95dec9006e4 100644 --- a/src/pytorch_lightning/strategies/fully_sharded_native.py +++ b/src/pytorch_lightning/strategies/fully_sharded_native.py @@ -21,7 +21,11 @@ import pytorch_lightning as pl from lightning_fabric.plugins import CheckpointIO, ClusterEnvironment -from lightning_fabric.strategies.fsdp import _optimizer_has_flat_params, _setup_activation_checkpointing +from lightning_fabric.strategies.fsdp import ( + _init_cpu_offload, + _optimizer_has_flat_params, + _setup_activation_checkpointing, +) from lightning_fabric.utilities.distributed import ( _get_default_process_group_backend_for_device, _init_dist_connection, @@ -84,14 +88,10 @@ class DDPFullyShardedNativeStrategy(ParallelStrategy): `this tutorial `__ for more information. Arguments: - cpu_offload: - CPU offloading config. Currently, only parameter and gradient CPU - offload is supported. It can be enabled via passing in - ``cpu_offload=CPUOffload(offload_params=True)``. Note that this - currently implicitly enables gradient offloading to CPU in order for - params and grads to be on same device to work with optimizer. This - API is subject to change. Default is ``None`` in which case there - will be no offloading. + cpu_offload: Enable offloading parameters and gradients to CPU to save GPU memory at the cost of speed. + You can also pass a config: ``cpu_offload=CPUOffload(offload_params=True)``. Note that this currently + implicitly enables gradient offloading to CPU in order for parameters and gradients to be on same device + to work with the optimizer. This API is subject to change. Default: no offoading backward_prefetch: This is an experimental feature that is subject to change in the the near future. It allows users to enable two different backward_prefetch @@ -120,7 +120,7 @@ def __init__( checkpoint_io: Optional[CheckpointIO] = None, precision_plugin: Optional[PrecisionPlugin] = None, process_group_backend: Optional[str] = None, - cpu_offload: Optional[CPUOffload] = None, + cpu_offload: Union[bool, "CPUOffload", None] = None, backward_prefetch: Optional[BackwardPrefetch] = None, mixed_precision: Optional[MixedPrecision] = None, activation_checkpointing: Optional[Union[Type[Module], List[Type[Module]]]] = None, @@ -141,7 +141,7 @@ def __init__( self._process_group = None self.num_nodes = 1 self._process_group_backend = process_group_backend - self.cpu_offload = cpu_offload + self.cpu_offload = _init_cpu_offload(cpu_offload) self.backward_prefetch = backward_prefetch self.mixed_precision = mixed_precision self._rank_0_will_call_children_scripts: bool = False @@ -403,6 +403,6 @@ def register_strategies(cls, strategy_registry: Dict) -> None: "fsdp_native_full_shard_offload", cls, description="Native FSDP with Full Sharding and CPU Offloading", - cpu_offload=CPUOffload(offload_params=True), + cpu_offload=True, ) cls._registered_strategies.append("fsdp_native_full_shard_offload") diff --git a/src/pytorch_lightning/strategies/tpu_spawn.py b/src/pytorch_lightning/strategies/tpu_spawn.py index 0b7c182a859b2..167a572181506 100644 --- a/src/pytorch_lightning/strategies/tpu_spawn.py +++ b/src/pytorch_lightning/strategies/tpu_spawn.py @@ -289,20 +289,22 @@ def remove_checkpoint(self, filepath: _PATH) -> None: self.checkpoint_io.remove_checkpoint(filepath) def all_gather(self, tensor: Tensor, group: Optional[Any] = None, sync_grads: bool = False) -> Tensor: - """ - Function to gather a tensor from several distributed processes + """Function to gather a tensor from several distributed processes. + Args: tensor: tensor of shape (batch, ...) group: not available with TPUs - sync_grads: not available with TPUs + sync_grads: flag that allows users to synchronize gradients for the all_gather operation Return: A tensor of shape (world_size, batch, ...) """ if isinstance(tensor, Tensor) and tensor.dim() == 0: tensor = tensor.unsqueeze(0) + + import torch_xla.core.functions as xf import torch_xla.core.xla_model as xm - return xm.all_gather(tensor) + return xf.all_gather(tensor) if sync_grads else xm.all_gather(tensor) def teardown(self) -> None: super().teardown() diff --git a/src/pytorch_lightning/strategies/utils.py b/src/pytorch_lightning/strategies/utils.py index 4989f265d760c..643d1aeb1b3ae 100644 --- a/src/pytorch_lightning/strategies/utils.py +++ b/src/pytorch_lightning/strategies/utils.py @@ -16,6 +16,7 @@ from inspect import getmembers, isclass import torch +from torch import Tensor from lightning_fabric.plugins.precision.utils import _convert_fp_tensor from lightning_fabric.strategies import _StrategyRegistry @@ -40,7 +41,7 @@ def _call_register_strategies(registry: _StrategyRegistry, base_module: str) -> mod.register_strategies(registry) -def _fp_to_half(tensor: torch.Tensor, precision: PrecisionType) -> torch.Tensor: +def _fp_to_half(tensor: Tensor, precision: PrecisionType) -> Tensor: if precision == PrecisionType.HALF: return _convert_fp_tensor(tensor, torch.half) if precision == PrecisionType.BFLOAT: diff --git a/src/pytorch_lightning/trainer/connectors/accelerator_connector.py b/src/pytorch_lightning/trainer/connectors/accelerator_connector.py index 932a15d416ef4..f2a7d14b26f88 100644 --- a/src/pytorch_lightning/trainer/connectors/accelerator_connector.py +++ b/src/pytorch_lightning/trainer/connectors/accelerator_connector.py @@ -80,12 +80,7 @@ from pytorch_lightning.strategies.ddp_spawn import _DDP_FORK_ALIASES from pytorch_lightning.tuner.auto_gpu_select import pick_multiple_gpus from pytorch_lightning.utilities.exceptions import MisconfigurationException -from pytorch_lightning.utilities.imports import ( - _HOROVOD_AVAILABLE, - _HPU_AVAILABLE, - _IPU_AVAILABLE, - _TORCH_GREATER_EQUAL_1_11, -) +from pytorch_lightning.utilities.imports import _HOROVOD_AVAILABLE, _IPU_AVAILABLE, _TORCH_GREATER_EQUAL_1_11 from pytorch_lightning.utilities.rank_zero import rank_zero_deprecation, rank_zero_info, rank_zero_warn log = logging.getLogger(__name__) @@ -499,7 +494,7 @@ def _choose_auto_accelerator(self) -> str: return "tpu" if _IPU_AVAILABLE: return "ipu" - if _HPU_AVAILABLE: + if HPUAccelerator.is_available(): return "hpu" if MPSAccelerator.is_available(): return "mps" diff --git a/src/pytorch_lightning/trainer/supporters.py b/src/pytorch_lightning/trainer/supporters.py index dc66152a33e97..a1fe23ea02e06 100644 --- a/src/pytorch_lightning/trainer/supporters.py +++ b/src/pytorch_lightning/trainer/supporters.py @@ -18,6 +18,7 @@ import torch from lightning_utilities.core.apply_func import apply_to_collection, apply_to_collections +from torch import Tensor from torch.utils.data import Dataset from torch.utils.data.dataloader import _BaseDataLoaderIter, _MultiProcessingDataLoaderIter, DataLoader from torch.utils.data.dataset import IterableDataset @@ -59,18 +60,18 @@ def reset(self, window_length: Optional[int] = None) -> None: """Empty the accumulator.""" if window_length is not None: self.window_length = window_length - self.memory: Optional[torch.Tensor] = None + self.memory: Optional[Tensor] = None self.current_idx: int = 0 self.last_idx: Optional[int] = None self.rotated: bool = False - def last(self) -> Optional[torch.Tensor]: + def last(self) -> Optional[Tensor]: """Get the last added element.""" if self.last_idx is not None: - assert isinstance(self.memory, torch.Tensor) + assert isinstance(self.memory, Tensor) return self.memory[self.last_idx].float() - def append(self, x: torch.Tensor) -> None: + def append(self, x: Tensor) -> None: """Add an element to the accumulator.""" if self.memory is None: # tradeoff memory for speed by keeping the memory on device @@ -89,21 +90,21 @@ def append(self, x: torch.Tensor) -> None: if self.current_idx == 0: self.rotated = True - def mean(self) -> Optional[torch.Tensor]: + def mean(self) -> Optional[Tensor]: """Get mean value from stored elements.""" return self._agg_memory("mean") - def max(self) -> Optional[torch.Tensor]: + def max(self) -> Optional[Tensor]: """Get maximal value from stored elements.""" return self._agg_memory("max") - def min(self) -> Optional[torch.Tensor]: + def min(self) -> Optional[Tensor]: """Get minimal value from stored elements.""" return self._agg_memory("min") - def _agg_memory(self, how: str) -> Optional[torch.Tensor]: + def _agg_memory(self, how: str) -> Optional[Tensor]: if self.last_idx is not None: - assert isinstance(self.memory, torch.Tensor) + assert isinstance(self.memory, Tensor) if self.rotated: return getattr(self.memory.float(), how)() return getattr(self.memory[: self.current_idx].float(), how)() diff --git a/src/pytorch_lightning/trainer/trainer.py b/src/pytorch_lightning/trainer/trainer.py index 9123ac23c5cab..8f50e6a935113 100644 --- a/src/pytorch_lightning/trainer/trainer.py +++ b/src/pytorch_lightning/trainer/trainer.py @@ -48,7 +48,7 @@ from lightning_fabric.utilities.data import _auto_add_worker_init_fn from lightning_fabric.utilities.types import _PATH from lightning_fabric.utilities.warnings import PossibleUserWarning -from pytorch_lightning.accelerators import Accelerator, HPUAccelerator, TPUAccelerator +from pytorch_lightning.accelerators import Accelerator, TPUAccelerator from pytorch_lightning.callbacks import Callback, Checkpoint, EarlyStopping, ProgressBarBase from pytorch_lightning.callbacks.prediction_writer import BasePredictionWriter from pytorch_lightning.core.datamodule import LightningDataModule @@ -991,11 +991,11 @@ def _run( if model._compiler_ctx is not None: supported_strategies = [SingleDeviceStrategy, DDPStrategy, DDPFullyShardedNativeStrategy] if self.strategy is not None and not any(isinstance(self.strategy, s) for s in supported_strategies): - supported_strategy_names = " ".join(s.__name__ for s in supported_strategies) + supported_strategy_names = ", ".join(s.__name__ for s in supported_strategies) raise RuntimeError( "Using a compiled model is incompatible with the current strategy: " f"{self.strategy.__class__.__name__}. " - f"Only {supported_strategy_names} support compilation." + f"Only {supported_strategy_names} support compilation. " "Either switch to one of the supported strategies or avoid passing in " "a compiled model." ) @@ -2261,13 +2261,11 @@ def configure_optimizers(self): @contextmanager def _evaluation_context(accelerator: Accelerator, inference_mode: bool = True) -> Generator: - # inference mode is not supported with gloo backend (#9431), - # and HPU & TPU accelerators. + # inference mode is not supported with gloo backend (#9431) and TPU accelerators. context_manager_class = ( torch.inference_mode if inference_mode and not (dist.is_available() and dist.is_initialized() and dist.get_backend() == "gloo") - and not isinstance(accelerator, HPUAccelerator) and not isinstance(accelerator, TPUAccelerator) else torch.no_grad ) diff --git a/src/pytorch_lightning/tuner/lr_finder.py b/src/pytorch_lightning/tuner/lr_finder.py index 29a5d47776a9e..2652267c93ae6 100644 --- a/src/pytorch_lightning/tuner/lr_finder.py +++ b/src/pytorch_lightning/tuner/lr_finder.py @@ -208,7 +208,7 @@ def lr_find( max_lr: float = 1, num_training: int = 100, mode: str = "exponential", - early_stop_threshold: float = 4.0, + early_stop_threshold: Optional[float] = 4.0, update_attr: bool = False, ) -> Optional[_LRFinder]: """See :meth:`~pytorch_lightning.tuner.tuning.Tuner.lr_find`""" @@ -225,6 +225,8 @@ def lr_find( ckpt_path = trainer.strategy.broadcast(ckpt_path) trainer.save_checkpoint(ckpt_path) + start_steps = trainer.global_step + # Arguments we adjust during the lr finder, save for restoring params = __lr_finder_dump_params(trainer) @@ -245,7 +247,7 @@ def lr_find( _try_loop_run(trainer, params) # Prompt if we stopped early - if trainer.global_step != num_training: + if trainer.global_step != num_training + start_steps: log.info(f"LR finder stopped early after {trainer.global_step} steps due to diverging loss.") # Transfer results from callback to lr finder object @@ -270,6 +272,7 @@ def lr_find( # Restore initial state of model trainer._checkpoint_connector.restore(ckpt_path) trainer.strategy.remove_checkpoint(ckpt_path) + trainer.fit_loop.restarting = False # reset restarting flag as checkpoint restoring sets it to True return lr_finder @@ -289,7 +292,7 @@ def __lr_finder_dump_params(trainer: "pl.Trainer") -> Dict[str, Any]: } -def __lr_finder_reset_params(trainer: "pl.Trainer", num_training: int, early_stop_threshold: float) -> None: +def __lr_finder_reset_params(trainer: "pl.Trainer", num_training: int, early_stop_threshold: Optional[float]) -> None: from pytorch_lightning.loggers.logger import DummyLogger trainer.strategy.lr_scheduler_configs = [] @@ -300,8 +303,8 @@ def __lr_finder_reset_params(trainer: "pl.Trainer", num_training: int, early_sto trainer.callbacks = [_LRCallback(num_training, early_stop_threshold, progress_bar_refresh_rate=1)] # No logging trainer.logger = DummyLogger() if trainer.logger is not None else None - # Max step set to number of iterations - trainer.fit_loop.max_steps = num_training + # Max step set to number of iterations starting at current number of iterations + trainer.fit_loop.max_steps = num_training + trainer.global_step trainer.limit_val_batches = num_training @@ -340,7 +343,7 @@ class _LRCallback(Callback): def __init__( self, num_training: int, - early_stop_threshold: float = 4.0, + early_stop_threshold: Optional[float] = 4.0, progress_bar_refresh_rate: int = 0, beta: float = 0.98, ): diff --git a/src/pytorch_lightning/utilities/imports.py b/src/pytorch_lightning/utilities/imports.py index f2efdfcb82fcf..1fcbbbe501716 100644 --- a/src/pytorch_lightning/utilities/imports.py +++ b/src/pytorch_lightning/utilities/imports.py @@ -52,6 +52,8 @@ from habana_frameworks.torch.utils.library_loader import is_habana_avaialble _HPU_AVAILABLE = is_habana_avaialble() + if _HPU_AVAILABLE: + _TORCH_GREATER_EQUAL_1_13 = compare_version("torch", operator.ge, "1.13.0", use_base_version=True) else: _HPU_AVAILABLE = False diff --git a/src/pytorch_lightning/utilities/types.py b/src/pytorch_lightning/utilities/types.py index a219c78b57a47..65db85b8edb9c 100644 --- a/src/pytorch_lightning/utilities/types.py +++ b/src/pytorch_lightning/utilities/types.py @@ -27,6 +27,13 @@ from torchmetrics import Metric from typing_extensions import Protocol, runtime_checkable +try: + from torch.optim.lr_scheduler import LRScheduler as TorchLRScheduler +except ImportError: + # For torch <= 1.13.x + # TODO: Remove once minimum torch version is 1.14 (or 2.0) + from torch.optim.lr_scheduler import _LRScheduler as TorchLRScheduler + from lightning_fabric.utilities.types import _LRScheduler, ProcessGroup, ReduceLROnPlateau _NUMBER = Union[int, float] @@ -111,9 +118,9 @@ def no_sync(self) -> Generator: # todo: improve LRSchedulerType naming/typing -LRSchedulerTypeTuple = (torch.optim.lr_scheduler._LRScheduler, torch.optim.lr_scheduler.ReduceLROnPlateau) -LRSchedulerTypeUnion = Union[torch.optim.lr_scheduler._LRScheduler, torch.optim.lr_scheduler.ReduceLROnPlateau] -LRSchedulerType = Union[Type[torch.optim.lr_scheduler._LRScheduler], Type[torch.optim.lr_scheduler.ReduceLROnPlateau]] +LRSchedulerTypeTuple = (TorchLRScheduler, torch.optim.lr_scheduler.ReduceLROnPlateau) +LRSchedulerTypeUnion = Union[TorchLRScheduler, torch.optim.lr_scheduler.ReduceLROnPlateau] +LRSchedulerType = Union[Type[TorchLRScheduler], Type[torch.optim.lr_scheduler.ReduceLROnPlateau]] LRSchedulerPLType = Union[_LRScheduler, ReduceLROnPlateau] diff --git a/tests/tests_app/cli/test_cmd_install.py b/tests/tests_app/cli/test_cmd_install.py index aa0c34ba6ed2d..2e2086348cb58 100644 --- a/tests/tests_app/cli/test_cmd_install.py +++ b/tests/tests_app/cli/test_cmd_install.py @@ -17,19 +17,19 @@ def test_valid_org_app_name(): # assert a bad app name should fail fake_app = "fakeuser/impossible/name" - result = runner.invoke(lightning_cli.install_app, [fake_app]) + result = runner.invoke(lightning_cli.cmd_install.install_app, [fake_app]) assert "app name format must have organization/app-name" in result.output # assert a good name (but unavailable name) should work fake_app = "fakeuser/ALKKLJAUHREKJ21234KLAKJDLF" - result = runner.invoke(lightning_cli.install_app, [fake_app]) + result = runner.invoke(lightning_cli.cmd_install.install_app, [fake_app]) assert f"app: '{fake_app}' is not available on ⚡ Lightning AI ⚡" in result.output assert result.exit_code # assert a good (and availablea name) works # This should be an app that's always in the gallery real_app = "lightning/invideo" - result = runner.invoke(lightning_cli.install_app, [real_app]) + result = runner.invoke(lightning_cli.cmd_install.install_app, [real_app]) assert "Press enter to continue:" in result.output @@ -47,16 +47,16 @@ def test_valid_unpublished_app_name(): assert "WARNING" in str(e.output) # assert aborted install - result = runner.invoke(lightning_cli.install_app, [real_app], input="q") + result = runner.invoke(lightning_cli.cmd_install.install_app, [real_app], input="q") assert "Installation aborted!" in result.output # assert a bad app name should fail fake_app = "https://github.com/Lightning-AI/install-appdd" - result = runner.invoke(lightning_cli.install_app, [fake_app, "--yes"]) + result = runner.invoke(lightning_cli.cmd_install.install_app, [fake_app, "--yes"]) assert "Looks like the github url was not found" in result.output # assert a good (and availablea name) works - result = runner.invoke(lightning_cli.install_app, [real_app]) + result = runner.invoke(lightning_cli.cmd_install.install_app, [real_app]) assert "Press enter to continue:" in result.output @@ -81,17 +81,17 @@ def test_valid_org_component_name(): # assert a bad name should fail fake_component = "fakeuser/impossible/name" - result = runner.invoke(lightning_cli.install_component, [fake_component]) + result = runner.invoke(lightning_cli.cmd_install.install_component, [fake_component]) assert "component name format must have organization/component-name" in result.output # assert a good name (but unavailable name) should work fake_component = "fakeuser/ALKKLJAUHREKJ21234KLAKJDLF" - result = runner.invoke(lightning_cli.install_component, [fake_component]) + result = runner.invoke(lightning_cli.cmd_install.install_component, [fake_component]) assert f"component: '{fake_component}' is not available on ⚡ Lightning AI ⚡" in result.output # assert a good (and availablea name) works fake_component = "lightning/lit-slack-messenger" - result = runner.invoke(lightning_cli.install_component, [fake_component]) + result = runner.invoke(lightning_cli.cmd_install.install_component, [fake_component]) assert "Press enter to continue:" in result.output @@ -100,13 +100,13 @@ def test_unpublished_component_url_parsing(): # assert a bad name should fail (no git@) fake_component = "https://github.com/Lightning-AI/LAI-slack-messenger" - result = runner.invoke(lightning_cli.install_component, [fake_component]) + result = runner.invoke(lightning_cli.cmd_install.install_component, [fake_component]) assert "Error, your github url must be in the following format" in result.output # assert a good (and availablea name) works sha = "14f333456ffb6758bd19458e6fa0bf12cf5575e1" real_component = f"git+https://github.com/Lightning-AI/LAI-slack-messenger.git@{sha}" - result = runner.invoke(lightning_cli.install_component, [real_component]) + result = runner.invoke(lightning_cli.cmd_install.install_component, [real_component]) assert "Press enter to continue:" in result.output @@ -148,26 +148,26 @@ def test_prompt_actions(): runner = CliRunner() # assert that the user can cancel the command with any letter other than y - result = runner.invoke(lightning_cli.install_app, [app_to_use], input="b") + result = runner.invoke(lightning_cli.cmd_install.install_app, [app_to_use], input="b") assert "Installation aborted!" in result.output # assert that the install happens with --yes - # result = runner.invoke(lightning_cli.install_app, [app_to_use, "--yes"]) + # result = runner.invoke(lightning_cli.cmd_install.install_app, [app_to_use, "--yes"]) # assert result.exit_code == 0 # assert that the install happens with y - # result = runner.invoke(lightning_cli.install_app, [app_to_use], input='y') + # result = runner.invoke(lightning_cli.cmd_install.install_app, [app_to_use], input='y') # assert result.exit_code == 0 # # assert that the install happens with yes - # result = runner.invoke(lightning_cli.install_app, [app_to_use], input='yes') + # result = runner.invoke(lightning_cli.cmd_install.install_app, [app_to_use], input='yes') # assert result.exit_code == 0 # assert that the install happens with pressing enter - # result = runner.invoke(lightning_cli.install_app, [app_to_use]) + # result = runner.invoke(lightning_cli.cmd_install.install_app, [app_to_use]) # TODO: how to check the output when the user types ctrl+c? - # result = runner.invoke(lightning_cli.install_app, [app_to_use], input='') + # result = runner.invoke(lightning_cli.cmd_install.install_app, [app_to_use], input='') @mock.patch("lightning_app.cli.cmd_install.subprocess", mock.MagicMock()) @@ -178,7 +178,7 @@ def test_version_arg_component(tmpdir, monkeypatch): # Version does not exist component_name = "lightning/lit-slack-messenger" version_arg = "NOT-EXIST" - result = runner.invoke(lightning_cli.install_component, [component_name, f"--version={version_arg}"]) + result = runner.invoke(lightning_cli.cmd_install.install_component, [component_name, f"--version={version_arg}"]) assert f"component: 'Version {version_arg} for {component_name}' is not" in str(result.exception) assert result.exit_code == 1 @@ -186,7 +186,9 @@ def test_version_arg_component(tmpdir, monkeypatch): # This somwehow fail in test but not when you actually run it version_arg = "0.0.1" runner = CliRunner() - result = runner.invoke(lightning_cli.install_component, [component_name, f"--version={version_arg}", "--yes"]) + result = runner.invoke( + lightning_cli.cmd_install.install_component, [component_name, f"--version={version_arg}", "--yes"] + ) assert result.exit_code == 0 @@ -198,14 +200,14 @@ def test_version_arg_app(tmpdir): app_name = "lightning/invideo" version_arg = "NOT-EXIST" runner = CliRunner() - result = runner.invoke(lightning_cli.install_app, [app_name, f"--version={version_arg}"]) + result = runner.invoke(lightning_cli.cmd_install.install_app, [app_name, f"--version={version_arg}"]) assert f"app: 'Version {version_arg} for {app_name}' is not" in str(result.exception) assert result.exit_code == 1 # Version exists version_arg = "0.0.2" runner = CliRunner() - result = runner.invoke(lightning_cli.install_app, [app_name, f"--version={version_arg}", "--yes"]) + result = runner.invoke(lightning_cli.cmd_install.install_app, [app_name, f"--version={version_arg}", "--yes"]) assert result.exit_code == 0 @@ -236,7 +238,9 @@ def test_install_resolve_latest_version(mock_show_install_app_prompt, tmpdir): }, ] } - runner.invoke(lightning_cli.install_app, [app_name, "--yes"]) # no version specified so latest is installed + runner.invoke( + lightning_cli.cmd_install.install_app, [app_name, "--yes"] + ) # no version specified so latest is installed assert mock_show_install_app_prompt.called assert mock_show_install_app_prompt.call_args[0][0]["version"] == "0.0.4" @@ -274,7 +278,7 @@ def test_install_app_shows_error(tmpdir): app_folder_dir.mkdir() with pytest.raises(SystemExit, match=f"Folder {str(app_folder_dir)} exists, please delete it and try again."): - cmd_install._install_app( + cmd_install._install_app_from_source( source_url=mock.ANY, git_url=mock.ANY, folder_name=str(app_folder_dir), overwrite=False ) @@ -360,7 +364,9 @@ def test_install_app_process(subprocess_mock, source_url, git_url, git_sha, tmpd app_folder_dir = Path(tmpdir / "some_random_directory").absolute() app_folder_dir.mkdir() - cmd_install._install_app(source_url, git_url, folder_name=str(app_folder_dir), overwrite=True, git_sha=git_sha) + cmd_install._install_app_from_source( + source_url, git_url, folder_name=str(app_folder_dir), overwrite=True, git_sha=git_sha + ) assert subprocess_mock.check_output.call_args_list[0].args == (["git", "clone", git_url],) if git_sha: assert subprocess_mock.check_output.call_args_list[1].args == (["git", "checkout", git_sha],) diff --git a/tests/tests_app/components/multi_node/__init__.py b/tests/tests_app/components/multi_node/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/tests/tests_app/components/multi_node/test_lite.py b/tests/tests_app/components/multi_node/test_lite.py new file mode 100644 index 0000000000000..b5e8900caead9 --- /dev/null +++ b/tests/tests_app/components/multi_node/test_lite.py @@ -0,0 +1,103 @@ +import os +from copy import deepcopy +from functools import partial +from unittest import mock + +import pytest +from lightning_utilities.core.imports import module_available +from tests_app.helpers.utils import no_warning_call + +import lightning_fabric as lf +from lightning_app.components.multi_node.lite import _LiteRunExecutor + + +class DummyLite(lf.Fabric): + def run(self): + pass + + +def dummy_callable(**kwargs): + lite = DummyLite(**kwargs) + return lite._all_passed_kwargs + + +def dummy_init(self, **kwargs): + self._all_passed_kwargs = kwargs + + +def _get_args_after_tracer_injection(**kwargs): + with mock.patch.object(lf.Fabric, "__init__", dummy_init): + ret_val = _LiteRunExecutor.run( + local_rank=0, + work_run=partial(dummy_callable, **kwargs), + main_address="1.2.3.4", + main_port=5, + node_rank=6, + num_nodes=7, + nprocs=8, + ) + env_vars = deepcopy(os.environ) + return ret_val, env_vars + + +def check_lightning_lite_mps(): + if module_available("lightning_lite"): + return lf.accelerators.MPSAccelerator.is_available() + return False + + +@pytest.mark.skipif(not check_lightning_lite_mps(), reason="Lightning lite not available or mps not available") +@pytest.mark.parametrize("accelerator_given,accelerator_expected", [("cpu", "cpu"), ("auto", "cpu"), ("gpu", "cpu")]) +def test_lite_run_executor_mps_forced_cpu(accelerator_given, accelerator_expected): + warning_str = ( + r"Forcing accelerator=cpu as other accelerators \(specifically MPS\) are not supported " + + "by PyTorch for distributed training on mps capable devices" + ) + if accelerator_expected != accelerator_given: + warning_context = pytest.warns(UserWarning, match=warning_str) + else: + warning_context = no_warning_call(match=warning_str + "*") + + with warning_context: + ret_val, env_vars = _get_args_after_tracer_injection(accelerator=accelerator_given) + assert ret_val["accelerator"] == accelerator_expected + + +@pytest.mark.parametrize( + "args_given,args_expected", + [ + ({"devices": 1, "num_nodes": 1, "accelerator": "gpu"}, {"devices": 8, "num_nodes": 7, "accelerator": "auto"}), + ({"strategy": "ddp_spawn"}, {"strategy": "ddp"}), + ({"strategy": "ddp_sharded_spawn"}, {"strategy": "ddp_sharded"}), + ], +) +@pytest.mark.skipif(not module_available("lightning"), reason="Lightning is required for this test") +def test_trainer_run_executor_arguments_choices(args_given: dict, args_expected: dict): + + # ddp with mps devices not available (tested separately, just patching here for cross-os testing of other args) + if lf.accelerators.MPSAccelerator.is_available(): + args_expected["accelerator"] = "cpu" + + ret_val, env_vars = _get_args_after_tracer_injection(**args_given) + + for k, v in args_expected.items(): + assert ret_val[k] == v + + assert env_vars["MASTER_ADDR"] == "1.2.3.4" + assert env_vars["MASTER_PORT"] == "5" + assert env_vars["GROUP_RANK"] == "6" + assert env_vars["RANK"] == str(0 + 6 * 8) + assert env_vars["LOCAL_RANK"] == "0" + assert env_vars["WORLD_SIZE"] == str(7 * 8) + assert env_vars["LOCAL_WORLD_SIZE"] == "8" + assert env_vars["TORCHELASTIC_RUN_ID"] == "1" + assert env_vars["LT_CLI_USED"] == "1" + + +@pytest.mark.skipif(not module_available("lightning"), reason="Lightning not available") +def test_lite_run_executor_invalid_strategy_instances(): + with pytest.raises(ValueError, match="DDP Spawned strategies aren't supported yet."): + _, _ = _get_args_after_tracer_injection(strategy=lf.strategies.DDPSpawnStrategy()) + + with pytest.raises(ValueError, match="DDP Spawned strategies aren't supported yet."): + _, _ = _get_args_after_tracer_injection(strategy=lf.strategies.DDPSpawnShardedStrategy()) diff --git a/tests/tests_app/components/multi_node/test_trainer.py b/tests/tests_app/components/multi_node/test_trainer.py new file mode 100644 index 0000000000000..c86e0968e2ab0 --- /dev/null +++ b/tests/tests_app/components/multi_node/test_trainer.py @@ -0,0 +1,99 @@ +import os +from copy import deepcopy +from functools import partial +from unittest import mock + +import pytest +from lightning_utilities.core.imports import module_available +from tests_app.helpers.utils import no_warning_call + +import pytorch_lightning as pl +from lightning_app.components.multi_node.trainer import _LightningTrainerRunExecutor + + +def dummy_callable(**kwargs): + t = pl.Trainer(**kwargs) + return t._all_passed_kwargs + + +def dummy_init(self, **kwargs): + self._all_passed_kwargs = kwargs + + +def _get_args_after_tracer_injection(**kwargs): + with mock.patch.object(pl.Trainer, "__init__", dummy_init): + ret_val = _LightningTrainerRunExecutor.run( + local_rank=0, + work_run=partial(dummy_callable, **kwargs), + main_address="1.2.3.4", + main_port=5, + node_rank=6, + num_nodes=7, + nprocs=8, + ) + env_vars = deepcopy(os.environ) + return ret_val, env_vars + + +def check_lightning_pytorch_and_mps(): + if module_available("pytorch_lightning"): + return pl.accelerators.MPSAccelerator.is_available() + return False + + +@pytest.mark.skipif(not check_lightning_pytorch_and_mps(), reason="pytorch_lightning and mps are required") +@pytest.mark.parametrize("accelerator_given,accelerator_expected", [("cpu", "cpu"), ("auto", "cpu"), ("gpu", "cpu")]) +def test_trainer_run_executor_mps_forced_cpu(accelerator_given, accelerator_expected): + warning_str = ( + r"Forcing accelerator=cpu as other accelerators \(specifically MPS\) are not supported " + + "by PyTorch for distributed training on mps capable devices" + ) + if accelerator_expected != accelerator_given: + warning_context = pytest.warns(UserWarning, match=warning_str) + else: + warning_context = no_warning_call(match=warning_str + "*") + + with warning_context: + ret_val, env_vars = _get_args_after_tracer_injection(accelerator=accelerator_given) + assert ret_val["accelerator"] == accelerator_expected + + +@pytest.mark.parametrize( + "args_given,args_expected", + [ + ({"devices": 1, "num_nodes": 1, "accelerator": "gpu"}, {"devices": 8, "num_nodes": 7, "accelerator": "auto"}), + ({"strategy": "ddp_spawn"}, {"strategy": "ddp"}), + ({"strategy": "ddp_sharded_spawn"}, {"strategy": "ddp_sharded"}), + ], +) +@pytest.mark.skipif(not module_available("pytorch"), reason="Lightning is not available") +def test_trainer_run_executor_arguments_choices( + args_given: dict, + args_expected: dict, +): + + if pl.accelerators.MPSAccelerator.is_available(): + args_expected.pop("accelerator", None) # Cross platform tests -> MPS is tested separately + + ret_val, env_vars = _get_args_after_tracer_injection(**args_given) + + for k, v in args_expected.items(): + assert ret_val[k] == v + + assert env_vars["MASTER_ADDR"] == "1.2.3.4" + assert env_vars["MASTER_PORT"] == "5" + assert env_vars["GROUP_RANK"] == "6" + assert env_vars["RANK"] == str(0 + 6 * 8) + assert env_vars["LOCAL_RANK"] == "0" + assert env_vars["WORLD_SIZE"] == str(7 * 8) + assert env_vars["LOCAL_WORLD_SIZE"] == "8" + assert env_vars["TORCHELASTIC_RUN_ID"] == "1" + + +@pytest.mark.skipif(not module_available("lightning"), reason="lightning not available") +def test_trainer_run_executor_invalid_strategy_instances(): + with pytest.raises(ValueError, match="DDP Spawned strategies aren't supported yet."): + _, _ = _get_args_after_tracer_injection(strategy=pl.strategies.DDPSpawnStrategy()) + + with pytest.raises(ValueError, match="DDP Spawned strategies aren't supported yet."): + _, _ = _get_args_after_tracer_injection(strategy=pl.strategies.DDPSpawnShardedStrategy()) diff --git a/tests/tests_app/components/test_auto_scaler.py b/tests/tests_app/components/test_auto_scaler.py new file mode 100644 index 0000000000000..436c3517d01ca --- /dev/null +++ b/tests/tests_app/components/test_auto_scaler.py @@ -0,0 +1,92 @@ +import time +from unittest.mock import patch + +import pytest + +from lightning_app import LightningWork +from lightning_app.components import AutoScaler + + +class EmptyWork(LightningWork): + def run(self): + pass + + +class AutoScaler1(AutoScaler): + def scale(self, replicas: int, metrics) -> int: + # only upscale + return replicas + 1 + + +class AutoScaler2(AutoScaler): + def scale(self, replicas: int, metrics) -> int: + # only downscale + return replicas - 1 + + +def test_num_replicas_after_init(): + """Test the number of works is the same as min_replicas after initialization.""" + min_replicas = 2 + auto_scaler = AutoScaler(EmptyWork, min_replicas=min_replicas) + assert auto_scaler.num_replicas == min_replicas + + +@patch("uvicorn.run") +@patch("lightning_app.components.auto_scaler._LoadBalancer.url") +@patch("lightning_app.components.auto_scaler.AutoScaler.num_pending_requests") +def test_num_replicas_not_above_max_replicas(*_): + """Test self.num_replicas doesn't exceed max_replicas.""" + max_replicas = 6 + auto_scaler = AutoScaler1( + EmptyWork, + min_replicas=1, + max_replicas=max_replicas, + autoscale_interval=0.001, + ) + + for _ in range(max_replicas + 1): + time.sleep(0.002) + auto_scaler.run() + + assert auto_scaler.num_replicas == max_replicas + + +@patch("uvicorn.run") +@patch("lightning_app.components.auto_scaler._LoadBalancer.url") +@patch("lightning_app.components.auto_scaler.AutoScaler.num_pending_requests") +def test_num_replicas_not_belo_min_replicas(*_): + """Test self.num_replicas doesn't exceed max_replicas.""" + min_replicas = 1 + auto_scaler = AutoScaler2( + EmptyWork, + min_replicas=min_replicas, + max_replicas=4, + autoscale_interval=0.001, + ) + + for _ in range(3): + time.sleep(0.002) + auto_scaler.run() + + assert auto_scaler.num_replicas == min_replicas + + +@pytest.mark.parametrize( + "replicas, metrics, expected_replicas", + [ + pytest.param(1, {"pending_requests": 1, "pending_works": 0}, 2, id="increase if no pending work"), + pytest.param(1, {"pending_requests": 1, "pending_works": 1}, 1, id="dont increase if pending works"), + pytest.param(8, {"pending_requests": 1, "pending_works": 0}, 7, id="reduce if requests < 25% capacity"), + pytest.param(8, {"pending_requests": 2, "pending_works": 0}, 8, id="dont reduce if requests >= 25% capacity"), + ], +) +def test_scale(replicas, metrics, expected_replicas): + """Test `scale()`, the default scaling strategy.""" + auto_scaler = AutoScaler( + EmptyWork, + min_replicas=1, + max_replicas=8, + max_batch_size=1, + ) + + assert auto_scaler.scale(replicas, metrics) == expected_replicas diff --git a/tests/tests_app/core/test_lightning_api.py b/tests/tests_app/core/test_lightning_api.py index 82b58cc36fac3..04b89c927941a 100644 --- a/tests/tests_app/core/test_lightning_api.py +++ b/tests/tests_app/core/test_lightning_api.py @@ -28,7 +28,7 @@ UIRefresher, ) from lightning_app.core.constants import APP_SERVER_PORT -from lightning_app.runners import MultiProcessRuntime, SingleProcessRuntime +from lightning_app.runners import MultiProcessRuntime from lightning_app.storage.drive import Drive from lightning_app.testing.helpers import _MockQueue from lightning_app.utilities.component import _set_frontend_context, _set_work_context @@ -71,12 +71,10 @@ def run(self): self.work_a.run() -# TODO: Resolve singleprocess - idea: explore frame calls recursively. -@pytest.mark.parametrize("runtime_cls", [MultiProcessRuntime]) -def test_app_state_api(runtime_cls): +def test_app_state_api(): """This test validates the AppState can properly broadcast changes from work within its own process.""" app = LightningApp(_A(), log_level="debug") - runtime_cls(app, start_server=True).dispatch() + MultiProcessRuntime(app, start_server=True).dispatch() assert app.root.work_a.var_a == -1 _set_work_context() assert app.root.work_a.drive.list(".") == ["test_app_state_api.txt"] @@ -105,13 +103,10 @@ def run(self): self._exit() -# TODO: Find why this test is flaky. -@pytest.mark.skip(reason="flaky test.") -@pytest.mark.parametrize("runtime_cls", [SingleProcessRuntime]) -def test_app_state_api_with_flows(runtime_cls, tmpdir): +def test_app_state_api_with_flows(tmpdir): """This test validates the AppState can properly broadcast changes from flows.""" app = LightningApp(A2(), log_level="debug") - runtime_cls(app, start_server=True).dispatch() + MultiProcessRuntime(app, start_server=True).dispatch() assert app.root.var_a == -1 @@ -181,13 +176,12 @@ def maybe_apply_changes(self): # FIXME: This test doesn't assert anything @pytest.mark.skip(reason="TODO: Resolve flaky test.") -@pytest.mark.parametrize("runtime_cls", [SingleProcessRuntime, MultiProcessRuntime]) -def test_app_stage_from_frontend(runtime_cls): +def test_app_stage_from_frontend(): """This test validates that delta from the `api_delta_queue` manipulating the ['app_state']['stage'] would start and stop the app.""" app = AppStageTestingApp(FlowA(), log_level="debug") app.stage = AppStage.BLOCKING - runtime_cls(app, start_server=True).dispatch() + MultiProcessRuntime(app, start_server=True).dispatch() def test_update_publish_state_and_maybe_refresh_ui(): diff --git a/tests/tests_app/core/test_lightning_app.py b/tests/tests_app/core/test_lightning_app.py index e5c265e2efde9..ea552adad7972 100644 --- a/tests/tests_app/core/test_lightning_app.py +++ b/tests/tests_app/core/test_lightning_app.py @@ -4,7 +4,6 @@ from re import escape from time import sleep from unittest import mock -from unittest.mock import ANY import pytest from deepdiff import Delta @@ -19,9 +18,9 @@ REDIS_QUEUES_READ_DEFAULT_TIMEOUT, STATE_UPDATE_TIMEOUT, ) -from lightning_app.core.queues import BaseQueue, MultiProcessQueue, RedisQueue, SingleProcessQueue +from lightning_app.core.queues import BaseQueue, MultiProcessQueue, RedisQueue from lightning_app.frontend import StreamlitFrontend -from lightning_app.runners import MultiProcessRuntime, SingleProcessRuntime +from lightning_app.runners import MultiProcessRuntime from lightning_app.storage import Path from lightning_app.storage.path import _storage_root_dir from lightning_app.testing.helpers import _RunIf @@ -82,7 +81,7 @@ def __init__(self, cache_calls: bool = True): self.has_finished = False def run(self): - self.counter += 1 + self.counter = self.counter + 1 if self.cache_calls: self.has_finished = True elif self.counter >= 3: @@ -96,40 +95,60 @@ def __init__(self): self.work_b = Work(cache_calls=False) def run(self): - self.work_a.run() - self.work_b.run() if self.work_a.has_finished and self.work_b.has_finished: self._exit() + self.work_a.run() + self.work_b.run() -@pytest.mark.skip -@pytest.mark.parametrize("component_cls", [SimpleFlow]) -@pytest.mark.parametrize("runtime_cls", [SingleProcessRuntime]) -def test_simple_app(component_cls, runtime_cls, tmpdir): - comp = component_cls() +def test_simple_app(tmpdir): + comp = SimpleFlow() app = LightningApp(comp, log_level="debug") assert app.root == comp expected = { - "app_state": ANY, - "vars": {"_layout": ANY, "_paths": {}}, + "app_state": mock.ANY, + "vars": {"_layout": mock.ANY, "_paths": {}}, "calls": {}, "flows": {}, + "structures": {}, "works": { "work_b": { - "vars": {"has_finished": False, "counter": 0, "_urls": {}, "_paths": {}}, - "calls": {}, + "vars": { + "has_finished": False, + "counter": 0, + "_cloud_compute": mock.ANY, + "_host": mock.ANY, + "_url": "", + "_future_url": "", + "_internal_ip": "", + "_paths": {}, + "_port": None, + "_restarting": False, + }, + "calls": {"latest_call_hash": None}, "changes": {}, }, "work_a": { - "vars": {"has_finished": False, "counter": 0, "_urls": {}, "_paths": {}}, - "calls": {}, + "vars": { + "has_finished": False, + "counter": 0, + "_cloud_compute": mock.ANY, + "_host": mock.ANY, + "_url": "", + "_future_url": "", + "_internal_ip": "", + "_paths": {}, + "_port": None, + "_restarting": False, + }, + "calls": {"latest_call_hash": None}, "changes": {}, }, }, "changes": {}, } assert app.state == expected - runtime_cls(app, start_server=False).dispatch() + MultiProcessRuntime(app, start_server=False).dispatch() assert comp.work_a.has_finished assert comp.work_b.has_finished @@ -357,11 +376,10 @@ def _apply_restarting(self): return True -@pytest.mark.parametrize("runtime_cls", [SingleProcessRuntime, MultiProcessRuntime]) -def test_app_restarting_move_to_blocking(runtime_cls, tmpdir): +def test_app_restarting_move_to_blocking(tmpdir): """Validates sending restarting move the app to blocking again.""" app = SimpleApp2(CounterFlow(), log_level="debug") - runtime_cls(app, start_server=False).dispatch() + MultiProcessRuntime(app, start_server=False).dispatch() class FlowWithFrontend(LightningFlow): @@ -411,7 +429,6 @@ def run(self): @pytest.mark.parametrize( "queue_type_cls, default_timeout", [ - (SingleProcessQueue, STATE_UPDATE_TIMEOUT), (MultiProcessQueue, STATE_UPDATE_TIMEOUT), pytest.param( RedisQueue, @@ -463,7 +480,7 @@ def make_delta(i): assert generated > expect -class SimpleFlow(LightningFlow): +class SimpleFlow2(LightningFlow): def __init__(self): super().__init__() self.counter = 0 @@ -476,8 +493,8 @@ def run(self): def test_maybe_apply_changes_from_flow(): """This test validates the app `_updated` is set to True only if the state was changed in the flow.""" - app = LightningApp(SimpleFlow()) - app.delta_queue = SingleProcessQueue("a", 0) + app = LightningApp(SimpleFlow2()) + app.delta_queue = MultiProcessQueue("a", 0) assert app._has_updated app.maybe_apply_changes() app.root.run() diff --git a/tests/tests_app/core/test_lightning_flow.py b/tests/tests_app/core/test_lightning_flow.py index ed668c12e9b1b..dacccfb3873aa 100644 --- a/tests/tests_app/core/test_lightning_flow.py +++ b/tests/tests_app/core/test_lightning_flow.py @@ -13,7 +13,7 @@ from lightning_app import LightningApp from lightning_app.core.flow import LightningFlow from lightning_app.core.work import LightningWork -from lightning_app.runners import MultiProcessRuntime, SingleProcessRuntime +from lightning_app.runners import MultiProcessRuntime from lightning_app.storage import Path from lightning_app.storage.path import _storage_root_dir from lightning_app.structures import Dict as LDict @@ -237,7 +237,7 @@ def run(self): flow = StateTransformationTest() assert flow.x == attribute app = LightningApp(flow) - SingleProcessRuntime(app, start_server=False).dispatch() + MultiProcessRuntime(app, start_server=False).dispatch() return app.state["vars"]["x"] @@ -519,11 +519,10 @@ def run(self): self._exit() -@pytest.mark.parametrize("runtime_cls", [SingleProcessRuntime]) @pytest.mark.parametrize("run_once", [False, True]) -def test_lightning_flow_iterate(tmpdir, runtime_cls, run_once): +def test_lightning_flow_iterate(tmpdir, run_once): app = LightningApp(CFlow(run_once)) - runtime_cls(app, start_server=False).dispatch() + MultiProcessRuntime(app, start_server=False).dispatch() assert app.root.looping == 0 assert app.root.tracker == 4 call_hash = list(v for v in app.root._calls if "experimental_iterate" in v)[0] @@ -537,7 +536,7 @@ def test_lightning_flow_iterate(tmpdir, runtime_cls, run_once): app.root.restarting = True assert app.root.looping == 0 assert app.root.tracker == 4 - runtime_cls(app, start_server=False).dispatch() + MultiProcessRuntime(app, start_server=False).dispatch() assert app.root.looping == 2 assert app.root.tracker == 10 if run_once else 20 iterate_call = app.root._calls[call_hash] @@ -555,12 +554,11 @@ def run(self): self.counter += 1 -@pytest.mark.parametrize("runtime_cls", [SingleProcessRuntime, MultiProcessRuntime]) -def test_lightning_flow_counter(runtime_cls, tmpdir): +def test_lightning_flow_counter(tmpdir): app = LightningApp(FlowCounter()) app.checkpointing = True - runtime_cls(app, start_server=False).dispatch() + MultiProcessRuntime(app, start_server=False).dispatch() assert app.root.counter == 3 checkpoint_dir = os.path.join(_storage_root_dir(), "checkpoints") @@ -571,7 +569,7 @@ def test_lightning_flow_counter(runtime_cls, tmpdir): with open(checkpoint_path, "rb") as f: app = LightningApp(FlowCounter()) app.set_state(pickle.load(f)) - runtime_cls(app, start_server=False).dispatch() + MultiProcessRuntime(app, start_server=False).dispatch() assert app.root.counter == 3 diff --git a/tests/tests_app/core/test_queues.py b/tests/tests_app/core/test_queues.py index 899ad9f606e85..9628e2414d5cf 100644 --- a/tests/tests_app/core/test_queues.py +++ b/tests/tests_app/core/test_queues.py @@ -5,7 +5,6 @@ from unittest import mock import pytest -import redis import requests_mock from lightning_app import LightningFlow @@ -17,12 +16,13 @@ @pytest.mark.skipif(not check_if_redis_running(), reason="Redis is not running") -@pytest.mark.parametrize("queue_type", [QueuingSystem.REDIS, QueuingSystem.MULTIPROCESS, QueuingSystem.SINGLEPROCESS]) +@pytest.mark.parametrize("queue_type", [QueuingSystem.REDIS, QueuingSystem.MULTIPROCESS]) def test_queue_api(queue_type, monkeypatch): """Test the Queue API. This test run all the Queue implementation but we monkeypatch the Redis Queues to avoid external interaction """ + import redis blpop_out = (b"entry-id", pickle.dumps("test_entry")) @@ -104,12 +104,14 @@ def test_redis_queue_read_timeout(redis_mock): @pytest.mark.parametrize( "queue_type, queue_process_mock", - [(QueuingSystem.SINGLEPROCESS, queue), (QueuingSystem.MULTIPROCESS, multiprocessing)], + [(QueuingSystem.MULTIPROCESS, multiprocessing)], ) def test_process_queue_read_timeout(queue_type, queue_process_mock, monkeypatch): + context = mock.MagicMock() queue_mocked = mock.MagicMock() - monkeypatch.setattr(queue_process_mock, "Queue", queue_mocked) + context.Queue = queue_mocked + monkeypatch.setattr(queue_process_mock, "get_context", mock.MagicMock(return_value=context)) my_queue = queue_type.get_readiness_queue() # default timeout diff --git a/tests/tests_app/runners/test_runtime.py b/tests/tests_app/runners/test_runtime.py index c79ef1207cae9..cf0e1feea34ae 100644 --- a/tests/tests_app/runners/test_runtime.py +++ b/tests/tests_app/runners/test_runtime.py @@ -13,7 +13,6 @@ @pytest.mark.parametrize( "runtime_type", [ - RuntimeType.SINGLEPROCESS, RuntimeType.MULTIPROCESS, RuntimeType.CLOUD, ], diff --git a/tests/tests_app/runners/test_singleprocess.py b/tests/tests_app/runners/test_singleprocess.py deleted file mode 100644 index 998f23e66296f..0000000000000 --- a/tests/tests_app/runners/test_singleprocess.py +++ /dev/null @@ -1,35 +0,0 @@ -import os -from unittest import mock - -import pytest - -from lightning_app import LightningFlow -from lightning_app.core.app import LightningApp -from lightning_app.runners import SingleProcessRuntime - - -class Flow(LightningFlow): - def run(self): - raise KeyboardInterrupt - - -def on_before_run(): - pass - - -def test_single_process_runtime(tmpdir): - - app = LightningApp(Flow()) - SingleProcessRuntime(app, start_server=False).dispatch(on_before_run=on_before_run) - - -@pytest.mark.parametrize( - "env,expected_url", - [ - ({}, "http://127.0.0.1:7501/view"), - ({"APP_SERVER_HOST": "http://test"}, "http://test"), - ], -) -def test_get_app_url(env, expected_url): - with mock.patch.dict(os.environ, env): - assert SingleProcessRuntime._get_app_url() == expected_url diff --git a/tests/tests_app/storage/test_copier.py b/tests/tests_app/storage/test_copier.py index df241ed34d1ec..9235c6ef9d7a3 100644 --- a/tests/tests_app/storage/test_copier.py +++ b/tests/tests_app/storage/test_copier.py @@ -22,9 +22,13 @@ def _handle_exists_request(work, request): return Path._handle_exists_request(work, request) +@mock.patch("lightning_app.storage.path.pathlib.Path.is_dir") +@mock.patch("lightning_app.storage.path.pathlib.Path.stat") @mock.patch("lightning_app.storage.copier._filesystem") -def test_copier_copies_all_files(fs_mock, tmpdir): +def test_copier_copies_all_files(fs_mock, stat_mock, dir_mock, tmpdir): """Test that the Copier calls the copy with the information provided in the request.""" + stat_mock().st_size = 0 + dir_mock.return_value = False copy_request_queue = _MockQueue() copy_response_queue = _MockQueue() work = mock.Mock() @@ -38,9 +42,13 @@ def test_copier_copies_all_files(fs_mock, tmpdir): fs_mock().put.assert_called_once_with("file", tmpdir / ".shared" / "123") -def test_copier_handles_exception(monkeypatch): +@mock.patch("lightning_app.storage.path.pathlib.Path.is_dir") +@mock.patch("lightning_app.storage.path.pathlib.Path.stat") +def test_copier_handles_exception(stat_mock, dir_mock, monkeypatch): """Test that the Copier captures exceptions from the file copy and forwards them through the queue without raising it.""" + stat_mock().st_size = 0 + dir_mock.return_value = False copy_request_queue = _MockQueue() copy_response_queue = _MockQueue() fs = mock.Mock() diff --git a/tests/tests_app/structures/test_structures.py b/tests/tests_app/structures/test_structures.py index 05905c3421bec..3346da5a858fc 100644 --- a/tests/tests_app/structures/test_structures.py +++ b/tests/tests_app/structures/test_structures.py @@ -4,7 +4,7 @@ import pytest from lightning_app import LightningApp, LightningFlow, LightningWork -from lightning_app.runners import MultiProcessRuntime, SingleProcessRuntime +from lightning_app.runners import MultiProcessRuntime from lightning_app.storage.payload import Payload from lightning_app.structures import Dict, List from lightning_app.testing.helpers import EmptyFlow @@ -309,11 +309,10 @@ def run(self): @pytest.mark.skip(reason="tchaton: Resolve this test.") -@pytest.mark.parametrize("runtime_cls", [MultiProcessRuntime, SingleProcessRuntime]) @pytest.mark.parametrize("run_once_iterable", [False, True]) @pytest.mark.parametrize("cache_calls", [False, True]) @pytest.mark.parametrize("use_list", [False, True]) -def test_structure_with_iterate_and_fault_tolerance(runtime_cls, run_once_iterable, cache_calls, use_list): +def test_structure_with_iterate_and_fault_tolerance(run_once_iterable, cache_calls, use_list): class DummyFlow(LightningFlow): def __init__(self): super().__init__() @@ -360,7 +359,7 @@ def run(self): self.looping += 1 app = LightningApp(RootFlow(use_list, run_once_iterable, cache_calls)) - runtime_cls(app, start_server=False).dispatch() + MultiProcessRuntime(app, start_server=False).dispatch() assert app.root.iter[0 if use_list else "0"].counter == 1 assert app.root.iter[1 if use_list else "1"].counter == 0 assert app.root.iter[2 if use_list else "2"].counter == 0 @@ -368,7 +367,7 @@ def run(self): app = LightningApp(RootFlow(use_list, run_once_iterable, cache_calls)) app.root.restarting = True - runtime_cls(app, start_server=False).dispatch() + MultiProcessRuntime(app, start_server=False).dispatch() if run_once_iterable: expected_value = 1 @@ -497,3 +496,51 @@ def test_structures_with_payload(): app = LightningApp(FlowPayload(), log_level="debug") MultiProcessRuntime(app, start_server=False).dispatch() os.remove("payload") + + +def test_structures_have_name_on_init(): + """Test that the children in structures have the correct name assigned upon initialization.""" + + class ChildWork(LightningWork): + def run(self): + pass + + class Collection(EmptyFlow): + def __init__(self): + super().__init__() + self.list_structure = List() + self.list_structure.append(ChildWork()) + + self.dict_structure = Dict() + self.dict_structure["dict_child"] = ChildWork() + + flow = Collection() + LightningApp(flow) # wrap in app to init all component names + assert flow.list_structure[0].name == "root.list_structure.0" + assert flow.dict_structure["dict_child"].name == "root.dict_structure.dict_child" + + +class FlowWiStructures(LightningFlow): + def __init__(self): + super().__init__() + + self.ws = [EmptyFlow(), EmptyFlow()] + + self.ws1 = {"a": EmptyFlow(), "b": EmptyFlow()} + + self.ws2 = { + "a": EmptyFlow(), + "b": EmptyFlow(), + "c": List(EmptyFlow(), EmptyFlow()), + "d": Dict(**{"a": EmptyFlow()}), + } + + def run(self): + pass + + +def test_flow_without_structures(): + + flow = FlowWiStructures() + assert isinstance(flow.ws, List) + assert isinstance(flow.ws1, Dict) diff --git a/tests/tests_app/utilities/packaging/test_build_spec.py b/tests/tests_app/utilities/packaging/test_build_spec.py index ba497a5efbdb4..70c4a60374b67 100644 --- a/tests/tests_app/utilities/packaging/test_build_spec.py +++ b/tests/tests_app/utilities/packaging/test_build_spec.py @@ -29,7 +29,7 @@ def test_build_config_requirements_provided(): assert spec.requirements == [ "dask", "pandas", - "pytorch_" + "lightning==1.5.9", # ugly hack due to replacing `pytorch_lightning string` + "pytorch_lightning==1.5.9", "git+https://github.com/mit-han-lab/torchsparse.git@v1.4.0", ] assert spec == BuildConfig.from_dict(spec.to_dict()) @@ -50,7 +50,7 @@ def test_build_config_dockerfile_provided(): spec = BuildConfig(dockerfile="./projects/Dockerfile.cpu") assert not spec.requirements # ugly hack due to replacing `pytorch_lightning string - assert "pytorchlightning/pytorch_" + "lightning" in spec.dockerfile.data[0] + assert "pytorchlightning/pytorch_lightning" in spec.dockerfile.data[0] class DockerfileLightningTestApp(LightningTestApp): diff --git a/tests/tests_app/utilities/test_app_helpers.py b/tests/tests_app/utilities/test_app_helpers.py index 791d2011f7651..2241e262cd381 100644 --- a/tests/tests_app/utilities/test_app_helpers.py +++ b/tests/tests_app/utilities/test_app_helpers.py @@ -1,4 +1,5 @@ import os +from functools import partial from unittest import mock import pytest @@ -10,6 +11,8 @@ ) from lightning_app import LightningApp, LightningFlow, LightningWork +from lightning_app.core.flow import _RootFlow +from lightning_app.frontend import StaticWebFrontend from lightning_app.utilities.app_helpers import ( _handle_is_headless, _is_headless, @@ -119,14 +122,9 @@ def configure_layout(self): return {"name": "test", "content": "https://appurl"} -class FlowWithWorkLayout(Flow): - def __init__(self): - super().__init__() - - self.work = Work() - +class FlowWithFrontend(Flow): def configure_layout(self): - return {"name": "test", "content": self.work} + return StaticWebFrontend(".") class FlowWithMockedFrontend(Flow): @@ -153,16 +151,62 @@ def __init__(self): self.flow = FlowWithURLLayout() +class WorkWithStringLayout(Work): + def configure_layout(self): + return "http://appurl" + + +class WorkWithMockedFrontendLayout(Work): + def configure_layout(self): + return _MagicMockJsonSerializable() + + +class WorkWithFrontendLayout(Work): + def configure_layout(self): + return StaticWebFrontend(".") + + +class WorkWithNoneLayout(Work): + def configure_layout(self): + return None + + +class FlowWithWorkLayout(Flow): + def __init__(self, work): + super().__init__() + + self.work = work() + + def configure_layout(self): + return {"name": "test", "content": self.work} + + +class WorkClassRootFlow(_RootFlow): + """A ``_RootFlow`` which takes a work class rather than the work itself.""" + + def __init__(self, work): + super().__init__(work()) + + @pytest.mark.parametrize( "flow,expected", [ (Flow, True), (FlowWithURLLayout, False), - (FlowWithWorkLayout, False), + (FlowWithFrontend, False), (FlowWithMockedFrontend, False), (FlowWithMockedContent, False), (NestedFlow, True), (NestedFlowWithURLLayout, False), + (partial(WorkClassRootFlow, WorkWithStringLayout), False), + (partial(WorkClassRootFlow, WorkWithMockedFrontendLayout), False), + (partial(WorkClassRootFlow, WorkWithFrontendLayout), False), + (partial(WorkClassRootFlow, WorkWithNoneLayout), True), + (partial(FlowWithWorkLayout, Work), False), + (partial(FlowWithWorkLayout, WorkWithStringLayout), False), + (partial(FlowWithWorkLayout, WorkWithMockedFrontendLayout), False), + (partial(FlowWithWorkLayout, WorkWithFrontendLayout), False), + (partial(FlowWithWorkLayout, WorkWithNoneLayout), True), ], ) def test_is_headless(flow, expected): diff --git a/tests/tests_app/utilities/test_introspection.py b/tests/tests_app/utilities/test_introspection.py index 1f8be2eafe04d..6b64be31e9b8a 100644 --- a/tests/tests_app/utilities/test_introspection.py +++ b/tests/tests_app/utilities/test_introspection.py @@ -49,7 +49,6 @@ def test_introspection_lightning_overrides(): "BaseProfiler", "Callback", "LightningDataModule", - "LightningLite", "LightningLoggerBase", "LightningModule", "Loop", diff --git a/tests/tests_app/utilities/test_layout.py b/tests/tests_app/utilities/test_layout.py new file mode 100644 index 0000000000000..98921e3d0000e --- /dev/null +++ b/tests/tests_app/utilities/test_layout.py @@ -0,0 +1,143 @@ +import pytest + +from lightning_app.core.flow import LightningFlow +from lightning_app.core.work import LightningWork +from lightning_app.frontend.web import StaticWebFrontend +from lightning_app.utilities.layout import _collect_layout + + +class _MockApp: + def __init__(self) -> None: + self.frontends = {} + + +class FlowWithFrontend(LightningFlow): + def configure_layout(self): + return StaticWebFrontend(".") + + +class WorkWithFrontend(LightningWork): + def run(self): + pass + + def configure_layout(self): + return StaticWebFrontend(".") + + +class FlowWithWorkWithFrontend(LightningFlow): + def __init__(self): + super().__init__() + + self.work = WorkWithFrontend() + + def configure_layout(self): + return {"name": "work", "content": self.work} + + +class FlowWithUrl(LightningFlow): + def configure_layout(self): + return {"name": "test", "content": "https://test"} + + +class WorkWithUrl(LightningWork): + def run(self): + pass + + def configure_layout(self): + return "https://test" + + +class FlowWithWorkWithUrl(LightningFlow): + def __init__(self): + super().__init__() + + self.work = WorkWithUrl() + + def configure_layout(self): + return {"name": "test", "content": self.work} + + +@pytest.mark.parametrize( + "flow,expected_layout,expected_frontends", + [ + (FlowWithFrontend, {}, [("root", StaticWebFrontend)]), + (FlowWithWorkWithFrontend, {}, [("root", StaticWebFrontend)]), + (FlowWithUrl, [{"name": "test", "content": "https://test", "target": "https://test"}], []), + (FlowWithWorkWithUrl, [{"name": "test", "content": "https://test", "target": "https://test"}], []), + ], +) +def test_collect_layout(flow, expected_layout, expected_frontends): + app = _MockApp() + flow = flow() + layout = _collect_layout(app, flow) + + assert layout == expected_layout + assert set(app.frontends.keys()) == {key for key, _ in expected_frontends} + for key, frontend_type in expected_frontends: + assert isinstance(app.frontends[key], frontend_type) + + +class FlowWithBadLayout(LightningFlow): + def configure_layout(self): + return 100 + + +class FlowWithBadLayoutDict(LightningFlow): + def configure_layout(self): + return {"this_key_should_not_be_here": "http://appurl"} + + +class FlowWithBadContent(LightningFlow): + def configure_layout(self): + return {"content": 100} + + +class WorkWithBadLayout(LightningWork): + def run(self): + pass + + def configure_layout(self): + return 100 + + +class FlowWithWorkWithBadLayout(LightningFlow): + def __init__(self): + super().__init__() + + self.work = WorkWithBadLayout() + + def configure_layout(self): + return {"name": "test", "content": self.work} + + +class FlowWithMultipleWorksWithFrontends(LightningFlow): + def __init__(self): + super().__init__() + + self.work1 = WorkWithFrontend() + self.work2 = WorkWithFrontend() + + def configure_layout(self): + return [{"name": "test1", "content": self.work1}, {"name": "test2", "content": self.work2}] + + +@pytest.mark.parametrize( + "flow,error_type,match", + [ + (FlowWithBadLayout, TypeError, "is an unsupported layout format"), + (FlowWithBadLayoutDict, ValueError, "missing a key 'content'."), + (FlowWithBadContent, ValueError, "contains an unsupported entry."), + (FlowWithWorkWithBadLayout, TypeError, "is of an unsupported type."), + ( + FlowWithMultipleWorksWithFrontends, + TypeError, + "The tab containing a `WorkWithFrontend` must be the only tab", + ), + ], +) +def test_collect_layout_errors(flow, error_type, match): + app = _MockApp() + flow = flow() + + with pytest.raises(error_type, match=match): + _collect_layout(app, flow) diff --git a/tests/tests_app/utilities/test_safe_pickle.py b/tests/tests_app/utilities/test_safe_pickle.py new file mode 100644 index 0000000000000..473fe28ed22f7 --- /dev/null +++ b/tests/tests_app/utilities/test_safe_pickle.py @@ -0,0 +1,11 @@ +import subprocess +from pathlib import Path + + +def test_safe_pickle_app(): + test_dir = Path(__file__).parent / "testdata" + proc = subprocess.Popen( + ["lightning", "run", "app", "safe_pickle_app.py", "--open-ui", "false"], stdout=subprocess.PIPE, cwd=test_dir + ) + stdout, _ = proc.communicate() + assert "Exiting the pickling app successfully" in stdout.decode("UTF-8") diff --git a/tests/tests_app/utilities/testdata/safe_pickle_app.py b/tests/tests_app/utilities/testdata/safe_pickle_app.py new file mode 100644 index 0000000000000..f15344360d85f --- /dev/null +++ b/tests/tests_app/utilities/testdata/safe_pickle_app.py @@ -0,0 +1,63 @@ +""" +This app tests three things +1. Can a work pickle `self` +2. Can the pickled work be unpickled in another work +3. Can the pickled work be unpickled from a script +""" + +import subprocess +from pathlib import Path + +from lightning_app import LightningApp, LightningFlow, LightningWork +from lightning_app.utilities import safe_pickle + + +class SelfPicklingWork(LightningWork): + def run(self): + with open("work.pkl", "wb") as f: + safe_pickle.dump(self, f) + + def get_test_string(self): + return f"Hello from {self.__class__.__name__}!" + + +class WorkThatLoadsPickledWork(LightningWork): + def run(self): + with open("work.pkl", "rb") as f: + work = safe_pickle.load(f) + assert work.get_test_string() == "Hello from SelfPicklingWork!" + + +script_load_pickled_work = """ +import pickle +work = pickle.load(open("work.pkl", "rb")) +print(work.get_test_string()) +""" + + +class RootFlow(LightningFlow): + def __init__(self): + super().__init__() + self.self_pickling_work = SelfPicklingWork() + self.work_that_loads_pickled_work = WorkThatLoadsPickledWork() + + def run(self): + self.self_pickling_work.run() + self.work_that_loads_pickled_work.run() + + with open("script_that_loads_pickled_work.py", "w") as f: + f.write(script_load_pickled_work) + + # read the output from subprocess + proc = subprocess.Popen(["python", "script_that_loads_pickled_work.py"], stdout=subprocess.PIPE) + assert "Hello from SelfPicklingWork" in proc.stdout.read().decode("UTF-8") + + # deleting the script + Path("script_that_loads_pickled_work.py").unlink() + # deleting the pkl file + Path("work.pkl").unlink() + + self._exit("Exiting the pickling app successfully!!") + + +app = LightningApp(RootFlow()) diff --git a/tests/tests_lite/strategies/test_fsdp.py b/tests/tests_lite/strategies/test_fsdp.py index f0cde3f1902b6..b02fa09d0f3e2 100644 --- a/tests/tests_lite/strategies/test_fsdp.py +++ b/tests/tests_lite/strategies/test_fsdp.py @@ -26,7 +26,7 @@ from lightning_fabric.utilities.imports import _TORCH_GREATER_EQUAL_1_12 if _TORCH_GREATER_EQUAL_1_12: - from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel, MixedPrecision + from torch.distributed.fsdp.fully_sharded_data_parallel import CPUOffload, FullyShardedDataParallel, MixedPrecision @mock.patch("lightning_fabric.strategies.fsdp._TORCH_GREATER_EQUAL_1_12", False) @@ -36,13 +36,26 @@ def test_fsdp_support(*_): @RunIf(min_torch="1.12") -def test_fsdp_custom_mixed_precision(*_): +def test_fsdp_custom_mixed_precision(): """Test that passing a custom mixed precision config works.""" config = MixedPrecision() strategy = FSDPStrategy(mixed_precision=config) assert strategy.mixed_precision_config == config +@RunIf(min_torch="1.12") +def test_fsdp_cpu_offload(): + """Test the different ways cpu offloading can be enabled.""" + # bool + strategy = FSDPStrategy(cpu_offload=True) + assert strategy.cpu_offload == CPUOffload(offload_params=True) + + # dataclass + config = CPUOffload() + strategy = FSDPStrategy(cpu_offload=config) + assert strategy.cpu_offload == config + + @RunIf(min_torch="1.12") def test_fsdp_setup_optimizer_validation(): """Test that `setup_optimizer()` validates the param groups and reference to FSDP parameters.""" diff --git a/tests/tests_lite/strategies/test_xla.py b/tests/tests_lite/strategies/test_xla.py index 96f6f28ca1fb2..934d4e3df59da 100644 --- a/tests/tests_lite/strategies/test_xla.py +++ b/tests/tests_lite/strategies/test_xla.py @@ -17,6 +17,7 @@ from unittest.mock import Mock import pytest +import torch from tests_lite.helpers.dataloaders import CustomNotImplementedErrorDataloader from tests_lite.helpers.models import RandomDataset, RandomIterableDataset from tests_lite.helpers.runif import RunIf @@ -113,3 +114,24 @@ def test_xla_validate_unsupported_iterable_dataloaders(_, dataloader, monkeypatc with pytest.raises(TypeError, match="TPUs do not currently support"): XLAStrategy().process_dataloader(dataloader) + + +def tpu_all_gather_fn(strategy): + for sync_grads in [True, False]: + tensor = torch.tensor(1.0, device=strategy.root_device, requires_grad=True) + result = strategy.all_gather(tensor, sync_grads=sync_grads) + summed = result.sum() + assert torch.equal(summed, torch.tensor(8.0)) + summed.backward() + if sync_grads: + assert torch.equal(tensor.grad, torch.tensor(1.0)) + else: + # As gradients are not synced, the original tensor will not have gradients. + assert tensor.grad is None + + +@RunIf(tpu=True) +@mock.patch.dict(os.environ, os.environ.copy(), clear=True) +def test_tpu_all_gather(): + """Test the all_gather operation on TPU.""" + xla_launch(tpu_all_gather_fn) diff --git a/tests/tests_lite/test_lite.py b/tests/tests_lite/test_lite.py index 5eee22fa5a62b..e9fdb7b2ab8c8 100644 --- a/tests/tests_lite/test_lite.py +++ b/tests/tests_lite/test_lite.py @@ -23,7 +23,7 @@ from tests_lite.helpers.runif import RunIf from tests_lite.helpers.utils import no_warning_call from torch import nn -from torch.utils.data import DataLoader, DistributedSampler, Sampler +from torch.utils.data import DataLoader, DistributedSampler, RandomSampler, Sampler, SequentialSampler, TensorDataset from lightning_fabric.fabric import Fabric from lightning_fabric.plugins import Precision @@ -40,7 +40,7 @@ from lightning_fabric.strategies.strategy import _Sharded from lightning_fabric.utilities import _StrategyType from lightning_fabric.utilities.exceptions import MisconfigurationException -from lightning_fabric.utilities.seed import pl_worker_init_function +from lightning_fabric.utilities.seed import pl_worker_init_function, seed_everything from lightning_fabric.utilities.warnings import PossibleUserWarning from lightning_fabric.wrappers import _LiteDataLoader, _LiteModule, _LiteOptimizer @@ -384,6 +384,32 @@ def test_setup_dataloaders_distributed_sampler_not_needed(): assert lite_dataloader.sampler is custom_sampler +def test_setup_dataloaders_distributed_sampler_shuffle(): + """Test that the DataLoader(shuffle=True|False) setting gets carried over correctly into the distributed + sampler.""" + lite = Fabric(accelerator="cpu", strategy="ddp_spawn", devices=2) + # no lite.launch(): pretend we are on rank 0 now + + dataset = TensorDataset(torch.arange(8)) + + # shuffling turned off + no_shuffle_dataloaders = [ + DataLoader(dataset), + DataLoader(dataset, shuffle=False), + DataLoader(dataset, sampler=SequentialSampler(dataset)), + ] + for dataloader in no_shuffle_dataloaders: + dataloader = lite.setup_dataloaders(dataloader) + assert list(t[0].item() for t in iter(dataloader)) == [0, 2, 4, 6] + + # shuffling turned on + shuffle_dataloaders = [DataLoader(dataset, shuffle=True), DataLoader(dataset, sampler=RandomSampler(dataset))] + for dataloader in shuffle_dataloaders: + seed_everything(1) + dataloader = lite.setup_dataloaders(dataloader) + assert list(t[0].item() for t in iter(dataloader)) == [5, 0, 2, 1] + + @mock.patch.dict(os.environ, {}, clear=True) def test_seed_everything(): """Test that seed everything is static and sets the worker init function on the dataloader.""" diff --git a/tests/tests_lite/test_parity.py b/tests/tests_lite/test_parity.py index c93a80f39ed4e..208d594266784 100644 --- a/tests/tests_lite/test_parity.py +++ b/tests/tests_lite/test_parity.py @@ -203,7 +203,7 @@ def test_boring_lite_model_ddp_spawn(precision, strategy, devices, accelerator, ) def test_boring_lite_model_ddp(precision, strategy, devices, accelerator, tmpdir): Fabric.seed_everything(42) - train_dataloader = DataLoader(RandomDataset(32, 4)) + train_dataloader = DataLoader(RandomDataset(32, 4), shuffle=True) model = BoringModel() num_epochs = 1 state_dict = deepcopy(model.state_dict()) @@ -214,13 +214,13 @@ def test_boring_lite_model_ddp(precision, strategy, devices, accelerator, tmpdir lite_model_state_dict = model.state_dict() for w_pure, w_lite in zip(state_dict.values(), lite_model_state_dict.values()): - assert not torch.equal(w_pure.cpu(), w_lite.cpu()) + assert not torch.allclose(w_pure.cpu(), w_lite.cpu()) Fabric.seed_everything(42) - train_dataloader = DataLoader(RandomDataset(32, 4)) + train_dataloader = DataLoader(RandomDataset(32, 4), shuffle=True) model = BoringModel() run(lite.global_rank, model, train_dataloader, num_epochs, precision, accelerator, tmpdir) pure_model_state_dict = model.state_dict() for w_pure, w_lite in zip(pure_model_state_dict.values(), lite_model_state_dict.values()): - assert torch.equal(w_pure.cpu(), w_lite.cpu()) + torch.testing.assert_close(w_pure.cpu(), w_lite.cpu()) diff --git a/tests/tests_pytorch/core/test_lightning_module.py b/tests/tests_pytorch/core/test_lightning_module.py index 32d41c2d50916..bb0562992a040 100644 --- a/tests/tests_pytorch/core/test_lightning_module.py +++ b/tests/tests_pytorch/core/test_lightning_module.py @@ -425,6 +425,17 @@ def test_proper_refcount(): assert sys.getrefcount(torch_module) == sys.getrefcount(lightning_module) +def test_lightning_module_scriptable(): + """Test that the LightningModule is `torch.jit.script`-able. + + Regression test for #15917. + """ + model = BoringModel() + trainer = Trainer() + model.trainer = trainer + torch.jit.script(model) + + def test_trainer_reference_recursively(): ensemble = LightningModule() inner = LightningModule() @@ -451,10 +462,27 @@ def test_compile_uncompile(): lit_model_compiled = LightningModule.from_compiled(model_compiled) + def has_dynamo(fn): + return any(el for el in dir(fn) if el.startswith("_torchdynamo")) + assert isinstance(lit_model_compiled, LightningModule) assert lit_model_compiled._compiler_ctx is not None + assert has_dynamo(lit_model_compiled.forward) + assert has_dynamo(lit_model_compiled.training_step) + assert has_dynamo(lit_model_compiled.validation_step) + assert has_dynamo(lit_model_compiled.test_step) + assert has_dynamo(lit_model_compiled.predict_step) lit_model_orig = LightningModule.to_uncompiled(lit_model) assert lit_model_orig._compiler_ctx is None assert lit_model_orig.forward == lit_model.forward + assert lit_model_orig.training_step == lit_model.training_step + assert lit_model_orig.validation_step == lit_model.validation_step + assert lit_model_orig.test_step == lit_model.test_step + assert lit_model_orig.predict_step == lit_model.predict_step + assert not has_dynamo(lit_model_orig.forward) + assert not has_dynamo(lit_model_orig.training_step) + assert not has_dynamo(lit_model_orig.validation_step) + assert not has_dynamo(lit_model_orig.test_step) + assert not has_dynamo(lit_model_orig.predict_step) diff --git a/tests/tests_pytorch/graveyard/test_core.py b/tests/tests_pytorch/graveyard/test_core.py index 8450f41f6c075..95db542658481 100644 --- a/tests/tests_pytorch/graveyard/test_core.py +++ b/tests/tests_pytorch/graveyard/test_core.py @@ -53,18 +53,3 @@ def on_load_checkpoint(self, checkpoint): match="`LightningDataModule.on_load_checkpoint`.*no longer supported as of v1.8.", ): trainer.fit(model, OnLoadDataModule()) - - -def test_v2_0_0_lightning_module_unsupported_use_amp(): - model = BoringModel() - with pytest.raises( - RuntimeError, - match="`LightningModule.use_amp`.*no longer accessible as of v1.8.", - ): - model.use_amp - - with pytest.raises( - RuntimeError, - match="`LightningModule.use_amp`.*no longer accessible as of v1.8.", - ): - model.use_amp = False diff --git a/tests/tests_pytorch/strategies/test_ddp_fully_sharded_native.py b/tests/tests_pytorch/strategies/test_ddp_fully_sharded_native.py index a9b47aad1dca5..d9f8f86dd0f79 100644 --- a/tests/tests_pytorch/strategies/test_ddp_fully_sharded_native.py +++ b/tests/tests_pytorch/strategies/test_ddp_fully_sharded_native.py @@ -17,7 +17,7 @@ from tests_pytorch.helpers.runif import RunIf if _TORCH_GREATER_EQUAL_1_12: - from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel, MixedPrecision + from torch.distributed.fsdp.fully_sharded_data_parallel import CPUOffload, FullyShardedDataParallel, MixedPrecision from torch.distributed.fsdp.wrap import wrap @@ -306,3 +306,16 @@ def __init__(self): ) as ckpt_mock: strategy._setup_model(model) ckpt_mock.assert_called_with(fsdp_mock(), checkpoint_wrapper_fn=ANY, check_fn=ANY) + + +@RunIf(min_torch="1.12") +def test_fully_sharded_native_strategy_cpu_offload(): + """Test the different ways cpu offloading can be enabled.""" + # bool + strategy = DDPFullyShardedNativeStrategy(cpu_offload=True) + assert strategy.cpu_offload == CPUOffload(offload_params=True) + + # dataclass + config = CPUOffload() + strategy = DDPFullyShardedNativeStrategy(cpu_offload=config) + assert strategy.cpu_offload == config diff --git a/tests/tests_pytorch/trainer/test_trainer.py b/tests/tests_pytorch/trainer/test_trainer.py index e1793b3a356aa..d1c1324cf3018 100644 --- a/tests/tests_pytorch/trainer/test_trainer.py +++ b/tests/tests_pytorch/trainer/test_trainer.py @@ -43,6 +43,7 @@ from pytorch_lightning.callbacks.prediction_writer import BasePredictionWriter from pytorch_lightning.core.saving import load_hparams_from_tags_csv, load_hparams_from_yaml, save_hparams_to_tags_csv from pytorch_lightning.demos.boring_classes import ( + BoringDataModule, BoringModel, RandomDataset, RandomIterableDataset, @@ -2247,12 +2248,14 @@ def test_trainer_compiled_model(): model = torch.compile(model) + data = BoringDataModule() + trainer = Trainer( max_epochs=1, limit_train_batches=1, limit_val_batches=1, ) - trainer.fit(model) + trainer.fit(model, data) assert trainer.model._compiler_ctx["compiler"] == "dynamo" @@ -2260,7 +2263,7 @@ def test_trainer_compiled_model(): assert model._compiler_ctx is None - trainer.train(model) + trainer.fit(model) assert trainer.model._compiler_ctx is None diff --git a/tests/tests_pytorch/tuner/test_lr_finder.py b/tests/tests_pytorch/tuner/test_lr_finder.py index ed4d9d33430f0..25fdcd35f31f7 100644 --- a/tests/tests_pytorch/tuner/test_lr_finder.py +++ b/tests/tests_pytorch/tuner/test_lr_finder.py @@ -441,6 +441,53 @@ def test_if_lr_finder_callback_already_configured(): trainer.tune(model) +def test_lr_finder_callback_restarting(tmpdir): + """Test that `LearningRateFinder` does not set restarting=True when loading checkpoint.""" + + num_lr_steps = 100 + + class MyBoringModel(BoringModel): + def __init__(self): + super().__init__() + self.learning_rate = 0.123 + + def on_train_batch_start(self, batch, batch_idx): + if getattr(self, "_expected_max_steps", None) is not None: + assert self.trainer.fit_loop.max_steps == self._expected_max_steps + + def configure_optimizers(self): + return torch.optim.SGD(self.parameters(), lr=self.learning_rate) + + class CustomLearningRateFinder(LearningRateFinder): + milestones = (1,) + + def lr_find(self, trainer, pl_module) -> None: + pl_module._expected_max_steps = trainer.global_step + self._num_training_steps + super().lr_find(trainer, pl_module) + pl_module._expected_max_steps = None + assert not trainer.fit_loop.restarting + + def on_train_epoch_start(self, trainer, pl_module): + if trainer.current_epoch in self.milestones or trainer.current_epoch == 0: + self.lr_find(trainer, pl_module) + + model = MyBoringModel() + trainer = Trainer( + default_root_dir=tmpdir, + max_epochs=3, + callbacks=[ + CustomLearningRateFinder(early_stop_threshold=None, update_attr=True, num_training_steps=num_lr_steps) + ], + limit_train_batches=10, + limit_val_batches=0, + limit_test_batches=0, + num_sanity_val_steps=0, + enable_model_summary=False, + ) + + trainer.fit(model) + + @mock.patch.dict(os.environ, os.environ.copy(), clear=True) @RunIf(standalone=True) def test_lr_finder_with_ddp(tmpdir):