Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Prevent components from working with non-uploaded files #7465

Merged
merged 12 commits into from
Feb 22, 2024
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .changeset/moody-impalas-rule.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
"gradio": minor
aliabid94 marked this conversation as resolved.
Show resolved Hide resolved
---

feat:Prevent components from working with non-uploaded files
14 changes: 12 additions & 2 deletions gradio/blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -1109,6 +1109,7 @@ def __call__(self, *inputs, fn_index: int = 0, api_name: str | None = None):
inputs=processed_inputs,
request=None,
state={},
explicit_call=True,
aliabid94 marked this conversation as resolved.
Show resolved Hide resolved
)
outputs = outputs["data"]

Expand Down Expand Up @@ -1298,7 +1299,11 @@ def validate_inputs(self, fn_index: int, inputs: list[Any]):
)

def preprocess_data(
self, fn_index: int, inputs: list[Any], state: SessionState | None
self,
fn_index: int,
inputs: list[Any],
state: SessionState | None,
explicit_call: bool = False,
aliabid94 marked this conversation as resolved.
Show resolved Hide resolved
):
state = state or SessionState(self)
block_fn = self.fns[fn_index]
Expand All @@ -1325,7 +1330,10 @@ def preprocess_data(
if input_id in state:
block = state[input_id]
inputs_cached = processing_utils.move_files_to_cache(
inputs[i], block, add_urls=True
inputs[i],
block,
add_urls=True,
check_in_upload_folder=not explicit_call,
)
if getattr(block, "data_model", None) and inputs_cached is not None:
if issubclass(block.data_model, GradioModel): # type: ignore
Expand Down Expand Up @@ -1535,6 +1543,7 @@ async def process_api(
event_id: str | None = None,
event_data: EventData | None = None,
in_event_listener: bool = True,
explicit_call: bool = False,
aliabid94 marked this conversation as resolved.
Show resolved Hide resolved
) -> dict[str, Any]:
"""
Processes API calls from the frontend. First preprocesses the data,
Expand Down Expand Up @@ -1574,6 +1583,7 @@ async def process_api(
inputs,
fn_index,
state,
explicit_call,
limiter=self.limiter,
)
result = await self.call_function(
Expand Down
8 changes: 7 additions & 1 deletion gradio/processing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

from gradio import wasm_utils
from gradio.data_classes import FileData, GradioModel, GradioRootModel
from gradio.utils import abspath
from gradio.utils import abspath, get_upload_folder

with warnings.catch_warnings():
warnings.simplefilter("ignore") # Ignore pydub warning if ffmpeg is not installed
Expand Down Expand Up @@ -241,6 +241,7 @@ def move_files_to_cache(
block: Component,
postprocess: bool = False,
add_urls=False,
check_in_upload_folder=False,
) -> dict:
"""Move any files in `data` to cache and (optionally), adds URL prefixes (/file=...) needed to access the cached file.
Also handles the case where the file is on an external Gradio app (/proxy=...).
Expand All @@ -252,6 +253,8 @@ def move_files_to_cache(
block: The component whose data is being processed
postprocess: Whether its running from postprocessing
root_url: The root URL of the local server, if applicable
add_urls: Whether to add URLs to the payload
check_in_upload_folder: If True, instead of moving the file to cache, checks if the file is in upload folder.
aliabid94 marked this conversation as resolved.
Show resolved Hide resolved
"""

def _move_to_cache(d: dict):
Expand All @@ -264,6 +267,9 @@ def _move_to_cache(d: dict):
payload.path = payload.url
elif not block.proxy_url:
# If the file is on a remote server, do not move it to cache.
if check_in_upload_folder:
path = os.path.abspath(payload.path)
assert path.startswith(get_upload_folder()), "File not in upload folder"
aliabid94 marked this conversation as resolved.
Show resolved Hide resolved
temp_file_path = move_resource_to_block_cache(payload.path, block)
assert temp_file_path is not None
payload.path = temp_file_path
Expand Down
9 changes: 2 additions & 7 deletions gradio/routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
import os
import posixpath
import secrets
import tempfile
import threading
import time
import traceback
Expand Down Expand Up @@ -67,9 +66,7 @@
move_uploaded_files_to_cache,
)
from gradio.state_holder import StateHolder
from gradio.utils import (
get_package_version,
)
from gradio.utils import get_package_version, get_upload_folder

if TYPE_CHECKING:
from gradio.blocks import Block
Expand Down Expand Up @@ -136,9 +133,7 @@ def __init__(self, **kwargs):
self.cookie_id = secrets.token_urlsafe(32)
self.queue_token = secrets.token_urlsafe(32)
self.startup_events_triggered = False
self.uploaded_file_dir = os.environ.get("GRADIO_TEMP_DIR") or str(
(Path(tempfile.gettempdir()) / "gradio").resolve()
)
self.uploaded_file_dir = get_upload_folder()
self.change_event: None | threading.Event = None
self._asyncio_tasks: list[asyncio.Task] = []
# Allow user to manually set `docs_url` and `redoc_url`
Expand Down
7 changes: 7 additions & 0 deletions gradio/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import os
import pkgutil
import re
import tempfile
import threading
import time
import traceback
Expand Down Expand Up @@ -1082,3 +1083,9 @@ def compare_objects(obj1, obj2, path=None):
return edits

return compare_objects(old, new)


def get_upload_folder():
aliabid94 marked this conversation as resolved.
Show resolved Hide resolved
return os.environ.get("GRADIO_TEMP_DIR") or str(
(Path(tempfile.gettempdir()) / "gradio").resolve()
)