Skip to content

Commit

Permalink
Adds additional_inputs to gr.ChatInterface (#4985)
Browse files Browse the repository at this point in the history
* adding additional inputs

* add param

* guide

* add is_rendered

* add demo

* fixing examples

* add test

* guide

* add changeset

* Fix typos

* Remove label

* Revert "Remove label"

This reverts commit 1004285.

* add changeset

---------

Co-authored-by: gradio-pr-bot <[email protected]>
Co-authored-by: freddyaboulton <[email protected]>
  • Loading branch information
3 people authored Jul 24, 2023
1 parent 4b0e98e commit b74f845
Show file tree
Hide file tree
Showing 9 changed files with 251 additions and 53 deletions.
5 changes: 5 additions & 0 deletions .changeset/witty-pets-rhyme.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
"gradio": minor
---

feat:Adds `additional_inputs` to `gr.ChatInterface`
1 change: 1 addition & 0 deletions demo/chatinterface_system_prompt/run.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"cells": [{"cell_type": "markdown", "id": 302934307671667531413257853548643485645, "metadata": {}, "source": ["# Gradio Demo: chatinterface_system_prompt"]}, {"cell_type": "code", "execution_count": null, "id": 272996653310673477252411125948039410165, "metadata": {}, "outputs": [], "source": ["!pip install -q gradio "]}, {"cell_type": "code", "execution_count": null, "id": 288918539441861185822528903084949547379, "metadata": {}, "outputs": [], "source": ["import gradio as gr\n", "import time\n", "\n", "def echo(message, history, system_prompt, tokens):\n", " response = f\"System prompt: {system_prompt}\\n Message: {message}.\"\n", " for i in range(min(len(response), int(tokens))):\n", " time.sleep(0.05)\n", " yield response[: i+1]\n", "\n", "demo = gr.ChatInterface(echo, \n", " additional_inputs=[\n", " gr.Textbox(\"You are helpful AI.\", label=\"System Prompt\"), \n", " gr.Slider(10, 100)\n", " ]\n", " )\n", "\n", "if __name__ == \"__main__\":\n", " demo.queue().launch()"]}], "metadata": {}, "nbformat": 4, "nbformat_minor": 5}
18 changes: 18 additions & 0 deletions demo/chatinterface_system_prompt/run.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
import gradio as gr
import time

def echo(message, history, system_prompt, tokens):
response = f"System prompt: {system_prompt}\n Message: {message}."
for i in range(min(len(response), int(tokens))):
time.sleep(0.05)
yield response[: i+1]

demo = gr.ChatInterface(echo,
additional_inputs=[
gr.Textbox("You are helpful AI.", label="System Prompt"),
gr.Slider(10, 100)
]
)

