Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix output directory of files in client & when calling Blocks as function #4501

Merged
merged 14 commits into from
Jun 14, 2023
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@ No changes to highlight.
- Fixes bug where `/proxy` route was being incorrectly constructed by the frontend by [@abidlabs](https://github.com/abidlabs) in [PR 4430](https://github.com/gradio-app/gradio/pull/4430).
- Fix z-index of status component by [@hannahblair](https://github.com/hannahblair) in [PR 4429](https://github.com/gradio-app/gradio/pull/4429)
- Fix video rendering in Safari by [@aliabid94](https://github.com/aliabid94) in [PR 4433](https://github.com/gradio-app/gradio/pull/4433).
- The output directory for files downloaded when calling Blocks as a function is now set to a temporary directory by default (instead of the working directory in some cases) by [@abidlabs](https://github.com/abidlabs) in [PR 4501](https://github.com/gradio-app/gradio/pull/4501)


## Other Changes:

Expand Down
4 changes: 2 additions & 2 deletions client/python/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,11 @@

## New Features:

No changes to highlight.
- The output directory for files downloaded via the Client can now be set by the `output_dir` parameter in `Client` by [@abidlabs](https://github.com/abidlabs) in [PR 4501](https://github.com/gradio-app/gradio/pull/4501)

## Bug Fixes:

No changes to highlight.
- The output directory for files downloaded via the Client are now set to a temporary directory by default (instead of the working directory in some cases) by [@abidlabs](https://github.com/abidlabs) in [PR 4501](https://github.com/gradio-app/gradio/pull/4501)

## Breaking Changes:

Expand Down
17 changes: 16 additions & 1 deletion client/python/gradio_client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@

import concurrent.futures
import json
import os
import re
import tempfile
import threading
import time
import urllib.parse
Expand Down Expand Up @@ -40,6 +42,11 @@
set_documentation_group("py-client")


DEFAULT_TEMP_DIR = os.environ.get("GRADIO_TEMP_DIR") or str(
Path(tempfile.gettempdir()) / "gradio"
)


@document("predict", "submit", "view_api", "duplicate")
class Client:
"""
Expand All @@ -64,6 +71,7 @@ def __init__(
hf_token: str | None = None,
max_workers: int = 40,
serialize: bool = True,
output_dir: str | Path | None = DEFAULT_TEMP_DIR,
verbose: bool = True,
):
"""
Expand All @@ -72,6 +80,7 @@ def __init__(
hf_token: The Hugging Face token to use to access private Spaces. Automatically fetched if you are logged in via the Hugging Face Hub CLI. Obtain from: https://huggingface.co/settings/token
max_workers: The maximum number of thread workers that can be used to make requests to the remote Gradio app simultaneously.
serialize: Whether the client should serialize the inputs and deserialize the outputs of the remote API. If set to False, the client will pass the inputs and outputs as-is, without serializing/deserializing them. E.g. you if you set this to False, you'd submit an image in base64 format instead of a filepath, and you'd get back an image in base64 format from the remote API instead of a filepath.
output_dir: The directory to save files that are downloaded from the remote API. If None, reads from the GRADIO_TEMP_DIR environment variable. Defaults to a temporary directory on your machine.
verbose: Whether the client should print statements to the console.
"""
self.verbose = verbose
Expand All @@ -83,6 +92,7 @@ def __init__(
library_version=utils.__version__,
)
self.space_id = None
self.output_dir = output_dir

if src.startswith("http://") or src.startswith("https://"):
_src = src if src.endswith("/") else src + "/"
Expand Down Expand Up @@ -796,7 +806,12 @@ def deserialize(self, *data) -> tuple:
), f"Expected {len(self.deserializers)} outputs, got {len(data)}"
outputs = tuple(
[
s.deserialize(d, hf_token=self.client.hf_token, root_url=self.root_url)
s.deserialize(
d,
save_dir=self.client.output_dir,
hf_token=self.client.hf_token,
root_url=self.root_url,
)
for s, d in zip(self.deserializers, data)
]
)
Expand Down
10 changes: 6 additions & 4 deletions client/python/gradio_client/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -383,8 +383,8 @@ def decode_base64_to_file(
dir: str | Path | None = None,
prefix: str | None = None,
):
if dir is not None:
os.makedirs(dir, exist_ok=True)
directory = Path(dir or tempfile.gettempdir()) / secrets.token_hex(20)
abidlabs marked this conversation as resolved.
Show resolved Hide resolved
directory.mkdir(exist_ok=True, parents=True)
data, extension = decode_base64_to_binary(encoding)
if file_path is not None and prefix is None:
filename = Path(file_path).name
Expand All @@ -397,13 +397,15 @@ def decode_base64_to_file(
prefix = strip_invalid_filename_characters(prefix)

if extension is None:
file_obj = tempfile.NamedTemporaryFile(delete=False, prefix=prefix, dir=dir)
file_obj = tempfile.NamedTemporaryFile(
delete=False, prefix=prefix, dir=directory
)
else:
file_obj = tempfile.NamedTemporaryFile(
delete=False,
prefix=prefix,
suffix="." + extension,
dir=dir,
dir=directory,
)
file_obj.write(data)
file_obj.flush()
Expand Down
14 changes: 13 additions & 1 deletion client/python/test/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from concurrent.futures import CancelledError, TimeoutError
from contextlib import contextmanager
from datetime import datetime, timedelta
from pathlib import Path
from unittest.mock import MagicMock, patch

import gradio as gr
Expand All @@ -17,6 +18,7 @@
from huggingface_hub.utils import RepositoryNotFoundError

from gradio_client import Client
from gradio_client.client import DEFAULT_TEMP_DIR
from gradio_client.serializing import Serializable
from gradio_client.utils import Communicator, ProgressUnit, Status, StatusUpdate

Expand Down Expand Up @@ -172,7 +174,17 @@ def test_job_output_video(self):
"https://huggingface.co/spaces/gradio/video_component/resolve/main/files/a.mp4",
fn_index=0,
)
assert pathlib.Path(job.result()).exists()
assert Path(job.result()).exists()
assert Path(DEFAULT_TEMP_DIR).resolve() in Path(job.result()).resolve().parents

temp_dir = tempfile.mkdtemp()
client = Client(src="gradio/video_component", output_dir=temp_dir)
job = client.submit(
"https://huggingface.co/spaces/gradio/video_component/resolve/main/files/a.mp4",
fn_index=0,
)
assert Path(job.result()).exists()
assert Path(temp_dir).resolve() in Path(job.result()).resolve().parents

def test_progress_updates(self, progress_demo):
with connect(progress_demo) as client:
Expand Down
11 changes: 10 additions & 1 deletion gradio/blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,12 @@
import random
import secrets
import sys
import tempfile
import time
import warnings
import webbrowser
from abc import abstractmethod
from pathlib import Path
from types import ModuleType
from typing import TYPE_CHECKING, Any, AsyncIterator, Callable

Expand Down Expand Up @@ -74,6 +76,10 @@
]
}

DEFAULT_TEMP_DIR = os.environ.get("GRADIO_TEMP_DIR") or str(
Path(tempfile.gettempdir()) / "gradio"
)


class Block:
def __init__(
Expand Down Expand Up @@ -1135,7 +1141,10 @@ def deserialize_data(self, fn_index: int, outputs: list[Any]) -> list[Any]:
block, components.IOComponent
), f"{block.__class__} Component with id {output_id} not a valid output component."
deserialized = block.deserialize(
outputs[o], root_url=block.root_url, hf_token=Context.hf_token
outputs[o],
save_dir=DEFAULT_TEMP_DIR,
abidlabs marked this conversation as resolved.
Show resolved Hide resolved
root_url=block.root_url,
hf_token=Context.hf_token,
)
predictions.append(deserialized)

Expand Down
6 changes: 3 additions & 3 deletions gradio/components.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,9 +71,9 @@
Submittable,
Uploadable,
)
from gradio.exceptions import Error
from gradio.interpretation import NeighborInterpretable, TokenInterpretable
from gradio.layouts import Column, Form, Row
from gradio.exceptions import Error

if TYPE_CHECKING:
from typing import TypedDict
Expand Down Expand Up @@ -797,9 +797,9 @@ def preprocess(self, x: float | None) -> float | None:
"""
if x is None:
return None
elif self.minimum != None and x < self.minimum:
elif self.minimum is not None and x < self.minimum:
abidlabs marked this conversation as resolved.
Show resolved Hide resolved
raise Error(f"Value {x} is less than minimum value {self.minimum}.")
elif self.maximum != None and x > self.maximum:
elif self.maximum is not None and x > self.maximum:
raise Error(f"Value {x} is greater than maximum value {self.maximum}.")
return self._round_to_precision(x, self.precision)

Expand Down
11 changes: 7 additions & 4 deletions test/test_blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import warnings
from contextlib import contextmanager
from functools import partial
from pathlib import Path
from string import capwords
from unittest.mock import patch

Expand All @@ -25,6 +26,7 @@
from PIL import Image

import gradio as gr
from gradio.blocks import DEFAULT_TEMP_DIR
from gradio.events import SelectData
from gradio.exceptions import DuplicateBlockError
from gradio.networking import Server, get_first_available_port
Expand Down Expand Up @@ -148,7 +150,7 @@ def update(name):

inp.submit(fn=update, inputs=inp, outputs=out, api_name="greet")

gr.Image().style(height=54, width=240)
gr.Image(height=54, width=240)
abidlabs marked this conversation as resolved.
Show resolved Hide resolved

config1 = demo1.get_config_file()
demo2 = gr.Blocks.from_config(config1, [update], "https://fake.hf.space")
Expand Down Expand Up @@ -485,13 +487,14 @@ def create_images(n_images):
demo = gr.Interface(
create_images,
inputs=[gr.Slider(value=3, minimum=1, maximum=3, step=1)],
outputs=[gr.Gallery().style(grid=2, preview=True)],
outputs=[gr.Gallery(columns=2, preview=True)],
)
with connect(demo) as client:
path = client.predict(3)
_ = client.predict(3)
_ = client.predict(3)
# only three files created
# only three files created and in correct directory
assert len([f for f in tmp_path.glob("**/*") if f.is_file()]) == 3
assert Path(DEFAULT_TEMP_DIR).resolve() in Path(path).resolve().parents

def test_no_empty_image_files(self, tmp_path, connect, monkeypatch):
file_dir = pathlib.Path(pathlib.Path(__file__).parent, "test_files")
Expand Down