Skip to content

Commit

Permalink
WIP: feat: redirect uri wildcards
Browse files Browse the repository at this point in the history
  • Loading branch information
dopry committed Oct 2, 2024
1 parent b48fd8b commit 3f9215c
Show file tree
Hide file tree
Showing 7 changed files with 473 additions and 22 deletions.
38 changes: 38 additions & 0 deletions docs/settings.rst
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,44 @@ assigned ports.
Note that you may override ``Application.get_allowed_schemes()`` to set this on
a per-application basis.

ALLOW_REDIRECT_URI_WILDCARDS
~~~~~~~~~~~~~~~~~~~~~~~~~~~~

Default: ``False``

SECURITY WARNING: Enabling this setting can introduce security vulnerabilities. Only enable
this setting if you understand the risks. https://datatracker.ietf.org/doc/html/rfc6749#section-3.1.2
states "The redirection endpoint URI MUST be an absolute URI as defined by [RFC3986] Section 4.3." The
intent of the URI restrictions is to prevent open redirects and phishing attacks. If you do enable this
ensure that the wildcards restrict URIs to resources under your control. You are strongly encouragd not
to use this feature in production.

When set to ``True``, the server will allow wildcard characters in the domains
and paths for redirect_uris and post_logout_redirect_uris.

``*`` is the only wildcard character allowed.

``*`` can only be used as a prefix to a domain, must be the first character in
the domain, and cannot be in the top or second level domain. Matching is done using an
endsWith check.

For example,
``https://*.example.com`` is allowed,
``https://*-myproject.example.com`` is allowed,
``https://*.sub.example.com`` is not allowed,
``https://*.com`` is not allowed, and
``https://example.*.com`` is not allowed.

``*`` can also be used as a suffix to a path, must be the last character in the path.
Matching is done using a startsWith check.

For example,
``https://example.com/*`` is allowed, and
``https://example.com/path/*`` is allowed.

This feature is useful for working with CI service such as cloudflare, netlify, and vercel that offer branch
deployments for development previews and user acceptance testing.

ALLOWED_SCHEMES
~~~~~~~~~~~~~~~~~~~~~~~~~~~~

Expand Down
65 changes: 45 additions & 20 deletions oauth2_provider/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,11 @@ def clean(self):

if redirect_uris:
validator = AllowedURIValidator(
allowed_schemes, name="redirect uri", allow_path=True, allow_query=True
allowed_schemes,
name="redirect uri",
allow_path=True,
allow_query=True,
allow_hostname_wildcard=oauth2_settings.ALLOW_REDIRECT_URI_WILDCARDS,
)
for uri in redirect_uris:
validator(uri)
Expand All @@ -227,7 +231,11 @@ def clean(self):
allowed_origins = self.allowed_origins.strip().split()
if allowed_origins:
# oauthlib allows only https scheme for CORS
validator = AllowedURIValidator(oauth2_settings.ALLOWED_SCHEMES, "allowed origin")
validator = AllowedURIValidator(
oauth2_settings.ALLOWED_SCHEMES,
"allowed origin",
allow_hostname_wildcard=oauth2_settings.ALLOW_REDIRECT_URI_WILDCARDS,
)
for uri in allowed_origins:
validator(uri)

