Skip to content

Commit

Permalink
Merge pull request from GHSA-3v5m-qmm9-3c6c
Browse files Browse the repository at this point in the history
Mitigate CSRF vulnerability
  • Loading branch information
ericholscher authored Jun 15, 2021
2 parents d61b921 + 5610045 commit 43f3303
Show file tree
Hide file tree
Showing 3 changed files with 197 additions and 72 deletions.
98 changes: 51 additions & 47 deletions readthedocs/core/signals.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
# -*- coding: utf-8 -*-

"""Signal handling for core app."""

import logging
Expand All @@ -16,14 +14,13 @@
from readthedocs.core.unresolver import unresolve
from readthedocs.projects.models import Domain, Project


log = logging.getLogger(__name__)

ALLOWED_URLS = [
'/api/v2/footer_html',
'/api/v2/search',
'/api/v2/docsearch',
'/api/v2/sustainability',
'/api/v2/embed',
]

webhook_github = Signal(providing_args=['project', 'data', 'event'])
Expand All @@ -34,33 +31,51 @@
post_collectstatic = Signal()


def _has_donate_app():
"""
Check if the current app has the sustainability API.
This is a separate function so it's easy to mock.
"""
return 'readthedocsext.donate' in settings.INSTALLED_APPS


def decide_if_cors(sender, request, **kwargs): # pylint: disable=unused-argument
"""
Decide whether a request should be given CORS access.
This checks that:
* The URL is whitelisted against our CORS-allowed domains
* The Domain exists in our database, and belongs to the project being queried.
Allow the request if:
Returns True when a request should be given CORS access.
* It's a safe HTTP method
* The origin is in ALLOWED_URLS
* The URL is owned by the project that they are requesting data from
* The version is public or the domain is linked to the project
(except for the embed API).
.. note::
Requests from the sustainability API are always allowed
if the donate app is installed.
:returns: `True` when a request should be given CORS access.
"""
if 'HTTP_ORIGIN' not in request.META:
if 'HTTP_ORIGIN' not in request.META or request.method not in SAFE_METHODS:
return False

host = urlparse(request.META['HTTP_ORIGIN']).netloc.split(':')[0]

# Don't do domain checking for this API for now
if request.path_info.startswith('/api/v2/sustainability'):
# Always allow the sustainability API,
# it's used only on .org to check for ad-free users.
if _has_donate_app() and request.path_info.startswith('/api/v2/sustainability'):
return True

# Don't do domain checking for APIv2 when the Domain is known
if request.path_info.startswith('/api/v2/') and request.method in SAFE_METHODS:
domain = Domain.objects.filter(domain__icontains=host)
if domain.exists():
return True
valid_url = None
for url in ALLOWED_URLS:
if request.path_info.startswith(url):
valid_url = url
break

