Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

hotfix import torch #15849

Merged
merged 13 commits into from
Nov 28, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion .github/workflows/ci-pkg-install.yml
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ jobs:
- name: DocTests actions
working-directory: .actions/
run: |
pip install pytest -q
pip install -q pytest
python -m pytest setup_tools.py

- run: python -c "print('NB_DIRS=' + str(2 if '${{ matrix.pkg-name }}' == 'pytorch' else 1))" >> $GITHUB_ENV
Expand All @@ -67,7 +67,10 @@ jobs:

- name: DocTest package
env:
LIGHTING_TESTING: 1 # path for require wrapper
PY_IGNORE_IMPORTMISMATCH: 1
run: |
pip install -q "pytest-doctestplus>=0.9.0"
pip list
PKG_NAME=$(python -c "print({'app': 'lightning_app', 'lite': 'lightning_lite', 'pytorch': 'pytorch_lightning', 'lightning': 'lightning'}['${{matrix.pkg-name}}'])")
python -m pytest src/${PKG_NAME} --ignore-glob="**/cli/*-template/**"
2 changes: 1 addition & 1 deletion requirements/app/base.txt
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,4 @@ beautifulsoup4>=4.8.0, <4.11.2
inquirer>=2.10.0
psutil<5.9.4
click<=8.1.3
lightning_api_access>=0.0.1
lightning_api_access>=0.0.3
1 change: 1 addition & 0 deletions requirements/app/test.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ codecov==2.1.12
pytest==7.2.0
pytest-timeout==2.1.0
pytest-cov==4.0.0
pytest-doctestplus>=0.9.0
playwright==1.27.1
httpx
trio<0.22.0
Expand Down
18 changes: 13 additions & 5 deletions src/lightning_app/components/serve/python_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
from pathlib import Path
from typing import Any, Dict, Optional

import torch
import uvicorn
from fastapi import FastAPI
from pydantic import BaseModel
Expand All @@ -13,16 +12,21 @@
from lightning_app.core.queues import MultiProcessQueue
from lightning_app.core.work import LightningWork
from lightning_app.utilities.app_helpers import Logger
from lightning_app.utilities.imports import _is_torch_available, requires
from lightning_app.utilities.proxies import _proxy_setattr, unwrap, WorkRunExecutor, WorkStateObserver

logger = Logger(__name__)

# Skip doctests if requirements aren't available
if not _is_torch_available():
__doctest_skip__ = ["PythonServer", "PythonServer.*"]


class _PyTorchSpawnRunExecutor(WorkRunExecutor):

"""This Executor enables to move PyTorch tensors on GPU.

Without this executor, it woud raise the following expection:
Without this executor, it would raise the following exception:
RuntimeError: Cannot re-initialize CUDA in forked subprocess.
To use CUDA with multiprocessing, you must use the 'spawn' start method
"""
Expand Down Expand Up @@ -86,6 +90,7 @@ def _get_sample_data() -> Dict[Any, Any]:


class PythonServer(LightningWork, abc.ABC):
@requires("torch")
def __init__( # type: ignore
self,
host: str = "127.0.0.1",
Expand Down Expand Up @@ -127,15 +132,16 @@ def predict(self, request):
and this can be accessed as `response.json()["prediction"]` in the client if
you are using requests library

.. doctest::
Example:

>>> from lightning_app.components.serve.python_server import PythonServer
>>> from lightning_app import LightningApp
>>>
...
>>> class SimpleServer(PythonServer):
...
... def setup(self):
... self._model = lambda x: x + " " + x
...
... def predict(self, request):
... return {"prediction": self._model(request.image)}
...
Expand Down Expand Up @@ -199,11 +205,13 @@ def _get_sample_dict_from_datatype(datatype: Any) -> dict:
return out

def _attach_predict_fn(self, fastapi_app: FastAPI) -> None:
from torch import inference_mode

input_type: type = self.configure_input_type()
output_type: type = self.configure_output_type()

def predict_fn(request: input_type): # type: ignore
with torch.inference_mode():
with inference_mode():
return self.predict(request)

fastapi_app.post("/predict", response_model=output_type)(predict_fn)
Expand Down
13 changes: 9 additions & 4 deletions src/lightning_app/utilities/imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""General utilities."""

import functools
import os
import warnings
from typing import List, Union

from lightning_utilities.core.imports import module_available
Expand Down Expand Up @@ -52,10 +54,13 @@ def decorator(func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
unavailable_modules = [f"'{module}'" for module in module_paths if not module_available(module)]
if any(unavailable_modules) and not bool(int(os.getenv("LIGHTING_TESTING", "0"))):
raise ModuleNotFoundError(
f"Required dependencies not available. Please run: pip install {' '.join(unavailable_modules)}"
)
if any(unavailable_modules):
is_lit_testing = bool(int(os.getenv("LIGHTING_TESTING", "0")))
msg = f"Required dependencies not available. Please run: pip install {' '.join(unavailable_modules)}"
if is_lit_testing:
warnings.warn(msg)
else:
raise ModuleNotFoundError(msg)
return func(*args, **kwargs)

return wrapper
Expand Down
13 changes: 7 additions & 6 deletions src/lightning_app/utilities/name_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -1332,12 +1332,13 @@ def get_unique_name():
Original source:
https://raw.githubusercontent.com/moby/moby/master/pkg/namesgenerator/names-generator.go

Examples
--------
>>> get_unique_name() # doctest: +SKIP
'focused-turing-23'
>>> get_unique_name() # doctest: +SKIP
'thirsty-allen-9200'
Examples:

>>> import random ; random.seed(42)
>>> get_unique_name()
'meek-ardinghelli-4506'
>>> get_unique_name()
'truthful-dijkstra-2286'
"""
adjective, surname, i = choice(_adjectives), choice(_surnames), randint(0, 9999)
return f"{adjective}-{surname}-{i}"
4 changes: 2 additions & 2 deletions tests/tests_app/core/test_lightning_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ def run(self):


# TODO: Find why this test is flaky.
@pytest.mark.skipif(True, reason="flaky test.")
@pytest.mark.skip(reason="flaky test.")
@pytest.mark.parametrize("runtime_cls", [SingleProcessRuntime])
def test_app_state_api_with_flows(runtime_cls, tmpdir):
"""This test validates the AppState can properly broadcast changes from flows."""
Expand Down Expand Up @@ -180,7 +180,7 @@ def maybe_apply_changes(self):


# FIXME: This test doesn't assert anything
@pytest.mark.skipif(True, reason="TODO: Resolve flaky test.")
@pytest.mark.skip(reason="TODO: Resolve flaky test.")
@pytest.mark.parametrize("runtime_cls", [SingleProcessRuntime, MultiProcessRuntime])
def test_app_stage_from_frontend(runtime_cls):
"""This test validates that delta from the `api_delta_queue` manipulating the ['app_state']['stage'] would
Expand Down
2 changes: 1 addition & 1 deletion tests/tests_app/core/test_lightning_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -582,7 +582,7 @@ def run(self):


# TODO (tchaton) Resolve this test.
@pytest.mark.skipif(True, reason="flaky test which never terminates")
@pytest.mark.skip(reason="flaky test which never terminates")
@pytest.mark.parametrize("runtime_cls", [MultiProcessRuntime])
@pytest.mark.parametrize("use_same_args", [False, True])
def test_state_wait_for_all_all_works(tmpdir, runtime_cls, use_same_args):
Expand Down
2 changes: 1 addition & 1 deletion tests/tests_app/structures/test_structures.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,7 +308,7 @@ def run(self):
self.counter += 1


@pytest.mark.skipif(True, reason="tchaton: Resolve this test.")
@pytest.mark.skip(reason="tchaton: Resolve this test.")
@pytest.mark.parametrize("runtime_cls", [MultiProcessRuntime, SingleProcessRuntime])
@pytest.mark.parametrize("run_once_iterable", [False, True])
@pytest.mark.parametrize("cache_calls", [False, True])
Expand Down
2 changes: 1 addition & 1 deletion tests/tests_app/utilities/packaging/test_docker.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from lightning_app.utilities.redis import check_if_redis_running


@pytest.mark.skipif(True, reason="FIXME (tchaton)")
@pytest.mark.skip(reason="FIXME (tchaton)")
@pytest.mark.skipif(not _is_docker_available(), reason="docker is required for this test.")
@pytest.mark.skipif(not check_if_redis_running(), reason="redis is required for this test.")
@_RunIf(skip_windows=True)
Expand Down
2 changes: 1 addition & 1 deletion tests/tests_lite/test_parity.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ def run(rank, model, train_dataloader, num_epochs, precision, accelerator, tmpdi
_atomic_save(model.state_dict(), os.path.join(tmpdir, "model_spawn.pt"))


@pytest.mark.skipif(True, reason="Skipping as it takes 80 seconds.")
@pytest.mark.skip(reason="Skipping as it takes 80 seconds.")
@RunIf(min_cuda_gpus=2)
@pytest.mark.parametrize(
"precision, strategy, devices, accelerator",
Expand Down