if __name__ == "__main__":
demo.queue().launch()
3 changes: 3 additions & 0 deletions gradio/blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@ def __init__(
self.share_token = secrets.token_urlsafe(32)
self._skip_init_processing = _skip_init_processing
self.parent: BlockContext | None = None
self.is_rendered: bool = False

if render:
self.render()
Expand All @@ -127,6 +128,7 @@ def render(self):
Context.block.add(self)
if Context.root_block is not None:
Context.root_block.blocks[self._id] = self
self.is_rendered = True
if isinstance(self, components.IOComponent):
Context.root_block.temp_file_sets.append(self.temp_files)
return self
Expand All @@ -144,6 +146,7 @@ def unrender(self):
if Context.root_block is not None:
try:
del Context.root_block.blocks[self._id]
self.is_rendered = False
except KeyError:
pass
return self
Expand Down
88 changes: 62 additions & 26 deletions gradio/chat_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,22 +6,24 @@
from __future__ import annotations

import inspect
import warnings
from typing import Callable, Generator

from gradio_client import utils as client_utils
from gradio_client.documentation import document, set_documentation_group

from gradio.blocks import Blocks
from gradio.components import (
Button,
Chatbot,
IOComponent,
Markdown,
State,
Textbox,
get_component_instance,
)
from gradio.events import Dependency, EventListenerMethod
from gradio.helpers import create_examples as Examples # noqa: N812
from gradio.layouts import Column, Group, Row
from gradio.layouts import Accordion, Column, Group, Row
from gradio.themes import ThemeClass as Theme

set_documentation_group("chatinterface")
Expand Down Expand Up @@ -53,6 +55,8 @@ def __init__(
*,
chatbot: Chatbot | None = None,
textbox: Textbox | None = None,
additional_inputs: str | IOComponent | list[str | IOComponent] | None = None,
additional_inputs_accordion_name: str = "Additional Inputs",
examples: list[str] | None = None,
cache_examples: bool | None = None,
title: str | None = None,
Expand All @@ -65,12 +69,15 @@ def __init__(
retry_btn: str | None | Button = "🔄 Retry",
undo_btn: str | None | Button = "↩️ Undo",
clear_btn: str | None | Button = "🗑️ Clear",
autofocus: bool = True,
):
"""
Parameters:
fn: the function to wrap the chat interface around. Should accept two parameters: a string input message and list of two-element lists of the form [[user_message, bot_message], ...] representing the chat history, and return a string response. See the Chatbot documentation for more information on the chat history format.
chatbot: an instance of the gr.Chatbot component to use for the chat interface, if you would like to customize the chatbot properties. If not provided, a default gr.Chatbot component will be created.
textbox: an instance of the gr.Textbox component to use for the chat interface, if you would like to customize the textbox properties. If not provided, a default gr.Textbox component will be created.
additional_inputs: an instance or list of instances of gradio components (or their string shortcuts) to use as additional inputs to the chatbot. If components are not already rendered in a surrounding Blocks, then the components will be displayed under the chatbot, in an accordion.
additional_inputs_accordion_name: the label of the accordion to use for additional inputs, only used if additional_inputs is provided.
examples: sample inputs for the function; if provided, appear below the chatbot and can be clicked to populate the chatbot input.
cache_examples: If True, caches examples in the server for fast runtime in examples. The default option in HuggingFace Spaces is True. The default option elsewhere is False.
title: a title for the interface; if provided, appears above chatbot in large font. Also used as the tab title when opened in a browser window.
Expand All @@ -83,6 +90,7 @@ def __init__(
retry_btn: Text to display on the retry button. If None, no button will be displayed. If a Button object, that button will be used.
undo_btn: Text to display on the delete last button. If None, no button will be displayed. If a Button object, that button will be used.
clear_btn: Text to display on the clear button. If None, no button will be displayed. If a Button object, that button will be used.
autofocus: If True, autofocuses to the textbox when the page loads.
"""
super().__init__(
analytics_enabled=analytics_enabled,
Expand All @@ -91,12 +99,6 @@ def __init__(
title=title or "Gradio",
theme=theme,
)
if len(inspect.signature(fn).parameters) != 2:
warnings.warn(
"The function to ChatInterface should take two inputs (message, history) and return a single string response.",
UserWarning,
)

self.fn = fn
self.is_generator = inspect.isgeneratorfunction(self.fn)
self.examples = examples
Expand All @@ -106,6 +108,16 @@ def __init__(
self.cache_examples = cache_examples or False
self.buttons: list[Button] = []

if additional_inputs:
if not isinstance(additional_inputs, list):
additional_inputs = [additional_inputs]
self.additional_inputs = [
get_component_instance(i, render=False) for i in additional_inputs # type: ignore
]
else:
self.additional_inputs = []
self.additional_inputs_accordion_name = additional_inputs_accordion_name

with self:
if title:
Markdown(
Expand All @@ -130,9 +142,10 @@ def __init__(
self.textbox = Textbox(
container=False,
show_label=False,
label="Message",
placeholder="Type a message...",
scale=7,
autofocus=True,
autofocus=autofocus,
)
if submit_btn:
if isinstance(submit_btn, Button):
Expand Down Expand Up @@ -199,12 +212,24 @@ def __init__(

self.examples_handler = Examples(
examples=examples,
inputs=self.textbox,
inputs=[self.textbox] + self.additional_inputs,
outputs=self.chatbot,
fn=examples_fn,
cache_examples=self.cache_examples,
)

any_unrendered_inputs = any(
not inp.is_rendered for inp in self.additional_inputs
)
if self.additional_inputs and any_unrendered_inputs:
with Accordion(self.additional_inputs_accordion_name, open=False):
for input_component in self.additional_inputs:
if not input_component.is_rendered:
input_component.render()

# The example caching must happen after the input components have rendered
if cache_examples:
client_utils.synchronize_async(self.examples_handler.cache)

self.saved_input = State()
self.chatbot_state = State([])

Expand All @@ -230,7 +255,7 @@ def _setup_events(self) -> None:
)
.then(
submit_fn,
[self.saved_input, self.chatbot_state],
[self.saved_input, self.chatbot_state] + self.additional_inputs,
[self.chatbot, self.chatbot_state],
api_name=False,
)
Expand All @@ -255,7 +280,7 @@ def _setup_events(self) -> None:
)
.then(
submit_fn,
[self.saved_input, self.chatbot_state],
[self.saved_input, self.chatbot_state] + self.additional_inputs,
[self.chatbot, self.chatbot_state],
api_name=False,
)
Expand All @@ -280,7 +305,7 @@ def _setup_events(self) -> None:
)
.then(
submit_fn,
[self.saved_input, self.chatbot_state],
[self.saved_input, self.chatbot_state] + self.additional_inputs,
[self.chatbot, self.chatbot_state],
api_name=False,
)
Expand Down Expand Up @@ -358,7 +383,7 @@ def _setup_api(self) -> None:

self.fake_api_btn.click(
api_fn,
[self.textbox, self.chatbot_state],
[self.textbox, self.chatbot_state] + self.additional_inputs,
[self.textbox, self.chatbot_state],
api_name="chat",
)
Expand All @@ -373,18 +398,26 @@ def _display_input(
return history, history

def _submit_fn(
self, message: str, history_with_input: list[list[str | None]]
self,
message: str,
history_with_input: list[list[str | None]],
*args,
**kwargs,
) -> tuple[list[list[str | None]], list[list[str | None]]]:
history = history_with_input[:-1]
response = self.fn(message, history)
response = self.fn(message, history, *args, **kwargs)
history.append([message, response])
return history, history

def _stream_fn(
self, message: str, history_with_input: list[list[str | None]]
self,
message: str,
history_with_input: list[list[str | None]],
*args,
**kwargs,
) -> Generator[tuple[list[list[str | None]], list[list[str | None]]], None, None]:
history = history_with_input[:-1]
generator = self.fn(message, history)
generator = self.fn(message, history, *args, **kwargs)
try:
first_response = next(generator)
update = history + [[message, first_response]]
Expand All @@ -397,16 +430,16 @@ def _stream_fn(
yield update, update

def _api_submit_fn(
self, message: str, history: list[list[str | None]]
self, message: str, history: list[list[str | None]], *args, **kwargs
) -> tuple[str, list[list[str | None]]]:
response = self.fn(message, history)
history.append([message, response])
return response, history

def _api_stream_fn(
self, message: str, history: list[list[str | None]]
self, message: str, history: list[list[str | None]], *args, **kwargs
) -> Generator[tuple[str | None, list[list[str | None]]], None, None]:
generator = self.fn(message, history)
generator = self.fn(message, history, *args, **kwargs)
try:
first_response = next(generator)
yield first_response, history + [[message, first_response]]
Expand All @@ -415,13 +448,16 @@ def _api_stream_fn(
for response in generator:
yield response, history + [[message, response]]

def _examples_fn(self, message: str) -> list[list[str | None]]:
return [[message, self.fn(message, [])]]
def _examples_fn(self, message: str, *args, **kwargs) -> list[list[str | None]]:
return [[message, self.fn(message, [], *args, **kwargs)]]

def _examples_stream_fn(
self, message: str
self,
message: str,
*args,
**kwargs,
) -> Generator[list[list[str | None]], None, None]:
for response in self.fn(message, []):
for response in self.fn(message, [], *args, **kwargs):
yield [[message, response]]

def _delete_prev_fn(
Expand Down
52 changes: 33 additions & 19 deletions gradio/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,7 @@ def __init__(
self.non_none_examples = non_none_examples
self.inputs = inputs
self.inputs_with_examples = inputs_with_examples
self.outputs = outputs
self.outputs = outputs or []
self.fn = fn
self.cache_examples = cache_examples
self._api_mode = _api_mode
Expand Down Expand Up @@ -250,23 +250,14 @@ async def create(self) -> None:
component to hold the examples"""

async def load_example(example_id):
if self.cache_examples:
processed_example = self.non_none_processed_examples[
example_id
] + await self.load_from_cache(example_id)
else:
processed_example = self.non_none_processed_examples[example_id]
processed_example = self.non_none_processed_examples[example_id]
return utils.resolve_singleton(processed_example)

if Context.root_block:
if self.cache_examples and self.outputs:
targets = self.inputs_with_examples + self.outputs
else:
targets = self.inputs_with_examples
load_input_event = self.dataset.click(
self.load_input_event = self.dataset.click(
load_example,
inputs=[self.dataset],
outputs=targets, # type: ignore
outputs=self.inputs_with_examples, # type: ignore
show_progress="hidden",
postprocess=False,
queue=False,
Expand All @@ -275,7 +266,7 @@ async def load_example(example_id):
if self.run_on_click and not self.cache_examples:
if self.fn is None:
raise ValueError("Cannot run_on_click if no function is provided")
load_input_event.then(
self.load_input_event.then(
self.fn,
inputs=self.inputs, # type: ignore
outputs=self.outputs, # type: ignore
Expand All @@ -301,25 +292,24 @@ async def cache(self) -> None:

if inspect.isgeneratorfunction(self.fn):

def get_final_item(args): # type: ignore
def get_final_item(*args): # type: ignore
x = None
for x in self.fn(args): # noqa: B007 # type: ignore
for x in self.fn(*args): # noqa: B007 # type: ignore
pass
return x

fn = get_final_item
elif inspect.isasyncgenfunction(self.fn):

async def get_final_item(args):
async def get_final_item(*args):
x = None
async for x in self.fn(args): # noqa: B007 # type: ignore
async for x in self.fn(*args): # noqa: B007 # type: ignore
pass
return x

fn = get_final_item
else:
fn = self.fn

# create a fake dependency to process the examples and get the predictions
dependency, fn_index = Context.root_block.set_event_trigger(
event_name="fake_event",
Expand Down Expand Up @@ -352,6 +342,30 @@ async def get_final_item(args):
# Remove the "fake_event" to prevent bugs in loading interfaces from spaces
Context.root_block.dependencies.remove(dependency)
Context.root_block.fns.pop(fn_index)

# Remove the original load_input_event and replace it with one that
# also populates the input. We do it this way to to allow the cache()
# method to be called independently of the create() method
index = Context.root_block.dependencies.index(self.load_input_event)
Context.root_block.dependencies.pop(index)
Context.root_block.fns.pop(index)

async def load_example(example_id):
processed_example = self.non_none_processed_examples[
example_id
] + await self.load_from_cache(example_id)
return utils.resolve_singleton(processed_example)

self.load_input_event = self.dataset.click(
load_example,
inputs=[self.dataset],
outputs=self.inputs_with_examples + self.outputs, # type: ignore
show_progress="hidden",
postprocess=False,
queue=False,
api_name=self.api_name, # type: ignore
)

print("Caching complete\n")

async def load_from_cache(self, example_id: int) -> list[Any]:
Expand Down
Loading

1 comment on commit b74f845

@vercel
Copy link

@vercel vercel bot commented on b74f845 Jul 24, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please sign in to comment.