From 34430b934dbab3bc525f56b390dbc054f76cf56c Mon Sep 17 00:00:00 2001 From: Abubakar Abid Date: Thu, 20 Jun 2024 19:25:42 -0400 Subject: [PATCH] Handle GIFs correct in `gr.Image` preprocessing (#8589) * handle gifs correct in image preprocessing * add changeset * fix * add test * add test * docstring * add docs * image * revert * change * add changeset --------- Co-authored-by: gradio-pr-bot --- .changeset/fluffy-crabs-sleep.md | 6 +++ gradio/components/image.py | 13 +++--- gradio/processing_utils.py | 43 +++++++----------- gradio/test_data/rectangles.gif | Bin 0 -> 1699 bytes .../templates/gradio/03_components/image.svx | 23 ++++++++++ test/test_processing_utils.py | 23 ++++------ 6 files changed, 61 insertions(+), 47 deletions(-) create mode 100644 .changeset/fluffy-crabs-sleep.md create mode 100644 gradio/test_data/rectangles.gif diff --git a/.changeset/fluffy-crabs-sleep.md b/.changeset/fluffy-crabs-sleep.md new file mode 100644 index 0000000000000..dc564413b3d59 --- /dev/null +++ b/.changeset/fluffy-crabs-sleep.md @@ -0,0 +1,6 @@ +--- +"gradio": patch +"website": patch +--- + +fix:Handle GIFs correct in `gr.Image` preprocessing diff --git a/gradio/components/image.py b/gradio/components/image.py index ab47c4f86b7ca..f0c3e898dcbab 100644 --- a/gradio/components/image.py +++ b/gradio/components/image.py @@ -71,12 +71,12 @@ def __init__( """ Parameters: value: A PIL Image, numpy array, path or URL for the default value that Image component is going to take. If callable, the function will be called whenever the app loads to set the initial value of the component. - format: Format to save image if it does not already have a valid format (e.g. if the image is being returned to the frontend as a numpy array or PIL Image). The format should be supported by the PIL library. This parameter has no effect on SVG files. + format: File format (e.g. "png" or "gif") to save image if it does not already have a valid format (e.g. if the image is being returned to the frontend as a numpy array or PIL Image). The format should be supported by the PIL library. This parameter has no effect on SVG files. height: The height of the displayed image, specified in pixels if a number is passed, or in CSS units if a string is passed. width: The width of the displayed image, specified in pixels if a number is passed, or in CSS units if a string is passed. - image_mode: "RGB" if color, or "L" if black and white. See https://pillow.readthedocs.io/en/stable/handbook/concepts.html for other supported image modes and their meaning. + image_mode: "RGB" if color, or "L" if black and white. See https://pillow.readthedocs.io/en/stable/handbook/concepts.html for other supported image modes and their meaning. This parameter has no effect on SVG or GIF files. sources: List of sources for the image. "upload" creates a box where user can drop an image file, "webcam" allows user to take snapshot from their webcam, "clipboard" allows users to paste an image from the clipboard. If None, defaults to ["upload", "webcam", "clipboard"] if streaming is False, otherwise defaults to ["webcam"]. - type: The format the image is converted before being passed into the prediction function. "numpy" converts the image to a numpy array with shape (height, width, 3) and values from 0 to 255, "pil" converts the image to a PIL image object, "filepath" passes a str path to a temporary file containing the image. If the image is SVG, the `type` is ignored and the filepath of the SVG is returned. + type: The format the image is converted before being passed into the prediction function. "numpy" converts the image to a numpy array with shape (height, width, 3) and values from 0 to 255, "pil" converts the image to a PIL image object, "filepath" passes a str path to a temporary file containing the image. If the image is SVG, the `type` is ignored and the filepath of the SVG is returned. To support animated GIFs in input, the `type` should be set to "filepath" or "pil". label: The label for this component. Appears above the component and is also used as the header if there are a table of examples for this component. If None and used in a `gr.Interface`, the label will be the name of the parameter this component is assigned to. every: If `value` is a callable, run the function 'every' number of seconds while the client connection is open. Has no effect otherwise. The event can be accessed (e.g. to cancel it) via this component's .load_event attribute. show_label: if True, will display label. @@ -181,9 +181,10 @@ def preprocess( warnings.warn( f"Failed to transpose image {file_path} based on EXIF data." ) - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - im = im.convert(self.image_mode) + if suffix.lower() != "gif" and im is not None: + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + im = im.convert(self.image_mode) return image_utils.format_image( im, cast(Literal["numpy", "pil", "filepath"], self.type), diff --git a/gradio/processing_utils.py b/gradio/processing_utils.py index 5b658842ee12c..70f958bbd778b 100644 --- a/gradio/processing_utils.py +++ b/gradio/processing_utils.py @@ -17,7 +17,7 @@ import httpx import numpy as np from gradio_client import utils as client_utils -from PIL import Image, ImageOps, PngImagePlugin +from PIL import Image, ImageOps, ImageSequence, PngImagePlugin from gradio import utils, wasm_utils from gradio.data_classes import FileData, GradioModel, GradioRootModel, JsonData @@ -138,7 +138,7 @@ def encode_plot_to_base64(plt, format: str = "png"): plt.savefig(output_bytes, format=fmt) bytes_data = output_bytes.getvalue() base64_str = str(base64.b64encode(bytes_data), "utf-8") - return output_base64(base64_str, fmt) + return f"data:image/{format or 'png'};base64,{base64_str}" def get_pil_exif_bytes(pil_image): @@ -158,34 +158,25 @@ def get_pil_metadata(pil_image): def encode_pil_to_bytes(pil_image, format="png"): with BytesIO() as output_bytes: - if format == "png": - params = {"pnginfo": get_pil_metadata(pil_image)} + if format.lower() == "gif": + frames = [frame.copy() for frame in ImageSequence.Iterator(pil_image)] + frames[0].save( + output_bytes, + format=format, + save_all=True, + append_images=frames[1:], + loop=0, + ) else: - exif = get_pil_exif_bytes(pil_image) - params = {"exif": exif} if exif else {} - pil_image.save(output_bytes, format, **params) + if format.lower() == "png": + params = {"pnginfo": get_pil_metadata(pil_image)} + else: + exif = get_pil_exif_bytes(pil_image) + params = {"exif": exif} if exif else {} + pil_image.save(output_bytes, format, **params) return output_bytes.getvalue() -def encode_pil_to_base64(pil_image, format="png"): - bytes_data = encode_pil_to_bytes(pil_image, format) - base64_str = str(base64.b64encode(bytes_data), "utf-8") - return output_base64(base64_str, format) - - -def encode_array_to_base64(image_array, format="png"): - with BytesIO() as output_bytes: - pil_image = Image.fromarray(_convert(image_array, np.uint8, force_copy=False)) - pil_image.save(output_bytes, format) - bytes_data = output_bytes.getvalue() - base64_str = str(base64.b64encode(bytes_data), "utf-8") - return output_base64(base64_str, format) - - -def output_base64(data, format=None) -> str: - return f"data:image/{format or 'png'};base64,{data}" - - def hash_file(file_path: str | Path, chunk_num_blocks: int = 128) -> str: sha1 = hashlib.sha1() with open(file_path, "rb") as f: diff --git a/gradio/test_data/rectangles.gif b/gradio/test_data/rectangles.gif new file mode 100644 index 0000000000000000000000000000000000000000..a6033b87ff21d603d29b0309c646b9de418f55ce GIT binary patch literal 1699 zcmZ?wbhEHboWQ8V_<;ck{sFP#PZkah5dBZkxhOTUBsE2$JhLQ2!QIn0fI;y;x1VcB zu(M-;tC5}oGb2dde{Rp*#Pn3(#PrPMY~J+5+}uQ413eQ{6H`+L9iRdRkns#m|6BT3 zo_@=}c+Qqv-J9?2`OV+<$Ya{G&SkGoZF|Rk{FBf1Z@tfdo%{ZepNFgc$fA!OK3dby zEc?va>$UvKs;^yVv$o$^_nn)s`}mVhKYRREUw^ahH*f#x=U;aH?K}Va`=5RP`6V{EXYRcD!y~I}>l>R}+dFsd-m`b#{^ga^v-69~ ztLvM$Zr{0k@BZ=0>)ZRs=hyeoU%&tS{rCU*4Hg_|WZ{;w=`dJ$sFg?Bs)oaG(UDFO z?WjE+hKrB&zL&AC`om$g^;&Ixb#+bj?mAxUZEtRF$-cd>*LwTAyL+mC|KqjU z@!{c-ZfW~Io1LGYo|$c3&u6>q%gZaPqxbjO?*8`n&hF~}e0F<&e0*|xwtc_d-d|te m+?LlSoREk}H4O~^K}l>h?Tpm411~=sKt np.ndarray | PIL.Image.Image | str | Path | None {/if} + +### `GIF` and `SVG` Image Formats + +The `gr.Image` component can process or display any image format that is [supported by the PIL library](https://pillow.readthedocs.io/en/stable/handbook/image-file-formats.html), including animated GIFs. In addition, it also supports the SVG image format. + +When the `gr.Image` component is used as an input component, the image is converted into a `str` filepath, a `PIL.Image` object, or a `numpy.array`, depending on the `type` parameter. However, animated GIF and SVG images are treated differently: + +* Animated `GIF` images can only be converted to `str` filepaths or `PIL.Image` objects. If they are converted to a `numpy.array` (which is the default behavior), only the first frame will be used. So if your demo expects an input `GIF` image, make sure to set the `type` parameter accordingly, e.g. + +```py +import gradio as gr + +demo = gr.Interface( + fn=lambda x:x, + inputs=gr.Image(type="filepath"), + outputs=gr.Image() +) + +demo.launch() +``` + +* For `SVG` images, the `type` parameter is ignored altogether and the image is always returned as an image filepath. This is because `SVG` images cannot be processed as `PIL.Image` or `numpy.array` objects. + {#if obj.demos && obj.demos.length > 0} ### Demos diff --git a/test/test_processing_utils.py b/test/test_processing_utils.py index 2369fd9a77b5f..730b62bd77948 100644 --- a/test/test_processing_utils.py +++ b/test/test_processing_utils.py @@ -1,7 +1,6 @@ import os import shutil import tempfile -from copy import deepcopy from pathlib import Path from unittest.mock import patch @@ -114,20 +113,6 @@ def test_encode_plot_to_base64(self): "data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAo" ) - def test_encode_array_to_base64(self): - img = Image.open("gradio/test_data/test_image.png") - img = img.convert("RGB") - numpy_data = np.asarray(img, dtype=np.uint8) - output_base64 = processing_utils.encode_array_to_base64(numpy_data) - assert output_base64 == deepcopy(media_data.ARRAY_TO_BASE64_IMAGE) - - def test_encode_pil_to_base64(self): - img = Image.open("gradio/test_data/test_image.png") - img = img.convert("RGB") - img.info = {} # Strip metadata - output_base64 = processing_utils.encode_pil_to_base64(img) - assert output_base64 == deepcopy(media_data.ARRAY_TO_BASE64_IMAGE) - def test_save_pil_to_file_keeps_pnginfo(self, gradio_temp_dir): input_img = Image.open("gradio/test_data/test_image.png") input_img = input_img.convert("RGB") @@ -141,6 +126,14 @@ def test_save_pil_to_file_keeps_pnginfo(self, gradio_temp_dir): assert output_img.info == input_img.info + def test_save_pil_to_file_keeps_all_gif_frames(self, gradio_temp_dir): + input_img = Image.open("gradio/test_data/rectangles.gif") + file_obj = processing_utils.save_pil_to_cache( + input_img, cache_dir=gradio_temp_dir, format="gif" + ) + output_img = Image.open(file_obj) + assert output_img.n_frames == input_img.n_frames == 3 + def test_np_pil_encode_to_the_same(self, gradio_temp_dir): arr = np.random.randint(0, 255, size=(100, 100, 3), dtype=np.uint8) pil = Image.fromarray(arr)