diff --git a/.actions/setup_tools.py b/.actions/setup_tools.py index 5088be2020738..a76e81246798c 100644 --- a/.actions/setup_tools.py +++ b/.actions/setup_tools.py @@ -94,11 +94,10 @@ def load_readme_description(path_dir: str, homepage: str, version: str) -> str: text = text.replace("pytorch-lightning.readthedocs.io/en/stable/", f"pytorch-lightning.readthedocs.io/en/{version}") # codecov badge text = text.replace("/branch/master/graph/badge.svg", f"/release/{version}/graph/badge.svg") - # replace github badges for release ones + # github actions badge text = text.replace("badge.svg?branch=master&event=push", f"badge.svg?tag={version}") - # Azure... + # azure pipelines badge text = text.replace("?branchName=master", f"?branchName=refs%2Ftags%2F{version}") - text = re.sub(r"\?definitionId=\d+&branchName=master", f"?definitionId=2&branchName=refs%2Ftags%2F{version}", text) skip_begin = r"" skip_end = r"" diff --git a/.azure/gpu-benchmark.yml b/.azure/gpu-benchmark.yml index ac5ca6f60a6b4..0de590f2c54a6 100644 --- a/.azure/gpu-benchmark.yml +++ b/.azure/gpu-benchmark.yml @@ -28,7 +28,7 @@ jobs: cancelTimeoutInMinutes: "2" pool: azure-jirka-spot container: - image: "pytorchlightning/pytorch_lightning:base-cuda-py3.9-torch1.12" + image: "pytorchlightning/pytorch_lightning:base-cuda-py3.9-torch1.12-cuda11.3.1" options: "--runtime=nvidia -e NVIDIA_VISIBLE_DEVICES=all --shm-size=32g" workspace: clean: all diff --git a/.azure/gpu-tests.yml b/.azure/gpu-tests.yml index f37c17613affc..8ae670d265ced 100644 --- a/.azure/gpu-tests.yml +++ b/.azure/gpu-tests.yml @@ -26,7 +26,7 @@ jobs: strategy: matrix: 'PyTorch - stable': - image: "pytorchlightning/pytorch_lightning:base-cuda-py3.9-torch1.12" + image: "pytorchlightning/pytorch_lightning:base-cuda-py3.9-torch1.12-cuda11.3.1" # how long to run the job before automatically cancelling timeoutInMinutes: "80" # how much time to give 'run always even if cancelled tasks' before stopping them @@ -44,7 +44,7 @@ jobs: - bash: | CHANGED_FILES=$(git diff --name-status origin/master -- . | awk '{print $2}') - FILTER='src/pytorch_lightning|requirements/pytorch|tests/tests_pytorch|examples/pl_*|.azure/*' + FILTER='src/pytorch_lightning|requirements/pytorch|tests/tests_pytorch|examples/pl_*|.azure/gpu-tests.yml' echo $CHANGED_FILES > changed_files.txt MATCHES=$(cat changed_files.txt | grep -E $FILTER) echo $MATCHES @@ -75,7 +75,7 @@ jobs: CUDA_VERSION_MM=$(python -c "import torch ; print(''.join(map(str, torch.version.cuda.split('.')[:2])))") pip install "bagua-cuda$CUDA_VERSION_MM>=0.9.0" pip install -e .[strategies] - pip install deepspeed>0.6.4 # TODO: remove when docker images are upgraded + pip install -U deepspeed # TODO: remove when docker images are upgraded pip install --requirement requirements/pytorch/devel.txt pip list env: diff --git a/.github/BECOMING_A_CORE_CONTRIBUTOR.md b/.github/BECOMING_A_CORE_CONTRIBUTOR.md index a179161f687a1..fd40e29e1ebf1 100644 --- a/.github/BECOMING_A_CORE_CONTRIBUTOR.md +++ b/.github/BECOMING_A_CORE_CONTRIBUTOR.md @@ -62,4 +62,4 @@ We are on the lookout for new people to join, however, if you feel like you meet ## Employment -You can also become a [Grid.ai](https://www.grid.ai) employee or intern and work on Lightning. To get started, you can email `careers@grid.ai` with your resume or check out our [open job postings](https://boards.greenhouse.io/gridai). +You can also become a [Lightning AI](https://lightning.ai/) employee or intern and work on Lightning. To get started, you can email `careers@lightning.ai` with your resume or check out our [open job postings](https://boards.greenhouse.io/lightningai). diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS index e40828557c2cf..0b4692731bff9 100644 --- a/.github/CODEOWNERS +++ b/.github/CODEOWNERS @@ -5,7 +5,7 @@ # the repo. Unless a later match takes precedence, # @global-owner1 and @global-owner2 will be requested for # review when someone opens a pull request. -* @williamfalcon @borda @tchaton @SeanNaren @carmocca @awaelchli @justusschock @kaushikb11 @rohitgr7 +* @williamfalcon @borda @tchaton @awaelchli @kaushikb11 @rohitgr7 # CI/CD and configs /.github/ @borda @carmocca @akihironitta @tchaton @@ -26,50 +26,52 @@ /docs/source-app/expertise_levels @williamfalcon @Felonious-Spellfire @RobertLaurella # Packages +/src/pytorch_lightning @carmocca @justusschock /src/pytorch_lightning/accelerators @williamfalcon @tchaton @SeanNaren @awaelchli @justusschock @kaushikb11 /src/pytorch_lightning/callbacks @williamfalcon @tchaton @carmocca @borda @kaushikb11 -/src/pytorch_lightning/core @tchaton @SeanNaren @borda @carmocca @justusschock @kaushikb11 +/src/pytorch_lightning/core @tchaton @borda @carmocca @justusschock @kaushikb11 /src/pytorch_lightning/distributed @williamfalcon @tchaton @awaelchli @kaushikb11 /src/pytorch_lightning/lite @tchaton @awaelchli @carmocca /src/pytorch_lightning/loggers @tchaton @awaelchli @borda -/src/pytorch_lightning/loggers/wandb.py @borisdayma +/src/pytorch_lightning/loggers/wandb.py @borisdayma @borda /src/pytorch_lightning/loggers/neptune.py @shnela @HubertJaworski @pkasprzyk @pitercl @Raalsky @aniezurawski @kamil-kaczmarek /src/pytorch_lightning/loops @tchaton @awaelchli @justusschock @carmocca -/src/pytorch_lightning/overrides @tchaton @SeanNaren @borda -/src/pytorch_lightning/plugins @tchaton @SeanNaren @awaelchli @justusschock +/src/pytorch_lightning/overrides @tchaton @borda +/src/pytorch_lightning/plugins @tchaton @awaelchli @justusschock /src/pytorch_lightning/profilers @williamfalcon @tchaton @borda @carmocca /src/pytorch_lightning/profilers/pytorch.py @nbcsm @guotuofeng /src/pytorch_lightning/strategies @tchaton @SeanNaren @awaelchli @justusschock @kaushikb11 -/src/pytorch_lightning/trainer @williamfalcon @borda @tchaton @SeanNaren @carmocca @awaelchli @justusschock @kaushikb11 -/src/pytorch_lightning/trainer/connectors @tchaton @SeanNaren @carmocca @borda +/src/pytorch_lightning/trainer @williamfalcon @borda @tchaton @carmocca @awaelchli @justusschock @kaushikb11 +/src/pytorch_lightning/trainer/connectors @tchaton @carmocca @borda /src/pytorch_lightning/tuner @SkafteNicki @borda @awaelchli -/src/pytorch_lightning/utilities @borda @tchaton @SeanNaren @carmocca +/src/pytorch_lightning/utilities @borda @tchaton @carmocca -/src/lightning_app @tchaton @awaelchli @manskx @hhsecond +/src/lightning_app @tchaton @manskx +/src/lightning_app/cli/pl-app-template @tchaton @awaelchli @Borda +/src/lightning_app/core @tchaton @awaelchli @manskx +/src/lightning_app/core/queues.py @tchaton @hhsecond @manskx +/src/lightning_app/runners/cloud.py @tchaton @hhsecond +/src/lightning_app/testing @tchaton @manskx +/src/lightning_app/__about__.py @nohalon @edenlightning @lantiga # Examples -/examples/app_* @tchaton @awaelchli @manskx @hhsecond +/examples/app_* @tchaton @awaelchli @manskx @hhsecond # App tests -/tests/tests_app @tchaton @awaelchli @manskx @hhsecond -/tests/tests_app_examples @tchaton @awaelchli @manskx @hhsecond +/tests/tests_app @tchaton @awaelchli @manskx @hhsecond +/tests/tests_app_examples @tchaton @awaelchli @manskx @hhsecond # Specifics -/src/pytorch_lightning/trainer/connectors/logger_connector @tchaton @carmocca -/src/pytorch_lightning/trainer/progress.py @tchaton @awaelchli @carmocca - +/src/pytorch_lightning/trainer/connectors/logger_connector @tchaton @carmocca +/src/pytorch_lightning/trainer/progress.py @tchaton @awaelchli @carmocca # API -/src/pytorch_lightning/callbacks/base.py @williamfalcon @awaelchli @ananthsub @carmocca -/src/pytorch_lightning/core/datamodule.py @williamFalcon @awaelchli @ananthsub @carmocca -/src/pytorch_lightning/trainer/trainer.py @williamfalcon @tchaton @awaelchli -/src/pytorch_lightning/core/hooks.py @williamfalcon @tchaton @awaelchli @ananthsub @carmocca -/src/pytorch_lightning/core/lightning.py @williamfalcon @tchaton @awaelchli - -# Testing -/tests/helpers/boring_model.py @williamfalcon @tchaton @borda +/src/pytorch_lightning/callbacks/callback.py @williamfalcon @awaelchli @ananthsub @carmocca +/src/pytorch_lightning/core/datamodule.py @williamFalcon @awaelchli @ananthsub @carmocca +/src/pytorch_lightning/trainer/trainer.py @williamfalcon @tchaton @awaelchli +/src/pytorch_lightning/core/hooks.py @williamfalcon @tchaton @awaelchli @ananthsub @carmocca +/src/pytorch_lightning/core/module.py @williamfalcon @tchaton @awaelchli -/.github/CODEOWNERS @williamfalcon -/.github/approve_config.yml @williamfalcon -/SECURITY.md @williamfalcon -/README.md @williamfalcon @edenlightning @borda -/setup.py @williamfalcon @borda @carmocca +/.github/CODEOWNERS @williamfalcon +/SECURITY.md @williamfalcon +/README.md @williamfalcon @edenlightning @borda +/setup.py @williamfalcon @borda @carmocca /src/pytorch_lightning/__about__.py @williamfalcon @borda @carmocca diff --git a/.github/ISSUE_TEMPLATE/bug_report.md b/.github/ISSUE_TEMPLATE/bug_report.md index f08865180ba1d..de4eacde1f39e 100644 --- a/.github/ISSUE_TEMPLATE/bug_report.md +++ b/.github/ISSUE_TEMPLATE/bug_report.md @@ -41,8 +41,16 @@ You can get the script and run it with: ```bash wget https://raw.githubusercontent.com/Lightning-AI/lightning/master/requirements/collect_env_details.py python collect_env_details.py + ``` + +
+ Details + Paste the output here and move this toggle outside of the comment block. +
+ + You can also fill out the list below manually. --> diff --git a/.github/checkgroup.yml b/.github/checkgroup.yml new file mode 100644 index 0000000000000..8f1d3c6fb5e86 --- /dev/null +++ b/.github/checkgroup.yml @@ -0,0 +1,165 @@ +custom_service_name: "Lightning CI required checker" +subprojects: + - id: "CI: CircleCI" + paths: + - ".circleci/**" + checks: + - "test-on-tpus" + + - id: "CI: Azure" + paths: + - ".azure/**" + checks: + - "pytorch-lightning (GPUs)" + - "pytorch-lightning (GPUs) (testing PyTorch - stable)" + - "pytorch-lightning (HPUs)" + - "pytorch-lightning (IPUs)" + + - id: "pytorch_lightning" + paths: + # all examples don't need to be added because they aren't used in CI, but these are + - "examples/run_ddp_examples.sh" + - "examples/convert_from_pt_to_pl/**" + - "examples/run_pl_examples.sh" + - "examples/pl_basics/backbone_image_classifier.py" + - "examples/pl_basics/autoencoder.py" + - "examples/pl_loops/mnist_lite.py" + - "examples/pl_fault_tolerant/automatic.py" + - "examples/test_pl_examples.py" + - "examples/pl_integrations/dali_image_classifier.py" + - "requirements/pytorch/**" + - "src/pytorch_lightning/**" + - "tests/tests_pytorch/**" + - "setup.cfg" # includes pytest config + - ".github/workflows/ci-pytorch*.yml" + - ".github/workflows/docs-*.yml" + checks: + - "conda (3.8, 1.10)" + - "conda (3.8, 1.9)" + - "conda (3.9, 1.11)" + - "conda (3.9, 1.12)" + - "cpu (macOS-11, 3.10, latest, stable)" + - "cpu (macOS-11, 3.7, latest, stable)" + - "cpu (macOS-11, 3.7, oldest, stable)" + - "cpu (ubuntu-20.04, 3.10, latest, stable)" + - "cpu (ubuntu-20.04, 3.7, latest, stable)" + - "cpu (ubuntu-20.04, 3.7, oldest, stable)" + - "cpu (windows-2022, 3.10, latest, stable)" + - "cpu (windows-2022, 3.7, latest, stable)" + - "cpu (windows-2022, 3.7, oldest, stable)" + - "doctest (pytorch)" + - "make-docs (pytorch)" + - "mypy" + - "PR Gatekeeper (pytorch)" + - "pytorch-lightning (GPUs)" + - "pytorch-lightning (GPUs) (testing PyTorch - stable)" + - "pytorch-lightning (HPUs)" + - "pytorch-lightning (IPUs)" + - "slow (macOS-11, 3.7, 1.11)" + - "slow (ubuntu-20.04, 3.7, 1.11)" + - "slow (windows-2022, 3.7, 1.11)" + - "test-on-tpus" + + - id: "pytorch_lightning: Docs" + paths: + - "docs/source-pytorch/**" + - ".github/workflows/docs-*.yml" + - "requirements/pytorch/**" + checks: + - "doctest (pytorch)" + - "make-docs (pytorch)" + + - id: "pytorch_lightning: Docker" + paths: + - "dockers/**" + checks: + - "build-conda (3.8, 1.10)" + - "build-conda (3.8, 1.9)" + - "build-conda (3.9, 1.11)" + - "build-conda (3.9, 1.12)" + - "build-cuda (3.8, 1.9, 11.1.1)" + - "build-cuda (3.9, 1.10, 11.3.1)" + - "build-cuda (3.9, 1.11, 11.3.1)" + - "build-cuda (3.9, 1.12, 11.3.1)" + - "build-cuda (3.9, 1.9, 11.1.1)" + - "build-hpu (1.5.0, 1.11.0)" + - "build-ipu (3.9, 1.9)" + - "build-NGC" + - "build-pl (3.9, 1.10, 11.3.1)" + - "build-pl (3.9, 1.11, 11.3.1)" + - "build-pl (3.9, 1.12, 11.3.1)" + - "build-pl (3.9, 1.9, 11.1.1)" + - "build-xla (3.7, 1.12)" + + - id: "pytorch_lightning: mypy" + paths: + - ".github/workflows/code-checks.yml" + - "pyproject.toml" # includes mypy config + checks: + - "mypy" + + - id: "lightning_app" + paths: + - ".github/workflows/ci-app*.yml" + - "examples/app_**" + - "requirements/app/**" + - "src/lightning_app/**" + - "tests/tests_app/**" + - "tests/tests_app_examples/**" + - "tests/tests_clusters/**" + # the examples are used in the app CI + - "examples/app_*" + checks: + - "Cloud Test (boring_app)" + - "Cloud Test (collect_failures)" + - "Cloud Test (commands_and_api)" + - "Cloud Test (custom_work_dependencies)" + - "Cloud Test (drive)" + - "Cloud Test (idle_timeout)" + - "Cloud Test (payload)" + - "Cloud Test (template_jupyterlab)" + - "Cloud Test (template_react_ui)" + - "Cloud Test (template_streamlit_ui)" + - "Cloud Test (v0_app)" + - "doctest (app)" + - "make-docs (app)" + - "pytest (macOS-11, 3.8, latest)" + - "pytest (macOS-11, 3.8, oldest)" + - "pytest (ubuntu-20.04, 3.8, latest)" + - "pytest (ubuntu-20.04, 3.8, oldest)" + - "pytest (windows-2022, 3.8, latest)" + - "pytest (windows-2022, 3.8, oldest)" + + - id: "lightning_app: Docs" + paths: + - "docs/source-app/**" + - ".github/workflows/docs-*.yml" + - "requirements/app/**" + checks: + - "doctest (app)" + - "make-docs (app)" + + - id: "install" + paths: + - ".actions/setup_tools.py" + - ".github/workflows/ci-pkg-install.yml" + - "setup.py" + - "src/lightning/**" + # all __about__, __version__, __setup__ + - "src/*/__*.py" + checks: + - "install-meta-pypi (macOS-11, 3.8)" + - "install-meta-pypi (ubuntu-20.04, 3.8)" + - "install-meta-pypi (windows-2022, 3.8)" + - "install-meta-src (macOS-11, 3.8)" + - "install-meta-src (macOS-11, lightning, 3.8)" + - "install-meta-src (ubuntu-20.04, 3.8)" + - "install-meta-src (ubuntu-20.04, lightning, 3.8)" + - "install-meta-src (windows-2022, 3.8)" + - "install-meta-src (windows-2022, lightning, 3.8)" + - "install-standalone (macOS-11, app, 3.8)" + - "install-standalone (macOS-11, pytorch, 3.8)" + - "install-standalone (ubuntu-20.04, app, 3.8)" + - "install-standalone (ubuntu-20.04, pytorch, 3.8)" + - "install-standalone (windows-2022, app, 3.8)" + - "install-standalone (windows-2022, pytorch, 3.8)" diff --git a/.github/workflows/README.md b/.github/workflows/README.md index 8b9e7d173b03c..4ed903c0f3a93 100644 --- a/.github/workflows/README.md +++ b/.github/workflows/README.md @@ -4,16 +4,16 @@ ## Unit and Integration Testing -| workflow name | workflow file | action | accelerator\* | (Python, PyTorch) | OS | -| -------------------------- | ----------------------------------- | --------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ------------- | ------------------------------------------------ | ------------------- | -| Test full | .github/workflows/ci_test-full.yml | Run all tests except for accelerator-specific, standalone and slow tests. | CPU | (3.7, 1.8), (3.7, 1.11), (3.9, 1.8), (3.9, 1.12) | linux, mac, windows | -| Test with Conda | .github/workflows/ci_test-conda.yml | Same as ci_test-full.yml but with dependencies installed with conda. | CPU | (3.8, 1.8), (3.8, 1.9), (3.8, 1.10), (3.9, 1.12) | linux | -| Test slow | .github/workflows/ci_test-slow.yml | Run only slow tests. Slow tests usually need to spawn threads and cannot be speed up or simplified. | CPU | (3.7, 1.8) | linux, mac, windows | -| pytorch-lightning (IPUs) | .azure-pipelines/ipu-tests.yml | Run only IPU-specific tests. | IPU | (3.8, 1.9) | linux | -| pytorch-lightning (HPUs) | .azure-pipelines/hpu-tests.yml | Run only HPU-specific tests. | HPU | (3.8, 1.10) | linux | -| pytorch-lightning (GPUs) | .azure-pipelines/gpu-tests.yml | Run all CPU and GPU-specific tests, standalone, and examples. Each standalone test needs to be run in separate processes to avoid unwanted interactions between test cases. | GPU | (3.9, 1.12) | linux | -| PyTorchLightning.Benchmark | .azure-pipelines/gpu-benchmark.yml | Run speed/memory benchmarks for parity with pure PyTorch. | GPU | (3.9, 1.12) | linux | -| test-on-tpus | .circleci/config.yml | Run only TPU-specific tests. | TPU | (3.7, 1.12) | linux | +| workflow name | workflow file | action | accelerator\* | (Python, PyTorch) | OS | +| -------------------------- | ------------------------------------------- | --------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ------------- | ------------------------------------------------- | ------------------- | +| Test PyTorch full | .github/workflows/ci-pytorch-test-full.yml | Run all tests except for accelerator-specific, standalone and slow tests. | CPU | (3.7, 1.9), (3.7, 1.12), (3.9, 1.9), (3.9, 1.12) | linux, mac, windows | +| Test PyTorch with Conda | .github/workflows/ci-pytorch-test-conda.yml | Same as ci-pytorch-test-full.yml but with dependencies installed with conda. | CPU | (3.8, 1.9), (3.8, 1.10), (3.8, 1.11), (3.9, 1.12) | linux | +| Test slow | .github/workflows/ci-pytorch-test-slow.yml | Run only slow tests. Slow tests usually need to spawn threads and cannot be speed up or simplified. | CPU | (3.7, 1.11) | linux, mac, windows | +| pytorch-lightning (IPUs) | .azure-pipelines/ipu-tests.yml | Run only IPU-specific tests. | IPU | (3.8, 1.9) | linux | +| pytorch-lightning (HPUs) | .azure-pipelines/hpu-tests.yml | Run only HPU-specific tests. | HPU | (3.8, 1.10) | linux | +| pytorch-lightning (GPUs) | .azure-pipelines/gpu-tests.yml | Run all CPU and GPU-specific tests, standalone, and examples. Each standalone test needs to be run in separate processes to avoid unwanted interactions between test cases. | GPU | (3.9, 1.12) | linux | +| PyTorchLightning.Benchmark | .azure-pipelines/gpu-benchmark.yml | Run speed/memory benchmarks for parity with pure PyTorch. | GPU | (3.9, 1.12) | linux | +| test-on-tpus | .circleci/config.yml | Run only TPU-specific tests. | TPU | (3.7, 1.12) | linux | - \*Accelerators used in CI - GPU: 2 x NVIDIA Tesla V100 @@ -33,15 +33,15 @@ | --------------------------------- | ----------------------------------------------------------------------------------------- | | .codecov.yml | Measure test coverage with [codecov.io](https://app.codecov.io/gh/Lightning-AI/lightning) | | .github/workflows/code-checks.yml | Check Python typing with [MyPy](https://mypy.readthedocs.io/en/stable/). | -| .github/workflows/ci_schema.yml | Validate the syntax of workflow files. | +| .github/workflows/ci-schema.yml | Validate the syntax of workflow files. | ## Others -| workflow file | action | -| -------------------------------------- | ----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | -| .github/workflows/ci_dockers.yml | Build docker images used for testing in CI without pushing to the [Docker Hub](https://hub.docker.com/r/pytorchlightning/pytorch_lightning). Publishing these built images takes place in `.github/workflows/release-docker.yml` which only runs in master. | -| .github/workflows/ci_pkg-install.yml | Test if pytorch-lightning is successfully installed using pip. | -| .github/workflows/events-recurrent.yml | Terminate TPU jobs that live more than one hour to avoid possible resource exhaustion due to hangs. | +| workflow file | action | +| ------------------------------------------ | -------------------------------------------------------------------------------------------------------------------------------------------------------------- | +| .github/workflows/cicd-pytorch-dockers.yml | Build docker images used for testing in CI. If run on nightly schedule, push to the [Docker Hub](https://hub.docker.com/r/pytorchlightning/pytorch_lightning). | +| .github/workflows/ci-pkg-install.yml | Test if pytorch-lightning is successfully installed using pip. | +| .github/workflows/events-recurrent.yml | Terminate TPU jobs that live more than one hour to avoid possible resource exhaustion due to hangs. | ## Deployment @@ -60,4 +60,4 @@ | .github/stale.yml | Close inactive issues/PRs sometimes after adding the "won't fix" label to them. | | .github/workflows/probot-auto-cc.yml, .github/lightning-probot.yml | Notify maintainers of interest depending on labels added to an issue We utilize lightning-probot forked from PyTorch’s probot. | | .pre-commit-config.yaml | pre-commit.ci runs a set of linters and formatters, such as black, flake8 and isort. When formatting is applied, the bot pushes a commit with its change. This configuration is also used for running pre-commit locally. | -| .github/workflows/ci_pr-gatekeeper.yml | Prevent PRs from merging into master without any Grid.ai employees’ approval. | +| .github/workflows/ci-pr-gatekeeper.yml | Prevent PRs from merging into master without any Grid.ai employees’ approval. | diff --git a/.github/workflows/ci-app_cloud_e2e_test.yml b/.github/workflows/ci-app-cloud-e2e-test.yml similarity index 99% rename from .github/workflows/ci-app_cloud_e2e_test.yml rename to .github/workflows/ci-app-cloud-e2e-test.yml index 3ad455650a117..9a5a10a95cd33 100644 --- a/.github/workflows/ci-app_cloud_e2e_test.yml +++ b/.github/workflows/ci-app-cloud-e2e-test.yml @@ -54,7 +54,7 @@ jobs: - custom_work_dependencies - drive - payload - - commands + - commands_and_api timeout-minutes: 35 steps: - uses: actions/checkout@v2 diff --git a/.github/workflows/ci-app_examples.yml b/.github/workflows/ci-app-examples.yml similarity index 98% rename from .github/workflows/ci-app_examples.yml rename to .github/workflows/ci-app-examples.yml index ec8becd5f70d1..01570f59c2c77 100644 --- a/.github/workflows/ci-app_examples.yml +++ b/.github/workflows/ci-app-examples.yml @@ -17,7 +17,7 @@ jobs: strategy: fail-fast: false matrix: - os: [ubuntu-20.04, macOS-11, windows-2019] + os: [ubuntu-20.04, macOS-11, windows-2022] python-version: [3.8] requires: ["oldest", "latest"] diff --git a/.github/workflows/ci-app_tests.yml b/.github/workflows/ci-app-tests.yml similarity index 96% rename from .github/workflows/ci-app_tests.yml rename to .github/workflows/ci-app-tests.yml index 1678dab257301..fe3cc36dc16d3 100644 --- a/.github/workflows/ci-app_tests.yml +++ b/.github/workflows/ci-app-tests.yml @@ -21,7 +21,7 @@ jobs: strategy: fail-fast: false matrix: - os: [ubuntu-20.04, macOS-11, windows-2019] + os: [ubuntu-20.04, macOS-11, windows-2022] python-version: [3.8] requires: ["oldest", "latest"] @@ -126,7 +126,7 @@ jobs: # - name: Clone Quick Start Example Repo # uses: actions/checkout@v3 # # TODO: this needs to be git submodule -# if: matrix.os == 'windows-2019' # because the install doesn't work on windows +# if: matrix.os == 'windows-2022' # because the install doesn't work on windows # with: # repository: Lightning-AI/lightning-quick-start # ref: 'main' @@ -134,6 +134,6 @@ jobs: # # - name: Lightning Install quick-start # shell: bash -# if: matrix.os != 'windows-2019' # because the install doesn't work on windows +# if: matrix.os != 'windows-2022' # because the install doesn't work on windows # run: | # python -m lightning install app lightning/quick-start -y diff --git a/.github/workflows/ci_pkg-install.yml b/.github/workflows/ci-pkg-install.yml similarity index 95% rename from .github/workflows/ci_pkg-install.yml rename to .github/workflows/ci-pkg-install.yml index 342e027b07cfe..a9fdd36693a67 100644 --- a/.github/workflows/ci_pkg-install.yml +++ b/.github/workflows/ci-pkg-install.yml @@ -33,7 +33,7 @@ jobs: fail-fast: true max-parallel: 1 matrix: - os: [ubuntu-20.04, macOS-11, windows-2019] + os: [ubuntu-20.04, macOS-11, windows-2022] pkg: ["app", "pytorch"] python-version: [3.8] # , 3.9 @@ -67,7 +67,7 @@ jobs: fail-fast: false # max-parallel: 1 matrix: - os: [ubuntu-20.04, macOS-11, windows-2019] + os: [ubuntu-20.04, macOS-11, windows-2022] pkg: ["", "lightning"] python-version: [3.8] # , 3.9 @@ -100,7 +100,7 @@ jobs: fail-fast: false # max-parallel: 1 matrix: - os: [ubuntu-20.04, macOS-11, windows-2019] + os: [ubuntu-20.04, macOS-11, windows-2022] python-version: [3.8] # , 3.9 steps: diff --git a/.github/workflows/ci_pr-gatekeeper.yml b/.github/workflows/ci-pr-gatekeeper.yml similarity index 100% rename from .github/workflows/ci_pr-gatekeeper.yml rename to .github/workflows/ci-pr-gatekeeper.yml diff --git a/.github/workflows/ci-pytorch_test-conda.yml b/.github/workflows/ci-pytorch-test-conda.yml similarity index 98% rename from .github/workflows/ci-pytorch_test-conda.yml rename to .github/workflows/ci-pytorch-test-conda.yml index 777ec2af759a0..2bbdb699c2c1e 100644 --- a/.github/workflows/ci-pytorch_test-conda.yml +++ b/.github/workflows/ci-pytorch-test-conda.yml @@ -22,13 +22,11 @@ jobs: strategy: fail-fast: false matrix: - # nightly: add when there's a release candidate include: - {python-version: "3.8", pytorch-version: "1.9"} - {python-version: "3.8", pytorch-version: "1.10"} - {python-version: "3.9", pytorch-version: "1.11"} - {python-version: "3.9", pytorch-version: "1.12"} - timeout-minutes: 30 steps: @@ -45,7 +43,7 @@ jobs: id: skip shell: bash -l {0} run: | - FILTER='src/pytorch_lightning|requirements/pytorch|tests/tests_pytorch|examples/pl_*' + FILTER='src/pytorch_lightning|requirements/pytorch|tests/tests_pytorch|examples/pl_*|.github/workflows/ci-pytorch-test-conda.yml' echo "${{ steps.changed-files.outputs.all_changed_files }}" | tr " " "\n" > changed_files.txt MATCHES=$(cat changed_files.txt | grep -E $FILTER) echo $MATCHES diff --git a/.github/workflows/ci-pytorch_test-full.yml b/.github/workflows/ci-pytorch-test-full.yml similarity index 95% rename from .github/workflows/ci-pytorch_test-full.yml rename to .github/workflows/ci-pytorch-test-full.yml index fb6916d1414fe..7409ce25a5128 100644 --- a/.github/workflows/ci-pytorch_test-full.yml +++ b/.github/workflows/ci-pytorch-test-full.yml @@ -20,10 +20,14 @@ jobs: strategy: fail-fast: false matrix: - os: [ubuntu-20.04, windows-2019, macOS-11] - python-version: ["3.7", "3.9"] # minimum, maximum + os: [ubuntu-20.04, windows-2022, macOS-11] + python-version: ["3.7", "3.10"] # minimum, maximum requires: ["oldest", "latest"] release: ["stable"] + exclude: + # There's no distribution of the oldest PyTorch 1.9 for Python 3.10. + # TODO: Remove the exclusion when dropping PyTorch 1.9 support. + - {python-version: "3.10", requires: "oldest"} # TODO: re-enable RC testing # include: # - {os: ubuntu-20.04, python-version: "3.10", requires: "latest", release: "pre"} @@ -41,7 +45,7 @@ jobs: id: skip shell: bash -l {0} run: | - FILTER='src/pytorch_lightning|requirements/pytorch|tests/tests_pytorch|examples/pl_*' + FILTER='src/pytorch_lightning|requirements/pytorch|tests/tests_pytorch|examples/pl_*|.github/workflows/ci-pytorch_test-full.yml' echo "${{ steps.changed-files.outputs.all_changed_files }}" | tr " " "\n" > changed_files.txt MATCHES=$(cat changed_files.txt | grep -E $FILTER) echo $MATCHES diff --git a/.github/workflows/ci-pytorch_test-slow.yml b/.github/workflows/ci-pytorch-test-slow.yml similarity index 97% rename from .github/workflows/ci-pytorch_test-slow.yml rename to .github/workflows/ci-pytorch-test-slow.yml index 905f60aa85699..36007d3311451 100644 --- a/.github/workflows/ci-pytorch_test-slow.yml +++ b/.github/workflows/ci-pytorch-test-slow.yml @@ -19,7 +19,7 @@ jobs: strategy: fail-fast: false matrix: - os: [ubuntu-20.04, windows-2019, macOS-11] + os: [ubuntu-20.04, windows-2022, macOS-11] # same config as '.azure-pipelines/gpu-tests.yml' python-version: ["3.7"] pytorch-version: ["1.11"] @@ -36,7 +36,7 @@ jobs: id: skip shell: bash -l {0} run: | - FILTER='src/pytorch_lightning|requirements/pytorch|tests/tests_pytorch|examples/pl_*' + FILTER='src/pytorch_lightning|requirements/pytorch|tests/tests_pytorch|examples/pl_*|.github/workflows/ci-pytorch_test-slow.yml' echo "${{ steps.changed-files.outputs.all_changed_files }}" | tr " " "\n" > changed_files.txt MATCHES=$(cat changed_files.txt | grep -E $FILTER) echo $MATCHES diff --git a/.github/workflows/ci_schema.yml b/.github/workflows/ci-schema.yml similarity index 100% rename from .github/workflows/ci_schema.yml rename to .github/workflows/ci-schema.yml diff --git a/.github/workflows/cicd-pytorch_dockers.yml b/.github/workflows/cicd-pytorch-dockers.yml similarity index 81% rename from .github/workflows/cicd-pytorch_dockers.yml rename to .github/workflows/cicd-pytorch-dockers.yml index a6ba2ac4aa5f4..84051cafd82d8 100644 --- a/.github/workflows/cicd-pytorch_dockers.yml +++ b/.github/workflows/cicd-pytorch-dockers.yml @@ -29,17 +29,22 @@ jobs: strategy: fail-fast: false matrix: - # the config used in '.azure-pipelines/gpu-tests.yml' since the Dockerfile uses the cuda image - python_version: ["3.9"] - pytorch_version: ["1.12"] + include: + # We only release one docker image per PyTorch version. + # The matrix here is the same as the one in release-docker.yml. + - {python_version: "3.9", pytorch_version: "1.9", cuda_version: "11.1.1"} + - {python_version: "3.9", pytorch_version: "1.10", cuda_version: "11.3.1"} + - {python_version: "3.9", pytorch_version: "1.11", cuda_version: "11.3.1"} + - {python_version: "3.9", pytorch_version: "1.12", cuda_version: "11.3.1"} steps: - - uses: actions/checkout@v2 + - uses: actions/checkout@v3 - uses: docker/setup-buildx-action@v2 - - uses: docker/build-push-action@v2 + - uses: docker/build-push-action@v3 with: build-args: | PYTHON_VERSION=${{ matrix.python_version }} PYTORCH_VERSION=${{ matrix.pytorch_version }} + CUDA_VERSION=${{ matrix.cuda_version }} file: dockers/release/Dockerfile push: false # pushed in release-docker.yml only when PL is released timeout-minutes: 50 @@ -53,14 +58,14 @@ jobs: python_version: ["3.7"] xla_version: ["1.12"] steps: - - uses: actions/checkout@v2 + - uses: actions/checkout@v3 - uses: docker/setup-buildx-action@v2 - - uses: docker/login-action@v1 + - uses: docker/login-action@v2 if: env.PUSH_TO_HUB == 'true' with: username: ${{ secrets.DOCKER_USERNAME }} password: ${{ secrets.DOCKER_PASSWORD }} - - uses: docker/build-push-action@v2 + - uses: docker/build-push-action@v3 with: build-args: | PYTHON_VERSION=${{ matrix.python_version }} @@ -85,30 +90,31 @@ jobs: fail-fast: false matrix: include: - # the config used in '.azure-pipelines/gpu-tests.yml' - - {python_version: "3.9", pytorch_version: "1.12", cuda_version: "11.3.1", ubuntu_version: "20.04"} - # latest (used in Tutorials) - - {python_version: "3.8", pytorch_version: "1.9", cuda_version: "11.1.1", ubuntu_version: "20.04"} - - {python_version: "3.9", pytorch_version: "1.10", cuda_version: "11.1.1", ubuntu_version: "20.04"} - - {python_version: "3.9", pytorch_version: "1.11", cuda_version: "11.3.1", ubuntu_version: "20.04"} + # These are the base images for PL release docker images, + # so include at least all of the combinations in release-dockers.yml. + - {python_version: "3.9", pytorch_version: "1.9", cuda_version: "11.1.1"} + - {python_version: "3.9", pytorch_version: "1.10", cuda_version: "11.3.1"} + - {python_version: "3.9", pytorch_version: "1.11", cuda_version: "11.3.1"} + - {python_version: "3.9", pytorch_version: "1.12", cuda_version: "11.3.1"} + # Used in Lightning-AI/tutorials + - {python_version: "3.8", pytorch_version: "1.9", cuda_version: "11.1.1"} steps: - - uses: actions/checkout@v2 + - uses: actions/checkout@v3 - uses: docker/setup-buildx-action@v2 - - uses: docker/login-action@v1 + - uses: docker/login-action@v2 if: env.PUSH_TO_HUB == 'true' with: username: ${{ secrets.DOCKER_USERNAME }} password: ${{ secrets.DOCKER_PASSWORD }} - - uses: docker/build-push-action@v2 + - uses: docker/build-push-action@v3 with: build-args: | PYTHON_VERSION=${{ matrix.python_version }} PYTORCH_VERSION=${{ matrix.pytorch_version }} CUDA_VERSION=${{ matrix.cuda_version }} - UBUNTU_VERSION=${{ matrix.ubuntu_version }} file: dockers/base-cuda/Dockerfile push: ${{ env.PUSH_TO_HUB }} - tags: pytorchlightning/pytorch_lightning:base-cuda-py${{ matrix.python_version }}-torch${{ matrix.pytorch_version }} + tags: pytorchlightning/pytorch_lightning:base-cuda-py${{ matrix.python_version }}-torch${{ matrix.pytorch_version }}-cuda${{ matrix.cuda_version }} timeout-minutes: 95 - uses: ravsamhq/notify-slack-action@v1 if: failure() && env.PUSH_TO_HUB == 'true' @@ -126,25 +132,23 @@ jobs: fail-fast: false matrix: include: - - {python_version: "3.8", pytorch_version: "1.9", cuda_version: "11.1.1"} - - {python_version: "3.8", pytorch_version: "1.10", cuda_version: "11.1.1"} - - {python_version: "3.9", pytorch_version: "1.11", cuda_version: "11.3.1"} - # nightly: add when there's a release candidate - # - {python_version: "3.9", pytorch_version: "1.12"} + - {python_version: "3.8", pytorch_version: "1.9"} + - {python_version: "3.8", pytorch_version: "1.10"} + - {python_version: "3.9", pytorch_version: "1.11"} + - {python_version: "3.9", pytorch_version: "1.12"} steps: - - uses: actions/checkout@v2 + - uses: actions/checkout@v3 - uses: docker/setup-buildx-action@v2 - - uses: docker/login-action@v1 + - uses: docker/login-action@v2 if: env.PUSH_TO_HUB == 'true' with: username: ${{ secrets.DOCKER_USERNAME }} password: ${{ secrets.DOCKER_PASSWORD }} - - uses: docker/build-push-action@v2 + - uses: docker/build-push-action@v3 with: build-args: | PYTHON_VERSION=${{ matrix.python_version }} PYTORCH_VERSION=${{ matrix.pytorch_version }} - CUDA_VERSION=${{ matrix.cuda_version }} file: dockers/base-conda/Dockerfile push: ${{ env.PUSH_TO_HUB }} tags: pytorchlightning/pytorch_lightning:base-conda-py${{ matrix.python_version }}-torch${{ matrix.pytorch_version }} @@ -168,14 +172,14 @@ jobs: # the config used in 'dockers/ci-runner-ipu/Dockerfile' - {python_version: "3.9", pytorch_version: "1.9"} steps: - - uses: actions/checkout@v2 + - uses: actions/checkout@v3 - uses: docker/setup-buildx-action@v2 - - uses: docker/login-action@v1 + - uses: docker/login-action@v2 if: env.PUSH_TO_HUB == 'true' with: username: ${{ secrets.DOCKER_USERNAME }} password: ${{ secrets.DOCKER_PASSWORD }} - - uses: docker/build-push-action@v2 + - uses: docker/build-push-action@v3 with: build-args: | PYTHON_VERSION=${{ matrix.python_version }} @@ -184,7 +188,7 @@ jobs: push: ${{ env.PUSH_TO_HUB }} tags: pytorchlightning/pytorch_lightning:base-ipu-py${{ matrix.python_version }}-torch${{ matrix.pytorch_version }} timeout-minutes: 100 - - uses: docker/build-push-action@v2 + - uses: docker/build-push-action@v3 with: build-args: | PYTHON_VERSION=${{ matrix.python_version }} @@ -199,7 +203,7 @@ jobs: status: ${{ job.status }} token: ${{ secrets.GITHUB_TOKEN }} notification_title: ${{ format('IPU; {0} py{1} for *{2}*', runner.os, matrix.python_version, matrix.pytorch_version) }} - message_format: '{emoji} *{workflow}* {status_message}, see <{run_url}|detail>, cc: <@U01BULUS2BG>' # SeanNaren + message_format: '{emoji} *{workflow}* {status_message}, see <{run_url}|detail>, cc: <@U01GD29QCAV>' # kaushikb11 env: SLACK_WEBHOOK_URL: ${{ secrets.SLACK_WEBHOOK_URL }} @@ -212,14 +216,14 @@ jobs: # the config used in 'dockers/ci-runner-hpu/Dockerfile' - {gaudi_version: "1.5.0", pytorch_version: "1.11.0"} steps: - - uses: actions/checkout@v2 + - uses: actions/checkout@v3 - uses: docker/setup-buildx-action@v2 - - uses: docker/login-action@v1 + - uses: docker/login-action@v2 if: env.PUSH_TO_HUB == 'true' with: username: ${{ secrets.DOCKER_USERNAME }} password: ${{ secrets.DOCKER_PASSWORD }} - - uses: docker/build-push-action@v2 + - uses: docker/build-push-action@v3 with: build-args: | DIST=latest @@ -243,10 +247,10 @@ jobs: runs-on: ubuntu-20.04 steps: - name: Checkout - uses: actions/checkout@v2 + uses: actions/checkout@v3 - name: Build Conda Docker # publish master/release - uses: docker/build-push-action@v2 + uses: docker/build-push-action@v3 with: file: dockers/nvidia/Dockerfile push: false diff --git a/.github/workflows/code-checks.yml b/.github/workflows/code-checks.yml index ed9cd46adbe44..15bd5e9911740 100644 --- a/.github/workflows/code-checks.yml +++ b/.github/workflows/code-checks.yml @@ -32,8 +32,9 @@ jobs: - name: Install dependencies run: | - pip install torch==1.11 --find-links https://download.pytorch.org/whl/cpu/torch_stable.html + pip install torch==1.12 --find-links https://download.pytorch.org/whl/cpu/torch_stable.html python ./requirements/pytorch/adjust-versions.py requirements/pytorch/extra.txt + # todo: adjust requirements for both code-bases pip install -r requirements/pytorch/devel.txt --find-links https://download.pytorch.org/whl/cpu/torch_stable.html pip list diff --git a/.github/workflows/release-docker.yml b/.github/workflows/release-docker.yml index 9d87f1a582fb1..6901a24204683 100644 --- a/.github/workflows/release-docker.yml +++ b/.github/workflows/release-docker.yml @@ -1,6 +1,5 @@ name: Docker -# https://www.docker.com/blog/first-docker-github-action-is-here -# https://github.com/docker/build-push-action + on: push: branches: [master, "release/*"] @@ -15,8 +14,12 @@ jobs: strategy: fail-fast: false matrix: - python_version: ["3.7", "3.8", "3.9"] - pytorch_version: ["1.9", "1.10"] + include: + # We only release one docker image per PyTorch version. + - {python_version: "3.9", pytorch_version: "1.9", cuda_version: "11.1.1"} + - {python_version: "3.9", pytorch_version: "1.10", cuda_version: "11.3.1"} + - {python_version: "3.9", pytorch_version: "1.11", cuda_version: "11.3.1"} + - {python_version: "3.9", pytorch_version: "1.12", cuda_version: "11.3.1"} steps: - name: Checkout uses: actions/checkout@v2 @@ -32,19 +35,29 @@ jobs: username: ${{ secrets.DOCKER_USERNAME }} password: ${{ secrets.DOCKER_PASSWORD }} dockerfile: dockers/release/Dockerfile - build_args: PYTHON_VERSION=${{ matrix.python_version }},PYTORCH_VERSION=${{ matrix.pytorch_version }},LIGHTNING_VERSION=${{ steps.get_version.outputs.RELEASE_VERSION }} - tags: "${{ steps.get_version.outputs.RELEASE_VERSION }}-py${{ matrix.python_version }}-torch${{ matrix.pytorch_version }},latest-py${{ matrix.python_version }}-torch${{ matrix.pytorch_version }}" + build_args: | + PYTHON_VERSION=${{ matrix.python_version }} + PYTORCH_VERSION=${{ matrix.pytorch_version }} + CUDA_VERSION=${{ matrix.cuda_version }} + LIGHTNING_VERSION=${{ steps.get_version.outputs.RELEASE_VERSION }} + tags: | + ${{ steps.get_version.outputs.RELEASE_VERSION }}-py${{ matrix.python_version }}-torch${{ matrix.pytorch_version }}-cuda${{ matrix.cuda_version }} + latest-py${{ matrix.python_version }}-torch${{ matrix.pytorch_version }}-cuda${{ matrix.cuda_version }} timeout-minutes: 55 - name: Publish Latest to Docker uses: docker/build-push-action@v1.1.0 - # only on releases and latest Python and PyTorch - if: matrix.python_version == '3.9' && matrix.pytorch_version == '1.10' + # Only latest Python and PyTorch + if: matrix.python_version == '3.9' && matrix.pytorch_version == '1.12' with: repository: pytorchlightning/pytorch_lightning username: ${{ secrets.DOCKER_USERNAME }} password: ${{ secrets.DOCKER_PASSWORD }} dockerfile: dockers/release/Dockerfile - build_args: PYTHON_VERSION=${{ matrix.python_version }},PYTORCH_VERSION=${{ matrix.pytorch_version }},LIGHTNING_VERSION=${{ steps.get_version.outputs.RELEASE_VERSION }} + build_args: | + PYTHON_VERSION=${{ matrix.python_version }} + PYTORCH_VERSION=${{ matrix.pytorch_version }} + CUDA_VERSION=${{ matrix.cuda_version }} + LIGHTNING_VERSION=${{ steps.get_version.outputs.RELEASE_VERSION }} tags: "latest" timeout-minutes: 55 diff --git a/.gitignore b/.gitignore index 719f291a492ca..259d9f271189c 100644 --- a/.gitignore +++ b/.gitignore @@ -165,3 +165,9 @@ hars* artifacts/* *docs/examples* *docs/source-app/api* + +# tutorials +our_model.tar +test.png +saved_models +data/ diff --git a/README.md b/README.md index 2fef343425f17..6f075f5fd42b6 100644 --- a/README.md +++ b/README.md @@ -1,7 +1,9 @@ +### \*\* NEWS: PyTorch Lightning has been renamed Lightning! In addition to building models, you can now build research workflows and production pipelines\*\* +
-**Build high-performance PyTorch models and deploy them with Lightning Apps (scalable end-to-end ML systems).** +**Build high-performance (PyTorch) models, research workflows, ML production pipelines.** ______________________________________________________________________ @@ -80,21 +82,24 @@ ______________________________________________________________________ ## Continuous Integration -Lightning is rigorously tested across multiple GPUs, TPUs CPUs and against major Python and PyTorch versions. +Lightning is rigorously tested across multiple CPUs, GPUs, TPUs, IPUs, and HPUs and against major Python and PyTorch versions.
Current build statuses
-| System / PyTorch ver. | 1.8 (LTS, min. req.) | 1.9 | 1.10 (latest) | -| :------------------------: | :-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: | :--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: | :--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: | -| Linux py3.7 \[GPUs\*\*\] | [![Build Status]()](https://dev.azure.com/Lightning-AI/lightning/_build/latest?definitionId=6&branchName=master) | - | - | -| Linux py3.7 \[TPUs\*\*\*\] | [![CircleCI](https://circleci.com/gh/Lightning-AI/lightning/tree/master.svg?style=svg)](https://circleci.com/gh/Lightning-AI/lightning/tree/master) | - | - | -| Linux py3.8 (with Conda | [![Test](https://github.com/Lightning-AI/lightning/actions/workflows/ci_test-conda.yml/badge.svg?branch=master&event=push)](https://github.com/Lightning-AI/lightning/actions/workflows/ci_test-conda.yml) | [![Test](https://github.com/Lightning-AI/lightning/actions/workflows/ci_test-conda.yml/badge.svg?branch=master&event=push)](https://github.com/Lightning-AI/lightning/actions/workflows/ci_test-conda.yml) | [![Test](https://github.com/Lightning-AI/lightning/actions/workflows/ci_test-conda.yml/badge.svg?branch=master&event=push)](https://github.com/Lightning-AI/lightning/actions/workflows/ci_test-conda.yml) | -| Linux py3.{7,9} | - | - | [![Test](https://github.com/Lightning-AI/lightning/actions/workflows/ci_test-full.yml/badge.svg?branch=master&event=push)](https://github.com/Lightning-AI/lightning/actions/workflows/ci_test-full.yml) | -| OSX py3.{7,9} | - | - | [![Test](https://github.com/Lightning-AI/lightning/actions/workflows/ci_test-full.yml/badge.svg?branch=master&event=push)](https://github.com/Lightning-AI/lightning/actions/workflows/ci_test-full.yml) | -| Windows py3.{7,9} | - | - | [![Test](https://github.com/Lightning-AI/lightning/actions/workflows/ci_test-full.yml/badge.svg?branch=master&event=push)](https://github.com/Lightning-AI/lightning/actions/workflows/ci_test-full.yml) | +| System / PyTorch ver. | 1.9 | 1.10 | 1.12 (latest) | +| :------------------------: | :-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: | :-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: | :------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: | +| Linux py3.7 \[GPUs\*\*\] | - | - | - | +| Linux py3.7 \[TPUs\*\*\*\] | [![CircleCI](https://circleci.com/gh/Lightning-AI/lightning/tree/master.svg?style=svg)](https://circleci.com/gh/Lightning-AI/lightning/tree/master) | - | - | +| Linux py3.8 \[IPUs\] | [![Build Status]()](https://dev.azure.com/Lightning-AI/lightning/_build/latest?definitionId=25&branchName=master) | - | - | +| Linux py3.8 \[HPUs\] | - | [![Build Status]()](https://dev.azure.com/Lightning-AI/lightning/_build/latest?definitionId=26&branchName=master) | - | +| Linux py3.8 (with Conda) | [![Test](https://github.com/Lightning-AI/lightning/actions/workflows/ci-pytorch-test-conda.yml/badge.svg?branch=master&event=push)](https://github.com/Lightning-AI/lightning/actions/workflows/ci-pytorch-test-conda.yml) | [![Test](https://github.com/Lightning-AI/lightning/actions/workflows/ci-pytorch-test-conda.yml/badge.svg?branch=master&event=push)](https://github.com/Lightning-AI/lightning/actions/workflows/ci-pytorch-test-conda.yml) | - | +| Linux py3.9 (with Conda) | - | - | [![Test](https://github.com/Lightning-AI/lightning/actions/workflows/ci-pytorch-test-conda.yml/badge.svg?branch=master&event=push)](https://github.com/Lightning-AI/lightning/actions/workflows/ci-pytorch-test-conda.yml) | +| Linux py3.{7,9} | - | - | [![Test](https://github.com/Lightning-AI/lightning/actions/workflows/ci-pytorch-test-full.yml/badge.svg?branch=master&event=push)](https://github.com/Lightning-AI/lightning/actions/workflows/ci-pytorch-test-full.yml) | +| OSX py3.{7,9} | - | - | [![Test](https://github.com/Lightning-AI/lightning/actions/workflows/ci-pytorch-test-full.yml/badge.svg?branch=master&event=push)](https://github.com/Lightning-AI/lightning/actions/workflows/ci-pytorch-test-full.yml) | +| Windows py3.{7,9} | - | - | [![Test](https://github.com/Lightning-AI/lightning/actions/workflows/ci-pytorch-test-full.yml/badge.svg?branch=master&event=push)](https://github.com/Lightning-AI/lightning/actions/workflows/ci-pytorch-test-full.yml) | - _\*\* tests run on two NVIDIA P100_ - _\*\*\* tests run on Google GKE TPUv2/3. TPU py3.7 means we support Colab and Kaggle env._ @@ -136,8 +141,8 @@ conda install pytorch-lightning -c conda-forge The actual status of 1.7 \[stable\] is the following: -[![Test PyTorch full](https://github.com/Lightning-AI/lightning/actions/workflows/ci-pytorch_test-full.yml/badge.svg?branch=release%2Fpytorch&event=push)](https://github.com/Lightning-AI/lightning/actions/workflows/ci-pytorch_test-full.yml?query=branch%3Arelease%2Fpytorch) -[![Test PyTorch with Conda](https://github.com/Lightning-AI/lightning/actions/workflows/ci-pytorch_test-conda.yml/badge.svg?branch=release%2Fpytorch&event=push)](https://github.com/Lightning-AI/lightning/actions/workflows/ci-pytorch_test-conda.yml?query=branch%3Arelease%2Fpytorch) +[![Test PyTorch full](https://github.com/Lightning-AI/lightning/actions/workflows/ci-pytorch-test-full.yml/badge.svg?branch=release%2Fpytorch&event=push)](https://github.com/Lightning-AI/lightning/actions/workflows/ci-pytorch-test-full.yml?query=branch%3Arelease%2Fpytorch) +[![Test PyTorch with Conda](https://github.com/Lightning-AI/lightning/actions/workflows/ci-pytorch-test-conda.yml/badge.svg?branch=release%2Fpytorch&event=push)](https://github.com/Lightning-AI/lightning/actions/workflows/ci-pytorch-test-conda.yml?query=branch%3Arelease%2Fpytorch) [![TPU tests](https://dl.circleci.com/status-badge/img/gh/Lightning-AI/lightning/tree/release%2Fpytorch.svg?style=shield)](https://dl.circleci.com/status-badge/redirect/gh/Lightning-AI/lightning/tree/release%2Fpytorch) [![Check Docs](https://github.com/Lightning-AI/lightning/actions/workflows/docs-checks.yml/badge.svg?branch=release%2Fpytorch&event=push)](https://github.com/Lightning-AI/lightning/actions/workflows/docs-checks.yml?query=branch%3Arelease%2Fpytorch) diff --git a/SECURITY.md b/SECURITY.md index 8f265f26be452..862563f84e2fe 100644 --- a/SECURITY.md +++ b/SECURITY.md @@ -1,2 +1,2 @@ -developer@grid.ai +developer@lightning.ai developer@pytorchlightning.ai diff --git a/dockers/README.md b/dockers/README.md index 533c85739f528..b1ff9826b6c1f 100644 --- a/dockers/README.md +++ b/dockers/README.md @@ -1,36 +1,17 @@ # Docker images -## Builds images form attached Dockerfiles +## Build images from Dockerfiles You can build it on your own, note it takes lots of time, be prepared. ```bash -git clone -docker image build -t pytorch-lightning:latest -f dockers/conda/Dockerfile . -``` - -or with specific arguments - -```bash -git clone -docker image build \ - -t pytorch-lightning:base-cuda-py3.9-pt1.10 \ - -f dockers/base-cuda/Dockerfile \ - --build-arg PYTHON_VERSION=3.9 \ - --build-arg PYTORCH_VERSION=1.10 \ - . -``` +git clone https://github.com/Lightning-AI/lightning.git -or nightly version from Conda +# build with the default arguments +docker image build -t pytorch-lightning:latest -f dockers/base-cuda/Dockerfile . -```bash -git clone -docker image build \ - -t pytorch-lightning:base-conda-py3.9-pt1.11 \ - -f dockers/base-conda/Dockerfile \ - --build-arg PYTHON_VERSION=3.9 \ - --build-arg PYTORCH_VERSION=1.11 \ - . +# build with specific arguments +docker image build -t pytorch-lightning:base-cuda-py3.9-torch1.11-cuda11.3.1 -f dockers/base-cuda/Dockerfile --build-arg PYTHON_VERSION=3.9 --build-arg PYTORCH_VERSION=1.11 --build-arg CUDA_VERSION=11.3.1 . ``` To run your docker use @@ -49,7 +30,7 @@ docker image rm pytorch-lightning:latest ## Run docker image with GPUs -To run docker image with access to you GPUs you need to install +To run docker image with access to your GPUs, you need to install ```bash # Add the package repositories @@ -61,10 +42,10 @@ sudo apt-get update && sudo apt-get install -y nvidia-container-toolkit sudo systemctl restart docker ``` -and later run the docker image with `--gpus all` so for example +and later run the docker image with `--gpus all`. For example, ``` -docker run --rm -it --gpus all pytorchlightning/pytorch_lightning:base-cuda-py3.9-torch1.10 +docker run --rm -it --gpus all pytorchlightning/pytorch_lightning:base-cuda-py3.9-torch1.11-cuda11.3.1 ``` ## Run Jupyter server @@ -73,15 +54,11 @@ Inspiration comes from https://u.group/thinking/how-to-put-jupyter-notebooks-in- 1. Build the docker image: ```bash - docker image build \ - -t pytorch-lightning:v1.3.1 \ - -f dockers/nvidia/Dockerfile \ - --build-arg LIGHTNING_VERSION=1.3.1 \ - . + docker image build -t pytorch-lightning:v1.6.5 -f dockers/nvidia/Dockerfile --build-arg LIGHTNING_VERSION=1.6.5 . ``` 1. start the server and map ports: ```bash - docker run --rm -it --runtime=nvidia -e NVIDIA_VISIBLE_DEVICES=all -p 8888:8888 pytorch-lightning:v1.3.1 + docker run --rm -it --runtime=nvidia -e NVIDIA_VISIBLE_DEVICES=all -p 8888:8888 pytorch-lightning:v1.6.5 ``` 1. Connect in local browser: - copy the generated path e.g. `http://hostname:8888/?token=0719fa7e1729778b0cec363541a608d5003e26d4910983c6` diff --git a/dockers/release/Dockerfile b/dockers/release/Dockerfile index cb393c91dfbe0..c39e66509188c 100644 --- a/dockers/release/Dockerfile +++ b/dockers/release/Dockerfile @@ -14,8 +14,9 @@ ARG PYTHON_VERSION=3.9 ARG PYTORCH_VERSION=1.11 +ARG CUDA_VERSION=11.3.1 -FROM pytorchlightning/pytorch_lightning:base-cuda-py${PYTHON_VERSION}-torch${PYTORCH_VERSION} +FROM pytorchlightning/pytorch_lightning:base-cuda-py${PYTHON_VERSION}-torch${PYTORCH_VERSION}-cuda${CUDA_VERSION} LABEL maintainer="Lightning-AI " diff --git a/docs/source-pytorch/api_references.rst b/docs/source-pytorch/api_references.rst index db4fc1e2c4cf8..8daed5ddcaf41 100644 --- a/docs/source-pytorch/api_references.rst +++ b/docs/source-pytorch/api_references.rst @@ -47,6 +47,20 @@ callbacks Timer TQDMProgressBar +cli +----- + +.. currentmodule:: pytorch_lightning.cli + +.. autosummary:: + :toctree: api + :nosignatures: + :template: classtemplate.rst + + LightningCLI + LightningArgumentParser + SaveConfigCallback + core ---- diff --git a/examples/app_commands/.lightning b/examples/app_commands_and_api/.lightning similarity index 100% rename from examples/app_commands/.lightning rename to examples/app_commands_and_api/.lightning diff --git a/examples/app_commands/app.py b/examples/app_commands_and_api/app.py similarity index 56% rename from examples/app_commands/app.py rename to examples/app_commands_and_api/app.py index 99eb15c75c709..0d15bc531bb38 100644 --- a/examples/app_commands/app.py +++ b/examples/app_commands_and_api/app.py @@ -1,15 +1,16 @@ from command import CustomCommand, CustomConfig from lightning import LightningFlow +from lightning_app.api import Post from lightning_app.core.app import LightningApp class ChildFlow(LightningFlow): - def trigger_method(self, name: str): + def nested_command(self, name: str): print(f"Hello {name}") def configure_commands(self): - return [{"nested_trigger_command": self.trigger_method}] + return [{"nested_command": self.nested_command}] class FlowCommands(LightningFlow): @@ -19,21 +20,24 @@ def __init__(self): self.child_flow = ChildFlow() def run(self): - if len(self.names): + if self.names: print(self.names) - def trigger_without_client_command(self, name: str): + def command_without_client(self, name: str): self.names.append(name) - def trigger_with_client_command(self, config: CustomConfig): + def command_with_client(self, config: CustomConfig): self.names.append(config.name) def configure_commands(self): commands = [ - {"trigger_without_client_command": self.trigger_without_client_command}, - {"trigger_with_client_command": CustomCommand(self.trigger_with_client_command)}, + {"command_without_client": self.command_without_client}, + {"command_with_client": CustomCommand(self.command_with_client)}, ] return commands + self.child_flow.configure_commands() + def configure_api(self): + return [Post("/user/command_without_client", self.command_without_client)] + app = LightningApp(FlowCommands()) diff --git a/examples/app_commands/command.py b/examples/app_commands_and_api/command.py similarity index 100% rename from examples/app_commands/command.py rename to examples/app_commands_and_api/command.py diff --git a/pyproject.toml b/pyproject.toml index 5473e73c52e19..9f7cc28d0b002 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -52,19 +52,13 @@ module = [ "pytorch_lightning.callbacks.progress.rich_progress", "pytorch_lightning.callbacks.quantization", "pytorch_lightning.core.datamodule", - "pytorch_lightning.core.decorators", "pytorch_lightning.core.module", - "pytorch_lightning.core.saving", "pytorch_lightning.demos.boring_classes", "pytorch_lightning.demos.mnist_datamodule", "pytorch_lightning.profilers.base", "pytorch_lightning.profilers.pytorch", - "pytorch_lightning.profilers.simple", - "pytorch_lightning.strategies.ddp", "pytorch_lightning.strategies.sharded", - "pytorch_lightning.strategies.sharded_spawn", "pytorch_lightning.trainer.callback_hook", - "pytorch_lightning.trainer.connectors.callback_connector", "pytorch_lightning.trainer.connectors.data_connector", "pytorch_lightning.trainer.supporters", "pytorch_lightning.trainer.trainer", diff --git a/requirements/app/base.txt b/requirements/app/base.txt index 0a0b9cdb4719d..fcde2f18a300a 100644 --- a/requirements/app/base.txt +++ b/requirements/app/base.txt @@ -1,9 +1,8 @@ -py -lightning-cloud==0.5.0 +lightning-cloud==0.5.3 packaging -deepdiff>=5.7.0 +deepdiff>=5.7.0, <=5.8.1 starsessions -fsspec>=2022.01.0 -s3fs>=2022.1.0 +fsspec>=2022.01.0, <=2022.7.1 +s3fs>=2022.1.0, <=2022.7.1 croniter # for now until we found something more robust. traitlets<5.2.0 # Traitlets 5.2.X fails: https://github.com/ipython/traitlets/issues/741 diff --git a/requirements/app/cloud.txt b/requirements/app/cloud.txt index 5f8bf0c48692f..6644a56a2894b 100644 --- a/requirements/app/cloud.txt +++ b/requirements/app/cloud.txt @@ -1,5 +1,4 @@ starsessions redis>=4.0.0, <=4.2.4 -docker==5.0.3 -setuptools==59.5.0 -s3fs==2022.1.0 +docker>=5.0.0, <=5.0.3 +# setuptools==59.5.0 diff --git a/requirements/app/docs.txt b/requirements/app/docs.txt index b35cc585b40c7..bf22aef2c2d92 100644 --- a/requirements/app/docs.txt +++ b/requirements/app/docs.txt @@ -1,18 +1,17 @@ sphinx>=4.0,<5.0 -myst-parser>=0.15 -nbsphinx>=0.8.5 +myst-parser>=0.15,<0.17 +nbsphinx>=0.8.5, <=0.8.9 ipython[notebook] ipython_genutils -pandoc>=1.0 -docutils>=0.16 -sphinxcontrib-fulltoc>=1.0 +pandoc>=1.0, <=2.2 +docutils>=0.16, <0.19 +sphinxcontrib-fulltoc>=1.0, <=1.2.0 sphinxcontrib-mockautodoc https://storage.googleapis.com/grid-packages/lightning-ai-sphinx-theme/build-31.3.zip sphinx-autodoc-typehints>=1.0,<1.15 # v1.15 failing on master (#11405) -sphinx-paramlinks>=0.5.1 -sphinx-togglebutton>=0.2 -sphinx-copybutton>=0.3 +sphinx-paramlinks>=0.5.1, <=0.5.4 +sphinx-togglebutton>=0.2, <=0.3.2 +sphinx-copybutton>=0.3, <=0.5.0 sphinx-autobuild -typing-extensions # already in `requirements.txt` but the docs CI job does not install it jinja2>=3.0.0,<3.1.0 diff --git a/requirements/app/test.txt b/requirements/app/test.txt index 9d2ed0af910ca..ab5ef8f1e85ac 100644 --- a/requirements/app/test.txt +++ b/requirements/app/test.txt @@ -1,15 +1,10 @@ -coverage>=5.0 -codecov>=2.1 -pytest>=5.0 -pytest-timeout -pytest-cov +coverage>=6.4, <=6.4.2 +codecov>=2.1, <=2.1.12 +pytest>=7.0, <=7.1.2 +pytest-timeout <=2.1.0 +pytest-cov <=3.0.0 playwright==1.22.0 # pytest-flake8 -flake8>=3.0 -check-manifest -twine>=3.2 -isort>=5.0 -mypy>=0.720 httpx trio pympler diff --git a/requirements/app/ui.txt b/requirements/app/ui.txt index 28df7f9c2ffe0..f0e4b2cdef471 100644 --- a/requirements/app/ui.txt +++ b/requirements/app/ui.txt @@ -1 +1 @@ -streamlit>=1.3.1 +streamlit>=1.3.1, <=1.11.1 diff --git a/requirements/collect_env_details.py b/requirements/collect_env_details.py index 1d65753a55553..b0c47efc43859 100644 --- a/requirements/collect_env_details.py +++ b/requirements/collect_env_details.py @@ -20,27 +20,17 @@ import platform import sys -import numpy +import pkg_resources import torch -import tqdm sys.path += [os.path.abspath(".."), os.path.abspath("")] -import pytorch_lightning # noqa: E402 -try: - import lightning -except ModuleNotFoundError: - pass -try: - import lightning_app -except ModuleNotFoundError: - pass LEVEL_OFFSET = "\t" KEY_PADDING = 20 -def info_system(): +def info_system() -> dict: return { "OS": platform.system(), "architecture": platform.architecture(), @@ -50,28 +40,24 @@ def info_system(): } -def info_cuda(): +def info_cuda() -> dict: return { - "GPU": [torch.cuda.get_device_name(i) for i in range(torch.cuda.device_count())], - # 'nvidia_driver': get_nvidia_driver_version(run_lambda), + "GPU": [torch.cuda.get_device_name(i) for i in range(torch.cuda.device_count())] or None, "available": torch.cuda.is_available(), "version": torch.version.cuda, } -def info_packages(): - return { - "numpy": numpy.__version__, - "pyTorch_version": torch.__version__, - "pyTorch_debug": torch.version.debug, - "pytorch-lightning": pytorch_lightning.__version__, - "lightning": lightning.__version__ if "lightning" in sys.modules else None, - "lightning_app": lightning_app.__version__ if "lightning_app" in sys.modules else None, - "tqdm": tqdm.__version__, - } +def info_packages() -> dict: + """Get name and version of all installed packages.""" + packages = {} + for dist in pkg_resources.working_set: + package = dist.as_requirement() + packages[package.key] = package.specs[0][1] + return packages -def nice_print(details, level=0): +def nice_print(details: dict, level: int = 0) -> list: lines = [] for k in sorted(details): key = f"* {k}:" if level == 0 else f"- {k}:" @@ -88,8 +74,9 @@ def nice_print(details, level=0): return lines -def main(): +def main() -> None: details = {"System": info_system(), "CUDA": info_cuda(), "Packages": info_packages()} + details["Lightning"] = {k: v for k, v in details["Packages"].items() if "torch" in k or "lightning" in k} lines = nice_print(details) text = os.linesep.join(lines) print(text) diff --git a/requirements/pytorch/base.txt b/requirements/pytorch/base.txt index e8743b18c73b0..49e2243319206 100644 --- a/requirements/pytorch/base.txt +++ b/requirements/pytorch/base.txt @@ -3,7 +3,7 @@ numpy>=1.17.2, <1.23.1 torch>=1.9.*, <=1.12.0 -tqdm>=4.57.0, <=4.63.0 +tqdm>=4.57.0, <4.65.0 PyYAML>=5.4, <=6.0 fsspec[http]>=2021.05.0, !=2021.06.0, <2022.6.0 tensorboard>=2.9.1, <2.10.0 diff --git a/requirements/pytorch/docs.txt b/requirements/pytorch/docs.txt index e6fbbe322b6bf..50e7c2049f6f6 100644 --- a/requirements/pytorch/docs.txt +++ b/requirements/pytorch/docs.txt @@ -1,16 +1,16 @@ sphinx>=4.0,<5.0 myst-parser>=0.15,<0.17 -nbsphinx>=0.8.5 +nbsphinx>=0.8.5, <=0.8.9 ipython[notebook] -pandoc>=1.0 -docutils>=0.16 -sphinxcontrib-fulltoc>=1.0 +pandoc>=1.0, <=2.2 +docutils>=0.16, <0.19 +sphinxcontrib-fulltoc>=1.0, <=1.2.0 sphinxcontrib-mockautodoc pt-lightning-sphinx-theme @ https://github.com/Lightning-AI/lightning_sphinx_theme/archive/master.zip -sphinx-autodoc-typehints>=1.11,<1.15 # v1.15 failing on master (#11405) -sphinx-paramlinks>=0.5.1 -sphinx-togglebutton>=0.2 -sphinx-copybutton>=0.3 +sphinx-autodoc-typehints>=1.11,<1.15 # strict; v1.15 failing on master (#11405) +sphinx-paramlinks>=0.5.1, <=0.5.4 +sphinx-togglebutton>=0.2, <=0.3.2 +sphinx-copybutton>=0.3, <=0.5.0 typing-extensions # already in `requirements.txt` but the docs CI job does not install it jinja2>=3.0.0,<3.1.0 diff --git a/requirements/pytorch/extra.txt b/requirements/pytorch/extra.txt index c386c5581cc42..20b6c1b8dbc12 100644 --- a/requirements/pytorch/extra.txt +++ b/requirements/pytorch/extra.txt @@ -7,5 +7,5 @@ torchtext>=0.10.*, <0.14.0 omegaconf>=2.0.5, <2.3.0 hydra-core>=1.0.5, <1.3.0 jsonargparse[signatures]>=4.12.0, <=4.12.0 -gcsfs>=2021.5.0, <2022.6.0 +gcsfs>=2021.5.0, <2022.8.0 rich>=10.14.0, !=10.15.0.a, <13.0.0 diff --git a/requirements/pytorch/loggers.txt b/requirements/pytorch/loggers.txt index 48a15c30f842f..df83a077f8457 100644 --- a/requirements/pytorch/loggers.txt +++ b/requirements/pytorch/loggers.txt @@ -7,4 +7,4 @@ neptune-client>=0.10.0, <0.16.4 comet-ml>=3.1.12, <3.31.8 mlflow>=1.0.0, <1.28.0 test_tube>=0.7.5, <=0.7.5 -wandb>=0.10.22, <0.12.20 +wandb>=0.10.22, <0.13.2 diff --git a/requirements/pytorch/strategies.txt b/requirements/pytorch/strategies.txt index 4e916fbc6c61f..c5fc92a67a837 100644 --- a/requirements/pytorch/strategies.txt +++ b/requirements/pytorch/strategies.txt @@ -2,7 +2,7 @@ # in case you want to preserve/enforce restrictions on the latest compatible version, add "strict" as an in-line comment fairscale>=0.4.5, <=0.4.6 -deepspeed>=0.6.0, <0.7.0 +deepspeed>=0.6.0, <=0.7.0 # no need to install with [pytorch] as pytorch is already installed horovod>=0.21.2, !=0.24.0, <0.25.1 hivemind>=1.0.1, <=1.0.1; sys_platform == 'linux' diff --git a/requirements/pytorch/test.txt b/requirements/pytorch/test.txt index ce54cd087b1de..f8bd5793a0af6 100644 --- a/requirements/pytorch/test.txt +++ b/requirements/pytorch/test.txt @@ -1,18 +1,17 @@ -coverage>=6.4 -codecov>=2.1 -pytest>=7.0 -pytest-cov -pytest-forked +coverage>=6.4, <=6.4.2 +codecov>=2.1, <=2.1.12 +pytest>=7.0, <=7.1.2 +pytest-cov <=3.0.0 +pytest-forked <=1.4.0 pytest-rerunfailures>=10.2 -mypy>=0.920 -flake8>=3.9.2 pre-commit>=1.0 +mypy==0.971 # needed in tests -cloudpickle>=1.3 -scikit-learn>0.22.1 -onnxruntime -psutil # for `DeviceStatsMonitor` -pandas # needed in benchmarks -fastapi -uvicorn +cloudpickle>=1.3, <=2.1.0 +scikit-learn>0.22.1, <=1.1.1 +onnxruntime<1.13.0 +psutil<=5.9.1 # for `DeviceStatsMonitor` +pandas>1.0, <=1.4.3 # needed in benchmarks +fastapi<=0.79.0 +uvicorn<=0.18.2 diff --git a/src/lightning_app/CHANGELOG.md b/src/lightning_app/CHANGELOG.md index 07927a1b01f87..2aa5c7cdd837c 100644 --- a/src/lightning_app/CHANGELOG.md +++ b/src/lightning_app/CHANGELOG.md @@ -9,21 +9,87 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Added - Add support for `Lightning App Commands` through the `configure_commands` hook on the Lightning Flow and the `ClientCommand` ([#13602](https://github.com/Lightning-AI/lightning/pull/13602)) + + - Add support for Lightning AI BYOC cluster management ([#13835](https://github.com/Lightning-AI/lightning/pull/13835)) + + - Add support to run Lightning apps on Lightning AI BYOC clusters ([#13894](https://github.com/Lightning-AI/lightning/pull/13894)) + + - Add support for listing Lightning AI apps ([#13987](https://github.com/Lightning-AI/lightning/pull/13987)) + + - Adds `LightningTrainingComponent`. `LightningTrainingComponent` orchestrates multi-node training in the cloud ([#13830](https://github.com/Lightning-AI/lightning/pull/13830)) +- Add support for printing application logs using CLI `lightning show logs [components]` ([#13634](https://github.com/Lightning-AI/lightning/pull/13634)) + + +- Add support for `Lightning API` through the `configure_api` hook on the Lightning Flow and the `Post`, `Get`, `Delete`, `Put` HttpMethods ([#13945](https://github.com/Lightning-AI/lightning/pull/13945)) ### Changed -- Update the Lightning App docs ([#13537](https://github.com/Lightning-AI/lightning/pull/13537)) +- Default values and parameter names for Lightning AI BYOC cluster management ([#14132](https://github.com/Lightning-AI/lightning/pull/14132)) + ### Changed -- Added `LIGHTNING_` prefix to Platform AWS credentials ([#13703](https://github.com/Lightning-AI/lightning/pull/13703)) +- + + +- Run the flow only if the state has changed from the previous execution ([#14076](https://github.com/Lightning-AI/lightning/pull/14076)) + +### Deprecated + +- + + +### Fixed + +- + + +## [0.5.5] - 2022-08-9 ### Deprecated +- Deprecate sheety API ([#14004](https://github.com/Lightning-AI/lightning/pull/14004)) + ### Fixed - Resolved a bug where the work statuses will grow quickly and be duplicated ([#13970](https://github.com/Lightning-AI/lightning/pull/13970)) +- Resolved a bug about a race condition when sending the work state through the caller_queue ([#14074](https://github.com/Lightning-AI/lightning/pull/14074)) +- Fixed Start Lightning App on Cloud if Repo Begins With Name "Lightning" ([#14025](https://github.com/Lightning-AI/lightning/pull/14025)) + + +## [0.5.4] - 2022-08-01 + +### Changed + +- Wrapped imports for traceability ([#13924](https://github.com/Lightning-AI/lightning/pull/13924)) +- Set version as today ([#13906](https://github.com/Lightning-AI/lightning/pull/13906)) + +### Fixed + +- Included app templates to the lightning and app packages ([#13731](https://github.com/Lightning-AI/lightning/pull/13731)) +- Added UI for install all ([#13732](https://github.com/Lightning-AI/lightning/pull/13732)) +- Fixed build meta pkg flow ([#13926](https://github.com/Lightning-AI/lightning/pull/13926)) + +## [0.5.3] - 2022-07-25 + +### Changed + +- Pruned requirements duplicity ([#13739](https://github.com/Lightning-AI/lightning/pull/13739)) + +### Fixed + +- Use correct python version in lightning component template ([#13790](https://github.com/Lightning-AI/lightning/pull/13790)) + +## [0.5.2] - 2022-07-18 + +### Added + +- Update the Lightning App docs ([#13537](https://github.com/Lightning-AI/lightning/pull/13537)) + +### Changed + +- Added `LIGHTNING_` prefix to Platform AWS credentials ([#13703](https://github.com/Lightning-AI/lightning/pull/13703)) diff --git a/src/lightning_app/api/__init__.py b/src/lightning_app/api/__init__.py new file mode 100644 index 0000000000000..25ec5c4708761 --- /dev/null +++ b/src/lightning_app/api/__init__.py @@ -0,0 +1,3 @@ +from lightning_app.api.http_methods import Delete, Get, Post, Put + +__all__ = ["Delete", "Get", "Post", "Put"] diff --git a/src/lightning_app/api/http_methods.py b/src/lightning_app/api/http_methods.py new file mode 100644 index 0000000000000..02b6ec87f17d2 --- /dev/null +++ b/src/lightning_app/api/http_methods.py @@ -0,0 +1,107 @@ +import asyncio +import inspect +import time +from copy import deepcopy +from functools import wraps +from multiprocessing import Queue +from typing import Any, Callable, Dict, List, Optional +from uuid import uuid4 + +from fastapi import FastAPI + +from lightning_app.api.request_types import APIRequest, CommandRequest + + +def _signature_proxy_function(): + pass + + +class HttpMethod: + def __init__(self, route: str, method: Callable, method_name: Optional[str] = None, timeout: int = 30, **kwargs): + """This class is used to inject user defined methods within the App Rest API. + + Arguments: + route: The path used to route the requests + method: The associated flow method + timeout: The time in seconds taken before raising a timeout exception. + """ + self.route = route + self.component_name = method.__self__.name + self.method_name = method_name or method.__name__ + self.method_annotations = method.__annotations__ + # TODO: Validate the signature contains only pydantic models. + self.method_signature = inspect.signature(method) + self.timeout = timeout + self.kwargs = kwargs + + def add_route(self, app: FastAPI, request_queue: Queue, responses_store: Dict[str, Any]) -> None: + # 1: Create a proxy function with the signature of the wrapped method. + fn = deepcopy(_signature_proxy_function) + fn.__annotations__ = self.method_annotations + fn.__name__ = self.method_name + setattr(fn, "__signature__", self.method_signature) + + # 2: Get the route associated with the http method. + route = getattr(app, self.__class__.__name__.lower()) + + request_cls = CommandRequest if self.route.startswith("/command/") else APIRequest + + # 3: Define the request handler. + @wraps(_signature_proxy_function) + async def _handle_request(*args, **kwargs): + async def fn(*args, **kwargs): + request_id = str(uuid4()).split("-")[0] + request_queue.put( + request_cls( + name=self.component_name, + method_name=self.method_name, + args=args, + kwargs=kwargs, + id=request_id, + ) + ) + + t0 = time.time() + while request_id not in responses_store: + await asyncio.sleep(0.1) + if (time.time() - t0) > self.timeout: + raise Exception("The response was never received.") + + return responses_store.pop(request_id) + + return await asyncio.create_task(fn(*args, **kwargs)) + + # 4: Register the user provided route to the Rest API. + route(self.route, **self.kwargs)(_handle_request) + + +class Post(HttpMethod): + pass + + +class Get(HttpMethod): + + pass + + +class Put(HttpMethod): + + pass + + +class Delete(HttpMethod): + pass + + +def _add_tags_to_api(apis: List[HttpMethod], tags: List[str]) -> None: + for api in apis: + if not api.kwargs.get("tag"): + api.kwargs["tags"] = tags + + +def _validate_api(apis: List[HttpMethod]) -> None: + for api in apis: + if not isinstance(api, HttpMethod): + raise Exception(f"The provided api should be either [{Delete}, {Get}, {Post}, {Put}]") + if api.route.startswith("/command"): + raise Exception("The route `/command` is reserved for commands. Please, use something else.") diff --git a/src/lightning_app/api/request_types.py b/src/lightning_app/api/request_types.py new file mode 100644 index 0000000000000..53a6df25820a3 --- /dev/null +++ b/src/lightning_app/api/request_types.py @@ -0,0 +1,36 @@ +from dataclasses import asdict, dataclass +from typing import Any + +from deepdiff import Delta + + +@dataclass +class BaseRequest: + def to_dict(self): + return asdict(self) + + +@dataclass +class DeltaRequest(BaseRequest): + delta: Delta + + def to_dict(self): + return self.delta.to_dict() + + +@dataclass +class CommandRequest(BaseRequest): + id: str + name: str + method_name: str + args: Any + kwargs: Any + + +@dataclass +class APIRequest(BaseRequest): + id: str + name: str + method_name: str + args: Any + kwargs: Any diff --git a/src/lightning_app/cli/lightning_cli.py b/src/lightning_app/cli/lightning_cli.py index fb4c40330dfd9..81d2a773b4619 100644 --- a/src/lightning_app/cli/lightning_cli.py +++ b/src/lightning_app/cli/lightning_cli.py @@ -4,11 +4,12 @@ from argparse import ArgumentParser from pathlib import Path from typing import List, Tuple, Union -from uuid import uuid4 import click import requests +import rich from requests.exceptions import ConnectionError +from rich.color import ANSI_COLOR_NAMES from lightning_app import __version__ as ver from lightning_app.cli import cmd_init, cmd_install, cmd_pl_init, cmd_react_ui_init @@ -18,13 +19,16 @@ from lightning_app.core.constants import get_lightning_cloud_url, LOCAL_LAUNCH_ADMIN_VIEW from lightning_app.runners.runtime import dispatch from lightning_app.runners.runtime_type import RuntimeType +from lightning_app.utilities.app_logs import _app_logs_reader from lightning_app.utilities.cli_helpers import ( _format_input_env_variables, _retrieve_application_url_and_available_commands, ) +from lightning_app.utilities.cloud import _get_project +from lightning_app.utilities.enum import OpenAPITags from lightning_app.utilities.install_components import register_all_external_components from lightning_app.utilities.login import Auth -from lightning_app.utilities.state import headers_for +from lightning_app.utilities.network import LightningClient logger = logging.getLogger(__name__) @@ -50,12 +54,96 @@ def main(): @click.version_option(ver) def _main(): register_all_external_components() + + +@_main.group() +def show(): + """Show given resource.""" pass +@show.command() +@click.argument("app_name", required=False) +@click.argument("components", nargs=-1, required=False) +@click.option("-f", "--follow", required=False, is_flag=True, help="Wait for new logs, to exit use CTRL+C.") +def logs(app_name: str, components: List[str], follow: bool) -> None: + """Show cloud application logs. By default prints logs for all currently available components. + + Example uses: + + Print all application logs: + + $ lightning show logs my-application + + + Print logs only from the flow (no work): + + $ lightning show logs my-application flow + + + Print logs only from selected works: + + $ lightning show logs my-application root.work_a root.work_b + """ + + client = LightningClient() + project = _get_project(client) + + apps = { + app.name: app + for app in client.lightningapp_instance_service_list_lightningapp_instances(project.project_id).lightningapps + } + + if not apps: + raise click.ClickException( + "You don't have any application in the cloud. Please, run an application first with `--cloud`." + ) + + if not app_name: + raise click.ClickException( + f"You have not specified any Lightning App. Please select one of available: [{', '.join(apps.keys())}]" + ) + + if app_name not in apps: + raise click.ClickException( + f"The Lightning App '{app_name}' does not exist. Please select one of following: [{', '.join(apps.keys())}]" + ) + + # Fetch all lightning works from given application + # 'Flow' component is somewhat implicit, only one for whole app, + # and not listed in lightningwork API - so we add it directly to the list + works = client.lightningwork_service_list_lightningwork( + project_id=project.project_id, app_id=apps[app_name].id + ).lightningworks + app_component_names = ["flow"] + [f.name for f in apps[app_name].spec.flow_servers] + [w.name for w in works] + + if not components: + components = app_component_names + + for component in components: + if component not in app_component_names: + raise click.ClickException(f"Component '{component}' does not exist in app {app_name}.") + + log_reader = _app_logs_reader( + client=client, + project_id=project.project_id, + app_id=apps[app_name].id, + component_names=components, + follow=follow, + ) + + rich_colors = list(ANSI_COLOR_NAMES) + colors = {c: rich_colors[i + 1] for i, c in enumerate(components)} + + for log_event in log_reader: + date = log_event.timestamp.strftime("%m/%d/%Y %H:%M:%S") + color = colors[log_event.component_name] + rich.print(f"[{color}]{log_event.component_name}[/{color}] {date} {log_event.message}") + + @_main.command() def login(): - """Log in to your Lightning.ai account.""" + """Log in to your lightning.ai account.""" auth = Auth() auth.clear() @@ -68,7 +156,7 @@ def login(): @_main.command() def logout(): - """Log out of your Lightning.ai account.""" + """Log out of your lightning.ai account.""" Auth().clear() @@ -127,7 +215,7 @@ def on_before_run(*args): @_main.group() def run(): - """Run your application.""" + """Run a Lightning application locally or on the cloud.""" @run.command("app") @@ -174,41 +262,42 @@ def app_command(): hparams, argv = parser.parse_known_args() # 1: Collect the url and comments from the running application - url, commands = _retrieve_application_url_and_available_commands(hparams.app_id) - if url is None or commands is None: + url, api_commands = _retrieve_application_url_and_available_commands(hparams.app_id) + if url is None or api_commands is None: raise Exception("We couldn't find any matching running app.") - if not commands: + if not api_commands: raise Exception("This application doesn't expose any commands yet.") command = argv[0] - command_names = [c["command"] for c in commands] - if command not in command_names: - raise Exception(f"The provided command {command} isn't available in {command_names}") + if command not in api_commands: + raise Exception(f"The provided command {command} isn't available in {list(api_commands)}") # 2: Send the command from the user - command_metadata = [c for c in commands if c["command"] == command][0] - params = command_metadata["params"] + metadata = api_commands[command] # 3: Execute the command - if not command_metadata["is_client_command"]: - # TODO: Improve what is supported there. - kwargs = {k.split("=")[0].replace("--", ""): k.split("=")[1] for k in argv[1:]} - for param in params: - if param not in kwargs: - raise Exception(f"The argument --{param}=X hasn't been provided.") - json = { - "command_name": command, - "command_arguments": kwargs, - "affiliation": command_metadata["affiliation"], - "id": str(uuid4()), - } - resp = requests.post(url + "/api/v1/commands", json=json, headers=headers_for({})) + if metadata["tag"] == OpenAPITags.APP_COMMAND: + # TODO: Improve what is current supported + kwargs = [v.replace("--", "") for v in argv[1:]] + + for p in kwargs: + if p.split("=")[0] not in metadata["parameters"]: + raise Exception(f"Some arguments need to be provided. The keys are {list(metadata['parameters'])}.") + # TODO: Encode the parameters and validate their type. + query_parameters = "&".join(kwargs) + resp = requests.post(url + f"/command/{command}?{query_parameters}") assert resp.status_code == 200, resp.json() else: - client_command, models = _download_command(command_metadata, hparams.app_id, debug_mode=debug_mode) - client_command._setup(metadata=command_metadata, models=models, app_url=url) + client_command = _download_command( + command, + metadata["cls_path"], + metadata["cls_name"], + hparams.app_id, + debug_mode=debug_mode, + ) + client_command._setup(command_name=command, app_url=url) sys.argv = argv client_command.run() @@ -232,7 +321,7 @@ def stop(): @_main.group() def install(): - """Install Lightning apps and components.""" + """Install a Lightning App and/or component.""" @install.command("app") @@ -290,7 +379,7 @@ def install_component(name, yes, version): @_main.group() def init(): - """Init a Lightning app and component.""" + """Init a Lightning App and/or component.""" @init.command("app") diff --git a/src/lightning_app/cli/lightning_cli_create.py b/src/lightning_app/cli/lightning_cli_create.py index 7e45fe7e7c078..7e9a6b9d2143b 100644 --- a/src/lightning_app/cli/lightning_cli_create.py +++ b/src/lightning_app/cli/lightning_cli_create.py @@ -5,7 +5,7 @@ @click.group("create") def create(): - """Create Lightning AI BYOC managed resources.""" + """Create Lightning AI self-managed resources (clusters, etc…)""" pass @@ -33,14 +33,14 @@ def create(): help="Instance types that you want to support, for computer jobs within the cluster.", ) @click.option( - "--cost-savings", - "cost_savings", + "--enable-performance", + "enable_performance", type=bool, required=False, default=False, is_flag=True, - help=""""Use this flag to ensure that the cluster is created with a profile that is optimized for cost savings. - This makes runs cheaper but start-up times may increase.""", + help=""""Use this flag to ensure that the cluster is created with a profile that is optimized for performance. + This makes runs more expensive but start-up times decrease.""", ) @click.option( "--edit-before-creation", @@ -65,12 +65,12 @@ def create_cluster( provider: str, instance_types: str, edit_before_creation: bool, - cost_savings: bool, + enable_performance: bool, wait: bool, **kwargs, ): """Create a Lightning AI BYOC compute cluster with your cloud provider credentials.""" - if provider != "aws": + if provider.lower() != "aws": click.echo("Only AWS is supported for now. But support for more providers is coming soon.") return cluster_manager = AWSClusterManager() @@ -79,8 +79,8 @@ def create_cluster( region=region, role_arn=role_arn, external_id=external_id, - instance_types=instance_types.split(","), + instance_types=instance_types.split(",") if instance_types is not None else None, edit_before_creation=edit_before_creation, - cost_savings=cost_savings, + cost_savings=not enable_performance, wait=wait, ) diff --git a/src/lightning_app/cli/lightning_cli_delete.py b/src/lightning_app/cli/lightning_cli_delete.py index c304b130bdf5d..366f4aa01e995 100644 --- a/src/lightning_app/cli/lightning_cli_delete.py +++ b/src/lightning_app/cli/lightning_cli_delete.py @@ -5,7 +5,7 @@ @click.group("delete") def delete(): - """Delete Lightning AI BYOC managed resources.""" + """Delete Lightning AI self-managed resources (clusters, etc…)""" pass diff --git a/src/lightning_app/cli/lightning_cli_list.py b/src/lightning_app/cli/lightning_cli_list.py index d0d1d34a6dd4d..7d38b5b57760f 100644 --- a/src/lightning_app/cli/lightning_cli_list.py +++ b/src/lightning_app/cli/lightning_cli_list.py @@ -6,7 +6,7 @@ @click.group(name="list") def get_list(): - """List your Lightning AI BYOC managed resources.""" + """List Lightning AI self-managed resources (clusters, etc…)""" pass diff --git a/src/lightning_app/core/api.py b/src/lightning_app/core/api.py index f19ada5340d57..8b625713e0c2c 100644 --- a/src/lightning_app/core/api.py +++ b/src/lightning_app/core/api.py @@ -3,7 +3,6 @@ import os import queue import sys -import time import traceback from copy import deepcopy from multiprocessing import Queue @@ -21,9 +20,12 @@ from pydantic import BaseModel from websockets.exceptions import ConnectionClosed +from lightning_app.api.http_methods import HttpMethod +from lightning_app.api.request_types import DeltaRequest from lightning_app.core.constants import FRONTEND_DIR from lightning_app.core.queues import RedisQueue from lightning_app.utilities.app_helpers import InMemoryStateStore, StateStore +from lightning_app.utilities.enum import OpenAPITags from lightning_app.utilities.imports import _is_redis_available, _is_starsessions_available if _is_starsessions_available(): @@ -42,9 +44,6 @@ class SessionMiddleware: frontend_static_dir = os.path.join(FRONTEND_DIR, "static") api_app_delta_queue: Queue = None -api_commands_requests_queue: Queue = None -api_commands_metadata_queue: Queue = None -api_commands_responses_queue: Queue = None template = {"ui": {}, "app": {}} templates = Jinja2Templates(directory=FRONTEND_DIR) @@ -56,8 +55,8 @@ class SessionMiddleware: lock = Lock() app_spec: Optional[List] = None -app_commands_metadata: Optional[Dict] = None -commands_response_store = {} +# In the future, this would be abstracted to support horizontal scaling. +responses_store = {} logger = logging.getLogger(__name__) @@ -67,11 +66,10 @@ class SessionMiddleware: class UIRefresher(Thread): - def __init__(self, api_publish_state_queue, api_commands_metadata_queue, api_commands_responses_queue) -> None: + def __init__(self, api_publish_state_queue, api_response_queue) -> None: super().__init__(daemon=True) self.api_publish_state_queue = api_publish_state_queue - self.api_commands_metadata_queue = api_commands_metadata_queue - self.api_commands_responses_queue = api_commands_responses_queue + self.api_response_queue = api_response_queue self._exit_event = Event() def run(self): @@ -93,18 +91,11 @@ def run_once(self): pass try: - metadata = self.api_commands_metadata_queue.get(timeout=0) + response = self.api_response_queue.get(timeout=0) with lock: - global app_commands_metadata - app_commands_metadata = metadata - except queue.Empty: - pass - - try: - response = self.api_commands_responses_queue.get(timeout=0) - with lock: - global commands_response_store - commands_response_store[response["id"]] = response["response"] + # TODO: Abstract the responses store to support horizontal scaling. + global responses_store + responses_store[response["id"]] = response["response"] except queue.Empty: pass @@ -117,6 +108,23 @@ class StateUpdate(BaseModel): state: dict = {} +openapi_tags = [ + { + "name": OpenAPITags.APP_CLIENT_COMMAND, + "description": "The App Endpoints to be triggered exclusively from the CLI", + }, + { + "name": OpenAPITags.APP_COMMAND, + "description": "The App Endpoints that can be triggered equally from the CLI or from a Http Request", + }, + { + "name": OpenAPITags.APP_API, + "description": "The App Endpoints that can be triggered exclusively from a Http Request", + }, +] + +app = FastAPI(openapi_tags=openapi_tags) + fastapi_service = FastAPI() fastapi_service.add_middleware( @@ -176,50 +184,13 @@ async def get_spec( return app_spec or [] -@fastapi_service.post("/api/v1/commands", response_class=JSONResponse) -async def run_remote_command( - request: Request, -) -> None: - data = await request.json() - command_name = data.get("command_name", None) - if not command_name: - raise Exception("The provided command name is empty.") - command_arguments = data.get("command_arguments", None) - if not command_arguments: - raise Exception("The provided command metadata is empty.") - affiliation = data.get("affiliation", None) - if not affiliation: - raise Exception("The provided affiliation is empty.") - - async def fn(data): - request_id = data["id"] - api_commands_requests_queue.put(data) - - t0 = time.time() - while request_id not in commands_response_store: - await asyncio.sleep(0.1) - if (time.time() - t0) > 15: - raise Exception("The response was never received.") - - return commands_response_store[request_id] - - return await asyncio.create_task(fn(data)) - - -@fastapi_service.get("/api/v1/commands", response_class=JSONResponse) -async def get_commands() -> Optional[Dict]: - global app_commands_metadata - with lock: - return app_commands_metadata - - @fastapi_service.post("/api/v1/delta") async def post_delta( request: Request, x_lightning_type: Optional[str] = Header(None), x_lightning_session_uuid: Optional[str] = Header(None), x_lightning_session_id: Optional[str] = Header(None), -) -> Mapping: +) -> None: """This endpoint is used to make an update to the app state using delta diff, mainly used by streamlit to update the state.""" @@ -229,9 +200,7 @@ async def post_delta( raise Exception("Missing X-Lightning-Session-ID header") body: Dict = await request.json() - delta = body["delta"] - update_delta = Delta(delta) - api_app_delta_queue.put(update_delta) + api_app_delta_queue.put(DeltaRequest(delta=Delta(body["delta"]))) @fastapi_service.post("/api/v1/state") @@ -240,7 +209,7 @@ async def post_state( x_lightning_type: Optional[str] = Header(None), x_lightning_session_uuid: Optional[str] = Header(None), x_lightning_session_id: Optional[str] = Header(None), -) -> Mapping: +) -> None: if x_lightning_session_uuid is None: raise Exception("Missing X-Lightning-Session-UUID header") if x_lightning_session_id is None: @@ -263,8 +232,7 @@ async def post_state( state = body["state"] last_state = global_app_state_store.get_served_state(x_lightning_session_uuid) deep_diff = DeepDiff(last_state, state, verbose_level=2) - update_delta = Delta(deep_diff) - api_app_delta_queue.put(update_delta) + api_app_delta_queue.put(DeltaRequest(delta=Delta(deep_diff))) @fastapi_service.get("/healthz", status_code=200) @@ -307,8 +275,6 @@ async def websocket_endpoint(websocket: WebSocket): await websocket.close() -# Catch-all for nonexistent API routes (since we define a catch-all for client-side routing) -@fastapi_service.get("/api{full_path:path}", response_class=JSONResponse) async def api_catch_all(request: Request, full_path: str): raise HTTPException(status_code=404, detail="Not found") @@ -317,14 +283,18 @@ async def api_catch_all(request: Request, full_path: str): fastapi_service.mount("/static", StaticFiles(directory=frontend_static_dir, check_dir=False), name="static") -# Catch-all for frontend routes, must be defined after all other routes -@fastapi_service.get("/{full_path:path}", response_class=HTMLResponse) async def frontend_route(request: Request, full_path: str): if "pytest" in sys.modules: return "" return templates.TemplateResponse("index.html", {"request": request}) +def register_global_routes(): + # Catch-all for nonexistent API routes (since we define a catch-all for client-side routing) + fastapi_service.get("/api{full_path:path}", response_class=JSONResponse)(api_catch_all) + fastapi_service.get("/{full_path:path}", response_class=HTMLResponse)(frontend_route) + + class LightningUvicornServer(uvicorn.Server): has_started_queue = None @@ -346,34 +316,28 @@ async def check_is_started(self, queue): def start_server( api_publish_state_queue, api_delta_queue, - commands_requests_queue, - commands_responses_queue, - commands_metadata_queue, + api_response_queue, has_started_queue: Optional[Queue] = None, host="127.0.0.1", port=8000, uvicorn_run: bool = True, spec: Optional[List] = None, + apis: Optional[List[HttpMethod]] = None, app_state_store: Optional[StateStore] = None, ): global api_app_delta_queue global global_app_state_store - global api_commands_requests_queue - global api_commands_responses_queue global app_spec app_spec = spec api_app_delta_queue = api_delta_queue - api_commands_requests_queue = commands_requests_queue - api_commands_responses_queue = commands_responses_queue - api_commands_metadata_queue = commands_metadata_queue if app_state_store is not None: global_app_state_store = app_state_store global_app_state_store.add(TEST_SESSION_UUID) - refresher = UIRefresher(api_publish_state_queue, api_commands_metadata_queue, commands_responses_queue) + refresher = UIRefresher(api_publish_state_queue, api_response_queue) refresher.setDaemon(True) refresher.start() @@ -384,6 +348,14 @@ def start_server( LightningUvicornServer.has_started_queue = has_started_queue # uvicorn is doing some uglyness by replacing uvicorn.main by click command. sys.modules["uvicorn.main"].Server = LightningUvicornServer + + # Register the user API. + if apis: + for api in apis: + api.add_route(fastapi_service, api_app_delta_queue, responses_store) + + register_global_routes() + uvicorn.run(app=fastapi_service, host=host, port=port, log_level="error") return refresher diff --git a/src/lightning_app/core/app.py b/src/lightning_app/core/app.py index 584f94285c219..65242a1ae0a2a 100644 --- a/src/lightning_app/core/app.py +++ b/src/lightning_app/core/app.py @@ -11,12 +11,13 @@ from deepdiff import DeepDiff, Delta import lightning_app +from lightning_app.api.request_types import APIRequest, CommandRequest, DeltaRequest from lightning_app.core.constants import FLOW_DURATION_SAMPLES, FLOW_DURATION_THRESHOLD, STATE_ACCUMULATE_WAIT from lightning_app.core.queues import BaseQueue, SingleProcessQueue from lightning_app.frontend import Frontend from lightning_app.storage.path import storage_root_dir -from lightning_app.utilities.app_helpers import _delta_to_appstate_delta, _LightningAppRef -from lightning_app.utilities.commands.base import _populate_commands_endpoint, _process_command_requests +from lightning_app.utilities.app_helpers import _delta_to_app_state_delta, _LightningAppRef +from lightning_app.utilities.commands.base import _process_requests from lightning_app.utilities.component import _convert_paths_after_init from lightning_app.utilities.enum import AppStage, CacheCallsKeys from lightning_app.utilities.exceptions import CacheMissException, ExitAppException @@ -73,9 +74,7 @@ def __init__( # queues definition. self.delta_queue: t.Optional[BaseQueue] = None self.readiness_queue: t.Optional[BaseQueue] = None - self.commands_requests_queue: t.Optional[BaseQueue] = None - self.commands_responses_queue: t.Optional[BaseQueue] = None - self.commands_metadata_queue: t.Optional[BaseQueue] = None + self.api_response_queue: t.Optional[BaseQueue] = None self.api_publish_state_queue: t.Optional[BaseQueue] = None self.api_delta_queue: t.Optional[BaseQueue] = None self.error_queue: t.Optional[BaseQueue] = None @@ -94,7 +93,7 @@ def __init__( self.processes: t.Dict[str, WorkManager] = {} self.frontends: t.Dict[str, Frontend] = {} self.stage = AppStage.RUNNING - self._has_updated: bool = False + self._has_updated: bool = True self._schedules: t.Dict[str, t.Dict] = {} self.threads: t.List[threading.Thread] = [] @@ -253,7 +252,7 @@ def named_works(self) -> t.List[t.Tuple[str, "lightning_app.LightningWork"]]: """Returns all the works defined within this application with their names.""" return self.root.named_works(recurse=True) - def _collect_deltas_from_ui_and_work_queues(self) -> t.List[Delta]: + def _collect_deltas_from_ui_and_work_queues(self) -> t.List[t.Union[Delta, APIRequest, CommandRequest]]: # The aggregation would try to get as many deltas as possible # from both the `api_delta_queue` and `delta_queue` # during the `state_accumulate_wait` time. @@ -267,8 +266,12 @@ def _collect_deltas_from_ui_and_work_queues(self) -> t.List[Delta]: while (time() - t0) < self.state_accumulate_wait: if self.api_delta_queue and should_get_delta_from_api: - delta_from_api: Delta = self.get_state_changed_from_queue(self.api_delta_queue) # TODO: rename + delta_from_api: t.Union[DeltaRequest, APIRequest, CommandRequest] = self.get_state_changed_from_queue( + self.api_delta_queue + ) # TODO: rename if delta_from_api: + if isinstance(delta_from_api, DeltaRequest): + delta_from_api = delta_from_api.delta deltas.append(delta_from_api) else: should_get_delta_from_api = False @@ -278,7 +281,7 @@ def _collect_deltas_from_ui_and_work_queues(self) -> t.List[Delta]: if component_output: logger.debug(f"Received from {component_output.id} : {component_output.delta.to_dict()}") work = self.get_component_by_name(component_output.id) - new_work_delta = _delta_to_appstate_delta(self.root, work, deepcopy(component_output.delta)) + new_work_delta = _delta_to_app_state_delta(self.root, work, deepcopy(component_output.delta)) deltas.append(new_work_delta) else: should_get_component_output = False @@ -307,16 +310,29 @@ def maybe_apply_changes(self) -> bool: if not deltas: # When no deltas are received from the Rest API or work queues, # we need to check if the flow modified the state and populate changes. - if Delta(DeepDiff(self.last_state, self.state, verbose_level=2)).to_dict(): + deep_diff = DeepDiff(self.last_state, self.state, verbose_level=2) + if deep_diff: + # TODO: Resolve changes with ``CacheMissException``. # new_state = self.populate_changes(self.last_state, self.state) - self.set_state(self.state) + self.set_last_state(self.state) self._has_updated = True return False logger.debug(f"Received {[d.to_dict() for d in deltas]}") - state = self.state + # 1: Process the API / Command Requests first as they might affect the state. + state_deltas = [] for delta in deltas: + if isinstance(delta, (APIRequest, CommandRequest)): + _process_requests(self, delta) + else: + state_deltas.append(delta) + + # 2: Collect the state + state = self.state + + # 3: Apply the state delta + for delta in state_deltas: try: state += delta except Exception as e: @@ -329,7 +345,6 @@ def maybe_apply_changes(self) -> bool: def run_once(self): """Method used to collect changes and run the root Flow once.""" done = False - self._has_updated = False self._last_run_time = 0.0 if self.backend is not None: @@ -350,19 +365,23 @@ def run_once(self): elif self.stage == AppStage.RESTARTING: return self._apply_restarting() - _process_command_requests(self) + t0 = time() try: self.check_error_queue() - t0 = time() - self.root.run() - self._last_run_time = time() - t0 + # Execute the flow only if: + # - There are state changes + # - It is the first execution of the flow + if self._has_updated: + self.root.run() except CacheMissException: self._on_cache_miss_exception() except (ExitAppException, KeyboardInterrupt): done = True self.stage = AppStage.STOPPING + self._last_run_time = time() - t0 + self.on_run_once_end() return done @@ -404,8 +423,6 @@ def _run(self) -> bool: self._reset_run_time_monitor() - _populate_commands_endpoint(self) - while not done: done = self.run_once() @@ -414,6 +431,8 @@ def _run(self) -> bool: if self._has_updated and self.should_publish_changes_to_api and self.api_publish_state_queue: self.api_publish_state_queue.put(self.state_vars) + self._has_updated = False + return True def _update_layout(self) -> None: @@ -430,8 +449,10 @@ def _apply_restarting(self) -> bool: self.stage = AppStage.BLOCKING return False - def _has_work_finished(self, work): + def _has_work_finished(self, work) -> bool: latest_call_hash = work._calls[CacheCallsKeys.LATEST_CALL_HASH] + if latest_call_hash is None: + return False return "ret" in work._calls[latest_call_hash] def _collect_work_finish_status(self) -> dict: diff --git a/src/lightning_app/core/flow.py b/src/lightning_app/core/flow.py index f6b6e34e81538..41c46cd868307 100644 --- a/src/lightning_app/core/flow.py +++ b/src/lightning_app/core/flow.py @@ -634,3 +634,36 @@ def my_remote_method(self, name): lightning my_command_name --args name=my_own_name """ raise NotImplementedError + + def configure_api(self): + """Configure the API routes of the LightningFlow. + + Returns a list of HttpMethod such as Post or Get. + + .. code-block:: python + + from lightning_app import LightningFlow + from lightning_app.api import Post + + from pydantic import BaseModel + + + class HandlerModel(BaseModel): + name: str + + + class Flow(L.LightningFlow): + def __init__(self): + super().__init__() + self.names = [] + + def handler(self, config: HandlerModel) -> None: + self.names.append(config.name) + + def configure_api(self): + return [Post("/v1/api/request", self.handler)] + + Once the app is running, you can access the Swagger UI of the app + under the ``/docs`` route. + """ + raise NotImplementedError diff --git a/src/lightning_app/core/queues.py b/src/lightning_app/core/queues.py index efac8230047e0..2b7295d7f327f 100644 --- a/src/lightning_app/core/queues.py +++ b/src/lightning_app/core/queues.py @@ -36,9 +36,7 @@ ORCHESTRATOR_COPY_REQUEST_CONSTANT = "ORCHESTRATOR_COPY_REQUEST" ORCHESTRATOR_COPY_RESPONSE_CONSTANT = "ORCHESTRATOR_COPY_RESPONSE" WORK_QUEUE_CONSTANT = "WORK_QUEUE" -COMMANDS_REQUESTS_QUEUE_CONSTANT = "COMMANDS_REQUESTS_QUEUE" -COMMANDS_RESPONSES_QUEUE_CONSTANT = "COMMANDS_RESPONSES_QUEUE" -COMMANDS_METADATA_QUEUE_CONSTANT = "COMMANDS_METADATA_QUEUE" +API_RESPONSE_QUEUE_CONSTANT = "API_RESPONSE_QUEUE" class QueuingSystem(Enum): @@ -54,18 +52,8 @@ def _get_queue(self, queue_name: str) -> "BaseQueue": else: return SingleProcessQueue(queue_name, default_timeout=STATE_UPDATE_TIMEOUT) - def get_commands_requests_queue(self, queue_id: Optional[str] = None) -> "BaseQueue": - queue_name = f"{queue_id}_{COMMANDS_REQUESTS_QUEUE_CONSTANT}" if queue_id else COMMANDS_REQUESTS_QUEUE_CONSTANT - return self._get_queue(queue_name) - - def get_commands_responses_queue(self, queue_id: Optional[str] = None) -> "BaseQueue": - queue_name = ( - f"{queue_id}_{COMMANDS_RESPONSES_QUEUE_CONSTANT}" if queue_id else COMMANDS_RESPONSES_QUEUE_CONSTANT - ) - return self._get_queue(queue_name) - - def get_commands_metadata_queue(self, queue_id: Optional[str] = None) -> "BaseQueue": - queue_name = f"{queue_id}_{COMMANDS_METADATA_QUEUE_CONSTANT}" if queue_id else COMMANDS_METADATA_QUEUE_CONSTANT + def get_api_response_queue(self, queue_id: Optional[str] = None) -> "BaseQueue": + queue_name = f"{queue_id}_{API_RESPONSE_QUEUE_CONSTANT}" if queue_id else API_RESPONSE_QUEUE_CONSTANT return self._get_queue(queue_name) def get_readiness_queue(self, queue_id: Optional[str] = None) -> "BaseQueue": @@ -98,10 +86,6 @@ def get_api_delta_queue(self, queue_id: Optional[str] = None) -> "BaseQueue": queue_name = f"{queue_id}_{API_DELTA_QUEUE_CONSTANT}" if queue_id else API_DELTA_QUEUE_CONSTANT return self._get_queue(queue_name) - def get_api_refresh_queue(self, queue_id: Optional[str] = None) -> "BaseQueue": - queue_name = f"{queue_id}_{API_REFRESH_QUEUE_CONSTANT}" if queue_id else API_REFRESH_QUEUE_CONSTANT - return self._get_queue(queue_name) - def get_orchestrator_request_queue(self, work_name: str, queue_id: Optional[str] = None) -> "BaseQueue": queue_name = ( f"{queue_id}_{ORCHESTRATOR_REQUEST_CONSTANT}_{work_name}" diff --git a/src/lightning_app/runners/backends/backend.py b/src/lightning_app/runners/backends/backend.py index 87bb103823fd2..a944cd4aa9093 100644 --- a/src/lightning_app/runners/backends/backend.py +++ b/src/lightning_app/runners/backends/backend.py @@ -82,11 +82,8 @@ def _prepare_queues(self, app): kw = dict(queue_id=self.queue_id) app.delta_queue = self.queues.get_delta_queue(**kw) app.readiness_queue = self.queues.get_readiness_queue(**kw) - app.commands_requests_queue = self.queues.get_commands_requests_queue(**kw) - app.commands_responses_queue = self.queues.get_commands_responses_queue(**kw) - app.commands_metadata_queue = self.queues.get_commands_metadata_queue(**kw) + app.api_response_queue = self.queues.get_api_response_queue(**kw) app.error_queue = self.queues.get_error_queue(**kw) - app.delta_queue = self.queues.get_delta_queue(**kw) app.api_publish_state_queue = self.queues.get_api_state_publish_queue(**kw) app.api_delta_queue = self.queues.get_api_delta_queue(**kw) app.request_queues = {} diff --git a/src/lightning_app/runners/cloud.py b/src/lightning_app/runners/cloud.py index 957b60b5d2ab5..2cd98ebe4cf68 100644 --- a/src/lightning_app/runners/cloud.py +++ b/src/lightning_app/runners/cloud.py @@ -18,15 +18,22 @@ Gridv1ImageSpec, V1BuildSpec, V1DependencyFileInfo, + V1Drive, + V1DriveSpec, + V1DriveStatus, + V1DriveType, V1EnvVar, V1Flowserver, V1LightningappInstanceSpec, V1LightningappInstanceState, + V1LightningworkDrives, V1LightningworkSpec, + V1Metadata, V1NetworkConfig, V1PackageManager, V1ProjectClusterBinding, V1PythonDependencyInfo, + V1SourceType, V1UserRequestedComputeConfig, V1Work, ) @@ -36,6 +43,7 @@ from lightning_app.runners.backends.cloud import CloudBackend from lightning_app.runners.runtime import Runtime from lightning_app.source_code import LocalSourceCodeDir +from lightning_app.storage import Drive from lightning_app.utilities.cloud import _get_project from lightning_app.utilities.dependency_caching import get_hash from lightning_app.utilities.packaging.app_config import AppConfig, find_config_file @@ -107,10 +115,45 @@ def dispatch( preemptible=work.cloud_compute.preemptible, shm_size=work.cloud_compute.shm_size, ) + + drive_specs: List[V1LightningworkDrives] = [] + for drive_attr_name, drive in [ + (k, getattr(work, k)) for k in work._state if isinstance(getattr(work, k), Drive) + ]: + if drive.protocol == "lit://": + drive_type = V1DriveType.NO_MOUNT_S3 + source_type = V1SourceType.S3 + elif drive.protocol == "s3://": + drive_type = V1DriveType.INDEXED_S3 + source_type = V1SourceType.S3 + else: + raise RuntimeError( + f"unknown drive protocol `{drive.protocol}`. Please verify this " + f"drive type has been configured for use in the cloud dispatcher." + ) + + drive_specs.append( + V1LightningworkDrives( + drive=V1Drive( + metadata=V1Metadata( + name=f"{work.name}.{drive_attr_name}", + ), + spec=V1DriveSpec( + drive_type=drive_type, + source_type=source_type, + source=f"{drive.protocol}{drive.id}", + ), + status=V1DriveStatus(), + ), + mount_location=str(drive.root_folder), + ), + ) + random_name = "".join(random.choice(string.ascii_lowercase) for _ in range(5)) spec = V1LightningworkSpec( build_spec=build_spec, cluster_id=cluster_id, + drives=drive_specs, user_requested_compute_config=user_compute_config, network_config=[V1NetworkConfig(name=random_name, port=work.port)], ) diff --git a/src/lightning_app/runners/multiprocess.py b/src/lightning_app/runners/multiprocess.py index 92ec900d89c65..16e373b0a37a2 100644 --- a/src/lightning_app/runners/multiprocess.py +++ b/src/lightning_app/runners/multiprocess.py @@ -3,10 +3,13 @@ from dataclasses import dataclass from typing import Any, Callable, Optional, Union +from lightning_app.api.http_methods import _add_tags_to_api, _validate_api from lightning_app.core.api import start_server from lightning_app.runners.backends import Backend from lightning_app.runners.runtime import Runtime from lightning_app.storage.orchestrator import StorageOrchestrator +from lightning_app.utilities.app_helpers import is_overridden +from lightning_app.utilities.commands.base import _commands_to_api, _prepare_commands from lightning_app.utilities.component import _set_flow_context, _set_frontend_context from lightning_app.utilities.load_app import extract_metadata_from_app from lightning_app.utilities.network import find_free_network_port @@ -60,15 +63,25 @@ def dispatch(self, *args: Any, on_before_run: Optional[Callable] = None, **kwarg if self.start_server: self.app.should_publish_changes_to_api = True has_started_queue = self.backend.queues.get_has_server_started_queue() + + apis = [] + if is_overridden("configure_api", self.app.root): + apis = self.app.root.configure_api() + _validate_api(apis) + _add_tags_to_api(apis, ["app_api"]) + + if is_overridden("configure_commands", self.app.root): + commands = _prepare_commands(self.app) + apis += _commands_to_api(commands) + kwargs = dict( + apis=apis, host=self.host, port=self.port, + api_response_queue=self.app.api_response_queue, api_publish_state_queue=self.app.api_publish_state_queue, api_delta_queue=self.app.api_delta_queue, has_started_queue=has_started_queue, - commands_requests_queue=self.app.commands_requests_queue, - commands_responses_queue=self.app.commands_responses_queue, - commands_metadata_queue=self.app.commands_metadata_queue, spec=extract_metadata_from_app(self.app), ) server_proc = multiprocessing.Process(target=start_server, kwargs=kwargs) diff --git a/src/lightning_app/storage/drive.py b/src/lightning_app/storage/drive.py index 3bcdf72780653..f72ad38b6e130 100644 --- a/src/lightning_app/storage/drive.py +++ b/src/lightning_app/storage/drive.py @@ -13,7 +13,7 @@ class Drive: __IDENTIFIER__ = "__drive__" - __PROTOCOLS__ = ["lit://"] + __PROTOCOLS__ = ["lit://", "s3://"] def __init__( self, @@ -35,18 +35,31 @@ def __init__( root_folder: This is the folder from where the Drive perceives the data (e.g this acts as a mount dir). """ self.id = None + self.protocol = None for protocol in self.__PROTOCOLS__: if id.startswith(protocol): self.protocol = protocol self.id = id.replace(protocol, "") + break + else: # N.B. for-else loop + raise ValueError( + f"Unknown protocol for the drive 'id' argument '{id}`. The 'id' string " + f"must start with one of the following prefixes {self.__PROTOCOLS__}" + ) + + if self.protocol == "s3://" and not self.id.endswith("/"): + raise ValueError( + "S3 drives must end in a trailing slash (`/`) to indicate a folder is being mounted. " + f"Recieved: '{id}'. Mounting a single file is not currently supported." + ) if not self.id: raise Exception(f"The Drive id needs to start with one of the following protocols: {self.__PROTOCOLS__}") - if "/" in self.id: + if self.protocol != "s3://" and "/" in self.id: raise Exception(f"The id should be unique to identify your drive. Found `{self.id}`.") - self.root_folder = pathlib.Path(root_folder).resolve() if root_folder else os.getcwd() + self.root_folder = pathlib.Path(root_folder).resolve() if root_folder else pathlib.Path(os.getcwd()) if not os.path.isdir(self.root_folder): raise Exception(f"The provided root_folder isn't a directory: {root_folder}") self.component_name = component_name @@ -75,6 +88,10 @@ def put(self, path: str) -> None: raise Exception("The component name needs to be known to put a path to the Drive.") if _is_flow_context(): raise Exception("The flow isn't allowed to put files into a Drive.") + if self.protocol == "s3://": + raise PermissionError( + "S3 based drives cannot currently add files via this API. Did you mean to use `lit://` drives?" + ) self._validate_path(path) @@ -98,6 +115,10 @@ def list(self, path: Optional[str] = ".", component_name: Optional[str] = None) """ if _is_flow_context(): raise Exception("The flow isn't allowed to list files from a Drive.") + if self.protocol == "s3://": + raise PermissionError( + "S3 based drives cannot currently list files via this API. Did you mean to use `lit://` drives?" + ) if component_name: paths = [ @@ -142,6 +163,10 @@ def get( """ if _is_flow_context(): raise Exception("The flow isn't allowed to get files from a Drive.") + if self.protocol == "s3://": + raise PermissionError( + "S3 based drives cannot currently get files via this API. Did you mean to use `lit://` drives?" + ) if component_name: shared_path = self._to_shared_path( @@ -189,6 +214,10 @@ def delete(self, path: str) -> None: """ if not self.component_name: raise Exception("The component name needs to be known to delete a path to the Drive.") + if self.protocol == "s3://": + raise PermissionError( + "S3 based drives cannot currently delete files via this API. Did you mean to use `lit://` drives?" + ) shared_path = self._to_shared_path( path, diff --git a/src/lightning_app/structures/dict.py b/src/lightning_app/structures/dict.py index 93e2b161b2e7a..b414269b93eec 100644 --- a/src/lightning_app/structures/dict.py +++ b/src/lightning_app/structures/dict.py @@ -22,7 +22,7 @@ def __init__(self, **kwargs: T): .. doctest:: >>> from lightning_app import LightningFlow, LightningWork - >>> from lightning_app.core import Dict + >>> from lightning_app.structures import Dict >>> class CounterWork(LightningWork): ... def __init__(self): ... super().__init__() diff --git a/src/lightning_app/structures/list.py b/src/lightning_app/structures/list.py index f5a7c5c9913ad..cf691c98a8c38 100644 --- a/src/lightning_app/structures/list.py +++ b/src/lightning_app/structures/list.py @@ -24,7 +24,7 @@ def __init__(self, *items: T): .. doctest:: >>> from lightning_app import LightningFlow, LightningWork - >>> from lightning_app.core import List + >>> from lightning_app.structures import List >>> class CounterWork(LightningWork): ... def __init__(self): ... super().__init__() diff --git a/src/lightning_app/testing/testing.py b/src/lightning_app/testing/testing.py index e1cc2e180dab5..884c02a0521c1 100644 --- a/src/lightning_app/testing/testing.py +++ b/src/lightning_app/testing/testing.py @@ -1,26 +1,30 @@ import asyncio import json +import logging import os import shutil import subprocess import sys import tempfile import time +import traceback from contextlib import contextmanager from subprocess import Popen from time import sleep -from typing import Any, Callable, Dict, Generator, List, Type +from typing import Any, Callable, Dict, Generator, List, Optional, Type import requests from lightning_cloud.openapi.rest import ApiException from requests import Session from rich import print +from rich.color import ANSI_COLOR_NAMES from lightning_app import LightningApp, LightningFlow from lightning_app.cli.lightning_cli import run_app from lightning_app.core.constants import LIGHTNING_CLOUD_PROJECT_ID from lightning_app.runners.multiprocess import MultiProcessRuntime from lightning_app.testing.config import Config +from lightning_app.utilities.app_logs import _app_logs_reader from lightning_app.utilities.cloud import _get_project from lightning_app.utilities.enum import CacheCallsKeys from lightning_app.utilities.imports import _is_playwright_available, requires @@ -32,6 +36,9 @@ from playwright.sync_api import HttpCredentials, sync_playwright +_logger = logging.getLogger(__name__) + + class LightningTestApp(LightningApp): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) @@ -282,20 +289,6 @@ def run_app_in_cloud(app_folder: str, app_name: str = "app.py", extra_args: [str var scrollingElement = (document.scrollingElement || document.body); scrollingElement.scrollTop = scrollingElement.scrollHeight; }, 200); - - if (!window._logs) { - window._logs = []; - } - - if (window.logTerminals) { - Object.entries(window.logTerminals).forEach( - ([key, value]) => { - window.logTerminals[key]._onLightningWritelnHandler = function (data) { - window._logs = window._logs.concat([data]); - } - } - ); - } """ ) @@ -309,8 +302,46 @@ def run_app_in_cloud(app_folder: str, app_name: str = "app.py", extra_args: [str except (playwright._impl._api_types.Error, playwright._impl._api_types.TimeoutError): pass - def fetch_logs() -> str: - return admin_page.evaluate("window._logs;") + client = LightningClient() + project = _get_project(client) + identifiers = [] + rich_colors = list(ANSI_COLOR_NAMES) + + def fetch_logs(component_names: Optional[List[str]] = None) -> Generator: + """This methods creates websockets connection in threads and returns the logs to the main thread.""" + app_id = admin_page.url.split("/")[-1] + + if not component_names: + works = client.lightningwork_service_list_lightningwork( + project_id=project.project_id, + app_id=app_id, + ).lightningworks + component_names = ["flow"] + [w.name for w in works] + + def on_error_callback(ws_app, *_): + print(traceback.print_exc()) + ws_app.close() + + colors = {c: rich_colors[i + 1] for i, c in enumerate(component_names)} + gen = _app_logs_reader( + client=client, + project_id=project.project_id, + app_id=app_id, + component_names=component_names, + follow=False, + on_error_callback=on_error_callback, + ) + max_length = max(len(c.replace("root.", "")) for c in component_names) + for log_event in gen: + message = log_event.message + identifier = f"{log_event.timestamp}{log_event.message}" + if identifier not in identifiers: + date = log_event.timestamp.strftime("%m/%d/%Y %H:%M:%S") + identifiers.append(identifier) + color = colors[log_event.component_name] + padding = (max_length - len(log_event.component_name)) * " " + print(f"[{color}]{log_event.component_name}{padding}[/{color}] {date} {message}") + yield message # 5. Print your application ID print( @@ -318,16 +349,11 @@ def fetch_logs() -> str: ) try: - yield admin_page, view_page, fetch_logs + yield admin_page, view_page, fetch_logs, name except KeyboardInterrupt: pass finally: print("##################################################") - printed_logs = [] - for log in fetch_logs(): - if log not in printed_logs: - printed_logs.append(log) - print(log.split("[0m")[-1]) button = admin_page.locator('[data-cy="stop"]') try: button.wait_for(timeout=3 * 1000) @@ -337,8 +363,6 @@ def fetch_logs() -> str: context.close() browser.close() - client = LightningClient() - project = _get_project(client) list_lightningapps = client.lightningapp_instance_service_list_lightningapp_instances(project.project_id) for lightningapp in list_lightningapps.lightningapps: diff --git a/src/lightning_app/utilities/app_helpers.py b/src/lightning_app/utilities/app_helpers.py index 4144c6de3ba12..faa612bba1998 100644 --- a/src/lightning_app/utilities/app_helpers.py +++ b/src/lightning_app/utilities/app_helpers.py @@ -299,7 +299,7 @@ def _set_child_name(component: "Component", child: "Component", new_name: str) - return child_name -def _delta_to_appstate_delta(root: "LightningFlow", component: "Component", delta: Delta) -> Delta: +def _delta_to_app_state_delta(root: "LightningFlow", component: "Component", delta: Delta) -> Delta: delta_dict = delta.to_dict() for changed in delta_dict.values(): for delta_key in changed.copy().keys(): @@ -322,8 +322,9 @@ def _delta_to_appstate_delta(root: "LightningFlow", component: "Component", delt delta_key_without_root = delta_key[4:] # the first 4 chars are the word 'root', strip it new_key = new_prefix + delta_key_without_root - changed[new_key] = val - del changed[delta_key] + if new_key != delta_key: + changed[new_key] = val + del changed[delta_key] return Delta(delta_dict) diff --git a/src/lightning_app/utilities/app_logs.py b/src/lightning_app/utilities/app_logs.py new file mode 100644 index 0000000000000..536fbaae05093 --- /dev/null +++ b/src/lightning_app/utilities/app_logs.py @@ -0,0 +1,140 @@ +import json +import queue +import sys +from dataclasses import dataclass +from datetime import datetime, timedelta +from json import JSONDecodeError +from threading import Thread +from typing import Callable, Iterator, List, Optional + +import dateutil.parser +from websocket import WebSocketApp + +from lightning_app.utilities.logs_socket_api import _LightningLogsSocketAPI +from lightning_app.utilities.network import LightningClient + + +@dataclass +class _LogEventLabels: + app: str + container: str + filename: str + job: str + namespace: str + node_name: str + pod: str + stream: Optional[str] = None + + +@dataclass +class _LogEvent: + message: str + timestamp: datetime + component_name: str + labels: _LogEventLabels + + def __ge__(self, other: "_LogEvent") -> bool: + return self.timestamp >= other.timestamp + + def __gt__(self, other: "_LogEvent") -> bool: + return self.timestamp > other.timestamp + + +def _push_log_events_to_read_queue_callback(component_name: str, read_queue: queue.PriorityQueue): + """Pushes _LogEvents from websocket to read_queue. + + Returns callback function used with `on_message_callback` of websocket.WebSocketApp. + """ + + def callback(ws_app: WebSocketApp, msg: str): + # We strongly trust that the contract on API will hold atm :D + event_dict = json.loads(msg) + labels = _LogEventLabels(**event_dict["labels"]) + + if "message" in event_dict: + message = event_dict["message"] + timestamp = dateutil.parser.isoparse(event_dict["timestamp"]) + event = _LogEvent( + message=message, + timestamp=timestamp, + component_name=component_name, + labels=labels, + ) + read_queue.put(event) + + return callback + + +def _error_callback(ws_app: WebSocketApp, error: Exception): + errors = { + KeyError: "Malformed log message, missing key", + JSONDecodeError: "Malformed log message", + TypeError: "Malformed log format", + ValueError: "Malformed date format", + } + print(f"Error while reading logs ({errors.get(type(error), 'Unknown')})", file=sys.stderr) + ws_app.close() + + +def _app_logs_reader( + client: LightningClient, + project_id: str, + app_id: str, + component_names: List[str], + follow: bool, + on_error_callback: Optional[Callable] = None, +) -> Iterator[_LogEvent]: + + read_queue = queue.PriorityQueue() + logs_api_client = _LightningLogsSocketAPI(client.api_client) + + # We will use a socket per component + log_sockets = [ + logs_api_client.create_lightning_logs_socket( + project_id=project_id, + app_id=app_id, + component=component_name, + on_message_callback=_push_log_events_to_read_queue_callback(component_name, read_queue), + on_error_callback=on_error_callback or _error_callback, + ) + for component_name in component_names + ] + + # And each socket on separate thread pushing log event to print queue + # run_forever() will run until we close() the connection from outside + log_threads = [Thread(target=work.run_forever) for work in log_sockets] + + # Establish connection and begin pushing logs to the print queue + for th in log_threads: + th.start() + + # Print logs from queue when log event is available + user_log_start = "<<< BEGIN USER_RUN_FLOW SECTION >>>" + start_timestamp = None + + # Print logs from queue when log event is available + try: + while True: + log_event = read_queue.get(timeout=None if follow else 1.0) + if user_log_start in log_event.message: + start_timestamp = log_event.timestamp + timedelta(seconds=0.5) + + if start_timestamp and log_event.timestamp > start_timestamp: + yield log_event + + except queue.Empty: + # Empty is raised by queue.get if timeout is reached. Follow = False case. + pass + + except KeyboardInterrupt: + # User pressed CTRL+C to exit, we sould respect that + pass + + finally: + # Close connections - it will cause run_forever() to finish -> thread as finishes aswell + for socket in log_sockets: + socket.close() + + # Because all socket were closed, we can just wait for threads to finish. + for th in log_threads: + th.join() diff --git a/src/lightning_app/utilities/cli_helpers.py b/src/lightning_app/utilities/cli_helpers.py index fcce96ec64407..6000114c3d4d6 100644 --- a/src/lightning_app/utilities/cli_helpers.py +++ b/src/lightning_app/utilities/cli_helpers.py @@ -49,16 +49,42 @@ def _is_url(id: Optional[str]) -> bool: return False +def _get_metadata_from_openapi(paths: Dict, path: str): + parameters = paths[path]["post"].get("parameters", {}) + tag = paths[path]["post"].get("tags", [None])[0] + cls_path = paths[path]["post"].get("cls_path", None) + cls_name = paths[path]["post"].get("cls_name", None) + + metadata = {"tag": tag, "parameters": {}} + + if cls_path: + metadata["cls_path"] = cls_path + + if cls_name: + metadata["cls_name"] = cls_name + + if not parameters: + return metadata + + metadata["parameters"].update({d["name"]: d["schema"]["type"] for d in parameters}) + return metadata + + +def _extract_command_from_openapi(openapi_resp: Dict) -> Dict[str, Dict[str, str]]: + command_paths = [p for p in openapi_resp["paths"] if p.startswith("/command/")] + return {p.replace("/command/", ""): _get_metadata_from_openapi(openapi_resp["paths"], p) for p in command_paths} + + def _retrieve_application_url_and_available_commands(app_id_or_name_or_url: Optional[str]): """This function is used to retrieve the current url associated with an id.""" if _is_url(app_id_or_name_or_url): url = app_id_or_name_or_url assert url - resp = requests.get(url + "/api/v1/commands") + resp = requests.get(url + "/openapi.json") if resp.status_code != 200: raise Exception(f"The server didn't process the request properly. Found {resp.json()}") - return url, resp.json() + return url, _extract_command_from_openapi(resp.json()) # 2: If no identifier has been provided, evaluate the local application failed_locally = False @@ -66,10 +92,10 @@ def _retrieve_application_url_and_available_commands(app_id_or_name_or_url: Opti if app_id_or_name_or_url is None: try: url = f"http://localhost:{APP_SERVER_PORT}" - resp = requests.get(f"{url}/api/v1/commands") + resp = requests.get(f"{url}/openapi.json") if resp.status_code != 200: raise Exception(f"The server didn't process the request properly. Found {resp.json()}") - return url, resp.json() + return url, _extract_command_from_openapi(resp.json()) except requests.exceptions.ConnectionError: failed_locally = True @@ -88,8 +114,8 @@ def _retrieve_application_url_and_available_commands(app_id_or_name_or_url: Opti if lightningapp.id == app_id_or_name_or_url or lightningapp.name == app_id_or_name_or_url: if lightningapp.status.url == "": raise Exception("The application is starting. Try in a few moments.") - resp = requests.get(lightningapp.status.url + "/api/v1/commands") + resp = requests.get(lightningapp.status.url + "/openapi.json") if resp.status_code != 200: raise Exception(f"The server didn't process the request properly. Found {resp.json()}") - return lightningapp.status.url, resp.json() + return lightningapp.status.url, _extract_command_from_openapi(resp.json()) return None, None diff --git a/src/lightning_app/utilities/commands/base.py b/src/lightning_app/utilities/commands/base.py index 11661e51ca26a..c74926f542744 100644 --- a/src/lightning_app/utilities/commands/base.py +++ b/src/lightning_app/utilities/commands/base.py @@ -1,6 +1,5 @@ import errno import inspect -import logging import os import os.path as osp import shutil @@ -8,19 +7,18 @@ from getpass import getuser from importlib.util import module_from_spec, spec_from_file_location from tempfile import gettempdir -from typing import Any, Callable, Dict, List, Optional, Tuple -from uuid import uuid4 +from typing import Any, Callable, Dict, List, Optional, Union import requests from pydantic import BaseModel +from lightning_app.api.http_methods import Post +from lightning_app.api.request_types import APIRequest, CommandRequest from lightning_app.utilities.app_helpers import is_overridden from lightning_app.utilities.cloud import _get_project from lightning_app.utilities.network import LightningClient from lightning_app.utilities.state import AppState -_logger = logging.getLogger(__name__) - def makedirs(path: str): r"""Recursive directory creation function.""" @@ -31,31 +29,18 @@ def makedirs(path: str): raise e -class _ClientCommandConfig(BaseModel): - command: str - affiliation: str - params: Dict[str, str] - is_client_command: bool - cls_path: str - cls_name: str - owner: str - requirements: Optional[List[str]] - - class ClientCommand: def __init__(self, method: Callable, requirements: Optional[List[str]] = None) -> None: self.method = method flow = getattr(method, "__self__", None) self.owner = flow.name if flow else None self.requirements = requirements - self.metadata = None self.models: Optional[Dict[str, BaseModel]] = None self.app_url = None self._state = None - def _setup(self, metadata: Dict[str, Any], models: Dict[str, BaseModel], app_url: str) -> None: - self.metadata = metadata - self.models = models + def _setup(self, command_name: str, app_url: str) -> None: + self.command_name = command_name self.app_url = app_url @property @@ -72,67 +57,50 @@ def state(self): def run(self, **cli_kwargs) -> None: """Overrides with the logic to execute on the client side.""" - def invoke_handler(self, **kwargs: Any) -> Dict[str, Any]: - from lightning.app.utilities.state import headers_for - - assert kwargs.keys() == self.models.keys() - for k, v in kwargs.items(): - assert isinstance(v, self.models[k]) - json = { - "command_name": self.metadata["command"], - "command_arguments": {k: v.json() for k, v in kwargs.items()}, - "affiliation": self.metadata["affiliation"], - "id": str(uuid4()), - } - resp = requests.post(self.app_url + "/api/v1/commands", json=json, headers=headers_for({})) + def invoke_handler(self, config: BaseModel) -> Dict[str, Any]: + resp = requests.post(self.app_url + f"/command/{self.command_name}", data=config.json()) assert resp.status_code == 200, resp.json() return resp.json() def _to_dict(self): return {"owner": self.owner, "requirements": self.requirements} - def __call__(self, **kwargs: Any) -> Any: - assert self.models - input = {} - for k, v in kwargs.items(): - input[k] = self.models[k].parse_raw(v) - return self.method(**input) + def __call__(self, **kwargs): + return self.method(**kwargs) def _download_command( - command_metadata: Dict[str, Any], - app_id: Optional[str], + command_name: str, + cls_path: str, + cls_name: str, + app_id: Optional[str] = None, debug_mode: bool = False, -) -> Tuple[ClientCommand, Dict[str, BaseModel]]: +) -> ClientCommand: # TODO: This is a skateboard implementation and the final version will rely on versioned # immutable commands for security concerns - config = _ClientCommandConfig(**command_metadata) tmpdir = osp.join(gettempdir(), f"{getuser()}_commands") makedirs(tmpdir) - target_file = osp.join(tmpdir, f"{config.command}.py") + target_file = osp.join(tmpdir, f"{command_name}.py") if app_id: client = LightningClient() project_id = _get_project(client).project_id response = client.lightningapp_instance_service_list_lightningapp_instance_artifacts(project_id, app_id) for artifact in response.artifacts: - if f"commands/{config.command}.py" == artifact.filename: + if f"commands/{command_name}.py" == artifact.filename: r = requests.get(artifact.url, allow_redirects=True) with open(target_file, "wb") as f: f.write(r.content) else: if not debug_mode: - shutil.copy(config.cls_path, target_file) + shutil.copy(cls_path, target_file) - cls_name = config.cls_name - spec = spec_from_file_location(config.cls_name, config.cls_path if debug_mode else target_file) + spec = spec_from_file_location(cls_name, cls_path if debug_mode else target_file) mod = module_from_spec(spec) sys.modules[cls_name] = mod spec.loader.exec_module(mod) - command = getattr(mod, cls_name)(method=None, requirements=config.requirements) - models = {k: getattr(mod, v) for k, v in config.params.items()} - if debug_mode: - shutil.rmtree(tmpdir) - return command, models + command = getattr(mod, cls_name)(method=None, requirements=[]) + shutil.rmtree(tmpdir) + return command def _to_annotation(anno: str) -> str: @@ -142,7 +110,7 @@ def _to_annotation(anno: str) -> str: return anno -def _command_to_method_and_metadata(command: ClientCommand) -> Tuple[Callable, Dict[str, Any]]: +def _validate_client_command(command: ClientCommand): """Extract method and its metadata from a ClientCommand.""" params = inspect.signature(command.method).parameters command_metadata = { @@ -170,8 +138,6 @@ def _command_to_method_and_metadata(command: ClientCommand) -> Tuple[Callable, D raise Exception( f"The provided annotation for the argument {k} shouldn't an instance of pydantic BaseModel." ) - command.models[k] = config - return method, command_metadata def _upload_command(command_name: str, command: ClientCommand) -> Optional[str]: @@ -192,54 +158,68 @@ def _upload_command(command_name: str, command: ClientCommand) -> Optional[str]: return filepath -def _populate_commands_endpoint(app): +def _prepare_commands(app) -> List: if not is_overridden("configure_commands", app.root): - return + return [] - # 1: Populate commands metadata + # 1: Upload the command to s3. commands = app.root.configure_commands() - commands_metadata = [] - command_names = set() for command_mapping in commands: for command_name, command in command_mapping.items(): - is_client_command = isinstance(command, ClientCommand) - extras = {} - if is_client_command: + if isinstance(command, ClientCommand): _upload_command(command_name, command) - command, extras = _command_to_method_and_metadata(command) - if command_name in command_names: - raise Exception(f"The component name {command_name} has already been used. They need to be unique.") - command_names.add(command_name) - params = inspect.signature(command).parameters - commands_metadata.append( - { - "command": command_name, - "affiliation": command.__self__.name, - "params": list(params.keys()), - "is_client_command": is_client_command, - **extras, - } - ) - # 1.2: Pass the collected commands through the queue to the Rest API. - app.commands_metadata_queue.put(commands_metadata) + # 2: Cache the commands on the app. app.commands = commands + return commands -def _process_command_requests(app): - if not is_overridden("configure_commands", app.root): - return - - # 1: Populate commands metadata - commands = app.commands - - # 2: Collect requests metadata - command_query = app.get_state_changed_from_queue(app.commands_requests_queue) - if command_query: - for command in commands: - for command_name, method in command.items(): - if command_query["command_name"] == command_name: - # 2.1: Evaluate the method associated to a specific command. - # Validation is done on the CLI side. - response = method(**command_query["command_arguments"]) - app.commands_responses_queue.put({"response": response, "id": command_query["id"]}) +def _process_api_request(app, request: APIRequest) -> None: + flow = app.get_component_by_name(request.name) + method = getattr(flow, request.method_name) + response = method(*request.args, **request.kwargs) + app.api_response_queue.put({"response": response, "id": request.id}) + + +def _process_command_requests(app, request: CommandRequest) -> None: + for command in app.commands: + for command_name, method in command.items(): + if request.method_name == command_name: + # 2.1: Evaluate the method associated to a specific command. + # Validation is done on the CLI side. + response = method(*request.args, **request.kwargs) + app.api_response_queue.put({"response": response, "id": request.id}) + + +def _process_requests(app, request: Union[APIRequest, CommandRequest]) -> None: + """Convert user commands to API endpoint.""" + if isinstance(request, APIRequest): + _process_api_request(app, request) + else: + _process_command_requests(app, request) + + +def _collect_open_api_extras(command) -> Dict: + if not isinstance(command, ClientCommand): + return {} + return { + "cls_path": inspect.getfile(command.__class__), + "cls_name": command.__class__.__name__, + } + + +def _commands_to_api(commands: List[Dict[str, Union[Callable, ClientCommand]]]) -> List: + """Convert user commands to API endpoint.""" + api = [] + for command in commands: + for k, v in command.items(): + api.append( + Post( + f"/command/{k}", + v.method if isinstance(v, ClientCommand) else v, + method_name=k, + tags=["app_client_command"] if isinstance(v, ClientCommand) else ["app_command"], + openapi_extra=_collect_open_api_extras(v), + ) + ) + return api diff --git a/src/lightning_app/utilities/enum.py b/src/lightning_app/utilities/enum.py index dbf20413aa9d9..2b88d93169930 100644 --- a/src/lightning_app/utilities/enum.py +++ b/src/lightning_app/utilities/enum.py @@ -72,3 +72,9 @@ def make_status(stage: str, message: Optional[str] = None, reason: Optional[str] class CacheCallsKeys: LATEST_CALL_HASH = "latest_call_hash" + + +class OpenAPITags: + APP_CLIENT_COMMAND = "app_client_command" + APP_COMMAND = "app_command" + APP_API = "app_api" diff --git a/src/lightning_app/utilities/logs_socket_api.py b/src/lightning_app/utilities/logs_socket_api.py new file mode 100644 index 0000000000000..0ab9a5c24f3e5 --- /dev/null +++ b/src/lightning_app/utilities/logs_socket_api.py @@ -0,0 +1,95 @@ +from typing import Callable, Optional +from urllib.parse import urlparse + +from lightning_cloud.openapi import ApiClient, AuthServiceApi, V1LoginRequest +from websocket import WebSocketApp + +from lightning_app.utilities.login import Auth + + +class _LightningLogsSocketAPI: + def __init__(self, api_client: ApiClient): + self.api_client = api_client + self._auth = Auth() + self._auth.authenticate() + self._auth_service = AuthServiceApi(api_client) + + def _get_api_token(self) -> str: + token_resp = self._auth_service.auth_service_login( + body=V1LoginRequest( + username=self._auth.username, + api_key=self._auth.api_key, + ) + ) + return token_resp.token + + @staticmethod + def _socket_url(host: str, project_id: str, app_id: str, token: str, component: str) -> str: + return ( + f"wss://{host}/v1/projects/{project_id}/appinstances/{app_id}/logs?" + f"token={token}&component={component}&follow=true" + ) + + def create_lightning_logs_socket( + self, + project_id: str, + app_id: str, + component: str, + on_message_callback: Callable[[WebSocketApp, str], None], + on_error_callback: Optional[Callable[[Exception, str], None]] = None, + ) -> WebSocketApp: + """Creates and returns WebSocketApp to listen to lightning app logs. + + .. code-block:: python + # Synchronous reading, run_forever() is blocking + + + def print_log_msg(ws_app, msg): + print(msg) + + + flow_logs_socket = client.create_lightning_logs_socket("project_id", "app_id", "flow", print_log_msg) + flow_socket.run_forever() + + .. code-block:: python + # Asynchronous reading (with Threads) + + + def print_log_msg(ws_app, msg): + print(msg) + + + flow_logs_socket = client.create_lightning_logs_socket("project_id", "app_id", "flow", print_log_msg) + work_logs_socket = client.create_lightning_logs_socket("project_id", "app_id", "work_1", print_log_msg) + + flow_logs_thread = Thread(target=flow_logs_socket.run_forever) + work_logs_thread = Thread(target=work_logs_socket.run_forever) + + flow_logs_thread.start() + work_logs_thread.start() + # ....... + + flow_logs_socket.close() + work_logs_thread.close() + + Arguments: + project_id: Project ID. + app_id: Application ID. + component: Component name eg flow. + on_message_callback: Callback object which is called when received data. + on_error_callback: Callback object which is called when we get error. + + Returns: + WebSocketApp of the wanted socket + """ + _token = self._get_api_token() + clean_ws_host = urlparse(self.api_client.configuration.host).netloc + socket_url = self._socket_url( + host=clean_ws_host, + project_id=project_id, + app_id=app_id, + token=_token, + component=component, + ) + + return WebSocketApp(socket_url, on_message=on_message_callback, on_error=on_error_callback) diff --git a/src/lightning_app/utilities/network.py b/src/lightning_app/utilities/network.py index 7fd03750a515d..050734723acc1 100644 --- a/src/lightning_app/utilities/network.py +++ b/src/lightning_app/utilities/network.py @@ -48,11 +48,12 @@ def _configure_session() -> Session: return http -def _check_service_url_is_ready(url: str, timeout: float = 100) -> bool: +def _check_service_url_is_ready(url: str, timeout: float = 5) -> bool: try: response = requests.get(url, timeout=timeout) return response.status_code in (200, 404) except (ConnectionError, ConnectTimeout, ReadTimeout): + logger.debug(f"The url {url} is not ready.") return False diff --git a/src/lightning_app/utilities/packaging/lightning_utils.py b/src/lightning_app/utilities/packaging/lightning_utils.py index 37f4ff22988eb..073d4d7ab613a 100644 --- a/src/lightning_app/utilities/packaging/lightning_utils.py +++ b/src/lightning_app/utilities/packaging/lightning_utils.py @@ -89,8 +89,13 @@ def get_dist_path_if_editable_install(project_name) -> str: def _prepare_lightning_wheels_and_requirements(root: Path) -> Optional[Callable]: + """This function determines if lightning is installed in editable mode (for developers) and packages the + current lightning source along with the app. - if "site-packages" in _PROJECT_ROOT: + For normal users who install via PyPi or Conda, then this function does not do anything. + """ + + if not get_dist_path_if_editable_install("lightning"): return # Packaging the Lightning codebase happens only inside the `lightning` repo. diff --git a/src/lightning_app/utilities/proxies.py b/src/lightning_app/utilities/proxies.py index 2c93a6c89f38c..99ad6e2aad0cf 100644 --- a/src/lightning_app/utilities/proxies.py +++ b/src/lightning_app/utilities/proxies.py @@ -74,7 +74,7 @@ def _send_data_to_caller_queue(work: "LightningWork", caller_queue: "BaseQueue", data.update({"state": work_state}) logger.debug(f"Sending to {work.name}: {data}") - caller_queue.put(data) + caller_queue.put(deepcopy(data)) # Reset the calls entry. work_state["calls"] = calls diff --git a/src/lightning_app/utilities/scheduler.py b/src/lightning_app/utilities/scheduler.py index 012930f017f20..e45b0879246b9 100644 --- a/src/lightning_app/utilities/scheduler.py +++ b/src/lightning_app/utilities/scheduler.py @@ -15,7 +15,7 @@ class SchedulerThread(threading.Thread): def __init__(self, app) -> None: super().__init__(daemon=True) self._exit_event = threading.Event() - self._sleep_time = 0.5 + self._sleep_time = 1.0 self._app = app def run(self) -> None: diff --git a/src/pytorch_lightning/CHANGELOG.md b/src/pytorch_lightning/CHANGELOG.md index a3755d7733dba..409d3f51bd46f 100644 --- a/src/pytorch_lightning/CHANGELOG.md +++ b/src/pytorch_lightning/CHANGELOG.md @@ -8,6 +8,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Added +- Added prefix to log message in `seed_everything` with rank info ([#13290](https://github.com/Lightning-AI/lightning/issues/13290)) + + - Added profiling to these hooks: `on_before_batch_transfer`, `transfer_batch_to_device`, `on_after_batch_transfer`, `configure_gradient_clipping`, `clip_gradients` ([#14069](https://github.com/Lightning-AI/lightning/pull/14069)) @@ -22,6 +25,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Raised a `MisconfigurationException` if batch transfer hooks are overriden with `IPUAccelerator` ([13961](https://github.com/Lightning-AI/lightning/pull/13961)) +- Updated compatibility for LightningLite to run with the latest DeepSpeed 0.7.0 ([13967](https://github.com/Lightning-AI/lightning/pull/13967)) + + ### Deprecated - Deprecated `LightningDeepSpeedModule` ([#14000](https://github.com/Lightning-AI/lightning/pull/14000)) @@ -30,7 +36,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Deprecated `amp_level` from `Trainer` in favour of passing it explictly via precision plugin ([#13898](https://github.com/Lightning-AI/lightning/pull/13898)) -- +- Deprecated the calls to `pytorch_lightning.utiltiies.meta` functions in favor of built-in https://github.com/pytorch/torchdistx support ([#13868](https://github.com/Lightning-AI/lightning/pull/13868)) ### Removed @@ -44,21 +50,48 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Removed the deprecated `DDP2Strategy` ([#14026](https://github.com/Lightning-AI/lightning/pull/14026)) +- Removed the deprecated `DistributedType` and `DeviceType` enum classes ([#14045](https://github.com/Lightning-AI/lightning/pull/14045)) + + +- Removed the experimental `pytorch_lightning.utiltiies.meta` functions in favor of built-in https://github.com/pytorch/torchdistx support ([#13868](https://github.com/Lightning-AI/lightning/pull/13868)) + + ### Fixed -- Casted only floating point tensors to fp16 with IPUs ([#13983](https://github.com/Lightning-AI/lightning/pull/13983)) +- Fixed a bug that caused spurious `AttributeError` when multiple `DataLoader` classes are imported ([#14117](https://github.com/Lightning-AI/lightning/pull/14117)) -- Casted tensors to fp16 before moving them to device with `DeepSpeedStrategy` ([#14000](https://github.com/Lightning-AI/lightning/pull/14000)) +- Fixed epoch-end logging results not being reset after the end of the epoch ([#14061](https://github.com/Lightning-AI/lightning/pull/14061)) -- Fixed the `NeptuneLogger` dependency being unrecognized ([#13988](https://github.com/Lightning-AI/lightning/pull/13988)) +- Fixed resuming from a checkpoint when using Stochastic Weight Averaging (SWA) ([#9938](https://github.com/Lightning-AI/lightning/pull/9938)) -- Fixed an issue where users would be warned about unset `max_epochs` even when `fast_dev_run` was set ([#13262](https://github.com/Lightning-AI/lightning/pull/13262)) +- Fixed the device placement when `LightningModule.cuda()` gets called without specifying a device index and the current cuda device was not 0 ([#14128](https://github.com/Lightning-AI/lightning/pull/14128)) + +- Avoid `metadata.entry_points` deprecation warning on Python 3.10 ([#14052](https://github.com/Lightning-AI/lightning/pull/14052)) + +- Fixed epoch-end logging results not being reset after the end of the epoch ([#14061](https://github.com/Lightning-AI/lightning/pull/14061)) + + +- Fixed saving hyperparameters in a composition where the parent class is not a `LightningModule` or `LightningDataModule` ([#14151](https://github.com/Lightning-AI/lightning/pull/14151)) + + + +## [1.7.1] - 2022-08-09 + +### Fixed + +- Casted only floating point tensors to fp16 with IPUs ([#13983](https://github.com/Lightning-AI/lightning/pull/13983)) +- Casted tensors to fp16 before moving them to device with `DeepSpeedStrategy` ([#14000](https://github.com/Lightning-AI/lightning/pull/14000)) +- Fixed the `NeptuneLogger` dependency being unrecognized ([#13988](https://github.com/Lightning-AI/lightning/pull/13988)) +- Fixed an issue where users would be warned about unset `max_epochs` even when `fast_dev_run` was set ([#13262](https://github.com/Lightning-AI/lightning/pull/13262)) - Fixed MPS device being unrecognized ([#13992](https://github.com/Lightning-AI/lightning/pull/13992)) +- Fixed incorrect `precision="mixed"` being used with `DeepSpeedStrategy` and `IPUStrategy` ([#14041](https://github.com/Lightning-AI/lightning/pull/14041)) +- Fixed dtype inference during gradient norm computation ([#14051](https://github.com/Lightning-AI/lightning/pull/14051)) +- Fixed a bug that caused `ddp_find_unused_parameters` to be set `False`, whereas the intended default is `True` ([#14095](https://github.com/Lightning-AI/lightning/pull/14095)) ## [1.7.0] - 2022-08-02 diff --git a/src/pytorch_lightning/README.md b/src/pytorch_lightning/README.md index eb1a42730b5f0..914596c0a9d2f 100644 --- a/src/pytorch_lightning/README.md +++ b/src/pytorch_lightning/README.md @@ -14,8 +14,8 @@ ______________________________________________________________________ DocsExamplesCommunity • - Grid AI • - License + Lightning AI • + License

@@ -78,17 +78,17 @@ Lightning is rigorously tested across multiple CPUs, GPUs, TPUs, IPUs, and HPUs
-| System / PyTorch ver. | 1.9 | 1.10 | 1.12 (latest) | -| :------------------------: | :-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: | :-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: | :--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: | -| Linux py3.7 \[GPUs\*\*\] | - | - | - | -| Linux py3.7 \[TPUs\*\*\*\] | [![CircleCI](https://circleci.com/gh/Lightning-AI/lightning/tree/master.svg?style=svg)](https://circleci.com/gh/Lightning-AI/lightning/tree/master) | - | - | -| Linux py3.8 \[IPUs\] | [![Build Status]()](https://dev.azure.com/Lightning-AI/lightning/_build/latest?definitionId=6&branchName=master) | - | - | -| Linux py3.8 \[HPUs\] | - | [![Build Status]()](https://dev.azure.com/Lightning-AI/lightning/_build/latest?definitionId=6&branchName=master) | - | -| Linux py3.8 (with Conda) | [![Test](https://github.com/Lightning-AI/lightning/actions/workflows/ci_test-conda.yml/badge.svg?branch=master&event=push)](https://github.com/Lightning-AI/lightning/actions/workflows/ci_test-conda.yml) | [![Test](https://github.com/Lightning-AI/lightning/actions/workflows/ci_test-conda.yml/badge.svg?branch=master&event=push)](https://github.com/Lightning-AI/lightning/actions/workflows/ci_test-conda.yml) | - | -| Linux py3.9 (with Conda) | - | - | [![Test](https://github.com/Lightning-AI/lightning/actions/workflows/ci_test-conda.yml/badge.svg?branch=master&event=push)](https://github.com/Lightning-AI/lightning/actions/workflows/ci_test-conda.yml) | -| Linux py3.{7,9} | - | - | [![Test](https://github.com/Lightning-AI/lightning/actions/workflows/ci_test-full.yml/badge.svg?branch=master&event=push)](https://github.com/Lightning-AI/lightning/actions/workflows/ci_test-full.yml) | -| OSX py3.{7,9} | - | - | [![Test](https://github.com/Lightning-AI/lightning/actions/workflows/ci_test-full.yml/badge.svg?branch=master&event=push)](https://github.com/Lightning-AI/lightning/actions/workflows/ci_test-full.yml) | -| Windows py3.{7,9} | - | - | [![Test](https://github.com/Lightning-AI/lightning/actions/workflows/ci_test-full.yml/badge.svg?branch=master&event=push)](https://github.com/Lightning-AI/lightning/actions/workflows/ci_test-full.yml) | +| System / PyTorch ver. | 1.9 | 1.10 | 1.12 (latest) | +| :------------------------: | :-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: | :-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: | :------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: | +| Linux py3.7 \[GPUs\*\*\] | - | - | - | +| Linux py3.7 \[TPUs\*\*\*\] | [![CircleCI](https://circleci.com/gh/Lightning-AI/lightning/tree/master.svg?style=svg)](https://circleci.com/gh/Lightning-AI/lightning/tree/master) | - | - | +| Linux py3.8 \[IPUs\] | [![Build Status]()](https://dev.azure.com/Lightning-AI/lightning/_build/latest?definitionId=25&branchName=master) | - | - | +| Linux py3.8 \[HPUs\] | - | [![Build Status]()](https://dev.azure.com/Lightning-AI/lightning/_build/latest?definitionId=26&branchName=master) | - | +| Linux py3.8 (with Conda) | [![Test](https://github.com/Lightning-AI/lightning/actions/workflows/ci-pytorch-test-conda.yml/badge.svg?branch=master&event=push)](https://github.com/Lightning-AI/lightning/actions/workflows/ci-pytorch-test-conda.yml) | [![Test](https://github.com/Lightning-AI/lightning/actions/workflows/ci-pytorch-test-conda.yml/badge.svg?branch=master&event=push)](https://github.com/Lightning-AI/lightning/actions/workflows/ci-pytorch-test-conda.yml) | - | +| Linux py3.9 (with Conda) | - | - | [![Test](https://github.com/Lightning-AI/lightning/actions/workflows/ci-pytorch-test-conda.yml/badge.svg?branch=master&event=push)](https://github.com/Lightning-AI/lightning/actions/workflows/ci-pytorch-test-conda.yml) | +| Linux py3.{7,9} | - | - | [![Test](https://github.com/Lightning-AI/lightning/actions/workflows/ci-pytorch-test-full.yml/badge.svg?branch=master&event=push)](https://github.com/Lightning-AI/lightning/actions/workflows/ci-pytorch-test-full.yml) | +| OSX py3.{7,9} | - | - | [![Test](https://github.com/Lightning-AI/lightning/actions/workflows/ci-pytorch-test-full.yml/badge.svg?branch=master&event=push)](https://github.com/Lightning-AI/lightning/actions/workflows/ci-pytorch-test-full.yml) | +| Windows py3.{7,9} | - | - | [![Test](https://github.com/Lightning-AI/lightning/actions/workflows/ci-pytorch-test-full.yml/badge.svg?branch=master&event=push)](https://github.com/Lightning-AI/lightning/actions/workflows/ci-pytorch-test-full.yml) | - _\*\* tests run on two NVIDIA P100_ - _\*\*\* tests run on Google GKE TPUv2/3. TPU py3.7 means we support Colab and Kaggle env._ @@ -130,8 +130,8 @@ conda install pytorch-lightning -c conda-forge The actual status of stable is the following: -[![Test PyTorch full](https://github.com/Lightning-AI/lightning/actions/workflows/ci-pytorch_test-full.yml/badge.svg?branch=release%2Fpytorch&event=push)](https://github.com/Lightning-AI/lightning/actions/workflows/ci-pytorch_test-full.yml) -[![Test PyTorch with Conda](https://github.com/Lightning-AI/lightning/actions/workflows/ci-pytorch_test-conda.yml/badge.svg?branch=release%2Fpytorch&event=push)](https://github.com/Lightning-AI/lightning/actions/workflows/ci-pytorch_test-conda.yml) +[![Test PyTorch full](https://github.com/Lightning-AI/lightning/actions/workflows/ci-pytorch-test-full.yml/badge.svg?branch=release%2Fpytorch&event=push)](https://github.com/Lightning-AI/lightning/actions/workflows/ci-pytorch-test-full.yml) +[![Test PyTorch with Conda](https://github.com/Lightning-AI/lightning/actions/workflows/ci-pytorch-test-conda.yml/badge.svg?branch=release%2Fpytorch&event=push)](https://github.com/Lightning-AI/lightning/actions/workflows/ci-pytorch-test-conda.yml) [![GPU]()](https://dev.azure.com/Lightning-AI/lightning/_build/latest?definitionId=24&branchName=release%2Fpytorch) [![TPU](https://dl.circleci.com/status-badge/img/gh/Lightning-AI/lightning/tree/release%2Fpytorch.svg?style=svg)](https://dl.circleci.com/status-badge/redirect/gh/Lightning-AI/lightning/tree/release%2Fpytorch) [![IPU]()](https://dev.azure.com/Lightning-AI/lightning/_build/latest?definitionId=25&branchName=release%2Fpytorch) diff --git a/src/pytorch_lightning/__about__.py b/src/pytorch_lightning/__about__.py index 6d09c5264e1ab..e2fdbd9ee3016 100644 --- a/src/pytorch_lightning/__about__.py +++ b/src/pytorch_lightning/__about__.py @@ -13,7 +13,6 @@ # limitations under the License. import time -# __version__ = "1.7.0" __author__ = "Lightning AI et al." __author_email__ = "pytorch@lightning.ai" __license__ = "Apache-2.0" diff --git a/src/pytorch_lightning/callbacks/stochastic_weight_avg.py b/src/pytorch_lightning/callbacks/stochastic_weight_avg.py index 20a3dcc3f0f26..6650bb3f0c479 100644 --- a/src/pytorch_lightning/callbacks/stochastic_weight_avg.py +++ b/src/pytorch_lightning/callbacks/stochastic_weight_avg.py @@ -16,7 +16,7 @@ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ """ from copy import deepcopy -from typing import Any, Callable, cast, List, Optional, Union +from typing import Any, Callable, cast, Dict, List, Optional, Union import torch from torch import nn, Tensor @@ -24,6 +24,7 @@ import pytorch_lightning as pl from pytorch_lightning.callbacks.callback import Callback +from pytorch_lightning.strategies import DDPFullyShardedStrategy, DeepSpeedStrategy from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.rank_zero import rank_zero_info, rank_zero_warn from pytorch_lightning.utilities.types import _LRScheduler, LRSchedulerConfig @@ -112,15 +113,22 @@ def __init__( if device is not None and not isinstance(device, (torch.device, str)): raise MisconfigurationException(f"device is expected to be a torch.device or a str. Found {device}") + self.n_averaged: Optional[torch.Tensor] = None self._swa_epoch_start = swa_epoch_start self._swa_lrs = swa_lrs self._annealing_epochs = annealing_epochs self._annealing_strategy = annealing_strategy self._avg_fn = avg_fn or self.avg_fn self._device = device - self._max_epochs: int - self._model_contains_batch_norm: bool + self._model_contains_batch_norm: Optional[bool] = None self._average_model: "pl.LightningModule" + self._initialized = False + self._swa_scheduler: Optional[_LRScheduler] = None + self._scheduler_state: Optional[Dict] = None + self._init_n_averaged = 0 + self._latest_update_epoch = -1 + self.momenta: Optional[Dict[nn.modules.batchnorm._BatchNorm, float]] = None + self._max_epochs: int @property def swa_start(self) -> int: @@ -147,6 +155,9 @@ def on_fit_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") - if len(trainer.lr_scheduler_configs) > 1: raise MisconfigurationException("SWA currently not supported for more than 1 `lr_scheduler`.") + if isinstance(trainer.strategy, (DDPFullyShardedStrategy, DeepSpeedStrategy)): + raise MisconfigurationException("SWA does not currently support sharded models.") + if isinstance(self._swa_epoch_start, float): self._swa_epoch_start = int(trainer.max_epochs * self._swa_epoch_start) @@ -158,8 +169,13 @@ def on_fit_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") - assert trainer.fit_loop.max_epochs is not None trainer.fit_loop.max_epochs += 1 + if self._scheduler_state is not None: + self._clear_schedulers(trainer) + def on_train_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: - if trainer.current_epoch == self.swa_start: + if (not self._initialized) and (self.swa_start <= trainer.current_epoch <= self.swa_end): + self._initialized = True + # move average model to request device. self._average_model = self._average_model.to(self._device or pl_module.device) @@ -180,6 +196,17 @@ def on_train_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningMo last_epoch=trainer.max_epochs if self._annealing_strategy == "cos" else -1, ), ) + if self._scheduler_state is not None: + # Restore scheduler state from checkpoint + self._swa_scheduler.load_state_dict(self._scheduler_state) + elif trainer.current_epoch != self.swa_start: + # Log a warning if we're initializing after start without any checkpoint data, + # as behaviour will be different compared to having checkpoint data. + rank_zero_warn( + "SWA is initializing after swa_start without any checkpoint data. " + "This may be caused by loading a checkpoint from an older version of PyTorch Lightning." + ) + # We assert that there is only one optimizer on fit start, so know opt_idx is always 0 default_scheduler_cfg = LRSchedulerConfig(self._swa_scheduler, opt_idx=0) assert default_scheduler_cfg.interval == "epoch" and default_scheduler_cfg.frequency == 1 @@ -196,14 +223,18 @@ def on_train_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningMo else: trainer.lr_scheduler_configs.append(default_scheduler_cfg) - self.n_averaged = torch.tensor(0, dtype=torch.long, device=pl_module.device) + if self.n_averaged is None: + self.n_averaged = torch.tensor(self._init_n_averaged, dtype=torch.long, device=pl_module.device) - if self.swa_start <= trainer.current_epoch <= self.swa_end: + if (self.swa_start <= trainer.current_epoch <= self.swa_end) and ( + trainer.current_epoch > self._latest_update_epoch + ): + assert self.n_averaged is not None self.update_parameters(self._average_model, pl_module, self.n_averaged, self._avg_fn) + self._latest_update_epoch = trainer.current_epoch # Note: No > here in case the callback is saved with the model and training continues if trainer.current_epoch == self.swa_end + 1: - # Transfer weights from average model to pl_module self.transfer_weights(self._average_model, pl_module) @@ -265,6 +296,7 @@ def reset_batch_norm_and_save_state(self, pl_module: "pl.LightningModule") -> No def reset_momenta(self) -> None: """Adapted from https://github.com/pytorch/pytorch/blob/v1.7.1/torch/optim/swa_utils.py#L164-L165.""" + assert self.momenta is not None for bn_module in self.momenta: bn_module.momentum = self.momenta[bn_module] @@ -285,3 +317,35 @@ def update_parameters( def avg_fn(averaged_model_parameter: Tensor, model_parameter: Tensor, num_averaged: Tensor) -> Tensor: """Adapted from https://github.com/pytorch/pytorch/blob/v1.7.1/torch/optim/swa_utils.py#L95-L97.""" return averaged_model_parameter + (model_parameter - averaged_model_parameter) / (num_averaged + 1) + + def state_dict(self) -> Dict[str, Any]: + return { + "n_averaged": 0 if self.n_averaged is None else self.n_averaged.item(), + "latest_update_epoch": self._latest_update_epoch, + "scheduler_state": None if self._swa_scheduler is None else self._swa_scheduler.state_dict(), + "average_model_state": None if self._average_model is None else self._average_model.state_dict(), + } + + def load_state_dict(self, state_dict: Dict[str, Any]) -> None: + self._init_n_averaged = state_dict["n_averaged"] + self._latest_update_epoch = state_dict["latest_update_epoch"] + self._scheduler_state = state_dict["scheduler_state"] + self._load_average_model_state(state_dict["average_model_state"]) + + @staticmethod + def _clear_schedulers(trainer: "pl.Trainer") -> None: + # If we have scheduler state saved, clear the scheduler configs so that we don't try to + # load state into the wrong type of schedulers when restoring scheduler checkpoint state. + # We'll configure the scheduler and re-load its state in on_train_epoch_start. + # Note that this relies on the callback state being restored before the scheduler state is + # restored, and doesn't work if restore_checkpoint_after_setup is True, but at the time of + # writing that is only True for deepspeed which is already not supported by SWA. + # See https://github.com/PyTorchLightning/pytorch-lightning/issues/11665 for background. + if trainer.lr_scheduler_configs: + assert len(trainer.lr_scheduler_configs) == 1 + trainer.lr_scheduler_configs.clear() + + def _load_average_model_state(self, model_state: Any) -> None: + if self._average_model is None: + return + self._average_model.load_state_dict(model_state) diff --git a/src/pytorch_lightning/core/mixins/device_dtype_mixin.py b/src/pytorch_lightning/core/mixins/device_dtype_mixin.py index 62e81e4839da6..2916d8b07cb4e 100644 --- a/src/pytorch_lightning/core/mixins/device_dtype_mixin.py +++ b/src/pytorch_lightning/core/mixins/device_dtype_mixin.py @@ -116,14 +116,16 @@ def cuda(self, device: Optional[Union[torch.device, int]] = None) -> Self: # ty while being optimized. Arguments: - device: if specified, all parameters will be - copied to that device + device: If specified, all parameters will be copied to that device. If `None`, the current CUDA device + index will be used. Returns: Module: self """ - if device is None or isinstance(device, int): - device = torch.device("cuda", index=(device or 0)) + if device is None: + device = torch.device("cuda", torch.cuda.current_device()) + elif isinstance(device, int): + device = torch.device("cuda", index=device) self.__update_properties(device=device) return super().cuda(device=device) diff --git a/src/pytorch_lightning/core/saving.py b/src/pytorch_lightning/core/saving.py index da81e4c212560..ffdc0988a1a6e 100644 --- a/src/pytorch_lightning/core/saving.py +++ b/src/pytorch_lightning/core/saving.py @@ -20,10 +20,9 @@ from argparse import Namespace from copy import deepcopy from enum import Enum -from typing import Any, Callable, Dict, IO, MutableMapping, Optional, Union +from typing import Any, Callable, cast, Dict, IO, MutableMapping, Optional, Type, Union from warnings import warn -import torch import yaml import pytorch_lightning as pl @@ -34,7 +33,7 @@ from pytorch_lightning.utilities.migration import pl_legacy_patch from pytorch_lightning.utilities.parsing import parse_class_init_keys from pytorch_lightning.utilities.rank_zero import rank_zero_warn -from pytorch_lightning.utilities.types import _PATH +from pytorch_lightning.utilities.types import _MAP_LOCATION_TYPE, _PATH log = logging.getLogger(__name__) PRIMITIVE_TYPES = (bool, int, float, str) @@ -58,11 +57,11 @@ class ModelIO: def load_from_checkpoint( cls, checkpoint_path: Union[str, IO], - map_location: Optional[Union[Dict[str, str], str, torch.device, int, Callable]] = None, + map_location: _MAP_LOCATION_TYPE = None, hparams_file: Optional[str] = None, strict: bool = True, - **kwargs, - ): + **kwargs: Any, + ) -> Union["pl.LightningModule", "pl.LightningDataModule"]: r""" Primary way of loading a model from a checkpoint. When Lightning saves a checkpoint it stores the arguments passed to ``__init__`` in the checkpoint under ``"hyper_parameters"``. @@ -171,15 +170,15 @@ def on_hpc_load(self, checkpoint: Dict[str, Any]) -> None: def _load_from_checkpoint( - cls: Union["pl.LightningModule", "pl.LightningDataModule"], + cls: Union[Type["ModelIO"], Type["pl.LightningModule"], Type["pl.LightningDataModule"]], checkpoint_path: Union[str, IO], - map_location: Optional[Union[Dict[str, str], str, torch.device, int, Callable]] = None, + map_location: _MAP_LOCATION_TYPE = None, hparams_file: Optional[str] = None, - strict: Optional[bool] = None, + strict: bool = True, **kwargs: Any, -) -> Any: +) -> Union["pl.LightningModule", "pl.LightningDataModule"]: if map_location is None: - map_location = lambda storage, loc: storage + map_location = cast(_MAP_LOCATION_TYPE, lambda storage, loc: storage) with pl_legacy_patch(): checkpoint = pl_load(checkpoint_path, map_location=map_location) @@ -202,15 +201,18 @@ def _load_from_checkpoint( if issubclass(cls, pl.LightningDataModule): return _load_state(cls, checkpoint, **kwargs) - return _load_state(cls, checkpoint, strict=strict, **kwargs) + # allow cls to be evaluated as subclassed LightningModule or, + # as LightningModule for internal tests + if issubclass(cls, pl.LightningModule): + return _load_state(cls, checkpoint, strict=strict, **kwargs) def _load_state( - cls: Union["pl.LightningModule", "pl.LightningDataModule"], + cls: Union[Type["pl.LightningModule"], Type["pl.LightningDataModule"]], checkpoint: Dict[str, Any], - strict: Optional[bool] = None, + strict: bool = True, **cls_kwargs_new: Any, -) -> Any: +) -> Union["pl.LightningModule", "pl.LightningDataModule"]: cls_spec = inspect.getfullargspec(cls.__init__) cls_init_args_name = inspect.signature(cls.__init__).parameters.keys() @@ -228,8 +230,7 @@ def _load_state( cls_kwargs_loaded.update(checkpoint.get(_old_hparam_key, {})) # 2. Try to restore model hparams from checkpoint using the new key - _new_hparam_key = cls.CHECKPOINT_HYPER_PARAMS_KEY - cls_kwargs_loaded.update(checkpoint.get(_new_hparam_key)) + cls_kwargs_loaded.update(checkpoint.get(cls.CHECKPOINT_HYPER_PARAMS_KEY, {})) # 3. Ensure that `cls_kwargs_old` has the right type, back compatibility between dict and Namespace cls_kwargs_loaded = _convert_loaded_hparams(cls_kwargs_loaded, checkpoint.get(cls.CHECKPOINT_HYPER_PARAMS_TYPE)) @@ -271,7 +272,9 @@ def _load_state( return obj -def _convert_loaded_hparams(model_args: dict, hparams_type: Optional[Union[Callable, str]] = None) -> object: +def _convert_loaded_hparams( + model_args: Dict[str, Any], hparams_type: Optional[Union[Callable, str]] = None +) -> Dict[str, Any]: """Convert hparams according given type in callable or string (past) format.""" # if not hparams type define if not hparams_type: diff --git a/src/pytorch_lightning/lite/lite.py b/src/pytorch_lightning/lite/lite.py index 5125bf4486a9d..981eed30635f6 100644 --- a/src/pytorch_lightning/lite/lite.py +++ b/src/pytorch_lightning/lite/lite.py @@ -40,7 +40,6 @@ has_iterable_dataset, ) from pytorch_lightning.utilities.exceptions import MisconfigurationException -from pytorch_lightning.utilities.imports import _RequirementAvailable from pytorch_lightning.utilities.seed import seed_everything @@ -106,8 +105,6 @@ def __init__( self._precision_plugin = self._strategy.precision_plugin self._models_setup: int = 0 - self._check_deepspeed_support() - # wrap the run method so we can inject setup logic or spawn processes for the user setattr(self, "run", partial(self._run_impl, self.run)) @@ -459,18 +456,6 @@ def _check_strategy_support(self, strategy: Optional[Union[str, Strategy]]) -> N f" Choose one of {supported} or pass in a `Strategy` instance." ) - def _check_deepspeed_support(self) -> None: - if ( - isinstance(self._strategy, DeepSpeedStrategy) - and self._strategy.zero_stage_3 - and _RequirementAvailable("deepspeed>=0.6.5") - ): - # https://github.com/microsoft/DeepSpeed/issues/2139 - raise RuntimeError( - "DeepSpeed ZeRO-3 is not supported with this version of Lightning Lite and `deepspeed>=0.6.5`." - " Please downgrade deepspeed to 0.6.4 or check if a newer version of Lightning is available." - ) - @staticmethod def _supported_device_types() -> Sequence[_AcceleratorType]: return ( diff --git a/src/pytorch_lightning/loggers/wandb.py b/src/pytorch_lightning/loggers/wandb.py index 8e30827759b99..530fb58fabe5e 100644 --- a/src/pytorch_lightning/loggers/wandb.py +++ b/src/pytorch_lightning/loggers/wandb.py @@ -328,7 +328,7 @@ def __getstate__(self) -> Dict[str, Any]: @property # type: ignore[misc] @rank_zero_experiment - def experiment(self) -> Run: + def experiment(self) -> Union[Run, RunDisabled]: r""" Actual wandb object. To use wandb features in your @@ -361,11 +361,13 @@ def experiment(self) -> Run: self._experiment = wandb.init(**self._wandb_init) # define default x-axis - if isinstance(self._experiment, Run) and getattr(self._experiment, "define_metric", None): + if isinstance(self._experiment, (Run, RunDisabled)) and getattr( + self._experiment, "define_metric", None + ): self._experiment.define_metric("trainer/global_step") self._experiment.define_metric("*", step_metric="trainer/global_step", step_sync=True) - assert isinstance(self._experiment, Run) + assert isinstance(self._experiment, (Run, RunDisabled)) return self._experiment def watch(self, model: nn.Module, log: str = "gradients", log_freq: int = 100, log_graph: bool = True) -> None: diff --git a/src/pytorch_lightning/overrides/base.py b/src/pytorch_lightning/overrides/base.py index 26c2837bda7e3..3e9fda2f966f5 100644 --- a/src/pytorch_lightning/overrides/base.py +++ b/src/pytorch_lightning/overrides/base.py @@ -75,6 +75,7 @@ def forward(self, *inputs: Any, **kwargs: Any) -> Any: trainer = pl_module._trainer if trainer is not None: + assert isinstance(self.module, (pl.LightningModule, _LightningPrecisionModuleWrapperBase)) if trainer.training: output = self.module.training_step(*inputs, **kwargs) # In manual_optimization, we need to prevent DDP reducer as diff --git a/src/pytorch_lightning/overrides/distributed.py b/src/pytorch_lightning/overrides/distributed.py index f09a7b9e3ae08..929d1ed486f4a 100644 --- a/src/pytorch_lightning/overrides/distributed.py +++ b/src/pytorch_lightning/overrides/distributed.py @@ -45,8 +45,6 @@ def _find_tensors( # https://github.com/pytorch/pytorch/blob/v1.7.1/torch/nn/parallel/distributed.py#L626-L638 def prepare_for_backward(model: DistributedDataParallel, output: Any) -> None: # `prepare_for_backward` is `DistributedDataParallel` specific. - if not isinstance(model, DistributedDataParallel): - return if torch.is_grad_enabled() and model.require_backward_grad_sync: model.require_forward_param_sync = True # type: ignore[assignment] # We'll return the output object verbatim since it is a freeform diff --git a/src/pytorch_lightning/plugins/precision/deepspeed.py b/src/pytorch_lightning/plugins/precision/deepspeed.py index 01d3017760b0e..456bba1e77823 100644 --- a/src/pytorch_lightning/plugins/precision/deepspeed.py +++ b/src/pytorch_lightning/plugins/precision/deepspeed.py @@ -60,7 +60,7 @@ def __init__(self, precision: Union[str, int], amp_type: str, amp_level: Optiona amp_level = amp_level or "O2" - supported_precision = (PrecisionType.HALF, PrecisionType.FLOAT, PrecisionType.BFLOAT, PrecisionType.MIXED) + supported_precision = (PrecisionType.HALF, PrecisionType.FLOAT, PrecisionType.BFLOAT) if precision not in supported_precision: raise ValueError( f"`Trainer(strategy='deepspeed', precision={precision!r})` is not supported." diff --git a/src/pytorch_lightning/plugins/precision/fully_sharded_native_amp.py b/src/pytorch_lightning/plugins/precision/fully_sharded_native_amp.py index 8c693f2975bbd..60e53b880c84d 100644 --- a/src/pytorch_lightning/plugins/precision/fully_sharded_native_amp.py +++ b/src/pytorch_lightning/plugins/precision/fully_sharded_native_amp.py @@ -23,7 +23,7 @@ if _TORCH_GREATER_EQUAL_1_12: from torch.distributed.fsdp.fully_sharded_data_parallel import MixedPrecision else: - MixedPrecision = None + MixedPrecision = None # type: ignore[misc,assignment] class FullyShardedNativeMixedPrecisionPlugin(ShardedNativeMixedPrecisionPlugin): diff --git a/src/pytorch_lightning/plugins/precision/ipu.py b/src/pytorch_lightning/plugins/precision/ipu.py index 89f544575f63f..67e5e373e9f52 100644 --- a/src/pytorch_lightning/plugins/precision/ipu.py +++ b/src/pytorch_lightning/plugins/precision/ipu.py @@ -19,6 +19,7 @@ import pytorch_lightning as pl from pytorch_lightning.plugins.precision.precision_plugin import PrecisionPlugin from pytorch_lightning.utilities import GradClipAlgorithmType +from pytorch_lightning.utilities.enums import PrecisionType from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.model_helpers import is_overridden from pytorch_lightning.utilities.warnings import WarningCache @@ -35,7 +36,7 @@ class IPUPrecisionPlugin(PrecisionPlugin): """ def __init__(self, precision: int) -> None: - supported_precision_values = (16, 32) + supported_precision_values = (PrecisionType.HALF, PrecisionType.FLOAT) if precision not in supported_precision_values: raise ValueError( f"`Trainer(accelerator='ipu', precision={precision!r})` is not supported." diff --git a/src/pytorch_lightning/profilers/simple.py b/src/pytorch_lightning/profilers/simple.py index 20d76f9b2d378..0fb9497ff17fb 100644 --- a/src/pytorch_lightning/profilers/simple.py +++ b/src/pytorch_lightning/profilers/simple.py @@ -60,7 +60,7 @@ def __init__( """ super().__init__(dirpath=dirpath, filename=filename) self.current_actions: Dict[str, float] = {} - self.recorded_durations = defaultdict(list) + self.recorded_durations: Dict = defaultdict(list) self.extended = extended self.start_time = time.monotonic() @@ -104,20 +104,23 @@ def summary(self) -> str: if len(self.recorded_durations) > 0: max_key = max(len(k) for k in self.recorded_durations.keys()) - def log_row(action, mean, num_calls, total, per): + def log_row_extended(action: str, mean: str, num_calls: str, total: str, per: str) -> str: row = f"{sep}| {action:<{max_key}s}\t| {mean:<15}\t|" row += f" {num_calls:<15}\t| {total:<15}\t| {per:<15}\t|" return row - header_string = log_row("Action", "Mean duration (s)", "Num calls", "Total time (s)", "Percentage %") + header_string = log_row_extended( + "Action", "Mean duration (s)", "Num calls", "Total time (s)", "Percentage %" + ) output_string_len = len(header_string.expandtabs()) sep_lines = f"{sep}{'-' * output_string_len}" output_string += sep_lines + header_string + sep_lines - report, total_calls, total_duration = self._make_report_extended() - output_string += log_row("Total", "-", f"{total_calls:}", f"{total_duration:.5}", "100 %") + report_extended: _TABLE_DATA_EXTENDED + report_extended, total_calls, total_duration = self._make_report_extended() + output_string += log_row_extended("Total", "-", f"{total_calls:}", f"{total_duration:.5}", "100 %") output_string += sep_lines - for action, mean_duration, num_calls, total_duration, duration_per in report: - output_string += log_row( + for action, mean_duration, num_calls, total_duration, duration_per in report_extended: + output_string += log_row_extended( action, f"{mean_duration:.5}", f"{num_calls}", @@ -128,7 +131,7 @@ def log_row(action, mean, num_calls, total, per): else: max_key = max(len(k) for k in self.recorded_durations) - def log_row(action, mean, total): + def log_row(action: str, mean: str, total: str) -> str: return f"{sep}| {action:<{max_key}s}\t| {mean:<15}\t| {total:<15}\t|" header_string = log_row("Action", "Mean duration (s)", "Total time (s)") diff --git a/src/pytorch_lightning/strategies/ddp.py b/src/pytorch_lightning/strategies/ddp.py index 922730df35269..57ab3a151b011 100644 --- a/src/pytorch_lightning/strategies/ddp.py +++ b/src/pytorch_lightning/strategies/ddp.py @@ -32,6 +32,7 @@ import pytorch_lightning as pl from pytorch_lightning.core.optimizer import LightningOptimizer from pytorch_lightning.overrides import LightningDistributedModule +from pytorch_lightning.overrides.base import _LightningPrecisionModuleWrapperBase from pytorch_lightning.overrides.distributed import prepare_for_backward from pytorch_lightning.overrides.fairscale import _FAIRSCALE_AVAILABLE from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment @@ -39,6 +40,7 @@ from pytorch_lightning.plugins.precision import PrecisionPlugin from pytorch_lightning.strategies.launchers.subprocess_script import _SubprocessScriptLauncher from pytorch_lightning.strategies.parallel import ParallelStrategy +from pytorch_lightning.strategies.strategy import TBroadcast from pytorch_lightning.trainer.states import TrainerFn from pytorch_lightning.utilities.distributed import ( _get_process_group_backend_from_env, @@ -57,7 +59,7 @@ from pytorch_lightning.utilities.optimizer import optimizers_to_device from pytorch_lightning.utilities.rank_zero import rank_zero_info, rank_zero_only, rank_zero_warn from pytorch_lightning.utilities.seed import reset_seed -from pytorch_lightning.utilities.types import STEP_OUTPUT +from pytorch_lightning.utilities.types import PredictStep, STEP_OUTPUT, TestStep, ValidationStep if _FAIRSCALE_AVAILABLE: from fairscale.optim import OSS @@ -83,12 +85,12 @@ def __init__( checkpoint_io: Optional[CheckpointIO] = None, precision_plugin: Optional[PrecisionPlugin] = None, ddp_comm_state: Optional[object] = None, - ddp_comm_hook: Optional[callable] = None, - ddp_comm_wrapper: Optional[callable] = None, + ddp_comm_hook: Optional[Callable] = None, + ddp_comm_wrapper: Optional[Callable] = None, model_averaging_period: Optional[int] = None, process_group_backend: Optional[str] = None, timeout: Optional[timedelta] = default_pg_timeout, - **kwargs: Union[Any, Dict[str, Any]], + **kwargs: Any, ) -> None: super().__init__( accelerator=accelerator, @@ -105,7 +107,7 @@ def __init__( self._ddp_comm_wrapper = ddp_comm_wrapper self._model_averaging_period = model_averaging_period self._model_averager: Optional[ModelAverager] = None - self._pids: Optional[List[int]] = None + self._pids: List[int] = [] self._sync_dir: Optional[str] = None self._rank_0_will_call_children_scripts: bool = False self._process_group_backend: Optional[str] = process_group_backend @@ -117,6 +119,7 @@ def is_distributed(self) -> bool: @property def root_device(self) -> torch.device: + assert self.parallel_devices is not None return self.parallel_devices[self.local_rank] @property @@ -129,11 +132,11 @@ def num_nodes(self, num_nodes: int) -> None: self._num_nodes = num_nodes @property - def num_processes(self): + def num_processes(self) -> int: return len(self.parallel_devices) if self.parallel_devices is not None else 0 @property - def distributed_sampler_kwargs(self): + def distributed_sampler_kwargs(self) -> Dict[str, Any]: distributed_sampler_kwargs = dict(num_replicas=(self.num_nodes * self.num_processes), rank=self.global_rank) return distributed_sampler_kwargs @@ -146,6 +149,7 @@ def process_group_backend(self) -> Optional[str]: return self._process_group_backend def _configure_launcher(self) -> None: + assert self.cluster_environment is not None if not self.cluster_environment.creates_processes_externally: self._launcher = _SubprocessScriptLauncher(self.cluster_environment, self.num_processes, self.num_nodes) self._rank_0_will_call_children_scripts = True @@ -156,10 +160,11 @@ def setup_environment(self) -> None: def setup(self, trainer: "pl.Trainer") -> None: # share ddp pids to all processes - self._rank_0_will_call_children_scripts = self.broadcast(self._rank_0_will_call_children_scripts) + self._rank_0_will_call_children_scripts = bool(self.broadcast(self._rank_0_will_call_children_scripts)) if self._should_run_deadlock_detection(): self._share_information_to_prevent_deadlock() + assert self.accelerator is not None self.accelerator.setup(trainer) # move the model to the correct device @@ -170,6 +175,7 @@ def setup(self, trainer: "pl.Trainer") -> None: if trainer_fn == TrainerFn.FITTING: if self._layer_sync: + assert self.model is not None self.model = self._layer_sync.apply(self.model) self.setup_precision_plugin() @@ -193,7 +199,7 @@ def _setup_model(self, model: Module) -> DistributedDataParallel: log.detail(f"setting up DDP model with device ids: {device_ids}, kwargs: {self._ddp_kwargs}") return DistributedDataParallel(module=model, device_ids=device_ids, **self._ddp_kwargs) - def setup_distributed(self): + def setup_distributed(self) -> None: log.detail(f"{self.__class__.__name__}: setting up distributed...") reset_seed() @@ -204,6 +210,7 @@ def setup_distributed(self): rank_zero_only.rank = self.global_rank self._process_group_backend = self._get_process_group_backend() + assert self.cluster_environment is not None init_dist_connection(self.cluster_environment, self._process_group_backend, timeout=self._timeout) def _get_process_group_backend(self) -> str: @@ -230,6 +237,7 @@ def pre_configure_ddp(self) -> None: def _register_ddp_hooks(self) -> None: log.detail(f"{self.__class__.__name__}: registering ddp hooks") if self.root_device.type == "cuda" and self._is_single_process_single_device: + assert isinstance(self.model, DistributedDataParallel) register_ddp_comm_hook( model=self.model, ddp_comm_state=self._ddp_comm_state, @@ -262,6 +270,7 @@ def _enable_model_averaging(self) -> None: f"{optimizer.__class__.__name__}." ) + assert self._ddp_comm_state is not None self._model_averager = torch.distributed.algorithms.model_averaging.averagers.PeriodicModelAverager( period=self._model_averaging_period, warmup_steps=self._ddp_comm_state.start_localSGD_iter ) @@ -296,15 +305,16 @@ def optimizer_step( def configure_ddp(self) -> None: log.detail(f"{self.__class__.__name__}: configuring DistributedDataParallel") self.pre_configure_ddp() + assert isinstance(self.model, (pl.LightningModule, _LightningPrecisionModuleWrapperBase)) self.model = self._setup_model(LightningDistributedModule(self.model)) self._register_ddp_hooks() - def determine_ddp_device_ids(self): + def determine_ddp_device_ids(self) -> Optional[List[int]]: if self.root_device.type == "cpu": return None return [self.root_device.index] - def barrier(self, *args, **kwargs) -> None: + def barrier(self, *args: Any, **kwargs: Any) -> None: if not distributed_available(): return if torch.distributed.get_backend() == "nccl": @@ -312,23 +322,29 @@ def barrier(self, *args, **kwargs) -> None: else: torch.distributed.barrier() - def broadcast(self, obj: object, src: int = 0) -> object: + def broadcast(self, obj: TBroadcast, src: int = 0) -> TBroadcast: obj = [obj] if self.global_rank != src: - obj = [None] + obj = [None] # type: ignore[list-item] torch.distributed.broadcast_object_list(obj, src, group=_group.WORLD) return obj[0] def pre_backward(self, closure_loss: Tensor) -> None: """Run before precision plugin executes backward.""" + if not isinstance(self.model, DistributedDataParallel): + return + assert self.lightning_module is not None if not self.lightning_module.automatic_optimization: prepare_for_backward(self.model, closure_loss) - def model_to_device(self): + def model_to_device(self) -> None: log.detail(f"{self.__class__.__name__}: moving model to device [{self.root_device}]...") + assert self.model is not None self.model.to(self.root_device) - def reduce(self, tensor, group: Optional[Any] = None, reduce_op: Union[ReduceOp, str] = "mean") -> Tensor: + def reduce( + self, tensor: Tensor, group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = "mean" + ) -> Tensor: """Reduces a tensor from several distributed processes to one aggregated tensor. Args: @@ -344,30 +360,38 @@ def reduce(self, tensor, group: Optional[Any] = None, reduce_op: Union[ReduceOp, tensor = sync_ddp_if_available(tensor, group, reduce_op=reduce_op) return tensor - def training_step(self, *args, **kwargs) -> STEP_OUTPUT: + def training_step(self, *args: Any, **kwargs: Any) -> STEP_OUTPUT: + assert self.model is not None with self.precision_plugin.train_step_context(): return self.model(*args, **kwargs) - def validation_step(self, *args, **kwargs) -> Optional[STEP_OUTPUT]: + def validation_step(self, *args: Any, **kwargs: Any) -> Optional[STEP_OUTPUT]: with self.precision_plugin.val_step_context(): + assert self.lightning_module is not None + assert self.model is not None if self.lightning_module.trainer.state.fn == TrainerFn.FITTING: # used when calling `trainer.fit` return self.model(*args, **kwargs) else: # used when calling `trainer.validate` + assert isinstance(self.model, ValidationStep) return self.model.validation_step(*args, **kwargs) - def test_step(self, *args, **kwargs) -> Optional[STEP_OUTPUT]: + def test_step(self, *args: Any, **kwargs: Any) -> Optional[STEP_OUTPUT]: with self.precision_plugin.test_step_context(): + assert isinstance(self.model, TestStep) return self.model.test_step(*args, **kwargs) - def predict_step(self, *args, **kwargs) -> STEP_OUTPUT: + def predict_step(self, *args: Any, **kwargs: Any) -> STEP_OUTPUT: with self.precision_plugin.predict_step_context(): + assert isinstance(self.model, PredictStep) return self.model.predict_step(*args, **kwargs) - def post_training_step(self): + def post_training_step(self) -> None: + assert self.lightning_module is not None if not self.lightning_module.automatic_optimization: - self.model.require_backward_grad_sync = True + assert self.model is not None + self.model.require_backward_grad_sync = True # type: ignore[assignment] @classmethod def register_strategies(cls, strategy_registry: Dict) -> None: @@ -458,7 +482,7 @@ def teardown(self) -> None: if ( _TORCH_GREATER_EQUAL_1_11 and not self.model.static_graph - and self.model._get_ddp_logging_data().get("can_set_static_graph") + and self.model._get_ddp_logging_data().get("can_set_static_graph") # type: ignore[operator] ): rank_zero_info( "Your model can run with static graph optimizations. For future training runs, we suggest you" @@ -475,6 +499,7 @@ def teardown(self) -> None: and pl_module._trainer.state.fn == TrainerFn.FITTING and self._layer_sync ): + assert self.model is not None self.model = self._layer_sync.revert(self.model) super().teardown() diff --git a/src/pytorch_lightning/strategies/ddp_spawn.py b/src/pytorch_lightning/strategies/ddp_spawn.py index 30bcef457c44a..de34320f54093 100644 --- a/src/pytorch_lightning/strategies/ddp_spawn.py +++ b/src/pytorch_lightning/strategies/ddp_spawn.py @@ -254,9 +254,10 @@ def model_to_device(self) -> None: def pre_backward(self, closure_loss: Tensor) -> None: """Run before precision plugin executes backward.""" + if not isinstance(self.model, DistributedDataParallel): + return assert self.lightning_module is not None if not self.lightning_module.automatic_optimization: - assert isinstance(self.model, DistributedDataParallel) prepare_for_backward(self.model, closure_loss) def reduce( @@ -314,10 +315,20 @@ def post_training_step(self) -> None: def register_strategies(cls, strategy_registry: Dict) -> None: entries = ( ("ddp_spawn", "spawn"), - ("ddp_spawn_find_unused_parameters_false", "spawn"), ("ddp_fork", "fork"), - ("ddp_fork_find_unused_parameters_false", "fork"), ("ddp_notebook", "fork"), + ) + for name, start_method in entries: + strategy_registry.register( + name, + cls, + description=f"DDP strategy with `start_method` '{start_method}'", + start_method=start_method, + ) + + entries = ( + ("ddp_spawn_find_unused_parameters_false", "spawn"), + ("ddp_fork_find_unused_parameters_false", "fork"), ("ddp_notebook_find_unused_parameters_false", "fork"), ) for name, start_method in entries: diff --git a/src/pytorch_lightning/strategies/deepspeed.py b/src/pytorch_lightning/strategies/deepspeed.py index b0b55374ba1a9..8acbc80257bd1 100644 --- a/src/pytorch_lightning/strategies/deepspeed.py +++ b/src/pytorch_lightning/strategies/deepspeed.py @@ -19,7 +19,7 @@ import platform from collections import OrderedDict from pathlib import Path -from typing import Any, cast, Dict, Generator, List, Mapping, Optional, Tuple, Union +from typing import Any, Dict, Generator, List, Mapping, Optional, Tuple, Union import torch from torch import Tensor @@ -696,7 +696,7 @@ def _auto_select_batch_size(self) -> int: def _format_precision_config(self) -> None: assert isinstance(self.config, dict) - if self.precision_plugin.precision in (PrecisionType.HALF, PrecisionType.MIXED): + if self.precision_plugin.precision == PrecisionType.HALF: if "fp16" not in self.config and self.precision_plugin.amp_type == AMPType.NATIVE: # FP16 is a DeepSpeed standalone AMP implementation rank_zero_info("Enabling DeepSpeed FP16.") @@ -831,7 +831,7 @@ def load_checkpoint(self, checkpoint_path: _PATH) -> Dict[str, Any]: if self.load_full_weights and self.zero_stage_3: # Broadcast to ensure we load from the rank 0 checkpoint # This doesn't have to be the case when using deepspeed sharded checkpointing - checkpoint_path = cast(_PATH, self.broadcast(checkpoint_path)) + checkpoint_path = self.broadcast(checkpoint_path) return super().load_checkpoint(checkpoint_path) # Rely on deepspeed to load the checkpoint and necessary information diff --git a/src/pytorch_lightning/strategies/fully_sharded_native.py b/src/pytorch_lightning/strategies/fully_sharded_native.py index 4c351f26fa3b9..d92931fb5cdb2 100644 --- a/src/pytorch_lightning/strategies/fully_sharded_native.py +++ b/src/pytorch_lightning/strategies/fully_sharded_native.py @@ -51,7 +51,7 @@ ) from torch.distributed.fsdp.wrap import enable_wrap else: - MixedPrecision = None + MixedPrecision = None # type: ignore[misc,assignment] BackwardPrefetch = None # type: ignore[misc,assignment] CPUOffload = None # type: ignore[misc,assignment] diff --git a/src/pytorch_lightning/strategies/ipu.py b/src/pytorch_lightning/strategies/ipu.py index c40addd4244b2..4bedbfd6d70fc 100644 --- a/src/pytorch_lightning/strategies/ipu.py +++ b/src/pytorch_lightning/strategies/ipu.py @@ -58,7 +58,7 @@ def __init__( self.precision = precision def forward(self, *inputs: Any, **kwargs: Any) -> Any: - if self.precision in (PrecisionType.MIXED, PrecisionType.HALF): + if self.precision == PrecisionType.HALF: inputs = self._move_float_tensors_to_half(inputs) return super().forward(*inputs, **kwargs) diff --git a/src/pytorch_lightning/strategies/launchers/multiprocessing.py b/src/pytorch_lightning/strategies/launchers/multiprocessing.py index 39bba092e9c60..2617e5fe27b10 100644 --- a/src/pytorch_lightning/strategies/launchers/multiprocessing.py +++ b/src/pytorch_lightning/strategies/launchers/multiprocessing.py @@ -144,7 +144,7 @@ def _recover_results_in_main_process(self, worker_output: "_WorkerOutput", train # load last weights if worker_output.weights_path is not None: ckpt = self._strategy.checkpoint_io.load_checkpoint(worker_output.weights_path) - trainer.lightning_module.load_state_dict(ckpt) # type: ignore[arg-type] + trainer.lightning_module.load_state_dict(ckpt) self._strategy.checkpoint_io.remove_checkpoint(worker_output.weights_path) trainer.state = worker_output.trainer_state diff --git a/src/pytorch_lightning/strategies/sharded_spawn.py b/src/pytorch_lightning/strategies/sharded_spawn.py index 4550e397ded80..882302e101cb6 100644 --- a/src/pytorch_lightning/strategies/sharded_spawn.py +++ b/src/pytorch_lightning/strategies/sharded_spawn.py @@ -12,13 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. from contextlib import contextmanager -from typing import Dict, Generator, List, Optional, Tuple +from typing import Any, Dict, Generator, List, Optional, Tuple from torch import Tensor from torch.nn import Module from torch.optim import Optimizer import pytorch_lightning as pl +from pytorch_lightning.overrides.base import _LightningPrecisionModuleWrapperBase from pytorch_lightning.overrides.fairscale import _FAIRSCALE_AVAILABLE from pytorch_lightning.strategies.ddp_spawn import DDPSpawnStrategy from pytorch_lightning.trainer.states import TrainerFn @@ -42,7 +43,9 @@ class DDPSpawnShardedStrategy(DDPSpawnStrategy): def configure_ddp(self) -> None: # set up optimizers after the wrapped module has been moved to the device + assert self.lightning_module is not None self.setup_optimizers(self.lightning_module.trainer) + assert isinstance(self.model, (pl.LightningModule, _LightningPrecisionModuleWrapperBase)) self.model, self.optimizers = self._setup_model_and_optimizers( model=LightningShardedDataParallel(self.model), optimizers=self.optimizers ) @@ -69,12 +72,13 @@ def _reinit_optimizers_with_oss(self, optimizers: List[Optimizer]) -> List["OSS" return optimizers def _wrap_optimizers(self, optimizers: List[Optimizer]) -> List["OSS"]: - if self.model is not None and self.model.trainer.state.fn != TrainerFn.FITTING: + assert self.lightning_module + if self.model is not None and self.lightning_module.trainer.state.fn != TrainerFn.FITTING: return optimizers return self._reinit_optimizers_with_oss(optimizers) - def optimizer_state(self, optimizer: "OSS") -> Optional[dict]: + def optimizer_state(self, optimizer: "OSS") -> Dict[str, Any]: if isinstance(optimizer, OSS): optimizer.consolidate_state_dict() return self._optim_state_dict(optimizer) @@ -93,7 +97,7 @@ def block_backward_sync(self) -> Generator: yield None @rank_zero_only - def _optim_state_dict(self, optimizer): + def _optim_state_dict(self, optimizer: Optimizer) -> Dict[str, Any]: """ Retrieves state dict only on rank 0, which contains the entire optimizer state after calling :meth:`consolidate_state_dict`. @@ -112,7 +116,7 @@ def lightning_module(self) -> Optional["pl.LightningModule"]: def pre_backward(self, closure_loss: Tensor) -> None: pass - def post_training_step(self): + def post_training_step(self) -> None: pass @classmethod diff --git a/src/pytorch_lightning/strategies/utils.py b/src/pytorch_lightning/strategies/utils.py index b71458bfc30d3..cdae7bf434eca 100644 --- a/src/pytorch_lightning/strategies/utils.py +++ b/src/pytorch_lightning/strategies/utils.py @@ -24,7 +24,7 @@ def on_colab_kaggle() -> bool: def _fp_to_half(tensor: torch.Tensor, precision: PrecisionType) -> torch.Tensor: if torch.is_floating_point(tensor): - if precision in (PrecisionType.MIXED, PrecisionType.HALF): + if precision == PrecisionType.HALF: return tensor.half() if precision == PrecisionType.BFLOAT: return tensor.bfloat16() diff --git a/src/pytorch_lightning/trainer/connectors/callback_connector.py b/src/pytorch_lightning/trainer/connectors/callback_connector.py index 83881905beeb1..32d67d44ad44c 100644 --- a/src/pytorch_lightning/trainer/connectors/callback_connector.py +++ b/src/pytorch_lightning/trainer/connectors/callback_connector.py @@ -17,6 +17,7 @@ from datetime import timedelta from typing import Dict, List, Optional, Sequence, Union +import pytorch_lightning as pl from pytorch_lightning.callbacks import ( Callback, Checkpoint, @@ -30,14 +31,14 @@ from pytorch_lightning.callbacks.rich_model_summary import RichModelSummary from pytorch_lightning.callbacks.timer import Timer from pytorch_lightning.utilities.exceptions import MisconfigurationException -from pytorch_lightning.utilities.imports import _PYTHON_GREATER_EQUAL_3_8_0 +from pytorch_lightning.utilities.imports import _PYTHON_GREATER_EQUAL_3_8_0, _PYTHON_GREATER_EQUAL_3_10_0 from pytorch_lightning.utilities.rank_zero import rank_zero_deprecation, rank_zero_info _log = logging.getLogger(__name__) class CallbackConnector: - def __init__(self, trainer): + def __init__(self, trainer: "pl.Trainer"): self.trainer = trainer def on_trainer_init( @@ -50,7 +51,7 @@ def on_trainer_init( enable_model_summary: bool, max_time: Optional[Union[str, timedelta, Dict[str, int]]] = None, accumulate_grad_batches: Optional[Union[int, Dict[int, int]]] = None, - ): + ) -> None: # init folder paths for checkpoint + weights save callbacks self.trainer._default_root_dir = default_root_dir or os.getcwd() if weights_save_path: @@ -95,16 +96,18 @@ def on_trainer_init( def _configure_accumulated_gradients( self, accumulate_grad_batches: Optional[Union[int, Dict[int, int]]] = None ) -> None: - grad_accum_callback = [cb for cb in self.trainer.callbacks if isinstance(cb, GradientAccumulationScheduler)] + grad_accum_callbacks: List[GradientAccumulationScheduler] = [ + cb for cb in self.trainer.callbacks if isinstance(cb, GradientAccumulationScheduler) + ] - if grad_accum_callback: + if grad_accum_callbacks: if accumulate_grad_batches is not None: raise MisconfigurationException( "You have set both `accumulate_grad_batches` and passed an instance of " "`GradientAccumulationScheduler` inside callbacks. Either remove `accumulate_grad_batches` " "from trainer or remove `GradientAccumulationScheduler` from callbacks list." ) - grad_accum_callback = grad_accum_callback[0] + grad_accum_callback = grad_accum_callbacks[0] else: if accumulate_grad_batches is None: accumulate_grad_batches = 1 @@ -148,6 +151,7 @@ def _configure_model_summary_callback(self, enable_model_summary: bool) -> None: progress_bar_callback = self.trainer.progress_bar_callback is_progress_bar_rich = isinstance(progress_bar_callback, RichProgressBar) + model_summary: ModelSummary if progress_bar_callback is not None and is_progress_bar_rich: model_summary = RichModelSummary() else: @@ -188,7 +192,7 @@ def _configure_timer_callback(self, max_time: Optional[Union[str, timedelta, Dic timer = Timer(duration=max_time, interval="step") self.trainer.callbacks.append(timer) - def _configure_fault_tolerance_callbacks(self): + def _configure_fault_tolerance_callbacks(self) -> None: from pytorch_lightning.callbacks.fault_tolerance import _FaultToleranceCheckpoint if any(isinstance(cb, _FaultToleranceCheckpoint) for cb in self.trainer.callbacks): @@ -196,7 +200,7 @@ def _configure_fault_tolerance_callbacks(self): # don't use `log_dir` to minimize the chances of failure self.trainer.callbacks.append(_FaultToleranceCheckpoint(dirpath=self.trainer.default_root_dir)) - def _attach_model_logging_functions(self): + def _attach_model_logging_functions(self) -> None: lightning_module = self.trainer.lightning_module for callback in self.trainer.callbacks: callback.log = lightning_module.log @@ -243,7 +247,7 @@ def _reorder_callbacks(callbacks: List[Callback]) -> List[Callback]: A new list in which the last elements are Checkpoint if there were any present in the input. """ - checkpoints = [c for c in callbacks if isinstance(c, Checkpoint)] + checkpoints: List[Callback] = [c for c in callbacks if isinstance(c, Checkpoint)] not_checkpoints = [c for c in callbacks if not isinstance(c, Checkpoint)] return not_checkpoints + checkpoints @@ -256,19 +260,24 @@ def _configure_external_callbacks() -> List[Callback]: Return: A list of all callbacks collected from external factories. """ + group = "pytorch_lightning.callbacks_factory" + if _PYTHON_GREATER_EQUAL_3_8_0: from importlib.metadata import entry_points - factories = entry_points().get("pytorch_lightning.callbacks_factory", ()) + if _PYTHON_GREATER_EQUAL_3_10_0: + factories = entry_points(group=group) # type: ignore[call-arg] + else: + factories = entry_points().get(group, {}) # type: ignore[assignment] else: from pkg_resources import iter_entry_points - factories = iter_entry_points("pytorch_lightning.callbacks_factory") + factories = iter_entry_points(group) # type: ignore[assignment] - external_callbacks = [] + external_callbacks: List[Callback] = [] for factory in factories: callback_factory = factory.load() - callbacks_list: List[Callback] = callback_factory() + callbacks_list: Union[List[Callback], Callback] = callback_factory() callbacks_list = [callbacks_list] if isinstance(callbacks_list, Callback) else callbacks_list _log.info( f"Adding {len(callbacks_list)} callbacks from entry point '{factory.name}':" diff --git a/src/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py b/src/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py index ff882912625d0..02e17a8d93494 100644 --- a/src/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py +++ b/src/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py @@ -163,8 +163,7 @@ def update_train_epoch_metrics(self) -> None: self.log_metrics(self.metrics["log"]) # reset result collection for next epoch - assert self.trainer._results is not None - self.trainer._results.reset(metrics=True) + self.reset_results() """ Utilities and properties diff --git a/src/pytorch_lightning/utilities/__init__.py b/src/pytorch_lightning/utilities/__init__.py index df5084dd85490..c849ba0a05d68 100644 --- a/src/pytorch_lightning/utilities/__init__.py +++ b/src/pytorch_lightning/utilities/__init__.py @@ -21,7 +21,6 @@ _AcceleratorType, _StrategyType, AMPType, - DistributedType, GradClipAlgorithmType, LightningEnum, ) diff --git a/src/pytorch_lightning/utilities/cloud_io.py b/src/pytorch_lightning/utilities/cloud_io.py index 81482a8ab24f9..99629bcda8980 100644 --- a/src/pytorch_lightning/utilities/cloud_io.py +++ b/src/pytorch_lightning/utilities/cloud_io.py @@ -15,21 +15,19 @@ import io from pathlib import Path -from typing import Any, Callable, Dict, IO, Optional, Union +from typing import Any, Dict, IO, Union import fsspec import torch from fsspec.core import url_to_fs from fsspec.implementations.local import AbstractFileSystem -from pytorch_lightning.utilities.types import _PATH +from pytorch_lightning.utilities.types import _MAP_LOCATION_TYPE, _PATH def load( path_or_url: Union[IO, _PATH], - map_location: Optional[ - Union[str, Callable, torch.device, Dict[Union[str, torch.device], Union[str, torch.device]]] - ] = None, + map_location: _MAP_LOCATION_TYPE = None, ) -> Any: """Loads a checkpoint. @@ -41,7 +39,10 @@ def load( # any sort of BytesIO or similar return torch.load(path_or_url, map_location=map_location) if str(path_or_url).startswith("http"): - return torch.hub.load_state_dict_from_url(str(path_or_url), map_location=map_location) + return torch.hub.load_state_dict_from_url( + str(path_or_url), + map_location=map_location, # type: ignore[arg-type] # upstream annotation is not correct + ) fs = get_filesystem(path_or_url) with fs.open(path_or_url, "rb") as f: return torch.load(f, map_location=map_location) diff --git a/src/pytorch_lightning/utilities/data.py b/src/pytorch_lightning/utilities/data.py index 00a7cb8486709..b625a046f6122 100644 --- a/src/pytorch_lightning/utilities/data.py +++ b/src/pytorch_lightning/utilities/data.py @@ -501,15 +501,17 @@ def _replace_init_method(base_cls: Type, store_explicit_arg: Optional[str] = Non It patches the ``__init__`` method. """ classes = _get_all_subclasses(base_cls) | {base_cls} - wrapped = set() for cls in classes: - if cls.__init__ not in wrapped: + # Check that __init__ belongs to the class + # https://stackoverflow.com/a/5253424 + if "__init__" in cls.__dict__: cls._old_init = cls.__init__ cls.__init__ = _wrap_init_method(cls.__init__, store_explicit_arg) - wrapped.add(cls.__init__) yield for cls in classes: - if hasattr(cls, "_old_init"): + # Check that _old_init belongs to the class + # https://stackoverflow.com/a/5253424 + if "_old_init" in cls.__dict__: cls.__init__ = cls._old_init del cls._old_init diff --git a/src/pytorch_lightning/utilities/enums.py b/src/pytorch_lightning/utilities/enums.py index e687d3f9f046b..06d616f87259f 100644 --- a/src/pytorch_lightning/utilities/enums.py +++ b/src/pytorch_lightning/utilities/enums.py @@ -15,11 +15,9 @@ from __future__ import annotations import os -from enum import Enum, EnumMeta -from typing import Any +from enum import Enum from pytorch_lightning.utilities.exceptions import MisconfigurationException -from pytorch_lightning.utilities.warnings import rank_zero_deprecation class LightningEnum(str, Enum): @@ -43,37 +41,6 @@ def __hash__(self) -> int: return hash(self.value.lower()) -class _DeprecatedEnumMeta(EnumMeta): - """Enum that calls `deprecate()` whenever a member is accessed. - - Adapted from: https://stackoverflow.com/a/62309159/208880 - """ - - def __getattribute__(cls, name: str) -> Any: - obj = super().__getattribute__(name) - # ignore __dunder__ names -- prevents potential recursion errors - if not (name.startswith("__") and name.endswith("__")) and isinstance(obj, Enum): - obj.deprecate() - return obj - - def __getitem__(cls, name: str) -> Any: - member: _DeprecatedEnumMeta = super().__getitem__(name) - member.deprecate() - return member - - def __call__(cls, *args: Any, **kwargs: Any) -> Any: - obj = super().__call__(*args, **kwargs) - if isinstance(obj, Enum): - obj.deprecate() - return obj - - -class _DeprecatedEnum(LightningEnum, metaclass=_DeprecatedEnumMeta): - """_DeprecatedEnum calls an enum's `deprecate()` method on member access.""" - - pass - - class AMPType(LightningEnum): """Type of Automatic Mixed Precission used for training. @@ -110,66 +77,6 @@ def supported_types() -> list[str]: return [x.value for x in PrecisionType] -class DistributedType(_DeprecatedEnum): - """Define type of training strategy. - - Deprecated since v1.6.0 and will be removed in v1.8.0. - - Use `_StrategyType` instead. - """ - - DP = "dp" - DDP = "ddp" - DDP_SPAWN = "ddp_spawn" - TPU_SPAWN = "tpu_spawn" - DEEPSPEED = "deepspeed" - HOROVOD = "horovod" - DDP_SHARDED = "ddp_sharded" - DDP_SHARDED_SPAWN = "ddp_sharded_spawn" - DDP_FULLY_SHARDED = "ddp_fully_sharded" - HPU_PARALLEL = "hpu_parallel" - - @staticmethod - def interactive_compatible_types() -> list[DistributedType]: - """Returns a list containing interactive compatible DistributeTypes.""" - return [ - DistributedType.DP, - DistributedType.DDP_SPAWN, - DistributedType.DDP_SHARDED_SPAWN, - DistributedType.TPU_SPAWN, - ] - - def is_interactive_compatible(self) -> bool: - """Returns whether self is interactive compatible.""" - return self in DistributedType.interactive_compatible_types() - - def deprecate(self) -> None: - rank_zero_deprecation( - "`DistributedType` Enum has been deprecated in v1.6 and will be removed in v1.8." - f" Use the string value `{self.value!r}` instead." - ) - - -class DeviceType(_DeprecatedEnum): - """Define Device type by its nature - accelerators. - - Deprecated since v1.6.0 and will be removed in v1.8.0. - - Use `_AcceleratorType` instead. - """ - - CPU = "CPU" - GPU = "GPU" - IPU = "IPU" - TPU = "TPU" - - def deprecate(self) -> None: - rank_zero_deprecation( - "`DeviceType` Enum has been deprecated in v1.6 and will be removed in v1.8." - f" Use the string value `{self.value!r}` instead." - ) - - class GradClipAlgorithmType(LightningEnum): """Define gradient_clip_algorithm types - training-tricks. NORM type means "clipping gradients by norm". This computed over all model parameters together. diff --git a/src/pytorch_lightning/utilities/grads.py b/src/pytorch_lightning/utilities/grads.py index 66c1b7d988522..76c3f39bdc013 100644 --- a/src/pytorch_lightning/utilities/grads.py +++ b/src/pytorch_lightning/utilities/grads.py @@ -41,12 +41,12 @@ def grad_norm(module: Module, norm_type: Union[float, int, str], group_separator raise ValueError(f"`norm_type` must be a positive number or 'inf' (infinity norm). Got {norm_type}") norms = { - f"grad_{norm_type}_norm{group_separator}{name}": p.grad.data.norm(norm_type).item() + f"grad_{norm_type}_norm{group_separator}{name}": p.grad.data.norm(norm_type) for name, p in module.named_parameters() if p.grad is not None } if norms: - total_norm = torch.tensor(list(norms.values())).norm(norm_type).item() + total_norm = torch.tensor(list(norms.values())).norm(norm_type) norms[f"grad_{norm_type}_norm_total"] = total_norm - norms = {k: round(v, 4) for k, v in norms.items()} + norms = {k: round(v.item(), 4) for k, v in norms.items()} return norms diff --git a/src/pytorch_lightning/utilities/imports.py b/src/pytorch_lightning/utilities/imports.py index 67bf75be3c4d3..ba437ad332dfa 100644 --- a/src/pytorch_lightning/utilities/imports.py +++ b/src/pytorch_lightning/utilities/imports.py @@ -124,6 +124,7 @@ def __repr__(self) -> str: _IS_WINDOWS = platform.system() == "Windows" _IS_INTERACTIVE = hasattr(sys, "ps1") # https://stackoverflow.com/a/64523765 _PYTHON_GREATER_EQUAL_3_8_0 = (sys.version_info.major, sys.version_info.minor) >= (3, 8) +_PYTHON_GREATER_EQUAL_3_10_0 = (sys.version_info.major, sys.version_info.minor) >= (3, 10) _TORCH_GREATER_EQUAL_1_9_1 = _compare_version("torch", operator.ge, "1.9.1") _TORCH_GREATER_EQUAL_1_10 = _compare_version("torch", operator.ge, "1.10.0") _TORCH_LESSER_EQUAL_1_10_2 = _compare_version("torch", operator.le, "1.10.2") diff --git a/src/pytorch_lightning/utilities/parsing.py b/src/pytorch_lightning/utilities/parsing.py index 9f5fe2d6b6841..073423ab60773 100644 --- a/src/pytorch_lightning/utilities/parsing.py +++ b/src/pytorch_lightning/utilities/parsing.py @@ -108,7 +108,9 @@ def clean_namespace(hparams: Union[Dict[str, Any], Namespace]) -> None: del hparams_dict[k] -def parse_class_init_keys(cls: Type["pl.LightningModule"]) -> Tuple[str, Optional[str], Optional[str]]: +def parse_class_init_keys( + cls: Union[Type["pl.LightningModule"], Type["pl.LightningDataModule"]] +) -> Tuple[str, Optional[str], Optional[str]]: """Parse key words for standard ``self``, ``*args`` and ``**kwargs``. Examples: @@ -160,7 +162,10 @@ def get_init_args(frame: types.FrameType) -> Dict[str, Any]: def collect_init_args( - frame: types.FrameType, path_args: List[Dict[str, Any]], inside: bool = False + frame: types.FrameType, + path_args: List[Dict[str, Any]], + inside: bool = False, + classes: Tuple[Type, ...] = (), ) -> List[Dict[str, Any]]: """Recursively collects the arguments passed to the child constructors in the inheritance tree. @@ -168,6 +173,7 @@ def collect_init_args( frame: the current stack frame path_args: a list of dictionaries containing the constructor args in all parent classes inside: track if we are inside inheritance path, avoid terminating too soon + classes: the classes in which to inspect the frames Return: A list of dictionaries where each dictionary contains the arguments passed to the @@ -179,13 +185,13 @@ def collect_init_args( if not isinstance(frame.f_back, types.FrameType): return path_args - if "__class__" in local_vars: + if "__class__" in local_vars and (not classes or issubclass(local_vars["__class__"], classes)): local_args = get_init_args(frame) # recursive update path_args.append(local_args) - return collect_init_args(frame.f_back, path_args, inside=True) + return collect_init_args(frame.f_back, path_args, inside=True, classes=classes) if not inside: - return collect_init_args(frame.f_back, path_args, inside) + return collect_init_args(frame.f_back, path_args, inside, classes=classes) return path_args @@ -223,7 +229,10 @@ def save_hyperparameters( init_args = {f.name: getattr(obj, f.name) for f in fields(obj)} else: init_args = {} - for local_args in collect_init_args(frame, []): + + from pytorch_lightning.core.mixins import HyperparametersMixin + + for local_args in collect_init_args(frame, [], classes=(HyperparametersMixin,)): init_args.update(local_args) if ignore is None: diff --git a/src/pytorch_lightning/utilities/seed.py b/src/pytorch_lightning/utilities/seed.py index 6648b5a56b2b1..8fce6a1debfcf 100644 --- a/src/pytorch_lightning/utilities/seed.py +++ b/src/pytorch_lightning/utilities/seed.py @@ -24,7 +24,7 @@ import numpy as np import torch -from pytorch_lightning.utilities.rank_zero import rank_zero_only, rank_zero_warn +from pytorch_lightning.utilities.rank_zero import _get_rank, rank_zero_only, rank_zero_warn log = logging.getLogger(__name__) @@ -66,9 +66,7 @@ def seed_everything(seed: Optional[int] = None, workers: bool = False) -> int: rank_zero_warn(f"{seed} is not in bounds, numpy accepts from {min_seed_value} to {max_seed_value}") seed = _select_seed_randomly(min_seed_value, max_seed_value) - # using `log.info` instead of `rank_zero_info`, - # so users can verify the seed is properly set in distributed training. - log.info(f"Global seed set to {seed}") + log.info(f"[rank: {_get_rank()}] Global seed set to {seed}") os.environ["PL_GLOBAL_SEED"] = str(seed) random.seed(seed) np.random.seed(seed) diff --git a/src/pytorch_lightning/utilities/types.py b/src/pytorch_lightning/utilities/types.py index f6c14d366805f..18e2db6feb6c6 100644 --- a/src/pytorch_lightning/utilities/types.py +++ b/src/pytorch_lightning/utilities/types.py @@ -19,7 +19,7 @@ from contextlib import contextmanager from dataclasses import dataclass from pathlib import Path -from typing import Any, Dict, Generator, Iterator, List, Mapping, Optional, Sequence, Type, Union +from typing import Any, Callable, Dict, Generator, Iterator, List, Mapping, Optional, Sequence, Type, Union import torch from torch import Tensor @@ -49,6 +49,7 @@ ] EVAL_DATALOADERS = Union[DataLoader, Sequence[DataLoader]] _DEVICE = Union[torch.device, str, int] +_MAP_LOCATION_TYPE = Optional[Union[_DEVICE, Callable[[_DEVICE], _DEVICE], Dict[_DEVICE, _DEVICE]]] @runtime_checkable diff --git a/tests/tests_app/cli/test_cli.py b/tests/tests_app/cli/test_cli.py index 16e641ac38f23..48e1a26bb6f2b 100644 --- a/tests/tests_app/cli/test_cli.py +++ b/tests/tests_app/cli/test_cli.py @@ -70,7 +70,18 @@ def test_main_lightning_cli_help(): @mock.patch("lightning_cloud.login.Auth.authenticate", MagicMock()) @mock.patch("lightning_app.cli.cmd_clusters.AWSClusterManager.create") -def test_create_cluster(create: mock.MagicMock): +@pytest.mark.parametrize( + "extra_arguments,expected_instance_types,expected_cost_savings_mode", + [ + (["--instance-types", "t3.xlarge"], ["t3.xlarge"], True), + (["--instance-types", "t3.xlarge,t3.2xlarge"], ["t3.xlarge", "t3.2xlarge"], True), + ([], None, True), + (["--enable-performance"], None, False), + ], +) +def test_create_cluster( + create_command: mock.MagicMock, extra_arguments, expected_instance_types, expected_cost_savings_mode +): runner = CliRunner() runner.invoke( create_cluster, @@ -82,19 +93,18 @@ def test_create_cluster(create: mock.MagicMock): "dummy", "--role-arn", "arn:aws:iam::1234567890:role/lai-byoc", - "--instance-types", - "t2.small", - ], + ] + + extra_arguments, ) - create.assert_called_once_with( + create_command.assert_called_once_with( cluster_name="test-7", region="us-east-1", role_arn="arn:aws:iam::1234567890:role/lai-byoc", external_id="dummy", - instance_types=["t2.small"], + instance_types=expected_instance_types, edit_before_creation=False, - cost_savings=False, + cost_savings=expected_cost_savings_mode, wait=False, ) diff --git a/tests/tests_app/cli/test_cmd_show_logs.py b/tests/tests_app/cli/test_cmd_show_logs.py new file mode 100644 index 0000000000000..0dc06025151fa --- /dev/null +++ b/tests/tests_app/cli/test_cmd_show_logs.py @@ -0,0 +1,61 @@ +from unittest import mock + +from click.testing import CliRunner + +from lightning_app.cli.lightning_cli import logs + + +@mock.patch("lightning_app.cli.lightning_cli.LightningClient") +@mock.patch("lightning_app.cli.lightning_cli._get_project") +def test_show_logs_errors(project, client): + """Test that the CLI prints the errors for the show logs command.""" + + runner = CliRunner() + + # Response prep + app = mock.MagicMock() + app.name = "MyFakeApp" + work = mock.MagicMock() + work.name = "MyFakeWork" + flow = mock.MagicMock() + flow.name = "MyFakeFlow" + + # No apps ever run + apps = {} + client.return_value.lightningapp_instance_service_list_lightningapp_instances.return_value.lightningapps = apps + + result = runner.invoke(logs, ["NonExistentApp"]) + + assert result.exit_code == 1 + assert "Error: You don't have any application in the cloud" in result.output + + # App not specified + apps = {app} + client.return_value.lightningapp_instance_service_list_lightningapp_instances.return_value.lightningapps = apps + + result = runner.invoke(logs) + + assert result.exit_code == 1 + assert "Please select one of available: [MyFakeApp]" in str(result.output) + + # App does not exit + apps = {app} + client.return_value.lightningapp_instance_service_list_lightningapp_instances.return_value.lightningapps = apps + + result = runner.invoke(logs, ["ThisAppDoesNotExist"]) + + assert result.exit_code == 1 + assert "The Lightning App 'ThisAppDoesNotExist' does not exist." in str(result.output) + + # Component does not exist + apps = {app} + works = {work} + flows = {flow} + client.return_value.lightningapp_instance_service_list_lightningapp_instances.return_value.lightningapps = apps + client.return_value.lightningwork_service_list_lightningwork.return_value.lightningworks = works + app.spec.flow_servers = flows + + result = runner.invoke(logs, ["MyFakeApp", "NonExistentComponent"]) + + assert result.exit_code == 1 + assert "Component 'NonExistentComponent' does not exist in app MyFakeApp." in result.output diff --git a/tests/tests_app/core/test_lightning_api.py b/tests/tests_app/core/test_lightning_api.py index edd2896d1951d..1b2bf2fb52fd9 100644 --- a/tests/tests_app/core/test_lightning_api.py +++ b/tests/tests_app/core/test_lightning_api.py @@ -2,15 +2,27 @@ import multiprocessing as mp import os from copy import deepcopy +from multiprocessing import Process +from time import sleep from unittest import mock import pytest +import requests from deepdiff import DeepDiff, Delta from httpx import AsyncClient +from pydantic import BaseModel from lightning_app import LightningApp, LightningFlow, LightningWork +from lightning_app.api.http_methods import Post from lightning_app.core import api -from lightning_app.core.api import fastapi_service, global_app_state_store, start_server, UIRefresher +from lightning_app.core.api import ( + fastapi_service, + global_app_state_store, + register_global_routes, + start_server, + UIRefresher, +) +from lightning_app.core.constants import APP_SERVER_PORT from lightning_app.runners import MultiProcessRuntime, SingleProcessRuntime from lightning_app.storage.drive import Drive from lightning_app.testing.helpers import MockQueue @@ -20,6 +32,8 @@ from lightning_app.utilities.redis import check_if_redis_running from lightning_app.utilities.state import AppState, headers_for +register_global_routes() + class WorkA(LightningWork): def __init__(self): @@ -161,12 +175,11 @@ def test_update_publish_state_and_maybe_refresh_ui(): app = AppStageTestingApp(FlowA(), debug=True) publish_state_queue = MockQueue("publish_state_queue") - commands_metadata_queue = MockQueue("commands_metadata_queue") - commands_responses_queue = MockQueue("commands_metadata_queue") + api_response_queue = MockQueue("api_response_queue") publish_state_queue.put(app.state_with_changes) - thread = UIRefresher(publish_state_queue, commands_metadata_queue, commands_responses_queue) + thread = UIRefresher(publish_state_queue, api_response_queue) thread.run_once() assert global_app_state_store.get_app_state("1234") == app.state_with_changes @@ -192,18 +205,14 @@ def get(self, timeout: int = 0): publish_state_queue = InfiniteQueue("publish_state_queue") change_state_queue = MockQueue("change_state_queue") has_started_queue = MockQueue("has_started_queue") - commands_requests_queue = MockQueue("commands_requests_queue") - commands_responses_queue = MockQueue("commands_responses_queue") - commands_metadata_queue = MockQueue("commands_metadata_queue") + api_response_queue = MockQueue("api_response_queue") state = app.state_with_changes publish_state_queue.put(state) spec = extract_metadata_from_app(app) ui_refresher = start_server( publish_state_queue, change_state_queue, - commands_requests_queue, - commands_responses_queue, - commands_metadata_queue, + api_response_queue, has_started_queue=has_started_queue, uvicorn_run=False, spec=spec, @@ -343,16 +352,12 @@ def test_start_server_started(): api_publish_state_queue = mp.Queue() api_delta_queue = mp.Queue() has_started_queue = mp.Queue() - commands_requests_queue = mp.Queue() - commands_responses_queue = mp.Queue() - commands_metadata_queue = mp.Queue() + api_response_queue = mp.Queue() kwargs = dict( api_publish_state_queue=api_publish_state_queue, api_delta_queue=api_delta_queue, has_started_queue=has_started_queue, - commands_requests_queue=commands_requests_queue, - commands_responses_queue=commands_responses_queue, - commands_metadata_queue=commands_metadata_queue, + api_response_queue=api_response_queue, port=1111, ) @@ -372,18 +377,14 @@ def test_start_server_info_message(ui_refresher, uvicorn_run, caplog, monkeypatc api_publish_state_queue = MockQueue() api_delta_queue = MockQueue() has_started_queue = MockQueue() - commands_requests_queue = MockQueue() - commands_responses_queue = MockQueue() - commands_metadata_queue = MockQueue() + api_response_queue = MockQueue() kwargs = dict( host=host, port=1111, api_publish_state_queue=api_publish_state_queue, api_delta_queue=api_delta_queue, has_started_queue=has_started_queue, - commands_requests_queue=commands_requests_queue, - commands_responses_queue=commands_responses_queue, - commands_metadata_queue=commands_metadata_queue, + api_response_queue=api_response_queue, ) monkeypatch.setattr(api, "logger", logging.getLogger()) @@ -395,3 +396,65 @@ def test_start_server_info_message(ui_refresher, uvicorn_run, caplog, monkeypatc ui_refresher.assert_called_once() uvicorn_run.assert_called_once_with(host="0.0.0.1", port=1111, log_level="error", app=mock.ANY) + + +class InputRequestModel(BaseModel): + name: str + + +class OutputRequestModel(BaseModel): + name: str + counter: int + + +class FlowAPI(LightningFlow): + def __init__(self): + super().__init__() + self.counter = 0 + + def run(self): + if self.counter == 2: + sleep(0.5) + self._exit() + + def request(self, config: InputRequestModel) -> OutputRequestModel: + self.counter += 1 + return OutputRequestModel(name=config.name, counter=self.counter) + + def configure_api(self): + return [Post("/api/v1/request", self.request)] + + +def target(): + app = LightningApp(FlowAPI()) + MultiProcessRuntime(app).dispatch() + + +def test_configure_api(): + + process = Process(target=target) + process.start() + time_left = 15 + while time_left > 0: + try: + requests.get(f"http://localhost:{APP_SERVER_PORT}/healthz") + break + except requests.exceptions.ConnectionError: + sleep(0.1) + time_left -= 0.1 + + response = requests.post( + f"http://localhost:{APP_SERVER_PORT}/api/v1/request", data=InputRequestModel(name="hello").json() + ) + assert response.json() == {"name": "hello", "counter": 1} + response = requests.post( + f"http://localhost:{APP_SERVER_PORT}/api/v1/request", data=InputRequestModel(name="hello").json() + ) + assert response.json() == {"name": "hello", "counter": 2} + time_left = 15 + while time_left > 0: + if process.exitcode == 0: + break + sleep(0.1) + time_left -= 0.1 + assert process.exitcode == 0 diff --git a/tests/tests_app/core/test_lightning_app.py b/tests/tests_app/core/test_lightning_app.py index a3a15085b98e3..3776481965be3 100644 --- a/tests/tests_app/core/test_lightning_app.py +++ b/tests/tests_app/core/test_lightning_app.py @@ -1,3 +1,4 @@ +import logging import os import pickle from time import sleep @@ -27,6 +28,8 @@ from lightning_app.utilities.redis import check_if_redis_running from lightning_app.utilities.warnings import LightningFlowWarning +logger = logging.getLogger() + class B1(LightningFlow): def __init__(self): @@ -439,19 +442,25 @@ def __init__(self): self.counter = 0 def run(self): - self.counter = 1 + if self.counter < 2: + self.counter += 1 def test_maybe_apply_changes_from_flow(): """This test validates the app `_updated` is set to True only if the state was changed in the flow.""" app = LightningApp(SimpleFlow()) - assert not app._has_updated + assert app._has_updated app.maybe_apply_changes() app.root.run() app.maybe_apply_changes() assert app._has_updated app._has_updated = False + app.root.run() + app.maybe_apply_changes() + assert app._has_updated + app._has_updated = False + app.root.run() app.maybe_apply_changes() assert not app._has_updated @@ -896,6 +905,7 @@ def __init__(self, **kwargs): def run(self, signal: int): self.counter += 1 + assert len(self._calls) == 2 class SizeFlow(LightningFlow): @@ -919,3 +929,29 @@ def test_state_size_constant_growth(): MultiProcessRuntime(app, start_server=False).dispatch() assert app.root._state_sizes[0] <= 5904 assert app.root._state_sizes[20] <= 23736 + + +class FlowUpdated(LightningFlow): + def run(self): + logger.info("Hello World") + + +class NonUpdatedLightningTestApp(LightningTestApp): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.counter = 0 + + def on_after_run_once(self): + self.counter += 1 + if not self._has_updated and self.counter > 2: + return True + return super().on_after_run_once() + + +def test_non_updated_flow(caplog): + """This tests validate the app can run 3 times and call the flow only once.""" + with caplog.at_level(logging.INFO): + app = NonUpdatedLightningTestApp(FlowUpdated()) + MultiProcessRuntime(app, start_server=False).dispatch() + assert caplog.messages == ["Hello World"] + assert app.counter == 3 diff --git a/tests/tests_app/core/test_lightning_flow.py b/tests/tests_app/core/test_lightning_flow.py index e8ce1222a3186..4c0eb23ea014c 100644 --- a/tests/tests_app/core/test_lightning_flow.py +++ b/tests/tests_app/core/test_lightning_flow.py @@ -16,7 +16,7 @@ from lightning_app.storage import Path from lightning_app.storage.path import storage_root_dir from lightning_app.testing.helpers import EmptyFlow, EmptyWork -from lightning_app.utilities.app_helpers import _delta_to_appstate_delta, _LightningAppRef +from lightning_app.utilities.app_helpers import _delta_to_app_state_delta, _LightningAppRef from lightning_app.utilities.enum import CacheCallsKeys from lightning_app.utilities.exceptions import ExitAppException @@ -416,7 +416,7 @@ def run(self): flow_a.work.counter = 1 work_state_2 = flow_a.work.state delta = Delta(DeepDiff(work_state, work_state_2, verbose_level=2)) - delta = _delta_to_appstate_delta(flow_a, flow_a.work, delta) + delta = _delta_to_app_state_delta(flow_a, flow_a.work, delta) new_flow_state = LightningApp.populate_changes(flow_state, flow_state + delta) flow_a.set_state(new_flow_state) assert flow_a.work.counter == 1 @@ -592,24 +592,23 @@ def run(self): class FlowSchedule(LightningFlow): def __init__(self): super().__init__() - self._last_time = None + self._last_times = [] + self.target = 3 + self.seconds = ",".join([str(v) for v in range(0, 60, self.target)]) def run(self): - if self.schedule("* * * * * 0,5,10,15,20,25,30,35,40,45,50,55"): - if self._last_time is None: - self._last_time = False - elif not self._last_time: - self._last_time = time() + if self.schedule(f"* * * * * {self.seconds}"): + if len(self._last_times) < 3: + self._last_times.append(time()) else: - # TODO (tchaton) Optimize flow execution. - assert 4.0 < abs(time() - self._last_time) < 6.0 + assert abs((time() - self._last_times[-1]) - self.target) < 3 self._exit() def test_scheduling_api(): app = LightningApp(FlowSchedule()) - MultiProcessRuntime(app).dispatch() + MultiProcessRuntime(app, start_server=True).dispatch() def test_lightning_flow(): diff --git a/tests/tests_app/runners/test_cloud.py b/tests/tests_app/runners/test_cloud.py index 4b1cf08e8554d..640eb9c114c2d 100644 --- a/tests/tests_app/runners/test_cloud.py +++ b/tests/tests_app/runners/test_cloud.py @@ -1,4 +1,5 @@ import logging +from copy import copy from pathlib import Path from unittest import mock from unittest.mock import MagicMock @@ -9,21 +10,29 @@ Gridv1ImageSpec, V1BuildSpec, V1DependencyFileInfo, + V1Drive, + V1DriveSpec, + V1DriveStatus, + V1DriveType, V1LightningappInstanceState, + V1LightningworkDrives, V1LightningworkSpec, V1ListLightningappInstancesResponse, V1ListMembershipsResponse, V1Membership, + V1Metadata, V1NetworkConfig, V1PackageManager, V1ProjectClusterBinding, V1PythonDependencyInfo, + V1SourceType, V1UserRequestedComputeConfig, V1Work, ) from lightning_app import LightningApp, LightningWork from lightning_app.runners import backends, cloud +from lightning_app.storage import Drive from lightning_app.utilities.cloud import _get_project from lightning_app.utilities.dependency_caching import get_hash @@ -33,6 +42,25 @@ def run(self): print("my run") +class WorkWithSingleDrive(LightningWork): + def __init__(self): + super().__init__() + self.drive = None + + def run(self): + pass + + +class WorkWithTwoDrives(LightningWork): + def __init__(self): + super().__init__() + self.lit_drive = None + self.s3_drive = None + + def run(self): + pass + + class TestAppCreationClient: """Testing the calls made using GridRestClient to create the app.""" @@ -250,6 +278,134 @@ def test_call_with_work_app(self, lightningapps, monkeypatch, tmpdir): ), image="random_base_public_image", ), + drives=[], + user_requested_compute_config=V1UserRequestedComputeConfig( + name="default", count=1, disk_size=0, preemptible=False, shm_size=0 + ), + network_config=[V1NetworkConfig(name=mock.ANY, host=None, port=8080)], + ), + ) + ], + ) + mock_client.lightningapp_v2_service_create_lightningapp_release.assert_called_once_with( + "test-project-id", mock.ANY, expected_body + ) + + # running dispatch with disabled dependency cache + mock_client.reset_mock() + monkeypatch.setattr(cloud, "DISABLE_DEPENDENCY_CACHE", True) + expected_body.dependency_cache_key = None + cloud_runtime.dispatch() + mock_client.lightningapp_v2_service_create_lightningapp_release.assert_called_once_with( + "test-project-id", mock.ANY, expected_body + ) + else: + mock_client.lightningapp_v2_service_create_lightningapp_release_instance.assert_called_once_with( + "test-project-id", mock.ANY, mock.ANY, mock.ANY + ) + + @mock.patch("lightning_app.runners.backends.cloud.LightningClient", mock.MagicMock()) + @pytest.mark.parametrize("lightningapps", [[], [MagicMock()]]) + def test_call_with_work_app_and_attached_drives(self, lightningapps, monkeypatch, tmpdir): + source_code_root_dir = Path(tmpdir / "src").absolute() + source_code_root_dir.mkdir() + Path(source_code_root_dir / ".lightning").write_text("name: myapp") + requirements_file = Path(source_code_root_dir / "requirements.txt") + Path(requirements_file).touch() + + mock_client = mock.MagicMock() + if lightningapps: + lightningapps[0].status.phase = V1LightningappInstanceState.STOPPED + mock_client.lightningapp_instance_service_list_lightningapp_instances.return_value = ( + V1ListLightningappInstancesResponse(lightningapps=lightningapps) + ) + lightning_app_instance = MagicMock() + mock_client.lightningapp_v2_service_create_lightningapp_release = MagicMock(return_value=lightning_app_instance) + mock_client.lightningapp_v2_service_create_lightningapp_release_instance = MagicMock( + return_value=lightning_app_instance + ) + existing_instance = MagicMock() + existing_instance.status.phase = V1LightningappInstanceState.STOPPED + mock_client.lightningapp_service_get_lightningapp = MagicMock(return_value=existing_instance) + cloud_backend = mock.MagicMock() + cloud_backend.client = mock_client + monkeypatch.setattr(backends, "CloudBackend", mock.MagicMock(return_value=cloud_backend)) + monkeypatch.setattr(cloud, "LocalSourceCodeDir", mock.MagicMock()) + monkeypatch.setattr(cloud, "_prepare_lightning_wheels_and_requirements", mock.MagicMock()) + app = mock.MagicMock() + flow = mock.MagicMock() + + mocked_drive = MagicMock(spec=Drive) + setattr(mocked_drive, "id", "foobar") + setattr(mocked_drive, "protocol", "lit://") + setattr(mocked_drive, "component_name", "test-work") + setattr(mocked_drive, "allow_duplicates", False) + setattr(mocked_drive, "root_folder", tmpdir) + # deepcopy on a MagicMock instance will return an empty magicmock instance. To + # overcome this we set the __deepcopy__ method `return_value` to equal what + # should be the results of the deepcopy operation (an instance of the original class) + mocked_drive.__deepcopy__.return_value = copy(mocked_drive) + + work = WorkWithSingleDrive() + monkeypatch.setattr(work, "drive", mocked_drive) + monkeypatch.setattr(work, "_state", {"_port", "drive"}) + monkeypatch.setattr(work, "_name", "test-work") + monkeypatch.setattr(work._cloud_build_config, "build_commands", lambda: ["echo 'start'"]) + monkeypatch.setattr(work._cloud_build_config, "requirements", ["torch==1.0.0", "numpy==1.0.0"]) + monkeypatch.setattr(work._cloud_build_config, "image", "random_base_public_image") + monkeypatch.setattr(work._cloud_compute, "disk_size", 0) + monkeypatch.setattr(work._cloud_compute, "preemptible", False) + monkeypatch.setattr(work, "_port", 8080) + + flow.works = lambda recurse: [work] + app.flows = [flow] + cloud_runtime = cloud.CloudRuntime(app=app, entrypoint_file=(source_code_root_dir / "entrypoint.py")) + monkeypatch.setattr( + "lightning_app.runners.cloud._get_project", + lambda x: V1Membership(name="test-project", project_id="test-project-id"), + ) + cloud_runtime.dispatch() + + if lightningapps: + expected_body = Body8( + description=None, + local_source=True, + app_entrypoint_file="entrypoint.py", + enable_app_server=True, + flow_servers=[], + dependency_cache_key=get_hash(requirements_file), + image_spec=Gridv1ImageSpec( + dependency_file_info=V1DependencyFileInfo( + package_manager=V1PackageManager.PIP, path="requirements.txt" + ) + ), + works=[ + V1Work( + name="test-work", + spec=V1LightningworkSpec( + build_spec=V1BuildSpec( + commands=["echo 'start'"], + python_dependencies=V1PythonDependencyInfo( + package_manager=V1PackageManager.PIP, packages="torch==1.0.0\nnumpy==1.0.0" + ), + image="random_base_public_image", + ), + drives=[ + V1LightningworkDrives( + drive=V1Drive( + metadata=V1Metadata( + name="test-work.drive", + ), + spec=V1DriveSpec( + drive_type=V1DriveType.NO_MOUNT_S3, + source_type=V1SourceType.S3, + source="lit://foobar", + ), + status=V1DriveStatus(), + ), + mount_location=str(tmpdir), + ), + ], user_requested_compute_config=V1UserRequestedComputeConfig( name="default", count=1, disk_size=0, preemptible=False, shm_size=0 ), @@ -275,6 +431,206 @@ def test_call_with_work_app(self, lightningapps, monkeypatch, tmpdir): "test-project-id", mock.ANY, mock.ANY, mock.ANY ) + @mock.patch("lightning_app.runners.backends.cloud.LightningClient", mock.MagicMock()) + @pytest.mark.parametrize("lightningapps", [[], [MagicMock()]]) + def test_call_with_work_app_and_multiple_attached_drives(self, lightningapps, monkeypatch, tmpdir): + source_code_root_dir = Path(tmpdir / "src").absolute() + source_code_root_dir.mkdir() + Path(source_code_root_dir / ".lightning").write_text("name: myapp") + requirements_file = Path(source_code_root_dir / "requirements.txt") + Path(requirements_file).touch() + + mock_client = mock.MagicMock() + if lightningapps: + lightningapps[0].status.phase = V1LightningappInstanceState.STOPPED + mock_client.lightningapp_instance_service_list_lightningapp_instances.return_value = ( + V1ListLightningappInstancesResponse(lightningapps=lightningapps) + ) + lightning_app_instance = MagicMock() + mock_client.lightningapp_v2_service_create_lightningapp_release = MagicMock(return_value=lightning_app_instance) + mock_client.lightningapp_v2_service_create_lightningapp_release_instance = MagicMock( + return_value=lightning_app_instance + ) + existing_instance = MagicMock() + existing_instance.status.phase = V1LightningappInstanceState.STOPPED + mock_client.lightningapp_service_get_lightningapp = MagicMock(return_value=existing_instance) + cloud_backend = mock.MagicMock() + cloud_backend.client = mock_client + monkeypatch.setattr(backends, "CloudBackend", mock.MagicMock(return_value=cloud_backend)) + monkeypatch.setattr(cloud, "LocalSourceCodeDir", mock.MagicMock()) + monkeypatch.setattr(cloud, "_prepare_lightning_wheels_and_requirements", mock.MagicMock()) + app = mock.MagicMock() + flow = mock.MagicMock() + + mocked_lit_drive = MagicMock(spec=Drive) + setattr(mocked_lit_drive, "id", "foobar") + setattr(mocked_lit_drive, "protocol", "lit://") + setattr(mocked_lit_drive, "component_name", "test-work") + setattr(mocked_lit_drive, "allow_duplicates", False) + setattr(mocked_lit_drive, "root_folder", tmpdir) + # deepcopy on a MagicMock instance will return an empty magicmock instance. To + # overcome this we set the __deepcopy__ method `return_value` to equal what + # should be the results of the deepcopy operation (an instance of the original class) + mocked_lit_drive.__deepcopy__.return_value = copy(mocked_lit_drive) + + mocked_s3_drive = MagicMock(spec=Drive) + setattr(mocked_s3_drive, "id", "some-bucket/path/") + setattr(mocked_s3_drive, "protocol", "s3://") + setattr(mocked_s3_drive, "component_name", "test-work") + setattr(mocked_s3_drive, "allow_duplicates", False) + setattr(mocked_s3_drive, "root_folder", "/hello/") + # deepcopy on a MagicMock instance will return an empty magicmock instance. To + # overcome this we set the __deepcopy__ method `return_value` to equal what + # should be the results of the deepcopy operation (an instance of the original class) + mocked_s3_drive.__deepcopy__.return_value = copy(mocked_s3_drive) + + work = WorkWithTwoDrives() + monkeypatch.setattr(work, "lit_drive", mocked_lit_drive) + monkeypatch.setattr(work, "s3_drive", mocked_s3_drive) + monkeypatch.setattr(work, "_state", {"_port", "_name", "lit_drive", "s3_drive"}) + monkeypatch.setattr(work, "_name", "test-work") + monkeypatch.setattr(work._cloud_build_config, "build_commands", lambda: ["echo 'start'"]) + monkeypatch.setattr(work._cloud_build_config, "requirements", ["torch==1.0.0", "numpy==1.0.0"]) + monkeypatch.setattr(work._cloud_build_config, "image", "random_base_public_image") + monkeypatch.setattr(work._cloud_compute, "disk_size", 0) + monkeypatch.setattr(work._cloud_compute, "preemptible", False) + monkeypatch.setattr(work, "_port", 8080) + + flow.works = lambda recurse: [work] + app.flows = [flow] + cloud_runtime = cloud.CloudRuntime(app=app, entrypoint_file=(source_code_root_dir / "entrypoint.py")) + monkeypatch.setattr( + "lightning_app.runners.cloud._get_project", + lambda x: V1Membership(name="test-project", project_id="test-project-id"), + ) + cloud_runtime.dispatch() + + if lightningapps: + s3_drive_spec = V1LightningworkDrives( + drive=V1Drive( + metadata=V1Metadata( + name="test-work.s3_drive", + ), + spec=V1DriveSpec( + drive_type=V1DriveType.INDEXED_S3, + source_type=V1SourceType.S3, + source="s3://some-bucket/path/", + ), + status=V1DriveStatus(), + ), + mount_location="/hello/", + ) + lit_drive_spec = V1LightningworkDrives( + drive=V1Drive( + metadata=V1Metadata( + name="test-work.lit_drive", + ), + spec=V1DriveSpec( + drive_type=V1DriveType.NO_MOUNT_S3, + source_type=V1SourceType.S3, + source="lit://foobar", + ), + status=V1DriveStatus(), + ), + mount_location=str(tmpdir), + ) + + # order of drives in the spec is non-deterministic, so there are two options + # depending for the expected body value on which drive is ordered in the list first. + + expected_body_option_1 = Body8( + description=None, + local_source=True, + app_entrypoint_file="entrypoint.py", + enable_app_server=True, + flow_servers=[], + dependency_cache_key=get_hash(requirements_file), + image_spec=Gridv1ImageSpec( + dependency_file_info=V1DependencyFileInfo( + package_manager=V1PackageManager.PIP, path="requirements.txt" + ) + ), + works=[ + V1Work( + name="test-work", + spec=V1LightningworkSpec( + build_spec=V1BuildSpec( + commands=["echo 'start'"], + python_dependencies=V1PythonDependencyInfo( + package_manager=V1PackageManager.PIP, packages="torch==1.0.0\nnumpy==1.0.0" + ), + image="random_base_public_image", + ), + drives=[lit_drive_spec, s3_drive_spec], + user_requested_compute_config=V1UserRequestedComputeConfig( + name="default", count=1, disk_size=0, preemptible=False, shm_size=0 + ), + network_config=[V1NetworkConfig(name=mock.ANY, host=None, port=8080)], + ), + ) + ], + ) + + expected_body_option_2 = Body8( + description=None, + local_source=True, + app_entrypoint_file="entrypoint.py", + enable_app_server=True, + flow_servers=[], + dependency_cache_key=get_hash(requirements_file), + image_spec=Gridv1ImageSpec( + dependency_file_info=V1DependencyFileInfo( + package_manager=V1PackageManager.PIP, path="requirements.txt" + ) + ), + works=[ + V1Work( + name="test-work", + spec=V1LightningworkSpec( + build_spec=V1BuildSpec( + commands=["echo 'start'"], + python_dependencies=V1PythonDependencyInfo( + package_manager=V1PackageManager.PIP, packages="torch==1.0.0\nnumpy==1.0.0" + ), + image="random_base_public_image", + ), + drives=[s3_drive_spec, lit_drive_spec], + user_requested_compute_config=V1UserRequestedComputeConfig( + name="default", count=1, disk_size=0, preemptible=False, shm_size=0 + ), + network_config=[V1NetworkConfig(name=mock.ANY, host=None, port=8080)], + ), + ) + ], + ) + + # try both options for the expected body to avoid false + # positive test failures depending on system randomness + + expected_body = expected_body_option_1 + try: + mock_client.lightningapp_v2_service_create_lightningapp_release.assert_called_once_with( + "test-project-id", mock.ANY, expected_body + ) + except Exception: + expected_body = expected_body_option_2 + mock_client.lightningapp_v2_service_create_lightningapp_release.assert_called_once_with( + "test-project-id", mock.ANY, expected_body + ) + + # running dispatch with disabled dependency cache + mock_client.reset_mock() + monkeypatch.setattr(cloud, "DISABLE_DEPENDENCY_CACHE", True) + expected_body.dependency_cache_key = None + cloud_runtime.dispatch() + mock_client.lightningapp_v2_service_create_lightningapp_release.assert_called_once_with( + "test-project-id", mock.ANY, expected_body + ) + else: + mock_client.lightningapp_v2_service_create_lightningapp_release_instance.assert_called_once_with( + "test-project-id", mock.ANY, mock.ANY, mock.ANY + ) + @mock.patch("lightning_app.core.queues.QueuingSystem", MagicMock()) @mock.patch("lightning_app.runners.backends.cloud.LightningClient", MagicMock()) diff --git a/tests/tests_app/storage/test_drive.py b/tests/tests_app/storage/test_drive.py index 3d9db44c10e13..0d452571d9f43 100644 --- a/tests/tests_app/storage/test_drive.py +++ b/tests/tests_app/storage/test_drive.py @@ -11,7 +11,7 @@ from lightning_app.utilities.component import _set_flow_context -class SyncWorkA(LightningWork): +class SyncWorkLITDriveA(LightningWork): def __init__(self, tmpdir): super().__init__() self.tmpdir = tmpdir @@ -25,19 +25,19 @@ def run(self, drive: Drive): os.remove(f"{self.tmpdir}/a.txt") -class SyncWorkB(LightningWork): +class SyncWorkLITDriveB(LightningWork): def run(self, drive: Drive): assert not os.path.exists("a.txt") drive.get("a.txt") assert os.path.exists("a.txt") -class SyncFlow(LightningFlow): +class SyncFlowLITDrives(LightningFlow): def __init__(self, tmpdir): super().__init__() self.log_dir = Drive("lit://log_dir") - self.work_a = SyncWorkA(str(tmpdir)) - self.work_b = SyncWorkB() + self.work_a = SyncWorkLITDriveA(str(tmpdir)) + self.work_b = SyncWorkLITDriveB() def run(self): self.work_a.run(self.log_dir) @@ -45,15 +45,15 @@ def run(self): self._exit() -def test_synchronization_drive(tmpdir): +def test_synchronization_lit_drive(tmpdir): if os.path.exists("a.txt"): os.remove("a.txt") - app = LightningApp(SyncFlow(tmpdir)) + app = LightningApp(SyncFlowLITDrives(tmpdir)) MultiProcessRuntime(app, start_server=False).dispatch() os.remove("a.txt") -class Work(LightningWork): +class LITDriveWork(LightningWork): def __init__(self): super().__init__(parallel=True) self.drive = None @@ -75,7 +75,7 @@ def run(self, *args, **kwargs): self.counter += 1 -class Work2(LightningWork): +class LITDriveWork2(LightningWork): def __init__(self): super().__init__(parallel=True) @@ -86,11 +86,11 @@ def run(self, drive: Drive, **kwargs): assert drive.list(".", component_name=self.name) == [] -class Flow(LightningFlow): +class LITDriveFlow(LightningFlow): def __init__(self): super().__init__() - self.work = Work() - self.work2 = Work2() + self.work = LITDriveWork() + self.work2 = LITDriveWork2() def run(self): self.work.run("0") @@ -102,15 +102,15 @@ def run(self): self._exit() -def test_drive_transferring_files(): - app = LightningApp(Flow()) +def test_lit_drive_transferring_files(): + app = LightningApp(LITDriveFlow()) MultiProcessRuntime(app, start_server=False).dispatch() os.remove("a.txt") -def test_drive(): - with pytest.raises(Exception, match="The Drive id needs to start with one of the following protocols"): - Drive("this_drive_id") +def test_lit_drive(): + with pytest.raises(Exception, match="Unknown protocol for the drive 'id' argument"): + Drive("invalid_drive_id") with pytest.raises( Exception, match="The id should be unique to identify your drive. Found `this_drive_id/something_else`." @@ -213,9 +213,46 @@ def test_drive(): os.remove("a.txt") -def test_maybe_create_drive(): +def test_s3_drives(): + drive = Drive("s3://foo/", allow_duplicates=True) + drive.component_name = "root.work" - drive = Drive("lit://drive_3", allow_duplicates=False) + with pytest.raises( + Exception, match="S3 based drives cannot currently add files via this API. Did you mean to use `lit://` drives?" + ): + drive.put("a.txt") + with pytest.raises( + Exception, + match="S3 based drives cannot currently list files via this API. Did you mean to use `lit://` drives?", + ): + drive.list("a.txt") + with pytest.raises( + Exception, match="S3 based drives cannot currently get files via this API. Did you mean to use `lit://` drives?" + ): + drive.get("a.txt") + with pytest.raises( + Exception, + match="S3 based drives cannot currently delete files via this API. Did you mean to use `lit://` drives?", + ): + drive.delete("a.txt") + + _set_flow_context() + with pytest.raises(Exception, match="The flow isn't allowed to put files into a Drive."): + drive.put("a.txt") + with pytest.raises(Exception, match="The flow isn't allowed to list files from a Drive."): + drive.list("a.txt") + with pytest.raises(Exception, match="The flow isn't allowed to get files from a Drive."): + drive.get("a.txt") + + +def test_create_s3_drive_without_trailing_slash_fails(): + with pytest.raises(ValueError, match="S3 drives must end in a trailing slash"): + Drive("s3://foo") + + +@pytest.mark.parametrize("drive_id", ["lit://drive", "s3://drive/"]) +def test_maybe_create_drive(drive_id): + drive = Drive(drive_id, allow_duplicates=False) drive.component_name = "root.work1" new_drive = _maybe_create_drive(drive.component_name, drive.to_dict()) assert new_drive.protocol == drive.protocol @@ -223,9 +260,9 @@ def test_maybe_create_drive(): assert new_drive.component_name == drive.component_name -def test_drive_deepcopy(): - - drive = Drive("lit://drive", allow_duplicates=True) +@pytest.mark.parametrize("drive_id", ["lit://drive", "s3://drive/"]) +def test_drive_deepcopy(drive_id): + drive = Drive(drive_id, allow_duplicates=True) drive.component_name = "root.work1" new_drive = deepcopy(drive) assert new_drive.id == drive.id diff --git a/tests/tests_app/utilities/packaging/test_lightning_utils.py b/tests/tests_app/utilities/packaging/test_lightning_utils.py index b34e3162d5a0c..8f30aa21dd396 100644 --- a/tests/tests_app/utilities/packaging/test_lightning_utils.py +++ b/tests/tests_app/utilities/packaging/test_lightning_utils.py @@ -1,4 +1,5 @@ import os +from unittest import mock import pytest @@ -21,6 +22,21 @@ def test_prepare_lightning_wheels_and_requirement(tmpdir): assert os.listdir(tmpdir) == [] +def _mocked_get_dist_path_if_editable_install(*args, **kwargs): + return None + + +@mock.patch( + "lightning_app.utilities.packaging.lightning_utils.get_dist_path_if_editable_install", + new=_mocked_get_dist_path_if_editable_install, +) +def test_prepare_lightning_wheels_and_requirement_for_packages_installed_in_editable_mode(tmpdir): + """This test ensures the source does not get packaged inside the lightning repo if not installed in editable + mode.""" + cleanup_handle = _prepare_lightning_wheels_and_requirements(tmpdir) + assert cleanup_handle is None + + @pytest.mark.skip(reason="TODO: Find a way to check for the latest version") @RunIf(skip_windows=True) def test_verify_lightning_version(monkeypatch): diff --git a/tests/tests_app/utilities/test_app_logs.py b/tests/tests_app/utilities/test_app_logs.py new file mode 100644 index 0000000000000..7a0fe087e7c29 --- /dev/null +++ b/tests/tests_app/utilities/test_app_logs.py @@ -0,0 +1,13 @@ +from datetime import datetime +from time import sleep +from unittest.mock import MagicMock + +from lightning_app.utilities.app_logs import _LogEvent + + +def test_log_event(): + event_1 = _LogEvent("", datetime.now(), MagicMock(), MagicMock()) + sleep(0.1) + event_2 = _LogEvent("", datetime.now(), MagicMock(), MagicMock()) + assert event_1 < event_2 + assert event_1 <= event_2 diff --git a/tests/tests_app/utilities/test_commands.py b/tests/tests_app/utilities/test_commands.py index 1e8e36ed09545..1be35a3a2e290 100644 --- a/tests/tests_app/utilities/test_commands.py +++ b/tests/tests_app/utilities/test_commands.py @@ -14,7 +14,7 @@ from lightning_app.core.constants import APP_SERVER_PORT from lightning_app.runners import MultiProcessRuntime from lightning_app.testing.helpers import RunIf -from lightning_app.utilities.commands.base import _command_to_method_and_metadata, _download_command, ClientCommand +from lightning_app.utilities.commands.base import _download_command, _validate_client_command, ClientCommand from lightning_app.utilities.state import AppState @@ -25,7 +25,6 @@ class SweepConfig(BaseModel): class SweepCommand(ClientCommand): def run(self) -> None: - print(sys.argv) parser = argparse.ArgumentParser() parser.add_argument("--sweep_name", type=str) parser.add_argument("--num_trials", type=int) @@ -44,7 +43,7 @@ def __init__(self): def run(self): if self.has_sweep and len(self.names) == 1: - sleep(2) + sleep(1) self._exit() def trigger_method(self, name: str): @@ -91,15 +90,15 @@ def run_failure_2(name: CustomModel): @RunIf(skip_windows=True) -def test_command_to_method_and_metadata(): +def test_validate_client_command(): with pytest.raises(Exception, match="The provided annotation for the argument name"): - _command_to_method_and_metadata(ClientCommand(run_failure_0)) + _validate_client_command(ClientCommand(run_failure_0)) with pytest.raises(Exception, match="annotate your method"): - _command_to_method_and_metadata(ClientCommand(run_failure_1)) + _validate_client_command(ClientCommand(run_failure_1)) with pytest.raises(Exception, match="lightning_app/utilities/commands/base.py"): - _command_to_method_and_metadata(ClientCommand(run_failure_2)) + _validate_client_command(ClientCommand(run_failure_2)) def test_client_commands(monkeypatch): @@ -115,17 +114,13 @@ def test_client_commands(monkeypatch): url = "http//" kwargs = {"something": "1", "something_else": "1"} command = DummyCommand(run) - _, command_metadata = _command_to_method_and_metadata(command) - command_metadata.update( - { - "command": "dummy", - "affiliation": "root", - "is_client_command": True, - "owner": "root", - } + _validate_client_command(command) + client_command = _download_command( + command_name="something", + cls_path=__file__, + cls_name="DummyCommand", ) - client_command, models = _download_command(command_metadata, None) - client_command._setup(metadata=command_metadata, models=models, app_url=url) + client_command._setup("something", app_url=url) client_command.run(**kwargs) @@ -153,10 +148,12 @@ def test_configure_commands(monkeypatch): state = AppState() state._request_state() assert state.names == ["something"] - monkeypatch.setattr(sys, "argv", ["lightning", "sweep", "--sweep_name", "my_name", "--num_trials", "1"]) + monkeypatch.setattr(sys, "argv", ["lightning", "sweep", "--sweep_name=my_name", "--num_trials=1"]) app_command() time_left = 15 - while time_left > 0 or process.exitcode is None: + while time_left > 0: + if process.exitcode == 0: + break sleep(0.1) time_left -= 0.1 assert process.exitcode == 0 diff --git a/tests/tests_app_examples/test_boring_app.py b/tests/tests_app_examples/test_boring_app.py index 1f681260de5c2..afb958571d16b 100644 --- a/tests/tests_app_examples/test_boring_app.py +++ b/tests/tests_app_examples/test_boring_app.py @@ -1,8 +1,10 @@ import os import pytest +from click.testing import CliRunner from tests_app import _PROJECT_ROOT +from lightning_app.cli.lightning_cli import logs from lightning_app.testing.testing import run_app_in_cloud, wait_for @@ -11,7 +13,8 @@ def test_boring_app_example_cloud() -> None: with run_app_in_cloud(os.path.join(_PROJECT_ROOT, "examples/app_boring/"), app_name="app_dynamic.py") as ( _, view_page, - _, + fetch_logs, + name, ): def check_hello_there(*_, **__): @@ -21,3 +24,14 @@ def check_hello_there(*_, **__): return True wait_for(view_page, check_hello_there) + + for _ in fetch_logs(): + pass + + runner = CliRunner() + result = runner.invoke(logs, [name]) + lines = result.output.splitlines() + + assert result.exit_code == 0 + assert result.exception is None + assert any("http://0.0.0.0:8080" in line for line in lines) diff --git a/tests/tests_app_examples/test_collect_failures.py b/tests/tests_app_examples/test_collect_failures.py index f263ebb1a9f58..c149211e10774 100644 --- a/tests/tests_app_examples/test_collect_failures.py +++ b/tests/tests_app_examples/test_collect_failures.py @@ -26,6 +26,7 @@ def test_collect_failures_example_cloud() -> None: _, _, fetch_logs, + _, ): last_found_log_index = -1 while len(expected_logs) != 0: diff --git a/tests/tests_app_examples/test_commands.py b/tests/tests_app_examples/test_commands.py deleted file mode 100644 index 5116b1b9d54bb..0000000000000 --- a/tests/tests_app_examples/test_commands.py +++ /dev/null @@ -1,31 +0,0 @@ -import os -from subprocess import Popen -from time import sleep -from unittest import mock - -import pytest -from tests_app import _PROJECT_ROOT - -from lightning_app.testing.testing import run_app_in_cloud - - -@mock.patch.dict(os.environ, {"SKIP_LIGHTING_UTILITY_WHEELS_BUILD": "0"}) -@pytest.mark.cloud -def test_commands_example_cloud() -> None: - with run_app_in_cloud(os.path.join(_PROJECT_ROOT, "examples/app_commands")) as ( - admin_page, - _, - fetch_logs, - ): - app_id = admin_page.url.split("/")[-1] - cmd = f"lightning trigger_with_client_command --name=something --app_id {app_id}" - Popen(cmd, shell=True).wait() - cmd = f"lightning trigger_without_client_command --name=else --app_id {app_id}" - Popen(cmd, shell=True).wait() - - has_logs = False - while not has_logs: - for log in fetch_logs(): - if "['something', 'else']" in log: - has_logs = True - sleep(1) diff --git a/tests/tests_app_examples/test_commands_and_api.py b/tests/tests_app_examples/test_commands_and_api.py new file mode 100644 index 0000000000000..8d84cf4847ebd --- /dev/null +++ b/tests/tests_app_examples/test_commands_and_api.py @@ -0,0 +1,42 @@ +import os +from subprocess import Popen +from time import sleep + +import pytest +import requests +from tests_app import _PROJECT_ROOT + +from lightning_app.testing.testing import run_app_in_cloud + + +@pytest.mark.cloud +def test_commands_and_api_example_cloud() -> None: + with run_app_in_cloud(os.path.join(_PROJECT_ROOT, "examples/app_commands_and_api")) as ( + admin_page, + view_page, + fetch_logs, + _, + ): + # 1: Collect the app_id + app_id = admin_page.url.split("/")[-1] + + # 2: Send the first command with the client + cmd = f"lightning command_with_client --name=this --app_id {app_id}" + Popen(cmd, shell=True).wait() + + # 3: Send the second command without a client + cmd = f"lightning command_without_client --name=is --app_id {app_id}" + Popen(cmd, shell=True).wait() + + # 4: Send a request to the Rest API directly. + base_url = view_page.url.replace("/view", "").replace("/child_flow", "") + resp = requests.post(base_url + "/user/command_without_client?name=awesome") + assert resp.status_code == 200, resp.json() + + # 5: Validate the logs. + has_logs = False + while not has_logs: + for log in fetch_logs(): + if "['this', 'is', 'awesome']" in log: + has_logs = True + sleep(1) diff --git a/tests/tests_app_examples/test_custom_work_dependencies.py b/tests/tests_app_examples/test_custom_work_dependencies.py index 8390233e2eee3..b8971e0ef2148 100644 --- a/tests/tests_app_examples/test_custom_work_dependencies.py +++ b/tests/tests_app_examples/test_custom_work_dependencies.py @@ -13,10 +13,10 @@ def test_custom_work_dependencies_example_cloud() -> None: with run_app_in_cloud( os.path.join(_PROJECT_ROOT, "tests/tests_app_examples/custom_work_dependencies/"), app_name="app.py", - ) as (_, _, fetch_logs): + ) as (_, _, fetch_logs, _): has_logs = False while not has_logs: - for log in fetch_logs(): + for log in fetch_logs(["flow"]): if "Custom Work Dependency checker End" in log: has_logs = True sleep(1) diff --git a/tests/tests_app_examples/test_drive.py b/tests/tests_app_examples/test_drive.py index 9cebca9cf1072..dde68d1a85113 100644 --- a/tests/tests_app_examples/test_drive.py +++ b/tests/tests_app_examples/test_drive.py @@ -11,8 +11,9 @@ def test_drive_example_cloud() -> None: with run_app_in_cloud(os.path.join(_PROJECT_ROOT, "examples/app_drive")) as ( _, - view_page, + _, fetch_logs, + _, ): has_logs = False diff --git a/tests/tests_app_examples/test_idle_timeout.py b/tests/tests_app_examples/test_idle_timeout.py index fb58a83aefc93..f06181ce86ed3 100644 --- a/tests/tests_app_examples/test_idle_timeout.py +++ b/tests/tests_app_examples/test_idle_timeout.py @@ -13,10 +13,11 @@ def test_idle_timeout_example_cloud() -> None: _, _, fetch_logs, + _, ): has_logs = False while not has_logs: - for log in fetch_logs(): + for log in fetch_logs(["flow"]): if "Application End" in log: has_logs = True sleep(1) diff --git a/tests/tests_app_examples/test_payload.py b/tests/tests_app_examples/test_payload.py index 28d2391c18a2a..b40b8ca52defd 100644 --- a/tests/tests_app_examples/test_payload.py +++ b/tests/tests_app_examples/test_payload.py @@ -9,11 +9,11 @@ @pytest.mark.cloud def test_payload_example_cloud() -> None: - with run_app_in_cloud(os.path.join(_PROJECT_ROOT, "examples/app_payload")) as (_, _, fetch_logs): + with run_app_in_cloud(os.path.join(_PROJECT_ROOT, "examples/app_payload")) as (_, _, fetch_logs, _): has_logs = False while not has_logs: - for log in fetch_logs(): + for log in fetch_logs(["flow"]): if "Application End!" in log: has_logs = True sleep(1) diff --git a/tests/tests_app_examples/test_quick_start.py b/tests/tests_app_examples/test_quick_start.py index 9db693a5dc3d6..454c1084ca1bb 100644 --- a/tests/tests_app_examples/test_quick_start.py +++ b/tests/tests_app_examples/test_quick_start.py @@ -51,7 +51,7 @@ def test_quick_start_example(caplog, monkeypatch): @pytest.mark.cloud def test_quick_start_example_cloud() -> None: - with run_app_in_cloud(os.path.join(_PROJECT_ROOT, "lightning-quick-start/")) as (_, view_page, _): + with run_app_in_cloud(os.path.join(_PROJECT_ROOT, "lightning-quick-start/")) as (_, view_page, _, _): def click_gradio_demo(*_, **__): button = view_page.locator('button:has-text("Interactive demo")') diff --git a/tests/tests_app_examples/test_template_react_ui.py b/tests/tests_app_examples/test_template_react_ui.py index 2e348035fe6e5..4b4588d2397e5 100644 --- a/tests/tests_app_examples/test_template_react_ui.py +++ b/tests/tests_app_examples/test_template_react_ui.py @@ -14,6 +14,7 @@ def test_template_react_ui_example_cloud() -> None: _, view_page, fetch_logs, + _, ): def click_button(*_, **__): diff --git a/tests/tests_app_examples/test_template_streamlit_ui.py b/tests/tests_app_examples/test_template_streamlit_ui.py index a8ba93794f2a0..e2c33305298f7 100644 --- a/tests/tests_app_examples/test_template_streamlit_ui.py +++ b/tests/tests_app_examples/test_template_streamlit_ui.py @@ -14,6 +14,7 @@ def test_template_streamlit_ui_example_cloud() -> None: _, view_page, fetch_logs, + _, ): def click_button(*_, **__): diff --git a/tests/tests_app_examples/test_v0_app.py b/tests/tests_app_examples/test_v0_app.py index d34a92d6102f8..026c45a4e1ba1 100644 --- a/tests/tests_app_examples/test_v0_app.py +++ b/tests/tests_app_examples/test_v0_app.py @@ -45,7 +45,7 @@ def check_content(button_name, text_content): wait_for(view_page, check_content, "TAB_2", "Hello from component B") has_logs = False while not has_logs: - for log in fetch_logs(): + for log in fetch_logs(["flow"]): if "'a': 'a', 'b': 'b'" in log: has_logs = True sleep(1) @@ -74,5 +74,6 @@ def test_v0_app_example_cloud() -> None: _, view_page, fetch_logs, + _, ): run_v0_app(fetch_logs, view_page) diff --git a/tests/tests_pytorch/accelerators/test_ipu.py b/tests/tests_pytorch/accelerators/test_ipu.py index 33d59d9a835ca..db3b9d1f91952 100644 --- a/tests/tests_pytorch/accelerators/test_ipu.py +++ b/tests/tests_pytorch/accelerators/test_ipu.py @@ -185,7 +185,7 @@ def test_optimization(tmpdir): @RunIf(ipu=True) -def test_mixed_precision(tmpdir): +def test_half_precision(tmpdir): class TestCallback(Callback): def setup(self, trainer: Trainer, pl_module: LightningModule, stage: Optional[str] = None) -> None: assert trainer.strategy.model.precision == 16 diff --git a/tests/tests_pytorch/callbacks/progress/test_rich_progress_bar.py b/tests/tests_pytorch/callbacks/progress/test_rich_progress_bar.py index e9374f8ea4be1..f1ccf2a2726a2 100644 --- a/tests/tests_pytorch/callbacks/progress/test_rich_progress_bar.py +++ b/tests/tests_pytorch/callbacks/progress/test_rich_progress_bar.py @@ -21,8 +21,7 @@ from pytorch_lightning import Trainer from pytorch_lightning.callbacks import ProgressBarBase, RichProgressBar from pytorch_lightning.callbacks.progress.rich_progress import RichProgressBarTheme -from pytorch_lightning.demos.boring_classes import BoringModel, RandomDataset -from tests_pytorch.helpers.datasets import RandomIterableDataset +from pytorch_lightning.demos.boring_classes import BoringModel, RandomDataset, RandomIterableDataset from tests_pytorch.helpers.runif import RunIf diff --git a/tests/tests_pytorch/callbacks/test_stochastic_weight_avg.py b/tests/tests_pytorch/callbacks/test_stochastic_weight_avg.py index 859cf2fa98c0c..7f1692e30a3f2 100644 --- a/tests/tests_pytorch/callbacks/test_stochastic_weight_avg.py +++ b/tests/tests_pytorch/callbacks/test_stochastic_weight_avg.py @@ -12,25 +12,30 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging +import os +from pathlib import Path +from typing import ContextManager, Optional from unittest import mock import pytest import torch from torch import nn +from torch.optim.lr_scheduler import LambdaLR from torch.optim.swa_utils import SWALR from torch.utils.data import DataLoader from pytorch_lightning import Trainer from pytorch_lightning.callbacks import StochasticWeightAveraging -from pytorch_lightning.demos.boring_classes import BoringModel, RandomDataset +from pytorch_lightning.demos.boring_classes import BoringModel, RandomDataset, RandomIterableDataset from pytorch_lightning.strategies import DDPSpawnStrategy, Strategy from pytorch_lightning.utilities.exceptions import MisconfigurationException -from tests_pytorch.helpers.datasets import RandomIterableDataset from tests_pytorch.helpers.runif import RunIf class SwaTestModel(BoringModel): - def __init__(self, batchnorm: bool = True, interval: str = "epoch", iterable_dataset: bool = False): + def __init__( + self, batchnorm: bool = True, interval: str = "epoch", iterable_dataset: bool = False, crash_on_epoch=None + ): super().__init__() layers = [nn.Linear(32, 32)] if batchnorm: @@ -39,17 +44,18 @@ def __init__(self, batchnorm: bool = True, interval: str = "epoch", iterable_dat self.layer = nn.Sequential(*layers) self.interval = interval self.iterable_dataset = iterable_dataset + self.crash_on_epoch = crash_on_epoch def training_step(self, batch, batch_idx): + if self.crash_on_epoch and self.trainer.current_epoch >= self.crash_on_epoch: + raise Exception("SWA crash test") output = self.forward(batch) loss = self.loss(batch, output) return {"loss": loss} def train_dataloader(self): - dset_cls = RandomIterableDataset if self.iterable_dataset else RandomDataset dset = dset_cls(32, 64) - return DataLoader(dset, batch_size=2) def configure_optimizers(self): @@ -66,6 +72,8 @@ def configure_optimizers(self): class SwaTestCallback(StochasticWeightAveraging): update_parameters_calls: int = 0 transfer_weights_calls: int = 0 + # Record the first epoch, as if we are resuming from a checkpoint this may not be equal to 0 + first_epoch: Optional[int] = None def update_parameters(self, *args, **kwargs): self.update_parameters_calls += 1 @@ -77,6 +85,11 @@ def transfer_weights(self, *args, **kwargs): def on_train_epoch_start(self, trainer, *args): super().on_train_epoch_start(trainer, *args) + if self.first_epoch is None and not trainer.fit_loop.restarting: + # since the checkpoint loaded was saved `on_train_epoch_end`, the first `FitLoop` iteration will + # not update the model and just call the epoch-level hooks, for that reason, we check that we are not + # restarting before choosing the first epoch + self.first_epoch = trainer.current_epoch assert trainer.fit_loop._skip_backward == (trainer.current_epoch > self.swa_end) if self.swa_start <= trainer.current_epoch: assert isinstance(trainer.lr_scheduler_configs[0].scheduler, SWALR) @@ -88,6 +101,7 @@ def on_train_epoch_end(self, trainer, *args): if self.swa_start <= trainer.current_epoch <= self.swa_end: swa_epoch = trainer.current_epoch - self.swa_start assert self.n_averaged == swa_epoch + 1 + assert self._swa_scheduler is not None # Scheduler is stepped once on initialization and then at the end of each epoch assert self._swa_scheduler._step_count == swa_epoch + 2 elif trainer.current_epoch > self.swa_end: @@ -103,10 +117,13 @@ def on_train_end(self, trainer, pl_module): if not isinstance(trainer.strategy, DDPSpawnStrategy): # check backward call count. the batchnorm update epoch should not backward - assert trainer.strategy.backward.call_count == trainer.max_epochs * trainer.limit_train_batches + assert trainer.strategy.backward.call_count == ( + (trainer.max_epochs - self.first_epoch) * trainer.limit_train_batches + ) # check call counts - assert self.update_parameters_calls == trainer.max_epochs - (self._swa_epoch_start - 1) + first_swa_epoch = max(self.first_epoch, self.swa_start) + assert self.update_parameters_calls == trainer.max_epochs - first_swa_epoch assert self.transfer_weights_calls == 1 @@ -140,7 +157,7 @@ def train_with_swa( devices=devices, ) - with mock.patch.object(Strategy, "backward", wraps=trainer.strategy.backward): + with _backward_patch(trainer): trainer.fit(model) # check the model is the expected @@ -226,9 +243,10 @@ def test_swa_multiple_lrs(tmpdir): class TestModel(BoringModel): def __init__(self): - super(BoringModel, self).__init__() + super().__init__() self.layer1 = torch.nn.Linear(32, 32) self.layer2 = torch.nn.Linear(32, 2) + self.on_train_epoch_start_called = False def forward(self, x): x = self.layer1(x) @@ -255,3 +273,98 @@ def on_train_epoch_start(self): ) trainer.fit(model) assert model.on_train_epoch_start_called + + +def _swa_resume_training_from_checkpoint(tmpdir, model, resume_model, ddp=False): + swa_start = 3 + trainer_kwargs = { + "default_root_dir": tmpdir, + "max_epochs": 5, + "accelerator": "cpu", + "strategy": "ddp_spawn_find_unused_parameters_false" if ddp else None, + "devices": 2 if ddp else 1, + "limit_train_batches": 5, + "limit_val_batches": 0, + "accumulate_grad_batches": 2, + "enable_progress_bar": False, + } + trainer = Trainer(callbacks=SwaTestCallback(swa_epoch_start=swa_start, swa_lrs=0.1), **trainer_kwargs) + + with _backward_patch(trainer), pytest.raises(Exception, match="SWA crash test"): + trainer.fit(model) + + checkpoint_dir = Path(tmpdir) / "lightning_logs" / "version_0" / "checkpoints" + checkpoint_files = os.listdir(checkpoint_dir) + assert len(checkpoint_files) == 1 + ckpt_path = str(checkpoint_dir / checkpoint_files[0]) + + trainer = Trainer(callbacks=SwaTestCallback(swa_epoch_start=swa_start, swa_lrs=0.1), **trainer_kwargs) + + with _backward_patch(trainer): + trainer.fit(resume_model, ckpt_path=ckpt_path) + + +class CustomSchedulerModel(SwaTestModel): + def configure_optimizers(self): + optimizer = torch.optim.SGD(self.layer.parameters(), lr=0.1) + + def lr_lambda(current_step: int): + return 0.1 + + scheduler = LambdaLR(optimizer, lr_lambda, -1) + return { + "optimizer": optimizer, + "lr_scheduler": { + "scheduler": scheduler, + "interval": self.interval, + }, + } + + +@pytest.mark.parametrize("crash_on_epoch", [1, 3]) +def test_swa_resume_training_from_checkpoint(tmpdir, crash_on_epoch): + model = SwaTestModel(crash_on_epoch=crash_on_epoch) + resume_model = SwaTestModel() + _swa_resume_training_from_checkpoint(tmpdir, model, resume_model) + + +@pytest.mark.parametrize("crash_on_epoch", [1, 3]) +def test_swa_resume_training_from_checkpoint_custom_scheduler(tmpdir, crash_on_epoch): + # Reproduces the bug reported in https://github.com/PyTorchLightning/pytorch-lightning/issues/11665 + model = CustomSchedulerModel(crash_on_epoch=crash_on_epoch) + resume_model = CustomSchedulerModel() + _swa_resume_training_from_checkpoint(tmpdir, model, resume_model) + + +@RunIf(skip_windows=True) +def test_swa_resume_training_from_checkpoint_ddp(tmpdir): + model = SwaTestModel(crash_on_epoch=3) + resume_model = SwaTestModel() + _swa_resume_training_from_checkpoint(tmpdir, model, resume_model, ddp=True) + + +@pytest.mark.parametrize( + "strategy", + [ + pytest.param("fsdp", marks=RunIf(fairscale_fully_sharded=True, min_cuda_gpus=1)), + pytest.param("deepspeed", marks=RunIf(deepspeed=True, min_cuda_gpus=1)), + ], +) +def test_misconfiguration_error_with_sharded_model(tmpdir, strategy: str): + model = SwaTestModel() + swa_callback = SwaTestCallback(swa_epoch_start=2, swa_lrs=0.1) + trainer = Trainer( + default_root_dir=tmpdir, + enable_progress_bar=False, + max_epochs=5, + callbacks=[swa_callback], + strategy=strategy, + accelerator="gpu", + devices=1, + ) + with pytest.raises(MisconfigurationException, match="SWA does not currently support sharded models"): + trainer.fit(model) + + +def _backward_patch(trainer: Trainer) -> ContextManager: + return mock.patch.object(Strategy, "backward", wraps=trainer.strategy.backward) diff --git a/tests/tests_pytorch/deprecated_api/test_remove_1-8.py b/tests/tests_pytorch/deprecated_api/test_remove_1-8.py index aa6c1a615f9d2..91be34c55078f 100644 --- a/tests/tests_pytorch/deprecated_api/test_remove_1-8.py +++ b/tests/tests_pytorch/deprecated_api/test_remove_1-8.py @@ -36,7 +36,6 @@ from pytorch_lightning.trainer.states import RunningStage from pytorch_lightning.utilities import device_parser from pytorch_lightning.utilities.apply_func import move_data_to_device -from pytorch_lightning.utilities.enums import DeviceType, DistributedType from pytorch_lightning.utilities.imports import _TORCHTEXT_LEGACY from pytorch_lightning.utilities.rank_zero import rank_zero_only, rank_zero_warn from tests_pytorch.deprecated_api import no_deprecated_call @@ -44,18 +43,6 @@ from tests_pytorch.helpers.torchtext_utils import get_dummy_torchtext_data_iterator -def test_v1_8_0_deprecated_distributed_type_enum(): - - with pytest.deprecated_call(match="has been deprecated in v1.6 and will be removed in v1.8."): - _ = DistributedType.DDP - - -def test_v1_8_0_deprecated_device_type_enum(): - - with pytest.deprecated_call(match="has been deprecated in v1.6 and will be removed in v1.8."): - _ = DeviceType.CPU - - @pytest.mark.skipif(not _TORCHTEXT_LEGACY, reason="torchtext.legacy is deprecated.") def test_v1_8_0_deprecated_torchtext_batch(): diff --git a/tests/tests_pytorch/helpers/datasets.py b/tests/tests_pytorch/helpers/datasets.py index 3443020d4528f..c9d185313e85e 100644 --- a/tests/tests_pytorch/helpers/datasets.py +++ b/tests/tests_pytorch/helpers/datasets.py @@ -19,7 +19,7 @@ from typing import Optional, Sequence, Tuple import torch -from torch.utils.data import Dataset, IterableDataset +from torch.utils.data import Dataset class MNIST(Dataset): @@ -212,40 +212,3 @@ def __getitem__(self, idx): def __len__(self): return len(self.y) - - -class RandomDictDataset(Dataset): - def __init__(self, size: int, length: int): - self.len = length - self.data = torch.randn(length, size) - - def __getitem__(self, index): - a = self.data[index] - b = a + 2 - return {"a": a, "b": b} - - def __len__(self): - return self.len - - -class RandomIterableDataset(IterableDataset): - def __init__(self, size: int, count: int): - self.count = count - self.size = size - - def __iter__(self): - for _ in range(self.count): - yield torch.randn(self.size) - - -class RandomIterableDatasetWithLen(IterableDataset): - def __init__(self, size: int, count: int): - self.count = count - self.size = size - - def __iter__(self): - for _ in range(len(self)): - yield torch.randn(self.size) - - def __len__(self): - return self.count diff --git a/tests/tests_pytorch/lite/test_lite.py b/tests/tests_pytorch/lite/test_lite.py index 2215ab3129780..86a0a5a82195a 100644 --- a/tests/tests_pytorch/lite/test_lite.py +++ b/tests/tests_pytorch/lite/test_lite.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import contextlib import os from copy import deepcopy from unittest import mock @@ -30,7 +29,6 @@ from pytorch_lightning.strategies import DeepSpeedStrategy, Strategy from pytorch_lightning.utilities import _StrategyType from pytorch_lightning.utilities.exceptions import MisconfigurationException -from pytorch_lightning.utilities.imports import _RequirementAvailable from pytorch_lightning.utilities.seed import pl_worker_init_function from tests_pytorch.helpers.runif import RunIf @@ -480,13 +478,4 @@ def run(self): assert self.broadcast(True) assert self.is_global_zero == (self.local_rank == 0) - if _RequirementAvailable("deepspeed>=0.6.5"): - # https://github.com/microsoft/DeepSpeed/issues/2139 - raise_if_deepspeed_incompatible = pytest.raises( - RuntimeError, match="DeepSpeed ZeRO-3 is not supported with this version of Lightning Lite" - ) - else: - raise_if_deepspeed_incompatible = contextlib.suppress() - - with raise_if_deepspeed_incompatible: - Lite(strategy=DeepSpeedStrategy(stage=3, logging_batch_size_per_gpu=1), devices=2, accelerator="gpu").run() + Lite(strategy=DeepSpeedStrategy(stage=3, logging_batch_size_per_gpu=1), devices=2, accelerator="gpu").run() diff --git a/tests/tests_pytorch/models/test_hparams.py b/tests/tests_pytorch/models/test_hparams.py index c130381c7832d..84311d6f780fb 100644 --- a/tests/tests_pytorch/models/test_hparams.py +++ b/tests/tests_pytorch/models/test_hparams.py @@ -29,6 +29,7 @@ from pytorch_lightning import LightningModule, Trainer from pytorch_lightning.callbacks import ModelCheckpoint from pytorch_lightning.core.datamodule import LightningDataModule +from pytorch_lightning.core.mixins import HyperparametersMixin from pytorch_lightning.core.saving import load_hparams_from_yaml, save_hparams_to_yaml from pytorch_lightning.demos.boring_classes import BoringDataModule, BoringModel, RandomDataset from pytorch_lightning.utilities import _OMEGACONF_AVAILABLE, AttributeDict, is_picklable @@ -399,6 +400,24 @@ def _raw_checkpoint_path(trainer) -> str: return raw_checkpoint_path +@pytest.mark.parametrize("base_class", (HyperparametersMixin, LightningModule, LightningDataModule)) +def test_save_hyperparameters_under_composition(base_class): + """Test that in a composition where the parent is not a Lightning-like module, the parent's arguments don't get + collected.""" + + class ChildInComposition(base_class): + def __init__(self, same_arg): + super().__init__() + self.save_hyperparameters() + + class NotPLSubclass: # intentionally not subclassing LightningModule/LightningDataModule + def __init__(self, same_arg="parent_default", other_arg="other"): + self.child = ChildInComposition(same_arg="cocofruit") + + parent = NotPLSubclass() + assert parent.child.hparams == dict(same_arg="cocofruit") + + class LocalVariableModelSuperLast(BoringModel): """This model has the super().__init__() call at the end.""" diff --git a/tests/tests_pytorch/run_standalone_tasks.sh b/tests/tests_pytorch/run_standalone_tasks.sh index 960bd867ceaa4..698ed7863ab96 100644 --- a/tests/tests_pytorch/run_standalone_tasks.sh +++ b/tests/tests_pytorch/run_standalone_tasks.sh @@ -34,6 +34,10 @@ fi # test that a user can manually launch individual processes echo "Running manual ddp launch test" export PYTHONPATH="${PYTHONPATH}:$(pwd)" -args="--trainer.accelerator gpu --trainer.devices 2 --trainer.strategy ddp --trainer.max_epochs=1 --trainer.limit_train_batches=1 --trainer.limit_val_batches=1 --trainer.limit_test_batches=1" -MASTER_ADDR="localhost" MASTER_PORT=1234 LOCAL_RANK=1 python ../../examples/convert_from_pt_to_pl/image_classifier_5_lightning_datamodule.py ${args} & -MASTER_ADDR="localhost" MASTER_PORT=1234 LOCAL_RANK=0 python ../../examples/convert_from_pt_to_pl/image_classifier_5_lightning_datamodule.py ${args} +args="fit --trainer.accelerator gpu --trainer.devices 2 --trainer.strategy ddp --trainer.max_epochs=1 --trainer.limit_train_batches=1 --trainer.limit_val_batches=1 --trainer.limit_test_batches=1" +MASTER_ADDR="localhost" MASTER_PORT=1234 LOCAL_RANK=1 python strategies/scripts/cli_script.py ${args} & +MASTER_ADDR="localhost" MASTER_PORT=1234 LOCAL_RANK=0 python strategies/scripts/cli_script.py ${args} + +# test that ddp can launched as a module (-m option) +echo "Running ddp example as module" +python -m strategies.scripts.cli_script ${args} diff --git a/tests/tests_pytorch/serve/__init__.py b/tests/tests_pytorch/serve/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/tests/tests_pytorch/strategies/ddp_model.py b/tests/tests_pytorch/strategies/ddp_model.py deleted file mode 100644 index 76d1f3f2f6866..0000000000000 --- a/tests/tests_pytorch/strategies/ddp_model.py +++ /dev/null @@ -1,58 +0,0 @@ -# Copyright The PyTorch Lightning team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Runs either `.fit()` or `.test()` on a single node across multiple gpus.""" -import os -from argparse import ArgumentParser - -import torch - -from pytorch_lightning import seed_everything, Trainer -from tests_pytorch.helpers.datamodules import ClassifDataModule -from tests_pytorch.helpers.simple_models import ClassificationModel - - -def main(): - seed_everything(4321) - - parser = ArgumentParser(add_help=False) - parser = Trainer.add_argparse_args(parser) - parser.add_argument("--trainer_method", default="fit") - parser.add_argument("--tmpdir") - parser.add_argument("--workdir") - parser.set_defaults(accelerator="gpu", devices=2) - parser.set_defaults(strategy="ddp") - args = parser.parse_args() - - dm = ClassifDataModule() - model = ClassificationModel() - trainer = Trainer.from_argparse_args(args) - - if args.trainer_method == "fit": - trainer.fit(model, datamodule=dm) - result = None - elif args.trainer_method == "test": - result = trainer.test(model, datamodule=dm) - elif args.trainer_method == "fit_test": - trainer.fit(model, datamodule=dm) - result = trainer.test(model, datamodule=dm) - else: - raise ValueError(f"Unsupported: {args.trainer_method}") - - result_ext = {"status": "complete", "method": args.trainer_method, "result": result} - file_path = os.path.join(args.tmpdir, "ddp.result") - torch.save(result_ext, file_path) - - -if __name__ == "__main__": - main() diff --git a/tests/tests_pytorch/strategies/scripts/__init__.py b/tests/tests_pytorch/strategies/scripts/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/tests/tests_pytorch/strategies/scripts/cli_script.py b/tests/tests_pytorch/strategies/scripts/cli_script.py new file mode 100644 index 0000000000000..17f0d29392eb9 --- /dev/null +++ b/tests/tests_pytorch/strategies/scripts/cli_script.py @@ -0,0 +1,24 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""A trivial script that wraps a LightningCLI around the BoringModel and BoringDataModule.""" +from pytorch_lightning.cli import LightningCLI +from pytorch_lightning.demos.boring_classes import BoringDataModule, BoringModel + +if __name__ == "__main__": + LightningCLI( + BoringModel, + BoringDataModule, + seed_everything_default=42, + save_config_overwrite=True, + ) diff --git a/tests/tests_pytorch/strategies/test_ddp.py b/tests/tests_pytorch/strategies/test_ddp.py index 4610f6153386b..9b196f3e2a97f 100644 --- a/tests/tests_pytorch/strategies/test_ddp.py +++ b/tests/tests_pytorch/strategies/test_ddp.py @@ -21,60 +21,41 @@ from torch.nn.parallel.distributed import DistributedDataParallel import pytorch_lightning as pl -from pytorch_lightning import Trainer +from pytorch_lightning import seed_everything, Trainer from pytorch_lightning.callbacks import Callback from pytorch_lightning.demos.boring_classes import BoringModel from pytorch_lightning.strategies import DDPStrategy +from tests_pytorch.helpers.datamodules import ClassifDataModule from tests_pytorch.helpers.runif import RunIf -from tests_pytorch.strategies import ddp_model -from tests_pytorch.utilities.distributed import call_training_script +from tests_pytorch.helpers.simple_models import ClassificationModel -CLI_ARGS = "--max_epochs 1 --accelerator gpu --devices 2 --strategy ddp" +@RunIf(min_cuda_gpus=2, standalone=True) +def test_multi_gpu_model_ddp_fit_only(tmpdir): + dm = ClassifDataModule() + model = ClassificationModel() + trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, accelerator="gpu", devices=2, strategy="ddp") + trainer.fit(model, datamodule=dm) -@RunIf(min_cuda_gpus=2) -@pytest.mark.parametrize("as_module", [True, False]) -def test_multi_gpu_model_ddp_fit_only(tmpdir, as_module): - # call the script - call_training_script(ddp_model, CLI_ARGS, "fit", tmpdir, timeout=120, as_module=as_module) - # load the results of the script - result_path = os.path.join(tmpdir, "ddp.result") - result = torch.load(result_path) +@RunIf(min_cuda_gpus=2, standalone=True) +def test_multi_gpu_model_ddp_test_only(tmpdir): + dm = ClassifDataModule() + model = ClassificationModel() + trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, accelerator="gpu", devices=2, strategy="ddp") + trainer.test(model, datamodule=dm) - # verify the file wrote the expected outputs - assert result["status"] == "complete" +@RunIf(min_cuda_gpus=2, standalone=True) +def test_multi_gpu_model_ddp_fit_test(tmpdir): + seed_everything(4321) + dm = ClassifDataModule() + model = ClassificationModel() + trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, accelerator="gpu", devices=2, strategy="ddp") + trainer.fit(model, datamodule=dm) + result = trainer.test(model, datamodule=dm) -@RunIf(min_cuda_gpus=2) -@pytest.mark.parametrize("as_module", [True, False]) -def test_multi_gpu_model_ddp_test_only(tmpdir, as_module): - # call the script - call_training_script(ddp_model, CLI_ARGS, "test", tmpdir, as_module=as_module) - - # load the results of the script - result_path = os.path.join(tmpdir, "ddp.result") - result = torch.load(result_path) - - # verify the file wrote the expected outputs - assert result["status"] == "complete" - - -@RunIf(min_cuda_gpus=2) -@pytest.mark.parametrize("as_module", [True, False]) -def test_multi_gpu_model_ddp_fit_test(tmpdir, as_module): - # call the script - call_training_script(ddp_model, CLI_ARGS, "fit_test", tmpdir, timeout=20, as_module=as_module) - - # load the results of the script - result_path = os.path.join(tmpdir, "ddp.result") - result = torch.load(result_path) - - # verify the file wrote the expected outputs - assert result["status"] == "complete" - - model_outs = result["result"] - for out in model_outs: + for out in result: assert out["test_acc"] > 0.7 @@ -194,3 +175,15 @@ def root_device(self): assert strategy._get_process_group_backend() == expected_process_group_backend else: assert strategy._get_process_group_backend() == expected_process_group_backend + + +@pytest.mark.parametrize( + "strategy_name,expected_ddp_kwargs", + [ + ("ddp", {}), + ("ddp_find_unused_parameters_false", {"find_unused_parameters": False}), + ], +) +def test_ddp_kwargs_from_registry(strategy_name, expected_ddp_kwargs): + trainer = Trainer(strategy=strategy_name) + assert trainer.strategy._ddp_kwargs == expected_ddp_kwargs diff --git a/tests/tests_pytorch/strategies/test_ddp_spawn_strategy.py b/tests/tests_pytorch/strategies/test_ddp_spawn_strategy.py index 52427c2c8cc3a..7fb22206c45c6 100644 --- a/tests/tests_pytorch/strategies/test_ddp_spawn_strategy.py +++ b/tests/tests_pytorch/strategies/test_ddp_spawn_strategy.py @@ -178,3 +178,25 @@ def test_ddp_spawn_strategy_set_timeout(mock_init_process_group): mock_init_process_group.assert_called_with( process_group_backend, rank=global_rank, world_size=world_size, timeout=test_timedelta ) + + +@pytest.mark.parametrize( + "strategy_name,expected_ddp_kwargs", + [ + ("ddp_spawn", {}), + pytest.param("ddp_fork", {}, marks=RunIf(skip_windows=True)), + pytest.param("ddp_notebook", {}, marks=RunIf(skip_windows=True)), + ("ddp_spawn_find_unused_parameters_false", {"find_unused_parameters": False}), + pytest.param( + "ddp_fork_find_unused_parameters_false", {"find_unused_parameters": False}, marks=RunIf(skip_windows=True) + ), + pytest.param( + "ddp_notebook_find_unused_parameters_false", + {"find_unused_parameters": False}, + marks=RunIf(skip_windows=True), + ), + ], +) +def test_ddp_kwargs_from_registry(strategy_name, expected_ddp_kwargs): + trainer = Trainer(strategy=strategy_name) + assert trainer.strategy._ddp_kwargs == expected_ddp_kwargs diff --git a/tests/tests_pytorch/strategies/test_deepspeed_strategy.py b/tests/tests_pytorch/strategies/test_deepspeed_strategy.py index 4f2cc14b6c62d..e3c6f95f3ff47 100644 --- a/tests/tests_pytorch/strategies/test_deepspeed_strategy.py +++ b/tests/tests_pytorch/strategies/test_deepspeed_strategy.py @@ -28,13 +28,12 @@ from pytorch_lightning import LightningDataModule, LightningModule, Trainer from pytorch_lightning.callbacks import Callback, LearningRateMonitor, ModelCheckpoint -from pytorch_lightning.demos.boring_classes import BoringModel, RandomDataset +from pytorch_lightning.demos.boring_classes import BoringModel, RandomDataset, RandomIterableDataset from pytorch_lightning.plugins import DeepSpeedPrecisionPlugin from pytorch_lightning.strategies import DeepSpeedStrategy from pytorch_lightning.strategies.deepspeed import _DEEPSPEED_AVAILABLE, LightningDeepSpeedModule from pytorch_lightning.utilities.exceptions import MisconfigurationException from tests_pytorch.helpers.datamodules import ClassifDataModule -from tests_pytorch.helpers.datasets import RandomIterableDataset from tests_pytorch.helpers.runif import RunIf if _DEEPSPEED_AVAILABLE: @@ -171,12 +170,11 @@ def test_deepspeed_strategy_env(tmpdir, monkeypatch, deepspeed_config): @RunIf(deepspeed=True) @mock.patch("pytorch_lightning.utilities.device_parser.num_cuda_devices", return_value=1) -@pytest.mark.parametrize("precision", [16, "mixed"]) @pytest.mark.parametrize( "amp_backend", ["native", pytest.param("apex", marks=RunIf(amp_apex=True))], ) -def test_deepspeed_precision_choice(_, amp_backend, precision, tmpdir): +def test_deepspeed_precision_choice(_, amp_backend, tmpdir): """Test to ensure precision plugin is also correctly chosen. DeepSpeed handles precision via Custom DeepSpeedPrecisionPlugin @@ -188,16 +186,16 @@ def test_deepspeed_precision_choice(_, amp_backend, precision, tmpdir): accelerator="gpu", strategy="deepspeed", amp_backend=amp_backend, - precision=precision, + precision=16, ) assert isinstance(trainer.strategy, DeepSpeedStrategy) assert isinstance(trainer.strategy.precision_plugin, DeepSpeedPrecisionPlugin) - assert trainer.strategy.precision_plugin.precision == precision + assert trainer.strategy.precision_plugin.precision == 16 @RunIf(deepspeed=True) -def test_deepspeed_with_invalid_config_path(tmpdir): +def test_deepspeed_with_invalid_config_path(): """Test to ensure if we pass an invalid config path we throw an exception.""" with pytest.raises( @@ -218,7 +216,7 @@ def test_deepspeed_with_env_path(tmpdir, monkeypatch, deepspeed_config): @RunIf(deepspeed=True) -def test_deepspeed_defaults(tmpdir): +def test_deepspeed_defaults(): """Ensure that defaults are correctly set as a config for DeepSpeed if no arguments are passed.""" strategy = DeepSpeedStrategy() assert strategy.config is not None @@ -663,7 +661,7 @@ def training_step(self, batch, batch_idx): @RunIf(min_cuda_gpus=2, standalone=True, deepspeed=True) -def test_deepspeed_multigpu_stage_3(tmpdir, deepspeed_config): +def test_deepspeed_multigpu_stage_3(tmpdir): """Test to ensure ZeRO Stage 3 works with a parallel model.""" model = ModelParallelBoringModel() trainer = Trainer( diff --git a/tests/tests_pytorch/strategies/test_sharded_strategy.py b/tests/tests_pytorch/strategies/test_sharded_strategy.py index a047a10df32e3..ad0673ed1a5fa 100644 --- a/tests/tests_pytorch/strategies/test_sharded_strategy.py +++ b/tests/tests_pytorch/strategies/test_sharded_strategy.py @@ -300,3 +300,17 @@ def test_block_backward_sync(): with strategy.block_backward_sync(): pass model.no_sync.assert_called_once() + + +@pytest.mark.parametrize( + "strategy_name,expected_ddp_kwargs", + [ + ("ddp_sharded", {}), + ("ddp_sharded_find_unused_parameters_false", {"find_unused_parameters": False}), + ("ddp_sharded_spawn", {}), + ("ddp_sharded_spawn_find_unused_parameters_false", {"find_unused_parameters": False}), + ], +) +def test_ddp_kwargs_from_registry(strategy_name, expected_ddp_kwargs): + trainer = Trainer(strategy=strategy_name) + assert trainer.strategy._ddp_kwargs == expected_ddp_kwargs diff --git a/tests/tests_pytorch/trainer/connectors/test_callback_connector.py b/tests/tests_pytorch/trainer/connectors/test_callback_connector.py index d6d5018aa1dd0..02e846425a2a0 100644 --- a/tests/tests_pytorch/trainer/connectors/test_callback_connector.py +++ b/tests/tests_pytorch/trainer/connectors/test_callback_connector.py @@ -30,7 +30,7 @@ ) from pytorch_lightning.demos.boring_classes import BoringModel from pytorch_lightning.trainer.connectors.callback_connector import CallbackConnector -from pytorch_lightning.utilities.imports import _PYTHON_GREATER_EQUAL_3_8_0 +from pytorch_lightning.utilities.imports import _PYTHON_GREATER_EQUAL_3_8_0, _PYTHON_GREATER_EQUAL_3_10_0 def test_checkpoint_callbacks_are_last(tmpdir): @@ -265,7 +265,10 @@ def _make_entry_point_query_mock(callback_factory): entry_point = Mock() entry_point.name = "mocked" entry_point.load.return_value = callback_factory - if _PYTHON_GREATER_EQUAL_3_8_0: + if _PYTHON_GREATER_EQUAL_3_10_0: + query_mock.return_value = [entry_point] + import_path = "importlib.metadata.entry_points" + elif _PYTHON_GREATER_EQUAL_3_8_0: query_mock().get.return_value = [entry_point] import_path = "importlib.metadata.entry_points" else: diff --git a/tests/tests_pytorch/trainer/connectors/test_data_connector.py b/tests/tests_pytorch/trainer/connectors/test_data_connector.py index f2a98daa9c5ad..7273d7719834e 100644 --- a/tests/tests_pytorch/trainer/connectors/test_data_connector.py +++ b/tests/tests_pytorch/trainer/connectors/test_data_connector.py @@ -445,7 +445,8 @@ def test_dataloader_source_direct_access(): def test_dataloader_source_request_from_module(): """Test requesting a dataloader from a module works.""" module = BoringModel() - module.trainer = Trainer() + trainer = Trainer() + module.trainer = trainer module.foo = Mock(return_value=module.train_dataloader()) source = _DataLoaderSource(module, "foo") diff --git a/tests/tests_pytorch/trainer/flags/test_val_check_interval.py b/tests/tests_pytorch/trainer/flags/test_val_check_interval.py index 9414fd1c5096f..e5fd9b5dd2706 100644 --- a/tests/tests_pytorch/trainer/flags/test_val_check_interval.py +++ b/tests/tests_pytorch/trainer/flags/test_val_check_interval.py @@ -16,10 +16,9 @@ import pytest from torch.utils.data import DataLoader -from pytorch_lightning.demos.boring_classes import BoringModel, RandomDataset +from pytorch_lightning.demos.boring_classes import BoringModel, RandomDataset, RandomIterableDataset from pytorch_lightning.trainer.trainer import Trainer from pytorch_lightning.utilities.exceptions import MisconfigurationException -from tests_pytorch.helpers.datasets import RandomIterableDataset @pytest.mark.parametrize("max_epochs", [1, 2, 3]) diff --git a/tests/tests_pytorch/trainer/logging_/test_train_loop_logging.py b/tests/tests_pytorch/trainer/logging_/test_train_loop_logging.py index 5855eba4c86af..85ed3d8e3471d 100644 --- a/tests/tests_pytorch/trainer/logging_/test_train_loop_logging.py +++ b/tests/tests_pytorch/trainer/logging_/test_train_loop_logging.py @@ -28,9 +28,8 @@ from pytorch_lightning import callbacks, Trainer from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint, TQDMProgressBar from pytorch_lightning.core.module import LightningModule -from pytorch_lightning.demos.boring_classes import BoringModel, RandomDataset +from pytorch_lightning.demos.boring_classes import BoringModel, RandomDataset, RandomDictDataset from pytorch_lightning.utilities.exceptions import MisconfigurationException -from tests_pytorch.helpers.datasets import RandomDictDataset from tests_pytorch.helpers.runif import RunIf @@ -569,11 +568,12 @@ def on_train_epoch_end(self, trainer, pl_module): "accelerator", [ pytest.param("gpu", marks=RunIf(min_cuda_gpus=1)), + "cpu", ], ) def test_metric_are_properly_reduced(tmpdir, accelerator): class TestingModel(BoringModel): - def __init__(self, *args, **kwargs) -> None: + def __init__(self) -> None: super().__init__() self.val_acc = Accuracy() @@ -592,7 +592,6 @@ def validation_step(self, batch, batch_idx): return super().validation_step(batch, batch_idx) early_stop = EarlyStopping(monitor="val_acc", mode="max") - checkpoint = ModelCheckpoint(monitor="val_acc", save_last=True, save_top_k=2, mode="max") model = TestingModel() @@ -812,3 +811,28 @@ def training_step(self, batch, batch_idx): call(metrics={"foo_epoch": 0.0, "epoch": 1}, step=3), ] ) + + +@mock.patch("pytorch_lightning.loggers.TensorBoardLogger.log_metrics") +def test_log_on_train_start(mock_log_metrics, tmpdir): + """Tests that logged metrics on_train_start get reset after the first epoch.""" + + class MyModel(BoringModel): + def on_train_start(self): + self.log("foo", 123) + + model = MyModel() + trainer = Trainer( + default_root_dir=tmpdir, + limit_train_batches=1, + limit_val_batches=0, + max_epochs=2, + log_every_n_steps=1, + enable_model_summary=False, + enable_checkpointing=False, + enable_progress_bar=False, + ) + trainer.fit(model) + + assert mock_log_metrics.mock_calls == [call(metrics={"foo": 123.0, "epoch": 0}, step=0)] + assert trainer.max_epochs > 1 diff --git a/tests/tests_pytorch/trainer/properties/test_estimated_stepping_batches.py b/tests/tests_pytorch/trainer/properties/test_estimated_stepping_batches.py index 92a1126294dfc..846a39a748a60 100644 --- a/tests/tests_pytorch/trainer/properties/test_estimated_stepping_batches.py +++ b/tests/tests_pytorch/trainer/properties/test_estimated_stepping_batches.py @@ -22,11 +22,10 @@ from pytorch_lightning import Trainer from pytorch_lightning.callbacks.gradient_accumulation_scheduler import GradientAccumulationScheduler -from pytorch_lightning.demos.boring_classes import BoringModel +from pytorch_lightning.demos.boring_classes import BoringModel, RandomIterableDataset from pytorch_lightning.strategies.ipu import IPUStrategy from pytorch_lightning.utilities import device_parser from pytorch_lightning.utilities.exceptions import MisconfigurationException -from tests_pytorch.helpers.datasets import RandomIterableDataset from tests_pytorch.helpers.runif import RunIf diff --git a/tests/tests_pytorch/trainer/test_dataloaders.py b/tests/tests_pytorch/trainer/test_dataloaders.py index 5bea5a4cbbe1c..34504392dc0c1 100644 --- a/tests/tests_pytorch/trainer/test_dataloaders.py +++ b/tests/tests_pytorch/trainer/test_dataloaders.py @@ -25,12 +25,16 @@ from pytorch_lightning import Callback, seed_everything, Trainer from pytorch_lightning.callbacks import ModelCheckpoint -from pytorch_lightning.demos.boring_classes import BoringModel, RandomDataset +from pytorch_lightning.demos.boring_classes import ( + BoringModel, + RandomDataset, + RandomIterableDataset, + RandomIterableDatasetWithLen, +) from pytorch_lightning.trainer.states import RunningStage from pytorch_lightning.utilities.data import _auto_add_worker_init_fn, has_iterable_dataset, has_len_all_ranks from pytorch_lightning.utilities.exceptions import MisconfigurationException from tests_pytorch.helpers.dataloaders import CustomInfDataloader, CustomNotImplementedErrorDataloader -from tests_pytorch.helpers.datasets import RandomIterableDataset, RandomIterableDatasetWithLen from tests_pytorch.helpers.runif import RunIf diff --git a/tests/tests_pytorch/trainer/test_trainer.py b/tests/tests_pytorch/trainer/test_trainer.py index e4be8929f9c7e..9506acee425d0 100644 --- a/tests/tests_pytorch/trainer/test_trainer.py +++ b/tests/tests_pytorch/trainer/test_trainer.py @@ -41,7 +41,12 @@ from pytorch_lightning.callbacks.fault_tolerance import _FaultToleranceCheckpoint from pytorch_lightning.callbacks.prediction_writer import BasePredictionWriter from pytorch_lightning.core.saving import load_hparams_from_tags_csv, load_hparams_from_yaml, save_hparams_to_tags_csv -from pytorch_lightning.demos.boring_classes import BoringModel, RandomDataset +from pytorch_lightning.demos.boring_classes import ( + BoringModel, + RandomDataset, + RandomIterableDataset, + RandomIterableDatasetWithLen, +) from pytorch_lightning.loggers import TensorBoardLogger from pytorch_lightning.overrides.distributed import IndexBatchSamplerWrapper, UnrepeatedDistributedSampler from pytorch_lightning.strategies import ( @@ -60,7 +65,6 @@ from pytorch_lightning.utilities.imports import _OMEGACONF_AVAILABLE, _TORCH_GREATER_EQUAL_1_12 from pytorch_lightning.utilities.seed import seed_everything from tests_pytorch.helpers.datamodules import ClassifDataModule -from tests_pytorch.helpers.datasets import RandomIterableDataset, RandomIterableDatasetWithLen from tests_pytorch.helpers.runif import RunIf from tests_pytorch.helpers.simple_models import ClassificationModel diff --git a/tests/tests_pytorch/utilities/distributed.py b/tests/tests_pytorch/utilities/distributed.py deleted file mode 100644 index 38a50edcc7177..0000000000000 --- a/tests/tests_pytorch/utilities/distributed.py +++ /dev/null @@ -1,45 +0,0 @@ -# Copyright The PyTorch Lightning team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import os -import subprocess -import sys -from pathlib import Path -from subprocess import TimeoutExpired - -import pytorch_lightning - - -def call_training_script(module_file, cli_args, method, tmpdir, timeout=60, as_module=False): - file = Path(module_file.__file__).absolute() - cli_args = cli_args.split(" ") if cli_args else [] - cli_args += ["--tmpdir", str(tmpdir)] - cli_args += ["--trainer_method", method] - file_args = ["-m", module_file.__spec__.name] if as_module else [str(file)] - command = [sys.executable] + file_args + cli_args - - # need to set the PYTHONPATH in case pytorch_lightning was not installed into the environment - env = os.environ.copy() - env["PYTHONPATH"] = env.get("PYTHONPATH", "") + f"{pytorch_lightning.__file__}:" - - # for running in ddp mode, we need to launch it's own process or pytest will get stuck - p = subprocess.Popen(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, env=env) - try: - std, err = p.communicate(timeout=timeout) - err = str(err.decode("utf-8")) - if "Exception" in err: - raise Exception(err) - except TimeoutExpired: - p.kill() - std, err = p.communicate() - return std, err diff --git a/tests/tests_pytorch/utilities/test_data.py b/tests/tests_pytorch/utilities/test_data.py index ffb898efaa815..cc70417988616 100644 --- a/tests/tests_pytorch/utilities/test_data.py +++ b/tests/tests_pytorch/utilities/test_data.py @@ -1,3 +1,4 @@ +import random from dataclasses import dataclass import pytest @@ -6,7 +7,7 @@ from torch.utils.data import BatchSampler, DataLoader, RandomSampler, SequentialSampler from pytorch_lightning import Trainer -from pytorch_lightning.demos.boring_classes import BoringModel, RandomDataset +from pytorch_lightning.demos.boring_classes import BoringModel, RandomDataset, RandomIterableDataset from pytorch_lightning.overrides.distributed import IndexBatchSamplerWrapper from pytorch_lightning.trainer.states import RunningStage from pytorch_lightning.utilities.data import ( @@ -23,7 +24,6 @@ warning_cache, ) from pytorch_lightning.utilities.exceptions import MisconfigurationException -from tests_pytorch.helpers.datasets import RandomIterableDataset from tests_pytorch.helpers.utils import no_warning_call @@ -173,6 +173,30 @@ def __init__(self, randomize, *args, **kwargs): assert isinstance(new_dataloader, GoodImpl) +def test_replace_init_method_multiple_loaders_without_init(): + """In case of a class, that inherits from a class that we are patching, but doesn't define its own `__init__` + method (the one we are wrapping), it can happen, that `hasattr(cls, "_old_init")` is True because of parent + class, but it is impossible to delete, because that method is owned by parent class. Furthermore, the error + occured only sometimes because it depends on the order in which we are iterating over a set of classes we are + patching. + + This test simulates the behavior by generating sufficient number of dummy classes, which do not define `__init__` + and are children of `DataLoader`. We are testing that a) context manager `_replace_init_method` exits cleanly, and + b) the mechanism checking for presence of `_old_init` works as expected. + """ + classes = [DataLoader] + for i in range(100): + classes.append(type(f"DataLoader_{i}", (random.choice(classes),), {})) + + with _replace_init_method(DataLoader, "dataset"): + for cls in classes[1:]: # First one is `DataLoader` + assert "_old_init" not in cls.__dict__ + assert hasattr(cls, "_old_init") + + assert "_old_init" in DataLoader.__dict__ + assert hasattr(DataLoader, "_old_init") + + class DataLoaderSubclass1(DataLoader): def __init__(self, attribute1, *args, **kwargs): self.at1 = attribute1 diff --git a/tests/tests_pytorch/utilities/test_dtype_device_mixin.py b/tests/tests_pytorch/utilities/test_dtype_device_mixin.py index 38f72b555d52d..7c17b3d9f7642 100644 --- a/tests/tests_pytorch/utilities/test_dtype_device_mixin.py +++ b/tests/tests_pytorch/utilities/test_dtype_device_mixin.py @@ -113,7 +113,7 @@ def test_submodules_multi_gpu_ddp_spawn(tmpdir): ], ) @RunIf(min_cuda_gpus=1) -def test_gpu_cuda_device(device): +def test_cuda_device(device): model = TopModule() model.cuda(device) @@ -122,3 +122,25 @@ def test_gpu_cuda_device(device): assert device.type == "cuda" assert device.index is not None assert device.index == torch.cuda.current_device() + + +@RunIf(min_cuda_gpus=2) +def test_cuda_current_device(): + """Test that calling .cuda() moves the model to the correct device and respects current cuda device setting.""" + + class CudaModule(DeviceDtypeModuleMixin): + def __init__(self): + super().__init__() + self.layer = nn.Linear(1, 1) + + model = CudaModule() + + torch.cuda.set_device(0) + model.cuda(1) + assert model.device == torch.device("cuda", 1) + assert model.layer.weight.device == torch.device("cuda", 1) + + torch.cuda.set_device(1) + model.cuda() # model is already on device 1, and calling .cuda() without device index should not move model + assert model.device == torch.device("cuda", 1) + assert model.layer.weight.device == torch.device("cuda", 1) diff --git a/tests/tests_pytorch/utilities/test_grads.py b/tests/tests_pytorch/utilities/test_grads.py index a548de66ab85d..49aab76403847 100644 --- a/tests/tests_pytorch/utilities/test_grads.py +++ b/tests/tests_pytorch/utilities/test_grads.py @@ -76,3 +76,17 @@ def __init__(self): def test_grad_norm_invalid_norm_type(norm_type): with pytest.raises(ValueError, match="`norm_type` must be a positive number or 'inf'"): grad_norm(Mock(), norm_type) + + +def test_grad_norm_with_double_dtype(): + class Model(nn.Module): + def __init__(self): + super().__init__() + dtype = torch.double + self.param = nn.Parameter(torch.tensor(1.0, dtype=dtype)) + # grad norm of this would become infinite + self.param.grad = torch.tensor(1e23, dtype=dtype) + + model = Model() + norms = grad_norm(model, 2) + assert all(torch.isfinite(torch.tensor(v)) for v in norms.values()), norms diff --git a/tests/tests_pytorch/utilities/test_seed.py b/tests/tests_pytorch/utilities/test_seed.py index 7f162bd605640..6908badf1a037 100644 --- a/tests/tests_pytorch/utilities/test_seed.py +++ b/tests/tests_pytorch/utilities/test_seed.py @@ -1,6 +1,8 @@ import os import random +from typing import Mapping from unittest import mock +from unittest.mock import MagicMock import numpy as np import pytest @@ -96,3 +98,19 @@ def test_isolate_rng(): with isolate_rng(): generated = [random.random() for _ in range(3)] assert random.random() == generated[0] + + +@mock.patch("pytorch_lightning.utilities.seed.log.info") +@pytest.mark.parametrize("env_vars", [{"RANK": "0"}, {"RANK": "1"}, {"RANK": "4"}]) +def test_seed_everything_log_info(log_mock: MagicMock, env_vars: Mapping[str, str]): + """Test that log message prefix with correct rank info.""" + with mock.patch.dict(os.environ, env_vars, clear=True): + from pytorch_lightning.utilities.rank_zero import _get_rank + + rank = _get_rank() + + seed_utils.seed_everything(123) + + expected_log = f"[rank: {rank}] Global seed set to 123" + + log_mock.assert_called_once_with(expected_log)