From ad4bd66df631c97d4b431ef1a8b5705c10f11367 Mon Sep 17 00:00:00 2001 From: Jirka Borovec <6035284+Borda@users.noreply.github.com> Date: Mon, 28 Nov 2022 14:58:29 +0100 Subject: [PATCH] hotfix import torch (#15849) * fix import torch * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * plugin * fix * skip * patch require * seed * warn * . * .. * skip True * 0.0.3 Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- .github/workflows/ci-pkg-install.yml | 5 ++++- requirements/app/base.txt | 2 +- requirements/app/test.txt | 1 + .../components/serve/python_server.py | 18 +++++++++++++----- src/lightning_app/utilities/imports.py | 13 +++++++++---- src/lightning_app/utilities/name_generator.py | 13 +++++++------ tests/tests_app/core/test_lightning_api.py | 4 ++-- tests/tests_app/core/test_lightning_app.py | 2 +- tests/tests_app/structures/test_structures.py | 2 +- .../utilities/packaging/test_docker.py | 2 +- tests/tests_lite/test_parity.py | 2 +- 11 files changed, 41 insertions(+), 23 deletions(-) diff --git a/.github/workflows/ci-pkg-install.yml b/.github/workflows/ci-pkg-install.yml index 4d2cc896453a9..d9474edb98f8f 100644 --- a/.github/workflows/ci-pkg-install.yml +++ b/.github/workflows/ci-pkg-install.yml @@ -46,7 +46,7 @@ jobs: - name: DocTests actions working-directory: .actions/ run: | - pip install pytest -q + pip install -q pytest python -m pytest setup_tools.py - run: python -c "print('NB_DIRS=' + str(2 if '${{ matrix.pkg-name }}' == 'pytorch' else 1))" >> $GITHUB_ENV @@ -67,7 +67,10 @@ jobs: - name: DocTest package env: + LIGHTING_TESTING: 1 # path for require wrapper PY_IGNORE_IMPORTMISMATCH: 1 run: | + pip install -q "pytest-doctestplus>=0.9.0" + pip list PKG_NAME=$(python -c "print({'app': 'lightning_app', 'lite': 'lightning_lite', 'pytorch': 'pytorch_lightning', 'lightning': 'lightning'}['${{matrix.pkg-name}}'])") python -m pytest src/${PKG_NAME} --ignore-glob="**/cli/*-template/**" diff --git a/requirements/app/base.txt b/requirements/app/base.txt index 83244527e285b..a3da6958b6688 100644 --- a/requirements/app/base.txt +++ b/requirements/app/base.txt @@ -12,4 +12,4 @@ beautifulsoup4>=4.8.0, <4.11.2 inquirer>=2.10.0 psutil<5.9.4 click<=8.1.3 -lightning_api_access>=0.0.1 +lightning_api_access>=0.0.3 diff --git a/requirements/app/test.txt b/requirements/app/test.txt index 3fc0e51c42215..1d1bcf271974f 100644 --- a/requirements/app/test.txt +++ b/requirements/app/test.txt @@ -3,6 +3,7 @@ codecov==2.1.12 pytest==7.2.0 pytest-timeout==2.1.0 pytest-cov==4.0.0 +pytest-doctestplus>=0.9.0 playwright==1.27.1 httpx trio<0.22.0 diff --git a/src/lightning_app/components/serve/python_server.py b/src/lightning_app/components/serve/python_server.py index 9ce1b23701059..7f7a8eeea98f4 100644 --- a/src/lightning_app/components/serve/python_server.py +++ b/src/lightning_app/components/serve/python_server.py @@ -4,7 +4,6 @@ from pathlib import Path from typing import Any, Dict, Optional -import torch import uvicorn from fastapi import FastAPI from pydantic import BaseModel @@ -13,16 +12,21 @@ from lightning_app.core.queues import MultiProcessQueue from lightning_app.core.work import LightningWork from lightning_app.utilities.app_helpers import Logger +from lightning_app.utilities.imports import _is_torch_available, requires from lightning_app.utilities.proxies import _proxy_setattr, unwrap, WorkRunExecutor, WorkStateObserver logger = Logger(__name__) +# Skip doctests if requirements aren't available +if not _is_torch_available(): + __doctest_skip__ = ["PythonServer", "PythonServer.*"] + class _PyTorchSpawnRunExecutor(WorkRunExecutor): """This Executor enables to move PyTorch tensors on GPU. - Without this executor, it woud raise the following expection: + Without this executor, it would raise the following exception: RuntimeError: Cannot re-initialize CUDA in forked subprocess. To use CUDA with multiprocessing, you must use the 'spawn' start method """ @@ -86,6 +90,7 @@ def _get_sample_data() -> Dict[Any, Any]: class PythonServer(LightningWork, abc.ABC): + @requires("torch") def __init__( # type: ignore self, host: str = "127.0.0.1", @@ -127,15 +132,16 @@ def predict(self, request): and this can be accessed as `response.json()["prediction"]` in the client if you are using requests library - .. doctest:: + Example: >>> from lightning_app.components.serve.python_server import PythonServer >>> from lightning_app import LightningApp - >>> ... >>> class SimpleServer(PythonServer): + ... ... def setup(self): ... self._model = lambda x: x + " " + x + ... ... def predict(self, request): ... return {"prediction": self._model(request.image)} ... @@ -199,11 +205,13 @@ def _get_sample_dict_from_datatype(datatype: Any) -> dict: return out def _attach_predict_fn(self, fastapi_app: FastAPI) -> None: + from torch import inference_mode + input_type: type = self.configure_input_type() output_type: type = self.configure_output_type() def predict_fn(request: input_type): # type: ignore - with torch.inference_mode(): + with inference_mode(): return self.predict(request) fastapi_app.post("/predict", response_model=output_type)(predict_fn) diff --git a/src/lightning_app/utilities/imports.py b/src/lightning_app/utilities/imports.py index 60747ff3624c0..b484110d3811e 100644 --- a/src/lightning_app/utilities/imports.py +++ b/src/lightning_app/utilities/imports.py @@ -12,8 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. """General utilities.""" + import functools import os +import warnings from typing import List, Union from lightning_utilities.core.imports import module_available @@ -52,10 +54,13 @@ def decorator(func): @functools.wraps(func) def wrapper(*args, **kwargs): unavailable_modules = [f"'{module}'" for module in module_paths if not module_available(module)] - if any(unavailable_modules) and not bool(int(os.getenv("LIGHTING_TESTING", "0"))): - raise ModuleNotFoundError( - f"Required dependencies not available. Please run: pip install {' '.join(unavailable_modules)}" - ) + if any(unavailable_modules): + is_lit_testing = bool(int(os.getenv("LIGHTING_TESTING", "0"))) + msg = f"Required dependencies not available. Please run: pip install {' '.join(unavailable_modules)}" + if is_lit_testing: + warnings.warn(msg) + else: + raise ModuleNotFoundError(msg) return func(*args, **kwargs) return wrapper diff --git a/src/lightning_app/utilities/name_generator.py b/src/lightning_app/utilities/name_generator.py index 4cbb46069ea13..121ef7e9b0852 100644 --- a/src/lightning_app/utilities/name_generator.py +++ b/src/lightning_app/utilities/name_generator.py @@ -1332,12 +1332,13 @@ def get_unique_name(): Original source: https://raw.githubusercontent.com/moby/moby/master/pkg/namesgenerator/names-generator.go - Examples - -------- - >>> get_unique_name() # doctest: +SKIP - 'focused-turing-23' - >>> get_unique_name() # doctest: +SKIP - 'thirsty-allen-9200' + Examples: + + >>> import random ; random.seed(42) + >>> get_unique_name() + 'meek-ardinghelli-4506' + >>> get_unique_name() + 'truthful-dijkstra-2286' """ adjective, surname, i = choice(_adjectives), choice(_surnames), randint(0, 9999) return f"{adjective}-{surname}-{i}" diff --git a/tests/tests_app/core/test_lightning_api.py b/tests/tests_app/core/test_lightning_api.py index d81c72c06f071..82b58cc36fac3 100644 --- a/tests/tests_app/core/test_lightning_api.py +++ b/tests/tests_app/core/test_lightning_api.py @@ -106,7 +106,7 @@ def run(self): # TODO: Find why this test is flaky. -@pytest.mark.skipif(True, reason="flaky test.") +@pytest.mark.skip(reason="flaky test.") @pytest.mark.parametrize("runtime_cls", [SingleProcessRuntime]) def test_app_state_api_with_flows(runtime_cls, tmpdir): """This test validates the AppState can properly broadcast changes from flows.""" @@ -180,7 +180,7 @@ def maybe_apply_changes(self): # FIXME: This test doesn't assert anything -@pytest.mark.skipif(True, reason="TODO: Resolve flaky test.") +@pytest.mark.skip(reason="TODO: Resolve flaky test.") @pytest.mark.parametrize("runtime_cls", [SingleProcessRuntime, MultiProcessRuntime]) def test_app_stage_from_frontend(runtime_cls): """This test validates that delta from the `api_delta_queue` manipulating the ['app_state']['stage'] would diff --git a/tests/tests_app/core/test_lightning_app.py b/tests/tests_app/core/test_lightning_app.py index a5fac6bf0da41..e5c265e2efde9 100644 --- a/tests/tests_app/core/test_lightning_app.py +++ b/tests/tests_app/core/test_lightning_app.py @@ -582,7 +582,7 @@ def run(self): # TODO (tchaton) Resolve this test. -@pytest.mark.skipif(True, reason="flaky test which never terminates") +@pytest.mark.skip(reason="flaky test which never terminates") @pytest.mark.parametrize("runtime_cls", [MultiProcessRuntime]) @pytest.mark.parametrize("use_same_args", [False, True]) def test_state_wait_for_all_all_works(tmpdir, runtime_cls, use_same_args): diff --git a/tests/tests_app/structures/test_structures.py b/tests/tests_app/structures/test_structures.py index 7b84e31402f36..05905c3421bec 100644 --- a/tests/tests_app/structures/test_structures.py +++ b/tests/tests_app/structures/test_structures.py @@ -308,7 +308,7 @@ def run(self): self.counter += 1 -@pytest.mark.skipif(True, reason="tchaton: Resolve this test.") +@pytest.mark.skip(reason="tchaton: Resolve this test.") @pytest.mark.parametrize("runtime_cls", [MultiProcessRuntime, SingleProcessRuntime]) @pytest.mark.parametrize("run_once_iterable", [False, True]) @pytest.mark.parametrize("cache_calls", [False, True]) diff --git a/tests/tests_app/utilities/packaging/test_docker.py b/tests/tests_app/utilities/packaging/test_docker.py index 0c87db03875f2..75d6b875c6a92 100644 --- a/tests/tests_app/utilities/packaging/test_docker.py +++ b/tests/tests_app/utilities/packaging/test_docker.py @@ -12,7 +12,7 @@ from lightning_app.utilities.redis import check_if_redis_running -@pytest.mark.skipif(True, reason="FIXME (tchaton)") +@pytest.mark.skip(reason="FIXME (tchaton)") @pytest.mark.skipif(not _is_docker_available(), reason="docker is required for this test.") @pytest.mark.skipif(not check_if_redis_running(), reason="redis is required for this test.") @_RunIf(skip_windows=True) diff --git a/tests/tests_lite/test_parity.py b/tests/tests_lite/test_parity.py index c5687ee58a120..b74a23438d0d6 100644 --- a/tests/tests_lite/test_parity.py +++ b/tests/tests_lite/test_parity.py @@ -162,7 +162,7 @@ def run(rank, model, train_dataloader, num_epochs, precision, accelerator, tmpdi _atomic_save(model.state_dict(), os.path.join(tmpdir, "model_spawn.pt")) -@pytest.mark.skipif(True, reason="Skipping as it takes 80 seconds.") +@pytest.mark.skip(reason="Skipping as it takes 80 seconds.") @RunIf(min_cuda_gpus=2) @pytest.mark.parametrize( "precision, strategy, devices, accelerator",