Expand Down Expand Up @@ -777,39 +785,55 @@ def redirect_to_uri_allowed(uri, allowed_uris):
:param allowed_uris: A list of URIs that are allowed
"""

if not isinstance(allowed_uris, list):
raise ValueError("allowed_uris must be a list")

parsed_uri = urlparse(uri)
uqs_set = set(parse_qsl(parsed_uri.query))
for allowed_uri in allowed_uris:
parsed_allowed_uri = urlparse(allowed_uri)

if parsed_allowed_uri.scheme != parsed_uri.scheme:
# match failed, continue
continue

""" check hostname """
if oauth2_settings.ALLOW_REDIRECT_URI_WILDCARDS and parsed_allowed_uri.hostname.startswith("*"):
""" wildcard hostname """
if not parsed_uri.hostname.endswith(parsed_allowed_uri.hostname[1:]):
continue
elif parsed_allowed_uri.hostname != parsed_uri.hostname:
continue

# From RFC 8252 (Section 7.3)
# https://datatracker.ietf.org/doc/html/rfc8252#section-7.3
#
# Loopback redirect URIs use the "http" scheme
# [...]
# The authorization server MUST allow any port to be specified at the
# time of the request for loopback IP redirect URIs, to accommodate
# clients that obtain an available ephemeral port from the operating
# system at the time of the request.
allowed_uri_is_loopback = parsed_allowed_uri.scheme == "http" and parsed_allowed_uri.hostname in [
"127.0.0.1",
"::1",
]
""" check port """
if not allowed_uri_is_loopback and parsed_allowed_uri.port != parsed_uri.port:
continue

""" check path """
if parsed_allowed_uri.path != parsed_uri.path:
continue

""" check querystring """
aqs_set = set(parse_qsl(parsed_allowed_uri.query))
if not aqs_set.issubset(uqs_set):
continue # circuit break

allowed_uri_is_loopback = (
parsed_allowed_uri.scheme == "http"
and parsed_allowed_uri.hostname in ["127.0.0.1", "::1"]
and parsed_allowed_uri.port is None
)
if (
allowed_uri_is_loopback
and parsed_allowed_uri.scheme == parsed_uri.scheme
and parsed_allowed_uri.hostname == parsed_uri.hostname
and parsed_allowed_uri.path == parsed_uri.path
) or (
parsed_allowed_uri.scheme == parsed_uri.scheme
and parsed_allowed_uri.netloc == parsed_uri.netloc
and parsed_allowed_uri.path == parsed_uri.path
):
aqs_set = set(parse_qsl(parsed_allowed_uri.query))
if aqs_set.issubset(uqs_set):
return True
return True

# if uris matched then it's not allowed
return False


Expand All @@ -833,4 +857,5 @@ def is_origin_allowed(origin, allowed_origins):
and parsed_allowed_origin.netloc == parsed_origin.netloc
):
return True

return False
1 change: 1 addition & 0 deletions oauth2_provider/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@
"REQUEST_APPROVAL_PROMPT": "force",
"ALLOWED_REDIRECT_URI_SCHEMES": ["http", "https"],
"ALLOWED_SCHEMES": ["https"],
"ALLOW_REDIRECT_URI_WILDCARDS": False,
"OIDC_ENABLED": False,
"OIDC_ISS_ENDPOINT": "",
"OIDC_USERINFO_ENDPOINT": "",
Expand Down
62 changes: 60 additions & 2 deletions oauth2_provider/validators.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,15 @@ class URIValidator(URLValidator):
class AllowedURIValidator(URIValidator):
# TODO: find a way to get these associated with their form fields in place of passing name
# TODO: submit PR to get `cause` included in the parent class ValidationError params`
def __init__(self, schemes, name, allow_path=False, allow_query=False, allow_fragments=False):
def __init__(
self,
schemes,
name,
allow_path=False,
allow_query=False,
allow_fragments=False,
allow_hostname_wildcard=False,
):
"""
:param schemes: List of allowed schemes. E.g.: ["https"]
:param name: Name of the validated URI. It is required for validation message. E.g.: "Origin"
Expand All @@ -34,6 +42,7 @@ def __init__(self, schemes, name, allow_path=False, allow_query=False, allow_fra
self.allow_path = allow_path
self.allow_query = allow_query
self.allow_fragments = allow_fragments
self.allow_hostname_wildcard = allow_hostname_wildcard

def __call__(self, value):
value = force_str(value)
Expand Down Expand Up @@ -68,8 +77,57 @@ def __call__(self, value):
params={"name": self.name, "value": value, "cause": "path not allowed"},
)

