Skip to content

Commit

Permalink
[App] Enable Python Server and Gradio Serve to run on accelerated dev…
Browse files Browse the repository at this point in the history
…ice such as GPU CUDA / MPS (#15813)
  • Loading branch information
tchaton committed Nov 25, 2022
1 parent 655ade6 commit 23e0d1b
Show file tree
Hide file tree
Showing 3 changed files with 56 additions and 0 deletions.
2 changes: 2 additions & 0 deletions src/lightning_app/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed debugging with VSCode IDE ([#15747](https://github.com/Lightning-AI/lightning/pull/15747))
- Fixed setting property to the `LightningFlow` ([#15750](https://github.com/Lightning-AI/lightning/pull/15750))

- Fixed the PyTorch Inference locally on GPU ([#15813](https://github.com/Lightning-AI/lightning/pull/15813))


## [1.8.2] - 2022-11-17

Expand Down
6 changes: 6 additions & 0 deletions src/lightning_app/components/serve/gradio.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import abc
import os
from functools import partial
from types import ModuleType
from typing import Any, List, Optional

from lightning_app.components.serve.python_server import _PyTorchSpawnRunExecutor, WorkRunExecutor
from lightning_app.core.work import LightningWork
from lightning_app.utilities.imports import _is_gradio_available, requires

Expand Down Expand Up @@ -39,6 +41,10 @@ def __init__(self, *args, **kwargs):
assert self.inputs
assert self.outputs
self._model = None
# Note: Enable to run inference on GPUs.
self._run_executor_cls = (
WorkRunExecutor if os.getenv("LIGHTNING_CLOUD_APP_ID", None) else _PyTorchSpawnRunExecutor
)

@property
def model(self):
Expand Down
48 changes: 48 additions & 0 deletions src/lightning_app/components/serve/python_server.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import abc
import base64
import os
from pathlib import Path
from typing import Any, Dict, Optional

Expand All @@ -9,12 +10,54 @@
from pydantic import BaseModel
from starlette.staticfiles import StaticFiles

from lightning_app.core.queues import MultiProcessQueue
from lightning_app.core.work import LightningWork
from lightning_app.utilities.app_helpers import Logger
from lightning_app.utilities.proxies import _proxy_setattr, unwrap, WorkRunExecutor, WorkStateObserver

logger = Logger(__name__)


class _PyTorchSpawnRunExecutor(WorkRunExecutor):

"""This Executor enables to move PyTorch tensors on GPU.
Without this executor, it woud raise the following expection:
RuntimeError: Cannot re-initialize CUDA in forked subprocess.
To use CUDA with multiprocessing, you must use the 'spawn' start method
"""

enable_start_observer: bool = False

def __call__(self, *args: Any, **kwargs: Any):
import torch

with self.enable_spawn():
queue = self.delta_queue if isinstance(self.delta_queue, MultiProcessQueue) else self.delta_queue.to_dict()
torch.multiprocessing.spawn(
self.dispatch_run,
args=(self.__class__, self.work, queue, args, kwargs),
nprocs=1,
)

@staticmethod
def dispatch_run(local_rank, cls, work, delta_queue, args, kwargs):
if local_rank == 0:
if isinstance(delta_queue, dict):
delta_queue = cls.process_queue(delta_queue)
work._request_queue = cls.process_queue(work._request_queue)
work._response_queue = cls.process_queue(work._response_queue)

state_observer = WorkStateObserver(work, delta_queue=delta_queue)
state_observer.start()
_proxy_setattr(work, delta_queue, state_observer)

unwrap(work.run)(*args, **kwargs)

if local_rank == 0:
state_observer.join(0)


class _DefaultInputData(BaseModel):
payload: str

Expand Down Expand Up @@ -106,6 +149,11 @@ def predict(self, request):
self._input_type = input_type
self._output_type = output_type

# Note: Enable to run inference on GPUs.
self._run_executor_cls = (
WorkRunExecutor if os.getenv("LIGHTNING_CLOUD_APP_ID", None) else _PyTorchSpawnRunExecutor
)

def setup(self, *args, **kwargs) -> None:
"""This method is called before the server starts. Override this if you need to download the model or
initialize the weights, setting up pipelines etc.
Expand Down

0 comments on commit 23e0d1b

Please sign in to comment.