diff --git a/src/sentry/backup/exports.py b/src/sentry/backup/exports.py index b9e16a812c0d64..ebfdc342a72452 100644 --- a/src/sentry/backup/exports.py +++ b/src/sentry/backup/exports.py @@ -1,14 +1,9 @@ from __future__ import annotations import io -import tarfile from typing import BinaryIO, Type import click -from cryptography.fernet import Fernet -from cryptography.hazmat.backends import default_backend -from cryptography.hazmat.primitives import hashes, serialization -from cryptography.hazmat.primitives.asymmetric import padding from django.db.models.base import Model from sentry.backup.dependencies import ( @@ -17,7 +12,7 @@ get_model_name, sorted_dependencies, ) -from sentry.backup.helpers import Filter +from sentry.backup.helpers import Filter, create_encrypted_export_tarball from sentry.backup.scopes import ExportScope from sentry.services.hybrid_cloud.import_export.model import ( RpcExportError, @@ -47,7 +42,7 @@ def __init__(self, context: RpcExportError) -> None: def _export( - dest, + dest: BinaryIO, scope: ExportScope, *, encrypt_with: BinaryIO | None = None, @@ -151,47 +146,11 @@ def get_exporter_for_model(model: Type[Model]): dest_wrapper.detach() return - # Generate a new DEK (data encryption key), and use that DEK to encrypt the JSON being exported. - pem = encrypt_with.read() - data_encryption_key = Fernet.generate_key() - backup_encryptor = Fernet(data_encryption_key) - encrypted_json_export = backup_encryptor.encrypt(json.dumps(json_export).encode("utf-8")) - - # Encrypt the newly minted DEK using symmetric public key encryption. - dek_encryption_key = serialization.load_pem_public_key(pem, default_backend()) - sha256 = hashes.SHA256() - mgf = padding.MGF1(algorithm=sha256) - oaep_padding = padding.OAEP(mgf=mgf, algorithm=sha256, label=None) - encrypted_dek = dek_encryption_key.encrypt(data_encryption_key, oaep_padding) # type: ignore - - # Generate a tarball with 3 files: - # - # 1. The DEK we minted, name "data.key". - # 2. The public key we used to encrypt that DEK, named "key.pub". - # 3. The exported JSON data, encrypted with that DEK, named "export.json". - # - # The upshot: to decrypt the exported JSON data, you need the plaintext (decrypted) DEK. But to - # decrypt the DEK, you need the private key associated with the included public key, which - # you've hopefully kept in a safe, trusted location. - # - # Note that the supplied file names are load-bearing - ex, changing to `data.key` to `foo.key` - # risks breaking assumptions that the decryption side will make on the other end! - tar_buffer = io.BytesIO() - with tarfile.open(fileobj=tar_buffer, mode="w") as tar: - json_info = tarfile.TarInfo("export.json") - json_info.size = len(encrypted_json_export) - tar.addfile(json_info, fileobj=io.BytesIO(encrypted_json_export)) - key_info = tarfile.TarInfo("data.key") - key_info.size = len(encrypted_dek) - tar.addfile(key_info, fileobj=io.BytesIO(encrypted_dek)) - pub_info = tarfile.TarInfo("key.pub") - pub_info.size = len(pem) - tar.addfile(pub_info, fileobj=io.BytesIO(pem)) - dest.write(tar_buffer.getvalue()) + dest.write(create_encrypted_export_tarball(json_export, encrypt_with).getvalue()) def export_in_user_scope( - dest, + dest: BinaryIO, *, encrypt_with: BinaryIO | None = None, user_filter: set[str] | None = None, @@ -217,7 +176,7 @@ def export_in_user_scope( def export_in_organization_scope( - dest, + dest: BinaryIO, *, encrypt_with: BinaryIO | None = None, org_filter: set[str] | None = None, @@ -244,7 +203,7 @@ def export_in_organization_scope( def export_in_config_scope( - dest, + dest: BinaryIO, *, encrypt_with: BinaryIO | None = None, indent: int = 2, @@ -269,7 +228,7 @@ def export_in_config_scope( def export_in_global_scope( - dest, + dest: BinaryIO, *, encrypt_with: BinaryIO | None = None, indent: int = 2, diff --git a/src/sentry/backup/helpers.py b/src/sentry/backup/helpers.py index ae3c6f81db4724..3ed8e2d095b47e 100644 --- a/src/sentry/backup/helpers.py +++ b/src/sentry/backup/helpers.py @@ -1,14 +1,21 @@ from __future__ import annotations +import io +import tarfile from datetime import datetime, timedelta, timezone from enum import Enum from functools import lru_cache -from typing import Generic, NamedTuple, Type, TypeVar +from typing import BinaryIO, Generic, NamedTuple, Type, TypeVar +from cryptography.fernet import Fernet +from cryptography.hazmat.backends import default_backend +from cryptography.hazmat.primitives import hashes, serialization +from cryptography.hazmat.primitives.asymmetric import padding from django.core.serializers.json import DjangoJSONEncoder from django.db import models from sentry.backup.scopes import RelocationScope +from sentry.utils import json # Django apps we take care to never import or export from. EXCLUDED_APPS = frozenset(("auth", "contenttypes", "fixtures")) @@ -27,6 +34,110 @@ def default(self, obj): return super().default(obj) +def create_encrypted_export_tarball( + json_export: json.JSONData, encrypt_with: BinaryIO +) -> io.BytesIO: + """ + Generate a tarball with 3 files: + + 1. The DEK we minted, name "data.key". + 2. The public key we used to encrypt that DEK, named "key.pub". + 3. The exported JSON data, encrypted with that DEK, named "export.json". + + The upshot: to decrypt the exported JSON data, you need the plaintext (decrypted) DEK. But to + decrypt the DEK, you need the private key associated with the included public key, which + you've hopefully kept in a safe, trusted location. + + Note that the supplied file names are load-bearing - ex, changing to `data.key` to `foo.key` + risks breaking assumptions that the decryption side will make on the other end! + """ + + # Generate a new DEK (data encryption key), and use that DEK to encrypt the JSON being exported. + pem = encrypt_with.read() + data_encryption_key = Fernet.generate_key() + backup_encryptor = Fernet(data_encryption_key) + encrypted_json_export = backup_encryptor.encrypt(json.dumps(json_export).encode("utf-8")) + + # Encrypt the newly minted DEK using asymmetric public key encryption. + dek_encryption_key = serialization.load_pem_public_key(pem, default_backend()) + sha256 = hashes.SHA256() + mgf = padding.MGF1(algorithm=sha256) + oaep_padding = padding.OAEP(mgf=mgf, algorithm=sha256, label=None) + encrypted_dek = dek_encryption_key.encrypt(data_encryption_key, oaep_padding) # type: ignore + + # Generate the tarball and write it to to a new output stream. + tar_buffer = io.BytesIO() + with tarfile.open(fileobj=tar_buffer, mode="w") as tar: + json_info = tarfile.TarInfo("export.json") + json_info.size = len(encrypted_json_export) + tar.addfile(json_info, fileobj=io.BytesIO(encrypted_json_export)) + key_info = tarfile.TarInfo("data.key") + key_info.size = len(encrypted_dek) + tar.addfile(key_info, fileobj=io.BytesIO(encrypted_dek)) + pub_info = tarfile.TarInfo("key.pub") + pub_info.size = len(pem) + tar.addfile(pub_info, fileobj=io.BytesIO(pem)) + + return tar_buffer + + +def decrypt_encrypted_tarball(tarball: BinaryIO, decrypt_with: BinaryIO) -> str: + """ + A tarball encrypted by a call to `_export` with `encrypt_with` set has some specific properties (filenames, etc). This method handles all of those, and decrypts using the provided private key into an in-memory JSON string. + """ + + export = None + encrypted_dek = None + public_key_pem = None + private_key_pem = decrypt_with.read() + with tarfile.open(fileobj=tarball, mode="r") as tar: + for member in tar.getmembers(): + if member.isfile(): + file = tar.extractfile(member) + if file is None: + raise ValueError(f"Could not extract file for {member.name}") + + content = file.read() + if member.name == "export.json": + export = content.decode("utf-8") + elif member.name == "data.key": + encrypted_dek = content + elif member.name == "key.pub": + public_key_pem = content + else: + raise ValueError(f"Unknown tarball entity {member.name}") + + if export is None or encrypted_dek is None or public_key_pem is None: + raise ValueError("A required file was missing from the temporary test tarball") + + # Compare the public and private key, to ensure that they are a match. + private_key = serialization.load_pem_private_key( + private_key_pem, + password=None, + backend=default_backend(), + ) + generated_public_key_pem = private_key.public_key().public_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PublicFormat.SubjectPublicKeyInfo, + ) + if public_key_pem != generated_public_key_pem: + raise ValueError( + "The public key does not match that generated by the `decrypt_with` private key." + ) + + # Decrypt the DEK, then use it to decrypt the underlying JSON + decrypted_dek = private_key.decrypt( # type: ignore + encrypted_dek, + padding.OAEP( + mgf=padding.MGF1(algorithm=hashes.SHA256()), + algorithm=hashes.SHA256(), + label=None, + ), + ) + decryptor = Fernet(decrypted_dek) + return decryptor.decrypt(export).decode("utf-8") + + def get_final_derivations_of(model: Type) -> set[Type]: """A "final" derivation of the given `model` base class is any non-abstract class for the "sentry" app with `BaseModel` as an ancestor. Top-level calls to this class should pass in diff --git a/src/sentry/backup/imports.py b/src/sentry/backup/imports.py index 05267db8db7d01..2d9aef46dd1590 100644 --- a/src/sentry/backup/imports.py +++ b/src/sentry/backup/imports.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import Iterator, Optional, Tuple, Type +from typing import BinaryIO, Iterator, Optional, Tuple, Type import click from django.conf import settings @@ -11,7 +11,7 @@ from rest_framework.serializers import ValidationError as DjangoRestFrameworkValidationError from sentry.backup.dependencies import NormalizedModelName, PrimaryKeyMap, get_model, get_model_name -from sentry.backup.helpers import EXCLUDED_APPS, Filter, ImportFlags +from sentry.backup.helpers import EXCLUDED_APPS, Filter, ImportFlags, decrypt_encrypted_tarball from sentry.backup.scopes import ImportScope from sentry.silo import unguarded_write from sentry.utils import json @@ -25,9 +25,10 @@ def _import( - src, + src: BinaryIO, scope: ImportScope, *, + decrypt_with: BinaryIO | None = None, flags: ImportFlags | None = None, filter_by: Filter | None = None, printer=click.echo, @@ -49,7 +50,14 @@ def _import( org_model_name = get_model_name(Organization) org_member_model_name = get_model_name(OrganizationMember) - start = src.tell() + # TODO(getsentry#team-ospo/190): Reading the entire export into memory as a string is quite + # wasteful - in the future, we should explore chunking strategies to enable a smaller memory + # footprint when processing super large (>100MB) exports. + content = ( + decrypt_encrypted_tarball(src, decrypt_with) + if decrypt_with is not None + else src.read().decode("utf-8") + ) filters = [] if filter_by is not None: filters.append(filter_by) @@ -73,7 +81,7 @@ def _import( # deserializer does no such thing, and actually loads the entire JSON into memory! If we # don't want to choke on large imports, we'll need use a truly "chunkable" JSON # importing library like ijson for this. - for obj in serializers.deserialize("json", src, stream=True): + for obj in serializers.deserialize("json", content): o = obj.object model_name = get_model_name(o) if model_name == user_model_name: @@ -98,7 +106,7 @@ def _import( break elif filter_by.model == User: seen_first_user_model = False - for obj in serializers.deserialize("json", src, stream=True): + for obj in serializers.deserialize("json", content): o = obj.object model_name = get_model_name(o) if model_name == user_model_name: @@ -121,13 +129,11 @@ def _import( filters.append(email_filter) - src.seek(start) - # The input JSON blob should already be ordered by model kind. We simply break up 1 JSON blob # with N model kinds into N json blobs with 1 model kind each. def yield_json_models(src) -> Iterator[Tuple[NormalizedModelName, str]]: # TODO(getsentry#team-ospo/190): Better error handling for unparsable JSON. - models = json.load(src) + models = json.loads(content) last_seen_model_name: Optional[NormalizedModelName] = None batch: list[Type[Model]] = [] for model in models: @@ -228,8 +234,9 @@ def do_write(): def import_in_user_scope( - src, + src: BinaryIO, *, + decrypt_with: BinaryIO | None = None, flags: ImportFlags | None = None, user_filter: set[str] | None = None, printer=click.echo, @@ -246,6 +253,7 @@ def import_in_user_scope( return _import( src, ImportScope.User, + decrypt_with=decrypt_with, flags=flags, filter_by=Filter(User, "username", user_filter) if user_filter is not None else None, printer=printer, @@ -253,8 +261,9 @@ def import_in_user_scope( def import_in_organization_scope( - src, + src: BinaryIO, *, + decrypt_with: BinaryIO | None = None, flags: ImportFlags | None = None, org_filter: set[str] | None = None, printer=click.echo, @@ -275,6 +284,7 @@ def import_in_organization_scope( return _import( src, ImportScope.Organization, + decrypt_with=decrypt_with, flags=flags, filter_by=Filter(Organization, "slug", org_filter) if org_filter is not None else None, printer=printer, @@ -282,8 +292,9 @@ def import_in_organization_scope( def import_in_config_scope( - src, + src: BinaryIO, *, + decrypt_with: BinaryIO | None = None, flags: ImportFlags | None = None, user_filter: set[str] | None = None, printer=click.echo, @@ -305,13 +316,20 @@ def import_in_config_scope( return _import( src, ImportScope.Config, + decrypt_with=decrypt_with, flags=flags, filter_by=Filter(User, "username", user_filter) if user_filter is not None else None, printer=printer, ) -def import_in_global_scope(src, *, flags: ImportFlags | None = None, printer=click.echo): +def import_in_global_scope( + src: BinaryIO, + *, + decrypt_with: BinaryIO | None = None, + flags: ImportFlags | None = None, + printer=click.echo, +): """ Perform an import in the `Global` scope, meaning that all models will be imported from the provided source file. Because a `Global` import is really only useful when restoring to a fresh @@ -319,4 +337,10 @@ def import_in_global_scope(src, *, flags: ImportFlags | None = None, printer=cli superuser privileges are not sanitized. This method can be thought of as a "pure" backup/restore, simply serializing and deserializing a (partial) snapshot of the database state. """ - return _import(src, ImportScope.Global, flags=flags, printer=printer) + return _import( + src, + ImportScope.Global, + decrypt_with=decrypt_with, + flags=flags, + printer=printer, + ) diff --git a/src/sentry/runner/commands/backup.py b/src/sentry/runner/commands/backup.py index 117e7ad7aa582a..eb914553683b7a 100644 --- a/src/sentry/runner/commands/backup.py +++ b/src/sentry/runner/commands/backup.py @@ -15,6 +15,12 @@ from sentry.runner.decorators import configuration from sentry.utils import json +DECRYPT_WITH_HELP = """A path to a file containing a private key with which to decrypt a tarball + previously encrypted using an `export ... --encrypt_with=` command. + The private key provided via this flag should be the complement of the public + key used to encrypt the tarball (this public key is included in the tarball + itself).""" + ENCRYPT_WITH_HELP = """A path to the a public key with which to encrypt this export. If this flag is enabled and points to a valid key, the output file will be a tarball containing 3 constituent files: 1. An encrypted JSON file called @@ -105,6 +111,11 @@ def import_(): @import_.command(name="users") @click.argument("src", type=click.File("rb")) +@click.option( + "--decrypt_with", + type=click.File("rb"), + help=DECRYPT_WITH_HELP, +) @click.option( "--filter_usernames", default="", @@ -120,13 +131,14 @@ def import_(): ) @click.option("--silent", "-q", default=False, is_flag=True, help="Silence all debug output.") @configuration -def import_users(src, filter_usernames, merge_users, silent): +def import_users(src, decrypt_with, filter_usernames, merge_users, silent): """ Import the Sentry users from an exported JSON file. """ import_in_user_scope( src, + decrypt_with=decrypt_with, flags=ImportFlags(merge_users=merge_users), user_filter=parse_filter_arg(filter_usernames), printer=(lambda *args, **kwargs: None) if silent else click.echo, @@ -135,6 +147,11 @@ def import_users(src, filter_usernames, merge_users, silent): @import_.command(name="organizations") @click.argument("src", type=click.File("rb")) +@click.option( + "--decrypt_with", + type=click.File("rb"), + help=DECRYPT_WITH_HELP, +) @click.option( "--filter_org_slugs", default="", @@ -151,13 +168,14 @@ def import_users(src, filter_usernames, merge_users, silent): ) @click.option("--silent", "-q", default=False, is_flag=True, help="Silence all debug output.") @configuration -def import_organizations(src, filter_org_slugs, merge_users, silent): +def import_organizations(src, decrypt_with, filter_org_slugs, merge_users, silent): """ Import the Sentry organizations, and all constituent Sentry users, from an exported JSON file. """ import_in_organization_scope( src, + decrypt_with=decrypt_with, flags=ImportFlags(merge_users=merge_users), org_filter=parse_filter_arg(filter_org_slugs), printer=(lambda *args, **kwargs: None) if silent else click.echo, @@ -166,6 +184,11 @@ def import_organizations(src, filter_org_slugs, merge_users, silent): @import_.command(name="config") @click.argument("src", type=click.File("rb")) +@click.option( + "--decrypt_with", + type=click.File("rb"), + help=DECRYPT_WITH_HELP, +) @click.option("--silent", "-q", default=False, is_flag=True, help="Silence all debug output.") @click.option( "--merge_users", @@ -180,13 +203,14 @@ def import_organizations(src, filter_org_slugs, merge_users, silent): help=OVERWRITE_CONFIGS_HELP, ) @configuration -def import_config(src, merge_users, overwrite_configs, silent): +def import_config(src, decrypt_with, merge_users, overwrite_configs, silent): """ Import all configuration and administrator accounts needed to set up this Sentry instance. """ import_in_config_scope( src, + decrypt_with=decrypt_with, flags=ImportFlags(merge_users=merge_users, overwrite_configs=overwrite_configs), printer=(lambda *args, **kwargs: None) if silent else click.echo, ) @@ -194,6 +218,11 @@ def import_config(src, merge_users, overwrite_configs, silent): @import_.command(name="global") @click.argument("src", type=click.File("rb")) +@click.option( + "--decrypt_with", + type=click.File("rb"), + help=DECRYPT_WITH_HELP, +) @click.option( "--overwrite_configs", default=False, @@ -202,13 +231,14 @@ def import_config(src, merge_users, overwrite_configs, silent): ) @click.option("--silent", "-q", default=False, is_flag=True, help="Silence all debug output.") @configuration -def import_global(src, silent, overwrite_configs): +def import_global(src, decrypt_with, silent, overwrite_configs): """ Import all Sentry data from an exported JSON file. """ import_in_global_scope( src, + decrypt_with=decrypt_with, flags=ImportFlags(overwrite_configs=overwrite_configs), printer=(lambda *args, **kwargs: None) if silent else click.echo, ) diff --git a/src/sentry/testutils/helpers/backups.py b/src/sentry/testutils/helpers/backups.py index 3e7d01fc42f849..7ccbe7736642ea 100644 --- a/src/sentry/testutils/helpers/backups.py +++ b/src/sentry/testutils/helpers/backups.py @@ -1,7 +1,6 @@ from __future__ import annotations import io -import tarfile import tempfile from copy import deepcopy from datetime import datetime, timedelta @@ -10,10 +9,9 @@ from typing import Tuple from uuid import uuid4 -from cryptography.fernet import Fernet from cryptography.hazmat.backends import default_backend -from cryptography.hazmat.primitives import hashes, serialization -from cryptography.hazmat.primitives.asymmetric import padding, rsa +from cryptography.hazmat.primitives import serialization +from cryptography.hazmat.primitives.asymmetric import rsa from django.apps import apps from django.conf import settings from django.core.management import call_command @@ -30,6 +28,7 @@ export_in_user_scope, ) from sentry.backup.findings import ComparatorFindings +from sentry.backup.helpers import decrypt_encrypted_tarball from sentry.backup.imports import import_in_global_scope from sentry.backup.scopes import ExportScope from sentry.backup.validate import validate @@ -183,45 +182,8 @@ def export_to_encrypted_tarball( # Read the files in the generated tarball. This bit of code assume the file names, but that is # part of the encrypt/decrypt tar-ing API, so we need to ensure that these exact names are # present and contain the data we expect. - export = None - encrypted_dek = None - pub_key = None - with tarfile.open(tar_file_path, "r") as tar: - for member in tar.getmembers(): - if member.isfile(): - file = tar.extractfile(member) - if file is None: - raise AssertionError(f"Could not extract file for {member.name}") - - content = file.read() - if member.name == "export.json": - export = content.decode("utf-8") - elif member.name == "data.key": - encrypted_dek = content - elif member.name == "key.pub": - pub_key = content - else: - raise AssertionError(f"Unknown tarball entity {member.name}") - - if export is None or encrypted_dek is None or pub_key is None: - raise AssertionError("A required file was missing from the temporary test tarball") - - # Decrypt the DEK, then use it to decrypt the underlying JSON. - private_key = serialization.load_pem_private_key( - private_key_pem, - password=None, # Use the password here if the PEM was encrypted - backend=default_backend(), - ) - decrypted_dek = private_key.decrypt( # type: ignore - encrypted_dek, - padding.OAEP( - mgf=padding.MGF1(algorithm=hashes.SHA256()), - algorithm=hashes.SHA256(), - label=None, - ), - ) - decryptor = Fernet(decrypted_dek) - return json.loads(decryptor.decrypt(export)) + with open(tar_file_path, "rb") as f: + return json.loads(decrypt_encrypted_tarball(f, io.BytesIO(private_key_pem))) # No arguments, so we lazily cache the result after the first calculation. @@ -278,7 +240,7 @@ def import_export_then_validate(method_name: str, *, reset_pks: bool = True) -> clear_database(reset_pks=reset_pks) # Write the contents of the "expected" JSON file into the now clean database. - with open(tmp_expect) as tmp_file: + with open(tmp_expect, "rb") as tmp_file: import_in_global_scope(tmp_file, printer=NOOP_PRINTER) # Validate that the "expected" and "actual" JSON matches. @@ -320,7 +282,7 @@ def import_export_from_fixture_then_validate( fixture_file_path = get_fixture_path("backup", fixture_file_name) with open(fixture_file_path) as backup_file: expect = json.load(backup_file) - with open(fixture_file_path) as fixture_file: + with open(fixture_file_path, "rb") as fixture_file: import_in_global_scope(fixture_file, printer=NOOP_PRINTER) res = validate( diff --git a/tests/sentry/backup/test_exhaustive.py b/tests/sentry/backup/test_exhaustive.py index 213cb0b3fc9564..2340b6693f3032 100644 --- a/tests/sentry/backup/test_exhaustive.py +++ b/tests/sentry/backup/test_exhaustive.py @@ -58,9 +58,9 @@ def test_uniqueness_clean_pks(self): # Now import twice, so that all random values in the export (UUIDs etc) are identical, # to test that these are properly replaced and handled. - with open(tmp_expect) as tmp_file: + with open(tmp_expect, "rb") as tmp_file: import_in_global_scope(tmp_file, printer=NOOP_PRINTER) - with open(tmp_expect) as tmp_file: + with open(tmp_expect, "rb") as tmp_file: # Back-to-back global scope imports are disallowed (global scope assume a clean # database), so use organization scope instead. # @@ -82,9 +82,9 @@ def test_uniqueness_dirty_pks(self): # Now import twice, so that all random values in the export (UUIDs etc) are identical, # to test that these are properly replaced and handled. - with open(tmp_expect) as tmp_file: + with open(tmp_expect, "rb") as tmp_file: import_in_global_scope(tmp_file, printer=NOOP_PRINTER) - with open(tmp_expect) as tmp_file: + with open(tmp_expect, "rb") as tmp_file: # Back-to-back global scope imports are disallowed (global scope assume a clean # database), so use organization scope followed by config scope instead. import_in_organization_scope(tmp_file, printer=NOOP_PRINTER) diff --git a/tests/sentry/backup/test_imports.py b/tests/sentry/backup/test_imports.py index 270bfabfcce058..fdee7577855f59 100644 --- a/tests/sentry/backup/test_imports.py +++ b/tests/sentry/backup/test_imports.py @@ -1,15 +1,22 @@ from __future__ import annotations +import io +import tarfile import tempfile from copy import deepcopy from datetime import date, datetime from functools import cached_property from os import environ from pathlib import Path +from typing import Tuple from unittest.mock import patch import pytest import urllib3.exceptions +from cryptography.fernet import Fernet +from cryptography.hazmat.backends import default_backend +from cryptography.hazmat.primitives import hashes, serialization +from cryptography.hazmat.primitives.asymmetric import padding from django.utils import timezone from rest_framework.serializers import ValidationError @@ -48,6 +55,7 @@ BackupTestCase, clear_database, export_to_file, + generate_rsa_key_pair, ) from sentry.testutils.hybrid_cloud import use_split_dbs from sentry.utils import json @@ -119,7 +127,7 @@ def test_user_sanitized_in_user_scope(self): with tempfile.TemporaryDirectory() as tmp_dir: tmp_path = Path(tmp_dir).joinpath(f"{self._testMethodName}.json") self.generate_tmp_json_file(tmp_path) - with open(tmp_path) as tmp_file: + with open(tmp_path, "rb") as tmp_file: import_in_user_scope(tmp_file, printer=NOOP_PRINTER) assert User.objects.count() == 4 @@ -152,7 +160,7 @@ def test_user_sanitized_in_organization_scope(self): with tempfile.TemporaryDirectory() as tmp_dir: tmp_path = Path(tmp_dir).joinpath(f"{self._testMethodName}.json") self.generate_tmp_json_file(tmp_path) - with open(tmp_path) as tmp_file: + with open(tmp_path, "rb") as tmp_file: import_in_organization_scope(tmp_file, printer=NOOP_PRINTER) assert User.objects.count() == 4 @@ -185,7 +193,7 @@ def test_users_unsanitized_in_config_scope(self): with tempfile.TemporaryDirectory() as tmp_dir: tmp_path = Path(tmp_dir).joinpath(f"{self._testMethodName}.json") self.generate_tmp_json_file(tmp_path) - with open(tmp_path) as tmp_file: + with open(tmp_path, "rb") as tmp_file: import_in_config_scope(tmp_file, printer=NOOP_PRINTER) assert User.objects.count() == 4 @@ -225,7 +233,7 @@ def test_users_unsanitized_in_global_scope(self): with tempfile.TemporaryDirectory() as tmp_dir: tmp_path = Path(tmp_dir).joinpath(f"{self._testMethodName}.json") self.generate_tmp_json_file(tmp_path) - with open(tmp_path) as tmp_file: + with open(tmp_path, "rb") as tmp_file: import_in_global_scope(tmp_file, printer=NOOP_PRINTER) assert User.objects.count() == 4 @@ -271,7 +279,7 @@ def test_generate_suffix_for_already_taken_organization(self): # Note that we have created an organization with the same name as one we are about to # import. self.create_organization(owner=self.user, name="some-org") - with open(tmp_path) as tmp_file: + with open(tmp_path, "rb") as tmp_file: import_in_organization_scope(tmp_file, printer=NOOP_PRINTER) assert Organization.objects.count() == 2 @@ -295,7 +303,7 @@ def test_generate_suffix_for_already_taken_username(self): ) json.dump(same_username_user + copy_of_same_username_user, tmp_file) - with open(tmp_path) as tmp_file: + with open(tmp_path, "rb") as tmp_file: import_in_user_scope(tmp_file, printer=NOOP_PRINTER) assert User.objects.count() == 3 @@ -315,7 +323,7 @@ def test_bad_invalid_user(self): model["fields"]["username"] = "x" * 129 json.dump(models, tmp_file) - with open(tmp_path) as tmp_file: + with open(tmp_path, "rb") as tmp_file: with pytest.raises(ValidationError): import_in_user_scope(tmp_file, printer=NOOP_PRINTER) @@ -338,7 +346,7 @@ def test_good_regional_user_ip_in_user_scope(self, mock_geo_by_addr): model["fields"]["ip_address"] = "8.8.8.8" json.dump(models, tmp_file) - with open(tmp_path) as tmp_file: + with open(tmp_path, "rb") as tmp_file: import_in_user_scope(tmp_file, printer=NOOP_PRINTER) assert UserIP.objects.count() == 1 @@ -369,7 +377,7 @@ def test_good_regional_user_ip_in_global_scope(self, mock_geo_by_addr): model["fields"]["ip_address"] = "8.8.8.8" json.dump(models, tmp_file) - with open(tmp_path) as tmp_file: + with open(tmp_path, "rb") as tmp_file: import_in_global_scope(tmp_file, printer=NOOP_PRINTER) assert UserIP.objects.count() == 1 @@ -393,7 +401,7 @@ def test_bad_invalid_user_ip(self): m["fields"]["ip_address"] = "0.1.2.3.4.5.6.7.8.9.abc.def" json.dump(list(models), tmp_file) - with open(tmp_path) as tmp_file: + with open(tmp_path, "rb") as tmp_file: with pytest.raises(ValidationError): import_in_user_scope(tmp_file, printer=NOOP_PRINTER) @@ -409,7 +417,7 @@ def test_bad_invalid_user_option(self): m["fields"]["value"] = '"MiddleEarth/Gondor"' json.dump(list(models), tmp_file) - with open(tmp_path) as tmp_file: + with open(tmp_path, "rb") as tmp_file: with pytest.raises(ValidationError): import_in_user_scope(tmp_file, printer=NOOP_PRINTER) @@ -430,7 +438,7 @@ def test_import_signaling_user(self): with tempfile.TemporaryDirectory() as tmp_dir: tmp_path = self.export_to_tmp_file_and_clear_database(tmp_dir) - with open(tmp_path) as tmp_file: + with open(tmp_path, "rb") as tmp_file: import_in_user_scope(tmp_file, printer=NOOP_PRINTER) assert User.objects.count() == 1 @@ -450,7 +458,7 @@ def test_import_signaling_organization(self): with tempfile.TemporaryDirectory() as tmp_dir: tmp_path = self.export_to_tmp_file_and_clear_database(tmp_dir) - with open(tmp_path) as tmp_file: + with open(tmp_path, "rb") as tmp_file: import_in_organization_scope(tmp_file, printer=NOOP_PRINTER) assert Organization.objects.count() == 1 @@ -502,7 +510,7 @@ def test_user_import_scoping(self): with tempfile.TemporaryDirectory() as tmp_dir: tmp_path = self.export_to_tmp_file_and_clear_database(tmp_dir) - with open(tmp_path) as tmp_file: + with open(tmp_path, "rb") as tmp_file: import_in_user_scope(tmp_file, printer=NOOP_PRINTER) self.verify_model_inclusion(ImportScope.User) @@ -511,7 +519,7 @@ def test_organization_import_scoping(self): with tempfile.TemporaryDirectory() as tmp_dir: tmp_path = self.export_to_tmp_file_and_clear_database(tmp_dir) - with open(tmp_path) as tmp_file: + with open(tmp_path, "rb") as tmp_file: import_in_organization_scope(tmp_file, printer=NOOP_PRINTER) self.verify_model_inclusion(ImportScope.Organization) @@ -520,7 +528,7 @@ def test_config_import_scoping(self): with tempfile.TemporaryDirectory() as tmp_dir: tmp_path = self.export_to_tmp_file_and_clear_database(tmp_dir) - with open(tmp_path) as tmp_file: + with open(tmp_path, "rb") as tmp_file: import_in_config_scope(tmp_file, printer=NOOP_PRINTER) self.verify_model_inclusion(ImportScope.Config) @@ -529,11 +537,123 @@ def test_global_import_scoping(self): with tempfile.TemporaryDirectory() as tmp_dir: tmp_path = self.export_to_tmp_file_and_clear_database(tmp_dir) - with open(tmp_path) as tmp_file: + with open(tmp_path, "rb") as tmp_file: import_in_global_scope(tmp_file, printer=NOOP_PRINTER) self.verify_model_inclusion(ImportScope.Global) +class DecryptionTests(ImportTestCase): + """ + Ensures that decryption actually works. We only test one model for each scope, because it's + extremely unlikely that a failed decryption will leave only part of the data unmangled. + """ + + @staticmethod + def encrypt_json_fixture(tmp_dir) -> Tuple[Path, Path]: + good_file_path = get_fixture_path("backup", "fresh-install.json") + (priv_key_pem, pub_key_pem) = generate_rsa_key_pair() + + tmp_priv_key_path = Path(tmp_dir).joinpath("key") + with open(tmp_priv_key_path, "wb") as f: + f.write(priv_key_pem) + + tmp_pub_key_path = Path(tmp_dir).joinpath("key.pub") + with open(tmp_pub_key_path, "wb") as f: + f.write(pub_key_pem) + + with open(good_file_path) as f: + json_data = json.load(f) + + tmp_tarball_path = Path(tmp_dir).joinpath("input.tar") + with open(tmp_tarball_path, "wb") as i, open(tmp_pub_key_path, "rb") as p: + pem = p.read() + data_encryption_key = Fernet.generate_key() + backup_encryptor = Fernet(data_encryption_key) + encrypted_json_export = backup_encryptor.encrypt(json.dumps(json_data).encode("utf-8")) + + dek_encryption_key = serialization.load_pem_public_key(pem, default_backend()) + sha256 = hashes.SHA256() + mgf = padding.MGF1(algorithm=sha256) + oaep_padding = padding.OAEP(mgf=mgf, algorithm=sha256, label=None) + encrypted_dek = dek_encryption_key.encrypt(data_encryption_key, oaep_padding) # type: ignore + + tar_buffer = io.BytesIO() + with tarfile.open(fileobj=tar_buffer, mode="w") as tar: + json_info = tarfile.TarInfo("export.json") + json_info.size = len(encrypted_json_export) + tar.addfile(json_info, fileobj=io.BytesIO(encrypted_json_export)) + key_info = tarfile.TarInfo("data.key") + key_info.size = len(encrypted_dek) + tar.addfile(key_info, fileobj=io.BytesIO(encrypted_dek)) + pub_info = tarfile.TarInfo("key.pub") + pub_info.size = len(pem) + tar.addfile(pub_info, fileobj=io.BytesIO(pem)) + + i.write(tar_buffer.getvalue()) + + return (tmp_tarball_path, tmp_priv_key_path) + + def test_user_import_decryption(self): + with tempfile.TemporaryDirectory() as tmp_dir: + (tmp_tarball_path, tmp_priv_key_path) = self.encrypt_json_fixture(tmp_dir) + assert User.objects.count() == 0 + + with open(tmp_tarball_path, "rb") as tmp_tarball_file, open( + tmp_priv_key_path, "rb" + ) as tmp_priv_key_file: + import_in_user_scope( + tmp_tarball_file, decrypt_with=tmp_priv_key_file, printer=NOOP_PRINTER + ) + + assert User.objects.count() > 0 + + def test_organization_import_decryption(self): + with tempfile.TemporaryDirectory() as tmp_dir: + (tmp_tarball_path, tmp_priv_key_path) = self.encrypt_json_fixture(tmp_dir) + assert Organization.objects.count() == 0 + + with open(tmp_tarball_path, "rb") as tmp_tarball_file, open( + tmp_priv_key_path, "rb" + ) as tmp_priv_key_file: + import_in_organization_scope( + tmp_tarball_file, decrypt_with=tmp_priv_key_file, printer=NOOP_PRINTER + ) + + assert Organization.objects.count() > 0 + + def test_config_import_decryption(self): + with tempfile.TemporaryDirectory() as tmp_dir: + (tmp_tarball_path, tmp_priv_key_path) = self.encrypt_json_fixture(tmp_dir) + assert UserRole.objects.count() == 0 + + with open(tmp_tarball_path, "rb") as tmp_tarball_file, open( + tmp_priv_key_path, "rb" + ) as tmp_priv_key_file: + import_in_config_scope( + tmp_tarball_file, decrypt_with=tmp_priv_key_file, printer=NOOP_PRINTER + ) + + assert UserRole.objects.count() > 0 + + def test_global_import_decryption(self): + with tempfile.TemporaryDirectory() as tmp_dir: + (tmp_tarball_path, tmp_priv_key_path) = self.encrypt_json_fixture(tmp_dir) + assert Organization.objects.count() == 0 + assert User.objects.count() == 0 + assert UserRole.objects.count() == 0 + + with open(tmp_tarball_path, "rb") as tmp_tarball_file, open( + tmp_priv_key_path, "rb" + ) as tmp_priv_key_file: + import_in_global_scope( + tmp_tarball_file, decrypt_with=tmp_priv_key_file, printer=NOOP_PRINTER + ) + + assert Organization.objects.count() > 0 + assert User.objects.count() > 0 + assert UserRole.objects.count() > 0 + + class FilterTests(ImportTestCase): """ Ensures that filtering operations include the correct models. @@ -545,7 +665,7 @@ def test_import_filter_users(self): with tempfile.TemporaryDirectory() as tmp_dir: tmp_path = self.export_to_tmp_file_and_clear_database(tmp_dir) - with open(tmp_path) as tmp_file: + with open(tmp_path, "rb") as tmp_file: import_in_user_scope(tmp_file, user_filter={"user_2"}, printer=NOOP_PRINTER) # Count users, but also count a random model naively derived from just `User` alone, like @@ -568,7 +688,7 @@ def test_export_filter_users_shared_email(self): with tempfile.TemporaryDirectory() as tmp_dir: tmp_path = self.export_to_tmp_file_and_clear_database(tmp_dir) - with open(tmp_path) as tmp_file: + with open(tmp_path, "rb") as tmp_file: import_in_user_scope( tmp_file, user_filter={"user_1", "user_2", "user_3"}, printer=NOOP_PRINTER ) @@ -589,7 +709,7 @@ def test_import_filter_users_empty(self): with tempfile.TemporaryDirectory() as tmp_dir: tmp_path = self.export_to_tmp_file_and_clear_database(tmp_dir) - with open(tmp_path) as tmp_file: + with open(tmp_path, "rb") as tmp_file: import_in_user_scope(tmp_file, user_filter=set(), printer=NOOP_PRINTER) assert User.objects.count() == 0 @@ -610,7 +730,7 @@ def test_import_filter_orgs_single(self): with tempfile.TemporaryDirectory() as tmp_dir: tmp_path = self.export_to_tmp_file_and_clear_database(tmp_dir) - with open(tmp_path) as tmp_file: + with open(tmp_path, "rb") as tmp_file: import_in_organization_scope(tmp_file, org_filter={"org-b"}, printer=NOOP_PRINTER) assert Organization.objects.count() == 1 @@ -645,7 +765,7 @@ def test_import_filter_orgs_multiple(self): with tempfile.TemporaryDirectory() as tmp_dir: tmp_path = self.export_to_tmp_file_and_clear_database(tmp_dir) - with open(tmp_path) as tmp_file: + with open(tmp_path, "rb") as tmp_file: import_in_organization_scope( tmp_file, org_filter={"org-a", "org-c"}, printer=NOOP_PRINTER ) @@ -682,7 +802,7 @@ def test_import_filter_orgs_empty(self): with tempfile.TemporaryDirectory() as tmp_dir: tmp_path = self.export_to_tmp_file_and_clear_database(tmp_dir) - with open(tmp_path) as tmp_file: + with open(tmp_path, "rb") as tmp_file: import_in_organization_scope(tmp_file, org_filter=set(), printer=NOOP_PRINTER) assert Organization.objects.count() == 0 @@ -766,7 +886,7 @@ def test_colliding_api_token(self): == 1 ) - with open(tmp_path) as tmp_file: + with open(tmp_path, "rb") as tmp_file: import_in_global_scope(tmp_file, printer=NOOP_PRINTER) # Ensure that old tokens have not been mutated. @@ -786,7 +906,7 @@ def test_colliding_api_token(self): == 1 ) - with open(tmp_path) as tmp_file: + with open(tmp_path, "rb") as tmp_file: return json.load(tmp_file) @targets(mark(COLLISION_TESTED, Monitor)) @@ -811,13 +931,13 @@ def test_colliding_monitor(self): assert Monitor.objects.count() == 1 assert Monitor.objects.filter(guid=colliding.guid).count() == 1 - with open(tmp_path) as tmp_file: + with open(tmp_path, "rb") as tmp_file: import_in_organization_scope(tmp_file, printer=NOOP_PRINTER) assert Monitor.objects.count() == 2 assert Monitor.objects.filter(guid=colliding.guid).count() == 1 - with open(tmp_path) as tmp_file: + with open(tmp_path, "rb") as tmp_file: return json.load(tmp_file) @targets(mark(COLLISION_TESTED, OrgAuthToken)) @@ -850,7 +970,7 @@ def test_colliding_org_auth_token(self): == 1 ) - with open(tmp_path) as tmp_file: + with open(tmp_path, "rb") as tmp_file: import_in_organization_scope(tmp_file, printer=NOOP_PRINTER) assert OrgAuthToken.objects.count() == 2 @@ -862,7 +982,7 @@ def test_colliding_org_auth_token(self): == 1 ) - with open(tmp_path) as tmp_file: + with open(tmp_path, "rb") as tmp_file: return json.load(tmp_file) @targets(mark(COLLISION_TESTED, ProjectKey)) @@ -888,14 +1008,14 @@ def test_colliding_project_key(self): assert ProjectKey.objects.filter(public_key=colliding.public_key).count() == 1 assert ProjectKey.objects.filter(secret_key=colliding.secret_key).count() == 1 - with open(tmp_path) as tmp_file: + with open(tmp_path, "rb") as tmp_file: import_in_organization_scope(tmp_file, printer=NOOP_PRINTER) assert ProjectKey.objects.count() == 4 assert ProjectKey.objects.filter(public_key=colliding.public_key).count() == 1 assert ProjectKey.objects.filter(secret_key=colliding.secret_key).count() == 1 - with open(tmp_path) as tmp_file: + with open(tmp_path, "rb") as tmp_file: return json.load(tmp_file) @pytest.mark.xfail( @@ -940,7 +1060,7 @@ def test_colliding_query_subscription(self): == 1 ) - with open(tmp_path) as tmp_file: + with open(tmp_path, "rb") as tmp_file: import_in_organization_scope(tmp_file, printer=NOOP_PRINTER) assert SnubaQuery.objects.count() > 1 @@ -952,7 +1072,7 @@ def test_colliding_query_subscription(self): == 1 ) - with open(tmp_path) as tmp_file: + with open(tmp_path, "rb") as tmp_file: return json.load(tmp_file) @targets(mark(COLLISION_TESTED, ControlOption, Option, Relay, RelayUsage, UserRole)) @@ -995,7 +1115,7 @@ def test_colliding_configs_overwrite_configs_enabled_in_config_scope(self): assert RelayUsage.objects.count() == 1 assert UserRole.objects.count() == 1 - with open(tmp_path) as tmp_file: + with open(tmp_path, "rb") as tmp_file: import_in_config_scope( tmp_file, flags=ImportFlags(overwrite_configs=True), printer=NOOP_PRINTER ) @@ -1017,7 +1137,7 @@ def test_colliding_configs_overwrite_configs_enabled_in_config_scope(self): for i, actual_permission in enumerate(actual_user_role.permissions): assert actual_permission == old_user_role_permissions[i] - with open(tmp_path) as tmp_file: + with open(tmp_path, "rb") as tmp_file: return json.load(tmp_file) @targets(mark(COLLISION_TESTED, ControlOption, Option, Relay, RelayUsage, UserRole)) @@ -1056,7 +1176,7 @@ def test_colliding_configs_overwrite_configs_disabled_in_config_scope(self): assert RelayUsage.objects.count() == 1 assert UserRole.objects.count() == 1 - with open(tmp_path) as tmp_file: + with open(tmp_path, "rb") as tmp_file: import_in_config_scope( tmp_file, flags=ImportFlags(overwrite_configs=False), printer=NOOP_PRINTER ) @@ -1078,7 +1198,7 @@ def test_colliding_configs_overwrite_configs_disabled_in_config_scope(self): assert len(actual_user_role.permissions) == 1 assert actual_user_role.permissions[0] == "other.admin" - with open(tmp_path) as tmp_file: + with open(tmp_path, "rb") as tmp_file: return json.load(tmp_file) @targets(mark(COLLISION_TESTED, ControlOption, Option, Relay, RelayUsage, UserRole)) @@ -1121,7 +1241,7 @@ def test_colliding_configs_overwrite_configs_enabled_in_global_scope(self): assert RelayUsage.objects.count() == 1 assert UserRole.objects.count() == 1 - with open(tmp_path) as tmp_file: + with open(tmp_path, "rb") as tmp_file: import_in_global_scope( tmp_file, flags=ImportFlags(overwrite_configs=True), printer=NOOP_PRINTER ) @@ -1143,7 +1263,7 @@ def test_colliding_configs_overwrite_configs_enabled_in_global_scope(self): for i, actual_permission in enumerate(actual_user_role.permissions): assert actual_permission == old_user_role_permissions[i] - with open(tmp_path) as tmp_file: + with open(tmp_path, "rb") as tmp_file: return json.load(tmp_file) @targets(mark(COLLISION_TESTED, ControlOption, Option, Relay, RelayUsage, UserRole)) @@ -1182,7 +1302,7 @@ def test_colliding_configs_overwrite_configs_disabled_in_global_scope(self): assert RelayUsage.objects.count() == 1 assert UserRole.objects.count() == 1 - with open(tmp_path) as tmp_file: + with open(tmp_path, "rb") as tmp_file: import_in_global_scope( tmp_file, flags=ImportFlags(overwrite_configs=False), printer=NOOP_PRINTER ) @@ -1204,7 +1324,7 @@ def test_colliding_configs_overwrite_configs_disabled_in_global_scope(self): assert len(actual_user_role.permissions) == 1 assert actual_user_role.permissions[0] == "other.admin" - with open(tmp_path) as tmp_file: + with open(tmp_path, "rb") as tmp_file: return json.load(tmp_file) @targets(mark(COLLISION_TESTED, Email, User, UserEmail, UserIP)) @@ -1213,7 +1333,7 @@ def test_colliding_user_with_merging_enabled_in_user_scope(self): with tempfile.TemporaryDirectory() as tmp_dir: tmp_path = self.export_to_tmp_file_and_clear_database(tmp_dir) - with open(tmp_path) as tmp_file: + with open(tmp_path, "rb") as tmp_file: self.create_exhaustive_user(username="owner", email="existing@example.com") import_in_user_scope( tmp_file, @@ -1236,7 +1356,7 @@ def test_colliding_user_with_merging_enabled_in_user_scope(self): assert UserEmail.objects.filter(email__icontains="existing@").exists() assert not UserEmail.objects.filter(email__icontains="importing@").exists() - with open(tmp_path) as tmp_file: + with open(tmp_path, "rb") as tmp_file: return json.load(tmp_file) @targets(mark(COLLISION_TESTED, Email, User, UserEmail, UserIP)) @@ -1245,7 +1365,7 @@ def test_colliding_user_with_merging_disabled_in_user_scope(self): with tempfile.TemporaryDirectory() as tmp_dir: tmp_path = self.export_to_tmp_file_and_clear_database(tmp_dir) - with open(tmp_path) as tmp_file: + with open(tmp_path, "rb") as tmp_file: self.create_exhaustive_user(username="owner", email="existing@example.com") import_in_user_scope( tmp_file, @@ -1268,7 +1388,7 @@ def test_colliding_user_with_merging_disabled_in_user_scope(self): assert UserEmail.objects.filter(email__icontains="existing@").exists() assert UserEmail.objects.filter(email__icontains="importing@").exists() - with open(tmp_path) as tmp_file: + with open(tmp_path, "rb") as tmp_file: return json.load(tmp_file) @targets( @@ -1280,7 +1400,7 @@ def test_colliding_user_with_merging_enabled_in_organization_scope(self): with tempfile.TemporaryDirectory() as tmp_dir: tmp_path = self.export_to_tmp_file_and_clear_database(tmp_dir) - with open(tmp_path) as tmp_file: + with open(tmp_path, "rb") as tmp_file: owner = self.create_exhaustive_user(username="owner", email="existing@example.com") self.create_organization("some-org", owner=owner) import_in_organization_scope( @@ -1333,7 +1453,7 @@ def test_colliding_user_with_merging_enabled_in_organization_scope(self): == 1 ) - with open(tmp_path) as tmp_file: + with open(tmp_path, "rb") as tmp_file: return json.load(tmp_file) @targets( @@ -1345,7 +1465,7 @@ def test_colliding_user_with_merging_disabled_in_organization_scope(self): with tempfile.TemporaryDirectory() as tmp_dir: tmp_path = self.export_to_tmp_file_and_clear_database(tmp_dir) - with open(tmp_path) as tmp_file: + with open(tmp_path, "rb") as tmp_file: owner = self.create_exhaustive_user(username="owner", email="existing@example.com") self.create_organization("some-org", owner=owner) import_in_organization_scope( @@ -1403,7 +1523,7 @@ def test_colliding_user_with_merging_disabled_in_organization_scope(self): == 1 ) - with open(tmp_path) as tmp_file: + with open(tmp_path, "rb") as tmp_file: return json.load(tmp_file) @targets(mark(COLLISION_TESTED, Email, User, UserEmail, UserIP, UserPermission)) @@ -1412,7 +1532,7 @@ def test_colliding_user_with_merging_enabled_in_config_scope(self): with tempfile.TemporaryDirectory() as tmp_dir: tmp_path = self.export_to_tmp_file_and_clear_database(tmp_dir) - with open(tmp_path) as tmp_file: + with open(tmp_path, "rb") as tmp_file: self.create_exhaustive_user( username="owner", email="existing@example.com", is_admin=True ) @@ -1438,7 +1558,7 @@ def test_colliding_user_with_merging_enabled_in_config_scope(self): assert UserEmail.objects.filter(email__icontains="existing@").exists() assert not UserEmail.objects.filter(email__icontains="importing@").exists() - with open(tmp_path) as tmp_file: + with open(tmp_path, "rb") as tmp_file: return json.load(tmp_file) @targets(mark(COLLISION_TESTED, Email, User, UserEmail, UserIP, UserPermission)) @@ -1447,7 +1567,7 @@ def test_colliding_user_with_merging_disabled_in_config_scope(self): with tempfile.TemporaryDirectory() as tmp_dir: tmp_path = self.export_to_tmp_file_and_clear_database(tmp_dir) - with open(tmp_path) as tmp_file: + with open(tmp_path, "rb") as tmp_file: self.create_exhaustive_user( username="owner", email="existing@example.com", is_admin=True ) @@ -1473,7 +1593,7 @@ def test_colliding_user_with_merging_disabled_in_config_scope(self): assert UserEmail.objects.filter(email__icontains="existing@").exists() assert UserEmail.objects.filter(email__icontains="importing@").exists() - with open(tmp_path) as tmp_file: + with open(tmp_path, "rb") as tmp_file: return json.load(tmp_file) diff --git a/tests/sentry/runner/commands/test_backup.py b/tests/sentry/runner/commands/test_backup.py index b96c796d776400..b0fe9128b38839 100644 --- a/tests/sentry/runner/commands/test_backup.py +++ b/tests/sentry/runner/commands/test_backup.py @@ -6,6 +6,7 @@ from click.testing import CliRunner from django.db import IntegrityError +from sentry.backup.helpers import create_encrypted_export_tarball from sentry.runner.commands.backup import compare, export, import_ from sentry.silo.base import SiloMode from sentry.testutils.cases import TestCase, TransactionTestCase @@ -115,36 +116,53 @@ def test_user_scope_export_filter_usernames(self): cli_import_then_export("users", export_args=["--filter_usernames", "testing@example.com"]) +def cli_encrypted_import_then_export(scope: str): + with tempfile.TemporaryDirectory() as tmp_dir: + (priv_key_pem, pub_key_pem) = generate_rsa_key_pair() + + tmp_priv_key_path = Path(tmp_dir).joinpath("key") + with open(tmp_priv_key_path, "wb") as f: + f.write(priv_key_pem) + + tmp_pub_key_path = Path(tmp_dir).joinpath("key.pub") + with open(tmp_pub_key_path, "wb") as f: + f.write(pub_key_pem) + + with open(GOOD_FILE_PATH) as f: + data = json.load(f) + + tmp_input_path = Path(tmp_dir).joinpath("input.tar") + with open(tmp_input_path, "wb") as i, open(tmp_pub_key_path, "rb") as p: + i.write(create_encrypted_export_tarball(data, p).getvalue()) + + rv = CliRunner().invoke( + import_, [scope, str(tmp_input_path), "--decrypt_with", str(tmp_priv_key_path)] + ) + assert rv.exit_code == 0, rv.output + + tmp_output_path = Path(tmp_dir).joinpath("output.tar") + rv = CliRunner().invoke( + export, [scope, str(tmp_output_path), "--encrypt_with", str(tmp_pub_key_path)] + ) + assert rv.exit_code == 0, rv.output + + class GoodImportExportCommandEncryptionTests(TransactionTestCase): """ Ensure that encryption using an `--encrypt_with` file works as expected. """ - def encryption_export_args(self, tmp_dir) -> list[str]: - tmp_pub_key_path = Path(tmp_dir).joinpath("key.pub") - (_, public_key_pem) = generate_rsa_key_pair() - public_key_str = public_key_pem.decode("utf-8") - with open(tmp_pub_key_path, "w") as f: - f.write(public_key_str) - return ["--encrypt_with", str(tmp_pub_key_path)] - def test_global_scope_encryption(self): - with tempfile.TemporaryDirectory() as tmp_dir: - cli_import_then_export("global", export_args=self.encryption_export_args(tmp_dir)) + cli_encrypted_import_then_export("global") def test_config_scope_encryption(self): - with tempfile.TemporaryDirectory() as tmp_dir: - cli_import_then_export("config", export_args=self.encryption_export_args(tmp_dir)) + cli_encrypted_import_then_export("config") def test_organization_scope_encryption(self): - with tempfile.TemporaryDirectory() as tmp_dir: - cli_import_then_export( - "organizations", export_args=self.encryption_export_args(tmp_dir) - ) + cli_encrypted_import_then_export("organizations") def test_user_scope_encryption(self): - with tempfile.TemporaryDirectory() as tmp_dir: - cli_import_then_export("users", export_args=self.encryption_export_args(tmp_dir)) + cli_encrypted_import_then_export("users") class BadImportExportDomainErrorTests(TransactionTestCase):