From 4ab399f40a300f267231f1b2dbe2a07494322d4d Mon Sep 17 00:00:00 2001 From: Abubakar Abid Date: Wed, 14 Feb 2024 09:43:41 -0800 Subject: [PATCH] Allow config to include non-pickle-able values (#7415) * fixes * lint * add changeset * route utils --------- Co-authored-by: gradio-pr-bot --- .changeset/mean-bushes-hide.md | 5 +++++ gradio/processing_utils.py | 4 +++- gradio/route_utils.py | 15 ++++++++++++- gradio/routes.py | 23 +++++++++----------- test/test_processing_utils.py | 39 ++++++++++++++++++++++++++++++++++ test/test_routes.py | 12 +++++++++++ 6 files changed, 83 insertions(+), 15 deletions(-) create mode 100644 .changeset/mean-bushes-hide.md diff --git a/.changeset/mean-bushes-hide.md b/.changeset/mean-bushes-hide.md new file mode 100644 index 0000000000000..c9e6e2d242372 --- /dev/null +++ b/.changeset/mean-bushes-hide.md @@ -0,0 +1,5 @@ +--- +"gradio": patch +--- + +fix:Allow config to include non-pickle-able values diff --git a/gradio/processing_utils.py b/gradio/processing_utils.py index fb1794f2e7763..2d86d1d8216f6 100644 --- a/gradio/processing_utils.py +++ b/gradio/processing_utils.py @@ -289,9 +289,11 @@ def _move_to_cache(d: dict): return client_utils.traverse(data, _move_to_cache, client_utils.is_file_obj) -def add_root_url(data, root_url) -> dict: +def add_root_url(data: dict, root_url: str, previous_root_url: str | None) -> dict: def _add_root_url(file_dict: dict): if not client_utils.is_http_url_like(file_dict["url"]): + if previous_root_url and file_dict["url"].startswith(previous_root_url): + file_dict["url"] = file_dict["url"][len(previous_root_url) :] file_dict["url"] = f'{root_url}{file_dict["url"]}' return file_dict diff --git a/gradio/route_utils.py b/gradio/route_utils.py index bceda289f0a8a..097bd26c40283 100644 --- a/gradio/route_utils.py +++ b/gradio/route_utils.py @@ -16,7 +16,7 @@ from starlette.datastructures import FormData, Headers, UploadFile from starlette.formparsers import MultiPartException, MultipartPart -from gradio import utils +from gradio import processing_utils, utils from gradio.data_classes import PredictBody from gradio.exceptions import Error from gradio.helpers import EventData @@ -561,3 +561,16 @@ async def parse(self) -> FormData: def move_uploaded_files_to_cache(files: list[str], destinations: list[str]) -> None: for file, dest in zip(files, destinations): shutil.move(file, dest) + + +def update_root_in_config(config: dict, root: str) -> dict: + """ + Updates the root "key" in the config dictionary to the new root url. If the + root url has changed, all of the urls in the config that correspond to component + file urls are updated to use the new root url. + """ + previous_root = config.get("root", None) + if previous_root is None or previous_root != root: + config["root"] = root + config = processing_utils.add_root_url(config, root, previous_root) + return config diff --git a/gradio/routes.py b/gradio/routes.py index 240aeb45117e1..ef8d2eabeff31 100644 --- a/gradio/routes.py +++ b/gradio/routes.py @@ -5,7 +5,6 @@ import asyncio import contextlib -import copy import sys if sys.version_info >= (3, 9): @@ -311,19 +310,18 @@ def login(form_data: OAuth2PasswordRequestForm = Depends()): def main(request: fastapi.Request, user: str = Depends(get_current_user)): mimetypes.add_type("application/javascript", ".js") blocks = app.get_blocks() - root_path = route_utils.get_root_url( + root = route_utils.get_root_url( request=request, route_path="/", root_path=app.root_path ) if app.auth is None or user is not None: - config = copy.deepcopy(app.get_blocks().config) - config["root"] = root_path - config = add_root_url(config, root_path) + config = app.get_blocks().config + config = route_utils.update_root_in_config(config, root) else: config = { "auth_required": True, "auth_message": blocks.auth_message, "space_id": app.get_blocks().space_id, - "root": root_path, + "root": root, } try: @@ -354,13 +352,12 @@ def api_info(): @app.get("/config/", dependencies=[Depends(login_check)]) @app.get("/config", dependencies=[Depends(login_check)]) def get_config(request: fastapi.Request): - config = copy.deepcopy(app.get_blocks().config) - root_path = route_utils.get_root_url( + config = app.get_blocks().config + root = route_utils.get_root_url( request=request, route_path="/config", root_path=app.root_path ) - config["root"] = root_path - config = add_root_url(config, root_path) - return config + config = route_utils.update_root_in_config(config, root) + return ORJSONResponse(content=config) @app.get("/static/{path:path}") def static_resource(path: str): @@ -577,7 +574,7 @@ async def predict( root_path = route_utils.get_root_url( request=request, route_path=f"/api/{api_name}", root_path=app.root_path ) - output = add_root_url(output, root_path) + output = add_root_url(output, root_path, None) return output @app.get("/queue/data", dependencies=[Depends(login_check)]) @@ -634,7 +631,7 @@ async def sse_stream(request: fastapi.Request): "success": False, } if message: - add_root_url(message, root_path) + add_root_url(message, root_path, None) yield f"data: {json.dumps(message)}\n\n" if message["msg"] == ServerMessage.process_completed: blocks._queue.pending_event_ids_session[ diff --git a/test/test_processing_utils.py b/test/test_processing_utils.py index d5743da42c8f2..539b812f13e43 100644 --- a/test/test_processing_utils.py +++ b/test/test_processing_utils.py @@ -332,3 +332,42 @@ def test_video_conversion_returns_original_video_if_fails( ) # If the conversion succeeded it'd be .mp4 assert Path(playable_vid).suffix == ".avi" + + +def test_add_root_url(): + data = { + "file": { + "path": "path", + "url": "/file=path", + }, + "file2": { + "path": "path2", + "url": "https://www.gradio.app", + }, + } + root_url = "http://localhost:7860" + expected = { + "file": { + "path": "path", + "url": f"{root_url}/file=path", + }, + "file2": { + "path": "path2", + "url": "https://www.gradio.app", + }, + } + assert processing_utils.add_root_url(data, root_url, None) == expected + new_root_url = "https://1234.gradio.live" + new_expected = { + "file": { + "path": "path", + "url": f"{root_url}/file=path", + }, + "file2": { + "path": "path2", + "url": "https://www.gradio.app", + }, + } + assert ( + processing_utils.add_root_url(expected, root_url, new_root_url) == new_expected + ) diff --git a/test/test_routes.py b/test/test_routes.py index 84a3e1a5c907b..f54da98c6a5f9 100644 --- a/test/test_routes.py +++ b/test/test_routes.py @@ -439,6 +439,18 @@ def test_proxy_does_not_leak_hf_token_externally(self): r = app.build_proxy_request("https://google.com") assert "authorization" not in dict(r.headers) + def test_can_get_config_that_includes_non_pickle_able_objects(self): + my_dict = {"a": 1, "b": 2, "c": 3} + with Blocks() as demo: + gr.JSON(my_dict.keys()) + + app, _, _ = demo.launch(prevent_thread_lock=True) + client = TestClient(app) + response = client.get("/") + assert response.is_success + response = client.get("/config/") + assert response.is_success + class TestApp: def test_create_app(self):