From 9a749552582c544d1f3f6b50f3f3851cdea11e12 Mon Sep 17 00:00:00 2001 From: Alex Zaslavsky Date: Fri, 10 Nov 2023 11:28:19 -0800 Subject: [PATCH] feat(backup): Support import chunking With this feature in place, we now atomically record which models we imported in a given `import_by_model` call. This will be useful in the short term for implementing the post-processing import step, and in the long term to support rollbacks and partial import recovery. Issue: getsentry/team-ospo#203 Issue: getsentry/team-ospo#213 --- src/sentry/backup/dependencies.py | 33 +- src/sentry/backup/helpers.py | 6 + src/sentry/backup/imports.py | 62 +++- .../hybrid_cloud/import_export/impl.py | 123 ++++++- .../hybrid_cloud/import_export/model.py | 29 +- src/sentry/tasks/relocation.py | 4 +- tests/sentry/backup/test_imports.py | 317 +++++++++++++++++- tests/sentry/backup/test_rpc.py | 169 +++++++++- tests/sentry/tasks/test_relocation.py | 122 ++++--- 9 files changed, 773 insertions(+), 92 deletions(-) diff --git a/src/sentry/backup/dependencies.py b/src/sentry/backup/dependencies.py index 7710e23e29d25e..5a12c468bf4384 100644 --- a/src/sentry/backup/dependencies.py +++ b/src/sentry/backup/dependencies.py @@ -17,7 +17,9 @@ class NormalizedModelName: """ - A wrapper type that ensures that the contained model name has been properly normalized. A "normalized" model name is one that is identical to the name as it appears in an exported JSON backup, so a string of the form `{app_label.lower()}.{model_name.lower()}`. + A wrapper type that ensures that the contained model name has been properly normalized. A + "normalized" model name is one that is identical to the name as it appears in an exported JSON + backup, so a string of the form `{app_label.lower()}.{model_name.lower()}`. """ __model_name: str @@ -252,6 +254,13 @@ class PrimaryKeyMap: def __init__(self): self.mapping = defaultdict(dict) + def __len__(self): + count = 0 + for model_name_str, mappings in self.mapping.items(): + count += len(mappings) + + return count + def get_pk(self, model_name: NormalizedModelName, old: int) -> Optional[int]: """ Get the new, post-mapping primary key from an old primary key. @@ -276,7 +285,8 @@ def get_pks(self, model_name: NormalizedModelName) -> set[int]: def get_kind(self, model_name: NormalizedModelName, old: int) -> Optional[ImportKind]: """ - Is the mapped entry a newly inserted model, or an already existing one that has been merged in? + Is the mapped entry a newly inserted model, or an already existing one that has been merged + in? """ pk_map = self.mapping.get(str(model_name)) @@ -313,7 +323,8 @@ def insert( slug: str | None = None, ) -> None: """ - Create a new OLD_PK -> NEW_PK mapping for the given model. Models that contain unique slugs (organizations, projects, etc) can optionally store that information as well. + Create a new OLD_PK -> NEW_PK mapping for the given model. Models that contain unique slugs + (organizations, projects, etc) can optionally store that information as well. """ self.mapping[str(model_name)][old] = (new, kind, slug) @@ -327,18 +338,25 @@ def extend(self, other: PrimaryKeyMap) -> None: for old_pk, new_entry in mappings.items(): self.mapping[model_name_str][old_pk] = new_entry - def partition(self, model_names: set[NormalizedModelName]) -> PrimaryKeyMap: + def partition( + self, model_names: set[NormalizedModelName], kinds: set[ImportKind] | None = None + ) -> PrimaryKeyMap: """ - Create a new map with only the specified model kinds retained. + Create a new map with only the specified models and kinds retained. """ building = PrimaryKeyMap() + import_kinds = {k for k in ImportKind} if kinds is None else kinds for model_name_str, mappings in self.mapping.items(): model_name = NormalizedModelName(model_name_str) if model_name not in model_names: continue for old_pk, new_entry in mappings.items(): + (_, import_kind, _) = new_entry + if import_kind not in import_kinds: + continue + building.mapping[model_name_str][old_pk] = new_entry return building @@ -347,7 +365,10 @@ def partition(self, model_names: set[NormalizedModelName]) -> PrimaryKeyMap: # No arguments, so we lazily cache the result after the first calculation. @lru_cache(maxsize=1) def dependencies() -> dict[NormalizedModelName, ModelRelations]: - """Produce a dictionary mapping model type definitions to a `ModelDeps` describing their dependencies.""" + """ + Produce a dictionary mapping model type definitions to a `ModelDeps` describing their + dependencies. + """ from django.apps import apps diff --git a/src/sentry/backup/helpers.py b/src/sentry/backup/helpers.py index e343d675f8e93f..3b4f37377a8e3f 100644 --- a/src/sentry/backup/helpers.py +++ b/src/sentry/backup/helpers.py @@ -432,3 +432,9 @@ class ImportFlags(NamedTuple): # `key`) or `Relay` (as identified by its unique `relay_id`) already exists, should we overwrite # it with the new value, or keep the existing one and discard the incoming value instead? overwrite_configs: bool = False + + # A UUID with which to identify this import's `*ImportChunk` database entries. Useful for + # passing the calling `Relocation` model's UUID to all of the imports it triggered. If this flag + # is not provided, the import was called in a non-relocation context, like from the `sentry + # import` CLI command. + import_uuid: str | None = None diff --git a/src/sentry/backup/imports.py b/src/sentry/backup/imports.py index 023f46de1b8913..1b92aecb2eea70 100644 --- a/src/sentry/backup/imports.py +++ b/src/sentry/backup/imports.py @@ -1,6 +1,7 @@ from __future__ import annotations from typing import BinaryIO, Iterator, Optional, Tuple, Type +from uuid import uuid4 import click from django.core import serializers @@ -8,6 +9,7 @@ from django.db.models.base import Model from sentry.backup.dependencies import ( + ImportKind, NormalizedModelName, PrimaryKeyMap, dependencies, @@ -15,6 +17,7 @@ ) from sentry.backup.helpers import Decryptor, Filter, ImportFlags, decrypt_encrypted_tarball from sentry.backup.scopes import ImportScope +from sentry.models.importchunk import ControlImportChunkReplica from sentry.models.orgauthtoken import OrgAuthToken from sentry.services.hybrid_cloud.import_export.model import ( RpcFilter, @@ -73,6 +76,10 @@ def _import( raise RuntimeError(errText) flags = flags if flags is not None else ImportFlags() + if flags.import_uuid is None: + flags = flags._replace(import_uuid=uuid4().hex) + + deps = dependencies() user_model_name = get_model_name(User) org_auth_token_model_name = get_model_name(OrgAuthToken) org_member_model_name = get_model_name(OrganizationMember) @@ -107,8 +114,8 @@ def _import( if filter_by is not None: filters.append(filter_by) - # `sentry.Email` models don't have any explicit dependencies on `User`, so we need to find - # and record them manually. + # `sentry.Email` models don't have any explicit dependencies on `sentry.User`, so we need to + # find and record them manually. user_to_email = dict() if filter_by.model == Organization: @@ -199,14 +206,17 @@ def yield_json_models(content) -> Iterator[Tuple[NormalizedModelName, str]]: def do_write( pk_map: PrimaryKeyMap, model_name: NormalizedModelName, json_data: json.JSONData ) -> None: - model_relations = dependencies().get(model_name) + nonlocal scope, flags, filters, deps + + model_relations = deps.get(model_name) if not model_relations: return dep_models = {get_model_name(d) for d in model_relations.get_dependencies_for_relocation()} import_by_model = ImportExportService.get_importer_for_model(model_relations.model) + model_name_str = str(model_name) result = import_by_model( - model_name=str(model_name), + model_name=model_name_str, scope=RpcImportScope.into_rpc(scope), flags=RpcImportFlags.into_rpc(flags), filter_by=[RpcFilter.into_rpc(f) for f in filters], @@ -220,15 +230,55 @@ def do_write( warningText = ">> Are you restoring from a backup of the same version of Sentry?\n>> Are you restoring onto a clean database?\n>> If so then this IntegrityError might be our fault, you can open an issue here:\n>> https://github.com/getsentry/sentry/issues/new/choose" printer(warningText, err=True) raise ImportingError(result) - pk_map.extend(result.mapped_pks) + + out_pk_map: PrimaryKeyMap = result.mapped_pks.from_rpc() + pk_map.extend(out_pk_map) + + # If the model we just imported lives in the control silo, that means the import took place + # over RPC. To ensure that we have an accurate view of the import result in both sides of + # the RPC divide, we create a replica of the `ControlImportChunk` that successful import + # would have generated in the calling region as well. + if result.min_ordinal is not None and SiloMode.CONTROL in deps[model_name].silos: + # If `min_ordinal` is not null, these values must not be either. + assert result.max_ordinal is not None + assert result.min_source_pk is not None + assert result.max_source_pk is not None + + inserted = out_pk_map.partition({model_name}, {ImportKind.Inserted}).mapping[ + model_name_str + ] + existing = out_pk_map.partition({model_name}, {ImportKind.Existing}).mapping[ + model_name_str + ] + overwrite = out_pk_map.partition({model_name}, {ImportKind.Overwrite}).mapping[ + model_name_str + ] + control_import_chunk_replica = ControlImportChunkReplica( + import_uuid=flags.import_uuid, + model=model_name_str, + # TODO(getsentry/team-ospo#190): The next two fields assume the entire model is + # being imported in a single call; we may change this in the future. + min_ordinal=result.min_ordinal, + max_ordinal=result.max_ordinal, + min_source_pk=result.min_source_pk, + max_source_pk=result.max_source_pk, + min_inserted_pk=result.min_inserted_pk, + max_inserted_pk=result.max_inserted_pk, + inserted_map={k: v[0] for k, v in inserted.items()}, + existing_map={k: v[0] for k, v in existing.items()}, + overwrite_map={k: v[0] for k, v in overwrite.items()}, + inserted_identifiers={k: v[2] for k, v in inserted.items() if v[2] is not None}, + ) + control_import_chunk_replica.save() # Extract some write logic into its own internal function, so that we may call it irrespective # of how we do atomicity: on a per-model (if using multiple dbs) or global (if using a single # db) basis. def do_writes(pk_map: PrimaryKeyMap) -> None: + nonlocal deferred_org_auth_tokens + for model_name, json_data in yield_json_models(content): if model_name == org_auth_token_model_name: - nonlocal deferred_org_auth_tokens deferred_org_auth_tokens = json_data continue diff --git a/src/sentry/services/hybrid_cloud/import_export/impl.py b/src/sentry/services/hybrid_cloud/import_export/impl.py index e999ca4b19cc82..12651a9ea0ebbf 100644 --- a/src/sentry/services/hybrid_cloud/import_export/impl.py +++ b/src/sentry/services/hybrid_cloud/import_export/impl.py @@ -11,6 +11,7 @@ from django.core.serializers.base import DeserializationError from django.db import DatabaseError, IntegrityError, connections, router, transaction from django.db.models import Q +from django.forms import model_to_dict from rest_framework.serializers import ValidationError as DjangoRestFrameworkValidationError from sentry.backup.dependencies import ( @@ -24,6 +25,7 @@ from sentry.backup.findings import InstanceID from sentry.backup.helpers import EXCLUDED_APPS, DatetimeSafeDjangoJSONEncoder, Filter from sentry.backup.scopes import ExportScope +from sentry.models.importchunk import ControlImportChunk, RegionImportChunk from sentry.models.user import User from sentry.models.userpermission import UserPermission from sentry.models.userrole import UserRoleUser @@ -71,7 +73,7 @@ def import_by_model( pk_map: RpcPrimaryKeyMap, json_data: str, ) -> RpcImportResult: - import_flags = flags.from_rpc() + deps = dependencies() batch_model_name = NormalizedModelName(model_name) model = get_model(batch_model_name) if model is None: @@ -97,6 +99,14 @@ def import_by_model( reason="The RPC was called incorrectly, please set an `ImportScope` parameter", ) + import_flags = flags.from_rpc() + if import_flags.import_uuid is None: + return RpcImportError( + kind=RpcImportErrorKind.MissingImportUUID, + on=InstanceID(model_name), + reason="Must specify `import_uuid` when importing", + ) + import_scope = scope.from_rpc() in_pk_map = pk_map.from_rpc() filters: List[Filter] = [] @@ -107,12 +117,52 @@ def import_by_model( try: using = router.db_for_write(model) with transaction.atomic(using=using): + # It's possible that this write has already occurred, and we are simply retrying + # because the response got lost in transit. If so, just re-use that reply. We do + # this in the transaction because, while `import_by_model` is generally called in a + # sequential manner, cases like timeouts or long queues may cause a previous call to + # still be active when the next one is made. Doing this check inside the transaction + # lock ensures that the data is globally accurate and thwarts data races. + found_chunk = ( + ( + ControlImportChunk + if SiloMode.CONTROL in deps[batch_model_name].silos + else RegionImportChunk + ) + .objects.filter(import_uuid=flags.import_uuid, model=model_name) + .first() + ) + if found_chunk is not None: + found_data = model_to_dict(found_chunk) + out_pk_map = PrimaryKeyMap() + for old_pk, new_pk in found_data["inserted_map"].items(): + identifier = found_data["inserted_identifiers"].get(new_pk, None) + out_pk_map.insert( + batch_model_name, old_pk, new_pk, ImportKind.Inserted, identifier + ) + for old_pk, new_pk in found_data["existing_map"].items(): + out_pk_map.insert(batch_model_name, old_pk, new_pk, ImportKind.Existing) + for old_pk, new_pk in found_data["overwrite_map"].items(): + out_pk_map.insert(batch_model_name, old_pk, new_pk, ImportKind.Overwrite) + + return RpcImportOk( + mapped_pks=RpcPrimaryKeyMap.into_rpc(out_pk_map), + min_ordinal=found_data["min_ordinal"], + max_ordinal=found_data["max_ordinal"], + min_source_pk=found_data["min_source_pk"], + max_source_pk=found_data["max_source_pk"], + min_inserted_pk=found_data["min_inserted_pk"], + max_inserted_pk=found_data["max_inserted_pk"], + ) + ok_relocation_scopes = import_scope.value out_pk_map = PrimaryKeyMap() - max_pk = 0 + min_old_pk = 0 + max_old_pk = 0 + min_inserted_pk: Optional[int] = None + max_inserted_pk: Optional[int] = None counter = 0 for deserialized_object in deserialize("json", json_data, use_natural_keys=False): - counter += 1 model_instance = deserialized_object.object if model_instance._meta.app_label not in EXCLUDED_APPS or model_instance: if model_instance.get_possible_relocation_scopes() & ok_relocation_scopes: @@ -159,6 +209,7 @@ def import_by_model( # For models that may have circular references to themselves # (unlikely), keep track of the new pk in the input map as well. + counter += 1 new_pk, import_kind = written slug = getattr(model_instance, "slug", None) in_pk_map.insert( @@ -167,8 +218,17 @@ def import_by_model( out_pk_map.insert( inst_model_name, old_pk, new_pk, import_kind, slug ) - if new_pk > max_pk: - max_pk = new_pk + + # Do a little bit of book-keeping for our future `ImportChunk`. + if min_old_pk == 0: + min_old_pk = old_pk + if old_pk > max_old_pk: + max_old_pk = old_pk + if import_kind == ImportKind.Inserted: + if min_inserted_pk is None: + min_inserted_pk = new_pk + if max_inserted_pk is None or new_pk > max_inserted_pk: + max_inserted_pk = new_pk except DjangoValidationError as e: errs = {field: error for field, error in e.message_dict.items()} @@ -187,17 +247,54 @@ def import_by_model( reason=str(e), ) - # If we wrote at least one model, make sure to update the sequences too. - if counter > 0: - table = model_instance._meta.db_table - seq = f"{table}_id_seq" - with connections[using].cursor() as cursor: - cursor.execute(f"SELECT setval(%s, (SELECT MAX(id) FROM {table}))", [seq]) + # If we wrote at least one model, make sure to write an appropriate `ImportChunk` + # and update the sequences too. + if counter > 0: + table = model_instance._meta.db_table + seq = f"{table}_id_seq" + with connections[using].cursor() as cursor: + cursor.execute(f"SELECT setval(%s, (SELECT MAX(id) FROM {table}))", [seq]) + + inserted = out_pk_map.partition( + {batch_model_name}, {ImportKind.Inserted} + ).mapping[model_name] + existing = out_pk_map.partition( + {batch_model_name}, {ImportKind.Existing} + ).mapping[model_name] + overwrite = out_pk_map.partition( + {batch_model_name}, {ImportKind.Overwrite} + ).mapping[model_name] + import_chunk_args = { + "import_uuid": flags.import_uuid, + "model": model_name, + # TODO(getsentry/team-ospo#190): The next two fields assume the entire model + # is being imported in a single call; we may change this in the future. + "min_ordinal": 1, + "max_ordinal": counter, + "min_source_pk": min_old_pk, + "max_source_pk": max_old_pk, + "min_inserted_pk": min_inserted_pk, + "max_inserted_pk": max_inserted_pk, + "inserted_map": {k: v[0] for k, v in inserted.items()}, + "existing_map": {k: v[0] for k, v in existing.items()}, + "overwrite_map": {k: v[0] for k, v in overwrite.items()}, + "inserted_identifiers": { + k: v[2] for k, v in inserted.items() if v[2] is not None + }, + } + if SiloMode.CONTROL in deps[batch_model_name].silos: + ControlImportChunk(**import_chunk_args).save() + else: + RegionImportChunk(**import_chunk_args).save() return RpcImportOk( mapped_pks=RpcPrimaryKeyMap.into_rpc(out_pk_map), - max_pk=max_pk, - num_imported=counter, + min_ordinal=1, + max_ordinal=counter, + min_source_pk=min_old_pk, + max_source_pk=max_old_pk, + min_inserted_pk=min_inserted_pk, + max_inserted_pk=max_inserted_pk, ) except DeserializationError: diff --git a/src/sentry/services/hybrid_cloud/import_export/model.py b/src/sentry/services/hybrid_cloud/import_export/model.py index 1c635cf151eb49..f38a204135d654 100644 --- a/src/sentry/services/hybrid_cloud/import_export/model.py +++ b/src/sentry/services/hybrid_cloud/import_export/model.py @@ -55,7 +55,11 @@ def into_rpc(cls, base_filter: Filter) -> "RpcFilter": class RpcPrimaryKeyMap(RpcModel): """ - Shadows `sentry.backup.dependencies.PrimaryKeyMap` for the purpose of passing it over an RPC boundary. The primary difference between this class and the one it shadows is that the original `PrimaryKeyMap` uses `defaultdict` for ergonomics purposes, whereas this one uses a regular dict but provides no mutation methods - it is only intended for data interchange, and should be converted to and from `PrimaryKeyMap` immediately on either side of the RPC call. + Shadows `sentry.backup.dependencies.PrimaryKeyMap` for the purpose of passing it over an RPC + boundary. The primary difference between this class and the one it shadows is that the original + `PrimaryKeyMap` uses `defaultdict` for ergonomics purposes, whereas this one uses a regular dict + but provides no mutation methods - it is only intended for data interchange, and should be + converted to and from `PrimaryKeyMap` immediately on either side of the RPC call. """ # Pydantic duplicates global default models on a per-instance basis, so using `{}` here is safe. @@ -98,14 +102,21 @@ class RpcImportFlags(RpcModel): merge_users: bool = False overwrite_configs: bool = False + import_uuid: Optional[str] = None def from_rpc(self) -> ImportFlags: - return ImportFlags(merge_users=self.merge_users, overwrite_configs=self.overwrite_configs) + return ImportFlags( + merge_users=self.merge_users, + overwrite_configs=self.overwrite_configs, + import_uuid=self.import_uuid, + ) @classmethod def into_rpc(cls, base_flags: ImportFlags) -> "RpcImportFlags": return cls( - merge_users=base_flags.merge_users, overwrite_configs=base_flags.overwrite_configs + merge_users=base_flags.merge_users, + overwrite_configs=base_flags.overwrite_configs, + import_uuid=base_flags.import_uuid, ) @@ -119,6 +130,7 @@ class RpcImportErrorKind(str, Enum): DeserializationFailed = "DeserializationFailed" IncorrectSiloModeForModel = "IncorrectSiloModeForModel" IntegrityError = "IntegrityError" + MissingImportUUID = "MissingImportUUID" UnknownModel = "UnknownModel" UnexpectedModel = "UnexpectedModel" UnspecifiedScope = "UnspecifiedScope" @@ -159,8 +171,12 @@ class RpcImportOk(RpcModel): is_err: Literal[False] = False mapped_pks: RpcPrimaryKeyMap - max_pk: int = 0 - num_imported: int = 0 + min_ordinal: Optional[int] = None + max_ordinal: Optional[int] = None + min_source_pk: Optional[int] = None + max_source_pk: Optional[int] = None + min_inserted_pk: Optional[int] = None + max_inserted_pk: Optional[int] = None RpcImportResult = Annotated[Union[RpcImportOk, RpcImportError], Field(discriminator="is_err")] @@ -168,7 +184,8 @@ class RpcImportOk(RpcModel): class RpcExportScope(str, Enum): """ - Scope values are rendered as strings for JSON interchange, but can easily be mapped back to their set-based values when necessary. + Scope values are rendered as strings for JSON interchange, but can easily be mapped back to + their set-based values when necessary. """ User = "User" diff --git a/src/sentry/tasks/relocation.py b/src/sentry/tasks/relocation.py index fcc23f9dcc0500..8e5f3892aff6f6 100644 --- a/src/sentry/tasks/relocation.py +++ b/src/sentry/tasks/relocation.py @@ -940,7 +940,9 @@ def printer(text: str, *, err: bool = False, **kwargs) -> None: import_in_organization_scope( relocation_data_fp, decryptor=GCPKMSDecryptor(kms_config_fp), - flags=ImportFlags(merge_users=False, overwrite_configs=False), + flags=ImportFlags( + merge_users=False, overwrite_configs=False, import_uuid=str(uuid) + ), org_filter=set(relocation.want_org_slugs), printer=printer, ) diff --git a/tests/sentry/backup/test_imports.py b/tests/sentry/backup/test_imports.py index c06326a99b8258..193a1ceae5c9f6 100644 --- a/tests/sentry/backup/test_imports.py +++ b/tests/sentry/backup/test_imports.py @@ -16,7 +16,7 @@ from cryptography.hazmat.primitives.asymmetric import padding from django.utils import timezone -from sentry.backup.dependencies import NormalizedModelName +from sentry.backup.dependencies import NormalizedModelName, get_model_name from sentry.backup.helpers import ImportFlags, LocalFileDecryptor from sentry.backup.imports import ( ImportingError, @@ -29,6 +29,11 @@ from sentry.models.apitoken import DEFAULT_EXPIRATION, ApiToken, generate_token from sentry.models.authenticator import Authenticator from sentry.models.email import Email +from sentry.models.importchunk import ( + ControlImportChunk, + ControlImportChunkReplica, + RegionImportChunk, +) from sentry.models.lostpasswordhash import LostPasswordHash from sentry.models.options.option import ControlOption, Option from sentry.models.options.project_option import ProjectOption @@ -269,9 +274,18 @@ def test_generate_suffix_for_already_taken_organization(self): assert Organization.objects.count() == 2 assert Organization.objects.filter(slug__icontains="some-org").count() == 2 assert Organization.objects.filter(slug__iexact="some-org").count() == 1 + imported_organization = Organization.objects.get(slug__icontains="some-org-") assert imported_organization.id != existing_org.id + org_chunk = RegionImportChunk.objects.get( + model="sentry.organization", min_ordinal=1, max_ordinal=1 + ) + assert len(org_chunk.inserted_map) == 1 + assert len(org_chunk.inserted_identifiers) == 1 + for slug in org_chunk.inserted_identifiers.values(): + assert slug.startswith("some-org-") + with assume_test_silo_mode(SiloMode.CONTROL): assert ( OrganizationSlugReservation.objects.filter( @@ -607,6 +621,8 @@ class ScopingTests(ImportTestCase): def verify_model_inclusion(scope: ImportScope): """ Ensure all in-scope models are included, and that no out-of-scope models are included. + Additionally, we verify that each such model had an appropriate `*ImportChunk` written out + atomically alongside it. """ included_models = get_matching_exportable_models( lambda mr: len(mr.get_possible_relocation_scopes() & scope.value) > 0 @@ -617,11 +633,29 @@ def verify_model_inclusion(scope: ImportScope): ) for model in included_models: + model_name_str = str(get_model_name(model)) if is_control_model(model): + replica = ControlImportChunkReplica.objects.filter(model=model_name_str).first() + assert replica is not None + with assume_test_silo_mode(SiloMode.CONTROL): assert model.objects.count() > 0 + + control = ControlImportChunk.objects.filter(model=model_name_str).first() + assert control is not None + + # Ensure that the region-silo replica and the control-silo original are + # identical. + common_fields = {f.name for f in ControlImportChunk._meta.get_fields()} - { + "id", + "date_added", + "date_updated", + } + for field in common_fields: + assert getattr(replica, field, None) == getattr(control, field, None) else: assert model.objects.count() > 0 + assert RegionImportChunk.objects.filter(model=model_name_str).count() == 1 for model in excluded_models: if is_control_model(model): @@ -639,6 +673,12 @@ def test_user_import_scoping(self): import_in_user_scope(tmp_file, printer=NOOP_PRINTER) self.verify_model_inclusion(ImportScope.User) + # Test that the import UUID is auto-assigned properly. + with assume_test_silo_mode(SiloMode.CONTROL): + assert ControlImportChunk.objects.values("import_uuid").distinct().count() == 1 + + assert ControlImportChunkReplica.objects.values("import_uuid").distinct().count() == 1 + def test_organization_import_scoping(self): self.create_exhaustive_instance(is_superadmin=True) @@ -648,6 +688,17 @@ def test_organization_import_scoping(self): import_in_organization_scope(tmp_file, printer=NOOP_PRINTER) self.verify_model_inclusion(ImportScope.Organization) + # Test that the import UUID is auto-assigned properly. + with assume_test_silo_mode(SiloMode.CONTROL): + assert ControlImportChunk.objects.values("import_uuid").distinct().count() == 1 + + assert ControlImportChunkReplica.objects.values("import_uuid").distinct().count() == 1 + assert RegionImportChunk.objects.values("import_uuid").distinct().count() == 1 + assert ( + ControlImportChunkReplica.objects.values("import_uuid").first() + == RegionImportChunk.objects.values("import_uuid").first() + ) + def test_config_import_scoping(self): self.create_exhaustive_instance(is_superadmin=True) @@ -657,6 +708,17 @@ def test_config_import_scoping(self): import_in_config_scope(tmp_file, printer=NOOP_PRINTER) self.verify_model_inclusion(ImportScope.Config) + # Test that the import UUID is auto-assigned properly. + with assume_test_silo_mode(SiloMode.CONTROL): + assert ControlImportChunk.objects.values("import_uuid").distinct().count() == 1 + + assert ControlImportChunkReplica.objects.values("import_uuid").distinct().count() == 1 + assert RegionImportChunk.objects.values("import_uuid").distinct().count() == 1 + assert ( + ControlImportChunkReplica.objects.values("import_uuid").first() + == RegionImportChunk.objects.values("import_uuid").first() + ) + def test_global_import_scoping(self): self.create_exhaustive_instance(is_superadmin=True) @@ -666,6 +728,17 @@ def test_global_import_scoping(self): import_in_global_scope(tmp_file, printer=NOOP_PRINTER) self.verify_model_inclusion(ImportScope.Global) + # Test that the import UUID is auto-assigned properly. + with assume_test_silo_mode(SiloMode.CONTROL): + assert ControlImportChunk.objects.values("import_uuid").distinct().count() == 1 + + assert ControlImportChunkReplica.objects.values("import_uuid").distinct().count() == 1 + assert RegionImportChunk.objects.values("import_uuid").distinct().count() == 1 + assert ( + ControlImportChunkReplica.objects.values("import_uuid").first() + == RegionImportChunk.objects.values("import_uuid").first() + ) + # Filters should work identically in both silo and monolith modes, so no need to repeat the tests # here. @@ -825,6 +898,31 @@ def test_import_filter_users(self): assert UserEmail.objects.count() == 1 assert Email.objects.count() == 1 + assert ( + ControlImportChunk.objects.filter( + model="sentry.user", min_ordinal=1, max_ordinal=1 + ).count() + == 1 + ) + assert ( + ControlImportChunk.objects.filter( + model="sentry.userip", min_ordinal=1, max_ordinal=1 + ).count() + == 1 + ) + assert ( + ControlImportChunk.objects.filter( + model="sentry.useremail", min_ordinal=1, max_ordinal=1 + ).count() + == 1 + ) + assert ( + ControlImportChunk.objects.filter( + model="sentry.email", min_ordinal=1, max_ordinal=1 + ).count() + == 1 + ) + assert not User.objects.filter(username="user_1").exists() assert User.objects.filter(username="user_2").exists() @@ -847,6 +945,31 @@ def test_export_filter_users_shared_email(self): assert UserEmail.objects.count() == 3 assert Email.objects.count() == 2 # Lower due to shared emails + assert ( + ControlImportChunk.objects.filter( + model="sentry.user", min_ordinal=1, max_ordinal=3 + ).count() + == 1 + ) + assert ( + ControlImportChunk.objects.filter( + model="sentry.userip", min_ordinal=1, max_ordinal=3 + ).count() + == 1 + ) + assert ( + ControlImportChunk.objects.filter( + model="sentry.useremail", min_ordinal=1, max_ordinal=3 + ).count() + == 1 + ) + assert ( + ControlImportChunk.objects.filter( + model="sentry.email", min_ordinal=1, max_ordinal=2 + ).count() + == 1 + ) + assert User.objects.filter(username="user_1").exists() assert User.objects.filter(username="user_2").exists() assert User.objects.filter(username="user_3").exists() @@ -884,6 +1007,12 @@ def test_import_filter_orgs_single(self): import_in_organization_scope(tmp_file, org_filter={"org-b"}, printer=NOOP_PRINTER) assert Organization.objects.count() == 1 + assert ( + RegionImportChunk.objects.filter( + model="sentry.organization", min_ordinal=1, max_ordinal=1 + ).count() + == 1 + ) assert not Organization.objects.filter(slug="org-a").exists() assert Organization.objects.filter(slug="org-b").exists() @@ -923,6 +1052,12 @@ def test_import_filter_orgs_multiple(self): ) assert Organization.objects.count() == 2 + assert ( + RegionImportChunk.objects.filter( + model="sentry.organization", min_ordinal=1, max_ordinal=2 + ).count() + == 1 + ) assert Organization.objects.filter(slug="org-a").exists() assert not Organization.objects.filter(slug="org-b").exists() @@ -930,6 +1065,12 @@ def test_import_filter_orgs_multiple(self): with assume_test_silo_mode(SiloMode.CONTROL): assert OrgAuthToken.objects.count() == 2 + assert ( + ControlImportChunk.objects.filter( + model="sentry.orgauthtoken", min_ordinal=1, max_ordinal=2 + ).count() + == 1 + ) assert User.objects.count() == 5 assert UserIP.objects.count() == 5 @@ -1289,16 +1430,40 @@ def test_colliding_configs_overwrite_configs_enabled_in_config_scope(self): tmp_file, flags=ImportFlags(overwrite_configs=True), printer=NOOP_PRINTER ) + option_chunk = RegionImportChunk.objects.get( + model="sentry.option", min_ordinal=1, max_ordinal=1 + ) + assert len(option_chunk.inserted_map) == 0 + assert len(option_chunk.existing_map) == 0 + assert len(option_chunk.overwrite_map) == 1 assert Option.objects.count() == 1 assert Option.objects.filter(value__exact="a").exists() + relay_chunk = RegionImportChunk.objects.get( + model="sentry.relay", min_ordinal=1, max_ordinal=1 + ) + assert len(relay_chunk.inserted_map) == 0 + assert len(relay_chunk.existing_map) == 0 + assert len(relay_chunk.overwrite_map) == 1 assert Relay.objects.count() == 1 assert Relay.objects.filter(public_key__exact=old_relay_public_key).exists() + relay_usage_chunk = RegionImportChunk.objects.get( + model="sentry.relayusage", min_ordinal=1, max_ordinal=1 + ) + assert len(relay_usage_chunk.inserted_map) == 0 + assert len(relay_usage_chunk.existing_map) == 0 + assert len(relay_usage_chunk.overwrite_map) == 1 assert RelayUsage.objects.count() == 1 assert RelayUsage.objects.filter(public_key__exact=old_relay_usage_public_key).exists() with assume_test_silo_mode(SiloMode.CONTROL): + control_option_chunk = ControlImportChunk.objects.get( + model="sentry.controloption", min_ordinal=1, max_ordinal=1 + ) + assert len(control_option_chunk.inserted_map) == 0 + assert len(control_option_chunk.existing_map) == 0 + assert len(control_option_chunk.overwrite_map) == 1 assert ControlOption.objects.count() == 1 assert ControlOption.objects.filter(value__exact="b").exists() @@ -1355,16 +1520,40 @@ def test_colliding_configs_overwrite_configs_disabled_in_config_scope(self): tmp_file, flags=ImportFlags(overwrite_configs=False), printer=NOOP_PRINTER ) + option_chunk = RegionImportChunk.objects.get( + model="sentry.option", min_ordinal=1, max_ordinal=1 + ) + assert len(option_chunk.inserted_map) == 0 + assert len(option_chunk.existing_map) == 1 + assert len(option_chunk.overwrite_map) == 0 assert Option.objects.count() == 1 assert Option.objects.filter(value__exact="y").exists() + relay_chunk = RegionImportChunk.objects.get( + model="sentry.relay", min_ordinal=1, max_ordinal=1 + ) + assert len(relay_chunk.inserted_map) == 0 + assert len(relay_chunk.existing_map) == 1 + assert len(relay_chunk.overwrite_map) == 0 assert Relay.objects.count() == 1 assert Relay.objects.filter(public_key__exact="invalid").exists() + relay_usage_chunk = RegionImportChunk.objects.get( + model="sentry.relayusage", min_ordinal=1, max_ordinal=1 + ) + assert len(relay_usage_chunk.inserted_map) == 0 + assert len(relay_usage_chunk.existing_map) == 1 + assert len(relay_usage_chunk.overwrite_map) == 0 assert RelayUsage.objects.count() == 1 assert RelayUsage.objects.filter(public_key__exact="invalid").exists() with assume_test_silo_mode(SiloMode.CONTROL): + control_option_chunk = ControlImportChunk.objects.get( + model="sentry.controloption", min_ordinal=1, max_ordinal=1 + ) + assert len(control_option_chunk.inserted_map) == 0 + assert len(control_option_chunk.existing_map) == 1 + assert len(control_option_chunk.overwrite_map) == 0 assert ControlOption.objects.count() == 1 assert ControlOption.objects.filter(value__exact="z").exists() @@ -1425,16 +1614,40 @@ def test_colliding_configs_overwrite_configs_enabled_in_global_scope(self): tmp_file, flags=ImportFlags(overwrite_configs=True), printer=NOOP_PRINTER ) + option_chunk = RegionImportChunk.objects.get( + model="sentry.option", min_ordinal=1, max_ordinal=1 + ) + assert len(option_chunk.inserted_map) == 0 + assert len(option_chunk.existing_map) == 0 + assert len(option_chunk.overwrite_map) == 1 assert Option.objects.count() == 1 assert Option.objects.filter(value__exact="a").exists() + relay_chunk = RegionImportChunk.objects.get( + model="sentry.relay", min_ordinal=1, max_ordinal=1 + ) + assert len(relay_chunk.inserted_map) == 0 + assert len(relay_chunk.existing_map) == 0 + assert len(relay_chunk.overwrite_map) == 1 assert Relay.objects.count() == 1 assert Relay.objects.filter(public_key__exact=old_relay_public_key).exists() + relay_usage_chunk = RegionImportChunk.objects.get( + model="sentry.relayusage", min_ordinal=1, max_ordinal=1 + ) + assert len(relay_usage_chunk.inserted_map) == 0 + assert len(relay_usage_chunk.existing_map) == 0 + assert len(relay_usage_chunk.overwrite_map) == 1 assert RelayUsage.objects.count() == 1 assert RelayUsage.objects.filter(public_key__exact=old_relay_usage_public_key).exists() with assume_test_silo_mode(SiloMode.CONTROL): + control_option_chunk = ControlImportChunk.objects.get( + model="sentry.controloption", min_ordinal=1, max_ordinal=1 + ) + assert len(control_option_chunk.inserted_map) == 0 + assert len(control_option_chunk.existing_map) == 0 + assert len(control_option_chunk.overwrite_map) == 1 assert ControlOption.objects.count() == 1 assert ControlOption.objects.filter(value__exact="b").exists() @@ -1491,16 +1704,40 @@ def test_colliding_configs_overwrite_configs_disabled_in_global_scope(self): tmp_file, flags=ImportFlags(overwrite_configs=False), printer=NOOP_PRINTER ) + option_chunk = RegionImportChunk.objects.get( + model="sentry.option", min_ordinal=1, max_ordinal=1 + ) + assert len(option_chunk.inserted_map) == 0 + assert len(option_chunk.existing_map) == 1 + assert len(option_chunk.overwrite_map) == 0 assert Option.objects.count() == 1 assert Option.objects.filter(value__exact="y").exists() + relay_chunk = RegionImportChunk.objects.get( + model="sentry.relay", min_ordinal=1, max_ordinal=1 + ) + assert len(relay_chunk.inserted_map) == 0 + assert len(relay_chunk.existing_map) == 1 + assert len(relay_chunk.overwrite_map) == 0 assert Relay.objects.count() == 1 assert Relay.objects.filter(public_key__exact="invalid").exists() + relay_usage_chunk = RegionImportChunk.objects.get( + model="sentry.relayusage", min_ordinal=1, max_ordinal=1 + ) + assert len(relay_usage_chunk.inserted_map) == 0 + assert len(relay_usage_chunk.existing_map) == 1 + assert len(relay_usage_chunk.overwrite_map) == 0 assert RelayUsage.objects.count() == 1 assert RelayUsage.objects.filter(public_key__exact="invalid").exists() with assume_test_silo_mode(SiloMode.CONTROL): + control_option_chunk = ControlImportChunk.objects.get( + model="sentry.controloption", min_ordinal=1, max_ordinal=1 + ) + assert len(control_option_chunk.inserted_map) == 0 + assert len(control_option_chunk.existing_map) == 1 + assert len(control_option_chunk.overwrite_map) == 0 assert ControlOption.objects.count() == 1 assert ControlOption.objects.filter(value__exact="z").exists() @@ -1528,21 +1765,31 @@ def test_colliding_user_with_merging_enabled_in_user_scope(self): with assume_test_silo_mode(SiloMode.CONTROL): assert User.objects.count() == 1 - assert UserIP.objects.count() == 1 - assert UserEmail.objects.count() == 1 # UserEmail gets overwritten + assert UserEmail.objects.count() == 1 # Keep only original when merging. + assert UserIP.objects.count() == 1 # Keep only original when merging. assert Authenticator.objects.count() == 1 assert Email.objects.count() == 2 + user_chunk = ControlImportChunk.objects.get( + model="sentry.user", min_ordinal=1, max_ordinal=1 + ) + assert len(user_chunk.inserted_map) == 0 + assert len(user_chunk.existing_map) == 1 assert User.objects.filter(username__iexact="owner").exists() assert not User.objects.filter(username__iexact="owner-").exists() assert User.objects.filter(is_unclaimed=True).count() == 0 assert LostPasswordHash.objects.count() == 0 assert User.objects.filter(is_unclaimed=False).count() == 1 - assert UserEmail.objects.filter(email__icontains="existing@").exists() assert not UserEmail.objects.filter(email__icontains="importing@").exists() + # Incoming `UserEmail`s, `UserPermissions`, and `UserIP`s for imported users are + # completely scrubbed when merging is enabled. + assert not ControlImportChunk.objects.filter(model="sentry.useremail").exists() + assert not ControlImportChunk.objects.filter(model="sentry.userip").exists() + assert not ControlImportChunk.objects.filter(model="sentry.userpermission").exists() + with open(tmp_path, "rb") as tmp_file: return json.load(tmp_file) @@ -1567,6 +1814,11 @@ def test_colliding_user_with_merging_disabled_in_user_scope(self): assert Authenticator.objects.count() == 1 # Only imported in global scope assert Email.objects.count() == 2 + user_chunk = ControlImportChunk.objects.get( + model="sentry.user", min_ordinal=1, max_ordinal=1 + ) + assert len(user_chunk.inserted_map) == 1 + assert len(user_chunk.existing_map) == 0 assert User.objects.filter(username__iexact="owner").exists() assert User.objects.filter(username__icontains="owner-").exists() @@ -1574,6 +1826,11 @@ def test_colliding_user_with_merging_disabled_in_user_scope(self): assert LostPasswordHash.objects.count() == 1 assert User.objects.filter(is_unclaimed=False).count() == 1 + useremail_chunk = ControlImportChunk.objects.get( + model="sentry.useremail", min_ordinal=1, max_ordinal=1 + ) + assert len(useremail_chunk.inserted_map) == 1 + assert len(useremail_chunk.existing_map) == 0 assert UserEmail.objects.filter(email__icontains="existing@").exists() assert UserEmail.objects.filter(email__icontains="importing@").exists() @@ -1602,11 +1859,16 @@ def test_colliding_user_with_merging_enabled_in_organization_scope(self): user = User.objects.get(username="owner") assert User.objects.count() == 1 - assert UserIP.objects.count() == 1 - assert UserEmail.objects.count() == 1 # UserEmail gets overwritten + assert UserEmail.objects.count() == 1 # Keep only original when merging. + assert UserIP.objects.count() == 1 # Keep only original when merging. assert Authenticator.objects.count() == 1 # Only imported in global scope assert Email.objects.count() == 2 + user_chunk = ControlImportChunk.objects.get( + model="sentry.user", min_ordinal=1, max_ordinal=1 + ) + assert len(user_chunk.inserted_map) == 0 + assert len(user_chunk.existing_map) == 1 assert User.objects.filter(username__iexact="owner").exists() assert not User.objects.filter(username__icontains="owner-").exists() @@ -1617,6 +1879,12 @@ def test_colliding_user_with_merging_enabled_in_organization_scope(self): assert UserEmail.objects.filter(email__icontains="existing@").exists() assert not UserEmail.objects.filter(email__icontains="importing@").exists() + # Incoming `UserEmail`s, `UserPermissions`, and `UserIP`s for imported users are + # completely dropped when merging is enabled. + assert not ControlImportChunk.objects.filter(model="sentry.useremail").exists() + assert not ControlImportChunk.objects.filter(model="sentry.userip").exists() + assert not ControlImportChunk.objects.filter(model="sentry.userpermission").exists() + assert Organization.objects.count() == 2 assert OrganizationMember.objects.count() == 2 # Same user in both orgs @@ -1672,6 +1940,11 @@ def test_colliding_user_with_merging_disabled_in_organization_scope(self): assert Authenticator.objects.count() == 1 # Only imported in global scope assert Email.objects.count() == 2 + user_chunk = ControlImportChunk.objects.get( + model="sentry.user", min_ordinal=1, max_ordinal=1 + ) + assert len(user_chunk.inserted_map) == 1 + assert len(user_chunk.existing_map) == 0 assert User.objects.filter(username__iexact="owner").exists() assert User.objects.filter(username__icontains="owner-").exists() @@ -1679,6 +1952,11 @@ def test_colliding_user_with_merging_disabled_in_organization_scope(self): assert LostPasswordHash.objects.count() == 1 assert User.objects.filter(is_unclaimed=False).count() == 1 + useremail_chunk = ControlImportChunk.objects.get( + model="sentry.useremail", min_ordinal=1, max_ordinal=1 + ) + assert len(useremail_chunk.inserted_map) == 1 + assert len(useremail_chunk.existing_map) == 0 assert UserEmail.objects.filter(email__icontains="existing@").exists() assert UserEmail.objects.filter(email__icontains="importing@").exists() @@ -1731,12 +2009,17 @@ def test_colliding_user_with_merging_enabled_in_config_scope(self): with assume_test_silo_mode(SiloMode.CONTROL): assert User.objects.count() == 1 - assert UserIP.objects.count() == 1 - assert UserEmail.objects.count() == 1 # UserEmail gets overwritten - assert UserPermission.objects.count() == 1 + assert UserEmail.objects.count() == 1 # Keep only original when merging. + assert UserIP.objects.count() == 1 # Keep only original when merging. + assert UserPermission.objects.count() == 1 # Keep only original when merging. assert Authenticator.objects.count() == 1 assert Email.objects.count() == 2 + user_chunk = ControlImportChunk.objects.get( + model="sentry.user", min_ordinal=1, max_ordinal=1 + ) + assert len(user_chunk.inserted_map) == 0 + assert len(user_chunk.existing_map) == 1 assert User.objects.filter(username__iexact="owner").exists() assert not User.objects.filter(username__iexact="owner-").exists() @@ -1747,6 +2030,12 @@ def test_colliding_user_with_merging_enabled_in_config_scope(self): assert UserEmail.objects.filter(email__icontains="existing@").exists() assert not UserEmail.objects.filter(email__icontains="importing@").exists() + # Incoming `UserEmail`s, `UserPermissions`, and `UserIP`s for imported users are + # completely dropped when merging is enabled. + assert not ControlImportChunk.objects.filter(model="sentry.useremail").exists() + assert not ControlImportChunk.objects.filter(model="sentry.userip").exists() + assert not ControlImportChunk.objects.filter(model="sentry.userpermission").exists() + with open(tmp_path, "rb") as tmp_file: return json.load(tmp_file) @@ -1774,6 +2063,11 @@ def test_colliding_user_with_merging_disabled_in_config_scope(self): assert Authenticator.objects.count() == 1 # Only imported in global scope assert Email.objects.count() == 2 + user_chunk = ControlImportChunk.objects.get( + model="sentry.user", min_ordinal=1, max_ordinal=1 + ) + assert len(user_chunk.inserted_map) == 1 + assert len(user_chunk.existing_map) == 0 assert User.objects.filter(username__iexact="owner").exists() assert User.objects.filter(username__icontains="owner-").exists() @@ -1781,6 +2075,11 @@ def test_colliding_user_with_merging_disabled_in_config_scope(self): assert LostPasswordHash.objects.count() == 1 assert User.objects.filter(is_unclaimed=False).count() == 1 + useremail_chunk = ControlImportChunk.objects.get( + model="sentry.useremail", min_ordinal=1, max_ordinal=1 + ) + assert len(useremail_chunk.inserted_map) == 1 + assert len(useremail_chunk.existing_map) == 0 assert UserEmail.objects.filter(email__icontains="existing@").exists() assert UserEmail.objects.filter(email__icontains="importing@").exists() diff --git a/tests/sentry/backup/test_rpc.py b/tests/sentry/backup/test_rpc.py index 9fe9451beeb07d..54bc0821945f04 100644 --- a/tests/sentry/backup/test_rpc.py +++ b/tests/sentry/backup/test_rpc.py @@ -2,8 +2,11 @@ from copy import deepcopy from functools import cached_property +from uuid import uuid4 from sentry.backup.dependencies import NormalizedModelName, get_model_name +from sentry.models.importchunk import ControlImportChunk, RegionImportChunk +from sentry.models.options.option import ControlOption, Option from sentry.models.project import Project from sentry.models.user import MAX_USERNAME_LENGTH, User from sentry.services.hybrid_cloud.import_export import import_export_service @@ -14,6 +17,7 @@ RpcImportError, RpcImportErrorKind, RpcImportFlags, + RpcImportOk, RpcImportScope, RpcPrimaryKeyMap, ) @@ -23,12 +27,142 @@ from sentry.testutils.silo import assume_test_silo_mode from sentry.utils import json -USER_MODEL_NAME = get_model_name(User) +CONTROL_OPTION_MODEL_NAME = get_model_name(ControlOption) +OPTION_MODEL_NAME = get_model_name(Option) PROJECT_MODEL_NAME = get_model_name(Project) +USER_MODEL_NAME = get_model_name(User) + + +class RpcImportRetryTests(TestCase): + """ + Ensure that retries don't duplicate writes. + """ + + def test_good_local_retry_idempotent(self): + # If the response gets lost on the way to the caller, it will try again. Make sure it is + # clever enough to not try to write the data twice if its already been committed. + import_uuid = str(uuid4().hex) + + option_count = Option.objects.count() + import_chunk_count = RegionImportChunk.objects.count() + + def verify_option_write(): + nonlocal option_count, import_chunk_count, import_uuid + + result = import_export_service.import_by_model( + model_name="sentry.option", + scope=RpcImportScope.Global, + flags=RpcImportFlags(import_uuid=import_uuid), + filter_by=[], + pk_map=RpcPrimaryKeyMap(), + json_data=""" + [ + { + "model": "sentry.option", + "pk": 5, + "fields": { + "key": "foo", + "last_updated": "2023-06-22T00:00:00.000Z", + "last_updated_by": "unknown", + "value": "bar" + } + } + ] + """, + ) + + assert isinstance(result, RpcImportOk) + assert result.min_ordinal == 1 + assert result.max_ordinal == 1 + assert result.min_source_pk == 5 + assert result.max_source_pk == 5 + assert result.min_inserted_pk == result.max_inserted_pk + assert len(result.mapped_pks.from_rpc().mapping[str(OPTION_MODEL_NAME)]) == 1 + + assert Option.objects.count() == option_count + 1 + assert RegionImportChunk.objects.count() == import_chunk_count + 1 + + import_chunk = RegionImportChunk.objects.get(import_uuid=import_uuid) + assert import_chunk.min_ordinal == 1 + assert import_chunk.max_ordinal == 1 + assert import_chunk.min_source_pk == 5 + assert import_chunk.max_source_pk == 5 + assert import_chunk.min_inserted_pk == import_chunk.max_inserted_pk + assert len(import_chunk.inserted_map) == 1 + assert len(import_chunk.existing_map) == 0 + assert len(import_chunk.overwrite_map) == 0 + + # Doing the write twice should produce identical results from the sender's point of view, + # and should not result in multiple `RegionImportChunk`s being written. + verify_option_write() + verify_option_write() + + def test_good_remote_retry_idempotent(self): + # If the response gets lost on the way to the caller, it will try again. Make sure it is + # clever enough to not try to write the data twice if its already been committed. + import_uuid = str(uuid4().hex) + + with assume_test_silo_mode(SiloMode.CONTROL): + control_option_count = ControlOption.objects.count() + import_chunk_count = ControlImportChunk.objects.count() + + def verify_control_option_write(): + nonlocal control_option_count, import_chunk_count, import_uuid + + result = import_export_service.import_by_model( + model_name="sentry.controloption", + scope=RpcImportScope.Global, + flags=RpcImportFlags(import_uuid=import_uuid), + filter_by=[], + pk_map=RpcPrimaryKeyMap(), + json_data=""" + [ + { + "model": "sentry.controloption", + "pk": 7, + "fields": { + "key": "foo", + "last_updated": "2023-06-22T00:00:00.000Z", + "last_updated_by": "unknown", + "value": "bar" + } + } + ] + """, + ) + + assert isinstance(result, RpcImportOk) + assert result.min_ordinal == 1 + assert result.max_ordinal == 1 + assert result.min_source_pk == 7 + assert result.max_source_pk == 7 + assert result.min_inserted_pk == result.max_inserted_pk + assert len(result.mapped_pks.from_rpc().mapping[str(CONTROL_OPTION_MODEL_NAME)]) == 1 + + with assume_test_silo_mode(SiloMode.CONTROL): + assert ControlOption.objects.count() == control_option_count + 1 + assert ControlImportChunk.objects.count() == import_chunk_count + 1 + + import_chunk = ControlImportChunk.objects.get(import_uuid=import_uuid) + assert import_chunk.min_ordinal == 1 + assert import_chunk.max_ordinal == 1 + assert import_chunk.min_source_pk == 7 + assert import_chunk.max_source_pk == 7 + assert import_chunk.min_inserted_pk == import_chunk.max_inserted_pk + assert len(import_chunk.inserted_map) == 1 + assert len(import_chunk.existing_map) == 0 + assert len(import_chunk.overwrite_map) == 0 + + # Doing the write twice should produce identical results from the sender's point of view, + # and should not result in multiple `ControlImportChunk`s being written. + verify_control_option_write() + verify_control_option_write() class RpcImportErrorTests(TestCase): - """Validate errors related to the `import_by_model()` RPC method.""" + """ + Validate errors related to the `import_by_model()` RPC method. + """ @staticmethod def is_user_model(model: json.JSONData) -> bool: @@ -46,10 +180,10 @@ def test_bad_unknown_model(self): result = import_export_service.import_by_model( model_name="sentry.doesnotexist", scope=RpcImportScope.Global, - flags=RpcImportFlags(), + flags=RpcImportFlags(import_uuid=str(uuid4().hex)), filter_by=[], pk_map=RpcPrimaryKeyMap(), - json_data="", + json_data="[]", ) assert isinstance(result, RpcImportError) @@ -60,10 +194,10 @@ def test_bad_incorrect_silo_mode_for_model(self): result = import_export_service.import_by_model( model_name=str(PROJECT_MODEL_NAME), scope=RpcImportScope.Global, - flags=RpcImportFlags(), + flags=RpcImportFlags(import_uuid=str(uuid4().hex)), filter_by=[], pk_map=RpcPrimaryKeyMap(), - json_data="", + json_data="[]", ) assert isinstance(result, RpcImportError) @@ -72,22 +206,35 @@ def test_bad_incorrect_silo_mode_for_model(self): def test_bad_unspecified_scope(self): result = import_export_service.import_by_model( model_name=str(USER_MODEL_NAME), - flags=RpcImportFlags(), + flags=RpcImportFlags(import_uuid=str(uuid4().hex)), filter_by=[], pk_map=RpcPrimaryKeyMap(), - json_data="", + json_data="[]", ) assert isinstance(result, RpcImportError) assert result.get_kind() == RpcImportErrorKind.UnspecifiedScope - def test_bad_invalid_json(self): + def test_bad_missing_import_uuid(self): result = import_export_service.import_by_model( model_name=str(USER_MODEL_NAME), scope=RpcImportScope.Global, flags=RpcImportFlags(), filter_by=[], pk_map=RpcPrimaryKeyMap(), + json_data="[]", + ) + + assert isinstance(result, RpcImportError) + assert result.get_kind() == RpcImportErrorKind.MissingImportUUID + + def test_bad_invalid_json(self): + result = import_export_service.import_by_model( + model_name=str(USER_MODEL_NAME), + scope=RpcImportScope.Global, + flags=RpcImportFlags(import_uuid=str(uuid4().hex)), + filter_by=[], + pk_map=RpcPrimaryKeyMap(), json_data="_", ) @@ -106,7 +253,7 @@ def test_bad_validation(self): result = import_export_service.import_by_model( model_name=str(USER_MODEL_NAME), scope=RpcImportScope.Global, - flags=RpcImportFlags(), + flags=RpcImportFlags(import_uuid=str(uuid4().hex)), filter_by=[], pk_map=RpcPrimaryKeyMap(), json_data=json_data, @@ -121,7 +268,7 @@ def test_bad_unexpected_model(self): result = import_export_service.import_by_model( model_name="sentry.option", scope=RpcImportScope.Global, - flags=RpcImportFlags(), + flags=RpcImportFlags(import_uuid=str(uuid4().hex)), filter_by=[], pk_map=RpcPrimaryKeyMap(), json_data=json_data, diff --git a/tests/sentry/tasks/test_relocation.py b/tests/sentry/tasks/test_relocation.py index 9ece4d236a731b..2ca5ce8cb84ae5 100644 --- a/tests/sentry/tasks/test_relocation.py +++ b/tests/sentry/tasks/test_relocation.py @@ -22,6 +22,7 @@ ) from sentry.models.files.file import File from sentry.models.files.utils import get_storage +from sentry.models.importchunk import ControlImportChunk, RegionImportChunk from sentry.models.organization import Organization from sentry.models.relocation import ( Relocation, @@ -31,6 +32,7 @@ ValidationStatus, ) from sentry.models.user import User +from sentry.silo.base import SiloMode from sentry.tasks.relocation import ( ERR_PREPROCESSING_DECRYPTION, ERR_PREPROCESSING_INTERNAL, @@ -59,7 +61,7 @@ from sentry.testutils.cases import TestCase, TransactionTestCase from sentry.testutils.factories import get_fixture_path from sentry.testutils.helpers.backups import FakeKeyManagementServiceClient, generate_rsa_key_pair -from sentry.testutils.silo import region_silo_test +from sentry.testutils.silo import assume_test_silo_mode, region_silo_test from sentry.utils import json from sentry.utils.relocation import RELOCATION_BLOB_SIZE, RELOCATION_FILE_TYPE @@ -94,7 +96,7 @@ def setUp(self): file=self.file, kind=RelocationFile.Kind.RAW_USER_DATA.value, ) - self.uuid = self.relocation.uuid + self.uuid = str(self.relocation.uuid) @cached_property def file(self): @@ -178,7 +180,7 @@ def mock_cloudbuild_client( @region_silo_test class UploadingCompleteTest(RelocationTaskTestCase): def test_success(self, preprocessing_scan_mock: Mock): - uploading_complete(self.relocation.uuid) + uploading_complete(self.uuid) assert preprocessing_scan_mock.call_count == 1 @@ -187,7 +189,7 @@ def test_retry_if_attempts_left(self, preprocessing_scan_mock: Mock): # An exception being raised will trigger a retry in celery. with pytest.raises(Exception): - uploading_complete(self.relocation.uuid) + uploading_complete(self.uuid) relocation = Relocation.objects.get(uuid=self.uuid) assert relocation.status == Relocation.Status.IN_PROGRESS.value @@ -201,7 +203,7 @@ def test_fail_if_no_attempts_left(self, preprocessing_scan_mock: Mock): RelocationFile.objects.filter(relocation=self.relocation).delete() with pytest.raises(Exception): - uploading_complete(self.relocation.uuid) + uploading_complete(self.uuid) relocation = Relocation.objects.get(uuid=self.uuid) assert relocation.status == Relocation.Status.FAILURE.value @@ -428,7 +430,7 @@ def test_success( ): self.mock_kms_client(fake_kms_client) - preprocessing_baseline_config(self.relocation.uuid) + preprocessing_baseline_config(self.uuid) assert fake_kms_client.asymmetric_decrypt.call_count == 0 assert fake_kms_client.get_public_key.call_count == 1 @@ -523,7 +525,7 @@ def test_success( fake_kms_client: FakeKeyManagementServiceClient, ): self.mock_kms_client(fake_kms_client) - preprocessing_colliding_users(self.relocation.uuid) + preprocessing_colliding_users(self.uuid) assert preprocessing_complete_mock.call_count == 1 assert fake_kms_client.asymmetric_decrypt.call_count == 0 @@ -625,21 +627,19 @@ def setUp(self): assert file.blobs.count() > 1 # A bit bigger, so we get chunks. def test_success(self, validating_start_mock: Mock): - assert not self.storage.exists(f"relocations/runs/{self.relocation.uuid}") + assert not self.storage.exists(f"relocations/runs/{self.uuid}") - preprocessing_complete(self.relocation.uuid) + preprocessing_complete(self.uuid) self.relocation.refresh_from_db() assert validating_start_mock.call_count == 1 - (_, files) = self.storage.listdir(f"relocations/runs/{self.relocation.uuid}/conf") + (_, files) = self.storage.listdir(f"relocations/runs/{self.uuid}/conf") assert len(files) == 2 assert "cloudbuild.yaml" in files assert "cloudbuild.zip" in files - cb_yaml_file = self.storage.open( - f"relocations/runs/{self.relocation.uuid}/conf/cloudbuild.yaml" - ) + cb_yaml_file = self.storage.open(f"relocations/runs/{self.uuid}/conf/cloudbuild.yaml") with cb_yaml_file: cb_conf = yaml.safe_load(cb_yaml_file) assert cb_conf is not None @@ -648,8 +648,8 @@ def test_success(self, validating_start_mock: Mock): # separately then replace them for snapshotting. in_path = cb_conf["steps"][0]["args"][2] findings_path = cb_conf["artifacts"]["objects"]["location"] - assert in_path == f"gs://default/relocations/runs/{self.relocation.uuid}/in" - assert findings_path == f"gs://default/relocations/runs/{self.relocation.uuid}/findings/" + assert in_path == f"gs://default/relocations/runs/{self.uuid}/in" + assert findings_path == f"gs://default/relocations/runs/{self.uuid}/findings/" # Do a snapshot test of the cloudbuild config. cb_conf["steps"][0]["args"][2] = "gs:///relocations/runs//in" @@ -659,14 +659,14 @@ def test_success(self, validating_start_mock: Mock): cb_conf["steps"][12]["args"][3] = "gs:///relocations/runs//out" self.insta_snapshot(cb_conf) - (_, files) = self.storage.listdir(f"relocations/runs/{self.relocation.uuid}/in") + (_, files) = self.storage.listdir(f"relocations/runs/{self.uuid}/in") assert len(files) == 4 assert "kms-config.json" in files assert "raw-relocation-data.tar" in files assert "baseline-config.tar" in files assert "colliding-users.tar" in files - kms_file = self.storage.open(f"relocations/runs/{self.relocation.uuid}/in/kms-config.json") + kms_file = self.storage.open(f"relocations/runs/{self.uuid}/in/kms-config.json") with kms_file: json.load(kms_file) @@ -678,7 +678,7 @@ def test_retry_if_attempts_left(self, validating_start_mock: Mock): # An exception being raised will trigger a retry in celery. with pytest.raises(Exception): - preprocessing_complete(self.relocation.uuid) + preprocessing_complete(self.uuid) relocation = Relocation.objects.get(uuid=self.uuid) assert relocation.status == Relocation.Status.IN_PROGRESS.value @@ -692,7 +692,7 @@ def test_fail_if_no_attempts_left(self, validating_start_mock: Mock): RelocationFile.objects.filter(relocation=self.relocation).delete() with pytest.raises(Exception): - preprocessing_complete(self.relocation.uuid) + preprocessing_complete(self.uuid) relocation = Relocation.objects.get(uuid=self.uuid) assert relocation.status == Relocation.Status.FAILURE.value @@ -726,7 +726,7 @@ def test_success( ): self.mock_cloudbuild_client(fake_cloudbuild_client, Build.Status(Build.Status.QUEUED)) - validating_start(self.relocation.uuid) + validating_start(self.uuid) self.relocation.refresh_from_db() self.relocation_validation.refresh_from_db() @@ -750,7 +750,7 @@ def test_retry_if_attempts_left( self.mock_cloudbuild_client(fake_cloudbuild_client, Build.Status(Build.Status.QUEUED)) fake_cloudbuild_client.create_build.side_effect = Exception("Test") - validating_start(self.relocation.uuid) + validating_start(self.uuid) relocation = Relocation.objects.get(uuid=self.uuid) assert relocation.status == Relocation.Status.IN_PROGRESS.value @@ -769,7 +769,7 @@ def test_fail_if_no_attempts_left( fake_cloudbuild_client.create_build.side_effect = Exception("Test") with pytest.raises(Exception): - validating_start(self.relocation.uuid) + validating_start(self.uuid) relocation = Relocation.objects.get(uuid=self.uuid) assert relocation.status == Relocation.Status.FAILURE.value @@ -797,7 +797,7 @@ def test_fail_if_max_runs_attempted( self.mock_cloudbuild_client(fake_cloudbuild_client, Build.Status(Build.Status.QUEUED)) fake_cloudbuild_client.create_build.side_effect = Exception("Test") - validating_start(self.relocation.uuid) + validating_start(self.uuid) relocation = Relocation.objects.get(uuid=self.uuid) assert relocation.status == Relocation.Status.FAILURE.value @@ -839,7 +839,7 @@ def test_success( ): self.mock_cloudbuild_client(fake_cloudbuild_client, Build.Status(Build.Status.SUCCESS)) - validating_poll(self.relocation.uuid, self.relocation_validation_attempt.build_id) + validating_poll(self.uuid, self.relocation_validation_attempt.build_id) self.relocation.refresh_from_db() self.relocation_validation.refresh_from_db() self.relocation_validation_attempt.refresh_from_db() @@ -859,7 +859,7 @@ def test_timeout_starts_new_validation_attempt( self.mock_cloudbuild_client(fake_cloudbuild_client, Build.Status(stat)) validating_start_mock.call_count = 0 - validating_poll(self.relocation.uuid, self.relocation_validation_attempt.build_id) + validating_poll(self.uuid, self.relocation_validation_attempt.build_id) self.relocation.refresh_from_db() self.relocation_validation.refresh_from_db() self.relocation_validation_attempt.refresh_from_db() @@ -884,7 +884,7 @@ def test_failure_starts_new_validation_attempt( self.mock_cloudbuild_client(fake_cloudbuild_client, Build.Status(stat)) validating_start_mock.call_count = 0 - validating_poll(self.relocation.uuid, self.relocation_validation_attempt.build_id) + validating_poll(self.uuid, self.relocation_validation_attempt.build_id) self.relocation.refresh_from_db() self.relocation_validation.refresh_from_db() @@ -910,7 +910,7 @@ def test_in_progress_retries_poll( self.mock_cloudbuild_client(fake_cloudbuild_client, Build.Status(stat)) validating_poll_mock.call_count = 0 - validating_poll(self.relocation.uuid, self.relocation_validation_attempt.build_id) + validating_poll(self.uuid, self.relocation_validation_attempt.build_id) self.relocation.refresh_from_db() self.relocation_validation.refresh_from_db() @@ -937,7 +937,7 @@ def test_retry_if_attempts_left( self.mock_cloudbuild_client(fake_cloudbuild_client, Build.Status(Build.Status.QUEUED)) fake_cloudbuild_client.get_build.side_effect = Exception("Test") - validating_poll(self.relocation.uuid, self.relocation_validation_attempt.build_id) + validating_poll(self.uuid, self.relocation_validation_attempt.build_id) relocation = Relocation.objects.get(uuid=self.uuid) assert relocation.status == Relocation.Status.IN_PROGRESS.value @@ -954,7 +954,7 @@ def test_fail_if_no_attempts_left( fake_cloudbuild_client.get_build.side_effect = Exception("Test") with pytest.raises(Exception): - validating_poll(self.relocation.uuid, self.relocation_validation_attempt.build_id) + validating_poll(self.uuid, self.relocation_validation_attempt.build_id) relocation = Relocation.objects.get(uuid=self.uuid) assert relocation.status == Relocation.Status.FAILURE.value @@ -1009,7 +1009,7 @@ def setUp(self): self.storage = get_storage() self.storage.save( - f"relocations/runs/{self.relocation.uuid}/findings/artifacts-prefixes-are-ignored.json", + f"relocations/runs/{self.uuid}/findings/artifacts-prefixes-are-ignored.json", BytesIO(b"invalid-json"), ) files = [ @@ -1024,12 +1024,10 @@ def setUp(self): "compare-colliding-users.json", ] for file in files: - self.storage.save( - f"relocations/runs/{self.relocation.uuid}/findings/{file}", BytesIO(b"[]") - ) + self.storage.save(f"relocations/runs/{self.uuid}/findings/{file}", BytesIO(b"[]")) def test_valid(self, importing_mock: Mock): - validating_complete(self.relocation.uuid, self.relocation_validation_attempt.build_id) + validating_complete(self.uuid, self.relocation_validation_attempt.build_id) self.relocation.refresh_from_db() self.relocation_validation.refresh_from_db() @@ -1044,7 +1042,7 @@ def test_valid(self, importing_mock: Mock): def test_invalid(self, importing_mock: Mock): mock_invalid_finding(self.storage, self.uuid) - validating_complete(self.relocation.uuid, self.relocation_validation_attempt.build_id) + validating_complete(self.uuid, self.relocation_validation_attempt.build_id) self.relocation.refresh_from_db() self.relocation_validation.refresh_from_db() @@ -1061,11 +1059,11 @@ def test_retry_if_attempts_left(self, _: Mock): # An exception being raised will trigger a retry in celery. with pytest.raises(Exception): self.storage.save( - f"relocations/runs/{self.relocation.uuid}/findings/null.json", + f"relocations/runs/{self.uuid}/findings/null.json", BytesIO(b"invalid-json"), ) - validating_complete(self.relocation.uuid, self.relocation_validation_attempt.build_id) + validating_complete(self.uuid, self.relocation_validation_attempt.build_id) relocation = Relocation.objects.get(uuid=self.uuid) assert relocation.status == Relocation.Status.IN_PROGRESS.value @@ -1076,11 +1074,11 @@ def test_fail_if_no_attempts_left(self, _: Mock): self.relocation.latest_task_attempts = MAX_FAST_TASK_RETRIES self.relocation.save() self.storage.save( - f"relocations/runs/{self.relocation.uuid}/findings/null.json", BytesIO(b"invalid-json") + f"relocations/runs/{self.uuid}/findings/null.json", BytesIO(b"invalid-json") ) with pytest.raises(Exception): - validating_complete(self.relocation.uuid, self.relocation_validation_attempt.build_id) + validating_complete(self.uuid, self.relocation_validation_attempt.build_id) relocation = Relocation.objects.get(uuid=self.uuid) assert relocation.status == Relocation.Status.FAILURE.value @@ -1105,12 +1103,32 @@ def test_success(self, completed_mock: Mock, fake_kms_client: FakeKeyManagementS self.mock_kms_client(fake_kms_client) org_count = Organization.objects.filter(slug__startswith="testing").count() - importing(self.relocation.uuid) + importing(self.uuid) # TODO(getsentry/team-ospo#203): Should notify users instead. assert completed_mock.call_count == 1 assert Organization.objects.filter(slug__startswith="testing").count() == org_count + 1 + assert RegionImportChunk.objects.filter(import_uuid=self.uuid).count() == 9 + assert sorted(RegionImportChunk.objects.values_list("model", flat=True)) == [ + "sentry.organization", + "sentry.organizationmember", + "sentry.organizationmemberteam", + "sentry.project", + "sentry.projectkey", + "sentry.projectoption", + "sentry.projectteam", + "sentry.rule", + "sentry.team", + ] + + with assume_test_silo_mode(SiloMode.CONTROL): + assert ControlImportChunk.objects.filter(import_uuid=self.uuid).count() == 2 + assert sorted(ControlImportChunk.objects.values_list("model", flat=True)) == [ + "sentry.user", + "sentry.useremail", + ] + @patch( "sentry.backup.helpers.KeyManagementServiceClient", @@ -1160,6 +1178,26 @@ def test_valid_no_retries( assert not relocation.failure_reason assert Organization.objects.filter(slug__startswith="testing").count() == org_count + 1 + assert RegionImportChunk.objects.filter(import_uuid=self.uuid).count() == 9 + assert sorted(RegionImportChunk.objects.values_list("model", flat=True)) == [ + "sentry.organization", + "sentry.organizationmember", + "sentry.organizationmemberteam", + "sentry.project", + "sentry.projectkey", + "sentry.projectoption", + "sentry.projectteam", + "sentry.rule", + "sentry.team", + ] + + with assume_test_silo_mode(SiloMode.CONTROL): + assert ControlImportChunk.objects.filter(import_uuid=self.uuid).count() == 2 + assert sorted(ControlImportChunk.objects.values_list("model", flat=True)) == [ + "sentry.user", + "sentry.useremail", + ] + def test_invalid_no_retries( self, fake_cloudbuild_client: FakeCloudBuildClient, @@ -1177,3 +1215,7 @@ def test_invalid_no_retries( assert relocation.status == Relocation.Status.FAILURE.value assert relocation.failure_reason assert Organization.objects.filter(slug__startswith="testing").count() == org_count + + # TODO(getsentry/team-ospo#190): We should add "max retry" tests as well, but these are quite + # hard to mock in celery at the moment. We may need to use the mock sync celery test scheduler, + # rather than the "self.tasks()" approach above, to accomplish this.