diff --git a/api/catalog/management/commands/resendoauthverification.py b/api/catalog/management/commands/resendoauthverification.py new file mode 100644 index 000000000..23dfc6beb --- /dev/null +++ b/api/catalog/management/commands/resendoauthverification.py @@ -0,0 +1,225 @@ +import argparse +from dataclasses import dataclass + +from django.conf import settings +from django.core.mail import send_mail +from django.db import transaction +from django.db.models import Q +from rest_framework.reverse import reverse + +from django_redis import get_redis_connection +from django_tqdm import BaseCommand + +from catalog.api.models.oauth import ( + OAuth2Registration, + OAuth2Verification, + ThrottledApplication, +) + + +def get_input(text): + """ + Wrapped ``input`` to allow patching in unittests + """ + return input(text) + + +verification_msg_template = """ +The Openverse API OAuth2 email verification process has recently been fixed. +We have detected that you attempted to register an application using this email. + +To verify your Openverse API credentials, click on the following link: + +{link} + +If you believe you received this message in error, please disregard it. +""" + + +@dataclass +class Result: + saved_application_name: str + deleted_applications: int + deleted_registrations: int + deleted_verifications: int + + +class Command(BaseCommand): + help = "Resends verification emails for unverified Oauth applications." + """ + This command is meant to be used a single time in production to remediate + failed email sending. + + It stores a cache of successfully sent emails in Redis, so running it multiple + times (in case of failure) should not be an issue. + """ + + processed_key = "resendoauthverification:processed" + + def add_arguments(self, parser): + parser.add_argument( + "--dry_run", + help="Count the records that will be removed but don't apply any changes.", + type=bool, + default=True, + action=argparse.BooleanOptionalAction, + ) + + @transaction.atomic + def _handle_email(self, email, dry): + """ + 1. Get all application IDs for the email + 2. Use the one with the lowest ID as the "original" attempt + 3. Delete the rest + 4. Delete OAuth2Registrations for the email not associated + with the "original" application + 5. Delete OAuth2Verifications for the email not associated + with the "original" application + + This ignores the fact that someone could have tried to register + multiple unverified but distinct applications under the same email. + This is unlikely given that none of the requests would have worked + and that the "feature" isn't explicitly documented anyway. + """ + application_ids = list( + OAuth2Registration.objects.filter(email=email) + .select_related("associated_application") + .order_by("id") + .values_list("pk", flat=True) + ) + + application_to_verify = ThrottledApplication.objects.get(pk=application_ids[0]) + + deleted_applications = 0 + deleted_registrations = 0 + deleted_verifications = 0 + if len(application_ids) > 1: + applications_to_delete_ids = application_ids[1:] + deleted_applications = len(applications_to_delete_ids) + if not dry: + ThrottledApplication.objects.filter( + pk__in=applications_to_delete_ids + ).delete() + + registrations_to_delete = OAuth2Registration.objects.filter( + email=email + ).exclude(name=application_to_verify.name) + deleted_registrations = registrations_to_delete.count() + if not dry: + registrations_to_delete.delete() + + verifications_to_delete = OAuth2Verification.objects.filter( + email=email + ).exclude(associated_application=application_to_verify) + deleted_verifications = verifications_to_delete.count() + if not dry: + verifications_to_delete.delete() + + if not dry: + verification = OAuth2Verification.objects.get( + associated_application=application_to_verify + ) + token = verification.code + # We don't have access to `request.build_absolute_uri` so we + # have to build it ourselves for the production endpoint + link = ( + f"https://api.openverse.engineering/{reverse('verify-email', [token])}" + ) + verification_msg = verification_msg_template.format(link=link) + send_mail( + subject="Verify your API credentials", + message=verification_msg, + from_email=settings.DEFAULT_FROM_EMAIL, + recipient_list=[verification.email], + fail_silently=False, + ) + + return Result( + saved_application_name=application_to_verify.name, + deleted_applications=deleted_applications, + deleted_registrations=deleted_registrations, + deleted_verifications=deleted_verifications, + ) + + def handle(self, *args, **options): + dry = options["dry_run"] + if not dry: + self.info( + self.style.WARNING( + "This is NOT a dry run. Are you sure you wish to proceed? " + "Respond 'yes' in all uppercase to proceed.\n" + ) + ) + if get_input(": ") != "YES": + self.error("Exiting.") + exit(1) + + redis = get_redis_connection("default") + + already_processed_emails = [ + email.decode("utf-8") for email in redis.smembers(self.processed_key) + ] + + emails_with_verified_applications = OAuth2Verification.objects.filter( + Q(associated_application__verified=True) + | Q(email__in=already_processed_emails) + ).values_list("email", flat=True) + + emails_with_zero_verified_applications = list( + OAuth2Verification.objects.exclude( + email__in=emails_with_verified_applications + ) + .values_list("email", flat=True) + .distinct() + ) + + count_to_process = len(emails_with_zero_verified_applications) + results = [] + errored_emails = [] + + with self.tqdm(total=count_to_process) as progress: + for email in emails_with_zero_verified_applications: + try: + results.append(self._handle_email(email, dry)) + if not dry: + redis.sadd(self.processed_key, email) + except BaseException as err: + errored_emails.append(email) + self.error(f"Unable to process {email}: " f"{err}") + + progress.update(1) + + if errored_emails: + joined = "\n".join(errored_emails) + self.info( + self.style.WARNING( + f"The following emails were unable to be processed.\n\n" + f"{joined}" + "\n\nPlease check the output above for the error related" + "to each email." + ) + ) + + formatted_results = "\n\n".join( + ( + f"Application name: {result.saved_application_name}\n" + f"Cleaned related application count: {result.deleted_applications}\n" + f"Cleaned related verification count: {result.deleted_verifications}\n" + f"Cleaned related registration count: {result.deleted_registrations}\n" + ) + for result in results + ) + + self.info( + self.style.SUCCESS( + f"The following applications had email verifications sent.\n\n" + f"{formatted_results}" + ) + ) + + if dry: + self.info( + self.style.WARNING( + "The above was a dry run and no records were actually affected." + ) + ) diff --git a/api/test/factory/models/oauth2.py b/api/test/factory/models/oauth2.py index 43818ea00..6ee5a511f 100644 --- a/api/test/factory/models/oauth2.py +++ b/api/test/factory/models/oauth2.py @@ -4,13 +4,18 @@ from factory.django import DjangoModelFactory from oauth2_provider.models import AccessToken -from catalog.api.models.oauth import ThrottledApplication +from catalog.api.models.oauth import ( + OAuth2Registration, + OAuth2Verification, + ThrottledApplication, +) class ThrottledApplicationFactory(DjangoModelFactory): class Meta: model = ThrottledApplication + name = Faker("md5") client_type = Faker( "random_choice_field", choices=ThrottledApplication.CLIENT_TYPES ) @@ -19,6 +24,24 @@ class Meta: ) +class OAuth2RegistrationFactory(DjangoModelFactory): + class Meta: + model = OAuth2Registration + + name = Faker("md5") + description = Faker("catch_phrase") + email = Faker("email") + + +class OAuth2VerificationFactory(DjangoModelFactory): + class Meta: + model = OAuth2Verification + + associated_application = factory.SubFactory(ThrottledApplicationFactory) + email = Faker("email") + code = Faker("md5") + + class AccessTokenFactory(DjangoModelFactory): class Meta: model = AccessToken diff --git a/api/test/unit/management/commands/resendoauthverification_test.py b/api/test/unit/management/commands/resendoauthverification_test.py new file mode 100644 index 000000000..8447dc05d --- /dev/null +++ b/api/test/unit/management/commands/resendoauthverification_test.py @@ -0,0 +1,378 @@ +import smtplib +from dataclasses import dataclass +from io import StringIO +from test.factory.models.oauth2 import ( + OAuth2RegistrationFactory, + OAuth2VerificationFactory, + ThrottledApplicationFactory, +) +from unittest import mock + +from django.core.management import call_command +from rest_framework.test import APIRequestFactory + +import pytest +from fakeredis import FakeRedis + +from catalog.api.models.oauth import ( + OAuth2Registration, + OAuth2Verification, + ThrottledApplication, +) +from catalog.api.utils.throttle import ExemptionAwareThrottle +from catalog.api.views.oauth2_views import Register + + +command_module_path = "catalog.management.commands.resendoauthverification" + + +@pytest.fixture(autouse=True) +def redis(monkeypatch) -> FakeRedis: + fake_redis = FakeRedis() + + def get_redis_connection(*args, **kwargs): + return fake_redis + + monkeypatch.setattr( + f"{command_module_path}.get_redis_connection", get_redis_connection + ) + + yield fake_redis + fake_redis.client().close() + + +@dataclass +class CapturedEmail: + message: str + recipient_list: list[str] + + +@pytest.fixture +def captured_emails(monkeypatch) -> list[CapturedEmail]: + captured = [] + + def send_mail(*args, **kwargs): + captured.append( + CapturedEmail( + message=kwargs["message"], + recipient_list=kwargs["recipient_list"], + ) + ) + + monkeypatch.setattr(f"{command_module_path}.send_mail", send_mail) + + yield captured + + +@pytest.fixture +def failed_emails(monkeypatch) -> list[CapturedEmail]: + failed = [] + + def send_mail(*args, **kwargs): + failed.append( + CapturedEmail( + message=kwargs["message"], + recipient_list=kwargs["recipient_list"], + ) + ) + raise smtplib.SMTPAuthenticationError(1, "beep boop bad password") + + monkeypatch.setattr(f"{command_module_path}.send_mail", send_mail) + + yield failed + + +@dataclass +class OAuthGroup: + registration: OAuth2Registration + application: ThrottledApplication + verification: OAuth2Verification + + +def cohesive_verification(email=None, verified=False) -> OAuthGroup: + """ + Generate a registration, application, and verification. + + Optionally associate it with a specific email. + """ + options = {} + if email: + options.update(email=email) + + registration = OAuth2RegistrationFactory.create(**options) + + application = ThrottledApplicationFactory.create( + name=registration.name, verified=verified + ) + + verification = OAuth2VerificationFactory.create( + email=registration.email, associated_application=application + ) + + return OAuthGroup( + registration=registration, application=application, verification=verification + ) + + +@dataclass +class CleanableEmail: + email: str + keep_group: OAuthGroup + clean_groups: list[OAuthGroup] + + +def make_cleanable_email(): + keep = cohesive_verification() + clean = [cohesive_verification(email=keep.registration.email) for _ in range(10)] + + return CleanableEmail( + email=keep.registration.email, keep_group=keep, clean_groups=clean + ) + + +@pytest.fixture +def cleanable_email(): + return make_cleanable_email() + + +def is_group_captured(email: CapturedEmail, group: OAuthGroup) -> bool: + return ( + group.verification.code in email.message + and [group.registration.email] == email.recipient_list + ) + + +def count_captured_emails_for_group( + captured_emails: list[CapturedEmail], oauth_group: OAuthGroup +) -> int: + count = 0 + for email in captured_emails: + if is_group_captured(email, oauth_group): + count += 1 + + return count + + +def assert_one_email_sent( + captured_emails: list[CapturedEmail], oauth_group: OAuthGroup +): + assert count_captured_emails_for_group(captured_emails, oauth_group) == 1 + + +def assert_cleaned_and_sent( + cleanable_email: CleanableEmail, captured_emails: list[CapturedEmail] +): + keep = cleanable_email.keep_group + assert OAuth2Registration.objects.filter(pk=keep.registration.pk).exists() is True + assert OAuth2Verification.objects.filter(pk=keep.verification.pk).exists() is True + assert ThrottledApplication.objects.filter(pk=keep.application.pk).exists() is True + + for cleaned in cleanable_email.clean_groups: + assert ( + OAuth2Registration.objects.filter(pk=cleaned.registration.pk).exists() + is False + ) + assert ( + OAuth2Verification.objects.filter(pk=cleaned.verification.pk).exists() + is False + ) + assert ( + ThrottledApplication.objects.filter(pk=cleaned.application.pk).exists() + is False + ) + + assert_one_email_sent(captured_emails, keep) + + +def call_resendoauthverification(input_response="YES", **options): + out = StringIO() + err = StringIO() + options.update(stdout=out, stderr=err) + with mock.patch(f"{command_module_path}.get_input", return_value=input_response): + call_command("resendoauthverification", **options) + + res = out.getvalue(), err.getvalue() + print(res) + + return res + + +@pytest.mark.parametrize( + "return_value", + ( + None, + "", + "no" "NO", + "yes", # must be exactly YES + ), +) +def test_should_exit_if_wet_unconfirmed(return_value): + with pytest.raises(SystemExit): + call_resendoauthverification(input_response=return_value, dry_run=False) + + +@pytest.mark.django_db +def test_should_continue_if_wet_confirmed_with_YES(captured_emails, cleanable_email): + call_resendoauthverification(input_response="YES", dry_run=False) + assert_cleaned_and_sent(cleanable_email, captured_emails) + + +@pytest.mark.django_db +def test_should_clean_for_several_emails(captured_emails): + cleanables = [make_cleanable_email() for _ in range(10)] + call_resendoauthverification(dry_run=False) + for cleanable in cleanables: + assert_cleaned_and_sent(cleanable, captured_emails) + + +@pytest.mark.django_db +def test_should_not_resend_for_already_sent(captured_emails): + cleanables = [make_cleanable_email() for _ in range(10)] + call_resendoauthverification(dry_run=False) + for cleanable in cleanables: + assert_cleaned_and_sent(cleanable, captured_emails) + call_resendoauthverification(dry_run=False) + for cleanable in cleanables: + assert_one_email_sent(captured_emails, cleanable.keep_group) + + +@pytest.mark.django_db +def test_should_not_count_email_as_sent_if_failed_and_rollback( + failed_emails, cleanable_email, redis +): + call_resendoauthverification(dry_run=False) + assert ( + count_captured_emails_for_group(failed_emails, cleanable_email.keep_group) == 1 + ) + + keep = cleanable_email.keep_group + assert OAuth2Registration.objects.filter(pk=keep.registration.pk).exists() is True + assert OAuth2Verification.objects.filter(pk=keep.verification.pk).exists() is True + assert ThrottledApplication.objects.filter(pk=keep.application.pk).exists() is True + + # Assert these all still exist + for cleaned in cleanable_email.clean_groups: + assert ( + OAuth2Registration.objects.filter(pk=cleaned.registration.pk).exists() + is True + ) + assert ( + OAuth2Verification.objects.filter(pk=cleaned.verification.pk).exists() + is True + ) + assert ( + ThrottledApplication.objects.filter(pk=cleaned.application.pk).exists() + is True + ) + + assert ( + redis.sismember("resendoauthverification:processed", keep.registration.email) + is False + ) + + +@pytest.mark.django_db +def test_should_not_delete_or_send_if_dry_run(cleanable_email, captured_emails, redis): + call_resendoauthverification(dry_run=True) + assert ( + count_captured_emails_for_group(captured_emails, cleanable_email.keep_group) + == 0 + ) + + keep = cleanable_email.keep_group + assert OAuth2Registration.objects.filter(pk=keep.registration.pk).exists() is True + assert OAuth2Verification.objects.filter(pk=keep.verification.pk).exists() is True + assert ThrottledApplication.objects.filter(pk=keep.application.pk).exists() is True + + # Assert these all still exist (no clean up has happened) + for cleaned in cleanable_email.clean_groups: + assert ( + OAuth2Registration.objects.filter(pk=cleaned.registration.pk).exists() + is True + ) + assert ( + OAuth2Verification.objects.filter(pk=cleaned.verification.pk).exists() + is True + ) + assert ( + ThrottledApplication.objects.filter(pk=cleaned.application.pk).exists() + is True + ) + + assert ( + redis.sismember("resendoauthverification:processed", keep.registration.email) + is False + ) + + +@pytest.mark.django_db +def test_should_not_send_for_verified_emails(cleanable_email, captured_emails): + verified = cohesive_verification(verified=True) + + call_resendoauthverification(dry_run=False) + assert count_captured_emails_for_group(captured_emails, verified) == 0 + assert_cleaned_and_sent(cleanable_email, captured_emails) + + +def register_with_email_times(email: str, times: int) -> list: + request_factory = APIRequestFactory() + requests = [ + request_factory.post( + "/", + data={ + "name": f"{email}'s sweet app #{i}", + "email": email, + "description": f"{email}'s sweet app", + }, + ) + for i in range(times) + ] + + view = Register.as_view() + return [view(request) for request in requests] + + +@pytest.mark.django_db +def test_create_tokens_with_view(captured_emails): + emails = [ + "app_developer@example.org", + "data_scientist@example.org", + "pen_tester@example.org", + ] + + with mock.patch("catalog.api.views.oauth2_views.send_mail"): + with mock.patch.object( + ExemptionAwareThrottle, "allow_request", return_value=True + ): + for email in emails: + responses = register_with_email_times(email, 10) + for response in responses: + assert response.status_code == 201 + + # assert everything was created in the registration view + for email in emails: + verifications = OAuth2Verification.objects.filter(email=email).select_related( + "associated_application" + ) + assert verifications.count() == 10 + assert OAuth2Registration.objects.filter(email=email).count() == 10 + for verification in verifications: + assert verification.associated_application is not None + + call_resendoauthverification(dry_run=False) + + for email in emails: + verifications = OAuth2Verification.objects.filter(email=email).select_related( + "associated_application" + ) + assert verifications.count() == 1 + assert OAuth2Registration.objects.filter(email=email).count() == 1 + for verification in verifications: + assert verification.associated_application is not None + + assert ( + ThrottledApplication.objects.filter( + name__contains=f"{email}'s sweet app" + ).count() + == 1 + )