From d8e97911a14868745e892efc37eca4a94fdc29d6 Mon Sep 17 00:00:00 2001 From: freddyaboulton Date: Mon, 4 Nov 2024 11:20:35 -0500 Subject: [PATCH 1/5] add code --- gradio/processing_utils.py | 6 ++---- test/test_routes.py | 27 +++++++++++++++++++++++++++ 2 files changed, 29 insertions(+), 4 deletions(-) diff --git a/gradio/processing_utils.py b/gradio/processing_utils.py index 332bb0b8d2143..88c5c402ca7ba 100644 --- a/gradio/processing_utils.py +++ b/gradio/processing_utils.py @@ -513,9 +513,7 @@ def _move_to_cache(d: dict): if isinstance(data, (GradioRootModel, GradioModel)): data = data.model_dump() - return client_utils.traverse( - data, _move_to_cache, client_utils.is_file_obj_with_meta - ) + return client_utils.traverse(data, _move_to_cache, client_utils.is_file_obj) def _check_allowed(path: str | Path, check_in_upload_folder: bool): @@ -635,7 +633,7 @@ async def _move_to_cache(d: dict): if isinstance(data, (GradioRootModel, GradioModel)): data = data.model_dump() return await client_utils.async_traverse( - data, _move_to_cache, client_utils.is_file_obj_with_meta + data, _move_to_cache, client_utils.is_file_obj ) diff --git a/test/test_routes.py b/test/test_routes.py index d085405e193ec..0273380f6cb0d 100644 --- a/test/test_routes.py +++ b/test/test_routes.py @@ -1613,3 +1613,30 @@ def victim(url, results): t.join() assert not any(results), "attacker was able to modify a victim's config root url" + + +def test_file_without_meta_key_not_moved(): + demo = gr.Interface( + fn=lambda s: str(s), inputs=gr.File(type="binary"), outputs="textbox" + ) + + app, _, _ = demo.launch(prevent_thread_lock=True) + test_client = TestClient(app) + try: + with test_client: + req = test_client.post( + "gradio_api/run/predict", + json={ + "data": [ + { + "path": "test/test_files/alphabet.txt", + "orig_name": "test.txt", + "size": 4, + "mime_type": "text/plain", + } + ] + }, + ) + assert req.status_code == 500 + finally: + demo.close() From 2e9b9c2ab65279fcfe5d19c0044152f06207d99a Mon Sep 17 00:00:00 2001 From: freddyaboulton Date: Mon, 4 Nov 2024 13:39:40 -0500 Subject: [PATCH 2/5] add code --- gradio/blocks.py | 17 +++++++---------- gradio/data_classes.py | 23 +++++++++++++++++++---- gradio/processing_utils.py | 6 ++++-- test/test_routes.py | 1 + 4 files changed, 31 insertions(+), 16 deletions(-) diff --git a/gradio/blocks.py b/gradio/blocks.py index e464037eb036e..9b36eb4e39076 100644 --- a/gradio/blocks.py +++ b/gradio/blocks.py @@ -18,12 +18,7 @@ from collections.abc import AsyncIterator, Callable, Coroutine, Sequence, Set from pathlib import Path from types import ModuleType -from typing import ( - TYPE_CHECKING, - Any, - Literal, - cast, -) +from typing import TYPE_CHECKING, Any, Literal, Union, cast from urllib.parse import urlparse, urlunparse import anyio @@ -1702,10 +1697,12 @@ async def preprocess_data( check_in_upload_folder=not explicit_call, ) if getattr(block, "data_model", None) and inputs_cached is not None: - if issubclass(block.data_model, GradioModel): # type: ignore - inputs_cached = block.data_model(**inputs_cached) # type: ignore - elif issubclass(block.data_model, GradioRootModel): # type: ignore - inputs_cached = block.data_model(root=inputs_cached) # type: ignore + data_model = cast( + Union[GradioModel, GradioRootModel], block.data_model + ) + inputs_cached = data_model.model_validate( + inputs_cached, context={"validate_meta": True} + ) processed_input.append(block.preprocess(inputs_cached)) else: processed_input = inputs diff --git a/gradio/data_classes.py b/gradio/data_classes.py index 268509e857a66..ade4f24d7d1e6 100644 --- a/gradio/data_classes.py +++ b/gradio/data_classes.py @@ -28,6 +28,8 @@ GetJsonSchemaHandler, RootModel, ValidationError, + ValidationInfo, + model_validator, ) from pydantic.json_schema import JsonSchemaValue from pydantic_core import core_schema @@ -219,14 +221,27 @@ class FileData(GradioModel): meta: Additional metadata used internally (should not be changed). """ - path: str # server filepath - url: Optional[str] = None # normalised server url - size: Optional[int] = None # size in bytes - orig_name: Optional[str] = None # original filename + path: str + url: Optional[str] = None + size: Optional[int] = None + orig_name: Optional[str] = None mime_type: Optional[str] = None is_stream: bool = False meta: dict = {"_type": "gradio.FileData"} + @model_validator(mode="before") + @classmethod + def validate_python(cls, v, info: ValidationInfo): + if ( + info.context + and info.context.get("validate_meta") + and v.get("meta", {}) != {"_type": "gradio.FileData"} + ): + raise ValueError( + "The 'meta' field must be explicitly provided in the input data and be equal to {'_type': 'gradio.FileData'}." + ) + return v + @property def is_none(self) -> bool: """ diff --git a/gradio/processing_utils.py b/gradio/processing_utils.py index 88c5c402ca7ba..332bb0b8d2143 100644 --- a/gradio/processing_utils.py +++ b/gradio/processing_utils.py @@ -513,7 +513,9 @@ def _move_to_cache(d: dict): if isinstance(data, (GradioRootModel, GradioModel)): data = data.model_dump() - return client_utils.traverse(data, _move_to_cache, client_utils.is_file_obj) + return client_utils.traverse( + data, _move_to_cache, client_utils.is_file_obj_with_meta + ) def _check_allowed(path: str | Path, check_in_upload_folder: bool): @@ -633,7 +635,7 @@ async def _move_to_cache(d: dict): if isinstance(data, (GradioRootModel, GradioModel)): data = data.model_dump() return await client_utils.async_traverse( - data, _move_to_cache, client_utils.is_file_obj + data, _move_to_cache, client_utils.is_file_obj_with_meta ) diff --git a/test/test_routes.py b/test/test_routes.py index 0273380f6cb0d..65c5985d76e1f 100644 --- a/test/test_routes.py +++ b/test/test_routes.py @@ -339,6 +339,7 @@ def test_get_file_created_by_app(self, test_client): { "path": file_response.json()[0], "size": os.path.getsize("test/test_files/alphabet.txt"), + "meta": {"_type": "gradio.FileData"}, } ], "fn_index": 0, From 824f7ba9b86e84698bbb28b1ef13ab0c672539ff Mon Sep 17 00:00:00 2001 From: gradio-pr-bot Date: Mon, 4 Nov 2024 18:44:55 +0000 Subject: [PATCH 3/5] add changeset --- .changeset/four-swans-throw.md | 5 +++++ 1 file changed, 5 insertions(+) create mode 100644 .changeset/four-swans-throw.md diff --git a/.changeset/four-swans-throw.md b/.changeset/four-swans-throw.md new file mode 100644 index 0000000000000..d31cb76317992 --- /dev/null +++ b/.changeset/four-swans-throw.md @@ -0,0 +1,5 @@ +--- +"gradio": minor +--- + +feat:Enforce `meta` key present during preprocess in FileData payloads From 190e5efb1a15796c93d38deba01873d0c663c995 Mon Sep 17 00:00:00 2001 From: freddyaboulton Date: Mon, 4 Nov 2024 13:46:53 -0500 Subject: [PATCH 4/5] revert --- gradio/data_classes.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/gradio/data_classes.py b/gradio/data_classes.py index ade4f24d7d1e6..8569eeff1814c 100644 --- a/gradio/data_classes.py +++ b/gradio/data_classes.py @@ -221,17 +221,17 @@ class FileData(GradioModel): meta: Additional metadata used internally (should not be changed). """ - path: str - url: Optional[str] = None - size: Optional[int] = None - orig_name: Optional[str] = None + path: str # server filepath + url: Optional[str] = None # normalised server url + size: Optional[int] = None # size in bytes + orig_name: Optional[str] = None # original filename mime_type: Optional[str] = None is_stream: bool = False meta: dict = {"_type": "gradio.FileData"} @model_validator(mode="before") @classmethod - def validate_python(cls, v, info: ValidationInfo): + def validate_model(cls, v, info: ValidationInfo): if ( info.context and info.context.get("validate_meta") From fcf1bcb5b2224f094178ea49c73b4114f70a1a9e Mon Sep 17 00:00:00 2001 From: freddyaboulton Date: Mon, 4 Nov 2024 14:42:42 -0500 Subject: [PATCH 5/5] use is_file_obj_with_meta --- gradio/data_classes.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/gradio/data_classes.py b/gradio/data_classes.py index 8569eeff1814c..15c3f3b264753 100644 --- a/gradio/data_classes.py +++ b/gradio/data_classes.py @@ -21,7 +21,7 @@ from fastapi import Request from gradio_client.documentation import document -from gradio_client.utils import traverse +from gradio_client.utils import is_file_obj_with_meta, traverse from pydantic import ( BaseModel, GetCoreSchemaHandler, @@ -235,7 +235,7 @@ def validate_model(cls, v, info: ValidationInfo): if ( info.context and info.context.get("validate_meta") - and v.get("meta", {}) != {"_type": "gradio.FileData"} + and not is_file_obj_with_meta(v) ): raise ValueError( "The 'meta' field must be explicitly provided in the input data and be equal to {'_type': 'gradio.FileData'}."