Skip to content

Commit

Permalink
Support gr.Progress() in python client (#3924)
Browse files Browse the repository at this point in the history
* Add progress message

* CHANGELOG

* Dont use pydantic

* Docs + local test

* Add gr to requirements

* Remove editable install

* make a bit softer
  • Loading branch information
freddyaboulton authored Apr 24, 2023
1 parent 6b854ee commit d835c9a
Show file tree
Hide file tree
Showing 6 changed files with 90 additions and 5 deletions.
2 changes: 1 addition & 1 deletion client/python/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

## New Features:

No changes to highlight.
- Progress Updates from `gr.Progress()` can be accessed via `job.status().progress_data` by @freddyaboulton](https://github.com/freddyaboulton) in [PR 3924](https://github.com/gradio-app/gradio/pull/3924)

## Bug Fixes:

Expand Down
10 changes: 9 additions & 1 deletion client/python/gradio_client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -949,7 +949,11 @@ def outputs(self) -> List[Tuple | Any]:
def status(self) -> StatusUpdate:
"""
Returns the latest status update from the Job in the form of a StatusUpdate
object, which contains the following fields: code, rank, queue_size, success, time, eta.
object, which contains the following fields: code, rank, queue_size, success, time, eta, and progress_data.
progress_data is a list of updates emitted by the gr.Progress() tracker of the event handler. Each element
of the list has the following fields: index, length, unit, progress, desc. If the event handler does not have
a gr.Progress() tracker, the progress_data field will be None.
Example:
from gradio_client import Client
Expand All @@ -973,6 +977,7 @@ def status(self) -> StatusUpdate:
success=False,
time=time,
eta=None,
progress_data=None,
)
if self.done():
if not self.future._exception: # type: ignore
Expand All @@ -983,6 +988,7 @@ def status(self) -> StatusUpdate:
success=True,
time=time,
eta=None,
progress_data=None,
)
else:
return StatusUpdate(
Expand All @@ -992,6 +998,7 @@ def status(self) -> StatusUpdate:
success=False,
time=time,
eta=None,
progress_data=None,
)
else:
if not self.communicator:
Expand All @@ -1002,6 +1009,7 @@ def status(self) -> StatusUpdate:
success=None,
time=time,
eta=None,
progress_data=None,
)
else:
with self.communicator.lock:
Expand Down
33 changes: 32 additions & 1 deletion client/python/gradio_client/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from enum import Enum
from pathlib import Path
from threading import Lock
from typing import Any, Callable, Dict, List, Tuple
from typing import Any, Callable, Dict, List, Optional, Tuple

import fsspec.asyn
import httpx
Expand Down Expand Up @@ -79,6 +79,7 @@ class Status(Enum):
SENDING_DATA = "SENDING_DATA"
PROCESSING = "PROCESSING"
ITERATING = "ITERATING"
PROGRESS = "PROGRESS"
FINISHED = "FINISHED"
CANCELLED = "CANCELLED"

Expand All @@ -92,6 +93,7 @@ def ordering(status: "Status") -> int:
Status.IN_QUEUE,
Status.SENDING_DATA,
Status.PROCESSING,
Status.PROGRESS,
Status.ITERATING,
Status.FINISHED,
Status.CANCELLED,
Expand All @@ -112,9 +114,32 @@ def msg_to_status(msg: str) -> "Status":
"process_starts": Status.PROCESSING,
"process_generating": Status.ITERATING,
"process_completed": Status.FINISHED,
"progress": Status.PROGRESS,
}[msg]


@dataclass
class ProgressUnit:
index: Optional[int]
length: Optional[int]
unit: Optional[str]
progress: Optional[float]
desc: Optional[str]

@classmethod
def from_ws_msg(cls, data: List[Dict]) -> List["ProgressUnit"]:
return [
cls(
index=d.get("index"),
length=d.get("length"),
unit=d.get("unit"),
progress=d.get("progress"),
desc=d.get("desc"),
)
for d in data
]


@dataclass
class StatusUpdate:
"""Update message sent from the worker thread to the Job on the main thread."""
Expand All @@ -125,6 +150,7 @@ class StatusUpdate:
eta: float | None
success: bool | None
time: datetime | None
progress_data: List[ProgressUnit] | None


def create_initial_status_update():
Expand All @@ -135,6 +161,7 @@ def create_initial_status_update():
eta=None,
success=None,
time=datetime.now(),
progress_data=None,
)


Expand Down Expand Up @@ -209,13 +236,17 @@ async def get_pred_from_ws(
resp = json.loads(msg)
if helper:
with helper.lock:
has_progress = "progress_data" in resp
status_update = StatusUpdate(
code=Status.msg_to_status(resp["msg"]),
queue_size=resp.get("queue_size"),
rank=resp.get("rank", None),
success=resp.get("success"),
time=datetime.now(),
eta=resp.get("rank_eta"),
progress_data=ProgressUnit.from_ws_msg(resp["progress_data"])
if has_progress
else None,
)
output = resp.get("output", {}).get("data", [])
if output and status_update.code != Status.FINISHED:
Expand Down
1 change: 0 additions & 1 deletion client/python/scripts/ci.sh
Original file line number Diff line number Diff line change
Expand Up @@ -9,5 +9,4 @@ black --check test gradio_client
pyright gradio_client/*.py

echo "Testing..."
python -m pip install -e ../../. # Install gradio from local source (as the latest version may not yet be published to PyPI)
python -m pytest test
1 change: 1 addition & 0 deletions client/python/test/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,4 @@ pytest-asyncio
pytest==7.1.2
ruff==0.0.260
pyright==1.1.298
gradio
48 changes: 47 additions & 1 deletion client/python/test/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,13 @@
from datetime import datetime, timedelta
from unittest.mock import MagicMock, patch

import gradio as gr
import pytest
from huggingface_hub.utils import RepositoryNotFoundError

from gradio_client import Client
from gradio_client.serializing import SimpleSerializable
from gradio_client.utils import Communicator, Status, StatusUpdate
from gradio_client.utils import Communicator, ProgressUnit, Status, StatusUpdate

os.environ["HF_HUB_DISABLE_TELEMETRY"] = "1"

Expand Down Expand Up @@ -96,6 +97,7 @@ def test_job_status_queue_disabled(self):
statuses.append(job.status())
statuses.append(job.status())
assert all(s.code in [Status.PROCESSING, Status.FINISHED] for s in statuses)
assert not any(s.progress_data for s in statuses)

@pytest.mark.flaky
def test_intermediate_outputs(
Expand Down Expand Up @@ -157,6 +159,40 @@ def test_job_output_video(self):
)
assert pathlib.Path(job.result()).exists()

def test_progress_updates(self):
def my_function(x, progress=gr.Progress()):
progress(0, desc="Starting...")
for i in progress.tqdm(range(20)):
time.sleep(0.1)
return x

demo = gr.Interface(my_function, gr.Textbox(), gr.Textbox()).queue(
concurrency_count=20
)
_, local_url, _ = demo.launch(prevent_thread_lock=True)

try:
client = Client(src=local_url)
job = client.submit("hello", api_name="/predict")
statuses = []
while not job.done():
statuses.append(job.status())
time.sleep(0.02)
assert any(s.code == Status.PROGRESS for s in statuses)
assert any(s.progress_data is not None for s in statuses)
all_progress_data = [
p for s in statuses if s.progress_data for p in s.progress_data
]
count = 0
for i in range(20):
unit = ProgressUnit(
index=i, length=20, unit="steps", progress=None, desc=None
)
count += unit in all_progress_data
assert count
finally:
demo.close()

@pytest.mark.flaky
def test_cancel_from_client_queued(self):
client = Client(src="gradio-tests/test-cancel-from-client")
Expand Down Expand Up @@ -284,6 +320,7 @@ def test_messages_passed_correctly(self, mock_make_end_to_end_fn):
success=None,
queue_size=None,
time=now,
progress_data=None,
),
StatusUpdate(
code=Status.SENDING_DATA,
Expand All @@ -292,6 +329,7 @@ def test_messages_passed_correctly(self, mock_make_end_to_end_fn):
success=None,
queue_size=None,
time=now + timedelta(seconds=1),
progress_data=None,
),
StatusUpdate(
code=Status.IN_QUEUE,
Expand All @@ -300,6 +338,7 @@ def test_messages_passed_correctly(self, mock_make_end_to_end_fn):
queue_size=2,
success=None,
time=now + timedelta(seconds=2),
progress_data=None,
),
StatusUpdate(
code=Status.IN_QUEUE,
Expand All @@ -308,6 +347,7 @@ def test_messages_passed_correctly(self, mock_make_end_to_end_fn):
queue_size=1,
success=None,
time=now + timedelta(seconds=3),
progress_data=None,
),
StatusUpdate(
code=Status.ITERATING,
Expand All @@ -316,6 +356,7 @@ def test_messages_passed_correctly(self, mock_make_end_to_end_fn):
queue_size=None,
success=None,
time=now + timedelta(seconds=3),
progress_data=None,
),
StatusUpdate(
code=Status.FINISHED,
Expand All @@ -324,6 +365,7 @@ def test_messages_passed_correctly(self, mock_make_end_to_end_fn):
queue_size=None,
success=True,
time=now + timedelta(seconds=4),
progress_data=None,
),
]

Expand Down Expand Up @@ -362,6 +404,7 @@ def test_messages_correct_two_concurrent(self, mock_make_end_to_end_fn):
success=None,
queue_size=None,
time=now,
progress_data=None,
),
StatusUpdate(
code=Status.FINISHED,
Expand All @@ -370,6 +413,7 @@ def test_messages_correct_two_concurrent(self, mock_make_end_to_end_fn):
queue_size=None,
success=True,
time=now + timedelta(seconds=4),
progress_data=None,
),
]

Expand All @@ -381,6 +425,7 @@ def test_messages_correct_two_concurrent(self, mock_make_end_to_end_fn):
queue_size=2,
success=None,
time=now + timedelta(seconds=2),
progress_data=None,
),
StatusUpdate(
code=Status.IN_QUEUE,
Expand All @@ -389,6 +434,7 @@ def test_messages_correct_two_concurrent(self, mock_make_end_to_end_fn):
queue_size=1,
success=None,
time=now + timedelta(seconds=3),
progress_data=None,
),
]

Expand Down

0 comments on commit d835c9a

Please sign in to comment.