Skip to content

Commit

Permalink
Fix gradio tool demos (#31230)
Browse files Browse the repository at this point in the history
* Fix gradio tool demos
  • Loading branch information
aymeric-roucher authored and itazap committed Jun 18, 2024
1 parent eddf052 commit 6631286
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 23 deletions.
29 changes: 24 additions & 5 deletions src/transformers/agents/agent_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,8 @@ class AgentImage(AgentType, ImageType):
"""

def __init__(self, value):
super().__init__(value)
AgentType.__init__(self, value)
ImageType.__init__(self)

if not is_vision_available():
raise ImportError("PIL must be installed in order to handle images.")
Expand All @@ -103,6 +104,8 @@ def __init__(self, value):
self._path = value
elif isinstance(value, torch.Tensor):
self._tensor = value
elif isinstance(value, np.ndarray):
self._tensor = torch.tensor(value)
else:
raise ValueError(f"Unsupported type for {self.__class__.__name__}: {type(value)}")

Expand All @@ -125,6 +128,10 @@ def to_raw(self):
self._raw = Image.open(self._path)
return self._raw

if self._tensor is not None:
array = self._tensor.cpu().detach().numpy()
return Image.fromarray((255 - array * 255).astype(np.uint8))

def to_string(self):
"""
Returns the stringified version of that object. In the case of an AgentImage, it is a path to the serialized
Expand All @@ -137,14 +144,13 @@ def to_string(self):
directory = tempfile.mkdtemp()
self._path = os.path.join(directory, str(uuid.uuid4()) + ".png")
self._raw.save(self._path)

return self._path

if self._tensor is not None:
array = self._tensor.cpu().detach().numpy()

# There is likely simpler than load into image into save
img = Image.fromarray((array * 255).astype(np.uint8))
img = Image.fromarray((255 - array * 255).astype(np.uint8))

directory = tempfile.mkdtemp()
self._path = os.path.join(directory, str(uuid.uuid4()) + ".png")
Expand All @@ -153,8 +159,19 @@ def to_string(self):

return self._path

def save(self, output_bytes, format, **params):
"""
Saves the image to a file.
Args:
output_bytes (bytes): The output bytes to save the image to.
format (str): The format to use for the output image. The format is the same as in PIL.Image.save.
**params: Additional parameters to pass to PIL.Image.save.
"""
img = self.to_raw()
img.save(output_bytes, format, **params)

class AgentAudio(AgentType):

class AgentAudio(AgentType, str):
"""
Audio type returned by the agent.
"""
Expand All @@ -169,11 +186,13 @@ def __init__(self, value, samplerate=16_000):
self._tensor = None

self.samplerate = samplerate

if isinstance(value, (str, pathlib.Path)):
self._path = value
elif isinstance(value, torch.Tensor):
self._tensor = value
elif isinstance(value, tuple):
self.samplerate = value[0]
self._tensor = torch.tensor(value[1])
else:
raise ValueError(f"Unsupported audio type: {type(value)}")

Expand Down
5 changes: 3 additions & 2 deletions src/transformers/agents/prompts.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,8 @@ def download_prompt(prompt_or_repo_id, agent_name, mode="run"):
"""


DEFAULT_REACT_JSON_SYSTEM_PROMPT = """You will be given a task to solve as best you can. To do so, you have been given access to the following tools: <<tool_names>>
DEFAULT_REACT_JSON_SYSTEM_PROMPT = """You are an expert assistant who can solve any task using JSON tool calls. You will be given a task to solve as best you can.
To do so, you have been given access to the following tools: <<tool_names>>
The way you use the tools is by specifying a json blob, ending with '<end_action>'.
Specifically, this json should have an `action` key (name of the tool to use) and an `action_input` key (input to the tool).
Expand Down Expand Up @@ -261,7 +262,7 @@ def download_prompt(prompt_or_repo_id, agent_name, mode="run"):
"""


DEFAULT_REACT_CODE_SYSTEM_PROMPT = """You will be given a task to solve as best you can.
DEFAULT_REACT_CODE_SYSTEM_PROMPT = """You are an expert assistant who can solve any task using code blobs. You will be given a task to solve as best you can.
To do so, you have been given access to *tools*: these tools are basically Python functions which you can call with code.
To solve the task, you must plan forward to proceed in a series of steps, in a cycle of 'Thought:', 'Code:', and 'Observation:' sequences.
Expand Down
28 changes: 12 additions & 16 deletions src/transformers/agents/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,10 +47,6 @@
logger = logging.get_logger(__name__)


if is_vision_available():
import PIL.Image
import PIL.ImageOps

if is_torch_available():
import torch

Expand Down Expand Up @@ -623,20 +619,20 @@ def fn(*args, **kwargs):
return tool(*args, **kwargs)

gradio_inputs = []
for input_type in [tool_input["type"] for tool_input in tool_class.inputs.values()]:
if input_type in [str, int, float]:
gradio_inputs += "text"
elif is_vision_available() and input_type == PIL.Image.Image:
gradio_inputs += "image"
for input_name, input_details in tool_class.inputs.items():
input_type = input_details["type"]
if input_type == "text":
gradio_inputs.append(gr.Textbox(label=input_name))
elif input_type == "image":
gradio_inputs.append(gr.Image(label=input_name))
elif input_type == "audio":
gradio_inputs.append(gr.Audio(label=input_name))
else:
gradio_inputs += "audio"
error_message = f"Input type '{input_type}' not supported."
raise ValueError(error_message)

if tool_class.output_type in [str, int, float]:
gradio_output = "text"
elif is_vision_available() and tool_class.output_type == PIL.Image.Image:
gradio_output = "image"
else:
gradio_output = "audio"
gradio_output = tool_class.output_type
assert gradio_output in ["text", "image", "audio"], f"Output type '{gradio_output}' not supported."

gr.Interface(
fn=fn,
Expand Down

0 comments on commit 6631286

Please sign in to comment.