Skip to content

Commit

Permalink
Add stream messages from agent run for gradio chatbot (#32142)
Browse files Browse the repository at this point in the history
* Add stream_to_gradio method for running agent in gradio demo
  • Loading branch information
aymeric-roucher authored Jul 29, 2024
1 parent 811a9ca commit a24a9a6
Show file tree
Hide file tree
Showing 5 changed files with 134 additions and 0 deletions.
51 changes: 51 additions & 0 deletions docs/source/en/agents.md
Original file line number Diff line number Diff line change
Expand Up @@ -509,3 +509,54 @@ agent = ReactCodeAgent(tools=[search_tool])

agent.run("How many more blocks (also denoted as layers) in BERT base encoder than the encoder from the architecture proposed in Attention is All You Need?")
```

## Gradio interface

You can leverage `gradio.Chatbot`to display your agent's thoughts using `stream_to_gradio`, here is an example:

```py
import gradio as gr
from transformers import (
load_tool,
ReactCodeAgent,
HfEngine,
stream_to_gradio,
)

# Import tool from Hub
image_generation_tool = load_tool("m-ric/text-to-image")

llm_engine = HfEngine("meta-llama/Meta-Llama-3-70B-Instruct")

# Initialize the agent with the image generation tool
agent = ReactCodeAgent(tools=[image_generation_tool], llm_engine=llm_engine)


def interact_with_agent(task):
messages = []
messages.append(gr.ChatMessage(role="user", content=task))
yield messages
for msg in stream_to_gradio(agent, task):
messages.append(msg)
yield messages + [
gr.ChatMessage(role="assistant", content="⏳ Task not finished yet!")
]
yield messages


with gr.Blocks() as demo:
text_input = gr.Textbox(lines=1, label="Chat Message", value="Make me a picture of the Statue of Liberty.")
submit = gr.Button("Run illustrator agent!")
chatbot = gr.Chatbot(
label="Agent",
type="messages",
avatar_images=(
None,
"https://em-content.zobj.net/source/twitter/53/robot-face_1f916.png",
),
)
submit.click(interact_with_agent, [text_input], [chatbot])

if __name__ == "__main__":
demo.launch()
```
4 changes: 4 additions & 0 deletions docs/source/en/main_classes/agent.md
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,10 @@ We provide two types of agents, based on the main [`Agent`] class:

[[autodoc]] launch_gradio_demo

### stream_to_gradio

[[autodoc]] stream_to_gradio

### ToolCollection

[[autodoc]] ToolCollection
Expand Down
2 changes: 2 additions & 0 deletions src/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@
"ToolCollection",
"launch_gradio_demo",
"load_tool",
"stream_to_gradio",
],
"audio_utils": [],
"benchmark": [],
Expand Down Expand Up @@ -4733,6 +4734,7 @@
ToolCollection,
launch_gradio_demo,
load_tool,
stream_to_gradio,
)
from .configuration_utils import PretrainedConfig

Expand Down
2 changes: 2 additions & 0 deletions src/transformers/agents/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
_import_structure = {
"agents": ["Agent", "CodeAgent", "ReactAgent", "ReactCodeAgent", "ReactJsonAgent", "Toolbox"],
"llm_engine": ["HfEngine"],
"monitoring": ["stream_to_gradio"],
"tools": ["PipelineTool", "Tool", "ToolCollection", "launch_gradio_demo", "load_tool"],
}

Expand All @@ -45,6 +46,7 @@
if TYPE_CHECKING:
from .agents import Agent, CodeAgent, ReactAgent, ReactCodeAgent, ReactJsonAgent, Toolbox
from .llm_engine import HfEngine
from .monitoring import stream_to_gradio
from .tools import PipelineTool, Tool, ToolCollection, launch_gradio_demo, load_tool

try:
Expand Down
75 changes: 75 additions & 0 deletions src/transformers/agents/monitoring.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
#!/usr/bin/env python
# coding=utf-8

# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .agent_types import AgentAudio, AgentImage, AgentText
from .agents import ReactAgent


def pull_message(step_log: dict):
try:
from gradio import ChatMessage
except ImportError:
raise ImportError("Gradio should be installed in order to launch a gradio demo.")

if step_log.get("rationale"):
yield ChatMessage(role="assistant", content=step_log["rationale"])
if step_log.get("tool_call"):
used_code = step_log["tool_call"]["tool_name"] == "code interpreter"
content = step_log["tool_call"]["tool_arguments"]
if used_code:
content = f"```py\n{content}\n```"
yield ChatMessage(
role="assistant",
metadata={"title": f"🛠️ Used tool {step_log['tool_call']['tool_name']}"},
content=content,
)
if step_log.get("observation"):
yield ChatMessage(role="assistant", content=f"```\n{step_log['observation']}\n```")
if step_log.get("error"):
yield ChatMessage(
role="assistant",
content=str(step_log["error"]),
metadata={"title": "💥 Error"},
)


def stream_to_gradio(agent: ReactAgent, task: str, **kwargs):
"""Runs an agent with the given task and streams the messages from the agent as gradio ChatMessages."""

try:
from gradio import ChatMessage
except ImportError:
raise ImportError("Gradio should be installed in order to launch a gradio demo.")

for step_log in agent.run(task, stream=True, **kwargs):
if isinstance(step_log, dict):
for message in pull_message(step_log):
yield message

if isinstance(step_log, AgentText):
yield ChatMessage(role="assistant", content=f"**Final answer:**\n```\n{step_log.to_string()}\n```")
elif isinstance(step_log, AgentImage):
yield ChatMessage(
role="assistant",
content={"path": step_log.to_string(), "mime_type": "image/png"},
)
elif isinstance(step_log, AgentAudio):
yield ChatMessage(
role="assistant",
content={"path": step_log.to_string(), "mime_type": "audio/wav"},
)
else:
yield ChatMessage(role="assistant", content=step_log)

0 comments on commit a24a9a6

Please sign in to comment.