diff --git a/.actions/setup_tools.py b/.actions/setup_tools.py
index 5088be2020738..a76e81246798c 100644
--- a/.actions/setup_tools.py
+++ b/.actions/setup_tools.py
@@ -94,11 +94,10 @@ def load_readme_description(path_dir: str, homepage: str, version: str) -> str:
text = text.replace("pytorch-lightning.readthedocs.io/en/stable/", f"pytorch-lightning.readthedocs.io/en/{version}")
# codecov badge
text = text.replace("/branch/master/graph/badge.svg", f"/release/{version}/graph/badge.svg")
- # replace github badges for release ones
+ # github actions badge
text = text.replace("badge.svg?branch=master&event=push", f"badge.svg?tag={version}")
- # Azure...
+ # azure pipelines badge
text = text.replace("?branchName=master", f"?branchName=refs%2Ftags%2F{version}")
- text = re.sub(r"\?definitionId=\d+&branchName=master", f"?definitionId=2&branchName=refs%2Ftags%2F{version}", text)
skip_begin = r""
skip_end = r""
diff --git a/.azure/gpu-benchmark.yml b/.azure/gpu-benchmark.yml
index ac5ca6f60a6b4..0de590f2c54a6 100644
--- a/.azure/gpu-benchmark.yml
+++ b/.azure/gpu-benchmark.yml
@@ -28,7 +28,7 @@ jobs:
cancelTimeoutInMinutes: "2"
pool: azure-jirka-spot
container:
- image: "pytorchlightning/pytorch_lightning:base-cuda-py3.9-torch1.12"
+ image: "pytorchlightning/pytorch_lightning:base-cuda-py3.9-torch1.12-cuda11.3.1"
options: "--runtime=nvidia -e NVIDIA_VISIBLE_DEVICES=all --shm-size=32g"
workspace:
clean: all
diff --git a/.azure/gpu-tests.yml b/.azure/gpu-tests.yml
index f37c17613affc..8ae670d265ced 100644
--- a/.azure/gpu-tests.yml
+++ b/.azure/gpu-tests.yml
@@ -26,7 +26,7 @@ jobs:
strategy:
matrix:
'PyTorch - stable':
- image: "pytorchlightning/pytorch_lightning:base-cuda-py3.9-torch1.12"
+ image: "pytorchlightning/pytorch_lightning:base-cuda-py3.9-torch1.12-cuda11.3.1"
# how long to run the job before automatically cancelling
timeoutInMinutes: "80"
# how much time to give 'run always even if cancelled tasks' before stopping them
@@ -44,7 +44,7 @@ jobs:
- bash: |
CHANGED_FILES=$(git diff --name-status origin/master -- . | awk '{print $2}')
- FILTER='src/pytorch_lightning|requirements/pytorch|tests/tests_pytorch|examples/pl_*|.azure/*'
+ FILTER='src/pytorch_lightning|requirements/pytorch|tests/tests_pytorch|examples/pl_*|.azure/gpu-tests.yml'
echo $CHANGED_FILES > changed_files.txt
MATCHES=$(cat changed_files.txt | grep -E $FILTER)
echo $MATCHES
@@ -75,7 +75,7 @@ jobs:
CUDA_VERSION_MM=$(python -c "import torch ; print(''.join(map(str, torch.version.cuda.split('.')[:2])))")
pip install "bagua-cuda$CUDA_VERSION_MM>=0.9.0"
pip install -e .[strategies]
- pip install deepspeed>0.6.4 # TODO: remove when docker images are upgraded
+ pip install -U deepspeed # TODO: remove when docker images are upgraded
pip install --requirement requirements/pytorch/devel.txt
pip list
env:
diff --git a/.github/BECOMING_A_CORE_CONTRIBUTOR.md b/.github/BECOMING_A_CORE_CONTRIBUTOR.md
index a179161f687a1..fd40e29e1ebf1 100644
--- a/.github/BECOMING_A_CORE_CONTRIBUTOR.md
+++ b/.github/BECOMING_A_CORE_CONTRIBUTOR.md
@@ -62,4 +62,4 @@ We are on the lookout for new people to join, however, if you feel like you meet
## Employment
-You can also become a [Grid.ai](https://www.grid.ai) employee or intern and work on Lightning. To get started, you can email `careers@grid.ai` with your resume or check out our [open job postings](https://boards.greenhouse.io/gridai).
+You can also become a [Lightning AI](https://lightning.ai/) employee or intern and work on Lightning. To get started, you can email `careers@lightning.ai` with your resume or check out our [open job postings](https://boards.greenhouse.io/lightningai).
diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS
index e40828557c2cf..0b4692731bff9 100644
--- a/.github/CODEOWNERS
+++ b/.github/CODEOWNERS
@@ -5,7 +5,7 @@
# the repo. Unless a later match takes precedence,
# @global-owner1 and @global-owner2 will be requested for
# review when someone opens a pull request.
-* @williamfalcon @borda @tchaton @SeanNaren @carmocca @awaelchli @justusschock @kaushikb11 @rohitgr7
+* @williamfalcon @borda @tchaton @awaelchli @kaushikb11 @rohitgr7
# CI/CD and configs
/.github/ @borda @carmocca @akihironitta @tchaton
@@ -26,50 +26,52 @@
/docs/source-app/expertise_levels @williamfalcon @Felonious-Spellfire @RobertLaurella
# Packages
+/src/pytorch_lightning @carmocca @justusschock
/src/pytorch_lightning/accelerators @williamfalcon @tchaton @SeanNaren @awaelchli @justusschock @kaushikb11
/src/pytorch_lightning/callbacks @williamfalcon @tchaton @carmocca @borda @kaushikb11
-/src/pytorch_lightning/core @tchaton @SeanNaren @borda @carmocca @justusschock @kaushikb11
+/src/pytorch_lightning/core @tchaton @borda @carmocca @justusschock @kaushikb11
/src/pytorch_lightning/distributed @williamfalcon @tchaton @awaelchli @kaushikb11
/src/pytorch_lightning/lite @tchaton @awaelchli @carmocca
/src/pytorch_lightning/loggers @tchaton @awaelchli @borda
-/src/pytorch_lightning/loggers/wandb.py @borisdayma
+/src/pytorch_lightning/loggers/wandb.py @borisdayma @borda
/src/pytorch_lightning/loggers/neptune.py @shnela @HubertJaworski @pkasprzyk @pitercl @Raalsky @aniezurawski @kamil-kaczmarek
/src/pytorch_lightning/loops @tchaton @awaelchli @justusschock @carmocca
-/src/pytorch_lightning/overrides @tchaton @SeanNaren @borda
-/src/pytorch_lightning/plugins @tchaton @SeanNaren @awaelchli @justusschock
+/src/pytorch_lightning/overrides @tchaton @borda
+/src/pytorch_lightning/plugins @tchaton @awaelchli @justusschock
/src/pytorch_lightning/profilers @williamfalcon @tchaton @borda @carmocca
/src/pytorch_lightning/profilers/pytorch.py @nbcsm @guotuofeng
/src/pytorch_lightning/strategies @tchaton @SeanNaren @awaelchli @justusschock @kaushikb11
-/src/pytorch_lightning/trainer @williamfalcon @borda @tchaton @SeanNaren @carmocca @awaelchli @justusschock @kaushikb11
-/src/pytorch_lightning/trainer/connectors @tchaton @SeanNaren @carmocca @borda
+/src/pytorch_lightning/trainer @williamfalcon @borda @tchaton @carmocca @awaelchli @justusschock @kaushikb11
+/src/pytorch_lightning/trainer/connectors @tchaton @carmocca @borda
/src/pytorch_lightning/tuner @SkafteNicki @borda @awaelchli
-/src/pytorch_lightning/utilities @borda @tchaton @SeanNaren @carmocca
+/src/pytorch_lightning/utilities @borda @tchaton @carmocca
-/src/lightning_app @tchaton @awaelchli @manskx @hhsecond
+/src/lightning_app @tchaton @manskx
+/src/lightning_app/cli/pl-app-template @tchaton @awaelchli @Borda
+/src/lightning_app/core @tchaton @awaelchli @manskx
+/src/lightning_app/core/queues.py @tchaton @hhsecond @manskx
+/src/lightning_app/runners/cloud.py @tchaton @hhsecond
+/src/lightning_app/testing @tchaton @manskx
+/src/lightning_app/__about__.py @nohalon @edenlightning @lantiga
# Examples
-/examples/app_* @tchaton @awaelchli @manskx @hhsecond
+/examples/app_* @tchaton @awaelchli @manskx @hhsecond
# App tests
-/tests/tests_app @tchaton @awaelchli @manskx @hhsecond
-/tests/tests_app_examples @tchaton @awaelchli @manskx @hhsecond
+/tests/tests_app @tchaton @awaelchli @manskx @hhsecond
+/tests/tests_app_examples @tchaton @awaelchli @manskx @hhsecond
# Specifics
-/src/pytorch_lightning/trainer/connectors/logger_connector @tchaton @carmocca
-/src/pytorch_lightning/trainer/progress.py @tchaton @awaelchli @carmocca
-
+/src/pytorch_lightning/trainer/connectors/logger_connector @tchaton @carmocca
+/src/pytorch_lightning/trainer/progress.py @tchaton @awaelchli @carmocca
# API
-/src/pytorch_lightning/callbacks/base.py @williamfalcon @awaelchli @ananthsub @carmocca
-/src/pytorch_lightning/core/datamodule.py @williamFalcon @awaelchli @ananthsub @carmocca
-/src/pytorch_lightning/trainer/trainer.py @williamfalcon @tchaton @awaelchli
-/src/pytorch_lightning/core/hooks.py @williamfalcon @tchaton @awaelchli @ananthsub @carmocca
-/src/pytorch_lightning/core/lightning.py @williamfalcon @tchaton @awaelchli
-
-# Testing
-/tests/helpers/boring_model.py @williamfalcon @tchaton @borda
+/src/pytorch_lightning/callbacks/callback.py @williamfalcon @awaelchli @ananthsub @carmocca
+/src/pytorch_lightning/core/datamodule.py @williamFalcon @awaelchli @ananthsub @carmocca
+/src/pytorch_lightning/trainer/trainer.py @williamfalcon @tchaton @awaelchli
+/src/pytorch_lightning/core/hooks.py @williamfalcon @tchaton @awaelchli @ananthsub @carmocca
+/src/pytorch_lightning/core/module.py @williamfalcon @tchaton @awaelchli
-/.github/CODEOWNERS @williamfalcon
-/.github/approve_config.yml @williamfalcon
-/SECURITY.md @williamfalcon
-/README.md @williamfalcon @edenlightning @borda
-/setup.py @williamfalcon @borda @carmocca
+/.github/CODEOWNERS @williamfalcon
+/SECURITY.md @williamfalcon
+/README.md @williamfalcon @edenlightning @borda
+/setup.py @williamfalcon @borda @carmocca
/src/pytorch_lightning/__about__.py @williamfalcon @borda @carmocca
diff --git a/.github/ISSUE_TEMPLATE/bug_report.md b/.github/ISSUE_TEMPLATE/bug_report.md
index f08865180ba1d..de4eacde1f39e 100644
--- a/.github/ISSUE_TEMPLATE/bug_report.md
+++ b/.github/ISSUE_TEMPLATE/bug_report.md
@@ -41,8 +41,16 @@ You can get the script and run it with:
```bash
wget https://raw.githubusercontent.com/Lightning-AI/lightning/master/requirements/collect_env_details.py
python collect_env_details.py
+
```
+
+
+ Details
+ Paste the output here and move this toggle outside of the comment block.
+
+
+
You can also fill out the list below manually.
-->
diff --git a/.github/checkgroup.yml b/.github/checkgroup.yml
new file mode 100644
index 0000000000000..8f1d3c6fb5e86
--- /dev/null
+++ b/.github/checkgroup.yml
@@ -0,0 +1,165 @@
+custom_service_name: "Lightning CI required checker"
+subprojects:
+ - id: "CI: CircleCI"
+ paths:
+ - ".circleci/**"
+ checks:
+ - "test-on-tpus"
+
+ - id: "CI: Azure"
+ paths:
+ - ".azure/**"
+ checks:
+ - "pytorch-lightning (GPUs)"
+ - "pytorch-lightning (GPUs) (testing PyTorch - stable)"
+ - "pytorch-lightning (HPUs)"
+ - "pytorch-lightning (IPUs)"
+
+ - id: "pytorch_lightning"
+ paths:
+ # all examples don't need to be added because they aren't used in CI, but these are
+ - "examples/run_ddp_examples.sh"
+ - "examples/convert_from_pt_to_pl/**"
+ - "examples/run_pl_examples.sh"
+ - "examples/pl_basics/backbone_image_classifier.py"
+ - "examples/pl_basics/autoencoder.py"
+ - "examples/pl_loops/mnist_lite.py"
+ - "examples/pl_fault_tolerant/automatic.py"
+ - "examples/test_pl_examples.py"
+ - "examples/pl_integrations/dali_image_classifier.py"
+ - "requirements/pytorch/**"
+ - "src/pytorch_lightning/**"
+ - "tests/tests_pytorch/**"
+ - "setup.cfg" # includes pytest config
+ - ".github/workflows/ci-pytorch*.yml"
+ - ".github/workflows/docs-*.yml"
+ checks:
+ - "conda (3.8, 1.10)"
+ - "conda (3.8, 1.9)"
+ - "conda (3.9, 1.11)"
+ - "conda (3.9, 1.12)"
+ - "cpu (macOS-11, 3.10, latest, stable)"
+ - "cpu (macOS-11, 3.7, latest, stable)"
+ - "cpu (macOS-11, 3.7, oldest, stable)"
+ - "cpu (ubuntu-20.04, 3.10, latest, stable)"
+ - "cpu (ubuntu-20.04, 3.7, latest, stable)"
+ - "cpu (ubuntu-20.04, 3.7, oldest, stable)"
+ - "cpu (windows-2022, 3.10, latest, stable)"
+ - "cpu (windows-2022, 3.7, latest, stable)"
+ - "cpu (windows-2022, 3.7, oldest, stable)"
+ - "doctest (pytorch)"
+ - "make-docs (pytorch)"
+ - "mypy"
+ - "PR Gatekeeper (pytorch)"
+ - "pytorch-lightning (GPUs)"
+ - "pytorch-lightning (GPUs) (testing PyTorch - stable)"
+ - "pytorch-lightning (HPUs)"
+ - "pytorch-lightning (IPUs)"
+ - "slow (macOS-11, 3.7, 1.11)"
+ - "slow (ubuntu-20.04, 3.7, 1.11)"
+ - "slow (windows-2022, 3.7, 1.11)"
+ - "test-on-tpus"
+
+ - id: "pytorch_lightning: Docs"
+ paths:
+ - "docs/source-pytorch/**"
+ - ".github/workflows/docs-*.yml"
+ - "requirements/pytorch/**"
+ checks:
+ - "doctest (pytorch)"
+ - "make-docs (pytorch)"
+
+ - id: "pytorch_lightning: Docker"
+ paths:
+ - "dockers/**"
+ checks:
+ - "build-conda (3.8, 1.10)"
+ - "build-conda (3.8, 1.9)"
+ - "build-conda (3.9, 1.11)"
+ - "build-conda (3.9, 1.12)"
+ - "build-cuda (3.8, 1.9, 11.1.1)"
+ - "build-cuda (3.9, 1.10, 11.3.1)"
+ - "build-cuda (3.9, 1.11, 11.3.1)"
+ - "build-cuda (3.9, 1.12, 11.3.1)"
+ - "build-cuda (3.9, 1.9, 11.1.1)"
+ - "build-hpu (1.5.0, 1.11.0)"
+ - "build-ipu (3.9, 1.9)"
+ - "build-NGC"
+ - "build-pl (3.9, 1.10, 11.3.1)"
+ - "build-pl (3.9, 1.11, 11.3.1)"
+ - "build-pl (3.9, 1.12, 11.3.1)"
+ - "build-pl (3.9, 1.9, 11.1.1)"
+ - "build-xla (3.7, 1.12)"
+
+ - id: "pytorch_lightning: mypy"
+ paths:
+ - ".github/workflows/code-checks.yml"
+ - "pyproject.toml" # includes mypy config
+ checks:
+ - "mypy"
+
+ - id: "lightning_app"
+ paths:
+ - ".github/workflows/ci-app*.yml"
+ - "examples/app_**"
+ - "requirements/app/**"
+ - "src/lightning_app/**"
+ - "tests/tests_app/**"
+ - "tests/tests_app_examples/**"
+ - "tests/tests_clusters/**"
+ # the examples are used in the app CI
+ - "examples/app_*"
+ checks:
+ - "Cloud Test (boring_app)"
+ - "Cloud Test (collect_failures)"
+ - "Cloud Test (commands_and_api)"
+ - "Cloud Test (custom_work_dependencies)"
+ - "Cloud Test (drive)"
+ - "Cloud Test (idle_timeout)"
+ - "Cloud Test (payload)"
+ - "Cloud Test (template_jupyterlab)"
+ - "Cloud Test (template_react_ui)"
+ - "Cloud Test (template_streamlit_ui)"
+ - "Cloud Test (v0_app)"
+ - "doctest (app)"
+ - "make-docs (app)"
+ - "pytest (macOS-11, 3.8, latest)"
+ - "pytest (macOS-11, 3.8, oldest)"
+ - "pytest (ubuntu-20.04, 3.8, latest)"
+ - "pytest (ubuntu-20.04, 3.8, oldest)"
+ - "pytest (windows-2022, 3.8, latest)"
+ - "pytest (windows-2022, 3.8, oldest)"
+
+ - id: "lightning_app: Docs"
+ paths:
+ - "docs/source-app/**"
+ - ".github/workflows/docs-*.yml"
+ - "requirements/app/**"
+ checks:
+ - "doctest (app)"
+ - "make-docs (app)"
+
+ - id: "install"
+ paths:
+ - ".actions/setup_tools.py"
+ - ".github/workflows/ci-pkg-install.yml"
+ - "setup.py"
+ - "src/lightning/**"
+ # all __about__, __version__, __setup__
+ - "src/*/__*.py"
+ checks:
+ - "install-meta-pypi (macOS-11, 3.8)"
+ - "install-meta-pypi (ubuntu-20.04, 3.8)"
+ - "install-meta-pypi (windows-2022, 3.8)"
+ - "install-meta-src (macOS-11, 3.8)"
+ - "install-meta-src (macOS-11, lightning, 3.8)"
+ - "install-meta-src (ubuntu-20.04, 3.8)"
+ - "install-meta-src (ubuntu-20.04, lightning, 3.8)"
+ - "install-meta-src (windows-2022, 3.8)"
+ - "install-meta-src (windows-2022, lightning, 3.8)"
+ - "install-standalone (macOS-11, app, 3.8)"
+ - "install-standalone (macOS-11, pytorch, 3.8)"
+ - "install-standalone (ubuntu-20.04, app, 3.8)"
+ - "install-standalone (ubuntu-20.04, pytorch, 3.8)"
+ - "install-standalone (windows-2022, app, 3.8)"
+ - "install-standalone (windows-2022, pytorch, 3.8)"
diff --git a/.github/workflows/README.md b/.github/workflows/README.md
index 8b9e7d173b03c..4ed903c0f3a93 100644
--- a/.github/workflows/README.md
+++ b/.github/workflows/README.md
@@ -4,16 +4,16 @@
## Unit and Integration Testing
-| workflow name | workflow file | action | accelerator\* | (Python, PyTorch) | OS |
-| -------------------------- | ----------------------------------- | --------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ------------- | ------------------------------------------------ | ------------------- |
-| Test full | .github/workflows/ci_test-full.yml | Run all tests except for accelerator-specific, standalone and slow tests. | CPU | (3.7, 1.8), (3.7, 1.11), (3.9, 1.8), (3.9, 1.12) | linux, mac, windows |
-| Test with Conda | .github/workflows/ci_test-conda.yml | Same as ci_test-full.yml but with dependencies installed with conda. | CPU | (3.8, 1.8), (3.8, 1.9), (3.8, 1.10), (3.9, 1.12) | linux |
-| Test slow | .github/workflows/ci_test-slow.yml | Run only slow tests. Slow tests usually need to spawn threads and cannot be speed up or simplified. | CPU | (3.7, 1.8) | linux, mac, windows |
-| pytorch-lightning (IPUs) | .azure-pipelines/ipu-tests.yml | Run only IPU-specific tests. | IPU | (3.8, 1.9) | linux |
-| pytorch-lightning (HPUs) | .azure-pipelines/hpu-tests.yml | Run only HPU-specific tests. | HPU | (3.8, 1.10) | linux |
-| pytorch-lightning (GPUs) | .azure-pipelines/gpu-tests.yml | Run all CPU and GPU-specific tests, standalone, and examples. Each standalone test needs to be run in separate processes to avoid unwanted interactions between test cases. | GPU | (3.9, 1.12) | linux |
-| PyTorchLightning.Benchmark | .azure-pipelines/gpu-benchmark.yml | Run speed/memory benchmarks for parity with pure PyTorch. | GPU | (3.9, 1.12) | linux |
-| test-on-tpus | .circleci/config.yml | Run only TPU-specific tests. | TPU | (3.7, 1.12) | linux |
+| workflow name | workflow file | action | accelerator\* | (Python, PyTorch) | OS |
+| -------------------------- | ------------------------------------------- | --------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ------------- | ------------------------------------------------- | ------------------- |
+| Test PyTorch full | .github/workflows/ci-pytorch-test-full.yml | Run all tests except for accelerator-specific, standalone and slow tests. | CPU | (3.7, 1.9), (3.7, 1.12), (3.9, 1.9), (3.9, 1.12) | linux, mac, windows |
+| Test PyTorch with Conda | .github/workflows/ci-pytorch-test-conda.yml | Same as ci-pytorch-test-full.yml but with dependencies installed with conda. | CPU | (3.8, 1.9), (3.8, 1.10), (3.8, 1.11), (3.9, 1.12) | linux |
+| Test slow | .github/workflows/ci-pytorch-test-slow.yml | Run only slow tests. Slow tests usually need to spawn threads and cannot be speed up or simplified. | CPU | (3.7, 1.11) | linux, mac, windows |
+| pytorch-lightning (IPUs) | .azure-pipelines/ipu-tests.yml | Run only IPU-specific tests. | IPU | (3.8, 1.9) | linux |
+| pytorch-lightning (HPUs) | .azure-pipelines/hpu-tests.yml | Run only HPU-specific tests. | HPU | (3.8, 1.10) | linux |
+| pytorch-lightning (GPUs) | .azure-pipelines/gpu-tests.yml | Run all CPU and GPU-specific tests, standalone, and examples. Each standalone test needs to be run in separate processes to avoid unwanted interactions between test cases. | GPU | (3.9, 1.12) | linux |
+| PyTorchLightning.Benchmark | .azure-pipelines/gpu-benchmark.yml | Run speed/memory benchmarks for parity with pure PyTorch. | GPU | (3.9, 1.12) | linux |
+| test-on-tpus | .circleci/config.yml | Run only TPU-specific tests. | TPU | (3.7, 1.12) | linux |
- \*Accelerators used in CI
- GPU: 2 x NVIDIA Tesla V100
@@ -33,15 +33,15 @@
| --------------------------------- | ----------------------------------------------------------------------------------------- |
| .codecov.yml | Measure test coverage with [codecov.io](https://app.codecov.io/gh/Lightning-AI/lightning) |
| .github/workflows/code-checks.yml | Check Python typing with [MyPy](https://mypy.readthedocs.io/en/stable/). |
-| .github/workflows/ci_schema.yml | Validate the syntax of workflow files. |
+| .github/workflows/ci-schema.yml | Validate the syntax of workflow files. |
## Others
-| workflow file | action |
-| -------------------------------------- | ----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
-| .github/workflows/ci_dockers.yml | Build docker images used for testing in CI without pushing to the [Docker Hub](https://hub.docker.com/r/pytorchlightning/pytorch_lightning). Publishing these built images takes place in `.github/workflows/release-docker.yml` which only runs in master. |
-| .github/workflows/ci_pkg-install.yml | Test if pytorch-lightning is successfully installed using pip. |
-| .github/workflows/events-recurrent.yml | Terminate TPU jobs that live more than one hour to avoid possible resource exhaustion due to hangs. |
+| workflow file | action |
+| ------------------------------------------ | -------------------------------------------------------------------------------------------------------------------------------------------------------------- |
+| .github/workflows/cicd-pytorch-dockers.yml | Build docker images used for testing in CI. If run on nightly schedule, push to the [Docker Hub](https://hub.docker.com/r/pytorchlightning/pytorch_lightning). |
+| .github/workflows/ci-pkg-install.yml | Test if pytorch-lightning is successfully installed using pip. |
+| .github/workflows/events-recurrent.yml | Terminate TPU jobs that live more than one hour to avoid possible resource exhaustion due to hangs. |
## Deployment
@@ -60,4 +60,4 @@
| .github/stale.yml | Close inactive issues/PRs sometimes after adding the "won't fix" label to them. |
| .github/workflows/probot-auto-cc.yml, .github/lightning-probot.yml | Notify maintainers of interest depending on labels added to an issue We utilize lightning-probot forked from PyTorch’s probot. |
| .pre-commit-config.yaml | pre-commit.ci runs a set of linters and formatters, such as black, flake8 and isort. When formatting is applied, the bot pushes a commit with its change. This configuration is also used for running pre-commit locally. |
-| .github/workflows/ci_pr-gatekeeper.yml | Prevent PRs from merging into master without any Grid.ai employees’ approval. |
+| .github/workflows/ci-pr-gatekeeper.yml | Prevent PRs from merging into master without any Grid.ai employees’ approval. |
diff --git a/.github/workflows/ci-app_cloud_e2e_test.yml b/.github/workflows/ci-app-cloud-e2e-test.yml
similarity index 99%
rename from .github/workflows/ci-app_cloud_e2e_test.yml
rename to .github/workflows/ci-app-cloud-e2e-test.yml
index 3ad455650a117..9a5a10a95cd33 100644
--- a/.github/workflows/ci-app_cloud_e2e_test.yml
+++ b/.github/workflows/ci-app-cloud-e2e-test.yml
@@ -54,7 +54,7 @@ jobs:
- custom_work_dependencies
- drive
- payload
- - commands
+ - commands_and_api
timeout-minutes: 35
steps:
- uses: actions/checkout@v2
diff --git a/.github/workflows/ci-app_examples.yml b/.github/workflows/ci-app-examples.yml
similarity index 98%
rename from .github/workflows/ci-app_examples.yml
rename to .github/workflows/ci-app-examples.yml
index ec8becd5f70d1..01570f59c2c77 100644
--- a/.github/workflows/ci-app_examples.yml
+++ b/.github/workflows/ci-app-examples.yml
@@ -17,7 +17,7 @@ jobs:
strategy:
fail-fast: false
matrix:
- os: [ubuntu-20.04, macOS-11, windows-2019]
+ os: [ubuntu-20.04, macOS-11, windows-2022]
python-version: [3.8]
requires: ["oldest", "latest"]
diff --git a/.github/workflows/ci-app_tests.yml b/.github/workflows/ci-app-tests.yml
similarity index 96%
rename from .github/workflows/ci-app_tests.yml
rename to .github/workflows/ci-app-tests.yml
index 1678dab257301..fe3cc36dc16d3 100644
--- a/.github/workflows/ci-app_tests.yml
+++ b/.github/workflows/ci-app-tests.yml
@@ -21,7 +21,7 @@ jobs:
strategy:
fail-fast: false
matrix:
- os: [ubuntu-20.04, macOS-11, windows-2019]
+ os: [ubuntu-20.04, macOS-11, windows-2022]
python-version: [3.8]
requires: ["oldest", "latest"]
@@ -126,7 +126,7 @@ jobs:
# - name: Clone Quick Start Example Repo
# uses: actions/checkout@v3
# # TODO: this needs to be git submodule
-# if: matrix.os == 'windows-2019' # because the install doesn't work on windows
+# if: matrix.os == 'windows-2022' # because the install doesn't work on windows
# with:
# repository: Lightning-AI/lightning-quick-start
# ref: 'main'
@@ -134,6 +134,6 @@ jobs:
#
# - name: Lightning Install quick-start
# shell: bash
-# if: matrix.os != 'windows-2019' # because the install doesn't work on windows
+# if: matrix.os != 'windows-2022' # because the install doesn't work on windows
# run: |
# python -m lightning install app lightning/quick-start -y
diff --git a/.github/workflows/ci_pkg-install.yml b/.github/workflows/ci-pkg-install.yml
similarity index 95%
rename from .github/workflows/ci_pkg-install.yml
rename to .github/workflows/ci-pkg-install.yml
index 342e027b07cfe..a9fdd36693a67 100644
--- a/.github/workflows/ci_pkg-install.yml
+++ b/.github/workflows/ci-pkg-install.yml
@@ -33,7 +33,7 @@ jobs:
fail-fast: true
max-parallel: 1
matrix:
- os: [ubuntu-20.04, macOS-11, windows-2019]
+ os: [ubuntu-20.04, macOS-11, windows-2022]
pkg: ["app", "pytorch"]
python-version: [3.8] # , 3.9
@@ -67,7 +67,7 @@ jobs:
fail-fast: false
# max-parallel: 1
matrix:
- os: [ubuntu-20.04, macOS-11, windows-2019]
+ os: [ubuntu-20.04, macOS-11, windows-2022]
pkg: ["", "lightning"]
python-version: [3.8] # , 3.9
@@ -100,7 +100,7 @@ jobs:
fail-fast: false
# max-parallel: 1
matrix:
- os: [ubuntu-20.04, macOS-11, windows-2019]
+ os: [ubuntu-20.04, macOS-11, windows-2022]
python-version: [3.8] # , 3.9
steps:
diff --git a/.github/workflows/ci_pr-gatekeeper.yml b/.github/workflows/ci-pr-gatekeeper.yml
similarity index 100%
rename from .github/workflows/ci_pr-gatekeeper.yml
rename to .github/workflows/ci-pr-gatekeeper.yml
diff --git a/.github/workflows/ci-pytorch_test-conda.yml b/.github/workflows/ci-pytorch-test-conda.yml
similarity index 98%
rename from .github/workflows/ci-pytorch_test-conda.yml
rename to .github/workflows/ci-pytorch-test-conda.yml
index 777ec2af759a0..2bbdb699c2c1e 100644
--- a/.github/workflows/ci-pytorch_test-conda.yml
+++ b/.github/workflows/ci-pytorch-test-conda.yml
@@ -22,13 +22,11 @@ jobs:
strategy:
fail-fast: false
matrix:
- # nightly: add when there's a release candidate
include:
- {python-version: "3.8", pytorch-version: "1.9"}
- {python-version: "3.8", pytorch-version: "1.10"}
- {python-version: "3.9", pytorch-version: "1.11"}
- {python-version: "3.9", pytorch-version: "1.12"}
-
timeout-minutes: 30
steps:
@@ -45,7 +43,7 @@ jobs:
id: skip
shell: bash -l {0}
run: |
- FILTER='src/pytorch_lightning|requirements/pytorch|tests/tests_pytorch|examples/pl_*'
+ FILTER='src/pytorch_lightning|requirements/pytorch|tests/tests_pytorch|examples/pl_*|.github/workflows/ci-pytorch-test-conda.yml'
echo "${{ steps.changed-files.outputs.all_changed_files }}" | tr " " "\n" > changed_files.txt
MATCHES=$(cat changed_files.txt | grep -E $FILTER)
echo $MATCHES
diff --git a/.github/workflows/ci-pytorch_test-full.yml b/.github/workflows/ci-pytorch-test-full.yml
similarity index 95%
rename from .github/workflows/ci-pytorch_test-full.yml
rename to .github/workflows/ci-pytorch-test-full.yml
index fb6916d1414fe..7409ce25a5128 100644
--- a/.github/workflows/ci-pytorch_test-full.yml
+++ b/.github/workflows/ci-pytorch-test-full.yml
@@ -20,10 +20,14 @@ jobs:
strategy:
fail-fast: false
matrix:
- os: [ubuntu-20.04, windows-2019, macOS-11]
- python-version: ["3.7", "3.9"] # minimum, maximum
+ os: [ubuntu-20.04, windows-2022, macOS-11]
+ python-version: ["3.7", "3.10"] # minimum, maximum
requires: ["oldest", "latest"]
release: ["stable"]
+ exclude:
+ # There's no distribution of the oldest PyTorch 1.9 for Python 3.10.
+ # TODO: Remove the exclusion when dropping PyTorch 1.9 support.
+ - {python-version: "3.10", requires: "oldest"}
# TODO: re-enable RC testing
# include:
# - {os: ubuntu-20.04, python-version: "3.10", requires: "latest", release: "pre"}
@@ -41,7 +45,7 @@ jobs:
id: skip
shell: bash -l {0}
run: |
- FILTER='src/pytorch_lightning|requirements/pytorch|tests/tests_pytorch|examples/pl_*'
+ FILTER='src/pytorch_lightning|requirements/pytorch|tests/tests_pytorch|examples/pl_*|.github/workflows/ci-pytorch_test-full.yml'
echo "${{ steps.changed-files.outputs.all_changed_files }}" | tr " " "\n" > changed_files.txt
MATCHES=$(cat changed_files.txt | grep -E $FILTER)
echo $MATCHES
diff --git a/.github/workflows/ci-pytorch_test-slow.yml b/.github/workflows/ci-pytorch-test-slow.yml
similarity index 97%
rename from .github/workflows/ci-pytorch_test-slow.yml
rename to .github/workflows/ci-pytorch-test-slow.yml
index 905f60aa85699..36007d3311451 100644
--- a/.github/workflows/ci-pytorch_test-slow.yml
+++ b/.github/workflows/ci-pytorch-test-slow.yml
@@ -19,7 +19,7 @@ jobs:
strategy:
fail-fast: false
matrix:
- os: [ubuntu-20.04, windows-2019, macOS-11]
+ os: [ubuntu-20.04, windows-2022, macOS-11]
# same config as '.azure-pipelines/gpu-tests.yml'
python-version: ["3.7"]
pytorch-version: ["1.11"]
@@ -36,7 +36,7 @@ jobs:
id: skip
shell: bash -l {0}
run: |
- FILTER='src/pytorch_lightning|requirements/pytorch|tests/tests_pytorch|examples/pl_*'
+ FILTER='src/pytorch_lightning|requirements/pytorch|tests/tests_pytorch|examples/pl_*|.github/workflows/ci-pytorch_test-slow.yml'
echo "${{ steps.changed-files.outputs.all_changed_files }}" | tr " " "\n" > changed_files.txt
MATCHES=$(cat changed_files.txt | grep -E $FILTER)
echo $MATCHES
diff --git a/.github/workflows/ci_schema.yml b/.github/workflows/ci-schema.yml
similarity index 100%
rename from .github/workflows/ci_schema.yml
rename to .github/workflows/ci-schema.yml
diff --git a/.github/workflows/cicd-pytorch_dockers.yml b/.github/workflows/cicd-pytorch-dockers.yml
similarity index 81%
rename from .github/workflows/cicd-pytorch_dockers.yml
rename to .github/workflows/cicd-pytorch-dockers.yml
index a6ba2ac4aa5f4..84051cafd82d8 100644
--- a/.github/workflows/cicd-pytorch_dockers.yml
+++ b/.github/workflows/cicd-pytorch-dockers.yml
@@ -29,17 +29,22 @@ jobs:
strategy:
fail-fast: false
matrix:
- # the config used in '.azure-pipelines/gpu-tests.yml' since the Dockerfile uses the cuda image
- python_version: ["3.9"]
- pytorch_version: ["1.12"]
+ include:
+ # We only release one docker image per PyTorch version.
+ # The matrix here is the same as the one in release-docker.yml.
+ - {python_version: "3.9", pytorch_version: "1.9", cuda_version: "11.1.1"}
+ - {python_version: "3.9", pytorch_version: "1.10", cuda_version: "11.3.1"}
+ - {python_version: "3.9", pytorch_version: "1.11", cuda_version: "11.3.1"}
+ - {python_version: "3.9", pytorch_version: "1.12", cuda_version: "11.3.1"}
steps:
- - uses: actions/checkout@v2
+ - uses: actions/checkout@v3
- uses: docker/setup-buildx-action@v2
- - uses: docker/build-push-action@v2
+ - uses: docker/build-push-action@v3
with:
build-args: |
PYTHON_VERSION=${{ matrix.python_version }}
PYTORCH_VERSION=${{ matrix.pytorch_version }}
+ CUDA_VERSION=${{ matrix.cuda_version }}
file: dockers/release/Dockerfile
push: false # pushed in release-docker.yml only when PL is released
timeout-minutes: 50
@@ -53,14 +58,14 @@ jobs:
python_version: ["3.7"]
xla_version: ["1.12"]
steps:
- - uses: actions/checkout@v2
+ - uses: actions/checkout@v3
- uses: docker/setup-buildx-action@v2
- - uses: docker/login-action@v1
+ - uses: docker/login-action@v2
if: env.PUSH_TO_HUB == 'true'
with:
username: ${{ secrets.DOCKER_USERNAME }}
password: ${{ secrets.DOCKER_PASSWORD }}
- - uses: docker/build-push-action@v2
+ - uses: docker/build-push-action@v3
with:
build-args: |
PYTHON_VERSION=${{ matrix.python_version }}
@@ -85,30 +90,31 @@ jobs:
fail-fast: false
matrix:
include:
- # the config used in '.azure-pipelines/gpu-tests.yml'
- - {python_version: "3.9", pytorch_version: "1.12", cuda_version: "11.3.1", ubuntu_version: "20.04"}
- # latest (used in Tutorials)
- - {python_version: "3.8", pytorch_version: "1.9", cuda_version: "11.1.1", ubuntu_version: "20.04"}
- - {python_version: "3.9", pytorch_version: "1.10", cuda_version: "11.1.1", ubuntu_version: "20.04"}
- - {python_version: "3.9", pytorch_version: "1.11", cuda_version: "11.3.1", ubuntu_version: "20.04"}
+ # These are the base images for PL release docker images,
+ # so include at least all of the combinations in release-dockers.yml.
+ - {python_version: "3.9", pytorch_version: "1.9", cuda_version: "11.1.1"}
+ - {python_version: "3.9", pytorch_version: "1.10", cuda_version: "11.3.1"}
+ - {python_version: "3.9", pytorch_version: "1.11", cuda_version: "11.3.1"}
+ - {python_version: "3.9", pytorch_version: "1.12", cuda_version: "11.3.1"}
+ # Used in Lightning-AI/tutorials
+ - {python_version: "3.8", pytorch_version: "1.9", cuda_version: "11.1.1"}
steps:
- - uses: actions/checkout@v2
+ - uses: actions/checkout@v3
- uses: docker/setup-buildx-action@v2
- - uses: docker/login-action@v1
+ - uses: docker/login-action@v2
if: env.PUSH_TO_HUB == 'true'
with:
username: ${{ secrets.DOCKER_USERNAME }}
password: ${{ secrets.DOCKER_PASSWORD }}
- - uses: docker/build-push-action@v2
+ - uses: docker/build-push-action@v3
with:
build-args: |
PYTHON_VERSION=${{ matrix.python_version }}
PYTORCH_VERSION=${{ matrix.pytorch_version }}
CUDA_VERSION=${{ matrix.cuda_version }}
- UBUNTU_VERSION=${{ matrix.ubuntu_version }}
file: dockers/base-cuda/Dockerfile
push: ${{ env.PUSH_TO_HUB }}
- tags: pytorchlightning/pytorch_lightning:base-cuda-py${{ matrix.python_version }}-torch${{ matrix.pytorch_version }}
+ tags: pytorchlightning/pytorch_lightning:base-cuda-py${{ matrix.python_version }}-torch${{ matrix.pytorch_version }}-cuda${{ matrix.cuda_version }}
timeout-minutes: 95
- uses: ravsamhq/notify-slack-action@v1
if: failure() && env.PUSH_TO_HUB == 'true'
@@ -126,25 +132,23 @@ jobs:
fail-fast: false
matrix:
include:
- - {python_version: "3.8", pytorch_version: "1.9", cuda_version: "11.1.1"}
- - {python_version: "3.8", pytorch_version: "1.10", cuda_version: "11.1.1"}
- - {python_version: "3.9", pytorch_version: "1.11", cuda_version: "11.3.1"}
- # nightly: add when there's a release candidate
- # - {python_version: "3.9", pytorch_version: "1.12"}
+ - {python_version: "3.8", pytorch_version: "1.9"}
+ - {python_version: "3.8", pytorch_version: "1.10"}
+ - {python_version: "3.9", pytorch_version: "1.11"}
+ - {python_version: "3.9", pytorch_version: "1.12"}
steps:
- - uses: actions/checkout@v2
+ - uses: actions/checkout@v3
- uses: docker/setup-buildx-action@v2
- - uses: docker/login-action@v1
+ - uses: docker/login-action@v2
if: env.PUSH_TO_HUB == 'true'
with:
username: ${{ secrets.DOCKER_USERNAME }}
password: ${{ secrets.DOCKER_PASSWORD }}
- - uses: docker/build-push-action@v2
+ - uses: docker/build-push-action@v3
with:
build-args: |
PYTHON_VERSION=${{ matrix.python_version }}
PYTORCH_VERSION=${{ matrix.pytorch_version }}
- CUDA_VERSION=${{ matrix.cuda_version }}
file: dockers/base-conda/Dockerfile
push: ${{ env.PUSH_TO_HUB }}
tags: pytorchlightning/pytorch_lightning:base-conda-py${{ matrix.python_version }}-torch${{ matrix.pytorch_version }}
@@ -168,14 +172,14 @@ jobs:
# the config used in 'dockers/ci-runner-ipu/Dockerfile'
- {python_version: "3.9", pytorch_version: "1.9"}
steps:
- - uses: actions/checkout@v2
+ - uses: actions/checkout@v3
- uses: docker/setup-buildx-action@v2
- - uses: docker/login-action@v1
+ - uses: docker/login-action@v2
if: env.PUSH_TO_HUB == 'true'
with:
username: ${{ secrets.DOCKER_USERNAME }}
password: ${{ secrets.DOCKER_PASSWORD }}
- - uses: docker/build-push-action@v2
+ - uses: docker/build-push-action@v3
with:
build-args: |
PYTHON_VERSION=${{ matrix.python_version }}
@@ -184,7 +188,7 @@ jobs:
push: ${{ env.PUSH_TO_HUB }}
tags: pytorchlightning/pytorch_lightning:base-ipu-py${{ matrix.python_version }}-torch${{ matrix.pytorch_version }}
timeout-minutes: 100
- - uses: docker/build-push-action@v2
+ - uses: docker/build-push-action@v3
with:
build-args: |
PYTHON_VERSION=${{ matrix.python_version }}
@@ -199,7 +203,7 @@ jobs:
status: ${{ job.status }}
token: ${{ secrets.GITHUB_TOKEN }}
notification_title: ${{ format('IPU; {0} py{1} for *{2}*', runner.os, matrix.python_version, matrix.pytorch_version) }}
- message_format: '{emoji} *{workflow}* {status_message}, see <{run_url}|detail>, cc: <@U01BULUS2BG>' # SeanNaren
+ message_format: '{emoji} *{workflow}* {status_message}, see <{run_url}|detail>, cc: <@U01GD29QCAV>' # kaushikb11
env:
SLACK_WEBHOOK_URL: ${{ secrets.SLACK_WEBHOOK_URL }}
@@ -212,14 +216,14 @@ jobs:
# the config used in 'dockers/ci-runner-hpu/Dockerfile'
- {gaudi_version: "1.5.0", pytorch_version: "1.11.0"}
steps:
- - uses: actions/checkout@v2
+ - uses: actions/checkout@v3
- uses: docker/setup-buildx-action@v2
- - uses: docker/login-action@v1
+ - uses: docker/login-action@v2
if: env.PUSH_TO_HUB == 'true'
with:
username: ${{ secrets.DOCKER_USERNAME }}
password: ${{ secrets.DOCKER_PASSWORD }}
- - uses: docker/build-push-action@v2
+ - uses: docker/build-push-action@v3
with:
build-args: |
DIST=latest
@@ -243,10 +247,10 @@ jobs:
runs-on: ubuntu-20.04
steps:
- name: Checkout
- uses: actions/checkout@v2
+ uses: actions/checkout@v3
- name: Build Conda Docker
# publish master/release
- uses: docker/build-push-action@v2
+ uses: docker/build-push-action@v3
with:
file: dockers/nvidia/Dockerfile
push: false
diff --git a/.github/workflows/code-checks.yml b/.github/workflows/code-checks.yml
index ed9cd46adbe44..15bd5e9911740 100644
--- a/.github/workflows/code-checks.yml
+++ b/.github/workflows/code-checks.yml
@@ -32,8 +32,9 @@ jobs:
- name: Install dependencies
run: |
- pip install torch==1.11 --find-links https://download.pytorch.org/whl/cpu/torch_stable.html
+ pip install torch==1.12 --find-links https://download.pytorch.org/whl/cpu/torch_stable.html
python ./requirements/pytorch/adjust-versions.py requirements/pytorch/extra.txt
+ # todo: adjust requirements for both code-bases
pip install -r requirements/pytorch/devel.txt --find-links https://download.pytorch.org/whl/cpu/torch_stable.html
pip list
diff --git a/.github/workflows/release-docker.yml b/.github/workflows/release-docker.yml
index 9d87f1a582fb1..6901a24204683 100644
--- a/.github/workflows/release-docker.yml
+++ b/.github/workflows/release-docker.yml
@@ -1,6 +1,5 @@
name: Docker
-# https://www.docker.com/blog/first-docker-github-action-is-here
-# https://github.com/docker/build-push-action
+
on:
push:
branches: [master, "release/*"]
@@ -15,8 +14,12 @@ jobs:
strategy:
fail-fast: false
matrix:
- python_version: ["3.7", "3.8", "3.9"]
- pytorch_version: ["1.9", "1.10"]
+ include:
+ # We only release one docker image per PyTorch version.
+ - {python_version: "3.9", pytorch_version: "1.9", cuda_version: "11.1.1"}
+ - {python_version: "3.9", pytorch_version: "1.10", cuda_version: "11.3.1"}
+ - {python_version: "3.9", pytorch_version: "1.11", cuda_version: "11.3.1"}
+ - {python_version: "3.9", pytorch_version: "1.12", cuda_version: "11.3.1"}
steps:
- name: Checkout
uses: actions/checkout@v2
@@ -32,19 +35,29 @@ jobs:
username: ${{ secrets.DOCKER_USERNAME }}
password: ${{ secrets.DOCKER_PASSWORD }}
dockerfile: dockers/release/Dockerfile
- build_args: PYTHON_VERSION=${{ matrix.python_version }},PYTORCH_VERSION=${{ matrix.pytorch_version }},LIGHTNING_VERSION=${{ steps.get_version.outputs.RELEASE_VERSION }}
- tags: "${{ steps.get_version.outputs.RELEASE_VERSION }}-py${{ matrix.python_version }}-torch${{ matrix.pytorch_version }},latest-py${{ matrix.python_version }}-torch${{ matrix.pytorch_version }}"
+ build_args: |
+ PYTHON_VERSION=${{ matrix.python_version }}
+ PYTORCH_VERSION=${{ matrix.pytorch_version }}
+ CUDA_VERSION=${{ matrix.cuda_version }}
+ LIGHTNING_VERSION=${{ steps.get_version.outputs.RELEASE_VERSION }}
+ tags: |
+ ${{ steps.get_version.outputs.RELEASE_VERSION }}-py${{ matrix.python_version }}-torch${{ matrix.pytorch_version }}-cuda${{ matrix.cuda_version }}
+ latest-py${{ matrix.python_version }}-torch${{ matrix.pytorch_version }}-cuda${{ matrix.cuda_version }}
timeout-minutes: 55
- name: Publish Latest to Docker
uses: docker/build-push-action@v1.1.0
- # only on releases and latest Python and PyTorch
- if: matrix.python_version == '3.9' && matrix.pytorch_version == '1.10'
+ # Only latest Python and PyTorch
+ if: matrix.python_version == '3.9' && matrix.pytorch_version == '1.12'
with:
repository: pytorchlightning/pytorch_lightning
username: ${{ secrets.DOCKER_USERNAME }}
password: ${{ secrets.DOCKER_PASSWORD }}
dockerfile: dockers/release/Dockerfile
- build_args: PYTHON_VERSION=${{ matrix.python_version }},PYTORCH_VERSION=${{ matrix.pytorch_version }},LIGHTNING_VERSION=${{ steps.get_version.outputs.RELEASE_VERSION }}
+ build_args: |
+ PYTHON_VERSION=${{ matrix.python_version }}
+ PYTORCH_VERSION=${{ matrix.pytorch_version }}
+ CUDA_VERSION=${{ matrix.cuda_version }}
+ LIGHTNING_VERSION=${{ steps.get_version.outputs.RELEASE_VERSION }}
tags: "latest"
timeout-minutes: 55
diff --git a/.gitignore b/.gitignore
index 719f291a492ca..259d9f271189c 100644
--- a/.gitignore
+++ b/.gitignore
@@ -165,3 +165,9 @@ hars*
artifacts/*
*docs/examples*
*docs/source-app/api*
+
+# tutorials
+our_model.tar
+test.png
+saved_models
+data/
diff --git a/README.md b/README.md
index 2fef343425f17..6f075f5fd42b6 100644
--- a/README.md
+++ b/README.md
@@ -1,7 +1,9 @@
+### \*\* NEWS: PyTorch Lightning has been renamed Lightning! In addition to building models, you can now build research workflows and production pipelines\*\*
+
-**Build high-performance PyTorch models and deploy them with Lightning Apps (scalable end-to-end ML systems).**
+**Build high-performance (PyTorch) models, research workflows, ML production pipelines.**
______________________________________________________________________
@@ -80,21 +82,24 @@ ______________________________________________________________________
## Continuous Integration
-Lightning is rigorously tested across multiple GPUs, TPUs CPUs and against major Python and PyTorch versions.
+Lightning is rigorously tested across multiple CPUs, GPUs, TPUs, IPUs, and HPUs and against major Python and PyTorch versions.
Current build statuses
-| System / PyTorch ver. | 1.8 (LTS, min. req.) | 1.9 | 1.10 (latest) |
-| :------------------------: | :-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: | :--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: | :--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: |
-| Linux py3.7 \[GPUs\*\*\] | [![Build Status]()](https://dev.azure.com/Lightning-AI/lightning/_build/latest?definitionId=6&branchName=master) | - | - |
-| Linux py3.7 \[TPUs\*\*\*\] | [![CircleCI](https://circleci.com/gh/Lightning-AI/lightning/tree/master.svg?style=svg)](https://circleci.com/gh/Lightning-AI/lightning/tree/master) | - | - |
-| Linux py3.8 (with Conda | [![Test](https://github.com/Lightning-AI/lightning/actions/workflows/ci_test-conda.yml/badge.svg?branch=master&event=push)](https://github.com/Lightning-AI/lightning/actions/workflows/ci_test-conda.yml) | [![Test](https://github.com/Lightning-AI/lightning/actions/workflows/ci_test-conda.yml/badge.svg?branch=master&event=push)](https://github.com/Lightning-AI/lightning/actions/workflows/ci_test-conda.yml) | [![Test](https://github.com/Lightning-AI/lightning/actions/workflows/ci_test-conda.yml/badge.svg?branch=master&event=push)](https://github.com/Lightning-AI/lightning/actions/workflows/ci_test-conda.yml) |
-| Linux py3.{7,9} | - | - | [![Test](https://github.com/Lightning-AI/lightning/actions/workflows/ci_test-full.yml/badge.svg?branch=master&event=push)](https://github.com/Lightning-AI/lightning/actions/workflows/ci_test-full.yml) |
-| OSX py3.{7,9} | - | - | [![Test](https://github.com/Lightning-AI/lightning/actions/workflows/ci_test-full.yml/badge.svg?branch=master&event=push)](https://github.com/Lightning-AI/lightning/actions/workflows/ci_test-full.yml) |
-| Windows py3.{7,9} | - | - | [![Test](https://github.com/Lightning-AI/lightning/actions/workflows/ci_test-full.yml/badge.svg?branch=master&event=push)](https://github.com/Lightning-AI/lightning/actions/workflows/ci_test-full.yml) |
+| System / PyTorch ver. | 1.9 | 1.10 | 1.12 (latest) |
+| :------------------------: | :-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: | :-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: | :------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: |
+| Linux py3.7 \[GPUs\*\*\] | - | - | - |
+| Linux py3.7 \[TPUs\*\*\*\] | [![CircleCI](https://circleci.com/gh/Lightning-AI/lightning/tree/master.svg?style=svg)](https://circleci.com/gh/Lightning-AI/lightning/tree/master) | - | - |
+| Linux py3.8 \[IPUs\] | [![Build Status]()](https://dev.azure.com/Lightning-AI/lightning/_build/latest?definitionId=25&branchName=master) | - | - |
+| Linux py3.8 \[HPUs\] | - | [![Build Status]()](https://dev.azure.com/Lightning-AI/lightning/_build/latest?definitionId=26&branchName=master) | - |
+| Linux py3.8 (with Conda) | [![Test](https://github.com/Lightning-AI/lightning/actions/workflows/ci-pytorch-test-conda.yml/badge.svg?branch=master&event=push)](https://github.com/Lightning-AI/lightning/actions/workflows/ci-pytorch-test-conda.yml) | [![Test](https://github.com/Lightning-AI/lightning/actions/workflows/ci-pytorch-test-conda.yml/badge.svg?branch=master&event=push)](https://github.com/Lightning-AI/lightning/actions/workflows/ci-pytorch-test-conda.yml) | - |
+| Linux py3.9 (with Conda) | - | - | [![Test](https://github.com/Lightning-AI/lightning/actions/workflows/ci-pytorch-test-conda.yml/badge.svg?branch=master&event=push)](https://github.com/Lightning-AI/lightning/actions/workflows/ci-pytorch-test-conda.yml) |
+| Linux py3.{7,9} | - | - | [![Test](https://github.com/Lightning-AI/lightning/actions/workflows/ci-pytorch-test-full.yml/badge.svg?branch=master&event=push)](https://github.com/Lightning-AI/lightning/actions/workflows/ci-pytorch-test-full.yml) |
+| OSX py3.{7,9} | - | - | [![Test](https://github.com/Lightning-AI/lightning/actions/workflows/ci-pytorch-test-full.yml/badge.svg?branch=master&event=push)](https://github.com/Lightning-AI/lightning/actions/workflows/ci-pytorch-test-full.yml) |
+| Windows py3.{7,9} | - | - | [![Test](https://github.com/Lightning-AI/lightning/actions/workflows/ci-pytorch-test-full.yml/badge.svg?branch=master&event=push)](https://github.com/Lightning-AI/lightning/actions/workflows/ci-pytorch-test-full.yml) |
- _\*\* tests run on two NVIDIA P100_
- _\*\*\* tests run on Google GKE TPUv2/3. TPU py3.7 means we support Colab and Kaggle env._
@@ -136,8 +141,8 @@ conda install pytorch-lightning -c conda-forge
The actual status of 1.7 \[stable\] is the following:
-[![Test PyTorch full](https://github.com/Lightning-AI/lightning/actions/workflows/ci-pytorch_test-full.yml/badge.svg?branch=release%2Fpytorch&event=push)](https://github.com/Lightning-AI/lightning/actions/workflows/ci-pytorch_test-full.yml?query=branch%3Arelease%2Fpytorch)
-[![Test PyTorch with Conda](https://github.com/Lightning-AI/lightning/actions/workflows/ci-pytorch_test-conda.yml/badge.svg?branch=release%2Fpytorch&event=push)](https://github.com/Lightning-AI/lightning/actions/workflows/ci-pytorch_test-conda.yml?query=branch%3Arelease%2Fpytorch)
+[![Test PyTorch full](https://github.com/Lightning-AI/lightning/actions/workflows/ci-pytorch-test-full.yml/badge.svg?branch=release%2Fpytorch&event=push)](https://github.com/Lightning-AI/lightning/actions/workflows/ci-pytorch-test-full.yml?query=branch%3Arelease%2Fpytorch)
+[![Test PyTorch with Conda](https://github.com/Lightning-AI/lightning/actions/workflows/ci-pytorch-test-conda.yml/badge.svg?branch=release%2Fpytorch&event=push)](https://github.com/Lightning-AI/lightning/actions/workflows/ci-pytorch-test-conda.yml?query=branch%3Arelease%2Fpytorch)
[![TPU tests](https://dl.circleci.com/status-badge/img/gh/Lightning-AI/lightning/tree/release%2Fpytorch.svg?style=shield)](https://dl.circleci.com/status-badge/redirect/gh/Lightning-AI/lightning/tree/release%2Fpytorch)
[![Check Docs](https://github.com/Lightning-AI/lightning/actions/workflows/docs-checks.yml/badge.svg?branch=release%2Fpytorch&event=push)](https://github.com/Lightning-AI/lightning/actions/workflows/docs-checks.yml?query=branch%3Arelease%2Fpytorch)
diff --git a/SECURITY.md b/SECURITY.md
index 8f265f26be452..862563f84e2fe 100644
--- a/SECURITY.md
+++ b/SECURITY.md
@@ -1,2 +1,2 @@
-developer@grid.ai
+developer@lightning.ai
developer@pytorchlightning.ai
diff --git a/dockers/README.md b/dockers/README.md
index 533c85739f528..b1ff9826b6c1f 100644
--- a/dockers/README.md
+++ b/dockers/README.md
@@ -1,36 +1,17 @@
# Docker images
-## Builds images form attached Dockerfiles
+## Build images from Dockerfiles
You can build it on your own, note it takes lots of time, be prepared.
```bash
-git clone
-docker image build -t pytorch-lightning:latest -f dockers/conda/Dockerfile .
-```
-
-or with specific arguments
-
-```bash
-git clone
-docker image build \
- -t pytorch-lightning:base-cuda-py3.9-pt1.10 \
- -f dockers/base-cuda/Dockerfile \
- --build-arg PYTHON_VERSION=3.9 \
- --build-arg PYTORCH_VERSION=1.10 \
- .
-```
+git clone https://github.com/Lightning-AI/lightning.git
-or nightly version from Conda
+# build with the default arguments
+docker image build -t pytorch-lightning:latest -f dockers/base-cuda/Dockerfile .
-```bash
-git clone
-docker image build \
- -t pytorch-lightning:base-conda-py3.9-pt1.11 \
- -f dockers/base-conda/Dockerfile \
- --build-arg PYTHON_VERSION=3.9 \
- --build-arg PYTORCH_VERSION=1.11 \
- .
+# build with specific arguments
+docker image build -t pytorch-lightning:base-cuda-py3.9-torch1.11-cuda11.3.1 -f dockers/base-cuda/Dockerfile --build-arg PYTHON_VERSION=3.9 --build-arg PYTORCH_VERSION=1.11 --build-arg CUDA_VERSION=11.3.1 .
```
To run your docker use
@@ -49,7 +30,7 @@ docker image rm pytorch-lightning:latest
## Run docker image with GPUs
-To run docker image with access to you GPUs you need to install
+To run docker image with access to your GPUs, you need to install
```bash
# Add the package repositories
@@ -61,10 +42,10 @@ sudo apt-get update && sudo apt-get install -y nvidia-container-toolkit
sudo systemctl restart docker
```
-and later run the docker image with `--gpus all` so for example
+and later run the docker image with `--gpus all`. For example,
```
-docker run --rm -it --gpus all pytorchlightning/pytorch_lightning:base-cuda-py3.9-torch1.10
+docker run --rm -it --gpus all pytorchlightning/pytorch_lightning:base-cuda-py3.9-torch1.11-cuda11.3.1
```
## Run Jupyter server
@@ -73,15 +54,11 @@ Inspiration comes from https://u.group/thinking/how-to-put-jupyter-notebooks-in-
1. Build the docker image:
```bash
- docker image build \
- -t pytorch-lightning:v1.3.1 \
- -f dockers/nvidia/Dockerfile \
- --build-arg LIGHTNING_VERSION=1.3.1 \
- .
+ docker image build -t pytorch-lightning:v1.6.5 -f dockers/nvidia/Dockerfile --build-arg LIGHTNING_VERSION=1.6.5 .
```
1. start the server and map ports:
```bash
- docker run --rm -it --runtime=nvidia -e NVIDIA_VISIBLE_DEVICES=all -p 8888:8888 pytorch-lightning:v1.3.1
+ docker run --rm -it --runtime=nvidia -e NVIDIA_VISIBLE_DEVICES=all -p 8888:8888 pytorch-lightning:v1.6.5
```
1. Connect in local browser:
- copy the generated path e.g. `http://hostname:8888/?token=0719fa7e1729778b0cec363541a608d5003e26d4910983c6`
diff --git a/dockers/release/Dockerfile b/dockers/release/Dockerfile
index cb393c91dfbe0..c39e66509188c 100644
--- a/dockers/release/Dockerfile
+++ b/dockers/release/Dockerfile
@@ -14,8 +14,9 @@
ARG PYTHON_VERSION=3.9
ARG PYTORCH_VERSION=1.11
+ARG CUDA_VERSION=11.3.1
-FROM pytorchlightning/pytorch_lightning:base-cuda-py${PYTHON_VERSION}-torch${PYTORCH_VERSION}
+FROM pytorchlightning/pytorch_lightning:base-cuda-py${PYTHON_VERSION}-torch${PYTORCH_VERSION}-cuda${CUDA_VERSION}
LABEL maintainer="Lightning-AI "
diff --git a/docs/source-pytorch/api_references.rst b/docs/source-pytorch/api_references.rst
index db4fc1e2c4cf8..8daed5ddcaf41 100644
--- a/docs/source-pytorch/api_references.rst
+++ b/docs/source-pytorch/api_references.rst
@@ -47,6 +47,20 @@ callbacks
Timer
TQDMProgressBar
+cli
+-----
+
+.. currentmodule:: pytorch_lightning.cli
+
+.. autosummary::
+ :toctree: api
+ :nosignatures:
+ :template: classtemplate.rst
+
+ LightningCLI
+ LightningArgumentParser
+ SaveConfigCallback
+
core
----
diff --git a/examples/app_commands/.lightning b/examples/app_commands_and_api/.lightning
similarity index 100%
rename from examples/app_commands/.lightning
rename to examples/app_commands_and_api/.lightning
diff --git a/examples/app_commands/app.py b/examples/app_commands_and_api/app.py
similarity index 56%
rename from examples/app_commands/app.py
rename to examples/app_commands_and_api/app.py
index 99eb15c75c709..0d15bc531bb38 100644
--- a/examples/app_commands/app.py
+++ b/examples/app_commands_and_api/app.py
@@ -1,15 +1,16 @@
from command import CustomCommand, CustomConfig
from lightning import LightningFlow
+from lightning_app.api import Post
from lightning_app.core.app import LightningApp
class ChildFlow(LightningFlow):
- def trigger_method(self, name: str):
+ def nested_command(self, name: str):
print(f"Hello {name}")
def configure_commands(self):
- return [{"nested_trigger_command": self.trigger_method}]
+ return [{"nested_command": self.nested_command}]
class FlowCommands(LightningFlow):
@@ -19,21 +20,24 @@ def __init__(self):
self.child_flow = ChildFlow()
def run(self):
- if len(self.names):
+ if self.names:
print(self.names)
- def trigger_without_client_command(self, name: str):
+ def command_without_client(self, name: str):
self.names.append(name)
- def trigger_with_client_command(self, config: CustomConfig):
+ def command_with_client(self, config: CustomConfig):
self.names.append(config.name)
def configure_commands(self):
commands = [
- {"trigger_without_client_command": self.trigger_without_client_command},
- {"trigger_with_client_command": CustomCommand(self.trigger_with_client_command)},
+ {"command_without_client": self.command_without_client},
+ {"command_with_client": CustomCommand(self.command_with_client)},
]
return commands + self.child_flow.configure_commands()
+ def configure_api(self):
+ return [Post("/user/command_without_client", self.command_without_client)]
+
app = LightningApp(FlowCommands())
diff --git a/examples/app_commands/command.py b/examples/app_commands_and_api/command.py
similarity index 100%
rename from examples/app_commands/command.py
rename to examples/app_commands_and_api/command.py
diff --git a/pyproject.toml b/pyproject.toml
index 5473e73c52e19..9f7cc28d0b002 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -52,19 +52,13 @@ module = [
"pytorch_lightning.callbacks.progress.rich_progress",
"pytorch_lightning.callbacks.quantization",
"pytorch_lightning.core.datamodule",
- "pytorch_lightning.core.decorators",
"pytorch_lightning.core.module",
- "pytorch_lightning.core.saving",
"pytorch_lightning.demos.boring_classes",
"pytorch_lightning.demos.mnist_datamodule",
"pytorch_lightning.profilers.base",
"pytorch_lightning.profilers.pytorch",
- "pytorch_lightning.profilers.simple",
- "pytorch_lightning.strategies.ddp",
"pytorch_lightning.strategies.sharded",
- "pytorch_lightning.strategies.sharded_spawn",
"pytorch_lightning.trainer.callback_hook",
- "pytorch_lightning.trainer.connectors.callback_connector",
"pytorch_lightning.trainer.connectors.data_connector",
"pytorch_lightning.trainer.supporters",
"pytorch_lightning.trainer.trainer",
diff --git a/requirements/app/base.txt b/requirements/app/base.txt
index 0a0b9cdb4719d..fcde2f18a300a 100644
--- a/requirements/app/base.txt
+++ b/requirements/app/base.txt
@@ -1,9 +1,8 @@
-py
-lightning-cloud==0.5.0
+lightning-cloud==0.5.3
packaging
-deepdiff>=5.7.0
+deepdiff>=5.7.0, <=5.8.1
starsessions
-fsspec>=2022.01.0
-s3fs>=2022.1.0
+fsspec>=2022.01.0, <=2022.7.1
+s3fs>=2022.1.0, <=2022.7.1
croniter # for now until we found something more robust.
traitlets<5.2.0 # Traitlets 5.2.X fails: https://github.com/ipython/traitlets/issues/741
diff --git a/requirements/app/cloud.txt b/requirements/app/cloud.txt
index 5f8bf0c48692f..6644a56a2894b 100644
--- a/requirements/app/cloud.txt
+++ b/requirements/app/cloud.txt
@@ -1,5 +1,4 @@
starsessions
redis>=4.0.0, <=4.2.4
-docker==5.0.3
-setuptools==59.5.0
-s3fs==2022.1.0
+docker>=5.0.0, <=5.0.3
+# setuptools==59.5.0
diff --git a/requirements/app/docs.txt b/requirements/app/docs.txt
index b35cc585b40c7..bf22aef2c2d92 100644
--- a/requirements/app/docs.txt
+++ b/requirements/app/docs.txt
@@ -1,18 +1,17 @@
sphinx>=4.0,<5.0
-myst-parser>=0.15
-nbsphinx>=0.8.5
+myst-parser>=0.15,<0.17
+nbsphinx>=0.8.5, <=0.8.9
ipython[notebook]
ipython_genutils
-pandoc>=1.0
-docutils>=0.16
-sphinxcontrib-fulltoc>=1.0
+pandoc>=1.0, <=2.2
+docutils>=0.16, <0.19
+sphinxcontrib-fulltoc>=1.0, <=1.2.0
sphinxcontrib-mockautodoc
https://storage.googleapis.com/grid-packages/lightning-ai-sphinx-theme/build-31.3.zip
sphinx-autodoc-typehints>=1.0,<1.15 # v1.15 failing on master (#11405)
-sphinx-paramlinks>=0.5.1
-sphinx-togglebutton>=0.2
-sphinx-copybutton>=0.3
+sphinx-paramlinks>=0.5.1, <=0.5.4
+sphinx-togglebutton>=0.2, <=0.3.2
+sphinx-copybutton>=0.3, <=0.5.0
sphinx-autobuild
-typing-extensions # already in `requirements.txt` but the docs CI job does not install it
jinja2>=3.0.0,<3.1.0
diff --git a/requirements/app/test.txt b/requirements/app/test.txt
index 9d2ed0af910ca..ab5ef8f1e85ac 100644
--- a/requirements/app/test.txt
+++ b/requirements/app/test.txt
@@ -1,15 +1,10 @@
-coverage>=5.0
-codecov>=2.1
-pytest>=5.0
-pytest-timeout
-pytest-cov
+coverage>=6.4, <=6.4.2
+codecov>=2.1, <=2.1.12
+pytest>=7.0, <=7.1.2
+pytest-timeout <=2.1.0
+pytest-cov <=3.0.0
playwright==1.22.0
# pytest-flake8
-flake8>=3.0
-check-manifest
-twine>=3.2
-isort>=5.0
-mypy>=0.720
httpx
trio
pympler
diff --git a/requirements/app/ui.txt b/requirements/app/ui.txt
index 28df7f9c2ffe0..f0e4b2cdef471 100644
--- a/requirements/app/ui.txt
+++ b/requirements/app/ui.txt
@@ -1 +1 @@
-streamlit>=1.3.1
+streamlit>=1.3.1, <=1.11.1
diff --git a/requirements/collect_env_details.py b/requirements/collect_env_details.py
index 1d65753a55553..b0c47efc43859 100644
--- a/requirements/collect_env_details.py
+++ b/requirements/collect_env_details.py
@@ -20,27 +20,17 @@
import platform
import sys
-import numpy
+import pkg_resources
import torch
-import tqdm
sys.path += [os.path.abspath(".."), os.path.abspath("")]
-import pytorch_lightning # noqa: E402
-try:
- import lightning
-except ModuleNotFoundError:
- pass
-try:
- import lightning_app
-except ModuleNotFoundError:
- pass
LEVEL_OFFSET = "\t"
KEY_PADDING = 20
-def info_system():
+def info_system() -> dict:
return {
"OS": platform.system(),
"architecture": platform.architecture(),
@@ -50,28 +40,24 @@ def info_system():
}
-def info_cuda():
+def info_cuda() -> dict:
return {
- "GPU": [torch.cuda.get_device_name(i) for i in range(torch.cuda.device_count())],
- # 'nvidia_driver': get_nvidia_driver_version(run_lambda),
+ "GPU": [torch.cuda.get_device_name(i) for i in range(torch.cuda.device_count())] or None,
"available": torch.cuda.is_available(),
"version": torch.version.cuda,
}
-def info_packages():
- return {
- "numpy": numpy.__version__,
- "pyTorch_version": torch.__version__,
- "pyTorch_debug": torch.version.debug,
- "pytorch-lightning": pytorch_lightning.__version__,
- "lightning": lightning.__version__ if "lightning" in sys.modules else None,
- "lightning_app": lightning_app.__version__ if "lightning_app" in sys.modules else None,
- "tqdm": tqdm.__version__,
- }
+def info_packages() -> dict:
+ """Get name and version of all installed packages."""
+ packages = {}
+ for dist in pkg_resources.working_set:
+ package = dist.as_requirement()
+ packages[package.key] = package.specs[0][1]
+ return packages
-def nice_print(details, level=0):
+def nice_print(details: dict, level: int = 0) -> list:
lines = []
for k in sorted(details):
key = f"* {k}:" if level == 0 else f"- {k}:"
@@ -88,8 +74,9 @@ def nice_print(details, level=0):
return lines
-def main():
+def main() -> None:
details = {"System": info_system(), "CUDA": info_cuda(), "Packages": info_packages()}
+ details["Lightning"] = {k: v for k, v in details["Packages"].items() if "torch" in k or "lightning" in k}
lines = nice_print(details)
text = os.linesep.join(lines)
print(text)
diff --git a/requirements/pytorch/base.txt b/requirements/pytorch/base.txt
index e8743b18c73b0..49e2243319206 100644
--- a/requirements/pytorch/base.txt
+++ b/requirements/pytorch/base.txt
@@ -3,7 +3,7 @@
numpy>=1.17.2, <1.23.1
torch>=1.9.*, <=1.12.0
-tqdm>=4.57.0, <=4.63.0
+tqdm>=4.57.0, <4.65.0
PyYAML>=5.4, <=6.0
fsspec[http]>=2021.05.0, !=2021.06.0, <2022.6.0
tensorboard>=2.9.1, <2.10.0
diff --git a/requirements/pytorch/docs.txt b/requirements/pytorch/docs.txt
index e6fbbe322b6bf..50e7c2049f6f6 100644
--- a/requirements/pytorch/docs.txt
+++ b/requirements/pytorch/docs.txt
@@ -1,16 +1,16 @@
sphinx>=4.0,<5.0
myst-parser>=0.15,<0.17
-nbsphinx>=0.8.5
+nbsphinx>=0.8.5, <=0.8.9
ipython[notebook]
-pandoc>=1.0
-docutils>=0.16
-sphinxcontrib-fulltoc>=1.0
+pandoc>=1.0, <=2.2
+docutils>=0.16, <0.19
+sphinxcontrib-fulltoc>=1.0, <=1.2.0
sphinxcontrib-mockautodoc
pt-lightning-sphinx-theme @ https://github.com/Lightning-AI/lightning_sphinx_theme/archive/master.zip
-sphinx-autodoc-typehints>=1.11,<1.15 # v1.15 failing on master (#11405)
-sphinx-paramlinks>=0.5.1
-sphinx-togglebutton>=0.2
-sphinx-copybutton>=0.3
+sphinx-autodoc-typehints>=1.11,<1.15 # strict; v1.15 failing on master (#11405)
+sphinx-paramlinks>=0.5.1, <=0.5.4
+sphinx-togglebutton>=0.2, <=0.3.2
+sphinx-copybutton>=0.3, <=0.5.0
typing-extensions # already in `requirements.txt` but the docs CI job does not install it
jinja2>=3.0.0,<3.1.0
diff --git a/requirements/pytorch/extra.txt b/requirements/pytorch/extra.txt
index c386c5581cc42..20b6c1b8dbc12 100644
--- a/requirements/pytorch/extra.txt
+++ b/requirements/pytorch/extra.txt
@@ -7,5 +7,5 @@ torchtext>=0.10.*, <0.14.0
omegaconf>=2.0.5, <2.3.0
hydra-core>=1.0.5, <1.3.0
jsonargparse[signatures]>=4.12.0, <=4.12.0
-gcsfs>=2021.5.0, <2022.6.0
+gcsfs>=2021.5.0, <2022.8.0
rich>=10.14.0, !=10.15.0.a, <13.0.0
diff --git a/requirements/pytorch/loggers.txt b/requirements/pytorch/loggers.txt
index 48a15c30f842f..df83a077f8457 100644
--- a/requirements/pytorch/loggers.txt
+++ b/requirements/pytorch/loggers.txt
@@ -7,4 +7,4 @@ neptune-client>=0.10.0, <0.16.4
comet-ml>=3.1.12, <3.31.8
mlflow>=1.0.0, <1.28.0
test_tube>=0.7.5, <=0.7.5
-wandb>=0.10.22, <0.12.20
+wandb>=0.10.22, <0.13.2
diff --git a/requirements/pytorch/strategies.txt b/requirements/pytorch/strategies.txt
index 4e916fbc6c61f..c5fc92a67a837 100644
--- a/requirements/pytorch/strategies.txt
+++ b/requirements/pytorch/strategies.txt
@@ -2,7 +2,7 @@
# in case you want to preserve/enforce restrictions on the latest compatible version, add "strict" as an in-line comment
fairscale>=0.4.5, <=0.4.6
-deepspeed>=0.6.0, <0.7.0
+deepspeed>=0.6.0, <=0.7.0
# no need to install with [pytorch] as pytorch is already installed
horovod>=0.21.2, !=0.24.0, <0.25.1
hivemind>=1.0.1, <=1.0.1; sys_platform == 'linux'
diff --git a/requirements/pytorch/test.txt b/requirements/pytorch/test.txt
index ce54cd087b1de..f8bd5793a0af6 100644
--- a/requirements/pytorch/test.txt
+++ b/requirements/pytorch/test.txt
@@ -1,18 +1,17 @@
-coverage>=6.4
-codecov>=2.1
-pytest>=7.0
-pytest-cov
-pytest-forked
+coverage>=6.4, <=6.4.2
+codecov>=2.1, <=2.1.12
+pytest>=7.0, <=7.1.2
+pytest-cov <=3.0.0
+pytest-forked <=1.4.0
pytest-rerunfailures>=10.2
-mypy>=0.920
-flake8>=3.9.2
pre-commit>=1.0
+mypy==0.971
# needed in tests
-cloudpickle>=1.3
-scikit-learn>0.22.1
-onnxruntime
-psutil # for `DeviceStatsMonitor`
-pandas # needed in benchmarks
-fastapi
-uvicorn
+cloudpickle>=1.3, <=2.1.0
+scikit-learn>0.22.1, <=1.1.1
+onnxruntime<1.13.0
+psutil<=5.9.1 # for `DeviceStatsMonitor`
+pandas>1.0, <=1.4.3 # needed in benchmarks
+fastapi<=0.79.0
+uvicorn<=0.18.2
diff --git a/src/lightning_app/CHANGELOG.md b/src/lightning_app/CHANGELOG.md
index 07927a1b01f87..2aa5c7cdd837c 100644
--- a/src/lightning_app/CHANGELOG.md
+++ b/src/lightning_app/CHANGELOG.md
@@ -9,21 +9,87 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
### Added
- Add support for `Lightning App Commands` through the `configure_commands` hook on the Lightning Flow and the `ClientCommand` ([#13602](https://github.com/Lightning-AI/lightning/pull/13602))
+
+
- Add support for Lightning AI BYOC cluster management ([#13835](https://github.com/Lightning-AI/lightning/pull/13835))
+
+
- Add support to run Lightning apps on Lightning AI BYOC clusters ([#13894](https://github.com/Lightning-AI/lightning/pull/13894))
+
+
- Add support for listing Lightning AI apps ([#13987](https://github.com/Lightning-AI/lightning/pull/13987))
+
+
- Adds `LightningTrainingComponent`. `LightningTrainingComponent` orchestrates multi-node training in the cloud ([#13830](https://github.com/Lightning-AI/lightning/pull/13830))
+- Add support for printing application logs using CLI `lightning show logs [components]` ([#13634](https://github.com/Lightning-AI/lightning/pull/13634))
+
+
+- Add support for `Lightning API` through the `configure_api` hook on the Lightning Flow and the `Post`, `Get`, `Delete`, `Put` HttpMethods ([#13945](https://github.com/Lightning-AI/lightning/pull/13945))
### Changed
-- Update the Lightning App docs ([#13537](https://github.com/Lightning-AI/lightning/pull/13537))
+- Default values and parameter names for Lightning AI BYOC cluster management ([#14132](https://github.com/Lightning-AI/lightning/pull/14132))
+
### Changed
-- Added `LIGHTNING_` prefix to Platform AWS credentials ([#13703](https://github.com/Lightning-AI/lightning/pull/13703))
+-
+
+
+- Run the flow only if the state has changed from the previous execution ([#14076](https://github.com/Lightning-AI/lightning/pull/14076))
+
+### Deprecated
+
+-
+
+
+### Fixed
+
+-
+
+
+## [0.5.5] - 2022-08-9
### Deprecated
+- Deprecate sheety API ([#14004](https://github.com/Lightning-AI/lightning/pull/14004))
+
### Fixed
- Resolved a bug where the work statuses will grow quickly and be duplicated ([#13970](https://github.com/Lightning-AI/lightning/pull/13970))
+- Resolved a bug about a race condition when sending the work state through the caller_queue ([#14074](https://github.com/Lightning-AI/lightning/pull/14074))
+- Fixed Start Lightning App on Cloud if Repo Begins With Name "Lightning" ([#14025](https://github.com/Lightning-AI/lightning/pull/14025))
+
+
+## [0.5.4] - 2022-08-01
+
+### Changed
+
+- Wrapped imports for traceability ([#13924](https://github.com/Lightning-AI/lightning/pull/13924))
+- Set version as today ([#13906](https://github.com/Lightning-AI/lightning/pull/13906))
+
+### Fixed
+
+- Included app templates to the lightning and app packages ([#13731](https://github.com/Lightning-AI/lightning/pull/13731))
+- Added UI for install all ([#13732](https://github.com/Lightning-AI/lightning/pull/13732))
+- Fixed build meta pkg flow ([#13926](https://github.com/Lightning-AI/lightning/pull/13926))
+
+## [0.5.3] - 2022-07-25
+
+### Changed
+
+- Pruned requirements duplicity ([#13739](https://github.com/Lightning-AI/lightning/pull/13739))
+
+### Fixed
+
+- Use correct python version in lightning component template ([#13790](https://github.com/Lightning-AI/lightning/pull/13790))
+
+## [0.5.2] - 2022-07-18
+
+### Added
+
+- Update the Lightning App docs ([#13537](https://github.com/Lightning-AI/lightning/pull/13537))
+
+### Changed
+
+- Added `LIGHTNING_` prefix to Platform AWS credentials ([#13703](https://github.com/Lightning-AI/lightning/pull/13703))
diff --git a/src/lightning_app/api/__init__.py b/src/lightning_app/api/__init__.py
new file mode 100644
index 0000000000000..25ec5c4708761
--- /dev/null
+++ b/src/lightning_app/api/__init__.py
@@ -0,0 +1,3 @@
+from lightning_app.api.http_methods import Delete, Get, Post, Put
+
+__all__ = ["Delete", "Get", "Post", "Put"]
diff --git a/src/lightning_app/api/http_methods.py b/src/lightning_app/api/http_methods.py
new file mode 100644
index 0000000000000..02b6ec87f17d2
--- /dev/null
+++ b/src/lightning_app/api/http_methods.py
@@ -0,0 +1,107 @@
+import asyncio
+import inspect
+import time
+from copy import deepcopy
+from functools import wraps
+from multiprocessing import Queue
+from typing import Any, Callable, Dict, List, Optional
+from uuid import uuid4
+
+from fastapi import FastAPI
+
+from lightning_app.api.request_types import APIRequest, CommandRequest
+
+
+def _signature_proxy_function():
+ pass
+
+
+class HttpMethod:
+ def __init__(self, route: str, method: Callable, method_name: Optional[str] = None, timeout: int = 30, **kwargs):
+ """This class is used to inject user defined methods within the App Rest API.
+
+ Arguments:
+ route: The path used to route the requests
+ method: The associated flow method
+ timeout: The time in seconds taken before raising a timeout exception.
+ """
+ self.route = route
+ self.component_name = method.__self__.name
+ self.method_name = method_name or method.__name__
+ self.method_annotations = method.__annotations__
+ # TODO: Validate the signature contains only pydantic models.
+ self.method_signature = inspect.signature(method)
+ self.timeout = timeout
+ self.kwargs = kwargs
+
+ def add_route(self, app: FastAPI, request_queue: Queue, responses_store: Dict[str, Any]) -> None:
+ # 1: Create a proxy function with the signature of the wrapped method.
+ fn = deepcopy(_signature_proxy_function)
+ fn.__annotations__ = self.method_annotations
+ fn.__name__ = self.method_name
+ setattr(fn, "__signature__", self.method_signature)
+
+ # 2: Get the route associated with the http method.
+ route = getattr(app, self.__class__.__name__.lower())
+
+ request_cls = CommandRequest if self.route.startswith("/command/") else APIRequest
+
+ # 3: Define the request handler.
+ @wraps(_signature_proxy_function)
+ async def _handle_request(*args, **kwargs):
+ async def fn(*args, **kwargs):
+ request_id = str(uuid4()).split("-")[0]
+ request_queue.put(
+ request_cls(
+ name=self.component_name,
+ method_name=self.method_name,
+ args=args,
+ kwargs=kwargs,
+ id=request_id,
+ )
+ )
+
+ t0 = time.time()
+ while request_id not in responses_store:
+ await asyncio.sleep(0.1)
+ if (time.time() - t0) > self.timeout:
+ raise Exception("The response was never received.")
+
+ return responses_store.pop(request_id)
+
+ return await asyncio.create_task(fn(*args, **kwargs))
+
+ # 4: Register the user provided route to the Rest API.
+ route(self.route, **self.kwargs)(_handle_request)
+
+
+class Post(HttpMethod):
+ pass
+
+
+class Get(HttpMethod):
+
+ pass
+
+
+class Put(HttpMethod):
+
+ pass
+
+
+class Delete(HttpMethod):
+ pass
+
+
+def _add_tags_to_api(apis: List[HttpMethod], tags: List[str]) -> None:
+ for api in apis:
+ if not api.kwargs.get("tag"):
+ api.kwargs["tags"] = tags
+
+
+def _validate_api(apis: List[HttpMethod]) -> None:
+ for api in apis:
+ if not isinstance(api, HttpMethod):
+ raise Exception(f"The provided api should be either [{Delete}, {Get}, {Post}, {Put}]")
+ if api.route.startswith("/command"):
+ raise Exception("The route `/command` is reserved for commands. Please, use something else.")
diff --git a/src/lightning_app/api/request_types.py b/src/lightning_app/api/request_types.py
new file mode 100644
index 0000000000000..53a6df25820a3
--- /dev/null
+++ b/src/lightning_app/api/request_types.py
@@ -0,0 +1,36 @@
+from dataclasses import asdict, dataclass
+from typing import Any
+
+from deepdiff import Delta
+
+
+@dataclass
+class BaseRequest:
+ def to_dict(self):
+ return asdict(self)
+
+
+@dataclass
+class DeltaRequest(BaseRequest):
+ delta: Delta
+
+ def to_dict(self):
+ return self.delta.to_dict()
+
+
+@dataclass
+class CommandRequest(BaseRequest):
+ id: str
+ name: str
+ method_name: str
+ args: Any
+ kwargs: Any
+
+
+@dataclass
+class APIRequest(BaseRequest):
+ id: str
+ name: str
+ method_name: str
+ args: Any
+ kwargs: Any
diff --git a/src/lightning_app/cli/lightning_cli.py b/src/lightning_app/cli/lightning_cli.py
index fb4c40330dfd9..81d2a773b4619 100644
--- a/src/lightning_app/cli/lightning_cli.py
+++ b/src/lightning_app/cli/lightning_cli.py
@@ -4,11 +4,12 @@
from argparse import ArgumentParser
from pathlib import Path
from typing import List, Tuple, Union
-from uuid import uuid4
import click
import requests
+import rich
from requests.exceptions import ConnectionError
+from rich.color import ANSI_COLOR_NAMES
from lightning_app import __version__ as ver
from lightning_app.cli import cmd_init, cmd_install, cmd_pl_init, cmd_react_ui_init
@@ -18,13 +19,16 @@
from lightning_app.core.constants import get_lightning_cloud_url, LOCAL_LAUNCH_ADMIN_VIEW
from lightning_app.runners.runtime import dispatch
from lightning_app.runners.runtime_type import RuntimeType
+from lightning_app.utilities.app_logs import _app_logs_reader
from lightning_app.utilities.cli_helpers import (
_format_input_env_variables,
_retrieve_application_url_and_available_commands,
)
+from lightning_app.utilities.cloud import _get_project
+from lightning_app.utilities.enum import OpenAPITags
from lightning_app.utilities.install_components import register_all_external_components
from lightning_app.utilities.login import Auth
-from lightning_app.utilities.state import headers_for
+from lightning_app.utilities.network import LightningClient
logger = logging.getLogger(__name__)
@@ -50,12 +54,96 @@ def main():
@click.version_option(ver)
def _main():
register_all_external_components()
+
+
+@_main.group()
+def show():
+ """Show given resource."""
pass
+@show.command()
+@click.argument("app_name", required=False)
+@click.argument("components", nargs=-1, required=False)
+@click.option("-f", "--follow", required=False, is_flag=True, help="Wait for new logs, to exit use CTRL+C.")
+def logs(app_name: str, components: List[str], follow: bool) -> None:
+ """Show cloud application logs. By default prints logs for all currently available components.
+
+ Example uses:
+
+ Print all application logs:
+
+ $ lightning show logs my-application
+
+
+ Print logs only from the flow (no work):
+
+ $ lightning show logs my-application flow
+
+
+ Print logs only from selected works:
+
+ $ lightning show logs my-application root.work_a root.work_b
+ """
+
+ client = LightningClient()
+ project = _get_project(client)
+
+ apps = {
+ app.name: app
+ for app in client.lightningapp_instance_service_list_lightningapp_instances(project.project_id).lightningapps
+ }
+
+ if not apps:
+ raise click.ClickException(
+ "You don't have any application in the cloud. Please, run an application first with `--cloud`."
+ )
+
+ if not app_name:
+ raise click.ClickException(
+ f"You have not specified any Lightning App. Please select one of available: [{', '.join(apps.keys())}]"
+ )
+
+ if app_name not in apps:
+ raise click.ClickException(
+ f"The Lightning App '{app_name}' does not exist. Please select one of following: [{', '.join(apps.keys())}]"
+ )
+
+ # Fetch all lightning works from given application
+ # 'Flow' component is somewhat implicit, only one for whole app,
+ # and not listed in lightningwork API - so we add it directly to the list
+ works = client.lightningwork_service_list_lightningwork(
+ project_id=project.project_id, app_id=apps[app_name].id
+ ).lightningworks
+ app_component_names = ["flow"] + [f.name for f in apps[app_name].spec.flow_servers] + [w.name for w in works]
+
+ if not components:
+ components = app_component_names
+
+ for component in components:
+ if component not in app_component_names:
+ raise click.ClickException(f"Component '{component}' does not exist in app {app_name}.")
+
+ log_reader = _app_logs_reader(
+ client=client,
+ project_id=project.project_id,
+ app_id=apps[app_name].id,
+ component_names=components,
+ follow=follow,
+ )
+
+ rich_colors = list(ANSI_COLOR_NAMES)
+ colors = {c: rich_colors[i + 1] for i, c in enumerate(components)}
+
+ for log_event in log_reader:
+ date = log_event.timestamp.strftime("%m/%d/%Y %H:%M:%S")
+ color = colors[log_event.component_name]
+ rich.print(f"[{color}]{log_event.component_name}[/{color}] {date} {log_event.message}")
+
+
@_main.command()
def login():
- """Log in to your Lightning.ai account."""
+ """Log in to your lightning.ai account."""
auth = Auth()
auth.clear()
@@ -68,7 +156,7 @@ def login():
@_main.command()
def logout():
- """Log out of your Lightning.ai account."""
+ """Log out of your lightning.ai account."""
Auth().clear()
@@ -127,7 +215,7 @@ def on_before_run(*args):
@_main.group()
def run():
- """Run your application."""
+ """Run a Lightning application locally or on the cloud."""
@run.command("app")
@@ -174,41 +262,42 @@ def app_command():
hparams, argv = parser.parse_known_args()
# 1: Collect the url and comments from the running application
- url, commands = _retrieve_application_url_and_available_commands(hparams.app_id)
- if url is None or commands is None:
+ url, api_commands = _retrieve_application_url_and_available_commands(hparams.app_id)
+ if url is None or api_commands is None:
raise Exception("We couldn't find any matching running app.")
- if not commands:
+ if not api_commands:
raise Exception("This application doesn't expose any commands yet.")
command = argv[0]
- command_names = [c["command"] for c in commands]
- if command not in command_names:
- raise Exception(f"The provided command {command} isn't available in {command_names}")
+ if command not in api_commands:
+ raise Exception(f"The provided command {command} isn't available in {list(api_commands)}")
# 2: Send the command from the user
- command_metadata = [c for c in commands if c["command"] == command][0]
- params = command_metadata["params"]
+ metadata = api_commands[command]
# 3: Execute the command
- if not command_metadata["is_client_command"]:
- # TODO: Improve what is supported there.
- kwargs = {k.split("=")[0].replace("--", ""): k.split("=")[1] for k in argv[1:]}
- for param in params:
- if param not in kwargs:
- raise Exception(f"The argument --{param}=X hasn't been provided.")
- json = {
- "command_name": command,
- "command_arguments": kwargs,
- "affiliation": command_metadata["affiliation"],
- "id": str(uuid4()),
- }
- resp = requests.post(url + "/api/v1/commands", json=json, headers=headers_for({}))
+ if metadata["tag"] == OpenAPITags.APP_COMMAND:
+ # TODO: Improve what is current supported
+ kwargs = [v.replace("--", "") for v in argv[1:]]
+
+ for p in kwargs:
+ if p.split("=")[0] not in metadata["parameters"]:
+ raise Exception(f"Some arguments need to be provided. The keys are {list(metadata['parameters'])}.")
+ # TODO: Encode the parameters and validate their type.
+ query_parameters = "&".join(kwargs)
+ resp = requests.post(url + f"/command/{command}?{query_parameters}")
assert resp.status_code == 200, resp.json()
else:
- client_command, models = _download_command(command_metadata, hparams.app_id, debug_mode=debug_mode)
- client_command._setup(metadata=command_metadata, models=models, app_url=url)
+ client_command = _download_command(
+ command,
+ metadata["cls_path"],
+ metadata["cls_name"],
+ hparams.app_id,
+ debug_mode=debug_mode,
+ )
+ client_command._setup(command_name=command, app_url=url)
sys.argv = argv
client_command.run()
@@ -232,7 +321,7 @@ def stop():
@_main.group()
def install():
- """Install Lightning apps and components."""
+ """Install a Lightning App and/or component."""
@install.command("app")
@@ -290,7 +379,7 @@ def install_component(name, yes, version):
@_main.group()
def init():
- """Init a Lightning app and component."""
+ """Init a Lightning App and/or component."""
@init.command("app")
diff --git a/src/lightning_app/cli/lightning_cli_create.py b/src/lightning_app/cli/lightning_cli_create.py
index 7e45fe7e7c078..7e9a6b9d2143b 100644
--- a/src/lightning_app/cli/lightning_cli_create.py
+++ b/src/lightning_app/cli/lightning_cli_create.py
@@ -5,7 +5,7 @@
@click.group("create")
def create():
- """Create Lightning AI BYOC managed resources."""
+ """Create Lightning AI self-managed resources (clusters, etc…)"""
pass
@@ -33,14 +33,14 @@ def create():
help="Instance types that you want to support, for computer jobs within the cluster.",
)
@click.option(
- "--cost-savings",
- "cost_savings",
+ "--enable-performance",
+ "enable_performance",
type=bool,
required=False,
default=False,
is_flag=True,
- help=""""Use this flag to ensure that the cluster is created with a profile that is optimized for cost savings.
- This makes runs cheaper but start-up times may increase.""",
+ help=""""Use this flag to ensure that the cluster is created with a profile that is optimized for performance.
+ This makes runs more expensive but start-up times decrease.""",
)
@click.option(
"--edit-before-creation",
@@ -65,12 +65,12 @@ def create_cluster(
provider: str,
instance_types: str,
edit_before_creation: bool,
- cost_savings: bool,
+ enable_performance: bool,
wait: bool,
**kwargs,
):
"""Create a Lightning AI BYOC compute cluster with your cloud provider credentials."""
- if provider != "aws":
+ if provider.lower() != "aws":
click.echo("Only AWS is supported for now. But support for more providers is coming soon.")
return
cluster_manager = AWSClusterManager()
@@ -79,8 +79,8 @@ def create_cluster(
region=region,
role_arn=role_arn,
external_id=external_id,
- instance_types=instance_types.split(","),
+ instance_types=instance_types.split(",") if instance_types is not None else None,
edit_before_creation=edit_before_creation,
- cost_savings=cost_savings,
+ cost_savings=not enable_performance,
wait=wait,
)
diff --git a/src/lightning_app/cli/lightning_cli_delete.py b/src/lightning_app/cli/lightning_cli_delete.py
index c304b130bdf5d..366f4aa01e995 100644
--- a/src/lightning_app/cli/lightning_cli_delete.py
+++ b/src/lightning_app/cli/lightning_cli_delete.py
@@ -5,7 +5,7 @@
@click.group("delete")
def delete():
- """Delete Lightning AI BYOC managed resources."""
+ """Delete Lightning AI self-managed resources (clusters, etc…)"""
pass
diff --git a/src/lightning_app/cli/lightning_cli_list.py b/src/lightning_app/cli/lightning_cli_list.py
index d0d1d34a6dd4d..7d38b5b57760f 100644
--- a/src/lightning_app/cli/lightning_cli_list.py
+++ b/src/lightning_app/cli/lightning_cli_list.py
@@ -6,7 +6,7 @@
@click.group(name="list")
def get_list():
- """List your Lightning AI BYOC managed resources."""
+ """List Lightning AI self-managed resources (clusters, etc…)"""
pass
diff --git a/src/lightning_app/core/api.py b/src/lightning_app/core/api.py
index f19ada5340d57..8b625713e0c2c 100644
--- a/src/lightning_app/core/api.py
+++ b/src/lightning_app/core/api.py
@@ -3,7 +3,6 @@
import os
import queue
import sys
-import time
import traceback
from copy import deepcopy
from multiprocessing import Queue
@@ -21,9 +20,12 @@
from pydantic import BaseModel
from websockets.exceptions import ConnectionClosed
+from lightning_app.api.http_methods import HttpMethod
+from lightning_app.api.request_types import DeltaRequest
from lightning_app.core.constants import FRONTEND_DIR
from lightning_app.core.queues import RedisQueue
from lightning_app.utilities.app_helpers import InMemoryStateStore, StateStore
+from lightning_app.utilities.enum import OpenAPITags
from lightning_app.utilities.imports import _is_redis_available, _is_starsessions_available
if _is_starsessions_available():
@@ -42,9 +44,6 @@ class SessionMiddleware:
frontend_static_dir = os.path.join(FRONTEND_DIR, "static")
api_app_delta_queue: Queue = None
-api_commands_requests_queue: Queue = None
-api_commands_metadata_queue: Queue = None
-api_commands_responses_queue: Queue = None
template = {"ui": {}, "app": {}}
templates = Jinja2Templates(directory=FRONTEND_DIR)
@@ -56,8 +55,8 @@ class SessionMiddleware:
lock = Lock()
app_spec: Optional[List] = None
-app_commands_metadata: Optional[Dict] = None
-commands_response_store = {}
+# In the future, this would be abstracted to support horizontal scaling.
+responses_store = {}
logger = logging.getLogger(__name__)
@@ -67,11 +66,10 @@ class SessionMiddleware:
class UIRefresher(Thread):
- def __init__(self, api_publish_state_queue, api_commands_metadata_queue, api_commands_responses_queue) -> None:
+ def __init__(self, api_publish_state_queue, api_response_queue) -> None:
super().__init__(daemon=True)
self.api_publish_state_queue = api_publish_state_queue
- self.api_commands_metadata_queue = api_commands_metadata_queue
- self.api_commands_responses_queue = api_commands_responses_queue
+ self.api_response_queue = api_response_queue
self._exit_event = Event()
def run(self):
@@ -93,18 +91,11 @@ def run_once(self):
pass
try:
- metadata = self.api_commands_metadata_queue.get(timeout=0)
+ response = self.api_response_queue.get(timeout=0)
with lock:
- global app_commands_metadata
- app_commands_metadata = metadata
- except queue.Empty:
- pass
-
- try:
- response = self.api_commands_responses_queue.get(timeout=0)
- with lock:
- global commands_response_store
- commands_response_store[response["id"]] = response["response"]
+ # TODO: Abstract the responses store to support horizontal scaling.
+ global responses_store
+ responses_store[response["id"]] = response["response"]
except queue.Empty:
pass
@@ -117,6 +108,23 @@ class StateUpdate(BaseModel):
state: dict = {}
+openapi_tags = [
+ {
+ "name": OpenAPITags.APP_CLIENT_COMMAND,
+ "description": "The App Endpoints to be triggered exclusively from the CLI",
+ },
+ {
+ "name": OpenAPITags.APP_COMMAND,
+ "description": "The App Endpoints that can be triggered equally from the CLI or from a Http Request",
+ },
+ {
+ "name": OpenAPITags.APP_API,
+ "description": "The App Endpoints that can be triggered exclusively from a Http Request",
+ },
+]
+
+app = FastAPI(openapi_tags=openapi_tags)
+
fastapi_service = FastAPI()
fastapi_service.add_middleware(
@@ -176,50 +184,13 @@ async def get_spec(
return app_spec or []
-@fastapi_service.post("/api/v1/commands", response_class=JSONResponse)
-async def run_remote_command(
- request: Request,
-) -> None:
- data = await request.json()
- command_name = data.get("command_name", None)
- if not command_name:
- raise Exception("The provided command name is empty.")
- command_arguments = data.get("command_arguments", None)
- if not command_arguments:
- raise Exception("The provided command metadata is empty.")
- affiliation = data.get("affiliation", None)
- if not affiliation:
- raise Exception("The provided affiliation is empty.")
-
- async def fn(data):
- request_id = data["id"]
- api_commands_requests_queue.put(data)
-
- t0 = time.time()
- while request_id not in commands_response_store:
- await asyncio.sleep(0.1)
- if (time.time() - t0) > 15:
- raise Exception("The response was never received.")
-
- return commands_response_store[request_id]
-
- return await asyncio.create_task(fn(data))
-
-
-@fastapi_service.get("/api/v1/commands", response_class=JSONResponse)
-async def get_commands() -> Optional[Dict]:
- global app_commands_metadata
- with lock:
- return app_commands_metadata
-
-
@fastapi_service.post("/api/v1/delta")
async def post_delta(
request: Request,
x_lightning_type: Optional[str] = Header(None),
x_lightning_session_uuid: Optional[str] = Header(None),
x_lightning_session_id: Optional[str] = Header(None),
-) -> Mapping:
+) -> None:
"""This endpoint is used to make an update to the app state using delta diff, mainly used by streamlit to
update the state."""
@@ -229,9 +200,7 @@ async def post_delta(
raise Exception("Missing X-Lightning-Session-ID header")
body: Dict = await request.json()
- delta = body["delta"]
- update_delta = Delta(delta)
- api_app_delta_queue.put(update_delta)
+ api_app_delta_queue.put(DeltaRequest(delta=Delta(body["delta"])))
@fastapi_service.post("/api/v1/state")
@@ -240,7 +209,7 @@ async def post_state(
x_lightning_type: Optional[str] = Header(None),
x_lightning_session_uuid: Optional[str] = Header(None),
x_lightning_session_id: Optional[str] = Header(None),
-) -> Mapping:
+) -> None:
if x_lightning_session_uuid is None:
raise Exception("Missing X-Lightning-Session-UUID header")
if x_lightning_session_id is None:
@@ -263,8 +232,7 @@ async def post_state(
state = body["state"]
last_state = global_app_state_store.get_served_state(x_lightning_session_uuid)
deep_diff = DeepDiff(last_state, state, verbose_level=2)
- update_delta = Delta(deep_diff)
- api_app_delta_queue.put(update_delta)
+ api_app_delta_queue.put(DeltaRequest(delta=Delta(deep_diff)))
@fastapi_service.get("/healthz", status_code=200)
@@ -307,8 +275,6 @@ async def websocket_endpoint(websocket: WebSocket):
await websocket.close()
-# Catch-all for nonexistent API routes (since we define a catch-all for client-side routing)
-@fastapi_service.get("/api{full_path:path}", response_class=JSONResponse)
async def api_catch_all(request: Request, full_path: str):
raise HTTPException(status_code=404, detail="Not found")
@@ -317,14 +283,18 @@ async def api_catch_all(request: Request, full_path: str):
fastapi_service.mount("/static", StaticFiles(directory=frontend_static_dir, check_dir=False), name="static")
-# Catch-all for frontend routes, must be defined after all other routes
-@fastapi_service.get("/{full_path:path}", response_class=HTMLResponse)
async def frontend_route(request: Request, full_path: str):
if "pytest" in sys.modules:
return ""
return templates.TemplateResponse("index.html", {"request": request})
+def register_global_routes():
+ # Catch-all for nonexistent API routes (since we define a catch-all for client-side routing)
+ fastapi_service.get("/api{full_path:path}", response_class=JSONResponse)(api_catch_all)
+ fastapi_service.get("/{full_path:path}", response_class=HTMLResponse)(frontend_route)
+
+
class LightningUvicornServer(uvicorn.Server):
has_started_queue = None
@@ -346,34 +316,28 @@ async def check_is_started(self, queue):
def start_server(
api_publish_state_queue,
api_delta_queue,
- commands_requests_queue,
- commands_responses_queue,
- commands_metadata_queue,
+ api_response_queue,
has_started_queue: Optional[Queue] = None,
host="127.0.0.1",
port=8000,
uvicorn_run: bool = True,
spec: Optional[List] = None,
+ apis: Optional[List[HttpMethod]] = None,
app_state_store: Optional[StateStore] = None,
):
global api_app_delta_queue
global global_app_state_store
- global api_commands_requests_queue
- global api_commands_responses_queue
global app_spec
app_spec = spec
api_app_delta_queue = api_delta_queue
- api_commands_requests_queue = commands_requests_queue
- api_commands_responses_queue = commands_responses_queue
- api_commands_metadata_queue = commands_metadata_queue
if app_state_store is not None:
global_app_state_store = app_state_store
global_app_state_store.add(TEST_SESSION_UUID)
- refresher = UIRefresher(api_publish_state_queue, api_commands_metadata_queue, commands_responses_queue)
+ refresher = UIRefresher(api_publish_state_queue, api_response_queue)
refresher.setDaemon(True)
refresher.start()
@@ -384,6 +348,14 @@ def start_server(
LightningUvicornServer.has_started_queue = has_started_queue
# uvicorn is doing some uglyness by replacing uvicorn.main by click command.
sys.modules["uvicorn.main"].Server = LightningUvicornServer
+
+ # Register the user API.
+ if apis:
+ for api in apis:
+ api.add_route(fastapi_service, api_app_delta_queue, responses_store)
+
+ register_global_routes()
+
uvicorn.run(app=fastapi_service, host=host, port=port, log_level="error")
return refresher
diff --git a/src/lightning_app/core/app.py b/src/lightning_app/core/app.py
index 584f94285c219..65242a1ae0a2a 100644
--- a/src/lightning_app/core/app.py
+++ b/src/lightning_app/core/app.py
@@ -11,12 +11,13 @@
from deepdiff import DeepDiff, Delta
import lightning_app
+from lightning_app.api.request_types import APIRequest, CommandRequest, DeltaRequest
from lightning_app.core.constants import FLOW_DURATION_SAMPLES, FLOW_DURATION_THRESHOLD, STATE_ACCUMULATE_WAIT
from lightning_app.core.queues import BaseQueue, SingleProcessQueue
from lightning_app.frontend import Frontend
from lightning_app.storage.path import storage_root_dir
-from lightning_app.utilities.app_helpers import _delta_to_appstate_delta, _LightningAppRef
-from lightning_app.utilities.commands.base import _populate_commands_endpoint, _process_command_requests
+from lightning_app.utilities.app_helpers import _delta_to_app_state_delta, _LightningAppRef
+from lightning_app.utilities.commands.base import _process_requests
from lightning_app.utilities.component import _convert_paths_after_init
from lightning_app.utilities.enum import AppStage, CacheCallsKeys
from lightning_app.utilities.exceptions import CacheMissException, ExitAppException
@@ -73,9 +74,7 @@ def __init__(
# queues definition.
self.delta_queue: t.Optional[BaseQueue] = None
self.readiness_queue: t.Optional[BaseQueue] = None
- self.commands_requests_queue: t.Optional[BaseQueue] = None
- self.commands_responses_queue: t.Optional[BaseQueue] = None
- self.commands_metadata_queue: t.Optional[BaseQueue] = None
+ self.api_response_queue: t.Optional[BaseQueue] = None
self.api_publish_state_queue: t.Optional[BaseQueue] = None
self.api_delta_queue: t.Optional[BaseQueue] = None
self.error_queue: t.Optional[BaseQueue] = None
@@ -94,7 +93,7 @@ def __init__(
self.processes: t.Dict[str, WorkManager] = {}
self.frontends: t.Dict[str, Frontend] = {}
self.stage = AppStage.RUNNING
- self._has_updated: bool = False
+ self._has_updated: bool = True
self._schedules: t.Dict[str, t.Dict] = {}
self.threads: t.List[threading.Thread] = []
@@ -253,7 +252,7 @@ def named_works(self) -> t.List[t.Tuple[str, "lightning_app.LightningWork"]]:
"""Returns all the works defined within this application with their names."""
return self.root.named_works(recurse=True)
- def _collect_deltas_from_ui_and_work_queues(self) -> t.List[Delta]:
+ def _collect_deltas_from_ui_and_work_queues(self) -> t.List[t.Union[Delta, APIRequest, CommandRequest]]:
# The aggregation would try to get as many deltas as possible
# from both the `api_delta_queue` and `delta_queue`
# during the `state_accumulate_wait` time.
@@ -267,8 +266,12 @@ def _collect_deltas_from_ui_and_work_queues(self) -> t.List[Delta]:
while (time() - t0) < self.state_accumulate_wait:
if self.api_delta_queue and should_get_delta_from_api:
- delta_from_api: Delta = self.get_state_changed_from_queue(self.api_delta_queue) # TODO: rename
+ delta_from_api: t.Union[DeltaRequest, APIRequest, CommandRequest] = self.get_state_changed_from_queue(
+ self.api_delta_queue
+ ) # TODO: rename
if delta_from_api:
+ if isinstance(delta_from_api, DeltaRequest):
+ delta_from_api = delta_from_api.delta
deltas.append(delta_from_api)
else:
should_get_delta_from_api = False
@@ -278,7 +281,7 @@ def _collect_deltas_from_ui_and_work_queues(self) -> t.List[Delta]:
if component_output:
logger.debug(f"Received from {component_output.id} : {component_output.delta.to_dict()}")
work = self.get_component_by_name(component_output.id)
- new_work_delta = _delta_to_appstate_delta(self.root, work, deepcopy(component_output.delta))
+ new_work_delta = _delta_to_app_state_delta(self.root, work, deepcopy(component_output.delta))
deltas.append(new_work_delta)
else:
should_get_component_output = False
@@ -307,16 +310,29 @@ def maybe_apply_changes(self) -> bool:
if not deltas:
# When no deltas are received from the Rest API or work queues,
# we need to check if the flow modified the state and populate changes.
- if Delta(DeepDiff(self.last_state, self.state, verbose_level=2)).to_dict():
+ deep_diff = DeepDiff(self.last_state, self.state, verbose_level=2)
+ if deep_diff:
+ # TODO: Resolve changes with ``CacheMissException``.
# new_state = self.populate_changes(self.last_state, self.state)
- self.set_state(self.state)
+ self.set_last_state(self.state)
self._has_updated = True
return False
logger.debug(f"Received {[d.to_dict() for d in deltas]}")
- state = self.state
+ # 1: Process the API / Command Requests first as they might affect the state.
+ state_deltas = []
for delta in deltas:
+ if isinstance(delta, (APIRequest, CommandRequest)):
+ _process_requests(self, delta)
+ else:
+ state_deltas.append(delta)
+
+ # 2: Collect the state
+ state = self.state
+
+ # 3: Apply the state delta
+ for delta in state_deltas:
try:
state += delta
except Exception as e:
@@ -329,7 +345,6 @@ def maybe_apply_changes(self) -> bool:
def run_once(self):
"""Method used to collect changes and run the root Flow once."""
done = False
- self._has_updated = False
self._last_run_time = 0.0
if self.backend is not None:
@@ -350,19 +365,23 @@ def run_once(self):
elif self.stage == AppStage.RESTARTING:
return self._apply_restarting()
- _process_command_requests(self)
+ t0 = time()
try:
self.check_error_queue()
- t0 = time()
- self.root.run()
- self._last_run_time = time() - t0
+ # Execute the flow only if:
+ # - There are state changes
+ # - It is the first execution of the flow
+ if self._has_updated:
+ self.root.run()
except CacheMissException:
self._on_cache_miss_exception()
except (ExitAppException, KeyboardInterrupt):
done = True
self.stage = AppStage.STOPPING
+ self._last_run_time = time() - t0
+
self.on_run_once_end()
return done
@@ -404,8 +423,6 @@ def _run(self) -> bool:
self._reset_run_time_monitor()
- _populate_commands_endpoint(self)
-
while not done:
done = self.run_once()
@@ -414,6 +431,8 @@ def _run(self) -> bool:
if self._has_updated and self.should_publish_changes_to_api and self.api_publish_state_queue:
self.api_publish_state_queue.put(self.state_vars)
+ self._has_updated = False
+
return True
def _update_layout(self) -> None:
@@ -430,8 +449,10 @@ def _apply_restarting(self) -> bool:
self.stage = AppStage.BLOCKING
return False
- def _has_work_finished(self, work):
+ def _has_work_finished(self, work) -> bool:
latest_call_hash = work._calls[CacheCallsKeys.LATEST_CALL_HASH]
+ if latest_call_hash is None:
+ return False
return "ret" in work._calls[latest_call_hash]
def _collect_work_finish_status(self) -> dict:
diff --git a/src/lightning_app/core/flow.py b/src/lightning_app/core/flow.py
index f6b6e34e81538..41c46cd868307 100644
--- a/src/lightning_app/core/flow.py
+++ b/src/lightning_app/core/flow.py
@@ -634,3 +634,36 @@ def my_remote_method(self, name):
lightning my_command_name --args name=my_own_name
"""
raise NotImplementedError
+
+ def configure_api(self):
+ """Configure the API routes of the LightningFlow.
+
+ Returns a list of HttpMethod such as Post or Get.
+
+ .. code-block:: python
+
+ from lightning_app import LightningFlow
+ from lightning_app.api import Post
+
+ from pydantic import BaseModel
+
+
+ class HandlerModel(BaseModel):
+ name: str
+
+
+ class Flow(L.LightningFlow):
+ def __init__(self):
+ super().__init__()
+ self.names = []
+
+ def handler(self, config: HandlerModel) -> None:
+ self.names.append(config.name)
+
+ def configure_api(self):
+ return [Post("/v1/api/request", self.handler)]
+
+ Once the app is running, you can access the Swagger UI of the app
+ under the ``/docs`` route.
+ """
+ raise NotImplementedError
diff --git a/src/lightning_app/core/queues.py b/src/lightning_app/core/queues.py
index efac8230047e0..2b7295d7f327f 100644
--- a/src/lightning_app/core/queues.py
+++ b/src/lightning_app/core/queues.py
@@ -36,9 +36,7 @@
ORCHESTRATOR_COPY_REQUEST_CONSTANT = "ORCHESTRATOR_COPY_REQUEST"
ORCHESTRATOR_COPY_RESPONSE_CONSTANT = "ORCHESTRATOR_COPY_RESPONSE"
WORK_QUEUE_CONSTANT = "WORK_QUEUE"
-COMMANDS_REQUESTS_QUEUE_CONSTANT = "COMMANDS_REQUESTS_QUEUE"
-COMMANDS_RESPONSES_QUEUE_CONSTANT = "COMMANDS_RESPONSES_QUEUE"
-COMMANDS_METADATA_QUEUE_CONSTANT = "COMMANDS_METADATA_QUEUE"
+API_RESPONSE_QUEUE_CONSTANT = "API_RESPONSE_QUEUE"
class QueuingSystem(Enum):
@@ -54,18 +52,8 @@ def _get_queue(self, queue_name: str) -> "BaseQueue":
else:
return SingleProcessQueue(queue_name, default_timeout=STATE_UPDATE_TIMEOUT)
- def get_commands_requests_queue(self, queue_id: Optional[str] = None) -> "BaseQueue":
- queue_name = f"{queue_id}_{COMMANDS_REQUESTS_QUEUE_CONSTANT}" if queue_id else COMMANDS_REQUESTS_QUEUE_CONSTANT
- return self._get_queue(queue_name)
-
- def get_commands_responses_queue(self, queue_id: Optional[str] = None) -> "BaseQueue":
- queue_name = (
- f"{queue_id}_{COMMANDS_RESPONSES_QUEUE_CONSTANT}" if queue_id else COMMANDS_RESPONSES_QUEUE_CONSTANT
- )
- return self._get_queue(queue_name)
-
- def get_commands_metadata_queue(self, queue_id: Optional[str] = None) -> "BaseQueue":
- queue_name = f"{queue_id}_{COMMANDS_METADATA_QUEUE_CONSTANT}" if queue_id else COMMANDS_METADATA_QUEUE_CONSTANT
+ def get_api_response_queue(self, queue_id: Optional[str] = None) -> "BaseQueue":
+ queue_name = f"{queue_id}_{API_RESPONSE_QUEUE_CONSTANT}" if queue_id else API_RESPONSE_QUEUE_CONSTANT
return self._get_queue(queue_name)
def get_readiness_queue(self, queue_id: Optional[str] = None) -> "BaseQueue":
@@ -98,10 +86,6 @@ def get_api_delta_queue(self, queue_id: Optional[str] = None) -> "BaseQueue":
queue_name = f"{queue_id}_{API_DELTA_QUEUE_CONSTANT}" if queue_id else API_DELTA_QUEUE_CONSTANT
return self._get_queue(queue_name)
- def get_api_refresh_queue(self, queue_id: Optional[str] = None) -> "BaseQueue":
- queue_name = f"{queue_id}_{API_REFRESH_QUEUE_CONSTANT}" if queue_id else API_REFRESH_QUEUE_CONSTANT
- return self._get_queue(queue_name)
-
def get_orchestrator_request_queue(self, work_name: str, queue_id: Optional[str] = None) -> "BaseQueue":
queue_name = (
f"{queue_id}_{ORCHESTRATOR_REQUEST_CONSTANT}_{work_name}"
diff --git a/src/lightning_app/runners/backends/backend.py b/src/lightning_app/runners/backends/backend.py
index 87bb103823fd2..a944cd4aa9093 100644
--- a/src/lightning_app/runners/backends/backend.py
+++ b/src/lightning_app/runners/backends/backend.py
@@ -82,11 +82,8 @@ def _prepare_queues(self, app):
kw = dict(queue_id=self.queue_id)
app.delta_queue = self.queues.get_delta_queue(**kw)
app.readiness_queue = self.queues.get_readiness_queue(**kw)
- app.commands_requests_queue = self.queues.get_commands_requests_queue(**kw)
- app.commands_responses_queue = self.queues.get_commands_responses_queue(**kw)
- app.commands_metadata_queue = self.queues.get_commands_metadata_queue(**kw)
+ app.api_response_queue = self.queues.get_api_response_queue(**kw)
app.error_queue = self.queues.get_error_queue(**kw)
- app.delta_queue = self.queues.get_delta_queue(**kw)
app.api_publish_state_queue = self.queues.get_api_state_publish_queue(**kw)
app.api_delta_queue = self.queues.get_api_delta_queue(**kw)
app.request_queues = {}
diff --git a/src/lightning_app/runners/cloud.py b/src/lightning_app/runners/cloud.py
index 957b60b5d2ab5..2cd98ebe4cf68 100644
--- a/src/lightning_app/runners/cloud.py
+++ b/src/lightning_app/runners/cloud.py
@@ -18,15 +18,22 @@
Gridv1ImageSpec,
V1BuildSpec,
V1DependencyFileInfo,
+ V1Drive,
+ V1DriveSpec,
+ V1DriveStatus,
+ V1DriveType,
V1EnvVar,
V1Flowserver,
V1LightningappInstanceSpec,
V1LightningappInstanceState,
+ V1LightningworkDrives,
V1LightningworkSpec,
+ V1Metadata,
V1NetworkConfig,
V1PackageManager,
V1ProjectClusterBinding,
V1PythonDependencyInfo,
+ V1SourceType,
V1UserRequestedComputeConfig,
V1Work,
)
@@ -36,6 +43,7 @@
from lightning_app.runners.backends.cloud import CloudBackend
from lightning_app.runners.runtime import Runtime
from lightning_app.source_code import LocalSourceCodeDir
+from lightning_app.storage import Drive
from lightning_app.utilities.cloud import _get_project
from lightning_app.utilities.dependency_caching import get_hash
from lightning_app.utilities.packaging.app_config import AppConfig, find_config_file
@@ -107,10 +115,45 @@ def dispatch(
preemptible=work.cloud_compute.preemptible,
shm_size=work.cloud_compute.shm_size,
)
+
+ drive_specs: List[V1LightningworkDrives] = []
+ for drive_attr_name, drive in [
+ (k, getattr(work, k)) for k in work._state if isinstance(getattr(work, k), Drive)
+ ]:
+ if drive.protocol == "lit://":
+ drive_type = V1DriveType.NO_MOUNT_S3
+ source_type = V1SourceType.S3
+ elif drive.protocol == "s3://":
+ drive_type = V1DriveType.INDEXED_S3
+ source_type = V1SourceType.S3
+ else:
+ raise RuntimeError(
+ f"unknown drive protocol `{drive.protocol}`. Please verify this "
+ f"drive type has been configured for use in the cloud dispatcher."
+ )
+
+ drive_specs.append(
+ V1LightningworkDrives(
+ drive=V1Drive(
+ metadata=V1Metadata(
+ name=f"{work.name}.{drive_attr_name}",
+ ),
+ spec=V1DriveSpec(
+ drive_type=drive_type,
+ source_type=source_type,
+ source=f"{drive.protocol}{drive.id}",
+ ),
+ status=V1DriveStatus(),
+ ),
+ mount_location=str(drive.root_folder),
+ ),
+ )
+
random_name = "".join(random.choice(string.ascii_lowercase) for _ in range(5))
spec = V1LightningworkSpec(
build_spec=build_spec,
cluster_id=cluster_id,
+ drives=drive_specs,
user_requested_compute_config=user_compute_config,
network_config=[V1NetworkConfig(name=random_name, port=work.port)],
)
diff --git a/src/lightning_app/runners/multiprocess.py b/src/lightning_app/runners/multiprocess.py
index 92ec900d89c65..16e373b0a37a2 100644
--- a/src/lightning_app/runners/multiprocess.py
+++ b/src/lightning_app/runners/multiprocess.py
@@ -3,10 +3,13 @@
from dataclasses import dataclass
from typing import Any, Callable, Optional, Union
+from lightning_app.api.http_methods import _add_tags_to_api, _validate_api
from lightning_app.core.api import start_server
from lightning_app.runners.backends import Backend
from lightning_app.runners.runtime import Runtime
from lightning_app.storage.orchestrator import StorageOrchestrator
+from lightning_app.utilities.app_helpers import is_overridden
+from lightning_app.utilities.commands.base import _commands_to_api, _prepare_commands
from lightning_app.utilities.component import _set_flow_context, _set_frontend_context
from lightning_app.utilities.load_app import extract_metadata_from_app
from lightning_app.utilities.network import find_free_network_port
@@ -60,15 +63,25 @@ def dispatch(self, *args: Any, on_before_run: Optional[Callable] = None, **kwarg
if self.start_server:
self.app.should_publish_changes_to_api = True
has_started_queue = self.backend.queues.get_has_server_started_queue()
+
+ apis = []
+ if is_overridden("configure_api", self.app.root):
+ apis = self.app.root.configure_api()
+ _validate_api(apis)
+ _add_tags_to_api(apis, ["app_api"])
+
+ if is_overridden("configure_commands", self.app.root):
+ commands = _prepare_commands(self.app)
+ apis += _commands_to_api(commands)
+
kwargs = dict(
+ apis=apis,
host=self.host,
port=self.port,
+ api_response_queue=self.app.api_response_queue,
api_publish_state_queue=self.app.api_publish_state_queue,
api_delta_queue=self.app.api_delta_queue,
has_started_queue=has_started_queue,
- commands_requests_queue=self.app.commands_requests_queue,
- commands_responses_queue=self.app.commands_responses_queue,
- commands_metadata_queue=self.app.commands_metadata_queue,
spec=extract_metadata_from_app(self.app),
)
server_proc = multiprocessing.Process(target=start_server, kwargs=kwargs)
diff --git a/src/lightning_app/storage/drive.py b/src/lightning_app/storage/drive.py
index 3bcdf72780653..f72ad38b6e130 100644
--- a/src/lightning_app/storage/drive.py
+++ b/src/lightning_app/storage/drive.py
@@ -13,7 +13,7 @@
class Drive:
__IDENTIFIER__ = "__drive__"
- __PROTOCOLS__ = ["lit://"]
+ __PROTOCOLS__ = ["lit://", "s3://"]
def __init__(
self,
@@ -35,18 +35,31 @@ def __init__(
root_folder: This is the folder from where the Drive perceives the data (e.g this acts as a mount dir).
"""
self.id = None
+ self.protocol = None
for protocol in self.__PROTOCOLS__:
if id.startswith(protocol):
self.protocol = protocol
self.id = id.replace(protocol, "")
+ break
+ else: # N.B. for-else loop
+ raise ValueError(
+ f"Unknown protocol for the drive 'id' argument '{id}`. The 'id' string "
+ f"must start with one of the following prefixes {self.__PROTOCOLS__}"
+ )
+
+ if self.protocol == "s3://" and not self.id.endswith("/"):
+ raise ValueError(
+ "S3 drives must end in a trailing slash (`/`) to indicate a folder is being mounted. "
+ f"Recieved: '{id}'. Mounting a single file is not currently supported."
+ )
if not self.id:
raise Exception(f"The Drive id needs to start with one of the following protocols: {self.__PROTOCOLS__}")
- if "/" in self.id:
+ if self.protocol != "s3://" and "/" in self.id:
raise Exception(f"The id should be unique to identify your drive. Found `{self.id}`.")
- self.root_folder = pathlib.Path(root_folder).resolve() if root_folder else os.getcwd()
+ self.root_folder = pathlib.Path(root_folder).resolve() if root_folder else pathlib.Path(os.getcwd())
if not os.path.isdir(self.root_folder):
raise Exception(f"The provided root_folder isn't a directory: {root_folder}")
self.component_name = component_name
@@ -75,6 +88,10 @@ def put(self, path: str) -> None:
raise Exception("The component name needs to be known to put a path to the Drive.")
if _is_flow_context():
raise Exception("The flow isn't allowed to put files into a Drive.")
+ if self.protocol == "s3://":
+ raise PermissionError(
+ "S3 based drives cannot currently add files via this API. Did you mean to use `lit://` drives?"
+ )
self._validate_path(path)
@@ -98,6 +115,10 @@ def list(self, path: Optional[str] = ".", component_name: Optional[str] = None)
"""
if _is_flow_context():
raise Exception("The flow isn't allowed to list files from a Drive.")
+ if self.protocol == "s3://":
+ raise PermissionError(
+ "S3 based drives cannot currently list files via this API. Did you mean to use `lit://` drives?"
+ )
if component_name:
paths = [
@@ -142,6 +163,10 @@ def get(
"""
if _is_flow_context():
raise Exception("The flow isn't allowed to get files from a Drive.")
+ if self.protocol == "s3://":
+ raise PermissionError(
+ "S3 based drives cannot currently get files via this API. Did you mean to use `lit://` drives?"
+ )
if component_name:
shared_path = self._to_shared_path(
@@ -189,6 +214,10 @@ def delete(self, path: str) -> None:
"""
if not self.component_name:
raise Exception("The component name needs to be known to delete a path to the Drive.")
+ if self.protocol == "s3://":
+ raise PermissionError(
+ "S3 based drives cannot currently delete files via this API. Did you mean to use `lit://` drives?"
+ )
shared_path = self._to_shared_path(
path,
diff --git a/src/lightning_app/structures/dict.py b/src/lightning_app/structures/dict.py
index 93e2b161b2e7a..b414269b93eec 100644
--- a/src/lightning_app/structures/dict.py
+++ b/src/lightning_app/structures/dict.py
@@ -22,7 +22,7 @@ def __init__(self, **kwargs: T):
.. doctest::
>>> from lightning_app import LightningFlow, LightningWork
- >>> from lightning_app.core import Dict
+ >>> from lightning_app.structures import Dict
>>> class CounterWork(LightningWork):
... def __init__(self):
... super().__init__()
diff --git a/src/lightning_app/structures/list.py b/src/lightning_app/structures/list.py
index f5a7c5c9913ad..cf691c98a8c38 100644
--- a/src/lightning_app/structures/list.py
+++ b/src/lightning_app/structures/list.py
@@ -24,7 +24,7 @@ def __init__(self, *items: T):
.. doctest::
>>> from lightning_app import LightningFlow, LightningWork
- >>> from lightning_app.core import List
+ >>> from lightning_app.structures import List
>>> class CounterWork(LightningWork):
... def __init__(self):
... super().__init__()
diff --git a/src/lightning_app/testing/testing.py b/src/lightning_app/testing/testing.py
index e1cc2e180dab5..884c02a0521c1 100644
--- a/src/lightning_app/testing/testing.py
+++ b/src/lightning_app/testing/testing.py
@@ -1,26 +1,30 @@
import asyncio
import json
+import logging
import os
import shutil
import subprocess
import sys
import tempfile
import time
+import traceback
from contextlib import contextmanager
from subprocess import Popen
from time import sleep
-from typing import Any, Callable, Dict, Generator, List, Type
+from typing import Any, Callable, Dict, Generator, List, Optional, Type
import requests
from lightning_cloud.openapi.rest import ApiException
from requests import Session
from rich import print
+from rich.color import ANSI_COLOR_NAMES
from lightning_app import LightningApp, LightningFlow
from lightning_app.cli.lightning_cli import run_app
from lightning_app.core.constants import LIGHTNING_CLOUD_PROJECT_ID
from lightning_app.runners.multiprocess import MultiProcessRuntime
from lightning_app.testing.config import Config
+from lightning_app.utilities.app_logs import _app_logs_reader
from lightning_app.utilities.cloud import _get_project
from lightning_app.utilities.enum import CacheCallsKeys
from lightning_app.utilities.imports import _is_playwright_available, requires
@@ -32,6 +36,9 @@
from playwright.sync_api import HttpCredentials, sync_playwright
+_logger = logging.getLogger(__name__)
+
+
class LightningTestApp(LightningApp):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
@@ -282,20 +289,6 @@ def run_app_in_cloud(app_folder: str, app_name: str = "app.py", extra_args: [str
var scrollingElement = (document.scrollingElement || document.body);
scrollingElement.scrollTop = scrollingElement.scrollHeight;
}, 200);
-
- if (!window._logs) {
- window._logs = [];
- }
-
- if (window.logTerminals) {
- Object.entries(window.logTerminals).forEach(
- ([key, value]) => {
- window.logTerminals[key]._onLightningWritelnHandler = function (data) {
- window._logs = window._logs.concat([data]);
- }
- }
- );
- }
"""
)
@@ -309,8 +302,46 @@ def run_app_in_cloud(app_folder: str, app_name: str = "app.py", extra_args: [str
except (playwright._impl._api_types.Error, playwright._impl._api_types.TimeoutError):
pass
- def fetch_logs() -> str:
- return admin_page.evaluate("window._logs;")
+ client = LightningClient()
+ project = _get_project(client)
+ identifiers = []
+ rich_colors = list(ANSI_COLOR_NAMES)
+
+ def fetch_logs(component_names: Optional[List[str]] = None) -> Generator:
+ """This methods creates websockets connection in threads and returns the logs to the main thread."""
+ app_id = admin_page.url.split("/")[-1]
+
+ if not component_names:
+ works = client.lightningwork_service_list_lightningwork(
+ project_id=project.project_id,
+ app_id=app_id,
+ ).lightningworks
+ component_names = ["flow"] + [w.name for w in works]
+
+ def on_error_callback(ws_app, *_):
+ print(traceback.print_exc())
+ ws_app.close()
+
+ colors = {c: rich_colors[i + 1] for i, c in enumerate(component_names)}
+ gen = _app_logs_reader(
+ client=client,
+ project_id=project.project_id,
+ app_id=app_id,
+ component_names=component_names,
+ follow=False,
+ on_error_callback=on_error_callback,
+ )
+ max_length = max(len(c.replace("root.", "")) for c in component_names)
+ for log_event in gen:
+ message = log_event.message
+ identifier = f"{log_event.timestamp}{log_event.message}"
+ if identifier not in identifiers:
+ date = log_event.timestamp.strftime("%m/%d/%Y %H:%M:%S")
+ identifiers.append(identifier)
+ color = colors[log_event.component_name]
+ padding = (max_length - len(log_event.component_name)) * " "
+ print(f"[{color}]{log_event.component_name}{padding}[/{color}] {date} {message}")
+ yield message
# 5. Print your application ID
print(
@@ -318,16 +349,11 @@ def fetch_logs() -> str:
)
try:
- yield admin_page, view_page, fetch_logs
+ yield admin_page, view_page, fetch_logs, name
except KeyboardInterrupt:
pass
finally:
print("##################################################")
- printed_logs = []
- for log in fetch_logs():
- if log not in printed_logs:
- printed_logs.append(log)
- print(log.split("[0m")[-1])
button = admin_page.locator('[data-cy="stop"]')
try:
button.wait_for(timeout=3 * 1000)
@@ -337,8 +363,6 @@ def fetch_logs() -> str:
context.close()
browser.close()
- client = LightningClient()
- project = _get_project(client)
list_lightningapps = client.lightningapp_instance_service_list_lightningapp_instances(project.project_id)
for lightningapp in list_lightningapps.lightningapps:
diff --git a/src/lightning_app/utilities/app_helpers.py b/src/lightning_app/utilities/app_helpers.py
index 4144c6de3ba12..faa612bba1998 100644
--- a/src/lightning_app/utilities/app_helpers.py
+++ b/src/lightning_app/utilities/app_helpers.py
@@ -299,7 +299,7 @@ def _set_child_name(component: "Component", child: "Component", new_name: str) -
return child_name
-def _delta_to_appstate_delta(root: "LightningFlow", component: "Component", delta: Delta) -> Delta:
+def _delta_to_app_state_delta(root: "LightningFlow", component: "Component", delta: Delta) -> Delta:
delta_dict = delta.to_dict()
for changed in delta_dict.values():
for delta_key in changed.copy().keys():
@@ -322,8 +322,9 @@ def _delta_to_appstate_delta(root: "LightningFlow", component: "Component", delt
delta_key_without_root = delta_key[4:] # the first 4 chars are the word 'root', strip it
new_key = new_prefix + delta_key_without_root
- changed[new_key] = val
- del changed[delta_key]
+ if new_key != delta_key:
+ changed[new_key] = val
+ del changed[delta_key]
return Delta(delta_dict)
diff --git a/src/lightning_app/utilities/app_logs.py b/src/lightning_app/utilities/app_logs.py
new file mode 100644
index 0000000000000..536fbaae05093
--- /dev/null
+++ b/src/lightning_app/utilities/app_logs.py
@@ -0,0 +1,140 @@
+import json
+import queue
+import sys
+from dataclasses import dataclass
+from datetime import datetime, timedelta
+from json import JSONDecodeError
+from threading import Thread
+from typing import Callable, Iterator, List, Optional
+
+import dateutil.parser
+from websocket import WebSocketApp
+
+from lightning_app.utilities.logs_socket_api import _LightningLogsSocketAPI
+from lightning_app.utilities.network import LightningClient
+
+
+@dataclass
+class _LogEventLabels:
+ app: str
+ container: str
+ filename: str
+ job: str
+ namespace: str
+ node_name: str
+ pod: str
+ stream: Optional[str] = None
+
+
+@dataclass
+class _LogEvent:
+ message: str
+ timestamp: datetime
+ component_name: str
+ labels: _LogEventLabels
+
+ def __ge__(self, other: "_LogEvent") -> bool:
+ return self.timestamp >= other.timestamp
+
+ def __gt__(self, other: "_LogEvent") -> bool:
+ return self.timestamp > other.timestamp
+
+
+def _push_log_events_to_read_queue_callback(component_name: str, read_queue: queue.PriorityQueue):
+ """Pushes _LogEvents from websocket to read_queue.
+
+ Returns callback function used with `on_message_callback` of websocket.WebSocketApp.
+ """
+
+ def callback(ws_app: WebSocketApp, msg: str):
+ # We strongly trust that the contract on API will hold atm :D
+ event_dict = json.loads(msg)
+ labels = _LogEventLabels(**event_dict["labels"])
+
+ if "message" in event_dict:
+ message = event_dict["message"]
+ timestamp = dateutil.parser.isoparse(event_dict["timestamp"])
+ event = _LogEvent(
+ message=message,
+ timestamp=timestamp,
+ component_name=component_name,
+ labels=labels,
+ )
+ read_queue.put(event)
+
+ return callback
+
+
+def _error_callback(ws_app: WebSocketApp, error: Exception):
+ errors = {
+ KeyError: "Malformed log message, missing key",
+ JSONDecodeError: "Malformed log message",
+ TypeError: "Malformed log format",
+ ValueError: "Malformed date format",
+ }
+ print(f"Error while reading logs ({errors.get(type(error), 'Unknown')})", file=sys.stderr)
+ ws_app.close()
+
+
+def _app_logs_reader(
+ client: LightningClient,
+ project_id: str,
+ app_id: str,
+ component_names: List[str],
+ follow: bool,
+ on_error_callback: Optional[Callable] = None,
+) -> Iterator[_LogEvent]:
+
+ read_queue = queue.PriorityQueue()
+ logs_api_client = _LightningLogsSocketAPI(client.api_client)
+
+ # We will use a socket per component
+ log_sockets = [
+ logs_api_client.create_lightning_logs_socket(
+ project_id=project_id,
+ app_id=app_id,
+ component=component_name,
+ on_message_callback=_push_log_events_to_read_queue_callback(component_name, read_queue),
+ on_error_callback=on_error_callback or _error_callback,
+ )
+ for component_name in component_names
+ ]
+
+ # And each socket on separate thread pushing log event to print queue
+ # run_forever() will run until we close() the connection from outside
+ log_threads = [Thread(target=work.run_forever) for work in log_sockets]
+
+ # Establish connection and begin pushing logs to the print queue
+ for th in log_threads:
+ th.start()
+
+ # Print logs from queue when log event is available
+ user_log_start = "<<< BEGIN USER_RUN_FLOW SECTION >>>"
+ start_timestamp = None
+
+ # Print logs from queue when log event is available
+ try:
+ while True:
+ log_event = read_queue.get(timeout=None if follow else 1.0)
+ if user_log_start in log_event.message:
+ start_timestamp = log_event.timestamp + timedelta(seconds=0.5)
+
+ if start_timestamp and log_event.timestamp > start_timestamp:
+ yield log_event
+
+ except queue.Empty:
+ # Empty is raised by queue.get if timeout is reached. Follow = False case.
+ pass
+
+ except KeyboardInterrupt:
+ # User pressed CTRL+C to exit, we sould respect that
+ pass
+
+ finally:
+ # Close connections - it will cause run_forever() to finish -> thread as finishes aswell
+ for socket in log_sockets:
+ socket.close()
+
+ # Because all socket were closed, we can just wait for threads to finish.
+ for th in log_threads:
+ th.join()
diff --git a/src/lightning_app/utilities/cli_helpers.py b/src/lightning_app/utilities/cli_helpers.py
index fcce96ec64407..6000114c3d4d6 100644
--- a/src/lightning_app/utilities/cli_helpers.py
+++ b/src/lightning_app/utilities/cli_helpers.py
@@ -49,16 +49,42 @@ def _is_url(id: Optional[str]) -> bool:
return False
+def _get_metadata_from_openapi(paths: Dict, path: str):
+ parameters = paths[path]["post"].get("parameters", {})
+ tag = paths[path]["post"].get("tags", [None])[0]
+ cls_path = paths[path]["post"].get("cls_path", None)
+ cls_name = paths[path]["post"].get("cls_name", None)
+
+ metadata = {"tag": tag, "parameters": {}}
+
+ if cls_path:
+ metadata["cls_path"] = cls_path
+
+ if cls_name:
+ metadata["cls_name"] = cls_name
+
+ if not parameters:
+ return metadata
+
+ metadata["parameters"].update({d["name"]: d["schema"]["type"] for d in parameters})
+ return metadata
+
+
+def _extract_command_from_openapi(openapi_resp: Dict) -> Dict[str, Dict[str, str]]:
+ command_paths = [p for p in openapi_resp["paths"] if p.startswith("/command/")]
+ return {p.replace("/command/", ""): _get_metadata_from_openapi(openapi_resp["paths"], p) for p in command_paths}
+
+
def _retrieve_application_url_and_available_commands(app_id_or_name_or_url: Optional[str]):
"""This function is used to retrieve the current url associated with an id."""
if _is_url(app_id_or_name_or_url):
url = app_id_or_name_or_url
assert url
- resp = requests.get(url + "/api/v1/commands")
+ resp = requests.get(url + "/openapi.json")
if resp.status_code != 200:
raise Exception(f"The server didn't process the request properly. Found {resp.json()}")
- return url, resp.json()
+ return url, _extract_command_from_openapi(resp.json())
# 2: If no identifier has been provided, evaluate the local application
failed_locally = False
@@ -66,10 +92,10 @@ def _retrieve_application_url_and_available_commands(app_id_or_name_or_url: Opti
if app_id_or_name_or_url is None:
try:
url = f"http://localhost:{APP_SERVER_PORT}"
- resp = requests.get(f"{url}/api/v1/commands")
+ resp = requests.get(f"{url}/openapi.json")
if resp.status_code != 200:
raise Exception(f"The server didn't process the request properly. Found {resp.json()}")
- return url, resp.json()
+ return url, _extract_command_from_openapi(resp.json())
except requests.exceptions.ConnectionError:
failed_locally = True
@@ -88,8 +114,8 @@ def _retrieve_application_url_and_available_commands(app_id_or_name_or_url: Opti
if lightningapp.id == app_id_or_name_or_url or lightningapp.name == app_id_or_name_or_url:
if lightningapp.status.url == "":
raise Exception("The application is starting. Try in a few moments.")
- resp = requests.get(lightningapp.status.url + "/api/v1/commands")
+ resp = requests.get(lightningapp.status.url + "/openapi.json")
if resp.status_code != 200:
raise Exception(f"The server didn't process the request properly. Found {resp.json()}")
- return lightningapp.status.url, resp.json()
+ return lightningapp.status.url, _extract_command_from_openapi(resp.json())
return None, None
diff --git a/src/lightning_app/utilities/commands/base.py b/src/lightning_app/utilities/commands/base.py
index 11661e51ca26a..c74926f542744 100644
--- a/src/lightning_app/utilities/commands/base.py
+++ b/src/lightning_app/utilities/commands/base.py
@@ -1,6 +1,5 @@
import errno
import inspect
-import logging
import os
import os.path as osp
import shutil
@@ -8,19 +7,18 @@
from getpass import getuser
from importlib.util import module_from_spec, spec_from_file_location
from tempfile import gettempdir
-from typing import Any, Callable, Dict, List, Optional, Tuple
-from uuid import uuid4
+from typing import Any, Callable, Dict, List, Optional, Union
import requests
from pydantic import BaseModel
+from lightning_app.api.http_methods import Post
+from lightning_app.api.request_types import APIRequest, CommandRequest
from lightning_app.utilities.app_helpers import is_overridden
from lightning_app.utilities.cloud import _get_project
from lightning_app.utilities.network import LightningClient
from lightning_app.utilities.state import AppState
-_logger = logging.getLogger(__name__)
-
def makedirs(path: str):
r"""Recursive directory creation function."""
@@ -31,31 +29,18 @@ def makedirs(path: str):
raise e
-class _ClientCommandConfig(BaseModel):
- command: str
- affiliation: str
- params: Dict[str, str]
- is_client_command: bool
- cls_path: str
- cls_name: str
- owner: str
- requirements: Optional[List[str]]
-
-
class ClientCommand:
def __init__(self, method: Callable, requirements: Optional[List[str]] = None) -> None:
self.method = method
flow = getattr(method, "__self__", None)
self.owner = flow.name if flow else None
self.requirements = requirements
- self.metadata = None
self.models: Optional[Dict[str, BaseModel]] = None
self.app_url = None
self._state = None
- def _setup(self, metadata: Dict[str, Any], models: Dict[str, BaseModel], app_url: str) -> None:
- self.metadata = metadata
- self.models = models
+ def _setup(self, command_name: str, app_url: str) -> None:
+ self.command_name = command_name
self.app_url = app_url
@property
@@ -72,67 +57,50 @@ def state(self):
def run(self, **cli_kwargs) -> None:
"""Overrides with the logic to execute on the client side."""
- def invoke_handler(self, **kwargs: Any) -> Dict[str, Any]:
- from lightning.app.utilities.state import headers_for
-
- assert kwargs.keys() == self.models.keys()
- for k, v in kwargs.items():
- assert isinstance(v, self.models[k])
- json = {
- "command_name": self.metadata["command"],
- "command_arguments": {k: v.json() for k, v in kwargs.items()},
- "affiliation": self.metadata["affiliation"],
- "id": str(uuid4()),
- }
- resp = requests.post(self.app_url + "/api/v1/commands", json=json, headers=headers_for({}))
+ def invoke_handler(self, config: BaseModel) -> Dict[str, Any]:
+ resp = requests.post(self.app_url + f"/command/{self.command_name}", data=config.json())
assert resp.status_code == 200, resp.json()
return resp.json()
def _to_dict(self):
return {"owner": self.owner, "requirements": self.requirements}
- def __call__(self, **kwargs: Any) -> Any:
- assert self.models
- input = {}
- for k, v in kwargs.items():
- input[k] = self.models[k].parse_raw(v)
- return self.method(**input)
+ def __call__(self, **kwargs):
+ return self.method(**kwargs)
def _download_command(
- command_metadata: Dict[str, Any],
- app_id: Optional[str],
+ command_name: str,
+ cls_path: str,
+ cls_name: str,
+ app_id: Optional[str] = None,
debug_mode: bool = False,
-) -> Tuple[ClientCommand, Dict[str, BaseModel]]:
+) -> ClientCommand:
# TODO: This is a skateboard implementation and the final version will rely on versioned
# immutable commands for security concerns
- config = _ClientCommandConfig(**command_metadata)
tmpdir = osp.join(gettempdir(), f"{getuser()}_commands")
makedirs(tmpdir)
- target_file = osp.join(tmpdir, f"{config.command}.py")
+ target_file = osp.join(tmpdir, f"{command_name}.py")
if app_id:
client = LightningClient()
project_id = _get_project(client).project_id
response = client.lightningapp_instance_service_list_lightningapp_instance_artifacts(project_id, app_id)
for artifact in response.artifacts:
- if f"commands/{config.command}.py" == artifact.filename:
+ if f"commands/{command_name}.py" == artifact.filename:
r = requests.get(artifact.url, allow_redirects=True)
with open(target_file, "wb") as f:
f.write(r.content)
else:
if not debug_mode:
- shutil.copy(config.cls_path, target_file)
+ shutil.copy(cls_path, target_file)
- cls_name = config.cls_name
- spec = spec_from_file_location(config.cls_name, config.cls_path if debug_mode else target_file)
+ spec = spec_from_file_location(cls_name, cls_path if debug_mode else target_file)
mod = module_from_spec(spec)
sys.modules[cls_name] = mod
spec.loader.exec_module(mod)
- command = getattr(mod, cls_name)(method=None, requirements=config.requirements)
- models = {k: getattr(mod, v) for k, v in config.params.items()}
- if debug_mode:
- shutil.rmtree(tmpdir)
- return command, models
+ command = getattr(mod, cls_name)(method=None, requirements=[])
+ shutil.rmtree(tmpdir)
+ return command
def _to_annotation(anno: str) -> str:
@@ -142,7 +110,7 @@ def _to_annotation(anno: str) -> str:
return anno
-def _command_to_method_and_metadata(command: ClientCommand) -> Tuple[Callable, Dict[str, Any]]:
+def _validate_client_command(command: ClientCommand):
"""Extract method and its metadata from a ClientCommand."""
params = inspect.signature(command.method).parameters
command_metadata = {
@@ -170,8 +138,6 @@ def _command_to_method_and_metadata(command: ClientCommand) -> Tuple[Callable, D
raise Exception(
f"The provided annotation for the argument {k} shouldn't an instance of pydantic BaseModel."
)
- command.models[k] = config
- return method, command_metadata
def _upload_command(command_name: str, command: ClientCommand) -> Optional[str]:
@@ -192,54 +158,68 @@ def _upload_command(command_name: str, command: ClientCommand) -> Optional[str]:
return filepath
-def _populate_commands_endpoint(app):
+def _prepare_commands(app) -> List:
if not is_overridden("configure_commands", app.root):
- return
+ return []
- # 1: Populate commands metadata
+ # 1: Upload the command to s3.
commands = app.root.configure_commands()
- commands_metadata = []
- command_names = set()
for command_mapping in commands:
for command_name, command in command_mapping.items():
- is_client_command = isinstance(command, ClientCommand)
- extras = {}
- if is_client_command:
+ if isinstance(command, ClientCommand):
_upload_command(command_name, command)
- command, extras = _command_to_method_and_metadata(command)
- if command_name in command_names:
- raise Exception(f"The component name {command_name} has already been used. They need to be unique.")
- command_names.add(command_name)
- params = inspect.signature(command).parameters
- commands_metadata.append(
- {
- "command": command_name,
- "affiliation": command.__self__.name,
- "params": list(params.keys()),
- "is_client_command": is_client_command,
- **extras,
- }
- )
- # 1.2: Pass the collected commands through the queue to the Rest API.
- app.commands_metadata_queue.put(commands_metadata)
+ # 2: Cache the commands on the app.
app.commands = commands
+ return commands
-def _process_command_requests(app):
- if not is_overridden("configure_commands", app.root):
- return
-
- # 1: Populate commands metadata
- commands = app.commands
-
- # 2: Collect requests metadata
- command_query = app.get_state_changed_from_queue(app.commands_requests_queue)
- if command_query:
- for command in commands:
- for command_name, method in command.items():
- if command_query["command_name"] == command_name:
- # 2.1: Evaluate the method associated to a specific command.
- # Validation is done on the CLI side.
- response = method(**command_query["command_arguments"])
- app.commands_responses_queue.put({"response": response, "id": command_query["id"]})
+def _process_api_request(app, request: APIRequest) -> None:
+ flow = app.get_component_by_name(request.name)
+ method = getattr(flow, request.method_name)
+ response = method(*request.args, **request.kwargs)
+ app.api_response_queue.put({"response": response, "id": request.id})
+
+
+def _process_command_requests(app, request: CommandRequest) -> None:
+ for command in app.commands:
+ for command_name, method in command.items():
+ if request.method_name == command_name:
+ # 2.1: Evaluate the method associated to a specific command.
+ # Validation is done on the CLI side.
+ response = method(*request.args, **request.kwargs)
+ app.api_response_queue.put({"response": response, "id": request.id})
+
+
+def _process_requests(app, request: Union[APIRequest, CommandRequest]) -> None:
+ """Convert user commands to API endpoint."""
+ if isinstance(request, APIRequest):
+ _process_api_request(app, request)
+ else:
+ _process_command_requests(app, request)
+
+
+def _collect_open_api_extras(command) -> Dict:
+ if not isinstance(command, ClientCommand):
+ return {}
+ return {
+ "cls_path": inspect.getfile(command.__class__),
+ "cls_name": command.__class__.__name__,
+ }
+
+
+def _commands_to_api(commands: List[Dict[str, Union[Callable, ClientCommand]]]) -> List:
+ """Convert user commands to API endpoint."""
+ api = []
+ for command in commands:
+ for k, v in command.items():
+ api.append(
+ Post(
+ f"/command/{k}",
+ v.method if isinstance(v, ClientCommand) else v,
+ method_name=k,
+ tags=["app_client_command"] if isinstance(v, ClientCommand) else ["app_command"],
+ openapi_extra=_collect_open_api_extras(v),
+ )
+ )
+ return api
diff --git a/src/lightning_app/utilities/enum.py b/src/lightning_app/utilities/enum.py
index dbf20413aa9d9..2b88d93169930 100644
--- a/src/lightning_app/utilities/enum.py
+++ b/src/lightning_app/utilities/enum.py
@@ -72,3 +72,9 @@ def make_status(stage: str, message: Optional[str] = None, reason: Optional[str]
class CacheCallsKeys:
LATEST_CALL_HASH = "latest_call_hash"
+
+
+class OpenAPITags:
+ APP_CLIENT_COMMAND = "app_client_command"
+ APP_COMMAND = "app_command"
+ APP_API = "app_api"
diff --git a/src/lightning_app/utilities/logs_socket_api.py b/src/lightning_app/utilities/logs_socket_api.py
new file mode 100644
index 0000000000000..0ab9a5c24f3e5
--- /dev/null
+++ b/src/lightning_app/utilities/logs_socket_api.py
@@ -0,0 +1,95 @@
+from typing import Callable, Optional
+from urllib.parse import urlparse
+
+from lightning_cloud.openapi import ApiClient, AuthServiceApi, V1LoginRequest
+from websocket import WebSocketApp
+
+from lightning_app.utilities.login import Auth
+
+
+class _LightningLogsSocketAPI:
+ def __init__(self, api_client: ApiClient):
+ self.api_client = api_client
+ self._auth = Auth()
+ self._auth.authenticate()
+ self._auth_service = AuthServiceApi(api_client)
+
+ def _get_api_token(self) -> str:
+ token_resp = self._auth_service.auth_service_login(
+ body=V1LoginRequest(
+ username=self._auth.username,
+ api_key=self._auth.api_key,
+ )
+ )
+ return token_resp.token
+
+ @staticmethod
+ def _socket_url(host: str, project_id: str, app_id: str, token: str, component: str) -> str:
+ return (
+ f"wss://{host}/v1/projects/{project_id}/appinstances/{app_id}/logs?"
+ f"token={token}&component={component}&follow=true"
+ )
+
+ def create_lightning_logs_socket(
+ self,
+ project_id: str,
+ app_id: str,
+ component: str,
+ on_message_callback: Callable[[WebSocketApp, str], None],
+ on_error_callback: Optional[Callable[[Exception, str], None]] = None,
+ ) -> WebSocketApp:
+ """Creates and returns WebSocketApp to listen to lightning app logs.
+
+ .. code-block:: python
+ # Synchronous reading, run_forever() is blocking
+
+
+ def print_log_msg(ws_app, msg):
+ print(msg)
+
+
+ flow_logs_socket = client.create_lightning_logs_socket("project_id", "app_id", "flow", print_log_msg)
+ flow_socket.run_forever()
+
+ .. code-block:: python
+ # Asynchronous reading (with Threads)
+
+
+ def print_log_msg(ws_app, msg):
+ print(msg)
+
+
+ flow_logs_socket = client.create_lightning_logs_socket("project_id", "app_id", "flow", print_log_msg)
+ work_logs_socket = client.create_lightning_logs_socket("project_id", "app_id", "work_1", print_log_msg)
+
+ flow_logs_thread = Thread(target=flow_logs_socket.run_forever)
+ work_logs_thread = Thread(target=work_logs_socket.run_forever)
+
+ flow_logs_thread.start()
+ work_logs_thread.start()
+ # .......
+
+ flow_logs_socket.close()
+ work_logs_thread.close()
+
+ Arguments:
+ project_id: Project ID.
+ app_id: Application ID.
+ component: Component name eg flow.
+ on_message_callback: Callback object which is called when received data.
+ on_error_callback: Callback object which is called when we get error.
+
+ Returns:
+ WebSocketApp of the wanted socket
+ """
+ _token = self._get_api_token()
+ clean_ws_host = urlparse(self.api_client.configuration.host).netloc
+ socket_url = self._socket_url(
+ host=clean_ws_host,
+ project_id=project_id,
+ app_id=app_id,
+ token=_token,
+ component=component,
+ )
+
+ return WebSocketApp(socket_url, on_message=on_message_callback, on_error=on_error_callback)
diff --git a/src/lightning_app/utilities/network.py b/src/lightning_app/utilities/network.py
index 7fd03750a515d..050734723acc1 100644
--- a/src/lightning_app/utilities/network.py
+++ b/src/lightning_app/utilities/network.py
@@ -48,11 +48,12 @@ def _configure_session() -> Session:
return http
-def _check_service_url_is_ready(url: str, timeout: float = 100) -> bool:
+def _check_service_url_is_ready(url: str, timeout: float = 5) -> bool:
try:
response = requests.get(url, timeout=timeout)
return response.status_code in (200, 404)
except (ConnectionError, ConnectTimeout, ReadTimeout):
+ logger.debug(f"The url {url} is not ready.")
return False
diff --git a/src/lightning_app/utilities/packaging/lightning_utils.py b/src/lightning_app/utilities/packaging/lightning_utils.py
index 37f4ff22988eb..073d4d7ab613a 100644
--- a/src/lightning_app/utilities/packaging/lightning_utils.py
+++ b/src/lightning_app/utilities/packaging/lightning_utils.py
@@ -89,8 +89,13 @@ def get_dist_path_if_editable_install(project_name) -> str:
def _prepare_lightning_wheels_and_requirements(root: Path) -> Optional[Callable]:
+ """This function determines if lightning is installed in editable mode (for developers) and packages the
+ current lightning source along with the app.
- if "site-packages" in _PROJECT_ROOT:
+ For normal users who install via PyPi or Conda, then this function does not do anything.
+ """
+
+ if not get_dist_path_if_editable_install("lightning"):
return
# Packaging the Lightning codebase happens only inside the `lightning` repo.
diff --git a/src/lightning_app/utilities/proxies.py b/src/lightning_app/utilities/proxies.py
index 2c93a6c89f38c..99ad6e2aad0cf 100644
--- a/src/lightning_app/utilities/proxies.py
+++ b/src/lightning_app/utilities/proxies.py
@@ -74,7 +74,7 @@ def _send_data_to_caller_queue(work: "LightningWork", caller_queue: "BaseQueue",
data.update({"state": work_state})
logger.debug(f"Sending to {work.name}: {data}")
- caller_queue.put(data)
+ caller_queue.put(deepcopy(data))
# Reset the calls entry.
work_state["calls"] = calls
diff --git a/src/lightning_app/utilities/scheduler.py b/src/lightning_app/utilities/scheduler.py
index 012930f017f20..e45b0879246b9 100644
--- a/src/lightning_app/utilities/scheduler.py
+++ b/src/lightning_app/utilities/scheduler.py
@@ -15,7 +15,7 @@ class SchedulerThread(threading.Thread):
def __init__(self, app) -> None:
super().__init__(daemon=True)
self._exit_event = threading.Event()
- self._sleep_time = 0.5
+ self._sleep_time = 1.0
self._app = app
def run(self) -> None:
diff --git a/src/pytorch_lightning/CHANGELOG.md b/src/pytorch_lightning/CHANGELOG.md
index a3755d7733dba..409d3f51bd46f 100644
--- a/src/pytorch_lightning/CHANGELOG.md
+++ b/src/pytorch_lightning/CHANGELOG.md
@@ -8,6 +8,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
### Added
+- Added prefix to log message in `seed_everything` with rank info ([#13290](https://github.com/Lightning-AI/lightning/issues/13290))
+
+
- Added profiling to these hooks: `on_before_batch_transfer`, `transfer_batch_to_device`, `on_after_batch_transfer`, `configure_gradient_clipping`, `clip_gradients` ([#14069](https://github.com/Lightning-AI/lightning/pull/14069))
@@ -22,6 +25,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Raised a `MisconfigurationException` if batch transfer hooks are overriden with `IPUAccelerator` ([13961](https://github.com/Lightning-AI/lightning/pull/13961))
+- Updated compatibility for LightningLite to run with the latest DeepSpeed 0.7.0 ([13967](https://github.com/Lightning-AI/lightning/pull/13967))
+
+
### Deprecated
- Deprecated `LightningDeepSpeedModule` ([#14000](https://github.com/Lightning-AI/lightning/pull/14000))
@@ -30,7 +36,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Deprecated `amp_level` from `Trainer` in favour of passing it explictly via precision plugin ([#13898](https://github.com/Lightning-AI/lightning/pull/13898))
--
+- Deprecated the calls to `pytorch_lightning.utiltiies.meta` functions in favor of built-in https://github.com/pytorch/torchdistx support ([#13868](https://github.com/Lightning-AI/lightning/pull/13868))
### Removed
@@ -44,21 +50,48 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Removed the deprecated `DDP2Strategy` ([#14026](https://github.com/Lightning-AI/lightning/pull/14026))
+- Removed the deprecated `DistributedType` and `DeviceType` enum classes ([#14045](https://github.com/Lightning-AI/lightning/pull/14045))
+
+
+- Removed the experimental `pytorch_lightning.utiltiies.meta` functions in favor of built-in https://github.com/pytorch/torchdistx support ([#13868](https://github.com/Lightning-AI/lightning/pull/13868))
+
+
### Fixed
-- Casted only floating point tensors to fp16 with IPUs ([#13983](https://github.com/Lightning-AI/lightning/pull/13983))
+- Fixed a bug that caused spurious `AttributeError` when multiple `DataLoader` classes are imported ([#14117](https://github.com/Lightning-AI/lightning/pull/14117))
-- Casted tensors to fp16 before moving them to device with `DeepSpeedStrategy` ([#14000](https://github.com/Lightning-AI/lightning/pull/14000))
+- Fixed epoch-end logging results not being reset after the end of the epoch ([#14061](https://github.com/Lightning-AI/lightning/pull/14061))
-- Fixed the `NeptuneLogger` dependency being unrecognized ([#13988](https://github.com/Lightning-AI/lightning/pull/13988))
+- Fixed resuming from a checkpoint when using Stochastic Weight Averaging (SWA) ([#9938](https://github.com/Lightning-AI/lightning/pull/9938))
-- Fixed an issue where users would be warned about unset `max_epochs` even when `fast_dev_run` was set ([#13262](https://github.com/Lightning-AI/lightning/pull/13262))
+- Fixed the device placement when `LightningModule.cuda()` gets called without specifying a device index and the current cuda device was not 0 ([#14128](https://github.com/Lightning-AI/lightning/pull/14128))
+
+- Avoid `metadata.entry_points` deprecation warning on Python 3.10 ([#14052](https://github.com/Lightning-AI/lightning/pull/14052))
+
+- Fixed epoch-end logging results not being reset after the end of the epoch ([#14061](https://github.com/Lightning-AI/lightning/pull/14061))
+
+
+- Fixed saving hyperparameters in a composition where the parent class is not a `LightningModule` or `LightningDataModule` ([#14151](https://github.com/Lightning-AI/lightning/pull/14151))
+
+
+
+## [1.7.1] - 2022-08-09
+
+### Fixed
+
+- Casted only floating point tensors to fp16 with IPUs ([#13983](https://github.com/Lightning-AI/lightning/pull/13983))
+- Casted tensors to fp16 before moving them to device with `DeepSpeedStrategy` ([#14000](https://github.com/Lightning-AI/lightning/pull/14000))
+- Fixed the `NeptuneLogger` dependency being unrecognized ([#13988](https://github.com/Lightning-AI/lightning/pull/13988))
+- Fixed an issue where users would be warned about unset `max_epochs` even when `fast_dev_run` was set ([#13262](https://github.com/Lightning-AI/lightning/pull/13262))
- Fixed MPS device being unrecognized ([#13992](https://github.com/Lightning-AI/lightning/pull/13992))
+- Fixed incorrect `precision="mixed"` being used with `DeepSpeedStrategy` and `IPUStrategy` ([#14041](https://github.com/Lightning-AI/lightning/pull/14041))
+- Fixed dtype inference during gradient norm computation ([#14051](https://github.com/Lightning-AI/lightning/pull/14051))
+- Fixed a bug that caused `ddp_find_unused_parameters` to be set `False`, whereas the intended default is `True` ([#14095](https://github.com/Lightning-AI/lightning/pull/14095))
## [1.7.0] - 2022-08-02
diff --git a/src/pytorch_lightning/README.md b/src/pytorch_lightning/README.md
index eb1a42730b5f0..914596c0a9d2f 100644
--- a/src/pytorch_lightning/README.md
+++ b/src/pytorch_lightning/README.md
@@ -14,8 +14,8 @@ ______________________________________________________________________
Docs •
Examples •
Community •
- Grid AI •
- License
+ Lightning AI •
+ License
@@ -78,17 +78,17 @@ Lightning is rigorously tested across multiple CPUs, GPUs, TPUs, IPUs, and HPUs
-| System / PyTorch ver. | 1.9 | 1.10 | 1.12 (latest) |
-| :------------------------: | :-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: | :-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: | :--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: |
-| Linux py3.7 \[GPUs\*\*\] | - | - | - |
-| Linux py3.7 \[TPUs\*\*\*\] | [![CircleCI](https://circleci.com/gh/Lightning-AI/lightning/tree/master.svg?style=svg)](https://circleci.com/gh/Lightning-AI/lightning/tree/master) | - | - |
-| Linux py3.8 \[IPUs\] | [![Build Status]()](https://dev.azure.com/Lightning-AI/lightning/_build/latest?definitionId=6&branchName=master) | - | - |
-| Linux py3.8 \[HPUs\] | - | [![Build Status]()](https://dev.azure.com/Lightning-AI/lightning/_build/latest?definitionId=6&branchName=master) | - |
-| Linux py3.8 (with Conda) | [![Test](https://github.com/Lightning-AI/lightning/actions/workflows/ci_test-conda.yml/badge.svg?branch=master&event=push)](https://github.com/Lightning-AI/lightning/actions/workflows/ci_test-conda.yml) | [![Test](https://github.com/Lightning-AI/lightning/actions/workflows/ci_test-conda.yml/badge.svg?branch=master&event=push)](https://github.com/Lightning-AI/lightning/actions/workflows/ci_test-conda.yml) | - |
-| Linux py3.9 (with Conda) | - | - | [![Test](https://github.com/Lightning-AI/lightning/actions/workflows/ci_test-conda.yml/badge.svg?branch=master&event=push)](https://github.com/Lightning-AI/lightning/actions/workflows/ci_test-conda.yml) |
-| Linux py3.{7,9} | - | - | [![Test](https://github.com/Lightning-AI/lightning/actions/workflows/ci_test-full.yml/badge.svg?branch=master&event=push)](https://github.com/Lightning-AI/lightning/actions/workflows/ci_test-full.yml) |
-| OSX py3.{7,9} | - | - | [![Test](https://github.com/Lightning-AI/lightning/actions/workflows/ci_test-full.yml/badge.svg?branch=master&event=push)](https://github.com/Lightning-AI/lightning/actions/workflows/ci_test-full.yml) |
-| Windows py3.{7,9} | - | - | [![Test](https://github.com/Lightning-AI/lightning/actions/workflows/ci_test-full.yml/badge.svg?branch=master&event=push)](https://github.com/Lightning-AI/lightning/actions/workflows/ci_test-full.yml) |
+| System / PyTorch ver. | 1.9 | 1.10 | 1.12 (latest) |
+| :------------------------: | :-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: | :-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: | :------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: |
+| Linux py3.7 \[GPUs\*\*\] | - | - | - |
+| Linux py3.7 \[TPUs\*\*\*\] | [![CircleCI](https://circleci.com/gh/Lightning-AI/lightning/tree/master.svg?style=svg)](https://circleci.com/gh/Lightning-AI/lightning/tree/master) | - | - |
+| Linux py3.8 \[IPUs\] | [![Build Status]()](https://dev.azure.com/Lightning-AI/lightning/_build/latest?definitionId=25&branchName=master) | - | - |
+| Linux py3.8 \[HPUs\] | - | [![Build Status]()](https://dev.azure.com/Lightning-AI/lightning/_build/latest?definitionId=26&branchName=master) | - |
+| Linux py3.8 (with Conda) | [![Test](https://github.com/Lightning-AI/lightning/actions/workflows/ci-pytorch-test-conda.yml/badge.svg?branch=master&event=push)](https://github.com/Lightning-AI/lightning/actions/workflows/ci-pytorch-test-conda.yml) | [![Test](https://github.com/Lightning-AI/lightning/actions/workflows/ci-pytorch-test-conda.yml/badge.svg?branch=master&event=push)](https://github.com/Lightning-AI/lightning/actions/workflows/ci-pytorch-test-conda.yml) | - |
+| Linux py3.9 (with Conda) | - | - | [![Test](https://github.com/Lightning-AI/lightning/actions/workflows/ci-pytorch-test-conda.yml/badge.svg?branch=master&event=push)](https://github.com/Lightning-AI/lightning/actions/workflows/ci-pytorch-test-conda.yml) |
+| Linux py3.{7,9} | - | - | [![Test](https://github.com/Lightning-AI/lightning/actions/workflows/ci-pytorch-test-full.yml/badge.svg?branch=master&event=push)](https://github.com/Lightning-AI/lightning/actions/workflows/ci-pytorch-test-full.yml) |
+| OSX py3.{7,9} | - | - | [![Test](https://github.com/Lightning-AI/lightning/actions/workflows/ci-pytorch-test-full.yml/badge.svg?branch=master&event=push)](https://github.com/Lightning-AI/lightning/actions/workflows/ci-pytorch-test-full.yml) |
+| Windows py3.{7,9} | - | - | [![Test](https://github.com/Lightning-AI/lightning/actions/workflows/ci-pytorch-test-full.yml/badge.svg?branch=master&event=push)](https://github.com/Lightning-AI/lightning/actions/workflows/ci-pytorch-test-full.yml) |
- _\*\* tests run on two NVIDIA P100_
- _\*\*\* tests run on Google GKE TPUv2/3. TPU py3.7 means we support Colab and Kaggle env._
@@ -130,8 +130,8 @@ conda install pytorch-lightning -c conda-forge
The actual status of stable is the following:
-[![Test PyTorch full](https://github.com/Lightning-AI/lightning/actions/workflows/ci-pytorch_test-full.yml/badge.svg?branch=release%2Fpytorch&event=push)](https://github.com/Lightning-AI/lightning/actions/workflows/ci-pytorch_test-full.yml)
-[![Test PyTorch with Conda](https://github.com/Lightning-AI/lightning/actions/workflows/ci-pytorch_test-conda.yml/badge.svg?branch=release%2Fpytorch&event=push)](https://github.com/Lightning-AI/lightning/actions/workflows/ci-pytorch_test-conda.yml)
+[![Test PyTorch full](https://github.com/Lightning-AI/lightning/actions/workflows/ci-pytorch-test-full.yml/badge.svg?branch=release%2Fpytorch&event=push)](https://github.com/Lightning-AI/lightning/actions/workflows/ci-pytorch-test-full.yml)
+[![Test PyTorch with Conda](https://github.com/Lightning-AI/lightning/actions/workflows/ci-pytorch-test-conda.yml/badge.svg?branch=release%2Fpytorch&event=push)](https://github.com/Lightning-AI/lightning/actions/workflows/ci-pytorch-test-conda.yml)
[![GPU]()](https://dev.azure.com/Lightning-AI/lightning/_build/latest?definitionId=24&branchName=release%2Fpytorch)
[![TPU](https://dl.circleci.com/status-badge/img/gh/Lightning-AI/lightning/tree/release%2Fpytorch.svg?style=svg)](https://dl.circleci.com/status-badge/redirect/gh/Lightning-AI/lightning/tree/release%2Fpytorch)
[![IPU]()](https://dev.azure.com/Lightning-AI/lightning/_build/latest?definitionId=25&branchName=release%2Fpytorch)
diff --git a/src/pytorch_lightning/__about__.py b/src/pytorch_lightning/__about__.py
index 6d09c5264e1ab..e2fdbd9ee3016 100644
--- a/src/pytorch_lightning/__about__.py
+++ b/src/pytorch_lightning/__about__.py
@@ -13,7 +13,6 @@
# limitations under the License.
import time
-# __version__ = "1.7.0"
__author__ = "Lightning AI et al."
__author_email__ = "pytorch@lightning.ai"
__license__ = "Apache-2.0"
diff --git a/src/pytorch_lightning/callbacks/stochastic_weight_avg.py b/src/pytorch_lightning/callbacks/stochastic_weight_avg.py
index 20a3dcc3f0f26..6650bb3f0c479 100644
--- a/src/pytorch_lightning/callbacks/stochastic_weight_avg.py
+++ b/src/pytorch_lightning/callbacks/stochastic_weight_avg.py
@@ -16,7 +16,7 @@
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
"""
from copy import deepcopy
-from typing import Any, Callable, cast, List, Optional, Union
+from typing import Any, Callable, cast, Dict, List, Optional, Union
import torch
from torch import nn, Tensor
@@ -24,6 +24,7 @@
import pytorch_lightning as pl
from pytorch_lightning.callbacks.callback import Callback
+from pytorch_lightning.strategies import DDPFullyShardedStrategy, DeepSpeedStrategy
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.rank_zero import rank_zero_info, rank_zero_warn
from pytorch_lightning.utilities.types import _LRScheduler, LRSchedulerConfig
@@ -112,15 +113,22 @@ def __init__(
if device is not None and not isinstance(device, (torch.device, str)):
raise MisconfigurationException(f"device is expected to be a torch.device or a str. Found {device}")
+ self.n_averaged: Optional[torch.Tensor] = None
self._swa_epoch_start = swa_epoch_start
self._swa_lrs = swa_lrs
self._annealing_epochs = annealing_epochs
self._annealing_strategy = annealing_strategy
self._avg_fn = avg_fn or self.avg_fn
self._device = device
- self._max_epochs: int
- self._model_contains_batch_norm: bool
+ self._model_contains_batch_norm: Optional[bool] = None
self._average_model: "pl.LightningModule"
+ self._initialized = False
+ self._swa_scheduler: Optional[_LRScheduler] = None
+ self._scheduler_state: Optional[Dict] = None
+ self._init_n_averaged = 0
+ self._latest_update_epoch = -1
+ self.momenta: Optional[Dict[nn.modules.batchnorm._BatchNorm, float]] = None
+ self._max_epochs: int
@property
def swa_start(self) -> int:
@@ -147,6 +155,9 @@ def on_fit_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -
if len(trainer.lr_scheduler_configs) > 1:
raise MisconfigurationException("SWA currently not supported for more than 1 `lr_scheduler`.")
+ if isinstance(trainer.strategy, (DDPFullyShardedStrategy, DeepSpeedStrategy)):
+ raise MisconfigurationException("SWA does not currently support sharded models.")
+
if isinstance(self._swa_epoch_start, float):
self._swa_epoch_start = int(trainer.max_epochs * self._swa_epoch_start)
@@ -158,8 +169,13 @@ def on_fit_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -
assert trainer.fit_loop.max_epochs is not None
trainer.fit_loop.max_epochs += 1
+ if self._scheduler_state is not None:
+ self._clear_schedulers(trainer)
+
def on_train_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
- if trainer.current_epoch == self.swa_start:
+ if (not self._initialized) and (self.swa_start <= trainer.current_epoch <= self.swa_end):
+ self._initialized = True
+
# move average model to request device.
self._average_model = self._average_model.to(self._device or pl_module.device)
@@ -180,6 +196,17 @@ def on_train_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningMo
last_epoch=trainer.max_epochs if self._annealing_strategy == "cos" else -1,
),
)
+ if self._scheduler_state is not None:
+ # Restore scheduler state from checkpoint
+ self._swa_scheduler.load_state_dict(self._scheduler_state)
+ elif trainer.current_epoch != self.swa_start:
+ # Log a warning if we're initializing after start without any checkpoint data,
+ # as behaviour will be different compared to having checkpoint data.
+ rank_zero_warn(
+ "SWA is initializing after swa_start without any checkpoint data. "
+ "This may be caused by loading a checkpoint from an older version of PyTorch Lightning."
+ )
+
# We assert that there is only one optimizer on fit start, so know opt_idx is always 0
default_scheduler_cfg = LRSchedulerConfig(self._swa_scheduler, opt_idx=0)
assert default_scheduler_cfg.interval == "epoch" and default_scheduler_cfg.frequency == 1
@@ -196,14 +223,18 @@ def on_train_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningMo
else:
trainer.lr_scheduler_configs.append(default_scheduler_cfg)
- self.n_averaged = torch.tensor(0, dtype=torch.long, device=pl_module.device)
+ if self.n_averaged is None:
+ self.n_averaged = torch.tensor(self._init_n_averaged, dtype=torch.long, device=pl_module.device)
- if self.swa_start <= trainer.current_epoch <= self.swa_end:
+ if (self.swa_start <= trainer.current_epoch <= self.swa_end) and (
+ trainer.current_epoch > self._latest_update_epoch
+ ):
+ assert self.n_averaged is not None
self.update_parameters(self._average_model, pl_module, self.n_averaged, self._avg_fn)
+ self._latest_update_epoch = trainer.current_epoch
# Note: No > here in case the callback is saved with the model and training continues
if trainer.current_epoch == self.swa_end + 1:
-
# Transfer weights from average model to pl_module
self.transfer_weights(self._average_model, pl_module)
@@ -265,6 +296,7 @@ def reset_batch_norm_and_save_state(self, pl_module: "pl.LightningModule") -> No
def reset_momenta(self) -> None:
"""Adapted from https://github.com/pytorch/pytorch/blob/v1.7.1/torch/optim/swa_utils.py#L164-L165."""
+ assert self.momenta is not None
for bn_module in self.momenta:
bn_module.momentum = self.momenta[bn_module]
@@ -285,3 +317,35 @@ def update_parameters(
def avg_fn(averaged_model_parameter: Tensor, model_parameter: Tensor, num_averaged: Tensor) -> Tensor:
"""Adapted from https://github.com/pytorch/pytorch/blob/v1.7.1/torch/optim/swa_utils.py#L95-L97."""
return averaged_model_parameter + (model_parameter - averaged_model_parameter) / (num_averaged + 1)
+
+ def state_dict(self) -> Dict[str, Any]:
+ return {
+ "n_averaged": 0 if self.n_averaged is None else self.n_averaged.item(),
+ "latest_update_epoch": self._latest_update_epoch,
+ "scheduler_state": None if self._swa_scheduler is None else self._swa_scheduler.state_dict(),
+ "average_model_state": None if self._average_model is None else self._average_model.state_dict(),
+ }
+
+ def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
+ self._init_n_averaged = state_dict["n_averaged"]
+ self._latest_update_epoch = state_dict["latest_update_epoch"]
+ self._scheduler_state = state_dict["scheduler_state"]
+ self._load_average_model_state(state_dict["average_model_state"])
+
+ @staticmethod
+ def _clear_schedulers(trainer: "pl.Trainer") -> None:
+ # If we have scheduler state saved, clear the scheduler configs so that we don't try to
+ # load state into the wrong type of schedulers when restoring scheduler checkpoint state.
+ # We'll configure the scheduler and re-load its state in on_train_epoch_start.
+ # Note that this relies on the callback state being restored before the scheduler state is
+ # restored, and doesn't work if restore_checkpoint_after_setup is True, but at the time of
+ # writing that is only True for deepspeed which is already not supported by SWA.
+ # See https://github.com/PyTorchLightning/pytorch-lightning/issues/11665 for background.
+ if trainer.lr_scheduler_configs:
+ assert len(trainer.lr_scheduler_configs) == 1
+ trainer.lr_scheduler_configs.clear()
+
+ def _load_average_model_state(self, model_state: Any) -> None:
+ if self._average_model is None:
+ return
+ self._average_model.load_state_dict(model_state)
diff --git a/src/pytorch_lightning/core/mixins/device_dtype_mixin.py b/src/pytorch_lightning/core/mixins/device_dtype_mixin.py
index 62e81e4839da6..2916d8b07cb4e 100644
--- a/src/pytorch_lightning/core/mixins/device_dtype_mixin.py
+++ b/src/pytorch_lightning/core/mixins/device_dtype_mixin.py
@@ -116,14 +116,16 @@ def cuda(self, device: Optional[Union[torch.device, int]] = None) -> Self: # ty
while being optimized.
Arguments:
- device: if specified, all parameters will be
- copied to that device
+ device: If specified, all parameters will be copied to that device. If `None`, the current CUDA device
+ index will be used.
Returns:
Module: self
"""
- if device is None or isinstance(device, int):
- device = torch.device("cuda", index=(device or 0))
+ if device is None:
+ device = torch.device("cuda", torch.cuda.current_device())
+ elif isinstance(device, int):
+ device = torch.device("cuda", index=device)
self.__update_properties(device=device)
return super().cuda(device=device)
diff --git a/src/pytorch_lightning/core/saving.py b/src/pytorch_lightning/core/saving.py
index da81e4c212560..ffdc0988a1a6e 100644
--- a/src/pytorch_lightning/core/saving.py
+++ b/src/pytorch_lightning/core/saving.py
@@ -20,10 +20,9 @@
from argparse import Namespace
from copy import deepcopy
from enum import Enum
-from typing import Any, Callable, Dict, IO, MutableMapping, Optional, Union
+from typing import Any, Callable, cast, Dict, IO, MutableMapping, Optional, Type, Union
from warnings import warn
-import torch
import yaml
import pytorch_lightning as pl
@@ -34,7 +33,7 @@
from pytorch_lightning.utilities.migration import pl_legacy_patch
from pytorch_lightning.utilities.parsing import parse_class_init_keys
from pytorch_lightning.utilities.rank_zero import rank_zero_warn
-from pytorch_lightning.utilities.types import _PATH
+from pytorch_lightning.utilities.types import _MAP_LOCATION_TYPE, _PATH
log = logging.getLogger(__name__)
PRIMITIVE_TYPES = (bool, int, float, str)
@@ -58,11 +57,11 @@ class ModelIO:
def load_from_checkpoint(
cls,
checkpoint_path: Union[str, IO],
- map_location: Optional[Union[Dict[str, str], str, torch.device, int, Callable]] = None,
+ map_location: _MAP_LOCATION_TYPE = None,
hparams_file: Optional[str] = None,
strict: bool = True,
- **kwargs,
- ):
+ **kwargs: Any,
+ ) -> Union["pl.LightningModule", "pl.LightningDataModule"]:
r"""
Primary way of loading a model from a checkpoint. When Lightning saves a checkpoint
it stores the arguments passed to ``__init__`` in the checkpoint under ``"hyper_parameters"``.
@@ -171,15 +170,15 @@ def on_hpc_load(self, checkpoint: Dict[str, Any]) -> None:
def _load_from_checkpoint(
- cls: Union["pl.LightningModule", "pl.LightningDataModule"],
+ cls: Union[Type["ModelIO"], Type["pl.LightningModule"], Type["pl.LightningDataModule"]],
checkpoint_path: Union[str, IO],
- map_location: Optional[Union[Dict[str, str], str, torch.device, int, Callable]] = None,
+ map_location: _MAP_LOCATION_TYPE = None,
hparams_file: Optional[str] = None,
- strict: Optional[bool] = None,
+ strict: bool = True,
**kwargs: Any,
-) -> Any:
+) -> Union["pl.LightningModule", "pl.LightningDataModule"]:
if map_location is None:
- map_location = lambda storage, loc: storage
+ map_location = cast(_MAP_LOCATION_TYPE, lambda storage, loc: storage)
with pl_legacy_patch():
checkpoint = pl_load(checkpoint_path, map_location=map_location)
@@ -202,15 +201,18 @@ def _load_from_checkpoint(
if issubclass(cls, pl.LightningDataModule):
return _load_state(cls, checkpoint, **kwargs)
- return _load_state(cls, checkpoint, strict=strict, **kwargs)
+ # allow cls to be evaluated as subclassed LightningModule or,
+ # as LightningModule for internal tests
+ if issubclass(cls, pl.LightningModule):
+ return _load_state(cls, checkpoint, strict=strict, **kwargs)
def _load_state(
- cls: Union["pl.LightningModule", "pl.LightningDataModule"],
+ cls: Union[Type["pl.LightningModule"], Type["pl.LightningDataModule"]],
checkpoint: Dict[str, Any],
- strict: Optional[bool] = None,
+ strict: bool = True,
**cls_kwargs_new: Any,
-) -> Any:
+) -> Union["pl.LightningModule", "pl.LightningDataModule"]:
cls_spec = inspect.getfullargspec(cls.__init__)
cls_init_args_name = inspect.signature(cls.__init__).parameters.keys()
@@ -228,8 +230,7 @@ def _load_state(
cls_kwargs_loaded.update(checkpoint.get(_old_hparam_key, {}))
# 2. Try to restore model hparams from checkpoint using the new key
- _new_hparam_key = cls.CHECKPOINT_HYPER_PARAMS_KEY
- cls_kwargs_loaded.update(checkpoint.get(_new_hparam_key))
+ cls_kwargs_loaded.update(checkpoint.get(cls.CHECKPOINT_HYPER_PARAMS_KEY, {}))
# 3. Ensure that `cls_kwargs_old` has the right type, back compatibility between dict and Namespace
cls_kwargs_loaded = _convert_loaded_hparams(cls_kwargs_loaded, checkpoint.get(cls.CHECKPOINT_HYPER_PARAMS_TYPE))
@@ -271,7 +272,9 @@ def _load_state(
return obj
-def _convert_loaded_hparams(model_args: dict, hparams_type: Optional[Union[Callable, str]] = None) -> object:
+def _convert_loaded_hparams(
+ model_args: Dict[str, Any], hparams_type: Optional[Union[Callable, str]] = None
+) -> Dict[str, Any]:
"""Convert hparams according given type in callable or string (past) format."""
# if not hparams type define
if not hparams_type:
diff --git a/src/pytorch_lightning/lite/lite.py b/src/pytorch_lightning/lite/lite.py
index 5125bf4486a9d..981eed30635f6 100644
--- a/src/pytorch_lightning/lite/lite.py
+++ b/src/pytorch_lightning/lite/lite.py
@@ -40,7 +40,6 @@
has_iterable_dataset,
)
from pytorch_lightning.utilities.exceptions import MisconfigurationException
-from pytorch_lightning.utilities.imports import _RequirementAvailable
from pytorch_lightning.utilities.seed import seed_everything
@@ -106,8 +105,6 @@ def __init__(
self._precision_plugin = self._strategy.precision_plugin
self._models_setup: int = 0
- self._check_deepspeed_support()
-
# wrap the run method so we can inject setup logic or spawn processes for the user
setattr(self, "run", partial(self._run_impl, self.run))
@@ -459,18 +456,6 @@ def _check_strategy_support(self, strategy: Optional[Union[str, Strategy]]) -> N
f" Choose one of {supported} or pass in a `Strategy` instance."
)
- def _check_deepspeed_support(self) -> None:
- if (
- isinstance(self._strategy, DeepSpeedStrategy)
- and self._strategy.zero_stage_3
- and _RequirementAvailable("deepspeed>=0.6.5")
- ):
- # https://github.com/microsoft/DeepSpeed/issues/2139
- raise RuntimeError(
- "DeepSpeed ZeRO-3 is not supported with this version of Lightning Lite and `deepspeed>=0.6.5`."
- " Please downgrade deepspeed to 0.6.4 or check if a newer version of Lightning is available."
- )
-
@staticmethod
def _supported_device_types() -> Sequence[_AcceleratorType]:
return (
diff --git a/src/pytorch_lightning/loggers/wandb.py b/src/pytorch_lightning/loggers/wandb.py
index 8e30827759b99..530fb58fabe5e 100644
--- a/src/pytorch_lightning/loggers/wandb.py
+++ b/src/pytorch_lightning/loggers/wandb.py
@@ -328,7 +328,7 @@ def __getstate__(self) -> Dict[str, Any]:
@property # type: ignore[misc]
@rank_zero_experiment
- def experiment(self) -> Run:
+ def experiment(self) -> Union[Run, RunDisabled]:
r"""
Actual wandb object. To use wandb features in your
@@ -361,11 +361,13 @@ def experiment(self) -> Run:
self._experiment = wandb.init(**self._wandb_init)
# define default x-axis
- if isinstance(self._experiment, Run) and getattr(self._experiment, "define_metric", None):
+ if isinstance(self._experiment, (Run, RunDisabled)) and getattr(
+ self._experiment, "define_metric", None
+ ):
self._experiment.define_metric("trainer/global_step")
self._experiment.define_metric("*", step_metric="trainer/global_step", step_sync=True)
- assert isinstance(self._experiment, Run)
+ assert isinstance(self._experiment, (Run, RunDisabled))
return self._experiment
def watch(self, model: nn.Module, log: str = "gradients", log_freq: int = 100, log_graph: bool = True) -> None:
diff --git a/src/pytorch_lightning/overrides/base.py b/src/pytorch_lightning/overrides/base.py
index 26c2837bda7e3..3e9fda2f966f5 100644
--- a/src/pytorch_lightning/overrides/base.py
+++ b/src/pytorch_lightning/overrides/base.py
@@ -75,6 +75,7 @@ def forward(self, *inputs: Any, **kwargs: Any) -> Any:
trainer = pl_module._trainer
if trainer is not None:
+ assert isinstance(self.module, (pl.LightningModule, _LightningPrecisionModuleWrapperBase))
if trainer.training:
output = self.module.training_step(*inputs, **kwargs)
# In manual_optimization, we need to prevent DDP reducer as
diff --git a/src/pytorch_lightning/overrides/distributed.py b/src/pytorch_lightning/overrides/distributed.py
index f09a7b9e3ae08..929d1ed486f4a 100644
--- a/src/pytorch_lightning/overrides/distributed.py
+++ b/src/pytorch_lightning/overrides/distributed.py
@@ -45,8 +45,6 @@ def _find_tensors(
# https://github.com/pytorch/pytorch/blob/v1.7.1/torch/nn/parallel/distributed.py#L626-L638
def prepare_for_backward(model: DistributedDataParallel, output: Any) -> None:
# `prepare_for_backward` is `DistributedDataParallel` specific.
- if not isinstance(model, DistributedDataParallel):
- return
if torch.is_grad_enabled() and model.require_backward_grad_sync:
model.require_forward_param_sync = True # type: ignore[assignment]
# We'll return the output object verbatim since it is a freeform
diff --git a/src/pytorch_lightning/plugins/precision/deepspeed.py b/src/pytorch_lightning/plugins/precision/deepspeed.py
index 01d3017760b0e..456bba1e77823 100644
--- a/src/pytorch_lightning/plugins/precision/deepspeed.py
+++ b/src/pytorch_lightning/plugins/precision/deepspeed.py
@@ -60,7 +60,7 @@ def __init__(self, precision: Union[str, int], amp_type: str, amp_level: Optiona
amp_level = amp_level or "O2"
- supported_precision = (PrecisionType.HALF, PrecisionType.FLOAT, PrecisionType.BFLOAT, PrecisionType.MIXED)
+ supported_precision = (PrecisionType.HALF, PrecisionType.FLOAT, PrecisionType.BFLOAT)
if precision not in supported_precision:
raise ValueError(
f"`Trainer(strategy='deepspeed', precision={precision!r})` is not supported."
diff --git a/src/pytorch_lightning/plugins/precision/fully_sharded_native_amp.py b/src/pytorch_lightning/plugins/precision/fully_sharded_native_amp.py
index 8c693f2975bbd..60e53b880c84d 100644
--- a/src/pytorch_lightning/plugins/precision/fully_sharded_native_amp.py
+++ b/src/pytorch_lightning/plugins/precision/fully_sharded_native_amp.py
@@ -23,7 +23,7 @@
if _TORCH_GREATER_EQUAL_1_12:
from torch.distributed.fsdp.fully_sharded_data_parallel import MixedPrecision
else:
- MixedPrecision = None
+ MixedPrecision = None # type: ignore[misc,assignment]
class FullyShardedNativeMixedPrecisionPlugin(ShardedNativeMixedPrecisionPlugin):
diff --git a/src/pytorch_lightning/plugins/precision/ipu.py b/src/pytorch_lightning/plugins/precision/ipu.py
index 89f544575f63f..67e5e373e9f52 100644
--- a/src/pytorch_lightning/plugins/precision/ipu.py
+++ b/src/pytorch_lightning/plugins/precision/ipu.py
@@ -19,6 +19,7 @@
import pytorch_lightning as pl
from pytorch_lightning.plugins.precision.precision_plugin import PrecisionPlugin
from pytorch_lightning.utilities import GradClipAlgorithmType
+from pytorch_lightning.utilities.enums import PrecisionType
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.model_helpers import is_overridden
from pytorch_lightning.utilities.warnings import WarningCache
@@ -35,7 +36,7 @@ class IPUPrecisionPlugin(PrecisionPlugin):
"""
def __init__(self, precision: int) -> None:
- supported_precision_values = (16, 32)
+ supported_precision_values = (PrecisionType.HALF, PrecisionType.FLOAT)
if precision not in supported_precision_values:
raise ValueError(
f"`Trainer(accelerator='ipu', precision={precision!r})` is not supported."
diff --git a/src/pytorch_lightning/profilers/simple.py b/src/pytorch_lightning/profilers/simple.py
index 20d76f9b2d378..0fb9497ff17fb 100644
--- a/src/pytorch_lightning/profilers/simple.py
+++ b/src/pytorch_lightning/profilers/simple.py
@@ -60,7 +60,7 @@ def __init__(
"""
super().__init__(dirpath=dirpath, filename=filename)
self.current_actions: Dict[str, float] = {}
- self.recorded_durations = defaultdict(list)
+ self.recorded_durations: Dict = defaultdict(list)
self.extended = extended
self.start_time = time.monotonic()
@@ -104,20 +104,23 @@ def summary(self) -> str:
if len(self.recorded_durations) > 0:
max_key = max(len(k) for k in self.recorded_durations.keys())
- def log_row(action, mean, num_calls, total, per):
+ def log_row_extended(action: str, mean: str, num_calls: str, total: str, per: str) -> str:
row = f"{sep}| {action:<{max_key}s}\t| {mean:<15}\t|"
row += f" {num_calls:<15}\t| {total:<15}\t| {per:<15}\t|"
return row
- header_string = log_row("Action", "Mean duration (s)", "Num calls", "Total time (s)", "Percentage %")
+ header_string = log_row_extended(
+ "Action", "Mean duration (s)", "Num calls", "Total time (s)", "Percentage %"
+ )
output_string_len = len(header_string.expandtabs())
sep_lines = f"{sep}{'-' * output_string_len}"
output_string += sep_lines + header_string + sep_lines
- report, total_calls, total_duration = self._make_report_extended()
- output_string += log_row("Total", "-", f"{total_calls:}", f"{total_duration:.5}", "100 %")
+ report_extended: _TABLE_DATA_EXTENDED
+ report_extended, total_calls, total_duration = self._make_report_extended()
+ output_string += log_row_extended("Total", "-", f"{total_calls:}", f"{total_duration:.5}", "100 %")
output_string += sep_lines
- for action, mean_duration, num_calls, total_duration, duration_per in report:
- output_string += log_row(
+ for action, mean_duration, num_calls, total_duration, duration_per in report_extended:
+ output_string += log_row_extended(
action,
f"{mean_duration:.5}",
f"{num_calls}",
@@ -128,7 +131,7 @@ def log_row(action, mean, num_calls, total, per):
else:
max_key = max(len(k) for k in self.recorded_durations)
- def log_row(action, mean, total):
+ def log_row(action: str, mean: str, total: str) -> str:
return f"{sep}| {action:<{max_key}s}\t| {mean:<15}\t| {total:<15}\t|"
header_string = log_row("Action", "Mean duration (s)", "Total time (s)")
diff --git a/src/pytorch_lightning/strategies/ddp.py b/src/pytorch_lightning/strategies/ddp.py
index 922730df35269..57ab3a151b011 100644
--- a/src/pytorch_lightning/strategies/ddp.py
+++ b/src/pytorch_lightning/strategies/ddp.py
@@ -32,6 +32,7 @@
import pytorch_lightning as pl
from pytorch_lightning.core.optimizer import LightningOptimizer
from pytorch_lightning.overrides import LightningDistributedModule
+from pytorch_lightning.overrides.base import _LightningPrecisionModuleWrapperBase
from pytorch_lightning.overrides.distributed import prepare_for_backward
from pytorch_lightning.overrides.fairscale import _FAIRSCALE_AVAILABLE
from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment
@@ -39,6 +40,7 @@
from pytorch_lightning.plugins.precision import PrecisionPlugin
from pytorch_lightning.strategies.launchers.subprocess_script import _SubprocessScriptLauncher
from pytorch_lightning.strategies.parallel import ParallelStrategy
+from pytorch_lightning.strategies.strategy import TBroadcast
from pytorch_lightning.trainer.states import TrainerFn
from pytorch_lightning.utilities.distributed import (
_get_process_group_backend_from_env,
@@ -57,7 +59,7 @@
from pytorch_lightning.utilities.optimizer import optimizers_to_device
from pytorch_lightning.utilities.rank_zero import rank_zero_info, rank_zero_only, rank_zero_warn
from pytorch_lightning.utilities.seed import reset_seed
-from pytorch_lightning.utilities.types import STEP_OUTPUT
+from pytorch_lightning.utilities.types import PredictStep, STEP_OUTPUT, TestStep, ValidationStep
if _FAIRSCALE_AVAILABLE:
from fairscale.optim import OSS
@@ -83,12 +85,12 @@ def __init__(
checkpoint_io: Optional[CheckpointIO] = None,
precision_plugin: Optional[PrecisionPlugin] = None,
ddp_comm_state: Optional[object] = None,
- ddp_comm_hook: Optional[callable] = None,
- ddp_comm_wrapper: Optional[callable] = None,
+ ddp_comm_hook: Optional[Callable] = None,
+ ddp_comm_wrapper: Optional[Callable] = None,
model_averaging_period: Optional[int] = None,
process_group_backend: Optional[str] = None,
timeout: Optional[timedelta] = default_pg_timeout,
- **kwargs: Union[Any, Dict[str, Any]],
+ **kwargs: Any,
) -> None:
super().__init__(
accelerator=accelerator,
@@ -105,7 +107,7 @@ def __init__(
self._ddp_comm_wrapper = ddp_comm_wrapper
self._model_averaging_period = model_averaging_period
self._model_averager: Optional[ModelAverager] = None
- self._pids: Optional[List[int]] = None
+ self._pids: List[int] = []
self._sync_dir: Optional[str] = None
self._rank_0_will_call_children_scripts: bool = False
self._process_group_backend: Optional[str] = process_group_backend
@@ -117,6 +119,7 @@ def is_distributed(self) -> bool:
@property
def root_device(self) -> torch.device:
+ assert self.parallel_devices is not None
return self.parallel_devices[self.local_rank]
@property
@@ -129,11 +132,11 @@ def num_nodes(self, num_nodes: int) -> None:
self._num_nodes = num_nodes
@property
- def num_processes(self):
+ def num_processes(self) -> int:
return len(self.parallel_devices) if self.parallel_devices is not None else 0
@property
- def distributed_sampler_kwargs(self):
+ def distributed_sampler_kwargs(self) -> Dict[str, Any]:
distributed_sampler_kwargs = dict(num_replicas=(self.num_nodes * self.num_processes), rank=self.global_rank)
return distributed_sampler_kwargs
@@ -146,6 +149,7 @@ def process_group_backend(self) -> Optional[str]:
return self._process_group_backend
def _configure_launcher(self) -> None:
+ assert self.cluster_environment is not None
if not self.cluster_environment.creates_processes_externally:
self._launcher = _SubprocessScriptLauncher(self.cluster_environment, self.num_processes, self.num_nodes)
self._rank_0_will_call_children_scripts = True
@@ -156,10 +160,11 @@ def setup_environment(self) -> None:
def setup(self, trainer: "pl.Trainer") -> None:
# share ddp pids to all processes
- self._rank_0_will_call_children_scripts = self.broadcast(self._rank_0_will_call_children_scripts)
+ self._rank_0_will_call_children_scripts = bool(self.broadcast(self._rank_0_will_call_children_scripts))
if self._should_run_deadlock_detection():
self._share_information_to_prevent_deadlock()
+ assert self.accelerator is not None
self.accelerator.setup(trainer)
# move the model to the correct device
@@ -170,6 +175,7 @@ def setup(self, trainer: "pl.Trainer") -> None:
if trainer_fn == TrainerFn.FITTING:
if self._layer_sync:
+ assert self.model is not None
self.model = self._layer_sync.apply(self.model)
self.setup_precision_plugin()
@@ -193,7 +199,7 @@ def _setup_model(self, model: Module) -> DistributedDataParallel:
log.detail(f"setting up DDP model with device ids: {device_ids}, kwargs: {self._ddp_kwargs}")
return DistributedDataParallel(module=model, device_ids=device_ids, **self._ddp_kwargs)
- def setup_distributed(self):
+ def setup_distributed(self) -> None:
log.detail(f"{self.__class__.__name__}: setting up distributed...")
reset_seed()
@@ -204,6 +210,7 @@ def setup_distributed(self):
rank_zero_only.rank = self.global_rank
self._process_group_backend = self._get_process_group_backend()
+ assert self.cluster_environment is not None
init_dist_connection(self.cluster_environment, self._process_group_backend, timeout=self._timeout)
def _get_process_group_backend(self) -> str:
@@ -230,6 +237,7 @@ def pre_configure_ddp(self) -> None:
def _register_ddp_hooks(self) -> None:
log.detail(f"{self.__class__.__name__}: registering ddp hooks")
if self.root_device.type == "cuda" and self._is_single_process_single_device:
+ assert isinstance(self.model, DistributedDataParallel)
register_ddp_comm_hook(
model=self.model,
ddp_comm_state=self._ddp_comm_state,
@@ -262,6 +270,7 @@ def _enable_model_averaging(self) -> None:
f"{optimizer.__class__.__name__}."
)
+ assert self._ddp_comm_state is not None
self._model_averager = torch.distributed.algorithms.model_averaging.averagers.PeriodicModelAverager(
period=self._model_averaging_period, warmup_steps=self._ddp_comm_state.start_localSGD_iter
)
@@ -296,15 +305,16 @@ def optimizer_step(
def configure_ddp(self) -> None:
log.detail(f"{self.__class__.__name__}: configuring DistributedDataParallel")
self.pre_configure_ddp()
+ assert isinstance(self.model, (pl.LightningModule, _LightningPrecisionModuleWrapperBase))
self.model = self._setup_model(LightningDistributedModule(self.model))
self._register_ddp_hooks()
- def determine_ddp_device_ids(self):
+ def determine_ddp_device_ids(self) -> Optional[List[int]]:
if self.root_device.type == "cpu":
return None
return [self.root_device.index]
- def barrier(self, *args, **kwargs) -> None:
+ def barrier(self, *args: Any, **kwargs: Any) -> None:
if not distributed_available():
return
if torch.distributed.get_backend() == "nccl":
@@ -312,23 +322,29 @@ def barrier(self, *args, **kwargs) -> None:
else:
torch.distributed.barrier()
- def broadcast(self, obj: object, src: int = 0) -> object:
+ def broadcast(self, obj: TBroadcast, src: int = 0) -> TBroadcast:
obj = [obj]
if self.global_rank != src:
- obj = [None]
+ obj = [None] # type: ignore[list-item]
torch.distributed.broadcast_object_list(obj, src, group=_group.WORLD)
return obj[0]
def pre_backward(self, closure_loss: Tensor) -> None:
"""Run before precision plugin executes backward."""
+ if not isinstance(self.model, DistributedDataParallel):
+ return
+ assert self.lightning_module is not None
if not self.lightning_module.automatic_optimization:
prepare_for_backward(self.model, closure_loss)
- def model_to_device(self):
+ def model_to_device(self) -> None:
log.detail(f"{self.__class__.__name__}: moving model to device [{self.root_device}]...")
+ assert self.model is not None
self.model.to(self.root_device)
- def reduce(self, tensor, group: Optional[Any] = None, reduce_op: Union[ReduceOp, str] = "mean") -> Tensor:
+ def reduce(
+ self, tensor: Tensor, group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = "mean"
+ ) -> Tensor:
"""Reduces a tensor from several distributed processes to one aggregated tensor.
Args:
@@ -344,30 +360,38 @@ def reduce(self, tensor, group: Optional[Any] = None, reduce_op: Union[ReduceOp,
tensor = sync_ddp_if_available(tensor, group, reduce_op=reduce_op)
return tensor
- def training_step(self, *args, **kwargs) -> STEP_OUTPUT:
+ def training_step(self, *args: Any, **kwargs: Any) -> STEP_OUTPUT:
+ assert self.model is not None
with self.precision_plugin.train_step_context():
return self.model(*args, **kwargs)
- def validation_step(self, *args, **kwargs) -> Optional[STEP_OUTPUT]:
+ def validation_step(self, *args: Any, **kwargs: Any) -> Optional[STEP_OUTPUT]:
with self.precision_plugin.val_step_context():
+ assert self.lightning_module is not None
+ assert self.model is not None
if self.lightning_module.trainer.state.fn == TrainerFn.FITTING:
# used when calling `trainer.fit`
return self.model(*args, **kwargs)
else:
# used when calling `trainer.validate`
+ assert isinstance(self.model, ValidationStep)
return self.model.validation_step(*args, **kwargs)
- def test_step(self, *args, **kwargs) -> Optional[STEP_OUTPUT]:
+ def test_step(self, *args: Any, **kwargs: Any) -> Optional[STEP_OUTPUT]:
with self.precision_plugin.test_step_context():
+ assert isinstance(self.model, TestStep)
return self.model.test_step(*args, **kwargs)
- def predict_step(self, *args, **kwargs) -> STEP_OUTPUT:
+ def predict_step(self, *args: Any, **kwargs: Any) -> STEP_OUTPUT:
with self.precision_plugin.predict_step_context():
+ assert isinstance(self.model, PredictStep)
return self.model.predict_step(*args, **kwargs)
- def post_training_step(self):
+ def post_training_step(self) -> None:
+ assert self.lightning_module is not None
if not self.lightning_module.automatic_optimization:
- self.model.require_backward_grad_sync = True
+ assert self.model is not None
+ self.model.require_backward_grad_sync = True # type: ignore[assignment]
@classmethod
def register_strategies(cls, strategy_registry: Dict) -> None:
@@ -458,7 +482,7 @@ def teardown(self) -> None:
if (
_TORCH_GREATER_EQUAL_1_11
and not self.model.static_graph
- and self.model._get_ddp_logging_data().get("can_set_static_graph")
+ and self.model._get_ddp_logging_data().get("can_set_static_graph") # type: ignore[operator]
):
rank_zero_info(
"Your model can run with static graph optimizations. For future training runs, we suggest you"
@@ -475,6 +499,7 @@ def teardown(self) -> None:
and pl_module._trainer.state.fn == TrainerFn.FITTING
and self._layer_sync
):
+ assert self.model is not None
self.model = self._layer_sync.revert(self.model)
super().teardown()
diff --git a/src/pytorch_lightning/strategies/ddp_spawn.py b/src/pytorch_lightning/strategies/ddp_spawn.py
index 30bcef457c44a..de34320f54093 100644
--- a/src/pytorch_lightning/strategies/ddp_spawn.py
+++ b/src/pytorch_lightning/strategies/ddp_spawn.py
@@ -254,9 +254,10 @@ def model_to_device(self) -> None:
def pre_backward(self, closure_loss: Tensor) -> None:
"""Run before precision plugin executes backward."""
+ if not isinstance(self.model, DistributedDataParallel):
+ return
assert self.lightning_module is not None
if not self.lightning_module.automatic_optimization:
- assert isinstance(self.model, DistributedDataParallel)
prepare_for_backward(self.model, closure_loss)
def reduce(
@@ -314,10 +315,20 @@ def post_training_step(self) -> None:
def register_strategies(cls, strategy_registry: Dict) -> None:
entries = (
("ddp_spawn", "spawn"),
- ("ddp_spawn_find_unused_parameters_false", "spawn"),
("ddp_fork", "fork"),
- ("ddp_fork_find_unused_parameters_false", "fork"),
("ddp_notebook", "fork"),
+ )
+ for name, start_method in entries:
+ strategy_registry.register(
+ name,
+ cls,
+ description=f"DDP strategy with `start_method` '{start_method}'",
+ start_method=start_method,
+ )
+
+ entries = (
+ ("ddp_spawn_find_unused_parameters_false", "spawn"),
+ ("ddp_fork_find_unused_parameters_false", "fork"),
("ddp_notebook_find_unused_parameters_false", "fork"),
)
for name, start_method in entries:
diff --git a/src/pytorch_lightning/strategies/deepspeed.py b/src/pytorch_lightning/strategies/deepspeed.py
index b0b55374ba1a9..8acbc80257bd1 100644
--- a/src/pytorch_lightning/strategies/deepspeed.py
+++ b/src/pytorch_lightning/strategies/deepspeed.py
@@ -19,7 +19,7 @@
import platform
from collections import OrderedDict
from pathlib import Path
-from typing import Any, cast, Dict, Generator, List, Mapping, Optional, Tuple, Union
+from typing import Any, Dict, Generator, List, Mapping, Optional, Tuple, Union
import torch
from torch import Tensor
@@ -696,7 +696,7 @@ def _auto_select_batch_size(self) -> int:
def _format_precision_config(self) -> None:
assert isinstance(self.config, dict)
- if self.precision_plugin.precision in (PrecisionType.HALF, PrecisionType.MIXED):
+ if self.precision_plugin.precision == PrecisionType.HALF:
if "fp16" not in self.config and self.precision_plugin.amp_type == AMPType.NATIVE:
# FP16 is a DeepSpeed standalone AMP implementation
rank_zero_info("Enabling DeepSpeed FP16.")
@@ -831,7 +831,7 @@ def load_checkpoint(self, checkpoint_path: _PATH) -> Dict[str, Any]:
if self.load_full_weights and self.zero_stage_3:
# Broadcast to ensure we load from the rank 0 checkpoint
# This doesn't have to be the case when using deepspeed sharded checkpointing
- checkpoint_path = cast(_PATH, self.broadcast(checkpoint_path))
+ checkpoint_path = self.broadcast(checkpoint_path)
return super().load_checkpoint(checkpoint_path)
# Rely on deepspeed to load the checkpoint and necessary information
diff --git a/src/pytorch_lightning/strategies/fully_sharded_native.py b/src/pytorch_lightning/strategies/fully_sharded_native.py
index 4c351f26fa3b9..d92931fb5cdb2 100644
--- a/src/pytorch_lightning/strategies/fully_sharded_native.py
+++ b/src/pytorch_lightning/strategies/fully_sharded_native.py
@@ -51,7 +51,7 @@
)
from torch.distributed.fsdp.wrap import enable_wrap
else:
- MixedPrecision = None
+ MixedPrecision = None # type: ignore[misc,assignment]
BackwardPrefetch = None # type: ignore[misc,assignment]
CPUOffload = None # type: ignore[misc,assignment]
diff --git a/src/pytorch_lightning/strategies/ipu.py b/src/pytorch_lightning/strategies/ipu.py
index c40addd4244b2..4bedbfd6d70fc 100644
--- a/src/pytorch_lightning/strategies/ipu.py
+++ b/src/pytorch_lightning/strategies/ipu.py
@@ -58,7 +58,7 @@ def __init__(
self.precision = precision
def forward(self, *inputs: Any, **kwargs: Any) -> Any:
- if self.precision in (PrecisionType.MIXED, PrecisionType.HALF):
+ if self.precision == PrecisionType.HALF:
inputs = self._move_float_tensors_to_half(inputs)
return super().forward(*inputs, **kwargs)
diff --git a/src/pytorch_lightning/strategies/launchers/multiprocessing.py b/src/pytorch_lightning/strategies/launchers/multiprocessing.py
index 39bba092e9c60..2617e5fe27b10 100644
--- a/src/pytorch_lightning/strategies/launchers/multiprocessing.py
+++ b/src/pytorch_lightning/strategies/launchers/multiprocessing.py
@@ -144,7 +144,7 @@ def _recover_results_in_main_process(self, worker_output: "_WorkerOutput", train
# load last weights
if worker_output.weights_path is not None:
ckpt = self._strategy.checkpoint_io.load_checkpoint(worker_output.weights_path)
- trainer.lightning_module.load_state_dict(ckpt) # type: ignore[arg-type]
+ trainer.lightning_module.load_state_dict(ckpt)
self._strategy.checkpoint_io.remove_checkpoint(worker_output.weights_path)
trainer.state = worker_output.trainer_state
diff --git a/src/pytorch_lightning/strategies/sharded_spawn.py b/src/pytorch_lightning/strategies/sharded_spawn.py
index 4550e397ded80..882302e101cb6 100644
--- a/src/pytorch_lightning/strategies/sharded_spawn.py
+++ b/src/pytorch_lightning/strategies/sharded_spawn.py
@@ -12,13 +12,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from contextlib import contextmanager
-from typing import Dict, Generator, List, Optional, Tuple
+from typing import Any, Dict, Generator, List, Optional, Tuple
from torch import Tensor
from torch.nn import Module
from torch.optim import Optimizer
import pytorch_lightning as pl
+from pytorch_lightning.overrides.base import _LightningPrecisionModuleWrapperBase
from pytorch_lightning.overrides.fairscale import _FAIRSCALE_AVAILABLE
from pytorch_lightning.strategies.ddp_spawn import DDPSpawnStrategy
from pytorch_lightning.trainer.states import TrainerFn
@@ -42,7 +43,9 @@ class DDPSpawnShardedStrategy(DDPSpawnStrategy):
def configure_ddp(self) -> None:
# set up optimizers after the wrapped module has been moved to the device
+ assert self.lightning_module is not None
self.setup_optimizers(self.lightning_module.trainer)
+ assert isinstance(self.model, (pl.LightningModule, _LightningPrecisionModuleWrapperBase))
self.model, self.optimizers = self._setup_model_and_optimizers(
model=LightningShardedDataParallel(self.model), optimizers=self.optimizers
)
@@ -69,12 +72,13 @@ def _reinit_optimizers_with_oss(self, optimizers: List[Optimizer]) -> List["OSS"
return optimizers
def _wrap_optimizers(self, optimizers: List[Optimizer]) -> List["OSS"]:
- if self.model is not None and self.model.trainer.state.fn != TrainerFn.FITTING:
+ assert self.lightning_module
+ if self.model is not None and self.lightning_module.trainer.state.fn != TrainerFn.FITTING:
return optimizers
return self._reinit_optimizers_with_oss(optimizers)
- def optimizer_state(self, optimizer: "OSS") -> Optional[dict]:
+ def optimizer_state(self, optimizer: "OSS") -> Dict[str, Any]:
if isinstance(optimizer, OSS):
optimizer.consolidate_state_dict()
return self._optim_state_dict(optimizer)
@@ -93,7 +97,7 @@ def block_backward_sync(self) -> Generator:
yield None
@rank_zero_only
- def _optim_state_dict(self, optimizer):
+ def _optim_state_dict(self, optimizer: Optimizer) -> Dict[str, Any]:
"""
Retrieves state dict only on rank 0, which contains the entire optimizer state after calling
:meth:`consolidate_state_dict`.
@@ -112,7 +116,7 @@ def lightning_module(self) -> Optional["pl.LightningModule"]:
def pre_backward(self, closure_loss: Tensor) -> None:
pass
- def post_training_step(self):
+ def post_training_step(self) -> None:
pass
@classmethod
diff --git a/src/pytorch_lightning/strategies/utils.py b/src/pytorch_lightning/strategies/utils.py
index b71458bfc30d3..cdae7bf434eca 100644
--- a/src/pytorch_lightning/strategies/utils.py
+++ b/src/pytorch_lightning/strategies/utils.py
@@ -24,7 +24,7 @@ def on_colab_kaggle() -> bool:
def _fp_to_half(tensor: torch.Tensor, precision: PrecisionType) -> torch.Tensor:
if torch.is_floating_point(tensor):
- if precision in (PrecisionType.MIXED, PrecisionType.HALF):
+ if precision == PrecisionType.HALF:
return tensor.half()
if precision == PrecisionType.BFLOAT:
return tensor.bfloat16()
diff --git a/src/pytorch_lightning/trainer/connectors/callback_connector.py b/src/pytorch_lightning/trainer/connectors/callback_connector.py
index 83881905beeb1..32d67d44ad44c 100644
--- a/src/pytorch_lightning/trainer/connectors/callback_connector.py
+++ b/src/pytorch_lightning/trainer/connectors/callback_connector.py
@@ -17,6 +17,7 @@
from datetime import timedelta
from typing import Dict, List, Optional, Sequence, Union
+import pytorch_lightning as pl
from pytorch_lightning.callbacks import (
Callback,
Checkpoint,
@@ -30,14 +31,14 @@
from pytorch_lightning.callbacks.rich_model_summary import RichModelSummary
from pytorch_lightning.callbacks.timer import Timer
from pytorch_lightning.utilities.exceptions import MisconfigurationException
-from pytorch_lightning.utilities.imports import _PYTHON_GREATER_EQUAL_3_8_0
+from pytorch_lightning.utilities.imports import _PYTHON_GREATER_EQUAL_3_8_0, _PYTHON_GREATER_EQUAL_3_10_0
from pytorch_lightning.utilities.rank_zero import rank_zero_deprecation, rank_zero_info
_log = logging.getLogger(__name__)
class CallbackConnector:
- def __init__(self, trainer):
+ def __init__(self, trainer: "pl.Trainer"):
self.trainer = trainer
def on_trainer_init(
@@ -50,7 +51,7 @@ def on_trainer_init(
enable_model_summary: bool,
max_time: Optional[Union[str, timedelta, Dict[str, int]]] = None,
accumulate_grad_batches: Optional[Union[int, Dict[int, int]]] = None,
- ):
+ ) -> None:
# init folder paths for checkpoint + weights save callbacks
self.trainer._default_root_dir = default_root_dir or os.getcwd()
if weights_save_path:
@@ -95,16 +96,18 @@ def on_trainer_init(
def _configure_accumulated_gradients(
self, accumulate_grad_batches: Optional[Union[int, Dict[int, int]]] = None
) -> None:
- grad_accum_callback = [cb for cb in self.trainer.callbacks if isinstance(cb, GradientAccumulationScheduler)]
+ grad_accum_callbacks: List[GradientAccumulationScheduler] = [
+ cb for cb in self.trainer.callbacks if isinstance(cb, GradientAccumulationScheduler)
+ ]
- if grad_accum_callback:
+ if grad_accum_callbacks:
if accumulate_grad_batches is not None:
raise MisconfigurationException(
"You have set both `accumulate_grad_batches` and passed an instance of "
"`GradientAccumulationScheduler` inside callbacks. Either remove `accumulate_grad_batches` "
"from trainer or remove `GradientAccumulationScheduler` from callbacks list."
)
- grad_accum_callback = grad_accum_callback[0]
+ grad_accum_callback = grad_accum_callbacks[0]
else:
if accumulate_grad_batches is None:
accumulate_grad_batches = 1
@@ -148,6 +151,7 @@ def _configure_model_summary_callback(self, enable_model_summary: bool) -> None:
progress_bar_callback = self.trainer.progress_bar_callback
is_progress_bar_rich = isinstance(progress_bar_callback, RichProgressBar)
+ model_summary: ModelSummary
if progress_bar_callback is not None and is_progress_bar_rich:
model_summary = RichModelSummary()
else:
@@ -188,7 +192,7 @@ def _configure_timer_callback(self, max_time: Optional[Union[str, timedelta, Dic
timer = Timer(duration=max_time, interval="step")
self.trainer.callbacks.append(timer)
- def _configure_fault_tolerance_callbacks(self):
+ def _configure_fault_tolerance_callbacks(self) -> None:
from pytorch_lightning.callbacks.fault_tolerance import _FaultToleranceCheckpoint
if any(isinstance(cb, _FaultToleranceCheckpoint) for cb in self.trainer.callbacks):
@@ -196,7 +200,7 @@ def _configure_fault_tolerance_callbacks(self):
# don't use `log_dir` to minimize the chances of failure
self.trainer.callbacks.append(_FaultToleranceCheckpoint(dirpath=self.trainer.default_root_dir))
- def _attach_model_logging_functions(self):
+ def _attach_model_logging_functions(self) -> None:
lightning_module = self.trainer.lightning_module
for callback in self.trainer.callbacks:
callback.log = lightning_module.log
@@ -243,7 +247,7 @@ def _reorder_callbacks(callbacks: List[Callback]) -> List[Callback]:
A new list in which the last elements are Checkpoint if there were any present in the
input.
"""
- checkpoints = [c for c in callbacks if isinstance(c, Checkpoint)]
+ checkpoints: List[Callback] = [c for c in callbacks if isinstance(c, Checkpoint)]
not_checkpoints = [c for c in callbacks if not isinstance(c, Checkpoint)]
return not_checkpoints + checkpoints
@@ -256,19 +260,24 @@ def _configure_external_callbacks() -> List[Callback]:
Return:
A list of all callbacks collected from external factories.
"""
+ group = "pytorch_lightning.callbacks_factory"
+
if _PYTHON_GREATER_EQUAL_3_8_0:
from importlib.metadata import entry_points
- factories = entry_points().get("pytorch_lightning.callbacks_factory", ())
+ if _PYTHON_GREATER_EQUAL_3_10_0:
+ factories = entry_points(group=group) # type: ignore[call-arg]
+ else:
+ factories = entry_points().get(group, {}) # type: ignore[assignment]
else:
from pkg_resources import iter_entry_points
- factories = iter_entry_points("pytorch_lightning.callbacks_factory")
+ factories = iter_entry_points(group) # type: ignore[assignment]
- external_callbacks = []
+ external_callbacks: List[Callback] = []
for factory in factories:
callback_factory = factory.load()
- callbacks_list: List[Callback] = callback_factory()
+ callbacks_list: Union[List[Callback], Callback] = callback_factory()
callbacks_list = [callbacks_list] if isinstance(callbacks_list, Callback) else callbacks_list
_log.info(
f"Adding {len(callbacks_list)} callbacks from entry point '{factory.name}':"
diff --git a/src/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py b/src/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py
index ff882912625d0..02e17a8d93494 100644
--- a/src/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py
+++ b/src/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py
@@ -163,8 +163,7 @@ def update_train_epoch_metrics(self) -> None:
self.log_metrics(self.metrics["log"])
# reset result collection for next epoch
- assert self.trainer._results is not None
- self.trainer._results.reset(metrics=True)
+ self.reset_results()
"""
Utilities and properties
diff --git a/src/pytorch_lightning/utilities/__init__.py b/src/pytorch_lightning/utilities/__init__.py
index df5084dd85490..c849ba0a05d68 100644
--- a/src/pytorch_lightning/utilities/__init__.py
+++ b/src/pytorch_lightning/utilities/__init__.py
@@ -21,7 +21,6 @@
_AcceleratorType,
_StrategyType,
AMPType,
- DistributedType,
GradClipAlgorithmType,
LightningEnum,
)
diff --git a/src/pytorch_lightning/utilities/cloud_io.py b/src/pytorch_lightning/utilities/cloud_io.py
index 81482a8ab24f9..99629bcda8980 100644
--- a/src/pytorch_lightning/utilities/cloud_io.py
+++ b/src/pytorch_lightning/utilities/cloud_io.py
@@ -15,21 +15,19 @@
import io
from pathlib import Path
-from typing import Any, Callable, Dict, IO, Optional, Union
+from typing import Any, Dict, IO, Union
import fsspec
import torch
from fsspec.core import url_to_fs
from fsspec.implementations.local import AbstractFileSystem
-from pytorch_lightning.utilities.types import _PATH
+from pytorch_lightning.utilities.types import _MAP_LOCATION_TYPE, _PATH
def load(
path_or_url: Union[IO, _PATH],
- map_location: Optional[
- Union[str, Callable, torch.device, Dict[Union[str, torch.device], Union[str, torch.device]]]
- ] = None,
+ map_location: _MAP_LOCATION_TYPE = None,
) -> Any:
"""Loads a checkpoint.
@@ -41,7 +39,10 @@ def load(
# any sort of BytesIO or similar
return torch.load(path_or_url, map_location=map_location)
if str(path_or_url).startswith("http"):
- return torch.hub.load_state_dict_from_url(str(path_or_url), map_location=map_location)
+ return torch.hub.load_state_dict_from_url(
+ str(path_or_url),
+ map_location=map_location, # type: ignore[arg-type] # upstream annotation is not correct
+ )
fs = get_filesystem(path_or_url)
with fs.open(path_or_url, "rb") as f:
return torch.load(f, map_location=map_location)
diff --git a/src/pytorch_lightning/utilities/data.py b/src/pytorch_lightning/utilities/data.py
index 00a7cb8486709..b625a046f6122 100644
--- a/src/pytorch_lightning/utilities/data.py
+++ b/src/pytorch_lightning/utilities/data.py
@@ -501,15 +501,17 @@ def _replace_init_method(base_cls: Type, store_explicit_arg: Optional[str] = Non
It patches the ``__init__`` method.
"""
classes = _get_all_subclasses(base_cls) | {base_cls}
- wrapped = set()
for cls in classes:
- if cls.__init__ not in wrapped:
+ # Check that __init__ belongs to the class
+ # https://stackoverflow.com/a/5253424
+ if "__init__" in cls.__dict__:
cls._old_init = cls.__init__
cls.__init__ = _wrap_init_method(cls.__init__, store_explicit_arg)
- wrapped.add(cls.__init__)
yield
for cls in classes:
- if hasattr(cls, "_old_init"):
+ # Check that _old_init belongs to the class
+ # https://stackoverflow.com/a/5253424
+ if "_old_init" in cls.__dict__:
cls.__init__ = cls._old_init
del cls._old_init
diff --git a/src/pytorch_lightning/utilities/enums.py b/src/pytorch_lightning/utilities/enums.py
index e687d3f9f046b..06d616f87259f 100644
--- a/src/pytorch_lightning/utilities/enums.py
+++ b/src/pytorch_lightning/utilities/enums.py
@@ -15,11 +15,9 @@
from __future__ import annotations
import os
-from enum import Enum, EnumMeta
-from typing import Any
+from enum import Enum
from pytorch_lightning.utilities.exceptions import MisconfigurationException
-from pytorch_lightning.utilities.warnings import rank_zero_deprecation
class LightningEnum(str, Enum):
@@ -43,37 +41,6 @@ def __hash__(self) -> int:
return hash(self.value.lower())
-class _DeprecatedEnumMeta(EnumMeta):
- """Enum that calls `deprecate()` whenever a member is accessed.
-
- Adapted from: https://stackoverflow.com/a/62309159/208880
- """
-
- def __getattribute__(cls, name: str) -> Any:
- obj = super().__getattribute__(name)
- # ignore __dunder__ names -- prevents potential recursion errors
- if not (name.startswith("__") and name.endswith("__")) and isinstance(obj, Enum):
- obj.deprecate()
- return obj
-
- def __getitem__(cls, name: str) -> Any:
- member: _DeprecatedEnumMeta = super().__getitem__(name)
- member.deprecate()
- return member
-
- def __call__(cls, *args: Any, **kwargs: Any) -> Any:
- obj = super().__call__(*args, **kwargs)
- if isinstance(obj, Enum):
- obj.deprecate()
- return obj
-
-
-class _DeprecatedEnum(LightningEnum, metaclass=_DeprecatedEnumMeta):
- """_DeprecatedEnum calls an enum's `deprecate()` method on member access."""
-
- pass
-
-
class AMPType(LightningEnum):
"""Type of Automatic Mixed Precission used for training.
@@ -110,66 +77,6 @@ def supported_types() -> list[str]:
return [x.value for x in PrecisionType]
-class DistributedType(_DeprecatedEnum):
- """Define type of training strategy.
-
- Deprecated since v1.6.0 and will be removed in v1.8.0.
-
- Use `_StrategyType` instead.
- """
-
- DP = "dp"
- DDP = "ddp"
- DDP_SPAWN = "ddp_spawn"
- TPU_SPAWN = "tpu_spawn"
- DEEPSPEED = "deepspeed"
- HOROVOD = "horovod"
- DDP_SHARDED = "ddp_sharded"
- DDP_SHARDED_SPAWN = "ddp_sharded_spawn"
- DDP_FULLY_SHARDED = "ddp_fully_sharded"
- HPU_PARALLEL = "hpu_parallel"
-
- @staticmethod
- def interactive_compatible_types() -> list[DistributedType]:
- """Returns a list containing interactive compatible DistributeTypes."""
- return [
- DistributedType.DP,
- DistributedType.DDP_SPAWN,
- DistributedType.DDP_SHARDED_SPAWN,
- DistributedType.TPU_SPAWN,
- ]
-
- def is_interactive_compatible(self) -> bool:
- """Returns whether self is interactive compatible."""
- return self in DistributedType.interactive_compatible_types()
-
- def deprecate(self) -> None:
- rank_zero_deprecation(
- "`DistributedType` Enum has been deprecated in v1.6 and will be removed in v1.8."
- f" Use the string value `{self.value!r}` instead."
- )
-
-
-class DeviceType(_DeprecatedEnum):
- """Define Device type by its nature - accelerators.
-
- Deprecated since v1.6.0 and will be removed in v1.8.0.
-
- Use `_AcceleratorType` instead.
- """
-
- CPU = "CPU"
- GPU = "GPU"
- IPU = "IPU"
- TPU = "TPU"
-
- def deprecate(self) -> None:
- rank_zero_deprecation(
- "`DeviceType` Enum has been deprecated in v1.6 and will be removed in v1.8."
- f" Use the string value `{self.value!r}` instead."
- )
-
-
class GradClipAlgorithmType(LightningEnum):
"""Define gradient_clip_algorithm types - training-tricks.
NORM type means "clipping gradients by norm". This computed over all model parameters together.
diff --git a/src/pytorch_lightning/utilities/grads.py b/src/pytorch_lightning/utilities/grads.py
index 66c1b7d988522..76c3f39bdc013 100644
--- a/src/pytorch_lightning/utilities/grads.py
+++ b/src/pytorch_lightning/utilities/grads.py
@@ -41,12 +41,12 @@ def grad_norm(module: Module, norm_type: Union[float, int, str], group_separator
raise ValueError(f"`norm_type` must be a positive number or 'inf' (infinity norm). Got {norm_type}")
norms = {
- f"grad_{norm_type}_norm{group_separator}{name}": p.grad.data.norm(norm_type).item()
+ f"grad_{norm_type}_norm{group_separator}{name}": p.grad.data.norm(norm_type)
for name, p in module.named_parameters()
if p.grad is not None
}
if norms:
- total_norm = torch.tensor(list(norms.values())).norm(norm_type).item()
+ total_norm = torch.tensor(list(norms.values())).norm(norm_type)
norms[f"grad_{norm_type}_norm_total"] = total_norm
- norms = {k: round(v, 4) for k, v in norms.items()}
+ norms = {k: round(v.item(), 4) for k, v in norms.items()}
return norms
diff --git a/src/pytorch_lightning/utilities/imports.py b/src/pytorch_lightning/utilities/imports.py
index 67bf75be3c4d3..ba437ad332dfa 100644
--- a/src/pytorch_lightning/utilities/imports.py
+++ b/src/pytorch_lightning/utilities/imports.py
@@ -124,6 +124,7 @@ def __repr__(self) -> str:
_IS_WINDOWS = platform.system() == "Windows"
_IS_INTERACTIVE = hasattr(sys, "ps1") # https://stackoverflow.com/a/64523765
_PYTHON_GREATER_EQUAL_3_8_0 = (sys.version_info.major, sys.version_info.minor) >= (3, 8)
+_PYTHON_GREATER_EQUAL_3_10_0 = (sys.version_info.major, sys.version_info.minor) >= (3, 10)
_TORCH_GREATER_EQUAL_1_9_1 = _compare_version("torch", operator.ge, "1.9.1")
_TORCH_GREATER_EQUAL_1_10 = _compare_version("torch", operator.ge, "1.10.0")
_TORCH_LESSER_EQUAL_1_10_2 = _compare_version("torch", operator.le, "1.10.2")
diff --git a/src/pytorch_lightning/utilities/parsing.py b/src/pytorch_lightning/utilities/parsing.py
index 9f5fe2d6b6841..073423ab60773 100644
--- a/src/pytorch_lightning/utilities/parsing.py
+++ b/src/pytorch_lightning/utilities/parsing.py
@@ -108,7 +108,9 @@ def clean_namespace(hparams: Union[Dict[str, Any], Namespace]) -> None:
del hparams_dict[k]
-def parse_class_init_keys(cls: Type["pl.LightningModule"]) -> Tuple[str, Optional[str], Optional[str]]:
+def parse_class_init_keys(
+ cls: Union[Type["pl.LightningModule"], Type["pl.LightningDataModule"]]
+) -> Tuple[str, Optional[str], Optional[str]]:
"""Parse key words for standard ``self``, ``*args`` and ``**kwargs``.
Examples:
@@ -160,7 +162,10 @@ def get_init_args(frame: types.FrameType) -> Dict[str, Any]:
def collect_init_args(
- frame: types.FrameType, path_args: List[Dict[str, Any]], inside: bool = False
+ frame: types.FrameType,
+ path_args: List[Dict[str, Any]],
+ inside: bool = False,
+ classes: Tuple[Type, ...] = (),
) -> List[Dict[str, Any]]:
"""Recursively collects the arguments passed to the child constructors in the inheritance tree.
@@ -168,6 +173,7 @@ def collect_init_args(
frame: the current stack frame
path_args: a list of dictionaries containing the constructor args in all parent classes
inside: track if we are inside inheritance path, avoid terminating too soon
+ classes: the classes in which to inspect the frames
Return:
A list of dictionaries where each dictionary contains the arguments passed to the
@@ -179,13 +185,13 @@ def collect_init_args(
if not isinstance(frame.f_back, types.FrameType):
return path_args
- if "__class__" in local_vars:
+ if "__class__" in local_vars and (not classes or issubclass(local_vars["__class__"], classes)):
local_args = get_init_args(frame)
# recursive update
path_args.append(local_args)
- return collect_init_args(frame.f_back, path_args, inside=True)
+ return collect_init_args(frame.f_back, path_args, inside=True, classes=classes)
if not inside:
- return collect_init_args(frame.f_back, path_args, inside)
+ return collect_init_args(frame.f_back, path_args, inside, classes=classes)
return path_args
@@ -223,7 +229,10 @@ def save_hyperparameters(
init_args = {f.name: getattr(obj, f.name) for f in fields(obj)}
else:
init_args = {}
- for local_args in collect_init_args(frame, []):
+
+ from pytorch_lightning.core.mixins import HyperparametersMixin
+
+ for local_args in collect_init_args(frame, [], classes=(HyperparametersMixin,)):
init_args.update(local_args)
if ignore is None:
diff --git a/src/pytorch_lightning/utilities/seed.py b/src/pytorch_lightning/utilities/seed.py
index 6648b5a56b2b1..8fce6a1debfcf 100644
--- a/src/pytorch_lightning/utilities/seed.py
+++ b/src/pytorch_lightning/utilities/seed.py
@@ -24,7 +24,7 @@
import numpy as np
import torch
-from pytorch_lightning.utilities.rank_zero import rank_zero_only, rank_zero_warn
+from pytorch_lightning.utilities.rank_zero import _get_rank, rank_zero_only, rank_zero_warn
log = logging.getLogger(__name__)
@@ -66,9 +66,7 @@ def seed_everything(seed: Optional[int] = None, workers: bool = False) -> int:
rank_zero_warn(f"{seed} is not in bounds, numpy accepts from {min_seed_value} to {max_seed_value}")
seed = _select_seed_randomly(min_seed_value, max_seed_value)
- # using `log.info` instead of `rank_zero_info`,
- # so users can verify the seed is properly set in distributed training.
- log.info(f"Global seed set to {seed}")
+ log.info(f"[rank: {_get_rank()}] Global seed set to {seed}")
os.environ["PL_GLOBAL_SEED"] = str(seed)
random.seed(seed)
np.random.seed(seed)
diff --git a/src/pytorch_lightning/utilities/types.py b/src/pytorch_lightning/utilities/types.py
index f6c14d366805f..18e2db6feb6c6 100644
--- a/src/pytorch_lightning/utilities/types.py
+++ b/src/pytorch_lightning/utilities/types.py
@@ -19,7 +19,7 @@
from contextlib import contextmanager
from dataclasses import dataclass
from pathlib import Path
-from typing import Any, Dict, Generator, Iterator, List, Mapping, Optional, Sequence, Type, Union
+from typing import Any, Callable, Dict, Generator, Iterator, List, Mapping, Optional, Sequence, Type, Union
import torch
from torch import Tensor
@@ -49,6 +49,7 @@
]
EVAL_DATALOADERS = Union[DataLoader, Sequence[DataLoader]]
_DEVICE = Union[torch.device, str, int]
+_MAP_LOCATION_TYPE = Optional[Union[_DEVICE, Callable[[_DEVICE], _DEVICE], Dict[_DEVICE, _DEVICE]]]
@runtime_checkable
diff --git a/tests/tests_app/cli/test_cli.py b/tests/tests_app/cli/test_cli.py
index 16e641ac38f23..48e1a26bb6f2b 100644
--- a/tests/tests_app/cli/test_cli.py
+++ b/tests/tests_app/cli/test_cli.py
@@ -70,7 +70,18 @@ def test_main_lightning_cli_help():
@mock.patch("lightning_cloud.login.Auth.authenticate", MagicMock())
@mock.patch("lightning_app.cli.cmd_clusters.AWSClusterManager.create")
-def test_create_cluster(create: mock.MagicMock):
+@pytest.mark.parametrize(
+ "extra_arguments,expected_instance_types,expected_cost_savings_mode",
+ [
+ (["--instance-types", "t3.xlarge"], ["t3.xlarge"], True),
+ (["--instance-types", "t3.xlarge,t3.2xlarge"], ["t3.xlarge", "t3.2xlarge"], True),
+ ([], None, True),
+ (["--enable-performance"], None, False),
+ ],
+)
+def test_create_cluster(
+ create_command: mock.MagicMock, extra_arguments, expected_instance_types, expected_cost_savings_mode
+):
runner = CliRunner()
runner.invoke(
create_cluster,
@@ -82,19 +93,18 @@ def test_create_cluster(create: mock.MagicMock):
"dummy",
"--role-arn",
"arn:aws:iam::1234567890:role/lai-byoc",
- "--instance-types",
- "t2.small",
- ],
+ ]
+ + extra_arguments,
)
- create.assert_called_once_with(
+ create_command.assert_called_once_with(
cluster_name="test-7",
region="us-east-1",
role_arn="arn:aws:iam::1234567890:role/lai-byoc",
external_id="dummy",
- instance_types=["t2.small"],
+ instance_types=expected_instance_types,
edit_before_creation=False,
- cost_savings=False,
+ cost_savings=expected_cost_savings_mode,
wait=False,
)
diff --git a/tests/tests_app/cli/test_cmd_show_logs.py b/tests/tests_app/cli/test_cmd_show_logs.py
new file mode 100644
index 0000000000000..0dc06025151fa
--- /dev/null
+++ b/tests/tests_app/cli/test_cmd_show_logs.py
@@ -0,0 +1,61 @@
+from unittest import mock
+
+from click.testing import CliRunner
+
+from lightning_app.cli.lightning_cli import logs
+
+
+@mock.patch("lightning_app.cli.lightning_cli.LightningClient")
+@mock.patch("lightning_app.cli.lightning_cli._get_project")
+def test_show_logs_errors(project, client):
+ """Test that the CLI prints the errors for the show logs command."""
+
+ runner = CliRunner()
+
+ # Response prep
+ app = mock.MagicMock()
+ app.name = "MyFakeApp"
+ work = mock.MagicMock()
+ work.name = "MyFakeWork"
+ flow = mock.MagicMock()
+ flow.name = "MyFakeFlow"
+
+ # No apps ever run
+ apps = {}
+ client.return_value.lightningapp_instance_service_list_lightningapp_instances.return_value.lightningapps = apps
+
+ result = runner.invoke(logs, ["NonExistentApp"])
+
+ assert result.exit_code == 1
+ assert "Error: You don't have any application in the cloud" in result.output
+
+ # App not specified
+ apps = {app}
+ client.return_value.lightningapp_instance_service_list_lightningapp_instances.return_value.lightningapps = apps
+
+ result = runner.invoke(logs)
+
+ assert result.exit_code == 1
+ assert "Please select one of available: [MyFakeApp]" in str(result.output)
+
+ # App does not exit
+ apps = {app}
+ client.return_value.lightningapp_instance_service_list_lightningapp_instances.return_value.lightningapps = apps
+
+ result = runner.invoke(logs, ["ThisAppDoesNotExist"])
+
+ assert result.exit_code == 1
+ assert "The Lightning App 'ThisAppDoesNotExist' does not exist." in str(result.output)
+
+ # Component does not exist
+ apps = {app}
+ works = {work}
+ flows = {flow}
+ client.return_value.lightningapp_instance_service_list_lightningapp_instances.return_value.lightningapps = apps
+ client.return_value.lightningwork_service_list_lightningwork.return_value.lightningworks = works
+ app.spec.flow_servers = flows
+
+ result = runner.invoke(logs, ["MyFakeApp", "NonExistentComponent"])
+
+ assert result.exit_code == 1
+ assert "Component 'NonExistentComponent' does not exist in app MyFakeApp." in result.output
diff --git a/tests/tests_app/core/test_lightning_api.py b/tests/tests_app/core/test_lightning_api.py
index edd2896d1951d..1b2bf2fb52fd9 100644
--- a/tests/tests_app/core/test_lightning_api.py
+++ b/tests/tests_app/core/test_lightning_api.py
@@ -2,15 +2,27 @@
import multiprocessing as mp
import os
from copy import deepcopy
+from multiprocessing import Process
+from time import sleep
from unittest import mock
import pytest
+import requests
from deepdiff import DeepDiff, Delta
from httpx import AsyncClient
+from pydantic import BaseModel
from lightning_app import LightningApp, LightningFlow, LightningWork
+from lightning_app.api.http_methods import Post
from lightning_app.core import api
-from lightning_app.core.api import fastapi_service, global_app_state_store, start_server, UIRefresher
+from lightning_app.core.api import (
+ fastapi_service,
+ global_app_state_store,
+ register_global_routes,
+ start_server,
+ UIRefresher,
+)
+from lightning_app.core.constants import APP_SERVER_PORT
from lightning_app.runners import MultiProcessRuntime, SingleProcessRuntime
from lightning_app.storage.drive import Drive
from lightning_app.testing.helpers import MockQueue
@@ -20,6 +32,8 @@
from lightning_app.utilities.redis import check_if_redis_running
from lightning_app.utilities.state import AppState, headers_for
+register_global_routes()
+
class WorkA(LightningWork):
def __init__(self):
@@ -161,12 +175,11 @@ def test_update_publish_state_and_maybe_refresh_ui():
app = AppStageTestingApp(FlowA(), debug=True)
publish_state_queue = MockQueue("publish_state_queue")
- commands_metadata_queue = MockQueue("commands_metadata_queue")
- commands_responses_queue = MockQueue("commands_metadata_queue")
+ api_response_queue = MockQueue("api_response_queue")
publish_state_queue.put(app.state_with_changes)
- thread = UIRefresher(publish_state_queue, commands_metadata_queue, commands_responses_queue)
+ thread = UIRefresher(publish_state_queue, api_response_queue)
thread.run_once()
assert global_app_state_store.get_app_state("1234") == app.state_with_changes
@@ -192,18 +205,14 @@ def get(self, timeout: int = 0):
publish_state_queue = InfiniteQueue("publish_state_queue")
change_state_queue = MockQueue("change_state_queue")
has_started_queue = MockQueue("has_started_queue")
- commands_requests_queue = MockQueue("commands_requests_queue")
- commands_responses_queue = MockQueue("commands_responses_queue")
- commands_metadata_queue = MockQueue("commands_metadata_queue")
+ api_response_queue = MockQueue("api_response_queue")
state = app.state_with_changes
publish_state_queue.put(state)
spec = extract_metadata_from_app(app)
ui_refresher = start_server(
publish_state_queue,
change_state_queue,
- commands_requests_queue,
- commands_responses_queue,
- commands_metadata_queue,
+ api_response_queue,
has_started_queue=has_started_queue,
uvicorn_run=False,
spec=spec,
@@ -343,16 +352,12 @@ def test_start_server_started():
api_publish_state_queue = mp.Queue()
api_delta_queue = mp.Queue()
has_started_queue = mp.Queue()
- commands_requests_queue = mp.Queue()
- commands_responses_queue = mp.Queue()
- commands_metadata_queue = mp.Queue()
+ api_response_queue = mp.Queue()
kwargs = dict(
api_publish_state_queue=api_publish_state_queue,
api_delta_queue=api_delta_queue,
has_started_queue=has_started_queue,
- commands_requests_queue=commands_requests_queue,
- commands_responses_queue=commands_responses_queue,
- commands_metadata_queue=commands_metadata_queue,
+ api_response_queue=api_response_queue,
port=1111,
)
@@ -372,18 +377,14 @@ def test_start_server_info_message(ui_refresher, uvicorn_run, caplog, monkeypatc
api_publish_state_queue = MockQueue()
api_delta_queue = MockQueue()
has_started_queue = MockQueue()
- commands_requests_queue = MockQueue()
- commands_responses_queue = MockQueue()
- commands_metadata_queue = MockQueue()
+ api_response_queue = MockQueue()
kwargs = dict(
host=host,
port=1111,
api_publish_state_queue=api_publish_state_queue,
api_delta_queue=api_delta_queue,
has_started_queue=has_started_queue,
- commands_requests_queue=commands_requests_queue,
- commands_responses_queue=commands_responses_queue,
- commands_metadata_queue=commands_metadata_queue,
+ api_response_queue=api_response_queue,
)
monkeypatch.setattr(api, "logger", logging.getLogger())
@@ -395,3 +396,65 @@ def test_start_server_info_message(ui_refresher, uvicorn_run, caplog, monkeypatc
ui_refresher.assert_called_once()
uvicorn_run.assert_called_once_with(host="0.0.0.1", port=1111, log_level="error", app=mock.ANY)
+
+
+class InputRequestModel(BaseModel):
+ name: str
+
+
+class OutputRequestModel(BaseModel):
+ name: str
+ counter: int
+
+
+class FlowAPI(LightningFlow):
+ def __init__(self):
+ super().__init__()
+ self.counter = 0
+
+ def run(self):
+ if self.counter == 2:
+ sleep(0.5)
+ self._exit()
+
+ def request(self, config: InputRequestModel) -> OutputRequestModel:
+ self.counter += 1
+ return OutputRequestModel(name=config.name, counter=self.counter)
+
+ def configure_api(self):
+ return [Post("/api/v1/request", self.request)]
+
+
+def target():
+ app = LightningApp(FlowAPI())
+ MultiProcessRuntime(app).dispatch()
+
+
+def test_configure_api():
+
+ process = Process(target=target)
+ process.start()
+ time_left = 15
+ while time_left > 0:
+ try:
+ requests.get(f"http://localhost:{APP_SERVER_PORT}/healthz")
+ break
+ except requests.exceptions.ConnectionError:
+ sleep(0.1)
+ time_left -= 0.1
+
+ response = requests.post(
+ f"http://localhost:{APP_SERVER_PORT}/api/v1/request", data=InputRequestModel(name="hello").json()
+ )
+ assert response.json() == {"name": "hello", "counter": 1}
+ response = requests.post(
+ f"http://localhost:{APP_SERVER_PORT}/api/v1/request", data=InputRequestModel(name="hello").json()
+ )
+ assert response.json() == {"name": "hello", "counter": 2}
+ time_left = 15
+ while time_left > 0:
+ if process.exitcode == 0:
+ break
+ sleep(0.1)
+ time_left -= 0.1
+ assert process.exitcode == 0
diff --git a/tests/tests_app/core/test_lightning_app.py b/tests/tests_app/core/test_lightning_app.py
index a3a15085b98e3..3776481965be3 100644
--- a/tests/tests_app/core/test_lightning_app.py
+++ b/tests/tests_app/core/test_lightning_app.py
@@ -1,3 +1,4 @@
+import logging
import os
import pickle
from time import sleep
@@ -27,6 +28,8 @@
from lightning_app.utilities.redis import check_if_redis_running
from lightning_app.utilities.warnings import LightningFlowWarning
+logger = logging.getLogger()
+
class B1(LightningFlow):
def __init__(self):
@@ -439,19 +442,25 @@ def __init__(self):
self.counter = 0
def run(self):
- self.counter = 1
+ if self.counter < 2:
+ self.counter += 1
def test_maybe_apply_changes_from_flow():
"""This test validates the app `_updated` is set to True only if the state was changed in the flow."""
app = LightningApp(SimpleFlow())
- assert not app._has_updated
+ assert app._has_updated
app.maybe_apply_changes()
app.root.run()
app.maybe_apply_changes()
assert app._has_updated
app._has_updated = False
+ app.root.run()
+ app.maybe_apply_changes()
+ assert app._has_updated
+ app._has_updated = False
+ app.root.run()
app.maybe_apply_changes()
assert not app._has_updated
@@ -896,6 +905,7 @@ def __init__(self, **kwargs):
def run(self, signal: int):
self.counter += 1
+ assert len(self._calls) == 2
class SizeFlow(LightningFlow):
@@ -919,3 +929,29 @@ def test_state_size_constant_growth():
MultiProcessRuntime(app, start_server=False).dispatch()
assert app.root._state_sizes[0] <= 5904
assert app.root._state_sizes[20] <= 23736
+
+
+class FlowUpdated(LightningFlow):
+ def run(self):
+ logger.info("Hello World")
+
+
+class NonUpdatedLightningTestApp(LightningTestApp):
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ self.counter = 0
+
+ def on_after_run_once(self):
+ self.counter += 1
+ if not self._has_updated and self.counter > 2:
+ return True
+ return super().on_after_run_once()
+
+
+def test_non_updated_flow(caplog):
+ """This tests validate the app can run 3 times and call the flow only once."""
+ with caplog.at_level(logging.INFO):
+ app = NonUpdatedLightningTestApp(FlowUpdated())
+ MultiProcessRuntime(app, start_server=False).dispatch()
+ assert caplog.messages == ["Hello World"]
+ assert app.counter == 3
diff --git a/tests/tests_app/core/test_lightning_flow.py b/tests/tests_app/core/test_lightning_flow.py
index e8ce1222a3186..4c0eb23ea014c 100644
--- a/tests/tests_app/core/test_lightning_flow.py
+++ b/tests/tests_app/core/test_lightning_flow.py
@@ -16,7 +16,7 @@
from lightning_app.storage import Path
from lightning_app.storage.path import storage_root_dir
from lightning_app.testing.helpers import EmptyFlow, EmptyWork
-from lightning_app.utilities.app_helpers import _delta_to_appstate_delta, _LightningAppRef
+from lightning_app.utilities.app_helpers import _delta_to_app_state_delta, _LightningAppRef
from lightning_app.utilities.enum import CacheCallsKeys
from lightning_app.utilities.exceptions import ExitAppException
@@ -416,7 +416,7 @@ def run(self):
flow_a.work.counter = 1
work_state_2 = flow_a.work.state
delta = Delta(DeepDiff(work_state, work_state_2, verbose_level=2))
- delta = _delta_to_appstate_delta(flow_a, flow_a.work, delta)
+ delta = _delta_to_app_state_delta(flow_a, flow_a.work, delta)
new_flow_state = LightningApp.populate_changes(flow_state, flow_state + delta)
flow_a.set_state(new_flow_state)
assert flow_a.work.counter == 1
@@ -592,24 +592,23 @@ def run(self):
class FlowSchedule(LightningFlow):
def __init__(self):
super().__init__()
- self._last_time = None
+ self._last_times = []
+ self.target = 3
+ self.seconds = ",".join([str(v) for v in range(0, 60, self.target)])
def run(self):
- if self.schedule("* * * * * 0,5,10,15,20,25,30,35,40,45,50,55"):
- if self._last_time is None:
- self._last_time = False
- elif not self._last_time:
- self._last_time = time()
+ if self.schedule(f"* * * * * {self.seconds}"):
+ if len(self._last_times) < 3:
+ self._last_times.append(time())
else:
- # TODO (tchaton) Optimize flow execution.
- assert 4.0 < abs(time() - self._last_time) < 6.0
+ assert abs((time() - self._last_times[-1]) - self.target) < 3
self._exit()
def test_scheduling_api():
app = LightningApp(FlowSchedule())
- MultiProcessRuntime(app).dispatch()
+ MultiProcessRuntime(app, start_server=True).dispatch()
def test_lightning_flow():
diff --git a/tests/tests_app/runners/test_cloud.py b/tests/tests_app/runners/test_cloud.py
index 4b1cf08e8554d..640eb9c114c2d 100644
--- a/tests/tests_app/runners/test_cloud.py
+++ b/tests/tests_app/runners/test_cloud.py
@@ -1,4 +1,5 @@
import logging
+from copy import copy
from pathlib import Path
from unittest import mock
from unittest.mock import MagicMock
@@ -9,21 +10,29 @@
Gridv1ImageSpec,
V1BuildSpec,
V1DependencyFileInfo,
+ V1Drive,
+ V1DriveSpec,
+ V1DriveStatus,
+ V1DriveType,
V1LightningappInstanceState,
+ V1LightningworkDrives,
V1LightningworkSpec,
V1ListLightningappInstancesResponse,
V1ListMembershipsResponse,
V1Membership,
+ V1Metadata,
V1NetworkConfig,
V1PackageManager,
V1ProjectClusterBinding,
V1PythonDependencyInfo,
+ V1SourceType,
V1UserRequestedComputeConfig,
V1Work,
)
from lightning_app import LightningApp, LightningWork
from lightning_app.runners import backends, cloud
+from lightning_app.storage import Drive
from lightning_app.utilities.cloud import _get_project
from lightning_app.utilities.dependency_caching import get_hash
@@ -33,6 +42,25 @@ def run(self):
print("my run")
+class WorkWithSingleDrive(LightningWork):
+ def __init__(self):
+ super().__init__()
+ self.drive = None
+
+ def run(self):
+ pass
+
+
+class WorkWithTwoDrives(LightningWork):
+ def __init__(self):
+ super().__init__()
+ self.lit_drive = None
+ self.s3_drive = None
+
+ def run(self):
+ pass
+
+
class TestAppCreationClient:
"""Testing the calls made using GridRestClient to create the app."""
@@ -250,6 +278,134 @@ def test_call_with_work_app(self, lightningapps, monkeypatch, tmpdir):
),
image="random_base_public_image",
),
+ drives=[],
+ user_requested_compute_config=V1UserRequestedComputeConfig(
+ name="default", count=1, disk_size=0, preemptible=False, shm_size=0
+ ),
+ network_config=[V1NetworkConfig(name=mock.ANY, host=None, port=8080)],
+ ),
+ )
+ ],
+ )
+ mock_client.lightningapp_v2_service_create_lightningapp_release.assert_called_once_with(
+ "test-project-id", mock.ANY, expected_body
+ )
+
+ # running dispatch with disabled dependency cache
+ mock_client.reset_mock()
+ monkeypatch.setattr(cloud, "DISABLE_DEPENDENCY_CACHE", True)
+ expected_body.dependency_cache_key = None
+ cloud_runtime.dispatch()
+ mock_client.lightningapp_v2_service_create_lightningapp_release.assert_called_once_with(
+ "test-project-id", mock.ANY, expected_body
+ )
+ else:
+ mock_client.lightningapp_v2_service_create_lightningapp_release_instance.assert_called_once_with(
+ "test-project-id", mock.ANY, mock.ANY, mock.ANY
+ )
+
+ @mock.patch("lightning_app.runners.backends.cloud.LightningClient", mock.MagicMock())
+ @pytest.mark.parametrize("lightningapps", [[], [MagicMock()]])
+ def test_call_with_work_app_and_attached_drives(self, lightningapps, monkeypatch, tmpdir):
+ source_code_root_dir = Path(tmpdir / "src").absolute()
+ source_code_root_dir.mkdir()
+ Path(source_code_root_dir / ".lightning").write_text("name: myapp")
+ requirements_file = Path(source_code_root_dir / "requirements.txt")
+ Path(requirements_file).touch()
+
+ mock_client = mock.MagicMock()
+ if lightningapps:
+ lightningapps[0].status.phase = V1LightningappInstanceState.STOPPED
+ mock_client.lightningapp_instance_service_list_lightningapp_instances.return_value = (
+ V1ListLightningappInstancesResponse(lightningapps=lightningapps)
+ )
+ lightning_app_instance = MagicMock()
+ mock_client.lightningapp_v2_service_create_lightningapp_release = MagicMock(return_value=lightning_app_instance)
+ mock_client.lightningapp_v2_service_create_lightningapp_release_instance = MagicMock(
+ return_value=lightning_app_instance
+ )
+ existing_instance = MagicMock()
+ existing_instance.status.phase = V1LightningappInstanceState.STOPPED
+ mock_client.lightningapp_service_get_lightningapp = MagicMock(return_value=existing_instance)
+ cloud_backend = mock.MagicMock()
+ cloud_backend.client = mock_client
+ monkeypatch.setattr(backends, "CloudBackend", mock.MagicMock(return_value=cloud_backend))
+ monkeypatch.setattr(cloud, "LocalSourceCodeDir", mock.MagicMock())
+ monkeypatch.setattr(cloud, "_prepare_lightning_wheels_and_requirements", mock.MagicMock())
+ app = mock.MagicMock()
+ flow = mock.MagicMock()
+
+ mocked_drive = MagicMock(spec=Drive)
+ setattr(mocked_drive, "id", "foobar")
+ setattr(mocked_drive, "protocol", "lit://")
+ setattr(mocked_drive, "component_name", "test-work")
+ setattr(mocked_drive, "allow_duplicates", False)
+ setattr(mocked_drive, "root_folder", tmpdir)
+ # deepcopy on a MagicMock instance will return an empty magicmock instance. To
+ # overcome this we set the __deepcopy__ method `return_value` to equal what
+ # should be the results of the deepcopy operation (an instance of the original class)
+ mocked_drive.__deepcopy__.return_value = copy(mocked_drive)
+
+ work = WorkWithSingleDrive()
+ monkeypatch.setattr(work, "drive", mocked_drive)
+ monkeypatch.setattr(work, "_state", {"_port", "drive"})
+ monkeypatch.setattr(work, "_name", "test-work")
+ monkeypatch.setattr(work._cloud_build_config, "build_commands", lambda: ["echo 'start'"])
+ monkeypatch.setattr(work._cloud_build_config, "requirements", ["torch==1.0.0", "numpy==1.0.0"])
+ monkeypatch.setattr(work._cloud_build_config, "image", "random_base_public_image")
+ monkeypatch.setattr(work._cloud_compute, "disk_size", 0)
+ monkeypatch.setattr(work._cloud_compute, "preemptible", False)
+ monkeypatch.setattr(work, "_port", 8080)
+
+ flow.works = lambda recurse: [work]
+ app.flows = [flow]
+ cloud_runtime = cloud.CloudRuntime(app=app, entrypoint_file=(source_code_root_dir / "entrypoint.py"))
+ monkeypatch.setattr(
+ "lightning_app.runners.cloud._get_project",
+ lambda x: V1Membership(name="test-project", project_id="test-project-id"),
+ )
+ cloud_runtime.dispatch()
+
+ if lightningapps:
+ expected_body = Body8(
+ description=None,
+ local_source=True,
+ app_entrypoint_file="entrypoint.py",
+ enable_app_server=True,
+ flow_servers=[],
+ dependency_cache_key=get_hash(requirements_file),
+ image_spec=Gridv1ImageSpec(
+ dependency_file_info=V1DependencyFileInfo(
+ package_manager=V1PackageManager.PIP, path="requirements.txt"
+ )
+ ),
+ works=[
+ V1Work(
+ name="test-work",
+ spec=V1LightningworkSpec(
+ build_spec=V1BuildSpec(
+ commands=["echo 'start'"],
+ python_dependencies=V1PythonDependencyInfo(
+ package_manager=V1PackageManager.PIP, packages="torch==1.0.0\nnumpy==1.0.0"
+ ),
+ image="random_base_public_image",
+ ),
+ drives=[
+ V1LightningworkDrives(
+ drive=V1Drive(
+ metadata=V1Metadata(
+ name="test-work.drive",
+ ),
+ spec=V1DriveSpec(
+ drive_type=V1DriveType.NO_MOUNT_S3,
+ source_type=V1SourceType.S3,
+ source="lit://foobar",
+ ),
+ status=V1DriveStatus(),
+ ),
+ mount_location=str(tmpdir),
+ ),
+ ],
user_requested_compute_config=V1UserRequestedComputeConfig(
name="default", count=1, disk_size=0, preemptible=False, shm_size=0
),
@@ -275,6 +431,206 @@ def test_call_with_work_app(self, lightningapps, monkeypatch, tmpdir):
"test-project-id", mock.ANY, mock.ANY, mock.ANY
)
+ @mock.patch("lightning_app.runners.backends.cloud.LightningClient", mock.MagicMock())
+ @pytest.mark.parametrize("lightningapps", [[], [MagicMock()]])
+ def test_call_with_work_app_and_multiple_attached_drives(self, lightningapps, monkeypatch, tmpdir):
+ source_code_root_dir = Path(tmpdir / "src").absolute()
+ source_code_root_dir.mkdir()
+ Path(source_code_root_dir / ".lightning").write_text("name: myapp")
+ requirements_file = Path(source_code_root_dir / "requirements.txt")
+ Path(requirements_file).touch()
+
+ mock_client = mock.MagicMock()
+ if lightningapps:
+ lightningapps[0].status.phase = V1LightningappInstanceState.STOPPED
+ mock_client.lightningapp_instance_service_list_lightningapp_instances.return_value = (
+ V1ListLightningappInstancesResponse(lightningapps=lightningapps)
+ )
+ lightning_app_instance = MagicMock()
+ mock_client.lightningapp_v2_service_create_lightningapp_release = MagicMock(return_value=lightning_app_instance)
+ mock_client.lightningapp_v2_service_create_lightningapp_release_instance = MagicMock(
+ return_value=lightning_app_instance
+ )
+ existing_instance = MagicMock()
+ existing_instance.status.phase = V1LightningappInstanceState.STOPPED
+ mock_client.lightningapp_service_get_lightningapp = MagicMock(return_value=existing_instance)
+ cloud_backend = mock.MagicMock()
+ cloud_backend.client = mock_client
+ monkeypatch.setattr(backends, "CloudBackend", mock.MagicMock(return_value=cloud_backend))
+ monkeypatch.setattr(cloud, "LocalSourceCodeDir", mock.MagicMock())
+ monkeypatch.setattr(cloud, "_prepare_lightning_wheels_and_requirements", mock.MagicMock())
+ app = mock.MagicMock()
+ flow = mock.MagicMock()
+
+ mocked_lit_drive = MagicMock(spec=Drive)
+ setattr(mocked_lit_drive, "id", "foobar")
+ setattr(mocked_lit_drive, "protocol", "lit://")
+ setattr(mocked_lit_drive, "component_name", "test-work")
+ setattr(mocked_lit_drive, "allow_duplicates", False)
+ setattr(mocked_lit_drive, "root_folder", tmpdir)
+ # deepcopy on a MagicMock instance will return an empty magicmock instance. To
+ # overcome this we set the __deepcopy__ method `return_value` to equal what
+ # should be the results of the deepcopy operation (an instance of the original class)
+ mocked_lit_drive.__deepcopy__.return_value = copy(mocked_lit_drive)
+
+ mocked_s3_drive = MagicMock(spec=Drive)
+ setattr(mocked_s3_drive, "id", "some-bucket/path/")
+ setattr(mocked_s3_drive, "protocol", "s3://")
+ setattr(mocked_s3_drive, "component_name", "test-work")
+ setattr(mocked_s3_drive, "allow_duplicates", False)
+ setattr(mocked_s3_drive, "root_folder", "/hello/")
+ # deepcopy on a MagicMock instance will return an empty magicmock instance. To
+ # overcome this we set the __deepcopy__ method `return_value` to equal what
+ # should be the results of the deepcopy operation (an instance of the original class)
+ mocked_s3_drive.__deepcopy__.return_value = copy(mocked_s3_drive)
+
+ work = WorkWithTwoDrives()
+ monkeypatch.setattr(work, "lit_drive", mocked_lit_drive)
+ monkeypatch.setattr(work, "s3_drive", mocked_s3_drive)
+ monkeypatch.setattr(work, "_state", {"_port", "_name", "lit_drive", "s3_drive"})
+ monkeypatch.setattr(work, "_name", "test-work")
+ monkeypatch.setattr(work._cloud_build_config, "build_commands", lambda: ["echo 'start'"])
+ monkeypatch.setattr(work._cloud_build_config, "requirements", ["torch==1.0.0", "numpy==1.0.0"])
+ monkeypatch.setattr(work._cloud_build_config, "image", "random_base_public_image")
+ monkeypatch.setattr(work._cloud_compute, "disk_size", 0)
+ monkeypatch.setattr(work._cloud_compute, "preemptible", False)
+ monkeypatch.setattr(work, "_port", 8080)
+
+ flow.works = lambda recurse: [work]
+ app.flows = [flow]
+ cloud_runtime = cloud.CloudRuntime(app=app, entrypoint_file=(source_code_root_dir / "entrypoint.py"))
+ monkeypatch.setattr(
+ "lightning_app.runners.cloud._get_project",
+ lambda x: V1Membership(name="test-project", project_id="test-project-id"),
+ )
+ cloud_runtime.dispatch()
+
+ if lightningapps:
+ s3_drive_spec = V1LightningworkDrives(
+ drive=V1Drive(
+ metadata=V1Metadata(
+ name="test-work.s3_drive",
+ ),
+ spec=V1DriveSpec(
+ drive_type=V1DriveType.INDEXED_S3,
+ source_type=V1SourceType.S3,
+ source="s3://some-bucket/path/",
+ ),
+ status=V1DriveStatus(),
+ ),
+ mount_location="/hello/",
+ )
+ lit_drive_spec = V1LightningworkDrives(
+ drive=V1Drive(
+ metadata=V1Metadata(
+ name="test-work.lit_drive",
+ ),
+ spec=V1DriveSpec(
+ drive_type=V1DriveType.NO_MOUNT_S3,
+ source_type=V1SourceType.S3,
+ source="lit://foobar",
+ ),
+ status=V1DriveStatus(),
+ ),
+ mount_location=str(tmpdir),
+ )
+
+ # order of drives in the spec is non-deterministic, so there are two options
+ # depending for the expected body value on which drive is ordered in the list first.
+
+ expected_body_option_1 = Body8(
+ description=None,
+ local_source=True,
+ app_entrypoint_file="entrypoint.py",
+ enable_app_server=True,
+ flow_servers=[],
+ dependency_cache_key=get_hash(requirements_file),
+ image_spec=Gridv1ImageSpec(
+ dependency_file_info=V1DependencyFileInfo(
+ package_manager=V1PackageManager.PIP, path="requirements.txt"
+ )
+ ),
+ works=[
+ V1Work(
+ name="test-work",
+ spec=V1LightningworkSpec(
+ build_spec=V1BuildSpec(
+ commands=["echo 'start'"],
+ python_dependencies=V1PythonDependencyInfo(
+ package_manager=V1PackageManager.PIP, packages="torch==1.0.0\nnumpy==1.0.0"
+ ),
+ image="random_base_public_image",
+ ),
+ drives=[lit_drive_spec, s3_drive_spec],
+ user_requested_compute_config=V1UserRequestedComputeConfig(
+ name="default", count=1, disk_size=0, preemptible=False, shm_size=0
+ ),
+ network_config=[V1NetworkConfig(name=mock.ANY, host=None, port=8080)],
+ ),
+ )
+ ],
+ )
+
+ expected_body_option_2 = Body8(
+ description=None,
+ local_source=True,
+ app_entrypoint_file="entrypoint.py",
+ enable_app_server=True,
+ flow_servers=[],
+ dependency_cache_key=get_hash(requirements_file),
+ image_spec=Gridv1ImageSpec(
+ dependency_file_info=V1DependencyFileInfo(
+ package_manager=V1PackageManager.PIP, path="requirements.txt"
+ )
+ ),
+ works=[
+ V1Work(
+ name="test-work",
+ spec=V1LightningworkSpec(
+ build_spec=V1BuildSpec(
+ commands=["echo 'start'"],
+ python_dependencies=V1PythonDependencyInfo(
+ package_manager=V1PackageManager.PIP, packages="torch==1.0.0\nnumpy==1.0.0"
+ ),
+ image="random_base_public_image",
+ ),
+ drives=[s3_drive_spec, lit_drive_spec],
+ user_requested_compute_config=V1UserRequestedComputeConfig(
+ name="default", count=1, disk_size=0, preemptible=False, shm_size=0
+ ),
+ network_config=[V1NetworkConfig(name=mock.ANY, host=None, port=8080)],
+ ),
+ )
+ ],
+ )
+
+ # try both options for the expected body to avoid false
+ # positive test failures depending on system randomness
+
+ expected_body = expected_body_option_1
+ try:
+ mock_client.lightningapp_v2_service_create_lightningapp_release.assert_called_once_with(
+ "test-project-id", mock.ANY, expected_body
+ )
+ except Exception:
+ expected_body = expected_body_option_2
+ mock_client.lightningapp_v2_service_create_lightningapp_release.assert_called_once_with(
+ "test-project-id", mock.ANY, expected_body
+ )
+
+ # running dispatch with disabled dependency cache
+ mock_client.reset_mock()
+ monkeypatch.setattr(cloud, "DISABLE_DEPENDENCY_CACHE", True)
+ expected_body.dependency_cache_key = None
+ cloud_runtime.dispatch()
+ mock_client.lightningapp_v2_service_create_lightningapp_release.assert_called_once_with(
+ "test-project-id", mock.ANY, expected_body
+ )
+ else:
+ mock_client.lightningapp_v2_service_create_lightningapp_release_instance.assert_called_once_with(
+ "test-project-id", mock.ANY, mock.ANY, mock.ANY
+ )
+
@mock.patch("lightning_app.core.queues.QueuingSystem", MagicMock())
@mock.patch("lightning_app.runners.backends.cloud.LightningClient", MagicMock())
diff --git a/tests/tests_app/storage/test_drive.py b/tests/tests_app/storage/test_drive.py
index 3d9db44c10e13..0d452571d9f43 100644
--- a/tests/tests_app/storage/test_drive.py
+++ b/tests/tests_app/storage/test_drive.py
@@ -11,7 +11,7 @@
from lightning_app.utilities.component import _set_flow_context
-class SyncWorkA(LightningWork):
+class SyncWorkLITDriveA(LightningWork):
def __init__(self, tmpdir):
super().__init__()
self.tmpdir = tmpdir
@@ -25,19 +25,19 @@ def run(self, drive: Drive):
os.remove(f"{self.tmpdir}/a.txt")
-class SyncWorkB(LightningWork):
+class SyncWorkLITDriveB(LightningWork):
def run(self, drive: Drive):
assert not os.path.exists("a.txt")
drive.get("a.txt")
assert os.path.exists("a.txt")
-class SyncFlow(LightningFlow):
+class SyncFlowLITDrives(LightningFlow):
def __init__(self, tmpdir):
super().__init__()
self.log_dir = Drive("lit://log_dir")
- self.work_a = SyncWorkA(str(tmpdir))
- self.work_b = SyncWorkB()
+ self.work_a = SyncWorkLITDriveA(str(tmpdir))
+ self.work_b = SyncWorkLITDriveB()
def run(self):
self.work_a.run(self.log_dir)
@@ -45,15 +45,15 @@ def run(self):
self._exit()
-def test_synchronization_drive(tmpdir):
+def test_synchronization_lit_drive(tmpdir):
if os.path.exists("a.txt"):
os.remove("a.txt")
- app = LightningApp(SyncFlow(tmpdir))
+ app = LightningApp(SyncFlowLITDrives(tmpdir))
MultiProcessRuntime(app, start_server=False).dispatch()
os.remove("a.txt")
-class Work(LightningWork):
+class LITDriveWork(LightningWork):
def __init__(self):
super().__init__(parallel=True)
self.drive = None
@@ -75,7 +75,7 @@ def run(self, *args, **kwargs):
self.counter += 1
-class Work2(LightningWork):
+class LITDriveWork2(LightningWork):
def __init__(self):
super().__init__(parallel=True)
@@ -86,11 +86,11 @@ def run(self, drive: Drive, **kwargs):
assert drive.list(".", component_name=self.name) == []
-class Flow(LightningFlow):
+class LITDriveFlow(LightningFlow):
def __init__(self):
super().__init__()
- self.work = Work()
- self.work2 = Work2()
+ self.work = LITDriveWork()
+ self.work2 = LITDriveWork2()
def run(self):
self.work.run("0")
@@ -102,15 +102,15 @@ def run(self):
self._exit()
-def test_drive_transferring_files():
- app = LightningApp(Flow())
+def test_lit_drive_transferring_files():
+ app = LightningApp(LITDriveFlow())
MultiProcessRuntime(app, start_server=False).dispatch()
os.remove("a.txt")
-def test_drive():
- with pytest.raises(Exception, match="The Drive id needs to start with one of the following protocols"):
- Drive("this_drive_id")
+def test_lit_drive():
+ with pytest.raises(Exception, match="Unknown protocol for the drive 'id' argument"):
+ Drive("invalid_drive_id")
with pytest.raises(
Exception, match="The id should be unique to identify your drive. Found `this_drive_id/something_else`."
@@ -213,9 +213,46 @@ def test_drive():
os.remove("a.txt")
-def test_maybe_create_drive():
+def test_s3_drives():
+ drive = Drive("s3://foo/", allow_duplicates=True)
+ drive.component_name = "root.work"
- drive = Drive("lit://drive_3", allow_duplicates=False)
+ with pytest.raises(
+ Exception, match="S3 based drives cannot currently add files via this API. Did you mean to use `lit://` drives?"
+ ):
+ drive.put("a.txt")
+ with pytest.raises(
+ Exception,
+ match="S3 based drives cannot currently list files via this API. Did you mean to use `lit://` drives?",
+ ):
+ drive.list("a.txt")
+ with pytest.raises(
+ Exception, match="S3 based drives cannot currently get files via this API. Did you mean to use `lit://` drives?"
+ ):
+ drive.get("a.txt")
+ with pytest.raises(
+ Exception,
+ match="S3 based drives cannot currently delete files via this API. Did you mean to use `lit://` drives?",
+ ):
+ drive.delete("a.txt")
+
+ _set_flow_context()
+ with pytest.raises(Exception, match="The flow isn't allowed to put files into a Drive."):
+ drive.put("a.txt")
+ with pytest.raises(Exception, match="The flow isn't allowed to list files from a Drive."):
+ drive.list("a.txt")
+ with pytest.raises(Exception, match="The flow isn't allowed to get files from a Drive."):
+ drive.get("a.txt")
+
+
+def test_create_s3_drive_without_trailing_slash_fails():
+ with pytest.raises(ValueError, match="S3 drives must end in a trailing slash"):
+ Drive("s3://foo")
+
+
+@pytest.mark.parametrize("drive_id", ["lit://drive", "s3://drive/"])
+def test_maybe_create_drive(drive_id):
+ drive = Drive(drive_id, allow_duplicates=False)
drive.component_name = "root.work1"
new_drive = _maybe_create_drive(drive.component_name, drive.to_dict())
assert new_drive.protocol == drive.protocol
@@ -223,9 +260,9 @@ def test_maybe_create_drive():
assert new_drive.component_name == drive.component_name
-def test_drive_deepcopy():
-
- drive = Drive("lit://drive", allow_duplicates=True)
+@pytest.mark.parametrize("drive_id", ["lit://drive", "s3://drive/"])
+def test_drive_deepcopy(drive_id):
+ drive = Drive(drive_id, allow_duplicates=True)
drive.component_name = "root.work1"
new_drive = deepcopy(drive)
assert new_drive.id == drive.id
diff --git a/tests/tests_app/utilities/packaging/test_lightning_utils.py b/tests/tests_app/utilities/packaging/test_lightning_utils.py
index b34e3162d5a0c..8f30aa21dd396 100644
--- a/tests/tests_app/utilities/packaging/test_lightning_utils.py
+++ b/tests/tests_app/utilities/packaging/test_lightning_utils.py
@@ -1,4 +1,5 @@
import os
+from unittest import mock
import pytest
@@ -21,6 +22,21 @@ def test_prepare_lightning_wheels_and_requirement(tmpdir):
assert os.listdir(tmpdir) == []
+def _mocked_get_dist_path_if_editable_install(*args, **kwargs):
+ return None
+
+
+@mock.patch(
+ "lightning_app.utilities.packaging.lightning_utils.get_dist_path_if_editable_install",
+ new=_mocked_get_dist_path_if_editable_install,
+)
+def test_prepare_lightning_wheels_and_requirement_for_packages_installed_in_editable_mode(tmpdir):
+ """This test ensures the source does not get packaged inside the lightning repo if not installed in editable
+ mode."""
+ cleanup_handle = _prepare_lightning_wheels_and_requirements(tmpdir)
+ assert cleanup_handle is None
+
+
@pytest.mark.skip(reason="TODO: Find a way to check for the latest version")
@RunIf(skip_windows=True)
def test_verify_lightning_version(monkeypatch):
diff --git a/tests/tests_app/utilities/test_app_logs.py b/tests/tests_app/utilities/test_app_logs.py
new file mode 100644
index 0000000000000..7a0fe087e7c29
--- /dev/null
+++ b/tests/tests_app/utilities/test_app_logs.py
@@ -0,0 +1,13 @@
+from datetime import datetime
+from time import sleep
+from unittest.mock import MagicMock
+
+from lightning_app.utilities.app_logs import _LogEvent
+
+
+def test_log_event():
+ event_1 = _LogEvent("", datetime.now(), MagicMock(), MagicMock())
+ sleep(0.1)
+ event_2 = _LogEvent("", datetime.now(), MagicMock(), MagicMock())
+ assert event_1 < event_2
+ assert event_1 <= event_2
diff --git a/tests/tests_app/utilities/test_commands.py b/tests/tests_app/utilities/test_commands.py
index 1e8e36ed09545..1be35a3a2e290 100644
--- a/tests/tests_app/utilities/test_commands.py
+++ b/tests/tests_app/utilities/test_commands.py
@@ -14,7 +14,7 @@
from lightning_app.core.constants import APP_SERVER_PORT
from lightning_app.runners import MultiProcessRuntime
from lightning_app.testing.helpers import RunIf
-from lightning_app.utilities.commands.base import _command_to_method_and_metadata, _download_command, ClientCommand
+from lightning_app.utilities.commands.base import _download_command, _validate_client_command, ClientCommand
from lightning_app.utilities.state import AppState
@@ -25,7 +25,6 @@ class SweepConfig(BaseModel):
class SweepCommand(ClientCommand):
def run(self) -> None:
- print(sys.argv)
parser = argparse.ArgumentParser()
parser.add_argument("--sweep_name", type=str)
parser.add_argument("--num_trials", type=int)
@@ -44,7 +43,7 @@ def __init__(self):
def run(self):
if self.has_sweep and len(self.names) == 1:
- sleep(2)
+ sleep(1)
self._exit()
def trigger_method(self, name: str):
@@ -91,15 +90,15 @@ def run_failure_2(name: CustomModel):
@RunIf(skip_windows=True)
-def test_command_to_method_and_metadata():
+def test_validate_client_command():
with pytest.raises(Exception, match="The provided annotation for the argument name"):
- _command_to_method_and_metadata(ClientCommand(run_failure_0))
+ _validate_client_command(ClientCommand(run_failure_0))
with pytest.raises(Exception, match="annotate your method"):
- _command_to_method_and_metadata(ClientCommand(run_failure_1))
+ _validate_client_command(ClientCommand(run_failure_1))
with pytest.raises(Exception, match="lightning_app/utilities/commands/base.py"):
- _command_to_method_and_metadata(ClientCommand(run_failure_2))
+ _validate_client_command(ClientCommand(run_failure_2))
def test_client_commands(monkeypatch):
@@ -115,17 +114,13 @@ def test_client_commands(monkeypatch):
url = "http//"
kwargs = {"something": "1", "something_else": "1"}
command = DummyCommand(run)
- _, command_metadata = _command_to_method_and_metadata(command)
- command_metadata.update(
- {
- "command": "dummy",
- "affiliation": "root",
- "is_client_command": True,
- "owner": "root",
- }
+ _validate_client_command(command)
+ client_command = _download_command(
+ command_name="something",
+ cls_path=__file__,
+ cls_name="DummyCommand",
)
- client_command, models = _download_command(command_metadata, None)
- client_command._setup(metadata=command_metadata, models=models, app_url=url)
+ client_command._setup("something", app_url=url)
client_command.run(**kwargs)
@@ -153,10 +148,12 @@ def test_configure_commands(monkeypatch):
state = AppState()
state._request_state()
assert state.names == ["something"]
- monkeypatch.setattr(sys, "argv", ["lightning", "sweep", "--sweep_name", "my_name", "--num_trials", "1"])
+ monkeypatch.setattr(sys, "argv", ["lightning", "sweep", "--sweep_name=my_name", "--num_trials=1"])
app_command()
time_left = 15
- while time_left > 0 or process.exitcode is None:
+ while time_left > 0:
+ if process.exitcode == 0:
+ break
sleep(0.1)
time_left -= 0.1
assert process.exitcode == 0
diff --git a/tests/tests_app_examples/test_boring_app.py b/tests/tests_app_examples/test_boring_app.py
index 1f681260de5c2..afb958571d16b 100644
--- a/tests/tests_app_examples/test_boring_app.py
+++ b/tests/tests_app_examples/test_boring_app.py
@@ -1,8 +1,10 @@
import os
import pytest
+from click.testing import CliRunner
from tests_app import _PROJECT_ROOT
+from lightning_app.cli.lightning_cli import logs
from lightning_app.testing.testing import run_app_in_cloud, wait_for
@@ -11,7 +13,8 @@ def test_boring_app_example_cloud() -> None:
with run_app_in_cloud(os.path.join(_PROJECT_ROOT, "examples/app_boring/"), app_name="app_dynamic.py") as (
_,
view_page,
- _,
+ fetch_logs,
+ name,
):
def check_hello_there(*_, **__):
@@ -21,3 +24,14 @@ def check_hello_there(*_, **__):
return True
wait_for(view_page, check_hello_there)
+
+ for _ in fetch_logs():
+ pass
+
+ runner = CliRunner()
+ result = runner.invoke(logs, [name])
+ lines = result.output.splitlines()
+
+ assert result.exit_code == 0
+ assert result.exception is None
+ assert any("http://0.0.0.0:8080" in line for line in lines)
diff --git a/tests/tests_app_examples/test_collect_failures.py b/tests/tests_app_examples/test_collect_failures.py
index f263ebb1a9f58..c149211e10774 100644
--- a/tests/tests_app_examples/test_collect_failures.py
+++ b/tests/tests_app_examples/test_collect_failures.py
@@ -26,6 +26,7 @@ def test_collect_failures_example_cloud() -> None:
_,
_,
fetch_logs,
+ _,
):
last_found_log_index = -1
while len(expected_logs) != 0:
diff --git a/tests/tests_app_examples/test_commands.py b/tests/tests_app_examples/test_commands.py
deleted file mode 100644
index 5116b1b9d54bb..0000000000000
--- a/tests/tests_app_examples/test_commands.py
+++ /dev/null
@@ -1,31 +0,0 @@
-import os
-from subprocess import Popen
-from time import sleep
-from unittest import mock
-
-import pytest
-from tests_app import _PROJECT_ROOT
-
-from lightning_app.testing.testing import run_app_in_cloud
-
-
-@mock.patch.dict(os.environ, {"SKIP_LIGHTING_UTILITY_WHEELS_BUILD": "0"})
-@pytest.mark.cloud
-def test_commands_example_cloud() -> None:
- with run_app_in_cloud(os.path.join(_PROJECT_ROOT, "examples/app_commands")) as (
- admin_page,
- _,
- fetch_logs,
- ):
- app_id = admin_page.url.split("/")[-1]
- cmd = f"lightning trigger_with_client_command --name=something --app_id {app_id}"
- Popen(cmd, shell=True).wait()
- cmd = f"lightning trigger_without_client_command --name=else --app_id {app_id}"
- Popen(cmd, shell=True).wait()
-
- has_logs = False
- while not has_logs:
- for log in fetch_logs():
- if "['something', 'else']" in log:
- has_logs = True
- sleep(1)
diff --git a/tests/tests_app_examples/test_commands_and_api.py b/tests/tests_app_examples/test_commands_and_api.py
new file mode 100644
index 0000000000000..8d84cf4847ebd
--- /dev/null
+++ b/tests/tests_app_examples/test_commands_and_api.py
@@ -0,0 +1,42 @@
+import os
+from subprocess import Popen
+from time import sleep
+
+import pytest
+import requests
+from tests_app import _PROJECT_ROOT
+
+from lightning_app.testing.testing import run_app_in_cloud
+
+
+@pytest.mark.cloud
+def test_commands_and_api_example_cloud() -> None:
+ with run_app_in_cloud(os.path.join(_PROJECT_ROOT, "examples/app_commands_and_api")) as (
+ admin_page,
+ view_page,
+ fetch_logs,
+ _,
+ ):
+ # 1: Collect the app_id
+ app_id = admin_page.url.split("/")[-1]
+
+ # 2: Send the first command with the client
+ cmd = f"lightning command_with_client --name=this --app_id {app_id}"
+ Popen(cmd, shell=True).wait()
+
+ # 3: Send the second command without a client
+ cmd = f"lightning command_without_client --name=is --app_id {app_id}"
+ Popen(cmd, shell=True).wait()
+
+ # 4: Send a request to the Rest API directly.
+ base_url = view_page.url.replace("/view", "").replace("/child_flow", "")
+ resp = requests.post(base_url + "/user/command_without_client?name=awesome")
+ assert resp.status_code == 200, resp.json()
+
+ # 5: Validate the logs.
+ has_logs = False
+ while not has_logs:
+ for log in fetch_logs():
+ if "['this', 'is', 'awesome']" in log:
+ has_logs = True
+ sleep(1)
diff --git a/tests/tests_app_examples/test_custom_work_dependencies.py b/tests/tests_app_examples/test_custom_work_dependencies.py
index 8390233e2eee3..b8971e0ef2148 100644
--- a/tests/tests_app_examples/test_custom_work_dependencies.py
+++ b/tests/tests_app_examples/test_custom_work_dependencies.py
@@ -13,10 +13,10 @@ def test_custom_work_dependencies_example_cloud() -> None:
with run_app_in_cloud(
os.path.join(_PROJECT_ROOT, "tests/tests_app_examples/custom_work_dependencies/"),
app_name="app.py",
- ) as (_, _, fetch_logs):
+ ) as (_, _, fetch_logs, _):
has_logs = False
while not has_logs:
- for log in fetch_logs():
+ for log in fetch_logs(["flow"]):
if "Custom Work Dependency checker End" in log:
has_logs = True
sleep(1)
diff --git a/tests/tests_app_examples/test_drive.py b/tests/tests_app_examples/test_drive.py
index 9cebca9cf1072..dde68d1a85113 100644
--- a/tests/tests_app_examples/test_drive.py
+++ b/tests/tests_app_examples/test_drive.py
@@ -11,8 +11,9 @@
def test_drive_example_cloud() -> None:
with run_app_in_cloud(os.path.join(_PROJECT_ROOT, "examples/app_drive")) as (
_,
- view_page,
+ _,
fetch_logs,
+ _,
):
has_logs = False
diff --git a/tests/tests_app_examples/test_idle_timeout.py b/tests/tests_app_examples/test_idle_timeout.py
index fb58a83aefc93..f06181ce86ed3 100644
--- a/tests/tests_app_examples/test_idle_timeout.py
+++ b/tests/tests_app_examples/test_idle_timeout.py
@@ -13,10 +13,11 @@ def test_idle_timeout_example_cloud() -> None:
_,
_,
fetch_logs,
+ _,
):
has_logs = False
while not has_logs:
- for log in fetch_logs():
+ for log in fetch_logs(["flow"]):
if "Application End" in log:
has_logs = True
sleep(1)
diff --git a/tests/tests_app_examples/test_payload.py b/tests/tests_app_examples/test_payload.py
index 28d2391c18a2a..b40b8ca52defd 100644
--- a/tests/tests_app_examples/test_payload.py
+++ b/tests/tests_app_examples/test_payload.py
@@ -9,11 +9,11 @@
@pytest.mark.cloud
def test_payload_example_cloud() -> None:
- with run_app_in_cloud(os.path.join(_PROJECT_ROOT, "examples/app_payload")) as (_, _, fetch_logs):
+ with run_app_in_cloud(os.path.join(_PROJECT_ROOT, "examples/app_payload")) as (_, _, fetch_logs, _):
has_logs = False
while not has_logs:
- for log in fetch_logs():
+ for log in fetch_logs(["flow"]):
if "Application End!" in log:
has_logs = True
sleep(1)
diff --git a/tests/tests_app_examples/test_quick_start.py b/tests/tests_app_examples/test_quick_start.py
index 9db693a5dc3d6..454c1084ca1bb 100644
--- a/tests/tests_app_examples/test_quick_start.py
+++ b/tests/tests_app_examples/test_quick_start.py
@@ -51,7 +51,7 @@ def test_quick_start_example(caplog, monkeypatch):
@pytest.mark.cloud
def test_quick_start_example_cloud() -> None:
- with run_app_in_cloud(os.path.join(_PROJECT_ROOT, "lightning-quick-start/")) as (_, view_page, _):
+ with run_app_in_cloud(os.path.join(_PROJECT_ROOT, "lightning-quick-start/")) as (_, view_page, _, _):
def click_gradio_demo(*_, **__):
button = view_page.locator('button:has-text("Interactive demo")')
diff --git a/tests/tests_app_examples/test_template_react_ui.py b/tests/tests_app_examples/test_template_react_ui.py
index 2e348035fe6e5..4b4588d2397e5 100644
--- a/tests/tests_app_examples/test_template_react_ui.py
+++ b/tests/tests_app_examples/test_template_react_ui.py
@@ -14,6 +14,7 @@ def test_template_react_ui_example_cloud() -> None:
_,
view_page,
fetch_logs,
+ _,
):
def click_button(*_, **__):
diff --git a/tests/tests_app_examples/test_template_streamlit_ui.py b/tests/tests_app_examples/test_template_streamlit_ui.py
index a8ba93794f2a0..e2c33305298f7 100644
--- a/tests/tests_app_examples/test_template_streamlit_ui.py
+++ b/tests/tests_app_examples/test_template_streamlit_ui.py
@@ -14,6 +14,7 @@ def test_template_streamlit_ui_example_cloud() -> None:
_,
view_page,
fetch_logs,
+ _,
):
def click_button(*_, **__):
diff --git a/tests/tests_app_examples/test_v0_app.py b/tests/tests_app_examples/test_v0_app.py
index d34a92d6102f8..026c45a4e1ba1 100644
--- a/tests/tests_app_examples/test_v0_app.py
+++ b/tests/tests_app_examples/test_v0_app.py
@@ -45,7 +45,7 @@ def check_content(button_name, text_content):
wait_for(view_page, check_content, "TAB_2", "Hello from component B")
has_logs = False
while not has_logs:
- for log in fetch_logs():
+ for log in fetch_logs(["flow"]):
if "'a': 'a', 'b': 'b'" in log:
has_logs = True
sleep(1)
@@ -74,5 +74,6 @@ def test_v0_app_example_cloud() -> None:
_,
view_page,
fetch_logs,
+ _,
):
run_v0_app(fetch_logs, view_page)
diff --git a/tests/tests_pytorch/accelerators/test_ipu.py b/tests/tests_pytorch/accelerators/test_ipu.py
index 33d59d9a835ca..db3b9d1f91952 100644
--- a/tests/tests_pytorch/accelerators/test_ipu.py
+++ b/tests/tests_pytorch/accelerators/test_ipu.py
@@ -185,7 +185,7 @@ def test_optimization(tmpdir):
@RunIf(ipu=True)
-def test_mixed_precision(tmpdir):
+def test_half_precision(tmpdir):
class TestCallback(Callback):
def setup(self, trainer: Trainer, pl_module: LightningModule, stage: Optional[str] = None) -> None:
assert trainer.strategy.model.precision == 16
diff --git a/tests/tests_pytorch/callbacks/progress/test_rich_progress_bar.py b/tests/tests_pytorch/callbacks/progress/test_rich_progress_bar.py
index e9374f8ea4be1..f1ccf2a2726a2 100644
--- a/tests/tests_pytorch/callbacks/progress/test_rich_progress_bar.py
+++ b/tests/tests_pytorch/callbacks/progress/test_rich_progress_bar.py
@@ -21,8 +21,7 @@
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ProgressBarBase, RichProgressBar
from pytorch_lightning.callbacks.progress.rich_progress import RichProgressBarTheme
-from pytorch_lightning.demos.boring_classes import BoringModel, RandomDataset
-from tests_pytorch.helpers.datasets import RandomIterableDataset
+from pytorch_lightning.demos.boring_classes import BoringModel, RandomDataset, RandomIterableDataset
from tests_pytorch.helpers.runif import RunIf
diff --git a/tests/tests_pytorch/callbacks/test_stochastic_weight_avg.py b/tests/tests_pytorch/callbacks/test_stochastic_weight_avg.py
index 859cf2fa98c0c..7f1692e30a3f2 100644
--- a/tests/tests_pytorch/callbacks/test_stochastic_weight_avg.py
+++ b/tests/tests_pytorch/callbacks/test_stochastic_weight_avg.py
@@ -12,25 +12,30 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
+import os
+from pathlib import Path
+from typing import ContextManager, Optional
from unittest import mock
import pytest
import torch
from torch import nn
+from torch.optim.lr_scheduler import LambdaLR
from torch.optim.swa_utils import SWALR
from torch.utils.data import DataLoader
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import StochasticWeightAveraging
-from pytorch_lightning.demos.boring_classes import BoringModel, RandomDataset
+from pytorch_lightning.demos.boring_classes import BoringModel, RandomDataset, RandomIterableDataset
from pytorch_lightning.strategies import DDPSpawnStrategy, Strategy
from pytorch_lightning.utilities.exceptions import MisconfigurationException
-from tests_pytorch.helpers.datasets import RandomIterableDataset
from tests_pytorch.helpers.runif import RunIf
class SwaTestModel(BoringModel):
- def __init__(self, batchnorm: bool = True, interval: str = "epoch", iterable_dataset: bool = False):
+ def __init__(
+ self, batchnorm: bool = True, interval: str = "epoch", iterable_dataset: bool = False, crash_on_epoch=None
+ ):
super().__init__()
layers = [nn.Linear(32, 32)]
if batchnorm:
@@ -39,17 +44,18 @@ def __init__(self, batchnorm: bool = True, interval: str = "epoch", iterable_dat
self.layer = nn.Sequential(*layers)
self.interval = interval
self.iterable_dataset = iterable_dataset
+ self.crash_on_epoch = crash_on_epoch
def training_step(self, batch, batch_idx):
+ if self.crash_on_epoch and self.trainer.current_epoch >= self.crash_on_epoch:
+ raise Exception("SWA crash test")
output = self.forward(batch)
loss = self.loss(batch, output)
return {"loss": loss}
def train_dataloader(self):
-
dset_cls = RandomIterableDataset if self.iterable_dataset else RandomDataset
dset = dset_cls(32, 64)
-
return DataLoader(dset, batch_size=2)
def configure_optimizers(self):
@@ -66,6 +72,8 @@ def configure_optimizers(self):
class SwaTestCallback(StochasticWeightAveraging):
update_parameters_calls: int = 0
transfer_weights_calls: int = 0
+ # Record the first epoch, as if we are resuming from a checkpoint this may not be equal to 0
+ first_epoch: Optional[int] = None
def update_parameters(self, *args, **kwargs):
self.update_parameters_calls += 1
@@ -77,6 +85,11 @@ def transfer_weights(self, *args, **kwargs):
def on_train_epoch_start(self, trainer, *args):
super().on_train_epoch_start(trainer, *args)
+ if self.first_epoch is None and not trainer.fit_loop.restarting:
+ # since the checkpoint loaded was saved `on_train_epoch_end`, the first `FitLoop` iteration will
+ # not update the model and just call the epoch-level hooks, for that reason, we check that we are not
+ # restarting before choosing the first epoch
+ self.first_epoch = trainer.current_epoch
assert trainer.fit_loop._skip_backward == (trainer.current_epoch > self.swa_end)
if self.swa_start <= trainer.current_epoch:
assert isinstance(trainer.lr_scheduler_configs[0].scheduler, SWALR)
@@ -88,6 +101,7 @@ def on_train_epoch_end(self, trainer, *args):
if self.swa_start <= trainer.current_epoch <= self.swa_end:
swa_epoch = trainer.current_epoch - self.swa_start
assert self.n_averaged == swa_epoch + 1
+ assert self._swa_scheduler is not None
# Scheduler is stepped once on initialization and then at the end of each epoch
assert self._swa_scheduler._step_count == swa_epoch + 2
elif trainer.current_epoch > self.swa_end:
@@ -103,10 +117,13 @@ def on_train_end(self, trainer, pl_module):
if not isinstance(trainer.strategy, DDPSpawnStrategy):
# check backward call count. the batchnorm update epoch should not backward
- assert trainer.strategy.backward.call_count == trainer.max_epochs * trainer.limit_train_batches
+ assert trainer.strategy.backward.call_count == (
+ (trainer.max_epochs - self.first_epoch) * trainer.limit_train_batches
+ )
# check call counts
- assert self.update_parameters_calls == trainer.max_epochs - (self._swa_epoch_start - 1)
+ first_swa_epoch = max(self.first_epoch, self.swa_start)
+ assert self.update_parameters_calls == trainer.max_epochs - first_swa_epoch
assert self.transfer_weights_calls == 1
@@ -140,7 +157,7 @@ def train_with_swa(
devices=devices,
)
- with mock.patch.object(Strategy, "backward", wraps=trainer.strategy.backward):
+ with _backward_patch(trainer):
trainer.fit(model)
# check the model is the expected
@@ -226,9 +243,10 @@ def test_swa_multiple_lrs(tmpdir):
class TestModel(BoringModel):
def __init__(self):
- super(BoringModel, self).__init__()
+ super().__init__()
self.layer1 = torch.nn.Linear(32, 32)
self.layer2 = torch.nn.Linear(32, 2)
+ self.on_train_epoch_start_called = False
def forward(self, x):
x = self.layer1(x)
@@ -255,3 +273,98 @@ def on_train_epoch_start(self):
)
trainer.fit(model)
assert model.on_train_epoch_start_called
+
+
+def _swa_resume_training_from_checkpoint(tmpdir, model, resume_model, ddp=False):
+ swa_start = 3
+ trainer_kwargs = {
+ "default_root_dir": tmpdir,
+ "max_epochs": 5,
+ "accelerator": "cpu",
+ "strategy": "ddp_spawn_find_unused_parameters_false" if ddp else None,
+ "devices": 2 if ddp else 1,
+ "limit_train_batches": 5,
+ "limit_val_batches": 0,
+ "accumulate_grad_batches": 2,
+ "enable_progress_bar": False,
+ }
+ trainer = Trainer(callbacks=SwaTestCallback(swa_epoch_start=swa_start, swa_lrs=0.1), **trainer_kwargs)
+
+ with _backward_patch(trainer), pytest.raises(Exception, match="SWA crash test"):
+ trainer.fit(model)
+
+ checkpoint_dir = Path(tmpdir) / "lightning_logs" / "version_0" / "checkpoints"
+ checkpoint_files = os.listdir(checkpoint_dir)
+ assert len(checkpoint_files) == 1
+ ckpt_path = str(checkpoint_dir / checkpoint_files[0])
+
+ trainer = Trainer(callbacks=SwaTestCallback(swa_epoch_start=swa_start, swa_lrs=0.1), **trainer_kwargs)
+
+ with _backward_patch(trainer):
+ trainer.fit(resume_model, ckpt_path=ckpt_path)
+
+
+class CustomSchedulerModel(SwaTestModel):
+ def configure_optimizers(self):
+ optimizer = torch.optim.SGD(self.layer.parameters(), lr=0.1)
+
+ def lr_lambda(current_step: int):
+ return 0.1
+
+ scheduler = LambdaLR(optimizer, lr_lambda, -1)
+ return {
+ "optimizer": optimizer,
+ "lr_scheduler": {
+ "scheduler": scheduler,
+ "interval": self.interval,
+ },
+ }
+
+
+@pytest.mark.parametrize("crash_on_epoch", [1, 3])
+def test_swa_resume_training_from_checkpoint(tmpdir, crash_on_epoch):
+ model = SwaTestModel(crash_on_epoch=crash_on_epoch)
+ resume_model = SwaTestModel()
+ _swa_resume_training_from_checkpoint(tmpdir, model, resume_model)
+
+
+@pytest.mark.parametrize("crash_on_epoch", [1, 3])
+def test_swa_resume_training_from_checkpoint_custom_scheduler(tmpdir, crash_on_epoch):
+ # Reproduces the bug reported in https://github.com/PyTorchLightning/pytorch-lightning/issues/11665
+ model = CustomSchedulerModel(crash_on_epoch=crash_on_epoch)
+ resume_model = CustomSchedulerModel()
+ _swa_resume_training_from_checkpoint(tmpdir, model, resume_model)
+
+
+@RunIf(skip_windows=True)
+def test_swa_resume_training_from_checkpoint_ddp(tmpdir):
+ model = SwaTestModel(crash_on_epoch=3)
+ resume_model = SwaTestModel()
+ _swa_resume_training_from_checkpoint(tmpdir, model, resume_model, ddp=True)
+
+
+@pytest.mark.parametrize(
+ "strategy",
+ [
+ pytest.param("fsdp", marks=RunIf(fairscale_fully_sharded=True, min_cuda_gpus=1)),
+ pytest.param("deepspeed", marks=RunIf(deepspeed=True, min_cuda_gpus=1)),
+ ],
+)
+def test_misconfiguration_error_with_sharded_model(tmpdir, strategy: str):
+ model = SwaTestModel()
+ swa_callback = SwaTestCallback(swa_epoch_start=2, swa_lrs=0.1)
+ trainer = Trainer(
+ default_root_dir=tmpdir,
+ enable_progress_bar=False,
+ max_epochs=5,
+ callbacks=[swa_callback],
+ strategy=strategy,
+ accelerator="gpu",
+ devices=1,
+ )
+ with pytest.raises(MisconfigurationException, match="SWA does not currently support sharded models"):
+ trainer.fit(model)
+
+
+def _backward_patch(trainer: Trainer) -> ContextManager:
+ return mock.patch.object(Strategy, "backward", wraps=trainer.strategy.backward)
diff --git a/tests/tests_pytorch/deprecated_api/test_remove_1-8.py b/tests/tests_pytorch/deprecated_api/test_remove_1-8.py
index aa6c1a615f9d2..91be34c55078f 100644
--- a/tests/tests_pytorch/deprecated_api/test_remove_1-8.py
+++ b/tests/tests_pytorch/deprecated_api/test_remove_1-8.py
@@ -36,7 +36,6 @@
from pytorch_lightning.trainer.states import RunningStage
from pytorch_lightning.utilities import device_parser
from pytorch_lightning.utilities.apply_func import move_data_to_device
-from pytorch_lightning.utilities.enums import DeviceType, DistributedType
from pytorch_lightning.utilities.imports import _TORCHTEXT_LEGACY
from pytorch_lightning.utilities.rank_zero import rank_zero_only, rank_zero_warn
from tests_pytorch.deprecated_api import no_deprecated_call
@@ -44,18 +43,6 @@
from tests_pytorch.helpers.torchtext_utils import get_dummy_torchtext_data_iterator
-def test_v1_8_0_deprecated_distributed_type_enum():
-
- with pytest.deprecated_call(match="has been deprecated in v1.6 and will be removed in v1.8."):
- _ = DistributedType.DDP
-
-
-def test_v1_8_0_deprecated_device_type_enum():
-
- with pytest.deprecated_call(match="has been deprecated in v1.6 and will be removed in v1.8."):
- _ = DeviceType.CPU
-
-
@pytest.mark.skipif(not _TORCHTEXT_LEGACY, reason="torchtext.legacy is deprecated.")
def test_v1_8_0_deprecated_torchtext_batch():
diff --git a/tests/tests_pytorch/helpers/datasets.py b/tests/tests_pytorch/helpers/datasets.py
index 3443020d4528f..c9d185313e85e 100644
--- a/tests/tests_pytorch/helpers/datasets.py
+++ b/tests/tests_pytorch/helpers/datasets.py
@@ -19,7 +19,7 @@
from typing import Optional, Sequence, Tuple
import torch
-from torch.utils.data import Dataset, IterableDataset
+from torch.utils.data import Dataset
class MNIST(Dataset):
@@ -212,40 +212,3 @@ def __getitem__(self, idx):
def __len__(self):
return len(self.y)
-
-
-class RandomDictDataset(Dataset):
- def __init__(self, size: int, length: int):
- self.len = length
- self.data = torch.randn(length, size)
-
- def __getitem__(self, index):
- a = self.data[index]
- b = a + 2
- return {"a": a, "b": b}
-
- def __len__(self):
- return self.len
-
-
-class RandomIterableDataset(IterableDataset):
- def __init__(self, size: int, count: int):
- self.count = count
- self.size = size
-
- def __iter__(self):
- for _ in range(self.count):
- yield torch.randn(self.size)
-
-
-class RandomIterableDatasetWithLen(IterableDataset):
- def __init__(self, size: int, count: int):
- self.count = count
- self.size = size
-
- def __iter__(self):
- for _ in range(len(self)):
- yield torch.randn(self.size)
-
- def __len__(self):
- return self.count
diff --git a/tests/tests_pytorch/lite/test_lite.py b/tests/tests_pytorch/lite/test_lite.py
index 2215ab3129780..86a0a5a82195a 100644
--- a/tests/tests_pytorch/lite/test_lite.py
+++ b/tests/tests_pytorch/lite/test_lite.py
@@ -11,7 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-import contextlib
import os
from copy import deepcopy
from unittest import mock
@@ -30,7 +29,6 @@
from pytorch_lightning.strategies import DeepSpeedStrategy, Strategy
from pytorch_lightning.utilities import _StrategyType
from pytorch_lightning.utilities.exceptions import MisconfigurationException
-from pytorch_lightning.utilities.imports import _RequirementAvailable
from pytorch_lightning.utilities.seed import pl_worker_init_function
from tests_pytorch.helpers.runif import RunIf
@@ -480,13 +478,4 @@ def run(self):
assert self.broadcast(True)
assert self.is_global_zero == (self.local_rank == 0)
- if _RequirementAvailable("deepspeed>=0.6.5"):
- # https://github.com/microsoft/DeepSpeed/issues/2139
- raise_if_deepspeed_incompatible = pytest.raises(
- RuntimeError, match="DeepSpeed ZeRO-3 is not supported with this version of Lightning Lite"
- )
- else:
- raise_if_deepspeed_incompatible = contextlib.suppress()
-
- with raise_if_deepspeed_incompatible:
- Lite(strategy=DeepSpeedStrategy(stage=3, logging_batch_size_per_gpu=1), devices=2, accelerator="gpu").run()
+ Lite(strategy=DeepSpeedStrategy(stage=3, logging_batch_size_per_gpu=1), devices=2, accelerator="gpu").run()
diff --git a/tests/tests_pytorch/models/test_hparams.py b/tests/tests_pytorch/models/test_hparams.py
index c130381c7832d..84311d6f780fb 100644
--- a/tests/tests_pytorch/models/test_hparams.py
+++ b/tests/tests_pytorch/models/test_hparams.py
@@ -29,6 +29,7 @@
from pytorch_lightning import LightningModule, Trainer
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.core.datamodule import LightningDataModule
+from pytorch_lightning.core.mixins import HyperparametersMixin
from pytorch_lightning.core.saving import load_hparams_from_yaml, save_hparams_to_yaml
from pytorch_lightning.demos.boring_classes import BoringDataModule, BoringModel, RandomDataset
from pytorch_lightning.utilities import _OMEGACONF_AVAILABLE, AttributeDict, is_picklable
@@ -399,6 +400,24 @@ def _raw_checkpoint_path(trainer) -> str:
return raw_checkpoint_path
+@pytest.mark.parametrize("base_class", (HyperparametersMixin, LightningModule, LightningDataModule))
+def test_save_hyperparameters_under_composition(base_class):
+ """Test that in a composition where the parent is not a Lightning-like module, the parent's arguments don't get
+ collected."""
+
+ class ChildInComposition(base_class):
+ def __init__(self, same_arg):
+ super().__init__()
+ self.save_hyperparameters()
+
+ class NotPLSubclass: # intentionally not subclassing LightningModule/LightningDataModule
+ def __init__(self, same_arg="parent_default", other_arg="other"):
+ self.child = ChildInComposition(same_arg="cocofruit")
+
+ parent = NotPLSubclass()
+ assert parent.child.hparams == dict(same_arg="cocofruit")
+
+
class LocalVariableModelSuperLast(BoringModel):
"""This model has the super().__init__() call at the end."""
diff --git a/tests/tests_pytorch/run_standalone_tasks.sh b/tests/tests_pytorch/run_standalone_tasks.sh
index 960bd867ceaa4..698ed7863ab96 100644
--- a/tests/tests_pytorch/run_standalone_tasks.sh
+++ b/tests/tests_pytorch/run_standalone_tasks.sh
@@ -34,6 +34,10 @@ fi
# test that a user can manually launch individual processes
echo "Running manual ddp launch test"
export PYTHONPATH="${PYTHONPATH}:$(pwd)"
-args="--trainer.accelerator gpu --trainer.devices 2 --trainer.strategy ddp --trainer.max_epochs=1 --trainer.limit_train_batches=1 --trainer.limit_val_batches=1 --trainer.limit_test_batches=1"
-MASTER_ADDR="localhost" MASTER_PORT=1234 LOCAL_RANK=1 python ../../examples/convert_from_pt_to_pl/image_classifier_5_lightning_datamodule.py ${args} &
-MASTER_ADDR="localhost" MASTER_PORT=1234 LOCAL_RANK=0 python ../../examples/convert_from_pt_to_pl/image_classifier_5_lightning_datamodule.py ${args}
+args="fit --trainer.accelerator gpu --trainer.devices 2 --trainer.strategy ddp --trainer.max_epochs=1 --trainer.limit_train_batches=1 --trainer.limit_val_batches=1 --trainer.limit_test_batches=1"
+MASTER_ADDR="localhost" MASTER_PORT=1234 LOCAL_RANK=1 python strategies/scripts/cli_script.py ${args} &
+MASTER_ADDR="localhost" MASTER_PORT=1234 LOCAL_RANK=0 python strategies/scripts/cli_script.py ${args}
+
+# test that ddp can launched as a module (-m option)
+echo "Running ddp example as module"
+python -m strategies.scripts.cli_script ${args}
diff --git a/tests/tests_pytorch/serve/__init__.py b/tests/tests_pytorch/serve/__init__.py
new file mode 100644
index 0000000000000..e69de29bb2d1d
diff --git a/tests/tests_pytorch/strategies/ddp_model.py b/tests/tests_pytorch/strategies/ddp_model.py
deleted file mode 100644
index 76d1f3f2f6866..0000000000000
--- a/tests/tests_pytorch/strategies/ddp_model.py
+++ /dev/null
@@ -1,58 +0,0 @@
-# Copyright The PyTorch Lightning team.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-"""Runs either `.fit()` or `.test()` on a single node across multiple gpus."""
-import os
-from argparse import ArgumentParser
-
-import torch
-
-from pytorch_lightning import seed_everything, Trainer
-from tests_pytorch.helpers.datamodules import ClassifDataModule
-from tests_pytorch.helpers.simple_models import ClassificationModel
-
-
-def main():
- seed_everything(4321)
-
- parser = ArgumentParser(add_help=False)
- parser = Trainer.add_argparse_args(parser)
- parser.add_argument("--trainer_method", default="fit")
- parser.add_argument("--tmpdir")
- parser.add_argument("--workdir")
- parser.set_defaults(accelerator="gpu", devices=2)
- parser.set_defaults(strategy="ddp")
- args = parser.parse_args()
-
- dm = ClassifDataModule()
- model = ClassificationModel()
- trainer = Trainer.from_argparse_args(args)
-
- if args.trainer_method == "fit":
- trainer.fit(model, datamodule=dm)
- result = None
- elif args.trainer_method == "test":
- result = trainer.test(model, datamodule=dm)
- elif args.trainer_method == "fit_test":
- trainer.fit(model, datamodule=dm)
- result = trainer.test(model, datamodule=dm)
- else:
- raise ValueError(f"Unsupported: {args.trainer_method}")
-
- result_ext = {"status": "complete", "method": args.trainer_method, "result": result}
- file_path = os.path.join(args.tmpdir, "ddp.result")
- torch.save(result_ext, file_path)
-
-
-if __name__ == "__main__":
- main()
diff --git a/tests/tests_pytorch/strategies/scripts/__init__.py b/tests/tests_pytorch/strategies/scripts/__init__.py
new file mode 100644
index 0000000000000..e69de29bb2d1d
diff --git a/tests/tests_pytorch/strategies/scripts/cli_script.py b/tests/tests_pytorch/strategies/scripts/cli_script.py
new file mode 100644
index 0000000000000..17f0d29392eb9
--- /dev/null
+++ b/tests/tests_pytorch/strategies/scripts/cli_script.py
@@ -0,0 +1,24 @@
+# Copyright The PyTorch Lightning team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""A trivial script that wraps a LightningCLI around the BoringModel and BoringDataModule."""
+from pytorch_lightning.cli import LightningCLI
+from pytorch_lightning.demos.boring_classes import BoringDataModule, BoringModel
+
+if __name__ == "__main__":
+ LightningCLI(
+ BoringModel,
+ BoringDataModule,
+ seed_everything_default=42,
+ save_config_overwrite=True,
+ )
diff --git a/tests/tests_pytorch/strategies/test_ddp.py b/tests/tests_pytorch/strategies/test_ddp.py
index 4610f6153386b..9b196f3e2a97f 100644
--- a/tests/tests_pytorch/strategies/test_ddp.py
+++ b/tests/tests_pytorch/strategies/test_ddp.py
@@ -21,60 +21,41 @@
from torch.nn.parallel.distributed import DistributedDataParallel
import pytorch_lightning as pl
-from pytorch_lightning import Trainer
+from pytorch_lightning import seed_everything, Trainer
from pytorch_lightning.callbacks import Callback
from pytorch_lightning.demos.boring_classes import BoringModel
from pytorch_lightning.strategies import DDPStrategy
+from tests_pytorch.helpers.datamodules import ClassifDataModule
from tests_pytorch.helpers.runif import RunIf
-from tests_pytorch.strategies import ddp_model
-from tests_pytorch.utilities.distributed import call_training_script
+from tests_pytorch.helpers.simple_models import ClassificationModel
-CLI_ARGS = "--max_epochs 1 --accelerator gpu --devices 2 --strategy ddp"
+@RunIf(min_cuda_gpus=2, standalone=True)
+def test_multi_gpu_model_ddp_fit_only(tmpdir):
+ dm = ClassifDataModule()
+ model = ClassificationModel()
+ trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, accelerator="gpu", devices=2, strategy="ddp")
+ trainer.fit(model, datamodule=dm)
-@RunIf(min_cuda_gpus=2)
-@pytest.mark.parametrize("as_module", [True, False])
-def test_multi_gpu_model_ddp_fit_only(tmpdir, as_module):
- # call the script
- call_training_script(ddp_model, CLI_ARGS, "fit", tmpdir, timeout=120, as_module=as_module)
- # load the results of the script
- result_path = os.path.join(tmpdir, "ddp.result")
- result = torch.load(result_path)
+@RunIf(min_cuda_gpus=2, standalone=True)
+def test_multi_gpu_model_ddp_test_only(tmpdir):
+ dm = ClassifDataModule()
+ model = ClassificationModel()
+ trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, accelerator="gpu", devices=2, strategy="ddp")
+ trainer.test(model, datamodule=dm)
- # verify the file wrote the expected outputs
- assert result["status"] == "complete"
+@RunIf(min_cuda_gpus=2, standalone=True)
+def test_multi_gpu_model_ddp_fit_test(tmpdir):
+ seed_everything(4321)
+ dm = ClassifDataModule()
+ model = ClassificationModel()
+ trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, accelerator="gpu", devices=2, strategy="ddp")
+ trainer.fit(model, datamodule=dm)
+ result = trainer.test(model, datamodule=dm)
-@RunIf(min_cuda_gpus=2)
-@pytest.mark.parametrize("as_module", [True, False])
-def test_multi_gpu_model_ddp_test_only(tmpdir, as_module):
- # call the script
- call_training_script(ddp_model, CLI_ARGS, "test", tmpdir, as_module=as_module)
-
- # load the results of the script
- result_path = os.path.join(tmpdir, "ddp.result")
- result = torch.load(result_path)
-
- # verify the file wrote the expected outputs
- assert result["status"] == "complete"
-
-
-@RunIf(min_cuda_gpus=2)
-@pytest.mark.parametrize("as_module", [True, False])
-def test_multi_gpu_model_ddp_fit_test(tmpdir, as_module):
- # call the script
- call_training_script(ddp_model, CLI_ARGS, "fit_test", tmpdir, timeout=20, as_module=as_module)
-
- # load the results of the script
- result_path = os.path.join(tmpdir, "ddp.result")
- result = torch.load(result_path)
-
- # verify the file wrote the expected outputs
- assert result["status"] == "complete"
-
- model_outs = result["result"]
- for out in model_outs:
+ for out in result:
assert out["test_acc"] > 0.7
@@ -194,3 +175,15 @@ def root_device(self):
assert strategy._get_process_group_backend() == expected_process_group_backend
else:
assert strategy._get_process_group_backend() == expected_process_group_backend
+
+
+@pytest.mark.parametrize(
+ "strategy_name,expected_ddp_kwargs",
+ [
+ ("ddp", {}),
+ ("ddp_find_unused_parameters_false", {"find_unused_parameters": False}),
+ ],
+)
+def test_ddp_kwargs_from_registry(strategy_name, expected_ddp_kwargs):
+ trainer = Trainer(strategy=strategy_name)
+ assert trainer.strategy._ddp_kwargs == expected_ddp_kwargs
diff --git a/tests/tests_pytorch/strategies/test_ddp_spawn_strategy.py b/tests/tests_pytorch/strategies/test_ddp_spawn_strategy.py
index 52427c2c8cc3a..7fb22206c45c6 100644
--- a/tests/tests_pytorch/strategies/test_ddp_spawn_strategy.py
+++ b/tests/tests_pytorch/strategies/test_ddp_spawn_strategy.py
@@ -178,3 +178,25 @@ def test_ddp_spawn_strategy_set_timeout(mock_init_process_group):
mock_init_process_group.assert_called_with(
process_group_backend, rank=global_rank, world_size=world_size, timeout=test_timedelta
)
+
+
+@pytest.mark.parametrize(
+ "strategy_name,expected_ddp_kwargs",
+ [
+ ("ddp_spawn", {}),
+ pytest.param("ddp_fork", {}, marks=RunIf(skip_windows=True)),
+ pytest.param("ddp_notebook", {}, marks=RunIf(skip_windows=True)),
+ ("ddp_spawn_find_unused_parameters_false", {"find_unused_parameters": False}),
+ pytest.param(
+ "ddp_fork_find_unused_parameters_false", {"find_unused_parameters": False}, marks=RunIf(skip_windows=True)
+ ),
+ pytest.param(
+ "ddp_notebook_find_unused_parameters_false",
+ {"find_unused_parameters": False},
+ marks=RunIf(skip_windows=True),
+ ),
+ ],
+)
+def test_ddp_kwargs_from_registry(strategy_name, expected_ddp_kwargs):
+ trainer = Trainer(strategy=strategy_name)
+ assert trainer.strategy._ddp_kwargs == expected_ddp_kwargs
diff --git a/tests/tests_pytorch/strategies/test_deepspeed_strategy.py b/tests/tests_pytorch/strategies/test_deepspeed_strategy.py
index 4f2cc14b6c62d..e3c6f95f3ff47 100644
--- a/tests/tests_pytorch/strategies/test_deepspeed_strategy.py
+++ b/tests/tests_pytorch/strategies/test_deepspeed_strategy.py
@@ -28,13 +28,12 @@
from pytorch_lightning import LightningDataModule, LightningModule, Trainer
from pytorch_lightning.callbacks import Callback, LearningRateMonitor, ModelCheckpoint
-from pytorch_lightning.demos.boring_classes import BoringModel, RandomDataset
+from pytorch_lightning.demos.boring_classes import BoringModel, RandomDataset, RandomIterableDataset
from pytorch_lightning.plugins import DeepSpeedPrecisionPlugin
from pytorch_lightning.strategies import DeepSpeedStrategy
from pytorch_lightning.strategies.deepspeed import _DEEPSPEED_AVAILABLE, LightningDeepSpeedModule
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from tests_pytorch.helpers.datamodules import ClassifDataModule
-from tests_pytorch.helpers.datasets import RandomIterableDataset
from tests_pytorch.helpers.runif import RunIf
if _DEEPSPEED_AVAILABLE:
@@ -171,12 +170,11 @@ def test_deepspeed_strategy_env(tmpdir, monkeypatch, deepspeed_config):
@RunIf(deepspeed=True)
@mock.patch("pytorch_lightning.utilities.device_parser.num_cuda_devices", return_value=1)
-@pytest.mark.parametrize("precision", [16, "mixed"])
@pytest.mark.parametrize(
"amp_backend",
["native", pytest.param("apex", marks=RunIf(amp_apex=True))],
)
-def test_deepspeed_precision_choice(_, amp_backend, precision, tmpdir):
+def test_deepspeed_precision_choice(_, amp_backend, tmpdir):
"""Test to ensure precision plugin is also correctly chosen.
DeepSpeed handles precision via Custom DeepSpeedPrecisionPlugin
@@ -188,16 +186,16 @@ def test_deepspeed_precision_choice(_, amp_backend, precision, tmpdir):
accelerator="gpu",
strategy="deepspeed",
amp_backend=amp_backend,
- precision=precision,
+ precision=16,
)
assert isinstance(trainer.strategy, DeepSpeedStrategy)
assert isinstance(trainer.strategy.precision_plugin, DeepSpeedPrecisionPlugin)
- assert trainer.strategy.precision_plugin.precision == precision
+ assert trainer.strategy.precision_plugin.precision == 16
@RunIf(deepspeed=True)
-def test_deepspeed_with_invalid_config_path(tmpdir):
+def test_deepspeed_with_invalid_config_path():
"""Test to ensure if we pass an invalid config path we throw an exception."""
with pytest.raises(
@@ -218,7 +216,7 @@ def test_deepspeed_with_env_path(tmpdir, monkeypatch, deepspeed_config):
@RunIf(deepspeed=True)
-def test_deepspeed_defaults(tmpdir):
+def test_deepspeed_defaults():
"""Ensure that defaults are correctly set as a config for DeepSpeed if no arguments are passed."""
strategy = DeepSpeedStrategy()
assert strategy.config is not None
@@ -663,7 +661,7 @@ def training_step(self, batch, batch_idx):
@RunIf(min_cuda_gpus=2, standalone=True, deepspeed=True)
-def test_deepspeed_multigpu_stage_3(tmpdir, deepspeed_config):
+def test_deepspeed_multigpu_stage_3(tmpdir):
"""Test to ensure ZeRO Stage 3 works with a parallel model."""
model = ModelParallelBoringModel()
trainer = Trainer(
diff --git a/tests/tests_pytorch/strategies/test_sharded_strategy.py b/tests/tests_pytorch/strategies/test_sharded_strategy.py
index a047a10df32e3..ad0673ed1a5fa 100644
--- a/tests/tests_pytorch/strategies/test_sharded_strategy.py
+++ b/tests/tests_pytorch/strategies/test_sharded_strategy.py
@@ -300,3 +300,17 @@ def test_block_backward_sync():
with strategy.block_backward_sync():
pass
model.no_sync.assert_called_once()
+
+
+@pytest.mark.parametrize(
+ "strategy_name,expected_ddp_kwargs",
+ [
+ ("ddp_sharded", {}),
+ ("ddp_sharded_find_unused_parameters_false", {"find_unused_parameters": False}),
+ ("ddp_sharded_spawn", {}),
+ ("ddp_sharded_spawn_find_unused_parameters_false", {"find_unused_parameters": False}),
+ ],
+)
+def test_ddp_kwargs_from_registry(strategy_name, expected_ddp_kwargs):
+ trainer = Trainer(strategy=strategy_name)
+ assert trainer.strategy._ddp_kwargs == expected_ddp_kwargs
diff --git a/tests/tests_pytorch/trainer/connectors/test_callback_connector.py b/tests/tests_pytorch/trainer/connectors/test_callback_connector.py
index d6d5018aa1dd0..02e846425a2a0 100644
--- a/tests/tests_pytorch/trainer/connectors/test_callback_connector.py
+++ b/tests/tests_pytorch/trainer/connectors/test_callback_connector.py
@@ -30,7 +30,7 @@
)
from pytorch_lightning.demos.boring_classes import BoringModel
from pytorch_lightning.trainer.connectors.callback_connector import CallbackConnector
-from pytorch_lightning.utilities.imports import _PYTHON_GREATER_EQUAL_3_8_0
+from pytorch_lightning.utilities.imports import _PYTHON_GREATER_EQUAL_3_8_0, _PYTHON_GREATER_EQUAL_3_10_0
def test_checkpoint_callbacks_are_last(tmpdir):
@@ -265,7 +265,10 @@ def _make_entry_point_query_mock(callback_factory):
entry_point = Mock()
entry_point.name = "mocked"
entry_point.load.return_value = callback_factory
- if _PYTHON_GREATER_EQUAL_3_8_0:
+ if _PYTHON_GREATER_EQUAL_3_10_0:
+ query_mock.return_value = [entry_point]
+ import_path = "importlib.metadata.entry_points"
+ elif _PYTHON_GREATER_EQUAL_3_8_0:
query_mock().get.return_value = [entry_point]
import_path = "importlib.metadata.entry_points"
else:
diff --git a/tests/tests_pytorch/trainer/connectors/test_data_connector.py b/tests/tests_pytorch/trainer/connectors/test_data_connector.py
index f2a98daa9c5ad..7273d7719834e 100644
--- a/tests/tests_pytorch/trainer/connectors/test_data_connector.py
+++ b/tests/tests_pytorch/trainer/connectors/test_data_connector.py
@@ -445,7 +445,8 @@ def test_dataloader_source_direct_access():
def test_dataloader_source_request_from_module():
"""Test requesting a dataloader from a module works."""
module = BoringModel()
- module.trainer = Trainer()
+ trainer = Trainer()
+ module.trainer = trainer
module.foo = Mock(return_value=module.train_dataloader())
source = _DataLoaderSource(module, "foo")
diff --git a/tests/tests_pytorch/trainer/flags/test_val_check_interval.py b/tests/tests_pytorch/trainer/flags/test_val_check_interval.py
index 9414fd1c5096f..e5fd9b5dd2706 100644
--- a/tests/tests_pytorch/trainer/flags/test_val_check_interval.py
+++ b/tests/tests_pytorch/trainer/flags/test_val_check_interval.py
@@ -16,10 +16,9 @@
import pytest
from torch.utils.data import DataLoader
-from pytorch_lightning.demos.boring_classes import BoringModel, RandomDataset
+from pytorch_lightning.demos.boring_classes import BoringModel, RandomDataset, RandomIterableDataset
from pytorch_lightning.trainer.trainer import Trainer
from pytorch_lightning.utilities.exceptions import MisconfigurationException
-from tests_pytorch.helpers.datasets import RandomIterableDataset
@pytest.mark.parametrize("max_epochs", [1, 2, 3])
diff --git a/tests/tests_pytorch/trainer/logging_/test_train_loop_logging.py b/tests/tests_pytorch/trainer/logging_/test_train_loop_logging.py
index 5855eba4c86af..85ed3d8e3471d 100644
--- a/tests/tests_pytorch/trainer/logging_/test_train_loop_logging.py
+++ b/tests/tests_pytorch/trainer/logging_/test_train_loop_logging.py
@@ -28,9 +28,8 @@
from pytorch_lightning import callbacks, Trainer
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint, TQDMProgressBar
from pytorch_lightning.core.module import LightningModule
-from pytorch_lightning.demos.boring_classes import BoringModel, RandomDataset
+from pytorch_lightning.demos.boring_classes import BoringModel, RandomDataset, RandomDictDataset
from pytorch_lightning.utilities.exceptions import MisconfigurationException
-from tests_pytorch.helpers.datasets import RandomDictDataset
from tests_pytorch.helpers.runif import RunIf
@@ -569,11 +568,12 @@ def on_train_epoch_end(self, trainer, pl_module):
"accelerator",
[
pytest.param("gpu", marks=RunIf(min_cuda_gpus=1)),
+ "cpu",
],
)
def test_metric_are_properly_reduced(tmpdir, accelerator):
class TestingModel(BoringModel):
- def __init__(self, *args, **kwargs) -> None:
+ def __init__(self) -> None:
super().__init__()
self.val_acc = Accuracy()
@@ -592,7 +592,6 @@ def validation_step(self, batch, batch_idx):
return super().validation_step(batch, batch_idx)
early_stop = EarlyStopping(monitor="val_acc", mode="max")
-
checkpoint = ModelCheckpoint(monitor="val_acc", save_last=True, save_top_k=2, mode="max")
model = TestingModel()
@@ -812,3 +811,28 @@ def training_step(self, batch, batch_idx):
call(metrics={"foo_epoch": 0.0, "epoch": 1}, step=3),
]
)
+
+
+@mock.patch("pytorch_lightning.loggers.TensorBoardLogger.log_metrics")
+def test_log_on_train_start(mock_log_metrics, tmpdir):
+ """Tests that logged metrics on_train_start get reset after the first epoch."""
+
+ class MyModel(BoringModel):
+ def on_train_start(self):
+ self.log("foo", 123)
+
+ model = MyModel()
+ trainer = Trainer(
+ default_root_dir=tmpdir,
+ limit_train_batches=1,
+ limit_val_batches=0,
+ max_epochs=2,
+ log_every_n_steps=1,
+ enable_model_summary=False,
+ enable_checkpointing=False,
+ enable_progress_bar=False,
+ )
+ trainer.fit(model)
+
+ assert mock_log_metrics.mock_calls == [call(metrics={"foo": 123.0, "epoch": 0}, step=0)]
+ assert trainer.max_epochs > 1
diff --git a/tests/tests_pytorch/trainer/properties/test_estimated_stepping_batches.py b/tests/tests_pytorch/trainer/properties/test_estimated_stepping_batches.py
index 92a1126294dfc..846a39a748a60 100644
--- a/tests/tests_pytorch/trainer/properties/test_estimated_stepping_batches.py
+++ b/tests/tests_pytorch/trainer/properties/test_estimated_stepping_batches.py
@@ -22,11 +22,10 @@
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks.gradient_accumulation_scheduler import GradientAccumulationScheduler
-from pytorch_lightning.demos.boring_classes import BoringModel
+from pytorch_lightning.demos.boring_classes import BoringModel, RandomIterableDataset
from pytorch_lightning.strategies.ipu import IPUStrategy
from pytorch_lightning.utilities import device_parser
from pytorch_lightning.utilities.exceptions import MisconfigurationException
-from tests_pytorch.helpers.datasets import RandomIterableDataset
from tests_pytorch.helpers.runif import RunIf
diff --git a/tests/tests_pytorch/trainer/test_dataloaders.py b/tests/tests_pytorch/trainer/test_dataloaders.py
index 5bea5a4cbbe1c..34504392dc0c1 100644
--- a/tests/tests_pytorch/trainer/test_dataloaders.py
+++ b/tests/tests_pytorch/trainer/test_dataloaders.py
@@ -25,12 +25,16 @@
from pytorch_lightning import Callback, seed_everything, Trainer
from pytorch_lightning.callbacks import ModelCheckpoint
-from pytorch_lightning.demos.boring_classes import BoringModel, RandomDataset
+from pytorch_lightning.demos.boring_classes import (
+ BoringModel,
+ RandomDataset,
+ RandomIterableDataset,
+ RandomIterableDatasetWithLen,
+)
from pytorch_lightning.trainer.states import RunningStage
from pytorch_lightning.utilities.data import _auto_add_worker_init_fn, has_iterable_dataset, has_len_all_ranks
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from tests_pytorch.helpers.dataloaders import CustomInfDataloader, CustomNotImplementedErrorDataloader
-from tests_pytorch.helpers.datasets import RandomIterableDataset, RandomIterableDatasetWithLen
from tests_pytorch.helpers.runif import RunIf
diff --git a/tests/tests_pytorch/trainer/test_trainer.py b/tests/tests_pytorch/trainer/test_trainer.py
index e4be8929f9c7e..9506acee425d0 100644
--- a/tests/tests_pytorch/trainer/test_trainer.py
+++ b/tests/tests_pytorch/trainer/test_trainer.py
@@ -41,7 +41,12 @@
from pytorch_lightning.callbacks.fault_tolerance import _FaultToleranceCheckpoint
from pytorch_lightning.callbacks.prediction_writer import BasePredictionWriter
from pytorch_lightning.core.saving import load_hparams_from_tags_csv, load_hparams_from_yaml, save_hparams_to_tags_csv
-from pytorch_lightning.demos.boring_classes import BoringModel, RandomDataset
+from pytorch_lightning.demos.boring_classes import (
+ BoringModel,
+ RandomDataset,
+ RandomIterableDataset,
+ RandomIterableDatasetWithLen,
+)
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.overrides.distributed import IndexBatchSamplerWrapper, UnrepeatedDistributedSampler
from pytorch_lightning.strategies import (
@@ -60,7 +65,6 @@
from pytorch_lightning.utilities.imports import _OMEGACONF_AVAILABLE, _TORCH_GREATER_EQUAL_1_12
from pytorch_lightning.utilities.seed import seed_everything
from tests_pytorch.helpers.datamodules import ClassifDataModule
-from tests_pytorch.helpers.datasets import RandomIterableDataset, RandomIterableDatasetWithLen
from tests_pytorch.helpers.runif import RunIf
from tests_pytorch.helpers.simple_models import ClassificationModel
diff --git a/tests/tests_pytorch/utilities/distributed.py b/tests/tests_pytorch/utilities/distributed.py
deleted file mode 100644
index 38a50edcc7177..0000000000000
--- a/tests/tests_pytorch/utilities/distributed.py
+++ /dev/null
@@ -1,45 +0,0 @@
-# Copyright The PyTorch Lightning team.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-import os
-import subprocess
-import sys
-from pathlib import Path
-from subprocess import TimeoutExpired
-
-import pytorch_lightning
-
-
-def call_training_script(module_file, cli_args, method, tmpdir, timeout=60, as_module=False):
- file = Path(module_file.__file__).absolute()
- cli_args = cli_args.split(" ") if cli_args else []
- cli_args += ["--tmpdir", str(tmpdir)]
- cli_args += ["--trainer_method", method]
- file_args = ["-m", module_file.__spec__.name] if as_module else [str(file)]
- command = [sys.executable] + file_args + cli_args
-
- # need to set the PYTHONPATH in case pytorch_lightning was not installed into the environment
- env = os.environ.copy()
- env["PYTHONPATH"] = env.get("PYTHONPATH", "") + f"{pytorch_lightning.__file__}:"
-
- # for running in ddp mode, we need to launch it's own process or pytest will get stuck
- p = subprocess.Popen(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, env=env)
- try:
- std, err = p.communicate(timeout=timeout)
- err = str(err.decode("utf-8"))
- if "Exception" in err:
- raise Exception(err)
- except TimeoutExpired:
- p.kill()
- std, err = p.communicate()
- return std, err
diff --git a/tests/tests_pytorch/utilities/test_data.py b/tests/tests_pytorch/utilities/test_data.py
index ffb898efaa815..cc70417988616 100644
--- a/tests/tests_pytorch/utilities/test_data.py
+++ b/tests/tests_pytorch/utilities/test_data.py
@@ -1,3 +1,4 @@
+import random
from dataclasses import dataclass
import pytest
@@ -6,7 +7,7 @@
from torch.utils.data import BatchSampler, DataLoader, RandomSampler, SequentialSampler
from pytorch_lightning import Trainer
-from pytorch_lightning.demos.boring_classes import BoringModel, RandomDataset
+from pytorch_lightning.demos.boring_classes import BoringModel, RandomDataset, RandomIterableDataset
from pytorch_lightning.overrides.distributed import IndexBatchSamplerWrapper
from pytorch_lightning.trainer.states import RunningStage
from pytorch_lightning.utilities.data import (
@@ -23,7 +24,6 @@
warning_cache,
)
from pytorch_lightning.utilities.exceptions import MisconfigurationException
-from tests_pytorch.helpers.datasets import RandomIterableDataset
from tests_pytorch.helpers.utils import no_warning_call
@@ -173,6 +173,30 @@ def __init__(self, randomize, *args, **kwargs):
assert isinstance(new_dataloader, GoodImpl)
+def test_replace_init_method_multiple_loaders_without_init():
+ """In case of a class, that inherits from a class that we are patching, but doesn't define its own `__init__`
+ method (the one we are wrapping), it can happen, that `hasattr(cls, "_old_init")` is True because of parent
+ class, but it is impossible to delete, because that method is owned by parent class. Furthermore, the error
+ occured only sometimes because it depends on the order in which we are iterating over a set of classes we are
+ patching.
+
+ This test simulates the behavior by generating sufficient number of dummy classes, which do not define `__init__`
+ and are children of `DataLoader`. We are testing that a) context manager `_replace_init_method` exits cleanly, and
+ b) the mechanism checking for presence of `_old_init` works as expected.
+ """
+ classes = [DataLoader]
+ for i in range(100):
+ classes.append(type(f"DataLoader_{i}", (random.choice(classes),), {}))
+
+ with _replace_init_method(DataLoader, "dataset"):
+ for cls in classes[1:]: # First one is `DataLoader`
+ assert "_old_init" not in cls.__dict__
+ assert hasattr(cls, "_old_init")
+
+ assert "_old_init" in DataLoader.__dict__
+ assert hasattr(DataLoader, "_old_init")
+
+
class DataLoaderSubclass1(DataLoader):
def __init__(self, attribute1, *args, **kwargs):
self.at1 = attribute1
diff --git a/tests/tests_pytorch/utilities/test_dtype_device_mixin.py b/tests/tests_pytorch/utilities/test_dtype_device_mixin.py
index 38f72b555d52d..7c17b3d9f7642 100644
--- a/tests/tests_pytorch/utilities/test_dtype_device_mixin.py
+++ b/tests/tests_pytorch/utilities/test_dtype_device_mixin.py
@@ -113,7 +113,7 @@ def test_submodules_multi_gpu_ddp_spawn(tmpdir):
],
)
@RunIf(min_cuda_gpus=1)
-def test_gpu_cuda_device(device):
+def test_cuda_device(device):
model = TopModule()
model.cuda(device)
@@ -122,3 +122,25 @@ def test_gpu_cuda_device(device):
assert device.type == "cuda"
assert device.index is not None
assert device.index == torch.cuda.current_device()
+
+
+@RunIf(min_cuda_gpus=2)
+def test_cuda_current_device():
+ """Test that calling .cuda() moves the model to the correct device and respects current cuda device setting."""
+
+ class CudaModule(DeviceDtypeModuleMixin):
+ def __init__(self):
+ super().__init__()
+ self.layer = nn.Linear(1, 1)
+
+ model = CudaModule()
+
+ torch.cuda.set_device(0)
+ model.cuda(1)
+ assert model.device == torch.device("cuda", 1)
+ assert model.layer.weight.device == torch.device("cuda", 1)
+
+ torch.cuda.set_device(1)
+ model.cuda() # model is already on device 1, and calling .cuda() without device index should not move model
+ assert model.device == torch.device("cuda", 1)
+ assert model.layer.weight.device == torch.device("cuda", 1)
diff --git a/tests/tests_pytorch/utilities/test_grads.py b/tests/tests_pytorch/utilities/test_grads.py
index a548de66ab85d..49aab76403847 100644
--- a/tests/tests_pytorch/utilities/test_grads.py
+++ b/tests/tests_pytorch/utilities/test_grads.py
@@ -76,3 +76,17 @@ def __init__(self):
def test_grad_norm_invalid_norm_type(norm_type):
with pytest.raises(ValueError, match="`norm_type` must be a positive number or 'inf'"):
grad_norm(Mock(), norm_type)
+
+
+def test_grad_norm_with_double_dtype():
+ class Model(nn.Module):
+ def __init__(self):
+ super().__init__()
+ dtype = torch.double
+ self.param = nn.Parameter(torch.tensor(1.0, dtype=dtype))
+ # grad norm of this would become infinite
+ self.param.grad = torch.tensor(1e23, dtype=dtype)
+
+ model = Model()
+ norms = grad_norm(model, 2)
+ assert all(torch.isfinite(torch.tensor(v)) for v in norms.values()), norms
diff --git a/tests/tests_pytorch/utilities/test_seed.py b/tests/tests_pytorch/utilities/test_seed.py
index 7f162bd605640..6908badf1a037 100644
--- a/tests/tests_pytorch/utilities/test_seed.py
+++ b/tests/tests_pytorch/utilities/test_seed.py
@@ -1,6 +1,8 @@
import os
import random
+from typing import Mapping
from unittest import mock
+from unittest.mock import MagicMock
import numpy as np
import pytest
@@ -96,3 +98,19 @@ def test_isolate_rng():
with isolate_rng():
generated = [random.random() for _ in range(3)]
assert random.random() == generated[0]
+
+
+@mock.patch("pytorch_lightning.utilities.seed.log.info")
+@pytest.mark.parametrize("env_vars", [{"RANK": "0"}, {"RANK": "1"}, {"RANK": "4"}])
+def test_seed_everything_log_info(log_mock: MagicMock, env_vars: Mapping[str, str]):
+ """Test that log message prefix with correct rank info."""
+ with mock.patch.dict(os.environ, env_vars, clear=True):
+ from pytorch_lightning.utilities.rank_zero import _get_rank
+
+ rank = _get_rank()
+
+ seed_utils.seed_everything(123)
+
+ expected_log = f"[rank: {rank}] Global seed set to 123"
+
+ log_mock.assert_called_once_with(expected_log)