if self.allow_hostname_wildcard and "*" in netloc:
domain_parts = netloc.split(".")
if netloc.count("*") > 1:
raise ValidationError(
"%(name)s URI validation error. %(cause)s: %(value)s",
params={
"name": self.name,
"value": value,
"cause": "only one wildcard is allowed in the hostname",
},
)
if not netloc.startswith("*"):
raise ValidationError(
"%(name)s URI validation error. %(cause)s: %(value)s",
params={
"name": self.name,
"value": value,
"cause": "wildcards must be at the beginning of the hostname",
},
)
if len(domain_parts) < 3:
raise ValidationError(
"%(name)s URI validation error. %(cause)s: %(value)s",
params={
"name": self.name,
"value": value,
"cause": "wildcards cannot be in the top level or second level domain",
},
)

# strip the wildcard from the netloc, we'll reassamble the value later to pass to URI Validator
if netloc.startswith("*."):
netloc = netloc[2:]
else:
netloc = netloc[1:]

# domains cannot start with a hyphen, but can have them in the middle, so we strip hyphens
# after the wildcard so the final domain is valid and will succeed in URIVAlidator
if netloc.startswith("-"):
netloc = netloc[1:]

# we stripped the wildcard from the netloc and path if they were allowed and present since they would
# fail validation we'll reassamble the URI to pass to the URIValidator
reassambled_uri = f"{scheme}://{netloc}{path}"
if query:
reassambled_uri += f"?{query}"
if fragment:
reassambled_uri += f"#{fragment}"

