Skip to content

Commit

Permalink
Allow config to include non-pickle-able values (#7415)
Browse files Browse the repository at this point in the history
* fixes

* lint

* add changeset

* route utils

---------

Co-authored-by: gradio-pr-bot <[email protected]>
  • Loading branch information
abidlabs and gradio-pr-bot authored Feb 14, 2024
1 parent c2dfc59 commit 4ab399f
Show file tree
Hide file tree
Showing 6 changed files with 83 additions and 15 deletions.
5 changes: 5 additions & 0 deletions .changeset/mean-bushes-hide.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
"gradio": patch
---

fix:Allow config to include non-pickle-able values
4 changes: 3 additions & 1 deletion gradio/processing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,9 +289,11 @@ def _move_to_cache(d: dict):
return client_utils.traverse(data, _move_to_cache, client_utils.is_file_obj)


def add_root_url(data, root_url) -> dict:
def add_root_url(data: dict, root_url: str, previous_root_url: str | None) -> dict:
def _add_root_url(file_dict: dict):
if not client_utils.is_http_url_like(file_dict["url"]):
if previous_root_url and file_dict["url"].startswith(previous_root_url):
file_dict["url"] = file_dict["url"][len(previous_root_url) :]
file_dict["url"] = f'{root_url}{file_dict["url"]}'
return file_dict

Expand Down
15 changes: 14 additions & 1 deletion gradio/route_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from starlette.datastructures import FormData, Headers, UploadFile
from starlette.formparsers import MultiPartException, MultipartPart

from gradio import utils
from gradio import processing_utils, utils
from gradio.data_classes import PredictBody
from gradio.exceptions import Error
from gradio.helpers import EventData
Expand Down Expand Up @@ -561,3 +561,16 @@ async def parse(self) -> FormData:
def move_uploaded_files_to_cache(files: list[str], destinations: list[str]) -> None:
for file, dest in zip(files, destinations):
shutil.move(file, dest)


def update_root_in_config(config: dict, root: str) -> dict:
"""
Updates the root "key" in the config dictionary to the new root url. If the
root url has changed, all of the urls in the config that correspond to component
file urls are updated to use the new root url.
"""
previous_root = config.get("root", None)
if previous_root is None or previous_root != root:
config["root"] = root
config = processing_utils.add_root_url(config, root, previous_root)
return config
23 changes: 10 additions & 13 deletions gradio/routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

import asyncio
import contextlib
import copy
import sys

if sys.version_info >= (3, 9):
Expand Down Expand Up @@ -311,19 +310,18 @@ def login(form_data: OAuth2PasswordRequestForm = Depends()):
def main(request: fastapi.Request, user: str = Depends(get_current_user)):
mimetypes.add_type("application/javascript", ".js")
blocks = app.get_blocks()
root_path = route_utils.get_root_url(
root = route_utils.get_root_url(
request=request, route_path="/", root_path=app.root_path
)
if app.auth is None or user is not None:
config = copy.deepcopy(app.get_blocks().config)
config["root"] = root_path
config = add_root_url(config, root_path)
config = app.get_blocks().config
config = route_utils.update_root_in_config(config, root)
else:
config = {
"auth_required": True,
"auth_message": blocks.auth_message,
"space_id": app.get_blocks().space_id,
"root": root_path,
"root": root,
}

try:
Expand Down Expand Up @@ -354,13 +352,12 @@ def api_info():
@app.get("/config/", dependencies=[Depends(login_check)])
@app.get("/config", dependencies=[Depends(login_check)])
def get_config(request: fastapi.Request):
config = copy.deepcopy(app.get_blocks().config)
root_path = route_utils.get_root_url(
config = app.get_blocks().config
root = route_utils.get_root_url(
request=request, route_path="/config", root_path=app.root_path
)
config["root"] = root_path
config = add_root_url(config, root_path)
return config
config = route_utils.update_root_in_config(config, root)
return ORJSONResponse(content=config)

@app.get("/static/{path:path}")
def static_resource(path: str):
Expand Down Expand Up @@ -577,7 +574,7 @@ async def predict(
root_path = route_utils.get_root_url(
request=request, route_path=f"/api/{api_name}", root_path=app.root_path
)
output = add_root_url(output, root_path)
output = add_root_url(output, root_path, None)
return output

@app.get("/queue/data", dependencies=[Depends(login_check)])
Expand Down Expand Up @@ -634,7 +631,7 @@ async def sse_stream(request: fastapi.Request):
"success": False,
}
if message:
add_root_url(message, root_path)
add_root_url(message, root_path, None)
yield f"data: {json.dumps(message)}\n\n"
if message["msg"] == ServerMessage.process_completed:
blocks._queue.pending_event_ids_session[
Expand Down
39 changes: 39 additions & 0 deletions test/test_processing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,3 +332,42 @@ def test_video_conversion_returns_original_video_if_fails(
)
# If the conversion succeeded it'd be .mp4
assert Path(playable_vid).suffix == ".avi"


def test_add_root_url():
data = {
"file": {
"path": "path",
"url": "/file=path",
},
"file2": {
"path": "path2",
"url": "https://www.gradio.app",
},
}
root_url = "http://localhost:7860"
expected = {
"file": {
"path": "path",
"url": f"{root_url}/file=path",
},
"file2": {
"path": "path2",
"url": "https://www.gradio.app",
},
}
assert processing_utils.add_root_url(data, root_url, None) == expected
new_root_url = "https://1234.gradio.live"
new_expected = {
"file": {
"path": "path",
"url": f"{root_url}/file=path",
},
"file2": {
"path": "path2",
"url": "https://www.gradio.app",
},
}
assert (
processing_utils.add_root_url(expected, root_url, new_root_url) == new_expected
)
12 changes: 12 additions & 0 deletions test/test_routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -439,6 +439,18 @@ def test_proxy_does_not_leak_hf_token_externally(self):
r = app.build_proxy_request("https://google.com")
assert "authorization" not in dict(r.headers)

def test_can_get_config_that_includes_non_pickle_able_objects(self):
my_dict = {"a": 1, "b": 2, "c": 3}
with Blocks() as demo:
gr.JSON(my_dict.keys())

app, _, _ = demo.launch(prevent_thread_lock=True)
client = TestClient(app)
response = client.get("/")
assert response.is_success
response = client.get("/config/")
assert response.is_success


class TestApp:
def test_create_app(self):
Expand Down

0 comments on commit 4ab399f

Please sign in to comment.