diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index 3b7a3a59a7a80e..f1d5881f161e25 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -67,6 +67,7 @@ "ToolCollection", "launch_gradio_demo", "load_tool", + "stream_from_transformers_agent", ], "audio_utils": [], "benchmark": [], @@ -4730,6 +4731,7 @@ ToolCollection, launch_gradio_demo, load_tool, + stream_from_transformers_agent, ) from .configuration_utils import PretrainedConfig diff --git a/src/transformers/agents/__init__.py b/src/transformers/agents/__init__.py index 672977f98812c5..1f2be3b3cf7bf7 100644 --- a/src/transformers/agents/__init__.py +++ b/src/transformers/agents/__init__.py @@ -27,6 +27,7 @@ "agents": ["Agent", "CodeAgent", "ReactAgent", "ReactCodeAgent", "ReactJsonAgent", "Toolbox"], "llm_engine": ["HfEngine"], "tools": ["PipelineTool", "Tool", "ToolCollection", "launch_gradio_demo", "load_tool"], + "monitoring": ["stream_from_transformers_agent"], } try: @@ -46,6 +47,7 @@ from .agents import Agent, CodeAgent, ReactAgent, ReactCodeAgent, ReactJsonAgent, Toolbox from .llm_engine import HfEngine from .tools import PipelineTool, Tool, ToolCollection, launch_gradio_demo, load_tool + from .monitoring import stream_from_transformers_agent try: if not is_torch_available(): diff --git a/src/transformers/agents/monitoring.py b/src/transformers/agents/monitoring.py index 6b78de184d79fe..b1ab66119b5a76 100644 --- a/src/transformers/agents/monitoring.py +++ b/src/transformers/agents/monitoring.py @@ -46,8 +46,8 @@ def pull_message(step_log: dict): ) -def stream_from_transformers_agent(agent: ReactAgent, prompt: str): - """Runs an agent with the given prompt and streams the messages from the agent as ChatMessages.""" +def stream_from_transformers_agent(agent: ReactAgent, task: str, **kwargs): + """Runs an agent with the given task and streams the messages from the agent as gradio ChatMessages.""" try: from gradio import ChatMessage @@ -57,7 +57,7 @@ def stream_from_transformers_agent(agent: ReactAgent, prompt: str): class Output: output: AgentType | str = None - for step_log in agent.run(prompt, stream=True): + for step_log in agent.run(task, stream=True, **kwargs): if isinstance(step_log, dict): for message in pull_message(step_log): print("message", message) @@ -77,4 +77,4 @@ class Output: content={"path": Output.output.to_string(), "mime_type": "audio/wav"}, ) else: - return ChatMessage(role="assistant", content=Output.output) + yield ChatMessage(role="assistant", content=Output.output)