From 0dc49b4c517706f572240f285313a881089ced79 Mon Sep 17 00:00:00 2001 From: Abubakar Abid Date: Tue, 8 Aug 2023 15:57:55 -0400 Subject: [PATCH] Add support for async functions and async generators to `gr.ChatInterface` (#5116) * add support for async functions and iterators to ChatInterface * fix api * added support to examples * add tests * chat * add changeset * typing * chat interface * anyio * anyio generator --------- Co-authored-by: gradio-pr-bot --- .changeset/moody-buttons-tell.md | 5 ++ gradio/chat_interface.py | 91 +++++++++++++++++++++++--------- scripts/format_backend.sh | 1 + test/test_chat_interface.py | 47 ++++++++++++++++- 4 files changed, 117 insertions(+), 27 deletions(-) create mode 100644 .changeset/moody-buttons-tell.md diff --git a/.changeset/moody-buttons-tell.md b/.changeset/moody-buttons-tell.md new file mode 100644 index 0000000000000..edaafc405b4a5 --- /dev/null +++ b/.changeset/moody-buttons-tell.md @@ -0,0 +1,5 @@ +--- +"gradio": minor +--- + +feat:Add support for async functions and async generators to `gr.ChatInterface` diff --git a/gradio/chat_interface.py b/gradio/chat_interface.py index aaee1c820e68e..5e2f22d3a9c05 100644 --- a/gradio/chat_interface.py +++ b/gradio/chat_interface.py @@ -6,8 +6,9 @@ from __future__ import annotations import inspect -from typing import Callable, Generator +from typing import AsyncGenerator, Callable +import anyio from gradio_client import utils as client_utils from gradio_client.documentation import document, set_documentation_group @@ -25,6 +26,7 @@ from gradio.helpers import create_examples as Examples # noqa: N812 from gradio.layouts import Accordion, Column, Group, Row from gradio.themes import ThemeClass as Theme +from gradio.utils import SyncToAsyncIterator, async_iteration set_documentation_group("chatinterface") @@ -100,7 +102,12 @@ def __init__( theme=theme, ) self.fn = fn - self.is_generator = inspect.isgeneratorfunction(self.fn) + self.is_async = inspect.iscoroutinefunction( + self.fn + ) or inspect.isasyncgenfunction(self.fn) + self.is_generator = inspect.isgeneratorfunction( + self.fn + ) or inspect.isasyncgenfunction(self.fn) self.examples = examples if self.space_id and cache_examples is None: self.cache_examples = True @@ -397,67 +404,99 @@ def _display_input( history.append([message, None]) return history, history - def _submit_fn( + async def _submit_fn( self, message: str, history_with_input: list[list[str | None]], *args, - **kwargs, ) -> tuple[list[list[str | None]], list[list[str | None]]]: history = history_with_input[:-1] - response = self.fn(message, history, *args, **kwargs) + if self.is_async: + response = await self.fn(message, history, *args) + else: + response = await anyio.to_thread.run_sync( + self.fn, message, history, *args, limiter=self.limiter + ) history.append([message, response]) return history, history - def _stream_fn( + async def _stream_fn( self, message: str, history_with_input: list[list[str | None]], *args, - **kwargs, - ) -> Generator[tuple[list[list[str | None]], list[list[str | None]]], None, None]: + ) -> AsyncGenerator: history = history_with_input[:-1] - generator = self.fn(message, history, *args, **kwargs) + if self.is_async: + generator = self.fn(message, history, *args) + else: + generator = await anyio.to_thread.run_sync( + self.fn, message, history, *args, limiter=self.limiter + ) + generator = SyncToAsyncIterator(generator, self.limiter) try: - first_response = next(generator) + first_response = await async_iteration(generator) update = history + [[message, first_response]] yield update, update except StopIteration: update = history + [[message, None]] yield update, update - for response in generator: + async for response in generator: update = history + [[message, response]] yield update, update - def _api_submit_fn( - self, message: str, history: list[list[str | None]], *args, **kwargs + async def _api_submit_fn( + self, message: str, history: list[list[str | None]], *args ) -> tuple[str, list[list[str | None]]]: - response = self.fn(message, history) + if self.is_async: + response = await self.fn(message, history, *args) + else: + response = await anyio.to_thread.run_sync( + self.fn, message, history, *args, limiter=self.limiter + ) history.append([message, response]) return response, history - def _api_stream_fn( - self, message: str, history: list[list[str | None]], *args, **kwargs - ) -> Generator[tuple[str | None, list[list[str | None]]], None, None]: - generator = self.fn(message, history, *args, **kwargs) + async def _api_stream_fn( + self, message: str, history: list[list[str | None]], *args + ) -> AsyncGenerator: + if self.is_async: + generator = self.fn(message, history, *args) + else: + generator = await anyio.to_thread.run_sync( + self.fn, message, history, *args, limiter=self.limiter + ) + generator = SyncToAsyncIterator(generator, self.limiter) try: - first_response = next(generator) + first_response = await async_iteration(generator) yield first_response, history + [[message, first_response]] except StopIteration: yield None, history + [[message, None]] - for response in generator: + async for response in generator: yield response, history + [[message, response]] - def _examples_fn(self, message: str, *args, **kwargs) -> list[list[str | None]]: - return [[message, self.fn(message, [], *args, **kwargs)]] + async def _examples_fn(self, message: str, *args) -> list[list[str | None]]: + if self.is_async: + response = await self.fn(message, [], *args) + else: + response = await anyio.to_thread.run_sync( + self.fn, message, [], *args, limiter=self.limiter + ) + return [[message, response]] - def _examples_stream_fn( + async def _examples_stream_fn( self, message: str, *args, - **kwargs, - ) -> Generator[list[list[str | None]], None, None]: - for response in self.fn(message, [], *args, **kwargs): + ) -> AsyncGenerator: + if self.is_async: + generator = self.fn(message, [], *args) + else: + generator = await anyio.to_thread.run_sync( + self.fn, message, [], *args, limiter=self.limiter + ) + generator = SyncToAsyncIterator(generator, self.limiter) + async for response in generator: yield [[message, response]] def _delete_prev_fn( diff --git a/scripts/format_backend.sh b/scripts/format_backend.sh index 69d5a3f9d85c2..26feae059ec31 100755 --- a/scripts/format_backend.sh +++ b/scripts/format_backend.sh @@ -5,5 +5,6 @@ cd "$(dirname ${0})/.." echo "Formatting the backend... Our style follows the Black code style." ruff --fix gradio test black gradio test +bash scripts/type_check_backend.sh bash client/python/scripts/format.sh # Call the client library's formatting script diff --git a/test/test_chat_interface.py b/test/test_chat_interface.py index c15ddc35856a3..fe96adfc2c932 100644 --- a/test/test_chat_interface.py +++ b/test/test_chat_interface.py @@ -15,11 +15,20 @@ def double(message, history): return message + " " + message +async def async_greet(message, history): + return "hi, " + message + + def stream(message, history): for i in range(len(message)): yield message[: i + 1] +async def async_stream(message, history): + for i in range(len(message)): + yield message[: i + 1] + + def count(message, history): return str(len(history)) @@ -68,7 +77,8 @@ def test_events_attached(self): ) @pytest.mark.asyncio - async def test_example_caching(self): + async def test_example_caching(self, monkeypatch): + monkeypatch.setattr(helpers, "CACHED_FOLDER", tempfile.mkdtemp()) chatbot = gr.ChatInterface( double, examples=["hello", "hi"], cache_examples=True ) @@ -77,6 +87,17 @@ async def test_example_caching(self): assert prediction_hello[0][0] == ["hello", "hello hello"] assert prediction_hi[0][0] == ["hi", "hi hi"] + @pytest.mark.asyncio + async def test_example_caching_async(self, monkeypatch): + monkeypatch.setattr(helpers, "CACHED_FOLDER", tempfile.mkdtemp()) + chatbot = gr.ChatInterface( + async_greet, examples=["abubakar", "tom"], cache_examples=True + ) + prediction_hello = await chatbot.examples_handler.load_from_cache(0) + prediction_hi = await chatbot.examples_handler.load_from_cache(1) + assert prediction_hello[0][0] == ["abubakar", "hi, abubakar"] + assert prediction_hi[0][0] == ["tom", "hi, tom"] + @pytest.mark.asyncio async def test_example_caching_with_streaming(self, monkeypatch): monkeypatch.setattr(helpers, "CACHED_FOLDER", tempfile.mkdtemp()) @@ -88,6 +109,17 @@ async def test_example_caching_with_streaming(self, monkeypatch): assert prediction_hello[0][0] == ["hello", "hello"] assert prediction_hi[0][0] == ["hi", "hi"] + @pytest.mark.asyncio + async def test_example_caching_with_streaming_async(self, monkeypatch): + monkeypatch.setattr(helpers, "CACHED_FOLDER", tempfile.mkdtemp()) + chatbot = gr.ChatInterface( + async_stream, examples=["hello", "hi"], cache_examples=True + ) + prediction_hello = await chatbot.examples_handler.load_from_cache(0) + prediction_hi = await chatbot.examples_handler.load_from_cache(1) + assert prediction_hello[0][0] == ["hello", "hello"] + assert prediction_hi[0][0] == ["hi", "hi"] + @pytest.mark.asyncio async def test_example_caching_with_additional_inputs(self, monkeypatch): monkeypatch.setattr(helpers, "CACHED_FOLDER", tempfile.mkdtemp()) @@ -138,12 +170,25 @@ def test_streaming_api(self, connect): wait([job]) assert job.outputs() == ["h", "he", "hel", "hell", "hello"] + def test_streaming_api_async(self, connect): + chatbot = gr.ChatInterface(async_stream).queue() + with connect(chatbot) as client: + job = client.submit("hello") + wait([job]) + assert job.outputs() == ["h", "he", "hel", "hell", "hello"] + def test_non_streaming_api(self, connect): chatbot = gr.ChatInterface(double) with connect(chatbot) as client: result = client.predict("hello") assert result == "hello hello" + def test_non_streaming_api_async(self, connect): + chatbot = gr.ChatInterface(async_greet) + with connect(chatbot) as client: + result = client.predict("gradio") + assert result == "hi, gradio" + def test_streaming_api_with_additional_inputs(self, connect): chatbot = gr.ChatInterface( echo_system_prompt_plus_message,