Skip to content

Commit

Permalink
[App] Resolve PythonServer on M1 (#15949)
Browse files Browse the repository at this point in the history
Co-authored-by: thomas <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
3 people authored Dec 8, 2022
1 parent 36aecde commit 904323b
Show file tree
Hide file tree
Showing 6 changed files with 27 additions and 56 deletions.
2 changes: 1 addition & 1 deletion requirements/app/ui.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
streamlit>=1.3.1, <=1.11.1
streamlit>=1.0.0, <=1.15.2
panel>=0.12.7, <=0.13.1
5 changes: 2 additions & 3 deletions src/lightning_app/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -56,17 +56,16 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Fixed MPS error for multinode component (defaults to cpu on mps devices now as distributed operations are not supported by pytorch on mps) ([#15748](https://github.com/Ligtning-AI/lightning/pull/15748))



- Fixed the work not stopped when successful when passed directly to the LightningApp ([#15801](https://github.com/Lightning-AI/lightning/pull/15801))


- Fixed the `enable_spawn` method of the `WorkRunExecutor` ([#15812](https://github.com/Lightning-AI/lightning/pull/15812)

- Fixed Sigterm Handler causing thread lock which caused KeyboardInterrupt to hang ([#15881](https://github.com/Lightning-AI/lightning/pull/15881))

- Fixed a bug where using `L.app.structures` would cause multiple apps to be opened and fail with an error in the cloud ([#15911](https://github.com/Lightning-AI/lightning/pull/15911))

- Fixed PythonServer generating noise on M1 ([#15949](https://github.com/Lightning-AI/lightning/pull/15949))


## [1.8.3] - 2022-11-22

Expand Down
2 changes: 0 additions & 2 deletions src/lightning_app/components/auto_scaler.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,6 @@ async def process_request(self, data: BaseModel):
return result

def run(self):

logger.info(f"servers: {self.servers}")
lock = asyncio.Lock()

Expand Down Expand Up @@ -271,7 +270,6 @@ async def sys_info(authenticated: bool = Depends(authenticate_private_endpoint))
async def update_servers(servers: List[str], authenticated: bool = Depends(authenticate_private_endpoint)):
async with lock:
self.servers = servers

self._iter = cycle(self.servers)

@fastapi_app.post(self.endpoint, response_model=self._output_type)
Expand Down
3 changes: 3 additions & 0 deletions src/lightning_app/components/python/tracer.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@ class Code(TypedDict):


class TracerPythonScript(LightningWork):

_start_method = "spawn"

def on_before_run(self):
"""Called before the python script is executed."""

Expand Down
8 changes: 2 additions & 6 deletions src/lightning_app/components/serve/gradio.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
import abc
import os
from functools import partial
from types import ModuleType
from typing import Any, List, Optional

from lightning_app.components.serve.python_server import _PyTorchSpawnRunExecutor, WorkRunExecutor
from lightning_app.core.work import LightningWork
from lightning_app.utilities.imports import _is_gradio_available, requires

Expand Down Expand Up @@ -36,15 +34,13 @@ class ServeGradio(LightningWork, abc.ABC):
title: Optional[str] = None
description: Optional[str] = None

_start_method = "spawn"

def __init__(self, *args, **kwargs):
requires("gradio")(super().__init__(*args, **kwargs))
assert self.inputs
assert self.outputs
self._model = None
# Note: Enable to run inference on GPUs.
self._run_executor_cls = (
WorkRunExecutor if os.getenv("LIGHTNING_CLOUD_APP_ID", None) else _PyTorchSpawnRunExecutor
)

@property
def model(self):
Expand Down
63 changes: 19 additions & 44 deletions src/lightning_app/components/serve/python_server.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,18 @@
import abc
import base64
import os
import platform
from pathlib import Path
from typing import Any, Dict, Optional

import uvicorn
from fastapi import FastAPI
from lightning_utilities.core.imports import module_available
from lightning_utilities.core.imports import compare_version, module_available
from pydantic import BaseModel

from lightning_app.core.queues import MultiProcessQueue
from lightning_app.core.work import LightningWork
from lightning_app.utilities.app_helpers import Logger
from lightning_app.utilities.imports import _is_torch_available, requires
from lightning_app.utilities.proxies import _proxy_setattr, unwrap, WorkRunExecutor, WorkStateObserver

logger = Logger(__name__)

Expand All @@ -27,44 +26,19 @@
__doctest_skip__ += ["PythonServer", "PythonServer.*"]


class _PyTorchSpawnRunExecutor(WorkRunExecutor):
def _get_device():
import operator

"""This Executor enables to move PyTorch tensors on GPU.
import torch

Without this executor, it would raise the following exception:
RuntimeError: Cannot re-initialize CUDA in forked subprocess.
To use CUDA with multiprocessing, you must use the 'spawn' start method
"""
_TORCH_GREATER_EQUAL_1_12 = compare_version("torch", operator.ge, "1.12.0")

enable_start_observer: bool = False
local_rank = int(os.getenv("LOCAL_RANK", "0"))

def __call__(self, *args: Any, **kwargs: Any):
import torch

with self.enable_spawn():
queue = self.delta_queue if isinstance(self.delta_queue, MultiProcessQueue) else self.delta_queue.to_dict()
torch.multiprocessing.spawn(
self.dispatch_run,
args=(self.__class__, self.work, queue, args, kwargs),
nprocs=1,
)

@staticmethod
def dispatch_run(local_rank, cls, work, delta_queue, args, kwargs):
if local_rank == 0:
if isinstance(delta_queue, dict):
delta_queue = cls.process_queue(delta_queue)
work._request_queue = cls.process_queue(work._request_queue)
work._response_queue = cls.process_queue(work._response_queue)

state_observer = WorkStateObserver(work, delta_queue=delta_queue)
state_observer.start()
_proxy_setattr(work, delta_queue, state_observer)

unwrap(work.run)(*args, **kwargs)

if local_rank == 0:
state_observer.join(0)
if _TORCH_GREATER_EQUAL_1_12 and torch.backends.mps.is_available() and platform.processor() in ("arm", "arm64"):
return torch.device("mps", local_rank)
else:
return torch.device(f"cuda:{local_rank}" if torch.cuda.is_available() else "cpu")


class _DefaultInputData(BaseModel):
Expand Down Expand Up @@ -95,6 +69,9 @@ def _get_sample_data() -> Dict[Any, Any]:


class PythonServer(LightningWork, abc.ABC):

_start_method = "spawn"

@requires(["torch", "lightning_api_access"])
def __init__( # type: ignore
self,
Expand Down Expand Up @@ -160,11 +137,6 @@ def predict(self, request):
self._input_type = input_type
self._output_type = output_type

# Note: Enable to run inference on GPUs.
self._run_executor_cls = (
WorkRunExecutor if os.getenv("LIGHTNING_CLOUD_APP_ID", None) else _PyTorchSpawnRunExecutor
)

def setup(self, *args, **kwargs) -> None:
"""This method is called before the server starts. Override this if you need to download the model or
initialize the weights, setting up pipelines etc.
Expand Down Expand Up @@ -210,13 +182,16 @@ def _get_sample_dict_from_datatype(datatype: Any) -> dict:
return out

def _attach_predict_fn(self, fastapi_app: FastAPI) -> None:
from torch import inference_mode
from torch import inference_mode, no_grad

input_type: type = self.configure_input_type()
output_type: type = self.configure_output_type()

device = _get_device()
context = no_grad if device.type == "mps" else inference_mode

def predict_fn(request: input_type): # type: ignore
with inference_mode():
with context():
return self.predict(request)

fastapi_app.post("/predict", response_model=output_type)(predict_fn)
Expand Down

0 comments on commit 904323b

Please sign in to comment.