Skip to content

Commit

Permalink
Add support for async functions and async generators to `gr.ChatInter…
Browse files Browse the repository at this point in the history
…face` (#5116)

* add support for async functions and iterators to ChatInterface

* fix api

* added support to examples

* add tests

* chat

* add changeset

* typing

* chat interface

* anyio

* anyio generator

---------

Co-authored-by: gradio-pr-bot <[email protected]>
  • Loading branch information
abidlabs and gradio-pr-bot authored Aug 8, 2023
1 parent eaa1ce1 commit 0dc49b4
Show file tree
Hide file tree
Showing 4 changed files with 117 additions and 27 deletions.
5 changes: 5 additions & 0 deletions .changeset/moody-buttons-tell.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
"gradio": minor
---

feat:Add support for async functions and async generators to `gr.ChatInterface`
91 changes: 65 additions & 26 deletions gradio/chat_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,9 @@
from __future__ import annotations

import inspect
from typing import Callable, Generator
from typing import AsyncGenerator, Callable

import anyio
from gradio_client import utils as client_utils
from gradio_client.documentation import document, set_documentation_group

Expand All @@ -25,6 +26,7 @@
from gradio.helpers import create_examples as Examples # noqa: N812
from gradio.layouts import Accordion, Column, Group, Row
from gradio.themes import ThemeClass as Theme
from gradio.utils import SyncToAsyncIterator, async_iteration

set_documentation_group("chatinterface")

Expand Down Expand Up @@ -100,7 +102,12 @@ def __init__(
theme=theme,
)
self.fn = fn
self.is_generator = inspect.isgeneratorfunction(self.fn)
self.is_async = inspect.iscoroutinefunction(
self.fn
) or inspect.isasyncgenfunction(self.fn)
self.is_generator = inspect.isgeneratorfunction(
self.fn
) or inspect.isasyncgenfunction(self.fn)
self.examples = examples
if self.space_id and cache_examples is None:
self.cache_examples = True
Expand Down Expand Up @@ -397,67 +404,99 @@ def _display_input(
history.append([message, None])
return history, history

def _submit_fn(
async def _submit_fn(
self,
message: str,
history_with_input: list[list[str | None]],
*args,
**kwargs,
) -> tuple[list[list[str | None]], list[list[str | None]]]:
history = history_with_input[:-1]
response = self.fn(message, history, *args, **kwargs)
if self.is_async:
response = await self.fn(message, history, *args)
else:
response = await anyio.to_thread.run_sync(
self.fn, message, history, *args, limiter=self.limiter
)
history.append([message, response])
return history, history

def _stream_fn(
async def _stream_fn(
self,
message: str,
history_with_input: list[list[str | None]],
*args,
**kwargs,
) -> Generator[tuple[list[list[str | None]], list[list[str | None]]], None, None]:
) -> AsyncGenerator:
history = history_with_input[:-1]
generator = self.fn(message, history, *args, **kwargs)
if self.is_async:
generator = self.fn(message, history, *args)
else:
generator = await anyio.to_thread.run_sync(
self.fn, message, history, *args, limiter=self.limiter
)
generator = SyncToAsyncIterator(generator, self.limiter)
try:
first_response = next(generator)
first_response = await async_iteration(generator)
update = history + [[message, first_response]]
yield update, update
except StopIteration:
update = history + [[message, None]]
yield update, update
for response in generator:
async for response in generator:
update = history + [[message, response]]
yield update, update

def _api_submit_fn(
self, message: str, history: list[list[str | None]], *args, **kwargs
async def _api_submit_fn(
self, message: str, history: list[list[str | None]], *args
) -> tuple[str, list[list[str | None]]]:
response = self.fn(message, history)
if self.is_async:
response = await self.fn(message, history, *args)
else:
response = await anyio.to_thread.run_sync(
self.fn, message, history, *args, limiter=self.limiter
)
history.append([message, response])
return response, history

def _api_stream_fn(
self, message: str, history: list[list[str | None]], *args, **kwargs
) -> Generator[tuple[str | None, list[list[str | None]]], None, None]:
generator = self.fn(message, history, *args, **kwargs)
async def _api_stream_fn(
self, message: str, history: list[list[str | None]], *args
) -> AsyncGenerator:
if self.is_async:
generator = self.fn(message, history, *args)
else:
generator = await anyio.to_thread.run_sync(
self.fn, message, history, *args, limiter=self.limiter
)
generator = SyncToAsyncIterator(generator, self.limiter)
try:
first_response = next(generator)
first_response = await async_iteration(generator)
yield first_response, history + [[message, first_response]]
except StopIteration:
yield None, history + [[message, None]]
for response in generator:
async for response in generator:
yield response, history + [[message, response]]

def _examples_fn(self, message: str, *args, **kwargs) -> list[list[str | None]]:
return [[message, self.fn(message, [], *args, **kwargs)]]
async def _examples_fn(self, message: str, *args) -> list[list[str | None]]:
if self.is_async:
response = await self.fn(message, [], *args)
else:
response = await anyio.to_thread.run_sync(
self.fn, message, [], *args, limiter=self.limiter
)
return [[message, response]]

def _examples_stream_fn(
async def _examples_stream_fn(
self,
message: str,
*args,
**kwargs,
) -> Generator[list[list[str | None]], None, None]:
for response in self.fn(message, [], *args, **kwargs):
) -> AsyncGenerator:
if self.is_async:
generator = self.fn(message, [], *args)
else:
generator = await anyio.to_thread.run_sync(
self.fn, message, [], *args, limiter=self.limiter
)
generator = SyncToAsyncIterator(generator, self.limiter)
async for response in generator:
yield [[message, response]]

def _delete_prev_fn(
Expand Down
1 change: 1 addition & 0 deletions scripts/format_backend.sh
Original file line number Diff line number Diff line change
Expand Up @@ -5,5 +5,6 @@ cd "$(dirname ${0})/.."
echo "Formatting the backend... Our style follows the Black code style."
ruff --fix gradio test
black gradio test
bash scripts/type_check_backend.sh

bash client/python/scripts/format.sh # Call the client library's formatting script
47 changes: 46 additions & 1 deletion test/test_chat_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,20 @@ def double(message, history):
return message + " " + message


async def async_greet(message, history):
return "hi, " + message


def stream(message, history):
for i in range(len(message)):
yield message[: i + 1]


async def async_stream(message, history):
for i in range(len(message)):
yield message[: i + 1]


def count(message, history):
return str(len(history))

Expand Down Expand Up @@ -68,7 +77,8 @@ def test_events_attached(self):
)

@pytest.mark.asyncio
async def test_example_caching(self):
async def test_example_caching(self, monkeypatch):
monkeypatch.setattr(helpers, "CACHED_FOLDER", tempfile.mkdtemp())
chatbot = gr.ChatInterface(
double, examples=["hello", "hi"], cache_examples=True
)
Expand All @@ -77,6 +87,17 @@ async def test_example_caching(self):
assert prediction_hello[0][0] == ["hello", "hello hello"]
assert prediction_hi[0][0] == ["hi", "hi hi"]

@pytest.mark.asyncio
async def test_example_caching_async(self, monkeypatch):
monkeypatch.setattr(helpers, "CACHED_FOLDER", tempfile.mkdtemp())
chatbot = gr.ChatInterface(
async_greet, examples=["abubakar", "tom"], cache_examples=True
)
prediction_hello = await chatbot.examples_handler.load_from_cache(0)
prediction_hi = await chatbot.examples_handler.load_from_cache(1)
assert prediction_hello[0][0] == ["abubakar", "hi, abubakar"]
assert prediction_hi[0][0] == ["tom", "hi, tom"]

@pytest.mark.asyncio
async def test_example_caching_with_streaming(self, monkeypatch):
monkeypatch.setattr(helpers, "CACHED_FOLDER", tempfile.mkdtemp())
Expand All @@ -88,6 +109,17 @@ async def test_example_caching_with_streaming(self, monkeypatch):
assert prediction_hello[0][0] == ["hello", "hello"]
assert prediction_hi[0][0] == ["hi", "hi"]

@pytest.mark.asyncio
async def test_example_caching_with_streaming_async(self, monkeypatch):
monkeypatch.setattr(helpers, "CACHED_FOLDER", tempfile.mkdtemp())
chatbot = gr.ChatInterface(
async_stream, examples=["hello", "hi"], cache_examples=True
)
prediction_hello = await chatbot.examples_handler.load_from_cache(0)
prediction_hi = await chatbot.examples_handler.load_from_cache(1)
assert prediction_hello[0][0] == ["hello", "hello"]
assert prediction_hi[0][0] == ["hi", "hi"]

@pytest.mark.asyncio
async def test_example_caching_with_additional_inputs(self, monkeypatch):
monkeypatch.setattr(helpers, "CACHED_FOLDER", tempfile.mkdtemp())
Expand Down Expand Up @@ -138,12 +170,25 @@ def test_streaming_api(self, connect):
wait([job])
assert job.outputs() == ["h", "he", "hel", "hell", "hello"]

def test_streaming_api_async(self, connect):
chatbot = gr.ChatInterface(async_stream).queue()
with connect(chatbot) as client:
job = client.submit("hello")
wait([job])
assert job.outputs() == ["h", "he", "hel", "hell", "hello"]

def test_non_streaming_api(self, connect):
chatbot = gr.ChatInterface(double)
with connect(chatbot) as client:
result = client.predict("hello")
assert result == "hello hello"

def test_non_streaming_api_async(self, connect):
chatbot = gr.ChatInterface(async_greet)
with connect(chatbot) as client:
result = client.predict("gradio")
assert result == "hi, gradio"

def test_streaming_api_with_additional_inputs(self, connect):
chatbot = gr.ChatInterface(
echo_system_prompt_plus_message,
Expand Down

1 comment on commit 0dc49b4

@vercel
Copy link

@vercel vercel bot commented on 0dc49b4 Aug 8, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please sign in to comment.