Skip to content

Commit

Permalink
Make fix in #7444 (Block /file= filepaths that could expose credentia…
Browse files Browse the repository at this point in the history
…ls on Windows) more general (#7453)

* test routes

* chagne

* add changeset

* add changeset

* type fixes

* fix typing issues

* typed dict

---------

Co-authored-by: gradio-pr-bot <[email protected]>
  • Loading branch information
abidlabs and gradio-pr-bot authored Feb 16, 2024
1 parent f52cab6 commit ba747ad
Show file tree
Hide file tree
Showing 5 changed files with 40 additions and 10 deletions.
5 changes: 5 additions & 0 deletions .changeset/bumpy-wasps-march.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
"gradio": patch
---

feat:Make fix in #7444 (Block /file= filepaths that could expose credentials on Windows) more general
19 changes: 13 additions & 6 deletions gradio/route_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import hashlib
import hmac
import json
import re
import shutil
from collections import deque
from dataclasses import dataclass as python_dataclass
Expand Down Expand Up @@ -466,12 +467,10 @@ def on_header_end(self) -> None:
self._current_partial_header_value = b""

def on_headers_finished(self) -> None:
disposition, options = parse_options_header(
self._current_part.content_disposition
)
_, options = parse_options_header(self._current_part.content_disposition or b"")
try:
self._current_part.field_name = _user_safe_decode(
options[b"name"], self._charset
options[b"name"], str(self._charset)
)
except KeyError as e:
raise MultiPartException(
Expand All @@ -483,7 +482,7 @@ def on_headers_finished(self) -> None:
raise MultiPartException(
f"Too many files. Maximum number of files is {self.max_files}."
)
filename = _user_safe_decode(options[b"filename"], self._charset)
filename = _user_safe_decode(options[b"filename"], str(self._charset))
tempfile = NamedTemporaryFile(delete=False)
self._files_to_close_on_error.append(tempfile)
self._current_part.file = GradioUploadFile(
Expand Down Expand Up @@ -516,7 +515,7 @@ async def parse(self) -> FormData:
raise MultiPartException("Missing boundary in multipart.") from e

# Callbacks dictionary.
callbacks = {
callbacks: multipart.multipart.MultipartCallbacks = {
"on_part_begin": self.on_part_begin,
"on_part_data": self.on_part_data,
"on_part_end": self.on_part_end,
Expand Down Expand Up @@ -579,3 +578,11 @@ def update_root_in_config(config: dict, root: str) -> dict:

def compare_passwords_securely(input_password: str, correct_password: str) -> bool:
return hmac.compare_digest(input_password.encode(), correct_password.encode())


def starts_with_protocol(string: str) -> bool:
"""This regex matches strings that start with a scheme (one or more characters not including colon, slash, or space)
followed by ://
"""
pattern = r"^[a-zA-Z][a-zA-Z0-9+\-.]*://"
return re.match(pattern, string) is not None
5 changes: 2 additions & 3 deletions gradio/routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -429,8 +429,7 @@ async def file(path_or_url: str, request: fastapi.Request):
url=path_or_url, status_code=status.HTTP_302_FOUND
)

invalid_prefixes = ["//", "file://", "ftp://", "sftp://", "smb://"]
if any(path_or_url.startswith(prefix) for prefix in invalid_prefixes):
if route_utils.starts_with_protocol(path_or_url):
raise HTTPException(403, f"File not allowed: {path_or_url}.")

abs_path = utils.abspath(path_or_url)
Expand Down Expand Up @@ -779,7 +778,7 @@ async def upload_file(
):
content_type_header = request.headers.get("Content-Type")
content_type: bytes
content_type, _ = parse_options_header(content_type_header)
content_type, _ = parse_options_header(content_type_header or "")
if content_type != b"multipart/form-data":
raise HTTPException(status_code=400, detail="Invalid content type.")

Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ packaging
pandas>=1.0,<3.0
pillow>=8.0,<11.0
pydantic>=2.0
python-multipart # required for fastapi forms
python-multipart>=0.0.9 # required for fastapi forms
pydub
pyyaml>=5.0,<7.0
semantic_version~=2.0
Expand Down
19 changes: 19 additions & 0 deletions test/test_routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
FnIndexInferError,
compare_passwords_securely,
get_root_url,
starts_with_protocol,
)


Expand Down Expand Up @@ -920,3 +921,21 @@ def test_compare_passwords_securely():
assert compare_passwords_securely(password1, password1)
assert not compare_passwords_securely(password1, password2)
assert compare_passwords_securely(password2, password2)


@pytest.mark.parametrize(
"string, expected",
[
("http://localhost:7860/", True),
("https://localhost:7860/", True),
("ftp://localhost:7860/", True),
("smb://example.com", True),
("ipfs://QmTzQ1Nj5R9BzF1djVQv8gvzZxVkJb1vhrLcXL1QyJzZE", True),
("usr/local/bin", False),
("localhost:7860", False),
("localhost", False),
("C:/Users/username", False),
],
)
def test_starts_with_protocol(string, expected):
assert starts_with_protocol(string) == expected

0 comments on commit ba747ad

Please sign in to comment.