# Check for Embed API, allowing CORS on public projects
# since they are already public
if request.path_info.startswith('/api/v2/embed'):
if valid_url:
url = request.GET.get('url')
if url:
unresolved = unresolve(url)
Expand All @@ -74,7 +89,7 @@ def decide_if_cors(sender, request, **kwargs): # pylint: disable=unused-argumen
if project and version_slug:
# This is from IsAuthorizedToViewVersion,
# we should abstract is a bit perhaps?
has_access = (
is_public = (
Version.objects
.public(
project=project,
Expand All @@ -83,36 +98,25 @@ def decide_if_cors(sender, request, **kwargs): # pylint: disable=unused-argumen
.filter(slug=version_slug)
.exists()
)
if has_access:
# Allowing CORS on public versions,
# since they are already public.
if is_public:
return True

return False

valid_url = False
for url in ALLOWED_URLS:
if request.path_info.startswith(url):
valid_url = True
break

if valid_url:
project_slug = request.GET.get('project', None)
try:
project = Project.objects.get(slug=project_slug)
except Project.DoesNotExist:
log.warning(
'Invalid project passed to domain. [%s:%s]',
project_slug,
host,
# Don't check for known domains for the embed api.
# It gives a lot of information,
# we should use a list of trusted domains from the user.
if valid_url == '/api/v2/embed':
return False

# Or allow if they have a registered domain
# linked to that project.
domain = Domain.objects.filter(
Q(domain__iexact=host),
Q(project=project) | Q(project__subprojects__child=project),
)
return False

domain = Domain.objects.filter(
Q(domain__icontains=host),
Q(project=project) | Q(project__subprojects__child=project),
)
if domain.exists():
return True

if domain.exists():
return True
return False


Expand Down
159 changes: 139 additions & 20 deletions readthedocs/rtd_tests/tests/test_middleware.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,22 @@
from unittest import mock

from corsheaders.middleware import CorsMiddleware
from django.conf import settings
from django.contrib.auth.middleware import AuthenticationMiddleware
from django.contrib.sessions.middleware import SessionMiddleware
from django.http import Http404
from django.http import HttpResponse
from django.test import TestCase
from django.test import TestCase, override_settings
from django.test.client import RequestFactory
from django.test.utils import override_settings
from django.urls.base import get_urlconf, set_urlconf
from django_dynamic_fixture import get

from readthedocs.builds.constants import LATEST
from readthedocs.core.middleware import ReadTheDocsSessionMiddleware
from readthedocs.projects.constants import PRIVATE, PUBLIC
from readthedocs.projects.models import Domain, Project, ProjectRelationship
from readthedocs.rtd_tests.utils import create_user


@override_settings(
PUBLIC_DOMAIN='readthedocs.io',
)
class TestCORSMiddleware(TestCase):

def setUp(self):
Expand All @@ -24,36 +26,95 @@ def setUp(self):
self.owner = create_user(username='owner', password='test')
self.project = get(
Project, slug='pip',
users=[self.owner], privacy_level='public',
users=[self.owner],
privacy_level=PUBLIC,
main_language_project=None,
)
self.project.versions.update(privacy_level=PUBLIC)
self.version = self.project.versions.get(slug=LATEST)
self.subproject = get(
Project,
users=[self.owner],
privacy_level='public',
privacy_level=PUBLIC,
main_language_project=None,
)
self.subproject.versions.update(privacy_level=PUBLIC)
self.version_subproject = self.subproject.versions.get(slug=LATEST)
self.relationship = get(
ProjectRelationship,
parent=self.project,
child=self.subproject,
)
self.domain = get(Domain, domain='my.valid.domain', project=self.project)
self.domain = get(
Domain,
domain='my.valid.domain',
project=self.project,
)
self.another_project = get(
Project,
privacy_level=PUBLIC,
slug='another',
)
self.another_project.versions.update(privacy_level=PUBLIC)
self.another_version = self.another_project.versions.get(slug=LATEST)
self.another_domain = get(
Domain,
domain='another.valid.domain',
project=self.another_project,
)

def test_proper_domain(self):
def test_allow_linked_domain_from_public_version(self):
request = self.factory.get(
self.url,
{'project': self.project.slug},
{'project': self.project.slug, 'version': self.version.slug},
HTTP_ORIGIN='http://my.valid.domain',
)
resp = self.middleware.process_response(request, {})
self.assertIn('Access-Control-Allow-Origin', resp)

def test_invalid_domain(self):
def test_allow_linked_domain_from_private_version(self):
self.version.privacy_level = PRIVATE
self.version.save()
request = self.factory.get(
self.url,
{'project': self.project.slug},
HTTP_ORIGIN='http://invalid.domain',
{'project': self.project.slug, 'version': self.version.slug},
HTTP_ORIGIN='http://my.valid.domain',
)
resp = self.middleware.process_response(request, {})
self.assertIn('Access-Control-Allow-Origin', resp)

def test_allowed_api_public_version_from_another_domain(self):
request = self.factory.get(
self.url,
{'project': self.project.slug, 'version': self.version.slug},
HTTP_ORIGIN='http://docs.another.domain',
)
resp = self.middleware.process_response(request, {})
self.assertIn('Access-Control-Allow-Origin', resp)

request = self.factory.get(
self.url,
{'project': self.project.slug, 'version': self.version.slug},
HTTP_ORIGIN='http://another.valid.domain',
)
resp = self.middleware.process_response(request, {})
self.assertIn('Access-Control-Allow-Origin', resp)

def test_not_allowed_api_private_version_from_another_domain(self):
self.version.privacy_level = PRIVATE
self.version.save()
request = self.factory.get(
self.url,
{'project': self.project.slug, 'version': self.version.slug},
HTTP_ORIGIN='http://docs.another.domain',
)
resp = self.middleware.process_response(request, {})
self.assertNotIn('Access-Control-Allow-Origin', resp)

request = self.factory.get(
self.url,
{'project': self.project.slug, 'version': self.version.slug},
HTTP_ORIGIN='http://another.valid.domain',
)
resp = self.middleware.process_response(request, {})
self.assertNotIn('Access-Control-Allow-Origin', resp)
Expand All @@ -67,34 +128,92 @@ def test_valid_subproject(self):
)
request = self.factory.get(
self.url,
{'project': self.subproject.slug},
{'project': self.project.slug, 'version': self.version.slug},
HTTP_ORIGIN='http://my.valid.domain',
)
resp = self.middleware.process_response(request, {})
self.assertIn('Access-Control-Allow-Origin', resp)

def test_apiv2_endpoint_allowed(self):
def test_embed_api_private_version_linked_domain(self):
self.version.privacy_level = PRIVATE
self.version.save()
request = self.factory.get(
'/api/v2/version/',
{'project__slug': self.project.slug, 'active': True},
'/api/v2/embed/',
{'project': self.project.slug, 'version': self.version.slug},
HTTP_ORIGIN='http://my.valid.domain',
)
resp = self.middleware.process_response(request, {})
self.assertNotIn('Access-Control-Allow-Origin', resp)

@mock.patch('readthedocs.core.signals._has_donate_app')
def test_sustainability_endpoint_allways_allowed(self, has_donate_app):
has_donate_app.return_value = True
request = self.factory.get(
'/api/v2/sustainability/',
{'project': self.project.slug, 'active': True, 'version': self.version.slug},
HTTP_ORIGIN='http://invalid.domain',
)
resp = self.middleware.process_response(request, {})
self.assertIn('Access-Control-Allow-Origin', resp)

request = self.factory.get(
'/api/v2/sustainability/',
{'project': self.project.slug, 'active': True, 'version': self.version.slug},
HTTP_ORIGIN='http://my.valid.domain',
)
resp = self.middleware.process_response(request, {})
self.assertIn('Access-Control-Allow-Origin', resp)

@mock.patch('readthedocs.core.signals._has_donate_app')
def test_sustainability_endpoint_no_ext(self, has_donate_app):
has_donate_app.return_value = False
request = self.factory.get(
'/api/v2/sustainability/',
{'project': self.project.slug, 'active': True, 'version': self.version.slug},
HTTP_ORIGIN='http://invalid.domain',
)
resp = self.middleware.process_response(request, {})
self.assertNotIn('Access-Control-Allow-Origin', resp)

request = self.factory.get(
'/api/v2/sustainability/',
{'project': self.project.slug, 'active': True, 'version': self.version.slug},
HTTP_ORIGIN='http://my.valid.domain',
)
resp = self.middleware.process_response(request, {})
self.assertNotIn('Access-Control-Allow-Origin', resp)

def test_apiv2_endpoint_not_allowed(self):
request = self.factory.get(
'/api/v2/version/',
{'project__slug': self.project.slug, 'active': True},
{'project': self.project.slug, 'active': True, 'version': self.version.slug},
HTTP_ORIGIN='http://invalid.domain',
)
resp = self.middleware.process_response(request, {})
self.assertNotIn('Access-Control-Allow-Origin', resp)

# This also doesn't work on registered domains.
request = self.factory.get(
'/api/v2/version/',
{'project': self.project.slug, 'active': True, 'version': self.version.slug},
HTTP_ORIGIN='http://my.valid.domain',
)
resp = self.middleware.process_response(request, {})
self.assertNotIn('Access-Control-Allow-Origin', resp)

# Or from our public domain.
request = self.factory.get(
'/api/v2/version/',
{'project': self.project.slug, 'active': True, 'version': self.version.slug},
HTTP_ORIGIN='http://docs.readthedocs.io/',
)
resp = self.middleware.process_response(request, {})
self.assertNotIn('Access-Control-Allow-Origin', resp)

# POST is not allowed
request = self.factory.post(
'/api/v2/version/',
{'project__slug': self.project.slug, 'active': True},
{'project': self.project.slug, 'active': True, 'version': self.version.slug},
HTTP_ORIGIN='http://my.valid.domain',
)
resp = self.middleware.process_response(request, {})
Expand Down
Loading

0 comments on commit 43f3303

Please sign in to comment.