Skip to content

Commit

Permalink
Change inits to allow easy imports
Browse files Browse the repository at this point in the history
  • Loading branch information
aymeric-roucher committed Jul 22, 2024
1 parent 8396a8c commit ac0d9d4
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 4 deletions.
2 changes: 2 additions & 0 deletions src/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@
"ToolCollection",
"launch_gradio_demo",
"load_tool",
"stream_from_transformers_agent",
],
"audio_utils": [],
"benchmark": [],
Expand Down Expand Up @@ -4730,6 +4731,7 @@
ToolCollection,
launch_gradio_demo,
load_tool,
stream_from_transformers_agent,
)
from .configuration_utils import PretrainedConfig

Expand Down
2 changes: 2 additions & 0 deletions src/transformers/agents/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
"agents": ["Agent", "CodeAgent", "ReactAgent", "ReactCodeAgent", "ReactJsonAgent", "Toolbox"],
"llm_engine": ["HfEngine"],
"tools": ["PipelineTool", "Tool", "ToolCollection", "launch_gradio_demo", "load_tool"],
"monitoring": ["stream_from_transformers_agent"],
}

try:
Expand All @@ -46,6 +47,7 @@
from .agents import Agent, CodeAgent, ReactAgent, ReactCodeAgent, ReactJsonAgent, Toolbox
from .llm_engine import HfEngine
from .tools import PipelineTool, Tool, ToolCollection, launch_gradio_demo, load_tool
from .monitoring import stream_from_transformers_agent

try:
if not is_torch_available():
Expand Down
8 changes: 4 additions & 4 deletions src/transformers/agents/monitoring.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,8 @@ def pull_message(step_log: dict):
)


def stream_from_transformers_agent(agent: ReactAgent, prompt: str):
"""Runs an agent with the given prompt and streams the messages from the agent as ChatMessages."""
def stream_from_transformers_agent(agent: ReactAgent, task: str, **kwargs):
"""Runs an agent with the given task and streams the messages from the agent as gradio ChatMessages."""

try:
from gradio import ChatMessage
Expand All @@ -57,7 +57,7 @@ def stream_from_transformers_agent(agent: ReactAgent, prompt: str):
class Output:
output: AgentType | str = None

for step_log in agent.run(prompt, stream=True):
for step_log in agent.run(task, stream=True, **kwargs):
if isinstance(step_log, dict):
for message in pull_message(step_log):
print("message", message)
Expand All @@ -77,4 +77,4 @@ class Output:
content={"path": Output.output.to_string(), "mime_type": "audio/wav"},
)
else:
return ChatMessage(role="assistant", content=Output.output)
yield ChatMessage(role="assistant", content=Output.output)

0 comments on commit ac0d9d4

Please sign in to comment.