try:
super().__call__(value)
super().__call__(reassambled_uri)
except ValidationError as e:
raise ValidationError(
"%(name)s URI validation error. %(cause)s: %(value)s",
Expand Down
145 changes: 145 additions & 0 deletions tests/test_application_views.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,151 @@ def test_application_registration_user(self):
self.assertEqual(app.algorithm, form_data["algorithm"])


@pytest.mark.usefixtures("oauth2_settings")
@pytest.mark.oauth2_settings({"ALLOW_REDIRECT_URI_WILDCARDS": True})
class TestApplicationRegistrationViewRedirectURIWithWildcardRedirectURIs(BaseTest):
def _test_valid(self, redirect_uri):
self.client.login(username="foo_user", password="123456")

form_data = {
"name": "Foo app",
"client_id": "client_id",
"client_secret": "client_secret",
"client_type": Application.CLIENT_CONFIDENTIAL,
"redirect_uris": redirect_uri,
"post_logout_redirect_uris": "http://example.com",
"authorization_grant_type": Application.GRANT_AUTHORIZATION_CODE,
"algorithm": "",
}

response = self.client.post(reverse("oauth2_provider:register"), form_data)
self.assertEqual(response.status_code, 302)

app = get_application_model().objects.get(name="Foo app")
self.assertEqual(app.user.username, "foo_user")
app = Application.objects.get()
self.assertEqual(app.name, form_data["name"])
self.assertEqual(app.client_id, form_data["client_id"])
self.assertEqual(app.redirect_uris, form_data["redirect_uris"])
self.assertEqual(app.post_logout_redirect_uris, form_data["post_logout_redirect_uris"])
self.assertEqual(app.client_type, form_data["client_type"])
self.assertEqual(app.authorization_grant_type, form_data["authorization_grant_type"])
self.assertEqual(app.algorithm, form_data["algorithm"])

def _test_invalid(self, uri, error_message):
self.client.login(username="foo_user", password="123456")

form_data = {
"name": "Foo app",
"client_id": "client_id",
"client_secret": "client_secret",
"client_type": Application.CLIENT_CONFIDENTIAL,
"redirect_uris": uri,
"post_logout_redirect_uris": "http://example.com",
"authorization_grant_type": Application.GRANT_AUTHORIZATION_CODE,
"algorithm": "",
}

response = self.client.post(reverse("oauth2_provider:register"), form_data)
self.assertEqual(response.status_code, 400)
self.assertContains(response, error_message)

def test_application_registration_valid_3ld_wildcard(self):
self._test_valid("http://*.example.com")

def test_application_registration_valid_3ld_partial_wildcard(self):
self._test_valid("http://*-partial.example.com")

def test_application_registration_invalid_tld_wildcard(self):
self._test_invalid("http://*", "Wildcard redirect URIs must be at least 3 levels deep")

def test_application_registration_invalid_tld_partial_wildcard(self):
self._test_invalid("http://*-partial", "Wildcard redirect URIs must be at least 3 levels deep")

def test_application_registration_invalid_tld_not_startswith_wildcard_tld(self):
self._test_invalid("http://example.*", "Wildcard redirect URIs must start with a wildcard character")

def test_application_registration_invalid_2ld_wildcard(self):
self._test_invalid("http://*.com", "Wildcard redirect URIs must be at least 3 levels deep")

def test_application_registration_invalid_2ld_partial_wildcard(self):
self._test_invalid("http://*-partial.com", "Wildcard redirect URIs must be at least 3 levels deep")

def test_application_registration_invalid_2ld_not_startswith_wildcard_tld(self):
self._test_invalid(
"http://example.*.com", "Wildcard redirect URIs must start with a wildcard character"
)

def test_application_registration_invalid_3ld_partial_not_startswith_wildcard_2ld(self):
self._test_invalid(
"http://invalid-*.example.com", "Wildcard redirect URIs must start with a wildcard character"
)

def test_application_registration_invalid_4ld_not_startswith_wildcard_3ld(self):
self._test_invalid(
"http://invalid/.*.invalid.example.com",
"Wildcard redirect URIs must start with a wildcard character",
)

def test_application_registration_invalid_4ld_partial_not_startswith_wildcard_2ld(self):
self._test_invalid(
"http://invalid-*.invalid.example.com",
"Wildcard redirect URIs must start with a wildcard character",
)


@pytest.mark.usefixtures("oauth2_settings")
@pytest.mark.oauth2_settings({"ALLOW_REDIRECT_URI_WILDCARDS": True})
class TestApplicationRegistrationViewPostLogoutRedirectURIWithWildcardRedirectURIs(
TestApplicationRegistrationViewRedirectURIWithWildcardRedirectURIs
):
def _test_valid(self, redirect_uri):
self.client.login(username="foo_user", password="123456")

form_data = {
"name": "Foo app",
"client_id": "client_id",
"client_secret": "client_secret",
"client_type": Application.CLIENT_CONFIDENTIAL,
"redirect_uris": "http://example.com",
"post_logout_redirect_uris": redirect_uri,
"authorization_grant_type": Application.GRANT_AUTHORIZATION_CODE,
"algorithm": "",
}

response = self.client.post(reverse("oauth2_provider:register"), form_data)
self.assertEqual(response.status_code, 302)

app = get_application_model().objects.get(name="Foo app")
self.assertEqual(app.user.username, "foo_user")
app = Application.objects.get()
self.assertEqual(app.name, form_data["name"])
self.assertEqual(app.client_id, form_data["client_id"])
self.assertEqual(app.redirect_uris, form_data["redirect_uris"])
self.assertEqual(app.post_logout_redirect_uris, form_data["post_logout_redirect_uris"])
self.assertEqual(app.client_type, form_data["client_type"])
self.assertEqual(app.authorization_grant_type, form_data["authorization_grant_type"])
self.assertEqual(app.algorithm, form_data["algorithm"])

def _test_invalid(self, uri, error_message):
self.client.login(username="foo_user", password="123456")

form_data = {
"name": "Foo app",
"client_id": "client_id",
"client_secret": "client_secret",
"client_type": Application.CLIENT_CONFIDENTIAL,
"redirect_uris": "http://example.com",
"post_logout_redirect_uris": uri,
"authorization_grant_type": Application.GRANT_AUTHORIZATION_CODE,
"algorithm": "",
}

response = self.client.post(reverse("oauth2_provider:register"), form_data)
self.assertEqual(response.status_code, 400)
self.assertContains(response, error_message)


class TestApplicationViews(BaseTest):
@classmethod
def _create_application(cls, name, user):
Expand Down
Loading

0 comments on commit 3f9215c

Please sign in to comment.