diff --git a/src/sentry/services/hybrid_cloud/import_export/impl.py b/src/sentry/services/hybrid_cloud/import_export/impl.py index 60cda917244a21..9287436c34fd5b 100644 --- a/src/sentry/services/hybrid_cloud/import_export/impl.py +++ b/src/sentry/services/hybrid_cloud/import_export/impl.py @@ -263,55 +263,67 @@ def import_by_model( reason=str(e), ) - # If we wrote at least one model, make sure to write an appropriate `ImportChunk` + # If the `counter` is at 0, no model instances were actually imported, so we can + # return early. + if counter == 0: + return RpcImportOk( + mapped_pks=RpcPrimaryKeyMap.into_rpc(out_pk_map), + min_ordinal=None, + max_ordinal=None, + min_source_pk=None, + max_source_pk=None, + min_inserted_pk=None, + max_inserted_pk=None, + ) + + # We wrote at least one model, so make sure to write an appropriate `ImportChunk` # and update the sequences too. - if counter > 0: - table = model_instance._meta.db_table - seq = f"{table}_id_seq" - with connections[using].cursor() as cursor: - cursor.execute(f"SELECT setval(%s, (SELECT MAX(id) FROM {table}))", [seq]) - - inserted = out_pk_map.partition( - {batch_model_name}, {ImportKind.Inserted} - ).mapping[model_name] - existing = out_pk_map.partition( - {batch_model_name}, {ImportKind.Existing} - ).mapping[model_name] - overwrite = out_pk_map.partition( - {batch_model_name}, {ImportKind.Overwrite} - ).mapping[model_name] - import_chunk_args = { - "import_uuid": flags.import_uuid, - "model": model_name, - # TODO(getsentry/team-ospo#190): The next two fields assume the entire model - # is being imported in a single call; we may change this in the future. - "min_ordinal": 1, - "max_ordinal": counter, - "min_source_pk": min_old_pk, - "max_source_pk": max_old_pk, - "min_inserted_pk": min_inserted_pk, - "max_inserted_pk": max_inserted_pk, - "inserted_map": {k: v[0] for k, v in inserted.items()}, - "existing_map": {k: v[0] for k, v in existing.items()}, - "overwrite_map": {k: v[0] for k, v in overwrite.items()}, - "inserted_identifiers": { - k: v[2] for k, v in inserted.items() if v[2] is not None - }, - } - if import_chunk_type == ControlImportChunk: - ControlImportChunk(**import_chunk_args).save() - else: - RegionImportChunk(**import_chunk_args).save() - - return RpcImportOk( - mapped_pks=RpcPrimaryKeyMap.into_rpc(out_pk_map), - min_ordinal=1, - max_ordinal=counter, - min_source_pk=min_old_pk, - max_source_pk=max_old_pk, - min_inserted_pk=min_inserted_pk, - max_inserted_pk=max_inserted_pk, - ) + table = model_instance._meta.db_table + seq = f"{table}_id_seq" + with connections[using].cursor() as cursor: + cursor.execute(f"SELECT setval(%s, (SELECT MAX(id) FROM {table}))", [seq]) + + inserted = out_pk_map.partition({batch_model_name}, {ImportKind.Inserted}).mapping[ + model_name + ] + existing = out_pk_map.partition({batch_model_name}, {ImportKind.Existing}).mapping[ + model_name + ] + overwrite = out_pk_map.partition( + {batch_model_name}, {ImportKind.Overwrite} + ).mapping[model_name] + import_chunk_args = { + "import_uuid": flags.import_uuid, + "model": model_name, + # TODO(getsentry/team-ospo#190): The next two fields assume the entire model is + # being imported in a single call; we may change this in the future. + "min_ordinal": 1, + "max_ordinal": counter, + "min_source_pk": min_old_pk, + "max_source_pk": max_old_pk, + "min_inserted_pk": min_inserted_pk, + "max_inserted_pk": max_inserted_pk, + "inserted_map": {k: v[0] for k, v in inserted.items()}, + "existing_map": {k: v[0] for k, v in existing.items()}, + "overwrite_map": {k: v[0] for k, v in overwrite.items()}, + "inserted_identifiers": { + k: v[2] for k, v in inserted.items() if v[2] is not None + }, + } + if import_chunk_type == ControlImportChunk: + ControlImportChunk(**import_chunk_args).save() + else: + RegionImportChunk(**import_chunk_args).save() + + return RpcImportOk( + mapped_pks=RpcPrimaryKeyMap.into_rpc(out_pk_map), + min_ordinal=1, + max_ordinal=counter, + min_source_pk=min_old_pk, + max_source_pk=max_old_pk, + min_inserted_pk=min_inserted_pk, + max_inserted_pk=max_inserted_pk, + ) except DeserializationError: return RpcImportError( diff --git a/src/sentry/tasks/relocation.py b/src/sentry/tasks/relocation.py index 60c8e213a68786..cf5160773e9502 100644 --- a/src/sentry/tasks/relocation.py +++ b/src/sentry/tasks/relocation.py @@ -240,19 +240,18 @@ def preprocessing_scan(uuid: str) -> None: # Decrypt the DEK using Google KMS, and use the decrypted DEK to decrypt the encoded # JSON. - try: + with retry_task_or_fail_relocation( + relocation, + OrderedTask.PREPROCESSING_SCAN, + attempts_left, + ERR_PREPROCESSING_DECRYPTION, + ): decryptor = GCPKMSDecryptor.from_bytes( json.dumps(get_default_crypto_key_version()).encode("utf-8") ) plaintext_data_encryption_key = decryptor.decrypt_data_encryption_key(unwrapped) fernet = Fernet(plaintext_data_encryption_key) json_data = fernet.decrypt(unwrapped.encrypted_json_blob).decode("utf-8") - except Exception: - return fail_relocation( - relocation, - OrderedTask.PREPROCESSING_SCAN, - ERR_PREPROCESSING_DECRYPTION, - ) # Grab usernames and org slugs from the JSON data. usernames = [] diff --git a/src/sentry/testutils/helpers/task_runner.py b/src/sentry/testutils/helpers/task_runner.py index 4f80425e78d526..820e356f7dc606 100644 --- a/src/sentry/testutils/helpers/task_runner.py +++ b/src/sentry/testutils/helpers/task_runner.py @@ -19,6 +19,16 @@ def TaskRunner(): settings.CELERY_ALWAYS_EAGER = prev +class BustTaskRunnerRetryError(Exception): + """ + An exception that mocks can throw, which will bubble to tasks run by the `BurstTaskRunner` and + cause them to be re-queued, rather than failed immediately. Useful for simulating the + `@instrument_task` decorator's retry semantics. + """ + + pass + + @contextmanager def BurstTaskRunner(): """ @@ -40,7 +50,10 @@ def work(max_jobs=None): self, args, kwargs = job_queue.pop(0) with patch("celery.app.task.Task.apply_async", apply_async): - self(*args, **kwargs) + try: + self(*args, **kwargs) + except BustTaskRunnerRetryError: + job_queue.append((self, args, kwargs)) jobs += 1 diff --git a/src/sentry/utils/relocation.py b/src/sentry/utils/relocation.py index d79dee249bcde0..fa4e2bb3e2101e 100644 --- a/src/sentry/utils/relocation.py +++ b/src/sentry/utils/relocation.py @@ -373,6 +373,11 @@ def fail_relocation(relocation: Relocation, task: OrderedTask, reason: str = "") instead. """ + # Another nested exception handler could have already failed this relocation - in this case, do + # nothing. + if relocation.status == Relocation.Status.FAILURE.value: + return + if reason: relocation.failure_reason = reason diff --git a/tests/sentry/tasks/test_relocation.py b/tests/sentry/tasks/test_relocation.py index c3a9ca9bd8b98c..a1229435a49f0c 100644 --- a/tests/sentry/tasks/test_relocation.py +++ b/tests/sentry/tasks/test_relocation.py @@ -55,7 +55,9 @@ ERR_UPLOADING_FAILED, ERR_VALIDATING_INTERNAL, ERR_VALIDATING_MAX_RUNS, + MAX_FAST_TASK_ATTEMPTS, MAX_FAST_TASK_RETRIES, + MAX_VALIDATION_POLL_ATTEMPTS, MAX_VALIDATION_POLLS, LostPasswordHash, completed, @@ -75,6 +77,7 @@ from sentry.testutils.cases import TestCase, TransactionTestCase from sentry.testutils.factories import get_fixture_path from sentry.testutils.helpers.backups import FakeKeyManagementServiceClient, generate_rsa_key_pair +from sentry.testutils.helpers.task_runner import BurstTaskRunner, BustTaskRunnerRetryError from sentry.testutils.silo import assume_test_silo_mode, region_silo_test from sentry.utils import json from sentry.utils.relocation import RELOCATION_BLOB_SIZE, RELOCATION_FILE_TYPE @@ -414,7 +417,14 @@ def test_fail_decryption_failure( self.mock_kms_client(fake_kms_client) fake_kms_client.asymmetric_decrypt.return_value.plaintext += b"\xc3\x28" - preprocessing_scan(self.uuid) + # We retry on decryption failures, just to account for flakiness on the KMS server's side. + # Try this as the last attempt to see the actual error. + self.relocation.latest_task = "PREPROCESSING_SCAN" + self.relocation.latest_task_attempts = MAX_FAST_TASK_RETRIES + self.relocation.save() + + with pytest.raises(Exception): + preprocessing_scan(self.uuid) assert fake_message_builder.call_count == 1 assert fake_message_builder.call_args.kwargs["type"] == "relocation.failed" @@ -422,6 +432,7 @@ def test_fail_decryption_failure( to=[self.owner.email, self.superuser.email] ) + assert fake_kms_client.asymmetric_decrypt.call_count == 1 assert preprocessing_baseline_config_mock.call_count == 0 relocation = Relocation.objects.get(uuid=self.uuid) @@ -1745,8 +1756,9 @@ def test_fail_if_no_attempts_left( # Oh, the irony: sending the "relocation success" email failed, so we send a "relocation # failed" email instead... assert fake_message_builder.call_count == 2 - assert fake_message_builder.call_args_list[0].kwargs["type"] == "relocation.succeeded" - assert fake_message_builder.call_args_list[1].kwargs["type"] == "relocation.failed" + email_types = [args.kwargs["type"] for args in fake_message_builder.call_args_list] + assert "relocation.failed" in email_types + assert "relocation.succeeded" in email_types assert completed_mock.call_count == 0 @@ -1804,36 +1816,37 @@ def setUp(self): f"relocations/runs/{self.relocation.uuid}/findings/{file}", BytesIO(b"[]") ) - def test_valid_no_retries( + def mock_max_retries( self, - fake_message_builder: Mock, fake_cloudbuild_client: FakeCloudBuildClient, fake_kms_client: FakeKeyManagementServiceClient, ): - self.mock_cloudbuild_client(fake_cloudbuild_client, Build.Status(Build.Status.SUCCESS)) - self.mock_kms_client(fake_kms_client) - self.mock_message_builder(fake_message_builder) - org_count = Organization.objects.filter(slug__startswith="testing").count() - - with self.tasks(), patch.object( - LostPasswordHash, "send_relocate_account_email" - ) as mock_relocation_email: - uploading_complete(self.relocation.uuid) - - assert mock_relocation_email.call_count == 2 - - assert fake_cloudbuild_client.create_build.call_count == 1 - assert fake_cloudbuild_client.get_build.call_count == 1 - - assert fake_kms_client.asymmetric_decrypt.call_count == 2 + fake_cloudbuild_client.create_build.side_effect = ( + [BustTaskRunnerRetryError("Retry")] * MAX_FAST_TASK_RETRIES + ) + [fake_cloudbuild_client.create_build.return_value] + + fake_cloudbuild_client.get_build.side_effect = ( + [BustTaskRunnerRetryError("Retry")] * MAX_VALIDATION_POLLS + ) + [fake_cloudbuild_client.get_build.return_value] + + fake_kms_client.asymmetric_decrypt.side_effect = ( + [BustTaskRunnerRetryError("Retry")] * MAX_FAST_TASK_RETRIES + ) + [ + fake_kms_client.asymmetric_decrypt.return_value, + # The second call to `asymmetric_decrypt` occurs from inside the `importing` task, which + # is not retried. + fake_kms_client.asymmetric_decrypt.return_value, + ] - assert fake_message_builder.call_count == 2 - assert fake_message_builder.call_args_list[0].kwargs["type"] == "relocation.succeeded" - assert fake_message_builder.call_args_list[1].kwargs["type"] == "relocation.started" + fake_kms_client.get_public_key.side_effect = ( + [BustTaskRunnerRetryError("Retry")] * MAX_FAST_TASK_RETRIES + ) + [fake_kms_client.get_public_key.return_value] + # Used by two tasks, so repeat the pattern (fail, fail, fail, succeed) twice. + fake_kms_client.get_public_key.side_effect = ( + list(fake_kms_client.get_public_key.side_effect) * 2 + ) - relocation = Relocation.objects.get(uuid=self.uuid) - assert relocation.status == Relocation.Status.SUCCESS.value - assert not relocation.failure_reason + def assert_success_database_state(self, org_count: int): assert Organization.objects.filter(slug__startswith="testing").count() == org_count + 1 assert RegionImportChunk.objects.filter(import_uuid=self.uuid).count() == 9 @@ -1849,6 +1862,7 @@ def test_valid_no_retries( "sentry.team", ] + assert ControlImportChunkReplica.objects.filter(import_uuid=self.uuid).count() == 2 with assume_test_silo_mode(SiloMode.CONTROL): assert ControlImportChunk.objects.filter(import_uuid=self.uuid).count() == 2 assert sorted(ControlImportChunk.objects.values_list("model", flat=True)) == [ @@ -1856,6 +1870,90 @@ def test_valid_no_retries( "sentry.useremail", ] + def assert_failure_database_state(self, org_count: int): + assert Organization.objects.filter(slug__startswith="testing").count() == org_count + assert RegionImportChunk.objects.filter(import_uuid=self.uuid).count() == 0 + + assert ControlImportChunkReplica.objects.filter(import_uuid=self.uuid).count() == 0 + with assume_test_silo_mode(SiloMode.CONTROL): + assert ControlImportChunk.objects.filter(import_uuid=self.uuid).count() == 0 + + def test_valid_no_retries( + self, + fake_message_builder: Mock, + fake_cloudbuild_client: FakeCloudBuildClient, + fake_kms_client: FakeKeyManagementServiceClient, + ): + self.mock_cloudbuild_client(fake_cloudbuild_client, Build.Status(Build.Status.SUCCESS)) + self.mock_kms_client(fake_kms_client) + self.mock_message_builder(fake_message_builder) + org_count = Organization.objects.filter(slug__startswith="testing").count() + + with BurstTaskRunner() as burst: + uploading_complete(self.relocation.uuid) + + with patch.object(LostPasswordHash, "send_relocate_account_email") as mock_relocation_email: + burst() + + assert mock_relocation_email.call_count == 2 + + assert fake_cloudbuild_client.create_build.call_count == 1 + assert fake_cloudbuild_client.get_build.call_count == 1 + + assert fake_kms_client.asymmetric_decrypt.call_count == 2 + assert fake_kms_client.get_public_key.call_count == 2 + + assert fake_message_builder.call_count == 2 + email_types = [args.kwargs["type"] for args in fake_message_builder.call_args_list] + assert "relocation.started" in email_types + assert "relocation.succeeded" in email_types + assert "relocation.failed" not in email_types + + relocation = Relocation.objects.get(uuid=self.uuid) + assert relocation.status == Relocation.Status.SUCCESS.value + assert not relocation.failure_reason + + self.assert_success_database_state(org_count) + + def test_valid_max_retries( + self, + fake_message_builder: Mock, + fake_cloudbuild_client: FakeCloudBuildClient, + fake_kms_client: FakeKeyManagementServiceClient, + ): + self.mock_cloudbuild_client(fake_cloudbuild_client, Build.Status(Build.Status.SUCCESS)) + self.mock_kms_client(fake_kms_client) + self.mock_max_retries(fake_cloudbuild_client, fake_kms_client) + + self.mock_message_builder(fake_message_builder) + org_count = Organization.objects.filter(slug__startswith="testing").count() + + with BurstTaskRunner() as burst: + uploading_complete(self.relocation.uuid) + + with patch.object(LostPasswordHash, "send_relocate_account_email") as mock_relocation_email: + burst() + + assert mock_relocation_email.call_count == 2 + + assert fake_cloudbuild_client.create_build.call_count == MAX_FAST_TASK_ATTEMPTS + assert fake_cloudbuild_client.get_build.call_count == MAX_VALIDATION_POLL_ATTEMPTS + + assert fake_kms_client.asymmetric_decrypt.call_count == MAX_FAST_TASK_ATTEMPTS + 1 + assert fake_kms_client.get_public_key.call_count == 2 * MAX_FAST_TASK_ATTEMPTS + + assert fake_message_builder.call_count == 2 + email_types = [args.kwargs["type"] for args in fake_message_builder.call_args_list] + assert "relocation.started" in email_types + assert "relocation.succeeded" in email_types + assert "relocation.failed" not in email_types + + relocation = Relocation.objects.get(uuid=self.uuid) + assert relocation.status == Relocation.Status.SUCCESS.value + assert not relocation.failure_reason + + self.assert_success_database_state(org_count) + def test_invalid_no_retries( self, fake_message_builder: Mock, @@ -1868,27 +1966,68 @@ def test_invalid_no_retries( mock_invalid_finding(self.storage, self.uuid) org_count = Organization.objects.filter(slug__startswith="testing").count() - with self.tasks(), patch.object( - LostPasswordHash, "send_relocate_account_email" - ) as mock_relocation_email: + with BurstTaskRunner() as burst: uploading_complete(self.relocation.uuid) + with patch.object(LostPasswordHash, "send_relocate_account_email") as mock_relocation_email: + burst() + assert mock_relocation_email.call_count == 0 assert fake_cloudbuild_client.create_build.call_count == 1 assert fake_cloudbuild_client.get_build.call_count == 1 assert fake_kms_client.asymmetric_decrypt.call_count == 1 + assert fake_kms_client.get_public_key.call_count == 2 assert fake_message_builder.call_count == 2 - assert fake_message_builder.call_args_list[0].kwargs["type"] == "relocation.failed" - assert fake_message_builder.call_args_list[1].kwargs["type"] == "relocation.started" + email_types = [args.kwargs["type"] for args in fake_message_builder.call_args_list] + assert "relocation.started" in email_types + assert "relocation.failed" in email_types + assert "relocation.succeeded" not in email_types + + relocation = Relocation.objects.get(uuid=self.uuid) + assert relocation.status == Relocation.Status.FAILURE.value + assert relocation.failure_reason + + self.assert_failure_database_state(org_count) + + def test_invalid_max_retries( + self, + fake_message_builder: Mock, + fake_cloudbuild_client: FakeCloudBuildClient, + fake_kms_client: FakeKeyManagementServiceClient, + ): + self.mock_cloudbuild_client(fake_cloudbuild_client, Build.Status(Build.Status.SUCCESS)) + self.mock_kms_client(fake_kms_client) + self.mock_max_retries(fake_cloudbuild_client, fake_kms_client) + + self.mock_message_builder(fake_message_builder) + mock_invalid_finding(self.storage, self.uuid) + org_count = Organization.objects.filter(slug__startswith="testing").count() + + with BurstTaskRunner() as burst: + uploading_complete(self.relocation.uuid) + + with patch.object(LostPasswordHash, "send_relocate_account_email") as mock_relocation_email: + burst() + + assert mock_relocation_email.call_count == 0 + + assert fake_cloudbuild_client.create_build.call_count == MAX_FAST_TASK_ATTEMPTS + assert fake_cloudbuild_client.get_build.call_count == MAX_VALIDATION_POLL_ATTEMPTS + + assert fake_kms_client.asymmetric_decrypt.call_count == MAX_FAST_TASK_ATTEMPTS + assert fake_kms_client.get_public_key.call_count == 2 * MAX_FAST_TASK_ATTEMPTS + + assert fake_message_builder.call_count == 2 + email_types = [args.kwargs["type"] for args in fake_message_builder.call_args_list] + assert "relocation.started" in email_types + assert "relocation.failed" in email_types + assert "relocation.succeeded" not in email_types relocation = Relocation.objects.get(uuid=self.uuid) assert relocation.status == Relocation.Status.FAILURE.value assert relocation.failure_reason - assert Organization.objects.filter(slug__startswith="testing").count() == org_count - # TODO(getsentry/team-ospo#190): We should add "max retry" tests as well, but these are quite - # hard to mock in celery at the moment. We may need to use the mock sync celery test scheduler, - # rather than the "self.tasks()" approach above, to accomplish this. + self.assert_failure_database_state(org_count)